dlc2action.transformer.kinematic
Kinematic transformer
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7""" 8Kinematic transformer 9""" 10 11from typing import Dict, Tuple, List, Set, Iterable 12import torch 13import numpy as np 14from copy import copy, deepcopy 15from random import getrandbits 16from collections import defaultdict 17from dlc2action.transformer.base_transformer import Transformer 18from dlc2action.utils import rotation_matrix_2d, rotation_matrix_3d 19 20 21class KinematicTransformer(Transformer): 22 """ 23 A transformer that augments the output of the Kinematic feature extractor 24 25 The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'` 26 """ 27 28 def __init__( 29 self, 30 model_name: str, 31 augmentations: List = None, 32 use_default_augmentations: bool = False, 33 rotation_limits: List = None, 34 mirror_dim: Set = None, 35 noise_std: float = 0.05, 36 zoom_limits: List = None, 37 masking_probability: float = 0.05, 38 dim: int = 2, 39 graph_features: bool = False, 40 bodyparts_order: List = None, 41 canvas_shape: List = None, 42 move_around_image_center: bool = True, 43 **kwargs, 44 ) -> None: 45 """ 46 Parameters 47 ---------- 48 augmentations : list, optional 49 list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") 50 use_default_augmentations : bool, default False 51 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 52 rotation_limits : list, default [-pi/2, pi/2] 53 list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` 54 for 3D data) 55 mirror_dim : set, default {0} 56 set of integer indices of dimensions that can be mirrored 57 noise_std : float, default 0.05 58 standard deviation of noise 59 zoom_limits : list, default [0.5, 1.5] 60 list of float zoom limits ([low, high]) 61 masking_probability : float, default 0.1 62 the probability of masking a joint 63 dim : int, default 2 64 the dimensionality of the input data 65 **kwargs : dict 66 other parameters for the base transformer class 67 """ 68 69 if augmentations is None: 70 augmentations = [] 71 if canvas_shape is None: 72 canvas_shape = [1, 1] 73 self.dim = int(dim) 74 75 self.offset = [0 for _ in range(self.dim)] 76 if isinstance(canvas_shape, str): 77 canvas_shape = eval(canvas_shape) 78 self.scale = canvas_shape[1] / canvas_shape[0] 79 self.image_center = move_around_image_center 80 # if canvas_shape is None: 81 # self.offset = [0.5 for _ in range(self.dim)] 82 # else: 83 # self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)] 84 85 self.blank = 0 # the value that nan values are set to (shouldn't be changed in augmentations) 86 super().__init__( 87 model_name, 88 augmentations, 89 use_default_augmentations, 90 graph_features=graph_features, 91 bodyparts_order=bodyparts_order, 92 **kwargs, 93 ) 94 if rotation_limits is None: 95 rotation_limits = [-np.pi / 2, np.pi / 2] 96 if mirror_dim is None: 97 mirror_dim = [0] 98 if zoom_limits is None: 99 zoom_limits = [0.5, 1.5] 100 self.rotation_limits = rotation_limits 101 self.n_bodyparts = None 102 self.mirror_dim = mirror_dim 103 self.noise_std = noise_std 104 self.zoom_limits = zoom_limits 105 self.masking_probability = masking_probability 106 107 def _apply_augmentations( 108 self, 109 main_input: Dict, 110 ssl_inputs: List = None, 111 ssl_targets: List = None, 112 augment: bool = False, 113 ssl: bool = False, 114 ) -> Tuple: 115 116 if ssl_targets is None: 117 ssl_targets = [None] 118 if ssl_inputs is None: 119 ssl_inputs = [None] 120 121 keys_main = self._get_keys( 122 ( 123 "coords", 124 "speed_joints", 125 "speed_direction", 126 "speed_bones", 127 "acc_bones", 128 "bones", 129 "coord_diff", 130 ), 131 main_input, 132 ) 133 if len(keys_main) > 0: 134 key = keys_main[0] 135 final_shape = main_input[key].shape 136 # self._get_bodyparts(final_shape) 137 batch = final_shape[0] 138 s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1])) 139 x_shape = s.shape 140 # print(f'{x_shape=}, {self.dim=}') 141 self.n_bodyparts = x_shape[1] 142 if self.dim == 3: 143 if len(self.rotation_limits) == 2: 144 self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits] 145 dicts = [main_input] + ssl_inputs + ssl_targets 146 for x in dicts: 147 if x is not None: 148 keys = self._get_keys( 149 ( 150 "coords", 151 "speed_joints", 152 "speed_direction", 153 "speed_bones", 154 "acc_bones", 155 "bones", 156 "coord_diff", 157 ), 158 x, 159 ) 160 for key in keys: 161 x[key] = x[key].reshape(x_shape) 162 if len(self._get_keys(("center"), main_input)) > 0: 163 dicts = [main_input] + ssl_inputs + ssl_targets 164 for x in dicts: 165 if x is not None: 166 key_bases = ["center"] 167 keys = self._get_keys( 168 key_bases, 169 x, 170 ) 171 for key in keys: 172 x[key] = x[key].reshape( 173 (x[key].shape[0], 1, -1, x[key].shape[-1]) 174 ) 175 main_input, ssl_inputs, ssl_targets = super()._apply_augmentations( 176 main_input, ssl_inputs, ssl_targets, augment 177 ) 178 if len(keys_main) > 0: 179 dicts = [main_input] + ssl_inputs + ssl_targets 180 for x in dicts: 181 if x is not None: 182 keys = self._get_keys( 183 ( 184 "coords", 185 "speed_joints", 186 "speed_direction", 187 "speed_bones", 188 "acc_bones", 189 "bones", 190 "coord_diff", 191 ), 192 x, 193 ) 194 for key in keys: 195 x[key] = x[key].reshape(final_shape) 196 if len(self._get_keys(("center"), main_input)) > 0: 197 dicts = [main_input] + ssl_inputs + ssl_targets 198 for x in dicts: 199 if x is not None: 200 key_bases = ["center"] 201 keys = self._get_keys( 202 key_bases, 203 x, 204 ) 205 for key in keys: 206 x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1])) 207 return main_input, ssl_inputs, ssl_targets 208 209 def _rotate( 210 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 211 ) -> (Dict, List, List): 212 """ 213 Rotate the "coords" and "speed_joints" features of the input to a random degree 214 """ 215 216 keys = self._get_keys( 217 ( 218 "coords", 219 "bones", 220 "coord_diff", 221 "center", 222 "speed_joints", 223 "speed_direction", 224 "center", 225 ), 226 main_input, 227 ) 228 if len(keys) == 0: 229 return main_input, ssl_inputs, ssl_targets 230 batch = main_input[keys[0]].shape[0] 231 if self.dim == 2: 232 angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits) 233 R = rotation_matrix_2d(angles).to(main_input[keys[0]].device) 234 else: 235 angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0]) 236 angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1]) 237 angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2]) 238 R = rotation_matrix_3d(angles_x, angles_y, angles_z).to( 239 main_input[keys[0]].device 240 ) 241 dicts = [main_input] + ssl_inputs + ssl_targets 242 for x in dicts: 243 if x is not None: 244 keys = self._get_keys( 245 ( 246 "coords", 247 "speed_joints", 248 "speed_direction", 249 "speed_bones", 250 "acc_bones", 251 "bones", 252 "coord_diff", 253 "center", 254 ), 255 x, 256 ) 257 for key in keys: 258 if key in x: 259 mask = x[key] == self.blank 260 if ( 261 any([key.startswith(x) for x in ["coords", "center"]]) 262 and not self.image_center 263 ): 264 offset = x[key].mean(1).unsqueeze(1) 265 else: 266 offset = 0 267 x[key] = ( 268 torch.einsum( 269 "abjd,aij->abid", 270 x[key] - offset, 271 R, 272 ) 273 + offset 274 ) 275 x[key][mask] = self.blank 276 return main_input, ssl_inputs, ssl_targets 277 278 def _mirror( 279 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 280 ) -> (Dict, List, List): 281 """ 282 Mirror the "coords" and "speed_joints" features of the input randomly 283 """ 284 285 mirror = [] 286 for i in range(3): 287 if i in self.mirror_dim and bool(getrandbits(1)): 288 mirror.append(i) 289 if len(mirror) == 0: 290 return main_input, ssl_inputs, ssl_targets 291 dicts = [main_input] + ssl_inputs + ssl_targets 292 for x in dicts: 293 if x is not None: 294 keys = self._get_keys( 295 ( 296 "coords", 297 "speed_joints", 298 "speed_direction", 299 "speed_bones", 300 "acc_bones", 301 "bones", 302 "center", 303 "coord_diff", 304 ), 305 x, 306 ) 307 for key in keys: 308 if key in x: 309 mask = x[key] == self.blank 310 y = deepcopy(x[key]) 311 if not self.image_center and any( 312 [key.startswith(x) for x in ["coords", "center"]] 313 ): 314 mean = y.mean(1).unsqueeze(1) 315 for dim in mirror: 316 y[:, :, dim, :] = ( 317 2 * mean[:, :, dim, :] - y[:, :, dim, :] 318 ) 319 else: 320 for dim in mirror: 321 y[:, :, dim, :] *= -1 322 x[key] = y 323 x[key][mask] = self.blank 324 return main_input, ssl_inputs, ssl_targets 325 326 def _shift( 327 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 328 ) -> (Dict, List, List): 329 """ 330 Shift the "coords" features of the input randomly 331 """ 332 333 keys = self._get_keys(("coords", "center"), main_input) 334 if len(keys) == 0: 335 return main_input, ssl_inputs, ssl_targets 336 batch = main_input[keys[0]].shape[0] 337 dim = main_input[keys[0]].shape[2] 338 device = main_input[keys[0]].device 339 coords = torch.cat( 340 [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys], 341 dim=1, 342 ) 343 minimask = coords[:, :, 0] != self.blank 344 min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0] 345 min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0] 346 max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0] 347 max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0] 348 del coords 349 shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * ( 350 max_x - min_x 351 ) 352 shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * ( 353 max_y - min_y 354 ) 355 shift_x = shift_x.unsqueeze(-1).unsqueeze(-1) 356 shift_y = shift_y.unsqueeze(-1).unsqueeze(-1) 357 dicts = [main_input] + ssl_inputs + ssl_targets 358 for x in dicts: 359 if x is not None: 360 keys = self._get_keys(("coords", "center"), x) 361 for key in keys: 362 y = deepcopy(x[key]) 363 mask = y == self.blank 364 y[:, :, 0, :] += shift_x 365 y[:, :, 1, :] += shift_y 366 x[key] = y 367 x[key][mask] = self.blank 368 return main_input, ssl_inputs, ssl_targets 369 370 def _add_noise( 371 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 372 ) -> (Dict, List, List): 373 """ 374 Add normal noise to all features of the input 375 """ 376 377 dicts = [main_input] + ssl_inputs + ssl_targets 378 for x in dicts: 379 if x is not None: 380 keys = self._get_keys( 381 ( 382 "coords", 383 "speed_joints", 384 "speed_direction", 385 "intra_distance", 386 "acc_joints", 387 "angle_speeds", 388 "inter_distance", 389 "speed_bones", 390 "acc_bones", 391 "bones", 392 "coord_diff", 393 "center", 394 ), 395 x, 396 ) 397 for key in keys: 398 mask = x[key] == self.blank 399 x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_( 400 std=self.noise_std 401 ).to(x[key].device) 402 x[key][mask] = self.blank 403 return main_input, ssl_inputs, ssl_targets 404 405 def _zoom( 406 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 407 ) -> (Dict, List, List): 408 """ 409 Add random zoom to all features of the input 410 """ 411 412 key = list(main_input.keys())[0] 413 batch = main_input[key].shape[0] 414 device = main_input[key].device 415 zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device) 416 dicts = [main_input] + ssl_inputs + ssl_targets 417 for x in dicts: 418 if x is not None: 419 keys = self._get_keys( 420 ( 421 "speed_joints", 422 "intra_distance", 423 "acc_joints", 424 "inter_distance", 425 "speed_bones", 426 "acc_bones", 427 "bones", 428 "coord_diff", 429 "center", 430 "speed_value", 431 ), 432 x, 433 ) 434 for key in keys: 435 mask = x[key] == self.blank 436 shape = np.array(x[key].shape) 437 shape[1:] = 1 438 y = deepcopy(x[key]) 439 y *= zoom.reshape(list(shape)) 440 x[key] = y 441 x[key][mask] = self.blank 442 keys = self._get_keys("coords", x) 443 for key in keys: 444 mask = x[key] == self.blank 445 shape = np.array(x[key].shape) 446 shape[1:] = 1 447 zoom = zoom.reshape(list(shape)) 448 x[key][mask] = 10 449 min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1) 450 min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1) 451 if x[key].shape[2] == 2: 452 center = torch.stack([min_x, min_y], dim=2) 453 if x[key].shape[2] == 3: 454 min_z = x[key][:, :, 2, :].min(1)[0].unsqueeze(1) 455 center = torch.stack([min_x, min_y, min_z], dim=2) 456 coords = (x[key] - center) * zoom + center 457 x[key] = coords 458 x[key][mask] = self.blank 459 460 return main_input, ssl_inputs, ssl_targets 461 462 def _mask_joints( 463 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 464 ) -> (Dict, List, List): 465 """ 466 Mask joints randomly 467 """ 468 469 key = list(main_input.keys())[0] 470 batch, *_, frames = main_input[key].shape 471 masked_joints = ( 472 torch.FloatTensor(batch, self.n_bodyparts).uniform_() 473 < self.masking_probability 474 ) 475 dicts = [main_input] + ssl_inputs + ssl_targets 476 for x in [y for y in dicts if y is not None]: 477 keys = self._get_keys(("intra_distance", "inter_distance"), x) 478 for key in keys: 479 mask = ( 480 masked_joints.repeat(self.n_bodyparts, frames, 1, 1) 481 .transpose(0, 2) 482 .transpose(1, 3) 483 ) 484 indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1) 485 486 X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to( 487 x[key].device 488 ) 489 X[:, indices[0], indices[1], :] = x[key] 490 X[mask] = self.blank 491 X[mask.transpose(1, 2)] = self.blank 492 x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames) 493 keys = self._get_keys( 494 ( 495 "speed_joints", 496 "speed_direction", 497 "coords", 498 "acc_joints", 499 "speed_bones", 500 "acc_bones", 501 "bones", 502 "coord_diff", 503 ), 504 x, 505 ) 506 for key in keys: 507 mask = ( 508 masked_joints.repeat(self.dim, frames, 1, 1) 509 .transpose(0, 2) 510 .transpose(1, 3) 511 ) 512 x[key][mask] = ( 513 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask] 514 ) 515 keys = self._get_keys("angle_speeds", x) 516 for key in keys: 517 mask = ( 518 masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2) 519 ) 520 x[key][mask] = ( 521 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask] 522 ) 523 524 return main_input, ssl_inputs, ssl_targets 525 526 def _switch( 527 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 528 ) -> (Dict, List, List): 529 if bool(getrandbits(1)): 530 return main_input, ssl_inputs, ssl_targets 531 individuals = set() 532 ind_dict = defaultdict(lambda: set()) 533 for key in main_input: 534 if len(key.split("---")) != 2: 535 continue 536 ind = key.split("---")[1] 537 if "+" in ind: 538 continue 539 individuals.add(ind) 540 ind_dict[ind].add(key.split("---")[0]) 541 individuals = list(individuals) 542 if len(individuals) < 2: 543 return main_input, ssl_inputs, ssl_targets 544 for x, y in zip(individuals[::2], individuals[1::2]): 545 for key in ind_dict[x]: 546 if key in ind_dict[y]: 547 main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = ( 548 main_input[f"{key}---{y}"], 549 main_input[f"{key}---{x}"], 550 ) 551 for d in ssl_inputs + ssl_targets: 552 if d is None: 553 continue 554 if f"{key}---{x}" in d and f"{key}---{y}" in d: 555 d[f"{key}---{x}"], d[f"{key}---{y}"] = ( 556 d[f"{key}---{y}"], 557 d[f"{key}---{x}"], 558 ) 559 return main_input, ssl_inputs, ssl_targets 560 561 def _get_keys(self, key_bases: Iterable, x: Dict) -> List: 562 """ 563 Get the keys of x that start with one of the strings from key_bases 564 """ 565 566 keys = [] 567 if isinstance(key_bases, str): 568 key_bases = [key_bases] 569 for key in x: 570 if any([x == key.split("---")[0] for x in key_bases]): 571 keys.append(key) 572 return keys 573 574 def _augmentations_dict(self) -> Dict: 575 """ 576 Get the mapping from augmentation names to functions 577 """ 578 579 return { 580 "mirror": self._mirror, 581 "shift": self._shift, 582 "add_noise": self._add_noise, 583 "zoom": self._zoom, 584 "rotate": self._rotate, 585 "mask": self._mask_joints, 586 "switch": self._switch, 587 } 588 589 def _default_augmentations(self) -> List: 590 """ 591 Get the list of default augmentation names 592 """ 593 594 return ["mirror", "shift", "add_noise"] 595 596 def _get_bodyparts(self, shape: Tuple) -> None: 597 """ 598 Set the number of bodyparts from the data if it is not known 599 """ 600 601 if self.n_bodyparts is None: 602 N, B, F = shape 603 self.n_bodyparts = B // 2
22class KinematicTransformer(Transformer): 23 """ 24 A transformer that augments the output of the Kinematic feature extractor 25 26 The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'` 27 """ 28 29 def __init__( 30 self, 31 model_name: str, 32 augmentations: List = None, 33 use_default_augmentations: bool = False, 34 rotation_limits: List = None, 35 mirror_dim: Set = None, 36 noise_std: float = 0.05, 37 zoom_limits: List = None, 38 masking_probability: float = 0.05, 39 dim: int = 2, 40 graph_features: bool = False, 41 bodyparts_order: List = None, 42 canvas_shape: List = None, 43 move_around_image_center: bool = True, 44 **kwargs, 45 ) -> None: 46 """ 47 Parameters 48 ---------- 49 augmentations : list, optional 50 list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") 51 use_default_augmentations : bool, default False 52 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 53 rotation_limits : list, default [-pi/2, pi/2] 54 list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` 55 for 3D data) 56 mirror_dim : set, default {0} 57 set of integer indices of dimensions that can be mirrored 58 noise_std : float, default 0.05 59 standard deviation of noise 60 zoom_limits : list, default [0.5, 1.5] 61 list of float zoom limits ([low, high]) 62 masking_probability : float, default 0.1 63 the probability of masking a joint 64 dim : int, default 2 65 the dimensionality of the input data 66 **kwargs : dict 67 other parameters for the base transformer class 68 """ 69 70 if augmentations is None: 71 augmentations = [] 72 if canvas_shape is None: 73 canvas_shape = [1, 1] 74 self.dim = int(dim) 75 76 self.offset = [0 for _ in range(self.dim)] 77 if isinstance(canvas_shape, str): 78 canvas_shape = eval(canvas_shape) 79 self.scale = canvas_shape[1] / canvas_shape[0] 80 self.image_center = move_around_image_center 81 # if canvas_shape is None: 82 # self.offset = [0.5 for _ in range(self.dim)] 83 # else: 84 # self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)] 85 86 self.blank = 0 # the value that nan values are set to (shouldn't be changed in augmentations) 87 super().__init__( 88 model_name, 89 augmentations, 90 use_default_augmentations, 91 graph_features=graph_features, 92 bodyparts_order=bodyparts_order, 93 **kwargs, 94 ) 95 if rotation_limits is None: 96 rotation_limits = [-np.pi / 2, np.pi / 2] 97 if mirror_dim is None: 98 mirror_dim = [0] 99 if zoom_limits is None: 100 zoom_limits = [0.5, 1.5] 101 self.rotation_limits = rotation_limits 102 self.n_bodyparts = None 103 self.mirror_dim = mirror_dim 104 self.noise_std = noise_std 105 self.zoom_limits = zoom_limits 106 self.masking_probability = masking_probability 107 108 def _apply_augmentations( 109 self, 110 main_input: Dict, 111 ssl_inputs: List = None, 112 ssl_targets: List = None, 113 augment: bool = False, 114 ssl: bool = False, 115 ) -> Tuple: 116 117 if ssl_targets is None: 118 ssl_targets = [None] 119 if ssl_inputs is None: 120 ssl_inputs = [None] 121 122 keys_main = self._get_keys( 123 ( 124 "coords", 125 "speed_joints", 126 "speed_direction", 127 "speed_bones", 128 "acc_bones", 129 "bones", 130 "coord_diff", 131 ), 132 main_input, 133 ) 134 if len(keys_main) > 0: 135 key = keys_main[0] 136 final_shape = main_input[key].shape 137 # self._get_bodyparts(final_shape) 138 batch = final_shape[0] 139 s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1])) 140 x_shape = s.shape 141 # print(f'{x_shape=}, {self.dim=}') 142 self.n_bodyparts = x_shape[1] 143 if self.dim == 3: 144 if len(self.rotation_limits) == 2: 145 self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits] 146 dicts = [main_input] + ssl_inputs + ssl_targets 147 for x in dicts: 148 if x is not None: 149 keys = self._get_keys( 150 ( 151 "coords", 152 "speed_joints", 153 "speed_direction", 154 "speed_bones", 155 "acc_bones", 156 "bones", 157 "coord_diff", 158 ), 159 x, 160 ) 161 for key in keys: 162 x[key] = x[key].reshape(x_shape) 163 if len(self._get_keys(("center"), main_input)) > 0: 164 dicts = [main_input] + ssl_inputs + ssl_targets 165 for x in dicts: 166 if x is not None: 167 key_bases = ["center"] 168 keys = self._get_keys( 169 key_bases, 170 x, 171 ) 172 for key in keys: 173 x[key] = x[key].reshape( 174 (x[key].shape[0], 1, -1, x[key].shape[-1]) 175 ) 176 main_input, ssl_inputs, ssl_targets = super()._apply_augmentations( 177 main_input, ssl_inputs, ssl_targets, augment 178 ) 179 if len(keys_main) > 0: 180 dicts = [main_input] + ssl_inputs + ssl_targets 181 for x in dicts: 182 if x is not None: 183 keys = self._get_keys( 184 ( 185 "coords", 186 "speed_joints", 187 "speed_direction", 188 "speed_bones", 189 "acc_bones", 190 "bones", 191 "coord_diff", 192 ), 193 x, 194 ) 195 for key in keys: 196 x[key] = x[key].reshape(final_shape) 197 if len(self._get_keys(("center"), main_input)) > 0: 198 dicts = [main_input] + ssl_inputs + ssl_targets 199 for x in dicts: 200 if x is not None: 201 key_bases = ["center"] 202 keys = self._get_keys( 203 key_bases, 204 x, 205 ) 206 for key in keys: 207 x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1])) 208 return main_input, ssl_inputs, ssl_targets 209 210 def _rotate( 211 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 212 ) -> (Dict, List, List): 213 """ 214 Rotate the "coords" and "speed_joints" features of the input to a random degree 215 """ 216 217 keys = self._get_keys( 218 ( 219 "coords", 220 "bones", 221 "coord_diff", 222 "center", 223 "speed_joints", 224 "speed_direction", 225 "center", 226 ), 227 main_input, 228 ) 229 if len(keys) == 0: 230 return main_input, ssl_inputs, ssl_targets 231 batch = main_input[keys[0]].shape[0] 232 if self.dim == 2: 233 angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits) 234 R = rotation_matrix_2d(angles).to(main_input[keys[0]].device) 235 else: 236 angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0]) 237 angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1]) 238 angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2]) 239 R = rotation_matrix_3d(angles_x, angles_y, angles_z).to( 240 main_input[keys[0]].device 241 ) 242 dicts = [main_input] + ssl_inputs + ssl_targets 243 for x in dicts: 244 if x is not None: 245 keys = self._get_keys( 246 ( 247 "coords", 248 "speed_joints", 249 "speed_direction", 250 "speed_bones", 251 "acc_bones", 252 "bones", 253 "coord_diff", 254 "center", 255 ), 256 x, 257 ) 258 for key in keys: 259 if key in x: 260 mask = x[key] == self.blank 261 if ( 262 any([key.startswith(x) for x in ["coords", "center"]]) 263 and not self.image_center 264 ): 265 offset = x[key].mean(1).unsqueeze(1) 266 else: 267 offset = 0 268 x[key] = ( 269 torch.einsum( 270 "abjd,aij->abid", 271 x[key] - offset, 272 R, 273 ) 274 + offset 275 ) 276 x[key][mask] = self.blank 277 return main_input, ssl_inputs, ssl_targets 278 279 def _mirror( 280 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 281 ) -> (Dict, List, List): 282 """ 283 Mirror the "coords" and "speed_joints" features of the input randomly 284 """ 285 286 mirror = [] 287 for i in range(3): 288 if i in self.mirror_dim and bool(getrandbits(1)): 289 mirror.append(i) 290 if len(mirror) == 0: 291 return main_input, ssl_inputs, ssl_targets 292 dicts = [main_input] + ssl_inputs + ssl_targets 293 for x in dicts: 294 if x is not None: 295 keys = self._get_keys( 296 ( 297 "coords", 298 "speed_joints", 299 "speed_direction", 300 "speed_bones", 301 "acc_bones", 302 "bones", 303 "center", 304 "coord_diff", 305 ), 306 x, 307 ) 308 for key in keys: 309 if key in x: 310 mask = x[key] == self.blank 311 y = deepcopy(x[key]) 312 if not self.image_center and any( 313 [key.startswith(x) for x in ["coords", "center"]] 314 ): 315 mean = y.mean(1).unsqueeze(1) 316 for dim in mirror: 317 y[:, :, dim, :] = ( 318 2 * mean[:, :, dim, :] - y[:, :, dim, :] 319 ) 320 else: 321 for dim in mirror: 322 y[:, :, dim, :] *= -1 323 x[key] = y 324 x[key][mask] = self.blank 325 return main_input, ssl_inputs, ssl_targets 326 327 def _shift( 328 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 329 ) -> (Dict, List, List): 330 """ 331 Shift the "coords" features of the input randomly 332 """ 333 334 keys = self._get_keys(("coords", "center"), main_input) 335 if len(keys) == 0: 336 return main_input, ssl_inputs, ssl_targets 337 batch = main_input[keys[0]].shape[0] 338 dim = main_input[keys[0]].shape[2] 339 device = main_input[keys[0]].device 340 coords = torch.cat( 341 [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys], 342 dim=1, 343 ) 344 minimask = coords[:, :, 0] != self.blank 345 min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0] 346 min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0] 347 max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0] 348 max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0] 349 del coords 350 shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * ( 351 max_x - min_x 352 ) 353 shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * ( 354 max_y - min_y 355 ) 356 shift_x = shift_x.unsqueeze(-1).unsqueeze(-1) 357 shift_y = shift_y.unsqueeze(-1).unsqueeze(-1) 358 dicts = [main_input] + ssl_inputs + ssl_targets 359 for x in dicts: 360 if x is not None: 361 keys = self._get_keys(("coords", "center"), x) 362 for key in keys: 363 y = deepcopy(x[key]) 364 mask = y == self.blank 365 y[:, :, 0, :] += shift_x 366 y[:, :, 1, :] += shift_y 367 x[key] = y 368 x[key][mask] = self.blank 369 return main_input, ssl_inputs, ssl_targets 370 371 def _add_noise( 372 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 373 ) -> (Dict, List, List): 374 """ 375 Add normal noise to all features of the input 376 """ 377 378 dicts = [main_input] + ssl_inputs + ssl_targets 379 for x in dicts: 380 if x is not None: 381 keys = self._get_keys( 382 ( 383 "coords", 384 "speed_joints", 385 "speed_direction", 386 "intra_distance", 387 "acc_joints", 388 "angle_speeds", 389 "inter_distance", 390 "speed_bones", 391 "acc_bones", 392 "bones", 393 "coord_diff", 394 "center", 395 ), 396 x, 397 ) 398 for key in keys: 399 mask = x[key] == self.blank 400 x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_( 401 std=self.noise_std 402 ).to(x[key].device) 403 x[key][mask] = self.blank 404 return main_input, ssl_inputs, ssl_targets 405 406 def _zoom( 407 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 408 ) -> (Dict, List, List): 409 """ 410 Add random zoom to all features of the input 411 """ 412 413 key = list(main_input.keys())[0] 414 batch = main_input[key].shape[0] 415 device = main_input[key].device 416 zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device) 417 dicts = [main_input] + ssl_inputs + ssl_targets 418 for x in dicts: 419 if x is not None: 420 keys = self._get_keys( 421 ( 422 "speed_joints", 423 "intra_distance", 424 "acc_joints", 425 "inter_distance", 426 "speed_bones", 427 "acc_bones", 428 "bones", 429 "coord_diff", 430 "center", 431 "speed_value", 432 ), 433 x, 434 ) 435 for key in keys: 436 mask = x[key] == self.blank 437 shape = np.array(x[key].shape) 438 shape[1:] = 1 439 y = deepcopy(x[key]) 440 y *= zoom.reshape(list(shape)) 441 x[key] = y 442 x[key][mask] = self.blank 443 keys = self._get_keys("coords", x) 444 for key in keys: 445 mask = x[key] == self.blank 446 shape = np.array(x[key].shape) 447 shape[1:] = 1 448 zoom = zoom.reshape(list(shape)) 449 x[key][mask] = 10 450 min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1) 451 min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1) 452 if x[key].shape[2] == 2: 453 center = torch.stack([min_x, min_y], dim=2) 454 if x[key].shape[2] == 3: 455 min_z = x[key][:, :, 2, :].min(1)[0].unsqueeze(1) 456 center = torch.stack([min_x, min_y, min_z], dim=2) 457 coords = (x[key] - center) * zoom + center 458 x[key] = coords 459 x[key][mask] = self.blank 460 461 return main_input, ssl_inputs, ssl_targets 462 463 def _mask_joints( 464 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 465 ) -> (Dict, List, List): 466 """ 467 Mask joints randomly 468 """ 469 470 key = list(main_input.keys())[0] 471 batch, *_, frames = main_input[key].shape 472 masked_joints = ( 473 torch.FloatTensor(batch, self.n_bodyparts).uniform_() 474 < self.masking_probability 475 ) 476 dicts = [main_input] + ssl_inputs + ssl_targets 477 for x in [y for y in dicts if y is not None]: 478 keys = self._get_keys(("intra_distance", "inter_distance"), x) 479 for key in keys: 480 mask = ( 481 masked_joints.repeat(self.n_bodyparts, frames, 1, 1) 482 .transpose(0, 2) 483 .transpose(1, 3) 484 ) 485 indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1) 486 487 X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to( 488 x[key].device 489 ) 490 X[:, indices[0], indices[1], :] = x[key] 491 X[mask] = self.blank 492 X[mask.transpose(1, 2)] = self.blank 493 x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames) 494 keys = self._get_keys( 495 ( 496 "speed_joints", 497 "speed_direction", 498 "coords", 499 "acc_joints", 500 "speed_bones", 501 "acc_bones", 502 "bones", 503 "coord_diff", 504 ), 505 x, 506 ) 507 for key in keys: 508 mask = ( 509 masked_joints.repeat(self.dim, frames, 1, 1) 510 .transpose(0, 2) 511 .transpose(1, 3) 512 ) 513 x[key][mask] = ( 514 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask] 515 ) 516 keys = self._get_keys("angle_speeds", x) 517 for key in keys: 518 mask = ( 519 masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2) 520 ) 521 x[key][mask] = ( 522 x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask] 523 ) 524 525 return main_input, ssl_inputs, ssl_targets 526 527 def _switch( 528 self, main_input: Dict, ssl_inputs: List, ssl_targets: List 529 ) -> (Dict, List, List): 530 if bool(getrandbits(1)): 531 return main_input, ssl_inputs, ssl_targets 532 individuals = set() 533 ind_dict = defaultdict(lambda: set()) 534 for key in main_input: 535 if len(key.split("---")) != 2: 536 continue 537 ind = key.split("---")[1] 538 if "+" in ind: 539 continue 540 individuals.add(ind) 541 ind_dict[ind].add(key.split("---")[0]) 542 individuals = list(individuals) 543 if len(individuals) < 2: 544 return main_input, ssl_inputs, ssl_targets 545 for x, y in zip(individuals[::2], individuals[1::2]): 546 for key in ind_dict[x]: 547 if key in ind_dict[y]: 548 main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = ( 549 main_input[f"{key}---{y}"], 550 main_input[f"{key}---{x}"], 551 ) 552 for d in ssl_inputs + ssl_targets: 553 if d is None: 554 continue 555 if f"{key}---{x}" in d and f"{key}---{y}" in d: 556 d[f"{key}---{x}"], d[f"{key}---{y}"] = ( 557 d[f"{key}---{y}"], 558 d[f"{key}---{x}"], 559 ) 560 return main_input, ssl_inputs, ssl_targets 561 562 def _get_keys(self, key_bases: Iterable, x: Dict) -> List: 563 """ 564 Get the keys of x that start with one of the strings from key_bases 565 """ 566 567 keys = [] 568 if isinstance(key_bases, str): 569 key_bases = [key_bases] 570 for key in x: 571 if any([x == key.split("---")[0] for x in key_bases]): 572 keys.append(key) 573 return keys 574 575 def _augmentations_dict(self) -> Dict: 576 """ 577 Get the mapping from augmentation names to functions 578 """ 579 580 return { 581 "mirror": self._mirror, 582 "shift": self._shift, 583 "add_noise": self._add_noise, 584 "zoom": self._zoom, 585 "rotate": self._rotate, 586 "mask": self._mask_joints, 587 "switch": self._switch, 588 } 589 590 def _default_augmentations(self) -> List: 591 """ 592 Get the list of default augmentation names 593 """ 594 595 return ["mirror", "shift", "add_noise"] 596 597 def _get_bodyparts(self, shape: Tuple) -> None: 598 """ 599 Set the number of bodyparts from the data if it is not known 600 """ 601 602 if self.n_bodyparts is None: 603 N, B, F = shape 604 self.n_bodyparts = B // 2
A transformer that augments the output of the Kinematic feature extractor
The available augmentations are 'rotate', 'mirror', 'shift', 'add_noise' and 'zoom'
29 def __init__( 30 self, 31 model_name: str, 32 augmentations: List = None, 33 use_default_augmentations: bool = False, 34 rotation_limits: List = None, 35 mirror_dim: Set = None, 36 noise_std: float = 0.05, 37 zoom_limits: List = None, 38 masking_probability: float = 0.05, 39 dim: int = 2, 40 graph_features: bool = False, 41 bodyparts_order: List = None, 42 canvas_shape: List = None, 43 move_around_image_center: bool = True, 44 **kwargs, 45 ) -> None: 46 """ 47 Parameters 48 ---------- 49 augmentations : list, optional 50 list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") 51 use_default_augmentations : bool, default False 52 if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations 53 rotation_limits : list, default [-pi/2, pi/2] 54 list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` 55 for 3D data) 56 mirror_dim : set, default {0} 57 set of integer indices of dimensions that can be mirrored 58 noise_std : float, default 0.05 59 standard deviation of noise 60 zoom_limits : list, default [0.5, 1.5] 61 list of float zoom limits ([low, high]) 62 masking_probability : float, default 0.1 63 the probability of masking a joint 64 dim : int, default 2 65 the dimensionality of the input data 66 **kwargs : dict 67 other parameters for the base transformer class 68 """ 69 70 if augmentations is None: 71 augmentations = [] 72 if canvas_shape is None: 73 canvas_shape = [1, 1] 74 self.dim = int(dim) 75 76 self.offset = [0 for _ in range(self.dim)] 77 if isinstance(canvas_shape, str): 78 canvas_shape = eval(canvas_shape) 79 self.scale = canvas_shape[1] / canvas_shape[0] 80 self.image_center = move_around_image_center 81 # if canvas_shape is None: 82 # self.offset = [0.5 for _ in range(self.dim)] 83 # else: 84 # self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)] 85 86 self.blank = 0 # the value that nan values are set to (shouldn't be changed in augmentations) 87 super().__init__( 88 model_name, 89 augmentations, 90 use_default_augmentations, 91 graph_features=graph_features, 92 bodyparts_order=bodyparts_order, 93 **kwargs, 94 ) 95 if rotation_limits is None: 96 rotation_limits = [-np.pi / 2, np.pi / 2] 97 if mirror_dim is None: 98 mirror_dim = [0] 99 if zoom_limits is None: 100 zoom_limits = [0.5, 1.5] 101 self.rotation_limits = rotation_limits 102 self.n_bodyparts = None 103 self.mirror_dim = mirror_dim 104 self.noise_std = noise_std 105 self.zoom_limits = zoom_limits 106 self.masking_probability = masking_probability
Parameters
augmentations : list, optional
list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
use_default_augmentations : bool, default False
if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
rotation_limits : list, default [-pi/2, pi/2]
list of float rotation angle limits ([low, high]``, or[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
for 3D data)
mirror_dim : set, default {0}
set of integer indices of dimensions that can be mirrored
noise_std : float, default 0.05
standard deviation of noise
zoom_limits : list, default [0.5, 1.5]
list of float zoom limits ([low, high])
masking_probability : float, default 0.1
the probability of masking a joint
dim : int, default 2
the dimensionality of the input data
**kwargs : dict
other parameters for the base transformer class