dlc2action.transformer.kinematic
Kinematic transformer
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Kinematic transformer 8""" 9 10from typing import Dict, Tuple, List, Set, Iterable 11import torch 12import numpy as np 13from copy import copy, deepcopy 14from random import getrandbits 15from collections import defaultdict 16from dlc2action.transformer.base_transformer import Transformer 17from dlc2action.utils import rotation_matrix_2d, rotation_matrix_3d 18 19 20class KinematicTransformer(Transformer): 21 """ 22 A transformer that augments the output of the Kinematic feature extractor 23 24 The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'` 25 """ 26 27 def __init__( 28 self, 29 model_name: str, 30 augmentations: List = None, 31 use_default_augmentations: bool = False, 32 rotation_limits: List = None, 33 mirror_dim: Set = None, 34 noise_std: float = 0.05, 35 zoom_limits: List = None, 36 masking_probability: float = 0.05, 37 dim: int = 2, 38 graph_features: bool = False, 39 bodyparts_order: List = None, 40 canvas_shape: List = None, 41 move_around_image_center: bool = True, 42 **kwargs, 43 ) -> None: 44 """ 45 Parameters 46 ---------- 47 augmentations : list, optional 48 list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") 49 use_default_augmentations : bool, default False 50 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 51 rotation_limits : list, default [-pi/2, pi/2] 52 list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` 53 for 3D data) 54 mirror_dim : set, default {0} 55 set of integer indices of dimensions that can be mirrored 56 noise_std : float, default 0.05 57 standard deviation of noise 58 zoom_limits : list, default [0.5, 1.5] 59 list of float zoom limits ([low, high]) 60 masking_probability : float, default 0.1 61 the probability of masking a joint 62 dim : int, default 2 63 the dimensionality of the input data 64 **kwargs : dict 65 other parameters for the base transformer class 66 """ 67 68 if augmentations is None: 69 augmentations = [] 70 if canvas_shape is None: 71 canvas_shape = [1, 1] 72 self.dim = int(dim) 73 74 self.offset = [0 for _ in range(self.dim)] 75 self.scale = canvas_shape[1] / canvas_shape[0] 76 self.image_center = move_around_image_center 77 # if canvas_shape is None: 78 # self.offset = [0.5 for _ in range(self.dim)] 79 # else: 80 # self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)] 81 82 self.blank = 0 # the value that nan values are set to (shouldn't be changed in augmentations) 83 super().__init__( 84 model_name, 85 augmentations, 86 use_default_augmentations, 87 graph_features=graph_features, 88 bodyparts_order=bodyparts_order, 89 **kwargs, 90 ) 91 if rotation_limits is None: 92 rotation_limits = [-np.pi / 2, np.pi / 2] 93 if mirror_dim is None: 94 mirror_dim = [0] 95 if zoom_limits is None: 96 zoom_limits = [0.5, 1.5] 97 self.rotation_limits = rotation_limits 98 self.n_bodyparts = None 99 self.mirror_dim = mirror_dim 100 self.noise_std = noise_std 101 self.zoom_limits = zoom_limits 102 self.masking_probability = masking_probability 103 104 def _apply_augmentations( 105 self, 106 main_input: Dict, 107 ssl_inputs: List = None, 108 ssl_targets: List = None, 109 augment: bool = False, 110 ssl: bool = False, 111 ) -> Tuple: 112 113 if ssl_targets is None: 114 ssl_targets = [None] 115 if ssl_inputs is None: 116 ssl_inputs = [None] 117 118 keys_main = self._get_keys( 119 ( 120 "coords", 121 "speed_joints", 122 "speed_direction", 123 "speed_bones", 124 "acc_bones", 125 "bones", 126 "coord_diff", 127 ), 128 main_input, 129 ) 130 if len(keys_main) > 0: 131 key = keys_main[0] 132 final_shape = main_input[key].shape 133 # self._get_bodyparts(final_shape) 134 batch = final_shape[0] 135 s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1])) 136 x_shape = s.shape 137 # print(f'{x_shape=}, {self.dim=}') 138 self.n_bodyparts = x_shape[1] 139 if self.dim == 3: 140 if len(self.rotation_limits) == 2: 141 self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits] 142 dicts = [main_input] + ssl_inputs + ssl_targets 143 for x in dicts: 144 if x is not None: 145 keys = self._get_keys( 146 ( 147 "coords", 148 "speed_joints", 149 "speed_direction", 150 "speed_bones", 151 "acc_bones", 152 "bones", 153 "coord_diff", 154 ), 155 x, 156 ) 157 for key in keys: 158 x[key] = x[key].reshape(x_shape) 159 if len(self._get_keys(("center"), main_input)) > 0: 160 dicts = [main_input] + ssl_inputs + ssl_targets 161 for x in dicts: 162 if x is not None: 163 key_bases = ["center"] 164 keys = self._get_keys( 165 key_bases, 166 x, 167 ) 168 for key in keys: 169 x[key] = x[key].reshape( 170 (x[key].shape[0], 1, -1, x[key].shape[-1]) 171 ) 172 main_input, ssl_inputs, ssl_targets = super()._apply_augmentations( 173 main_input, ssl_inputs, ssl_targets, augment 174 ) 175 if len(keys_main) > 0: 176 dicts = [main_input] + ssl_inputs + ssl_targets 177 for x in dicts: 178 if x is not None: 179 keys = self._get_keys( 180 ( 181 "coords", 182 "speed_joints", 183 "speed_direction", 184 "speed_bones", 185 "acc_bones", 186 "bones", 187 "coord_diff", 188 ), 189 x, 190 ) 191 for key in keys: 192 x[key] = x[key].reshape(final_shape) 193 if len(self._get_keys(("center"), main_input)) > 0: 194 dicts = [main_input] + ssl_inputs + ssl_targets 195 for x in dicts: 196 if x is not None: 197 key_bases = ["center"] 198 keys = self._get_keys( 199 key_bases, 200 x, 201 ) 202 for key in keys: 203 x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1])) 204 return main_input, ssl_inputs, ssl_targets 205 206 def _rotate( 207 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 208 ) -> Tuple[Dict, List, List]: 209 """ 210 Rotate the "coords" and "speed_joints" features of the input to a random degree 211 """ 212 213 keys = self._get_keys( 214 ( 215 "coords", 216 "bones", 217 "coord_diff", 218 "center", 219 "speed_joints", 220 "speed_direction", 221 "center", 222 ), 223 main_input, 224 ) 225 if len(keys) == 0: 226 return main_input, ssl_inputs, ssl_targets 227 batch = main_input[keys[0]].shape[0] 228 if self.dim == 2: 229 angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits) 230 R = rotation_matrix_2d(angles).to(main_input[keys[0]].device) 231 else: 232 angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0]) 233 angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1]) 234 angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2]) 235 R = rotation_matrix_3d(angles_x, angles_y, angles_z).to( 236 main_input[keys[0]].device 237 ) 238 dicts = [main_input] + ssl_inputs + ssl_targets 239 for x in dicts: 240 if x is not None: 241 keys = self._get_keys( 242 ( 243 "coords", 244 "speed_joints", 245 "speed_direction", 246 "speed_bones", 247 "acc_bones", 248 "bones", 249 "coord_diff", 250 "center", 251 ), 252 x, 253 ) 254 for key in keys: 255 if key in x: 256 mask = x[key] == self.blank 257 if ( 258 any([key.startswith(x) for x in ["coords", "center"]]) 259 and not self.image_center 260 ): 261 offset = x[key].mean(1).unsqueeze(1) 262 else: 263 offset = 0 264 x[key] = ( 265 torch.einsum( 266 "abjd,aij->abid", 267 x[key] - offset, 268 R, 269 ) 270 + offset 271 ) 272 x[key][mask] = self.blank 273 return main_input, ssl_inputs, ssl_targets 274 275 def _mirror( 276 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 277 ) -> Tuple[Dict, List, List]: 278 """ 279 Mirror the "coords" and "speed_joints" features of the input randomly 280 """ 281 282 mirror = [] 283 for i in range(3): 284 if i in self.mirror_dim and bool(getrandbits(1)): 285 mirror.append(i) 286 if len(mirror) == 0: 287 return main_input, ssl_inputs, ssl_targets 288 dicts = [main_input] + ssl_inputs + ssl_targets 289 for x in dicts: 290 if x is not None: 291 keys = self._get_keys( 292 ( 293 "coords", 294 "speed_joints", 295 "speed_direction", 296 "speed_bones", 297 "acc_bones", 298 "bones", 299 "center", 300 "coord_diff", 301 ), 302 x, 303 ) 304 for key in keys: 305 if key in x: 306 mask = x[key] == self.blank 307 y = deepcopy(x[key]) 308 if not self.image_center and any( 309 [key.startswith(x) for x in ["coords", "center"]] 310 ): 311 mean = y.mean(1).unsqueeze(1) 312 for dim in mirror: 313 y[:, :, dim, :] = ( 314 2 * mean[:, :, dim, :] - y[:, :, dim, :] 315 ) 316 else: 317 for dim in mirror: 318 y[:, :, dim, :] *= -1 319 x[key] = y 320 x[key][mask] = self.blank 321 return main_input, ssl_inputs, ssl_targets 322 323 def _shift( 324 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 325 ) -> Tuple[Dict, List, List]: 326 """ 327 Shift the "coords" features of the input randomly 328 """ 329 330 keys = self._get_keys(("coords", "center"), main_input) 331 if len(keys) == 0: 332 return main_input, ssl_inputs, ssl_targets 333 batch = main_input[keys[0]].shape[0] 334 dim = main_input[keys[0]].shape[2] 335 device = main_input[keys[0]].device 336 coords = torch.cat( 337 [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys], 338 dim=1, 339 ) 340 minimask = coords[:, :, 0] != self.blank 341 min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0] 342 min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0] 343 max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0] 344 max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0] 345 del coords 346 shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * ( 347 max_x - min_x 348 ) 349 shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * ( 350 max_y - min_y 351 ) 352 shift_x = shift_x.unsqueeze(-1).unsqueeze(-1) 353 shift_y = shift_y.unsqueeze(-1).unsqueeze(-1) 354 dicts = [main_input] + ssl_inputs + ssl_targets 355 for x in dicts: 356 if x is not None: 357 keys = self._get_keys(("coords", "center"), x) 358 for key in keys: 359 y = deepcopy(x[key]) 360 mask = y == self.blank 361 y[:, :, 0, :] += shift_x 362 y[:, :, 1, :] += shift_y 363 x[key] = y 364 x[key][mask] = self.blank 365 return main_input, ssl_inputs, ssl_targets 366 367 def _add_noise( 368 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 369 ) -> Tuple[Dict, List, List]: 370 """ 371 Add normal noise to all features of the input 372 """ 373 374 dicts = [main_input] + ssl_inputs + ssl_targets 375 for x in dicts: 376 if x is not None: 377 keys = self._get_keys( 378 ( 379 "coords", 380 "speed_joints", 381 "speed_direction", 382 "intra_distance", 383 "acc_joints", 384 "angle_speeds", 385 "inter_distance", 386 "speed_bones", 387 "acc_bones", 388 "bones", 389 "coord_diff", 390 "center", 391 ), 392 x, 393 ) 394 for key in keys: 395 mask = x[key] == self.blank 396 x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_( 397 std=self.noise_std 398 ).to(x[key].device) 399 x[key][mask] = self.blank 400 return main_input, ssl_inputs, ssl_targets 401 402 def _zoom( 403 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 404 ) -> Tuple[Dict, List, List]: 405 """ 406 Add random zoom to all features of the input 407 """ 408 409 key = list(main_input.keys())[0] 410 batch = main_input[key].shape[0] 411 device = main_input[key].device 412 zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device) 413 dicts = [main_input] + ssl_inputs + ssl_targets 414 for x in dicts: 415 if x is not None: 416 keys = self._get_keys( 417 ( 418 "speed_joints", 419 "intra_distance", 420 "acc_joints", 421 "inter_distance", 422 "speed_bones", 423 "acc_bones", 424 "bones", 425 "coord_diff", 426 "center", 427 "speed_value", 428 ), 429 x, 430 ) 431 for key in keys: 432 mask = x[key] == self.blank 433 shape = np.array(x[key].shape) 434 shape[1:] = 1 435 y = deepcopy(x[key]) 436 y *= zoom.reshape(list(shape)) 437 x[key] = y 438 x[key][mask] = self.blank 439 keys = self._get_keys("coords", x) 440 for key in keys: 441 mask = x[key] == self.blank 442 shape = np.array(x[key].shape) 443 shape[1:] = 1 444 zoom = zoom.reshape(list(shape)) 445 x[key][mask] = 10 446 min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1) 447 min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1) 448 center = torch.stack([min_x, min_y], dim=2) 449 coords = (x[key] - center) * zoom + center 450 x[key] = coords 451 x[key][mask] = self.blank 452 453 return main_input, ssl_inputs, ssl_targets 454 455 def _mask_joints( 456 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 457 ) -> Tuple[Dict, List, List]: 458 """ 459 Mask joints randomly 460 """ 461 462 key = list(main_input.keys())[0] 463 batch, *_, frames = main_input[key].shape 464 masked_joints = ( 465 torch.FloatTensor(batch, self.n_bodyparts).uniform_() 466 < self.masking_probability 467 ) 468 dicts = [main_input] + ssl_inputs + ssl_targets 469 for x in [y for y in dicts if y is not None]: 470 keys = self._get_keys(("intra_distance", "inter_distance"), x) 471 for key in keys: 472 mask = ( 473 masked_joints.repeat(self.n_bodyparts, frames, 1, 1) 474 .transpose(0, 2) 475 .transpose(1, 3) 476 ) 477 indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1) 478 479 X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to( 480 x[key].device 481 ) 482 X[:, indices[0], indices[1], :] = x[key] 483 X[mask] = self.blank 484 X[mask.transpose(1, 2)] = self.blank 485 x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames) 486 keys = self._get_keys( 487 ( 488 "speed_joints", 489 "speed_direction", 490 "coords", 491 "acc_joints", 492 "speed_bones", 493 "acc_bones", 494 "bones", 495 "coord_diff", 496 ), 497 x, 498 ) 499 for key in keys: 500 mask = ( 501 masked_joints.repeat(self.dim, frames, 1, 1) 502 .transpose(0, 2) 503 .transpose(1, 3) 504 ) 505 x[key][mask] = ( 506 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask] 507 ) 508 keys = self._get_keys("angle_speeds", x) 509 for key in keys: 510 mask = ( 511 masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2) 512 ) 513 x[key][mask] = ( 514 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask] 515 ) 516 517 return main_input, ssl_inputs, ssl_targets 518 519 def _switch( 520 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 521 ) -> Tuple[Dict, List, List]: 522 if bool(getrandbits(1)): 523 return main_input, ssl_inputs, ssl_targets 524 individuals = set() 525 ind_dict = defaultdict(lambda: set()) 526 for key in main_input: 527 if len(key.split("---")) != 2: 528 continue 529 ind = key.split("---")[1] 530 if "+" in ind: 531 continue 532 individuals.add(ind) 533 ind_dict[ind].add(key.split("---")[0]) 534 individuals = list(individuals) 535 if len(individuals) < 2: 536 return main_input, ssl_inputs, ssl_targets 537 for x, y in zip(individuals[::2], individuals[1::2]): 538 for key in ind_dict[x]: 539 if key in ind_dict[y]: 540 main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = ( 541 main_input[f"{key}---{y}"], 542 main_input[f"{key}---{x}"], 543 ) 544 for d in ssl_inputs + ssl_targets: 545 if d is None: 546 continue 547 if f"{key}---{x}" in d and f"{key}---{y}" in d: 548 d[f"{key}---{x}"], d[f"{key}---{y}"] = ( 549 d[f"{key}---{y}"], 550 d[f"{key}---{x}"], 551 ) 552 return main_input, ssl_inputs, ssl_targets 553 554 def _get_keys(self, key_bases: Iterable, x: Dict) -> List: 555 """ 556 Get the keys of x that start with one of the strings from key_bases 557 """ 558 559 keys = [] 560 if isinstance(key_bases, str): 561 key_bases = [key_bases] 562 for key in x: 563 if any([x == key.split("---")[0] for x in key_bases]): 564 keys.append(key) 565 return keys 566 567 def _augmentations_dict(self) -> Dict: 568 """ 569 Get the mapping from augmentation names to functions 570 """ 571 572 return { 573 "mirror": self._mirror, 574 "shift": self._shift, 575 "add_noise": self._add_noise, 576 "zoom": self._zoom, 577 "rotate": self._rotate, 578 "mask": self._mask_joints, 579 "switch": self._switch, 580 } 581 582 def _default_augmentations(self) -> List: 583 """ 584 Get the list of default augmentation names 585 """ 586 587 return ["mirror", "shift", "add_noise"] 588 589 def _get_bodyparts(self, shape: Tuple) -> None: 590 """ 591 Set the number of bodyparts from the data if it is not known 592 """ 593 594 if self.n_bodyparts is None: 595 N, B, F = shape 596 self.n_bodyparts = B // 2
21class KinematicTransformer(Transformer): 22 """ 23 A transformer that augments the output of the Kinematic feature extractor 24 25 The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'` 26 """ 27 28 def __init__( 29 self, 30 model_name: str, 31 augmentations: List = None, 32 use_default_augmentations: bool = False, 33 rotation_limits: List = None, 34 mirror_dim: Set = None, 35 noise_std: float = 0.05, 36 zoom_limits: List = None, 37 masking_probability: float = 0.05, 38 dim: int = 2, 39 graph_features: bool = False, 40 bodyparts_order: List = None, 41 canvas_shape: List = None, 42 move_around_image_center: bool = True, 43 **kwargs, 44 ) -> None: 45 """ 46 Parameters 47 ---------- 48 augmentations : list, optional 49 list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") 50 use_default_augmentations : bool, default False 51 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 52 rotation_limits : list, default [-pi/2, pi/2] 53 list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` 54 for 3D data) 55 mirror_dim : set, default {0} 56 set of integer indices of dimensions that can be mirrored 57 noise_std : float, default 0.05 58 standard deviation of noise 59 zoom_limits : list, default [0.5, 1.5] 60 list of float zoom limits ([low, high]) 61 masking_probability : float, default 0.1 62 the probability of masking a joint 63 dim : int, default 2 64 the dimensionality of the input data 65 **kwargs : dict 66 other parameters for the base transformer class 67 """ 68 69 if augmentations is None: 70 augmentations = [] 71 if canvas_shape is None: 72 canvas_shape = [1, 1] 73 self.dim = int(dim) 74 75 self.offset = [0 for _ in range(self.dim)] 76 self.scale = canvas_shape[1] / canvas_shape[0] 77 self.image_center = move_around_image_center 78 # if canvas_shape is None: 79 # self.offset = [0.5 for _ in range(self.dim)] 80 # else: 81 # self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)] 82 83 self.blank = 0 # the value that nan values are set to (shouldn't be changed in augmentations) 84 super().__init__( 85 model_name, 86 augmentations, 87 use_default_augmentations, 88 graph_features=graph_features, 89 bodyparts_order=bodyparts_order, 90 **kwargs, 91 ) 92 if rotation_limits is None: 93 rotation_limits = [-np.pi / 2, np.pi / 2] 94 if mirror_dim is None: 95 mirror_dim = [0] 96 if zoom_limits is None: 97 zoom_limits = [0.5, 1.5] 98 self.rotation_limits = rotation_limits 99 self.n_bodyparts = None 100 self.mirror_dim = mirror_dim 101 self.noise_std = noise_std 102 self.zoom_limits = zoom_limits 103 self.masking_probability = masking_probability 104 105 def _apply_augmentations( 106 self, 107 main_input: Dict, 108 ssl_inputs: List = None, 109 ssl_targets: List = None, 110 augment: bool = False, 111 ssl: bool = False, 112 ) -> Tuple: 113 114 if ssl_targets is None: 115 ssl_targets = [None] 116 if ssl_inputs is None: 117 ssl_inputs = [None] 118 119 keys_main = self._get_keys( 120 ( 121 "coords", 122 "speed_joints", 123 "speed_direction", 124 "speed_bones", 125 "acc_bones", 126 "bones", 127 "coord_diff", 128 ), 129 main_input, 130 ) 131 if len(keys_main) > 0: 132 key = keys_main[0] 133 final_shape = main_input[key].shape 134 # self._get_bodyparts(final_shape) 135 batch = final_shape[0] 136 s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1])) 137 x_shape = s.shape 138 # print(f'{x_shape=}, {self.dim=}') 139 self.n_bodyparts = x_shape[1] 140 if self.dim == 3: 141 if len(self.rotation_limits) == 2: 142 self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits] 143 dicts = [main_input] + ssl_inputs + ssl_targets 144 for x in dicts: 145 if x is not None: 146 keys = self._get_keys( 147 ( 148 "coords", 149 "speed_joints", 150 "speed_direction", 151 "speed_bones", 152 "acc_bones", 153 "bones", 154 "coord_diff", 155 ), 156 x, 157 ) 158 for key in keys: 159 x[key] = x[key].reshape(x_shape) 160 if len(self._get_keys(("center"), main_input)) > 0: 161 dicts = [main_input] + ssl_inputs + ssl_targets 162 for x in dicts: 163 if x is not None: 164 key_bases = ["center"] 165 keys = self._get_keys( 166 key_bases, 167 x, 168 ) 169 for key in keys: 170 x[key] = x[key].reshape( 171 (x[key].shape[0], 1, -1, x[key].shape[-1]) 172 ) 173 main_input, ssl_inputs, ssl_targets = super()._apply_augmentations( 174 main_input, ssl_inputs, ssl_targets, augment 175 ) 176 if len(keys_main) > 0: 177 dicts = [main_input] + ssl_inputs + ssl_targets 178 for x in dicts: 179 if x is not None: 180 keys = self._get_keys( 181 ( 182 "coords", 183 "speed_joints", 184 "speed_direction", 185 "speed_bones", 186 "acc_bones", 187 "bones", 188 "coord_diff", 189 ), 190 x, 191 ) 192 for key in keys: 193 x[key] = x[key].reshape(final_shape) 194 if len(self._get_keys(("center"), main_input)) > 0: 195 dicts = [main_input] + ssl_inputs + ssl_targets 196 for x in dicts: 197 if x is not None: 198 key_bases = ["center"] 199 keys = self._get_keys( 200 key_bases, 201 x, 202 ) 203 for key in keys: 204 x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1])) 205 return main_input, ssl_inputs, ssl_targets 206 207 def _rotate( 208 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 209 ) -> Tuple[Dict, List, List]: 210 """ 211 Rotate the "coords" and "speed_joints" features of the input to a random degree 212 """ 213 214 keys = self._get_keys( 215 ( 216 "coords", 217 "bones", 218 "coord_diff", 219 "center", 220 "speed_joints", 221 "speed_direction", 222 "center", 223 ), 224 main_input, 225 ) 226 if len(keys) == 0: 227 return main_input, ssl_inputs, ssl_targets 228 batch = main_input[keys[0]].shape[0] 229 if self.dim == 2: 230 angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits) 231 R = rotation_matrix_2d(angles).to(main_input[keys[0]].device) 232 else: 233 angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0]) 234 angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1]) 235 angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2]) 236 R = rotation_matrix_3d(angles_x, angles_y, angles_z).to( 237 main_input[keys[0]].device 238 ) 239 dicts = [main_input] + ssl_inputs + ssl_targets 240 for x in dicts: 241 if x is not None: 242 keys = self._get_keys( 243 ( 244 "coords", 245 "speed_joints", 246 "speed_direction", 247 "speed_bones", 248 "acc_bones", 249 "bones", 250 "coord_diff", 251 "center", 252 ), 253 x, 254 ) 255 for key in keys: 256 if key in x: 257 mask = x[key] == self.blank 258 if ( 259 any([key.startswith(x) for x in ["coords", "center"]]) 260 and not self.image_center 261 ): 262 offset = x[key].mean(1).unsqueeze(1) 263 else: 264 offset = 0 265 x[key] = ( 266 torch.einsum( 267 "abjd,aij->abid", 268 x[key] - offset, 269 R, 270 ) 271 + offset 272 ) 273 x[key][mask] = self.blank 274 return main_input, ssl_inputs, ssl_targets 275 276 def _mirror( 277 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 278 ) -> Tuple[Dict, List, List]: 279 """ 280 Mirror the "coords" and "speed_joints" features of the input randomly 281 """ 282 283 mirror = [] 284 for i in range(3): 285 if i in self.mirror_dim and bool(getrandbits(1)): 286 mirror.append(i) 287 if len(mirror) == 0: 288 return main_input, ssl_inputs, ssl_targets 289 dicts = [main_input] + ssl_inputs + ssl_targets 290 for x in dicts: 291 if x is not None: 292 keys = self._get_keys( 293 ( 294 "coords", 295 "speed_joints", 296 "speed_direction", 297 "speed_bones", 298 "acc_bones", 299 "bones", 300 "center", 301 "coord_diff", 302 ), 303 x, 304 ) 305 for key in keys: 306 if key in x: 307 mask = x[key] == self.blank 308 y = deepcopy(x[key]) 309 if not self.image_center and any( 310 [key.startswith(x) for x in ["coords", "center"]] 311 ): 312 mean = y.mean(1).unsqueeze(1) 313 for dim in mirror: 314 y[:, :, dim, :] = ( 315 2 * mean[:, :, dim, :] - y[:, :, dim, :] 316 ) 317 else: 318 for dim in mirror: 319 y[:, :, dim, :] *= -1 320 x[key] = y 321 x[key][mask] = self.blank 322 return main_input, ssl_inputs, ssl_targets 323 324 def _shift( 325 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 326 ) -> Tuple[Dict, List, List]: 327 """ 328 Shift the "coords" features of the input randomly 329 """ 330 331 keys = self._get_keys(("coords", "center"), main_input) 332 if len(keys) == 0: 333 return main_input, ssl_inputs, ssl_targets 334 batch = main_input[keys[0]].shape[0] 335 dim = main_input[keys[0]].shape[2] 336 device = main_input[keys[0]].device 337 coords = torch.cat( 338 [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys], 339 dim=1, 340 ) 341 minimask = coords[:, :, 0] != self.blank 342 min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0] 343 min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0] 344 max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0] 345 max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0] 346 del coords 347 shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * ( 348 max_x - min_x 349 ) 350 shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * ( 351 max_y - min_y 352 ) 353 shift_x = shift_x.unsqueeze(-1).unsqueeze(-1) 354 shift_y = shift_y.unsqueeze(-1).unsqueeze(-1) 355 dicts = [main_input] + ssl_inputs + ssl_targets 356 for x in dicts: 357 if x is not None: 358 keys = self._get_keys(("coords", "center"), x) 359 for key in keys: 360 y = deepcopy(x[key]) 361 mask = y == self.blank 362 y[:, :, 0, :] += shift_x 363 y[:, :, 1, :] += shift_y 364 x[key] = y 365 x[key][mask] = self.blank 366 return main_input, ssl_inputs, ssl_targets 367 368 def _add_noise( 369 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 370 ) -> Tuple[Dict, List, List]: 371 """ 372 Add normal noise to all features of the input 373 """ 374 375 dicts = [main_input] + ssl_inputs + ssl_targets 376 for x in dicts: 377 if x is not None: 378 keys = self._get_keys( 379 ( 380 "coords", 381 "speed_joints", 382 "speed_direction", 383 "intra_distance", 384 "acc_joints", 385 "angle_speeds", 386 "inter_distance", 387 "speed_bones", 388 "acc_bones", 389 "bones", 390 "coord_diff", 391 "center", 392 ), 393 x, 394 ) 395 for key in keys: 396 mask = x[key] == self.blank 397 x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_( 398 std=self.noise_std 399 ).to(x[key].device) 400 x[key][mask] = self.blank 401 return main_input, ssl_inputs, ssl_targets 402 403 def _zoom( 404 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 405 ) -> Tuple[Dict, List, List]: 406 """ 407 Add random zoom to all features of the input 408 """ 409 410 key = list(main_input.keys())[0] 411 batch = main_input[key].shape[0] 412 device = main_input[key].device 413 zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device) 414 dicts = [main_input] + ssl_inputs + ssl_targets 415 for x in dicts: 416 if x is not None: 417 keys = self._get_keys( 418 ( 419 "speed_joints", 420 "intra_distance", 421 "acc_joints", 422 "inter_distance", 423 "speed_bones", 424 "acc_bones", 425 "bones", 426 "coord_diff", 427 "center", 428 "speed_value", 429 ), 430 x, 431 ) 432 for key in keys: 433 mask = x[key] == self.blank 434 shape = np.array(x[key].shape) 435 shape[1:] = 1 436 y = deepcopy(x[key]) 437 y *= zoom.reshape(list(shape)) 438 x[key] = y 439 x[key][mask] = self.blank 440 keys = self._get_keys("coords", x) 441 for key in keys: 442 mask = x[key] == self.blank 443 shape = np.array(x[key].shape) 444 shape[1:] = 1 445 zoom = zoom.reshape(list(shape)) 446 x[key][mask] = 10 447 min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1) 448 min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1) 449 center = torch.stack([min_x, min_y], dim=2) 450 coords = (x[key] - center) * zoom + center 451 x[key] = coords 452 x[key][mask] = self.blank 453 454 return main_input, ssl_inputs, ssl_targets 455 456 def _mask_joints( 457 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 458 ) -> Tuple[Dict, List, List]: 459 """ 460 Mask joints randomly 461 """ 462 463 key = list(main_input.keys())[0] 464 batch, *_, frames = main_input[key].shape 465 masked_joints = ( 466 torch.FloatTensor(batch, self.n_bodyparts).uniform_() 467 < self.masking_probability 468 ) 469 dicts = [main_input] + ssl_inputs + ssl_targets 470 for x in [y for y in dicts if y is not None]: 471 keys = self._get_keys(("intra_distance", "inter_distance"), x) 472 for key in keys: 473 mask = ( 474 masked_joints.repeat(self.n_bodyparts, frames, 1, 1) 475 .transpose(0, 2) 476 .transpose(1, 3) 477 ) 478 indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1) 479 480 X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to( 481 x[key].device 482 ) 483 X[:, indices[0], indices[1], :] = x[key] 484 X[mask] = self.blank 485 X[mask.transpose(1, 2)] = self.blank 486 x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames) 487 keys = self._get_keys( 488 ( 489 "speed_joints", 490 "speed_direction", 491 "coords", 492 "acc_joints", 493 "speed_bones", 494 "acc_bones", 495 "bones", 496 "coord_diff", 497 ), 498 x, 499 ) 500 for key in keys: 501 mask = ( 502 masked_joints.repeat(self.dim, frames, 1, 1) 503 .transpose(0, 2) 504 .transpose(1, 3) 505 ) 506 x[key][mask] = ( 507 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask] 508 ) 509 keys = self._get_keys("angle_speeds", x) 510 for key in keys: 511 mask = ( 512 masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2) 513 ) 514 x[key][mask] = ( 515 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask] 516 ) 517 518 return main_input, ssl_inputs, ssl_targets 519 520 def _switch( 521 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 522 ) -> Tuple[Dict, List, List]: 523 if bool(getrandbits(1)): 524 return main_input, ssl_inputs, ssl_targets 525 individuals = set() 526 ind_dict = defaultdict(lambda: set()) 527 for key in main_input: 528 if len(key.split("---")) != 2: 529 continue 530 ind = key.split("---")[1] 531 if "+" in ind: 532 continue 533 individuals.add(ind) 534 ind_dict[ind].add(key.split("---")[0]) 535 individuals = list(individuals) 536 if len(individuals) < 2: 537 return main_input, ssl_inputs, ssl_targets 538 for x, y in zip(individuals[::2], individuals[1::2]): 539 for key in ind_dict[x]: 540 if key in ind_dict[y]: 541 main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = ( 542 main_input[f"{key}---{y}"], 543 main_input[f"{key}---{x}"], 544 ) 545 for d in ssl_inputs + ssl_targets: 546 if d is None: 547 continue 548 if f"{key}---{x}" in d and f"{key}---{y}" in d: 549 d[f"{key}---{x}"], d[f"{key}---{y}"] = ( 550 d[f"{key}---{y}"], 551 d[f"{key}---{x}"], 552 ) 553 return main_input, ssl_inputs, ssl_targets 554 555 def _get_keys(self, key_bases: Iterable, x: Dict) -> List: 556 """ 557 Get the keys of x that start with one of the strings from key_bases 558 """ 559 560 keys = [] 561 if isinstance(key_bases, str): 562 key_bases = [key_bases] 563 for key in x: 564 if any([x == key.split("---")[0] for x in key_bases]): 565 keys.append(key) 566 return keys 567 568 def _augmentations_dict(self) -> Dict: 569 """ 570 Get the mapping from augmentation names to functions 571 """ 572 573 return { 574 "mirror": self._mirror, 575 "shift": self._shift, 576 "add_noise": self._add_noise, 577 "zoom": self._zoom, 578 "rotate": self._rotate, 579 "mask": self._mask_joints, 580 "switch": self._switch, 581 } 582 583 def _default_augmentations(self) -> List: 584 """ 585 Get the list of default augmentation names 586 """ 587 588 return ["mirror", "shift", "add_noise"] 589 590 def _get_bodyparts(self, shape: Tuple) -> None: 591 """ 592 Set the number of bodyparts from the data if it is not known 593 """ 594 595 if self.n_bodyparts is None: 596 N, B, F = shape 597 self.n_bodyparts = B // 2
A transformer that augments the output of the Kinematic feature extractor
The available augmentations are 'rotate'
, 'mirror'
, 'shift'
, 'add_noise'
and 'zoom'
28 def __init__( 29 self, 30 model_name: str, 31 augmentations: List = None, 32 use_default_augmentations: bool = False, 33 rotation_limits: List = None, 34 mirror_dim: Set = None, 35 noise_std: float = 0.05, 36 zoom_limits: List = None, 37 masking_probability: float = 0.05, 38 dim: int = 2, 39 graph_features: bool = False, 40 bodyparts_order: List = None, 41 canvas_shape: List = None, 42 move_around_image_center: bool = True, 43 **kwargs, 44 ) -> None: 45 """ 46 Parameters 47 ---------- 48 augmentations : list, optional 49 list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") 50 use_default_augmentations : bool, default False 51 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 52 rotation_limits : list, default [-pi/2, pi/2] 53 list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` 54 for 3D data) 55 mirror_dim : set, default {0} 56 set of integer indices of dimensions that can be mirrored 57 noise_std : float, default 0.05 58 standard deviation of noise 59 zoom_limits : list, default [0.5, 1.5] 60 list of float zoom limits ([low, high]) 61 masking_probability : float, default 0.1 62 the probability of masking a joint 63 dim : int, default 2 64 the dimensionality of the input data 65 **kwargs : dict 66 other parameters for the base transformer class 67 """ 68 69 if augmentations is None: 70 augmentations = [] 71 if canvas_shape is None: 72 canvas_shape = [1, 1] 73 self.dim = int(dim) 74 75 self.offset = [0 for _ in range(self.dim)] 76 self.scale = canvas_shape[1] / canvas_shape[0] 77 self.image_center = move_around_image_center 78 # if canvas_shape is None: 79 # self.offset = [0.5 for _ in range(self.dim)] 80 # else: 81 # self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)] 82 83 self.blank = 0 # the value that nan values are set to (shouldn't be changed in augmentations) 84 super().__init__( 85 model_name, 86 augmentations, 87 use_default_augmentations, 88 graph_features=graph_features, 89 bodyparts_order=bodyparts_order, 90 **kwargs, 91 ) 92 if rotation_limits is None: 93 rotation_limits = [-np.pi / 2, np.pi / 2] 94 if mirror_dim is None: 95 mirror_dim = [0] 96 if zoom_limits is None: 97 zoom_limits = [0.5, 1.5] 98 self.rotation_limits = rotation_limits 99 self.n_bodyparts = None 100 self.mirror_dim = mirror_dim 101 self.noise_std = noise_std 102 self.zoom_limits = zoom_limits 103 self.masking_probability = masking_probability
Parameters
augmentations : list, optional
list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
use_default_augmentations : bool, default False
if True
and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
rotation_limits : list, default [-pi/2, pi/2]
list of float rotation angle limits ([low, high]``, or
[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
for 3D data)
mirror_dim : set, default {0}
set of integer indices of dimensions that can be mirrored
noise_std : float, default 0.05
standard deviation of noise
zoom_limits : list, default [0.5, 1.5]
list of float zoom limits ([low, high])
masking_probability : float, default 0.1
the probability of masking a joint
dim : int, default 2
the dimensionality of the input data
**kwargs : dict
other parameters for the base transformer class