dlc2action.transformer.kinematic

Kinematic transformer

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""
  8Kinematic transformer
  9"""
 10
 11from typing import Dict, Tuple, List, Set, Iterable
 12import torch
 13import numpy as np
 14from copy import copy, deepcopy
 15from random import getrandbits
 16from collections import defaultdict
 17from dlc2action.transformer.base_transformer import Transformer
 18from dlc2action.utils import rotation_matrix_2d, rotation_matrix_3d
 19
 20
 21class KinematicTransformer(Transformer):
 22    """
 23    A transformer that augments the output of the Kinematic feature extractor
 24
 25    The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'`
 26    """
 27
 28    def __init__(
 29        self,
 30        model_name: str,
 31        augmentations: List = None,
 32        use_default_augmentations: bool = False,
 33        rotation_limits: List = None,
 34        mirror_dim: Set = None,
 35        noise_std: float = 0.05,
 36        zoom_limits: List = None,
 37        masking_probability: float = 0.05,
 38        dim: int = 2,
 39        graph_features: bool = False,
 40        bodyparts_order: List = None,
 41        canvas_shape: List = None,
 42        move_around_image_center: bool = True,
 43        **kwargs,
 44    ) -> None:
 45        """
 46        Parameters
 47        ----------
 48        augmentations : list, optional
 49            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 50        use_default_augmentations : bool, default False
 51            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 52        rotation_limits : list, default [-pi/2, pi/2]
 53            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 54            for 3D data)
 55        mirror_dim : set, default {0}
 56            set of integer indices of dimensions that can be mirrored
 57        noise_std : float, default 0.05
 58            standard deviation of noise
 59        zoom_limits : list, default [0.5, 1.5]
 60            list of float zoom limits ([low, high])
 61        masking_probability : float, default 0.1
 62            the probability of masking a joint
 63        dim : int, default 2
 64            the dimensionality of the input data
 65        **kwargs : dict
 66            other parameters for the base transformer class
 67        """
 68
 69        if augmentations is None:
 70            augmentations = []
 71        if canvas_shape is None:
 72            canvas_shape = [1, 1]
 73        self.dim = int(dim)
 74
 75        self.offset = [0 for _ in range(self.dim)]
 76        if isinstance(canvas_shape, str):
 77            canvas_shape = eval(canvas_shape)
 78        self.scale = canvas_shape[1] / canvas_shape[0]
 79        self.image_center = move_around_image_center
 80        # if canvas_shape is None:
 81        #     self.offset = [0.5 for _ in range(self.dim)]
 82        # else:
 83        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 84
 85        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 86        super().__init__(
 87            model_name,
 88            augmentations,
 89            use_default_augmentations,
 90            graph_features=graph_features,
 91            bodyparts_order=bodyparts_order,
 92            **kwargs,
 93        )
 94        if rotation_limits is None:
 95            rotation_limits = [-np.pi / 2, np.pi / 2]
 96        if mirror_dim is None:
 97            mirror_dim = [0]
 98        if zoom_limits is None:
 99            zoom_limits = [0.5, 1.5]
100        self.rotation_limits = rotation_limits
101        self.n_bodyparts = None
102        self.mirror_dim = mirror_dim
103        self.noise_std = noise_std
104        self.zoom_limits = zoom_limits
105        self.masking_probability = masking_probability
106
107    def _apply_augmentations(
108        self,
109        main_input: Dict,
110        ssl_inputs: List = None,
111        ssl_targets: List = None,
112        augment: bool = False,
113        ssl: bool = False,
114    ) -> Tuple:
115
116        if ssl_targets is None:
117            ssl_targets = [None]
118        if ssl_inputs is None:
119            ssl_inputs = [None]
120
121        keys_main = self._get_keys(
122            (
123                "coords",
124                "speed_joints",
125                "speed_direction",
126                "speed_bones",
127                "acc_bones",
128                "bones",
129                "coord_diff",
130            ),
131            main_input,
132        )
133        if len(keys_main) > 0:
134            key = keys_main[0]
135            final_shape = main_input[key].shape
136            # self._get_bodyparts(final_shape)
137            batch = final_shape[0]
138            s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1]))
139            x_shape = s.shape
140            # print(f'{x_shape=}, {self.dim=}')
141            self.n_bodyparts = x_shape[1]
142            if self.dim == 3:
143                if len(self.rotation_limits) == 2:
144                    self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits]
145            dicts = [main_input] + ssl_inputs + ssl_targets
146            for x in dicts:
147                if x is not None:
148                    keys = self._get_keys(
149                        (
150                            "coords",
151                            "speed_joints",
152                            "speed_direction",
153                            "speed_bones",
154                            "acc_bones",
155                            "bones",
156                            "coord_diff",
157                        ),
158                        x,
159                    )
160                    for key in keys:
161                        x[key] = x[key].reshape(x_shape)
162        if len(self._get_keys(("center"), main_input)) > 0:
163            dicts = [main_input] + ssl_inputs + ssl_targets
164            for x in dicts:
165                if x is not None:
166                    key_bases = ["center"]
167                    keys = self._get_keys(
168                        key_bases,
169                        x,
170                    )
171                    for key in keys:
172                        x[key] = x[key].reshape(
173                            (x[key].shape[0], 1, -1, x[key].shape[-1])
174                        )
175        main_input, ssl_inputs, ssl_targets = super()._apply_augmentations(
176            main_input, ssl_inputs, ssl_targets, augment
177        )
178        if len(keys_main) > 0:
179            dicts = [main_input] + ssl_inputs + ssl_targets
180            for x in dicts:
181                if x is not None:
182                    keys = self._get_keys(
183                        (
184                            "coords",
185                            "speed_joints",
186                            "speed_direction",
187                            "speed_bones",
188                            "acc_bones",
189                            "bones",
190                            "coord_diff",
191                        ),
192                        x,
193                    )
194                    for key in keys:
195                        x[key] = x[key].reshape(final_shape)
196        if len(self._get_keys(("center"), main_input)) > 0:
197            dicts = [main_input] + ssl_inputs + ssl_targets
198            for x in dicts:
199                if x is not None:
200                    key_bases = ["center"]
201                    keys = self._get_keys(
202                        key_bases,
203                        x,
204                    )
205                    for key in keys:
206                        x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1]))
207        return main_input, ssl_inputs, ssl_targets
208
209    def _rotate(
210        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
211    ) -> (Dict, List, List):
212        """
213        Rotate the "coords" and "speed_joints" features of the input to a random degree
214        """
215
216        keys = self._get_keys(
217            (
218                "coords",
219                "bones",
220                "coord_diff",
221                "center",
222                "speed_joints",
223                "speed_direction",
224                "center",
225            ),
226            main_input,
227        )
228        if len(keys) == 0:
229            return main_input, ssl_inputs, ssl_targets
230        batch = main_input[keys[0]].shape[0]
231        if self.dim == 2:
232            angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits)
233            R = rotation_matrix_2d(angles).to(main_input[keys[0]].device)
234        else:
235            angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0])
236            angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1])
237            angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2])
238            R = rotation_matrix_3d(angles_x, angles_y, angles_z).to(
239                main_input[keys[0]].device
240            )
241        dicts = [main_input] + ssl_inputs + ssl_targets
242        for x in dicts:
243            if x is not None:
244                keys = self._get_keys(
245                    (
246                        "coords",
247                        "speed_joints",
248                        "speed_direction",
249                        "speed_bones",
250                        "acc_bones",
251                        "bones",
252                        "coord_diff",
253                        "center",
254                    ),
255                    x,
256                )
257                for key in keys:
258                    if key in x:
259                        mask = x[key] == self.blank
260                        if (
261                            any([key.startswith(x) for x in ["coords", "center"]])
262                            and not self.image_center
263                        ):
264                            offset = x[key].mean(1).unsqueeze(1)
265                        else:
266                            offset = 0
267                        x[key] = (
268                            torch.einsum(
269                                "abjd,aij->abid",
270                                x[key] - offset,
271                                R,
272                            )
273                            + offset
274                        )
275                        x[key][mask] = self.blank
276        return main_input, ssl_inputs, ssl_targets
277
278    def _mirror(
279        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
280    ) -> (Dict, List, List):
281        """
282        Mirror the "coords" and "speed_joints" features of the input randomly
283        """
284
285        mirror = []
286        for i in range(3):
287            if i in self.mirror_dim and bool(getrandbits(1)):
288                mirror.append(i)
289        if len(mirror) == 0:
290            return main_input, ssl_inputs, ssl_targets
291        dicts = [main_input] + ssl_inputs + ssl_targets
292        for x in dicts:
293            if x is not None:
294                keys = self._get_keys(
295                    (
296                        "coords",
297                        "speed_joints",
298                        "speed_direction",
299                        "speed_bones",
300                        "acc_bones",
301                        "bones",
302                        "center",
303                        "coord_diff",
304                    ),
305                    x,
306                )
307                for key in keys:
308                    if key in x:
309                        mask = x[key] == self.blank
310                        y = deepcopy(x[key])
311                        if not self.image_center and any(
312                            [key.startswith(x) for x in ["coords", "center"]]
313                        ):
314                            mean = y.mean(1).unsqueeze(1)
315                            for dim in mirror:
316                                y[:, :, dim, :] = (
317                                    2 * mean[:, :, dim, :] - y[:, :, dim, :]
318                                )
319                        else:
320                            for dim in mirror:
321                                y[:, :, dim, :] *= -1
322                        x[key] = y
323                        x[key][mask] = self.blank
324        return main_input, ssl_inputs, ssl_targets
325
326    def _shift(
327        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
328    ) -> (Dict, List, List):
329        """
330        Shift the "coords" features of the input randomly
331        """
332
333        keys = self._get_keys(("coords", "center"), main_input)
334        if len(keys) == 0:
335            return main_input, ssl_inputs, ssl_targets
336        batch = main_input[keys[0]].shape[0]
337        dim = main_input[keys[0]].shape[2]
338        device = main_input[keys[0]].device
339        coords = torch.cat(
340            [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys],
341            dim=1,
342        )
343        minimask = coords[:, :, 0] != self.blank
344        min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0]
345        min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0]
346        max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0]
347        max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0]
348        del coords
349        shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * (
350            max_x - min_x
351        )
352        shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * (
353            max_y - min_y
354        )
355        shift_x = shift_x.unsqueeze(-1).unsqueeze(-1)
356        shift_y = shift_y.unsqueeze(-1).unsqueeze(-1)
357        dicts = [main_input] + ssl_inputs + ssl_targets
358        for x in dicts:
359            if x is not None:
360                keys = self._get_keys(("coords", "center"), x)
361                for key in keys:
362                    y = deepcopy(x[key])
363                    mask = y == self.blank
364                    y[:, :, 0, :] += shift_x
365                    y[:, :, 1, :] += shift_y
366                    x[key] = y
367                    x[key][mask] = self.blank
368        return main_input, ssl_inputs, ssl_targets
369
370    def _add_noise(
371        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
372    ) -> (Dict, List, List):
373        """
374        Add normal noise to all features of the input
375        """
376
377        dicts = [main_input] + ssl_inputs + ssl_targets
378        for x in dicts:
379            if x is not None:
380                keys = self._get_keys(
381                    (
382                        "coords",
383                        "speed_joints",
384                        "speed_direction",
385                        "intra_distance",
386                        "acc_joints",
387                        "angle_speeds",
388                        "inter_distance",
389                        "speed_bones",
390                        "acc_bones",
391                        "bones",
392                        "coord_diff",
393                        "center",
394                    ),
395                    x,
396                )
397                for key in keys:
398                    mask = x[key] == self.blank
399                    x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_(
400                        std=self.noise_std
401                    ).to(x[key].device)
402                    x[key][mask] = self.blank
403        return main_input, ssl_inputs, ssl_targets
404
405    def _zoom(
406        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
407    ) -> (Dict, List, List):
408        """
409        Add random zoom to all features of the input
410        """
411
412        key = list(main_input.keys())[0]
413        batch = main_input[key].shape[0]
414        device = main_input[key].device
415        zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device)
416        dicts = [main_input] + ssl_inputs + ssl_targets
417        for x in dicts:
418            if x is not None:
419                keys = self._get_keys(
420                    (
421                        "speed_joints",
422                        "intra_distance",
423                        "acc_joints",
424                        "inter_distance",
425                        "speed_bones",
426                        "acc_bones",
427                        "bones",
428                        "coord_diff",
429                        "center",
430                        "speed_value",
431                    ),
432                    x,
433                )
434                for key in keys:
435                    mask = x[key] == self.blank
436                    shape = np.array(x[key].shape)
437                    shape[1:] = 1
438                    y = deepcopy(x[key])
439                    y *= zoom.reshape(list(shape))
440                    x[key] = y
441                    x[key][mask] = self.blank
442                keys = self._get_keys("coords", x)
443                for key in keys:
444                    mask = x[key] == self.blank
445                    shape = np.array(x[key].shape)
446                    shape[1:] = 1
447                    zoom = zoom.reshape(list(shape))
448                    x[key][mask] = 10
449                    min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1)
450                    min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1)
451                    if x[key].shape[2] == 2:
452                        center = torch.stack([min_x, min_y], dim=2)
453                    if x[key].shape[2] == 3:
454                        min_z = x[key][:, :, 2, :].min(1)[0].unsqueeze(1)
455                        center = torch.stack([min_x, min_y, min_z], dim=2)
456                    coords = (x[key] - center) * zoom + center
457                    x[key] = coords
458                    x[key][mask] = self.blank
459
460        return main_input, ssl_inputs, ssl_targets
461
462    def _mask_joints(
463        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
464    ) -> (Dict, List, List):
465        """
466        Mask joints randomly
467        """
468
469        key = list(main_input.keys())[0]
470        batch, *_, frames = main_input[key].shape
471        masked_joints = (
472            torch.FloatTensor(batch, self.n_bodyparts).uniform_()
473            < self.masking_probability
474        )
475        dicts = [main_input] + ssl_inputs + ssl_targets
476        for x in [y for y in dicts if y is not None]:
477            keys = self._get_keys(("intra_distance", "inter_distance"), x)
478            for key in keys:
479                mask = (
480                    masked_joints.repeat(self.n_bodyparts, frames, 1, 1)
481                    .transpose(0, 2)
482                    .transpose(1, 3)
483                )
484                indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1)
485
486                X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to(
487                    x[key].device
488                )
489                X[:, indices[0], indices[1], :] = x[key]
490                X[mask] = self.blank
491                X[mask.transpose(1, 2)] = self.blank
492                x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames)
493            keys = self._get_keys(
494                (
495                    "speed_joints",
496                    "speed_direction",
497                    "coords",
498                    "acc_joints",
499                    "speed_bones",
500                    "acc_bones",
501                    "bones",
502                    "coord_diff",
503                ),
504                x,
505            )
506            for key in keys:
507                mask = (
508                    masked_joints.repeat(self.dim, frames, 1, 1)
509                    .transpose(0, 2)
510                    .transpose(1, 3)
511                )
512                x[key][mask] = (
513                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask]
514                )
515            keys = self._get_keys("angle_speeds", x)
516            for key in keys:
517                mask = (
518                    masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2)
519                )
520                x[key][mask] = (
521                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask]
522                )
523
524        return main_input, ssl_inputs, ssl_targets
525
526    def _switch(
527        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
528    ) -> (Dict, List, List):
529        if bool(getrandbits(1)):
530            return main_input, ssl_inputs, ssl_targets
531        individuals = set()
532        ind_dict = defaultdict(lambda: set())
533        for key in main_input:
534            if len(key.split("---")) != 2:
535                continue
536            ind = key.split("---")[1]
537            if "+" in ind:
538                continue
539            individuals.add(ind)
540            ind_dict[ind].add(key.split("---")[0])
541        individuals = list(individuals)
542        if len(individuals) < 2:
543            return main_input, ssl_inputs, ssl_targets
544        for x, y in zip(individuals[::2], individuals[1::2]):
545            for key in ind_dict[x]:
546                if key in ind_dict[y]:
547                    main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = (
548                        main_input[f"{key}---{y}"],
549                        main_input[f"{key}---{x}"],
550                    )
551                for d in ssl_inputs + ssl_targets:
552                    if d is None:
553                        continue
554                    if f"{key}---{x}" in d and f"{key}---{y}" in d:
555                        d[f"{key}---{x}"], d[f"{key}---{y}"] = (
556                            d[f"{key}---{y}"],
557                            d[f"{key}---{x}"],
558                        )
559        return main_input, ssl_inputs, ssl_targets
560
561    def _get_keys(self, key_bases: Iterable, x: Dict) -> List:
562        """
563        Get the keys of x that start with one of the strings from key_bases
564        """
565
566        keys = []
567        if isinstance(key_bases, str):
568            key_bases = [key_bases]
569        for key in x:
570            if any([x == key.split("---")[0] for x in key_bases]):
571                keys.append(key)
572        return keys
573
574    def _augmentations_dict(self) -> Dict:
575        """
576        Get the mapping from augmentation names to functions
577        """
578
579        return {
580            "mirror": self._mirror,
581            "shift": self._shift,
582            "add_noise": self._add_noise,
583            "zoom": self._zoom,
584            "rotate": self._rotate,
585            "mask": self._mask_joints,
586            "switch": self._switch,
587        }
588
589    def _default_augmentations(self) -> List:
590        """
591        Get the list of default augmentation names
592        """
593
594        return ["mirror", "shift", "add_noise"]
595
596    def _get_bodyparts(self, shape: Tuple) -> None:
597        """
598        Set the number of bodyparts from the data if it is not known
599        """
600
601        if self.n_bodyparts is None:
602            N, B, F = shape
603            self.n_bodyparts = B // 2
class KinematicTransformer(dlc2action.transformer.base_transformer.Transformer):
 22class KinematicTransformer(Transformer):
 23    """
 24    A transformer that augments the output of the Kinematic feature extractor
 25
 26    The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'`
 27    """
 28
 29    def __init__(
 30        self,
 31        model_name: str,
 32        augmentations: List = None,
 33        use_default_augmentations: bool = False,
 34        rotation_limits: List = None,
 35        mirror_dim: Set = None,
 36        noise_std: float = 0.05,
 37        zoom_limits: List = None,
 38        masking_probability: float = 0.05,
 39        dim: int = 2,
 40        graph_features: bool = False,
 41        bodyparts_order: List = None,
 42        canvas_shape: List = None,
 43        move_around_image_center: bool = True,
 44        **kwargs,
 45    ) -> None:
 46        """
 47        Parameters
 48        ----------
 49        augmentations : list, optional
 50            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 51        use_default_augmentations : bool, default False
 52            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 53        rotation_limits : list, default [-pi/2, pi/2]
 54            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 55            for 3D data)
 56        mirror_dim : set, default {0}
 57            set of integer indices of dimensions that can be mirrored
 58        noise_std : float, default 0.05
 59            standard deviation of noise
 60        zoom_limits : list, default [0.5, 1.5]
 61            list of float zoom limits ([low, high])
 62        masking_probability : float, default 0.1
 63            the probability of masking a joint
 64        dim : int, default 2
 65            the dimensionality of the input data
 66        **kwargs : dict
 67            other parameters for the base transformer class
 68        """
 69
 70        if augmentations is None:
 71            augmentations = []
 72        if canvas_shape is None:
 73            canvas_shape = [1, 1]
 74        self.dim = int(dim)
 75
 76        self.offset = [0 for _ in range(self.dim)]
 77        if isinstance(canvas_shape, str):
 78            canvas_shape = eval(canvas_shape)
 79        self.scale = canvas_shape[1] / canvas_shape[0]
 80        self.image_center = move_around_image_center
 81        # if canvas_shape is None:
 82        #     self.offset = [0.5 for _ in range(self.dim)]
 83        # else:
 84        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 85
 86        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 87        super().__init__(
 88            model_name,
 89            augmentations,
 90            use_default_augmentations,
 91            graph_features=graph_features,
 92            bodyparts_order=bodyparts_order,
 93            **kwargs,
 94        )
 95        if rotation_limits is None:
 96            rotation_limits = [-np.pi / 2, np.pi / 2]
 97        if mirror_dim is None:
 98            mirror_dim = [0]
 99        if zoom_limits is None:
100            zoom_limits = [0.5, 1.5]
101        self.rotation_limits = rotation_limits
102        self.n_bodyparts = None
103        self.mirror_dim = mirror_dim
104        self.noise_std = noise_std
105        self.zoom_limits = zoom_limits
106        self.masking_probability = masking_probability
107
108    def _apply_augmentations(
109        self,
110        main_input: Dict,
111        ssl_inputs: List = None,
112        ssl_targets: List = None,
113        augment: bool = False,
114        ssl: bool = False,
115    ) -> Tuple:
116
117        if ssl_targets is None:
118            ssl_targets = [None]
119        if ssl_inputs is None:
120            ssl_inputs = [None]
121
122        keys_main = self._get_keys(
123            (
124                "coords",
125                "speed_joints",
126                "speed_direction",
127                "speed_bones",
128                "acc_bones",
129                "bones",
130                "coord_diff",
131            ),
132            main_input,
133        )
134        if len(keys_main) > 0:
135            key = keys_main[0]
136            final_shape = main_input[key].shape
137            # self._get_bodyparts(final_shape)
138            batch = final_shape[0]
139            s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1]))
140            x_shape = s.shape
141            # print(f'{x_shape=}, {self.dim=}')
142            self.n_bodyparts = x_shape[1]
143            if self.dim == 3:
144                if len(self.rotation_limits) == 2:
145                    self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits]
146            dicts = [main_input] + ssl_inputs + ssl_targets
147            for x in dicts:
148                if x is not None:
149                    keys = self._get_keys(
150                        (
151                            "coords",
152                            "speed_joints",
153                            "speed_direction",
154                            "speed_bones",
155                            "acc_bones",
156                            "bones",
157                            "coord_diff",
158                        ),
159                        x,
160                    )
161                    for key in keys:
162                        x[key] = x[key].reshape(x_shape)
163        if len(self._get_keys(("center"), main_input)) > 0:
164            dicts = [main_input] + ssl_inputs + ssl_targets
165            for x in dicts:
166                if x is not None:
167                    key_bases = ["center"]
168                    keys = self._get_keys(
169                        key_bases,
170                        x,
171                    )
172                    for key in keys:
173                        x[key] = x[key].reshape(
174                            (x[key].shape[0], 1, -1, x[key].shape[-1])
175                        )
176        main_input, ssl_inputs, ssl_targets = super()._apply_augmentations(
177            main_input, ssl_inputs, ssl_targets, augment
178        )
179        if len(keys_main) > 0:
180            dicts = [main_input] + ssl_inputs + ssl_targets
181            for x in dicts:
182                if x is not None:
183                    keys = self._get_keys(
184                        (
185                            "coords",
186                            "speed_joints",
187                            "speed_direction",
188                            "speed_bones",
189                            "acc_bones",
190                            "bones",
191                            "coord_diff",
192                        ),
193                        x,
194                    )
195                    for key in keys:
196                        x[key] = x[key].reshape(final_shape)
197        if len(self._get_keys(("center"), main_input)) > 0:
198            dicts = [main_input] + ssl_inputs + ssl_targets
199            for x in dicts:
200                if x is not None:
201                    key_bases = ["center"]
202                    keys = self._get_keys(
203                        key_bases,
204                        x,
205                    )
206                    for key in keys:
207                        x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1]))
208        return main_input, ssl_inputs, ssl_targets
209
210    def _rotate(
211        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
212    ) -> (Dict, List, List):
213        """
214        Rotate the "coords" and "speed_joints" features of the input to a random degree
215        """
216
217        keys = self._get_keys(
218            (
219                "coords",
220                "bones",
221                "coord_diff",
222                "center",
223                "speed_joints",
224                "speed_direction",
225                "center",
226            ),
227            main_input,
228        )
229        if len(keys) == 0:
230            return main_input, ssl_inputs, ssl_targets
231        batch = main_input[keys[0]].shape[0]
232        if self.dim == 2:
233            angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits)
234            R = rotation_matrix_2d(angles).to(main_input[keys[0]].device)
235        else:
236            angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0])
237            angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1])
238            angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2])
239            R = rotation_matrix_3d(angles_x, angles_y, angles_z).to(
240                main_input[keys[0]].device
241            )
242        dicts = [main_input] + ssl_inputs + ssl_targets
243        for x in dicts:
244            if x is not None:
245                keys = self._get_keys(
246                    (
247                        "coords",
248                        "speed_joints",
249                        "speed_direction",
250                        "speed_bones",
251                        "acc_bones",
252                        "bones",
253                        "coord_diff",
254                        "center",
255                    ),
256                    x,
257                )
258                for key in keys:
259                    if key in x:
260                        mask = x[key] == self.blank
261                        if (
262                            any([key.startswith(x) for x in ["coords", "center"]])
263                            and not self.image_center
264                        ):
265                            offset = x[key].mean(1).unsqueeze(1)
266                        else:
267                            offset = 0
268                        x[key] = (
269                            torch.einsum(
270                                "abjd,aij->abid",
271                                x[key] - offset,
272                                R,
273                            )
274                            + offset
275                        )
276                        x[key][mask] = self.blank
277        return main_input, ssl_inputs, ssl_targets
278
279    def _mirror(
280        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
281    ) -> (Dict, List, List):
282        """
283        Mirror the "coords" and "speed_joints" features of the input randomly
284        """
285
286        mirror = []
287        for i in range(3):
288            if i in self.mirror_dim and bool(getrandbits(1)):
289                mirror.append(i)
290        if len(mirror) == 0:
291            return main_input, ssl_inputs, ssl_targets
292        dicts = [main_input] + ssl_inputs + ssl_targets
293        for x in dicts:
294            if x is not None:
295                keys = self._get_keys(
296                    (
297                        "coords",
298                        "speed_joints",
299                        "speed_direction",
300                        "speed_bones",
301                        "acc_bones",
302                        "bones",
303                        "center",
304                        "coord_diff",
305                    ),
306                    x,
307                )
308                for key in keys:
309                    if key in x:
310                        mask = x[key] == self.blank
311                        y = deepcopy(x[key])
312                        if not self.image_center and any(
313                            [key.startswith(x) for x in ["coords", "center"]]
314                        ):
315                            mean = y.mean(1).unsqueeze(1)
316                            for dim in mirror:
317                                y[:, :, dim, :] = (
318                                    2 * mean[:, :, dim, :] - y[:, :, dim, :]
319                                )
320                        else:
321                            for dim in mirror:
322                                y[:, :, dim, :] *= -1
323                        x[key] = y
324                        x[key][mask] = self.blank
325        return main_input, ssl_inputs, ssl_targets
326
327    def _shift(
328        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
329    ) -> (Dict, List, List):
330        """
331        Shift the "coords" features of the input randomly
332        """
333
334        keys = self._get_keys(("coords", "center"), main_input)
335        if len(keys) == 0:
336            return main_input, ssl_inputs, ssl_targets
337        batch = main_input[keys[0]].shape[0]
338        dim = main_input[keys[0]].shape[2]
339        device = main_input[keys[0]].device
340        coords = torch.cat(
341            [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys],
342            dim=1,
343        )
344        minimask = coords[:, :, 0] != self.blank
345        min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0]
346        min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0]
347        max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0]
348        max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0]
349        del coords
350        shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * (
351            max_x - min_x
352        )
353        shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * (
354            max_y - min_y
355        )
356        shift_x = shift_x.unsqueeze(-1).unsqueeze(-1)
357        shift_y = shift_y.unsqueeze(-1).unsqueeze(-1)
358        dicts = [main_input] + ssl_inputs + ssl_targets
359        for x in dicts:
360            if x is not None:
361                keys = self._get_keys(("coords", "center"), x)
362                for key in keys:
363                    y = deepcopy(x[key])
364                    mask = y == self.blank
365                    y[:, :, 0, :] += shift_x
366                    y[:, :, 1, :] += shift_y
367                    x[key] = y
368                    x[key][mask] = self.blank
369        return main_input, ssl_inputs, ssl_targets
370
371    def _add_noise(
372        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
373    ) -> (Dict, List, List):
374        """
375        Add normal noise to all features of the input
376        """
377
378        dicts = [main_input] + ssl_inputs + ssl_targets
379        for x in dicts:
380            if x is not None:
381                keys = self._get_keys(
382                    (
383                        "coords",
384                        "speed_joints",
385                        "speed_direction",
386                        "intra_distance",
387                        "acc_joints",
388                        "angle_speeds",
389                        "inter_distance",
390                        "speed_bones",
391                        "acc_bones",
392                        "bones",
393                        "coord_diff",
394                        "center",
395                    ),
396                    x,
397                )
398                for key in keys:
399                    mask = x[key] == self.blank
400                    x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_(
401                        std=self.noise_std
402                    ).to(x[key].device)
403                    x[key][mask] = self.blank
404        return main_input, ssl_inputs, ssl_targets
405
406    def _zoom(
407        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
408    ) -> (Dict, List, List):
409        """
410        Add random zoom to all features of the input
411        """
412
413        key = list(main_input.keys())[0]
414        batch = main_input[key].shape[0]
415        device = main_input[key].device
416        zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device)
417        dicts = [main_input] + ssl_inputs + ssl_targets
418        for x in dicts:
419            if x is not None:
420                keys = self._get_keys(
421                    (
422                        "speed_joints",
423                        "intra_distance",
424                        "acc_joints",
425                        "inter_distance",
426                        "speed_bones",
427                        "acc_bones",
428                        "bones",
429                        "coord_diff",
430                        "center",
431                        "speed_value",
432                    ),
433                    x,
434                )
435                for key in keys:
436                    mask = x[key] == self.blank
437                    shape = np.array(x[key].shape)
438                    shape[1:] = 1
439                    y = deepcopy(x[key])
440                    y *= zoom.reshape(list(shape))
441                    x[key] = y
442                    x[key][mask] = self.blank
443                keys = self._get_keys("coords", x)
444                for key in keys:
445                    mask = x[key] == self.blank
446                    shape = np.array(x[key].shape)
447                    shape[1:] = 1
448                    zoom = zoom.reshape(list(shape))
449                    x[key][mask] = 10
450                    min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1)
451                    min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1)
452                    if x[key].shape[2] == 2:
453                        center = torch.stack([min_x, min_y], dim=2)
454                    if x[key].shape[2] == 3:
455                        min_z = x[key][:, :, 2, :].min(1)[0].unsqueeze(1)
456                        center = torch.stack([min_x, min_y, min_z], dim=2)
457                    coords = (x[key] - center) * zoom + center
458                    x[key] = coords
459                    x[key][mask] = self.blank
460
461        return main_input, ssl_inputs, ssl_targets
462
463    def _mask_joints(
464        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
465    ) -> (Dict, List, List):
466        """
467        Mask joints randomly
468        """
469
470        key = list(main_input.keys())[0]
471        batch, *_, frames = main_input[key].shape
472        masked_joints = (
473            torch.FloatTensor(batch, self.n_bodyparts).uniform_()
474            < self.masking_probability
475        )
476        dicts = [main_input] + ssl_inputs + ssl_targets
477        for x in [y for y in dicts if y is not None]:
478            keys = self._get_keys(("intra_distance", "inter_distance"), x)
479            for key in keys:
480                mask = (
481                    masked_joints.repeat(self.n_bodyparts, frames, 1, 1)
482                    .transpose(0, 2)
483                    .transpose(1, 3)
484                )
485                indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1)
486
487                X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to(
488                    x[key].device
489                )
490                X[:, indices[0], indices[1], :] = x[key]
491                X[mask] = self.blank
492                X[mask.transpose(1, 2)] = self.blank
493                x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames)
494            keys = self._get_keys(
495                (
496                    "speed_joints",
497                    "speed_direction",
498                    "coords",
499                    "acc_joints",
500                    "speed_bones",
501                    "acc_bones",
502                    "bones",
503                    "coord_diff",
504                ),
505                x,
506            )
507            for key in keys:
508                mask = (
509                    masked_joints.repeat(self.dim, frames, 1, 1)
510                    .transpose(0, 2)
511                    .transpose(1, 3)
512                )
513                x[key][mask] = (
514                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask]
515                )
516            keys = self._get_keys("angle_speeds", x)
517            for key in keys:
518                mask = (
519                    masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2)
520                )
521                x[key][mask] = (
522                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask]
523                )
524
525        return main_input, ssl_inputs, ssl_targets
526
527    def _switch(
528        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
529    ) -> (Dict, List, List):
530        if bool(getrandbits(1)):
531            return main_input, ssl_inputs, ssl_targets
532        individuals = set()
533        ind_dict = defaultdict(lambda: set())
534        for key in main_input:
535            if len(key.split("---")) != 2:
536                continue
537            ind = key.split("---")[1]
538            if "+" in ind:
539                continue
540            individuals.add(ind)
541            ind_dict[ind].add(key.split("---")[0])
542        individuals = list(individuals)
543        if len(individuals) < 2:
544            return main_input, ssl_inputs, ssl_targets
545        for x, y in zip(individuals[::2], individuals[1::2]):
546            for key in ind_dict[x]:
547                if key in ind_dict[y]:
548                    main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = (
549                        main_input[f"{key}---{y}"],
550                        main_input[f"{key}---{x}"],
551                    )
552                for d in ssl_inputs + ssl_targets:
553                    if d is None:
554                        continue
555                    if f"{key}---{x}" in d and f"{key}---{y}" in d:
556                        d[f"{key}---{x}"], d[f"{key}---{y}"] = (
557                            d[f"{key}---{y}"],
558                            d[f"{key}---{x}"],
559                        )
560        return main_input, ssl_inputs, ssl_targets
561
562    def _get_keys(self, key_bases: Iterable, x: Dict) -> List:
563        """
564        Get the keys of x that start with one of the strings from key_bases
565        """
566
567        keys = []
568        if isinstance(key_bases, str):
569            key_bases = [key_bases]
570        for key in x:
571            if any([x == key.split("---")[0] for x in key_bases]):
572                keys.append(key)
573        return keys
574
575    def _augmentations_dict(self) -> Dict:
576        """
577        Get the mapping from augmentation names to functions
578        """
579
580        return {
581            "mirror": self._mirror,
582            "shift": self._shift,
583            "add_noise": self._add_noise,
584            "zoom": self._zoom,
585            "rotate": self._rotate,
586            "mask": self._mask_joints,
587            "switch": self._switch,
588        }
589
590    def _default_augmentations(self) -> List:
591        """
592        Get the list of default augmentation names
593        """
594
595        return ["mirror", "shift", "add_noise"]
596
597    def _get_bodyparts(self, shape: Tuple) -> None:
598        """
599        Set the number of bodyparts from the data if it is not known
600        """
601
602        if self.n_bodyparts is None:
603            N, B, F = shape
604            self.n_bodyparts = B // 2

A transformer that augments the output of the Kinematic feature extractor

The available augmentations are 'rotate', 'mirror', 'shift', 'add_noise' and 'zoom'

KinematicTransformer( model_name: str, augmentations: List = None, use_default_augmentations: bool = False, rotation_limits: List = None, mirror_dim: Set = None, noise_std: float = 0.05, zoom_limits: List = None, masking_probability: float = 0.05, dim: int = 2, graph_features: bool = False, bodyparts_order: List = None, canvas_shape: List = None, move_around_image_center: bool = True, **kwargs)
 29    def __init__(
 30        self,
 31        model_name: str,
 32        augmentations: List = None,
 33        use_default_augmentations: bool = False,
 34        rotation_limits: List = None,
 35        mirror_dim: Set = None,
 36        noise_std: float = 0.05,
 37        zoom_limits: List = None,
 38        masking_probability: float = 0.05,
 39        dim: int = 2,
 40        graph_features: bool = False,
 41        bodyparts_order: List = None,
 42        canvas_shape: List = None,
 43        move_around_image_center: bool = True,
 44        **kwargs,
 45    ) -> None:
 46        """
 47        Parameters
 48        ----------
 49        augmentations : list, optional
 50            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 51        use_default_augmentations : bool, default False
 52            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 53        rotation_limits : list, default [-pi/2, pi/2]
 54            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 55            for 3D data)
 56        mirror_dim : set, default {0}
 57            set of integer indices of dimensions that can be mirrored
 58        noise_std : float, default 0.05
 59            standard deviation of noise
 60        zoom_limits : list, default [0.5, 1.5]
 61            list of float zoom limits ([low, high])
 62        masking_probability : float, default 0.1
 63            the probability of masking a joint
 64        dim : int, default 2
 65            the dimensionality of the input data
 66        **kwargs : dict
 67            other parameters for the base transformer class
 68        """
 69
 70        if augmentations is None:
 71            augmentations = []
 72        if canvas_shape is None:
 73            canvas_shape = [1, 1]
 74        self.dim = int(dim)
 75
 76        self.offset = [0 for _ in range(self.dim)]
 77        if isinstance(canvas_shape, str):
 78            canvas_shape = eval(canvas_shape)
 79        self.scale = canvas_shape[1] / canvas_shape[0]
 80        self.image_center = move_around_image_center
 81        # if canvas_shape is None:
 82        #     self.offset = [0.5 for _ in range(self.dim)]
 83        # else:
 84        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 85
 86        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 87        super().__init__(
 88            model_name,
 89            augmentations,
 90            use_default_augmentations,
 91            graph_features=graph_features,
 92            bodyparts_order=bodyparts_order,
 93            **kwargs,
 94        )
 95        if rotation_limits is None:
 96            rotation_limits = [-np.pi / 2, np.pi / 2]
 97        if mirror_dim is None:
 98            mirror_dim = [0]
 99        if zoom_limits is None:
100            zoom_limits = [0.5, 1.5]
101        self.rotation_limits = rotation_limits
102        self.n_bodyparts = None
103        self.mirror_dim = mirror_dim
104        self.noise_std = noise_std
105        self.zoom_limits = zoom_limits
106        self.masking_probability = masking_probability

Parameters

augmentations : list, optional list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") use_default_augmentations : bool, default False if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations rotation_limits : list, default [-pi/2, pi/2] list of float rotation angle limits ([low, high]``, or[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` for 3D data) mirror_dim : set, default {0} set of integer indices of dimensions that can be mirrored noise_std : float, default 0.05 standard deviation of noise zoom_limits : list, default [0.5, 1.5] list of float zoom limits ([low, high]) masking_probability : float, default 0.1 the probability of masking a joint dim : int, default 2 the dimensionality of the input data **kwargs : dict other parameters for the base transformer class

dim
offset
scale
image_center
blank
rotation_limits
n_bodyparts
mirror_dim
noise_std
zoom_limits
masking_probability