dlc2action.transformer.kinematic
Kinematic transformer
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7""" 8Kinematic transformer 9""" 10 11from typing import Dict, Tuple, List, Set, Iterable 12import torch 13import numpy as np 14from copy import copy, deepcopy 15from random import getrandbits 16from collections import defaultdict 17from dlc2action.transformer.base_transformer import Transformer 18from dlc2action.utils import rotation_matrix_2d, rotation_matrix_3d 19 20 21class KinematicTransformer(Transformer): 22 """ 23 A transformer that augments the output of the Kinematic feature extractor 24 25 The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'` 26 """ 27 28 def __init__( 29 self, 30 model_name: str, 31 augmentations: List = None, 32 use_default_augmentations: bool = False, 33 rotation_limits: List = None, 34 mirror_dim: Set = None, 35 noise_std: float = 0.05, 36 zoom_limits: List = None, 37 masking_probability: float = 0.05, 38 dim: int = 2, 39 graph_features: bool = False, 40 bodyparts_order: List = None, 41 canvas_shape: List = None, 42 move_around_image_center: bool = True, 43 **kwargs, 44 ) -> None: 45 """ 46 Parameters 47 ---------- 48 augmentations : list, optional 49 list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") 50 use_default_augmentations : bool, default False 51 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 52 rotation_limits : list, default [-pi/2, pi/2] 53 list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` 54 for 3D data) 55 mirror_dim : set, default {0} 56 set of integer indices of dimensions that can be mirrored 57 noise_std : float, default 0.05 58 standard deviation of noise 59 zoom_limits : list, default [0.5, 1.5] 60 list of float zoom limits ([low, high]) 61 masking_probability : float, default 0.1 62 the probability of masking a joint 63 dim : int, default 2 64 the dimensionality of the input data 65 **kwargs : dict 66 other parameters for the base transformer class 67 """ 68 69 if augmentations is None: 70 augmentations = [] 71 if canvas_shape is None: 72 canvas_shape = [1, 1] 73 self.dim = int(dim) 74 75 self.offset = [0 for _ in range(self.dim)] 76 if isinstance(canvas_shape, str): 77 canvas_shape = eval(canvas_shape) 78 self.scale = canvas_shape[1] / canvas_shape[0] 79 self.image_center = move_around_image_center 80 # if canvas_shape is None: 81 # self.offset = [0.5 for _ in range(self.dim)] 82 # else: 83 # self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)] 84 85 self.blank = 0 # the value that nan values are set to (shouldn't be changed in augmentations) 86 super().__init__( 87 model_name, 88 augmentations, 89 use_default_augmentations, 90 graph_features=graph_features, 91 bodyparts_order=bodyparts_order, 92 **kwargs, 93 ) 94 if rotation_limits is None: 95 rotation_limits = [-np.pi / 2, np.pi / 2] 96 if mirror_dim is None: 97 mirror_dim = [0] 98 if zoom_limits is None: 99 zoom_limits = [0.5, 1.5] 100 self.rotation_limits = rotation_limits 101 self.n_bodyparts = None 102 self.mirror_dim = mirror_dim 103 self.noise_std = noise_std 104 self.zoom_limits = zoom_limits 105 self.masking_probability = masking_probability 106 107 def _apply_augmentations( 108 self, 109 main_input: Dict, 110 ssl_inputs: List = None, 111 ssl_targets: List = None, 112 augment: bool = False, 113 ssl: bool = False, 114 ) -> Tuple: 115 116 if ssl_targets is None: 117 ssl_targets = [None] 118 if ssl_inputs is None: 119 ssl_inputs = [None] 120 121 keys_main = self._get_keys( 122 ( 123 "coords", 124 "speed_joints", 125 "speed_direction", 126 "speed_bones", 127 "acc_bones", 128 "bones", 129 "coord_diff", 130 ), 131 main_input, 132 ) 133 if len(keys_main) > 0: 134 key = keys_main[0] 135 final_shape = main_input[key].shape 136 # self._get_bodyparts(final_shape) 137 batch = final_shape[0] 138 s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1])) 139 x_shape = s.shape 140 # print(f'{x_shape=}, {self.dim=}') 141 self.n_bodyparts = x_shape[1] 142 if self.dim == 3: 143 if len(self.rotation_limits) == 2: 144 self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits] 145 dicts = [main_input] + ssl_inputs + ssl_targets 146 for x in dicts: 147 if x is not None: 148 keys = self._get_keys( 149 ( 150 "coords", 151 "speed_joints", 152 "speed_direction", 153 "speed_bones", 154 "acc_bones", 155 "bones", 156 "coord_diff", 157 ), 158 x, 159 ) 160 for key in keys: 161 x[key] = x[key].reshape(x_shape) 162 if len(self._get_keys(("center"), main_input)) > 0: 163 dicts = [main_input] + ssl_inputs + ssl_targets 164 for x in dicts: 165 if x is not None: 166 key_bases = ["center"] 167 keys = self._get_keys( 168 key_bases, 169 x, 170 ) 171 for key in keys: 172 x[key] = x[key].reshape( 173 (x[key].shape[0], 1, -1, x[key].shape[-1]) 174 ) 175 main_input, ssl_inputs, ssl_targets = super()._apply_augmentations( 176 main_input, ssl_inputs, ssl_targets, augment 177 ) 178 if len(keys_main) > 0: 179 dicts = [main_input] + ssl_inputs + ssl_targets 180 for x in dicts: 181 if x is not None: 182 keys = self._get_keys( 183 ( 184 "coords", 185 "speed_joints", 186 "speed_direction", 187 "speed_bones", 188 "acc_bones", 189 "bones", 190 "coord_diff", 191 ), 192 x, 193 ) 194 for key in keys: 195 x[key] = x[key].reshape(final_shape) 196 if len(self._get_keys(("center"), main_input)) > 0: 197 dicts = [main_input] + ssl_inputs + ssl_targets 198 for x in dicts: 199 if x is not None: 200 key_bases = ["center"] 201 keys = self._get_keys( 202 key_bases, 203 x, 204 ) 205 for key in keys: 206 x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1])) 207 return main_input, ssl_inputs, ssl_targets 208 209 def _rotate( 210 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 211 ) -> (Dict, List, List): 212 """ 213 Rotate the "coords" and "speed_joints" features of the input to a random degree 214 """ 215 216 keys = self._get_keys( 217 ( 218 "coords", 219 "bones", 220 "coord_diff", 221 "center", 222 "speed_joints", 223 "speed_direction", 224 "center", 225 ), 226 main_input, 227 ) 228 if len(keys) == 0: 229 return main_input, ssl_inputs, ssl_targets 230 batch = main_input[keys[0]].shape[0] 231 if self.dim == 2: 232 angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits) 233 R = rotation_matrix_2d(angles).to(main_input[keys[0]].device) 234 else: 235 angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0]) 236 angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1]) 237 angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2]) 238 R = rotation_matrix_3d(angles_x, angles_y, angles_z).to( 239 main_input[keys[0]].device 240 ) 241 dicts = [main_input] + ssl_inputs + ssl_targets 242 for x in dicts: 243 if x is not None: 244 keys = self._get_keys( 245 ( 246 "coords", 247 "speed_joints", 248 "speed_direction", 249 "speed_bones", 250 "acc_bones", 251 "bones", 252 "coord_diff", 253 "center", 254 ), 255 x, 256 ) 257 for key in keys: 258 if key in x: 259 mask = x[key] == self.blank 260 if ( 261 any([key.startswith(x) for x in ["coords", "center"]]) 262 and not self.image_center 263 ): 264 offset = x[key].mean(1).unsqueeze(1) 265 else: 266 offset = 0 267 x[key] = ( 268 torch.einsum( 269 "abjd,aij->abid", 270 x[key] - offset, 271 R, 272 ) 273 + offset 274 ) 275 x[key][mask] = self.blank 276 return main_input, ssl_inputs, ssl_targets 277 278 def _mirror( 279 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 280 ) -> (Dict, List, List): 281 """ 282 Mirror the "coords" and "speed_joints" features of the input randomly 283 """ 284 285 mirror = [] 286 for i in range(3): 287 if i in self.mirror_dim and bool(getrandbits(1)): 288 mirror.append(i) 289 if len(mirror) == 0: 290 return main_input, ssl_inputs, ssl_targets 291 dicts = [main_input] + ssl_inputs + ssl_targets 292 for x in dicts: 293 if x is not None: 294 keys = self._get_keys( 295 ( 296 "coords", 297 "speed_joints", 298 "speed_direction", 299 "speed_bones", 300 "acc_bones", 301 "bones", 302 "center", 303 "coord_diff", 304 ), 305 x, 306 ) 307 for key in keys: 308 if key in x: 309 mask = x[key] == self.blank 310 y = deepcopy(x[key]) 311 if not self.image_center and any( 312 [key.startswith(x) for x in ["coords", "center"]] 313 ): 314 mean = y.mean(1).unsqueeze(1) 315 for dim in mirror: 316 y[:, :, dim, :] = ( 317 2 * mean[:, :, dim, :] - y[:, :, dim, :] 318 ) 319 else: 320 for dim in mirror: 321 y[:, :, dim, :] *= -1 322 x[key] = y 323 x[key][mask] = self.blank 324 return main_input, ssl_inputs, ssl_targets 325 326 def _shift( 327 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 328 ) -> (Dict, List, List): 329 """ 330 Shift the "coords" features of the input randomly 331 """ 332 333 keys = self._get_keys(("coords", "center"), main_input) 334 if len(keys) == 0: 335 return main_input, ssl_inputs, ssl_targets 336 batch = main_input[keys[0]].shape[0] 337 dim = main_input[keys[0]].shape[2] 338 device = main_input[keys[0]].device 339 coords = torch.cat( 340 [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys], 341 dim=1, 342 ) 343 minimask = coords[:, :, 0] != self.blank 344 min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0] 345 min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0] 346 max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0] 347 max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0] 348 del coords 349 shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * ( 350 max_x - min_x 351 ) 352 shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * ( 353 max_y - min_y 354 ) 355 shift_x = shift_x.unsqueeze(-1).unsqueeze(-1) 356 shift_y = shift_y.unsqueeze(-1).unsqueeze(-1) 357 dicts = [main_input] + ssl_inputs + ssl_targets 358 for x in dicts: 359 if x is not None: 360 keys = self._get_keys(("coords", "center"), x) 361 for key in keys: 362 y = deepcopy(x[key]) 363 mask = y == self.blank 364 y[:, :, 0, :] += shift_x 365 y[:, :, 1, :] += shift_y 366 x[key] = y 367 x[key][mask] = self.blank 368 return main_input, ssl_inputs, ssl_targets 369 370 def _add_noise( 371 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 372 ) -> (Dict, List, List): 373 """ 374 Add normal noise to all features of the input 375 """ 376 377 dicts = [main_input] + ssl_inputs + ssl_targets 378 for x in dicts: 379 if x is not None: 380 keys = self._get_keys( 381 ( 382 "coords", 383 "speed_joints", 384 "speed_direction", 385 "intra_distance", 386 "acc_joints", 387 "angle_speeds", 388 "inter_distance", 389 "speed_bones", 390 "acc_bones", 391 "bones", 392 "coord_diff", 393 "center", 394 ), 395 x, 396 ) 397 for key in keys: 398 mask = x[key] == self.blank 399 x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_( 400 std=self.noise_std 401 ).to(x[key].device) 402 x[key][mask] = self.blank 403 return main_input, ssl_inputs, ssl_targets 404 405 def _zoom( 406 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 407 ) -> (Dict, List, List): 408 """ 409 Add random zoom to all features of the input 410 """ 411 412 key = list(main_input.keys())[0] 413 batch = main_input[key].shape[0] 414 device = main_input[key].device 415 zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device) 416 dicts = [main_input] + ssl_inputs + ssl_targets 417 for x in dicts: 418 if x is not None: 419 keys = self._get_keys( 420 ( 421 "speed_joints", 422 "intra_distance", 423 "acc_joints", 424 "inter_distance", 425 "speed_bones", 426 "acc_bones", 427 "bones", 428 "coord_diff", 429 "center", 430 "speed_value", 431 ), 432 x, 433 ) 434 for key in keys: 435 mask = x[key] == self.blank 436 shape = np.array(x[key].shape) 437 shape[1:] = 1 438 y = deepcopy(x[key]) 439 y *= zoom.reshape(list(shape)) 440 x[key] = y 441 x[key][mask] = self.blank 442 keys = self._get_keys("coords", x) 443 for key in keys: 444 mask = x[key] == self.blank 445 shape = np.array(x[key].shape) 446 shape[1:] = 1 447 zoom = zoom.reshape(list(shape)) 448 x[key][mask] = 10 449 min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1) 450 min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1) 451 center = torch.stack([min_x, min_y], dim=2) 452 coords = (x[key] - center) * zoom + center 453 x[key] = coords 454 x[key][mask] = self.blank 455 456 return main_input, ssl_inputs, ssl_targets 457 458 def _mask_joints( 459 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 460 ) -> (Dict, List, List): 461 """ 462 Mask joints randomly 463 """ 464 465 key = list(main_input.keys())[0] 466 batch, *_, frames = main_input[key].shape 467 masked_joints = ( 468 torch.FloatTensor(batch, self.n_bodyparts).uniform_() 469 < self.masking_probability 470 ) 471 dicts = [main_input] + ssl_inputs + ssl_targets 472 for x in [y for y in dicts if y is not None]: 473 keys = self._get_keys(("intra_distance", "inter_distance"), x) 474 for key in keys: 475 mask = ( 476 masked_joints.repeat(self.n_bodyparts, frames, 1, 1) 477 .transpose(0, 2) 478 .transpose(1, 3) 479 ) 480 indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1) 481 482 X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to( 483 x[key].device 484 ) 485 X[:, indices[0], indices[1], :] = x[key] 486 X[mask] = self.blank 487 X[mask.transpose(1, 2)] = self.blank 488 x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames) 489 keys = self._get_keys( 490 ( 491 "speed_joints", 492 "speed_direction", 493 "coords", 494 "acc_joints", 495 "speed_bones", 496 "acc_bones", 497 "bones", 498 "coord_diff", 499 ), 500 x, 501 ) 502 for key in keys: 503 mask = ( 504 masked_joints.repeat(self.dim, frames, 1, 1) 505 .transpose(0, 2) 506 .transpose(1, 3) 507 ) 508 x[key][mask] = ( 509 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask] 510 ) 511 keys = self._get_keys("angle_speeds", x) 512 for key in keys: 513 mask = ( 514 masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2) 515 ) 516 x[key][mask] = ( 517 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask] 518 ) 519 520 return main_input, ssl_inputs, ssl_targets 521 522 def _switch( 523 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 524 ) -> (Dict, List, List): 525 if bool(getrandbits(1)): 526 return main_input, ssl_inputs, ssl_targets 527 individuals = set() 528 ind_dict = defaultdict(lambda: set()) 529 for key in main_input: 530 if len(key.split("---")) != 2: 531 continue 532 ind = key.split("---")[1] 533 if "+" in ind: 534 continue 535 individuals.add(ind) 536 ind_dict[ind].add(key.split("---")[0]) 537 individuals = list(individuals) 538 if len(individuals) < 2: 539 return main_input, ssl_inputs, ssl_targets 540 for x, y in zip(individuals[::2], individuals[1::2]): 541 for key in ind_dict[x]: 542 if key in ind_dict[y]: 543 main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = ( 544 main_input[f"{key}---{y}"], 545 main_input[f"{key}---{x}"], 546 ) 547 for d in ssl_inputs + ssl_targets: 548 if d is None: 549 continue 550 if f"{key}---{x}" in d and f"{key}---{y}" in d: 551 d[f"{key}---{x}"], d[f"{key}---{y}"] = ( 552 d[f"{key}---{y}"], 553 d[f"{key}---{x}"], 554 ) 555 return main_input, ssl_inputs, ssl_targets 556 557 def _get_keys(self, key_bases: Iterable, x: Dict) -> List: 558 """ 559 Get the keys of x that start with one of the strings from key_bases 560 """ 561 562 keys = [] 563 if isinstance(key_bases, str): 564 key_bases = [key_bases] 565 for key in x: 566 if any([x == key.split("---")[0] for x in key_bases]): 567 keys.append(key) 568 return keys 569 570 def _augmentations_dict(self) -> Dict: 571 """ 572 Get the mapping from augmentation names to functions 573 """ 574 575 return { 576 "mirror": self._mirror, 577 "shift": self._shift, 578 "add_noise": self._add_noise, 579 "zoom": self._zoom, 580 "rotate": self._rotate, 581 "mask": self._mask_joints, 582 "switch": self._switch, 583 } 584 585 def _default_augmentations(self) -> List: 586 """ 587 Get the list of default augmentation names 588 """ 589 590 return ["mirror", "shift", "add_noise"] 591 592 def _get_bodyparts(self, shape: Tuple) -> None: 593 """ 594 Set the number of bodyparts from the data if it is not known 595 """ 596 597 if self.n_bodyparts is None: 598 N, B, F = shape 599 self.n_bodyparts = B // 2
22class KinematicTransformer(Transformer): 23 """ 24 A transformer that augments the output of the Kinematic feature extractor 25 26 The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'` 27 """ 28 29 def __init__( 30 self, 31 model_name: str, 32 augmentations: List = None, 33 use_default_augmentations: bool = False, 34 rotation_limits: List = None, 35 mirror_dim: Set = None, 36 noise_std: float = 0.05, 37 zoom_limits: List = None, 38 masking_probability: float = 0.05, 39 dim: int = 2, 40 graph_features: bool = False, 41 bodyparts_order: List = None, 42 canvas_shape: List = None, 43 move_around_image_center: bool = True, 44 **kwargs, 45 ) -> None: 46 """ 47 Parameters 48 ---------- 49 augmentations : list, optional 50 list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") 51 use_default_augmentations : bool, default False 52 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 53 rotation_limits : list, default [-pi/2, pi/2] 54 list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` 55 for 3D data) 56 mirror_dim : set, default {0} 57 set of integer indices of dimensions that can be mirrored 58 noise_std : float, default 0.05 59 standard deviation of noise 60 zoom_limits : list, default [0.5, 1.5] 61 list of float zoom limits ([low, high]) 62 masking_probability : float, default 0.1 63 the probability of masking a joint 64 dim : int, default 2 65 the dimensionality of the input data 66 **kwargs : dict 67 other parameters for the base transformer class 68 """ 69 70 if augmentations is None: 71 augmentations = [] 72 if canvas_shape is None: 73 canvas_shape = [1, 1] 74 self.dim = int(dim) 75 76 self.offset = [0 for _ in range(self.dim)] 77 if isinstance(canvas_shape, str): 78 canvas_shape = eval(canvas_shape) 79 self.scale = canvas_shape[1] / canvas_shape[0] 80 self.image_center = move_around_image_center 81 # if canvas_shape is None: 82 # self.offset = [0.5 for _ in range(self.dim)] 83 # else: 84 # self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)] 85 86 self.blank = 0 # the value that nan values are set to (shouldn't be changed in augmentations) 87 super().__init__( 88 model_name, 89 augmentations, 90 use_default_augmentations, 91 graph_features=graph_features, 92 bodyparts_order=bodyparts_order, 93 **kwargs, 94 ) 95 if rotation_limits is None: 96 rotation_limits = [-np.pi / 2, np.pi / 2] 97 if mirror_dim is None: 98 mirror_dim = [0] 99 if zoom_limits is None: 100 zoom_limits = [0.5, 1.5] 101 self.rotation_limits = rotation_limits 102 self.n_bodyparts = None 103 self.mirror_dim = mirror_dim 104 self.noise_std = noise_std 105 self.zoom_limits = zoom_limits 106 self.masking_probability = masking_probability 107 108 def _apply_augmentations( 109 self, 110 main_input: Dict, 111 ssl_inputs: List = None, 112 ssl_targets: List = None, 113 augment: bool = False, 114 ssl: bool = False, 115 ) -> Tuple: 116 117 if ssl_targets is None: 118 ssl_targets = [None] 119 if ssl_inputs is None: 120 ssl_inputs = [None] 121 122 keys_main = self._get_keys( 123 ( 124 "coords", 125 "speed_joints", 126 "speed_direction", 127 "speed_bones", 128 "acc_bones", 129 "bones", 130 "coord_diff", 131 ), 132 main_input, 133 ) 134 if len(keys_main) > 0: 135 key = keys_main[0] 136 final_shape = main_input[key].shape 137 # self._get_bodyparts(final_shape) 138 batch = final_shape[0] 139 s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1])) 140 x_shape = s.shape 141 # print(f'{x_shape=}, {self.dim=}') 142 self.n_bodyparts = x_shape[1] 143 if self.dim == 3: 144 if len(self.rotation_limits) == 2: 145 self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits] 146 dicts = [main_input] + ssl_inputs + ssl_targets 147 for x in dicts: 148 if x is not None: 149 keys = self._get_keys( 150 ( 151 "coords", 152 "speed_joints", 153 "speed_direction", 154 "speed_bones", 155 "acc_bones", 156 "bones", 157 "coord_diff", 158 ), 159 x, 160 ) 161 for key in keys: 162 x[key] = x[key].reshape(x_shape) 163 if len(self._get_keys(("center"), main_input)) > 0: 164 dicts = [main_input] + ssl_inputs + ssl_targets 165 for x in dicts: 166 if x is not None: 167 key_bases = ["center"] 168 keys = self._get_keys( 169 key_bases, 170 x, 171 ) 172 for key in keys: 173 x[key] = x[key].reshape( 174 (x[key].shape[0], 1, -1, x[key].shape[-1]) 175 ) 176 main_input, ssl_inputs, ssl_targets = super()._apply_augmentations( 177 main_input, ssl_inputs, ssl_targets, augment 178 ) 179 if len(keys_main) > 0: 180 dicts = [main_input] + ssl_inputs + ssl_targets 181 for x in dicts: 182 if x is not None: 183 keys = self._get_keys( 184 ( 185 "coords", 186 "speed_joints", 187 "speed_direction", 188 "speed_bones", 189 "acc_bones", 190 "bones", 191 "coord_diff", 192 ), 193 x, 194 ) 195 for key in keys: 196 x[key] = x[key].reshape(final_shape) 197 if len(self._get_keys(("center"), main_input)) > 0: 198 dicts = [main_input] + ssl_inputs + ssl_targets 199 for x in dicts: 200 if x is not None: 201 key_bases = ["center"] 202 keys = self._get_keys( 203 key_bases, 204 x, 205 ) 206 for key in keys: 207 x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1])) 208 return main_input, ssl_inputs, ssl_targets 209 210 def _rotate( 211 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 212 ) -> (Dict, List, List): 213 """ 214 Rotate the "coords" and "speed_joints" features of the input to a random degree 215 """ 216 217 keys = self._get_keys( 218 ( 219 "coords", 220 "bones", 221 "coord_diff", 222 "center", 223 "speed_joints", 224 "speed_direction", 225 "center", 226 ), 227 main_input, 228 ) 229 if len(keys) == 0: 230 return main_input, ssl_inputs, ssl_targets 231 batch = main_input[keys[0]].shape[0] 232 if self.dim == 2: 233 angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits) 234 R = rotation_matrix_2d(angles).to(main_input[keys[0]].device) 235 else: 236 angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0]) 237 angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1]) 238 angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2]) 239 R = rotation_matrix_3d(angles_x, angles_y, angles_z).to( 240 main_input[keys[0]].device 241 ) 242 dicts = [main_input] + ssl_inputs + ssl_targets 243 for x in dicts: 244 if x is not None: 245 keys = self._get_keys( 246 ( 247 "coords", 248 "speed_joints", 249 "speed_direction", 250 "speed_bones", 251 "acc_bones", 252 "bones", 253 "coord_diff", 254 "center", 255 ), 256 x, 257 ) 258 for key in keys: 259 if key in x: 260 mask = x[key] == self.blank 261 if ( 262 any([key.startswith(x) for x in ["coords", "center"]]) 263 and not self.image_center 264 ): 265 offset = x[key].mean(1).unsqueeze(1) 266 else: 267 offset = 0 268 x[key] = ( 269 torch.einsum( 270 "abjd,aij->abid", 271 x[key] - offset, 272 R, 273 ) 274 + offset 275 ) 276 x[key][mask] = self.blank 277 return main_input, ssl_inputs, ssl_targets 278 279 def _mirror( 280 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 281 ) -> (Dict, List, List): 282 """ 283 Mirror the "coords" and "speed_joints" features of the input randomly 284 """ 285 286 mirror = [] 287 for i in range(3): 288 if i in self.mirror_dim and bool(getrandbits(1)): 289 mirror.append(i) 290 if len(mirror) == 0: 291 return main_input, ssl_inputs, ssl_targets 292 dicts = [main_input] + ssl_inputs + ssl_targets 293 for x in dicts: 294 if x is not None: 295 keys = self._get_keys( 296 ( 297 "coords", 298 "speed_joints", 299 "speed_direction", 300 "speed_bones", 301 "acc_bones", 302 "bones", 303 "center", 304 "coord_diff", 305 ), 306 x, 307 ) 308 for key in keys: 309 if key in x: 310 mask = x[key] == self.blank 311 y = deepcopy(x[key]) 312 if not self.image_center and any( 313 [key.startswith(x) for x in ["coords", "center"]] 314 ): 315 mean = y.mean(1).unsqueeze(1) 316 for dim in mirror: 317 y[:, :, dim, :] = ( 318 2 * mean[:, :, dim, :] - y[:, :, dim, :] 319 ) 320 else: 321 for dim in mirror: 322 y[:, :, dim, :] *= -1 323 x[key] = y 324 x[key][mask] = self.blank 325 return main_input, ssl_inputs, ssl_targets 326 327 def _shift( 328 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 329 ) -> (Dict, List, List): 330 """ 331 Shift the "coords" features of the input randomly 332 """ 333 334 keys = self._get_keys(("coords", "center"), main_input) 335 if len(keys) == 0: 336 return main_input, ssl_inputs, ssl_targets 337 batch = main_input[keys[0]].shape[0] 338 dim = main_input[keys[0]].shape[2] 339 device = main_input[keys[0]].device 340 coords = torch.cat( 341 [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys], 342 dim=1, 343 ) 344 minimask = coords[:, :, 0] != self.blank 345 min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0] 346 min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0] 347 max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0] 348 max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0] 349 del coords 350 shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * ( 351 max_x - min_x 352 ) 353 shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * ( 354 max_y - min_y 355 ) 356 shift_x = shift_x.unsqueeze(-1).unsqueeze(-1) 357 shift_y = shift_y.unsqueeze(-1).unsqueeze(-1) 358 dicts = [main_input] + ssl_inputs + ssl_targets 359 for x in dicts: 360 if x is not None: 361 keys = self._get_keys(("coords", "center"), x) 362 for key in keys: 363 y = deepcopy(x[key]) 364 mask = y == self.blank 365 y[:, :, 0, :] += shift_x 366 y[:, :, 1, :] += shift_y 367 x[key] = y 368 x[key][mask] = self.blank 369 return main_input, ssl_inputs, ssl_targets 370 371 def _add_noise( 372 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 373 ) -> (Dict, List, List): 374 """ 375 Add normal noise to all features of the input 376 """ 377 378 dicts = [main_input] + ssl_inputs + ssl_targets 379 for x in dicts: 380 if x is not None: 381 keys = self._get_keys( 382 ( 383 "coords", 384 "speed_joints", 385 "speed_direction", 386 "intra_distance", 387 "acc_joints", 388 "angle_speeds", 389 "inter_distance", 390 "speed_bones", 391 "acc_bones", 392 "bones", 393 "coord_diff", 394 "center", 395 ), 396 x, 397 ) 398 for key in keys: 399 mask = x[key] == self.blank 400 x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_( 401 std=self.noise_std 402 ).to(x[key].device) 403 x[key][mask] = self.blank 404 return main_input, ssl_inputs, ssl_targets 405 406 def _zoom( 407 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 408 ) -> (Dict, List, List): 409 """ 410 Add random zoom to all features of the input 411 """ 412 413 key = list(main_input.keys())[0] 414 batch = main_input[key].shape[0] 415 device = main_input[key].device 416 zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device) 417 dicts = [main_input] + ssl_inputs + ssl_targets 418 for x in dicts: 419 if x is not None: 420 keys = self._get_keys( 421 ( 422 "speed_joints", 423 "intra_distance", 424 "acc_joints", 425 "inter_distance", 426 "speed_bones", 427 "acc_bones", 428 "bones", 429 "coord_diff", 430 "center", 431 "speed_value", 432 ), 433 x, 434 ) 435 for key in keys: 436 mask = x[key] == self.blank 437 shape = np.array(x[key].shape) 438 shape[1:] = 1 439 y = deepcopy(x[key]) 440 y *= zoom.reshape(list(shape)) 441 x[key] = y 442 x[key][mask] = self.blank 443 keys = self._get_keys("coords", x) 444 for key in keys: 445 mask = x[key] == self.blank 446 shape = np.array(x[key].shape) 447 shape[1:] = 1 448 zoom = zoom.reshape(list(shape)) 449 x[key][mask] = 10 450 min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1) 451 min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1) 452 center = torch.stack([min_x, min_y], dim=2) 453 coords = (x[key] - center) * zoom + center 454 x[key] = coords 455 x[key][mask] = self.blank 456 457 return main_input, ssl_inputs, ssl_targets 458 459 def _mask_joints( 460 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 461 ) -> (Dict, List, List): 462 """ 463 Mask joints randomly 464 """ 465 466 key = list(main_input.keys())[0] 467 batch, *_, frames = main_input[key].shape 468 masked_joints = ( 469 torch.FloatTensor(batch, self.n_bodyparts).uniform_() 470 < self.masking_probability 471 ) 472 dicts = [main_input] + ssl_inputs + ssl_targets 473 for x in [y for y in dicts if y is not None]: 474 keys = self._get_keys(("intra_distance", "inter_distance"), x) 475 for key in keys: 476 mask = ( 477 masked_joints.repeat(self.n_bodyparts, frames, 1, 1) 478 .transpose(0, 2) 479 .transpose(1, 3) 480 ) 481 indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1) 482 483 X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to( 484 x[key].device 485 ) 486 X[:, indices[0], indices[1], :] = x[key] 487 X[mask] = self.blank 488 X[mask.transpose(1, 2)] = self.blank 489 x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames) 490 keys = self._get_keys( 491 ( 492 "speed_joints", 493 "speed_direction", 494 "coords", 495 "acc_joints", 496 "speed_bones", 497 "acc_bones", 498 "bones", 499 "coord_diff", 500 ), 501 x, 502 ) 503 for key in keys: 504 mask = ( 505 masked_joints.repeat(self.dim, frames, 1, 1) 506 .transpose(0, 2) 507 .transpose(1, 3) 508 ) 509 x[key][mask] = ( 510 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask] 511 ) 512 keys = self._get_keys("angle_speeds", x) 513 for key in keys: 514 mask = ( 515 masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2) 516 ) 517 x[key][mask] = ( 518 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask] 519 ) 520 521 return main_input, ssl_inputs, ssl_targets 522 523 def _switch( 524 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 525 ) -> (Dict, List, List): 526 if bool(getrandbits(1)): 527 return main_input, ssl_inputs, ssl_targets 528 individuals = set() 529 ind_dict = defaultdict(lambda: set()) 530 for key in main_input: 531 if len(key.split("---")) != 2: 532 continue 533 ind = key.split("---")[1] 534 if "+" in ind: 535 continue 536 individuals.add(ind) 537 ind_dict[ind].add(key.split("---")[0]) 538 individuals = list(individuals) 539 if len(individuals) < 2: 540 return main_input, ssl_inputs, ssl_targets 541 for x, y in zip(individuals[::2], individuals[1::2]): 542 for key in ind_dict[x]: 543 if key in ind_dict[y]: 544 main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = ( 545 main_input[f"{key}---{y}"], 546 main_input[f"{key}---{x}"], 547 ) 548 for d in ssl_inputs + ssl_targets: 549 if d is None: 550 continue 551 if f"{key}---{x}" in d and f"{key}---{y}" in d: 552 d[f"{key}---{x}"], d[f"{key}---{y}"] = ( 553 d[f"{key}---{y}"], 554 d[f"{key}---{x}"], 555 ) 556 return main_input, ssl_inputs, ssl_targets 557 558 def _get_keys(self, key_bases: Iterable, x: Dict) -> List: 559 """ 560 Get the keys of x that start with one of the strings from key_bases 561 """ 562 563 keys = [] 564 if isinstance(key_bases, str): 565 key_bases = [key_bases] 566 for key in x: 567 if any([x == key.split("---")[0] for x in key_bases]): 568 keys.append(key) 569 return keys 570 571 def _augmentations_dict(self) -> Dict: 572 """ 573 Get the mapping from augmentation names to functions 574 """ 575 576 return { 577 "mirror": self._mirror, 578 "shift": self._shift, 579 "add_noise": self._add_noise, 580 "zoom": self._zoom, 581 "rotate": self._rotate, 582 "mask": self._mask_joints, 583 "switch": self._switch, 584 } 585 586 def _default_augmentations(self) -> List: 587 """ 588 Get the list of default augmentation names 589 """ 590 591 return ["mirror", "shift", "add_noise"] 592 593 def _get_bodyparts(self, shape: Tuple) -> None: 594 """ 595 Set the number of bodyparts from the data if it is not known 596 """ 597 598 if self.n_bodyparts is None: 599 N, B, F = shape 600 self.n_bodyparts = B // 2
A transformer that augments the output of the Kinematic feature extractor
The available augmentations are 'rotate', 'mirror', 'shift', 'add_noise' and 'zoom'
29 def __init__( 30 self, 31 model_name: str, 32 augmentations: List = None, 33 use_default_augmentations: bool = False, 34 rotation_limits: List = None, 35 mirror_dim: Set = None, 36 noise_std: float = 0.05, 37 zoom_limits: List = None, 38 masking_probability: float = 0.05, 39 dim: int = 2, 40 graph_features: bool = False, 41 bodyparts_order: List = None, 42 canvas_shape: List = None, 43 move_around_image_center: bool = True, 44 **kwargs, 45 ) -> None: 46 """ 47 Parameters 48 ---------- 49 augmentations : list, optional 50 list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") 51 use_default_augmentations : bool, default False 52 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 53 rotation_limits : list, default [-pi/2, pi/2] 54 list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` 55 for 3D data) 56 mirror_dim : set, default {0} 57 set of integer indices of dimensions that can be mirrored 58 noise_std : float, default 0.05 59 standard deviation of noise 60 zoom_limits : list, default [0.5, 1.5] 61 list of float zoom limits ([low, high]) 62 masking_probability : float, default 0.1 63 the probability of masking a joint 64 dim : int, default 2 65 the dimensionality of the input data 66 **kwargs : dict 67 other parameters for the base transformer class 68 """ 69 70 if augmentations is None: 71 augmentations = [] 72 if canvas_shape is None: 73 canvas_shape = [1, 1] 74 self.dim = int(dim) 75 76 self.offset = [0 for _ in range(self.dim)] 77 if isinstance(canvas_shape, str): 78 canvas_shape = eval(canvas_shape) 79 self.scale = canvas_shape[1] / canvas_shape[0] 80 self.image_center = move_around_image_center 81 # if canvas_shape is None: 82 # self.offset = [0.5 for _ in range(self.dim)] 83 # else: 84 # self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)] 85 86 self.blank = 0 # the value that nan values are set to (shouldn't be changed in augmentations) 87 super().__init__( 88 model_name, 89 augmentations, 90 use_default_augmentations, 91 graph_features=graph_features, 92 bodyparts_order=bodyparts_order, 93 **kwargs, 94 ) 95 if rotation_limits is None: 96 rotation_limits = [-np.pi / 2, np.pi / 2] 97 if mirror_dim is None: 98 mirror_dim = [0] 99 if zoom_limits is None: 100 zoom_limits = [0.5, 1.5] 101 self.rotation_limits = rotation_limits 102 self.n_bodyparts = None 103 self.mirror_dim = mirror_dim 104 self.noise_std = noise_std 105 self.zoom_limits = zoom_limits 106 self.masking_probability = masking_probability
Parameters
augmentations : list, optional
list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
use_default_augmentations : bool, default False
if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
rotation_limits : list, default [-pi/2, pi/2]
list of float rotation angle limits ([low, high]``, or[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
for 3D data)
mirror_dim : set, default {0}
set of integer indices of dimensions that can be mirrored
noise_std : float, default 0.05
standard deviation of noise
zoom_limits : list, default [0.5, 1.5]
list of float zoom limits ([low, high])
masking_probability : float, default 0.1
the probability of masking a joint
dim : int, default 2
the dimensionality of the input data
**kwargs : dict
other parameters for the base transformer class