dlc2action.transformer.kinematic

Kinematic transformer

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6"""
  7Kinematic transformer
  8"""
  9
 10from typing import Dict, Tuple, List, Set, Iterable
 11import torch
 12import numpy as np
 13from copy import copy, deepcopy
 14from random import getrandbits
 15from collections import defaultdict
 16from dlc2action.transformer.base_transformer import Transformer
 17from dlc2action.utils import rotation_matrix_2d, rotation_matrix_3d
 18
 19
 20class KinematicTransformer(Transformer):
 21    """
 22    A transformer that augments the output of the Kinematic feature extractor
 23
 24    The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'`
 25    """
 26
 27    def __init__(
 28        self,
 29        model_name: str,
 30        augmentations: List = None,
 31        use_default_augmentations: bool = False,
 32        rotation_limits: List = None,
 33        mirror_dim: Set = None,
 34        noise_std: float = 0.05,
 35        zoom_limits: List = None,
 36        masking_probability: float = 0.05,
 37        dim: int = 2,
 38        graph_features: bool = False,
 39        bodyparts_order: List = None,
 40        canvas_shape: List = None,
 41        move_around_image_center: bool = True,
 42        **kwargs,
 43    ) -> None:
 44        """
 45        Parameters
 46        ----------
 47        augmentations : list, optional
 48            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 49        use_default_augmentations : bool, default False
 50            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 51        rotation_limits : list, default [-pi/2, pi/2]
 52            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 53            for 3D data)
 54        mirror_dim : set, default {0}
 55            set of integer indices of dimensions that can be mirrored
 56        noise_std : float, default 0.05
 57            standard deviation of noise
 58        zoom_limits : list, default [0.5, 1.5]
 59            list of float zoom limits ([low, high])
 60        masking_probability : float, default 0.1
 61            the probability of masking a joint
 62        dim : int, default 2
 63            the dimensionality of the input data
 64        **kwargs : dict
 65            other parameters for the base transformer class
 66        """
 67
 68        if augmentations is None:
 69            augmentations = []
 70        if canvas_shape is None:
 71            canvas_shape = [1, 1]
 72        self.dim = int(dim)
 73
 74        self.offset = [0 for _ in range(self.dim)]
 75        self.scale = canvas_shape[1] / canvas_shape[0]
 76        self.image_center = move_around_image_center
 77        # if canvas_shape is None:
 78        #     self.offset = [0.5 for _ in range(self.dim)]
 79        # else:
 80        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 81
 82        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 83        super().__init__(
 84            model_name,
 85            augmentations,
 86            use_default_augmentations,
 87            graph_features=graph_features,
 88            bodyparts_order=bodyparts_order,
 89            **kwargs,
 90        )
 91        if rotation_limits is None:
 92            rotation_limits = [-np.pi / 2, np.pi / 2]
 93        if mirror_dim is None:
 94            mirror_dim = [0]
 95        if zoom_limits is None:
 96            zoom_limits = [0.5, 1.5]
 97        self.rotation_limits = rotation_limits
 98        self.n_bodyparts = None
 99        self.mirror_dim = mirror_dim
