
Kinematic transformer

  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  7Kinematic transformer
 10from typing import Dict, Tuple, List, Set, Iterable
 11import torch
 12import numpy as np
 13from copy import copy, deepcopy
 14from random import getrandbits
 15from collections import defaultdict
 16from dlc2action.transformer.base_transformer import Transformer
 17from dlc2action.utils import rotation_matrix_2d, rotation_matrix_3d
 20class KinematicTransformer(Transformer):
 21    """
 22    A transformer that augments the output of the Kinematic feature extractor
 24    The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'`
 25    """
 27    def __init__(
 28        self,
 29        model_name: str,
 30        augmentations: List = None,
 31        use_default_augmentations: bool = False,
 32        rotation_limits: List = None,
 33        mirror_dim: Set = None,
 34        noise_std: float = 0.05,
 35        zoom_limits: List = None,
 36        masking_probability: float = 0.05,
 37        dim: int = 2,
 38        graph_features: bool = False,
 39        bodyparts_order: List = None,
 40        canvas_shape: List = None,
 41        move_around_image_center: bool = True,
 42        **kwargs,
 43    ) -> None:
 44        """
 45        Parameters
 46        ----------
 47        augmentations : list, optional
 48            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 49        use_default_augmentations : bool, default False
 50            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 51        rotation_limits : list, default [-pi/2, pi/2]
 52            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 53            for 3D data)
 54        mirror_dim : set, default {0}
 55            set of integer indices of dimensions that can be mirrored
 56        noise_std : float, default 0.05
 57            standard deviation of noise
 58        zoom_limits : list, default [0.5, 1.5]
 59            list of float zoom limits ([low, high])
 60        masking_probability : float, default 0.1
 61            the probability of masking a joint
 62        dim : int, default 2
 63            the dimensionality of the input data
 64        **kwargs : dict
 65            other parameters for the base transformer class
 66        """
 68        if augmentations is None:
 69            augmentations = []
 70        if canvas_shape is None:
 71            canvas_shape = [1, 1]
 72        self.dim = int(dim)
 74        self.offset = [0 for _ in range(self.dim)]
 75        self.scale = canvas_shape[1] / canvas_shape[0]
 76        self.image_center = move_around_image_center
 77        # if canvas_shape is None:
 78        #     self.offset = [0.5 for _ in range(self.dim)]
 79        # else:
 80        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 82        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 83        super().__init__(
 84            model_name,
 85            augmentations,
 86            use_default_augmentations,
 87            graph_features=graph_features,
 88            bodyparts_order=bodyparts_order,
 89            **kwargs,
 90        )
 91        if rotation_limits is None:
 92            rotation_limits = [-np.pi / 2, np.pi / 2]
 93        if mirror_dim is None:
 94            mirror_dim = [0]
 95        if zoom_limits is None:
 96            zoom_limits = [0.5, 1.5]
 97        self.rotation_limits = rotation_limits
 98        self.n_bodyparts = None
 99        self.mirror_dim = mirror_dim
