dlc2action.transformer.base_transformer
Abstract parent class for transformers
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7""" 8Abstract parent class for transformers 9""" 10 11from typing import Dict, List, Callable, Union, Tuple 12import torch 13from abc import ABC, abstractmethod 14from dlc2action.utils import TensorList 15from copy import deepcopy 16from matplotlib import pyplot as plt 17 18 19class Transformer(ABC): 20 """ 21 A base class for all transformers 22 23 A transformer should apply augmentations and generate model input and training target tensors. 24 25 All augmentation functions need to take `(main_input: dict, ssl_inputs: list, ssl_targets: list)` 26 as input and return an output of the same format. Here `main_input` is a feature dictionary of the sample 27 data, `ssl_inputs` is a list of SSL input feature dictionaries and `ssl_targets` is a list of SSL target 28 feature dictionaries. The same augmentations are applied to all inputs and then `None` values are replaced 29 according to the rules set by `keep_target_none` and `generate_ssl_input` parameters and the feature 30 dictionaries are compiled into tensors. 31 """ 32 33 def __init__( 34 self, 35 model_name: str, 36 augmentations: List = None, 37 use_default_augmentations: bool = False, 38 generate_ssl_input: List = None, 39 keep_target_none: List = None, 40 ssl_augmentations: List = None, 41 graph_features: bool = False, 42 bodyparts_order: List = None, 43 ) -> None: 44 """ 45 Parameters 46 ---------- 47 augmentations : list, optional 48 a list of string names of augmentations to use (if not provided, either no augmentations are applied or 49 (if use_default_augmentations is True) a default list is used 50 use_default_augmentations : bool, default False 51 if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 52 generate_ssl_input : list, optional 53 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 54 is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to 55 `False` for each module) 56 keep_target_none : list, optional 57 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 58 is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults 59 to `True` for each module) 60 ssl_augmentations : list, optional 61 a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True) 62 (if not provided, defaults to the main augmentations list) 63 graph_features : bool, default False 64 if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)` 65 bodyparts_order : list, optional 66 a list of bodypart names, optional 67 """ 68 69 if augmentations is None: 70 augmentations = [] 71 if generate_ssl_input is None: 72 generate_ssl_input = [None] 73 if keep_target_none is None: 74 keep_target_none = [None] 75 self.model_name = model_name 76 self.augmentations = augmentations 77 self.generate_ssl_input = generate_ssl_input 78 self.keep_target_none = keep_target_none 79 if len(self.augmentations) == 0 and use_default_augmentations: 80 self.augmentations = self._default_augmentations() 81 if ssl_augmentations is None: 82 ssl_augmentations = self.augmentations 83 self.ssl_augmentations = ssl_augmentations 84 self.graph_features = graph_features 85 self.num_graph_nodes = ( 86 len(bodyparts_order) if bodyparts_order is not None else None 87 ) 88 self._check_augmentations(self.augmentations) 89 self._check_augmentations(self.ssl_augmentations) 90 91 def transform( 92 self, 93 main_input: Dict, 94 ssl_inputs: List = None, 95 ssl_targets: List = None, 96 augment: bool = False, 97 subsample: List = None, 98 ) -> Tuple: 99 """ 100 Apply augmentations and generate tensors from feature dictionaries. 101 102 The same augmentations are applied to all the inputs (if they have the features known to the transformer). 103 104 If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as 105 another augmentation of main_input. 106 Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`. 107 All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to 108 a model. 109 110 Parameters 111 ---------- 112 main_input : dict 113 the feature dictionary of the main input 114 ssl_inputs : list, optional 115 a list of feature dictionaries of SSL inputs (some or all can be None) 116 ssl_targets : list, optional 117 a list of feature dictionaries of SSL targets (some or all can be None) 118 augment : bool, default True 119 if True, augmentations are applied 120 subsample : list, optional 121 a list of indices to subsample the input tensors (if not provided, no subsampling is applied) 122 123 Returns 124 ------- 125 main_input : torch.Tensor 126 the augmented tensor of the main input 127 ssl_inputs : list, optional 128 a list of augmented tensors of SSL inputs (some or all can be None) 129 ssl_targets : list, optional 130 a list of augmented tensors of SSL targets (some or all can be None) 131 132 """ 133 if subsample is not None: 134 original_len = list(main_input.values())[0].shape[-1] 135 for key in main_input: 136 main_input[key] = main_input[key][..., subsample] 137 # subsample_ssl = sorted(random.sample(range(original_len), len(subsample))) 138 for x in ssl_inputs + ssl_targets: 139 if x is not None: 140 for key in x: 141 if len(x[key].shape) == 3 and x[key].shape[-1] == original_len: 142 x[key] = x[key][..., subsample] 143 main_input, ssl_inputs, ssl_targets = self._apply_augmentations( 144 main_input, ssl_inputs, ssl_targets, augment 145 ) 146 meta = [None for _ in ssl_inputs] 147 for i, x in enumerate(ssl_inputs): 148 if type(x) is tuple: 149 x, meta_x = x 150 meta[i] = meta_x 151 ssl_inputs[i] = x 152 for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input): 153 if ssl_x is None and generate: 154 ssl_inputs[i] = self._apply_augmentations( 155 deepcopy(main_input), None, None, augment=True, ssl=True 156 )[0] 157 output = [] 158 num_ssl = len(ssl_inputs) 159 dicts = [main_input] + ssl_inputs + ssl_targets 160 161 for x in dicts: 162 if x is None: 163 output.append(None) 164 else: 165 output.append(self._make_tensor(x, self.model_name)) 166 # output.append( 167 # torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 168 # ) 169 main_input, ssl_inputs, ssl_targets = ( 170 output[0], 171 output[1 : num_ssl + 1], 172 output[num_ssl + 1 :], 173 ) 174 for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none): 175 if not keep and ssl_x is None: 176 ssl_targets[i] = main_input 177 for i, meta_x in enumerate(meta): 178 if meta_x is not None: 179 ssl_inputs[i] = (ssl_inputs[i], meta_x) 180 return main_input, ssl_inputs, ssl_targets 181 182 @abstractmethod 183 def _augmentations_dict(self) -> Dict: 184 """ 185 Return a dictionary of possible augmentations 186 187 The keys are augmentation names and the values are the corresponding functions 188 189 Returns 190 ------- 191 augmentations_dict : dict 192 a dictionary of augmentation functions (each function needs to take 193 `(main_input: dict, ssl_inputs: list, ssl_targets: list)` as input and return an output 194 of the same format) 195 """ 196 197 @abstractmethod 198 def _default_augmentations(self) -> List: 199 """ 200 Return a list of default augmentation names 201 202 In case an augmentation list is not provided to the class constructor and use_default_augmentations is `True`, 203 this function is called to set the augmentations parameter. The elements of the list have to be keys of the 204 dictionary returned by `self._augmentations_dict()` 205 206 Returns 207 ------- 208 default_augmentations : list 209 a list of string names of the default augmentation functions 210 """ 211 212 def _check_augmentations(self, augmentations: List) -> None: 213 """ 214 Check the validity of an augmentations list 215 """ 216 217 for aug_name in augmentations: 218 if aug_name not in self._augmentations_dict(): 219 raise ValueError( 220 f"The {aug_name} augmentation is not possible in this augmentor! Please choose from {list(self._augmentations_dict().keys())}" 221 ) 222 223 def _get_augmentation(self, aug_name: str) -> Callable: 224 """ 225 Return the augmentation specified by `aug_name` 226 """ 227 228 return self._augmentations_dict()[aug_name] 229 230 def _visualize(self, main_input, title): 231 coord_keys = [x for x in main_input.keys() if x.split("---")[0] == "coords"] 232 if len(coord_keys) > 0: 233 coords = [main_input[key][0, :, :, 0].detach().cpu() for key in coord_keys] 234 centers = [None] 235 else: 236 coord_keys = [ 237 x for x in main_input.keys() if x.split("---")[0] == "coord_diff" 238 ] 239 coords = [] 240 centers = [] 241 for coord_diff_key in coord_keys: 242 if len(coord_diff_key.split("---")) == 2: 243 ind = coord_diff_key.split("---")[1] 244 center_key = f"center---{ind}" 245 else: 246 center_key = "center" 247 ind = "" 248 if center_key in main_input.keys(): 249 coords.append( 250 main_input[center_key][0, :, :, 0].detach().cpu() 251 + main_input[coord_diff_key][0, :, :, 0].detach().cpu() 252 ) 253 title += f", {ind}: {main_input[center_key][0, 0, :, 0].data}" 254 center = main_input[center_key][0, :, :, 0].detach().cpu() 255 else: 256 coords.append(main_input[coord_diff_key][0, :, :, 0].detach().cpu()) 257 center = None 258 centers.append(center) 259 colors = ["blue", "orange", "green", "purple", "pink"] 260 if coords[0].shape[1] == 2: 261 plt.figure(figsize=(15, 15)) 262 else: 263 fig = plt.figure(figsize=(15, 15)) 264 ax = fig.add_subplot(projection="3d") 265 for i, coord in enumerate(coords): 266 if coord.shape[1] == 2: 267 plt.scatter(coord[:, 0], coord[:, 1], color=colors[i]) 268 plt.xlim((-0.5, 0.5)) 269 plt.ylim((-0.5, 0.5)) 270 else: 271 ax.scatter( 272 coord[:, 0].detach().cpu(), 273 coord[:, 1].detach().cpu(), 274 coord[:, 2].detach().cpu(), 275 color=colors[i], 276 ) 277 if centers[i] is not None: 278 ax.scatter( 279 centers[i][:, 0].detach().cpu(), 280 centers[i][:, 1].detach().cpu(), 281 0, 282 color="red", 283 s=30, 284 ) 285 center = centers[i][0].detach().cpu() 286 ax.text( 287 center[0], center[1], 0, f"({center[0]:.2f}, {center[1]:.2f})" 288 ) 289 for i in [1, 8]: 290 ax.scatter( 291 coord[i : i + 1, 0].detach().cpu(), 292 coord[i : i + 1, 1].detach().cpu(), 293 coord[i : i + 1, 2].detach().cpu(), 294 color="purple", 295 ) 296 ax.text( 297 coord[i, 0], 298 coord[i, 1], 299 coord[i, 2], 300 f"({coord[i, 0]:.2f}, {coord[i, 1]:.2f}, {coord[i, 2]:.2f})", 301 ) 302 303 plt.xlim((-3, 3)) 304 plt.ylim((-3, 3)) 305 ax.set_zlim(-2, 4) 306 ax.set_xlabel("x") 307 ax.set_ylabel("y") 308 ax.set_zlabel("z") 309 plt.title(title) 310 plt.show() 311 312 def _apply_augmentations( 313 self, 314 main_input: Dict, 315 ssl_inputs: List = None, 316 ssl_targets: List = None, 317 augment: bool = False, 318 ssl: bool = False, 319 ) -> Tuple: 320 """ 321 Apply the augmentations 322 323 The same augmentations are applied to all inputs 324 """ 325 326 visualize = False 327 if ssl: 328 augmentations = self.ssl_augmentations 329 else: 330 augmentations = self.augmentations 331 if visualize: 332 self._visualize(main_input, "before") 333 if ssl_inputs is None: 334 ssl_inputs = [None] 335 if ssl_targets is None: 336 ssl_targets = [None] 337 if augment: 338 for aug_name in augmentations: 339 augment_func = self._get_augmentation(aug_name) 340 main_input, ssl_inputs, ssl_targets = augment_func( 341 main_input, ssl_inputs, ssl_targets 342 ) 343 if visualize: 344 self._visualize(main_input, aug_name) 345 return main_input, ssl_inputs, ssl_targets 346 347 def _make_tensor(self, x: Dict, model_name: str) -> Union[torch.Tensor, TensorList]: 348 """ 349 Turn a feature dictionary into a tensor or a `dlc2action.utils.TensorList` object 350 """ 351 352 if model_name == "ms_tcn_p": 353 keys = sorted(list(x.keys())) 354 groups = [key.split("---")[-1] for key in keys] 355 unique_groups = sorted(set(groups)) 356 tensor = TensorList() 357 for group in unique_groups: 358 if not self.graph_features: 359 tensor.append( 360 torch.cat( 361 [x[key] for key, g in zip(keys, groups) if g == group], 362 dim=2, 363 ) 364 ) 365 else: 366 tensor.append( 367 torch.cat( 368 [ 369 x[key].reshape( 370 ( 371 x[key].shape[0], 372 self.num_graph_nodes, 373 -1, 374 x[key].shape[-1], 375 ) 376 ) 377 for key, g in zip(keys, groups) 378 if g == group 379 ], 380 dim=2, 381 ) 382 ) 383 tensor[-1] = tensor[-1].reshape( 384 (tensor[-1].shape[0], -1, tensor[-1].shape[-1]) 385 ) 386 if "loaded" in x: 387 tensor.append(x["loaded"]) 388 elif model_name == "c2f_tcn_p": 389 keys = sorted( 390 [ 391 key 392 for key in x.keys() 393 if len(key.split("---")) != 1 394 and len(key.split("---")[-1].split("+")) != 2 395 ] 396 ) 397 inds = [key.split("---")[-1] for key in keys] 398 unique_inds = sorted(set(inds)) 399 tensor = TensorList() 400 for ind in unique_inds: 401 if not self.graph_features: 402 tensor.append( 403 torch.cat( 404 [x[key] for key, g in zip(keys, inds) if g == ind], 405 dim=1, 406 ) 407 ) 408 else: 409 tensor.append( 410 torch.cat( 411 [ 412 x[key].reshape( 413 ( 414 x[key].shape[0], 415 self.num_graph_nodes, 416 -1, 417 x[key].shape[-1], 418 ) 419 ) 420 for key, g in zip(keys, inds) 421 if g == ind 422 ], 423 dim=1, 424 ) 425 ) 426 tensor[-1] = tensor[-1].reshape( 427 (tensor[-1].shape[0], -1, tensor[-1].shape[-1]) 428 ) 429 elif model_name == "c3d_a": 430 tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 431 else: 432 if not self.graph_features: 433 tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 434 else: 435 tensor = torch.cat( 436 [ 437 x[key].reshape( 438 ( 439 x[key].shape[0], 440 self.num_graph_nodes, 441 -1, 442 x[key].shape[-1], 443 ) 444 ) 445 for key in sorted(list(x.keys())) 446 ], 447 dim=2, 448 ) 449 tensor = tensor.reshape((tensor.shape[0], -1, tensor.shape[-1])) 450 return tensor 451 452 453class EmptyTransformer(Transformer): 454 """ 455 Empty transformer class that does not apply augmentations 456 """ 457 458 def _augmentations_dict(self) -> Dict: 459 """ 460 Return a dictionary of possible augmentations 461 462 The keys are augmentation names and the values are the corresponding functions 463 464 Returns 465 ------- 466 augmentations_dict : dict 467 a dictionary of augmentation functions 468 """ 469 470 return {} 471 472 def _default_augmentations(self) -> List: 473 """ 474 Return a list of default augmentation names 475 476 In case an augmentation list is not provided to the class constructor and use_default_augmentations is True, 477 this function is called to set the augmentations parameter. The elements of the list have to be keys of the 478 dictionary returned by `self._augmentations_dict()` 479 480 Returns 481 ------- 482 default_augmentations : list 483 a list of string names of the default augmentation functions 484 """ 485 486 return []
20class Transformer(ABC): 21 """ 22 A base class for all transformers 23 24 A transformer should apply augmentations and generate model input and training target tensors. 25 26 All augmentation functions need to take `(main_input: dict, ssl_inputs: list, ssl_targets: list)` 27 as input and return an output of the same format. Here `main_input` is a feature dictionary of the sample 28 data, `ssl_inputs` is a list of SSL input feature dictionaries and `ssl_targets` is a list of SSL target 29 feature dictionaries. The same augmentations are applied to all inputs and then `None` values are replaced 30 according to the rules set by `keep_target_none` and `generate_ssl_input` parameters and the feature 31 dictionaries are compiled into tensors. 32 """ 33 34 def __init__( 35 self, 36 model_name: str, 37 augmentations: List = None, 38 use_default_augmentations: bool = False, 39 generate_ssl_input: List = None, 40 keep_target_none: List = None, 41 ssl_augmentations: List = None, 42 graph_features: bool = False, 43 bodyparts_order: List = None, 44 ) -> None: 45 """ 46 Parameters 47 ---------- 48 augmentations : list, optional 49 a list of string names of augmentations to use (if not provided, either no augmentations are applied or 50 (if use_default_augmentations is True) a default list is used 51 use_default_augmentations : bool, default False 52 if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 53 generate_ssl_input : list, optional 54 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 55 is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to 56 `False` for each module) 57 keep_target_none : list, optional 58 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 59 is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults 60 to `True` for each module) 61 ssl_augmentations : list, optional 62 a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True) 63 (if not provided, defaults to the main augmentations list) 64 graph_features : bool, default False 65 if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)` 66 bodyparts_order : list, optional 67 a list of bodypart names, optional 68 """ 69 70 if augmentations is None: 71 augmentations = [] 72 if generate_ssl_input is None: 73 generate_ssl_input = [None] 74 if keep_target_none is None: 75 keep_target_none = [None] 76 self.model_name = model_name 77 self.augmentations = augmentations 78 self.generate_ssl_input = generate_ssl_input 79 self.keep_target_none = keep_target_none 80 if len(self.augmentations) == 0 and use_default_augmentations: 81 self.augmentations = self._default_augmentations() 82 if ssl_augmentations is None: 83 ssl_augmentations = self.augmentations 84 self.ssl_augmentations = ssl_augmentations 85 self.graph_features = graph_features 86 self.num_graph_nodes = ( 87 len(bodyparts_order) if bodyparts_order is not None else None 88 ) 89 self._check_augmentations(self.augmentations) 90 self._check_augmentations(self.ssl_augmentations) 91 92 def transform( 93 self, 94 main_input: Dict, 95 ssl_inputs: List = None, 96 ssl_targets: List = None, 97 augment: bool = False, 98 subsample: List = None, 99 ) -> Tuple: 100 """ 101 Apply augmentations and generate tensors from feature dictionaries. 102 103 The same augmentations are applied to all the inputs (if they have the features known to the transformer). 104 105 If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as 106 another augmentation of main_input. 107 Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`. 108 All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to 109 a model. 110 111 Parameters 112 ---------- 113 main_input : dict 114 the feature dictionary of the main input 115 ssl_inputs : list, optional 116 a list of feature dictionaries of SSL inputs (some or all can be None) 117 ssl_targets : list, optional 118 a list of feature dictionaries of SSL targets (some or all can be None) 119 augment : bool, default True 120 if True, augmentations are applied 121 subsample : list, optional 122 a list of indices to subsample the input tensors (if not provided, no subsampling is applied) 123 124 Returns 125 ------- 126 main_input : torch.Tensor 127 the augmented tensor of the main input 128 ssl_inputs : list, optional 129 a list of augmented tensors of SSL inputs (some or all can be None) 130 ssl_targets : list, optional 131 a list of augmented tensors of SSL targets (some or all can be None) 132 133 """ 134 if subsample is not None: 135 original_len = list(main_input.values())[0].shape[-1] 136 for key in main_input: 137 main_input[key] = main_input[key][..., subsample] 138 # subsample_ssl = sorted(random.sample(range(original_len), len(subsample))) 139 for x in ssl_inputs + ssl_targets: 140 if x is not None: 141 for key in x: 142 if len(x[key].shape) == 3 and x[key].shape[-1] == original_len: 143 x[key] = x[key][..., subsample] 144 main_input, ssl_inputs, ssl_targets = self._apply_augmentations( 145 main_input, ssl_inputs, ssl_targets, augment 146 ) 147 meta = [None for _ in ssl_inputs] 148 for i, x in enumerate(ssl_inputs): 149 if type(x) is tuple: 150 x, meta_x = x 151 meta[i] = meta_x 152 ssl_inputs[i] = x 153 for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input): 154 if ssl_x is None and generate: 155 ssl_inputs[i] = self._apply_augmentations( 156 deepcopy(main_input), None, None, augment=True, ssl=True 157 )[0] 158 output = [] 159 num_ssl = len(ssl_inputs) 160 dicts = [main_input] + ssl_inputs + ssl_targets 161 162 for x in dicts: 163 if x is None: 164 output.append(None) 165 else: 166 output.append(self._make_tensor(x, self.model_name)) 167 # output.append( 168 # torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 169 # ) 170 main_input, ssl_inputs, ssl_targets = ( 171 output[0], 172 output[1 : num_ssl + 1], 173 output[num_ssl + 1 :], 174 ) 175 for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none): 176 if not keep and ssl_x is None: 177 ssl_targets[i] = main_input 178 for i, meta_x in enumerate(meta): 179 if meta_x is not None: 180 ssl_inputs[i] = (ssl_inputs[i], meta_x) 181 return main_input, ssl_inputs, ssl_targets 182 183 @abstractmethod 184 def _augmentations_dict(self) -> Dict: 185 """ 186 Return a dictionary of possible augmentations 187 188 The keys are augmentation names and the values are the corresponding functions 189 190 Returns 191 ------- 192 augmentations_dict : dict 193 a dictionary of augmentation functions (each function needs to take 194 `(main_input: dict, ssl_inputs: list, ssl_targets: list)` as input and return an output 195 of the same format) 196 """ 197 198 @abstractmethod 199 def _default_augmentations(self) -> List: 200 """ 201 Return a list of default augmentation names 202 203 In case an augmentation list is not provided to the class constructor and use_default_augmentations is `True`, 204 this function is called to set the augmentations parameter. The elements of the list have to be keys of the 205 dictionary returned by `self._augmentations_dict()` 206 207 Returns 208 ------- 209 default_augmentations : list 210 a list of string names of the default augmentation functions 211 """ 212 213 def _check_augmentations(self, augmentations: List) -> None: 214 """ 215 Check the validity of an augmentations list 216 """ 217 218 for aug_name in augmentations: 219 if aug_name not in self._augmentations_dict(): 220 raise ValueError( 221 f"The {aug_name} augmentation is not possible in this augmentor! Please choose from {list(self._augmentations_dict().keys())}" 222 ) 223 224 def _get_augmentation(self, aug_name: str) -> Callable: 225 """ 226 Return the augmentation specified by `aug_name` 227 """ 228 229 return self._augmentations_dict()[aug_name] 230 231 def _visualize(self, main_input, title): 232 coord_keys = [x for x in main_input.keys() if x.split("---")[0] == "coords"] 233 if len(coord_keys) > 0: 234 coords = [main_input[key][0, :, :, 0].detach().cpu() for key in coord_keys] 235 centers = [None] 236 else: 237 coord_keys = [ 238 x for x in main_input.keys() if x.split("---")[0] == "coord_diff" 239 ] 240 coords = [] 241 centers = [] 242 for coord_diff_key in coord_keys: 243 if len(coord_diff_key.split("---")) == 2: 244 ind = coord_diff_key.split("---")[1] 245 center_key = f"center---{ind}" 246 else: 247 center_key = "center" 248 ind = "" 249 if center_key in main_input.keys(): 250 coords.append( 251 main_input[center_key][0, :, :, 0].detach().cpu() 252 + main_input[coord_diff_key][0, :, :, 0].detach().cpu() 253 ) 254 title += f", {ind}: {main_input[center_key][0, 0, :, 0].data}" 255 center = main_input[center_key][0, :, :, 0].detach().cpu() 256 else: 257 coords.append(main_input[coord_diff_key][0, :, :, 0].detach().cpu()) 258 center = None 259 centers.append(center) 260 colors = ["blue", "orange", "green", "purple", "pink"] 261 if coords[0].shape[1] == 2: 262 plt.figure(figsize=(15, 15)) 263 else: 264 fig = plt.figure(figsize=(15, 15)) 265 ax = fig.add_subplot(projection="3d") 266 for i, coord in enumerate(coords): 267 if coord.shape[1] == 2: 268 plt.scatter(coord[:, 0], coord[:, 1], color=colors[i]) 269 plt.xlim((-0.5, 0.5)) 270 plt.ylim((-0.5, 0.5)) 271 else: 272 ax.scatter( 273 coord[:, 0].detach().cpu(), 274 coord[:, 1].detach().cpu(), 275 coord[:, 2].detach().cpu(), 276 color=colors[i], 277 ) 278 if centers[i] is not None: 279 ax.scatter( 280 centers[i][:, 0].detach().cpu(), 281 centers[i][:, 1].detach().cpu(), 282 0, 283 color="red", 284 s=30, 285 ) 286 center = centers[i][0].detach().cpu() 287 ax.text( 288 center[0], center[1], 0, f"({center[0]:.2f}, {center[1]:.2f})" 289 ) 290 for i in [1, 8]: 291 ax.scatter( 292 coord[i : i + 1, 0].detach().cpu(), 293 coord[i : i + 1, 1].detach().cpu(), 294 coord[i : i + 1, 2].detach().cpu(), 295 color="purple", 296 ) 297 ax.text( 298 coord[i, 0], 299 coord[i, 1], 300 coord[i, 2], 301 f"({coord[i, 0]:.2f}, {coord[i, 1]:.2f}, {coord[i, 2]:.2f})", 302 ) 303 304 plt.xlim((-3, 3)) 305 plt.ylim((-3, 3)) 306 ax.set_zlim(-2, 4) 307 ax.set_xlabel("x") 308 ax.set_ylabel("y") 309 ax.set_zlabel("z") 310 plt.title(title) 311 plt.show() 312 313 def _apply_augmentations( 314 self, 315 main_input: Dict, 316 ssl_inputs: List = None, 317 ssl_targets: List = None, 318 augment: bool = False, 319 ssl: bool = False, 320 ) -> Tuple: 321 """ 322 Apply the augmentations 323 324 The same augmentations are applied to all inputs 325 """ 326 327 visualize = False 328 if ssl: 329 augmentations = self.ssl_augmentations 330 else: 331 augmentations = self.augmentations 332 if visualize: 333 self._visualize(main_input, "before") 334 if ssl_inputs is None: 335 ssl_inputs = [None] 336 if ssl_targets is None: 337 ssl_targets = [None] 338 if augment: 339 for aug_name in augmentations: 340 augment_func = self._get_augmentation(aug_name) 341 main_input, ssl_inputs, ssl_targets = augment_func( 342 main_input, ssl_inputs, ssl_targets 343 ) 344 if visualize: 345 self._visualize(main_input, aug_name) 346 return main_input, ssl_inputs, ssl_targets 347 348 def _make_tensor(self, x: Dict, model_name: str) -> Union[torch.Tensor, TensorList]: 349 """ 350 Turn a feature dictionary into a tensor or a `dlc2action.utils.TensorList` object 351 """ 352 353 if model_name == "ms_tcn_p": 354 keys = sorted(list(x.keys())) 355 groups = [key.split("---")[-1] for key in keys] 356 unique_groups = sorted(set(groups)) 357 tensor = TensorList() 358 for group in unique_groups: 359 if not self.graph_features: 360 tensor.append( 361 torch.cat( 362 [x[key] for key, g in zip(keys, groups) if g == group], 363 dim=2, 364 ) 365 ) 366 else: 367 tensor.append( 368 torch.cat( 369 [ 370 x[key].reshape( 371 ( 372 x[key].shape[0], 373 self.num_graph_nodes, 374 -1, 375 x[key].shape[-1], 376 ) 377 ) 378 for key, g in zip(keys, groups) 379 if g == group 380 ], 381 dim=2, 382 ) 383 ) 384 tensor[-1] = tensor[-1].reshape( 385 (tensor[-1].shape[0], -1, tensor[-1].shape[-1]) 386 ) 387 if "loaded" in x: 388 tensor.append(x["loaded"]) 389 elif model_name == "c2f_tcn_p": 390 keys = sorted( 391 [ 392 key 393 for key in x.keys() 394 if len(key.split("---")) != 1 395 and len(key.split("---")[-1].split("+")) != 2 396 ] 397 ) 398 inds = [key.split("---")[-1] for key in keys] 399 unique_inds = sorted(set(inds)) 400 tensor = TensorList() 401 for ind in unique_inds: 402 if not self.graph_features: 403 tensor.append( 404 torch.cat( 405 [x[key] for key, g in zip(keys, inds) if g == ind], 406 dim=1, 407 ) 408 ) 409 else: 410 tensor.append( 411 torch.cat( 412 [ 413 x[key].reshape( 414 ( 415 x[key].shape[0], 416 self.num_graph_nodes, 417 -1, 418 x[key].shape[-1], 419 ) 420 ) 421 for key, g in zip(keys, inds) 422 if g == ind 423 ], 424 dim=1, 425 ) 426 ) 427 tensor[-1] = tensor[-1].reshape( 428 (tensor[-1].shape[0], -1, tensor[-1].shape[-1]) 429 ) 430 elif model_name == "c3d_a": 431 tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 432 else: 433 if not self.graph_features: 434 tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 435 else: 436 tensor = torch.cat( 437 [ 438 x[key].reshape( 439 ( 440 x[key].shape[0], 441 self.num_graph_nodes, 442 -1, 443 x[key].shape[-1], 444 ) 445 ) 446 for key in sorted(list(x.keys())) 447 ], 448 dim=2, 449 ) 450 tensor = tensor.reshape((tensor.shape[0], -1, tensor.shape[-1])) 451 return tensor
A base class for all transformers
A transformer should apply augmentations and generate model input and training target tensors.
All augmentation functions need to take (main_input: dict, ssl_inputs: list, ssl_targets: list)
as input and return an output of the same format. Here main_input is a feature dictionary of the sample
data, ssl_inputs is a list of SSL input feature dictionaries and ssl_targets is a list of SSL target
feature dictionaries. The same augmentations are applied to all inputs and then None values are replaced
according to the rules set by keep_target_none and generate_ssl_input parameters and the feature
dictionaries are compiled into tensors.
34 def __init__( 35 self, 36 model_name: str, 37 augmentations: List = None, 38 use_default_augmentations: bool = False, 39 generate_ssl_input: List = None, 40 keep_target_none: List = None, 41 ssl_augmentations: List = None, 42 graph_features: bool = False, 43 bodyparts_order: List = None, 44 ) -> None: 45 """ 46 Parameters 47 ---------- 48 augmentations : list, optional 49 a list of string names of augmentations to use (if not provided, either no augmentations are applied or 50 (if use_default_augmentations is True) a default list is used 51 use_default_augmentations : bool, default False 52 if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 53 generate_ssl_input : list, optional 54 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 55 is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to 56 `False` for each module) 57 keep_target_none : list, optional 58 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 59 is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults 60 to `True` for each module) 61 ssl_augmentations : list, optional 62 a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True) 63 (if not provided, defaults to the main augmentations list) 64 graph_features : bool, default False 65 if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)` 66 bodyparts_order : list, optional 67 a list of bodypart names, optional 68 """ 69 70 if augmentations is None: 71 augmentations = [] 72 if generate_ssl_input is None: 73 generate_ssl_input = [None] 74 if keep_target_none is None: 75 keep_target_none = [None] 76 self.model_name = model_name 77 self.augmentations = augmentations 78 self.generate_ssl_input = generate_ssl_input 79 self.keep_target_none = keep_target_none 80 if len(self.augmentations) == 0 and use_default_augmentations: 81 self.augmentations = self._default_augmentations() 82 if ssl_augmentations is None: 83 ssl_augmentations = self.augmentations 84 self.ssl_augmentations = ssl_augmentations 85 self.graph_features = graph_features 86 self.num_graph_nodes = ( 87 len(bodyparts_order) if bodyparts_order is not None else None 88 ) 89 self._check_augmentations(self.augmentations) 90 self._check_augmentations(self.ssl_augmentations)
Parameters
augmentations : list, optional
a list of string names of augmentations to use (if not provided, either no augmentations are applied or
(if use_default_augmentations is True) a default list is used
use_default_augmentations : bool, default False
if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
generate_ssl_input : list, optional
a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
is True, the ssl input will be generated as a new augmentation of main input (if not provided defaults to
False for each module)
keep_target_none : list, optional
a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
is False and the SSL target is None, the target is set to augmented main input (if not provided defaults
to True for each module)
ssl_augmentations : list, optional
a list of augmentation names to be applied with generating SSL input (when generate_ssl_input is True)
(if not provided, defaults to the main augmentations list)
graph_features : bool, default False
if True, all features in each frame can be meaningfully reshaped to (#bodyparts, #features)
bodyparts_order : list, optional
a list of bodypart names, optional
92 def transform( 93 self, 94 main_input: Dict, 95 ssl_inputs: List = None, 96 ssl_targets: List = None, 97 augment: bool = False, 98 subsample: List = None, 99 ) -> Tuple: 100 """ 101 Apply augmentations and generate tensors from feature dictionaries. 102 103 The same augmentations are applied to all the inputs (if they have the features known to the transformer). 104 105 If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as 106 another augmentation of main_input. 107 Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`. 108 All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to 109 a model. 110 111 Parameters 112 ---------- 113 main_input : dict 114 the feature dictionary of the main input 115 ssl_inputs : list, optional 116 a list of feature dictionaries of SSL inputs (some or all can be None) 117 ssl_targets : list, optional 118 a list of feature dictionaries of SSL targets (some or all can be None) 119 augment : bool, default True 120 if True, augmentations are applied 121 subsample : list, optional 122 a list of indices to subsample the input tensors (if not provided, no subsampling is applied) 123 124 Returns 125 ------- 126 main_input : torch.Tensor 127 the augmented tensor of the main input 128 ssl_inputs : list, optional 129 a list of augmented tensors of SSL inputs (some or all can be None) 130 ssl_targets : list, optional 131 a list of augmented tensors of SSL targets (some or all can be None) 132 133 """ 134 if subsample is not None: 135 original_len = list(main_input.values())[0].shape[-1] 136 for key in main_input: 137 main_input[key] = main_input[key][..., subsample] 138 # subsample_ssl = sorted(random.sample(range(original_len), len(subsample))) 139 for x in ssl_inputs + ssl_targets: 140 if x is not None: 141 for key in x: 142 if len(x[key].shape) == 3 and x[key].shape[-1] == original_len: 143 x[key] = x[key][..., subsample] 144 main_input, ssl_inputs, ssl_targets = self._apply_augmentations( 145 main_input, ssl_inputs, ssl_targets, augment 146 ) 147 meta = [None for _ in ssl_inputs] 148 for i, x in enumerate(ssl_inputs): 149 if type(x) is tuple: 150 x, meta_x = x 151 meta[i] = meta_x 152 ssl_inputs[i] = x 153 for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input): 154 if ssl_x is None and generate: 155 ssl_inputs[i] = self._apply_augmentations( 156 deepcopy(main_input), None, None, augment=True, ssl=True 157 )[0] 158 output = [] 159 num_ssl = len(ssl_inputs) 160 dicts = [main_input] + ssl_inputs + ssl_targets 161 162 for x in dicts: 163 if x is None: 164 output.append(None) 165 else: 166 output.append(self._make_tensor(x, self.model_name)) 167 # output.append( 168 # torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 169 # ) 170 main_input, ssl_inputs, ssl_targets = ( 171 output[0], 172 output[1 : num_ssl + 1], 173 output[num_ssl + 1 :], 174 ) 175 for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none): 176 if not keep and ssl_x is None: 177 ssl_targets[i] = main_input 178 for i, meta_x in enumerate(meta): 179 if meta_x is not None: 180 ssl_inputs[i] = (ssl_inputs[i], meta_x) 181 return main_input, ssl_inputs, ssl_targets
Apply augmentations and generate tensors from feature dictionaries.
The same augmentations are applied to all the inputs (if they have the features known to the transformer).
If generate_ssl_input is set to True for some of the SSL pairs, those SSL inputs will be generated as
another augmentation of main_input.
Unless keep_target_none is set to True, None SSL targets will be replaced with augmented main_input.
All features are stacked together to form a tensor of shape (#features, #frames) that can be passed to
a model.
Parameters
main_input : dict the feature dictionary of the main input ssl_inputs : list, optional a list of feature dictionaries of SSL inputs (some or all can be None) ssl_targets : list, optional a list of feature dictionaries of SSL targets (some or all can be None) augment : bool, default True if True, augmentations are applied subsample : list, optional a list of indices to subsample the input tensors (if not provided, no subsampling is applied)
Returns
main_input : torch.Tensor the augmented tensor of the main input ssl_inputs : list, optional a list of augmented tensors of SSL inputs (some or all can be None) ssl_targets : list, optional a list of augmented tensors of SSL targets (some or all can be None)
454class EmptyTransformer(Transformer): 455 """ 456 Empty transformer class that does not apply augmentations 457 """ 458 459 def _augmentations_dict(self) -> Dict: 460 """ 461 Return a dictionary of possible augmentations 462 463 The keys are augmentation names and the values are the corresponding functions 464 465 Returns 466 ------- 467 augmentations_dict : dict 468 a dictionary of augmentation functions 469 """ 470 471 return {} 472 473 def _default_augmentations(self) -> List: 474 """ 475 Return a list of default augmentation names 476 477 In case an augmentation list is not provided to the class constructor and use_default_augmentations is True, 478 this function is called to set the augmentations parameter. The elements of the list have to be keys of the 479 dictionary returned by `self._augmentations_dict()` 480 481 Returns 482 ------- 483 default_augmentations : list 484 a list of string names of the default augmentation functions 485 """ 486 487 return []
Empty transformer class that does not apply augmentations