100        self.noise_std = noise_std
101        self.zoom_limits = zoom_limits
102        self.masking_probability = masking_probability
103
104    def _apply_augmentations(
105        self,
106        main_input: Dict,
107        ssl_inputs: List = None,
108        ssl_targets: List = None,
109        augment: bool = False,
110        ssl: bool = False,
111    ) -> Tuple:
112
113        if ssl_targets is None:
114            ssl_targets = [None]
115        if ssl_inputs is None:
116            ssl_inputs = [None]
117
118        keys_main = self._get_keys(
119            (
120                "coords",
121                "speed_joints",
122                "speed_direction",
123                "speed_bones",
124                "acc_bones",
125                "bones",
126                "coord_diff",
127            ),
128            main_input,
129        )
130        if len(keys_main) > 0:
131            key = keys_main[0]
132            final_shape = main_input[key].shape
133            # self._get_bodyparts(final_shape)
134            batch = final_shape[0]
135            s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1]))
136            x_shape = s.shape
137            # print(f'{x_shape=}, {self.dim=}')
138            self.n_bodyparts = x_shape[1]
139            if self.dim == 3:
140                if len(self.rotation_limits) == 2:
141                    self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits]
142            dicts = [main_input] + ssl_inputs + ssl_targets
143            for x in dicts:
144                if x is not None:
145                    keys = self._get_keys(
146                        (
147                            "coords",
148                            "speed_joints",
149                            "speed_direction",
150                            "speed_bones",
151                            "acc_bones",
152                            "bones",
153                            "coord_diff",
154                        ),
155                        x,
156                    )
157                    for key in keys:
158                        x[key] = x[key].reshape(x_shape)
159        if len(self._get_keys(("center"), main_input)) > 0:
160            dicts = [main_input] + ssl_inputs + ssl_targets
161            for x in dicts:
162                if x is not None:
163                    key_bases = ["center"]
164                    keys = self._get_keys(
165                        key_bases,
166                        x,
167                    )
168                    for key in keys:
169                        x[key] = x[key].reshape(
170                            (x[key].shape[0], 1, -1, x[key].shape[-1])
171                        )
172        main_input, ssl_inputs, ssl_targets = super()._apply_augmentations(
173            main_input, ssl_inputs, ssl_targets, augment
174        )
175        if len(keys_main) > 0:
176            dicts = [main_input] + ssl_inputs + ssl_targets
177            for x in dicts:
178                if x is not None:
179                    keys = self._get_keys(
180                        (
181                            "coords",
182                            "speed_joints",
183                            "speed_direction",
184                            "speed_bones",
185                            "acc_bones",
186                            "bones",
187                            "coord_diff",
188                        ),
189                        x,
190                    )
191                    for key in keys:
192                        x[key] = x[key].reshape(final_shape)
193        if len(self._get_keys(("center"), main_input)) > 0:
194            dicts = [main_input] + ssl_inputs + ssl_targets
195            for x in dicts:
196                if x is not None:
197                    key_bases = ["center"]
198                    keys = self._get_keys(
199                        key_bases,
200                        x,
201                    )
202                    for key in keys:
203                        x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1]))
204        return main_input, ssl_inputs, ssl_targets
205
206    def _rotate(
207        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
208    ) -> Tuple[Dict, List, List]:
209        """
210        Rotate the "coords" and "speed_joints" features of the input to a random degree
211        """
212
213        keys = self._get_keys(
214            (
215                "coords",
216                "bones",
217                "coord_diff",
218                "center",
219                "speed_joints",
220                "speed_direction",
221                "center",
222            ),
223            main_input,
224        )
225        if len(keys) == 0:
226            return main_input, ssl_inputs, ssl_targets
227        batch = main_input[keys[0]].shape[0]
228        if self.dim == 2:
229            angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits)
230            R = rotation_matrix_2d(angles).to(main_input[keys[0]].device)
231        else:
232            angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0])
233            angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1])
234            angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2])
235            R = rotation_matrix_3d(angles_x, angles_y, angles_z).to(
236                main_input[keys[0]].device
237            )
238        dicts = [main_input] + ssl_inputs + ssl_targets
239        for x in dicts:
240            if x is not None:
241                keys = self._get_keys(
242                    (
243                        "coords",
244                        "speed_joints",
245                        "speed_direction",
246                        "speed_bones",
247                        "acc_bones",
248                        "bones",
249                        "coord_diff",
250                        "center",
251                    ),
252                    x,
253                )
254                for key in keys:
255                    if key in x:
256                        mask = x[key] == self.blank
257                        if (
258                            any([key.startswith(x) for x in ["coords", "center"]])
259                            and not self.image_center
260                        ):
261                            offset = x[key].mean(1).unsqueeze(1)
262                        else:
263                            offset = 0
264                        x[key] = (
265                            torch.einsum(
266                                "abjd,aij->abid",
267                                x[key] - offset,
268                                R,
269                            )
270                            + offset
271                        )
272                        x[key][mask] = self.blank
273        return main_input, ssl_inputs, ssl_targets
274
275    def _mirror(
276        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
277    ) -> Tuple[Dict, List, List]:
278        """
279        Mirror the "coords" and "speed_joints" features of the input randomly
280        """
281
282        mirror = []
283        for i in range(3):
284            if i in self.mirror_dim and bool(getrandbits(1)):
285                mirror.append(i)
286        if len(mirror) == 0:
287            return main_input, ssl_inputs, ssl_targets
288        dicts = [main_input] + ssl_inputs + ssl_targets
289        for x in dicts:
290            if x is not None:
291                keys = self._get_keys(
292                    (
293                        "coords",
294                        "speed_joints",
295                        "speed_direction",
296                        "speed_bones",
297                        "acc_bones",
298                        "bones",
299                        "center",
300                        "coord_diff",
301                    ),
302                    x,
303                )
304                for key in keys:
305                    if key in x:
306                        mask = x[key] == self.blank
307                        y = deepcopy(x[key])
308                        if not self.image_center and any(
309                            [key.startswith(x) for x in ["coords", "center"]]
310                        ):
311                            mean = y.mean(1).unsqueeze(1)
312                            for dim in mirror:
313                                y[:, :, dim, :] = (
314                                    2 * mean[:, :, dim, :] - y[:, :, dim, :]
315                                )
316                        else:
317                            for dim in mirror:
318                                y[:, :, dim, :] *= -1
319                        x[key] = y
320                        x[key][mask] = self.blank
321        return main_input, ssl_inputs, ssl_targets
322
323    def _shift(
324        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
325    ) -> Tuple[Dict, List, List]:
326        """
327        Shift the "coords" features of the input randomly
328        """
329
330        keys = self._get_keys(("coords", "center"), main_input)
331        if len(keys) == 0:
332            return main_input, ssl_inputs, ssl_targets
333        batch = main_input[keys[0]].shape[0]
334        dim = main_input[keys[0]].shape[2]
335        device = main_input[keys[0]].device
336        coords = torch.cat(
337            [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys],
338            dim=1,
339        )
340        minimask = coords[:, :, 0] != self.blank
341        min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0]
342        min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0]
343        max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0]
344        max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0]
345        del coords
346        shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * (
347            max_x - min_x
348        )
349        shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * (
350            max_y - min_y
351        )
352        shift_x = shift_x.unsqueeze(-1).unsqueeze(-1)
353        shift_y = shift_y.unsqueeze(-1).unsqueeze(-1)
354        dicts = [main_input] + ssl_inputs + ssl_targets
355        for x in dicts:
356            if x is not None:
357                keys = self._get_keys(("coords", "center"), x)
358                for key in keys:
359                    y = deepcopy(x[key])
360                    mask = y == self.blank
361                    y[:, :, 0, :] += shift_x
362                    y[:, :, 1, :] += shift_y
363                    x[key] = y
364                    x[key][mask] = self.blank
365        return main_input, ssl_inputs, ssl_targets
366
367    def _add_noise(
368        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
369    ) -> Tuple[Dict, List, List]:
370        """
371        Add normal noise to all features of the input
372        """
373
374        dicts = [main_input] + ssl_inputs + ssl_targets
375        for x in dicts:
376            if x is not None:
377                keys = self._get_keys(
378                    (
379                        "coords",
380                        "speed_joints",
381                        "speed_direction",
382                        "intra_distance",
383                        "acc_joints",
384                        "angle_speeds",
385                        "inter_distance",
386                        "speed_bones",
387                        "acc_bones",
388                        "bones",
389                        "coord_diff",
390                        "center",
391                    ),
392                    x,
393                )
394                for key in keys:
395                    mask = x[key] == self.blank
396                    x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_(
397                        std=self.noise_std
398                    ).to(x[key].device)
399                    x[key][mask] = self.blank
400        return main_input, ssl_inputs, ssl_targets
401
402    def _zoom(
403        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
404    ) -> Tuple[Dict, List, List]:
405        """
406        Add random zoom to all features of the input
407        """
408
409        key = list(main_input.keys())[0]
410        batch = main_input[key].shape[0]
411        device = main_input[key].device
412        zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device)
413        dicts = [main_input] + ssl_inputs + ssl_targets
414        for x in dicts:
415            if x is not None:
416                keys = self._get_keys(
417                    (
418                        "speed_joints",
419                        "intra_distance",
420                        "acc_joints",
421                        "inter_distance",
422                        "speed_bones",
423                        "acc_bones",
424                        "bones",
425                        "coord_diff",
426                        "center",
427                        "speed_value",
428                    ),
429                    x,
430                )
431                for key in keys:
432                    mask = x[key] == self.blank
433                    shape = np.array(x[key].shape)
434                    shape[1:] = 1
435                    y = deepcopy(x[key])
436                    y *= zoom.reshape(list(shape))
437                    x[key] = y
438                    x[key][mask] = self.blank
439                keys = self._get_keys("coords", x)
440                for key in keys:
441                    mask = x[key] == self.blank
442                    shape = np.array(x[key].shape)
443                    shape[1:] = 1
444                    zoom = zoom.reshape(list(shape))
445                    x[key][mask] = 10
446                    min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1)
447                    min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1)
448                    center = torch.stack([min_x, min_y], dim=2)
449                    coords = (x[key] - center) * zoom + center
450                    x[key] = coords
451                    x[key][mask] = self.blank
452
453        return main_input, ssl_inputs, ssl_targets
454
455    def _mask_joints(
456        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
457    ) -> Tuple[Dict, List, List]:
458        """
459        Mask joints randomly
460        """
461
462        key = list(main_input.keys())[0]
463        batch, *_, frames = main_input[key].shape
464        masked_joints = (
465            torch.FloatTensor(batch, self.n_bodyparts).uniform_()
466            < self.masking_probability
467        )
468        dicts = [main_input] + ssl_inputs + ssl_targets
469        for x in [y for y in dicts if y is not None]:
470            keys = self._get_keys(("intra_distance", "inter_distance"), x)
471            for key in keys:
472                mask = (
473                    masked_joints.repeat(self.n_bodyparts, frames, 1, 1)
474                    .transpose(0, 2)
475                    .transpose(1, 3)
476                )
477                indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1)
478
479                X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to(
480                    x[key].device
481                )
482                X[:, indices[0], indices[1], :] = x[key]
483                X[mask] = self.blank
484                X[mask.transpose(1, 2)] = self.blank
485                x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames)
486            keys = self._get_keys(
487                (
488                    "speed_joints",
489                    "speed_direction",
490                    "coords",
491                    "acc_joints",
492                    "speed_bones",
493                    "acc_bones",
494                    "bones",
495                    "coord_diff",
496                ),
497                x,
498            )
499            for key in keys:
500                mask = (
501                    masked_joints.repeat(self.dim, frames, 1, 1)
502                    .transpose(0, 2)
503                    .transpose(1, 3)
504                )
505                x[key][mask] = (
506                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask]
507                )
508            keys = self._get_keys("angle_speeds", x)
509            for key in keys:
510                mask = (
511                    masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2)
512                )
513                x[key][mask] = (
514                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask]
515                )
516
517        return main_input, ssl_inputs, ssl_targets
518
519    def _switch(
520        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
521    ) -> Tuple[Dict, List, List]:
522        if bool(getrandbits(1)):
523            return main_input, ssl_inputs, ssl_targets
524        individuals = set()
525        ind_dict = defaultdict(lambda: set())
526        for key in main_input:
527            if len(key.split("---")) != 2:
528                continue
529            ind = key.split("---")[1]
530            if "+" in ind:
531                continue
532            individuals.add(ind)
533            ind_dict[ind].add(key.split("---")[0])
534        individuals = list(individuals)
535        if len(individuals) < 2:
536            return main_input, ssl_inputs, ssl_targets
537        for x, y in zip(individuals[::2], individuals[1::2]):
538            for key in ind_dict[x]:
539                if key in ind_dict[y]:
540                    main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = (
541                        main_input[f"{key}---{y}"],
542                        main_input[f"{key}---{x}"],
543                    )
544                for d in ssl_inputs + ssl_targets:
545                    if d is None:
546                        continue
547                    if f"{key}---{x}" in d and f"{key}---{y}" in d:
548                        d[f"{key}---{x}"], d[f"{key}---{y}"] = (
549                            d[f"{key}---{y}"],
550                            d[f"{key}---{x}"],
551                        )
552        return main_input, ssl_inputs, ssl_targets
553
554    def _get_keys(self, key_bases: Iterable, x: Dict) -> List:
555        """
556        Get the keys of x that start with one of the strings from key_bases
557        """
558
559        keys = []
560        if isinstance(key_bases, str):
561            key_bases = [key_bases]
562        for key in x:
563            if any([x == key.split("---")[0] for x in key_bases]):
564                keys.append(key)
565        return keys
566
567    def _augmentations_dict(self) -> Dict:
568        """
569        Get the mapping from augmentation names to functions
570        """
571
572        return {
573            "mirror": self._mirror,
574            "shift": self._shift,
575            "add_noise": self._add_noise,
576            "zoom": self._zoom,
577            "rotate": self._rotate,
578            "mask": self._mask_joints,
579            "switch": self._switch,
580        }
581
582    def _default_augmentations(self) -> List:
583        """
584        Get the list of default augmentation names
585        """
586
587        return ["mirror", "shift", "add_noise"]
588
589    def _get_bodyparts(self, shape: Tuple) -> None:
590        """
591        Set the number of bodyparts from the data if it is not known
592        """
593
594        if self.n_bodyparts is None:
595            N, B, F = shape
596            self.n_bodyparts = B // 2
class KinematicTransformer(dlc2action.transformer.base_transformer.Transformer):
 21class KinematicTransformer(Transformer):
 22    """
 23    A transformer that augments the output of the Kinematic feature extractor
 24
 25    The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'`
 26    """
 27
 28    def __init__(
 29        self,
 30        model_name: str,
 31        augmentations: List = None,
 32        use_default_augmentations: bool = False,
 33        rotation_limits: List = None,
 34        mirror_dim: Set = None,
 35        noise_std: float = 0.05,
 36        zoom_limits: List = None,
 37        masking_probability: float = 0.05,
 38        dim: int = 2,
 39        graph_features: bool = False,
 40        bodyparts_order: List = None,
 41        canvas_shape: List = None,
 42        move_around_image_center: bool = True,
 43        **kwargs,
 44    ) -> None:
 45        """
 46        Parameters
 47        ----------
 48        augmentations : list, optional
 49            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 50        use_default_augmentations : bool, default False
 51            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 52        rotation_limits : list, default [-pi/2, pi/2]
 53            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 54            for 3D data)
 55        mirror_dim : set, default {0}
 56            set of integer indices of dimensions that can be mirrored
 57        noise_std : float, default 0.05
 58            standard deviation of noise
 59        zoom_limits : list, default [0.5, 1.5]
 60            list of float zoom limits ([low, high])
 61        masking_probability : float, default 0.1
 62            the probability of masking a joint
 63        dim : int, default 2
 64            the dimensionality of the input data
 65        **kwargs : dict
 66            other parameters for the base transformer class
 67        """
 68
 69        if augmentations is None:
 70            augmentations = []
 71        if canvas_shape is None:
 72            canvas_shape = [1, 1]
 73        self.dim = int(dim)
 74
 75        self.offset = [0 for _ in range(self.dim)]
 76        self.scale = canvas_shape[1] / canvas_shape[0]
 77        self.image_center = move_around_image_center
 78        # if canvas_shape is None:
 79        #     self.offset = [0.5 for _ in range(self.dim)]
 80        # else:
 81        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 82
 83        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 84        super().__init__(
 85            model_name,
 86            augmentations,
 87            use_default_augmentations,
 88            graph_features=graph_features,
 89            bodyparts_order=bodyparts_order,
 90            **kwargs,
 91        )
 92        if rotation_limits is None:
 93            rotation_limits = [-np.pi / 2, np.pi / 2]
 94        if mirror_dim is None:
 95            mirror_dim = [0]
 96        if zoom_limits is None:
 97            zoom_limits = [0.5, 1.5]
 98        self.rotation_limits = rotation_limits
 99        self.n_bodyparts = None