100        self.noise_std = noise_std
101        self.zoom_limits = zoom_limits
102        self.masking_probability = masking_probability
104    def _apply_augmentations(
105        self,
106        main_input: Dict,
107        ssl_inputs: List = None,
108        ssl_targets: List = None,
109        augment: bool = False,
110        ssl: bool = False,
111    ) -> Tuple:
113        if ssl_targets is None:
114            ssl_targets = [None]
115        if ssl_inputs is None:
116            ssl_inputs = [None]
118        keys_main = self._get_keys(
119            (
120                "coords",
121                "speed_joints",
122                "speed_direction",
123                "speed_bones",
124                "acc_bones",
125                "bones",
126                "coord_diff",
127            ),
128            main_input,
129        )
130        if len(keys_main) > 0:
131            key = keys_main[0]
132            final_shape = main_input[key].shape
133            # self._get_bodyparts(final_shape)
134            batch = final_shape[0]
135            s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1]))
136            x_shape = s.shape
137            # print(f'{x_shape=}, {self.dim=}')
138            self.n_bodyparts = x_shape[1]
139            if self.dim == 3:
140                if len(self.rotation_limits) == 2:
141                    self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits]
142            dicts = [main_input] + ssl_inputs + ssl_targets
143            for x in dicts:
144                if x is not None:
145                    keys = self._get_keys(
146                        (
147                            "coords",
148                            "speed_joints",
149                            "speed_direction",
150                            "speed_bones",
151                            "acc_bones",
152                            "bones",
153                            "coord_diff",
154                        ),
155                        x,
156                    )
157                    for key in keys:
158                        x[key] = x[key].reshape(x_shape)
159        if len(self._get_keys(("center"), main_input)) > 0:
160            dicts = [main_input] + ssl_inputs + ssl_targets
161            for x in dicts:
162                if x is not None:
163                    key_bases = ["center"]
164                    keys = self._get_keys(
165                        key_bases,
166                        x,
167                    )
168                    for key in keys:
169                        x[key] = x[key].reshape(
170                            (x[key].shape[0], 1, -1, x[key].shape[-1])
171                        )
172        main_input, ssl_inputs, ssl_targets = super()._apply_augmentations(
173            main_input, ssl_inputs, ssl_targets, augment
174        )
175        if len(keys_main) > 0:
176            dicts = [main_input] + ssl_inputs + ssl_targets
177            for x in dicts:
178                if x is not None:
179                    keys = self._get_keys(
180                        (
181                            "coords",
182                            "speed_joints",
183                            "speed_direction",
184                            "speed_bones",
185                            "acc_bones",
186                            "bones",
187                            "coord_diff",
188                        ),
189                        x,
190                    )
191                    for key in keys:
192                        x[key] = x[key].reshape(final_shape)
193        if len(self._get_keys(("center"), main_input)) > 0:
194            dicts = [main_input] + ssl_inputs + ssl_targets
195            for x in dicts:
196                if x is not None:
197                    key_bases = ["center"]
198                    keys = self._get_keys(
199                        key_bases,
200                        x,
201                    )
202                    for key in keys:
203                        x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1]))
204        return main_input, ssl_inputs, ssl_targets
206    def _rotate(
207        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
208    ) -> Tuple[Dict, List, List]:
209        """
210        Rotate the "coords" and "speed_joints" features of the input to a random degree
211        """
213        keys = self._get_keys(
214            (
215                "coords",
216                "bones",
217                "coord_diff",
218                "center",
219                "speed_joints",
220                "speed_direction",
221                "center",
222            ),
223            main_input,
224        )
225        if len(keys) == 0:
226            return main_input, ssl_inputs, ssl_targets
227        batch = main_input[keys[0]].shape[0]
228        if self.dim == 2:
229            angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits)
230            R = rotation_matrix_2d(angles).to(main_input[keys[0]].device)
231        else:
232            angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0])
233            angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1])
234            angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2])
235            R = rotation_matrix_3d(angles_x, angles_y, angles_z).to(
236                main_input[keys[0]].device
237            )
238        dicts = [main_input] + ssl_inputs + ssl_targets
239        for x in dicts:
240            if x is not None:
241                keys = self._get_keys(
242                    (
243                        "coords",
244                        "speed_joints",
245                        "speed_direction",
246                        "speed_bones",
247                        "acc_bones",
248                        "bones",
249                        "coord_diff",
250                        "center",
251                    ),
252                    x,
253                )
254                for key in keys:
255                    if key in x:
256                        mask = x[key] == self.blank
257                        if (
258                            any([key.startswith(x) for x in ["coords", "center"]])
259                            and not self.image_center
260                        ):
261                            offset = x[key].mean(1).unsqueeze(1)
262                        else:
263                            offset = 0
264                        x[key] = (
265                            torch.einsum(
266                                "abjd,aij->abid",
267                                x[key] - offset,
268                                R,
269                            )
270                            + offset
271                        )
272                        x[key][mask] = self.blank
273        return main_input, ssl_inputs, ssl_targets
275    def _mirror(
276        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
277    ) -> Tuple[Dict, List, List]:
278        """
279        Mirror the "coords" and "speed_joints" features of the input randomly
280        """
282        mirror = []
283        for i in range(3):
284            if i in self.mirror_dim and bool(getrandbits(1)):
285                mirror.append(i)
286        if len(mirror) == 0:
287            return main_input, ssl_inputs, ssl_targets
288        dicts = [main_input] + ssl_inputs + ssl_targets
289        for x in dicts:
290            if x is not None:
291                keys = self._get_keys(
292                    (
293                        "coords",
294                        "speed_joints",
295                        "speed_direction",
296                        "speed_bones",
297                        "acc_bones",
298                        "bones",
299                        "center",
300                        "coord_diff",
301                    ),
302                    x,
303                )
304                for key in keys:
305                    if key in x:
306                        mask = x[key] == self.blank
307                        y = deepcopy(x[key])
308                        if not self.image_center and any(
309                            [key.startswith(x) for x in ["coords", "center"]]
310                        ):
311                            mean = y.mean(1).unsqueeze(1)
312                            for dim in mirror:
313                                y[:, :, dim, :] = (
314                                    2 * mean[:, :, dim, :] - y[:, :, dim, :]
315                                )
316                        else:
317                            for dim in mirror:
318                                y[:, :, dim, :] *= -1
319                        x[key] = y
320                        x[key][mask] = self.blank
321        return main_input, ssl_inputs, ssl_targets
323    def _shift(
324        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
325    ) -> Tuple[Dict, List, List]:
326        """
327        Shift the "coords" features of the input randomly
328        """
330        keys = self._get_keys(("coords", "center"), main_input)
331        if len(keys) == 0:
332            return main_input, ssl_inputs, ssl_targets
333        batch = main_input[keys[0]].shape[0]
334        dim = main_input[keys[0]].shape[2]
335        device = main_input[keys[0]].device
336        coords = torch.cat(
337            [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys],
338            dim=1,
339        )
340        minimask = coords[:, :, 0] != self.blank
341        min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0]
342        min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0]
343        max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0]
344        max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0]
345        del coords
346        shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * (
347            max_x - min_x
348        )
349        shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * (
350            max_y - min_y
351        )
352        shift_x = shift_x.unsqueeze(-1).unsqueeze(-1)
353        shift_y = shift_y.unsqueeze(-1).unsqueeze(-1)
354        dicts = [main_input] + ssl_inputs + ssl_targets
355        for x in dicts:
356            if x is not None:
357                keys = self._get_keys(("coords", "center"), x)
358                for key in keys:
359                    y = deepcopy(x[key])
360                    mask = y == self.blank
361                    y[:, :, 0, :] += shift_x
362                    y[:, :, 1, :] += shift_y
363                    x[key] = y
364                    x[key][mask] = self.blank
365        return main_input, ssl_inputs, ssl_targets
367    def _add_noise(
368        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
369    ) -> Tuple[Dict, List, List]:
370        """
371        Add normal noise to all features of the input
372        """
374        dicts = [main_input] + ssl_inputs + ssl_targets
375        for x in dicts:
376            if x is not None:
377                keys = self._get_keys(
378                    (
379                        "coords",
380                        "speed_joints",
381                        "speed_direction",
382                        "intra_distance",
383                        "acc_joints",
384                        "angle_speeds",
385                        "inter_distance",
386                        "speed_bones",
387                        "acc_bones",
388                        "bones",
389                        "coord_diff",
390                        "center",
391                    ),
392                    x,
393                )
394                for key in keys:
395                    mask = x[key] == self.blank
396                    x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_(
397                        std=self.noise_std
398                    ).to(x[key].device)
399                    x[key][mask] = self.blank
400        return main_input, ssl_inputs, ssl_targets
402    def _zoom(
403        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
404    ) -> Tuple[Dict, List, List]:
405        """
406        Add random zoom to all features of the input
407        """
409        key = list(main_input.keys())[0]
410        batch = main_input[key].shape[0]
411        device = main_input[key].device
412        zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device)
413        dicts = [main_input] + ssl_inputs + ssl_targets
414        for x in dicts:
415            if x is not None:
416                keys = self._get_keys(
417                    (
418                        "speed_joints",
419                        "intra_distance",
420                        "acc_joints",
421                        "inter_distance",
422                        "speed_bones",
423                        "acc_bones",
424                        "bones",
425                        "coord_diff",
426                        "center",
427                        "speed_value",
428                    ),
429                    x,
430                )
431                for key in keys:
432                    mask = x[key] == self.blank
433                    shape = np.array(x[key].shape)
434                    shape[1:] = 1
435                    y = deepcopy(x[key])
436                    y *= zoom.reshape(list(shape))
437                    x[key] = y
438                    x[key][mask] = self.blank
439                keys = self._get_keys("coords", x)
440                for key in keys:
441                    mask = x[key] == self.blank
442                    shape = np.array(x[key].shape)
443                    shape[1:] = 1
444                    zoom = zoom.reshape(list(shape))
445                    x[key][mask] = 10
446                    min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1)
447                    min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1)
448                    center = torch.stack([min_x, min_y], dim=2)
449                    coords = (x[key] - center) * zoom + center
450                    x[key] = coords
451                    x[key][mask] = self.blank
453        return main_input, ssl_inputs, ssl_targets
455    def _mask_joints(
456        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
457    ) -> Tuple[Dict, List, List]:
458        """
459        Mask joints randomly
460        """
462        key = list(main_input.keys())[0]
463        batch, *_, frames = main_input[key].shape
464        masked_joints = (
465            torch.FloatTensor(batch, self.n_bodyparts).uniform_()
466            < self.masking_probability
467        )
468        dicts = [main_input] + ssl_inputs + ssl_targets
469        for x in [y for y in dicts if y is not None]:
470            keys = self._get_keys(("intra_distance", "inter_distance"), x)
471            for key in keys:
472                mask = (
473                    masked_joints.repeat(self.n_bodyparts, frames, 1, 1)
474                    .transpose(0, 2)
475                    .transpose(1, 3)
476                )
477                indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1)
479                X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to(
480                    x[key].device
481                )
482                X[:, indices[0], indices[1], :] = x[key]
483                X[mask] = self.blank
484                X[mask.transpose(1, 2)] = self.blank
485                x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames)
486            keys = self._get_keys(
487                (
488                    "speed_joints",
489                    "speed_direction",
490                    "coords",
491                    "acc_joints",
492                    "speed_bones",
493                    "acc_bones",
494                    "bones",
495                    "coord_diff",
496                ),
497                x,
498            )
499            for key in keys:
500                mask = (
501                    masked_joints.repeat(self.dim, frames, 1, 1)
502                    .transpose(0, 2)
503                    .transpose(1, 3)
504                )
505                x[key][mask] = (
506                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask]
507                )
508            keys = self._get_keys("angle_speeds", x)
509            for key in keys:
510                mask = (
511                    masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2)
512                )
513                x[key][mask] = (
514                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask]
515                )
517        return main_input, ssl_inputs, ssl_targets
519    def _switch(
520        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
521    ) -> Tuple[Dict, List, List]:
522        if bool(getrandbits(1)):
523            return main_input, ssl_inputs, ssl_targets
524        individuals = set()
525        ind_dict = defaultdict(lambda: set())
526        for key in main_input:
527            if len(key.split("---")) != 2:
528                continue
529            ind = key.split("---")[1]
530            if "+" in ind:
531                continue
532            individuals.add(ind)
533            ind_dict[ind].add(key.split("---")[0])
534        individuals = list(individuals)
535        if len(individuals) < 2:
536            return main_input, ssl_inputs, ssl_targets
537        for x, y in zip(individuals[::2], individuals[1::2]):
538            for key in ind_dict[x]:
539                if key in ind_dict[y]:
540                    main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = (
541                        main_input[f"{key}---{y}"],
542                        main_input[f"{key}---{x}"],
543                    )
544                for d in ssl_inputs + ssl_targets:
545                    if d is None:
546                        continue
547                    if f"{key}---{x}" in d and f"{key}---{y}" in d:
548                        d[f"{key}---{x}"], d[f"{key}---{y}"] = (
549                            d[f"{key}---{y}"],
550                            d[f"{key}---{x}"],
551                        )
552        return main_input, ssl_inputs, ssl_targets
554    def _get_keys(self, key_bases: Iterable, x: Dict) -> List:
555        """
556        Get the keys of x that start with one of the strings from key_bases
557        """
559        keys = []
560        if isinstance(key_bases, str):
561            key_bases = [key_bases]
562        for key in x:
563            if any([x == key.split("---")[0] for x in key_bases]):
564                keys.append(key)
565        return keys
567    def _augmentations_dict(self) -> Dict:
568        """
569        Get the mapping from augmentation names to functions
570        """
572        return {
573            "mirror": self._mirror,
574            "shift": self._shift,
575            "add_noise": self._add_noise,
576            "zoom": self._zoom,
577            "rotate": self._rotate,
578            "mask": self._mask_joints,
579            "switch": self._switch,
580        }
582    def _default_augmentations(self) -> List:
583        """
584        Get the list of default augmentation names
585        """
587        return ["mirror", "shift", "add_noise"]
589    def _get_bodyparts(self, shape: Tuple) -> None:
590        """
591        Set the number of bodyparts from the data if it is not known
592        """
594        if self.n_bodyparts is None:
595            N, B, F = shape
596            self.n_bodyparts = B // 2
class KinematicTransformer(dlc2action.transformer.base_transformer.Transformer):
 21class KinematicTransformer(Transformer):
 22    """
 23    A transformer that augments the output of the Kinematic feature extractor
 25    The available augmentations are `'rotate'`, `'mirror'`, `'shift'`, `'add_noise'` and `'zoom'`
 26    """
 28    def __init__(
 29        self,
 30        model_name: str,
 31        augmentations: List = None,
 32        use_default_augmentations: bool = False,
 33        rotation_limits: List = None,
 34        mirror_dim: Set = None,
 35        noise_std: float = 0.05,
 36        zoom_limits: List = None,
 37        masking_probability: float = 0.05,
 38        dim: int = 2,
 39        graph_features: bool = False,
 40        bodyparts_order: List = None,
 41        canvas_shape: List = None,
 42        move_around_image_center: bool = True,
 43        **kwargs,
 44    ) -> None:
 45        """
 46        Parameters
 47        ----------
 48        augmentations : list, optional
 49            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 50        use_default_augmentations : bool, default False
 51            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 52        rotation_limits : list, default [-pi/2, pi/2]
 53            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 54            for 3D data)
 55        mirror_dim : set, default {0}
 56            set of integer indices of dimensions that can be mirrored
 57        noise_std : float, default 0.05
 58            standard deviation of noise
 59        zoom_limits : list, default [0.5, 1.5]
 60            list of float zoom limits ([low, high])
 61        masking_probability : float, default 0.1
 62            the probability of masking a joint
 63        dim : int, default 2
 64            the dimensionality of the input data
 65        **kwargs : dict
 66            other parameters for the base transformer class
 67        """
 69        if augmentations is None:
 70            augmentations = []
 71        if canvas_shape is None:
 72            canvas_shape = [1, 1]
 73        self.dim = int(dim)
 75        self.offset = [0 for _ in range(self.dim)]
 76        self.scale = canvas_shape[1] / canvas_shape[0]
 77        self.image_center = move_around_image_center
 78        # if canvas_shape is None:
 79        #     self.offset = [0.5 for _ in range(self.dim)]
 80        # else:
 81        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 83        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 84        super().__init__(
 85            model_name,
 86            augmentations,
 87            use_default_augmentations,
 88            graph_features=graph_features,
 89            bodyparts_order=bodyparts_order,
 90            **kwargs,
 91        )
 92        if rotation_limits is None:
 93            rotation_limits = [-np.pi / 2, np.pi / 2]
 94        if mirror_dim is None:
 95            mirror_dim = [0]
 96        if zoom_limits is None:
 97            zoom_limits = [0.5, 1.5]
 98        self.rotation_limits = rotation_limits
 99        self.n_bodyparts = None
