dlc2action.transformer.heatmap

Kinematic transformer

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""
  8Kinematic transformer
  9"""
 10
 11from typing import Dict, Tuple, List
 12from torchvision import transforms as tr
 13from dlc2action.transformer.base_transformer import Transformer
 14
 15
 16class HeatmapTransformer(Transformer):
 17    """
 18    A transformer that augments the output of the Heatmap feature extractor
 19
 20    The available augmentations are `'rotate'`, `'horizontal_flip'`, `'vertical_flip'`.
 21    """
 22
 23    def __init__(
 24        self,
 25        model_name: str,
 26        augmentations: List = None,
 27        use_default_augmentations: bool = False,
 28        rotation_degree_limits: List = None,
 29        **kwargs
 30    ) -> None:
 31        """
 32        Parameters
 33        ----------
 34        model_name : str
 35            the name of the model used
 36        augmentations : list, optional
 37            list of augmentation names to use ("rotate", "mirror", "shift")
 38        use_default_augmentations : bool, default False
 39            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 40        rotation_degree_limits : list, default [-90, 90]
 41            list of float rotation angle limits (`[low, high]`)
 42        **kwargs : dict
 43            other parameters for the base transformer class
 44        """
 45
 46        if augmentations is None:
 47            augmentations = []
 48        if rotation_degree_limits is None:
 49            rotation_degree_limits = [-90, 90]
 50        super().__init__(
 51            model_name,
 52            augmentations,
 53            use_default_augmentations,
 54            graph_features=False,
 55            bodyparts_order=None,
 56        )
 57        self.rotation_limits = rotation_degree_limits
 58
 59    def _apply_transform(
 60        self, transformation, main_input: Dict, ssl_inputs: List, ssl_targets: List
 61    ) -> (Dict, List, List):
 62        """
 63        Apply a `torchvision.transforms` transformation to the data
 64        """
 65
 66        dicts = [main_input] + ssl_inputs + ssl_targets
 67        for x in dicts:
 68            if x is not None:
 69                keys = self._get_keys(
 70                    ("coords_heatmap", "motion_heatmap"),
 71                    x,
 72                )
 73                for key in keys:
 74                    if key in x:
 75                        x[key] = transformation(x[key])
 76        return main_input, ssl_inputs, ssl_targets
 77
 78    def _horizontal_flip(
 79        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
 80    ) -> (Dict, List, List):
 81        """
 82        Apply a random horizontal flip
 83        """
 84
 85        transform = tr.RandomHorizontalFlip()
 86        return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets)
 87
 88    def _vertical_flip(
 89        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
 90    ) -> (Dict, List, List):
 91        """
 92        Apply a random vertical flip
 93        """
 94
 95        transform = tr.RandomVerticalFlip()
 96        return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets)
 97
 98    def _rotate(
 99        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
100    ) -> (Dict, List, List):
101        """
102        Apply a random rotation
103        """
104
105        transform = tr.RandomRotation(self.rotation_limits)
106        return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets)
107
108    def _augmentations_dict(self) -> Dict:
109        """
110        Get the mapping from augmentation names to functions
111        """
112
113        return {
114            "rotate": self._rotate,
115            "horizontal_flip": self._horizontal_flip,
116            "vertical_flip": self._vertical_flip,
117        }
118
119    def _default_augmentations(self) -> List:
120        """
121        Get the list of default augmentation names
122        """
123
124        return ["horizontal_flip"]
125
126    def _get_bodyparts(self, shape: Tuple) -> None:
127        """
128        Set the number of bodyparts from the data if it is not known
129        """
130
131        if self.n_bodyparts is None:
132            N, B, F = shape
133            self.n_bodyparts = B // 2
134
135    def _get_keys(self, key_bases: Tuple, x: Dict) -> List:
136        """
137        Get the keys of x that start with one of the strings from key_bases
138        """
139
140        keys = []
141        for key in x:
142            if key.startswith(key_bases):
143                keys.append(key)
144        return keys
145
146    def _apply_augmentations(
147        self,
148        main_input: Dict,
149        ssl_inputs: List = None,
150        ssl_targets: List = None,
151        augment: bool = False,
152        ssl: bool = False,
153    ) -> Tuple:
154        dicts = [main_input] + ssl_inputs + ssl_targets
155        key = self._get_keys(
156            ("coords_heatmap", "motion_heatmap"),
157            main_input,
158        )[0]
159        self.original_shape = main_input[key].shape
160        for x in dicts:
161            if x is not None:
162                keys = self._get_keys(
163                    ("coords_heatmap", "motion_heatmap"),
164                    x,
165                )
166                for key in keys:
167                    x[key] = x[key].reshape((-1, x[key].shape[-2], x[key].shape[-1]))
168        main_input, ssl_inputs, ssl_targets = super()._apply_augmentations(
169            main_input, ssl_inputs, ssl_targets, augment=augment, ssl=ssl
170        )
171        dicts = [main_input] + ssl_inputs + ssl_targets
172        for x in dicts:
173            if x is not None:
174                keys = self._get_keys(
175                    ("coords_heatmap", "motion_heatmap"),
176                    x,
177                )
178                for key in keys:
179                    x[key] = x[key].reshape(self.original_shape)
180        return main_input, ssl_inputs, ssl_targets
class HeatmapTransformer(dlc2action.transformer.base_transformer.Transformer):
 17class HeatmapTransformer(Transformer):
 18    """
 19    A transformer that augments the output of the Heatmap feature extractor
 20
 21    The available augmentations are `'rotate'`, `'horizontal_flip'`, `'vertical_flip'`.
 22    """
 23
 24    def __init__(
 25        self,
 26        model_name: str,
 27        augmentations: List = None,
 28        use_default_augmentations: bool = False,
 29        rotation_degree_limits: List = None,
 30        **kwargs
 31    ) -> None:
 32        """
 33        Parameters
 34        ----------
 35        model_name : str
 36            the name of the model used
 37        augmentations : list, optional
 38            list of augmentation names to use ("rotate", "mirror", "shift")
 39        use_default_augmentations : bool, default False
 40            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 41        rotation_degree_limits : list, default [-90, 90]
 42            list of float rotation angle limits (`[low, high]`)
 43        **kwargs : dict
 44            other parameters for the base transformer class
 45        """
 46
 47        if augmentations is None:
 48            augmentations = []
 49        if rotation_degree_limits is None:
 50            rotation_degree_limits = [-90, 90]
 51        super().__init__(
 52            model_name,
 53            augmentations,
 54            use_default_augmentations,
 55            graph_features=False,
 56            bodyparts_order=None,
 57        )
 58        self.rotation_limits = rotation_degree_limits
 59
 60    def _apply_transform(
 61        self, transformation, main_input: Dict, ssl_inputs: List, ssl_targets: List
 62    ) -> (Dict, List, List):
 63        """
 64        Apply a `torchvision.transforms` transformation to the data
 65        """
 66
 67        dicts = [main_input] + ssl_inputs + ssl_targets
 68        for x in dicts:
 69            if x is not None:
 70                keys = self._get_keys(
 71                    ("coords_heatmap", "motion_heatmap"),
 72                    x,
 73                )
 74                for key in keys:
 75                    if key in x:
 76                        x[key] = transformation(x[key])
 77        return main_input, ssl_inputs, ssl_targets
 78
 79    def _horizontal_flip(
 80        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
 81    ) -> (Dict, List, List):
 82        """
 83        Apply a random horizontal flip
 84        """
 85
 86        transform = tr.RandomHorizontalFlip()
 87        return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets)
 88
 89    def _vertical_flip(
 90        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
 91    ) -> (Dict, List, List):
 92        """
 93        Apply a random vertical flip
 94        """
 95
 96        transform = tr.RandomVerticalFlip()
 97        return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets)
 98
 99    def _rotate(
100        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
101    ) -> (Dict, List, List):
102        """
103        Apply a random rotation
104        """
105
106        transform = tr.RandomRotation(self.rotation_limits)
107        return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets)
108
109    def _augmentations_dict(self) -> Dict:
110        """
111        Get the mapping from augmentation names to functions
112        """
113
114        return {
115            "rotate": self._rotate,
116            "horizontal_flip": self._horizontal_flip,
117            "vertical_flip": self._vertical_flip,
118        }
119
120    def _default_augmentations(self) -> List:
121        """
122        Get the list of default augmentation names
123        """
124
125        return ["horizontal_flip"]
126
127    def _get_bodyparts(self, shape: Tuple) -> None:
128        """
129        Set the number of bodyparts from the data if it is not known
130        """
131
132        if self.n_bodyparts is None:
133            N, B, F = shape
134            self.n_bodyparts = B // 2
135
136    def _get_keys(self, key_bases: Tuple, x: Dict) -> List:
137        """
138        Get the keys of x that start with one of the strings from key_bases
139        """
140
141        keys = []
142        for key in x:
143            if key.startswith(key_bases):
144                keys.append(key)
145        return keys
146
147    def _apply_augmentations(
148        self,
149        main_input: Dict,
150        ssl_inputs: List = None,
151        ssl_targets: List = None,
152        augment: bool = False,
153        ssl: bool = False,
154    ) -> Tuple:
155        dicts = [main_input] + ssl_inputs + ssl_targets
156        key = self._get_keys(
157            ("coords_heatmap", "motion_heatmap"),
158            main_input,
159        )[0]
160        self.original_shape = main_input[key].shape
161        for x in dicts:
162            if x is not None:
163                keys = self._get_keys(
164                    ("coords_heatmap", "motion_heatmap"),
165                    x,
166                )
167                for key in keys:
168                    x[key] = x[key].reshape((-1, x[key].shape[-2], x[key].shape[-1]))
169        main_input, ssl_inputs, ssl_targets = super()._apply_augmentations(
170            main_input, ssl_inputs, ssl_targets, augment=augment, ssl=ssl
171        )
172        dicts = [main_input] + ssl_inputs + ssl_targets
173        for x in dicts:
174            if x is not None:
175                keys = self._get_keys(
176                    ("coords_heatmap", "motion_heatmap"),
177                    x,
178                )
179                for key in keys:
180                    x[key] = x[key].reshape(self.original_shape)
181        return main_input, ssl_inputs, ssl_targets

A transformer that augments the output of the Heatmap feature extractor

The available augmentations are 'rotate', 'horizontal_flip', 'vertical_flip'.

HeatmapTransformer( model_name: str, augmentations: List = None, use_default_augmentations: bool = False, rotation_degree_limits: List = None, **kwargs)
24    def __init__(
25        self,
26        model_name: str,
27        augmentations: List = None,
28        use_default_augmentations: bool = False,
29        rotation_degree_limits: List = None,
30        **kwargs
31    ) -> None:
32        """
33        Parameters
34        ----------
35        model_name : str
36            the name of the model used
37        augmentations : list, optional
38            list of augmentation names to use ("rotate", "mirror", "shift")
39        use_default_augmentations : bool, default False
40            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
41        rotation_degree_limits : list, default [-90, 90]
42            list of float rotation angle limits (`[low, high]`)
43        **kwargs : dict
44            other parameters for the base transformer class
45        """
46
47        if augmentations is None:
48            augmentations = []
49        if rotation_degree_limits is None:
50            rotation_degree_limits = [-90, 90]
51        super().__init__(
52            model_name,
53            augmentations,
54            use_default_augmentations,
55            graph_features=False,
56            bodyparts_order=None,
57        )
58        self.rotation_limits = rotation_degree_limits

Parameters

model_name : str the name of the model used augmentations : list, optional list of augmentation names to use ("rotate", "mirror", "shift") use_default_augmentations : bool, default False if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations rotation_degree_limits : list, default [-90, 90] list of float rotation angle limits ([low, high]) **kwargs : dict other parameters for the base transformer class

rotation_limits