100        self.mirror_dim = mirror_dim
101        self.noise_std = noise_std
102        self.zoom_limits = zoom_limits
103        self.masking_probability = masking_probability
104
105    def _apply_augmentations(
106        self,
107        main_input: Dict,
108        ssl_inputs: List = None,
109        ssl_targets: List = None,
110        augment: bool = False,
111        ssl: bool = False,
112    ) -> Tuple:
113
114        if ssl_targets is None:
115            ssl_targets = [None]
116        if ssl_inputs is None:
117            ssl_inputs = [None]
118
119        keys_main = self._get_keys(
120            (
121                "coords",
122                "speed_joints",
123                "speed_direction",
124                "speed_bones",
125                "acc_bones",
126                "bones",
127                "coord_diff",
128            ),
129            main_input,
130        )
131        if len(keys_main) > 0:
132            key = keys_main[0]
133            final_shape = main_input[key].shape
134            # self._get_bodyparts(final_shape)
135            batch = final_shape[0]
136            s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1]))
137            x_shape = s.shape
138            # print(f'{x_shape=}, {self.dim=}')
139            self.n_bodyparts = x_shape[1]
140            if self.dim == 3:
141                if len(self.rotation_limits) == 2:
142                    self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits]
143            dicts = [main_input] + ssl_inputs + ssl_targets
144            for x in dicts:
145                if x is not None:
146                    keys = self._get_keys(
147                        (
148                            "coords",
149                            "speed_joints",
150                            "speed_direction",
151                            "speed_bones",
152                            "acc_bones",
153                            "bones",
154                            "coord_diff",
155                        ),
156                        x,
157                    )
158                    for key in keys:
159                        x[key] = x[key].reshape(x_shape)
160        if len(self._get_keys(("center"), main_input)) > 0:
161            dicts = [main_input] + ssl_inputs + ssl_targets
162            for x in dicts:
163                if x is not None:
164                    key_bases = ["center"]
165                    keys = self._get_keys(
166                        key_bases,
167                        x,
168                    )
169                    for key in keys:
170                        x[key] = x[key].reshape(
171                            (x[key].shape[0], 1, -1, x[key].shape[-1])
172                        )
173        main_input, ssl_inputs, ssl_targets = super()._apply_augmentations(
174            main_input, ssl_inputs, ssl_targets, augment
175        )
176        if len(keys_main) > 0:
177            dicts = [main_input] + ssl_inputs + ssl_targets
178            for x in dicts:
179                if x is not None:
180                    keys = self._get_keys(
181                        (
182                            "coords",
183                            "speed_joints",
184                            "speed_direction",
185                            "speed_bones",
186                            "acc_bones",
187                            "bones",
188                            "coord_diff",
189                        ),
190                        x,
191                    )
192                    for key in keys:
193                        x[key] = x[key].reshape(final_shape)
194        if len(self._get_keys(("center"), main_input)) > 0:
195            dicts = [main_input] + ssl_inputs + ssl_targets
196            for x in dicts:
197                if x is not None:
198                    key_bases = ["center"]
199                    keys = self._get_keys(
200                        key_bases,
201                        x,
202                    )
203                    for key in keys:
204                        x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1]))
205        return main_input, ssl_inputs, ssl_targets
206
207    def _rotate(
208        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
209    ) -> Tuple[Dict, List, List]:
210        """
211        Rotate the "coords" and "speed_joints" features of the input to a random degree
212        """
213
214        keys = self._get_keys(
215            (
216                "coords",
217                "bones",
218                "coord_diff",
219                "center",
220                "speed_joints",
221                "speed_direction",
222                "center",
223            ),
224            main_input,
225        )
226        if len(keys) == 0:
227            return main_input, ssl_inputs, ssl_targets
228        batch = main_input[keys[0]].shape[0]
229        if self.dim == 2:
230            angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits)
231            R = rotation_matrix_2d(angles).to(main_input[keys[0]].device)
232        else:
233            angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0])
234            angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1])
235            angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2])
236            R = rotation_matrix_3d(angles_x, angles_y, angles_z).to(
237                main_input[keys[0]].device
238            )
239        dicts = [main_input] + ssl_inputs + ssl_targets
240        for x in dicts:
241            if x is not None:
242                keys = self._get_keys(
243                    (
244                        "coords",
245                        "speed_joints",
246                        "speed_direction",
247                        "speed_bones",
248                        "acc_bones",
249                        "bones",
250                        "coord_diff",
251                        "center",
252                    ),
253                    x,
254                )
255                for key in keys:
256                    if key in x:
257                        mask = x[key] == self.blank
258                        if (
259                            any([key.startswith(x) for x in ["coords", "center"]])
260                            and not self.image_center
261                        ):
262                            offset = x[key].mean(1).unsqueeze(1)
263                        else:
264                            offset = 0
265                        x[key] = (
266                            torch.einsum(
267                                "abjd,aij->abid",
268                                x[key] - offset,
269                                R,
270                            )
271                            + offset
272                        )
273                        x[key][mask] = self.blank
274        return main_input, ssl_inputs, ssl_targets
275
276    def _mirror(
277        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
278    ) -> Tuple[Dict, List, List]:
279        """
280        Mirror the "coords" and "speed_joints" features of the input randomly
281        """
282
283        mirror = []
284        for i in range(3):
285            if i in self.mirror_dim and bool(getrandbits(1)):
286                mirror.append(i)
287        if len(mirror) == 0:
288            return main_input, ssl_inputs, ssl_targets
289        dicts = [main_input] + ssl_inputs + ssl_targets
290        for x in dicts:
291            if x is not None:
292                keys = self._get_keys(
293                    (
294                        "coords",
295                        "speed_joints",
296                        "speed_direction",
297                        "speed_bones",
298                        "acc_bones",
299                        "bones",
300                        "center",
301                        "coord_diff",
302                    ),
303                    x,
304                )
305                for key in keys:
306                    if key in x:
307                        mask = x[key] == self.blank
308                        y = deepcopy(x[key])
309                        if not self.image_center and any(
310                            [key.startswith(x) for x in ["coords", "center"]]
311                        ):
312                            mean = y.mean(1).unsqueeze(1)
313                            for dim in mirror:
314                                y[:, :, dim, :] = (
315                                    2 * mean[:, :, dim, :] - y[:, :, dim, :]
316                                )
317                        else:
318                            for dim in mirror:
319                                y[:, :, dim, :] *= -1
320                        x[key] = y
321                        x[key][mask] = self.blank
322        return main_input, ssl_inputs, ssl_targets
323
324    def _shift(
325        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
326    ) -> Tuple[Dict, List, List]:
327        """
328        Shift the "coords" features of the input randomly
329        """
330
331        keys = self._get_keys(("coords", "center"), main_input)
332        if len(keys) == 0:
333            return main_input, ssl_inputs, ssl_targets
334        batch = main_input[keys[0]].shape[0]
335        dim = main_input[keys[0]].shape[2]
336        device = main_input[keys[0]].device
337        coords = torch.cat(
338            [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys],
339            dim=1,
340        )
341        minimask = coords[:, :, 0] != self.blank
342        min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0]
343        min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0]
344        max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0]
345        max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0]
346        del coords
347        shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * (
348            max_x - min_x
349        )
350        shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * (
351            max_y - min_y
352        )
353        shift_x = shift_x.unsqueeze(-1).unsqueeze(-1)
354        shift_y = shift_y.unsqueeze(-1).unsqueeze(-1)
355        dicts = [main_input] + ssl_inputs + ssl_targets
356        for x in dicts:
357            if x is not None:
358                keys = self._get_keys(("coords", "center"), x)
359                for key in keys:
360                    y = deepcopy(x[key])
361                    mask = y == self.blank
362                    y[:, :, 0, :] += shift_x
363                    y[:, :, 1, :] += shift_y
364                    x[key] = y
365                    x[key][mask] = self.blank
366        return main_input, ssl_inputs, ssl_targets
367
368    def _add_noise(
369        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
370    ) -> Tuple[Dict, List, List]:
371        """
372        Add normal noise to all features of the input
373        """
374
375        dicts = [main_input] + ssl_inputs + ssl_targets
376        for x in dicts:
377            if x is not None:
378                keys = self._get_keys(
379                    (
380                        "coords",
381                        "speed_joints",
382                        "speed_direction",
383                        "intra_distance",
384                        "acc_joints",
385                        "angle_speeds",
386                        "inter_distance",
387                        "speed_bones",
388                        "acc_bones",
389                        "bones",
390                        "coord_diff",
391                        "center",
392                    ),
393                    x,
394                )
395                for key in keys:
396                    mask = x[key] == self.blank
397                    x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_(
398                        std=self.noise_std
399                    ).to(x[key].device)
400                    x[key][mask] = self.blank
401        return main_input, ssl_inputs, ssl_targets
402
403    def _zoom(
404        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
405    ) -> Tuple[Dict, List, List]:
406        """
407        Add random zoom to all features of the input
408        """
409
410        key = list(main_input.keys())[0]
411        batch = main_input[key].shape[0]
412        device = main_input[key].device
413        zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device)
414        dicts = [main_input] + ssl_inputs + ssl_targets
415        for x in dicts:
416            if x is not None:
417                keys = self._get_keys(
418                    (
419                        "speed_joints",
420                        "intra_distance",
421                        "acc_joints",
422                        "inter_distance",
423                        "speed_bones",
424                        "acc_bones",
425                        "bones",
426                        "coord_diff",
427                        "center",
428                        "speed_value",
429                    ),
430                    x,
431                )
432                for key in keys:
433                    mask = x[key] == self.blank
434                    shape = np.array(x[key].shape)
435                    shape[1:] = 1
436                    y = deepcopy(x[key])
437                    y *= zoom.reshape(list(shape))
438                    x[key] = y
439                    x[key][mask] = self.blank
440                keys = self._get_keys("coords", x)
441                for key in keys:
442                    mask = x[key] == self.blank
443                    shape = np.array(x[key].shape)
444                    shape[1:] = 1
445                    zoom = zoom.reshape(list(shape))
446                    x[key][mask] = 10
447                    min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1)
448                    min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1)
449                    center = torch.stack([min_x, min_y], dim=2)
450                    coords = (x[key] - center) * zoom + center
451                    x[key] = coords
452                    x[key][mask] = self.blank
453
454        return main_input, ssl_inputs, ssl_targets
455
456    def _mask_joints(
457        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
458    ) -> Tuple[Dict, List, List]:
459        """
460        Mask joints randomly
461        """
462
463        key = list(main_input.keys())[0]
464        batch, *_, frames = main_input[key].shape
465        masked_joints = (
466            torch.FloatTensor(batch, self.n_bodyparts).uniform_()
467            < self.masking_probability
468        )
469        dicts = [main_input] + ssl_inputs + ssl_targets
470        for x in [y for y in dicts if y is not None]:
471            keys = self._get_keys(("intra_distance", "inter_distance"), x)
472            for key in keys:
473                mask = (
474                    masked_joints.repeat(self.n_bodyparts, frames, 1, 1)
475                    .transpose(0, 2)
476                    .transpose(1, 3)
477                )
478                indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1)
479
480                X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to(
481                    x[key].device
482                )
483                X[:, indices[0], indices[1], :] = x[key]
484                X[mask] = self.blank
485                X[mask.transpose(1, 2)] = self.blank
486                x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames)
487            keys = self._get_keys(
488                (
489                    "speed_joints",
490                    "speed_direction",
491                    "coords",
492                    "acc_joints",
493                    "speed_bones",
494                    "acc_bones",
495                    "bones",
496                    "coord_diff",
497                ),
498                x,
499            )
500            for key in keys:
501                mask = (
502                    masked_joints.repeat(self.dim, frames, 1, 1)
503                    .transpose(0, 2)
504                    .transpose(1, 3)
505                )
506                x[key][mask] = (
507                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask]
508                )
509            keys = self._get_keys("angle_speeds", x)
510            for key in keys:
511                mask = (
512                    masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2)
513                )
514                x[key][mask] = (
515                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask]
516                )
517
518        return main_input, ssl_inputs, ssl_targets
519
520    def _switch(
521        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
522    ) -> Tuple[Dict, List, List]:
523        if bool(getrandbits(1)):
524            return main_input, ssl_inputs, ssl_targets
525        individuals = set()
526        ind_dict = defaultdict(lambda: set())
527        for key in main_input:
528            if len(key.split("---")) != 2:
529                continue
530            ind = key.split("---")[1]
531            if "+" in ind:
532                continue
533            individuals.add(ind)
534            ind_dict[ind].add(key.split("---")[0])
535        individuals = list(individuals)
536        if len(individuals) < 2:
537            return main_input, ssl_inputs, ssl_targets
538        for x, y in zip(individuals[::2], individuals[1::2]):
539            for key in ind_dict[x]:
540                if key in ind_dict[y]:
541                    main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = (
542                        main_input[f"{key}---{y}"],
543                        main_input[f"{key}---{x}"],
544                    )
545                for d in ssl_inputs + ssl_targets:
546                    if d is None:
547                        continue
548                    if f"{key}---{x}" in d and f"{key}---{y}" in d:
549                        d[f"{key}---{x}"], d[f"{key}---{y}"] = (
550                            d[f"{key}---{y}"],
551                            d[f"{key}---{x}"],
552                        )
553        return main_input, ssl_inputs, ssl_targets
554
555    def _get_keys(self, key_bases: Iterable, x: Dict) -> List:
556        """
557        Get the keys of x that start with one of the strings from key_bases
558        """
559
560        keys = []
561        if isinstance(key_bases, str):
562            key_bases = [key_bases]
563        for key in x:
564            if any([x == key.split("---")[0] for x in key_bases]):
565                keys.append(key)
566        return keys
567
568    def _augmentations_dict(self) -> Dict:
569        """
570        Get the mapping from augmentation names to functions
571        """
572
573        return {
574            "mirror": self._mirror,
575            "shift": self._shift,
576            "add_noise": self._add_noise,
577            "zoom": self._zoom,
578            "rotate": self._rotate,
579            "mask": self._mask_joints,
580            "switch": self._switch,
581        }
582
583    def _default_augmentations(self) -> List:
584        """
585        Get the list of default augmentation names
586        """
587
588        return ["mirror", "shift", "add_noise"]
589
590    def _get_bodyparts(self, shape: Tuple) -> None:
591        """
592        Set the number of bodyparts from the data if it is not known
593        """
594
595        if self.n_bodyparts is None:
596            N, B, F = shape
597            self.n_bodyparts = B // 2