100        self.mirror_dim = mirror_dim
101        self.noise_std = noise_std
102        self.zoom_limits = zoom_limits
103        self.masking_probability = masking_probability
105    def _apply_augmentations(
106        self,
107        main_input: Dict,
108        ssl_inputs: List = None,
109        ssl_targets: List = None,
110        augment: bool = False,
111        ssl: bool = False,
112    ) -> Tuple:
114        if ssl_targets is None:
115            ssl_targets = [None]
116        if ssl_inputs is None:
117            ssl_inputs = [None]
119        keys_main = self._get_keys(
120            (
121                "coords",
122                "speed_joints",
123                "speed_direction",
124                "speed_bones",
125                "acc_bones",
126                "bones",
127                "coord_diff",
128            ),
129            main_input,
130        )
131        if len(keys_main) > 0:
132            key = keys_main[0]
133            final_shape = main_input[key].shape
134            # self._get_bodyparts(final_shape)
135            batch = final_shape[0]
136            s = main_input[key].reshape((batch, -1, self.dim, final_shape[-1]))
137            x_shape = s.shape
138            # print(f'{x_shape=}, {self.dim=}')
139            self.n_bodyparts = x_shape[1]
140            if self.dim == 3:
141                if len(self.rotation_limits) == 2:
142                    self.rotation_limits = [[0, 0], [0, 0], self.rotation_limits]
143            dicts = [main_input] + ssl_inputs + ssl_targets
144            for x in dicts:
145                if x is not None:
146                    keys = self._get_keys(
147                        (
148                            "coords",
149                            "speed_joints",
150                            "speed_direction",
151                            "speed_bones",
152                            "acc_bones",
153                            "bones",
154                            "coord_diff",
155                        ),
156                        x,
157                    )
158                    for key in keys:
159                        x[key] = x[key].reshape(x_shape)
160        if len(self._get_keys(("center"), main_input)) > 0:
161            dicts = [main_input] + ssl_inputs + ssl_targets
162            for x in dicts:
163                if x is not None:
164                    key_bases = ["center"]
165                    keys = self._get_keys(
166                        key_bases,
167                        x,
168                    )
169                    for key in keys:
170                        x[key] = x[key].reshape(
171                            (x[key].shape[0], 1, -1, x[key].shape[-1])
172                        )
173        main_input, ssl_inputs, ssl_targets = super()._apply_augmentations(
174            main_input, ssl_inputs, ssl_targets, augment
175        )
176        if len(keys_main) > 0:
177            dicts = [main_input] + ssl_inputs + ssl_targets
178            for x in dicts:
179                if x is not None:
180                    keys = self._get_keys(
181                        (
182                            "coords",
183                            "speed_joints",
184                            "speed_direction",
185                            "speed_bones",
186                            "acc_bones",
187                            "bones",
188                            "coord_diff",
189                        ),
190                        x,
191                    )
192                    for key in keys:
193                        x[key] = x[key].reshape(final_shape)
194        if len(self._get_keys(("center"), main_input)) > 0:
195            dicts = [main_input] + ssl_inputs + ssl_targets
196            for x in dicts:
197                if x is not None:
198                    key_bases = ["center"]
199                    keys = self._get_keys(
200                        key_bases,
201                        x,
202                    )
203                    for key in keys:
204                        x[key] = x[key].reshape((x[key].shape[0], -1, x[key].shape[-1]))
205        return main_input, ssl_inputs, ssl_targets
207    def _rotate(
208        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
209    ) -> Tuple[Dict, List, List]:
210        """
211        Rotate the "coords" and "speed_joints" features of the input to a random degree
212        """
214        keys = self._get_keys(
215            (
216                "coords",
217                "bones",
218                "coord_diff",
219                "center",
220                "speed_joints",
221                "speed_direction",
222                "center",
223            ),
224            main_input,
225        )
226        if len(keys) == 0:
227            return main_input, ssl_inputs, ssl_targets
228        batch = main_input[keys[0]].shape[0]
229        if self.dim == 2:
230            angles = torch.FloatTensor(batch).uniform_(*self.rotation_limits)
231            R = rotation_matrix_2d(angles).to(main_input[keys[0]].device)
232        else:
233            angles_x = torch.FloatTensor(batch).uniform_(*self.rotation_limits[0])
234            angles_y = torch.FloatTensor(batch).uniform_(*self.rotation_limits[1])
235            angles_z = torch.FloatTensor(batch).uniform_(*self.rotation_limits[2])
236            R = rotation_matrix_3d(angles_x, angles_y, angles_z).to(
237                main_input[keys[0]].device
238            )
239        dicts = [main_input] + ssl_inputs + ssl_targets
240        for x in dicts:
241            if x is not None:
242                keys = self._get_keys(
243                    (
244                        "coords",
245                        "speed_joints",
246                        "speed_direction",
247                        "speed_bones",
248                        "acc_bones",
249                        "bones",
250                        "coord_diff",
251                        "center",
252                    ),
253                    x,
254                )
255                for key in keys:
256                    if key in x:
257                        mask = x[key] == self.blank
258                        if (
259                            any([key.startswith(x) for x in ["coords", "center"]])
260                            and not self.image_center
261                        ):
262                            offset = x[key].mean(1).unsqueeze(1)
263                        else:
264                            offset = 0
265                        x[key] = (
266                            torch.einsum(
267                                "abjd,aij->abid",
268                                x[key] - offset,
269                                R,
270                            )
271                            + offset
272                        )
273                        x[key][mask] = self.blank
274        return main_input, ssl_inputs, ssl_targets
276    def _mirror(
277        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
278    ) -> Tuple[Dict, List, List]:
279        """
280        Mirror the "coords" and "speed_joints" features of the input randomly
281        """
283        mirror = []
284        for i in range(3):
285            if i in self.mirror_dim and bool(getrandbits(1)):
286                mirror.append(i)
287        if len(mirror) == 0:
288            return main_input, ssl_inputs, ssl_targets
289        dicts = [main_input] + ssl_inputs + ssl_targets
290        for x in dicts:
291            if x is not None:
292                keys = self._get_keys(
293                    (
294                        "coords",
295                        "speed_joints",
296                        "speed_direction",
297                        "speed_bones",
298                        "acc_bones",
299                        "bones",
300                        "center",
301                        "coord_diff",
302                    ),
303                    x,
304                )
305                for key in keys:
306                    if key in x:
307                        mask = x[key] == self.blank
308                        y = deepcopy(x[key])
309                        if not self.image_center and any(
310                            [key.startswith(x) for x in ["coords", "center"]]
311                        ):
312                            mean = y.mean(1).unsqueeze(1)
313                            for dim in mirror:
314                                y[:, :, dim, :] = (
315                                    2 * mean[:, :, dim, :] - y[:, :, dim, :]
316                                )
317                        else:
318                            for dim in mirror:
319                                y[:, :, dim, :] *= -1
320                        x[key] = y
321                        x[key][mask] = self.blank
322        return main_input, ssl_inputs, ssl_targets
324    def _shift(
325        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
326    ) -> Tuple[Dict, List, List]:
327        """
328        Shift the "coords" features of the input randomly
329        """
331        keys = self._get_keys(("coords", "center"), main_input)
332        if len(keys) == 0:
333            return main_input, ssl_inputs, ssl_targets
334        batch = main_input[keys[0]].shape[0]
335        dim = main_input[keys[0]].shape[2]
336        device = main_input[keys[0]].device
337        coords = torch.cat(
338            [main_input[key].transpose(-1, -2).reshape(batch, -1, dim) for key in keys],
339            dim=1,
340        )
341        minimask = coords[:, :, 0] != self.blank
342        min_x = -0.5 - torch.min(coords[:, :, 0], dim=-1)[0]
343        min_y = -0.5 * self.scale - torch.min(coords[:, :, 1], dim=-1)[0]
344        max_x = 0.5 - torch.max(coords[:, :, 0][minimask], dim=-1)[0]
345        max_y = 0.5 * self.scale - torch.max(coords[:, :, 1][minimask], dim=-1)[0]
346        del coords
347        shift_x = min_x + torch.FloatTensor(batch).uniform_().to(device) * (
348            max_x - min_x
349        )
350        shift_y = min_y + torch.FloatTensor(batch).uniform_().to(device) * (
351            max_y - min_y
352        )
353        shift_x = shift_x.unsqueeze(-1).unsqueeze(-1)
354        shift_y = shift_y.unsqueeze(-1).unsqueeze(-1)
355        dicts = [main_input] + ssl_inputs + ssl_targets
356        for x in dicts:
357            if x is not None:
358                keys = self._get_keys(("coords", "center"), x)
359                for key in keys:
360                    y = deepcopy(x[key])
361                    mask = y == self.blank
362                    y[:, :, 0, :] += shift_x
363                    y[:, :, 1, :] += shift_y
364                    x[key] = y
365                    x[key][mask] = self.blank
366        return main_input, ssl_inputs, ssl_targets
368    def _add_noise(
369        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
370    ) -> Tuple[Dict, List, List]:
371        """
372        Add normal noise to all features of the input
373        """
375        dicts = [main_input] + ssl_inputs + ssl_targets
376        for x in dicts:
377            if x is not None:
378                keys = self._get_keys(
379                    (
380                        "coords",
381                        "speed_joints",
382                        "speed_direction",
383                        "intra_distance",
384                        "acc_joints",
385                        "angle_speeds",
386                        "inter_distance",
387                        "speed_bones",
388                        "acc_bones",
389                        "bones",
390                        "coord_diff",
391                        "center",
392                    ),
393                    x,
394                )
395                for key in keys:
396                    mask = x[key] == self.blank
397                    x[key] = x[key] + torch.FloatTensor(x[key].shape).normal_(
398                        std=self.noise_std
399                    ).to(x[key].device)
400                    x[key][mask] = self.blank
401        return main_input, ssl_inputs, ssl_targets
403    def _zoom(
404        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
405    ) -> Tuple[Dict, List, List]:
406        """
407        Add random zoom to all features of the input
408        """
410        key = list(main_input.keys())[0]
411        batch = main_input[key].shape[0]
412        device = main_input[key].device
413        zoom = torch.FloatTensor(batch).uniform_(*self.zoom_limits).to(device)
414        dicts = [main_input] + ssl_inputs + ssl_targets
415        for x in dicts:
416            if x is not None:
417                keys = self._get_keys(
418                    (
419                        "speed_joints",
420                        "intra_distance",
421                        "acc_joints",
422                        "inter_distance",
423                        "speed_bones",
424                        "acc_bones",
425                        "bones",
426                        "coord_diff",
427                        "center",
428                        "speed_value",
429                    ),
430                    x,
431                )
432                for key in keys:
433                    mask = x[key] == self.blank
434                    shape = np.array(x[key].shape)
435                    shape[1:] = 1
436                    y = deepcopy(x[key])
437                    y *= zoom.reshape(list(shape))
438                    x[key] = y
439                    x[key][mask] = self.blank
440                keys = self._get_keys("coords", x)
441                for key in keys:
442                    mask = x[key] == self.blank
443                    shape = np.array(x[key].shape)
444                    shape[1:] = 1
445                    zoom = zoom.reshape(list(shape))
446                    x[key][mask] = 10
447                    min_x = x[key][:, :, 0, :].min(1)[0].unsqueeze(1)
448                    min_y = x[key][:, :, 1, :].min(1)[0].unsqueeze(1)
449                    center = torch.stack([min_x, min_y], dim=2)
450                    coords = (x[key] - center) * zoom + center
451                    x[key] = coords
452                    x[key][mask] = self.blank
454        return main_input, ssl_inputs, ssl_targets
456    def _mask_joints(
457        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
458    ) -> Tuple[Dict, List, List]:
459        """
460        Mask joints randomly
461        """
463        key = list(main_input.keys())[0]
464        batch, *_, frames = main_input[key].shape
465        masked_joints = (
466            torch.FloatTensor(batch, self.n_bodyparts).uniform_()
467            < self.masking_probability
468        )
469        dicts = [main_input] + ssl_inputs + ssl_targets
470        for x in [y for y in dicts if y is not None]:
471            keys = self._get_keys(("intra_distance", "inter_distance"), x)
472            for key in keys:
473                mask = (
474                    masked_joints.repeat(self.n_bodyparts, frames, 1, 1)
475                    .transpose(0, 2)
476                    .transpose(1, 3)
477                )
478                indices = torch.triu_indices(self.n_bodyparts, self.n_bodyparts, 1)
480                X = torch.zeros((batch, self.n_bodyparts, self.n_bodyparts, frames)).to(
481                    x[key].device
482                )
483                X[:, indices[0], indices[1], :] = x[key]
484                X[mask] = self.blank
485                X[mask.transpose(1, 2)] = self.blank
486                x[key] = X[:, indices[0], indices[1], :].reshape(batch, -1, frames)
487            keys = self._get_keys(
488                (
489                    "speed_joints",
490                    "speed_direction",
491                    "coords",
492                    "acc_joints",
493                    "speed_bones",
494                    "acc_bones",
495                    "bones",
496                    "coord_diff",
497                ),
498                x,
499            )
500            for key in keys:
501                mask = (
502                    masked_joints.repeat(self.dim, frames, 1, 1)
503                    .transpose(0, 2)
504                    .transpose(1, 3)
505                )
506                x[key][mask] = (
507                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1, 1)[mask]
508                )
509            keys = self._get_keys("angle_speeds", x)
510            for key in keys:
511                mask = (
512                    masked_joints.repeat(frames, 1, 1).transpose(0, 1).transpose(1, 2)
513                )
514                x[key][mask] = (
515                    x[key].mean(1).unsqueeze(1).repeat(1, x[key].shape[1], 1)[mask]
516                )
518        return main_input, ssl_inputs, ssl_targets
520    def _switch(
521        self, main_input: Dict, ssl_inputs: List, ssl_targets: List
522    ) -> Tuple[Dict, List, List]:
523        if bool(getrandbits(1)):
524            return main_input, ssl_inputs, ssl_targets
525        individuals = set()
526        ind_dict = defaultdict(lambda: set())
527        for key in main_input:
528            if len(key.split("---")) != 2:
529                continue
530            ind = key.split("---")[1]
531            if "+" in ind:
532                continue
533            individuals.add(ind)
534            ind_dict[ind].add(key.split("---")[0])
535        individuals = list(individuals)
536        if len(individuals) < 2:
537            return main_input, ssl_inputs, ssl_targets
538        for x, y in zip(individuals[::2], individuals[1::2]):
539            for key in ind_dict[x]:
540                if key in ind_dict[y]:
541                    main_input[f"{key}---{x}"], main_input[f"{key}---{y}"] = (
542                        main_input[f"{key}---{y}"],
543                        main_input[f"{key}---{x}"],
544                    )
545                for d in ssl_inputs + ssl_targets:
546                    if d is None:
547                        continue
548                    if f"{key}---{x}" in d and f"{key}---{y}" in d:
549                        d[f"{key}---{x}"], d[f"{key}---{y}"] = (
550                            d[f"{key}---{y}"],
551                            d[f"{key}---{x}"],
552                        )
553        return main_input, ssl_inputs, ssl_targets
555    def _get_keys(self, key_bases: Iterable, x: Dict) -> List:
556        """
557        Get the keys of x that start with one of the strings from key_bases
558        """
560        keys = []
561        if isinstance(key_bases, str):
562            key_bases = [key_bases]
563        for key in x:
564            if any([x == key.split("---")[0] for x in key_bases]):
565                keys.append(key)
566        return keys
568    def _augmentations_dict(self) -> Dict:
569        """
570        Get the mapping from augmentation names to functions
571        """
573        return {
574            "mirror": self._mirror,
575            "shift": self._shift,
576            "add_noise": self._add_noise,
577            "zoom": self._zoom,
578            "rotate": self._rotate,
579            "mask": self._mask_joints,
580            "switch": self._switch,
581        }
583    def _default_augmentations(self) -> List:
584        """
585        Get the list of default augmentation names
586        """
588        return ["mirror", "shift", "add_noise"]
590    def _get_bodyparts(self, shape: Tuple) -> None:
591        """
592        Set the number of bodyparts from the data if it is not known
593        """
595        if self.n_bodyparts is None:
596            N, B, F = shape
597            self.n_bodyparts = B // 2

