dlc2action.transformer.kinematic

Kinematic transformer

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""
  8Kinematic transformer
  9"""
 10
 11from typing import Dict, Tuple, List, Set, Iterable
 12import torch
 13import numpy as np
 14from copy import copy, deepcopy
 15from random import getrandbits
 16from collections import defaultdict
 17from dlc2action.transformer.base_transformer import Transformer
 18from dlc2action.utils import rotation_matrix_2d, rotation_matrix_3d
 19
 20
 21class KinematicTransformer(Transformer):
 22    """
 23    A transformer that augments the output of the Kinematic feature extractor
 24
 25    The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'`
 26    """
 27
 28    def __init__(
 29        self,
 30        model_name: str,
 31        augmentations: List = None,
 32        use_default_augmentations: bool = False,
 33        rotation_limits: List = None,
 34        mirror_dim: Set = None,
 35        noise_std: float = 0.05,
 36        zoom_limits: List = None,
 37        masking_probability: float = 0.05,
 38        dim: int = 2,
 39        graph_features: bool = False,
 40        bodyparts_order: List = None,
 41        canvas_shape: List = None,
 42        move_around_image_center: bool = True,
 43        **kwargs,
 44    ) -> None:
 45        """
 46        Parameters
 47        ----------
 48        augmentations : list, optional
 49            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 50        use_default_augmentations : bool, default False
 51            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 52        rotation_limits : list, default [-pi/2, pi/2]
 53            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 54            for 3D data)
 55        mirror_dim : set, default {0}
 56            set of integer indices of dimensions that can be mirrored
 57        noise_std : float, default 0.05
 58            standard deviation of noise
 59        zoom_limits : list, default [0.5, 1.5]
 60            list of float zoom limits ([low, high])
 61        masking_probability : float, default 0.1
 62            the probability of masking a joint
 63        dim : int, default 2
 64            the dimensionality of the input data
 65        **kwargs : dict
 66            other parameters for the base transformer class
 67        """
 68
 69        if augmentations is None:
 70            augmentations = []
 71        if canvas_shape is None:
 72            canvas_shape = [1, 1]
 73        self.dim = int(dim)
 74
 75        self.offset = [0 for _ in range(self.dim)]
 76        if isinstance(canvas_shape, str):
 77            canvas_shape = eval(canvas_shape)
 78        self.scale = canvas_shape[1] / canvas_shape[0]
 79        self.image_center = move_around_image_center
 80        # if canvas_shape is None:
 81        #     self.offset = [0.5 for _ in range(self.dim)]
 82        # else:
 83        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 84
 85        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 86        super().__init__(
 87            model_name,
 88            augmentations,
 89            use_default_augmentations,
 90            graph_features=graph_features,
 91            bodyparts_order=bodyparts_order,
 92            **kwargs,
 93        )
 94        if rotation_limits is None:
 95            rotation_limits = [-np.pi / 2, np.pi / 2]
 96        if mirror_dim is None:
 97            mirror_dim = [0]
 98        if zoom_limits is None:
 99            zoom_limits = [0.5, 1.5]
100        self.rotation_limits = rotation_limits
101        self.n_bodyparts = None
102        self.mirror_dim = mirror_dim
103        self.noise_std = noise_std
104        self.zoom_limits = zoom_limits
105        self.masking_probability = masking_probability
106
107    def _apply_augmentations(
108        self,
109        main_input: Dict,
110        ssl_inputs: List = None,
111        ssl_targets: List = None,
112        augment: bool = False,
113        ssl: bool = False,
114    ) -> Tuple:
115
116        if ssl_targets is None:
117            ssl_targets = [None]
118        if ssl_inputs is None:
119            ssl_inputs = [None]
120
121        keys_main = self._get_keys(
122            (
123                "coords",
124                "speed_joints",
125                "speed_direction",
126                "speed_bones",
127                "acc_bones",
128                "bones",
129                "coord_diff",
130            ),
131            main_input,
132        )
133        if len(keys_main) > 0:
134            key = keys_main[0]
135            final_shape = main_input[key].shape
136            # self._get_bodyparts(final_shape)
137            batch = final_shape[0]
138            s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1]))
139            x_shape = s.shape
140            # print(f'{x_shape=}, {self.dim=}')
141            self.n_bodyparts = x_shape[1]
142            if self.dim == 3:
143                if len(self.rotation_limits) == 2:
144                    self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits]
145            dicts = [main_input] + ssl_inputs + ssl_targets
146            for x in dicts:
147                if x is not None:
148                    keys = self._get_keys(
149                        (
150                            "coords",
151                            "speed_joints",
152                            "speed_direction",
153                            "speed_bones",
154                            "acc_bones",
155                            "bones",
156                            "coord_diff",
157                        ),
158                        x,
159                    )
160                    for key in keys:
161                        x[key] = x[key].reshape(x_shape)
162        if len(self._get_keys(("center"), main_input)) > 0:
163            dicts = [main_input] + ssl_inputs + ssl_targets
164            for x in dicts:
165                if x is not None:
166                    key_bases = ["center"]
167                    keys = self._get_keys(
168                        key_bases,
169                        x,
170                    )
171                    for key in keys:
172                        x[key] = x[key].reshape(
173                            (x[key].shape[0], 1, -1, x[key].shape[-1])
174                        )
175        main_input, ssl_inputs, ssl_targets = super()._apply_augmentations(
176            main_input, ssl_inputs, ssl_targets, augment
177        )
178        if len(keys_main) > 0:
179            dicts = [main_input] + ssl_inputs + ssl_targets
180            for x in dicts:
181                if x is not None:
182                    keys = self._get_keys(
183                        (
184                            "coords",
185                            "speed_joints",
186                            "speed_direction",
187                            "speed_bones",
188                            "acc_bones",
189                            "bones",
190                            "coord_diff",
191                        ),
192                        x,
193                    )
194                    for key in keys:
195                        x[key] = x[key].reshape(final_shape)
196        if len(self._get_keys(("center"), main_input)) > 0:
197            dicts = [main_input] + ssl_inputs + ssl_targets
198            for x in dicts:
199                if x is not None:
200                    key_bases = ["center"]
201                    keys = self._get_keys(
202                        key_bases,
203                        x,
204                    )
205                    for key in keys:
206                        x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1]))
207        return main_input, ssl_inputs, ssl_targets
208
209    def _rotate(
210        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
211    ) -> (Dict, List, List):
212        """
213        Rotate the "coords" and "speed_joints" features of the input to a random degree
214        """
215
216        keys = self._get_keys(
217            (
218                "coords",
219                "bones",
220                "coord_diff",
221                "center",
222                "speed_joints",
223                "speed_direction",
224                "center",
225            ),
226            main_input,
227        )
228        if len(keys) == 0:
229            return main_input, ssl_inputs, ssl_targets
230        batch = main_input[keys[0]].shape[0]
231        if self.dim == 2:
232            angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits)
233            R = rotation_matrix_2d(angles).to(main_input[keys[0]].device)
234        else:
235            angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0])
236            angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1])
237            angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2])
238            R = rotation_matrix_3d(angles_x, angles_y, angles_z).to(
239                main_input[keys[0]].device
240            )
241        dicts = [main_input] + ssl_inputs + ssl_targets
242        for x in dicts:
243            if x is not None:
244                keys = self._get_keys(
245                    (
246                        "coords",
247                        "speed_joints",
248                        "speed_direction",
249                        "speed_bones",
250                        "acc_bones",
251                        "bones",
252                        "coord_diff",
253                        "center",
254                    ),
255                    x,
256                )
257                for key in keys:
258                    if key in x:
259                        mask = x[key] == self.blank
260                        if (
261                            any([key.startswith(x) for x in ["coords", "center"]])
262                            and not self.image_center
263                        ):
264                            offset = x[key].mean(1).unsqueeze(1)
265                        else:
266                            offset = 0
267                        x[key] = (
268                            torch.einsum(
269                                "abjd,aij->abid",
270                                x[key] - offset,
271                                R,
272                            )
273                            + offset
274                        )
275                        x[key][mask] = self.blank
276        return main_input, ssl_inputs, ssl_targets
277
278    def _mirror(
279        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
280    ) -> (Dict, List, List):
281        """
282        Mirror the "coords" and "speed_joints" features of the input randomly
283        """
284
285        mirror = []
286        for i in range(3):
287            if i in self.mirror_dim and bool(getrandbits(1)):
288                mirror.append(i)
289        if len(mirror) == 0:
290            return main_input, ssl_inputs, ssl_targets
291        dicts = [main_input] + ssl_inputs + ssl_targets
292        for x in dicts:
293            if x is not None:
294                keys = self._get_keys(
295                    (
296                        "coords",
297                        "speed_joints",
298                        "speed_direction",
299                        "speed_bones",
300                        "acc_bones",
301                        "bones",
302                        "center",
303                        "coord_diff",
304                    ),
305                    x,
306                )
307                for key in keys:
308                    if key in x:
309                        mask = x[key] == self.blank
310                        y = deepcopy(x[key])
311                        if not self.image_center and any(
312                            [key.startswith(x) for x in ["coords", "center"]]
313                        ):
314                            mean = y.mean(1).unsqueeze(1)
315                            for dim in mirror:
316                                y[:, :, dim, :] = (
317                                    2 * mean[:, :, dim, :] - y[:, :, dim, :]
318                                )
319                        else:
320                            for dim in mirror:
321                                y[:, :, dim, :] *= -1
322                        x[key] = y
323                        x[key][mask] = self.blank
324        return main_input, ssl_inputs, ssl_targets
325
326    def _shift(
327        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
328    ) -> (Dict, List, List):
329        """
330        Shift the "coords" features of the input randomly
331        """
332
333        keys = self._get_keys(("coords", "center"), main_input)
334        if len(keys) == 0:
335            return main_input, ssl_inputs, ssl_targets
336        batch = main_input[keys[0]].shape[0]
337        dim = main_input[keys[0]].shape[2]
338        device = main_input[keys[0]].device
339        coords = torch.cat(
340            [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys],
341            dim=1,
342        )
343        minimask = coords[:, :, 0] != self.blank
344        min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0]
345        min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0]
346        max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0]
347        max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0]
348        del coords
349        shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * (
350            max_x - min_x
351        )
352        shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * (
353            max_y - min_y
354        )
355        shift_x = shift_x.unsqueeze(-1).unsqueeze(-1)
356        shift_y = shift_y.unsqueeze(-1).unsqueeze(-1)
357        dicts = [main_input] + ssl_inputs + ssl_targets
358        for x in dicts:
359            if x is not None:
360                keys = self._get_keys(("coords", "center"), x)
361                for key in keys:
362                    y = deepcopy(x[key])
363                    mask = y == self.blank
364                    y[:, :, 0, :] += shift_x
365                    y[:, :, 1, :] += shift_y
366                    x[key] = y
367                    x[key][mask] = self.blank
368        return main_input, ssl_inputs, ssl_targets
369
370    def _add_noise(
371        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
372    ) -> (Dict, List, List):
373        """
374        Add normal noise to all features of the input
375        """
376
377        dicts = [main_input] + ssl_inputs + ssl_targets
378        for x in dicts:
379            if x is not None:
380                keys = self._get_keys(
381                    (
382                        "coords",
383                        "speed_joints",
384                        "speed_direction",
385                        "intra_distance",
386                        "acc_joints",
387                        "angle_speeds",
388                        "inter_distance",
389                        "speed_bones",
390                        "acc_bones",
391                        "bones",
392                        "coord_diff",
393                        "center",
394                    ),
395                    x,
396                )
397                for key in keys:
398                    mask = x[key] == self.blank
399                    x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_(
400                        std=self.noise_std
401                    ).to(x[key].device)
402                    x[key][mask] = self.blank
403        return main_input, ssl_inputs, ssl_targets
404
405    def _zoom(
406        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
407    ) -> (Dict, List, List):
408        """
409        Add random zoom to all features of the input
410        """
411
412        key = list(main_input.keys())[0]
413        batch = main_input[key].shape[0]
414        device = main_input[key].device
415        zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device)
416        dicts = [main_input] + ssl_inputs + ssl_targets
417        for x in dicts:
418            if x is not None:
419                keys = self._get_keys(
420                    (
421                        "speed_joints",
422                        "intra_distance",
423                        "acc_joints",
424                        "inter_distance",
425                        "speed_bones",
426                        "acc_bones",
427                        "bones",
428                        "coord_diff",
429                        "center",
430                        "speed_value",
431                    ),
432                    x,
433                )
434                for key in keys:
435                    mask = x[key] == self.blank
436                    shape = np.array(x[key].shape)
437                    shape[1:] = 1
438                    y = deepcopy(x[key])
439                    y *= zoom.reshape(list(shape))
440                    x[key] = y
441                    x[key][mask] = self.blank
442                keys = self._get_keys("coords", x)
443                for key in keys:
444                    mask = x[key] == self.blank
445                    shape = np.array(x[key].shape)
446                    shape[1:] = 1
447                    zoom = zoom.reshape(list(shape))
448                    x[key][mask] = 10
449                    min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1)
450                    min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1)
451                    center = torch.stack([min_x, min_y], dim=2)
452                    coords = (x[key] - center) * zoom + center
453                    x[key] = coords
454                    x[key][mask] = self.blank
455
456        return main_input, ssl_inputs, ssl_targets
457
458    def _mask_joints(
459        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
460    ) -> (Dict, List, List):
461        """
462        Mask joints randomly
463        """
464
465        key = list(main_input.keys())[0]
466        batch, *_, frames = main_input[key].shape
467        masked_joints = (
468            torch.FloatTensor(batch, self.n_bodyparts).uniform_()
469            < self.masking_probability
470        )
471        dicts = [main_input] + ssl_inputs + ssl_targets
472        for x in [y for y in dicts if y is not None]:
473            keys = self._get_keys(("intra_distance", "inter_distance"), x)
474            for key in keys:
475                mask = (
476                    masked_joints.repeat(self.n_bodyparts, frames, 1, 1)
477                    .transpose(0, 2)
478                    .transpose(1, 3)
479                )
480                indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1)
481
482                X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to(
483                    x[key].device
484                )
485                X[:, indices[0], indices[1], :] = x[key]
486                X[mask] = self.blank
487                X[mask.transpose(1, 2)] = self.blank
488                x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames)
489            keys = self._get_keys(
490                (
491                    "speed_joints",
492                    "speed_direction",
493                    "coords",
494                    "acc_joints",
495                    "speed_bones",
496                    "acc_bones",
497                    "bones",
498                    "coord_diff",
499                ),
500                x,
501            )
502            for key in keys:
503                mask = (
504                    masked_joints.repeat(self.dim, frames, 1, 1)
505                    .transpose(0, 2)
506                    .transpose(1, 3)
507                )
508                x[key][mask] = (
509                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask]
510                )
511            keys = self._get_keys("angle_speeds", x)
512            for key in keys:
513                mask = (
514                    masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2)
515                )
516                x[key][mask] = (
517                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask]
518                )
519
520        return main_input, ssl_inputs, ssl_targets
521
522    def _switch(
523        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
524    ) -> (Dict, List, List):
525        if bool(getrandbits(1)):
526            return main_input, ssl_inputs, ssl_targets
527        individuals = set()
528        ind_dict = defaultdict(lambda: set())
529        for key in main_input:
530            if len(key.split("---")) != 2:
531                continue
532            ind = key.split("---")[1]
533            if "+" in ind:
534                continue
535            individuals.add(ind)
536            ind_dict[ind].add(key.split("---")[0])
537        individuals = list(individuals)
538        if len(individuals) < 2:
539            return main_input, ssl_inputs, ssl_targets
540        for x, y in zip(individuals[::2], individuals[1::2]):
541            for key in ind_dict[x]:
542                if key in ind_dict[y]:
543                    main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = (
544                        main_input[f"{key}---{y}"],
545                        main_input[f"{key}---{x}"],
546                    )
547                for d in ssl_inputs + ssl_targets:
548                    if d is None:
549                        continue
550                    if f"{key}---{x}" in d and f"{key}---{y}" in d:
551                        d[f"{key}---{x}"], d[f"{key}---{y}"] = (
552                            d[f"{key}---{y}"],
553                            d[f"{key}---{x}"],
554                        )
555        return main_input, ssl_inputs, ssl_targets
556
557    def _get_keys(self, key_bases: Iterable, x: Dict) -> List:
558        """
559        Get the keys of x that start with one of the strings from key_bases
560        """
561
562        keys = []
563        if isinstance(key_bases, str):
564            key_bases = [key_bases]
565        for key in x:
566            if any([x == key.split("---")[0] for x in key_bases]):
567                keys.append(key)
568        return keys
569
570    def _augmentations_dict(self) -> Dict:
571        """
572        Get the mapping from augmentation names to functions
573        """
574
575        return {
576            "mirror": self._mirror,
577            "shift": self._shift,
578            "add_noise": self._add_noise,
579            "zoom": self._zoom,
580            "rotate": self._rotate,
581            "mask": self._mask_joints,
582            "switch": self._switch,
583        }
584
585    def _default_augmentations(self) -> List:
586        """
587        Get the list of default augmentation names
588        """
589
590        return ["mirror", "shift", "add_noise"]
591
592    def _get_bodyparts(self, shape: Tuple) -> None:
593        """
594        Set the number of bodyparts from the data if it is not known
595        """
596
597        if self.n_bodyparts is None:
598            N, B, F = shape
599            self.n_bodyparts = B // 2
class KinematicTransformer(dlc2action.transformer.base_transformer.Transformer):
 22class KinematicTransformer(Transformer):
 23    """
 24    A transformer that augments the output of the Kinematic feature extractor
 25
 26    The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'`
 27    """
 28
 29    def __init__(
 30        self,
 31        model_name: str,
 32        augmentations: List = None,
 33        use_default_augmentations: bool = False,
 34        rotation_limits: List = None,
 35        mirror_dim: Set = None,
 36        noise_std: float = 0.05,
 37        zoom_limits: List = None,
 38        masking_probability: float = 0.05,
 39        dim: int = 2,
 40        graph_features: bool = False,
 41        bodyparts_order: List = None,
 42        canvas_shape: List = None,
 43        move_around_image_center: bool = True,
 44        **kwargs,
 45    ) -> None:
 46        """
 47        Parameters
 48        ----------
 49        augmentations : list, optional
 50            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 51        use_default_augmentations : bool, default False
 52            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 53        rotation_limits : list, default [-pi/2, pi/2]
 54            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 55            for 3D data)
 56        mirror_dim : set, default {0}
 57            set of integer indices of dimensions that can be mirrored
 58        noise_std : float, default 0.05
 59            standard deviation of noise
 60        zoom_limits : list, default [0.5, 1.5]
 61            list of float zoom limits ([low, high])
 62        masking_probability : float, default 0.1
 63            the probability of masking a joint
 64        dim : int, default 2
 65            the dimensionality of the input data
 66        **kwargs : dict
 67            other parameters for the base transformer class
 68        """
 69
 70        if augmentations is None:
 71            augmentations = []
 72        if canvas_shape is None:
 73            canvas_shape = [1, 1]
 74        self.dim = int(dim)
 75
 76        self.offset = [0 for _ in range(self.dim)]
 77        if isinstance(canvas_shape, str):
 78            canvas_shape = eval(canvas_shape)
 79        self.scale = canvas_shape[1] / canvas_shape[0]
 80        self.image_center = move_around_image_center
 81        # if canvas_shape is None:
 82        #     self.offset = [0.5 for _ in range(self.dim)]
 83        # else:
 84        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 85
 86        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 87        super().__init__(
 88            model_name,
 89            augmentations,
 90            use_default_augmentations,
 91            graph_features=graph_features,
 92            bodyparts_order=bodyparts_order,
 93            **kwargs,
 94        )
 95        if rotation_limits is None:
 96            rotation_limits = [-np.pi / 2, np.pi / 2]
 97        if mirror_dim is None:
 98            mirror_dim = [0]
 99        if zoom_limits is None:
100            zoom_limits = [0.5, 1.5]
101        self.rotation_limits = rotation_limits
102        self.n_bodyparts = None
103        self.mirror_dim = mirror_dim
104        self.noise_std = noise_std
105        self.zoom_limits = zoom_limits
106        self.masking_probability = masking_probability
107
108    def _apply_augmentations(
109        self,
110        main_input: Dict,
111        ssl_inputs: List = None,
112        ssl_targets: List = None,
113        augment: bool = False,
114        ssl: bool = False,
115    ) -> Tuple:
116
117        if ssl_targets is None:
118            ssl_targets = [None]
119        if ssl_inputs is None:
120            ssl_inputs = [None]
121
122        keys_main = self._get_keys(
123            (
124                "coords",
125                "speed_joints",
126                "speed_direction",
127                "speed_bones",
128                "acc_bones",
129                "bones",
130                "coord_diff",
131            ),
132            main_input,
133        )
134        if len(keys_main) > 0:
135            key = keys_main[0]
136            final_shape = main_input[key].shape
137            # self._get_bodyparts(final_shape)
138            batch = final_shape[0]
139            s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1]))
140            x_shape = s.shape
141            # print(f'{x_shape=}, {self.dim=}')
142            self.n_bodyparts = x_shape[1]
143            if self.dim == 3:
144                if len(self.rotation_limits) == 2:
145                    self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits]
146            dicts = [main_input] + ssl_inputs + ssl_targets
147            for x in dicts:
148                if x is not None:
149                    keys = self._get_keys(
150                        (
151                            "coords",
152                            "speed_joints",
153                            "speed_direction",
154                            "speed_bones",
155                            "acc_bones",
156                            "bones",
157                            "coord_diff",
158                        ),
159                        x,
160                    )
161                    for key in keys:
162                        x[key] = x[key].reshape(x_shape)
163        if len(self._get_keys(("center"), main_input)) > 0:
164            dicts = [main_input] + ssl_inputs + ssl_targets
165            for x in dicts:
166                if x is not None:
167                    key_bases = ["center"]
168                    keys = self._get_keys(
169                        key_bases,
170                        x,
171                    )
172                    for key in keys:
173                        x[key] = x[key].reshape(
174                            (x[key].shape[0], 1, -1, x[key].shape[-1])
175                        )
176        main_input, ssl_inputs, ssl_targets = super()._apply_augmentations(
177            main_input, ssl_inputs, ssl_targets, augment
178        )
179        if len(keys_main) > 0:
180            dicts = [main_input] + ssl_inputs + ssl_targets
181            for x in dicts:
182                if x is not None:
183                    keys = self._get_keys(
184                        (
185                            "coords",
186                            "speed_joints",
187                            "speed_direction",
188                            "speed_bones",
189                            "acc_bones",
190                            "bones",
191                            "coord_diff",
192                        ),
193                        x,
194                    )
195                    for key in keys:
196                        x[key] = x[key].reshape(final_shape)
197        if len(self._get_keys(("center"), main_input)) > 0:
198            dicts = [main_input] + ssl_inputs + ssl_targets
199            for x in dicts:
200                if x is not None:
201                    key_bases = ["center"]
202                    keys = self._get_keys(
203                        key_bases,
204                        x,
205                    )
206                    for key in keys:
207                        x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1]))
208        return main_input, ssl_inputs, ssl_targets
209
210    def _rotate(
211        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
212    ) -> (Dict, List, List):
213        """
214        Rotate the "coords" and "speed_joints" features of the input to a random degree
215        """
216
217        keys = self._get_keys(
218            (
219                "coords",
220                "bones",
221                "coord_diff",
222                "center",
223                "speed_joints",
224                "speed_direction",
225                "center",
226            ),
227            main_input,
228        )
229        if len(keys) == 0:
230            return main_input, ssl_inputs, ssl_targets
231        batch = main_input[keys[0]].shape[0]
232        if self.dim == 2:
233            angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits)
234            R = rotation_matrix_2d(angles).to(main_input[keys[0]].device)
235        else:
236            angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0])
237            angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1])
238            angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2])
239            R = rotation_matrix_3d(angles_x, angles_y, angles_z).to(
240                main_input[keys[0]].device
241            )
242        dicts = [main_input] + ssl_inputs + ssl_targets
243        for x in dicts:
244            if x is not None:
245                keys = self._get_keys(
246                    (
247                        "coords",
248                        "speed_joints",
249                        "speed_direction",
250                        "speed_bones",
251                        "acc_bones",
252                        "bones",
253                        "coord_diff",
254                        "center",
255                    ),
256                    x,
257                )
258                for key in keys:
259                    if key in x:
260                        mask = x[key] == self.blank
261                        if (
262                            any([key.startswith(x) for x in ["coords", "center"]])
263                            and not self.image_center
264                        ):
265                            offset = x[key].mean(1).unsqueeze(1)
266                        else:
267                            offset = 0
268                        x[key] = (
269                            torch.einsum(
270                                "abjd,aij->abid",
271                                x[key] - offset,
272                                R,
273                            )
274                            + offset
275                        )
276                        x[key][mask] = self.blank
277        return main_input, ssl_inputs, ssl_targets
278
279    def _mirror(
280        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
281    ) -> (Dict, List, List):
282        """
283        Mirror the "coords" and "speed_joints" features of the input randomly
284        """
285
286        mirror = []
287        for i in range(3):
288            if i in self.mirror_dim and bool(getrandbits(1)):
289                mirror.append(i)
290        if len(mirror) == 0:
291            return main_input, ssl_inputs, ssl_targets
292        dicts = [main_input] + ssl_inputs + ssl_targets
293        for x in dicts:
294            if x is not None:
295                keys = self._get_keys(
296                    (
297                        "coords",
298                        "speed_joints",
299                        "speed_direction",
300                        "speed_bones",
301                        "acc_bones",
302                        "bones",
303                        "center",
304                        "coord_diff",
305                    ),
306                    x,
307                )
308                for key in keys:
309                    if key in x:
310                        mask = x[key] == self.blank
311                        y = deepcopy(x[key])
312                        if not self.image_center and any(
313                            [key.startswith(x) for x in ["coords", "center"]]
314                        ):
315                            mean = y.mean(1).unsqueeze(1)
316                            for dim in mirror:
317                                y[:, :, dim, :] = (
318                                    2 * mean[:, :, dim, :] - y[:, :, dim, :]
319                                )
320                        else:
321                            for dim in mirror:
322                                y[:, :, dim, :] *= -1
323                        x[key] = y
324                        x[key][mask] = self.blank
325        return main_input, ssl_inputs, ssl_targets
326
327    def _shift(
328        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
329    ) -> (Dict, List, List):
330        """
331        Shift the "coords" features of the input randomly
332        """
333
334        keys = self._get_keys(("coords", "center"), main_input)
335        if len(keys) == 0:
336            return main_input, ssl_inputs, ssl_targets
337        batch = main_input[keys[0]].shape[0]
338        dim = main_input[keys[0]].shape[2]
339        device = main_input[keys[0]].device
340        coords = torch.cat(
341            [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys],
342            dim=1,
343        )
344        minimask = coords[:, :, 0] != self.blank
345        min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0]
346        min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0]
347        max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0]
348        max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0]
349        del coords
350        shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * (
351            max_x - min_x
352        )
353        shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * (
354            max_y - min_y
355        )
356        shift_x = shift_x.unsqueeze(-1).unsqueeze(-1)
357        shift_y = shift_y.unsqueeze(-1).unsqueeze(-1)
358        dicts = [main_input] + ssl_inputs + ssl_targets
359        for x in dicts:
360            if x is not None:
361                keys = self._get_keys(("coords", "center"), x)
362                for key in keys:
363                    y = deepcopy(x[key])
364                    mask = y == self.blank
365                    y[:, :, 0, :] += shift_x
366                    y[:, :, 1, :] += shift_y
367                    x[key] = y
368                    x[key][mask] = self.blank
369        return main_input, ssl_inputs, ssl_targets
370
371    def _add_noise(
372        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
373    ) -> (Dict, List, List):
374        """
375        Add normal noise to all features of the input
376        """
377
378        dicts = [main_input] + ssl_inputs + ssl_targets
379        for x in dicts:
380            if x is not None:
381                keys = self._get_keys(
382                    (
383                        "coords",
384                        "speed_joints",
385                        "speed_direction",
386                        "intra_distance",
387                        "acc_joints",
388                        "angle_speeds",
389                        "inter_distance",
390                        "speed_bones",
391                        "acc_bones",
392                        "bones",
393                        "coord_diff",
394                        "center",
395                    ),
396                    x,
397                )
398                for key in keys:
399                    mask = x[key] == self.blank
400                    x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_(
401                        std=self.noise_std
402                    ).to(x[key].device)
403                    x[key][mask] = self.blank
404        return main_input, ssl_inputs, ssl_targets
405
406    def _zoom(
407        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
408    ) -> (Dict, List, List):
409        """
410        Add random zoom to all features of the input
411        """
412
413        key = list(main_input.keys())[0]
414        batch = main_input[key].shape[0]
415        device = main_input[key].device
416        zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device)
417        dicts = [main_input] + ssl_inputs + ssl_targets
418        for x in dicts:
419            if x is not None:
420                keys = self._get_keys(
421                    (
422                        "speed_joints",
423                        "intra_distance",
424                        "acc_joints",
425                        "inter_distance",
426                        "speed_bones",
427                        "acc_bones",
428                        "bones",
429                        "coord_diff",
430                        "center",
431                        "speed_value",
432                    ),
433                    x,
434                )
435                for key in keys:
436                    mask = x[key] == self.blank
437                    shape = np.array(x[key].shape)
438                    shape[1:] = 1
439                    y = deepcopy(x[key])
440                    y *= zoom.reshape(list(shape))
441                    x[key] = y
442                    x[key][mask] = self.blank
443                keys = self._get_keys("coords", x)
444                for key in keys:
445                    mask = x[key] == self.blank
446                    shape = np.array(x[key].shape)
447                    shape[1:] = 1
448                    zoom = zoom.reshape(list(shape))
449                    x[key][mask] = 10
450                    min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1)
451                    min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1)
452                    center = torch.stack([min_x, min_y], dim=2)
453                    coords = (x[key] - center) * zoom + center
454                    x[key] = coords
455                    x[key][mask] = self.blank
456
457        return main_input, ssl_inputs, ssl_targets
458
459    def _mask_joints(
460        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
461    ) -> (Dict, List, List):
462        """
463        Mask joints randomly
464        """
465
466        key = list(main_input.keys())[0]
467        batch, *_, frames = main_input[key].shape
468        masked_joints = (
469            torch.FloatTensor(batch, self.n_bodyparts).uniform_()
470            < self.masking_probability
471        )
472        dicts = [main_input] + ssl_inputs + ssl_targets
473        for x in [y for y in dicts if y is not None]:
474            keys = self._get_keys(("intra_distance", "inter_distance"), x)
475            for key in keys:
476                mask = (
477                    masked_joints.repeat(self.n_bodyparts, frames, 1, 1)
478                    .transpose(0, 2)
479                    .transpose(1, 3)
480                )
481                indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1)
482
483                X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to(
484                    x[key].device
485                )
486                X[:, indices[0], indices[1], :] = x[key]
487                X[mask] = self.blank
488                X[mask.transpose(1, 2)] = self.blank
489                x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames)
490            keys = self._get_keys(
491                (
492                    "speed_joints",
493                    "speed_direction",
494                    "coords",
495                    "acc_joints",
496                    "speed_bones",
497                    "acc_bones",
498                    "bones",
499                    "coord_diff",
500                ),
501                x,
502            )
503            for key in keys:
504                mask = (
505                    masked_joints.repeat(self.dim, frames, 1, 1)
506                    .transpose(0, 2)
507                    .transpose(1, 3)
508                )
509                x[key][mask] = (
510                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask]
511                )
512            keys = self._get_keys("angle_speeds", x)
513            for key in keys:
514                mask = (
515                    masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2)
516                )
517                x[key][mask] = (
518                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask]
519                )
520
521        return main_input, ssl_inputs, ssl_targets
522
523    def _switch(
524        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
525    ) -> (Dict, List, List):
526        if bool(getrandbits(1)):
527            return main_input, ssl_inputs, ssl_targets
528        individuals = set()
529        ind_dict = defaultdict(lambda: set())
530        for key in main_input:
531            if len(key.split("---")) != 2:
532                continue
533            ind = key.split("---")[1]
534            if "+" in ind:
535                continue
536            individuals.add(ind)
537            ind_dict[ind].add(key.split("---")[0])
538        individuals = list(individuals)
539        if len(individuals) < 2:
540            return main_input, ssl_inputs, ssl_targets
541        for x, y in zip(individuals[::2], individuals[1::2]):
542            for key in ind_dict[x]:
543                if key in ind_dict[y]:
544                    main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = (
545                        main_input[f"{key}---{y}"],
546                        main_input[f"{key}---{x}"],
547                    )
548                for d in ssl_inputs + ssl_targets:
549                    if d is None:
550                        continue
551                    if f"{key}---{x}" in d and f"{key}---{y}" in d:
552                        d[f"{key}---{x}"], d[f"{key}---{y}"] = (
553                            d[f"{key}---{y}"],
554                            d[f"{key}---{x}"],
555                        )
556        return main_input, ssl_inputs, ssl_targets
557
558    def _get_keys(self, key_bases: Iterable, x: Dict) -> List:
559        """
560        Get the keys of x that start with one of the strings from key_bases
561        """
562
563        keys = []
564        if isinstance(key_bases, str):
565            key_bases = [key_bases]
566        for key in x:
567            if any([x == key.split("---")[0] for x in key_bases]):
568                keys.append(key)
569        return keys
570
571    def _augmentations_dict(self) -> Dict:
572        """
573        Get the mapping from augmentation names to functions
574        """
575
576        return {
577            "mirror": self._mirror,
578            "shift": self._shift,
579            "add_noise": self._add_noise,
580            "zoom": self._zoom,
581            "rotate": self._rotate,
582            "mask": self._mask_joints,
583            "switch": self._switch,
584        }
585
586    def _default_augmentations(self) -> List:
587        """
588        Get the list of default augmentation names
589        """
590
591        return ["mirror", "shift", "add_noise"]
592
593    def _get_bodyparts(self, shape: Tuple) -> None:
594        """
595        Set the number of bodyparts from the data if it is not known
596        """
597
598        if self.n_bodyparts is None:
599            N, B, F = shape
600            self.n_bodyparts = B // 2

A transformer that augments the output of the Kinematic feature extractor

The available augmentations are 'rotate', 'mirror', 'shift', 'add_noise' and 'zoom'

KinematicTransformer( model_name: str, augmentations: List = None, use_default_augmentations: bool = False, rotation_limits: List = None, mirror_dim: Set = None, noise_std: float = 0.05, zoom_limits: List = None, masking_probability: float = 0.05, dim: int = 2, graph_features: bool = False, bodyparts_order: List = None, canvas_shape: List = None, move_around_image_center: bool = True, **kwargs)
 29    def __init__(
 30        self,
 31        model_name: str,
 32        augmentations: List = None,
 33        use_default_augmentations: bool = False,
 34        rotation_limits: List = None,
 35        mirror_dim: Set = None,
 36        noise_std: float = 0.05,
 37        zoom_limits: List = None,
 38        masking_probability: float = 0.05,
 39        dim: int = 2,
 40        graph_features: bool = False,
 41        bodyparts_order: List = None,
 42        canvas_shape: List = None,
 43        move_around_image_center: bool = True,
 44        **kwargs,
 45    ) -> None:
 46        """
 47        Parameters
 48        ----------
 49        augmentations : list, optional
 50            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 51        use_default_augmentations : bool, default False
 52            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 53        rotation_limits : list, default [-pi/2, pi/2]
 54            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 55            for 3D data)
 56        mirror_dim : set, default {0}
 57            set of integer indices of dimensions that can be mirrored
 58        noise_std : float, default 0.05
 59            standard deviation of noise
 60        zoom_limits : list, default [0.5, 1.5]
 61            list of float zoom limits ([low, high])
 62        masking_probability : float, default 0.1
 63            the probability of masking a joint
 64        dim : int, default 2
 65            the dimensionality of the input data
 66        **kwargs : dict
 67            other parameters for the base transformer class
 68        """
 69
 70        if augmentations is None:
 71            augmentations = []
 72        if canvas_shape is None:
 73            canvas_shape = [1, 1]
 74        self.dim = int(dim)
 75
 76        self.offset = [0 for _ in range(self.dim)]
 77        if isinstance(canvas_shape, str):
 78            canvas_shape = eval(canvas_shape)
 79        self.scale = canvas_shape[1] / canvas_shape[0]
 80        self.image_center = move_around_image_center
 81        # if canvas_shape is None:
 82        #     self.offset = [0.5 for _ in range(self.dim)]
 83        # else:
 84        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 85
 86        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 87        super().__init__(
 88            model_name,
 89            augmentations,
 90            use_default_augmentations,
 91            graph_features=graph_features,
 92            bodyparts_order=bodyparts_order,
 93            **kwargs,
 94        )
 95        if rotation_limits is None:
 96            rotation_limits = [-np.pi / 2, np.pi / 2]
 97        if mirror_dim is None:
 98            mirror_dim = [0]
 99        if zoom_limits is None:
100            zoom_limits = [0.5, 1.5]
101        self.rotation_limits = rotation_limits
102        self.n_bodyparts = None
103        self.mirror_dim = mirror_dim
104        self.noise_std = noise_std
105        self.zoom_limits = zoom_limits
106        self.masking_probability = masking_probability

Parameters

augmentations : list, optional list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") use_default_augmentations : bool, default False if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations rotation_limits : list, default [-pi/2, pi/2] list of float rotation angle limits ([low, high]``, or[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` for 3D data) mirror_dim : set, default {0} set of integer indices of dimensions that can be mirrored noise_std : float, default 0.05 standard deviation of noise zoom_limits : list, default [0.5, 1.5] list of float zoom limits ([low, high]) masking_probability : float, default 0.1 the probability of masking a joint dim : int, default 2 the dimensionality of the input data **kwargs : dict other parameters for the base transformer class

dim
offset
scale
image_center
blank
rotation_limits
n_bodyparts
mirror_dim
noise_std
zoom_limits
masking_probability