A transformer that augments the output of the Kinematic feature extractor

The available augmentations are 'rotate', 'mirror', 'shift', 'add_noise' and 'zoom'

KinematicTransformer( model_name: str, augmentations: List = None, use_default_augmentations: bool = False, rotation_limits: List = None, mirror_dim: Set = None, noise_std: float = 0.05, zoom_limits: List = None, masking_probability: float = 0.05, dim: int = 2, graph_features: bool = False, bodyparts_order: List = None, canvas_shape: List = None, move_around_image_center: bool = True, **kwargs)
 28    def __init__(
 29        self,
 30        model_name: str,
 31        augmentations: List = None,
 32        use_default_augmentations: bool = False,
 33        rotation_limits: List = None,
 34        mirror_dim: Set = None,
 35        noise_std: float = 0.05,
 36        zoom_limits: List = None,
 37        masking_probability: float = 0.05,
 38        dim: int = 2,
 39        graph_features: bool = False,
 40        bodyparts_order: List = None,
 41        canvas_shape: List = None,
 42        move_around_image_center: bool = True,
 43        **kwargs,
 44    ) -> None:
 45        """
 46        Parameters
 47        ----------
 48        augmentations : list, optional
 49            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 50        use_default_augmentations : bool, default False
 51            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 52        rotation_limits : list, default [-pi/2, pi/2]
 53            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 54            for 3D data)
 55        mirror_dim : set, default {0}
 56            set of integer indices of dimensions that can be mirrored
 57        noise_std : float, default 0.05
 58            standard deviation of noise
 59        zoom_limits : list, default [0.5, 1.5]
 60            list of float zoom limits ([low, high])
 61        masking_probability : float, default 0.1
 62            the probability of masking a joint
 63        dim : int, default 2
 64            the dimensionality of the input data
 65        **kwargs : dict
 66            other parameters for the base transformer class
 67        """
 68
 69        if augmentations is None:
 70            augmentations = []
 71        if canvas_shape is None:
 72            canvas_shape = [1, 1]
 73        self.dim = int(dim)
 74
 75        self.offset = [0 for _ in range(self.dim)]
 76        self.scale = canvas_shape[1] / canvas_shape[0]
 77        self.image_center = move_around_image_center
 78        # if canvas_shape is None:
 79        #     self.offset = [0.5 for _ in range(self.dim)]
 80        # else:
 81        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 82
 83        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 84        super().__init__(
 85            model_name,
 86            augmentations,
 87            use_default_augmentations,
 88            graph_features=graph_features,
 89            bodyparts_order=bodyparts_order,
 90            **kwargs,
 91        )
 92        if rotation_limits is None:
 93            rotation_limits = [-np.pi / 2, np.pi / 2]
 94        if mirror_dim is None:
 95            mirror_dim = [0]
 96        if zoom_limits is None:
 97            zoom_limits = [0.5, 1.5]
 98        self.rotation_limits = rotation_limits
 99        self.n_bodyparts = None
100        self.mirror_dim = mirror_dim
101        self.noise_std = noise_std
102        self.zoom_limits = zoom_limits
103        self.masking_probability = masking_probability

Parameters

augmentations : list, optional list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") use_default_augmentations : bool, default False if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations rotation_limits : list, default [-pi/2, pi/2] list of float rotation angle limits ([low, high]``, or[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` for 3D data) mirror_dim : set, default {0} set of integer indices of dimensions that can be mirrored noise_std : float, default 0.05 standard deviation of noise zoom_limits : list, default [0.5, 1.5] list of float zoom limits ([low, high]) masking_probability : float, default 0.1 the probability of masking a joint dim : int, default 2 the dimensionality of the input data **kwargs : dict other parameters for the base transformer class