A transformer that augments the output of the Kinematic feature extractor

The available augmentations are 'rotate', 'mirror', 'shift', 'add_noise' and 'zoom'

KinematicTransformer( model_name: str, augmentations: List = None, use_default_augmentations: bool = False, rotation_limits: List = None, mirror_dim: Set = None, noise_std: float = 0.05, zoom_limits: List = None, masking_probability: float = 0.05, dim: int = 2, graph_features: bool = False, bodyparts_order: List = None, canvas_shape: List = None, move_around_image_center: bool = True, **kwargs)
 28    def __init__(
 29        self,
 30        model_name: str,
 31        augmentations: List = None,
 32        use_default_augmentations: bool = False,
 33        rotation_limits: List = None,
 34        mirror_dim: Set = None,
 35        noise_std: float = 0.05,
 36        zoom_limits: List = None,
 37        masking_probability: float = 0.05,
 38        dim: int = 2,
 39        graph_features: bool = False,
 40        bodyparts_order: List = None,
 41        canvas_shape: List = None,
 42        move_around_image_center: bool = True,
 43        **kwargs,
 44    ) -> None:
 45        """
 46        Parameters
 47        ----------
 48        augmentations : list, optional
 49            list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom")
 50        use_default_augmentations : bool, default False
 51            if `True` and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 52        rotation_limits : list, default [-pi/2, pi/2]
 53            list of float rotation angle limits (`[low, high]``, or `[[low_x, high_x], [low_y, high_y], [low_z, high_z]]`
 54            for 3D data)
 55        mirror_dim : set, default {0}
 56            set of integer indices of dimensions that can be mirrored
 57        noise_std : float, default 0.05
 58            standard deviation of noise
 59        zoom_limits : list, default [0.5, 1.5]
 60            list of float zoom limits ([low, high])
 61        masking_probability : float, default 0.1
 62            the probability of masking a joint
 63        dim : int, default 2
 64            the dimensionality of the input data
 65        **kwargs : dict
 66            other parameters for the base transformer class
 67        """
 69        if augmentations is None:
 70            augmentations = []
 71        if canvas_shape is None:
 72            canvas_shape = [1, 1]
 73        self.dim = int(dim)
 75        self.offset = [0 for _ in range(self.dim)]
 76        self.scale = canvas_shape[1] / canvas_shape[0]
 77        self.image_center = move_around_image_center
 78        # if canvas_shape is None:
 79        #     self.offset = [0.5 for _ in range(self.dim)]
 80        # else:
 81        #     self.offset = [0.5 * canvas_shape[i] / canvas_shape[0] for i in range(self.dim)]
 83        self.blank = 0  # the value that nan values are set to (shouldn't be changed in augmentations)
 84        super().__init__(
 85            model_name,
 86            augmentations,
 87            use_default_augmentations,
 88            graph_features=graph_features,
 89            bodyparts_order=bodyparts_order,
 90            **kwargs,
 91        )
 92        if rotation_limits is None:
 93            rotation_limits = [-np.pi / 2, np.pi / 2]
 94        if mirror_dim is None:
 95            mirror_dim = [0]
 96        if zoom_limits is None:
 97            zoom_limits = [0.5, 1.5]
 98        self.rotation_limits = rotation_limits
 99        self.n_bodyparts = None
100        self.mirror_dim = mirror_dim
101        self.noise_std = noise_std
102        self.zoom_limits = zoom_limits
103        self.masking_probability = masking_probability


augmentations : list, optional list of augmentation names to use ("rotate", "mirror", "shift", "add_noise", "zoom") use_default_augmentations : bool, default False if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations rotation_limits : list, default [-pi/2, pi/2] list of float rotation angle limits ([low, high]``, or[[low_x, high_x], [low_y, high_y], [low_z, high_z]]` for 3D data) mirror_dim : set, default {0} set of integer indices of dimensions that can be mirrored noise_std : float, default 0.05 standard deviation of noise zoom_limits : list, default [0.5, 1.5] list of float zoom limits ([low, high]) masking_probability : float, default 0.1 the probability of masking a joint dim : int, default 2 the dimensionality of the input data **kwargs : dict other parameters for the base transformer class