dlc2action.transformer.heatmap
Kinematic transformer
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7""" 8Kinematic transformer 9""" 10 11from typing import Dict, Tuple, List 12from torchvision import transforms as tr 13from dlc2action.transformer.base_transformer import Transformer 14 15 16class HeatmapTransformer(Transformer): 17 """ 18 A transformer that augments the output of the Heatmap feature extractor 19 20 The available augmentations are `'rotate'`, `'horizontal_flip'`, `'vertical_flip'`. 21 """ 22 23 def __init__( 24 self, 25 model_name: str, 26 augmentations: List = None, 27 use_default_augmentations: bool = False, 28 rotation_degree_limits: List = None, 29 **kwargs 30 ) -> None: 31 """ 32 Parameters 33 ---------- 34 model_name : str 35 the name of the model used 36 augmentations : list, optional 37 list of augmentation names to use ("rotate", "mirror", "shift") 38 use_default_augmentations : bool, default False 39 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 40 rotation_degree_limits : list, default [-90, 90] 41 list of float rotation angle limits (`[low, high]`) 42 **kwargs : dict 43 other parameters for the base transformer class 44 """ 45 46 if augmentations is None: 47 augmentations = [] 48 if rotation_degree_limits is None: 49 rotation_degree_limits = [-90, 90] 50 super().__init__( 51 model_name, 52 augmentations, 53 use_default_augmentations, 54 graph_features=False, 55 bodyparts_order=None, 56 ) 57 self.rotation_limits = rotation_degree_limits 58 59 def _apply_transform( 60 self, transformation, main_input: Dict, ssl_inputs: List, ssl_targets: List 61 ) -> (Dict, List, List): 62 """ 63 Apply a `torchvision.transforms` transformation to the data 64 """ 65 66 dicts = [main_input] + ssl_inputs + ssl_targets 67 for x in dicts: 68 if x is not None: 69 keys = self._get_keys( 70 ("coords_heatmap", "motion_heatmap"), 71 x, 72 ) 73 for key in keys: 74 if key in x: 75 x[key] = transformation(x[key]) 76 return main_input, ssl_inputs, ssl_targets 77 78 def _horizontal_flip( 79 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 80 ) -> (Dict, List, List): 81 """ 82 Apply a random horizontal flip 83 """ 84 85 transform = tr.RandomHorizontalFlip() 86 return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets) 87 88 def _vertical_flip( 89 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 90 ) -> (Dict, List, List): 91 """ 92 Apply a random vertical flip 93 """ 94 95 transform = tr.RandomVerticalFlip() 96 return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets) 97 98 def _rotate( 99 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 100 ) -> (Dict, List, List): 101 """ 102 Apply a random rotation 103 """ 104 105 transform = tr.RandomRotation(self.rotation_limits) 106 return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets) 107 108 def _augmentations_dict(self) -> Dict: 109 """ 110 Get the mapping from augmentation names to functions 111 """ 112 113 return { 114 "rotate": self._rotate, 115 "horizontal_flip": self._horizontal_flip, 116 "vertical_flip": self._vertical_flip, 117 } 118 119 def _default_augmentations(self) -> List: 120 """ 121 Get the list of default augmentation names 122 """ 123 124 return ["horizontal_flip"] 125 126 def _get_bodyparts(self, shape: Tuple) -> None: 127 """ 128 Set the number of bodyparts from the data if it is not known 129 """ 130 131 if self.n_bodyparts is None: 132 N, B, F = shape 133 self.n_bodyparts = B // 2 134 135 def _get_keys(self, key_bases: Tuple, x: Dict) -> List: 136 """ 137 Get the keys of x that start with one of the strings from key_bases 138 """ 139 140 keys = [] 141 for key in x: 142 if key.startswith(key_bases): 143 keys.append(key) 144 return keys 145 146 def _apply_augmentations( 147 self, 148 main_input: Dict, 149 ssl_inputs: List = None, 150 ssl_targets: List = None, 151 augment: bool = False, 152 ssl: bool = False, 153 ) -> Tuple: 154 dicts = [main_input] + ssl_inputs + ssl_targets 155 key = self._get_keys( 156 ("coords_heatmap", "motion_heatmap"), 157 main_input, 158 )[0] 159 self.original_shape = main_input[key].shape 160 for x in dicts: 161 if x is not None: 162 keys = self._get_keys( 163 ("coords_heatmap", "motion_heatmap"), 164 x, 165 ) 166 for key in keys: 167 x[key] = x[key].reshape((-1, x[key].shape[-2], x[key].shape[-1])) 168 main_input, ssl_inputs, ssl_targets = super()._apply_augmentations( 169 main_input, ssl_inputs, ssl_targets, augment=augment, ssl=ssl 170 ) 171 dicts = [main_input] + ssl_inputs + ssl_targets 172 for x in dicts: 173 if x is not None: 174 keys = self._get_keys( 175 ("coords_heatmap", "motion_heatmap"), 176 x, 177 ) 178 for key in keys: 179 x[key] = x[key].reshape(self.original_shape) 180 return main_input, ssl_inputs, ssl_targets
17class HeatmapTransformer(Transformer): 18 """ 19 A transformer that augments the output of the Heatmap feature extractor 20 21 The available augmentations are `'rotate'`, `'horizontal_flip'`, `'vertical_flip'`. 22 """ 23 24 def __init__( 25 self, 26 model_name: str, 27 augmentations: List = None, 28 use_default_augmentations: bool = False, 29 rotation_degree_limits: List = None, 30 **kwargs 31 ) -> None: 32 """ 33 Parameters 34 ---------- 35 model_name : str 36 the name of the model used 37 augmentations : list, optional 38 list of augmentation names to use ("rotate", "mirror", "shift") 39 use_default_augmentations : bool, default False 40 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 41 rotation_degree_limits : list, default [-90, 90] 42 list of float rotation angle limits (`[low, high]`) 43 **kwargs : dict 44 other parameters for the base transformer class 45 """ 46 47 if augmentations is None: 48 augmentations = [] 49 if rotation_degree_limits is None: 50 rotation_degree_limits = [-90, 90] 51 super().__init__( 52 model_name, 53 augmentations, 54 use_default_augmentations, 55 graph_features=False, 56 bodyparts_order=None, 57 ) 58 self.rotation_limits = rotation_degree_limits 59 60 def _apply_transform( 61 self, transformation, main_input: Dict, ssl_inputs: List, ssl_targets: List 62 ) -> (Dict, List, List): 63 """ 64 Apply a `torchvision.transforms` transformation to the data 65 """ 66 67 dicts = [main_input] + ssl_inputs + ssl_targets 68 for x in dicts: 69 if x is not None: 70 keys = self._get_keys( 71 ("coords_heatmap", "motion_heatmap"), 72 x, 73 ) 74 for key in keys: 75 if key in x: 76 x[key] = transformation(x[key]) 77 return main_input, ssl_inputs, ssl_targets 78 79 def _horizontal_flip( 80 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 81 ) -> (Dict, List, List): 82 """ 83 Apply a random horizontal flip 84 """ 85 86 transform = tr.RandomHorizontalFlip() 87 return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets) 88 89 def _vertical_flip( 90 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 91 ) -> (Dict, List, List): 92 """ 93 Apply a random vertical flip 94 """ 95 96 transform = tr.RandomVerticalFlip() 97 return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets) 98 99 def _rotate( 100 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 101 ) -> (Dict, List, List): 102 """ 103 Apply a random rotation 104 """ 105 106 transform = tr.RandomRotation(self.rotation_limits) 107 return self._apply_transform(transform, main_input, ssl_inputs, ssl_targets) 108 109 def _augmentations_dict(self) -> Dict: 110 """ 111 Get the mapping from augmentation names to functions 112 """ 113 114 return { 115 "rotate": self._rotate, 116 "horizontal_flip": self._horizontal_flip, 117 "vertical_flip": self._vertical_flip, 118 } 119 120 def _default_augmentations(self) -> List: 121 """ 122 Get the list of default augmentation names 123 """ 124 125 return ["horizontal_flip"] 126 127 def _get_bodyparts(self, shape: Tuple) -> None: 128 """ 129 Set the number of bodyparts from the data if it is not known 130 """ 131 132 if self.n_bodyparts is None: 133 N, B, F = shape 134 self.n_bodyparts = B // 2 135 136 def _get_keys(self, key_bases: Tuple, x: Dict) -> List: 137 """ 138 Get the keys of x that start with one of the strings from key_bases 139 """ 140 141 keys = [] 142 for key in x: 143 if key.startswith(key_bases): 144 keys.append(key) 145 return keys 146 147 def _apply_augmentations( 148 self, 149 main_input: Dict, 150 ssl_inputs: List = None, 151 ssl_targets: List = None, 152 augment: bool = False, 153 ssl: bool = False, 154 ) -> Tuple: 155 dicts = [main_input] + ssl_inputs + ssl_targets 156 key = self._get_keys( 157 ("coords_heatmap", "motion_heatmap"), 158 main_input, 159 )[0] 160 self.original_shape = main_input[key].shape 161 for x in dicts: 162 if x is not None: 163 keys = self._get_keys( 164 ("coords_heatmap", "motion_heatmap"), 165 x, 166 ) 167 for key in keys: 168 x[key] = x[key].reshape((-1, x[key].shape[-2], x[key].shape[-1])) 169 main_input, ssl_inputs, ssl_targets = super()._apply_augmentations( 170 main_input, ssl_inputs, ssl_targets, augment=augment, ssl=ssl 171 ) 172 dicts = [main_input] + ssl_inputs + ssl_targets 173 for x in dicts: 174 if x is not None: 175 keys = self._get_keys( 176 ("coords_heatmap", "motion_heatmap"), 177 x, 178 ) 179 for key in keys: 180 x[key] = x[key].reshape(self.original_shape) 181 return main_input, ssl_inputs, ssl_targets
A transformer that augments the output of the Heatmap feature extractor
The available augmentations are 'rotate', 'horizontal_flip', 'vertical_flip'.
HeatmapTransformer( model_name: str, augmentations: List = None, use_default_augmentations: bool = False, rotation_degree_limits: List = None, **kwargs)
24 def __init__( 25 self, 26 model_name: str, 27 augmentations: List = None, 28 use_default_augmentations: bool = False, 29 rotation_degree_limits: List = None, 30 **kwargs 31 ) -> None: 32 """ 33 Parameters 34 ---------- 35 model_name : str 36 the name of the model used 37 augmentations : list, optional 38 list of augmentation names to use ("rotate", "mirror", "shift") 39 use_default_augmentations : bool, default False 40 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 41 rotation_degree_limits : list, default [-90, 90] 42 list of float rotation angle limits (`[low, high]`) 43 **kwargs : dict 44 other parameters for the base transformer class 45 """ 46 47 if augmentations is None: 48 augmentations = [] 49 if rotation_degree_limits is None: 50 rotation_degree_limits = [-90, 90] 51 super().__init__( 52 model_name, 53 augmentations, 54 use_default_augmentations, 55 graph_features=False, 56 bodyparts_order=None, 57 ) 58 self.rotation_limits = rotation_degree_limits
Parameters
model_name : str
the name of the model used
augmentations : list, optional
list of augmentation names to use ("rotate", "mirror", "shift")
use_default_augmentations : bool, default False
if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
rotation_degree_limits : list, default [-90, 90]
list of float rotation angle limits ([low, high])
**kwargs : dict
other parameters for the base transformer class