dlc2action.transformer.base_transformer
Abstract parent class for transformers
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Abstract parent class for transformers 8""" 9 10from typing import Dict, List, Callable, Union, Tuple 11import torch 12from abc import ABC, abstractmethod 13from dlc2action.utils import TensorList 14from copy import deepcopy 15from matplotlib import pyplot as plt 16 17 18class Transformer(ABC): 19 """ 20 A base class for all transformers 21 22 A transformer should apply augmentations and generate model input and training target tensors. 23 24 All augmentation functions need to take `(main_input: dict, ssl_inputs: list, ssl_targets: list)` 25 as input and return an output of the same format. Here `main_input` is a feature dictionary of the sample 26 data, `ssl_inputs` is a list of SSL input feature dictionaries and `ssl_targets` is a list of SSL target 27 feature dictionaries. The same augmentations are applied to all inputs and then `None` values are replaced 28 according to the rules set by `keep_target_none` and `generate_ssl_input` parameters and the feature 29 dictionaries are compiled into tensors. 30 """ 31 32 def __init__( 33 self, 34 model_name: str, 35 augmentations: List = None, 36 use_default_augmentations: bool = False, 37 generate_ssl_input: List = None, 38 keep_target_none: List = None, 39 ssl_augmentations: List = None, 40 graph_features: bool = False, 41 bodyparts_order: List = None, 42 ) -> None: 43 """ 44 Parameters 45 ---------- 46 augmentations : list, optional 47 a list of string names of augmentations to use (if not provided, either no augmentations are applied or 48 (if use_default_augmentations is True) a default list is used 49 use_default_augmentations : bool, default False 50 if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 51 generate_ssl_input : list, optional 52 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 53 is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to 54 `False` for each module) 55 keep_target_none : list, optional 56 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 57 is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults 58 to `True` for each module) 59 ssl_augmentations : list, optional 60 a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True) 61 (if not provided, defaults to the main augmentations list) 62 graph_features : bool, default False 63 if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)` 64 bodyparts_order : list, optional 65 a list of bodypart names, optional 66 """ 67 68 if augmentations is None: 69 augmentations = [] 70 if generate_ssl_input is None: 71 generate_ssl_input = [None] 72 if keep_target_none is None: 73 keep_target_none = [None] 74 self.model_name = model_name 75 self.augmentations = augmentations 76 self.generate_ssl_input = generate_ssl_input 77 self.keep_target_none = keep_target_none 78 if len(self.augmentations) == 0 and use_default_augmentations: 79 self.augmentations = self._default_augmentations() 80 if ssl_augmentations is None: 81 ssl_augmentations = self.augmentations 82 self.ssl_augmentations = ssl_augmentations 83 self.graph_features = graph_features 84 self.num_graph_nodes = ( 85 len(bodyparts_order) if bodyparts_order is not None else None 86 ) 87 self._check_augmentations(self.augmentations) 88 self._check_augmentations(self.ssl_augmentations) 89 90 def transform( 91 self, 92 main_input: Dict, 93 ssl_inputs: List = None, 94 ssl_targets: List = None, 95 augment: bool = False, 96 subsample: List = None, 97 ) -> Tuple: 98 """ 99 Apply augmentations and generate tensors from feature dictionaries 100 101 The same augmentations are applied to all the inputs (if they have the features known to the transformer). 102 103 If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as 104 another augmentation of main_input. 105 Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`. 106 All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to 107 a model. 108 109 Parameters 110 ---------- 111 main_input : dict 112 the feature dictionary of the main input 113 ssl_inputs : list, optional 114 a list of feature dictionaries of SSL inputs (some or all can be None) 115 ssl_targets : list, optional 116 a list of feature dictionaries of SSL targets (some or all can be None) 117 augment : bool, default True 118 119 Returns 120 ------- 121 main_input : torch.Tensor 122 the augmented tensor of the main input 123 ssl_inputs : list, optional 124 a list of augmented tensors of SSL inputs (some or all can be None) 125 ssl_targets : list, optional 126 a list of augmented tensors of SSL targets (some or all can be None) 127 """ 128 129 if subsample is not None: 130 original_len = list(main_input.values())[0].shape[-1] 131 for key in main_input: 132 main_input[key] = main_input[key][..., subsample] 133 # subsample_ssl = sorted(random.sample(range(original_len), len(subsample))) 134 for x in ssl_inputs + ssl_targets: 135 if x is not None: 136 for key in x: 137 if len(x[key].shape) == 3 and x[key].shape[-1] == original_len: 138 x[key] = x[key][..., subsample] 139 main_input, ssl_inputs, ssl_targets = self._apply_augmentations( 140 main_input, ssl_inputs, ssl_targets, augment 141 ) 142 meta = [None for _ in ssl_inputs] 143 for i, x in enumerate(ssl_inputs): 144 if type(x) is tuple: 145 x, meta_x = x 146 meta[i] = meta_x 147 ssl_inputs[i] = x 148 for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input): 149 if ssl_x is None and generate: 150 ssl_inputs[i] = self._apply_augmentations( 151 deepcopy(main_input), None, None, augment=True, ssl=True 152 )[0] 153 output = [] 154 num_ssl = len(ssl_inputs) 155 dicts = [main_input] + ssl_inputs + ssl_targets 156 157 for x in dicts: 158 if x is None: 159 output.append(None) 160 else: 161 output.append(self._make_tensor(x, self.model_name)) 162 # output.append( 163 # torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 164 # ) 165 main_input, ssl_inputs, ssl_targets = ( 166 output[0], 167 output[1 : num_ssl + 1], 168 output[num_ssl + 1 :], 169 ) 170 for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none): 171 if not keep and ssl_x is None: 172 ssl_targets[i] = main_input 173 for i, meta_x in enumerate(meta): 174 if meta_x is not None: 175 ssl_inputs[i] = (ssl_inputs[i], meta_x) 176 return main_input, ssl_inputs, ssl_targets 177 178 @abstractmethod 179 def _augmentations_dict(self) -> Dict: 180 """ 181 Return a dictionary of possible augmentations 182 183 The keys are augmentation names and the values are the corresponding functions 184 185 Returns 186 ------- 187 augmentations_dict : dict 188 a dictionary of augmentation functions (each function needs to take 189 `(main_input: dict, ssl_inputs: list, ssl_targets: list)` as input and return an output 190 of the same format) 191 """ 192 193 @abstractmethod 194 def _default_augmentations(self) -> List: 195 """ 196 Return a list of default augmentation names 197 198 In case an augmentation list is not provided to the class constructor and use_default_augmentations is `True`, 199 this function is called to set the augmentations parameter. The elements of the list have to be keys of the 200 dictionary returned by `self._augmentations_dict()` 201 202 Returns 203 ------- 204 default_augmentations : list 205 a list of string names of the default augmentation functions 206 """ 207 208 def _check_augmentations(self, augmentations: List) -> None: 209 """ 210 Check the validity of an augmentations list 211 """ 212 213 for aug_name in augmentations: 214 if aug_name not in self._augmentations_dict(): 215 raise ValueError( 216 f"The {aug_name} augmentation is not possible in this augmentor! Please choose from {list(self._augmentations_dict().keys())}" 217 ) 218 219 def _get_augmentation(self, aug_name: str) -> Callable: 220 """ 221 Return the augmentation specified by `aug_name` 222 """ 223 224 return self._augmentations_dict()[aug_name] 225 226 def _visualize(self, main_input, title): 227 coord_keys = [x for x in main_input.keys() if x.split("---")[0] == "coords"] 228 if len(coord_keys) > 0: 229 coords = [main_input[key][0, :, :, 0].detach().cpu() for key in coord_keys] 230 centers = [None] 231 else: 232 coord_keys = [ 233 x for x in main_input.keys() if x.split("---")[0] == "coord_diff" 234 ] 235 coords = [] 236 centers = [] 237 for coord_diff_key in coord_keys: 238 if len(coord_diff_key.split("---")) == 2: 239 ind = coord_diff_key.split("---")[1] 240 center_key = f"center---{ind}" 241 else: 242 center_key = "center" 243 ind = "" 244 if center_key in main_input.keys(): 245 coords.append( 246 main_input[center_key][0, :, :, 0].detach().cpu() 247 + main_input[coord_diff_key][0, :, :, 0].detach().cpu() 248 ) 249 title += f", {ind}: {main_input[center_key][0, 0, :, 0].data}" 250 center = main_input[center_key][0, :, :, 0].detach().cpu() 251 else: 252 coords.append(main_input[coord_diff_key][0, :, :, 0].detach().cpu()) 253 center = None 254 centers.append(center) 255 colors = ["blue", "orange", "green", "purple", "pink"] 256 if coords[0].shape[1] == 2: 257 plt.figure(figsize=(15, 15)) 258 else: 259 fig = plt.figure(figsize=(15, 15)) 260 ax = fig.add_subplot(projection="3d") 261 for i, coord in enumerate(coords): 262 if coord.shape[1] == 2: 263 plt.scatter(coord[:, 0], coord[:, 1], color=colors[i]) 264 plt.xlim((-0.5, 0.5)) 265 plt.ylim((-0.5, 0.5)) 266 else: 267 ax.scatter( 268 coord[:, 0].detach().cpu(), 269 coord[:, 1].detach().cpu(), 270 coord[:, 2].detach().cpu(), 271 color=colors[i], 272 ) 273 if centers[i] is not None: 274 ax.scatter( 275 centers[i][:, 0].detach().cpu(), 276 centers[i][:, 1].detach().cpu(), 277 0, 278 color="red", 279 s=30, 280 ) 281 center = centers[i][0].detach().cpu() 282 ax.text( 283 center[0], center[1], 0, f"({center[0]:.2f}, {center[1]:.2f})" 284 ) 285 for i in [1, 8]: 286 ax.scatter( 287 coord[i : i + 1, 0].detach().cpu(), 288 coord[i : i + 1, 1].detach().cpu(), 289 coord[i : i + 1, 2].detach().cpu(), 290 color="purple", 291 ) 292 ax.text( 293 coord[i, 0], 294 coord[i, 1], 295 coord[i, 2], 296 f"({coord[i, 0]:.2f}, {coord[i, 1]:.2f}, {coord[i, 2]:.2f})", 297 ) 298 299 plt.xlim((-3, 3)) 300 plt.ylim((-3, 3)) 301 ax.set_zlim(-2, 4) 302 ax.set_xlabel("x") 303 ax.set_ylabel("y") 304 ax.set_zlabel("z") 305 plt.title(title) 306 plt.show() 307 308 def _apply_augmentations( 309 self, 310 main_input: Dict, 311 ssl_inputs: List = None, 312 ssl_targets: List = None, 313 augment: bool = False, 314 ssl: bool = False, 315 ) -> Tuple: 316 """ 317 Apply the augmentations 318 319 The same augmentations are applied to all inputs 320 """ 321 322 visualize = False 323 if ssl: 324 augmentations = self.ssl_augmentations 325 else: 326 augmentations = self.augmentations 327 if visualize: 328 self._visualize(main_input, "before") 329 if ssl_inputs is None: 330 ssl_inputs = [None] 331 if ssl_targets is None: 332 ssl_targets = [None] 333 if augment: 334 for aug_name in augmentations: 335 augment_func = self._get_augmentation(aug_name) 336 main_input, ssl_inputs, ssl_targets = augment_func( 337 main_input, ssl_inputs, ssl_targets 338 ) 339 if visualize: 340 self._visualize(main_input, aug_name) 341 return main_input, ssl_inputs, ssl_targets 342 343 def _make_tensor(self, x: Dict, model_name: str) -> Union[torch.Tensor, TensorList]: 344 """ 345 Turn a feature dictionary into a tensor or a `dlc2action.utils.TensorList` object 346 """ 347 348 if model_name == "ms_tcn_p": 349 keys = sorted(list(x.keys())) 350 groups = [key.split("---")[-1] for key in keys] 351 unique_groups = sorted(set(groups)) 352 tensor = TensorList() 353 for group in unique_groups: 354 if not self.graph_features: 355 tensor.append( 356 torch.cat( 357 [x[key] for key, g in zip(keys, groups) if g == group], 358 dim=2, 359 ) 360 ) 361 else: 362 tensor.append( 363 torch.cat( 364 [ 365 x[key].reshape( 366 ( 367 x[key].shape[0], 368 self.num_graph_nodes, 369 -1, 370 x[key].shape[-1], 371 ) 372 ) 373 for key, g in zip(keys, groups) 374 if g == group 375 ], 376 dim=2, 377 ) 378 ) 379 tensor[-1] = tensor[-1].reshape( 380 (tensor[-1].shape[0], -1, tensor[-1].shape[-1]) 381 ) 382 if "loaded" in x: 383 tensor.append(x["loaded"]) 384 elif model_name == "c2f_tcn_p": 385 keys = sorted( 386 [ 387 key 388 for key in x.keys() 389 if len(key.split("---")) != 1 390 and len(key.split("---")[-1].split("+")) != 2 391 ] 392 ) 393 inds = [key.split("---")[-1] for key in keys] 394 unique_inds = sorted(set(inds)) 395 tensor = TensorList() 396 for ind in unique_inds: 397 if not self.graph_features: 398 tensor.append( 399 torch.cat( 400 [x[key] for key, g in zip(keys, inds) if g == ind], 401 dim=1, 402 ) 403 ) 404 else: 405 tensor.append( 406 torch.cat( 407 [ 408 x[key].reshape( 409 ( 410 x[key].shape[0], 411 self.num_graph_nodes, 412 -1, 413 x[key].shape[-1], 414 ) 415 ) 416 for key, g in zip(keys, inds) 417 if g == ind 418 ], 419 dim=1, 420 ) 421 ) 422 tensor[-1] = tensor[-1].reshape( 423 (tensor[-1].shape[0], -1, tensor[-1].shape[-1]) 424 ) 425 elif model_name == "c3d_a": 426 tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 427 else: 428 if not self.graph_features: 429 tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 430 else: 431 tensor = torch.cat( 432 [ 433 x[key].reshape( 434 ( 435 x[key].shape[0], 436 self.num_graph_nodes, 437 -1, 438 x[key].shape[-1], 439 ) 440 ) 441 for key in sorted(list(x.keys())) 442 ], 443 dim=2, 444 ) 445 tensor = tensor.reshape((tensor.shape[0], -1, tensor.shape[-1])) 446 return tensor 447 448 449class EmptyTransformer(Transformer): 450 """ 451 Empty transformer class that does not apply augmentations 452 """ 453 454 def _augmentations_dict(self) -> Dict: 455 """ 456 Return a dictionary of possible augmentations 457 458 The keys are augmentation names and the values are the corresponding functions 459 460 Returns 461 ------- 462 augmentations_dict : dict 463 a dictionary of augmentation functions 464 """ 465 466 return {} 467 468 def _default_augmentations(self) -> List: 469 """ 470 Return a list of default augmentation names 471 472 In case an augmentation list is not provided to the class constructor and use_default_augmentations is True, 473 this function is called to set the augmentations parameter. The elements of the list have to be keys of the 474 dictionary returned by `self._augmentations_dict()` 475 476 Returns 477 ------- 478 default_augmentations : list 479 a list of string names of the default augmentation functions 480 """ 481 482 return []
19class Transformer(ABC): 20 """ 21 A base class for all transformers 22 23 A transformer should apply augmentations and generate model input and training target tensors. 24 25 All augmentation functions need to take `(main_input: dict, ssl_inputs: list, ssl_targets: list)` 26 as input and return an output of the same format. Here `main_input` is a feature dictionary of the sample 27 data, `ssl_inputs` is a list of SSL input feature dictionaries and `ssl_targets` is a list of SSL target 28 feature dictionaries. The same augmentations are applied to all inputs and then `None` values are replaced 29 according to the rules set by `keep_target_none` and `generate_ssl_input` parameters and the feature 30 dictionaries are compiled into tensors. 31 """ 32 33 def __init__( 34 self, 35 model_name: str, 36 augmentations: List = None, 37 use_default_augmentations: bool = False, 38 generate_ssl_input: List = None, 39 keep_target_none: List = None, 40 ssl_augmentations: List = None, 41 graph_features: bool = False, 42 bodyparts_order: List = None, 43 ) -> None: 44 """ 45 Parameters 46 ---------- 47 augmentations : list, optional 48 a list of string names of augmentations to use (if not provided, either no augmentations are applied or 49 (if use_default_augmentations is True) a default list is used 50 use_default_augmentations : bool, default False 51 if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 52 generate_ssl_input : list, optional 53 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 54 is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to 55 `False` for each module) 56 keep_target_none : list, optional 57 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 58 is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults 59 to `True` for each module) 60 ssl_augmentations : list, optional 61 a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True) 62 (if not provided, defaults to the main augmentations list) 63 graph_features : bool, default False 64 if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)` 65 bodyparts_order : list, optional 66 a list of bodypart names, optional 67 """ 68 69 if augmentations is None: 70 augmentations = [] 71 if generate_ssl_input is None: 72 generate_ssl_input = [None] 73 if keep_target_none is None: 74 keep_target_none = [None] 75 self.model_name = model_name 76 self.augmentations = augmentations 77 self.generate_ssl_input = generate_ssl_input 78 self.keep_target_none = keep_target_none 79 if len(self.augmentations) == 0 and use_default_augmentations: 80 self.augmentations = self._default_augmentations() 81 if ssl_augmentations is None: 82 ssl_augmentations = self.augmentations 83 self.ssl_augmentations = ssl_augmentations 84 self.graph_features = graph_features 85 self.num_graph_nodes = ( 86 len(bodyparts_order) if bodyparts_order is not None else None 87 ) 88 self._check_augmentations(self.augmentations) 89 self._check_augmentations(self.ssl_augmentations) 90 91 def transform( 92 self, 93 main_input: Dict, 94 ssl_inputs: List = None, 95 ssl_targets: List = None, 96 augment: bool = False, 97 subsample: List = None, 98 ) -> Tuple: 99 """ 100 Apply augmentations and generate tensors from feature dictionaries 101 102 The same augmentations are applied to all the inputs (if they have the features known to the transformer). 103 104 If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as 105 another augmentation of main_input. 106 Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`. 107 All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to 108 a model. 109 110 Parameters 111 ---------- 112 main_input : dict 113 the feature dictionary of the main input 114 ssl_inputs : list, optional 115 a list of feature dictionaries of SSL inputs (some or all can be None) 116 ssl_targets : list, optional 117 a list of feature dictionaries of SSL targets (some or all can be None) 118 augment : bool, default True 119 120 Returns 121 ------- 122 main_input : torch.Tensor 123 the augmented tensor of the main input 124 ssl_inputs : list, optional 125 a list of augmented tensors of SSL inputs (some or all can be None) 126 ssl_targets : list, optional 127 a list of augmented tensors of SSL targets (some or all can be None) 128 """ 129 130 if subsample is not None: 131 original_len = list(main_input.values())[0].shape[-1] 132 for key in main_input: 133 main_input[key] = main_input[key][..., subsample] 134 # subsample_ssl = sorted(random.sample(range(original_len), len(subsample))) 135 for x in ssl_inputs + ssl_targets: 136 if x is not None: 137 for key in x: 138 if len(x[key].shape) == 3 and x[key].shape[-1] == original_len: 139 x[key] = x[key][..., subsample] 140 main_input, ssl_inputs, ssl_targets = self._apply_augmentations( 141 main_input, ssl_inputs, ssl_targets, augment 142 ) 143 meta = [None for _ in ssl_inputs] 144 for i, x in enumerate(ssl_inputs): 145 if type(x) is tuple: 146 x, meta_x = x 147 meta[i] = meta_x 148 ssl_inputs[i] = x 149 for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input): 150 if ssl_x is None and generate: 151 ssl_inputs[i] = self._apply_augmentations( 152 deepcopy(main_input), None, None, augment=True, ssl=True 153 )[0] 154 output = [] 155 num_ssl = len(ssl_inputs) 156 dicts = [main_input] + ssl_inputs + ssl_targets 157 158 for x in dicts: 159 if x is None: 160 output.append(None) 161 else: 162 output.append(self._make_tensor(x, self.model_name)) 163 # output.append( 164 # torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 165 # ) 166 main_input, ssl_inputs, ssl_targets = ( 167 output[0], 168 output[1 : num_ssl + 1], 169 output[num_ssl + 1 :], 170 ) 171 for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none): 172 if not keep and ssl_x is None: 173 ssl_targets[i] = main_input 174 for i, meta_x in enumerate(meta): 175 if meta_x is not None: 176 ssl_inputs[i] = (ssl_inputs[i], meta_x) 177 return main_input, ssl_inputs, ssl_targets 178 179 @abstractmethod 180 def _augmentations_dict(self) -> Dict: 181 """ 182 Return a dictionary of possible augmentations 183 184 The keys are augmentation names and the values are the corresponding functions 185 186 Returns 187 ------- 188 augmentations_dict : dict 189 a dictionary of augmentation functions (each function needs to take 190 `(main_input: dict, ssl_inputs: list, ssl_targets: list)` as input and return an output 191 of the same format) 192 """ 193 194 @abstractmethod 195 def _default_augmentations(self) -> List: 196 """ 197 Return a list of default augmentation names 198 199 In case an augmentation list is not provided to the class constructor and use_default_augmentations is `True`, 200 this function is called to set the augmentations parameter. The elements of the list have to be keys of the 201 dictionary returned by `self._augmentations_dict()` 202 203 Returns 204 ------- 205 default_augmentations : list 206 a list of string names of the default augmentation functions 207 """ 208 209 def _check_augmentations(self, augmentations: List) -> None: 210 """ 211 Check the validity of an augmentations list 212 """ 213 214 for aug_name in augmentations: 215 if aug_name not in self._augmentations_dict(): 216 raise ValueError( 217 f"The {aug_name} augmentation is not possible in this augmentor! Please choose from {list(self._augmentations_dict().keys())}" 218 ) 219 220 def _get_augmentation(self, aug_name: str) -> Callable: 221 """ 222 Return the augmentation specified by `aug_name` 223 """ 224 225 return self._augmentations_dict()[aug_name] 226 227 def _visualize(self, main_input, title): 228 coord_keys = [x for x in main_input.keys() if x.split("---")[0] == "coords"] 229 if len(coord_keys) > 0: 230 coords = [main_input[key][0, :, :, 0].detach().cpu() for key in coord_keys] 231 centers = [None] 232 else: 233 coord_keys = [ 234 x for x in main_input.keys() if x.split("---")[0] == "coord_diff" 235 ] 236 coords = [] 237 centers = [] 238 for coord_diff_key in coord_keys: 239 if len(coord_diff_key.split("---")) == 2: 240 ind = coord_diff_key.split("---")[1] 241 center_key = f"center---{ind}" 242 else: 243 center_key = "center" 244 ind = "" 245 if center_key in main_input.keys(): 246 coords.append( 247 main_input[center_key][0, :, :, 0].detach().cpu() 248 + main_input[coord_diff_key][0, :, :, 0].detach().cpu() 249 ) 250 title += f", {ind}: {main_input[center_key][0, 0, :, 0].data}" 251 center = main_input[center_key][0, :, :, 0].detach().cpu() 252 else: 253 coords.append(main_input[coord_diff_key][0, :, :, 0].detach().cpu()) 254 center = None 255 centers.append(center) 256 colors = ["blue", "orange", "green", "purple", "pink"] 257 if coords[0].shape[1] == 2: 258 plt.figure(figsize=(15, 15)) 259 else: 260 fig = plt.figure(figsize=(15, 15)) 261 ax = fig.add_subplot(projection="3d") 262 for i, coord in enumerate(coords): 263 if coord.shape[1] == 2: 264 plt.scatter(coord[:, 0], coord[:, 1], color=colors[i]) 265 plt.xlim((-0.5, 0.5)) 266 plt.ylim((-0.5, 0.5)) 267 else: 268 ax.scatter( 269 coord[:, 0].detach().cpu(), 270 coord[:, 1].detach().cpu(), 271 coord[:, 2].detach().cpu(), 272 color=colors[i], 273 ) 274 if centers[i] is not None: 275 ax.scatter( 276 centers[i][:, 0].detach().cpu(), 277 centers[i][:, 1].detach().cpu(), 278 0, 279 color="red", 280 s=30, 281 ) 282 center = centers[i][0].detach().cpu() 283 ax.text( 284 center[0], center[1], 0, f"({center[0]:.2f}, {center[1]:.2f})" 285 ) 286 for i in [1, 8]: 287 ax.scatter( 288 coord[i : i + 1, 0].detach().cpu(), 289 coord[i : i + 1, 1].detach().cpu(), 290 coord[i : i + 1, 2].detach().cpu(), 291 color="purple", 292 ) 293 ax.text( 294 coord[i, 0], 295 coord[i, 1], 296 coord[i, 2], 297 f"({coord[i, 0]:.2f}, {coord[i, 1]:.2f}, {coord[i, 2]:.2f})", 298 ) 299 300 plt.xlim((-3, 3)) 301 plt.ylim((-3, 3)) 302 ax.set_zlim(-2, 4) 303 ax.set_xlabel("x") 304 ax.set_ylabel("y") 305 ax.set_zlabel("z") 306 plt.title(title) 307 plt.show() 308 309 def _apply_augmentations( 310 self, 311 main_input: Dict, 312 ssl_inputs: List = None, 313 ssl_targets: List = None, 314 augment: bool = False, 315 ssl: bool = False, 316 ) -> Tuple: 317 """ 318 Apply the augmentations 319 320 The same augmentations are applied to all inputs 321 """ 322 323 visualize = False 324 if ssl: 325 augmentations = self.ssl_augmentations 326 else: 327 augmentations = self.augmentations 328 if visualize: 329 self._visualize(main_input, "before") 330 if ssl_inputs is None: 331 ssl_inputs = [None] 332 if ssl_targets is None: 333 ssl_targets = [None] 334 if augment: 335 for aug_name in augmentations: 336 augment_func = self._get_augmentation(aug_name) 337 main_input, ssl_inputs, ssl_targets = augment_func( 338 main_input, ssl_inputs, ssl_targets 339 ) 340 if visualize: 341 self._visualize(main_input, aug_name) 342 return main_input, ssl_inputs, ssl_targets 343 344 def _make_tensor(self, x: Dict, model_name: str) -> Union[torch.Tensor, TensorList]: 345 """ 346 Turn a feature dictionary into a tensor or a `dlc2action.utils.TensorList` object 347 """ 348 349 if model_name == "ms_tcn_p": 350 keys = sorted(list(x.keys())) 351 groups = [key.split("---")[-1] for key in keys] 352 unique_groups = sorted(set(groups)) 353 tensor = TensorList() 354 for group in unique_groups: 355 if not self.graph_features: 356 tensor.append( 357 torch.cat( 358 [x[key] for key, g in zip(keys, groups) if g == group], 359 dim=2, 360 ) 361 ) 362 else: 363 tensor.append( 364 torch.cat( 365 [ 366 x[key].reshape( 367 ( 368 x[key].shape[0], 369 self.num_graph_nodes, 370 -1, 371 x[key].shape[-1], 372 ) 373 ) 374 for key, g in zip(keys, groups) 375 if g == group 376 ], 377 dim=2, 378 ) 379 ) 380 tensor[-1] = tensor[-1].reshape( 381 (tensor[-1].shape[0], -1, tensor[-1].shape[-1]) 382 ) 383 if "loaded" in x: 384 tensor.append(x["loaded"]) 385 elif model_name == "c2f_tcn_p": 386 keys = sorted( 387 [ 388 key 389 for key in x.keys() 390 if len(key.split("---")) != 1 391 and len(key.split("---")[-1].split("+")) != 2 392 ] 393 ) 394 inds = [key.split("---")[-1] for key in keys] 395 unique_inds = sorted(set(inds)) 396 tensor = TensorList() 397 for ind in unique_inds: 398 if not self.graph_features: 399 tensor.append( 400 torch.cat( 401 [x[key] for key, g in zip(keys, inds) if g == ind], 402 dim=1, 403 ) 404 ) 405 else: 406 tensor.append( 407 torch.cat( 408 [ 409 x[key].reshape( 410 ( 411 x[key].shape[0], 412 self.num_graph_nodes, 413 -1, 414 x[key].shape[-1], 415 ) 416 ) 417 for key, g in zip(keys, inds) 418 if g == ind 419 ], 420 dim=1, 421 ) 422 ) 423 tensor[-1] = tensor[-1].reshape( 424 (tensor[-1].shape[0], -1, tensor[-1].shape[-1]) 425 ) 426 elif model_name == "c3d_a": 427 tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 428 else: 429 if not self.graph_features: 430 tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 431 else: 432 tensor = torch.cat( 433 [ 434 x[key].reshape( 435 ( 436 x[key].shape[0], 437 self.num_graph_nodes, 438 -1, 439 x[key].shape[-1], 440 ) 441 ) 442 for key in sorted(list(x.keys())) 443 ], 444 dim=2, 445 ) 446 tensor = tensor.reshape((tensor.shape[0], -1, tensor.shape[-1])) 447 return tensor
A base class for all transformers
A transformer should apply augmentations and generate model input and training target tensors.
All augmentation functions need to take (main_input: dict, ssl_inputs: list, ssl_targets: list)
as input and return an output of the same format. Here main_input
is a feature dictionary of the sample
data, ssl_inputs
is a list of SSL input feature dictionaries and ssl_targets
is a list of SSL target
feature dictionaries. The same augmentations are applied to all inputs and then None
values are replaced
according to the rules set by keep_target_none
and generate_ssl_input
parameters and the feature
dictionaries are compiled into tensors.
33 def __init__( 34 self, 35 model_name: str, 36 augmentations: List = None, 37 use_default_augmentations: bool = False, 38 generate_ssl_input: List = None, 39 keep_target_none: List = None, 40 ssl_augmentations: List = None, 41 graph_features: bool = False, 42 bodyparts_order: List = None, 43 ) -> None: 44 """ 45 Parameters 46 ---------- 47 augmentations : list, optional 48 a list of string names of augmentations to use (if not provided, either no augmentations are applied or 49 (if use_default_augmentations is True) a default list is used 50 use_default_augmentations : bool, default False 51 if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 52 generate_ssl_input : list, optional 53 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 54 is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to 55 `False` for each module) 56 keep_target_none : list, optional 57 a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value 58 is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults 59 to `True` for each module) 60 ssl_augmentations : list, optional 61 a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True) 62 (if not provided, defaults to the main augmentations list) 63 graph_features : bool, default False 64 if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)` 65 bodyparts_order : list, optional 66 a list of bodypart names, optional 67 """ 68 69 if augmentations is None: 70 augmentations = [] 71 if generate_ssl_input is None: 72 generate_ssl_input = [None] 73 if keep_target_none is None: 74 keep_target_none = [None] 75 self.model_name = model_name 76 self.augmentations = augmentations 77 self.generate_ssl_input = generate_ssl_input 78 self.keep_target_none = keep_target_none 79 if len(self.augmentations) == 0 and use_default_augmentations: 80 self.augmentations = self._default_augmentations() 81 if ssl_augmentations is None: 82 ssl_augmentations = self.augmentations 83 self.ssl_augmentations = ssl_augmentations 84 self.graph_features = graph_features 85 self.num_graph_nodes = ( 86 len(bodyparts_order) if bodyparts_order is not None else None 87 ) 88 self._check_augmentations(self.augmentations) 89 self._check_augmentations(self.ssl_augmentations)
Parameters
augmentations : list, optional
a list of string names of augmentations to use (if not provided, either no augmentations are applied or
(if use_default_augmentations is True) a default list is used
use_default_augmentations : bool, default False
if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
generate_ssl_input : list, optional
a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
is True
, the ssl input will be generated as a new augmentation of main input (if not provided defaults to
False
for each module)
keep_target_none : list, optional
a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
is False
and the SSL target is None
, the target is set to augmented main input (if not provided defaults
to True
for each module)
ssl_augmentations : list, optional
a list of augmentation names to be applied with generating SSL input (when generate_ssl_input
is True)
(if not provided, defaults to the main augmentations list)
graph_features : bool, default False
if True
, all features in each frame can be meaningfully reshaped to (#bodyparts, #features)
bodyparts_order : list, optional
a list of bodypart names, optional
91 def transform( 92 self, 93 main_input: Dict, 94 ssl_inputs: List = None, 95 ssl_targets: List = None, 96 augment: bool = False, 97 subsample: List = None, 98 ) -> Tuple: 99 """ 100 Apply augmentations and generate tensors from feature dictionaries 101 102 The same augmentations are applied to all the inputs (if they have the features known to the transformer). 103 104 If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as 105 another augmentation of main_input. 106 Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`. 107 All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to 108 a model. 109 110 Parameters 111 ---------- 112 main_input : dict 113 the feature dictionary of the main input 114 ssl_inputs : list, optional 115 a list of feature dictionaries of SSL inputs (some or all can be None) 116 ssl_targets : list, optional 117 a list of feature dictionaries of SSL targets (some or all can be None) 118 augment : bool, default True 119 120 Returns 121 ------- 122 main_input : torch.Tensor 123 the augmented tensor of the main input 124 ssl_inputs : list, optional 125 a list of augmented tensors of SSL inputs (some or all can be None) 126 ssl_targets : list, optional 127 a list of augmented tensors of SSL targets (some or all can be None) 128 """ 129 130 if subsample is not None: 131 original_len = list(main_input.values())[0].shape[-1] 132 for key in main_input: 133 main_input[key] = main_input[key][..., subsample] 134 # subsample_ssl = sorted(random.sample(range(original_len), len(subsample))) 135 for x in ssl_inputs + ssl_targets: 136 if x is not None: 137 for key in x: 138 if len(x[key].shape) == 3 and x[key].shape[-1] == original_len: 139 x[key] = x[key][..., subsample] 140 main_input, ssl_inputs, ssl_targets = self._apply_augmentations( 141 main_input, ssl_inputs, ssl_targets, augment 142 ) 143 meta = [None for _ in ssl_inputs] 144 for i, x in enumerate(ssl_inputs): 145 if type(x) is tuple: 146 x, meta_x = x 147 meta[i] = meta_x 148 ssl_inputs[i] = x 149 for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input): 150 if ssl_x is None and generate: 151 ssl_inputs[i] = self._apply_augmentations( 152 deepcopy(main_input), None, None, augment=True, ssl=True 153 )[0] 154 output = [] 155 num_ssl = len(ssl_inputs) 156 dicts = [main_input] + ssl_inputs + ssl_targets 157 158 for x in dicts: 159 if x is None: 160 output.append(None) 161 else: 162 output.append(self._make_tensor(x, self.model_name)) 163 # output.append( 164 # torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1) 165 # ) 166 main_input, ssl_inputs, ssl_targets = ( 167 output[0], 168 output[1 : num_ssl + 1], 169 output[num_ssl + 1 :], 170 ) 171 for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none): 172 if not keep and ssl_x is None: 173 ssl_targets[i] = main_input 174 for i, meta_x in enumerate(meta): 175 if meta_x is not None: 176 ssl_inputs[i] = (ssl_inputs[i], meta_x) 177 return main_input, ssl_inputs, ssl_targets
Apply augmentations and generate tensors from feature dictionaries
The same augmentations are applied to all the inputs (if they have the features known to the transformer).
If generate_ssl_input
is set to True for some of the SSL pairs, those SSL inputs will be generated as
another augmentation of main_input.
Unless keep_target_none
is set to True, None
SSL targets will be replaced with augmented main_input
.
All features are stacked together to form a tensor of shape (#features, #frames)
that can be passed to
a model.
Parameters
main_input : dict the feature dictionary of the main input ssl_inputs : list, optional a list of feature dictionaries of SSL inputs (some or all can be None) ssl_targets : list, optional a list of feature dictionaries of SSL targets (some or all can be None) augment : bool, default True
Returns
main_input : torch.Tensor the augmented tensor of the main input ssl_inputs : list, optional a list of augmented tensors of SSL inputs (some or all can be None) ssl_targets : list, optional a list of augmented tensors of SSL targets (some or all can be None)
450class EmptyTransformer(Transformer): 451 """ 452 Empty transformer class that does not apply augmentations 453 """ 454 455 def _augmentations_dict(self) -> Dict: 456 """ 457 Return a dictionary of possible augmentations 458 459 The keys are augmentation names and the values are the corresponding functions 460 461 Returns 462 ------- 463 augmentations_dict : dict 464 a dictionary of augmentation functions 465 """ 466 467 return {} 468 469 def _default_augmentations(self) -> List: 470 """ 471 Return a list of default augmentation names 472 473 In case an augmentation list is not provided to the class constructor and use_default_augmentations is True, 474 this function is called to set the augmentations parameter. The elements of the list have to be keys of the 475 dictionary returned by `self._augmentations_dict()` 476 477 Returns 478 ------- 479 default_augmentations : list 480 a list of string names of the default augmentation functions 481 """ 482 483 return []
Empty transformer class that does not apply augmentations