dlc2action.transformer.base_transformer

Abstract parent class for transformers

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""
  8Abstract parent class for transformers
  9"""
 10
 11from typing import Dict, List, Callable, Union, Tuple
 12import torch
 13from abc import ABC, abstractmethod
 14from dlc2action.utils import TensorList
 15from copy import deepcopy
 16from matplotlib import pyplot as plt
 17
 18
 19class Transformer(ABC):
 20    """
 21    A base class for all transformers
 22
 23    A transformer should apply augmentations and generate model input and training target tensors.
 24
 25    All augmentation functions need to take `(main_input: dict, ssl_inputs: list, ssl_targets: list)`
 26    as input and return an output of the same format. Here `main_input` is a feature dictionary of the sample
 27    data, `ssl_inputs` is a list of SSL input feature dictionaries and `ssl_targets` is a list of SSL target
 28    feature dictionaries. The same augmentations are applied to all inputs and then `None` values are replaced
 29    according to the rules set by `keep_target_none` and `generate_ssl_input` parameters and the feature
 30    dictionaries are compiled into tensors.
 31    """
 32
 33    def __init__(
 34        self,
 35        model_name: str,
 36        augmentations: List = None,
 37        use_default_augmentations: bool = False,
 38        generate_ssl_input: List = None,
 39        keep_target_none: List = None,
 40        ssl_augmentations: List = None,
 41        graph_features: bool = False,
 42        bodyparts_order: List = None,
 43    ) -> None:
 44        """
 45        Parameters
 46        ----------
 47        augmentations : list, optional
 48            a list of string names of augmentations to use (if not provided, either no augmentations are applied or
 49            (if use_default_augmentations is True) a default list is used
 50        use_default_augmentations : bool, default False
 51            if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 52        generate_ssl_input : list, optional
 53            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
 54            is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to
 55            `False` for each module)
 56        keep_target_none : list, optional
 57            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
 58            is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults
 59            to `True` for each module)
 60        ssl_augmentations : list, optional
 61            a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True)
 62            (if not provided, defaults to the main augmentations list)
 63        graph_features : bool, default False
 64            if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)`
 65        bodyparts_order : list, optional
 66            a list of bodypart names, optional
 67        """
 68
 69        if augmentations is None:
 70            augmentations = []
 71        if generate_ssl_input is None:
 72            generate_ssl_input = [None]
 73        if keep_target_none is None:
 74            keep_target_none = [None]
 75        self.model_name = model_name
 76        self.augmentations = augmentations
 77        self.generate_ssl_input = generate_ssl_input
 78        self.keep_target_none = keep_target_none
 79        if len(self.augmentations) == 0 and use_default_augmentations:
 80            self.augmentations = self._default_augmentations()
 81        if ssl_augmentations is None:
 82            ssl_augmentations = self.augmentations
 83        self.ssl_augmentations = ssl_augmentations
 84        self.graph_features = graph_features
 85        self.num_graph_nodes = (
 86            len(bodyparts_order) if bodyparts_order is not None else None
 87        )
 88        self._check_augmentations(self.augmentations)
 89        self._check_augmentations(self.ssl_augmentations)
 90
 91    def transform(
 92        self,
 93        main_input: Dict,
 94        ssl_inputs: List = None,
 95        ssl_targets: List = None,
 96        augment: bool = False,
 97        subsample: List = None,
 98    ) -> Tuple:
 99        """
100        Apply augmentations and generate tensors from feature dictionaries.
101
102        The same augmentations are applied to all the inputs (if they have the features known to the transformer).
103
104        If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as
105        another augmentation of main_input.
106        Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`.
107        All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to
108        a model.
109
110        Parameters
111        ----------
112        main_input : dict
113            the feature dictionary of the main input
114        ssl_inputs : list, optional
115            a list of feature dictionaries of SSL inputs (some or all can be None)
116        ssl_targets : list, optional
117            a list of feature dictionaries of SSL targets (some or all can be None)
118        augment : bool, default True
119            if True, augmentations are applied
120        subsample : list, optional
121            a list of indices to subsample the input tensors (if not provided, no subsampling is applied)
122
123        Returns
124        -------
125        main_input : torch.Tensor
126            the augmented tensor of the main input
127        ssl_inputs : list, optional
128            a list of augmented tensors of SSL inputs (some or all can be None)
129        ssl_targets : list, optional
130            a list of augmented tensors of SSL targets (some or all can be None)
131
132        """
133        if subsample is not None:
134            original_len = list(main_input.values())[0].shape[-1]
135            for key in main_input:
136                main_input[key] = main_input[key][..., subsample]
137            # subsample_ssl = sorted(random.sample(range(original_len), len(subsample)))
138            for x in ssl_inputs + ssl_targets:
139                if x is not None:
140                    for key in x:
141                        if len(x[key].shape) == 3 and x[key].shape[-1] == original_len:
142                            x[key] = x[key][..., subsample]
143        main_input, ssl_inputs, ssl_targets = self._apply_augmentations(
144            main_input, ssl_inputs, ssl_targets, augment
145        )
146        meta = [None for _ in ssl_inputs]
147        for i, x in enumerate(ssl_inputs):
148            if type(x) is tuple:
149                x, meta_x = x
150                meta[i] = meta_x
151                ssl_inputs[i] = x
152        for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input):
153            if ssl_x is None and generate:
154                ssl_inputs[i] = self._apply_augmentations(
155                    deepcopy(main_input), None, None, augment=True, ssl=True
156                )[0]
157        output = []
158        num_ssl = len(ssl_inputs)
159        dicts = [main_input] + ssl_inputs + ssl_targets
160
161        for x in dicts:
162            if x is None:
163                output.append(None)
164            else:
165                output.append(self._make_tensor(x, self.model_name))
166                # output.append(
167                #     torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
168                # )
169        main_input, ssl_inputs, ssl_targets = (
170            output[0],
171            output[1 : num_ssl + 1],
172            output[num_ssl + 1 :],
173        )
174        for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none):
175            if not keep and ssl_x is None:
176                ssl_targets[i] = main_input
177        for i, meta_x in enumerate(meta):
178            if meta_x is not None:
179                ssl_inputs[i] = (ssl_inputs[i], meta_x)
180        return main_input, ssl_inputs, ssl_targets
181
182    @abstractmethod
183    def _augmentations_dict(self) -> Dict:
184        """
185        Return a dictionary of possible augmentations
186
187        The keys are augmentation names and the values are the corresponding functions
188
189        Returns
190        -------
191        augmentations_dict : dict
192            a dictionary of augmentation functions (each function needs to take
193            `(main_input: dict, ssl_inputs: list, ssl_targets: list)` as input and return an output
194            of the same format)
195        """
196
197    @abstractmethod
198    def _default_augmentations(self) -> List:
199        """
200        Return a list of default augmentation names
201
202        In case an augmentation list is not provided to the class constructor and use_default_augmentations is `True`,
203        this function is called to set the augmentations parameter. The elements of the list have to be keys of the
204        dictionary returned by `self._augmentations_dict()`
205
206        Returns
207        -------
208        default_augmentations : list
209            a list of string names of the default augmentation functions
210        """
211
212    def _check_augmentations(self, augmentations: List) -> None:
213        """
214        Check the validity of an augmentations list
215        """
216
217        for aug_name in augmentations:
218            if aug_name not in self._augmentations_dict():
219                raise ValueError(
220                    f"The {aug_name} augmentation is not possible in this augmentor! Please choose from {list(self._augmentations_dict().keys())}"
221                )
222
223    def _get_augmentation(self, aug_name: str) -> Callable:
224        """
225        Return the augmentation specified by `aug_name`
226        """
227
228        return self._augmentations_dict()[aug_name]
229
230    def _visualize(self, main_input, title):
231        coord_keys = [x for x in main_input.keys() if x.split("---")[0] == "coords"]
232        if len(coord_keys) > 0:
233            coords = [main_input[key][0, :, :, 0].detach().cpu() for key in coord_keys]
234            centers = [None]
235        else:
236            coord_keys = [
237                x for x in main_input.keys() if x.split("---")[0] == "coord_diff"
238            ]
239            coords = []
240            centers = []
241            for coord_diff_key in coord_keys:
242                if len(coord_diff_key.split("---")) == 2:
243                    ind = coord_diff_key.split("---")[1]
244                    center_key = f"center---{ind}"
245                else:
246                    center_key = "center"
247                    ind = ""
248                if center_key in main_input.keys():
249                    coords.append(
250                        main_input[center_key][0, :, :, 0].detach().cpu()
251                        + main_input[coord_diff_key][0, :, :, 0].detach().cpu()
252                    )
253                    title += f", {ind}: {main_input[center_key][0, 0, :, 0].data}"
254                    center = main_input[center_key][0, :, :, 0].detach().cpu()
255                else:
256                    coords.append(main_input[coord_diff_key][0, :, :, 0].detach().cpu())
257                    center = None
258                centers.append(center)
259        colors = ["blue", "orange", "green", "purple", "pink"]
260        if coords[0].shape[1] == 2:
261            plt.figure(figsize=(15, 15))
262        else:
263            fig = plt.figure(figsize=(15, 15))
264            ax = fig.add_subplot(projection="3d")
265        for i, coord in enumerate(coords):
266            if coord.shape[1] == 2:
267                plt.scatter(coord[:, 0], coord[:, 1], color=colors[i])
268                plt.xlim((-0.5, 0.5))
269                plt.ylim((-0.5, 0.5))
270            else:
271                ax.scatter(
272                    coord[:, 0].detach().cpu(),
273                    coord[:, 1].detach().cpu(),
274                    coord[:, 2].detach().cpu(),
275                    color=colors[i],
276                )
277                if centers[i] is not None:
278                    ax.scatter(
279                        centers[i][:, 0].detach().cpu(),
280                        centers[i][:, 1].detach().cpu(),
281                        0,
282                        color="red",
283                        s=30,
284                    )
285                    center = centers[i][0].detach().cpu()
286                    ax.text(
287                        center[0], center[1], 0, f"({center[0]:.2f}, {center[1]:.2f})"
288                    )
289                for i in [1, 8]:
290                    ax.scatter(
291                        coord[i : i + 1, 0].detach().cpu(),
292                        coord[i : i + 1, 1].detach().cpu(),
293                        coord[i : i + 1, 2].detach().cpu(),
294                        color="purple",
295                    )
296                    ax.text(
297                        coord[i, 0],
298                        coord[i, 1],
299                        coord[i, 2],
300                        f"({coord[i, 0]:.2f}, {coord[i, 1]:.2f}, {coord[i, 2]:.2f})",
301                    )
302
303        plt.xlim((-3, 3))
304        plt.ylim((-3, 3))
305        ax.set_zlim(-2, 4)
306        ax.set_xlabel("x")
307        ax.set_ylabel("y")
308        ax.set_zlabel("z")
309        plt.title(title)
310        plt.show()
311
312    def _apply_augmentations(
313        self,
314        main_input: Dict,
315        ssl_inputs: List = None,
316        ssl_targets: List = None,
317        augment: bool = False,
318        ssl: bool = False,
319    ) -> Tuple:
320        """
321        Apply the augmentations
322
323        The same augmentations are applied to all inputs
324        """
325
326        visualize = False
327        if ssl:
328            augmentations = self.ssl_augmentations
329        else:
330            augmentations = self.augmentations
331        if visualize:
332            self._visualize(main_input, "before")
333        if ssl_inputs is None:
334            ssl_inputs = [None]
335        if ssl_targets is None:
336            ssl_targets = [None]
337        if augment:
338            for aug_name in augmentations:
339                augment_func = self._get_augmentation(aug_name)
340                main_input, ssl_inputs, ssl_targets = augment_func(
341                    main_input, ssl_inputs, ssl_targets
342                )
343                if visualize:
344                    self._visualize(main_input, aug_name)
345        return main_input, ssl_inputs, ssl_targets
346
347    def _make_tensor(self, x: Dict, model_name: str) -> Union[torch.Tensor, TensorList]:
348        """
349        Turn a feature dictionary into a tensor or a `dlc2action.utils.TensorList` object
350        """
351
352        if model_name == "ms_tcn_p":
353            keys = sorted(list(x.keys()))
354            groups = [key.split("---")[-1] for key in keys]
355            unique_groups = sorted(set(groups))
356            tensor = TensorList()
357            for group in unique_groups:
358                if not self.graph_features:
359                    tensor.append(
360                        torch.cat(
361                            [x[key] for key, g in zip(keys, groups) if g == group],
362                            dim=2,
363                        )
364                    )
365                else:
366                    tensor.append(
367                        torch.cat(
368                            [
369                                x[key].reshape(
370                                    (
371                                        x[key].shape[0],
372                                        self.num_graph_nodes,
373                                        -1,
374                                        x[key].shape[-1],
375                                    )
376                                )
377                                for key, g in zip(keys, groups)
378                                if g == group
379                            ],
380                            dim=2,
381                        )
382                    )
383                    tensor[-1] = tensor[-1].reshape(
384                        (tensor[-1].shape[0], -1, tensor[-1].shape[-1])
385                    )
386            if "loaded" in x:
387                tensor.append(x["loaded"])
388        elif model_name == "c2f_tcn_p":
389            keys = sorted(
390                [
391                    key
392                    for key in x.keys()
393                    if len(key.split("---")) != 1
394                    and len(key.split("---")[-1].split("+")) != 2
395                ]
396            )
397            inds = [key.split("---")[-1] for key in keys]
398            unique_inds = sorted(set(inds))
399            tensor = TensorList()
400            for ind in unique_inds:
401                if not self.graph_features:
402                    tensor.append(
403                        torch.cat(
404                            [x[key] for key, g in zip(keys, inds) if g == ind],
405                            dim=1,
406                        )
407                    )
408                else:
409                    tensor.append(
410                        torch.cat(
411                            [
412                                x[key].reshape(
413                                    (
414                                        x[key].shape[0],
415                                        self.num_graph_nodes,
416                                        -1,
417                                        x[key].shape[-1],
418                                    )
419                                )
420                                for key, g in zip(keys, inds)
421                                if g == ind
422                            ],
423                            dim=1,
424                        )
425                    )
426                    tensor[-1] = tensor[-1].reshape(
427                        (tensor[-1].shape[0], -1, tensor[-1].shape[-1])
428                    )
429        elif model_name == "c3d_a":
430            tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
431        else:
432            if not self.graph_features:
433                tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
434            else:
435                tensor = torch.cat(
436                    [
437                        x[key].reshape(
438                            (
439                                x[key].shape[0],
440                                self.num_graph_nodes,
441                                -1,
442                                x[key].shape[-1],
443                            )
444                        )
445                        for key in sorted(list(x.keys()))
446                    ],
447                    dim=2,
448                )
449                tensor = tensor.reshape((tensor.shape[0], -1, tensor.shape[-1]))
450        return tensor
451
452
453class EmptyTransformer(Transformer):
454    """
455    Empty transformer class that does not apply augmentations
456    """
457
458    def _augmentations_dict(self) -> Dict:
459        """
460        Return a dictionary of possible augmentations
461
462        The keys are augmentation names and the values are the corresponding functions
463
464        Returns
465        -------
466        augmentations_dict : dict
467            a dictionary of augmentation functions
468        """
469
470        return {}
471
472    def _default_augmentations(self) -> List:
473        """
474        Return a list of default augmentation names
475
476        In case an augmentation list is not provided to the class constructor and use_default_augmentations is True,
477        this function is called to set the augmentations parameter. The elements of the list have to be keys of the
478        dictionary returned by `self._augmentations_dict()`
479
480        Returns
481        -------
482        default_augmentations : list
483            a list of string names of the default augmentation functions
484        """
485
486        return []
class Transformer(abc.ABC):
 20class Transformer(ABC):
 21    """
 22    A base class for all transformers
 23
 24    A transformer should apply augmentations and generate model input and training target tensors.
 25
 26    All augmentation functions need to take `(main_input: dict, ssl_inputs: list, ssl_targets: list)`
 27    as input and return an output of the same format. Here `main_input` is a feature dictionary of the sample
 28    data, `ssl_inputs` is a list of SSL input feature dictionaries and `ssl_targets` is a list of SSL target
 29    feature dictionaries. The same augmentations are applied to all inputs and then `None` values are replaced
 30    according to the rules set by `keep_target_none` and `generate_ssl_input` parameters and the feature
 31    dictionaries are compiled into tensors.
 32    """
 33
 34    def __init__(
 35        self,
 36        model_name: str,
 37        augmentations: List = None,
 38        use_default_augmentations: bool = False,
 39        generate_ssl_input: List = None,
 40        keep_target_none: List = None,
 41        ssl_augmentations: List = None,
 42        graph_features: bool = False,
 43        bodyparts_order: List = None,
 44    ) -> None:
 45        """
 46        Parameters
 47        ----------
 48        augmentations : list, optional
 49            a list of string names of augmentations to use (if not provided, either no augmentations are applied or
 50            (if use_default_augmentations is True) a default list is used
 51        use_default_augmentations : bool, default False
 52            if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 53        generate_ssl_input : list, optional
 54            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
 55            is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to
 56            `False` for each module)
 57        keep_target_none : list, optional
 58            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
 59            is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults
 60            to `True` for each module)
 61        ssl_augmentations : list, optional
 62            a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True)
 63            (if not provided, defaults to the main augmentations list)
 64        graph_features : bool, default False
 65            if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)`
 66        bodyparts_order : list, optional
 67            a list of bodypart names, optional
 68        """
 69
 70        if augmentations is None:
 71            augmentations = []
 72        if generate_ssl_input is None:
 73            generate_ssl_input = [None]
 74        if keep_target_none is None:
 75            keep_target_none = [None]
 76        self.model_name = model_name
 77        self.augmentations = augmentations
 78        self.generate_ssl_input = generate_ssl_input
 79        self.keep_target_none = keep_target_none
 80        if len(self.augmentations) == 0 and use_default_augmentations:
 81            self.augmentations = self._default_augmentations()
 82        if ssl_augmentations is None:
 83            ssl_augmentations = self.augmentations
 84        self.ssl_augmentations = ssl_augmentations
 85        self.graph_features = graph_features
 86        self.num_graph_nodes = (
 87            len(bodyparts_order) if bodyparts_order is not None else None
 88        )
 89        self._check_augmentations(self.augmentations)
 90        self._check_augmentations(self.ssl_augmentations)
 91
 92    def transform(
 93        self,
 94        main_input: Dict,
 95        ssl_inputs: List = None,
 96        ssl_targets: List = None,
 97        augment: bool = False,
 98        subsample: List = None,
 99    ) -> Tuple:
100        """
101        Apply augmentations and generate tensors from feature dictionaries.
102
103        The same augmentations are applied to all the inputs (if they have the features known to the transformer).
104
105        If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as
106        another augmentation of main_input.
107        Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`.
108        All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to
109        a model.
110
111        Parameters
112        ----------
113        main_input : dict
114            the feature dictionary of the main input
115        ssl_inputs : list, optional
116            a list of feature dictionaries of SSL inputs (some or all can be None)
117        ssl_targets : list, optional
118            a list of feature dictionaries of SSL targets (some or all can be None)
119        augment : bool, default True
120            if True, augmentations are applied
121        subsample : list, optional
122            a list of indices to subsample the input tensors (if not provided, no subsampling is applied)
123
124        Returns
125        -------
126        main_input : torch.Tensor
127            the augmented tensor of the main input
128        ssl_inputs : list, optional
129            a list of augmented tensors of SSL inputs (some or all can be None)
130        ssl_targets : list, optional
131            a list of augmented tensors of SSL targets (some or all can be None)
132
133        """
134        if subsample is not None:
135            original_len = list(main_input.values())[0].shape[-1]
136            for key in main_input:
137                main_input[key] = main_input[key][..., subsample]
138            # subsample_ssl = sorted(random.sample(range(original_len), len(subsample)))
139            for x in ssl_inputs + ssl_targets:
140                if x is not None:
141                    for key in x:
142                        if len(x[key].shape) == 3 and x[key].shape[-1] == original_len:
143                            x[key] = x[key][..., subsample]
144        main_input, ssl_inputs, ssl_targets = self._apply_augmentations(
145            main_input, ssl_inputs, ssl_targets, augment
146        )
147        meta = [None for _ in ssl_inputs]
148        for i, x in enumerate(ssl_inputs):
149            if type(x) is tuple:
150                x, meta_x = x
151                meta[i] = meta_x
152                ssl_inputs[i] = x
153        for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input):
154            if ssl_x is None and generate:
155                ssl_inputs[i] = self._apply_augmentations(
156                    deepcopy(main_input), None, None, augment=True, ssl=True
157                )[0]
158        output = []
159        num_ssl = len(ssl_inputs)
160        dicts = [main_input] + ssl_inputs + ssl_targets
161
162        for x in dicts:
163            if x is None:
164                output.append(None)
165            else:
166                output.append(self._make_tensor(x, self.model_name))
167                # output.append(
168                #     torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
169                # )
170        main_input, ssl_inputs, ssl_targets = (
171            output[0],
172            output[1 : num_ssl + 1],
173            output[num_ssl + 1 :],
174        )
175        for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none):
176            if not keep and ssl_x is None:
177                ssl_targets[i] = main_input
178        for i, meta_x in enumerate(meta):
179            if meta_x is not None:
180                ssl_inputs[i] = (ssl_inputs[i], meta_x)
181        return main_input, ssl_inputs, ssl_targets
182
183    @abstractmethod
184    def _augmentations_dict(self) -> Dict:
185        """
186        Return a dictionary of possible augmentations
187
188        The keys are augmentation names and the values are the corresponding functions
189
190        Returns
191        -------
192        augmentations_dict : dict
193            a dictionary of augmentation functions (each function needs to take
194            `(main_input: dict, ssl_inputs: list, ssl_targets: list)` as input and return an output
195            of the same format)
196        """
197
198    @abstractmethod
199    def _default_augmentations(self) -> List:
200        """
201        Return a list of default augmentation names
202
203        In case an augmentation list is not provided to the class constructor and use_default_augmentations is `True`,
204        this function is called to set the augmentations parameter. The elements of the list have to be keys of the
205        dictionary returned by `self._augmentations_dict()`
206
207        Returns
208        -------
209        default_augmentations : list
210            a list of string names of the default augmentation functions
211        """
212
213    def _check_augmentations(self, augmentations: List) -> None:
214        """
215        Check the validity of an augmentations list
216        """
217
218        for aug_name in augmentations:
219            if aug_name not in self._augmentations_dict():
220                raise ValueError(
221                    f"The {aug_name} augmentation is not possible in this augmentor! Please choose from {list(self._augmentations_dict().keys())}"
222                )
223
224    def _get_augmentation(self, aug_name: str) -> Callable:
225        """
226        Return the augmentation specified by `aug_name`
227        """
228
229        return self._augmentations_dict()[aug_name]
230
231    def _visualize(self, main_input, title):
232        coord_keys = [x for x in main_input.keys() if x.split("---")[0] == "coords"]
233        if len(coord_keys) > 0:
234            coords = [main_input[key][0, :, :, 0].detach().cpu() for key in coord_keys]
235            centers = [None]
236        else:
237            coord_keys = [
238                x for x in main_input.keys() if x.split("---")[0] == "coord_diff"
239            ]
240            coords = []
241            centers = []
242            for coord_diff_key in coord_keys:
243                if len(coord_diff_key.split("---")) == 2:
244                    ind = coord_diff_key.split("---")[1]
245                    center_key = f"center---{ind}"
246                else:
247                    center_key = "center"
248                    ind = ""
249                if center_key in main_input.keys():
250                    coords.append(
251                        main_input[center_key][0, :, :, 0].detach().cpu()
252                        + main_input[coord_diff_key][0, :, :, 0].detach().cpu()
253                    )
254                    title += f", {ind}: {main_input[center_key][0, 0, :, 0].data}"
255                    center = main_input[center_key][0, :, :, 0].detach().cpu()
256                else:
257                    coords.append(main_input[coord_diff_key][0, :, :, 0].detach().cpu())
258                    center = None
259                centers.append(center)
260        colors = ["blue", "orange", "green", "purple", "pink"]
261        if coords[0].shape[1] == 2:
262            plt.figure(figsize=(15, 15))
263        else:
264            fig = plt.figure(figsize=(15, 15))
265            ax = fig.add_subplot(projection="3d")
266        for i, coord in enumerate(coords):
267            if coord.shape[1] == 2:
268                plt.scatter(coord[:, 0], coord[:, 1], color=colors[i])
269                plt.xlim((-0.5, 0.5))
270                plt.ylim((-0.5, 0.5))
271            else:
272                ax.scatter(
273                    coord[:, 0].detach().cpu(),
274                    coord[:, 1].detach().cpu(),
275                    coord[:, 2].detach().cpu(),
276                    color=colors[i],
277                )
278                if centers[i] is not None:
279                    ax.scatter(
280                        centers[i][:, 0].detach().cpu(),
281                        centers[i][:, 1].detach().cpu(),
282                        0,
283                        color="red",
284                        s=30,
285                    )
286                    center = centers[i][0].detach().cpu()
287                    ax.text(
288                        center[0], center[1], 0, f"({center[0]:.2f}, {center[1]:.2f})"
289                    )
290                for i in [1, 8]:
291                    ax.scatter(
292                        coord[i : i + 1, 0].detach().cpu(),
293                        coord[i : i + 1, 1].detach().cpu(),
294                        coord[i : i + 1, 2].detach().cpu(),
295                        color="purple",
296                    )
297                    ax.text(
298                        coord[i, 0],
299                        coord[i, 1],
300                        coord[i, 2],
301                        f"({coord[i, 0]:.2f}, {coord[i, 1]:.2f}, {coord[i, 2]:.2f})",
302                    )
303
304        plt.xlim((-3, 3))
305        plt.ylim((-3, 3))
306        ax.set_zlim(-2, 4)
307        ax.set_xlabel("x")
308        ax.set_ylabel("y")
309        ax.set_zlabel("z")
310        plt.title(title)
311        plt.show()
312
313    def _apply_augmentations(
314        self,
315        main_input: Dict,
316        ssl_inputs: List = None,
317        ssl_targets: List = None,
318        augment: bool = False,
319        ssl: bool = False,
320    ) -> Tuple:
321        """
322        Apply the augmentations
323
324        The same augmentations are applied to all inputs
325        """
326
327        visualize = False
328        if ssl:
329            augmentations = self.ssl_augmentations
330        else:
331            augmentations = self.augmentations
332        if visualize:
333            self._visualize(main_input, "before")
334        if ssl_inputs is None:
335            ssl_inputs = [None]
336        if ssl_targets is None:
337            ssl_targets = [None]
338        if augment:
339            for aug_name in augmentations:
340                augment_func = self._get_augmentation(aug_name)
341                main_input, ssl_inputs, ssl_targets = augment_func(
342                    main_input, ssl_inputs, ssl_targets
343                )
344                if visualize:
345                    self._visualize(main_input, aug_name)
346        return main_input, ssl_inputs, ssl_targets
347
348    def _make_tensor(self, x: Dict, model_name: str) -> Union[torch.Tensor, TensorList]:
349        """
350        Turn a feature dictionary into a tensor or a `dlc2action.utils.TensorList` object
351        """
352
353        if model_name == "ms_tcn_p":
354            keys = sorted(list(x.keys()))
355            groups = [key.split("---")[-1] for key in keys]
356            unique_groups = sorted(set(groups))
357            tensor = TensorList()
358            for group in unique_groups:
359                if not self.graph_features:
360                    tensor.append(
361                        torch.cat(
362                            [x[key] for key, g in zip(keys, groups) if g == group],
363                            dim=2,
364                        )
365                    )
366                else:
367                    tensor.append(
368                        torch.cat(
369                            [
370                                x[key].reshape(
371                                    (
372                                        x[key].shape[0],
373                                        self.num_graph_nodes,
374                                        -1,
375                                        x[key].shape[-1],
376                                    )
377                                )
378                                for key, g in zip(keys, groups)
379                                if g == group
380                            ],
381                            dim=2,
382                        )
383                    )
384                    tensor[-1] = tensor[-1].reshape(
385                        (tensor[-1].shape[0], -1, tensor[-1].shape[-1])
386                    )
387            if "loaded" in x:
388                tensor.append(x["loaded"])
389        elif model_name == "c2f_tcn_p":
390            keys = sorted(
391                [
392                    key
393                    for key in x.keys()
394                    if len(key.split("---")) != 1
395                    and len(key.split("---")[-1].split("+")) != 2
396                ]
397            )
398            inds = [key.split("---")[-1] for key in keys]
399            unique_inds = sorted(set(inds))
400            tensor = TensorList()
401            for ind in unique_inds:
402                if not self.graph_features:
403                    tensor.append(
404                        torch.cat(
405                            [x[key] for key, g in zip(keys, inds) if g == ind],
406                            dim=1,
407                        )
408                    )
409                else:
410                    tensor.append(
411                        torch.cat(
412                            [
413                                x[key].reshape(
414                                    (
415                                        x[key].shape[0],
416                                        self.num_graph_nodes,
417                                        -1,
418                                        x[key].shape[-1],
419                                    )
420                                )
421                                for key, g in zip(keys, inds)
422                                if g == ind
423                            ],
424                            dim=1,
425                        )
426                    )
427                    tensor[-1] = tensor[-1].reshape(
428                        (tensor[-1].shape[0], -1, tensor[-1].shape[-1])
429                    )
430        elif model_name == "c3d_a":
431            tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
432        else:
433            if not self.graph_features:
434                tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
435            else:
436                tensor = torch.cat(
437                    [
438                        x[key].reshape(
439                            (
440                                x[key].shape[0],
441                                self.num_graph_nodes,
442                                -1,
443                                x[key].shape[-1],
444                            )
445                        )
446                        for key in sorted(list(x.keys()))
447                    ],
448                    dim=2,
449                )
450                tensor = tensor.reshape((tensor.shape[0], -1, tensor.shape[-1]))
451        return tensor

A base class for all transformers

A transformer should apply augmentations and generate model input and training target tensors.

All augmentation functions need to take (main_input: dict, ssl_inputs: list, ssl_targets: list) as input and return an output of the same format. Here main_input is a feature dictionary of the sample data, ssl_inputs is a list of SSL input feature dictionaries and ssl_targets is a list of SSL target feature dictionaries. The same augmentations are applied to all inputs and then None values are replaced according to the rules set by keep_target_none and generate_ssl_input parameters and the feature dictionaries are compiled into tensors.

Transformer( model_name: str, augmentations: List = None, use_default_augmentations: bool = False, generate_ssl_input: List = None, keep_target_none: List = None, ssl_augmentations: List = None, graph_features: bool = False, bodyparts_order: List = None)
34    def __init__(
35        self,
36        model_name: str,
37        augmentations: List = None,
38        use_default_augmentations: bool = False,
39        generate_ssl_input: List = None,
40        keep_target_none: List = None,
41        ssl_augmentations: List = None,
42        graph_features: bool = False,
43        bodyparts_order: List = None,
44    ) -> None:
45        """
46        Parameters
47        ----------
48        augmentations : list, optional
49            a list of string names of augmentations to use (if not provided, either no augmentations are applied or
50            (if use_default_augmentations is True) a default list is used
51        use_default_augmentations : bool, default False
52            if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
53        generate_ssl_input : list, optional
54            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
55            is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to
56            `False` for each module)
57        keep_target_none : list, optional
58            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
59            is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults
60            to `True` for each module)
61        ssl_augmentations : list, optional
62            a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True)
63            (if not provided, defaults to the main augmentations list)
64        graph_features : bool, default False
65            if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)`
66        bodyparts_order : list, optional
67            a list of bodypart names, optional
68        """
69
70        if augmentations is None:
71            augmentations = []
72        if generate_ssl_input is None:
73            generate_ssl_input = [None]
74        if keep_target_none is None:
75            keep_target_none = [None]
76        self.model_name = model_name
77        self.augmentations = augmentations
78        self.generate_ssl_input = generate_ssl_input
79        self.keep_target_none = keep_target_none
80        if len(self.augmentations) == 0 and use_default_augmentations:
81            self.augmentations = self._default_augmentations()
82        if ssl_augmentations is None:
83            ssl_augmentations = self.augmentations
84        self.ssl_augmentations = ssl_augmentations
85        self.graph_features = graph_features
86        self.num_graph_nodes = (
87            len(bodyparts_order) if bodyparts_order is not None else None
88        )
89        self._check_augmentations(self.augmentations)
90        self._check_augmentations(self.ssl_augmentations)

Parameters

augmentations : list, optional a list of string names of augmentations to use (if not provided, either no augmentations are applied or (if use_default_augmentations is True) a default list is used use_default_augmentations : bool, default False if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations generate_ssl_input : list, optional a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value is True, the ssl input will be generated as a new augmentation of main input (if not provided defaults to False for each module) keep_target_none : list, optional a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value is False and the SSL target is None, the target is set to augmented main input (if not provided defaults to True for each module) ssl_augmentations : list, optional a list of augmentation names to be applied with generating SSL input (when generate_ssl_input is True) (if not provided, defaults to the main augmentations list) graph_features : bool, default False if True, all features in each frame can be meaningfully reshaped to (#bodyparts, #features) bodyparts_order : list, optional a list of bodypart names, optional

model_name
augmentations
generate_ssl_input
keep_target_none
ssl_augmentations
graph_features
num_graph_nodes
def transform( self, main_input: Dict, ssl_inputs: List = None, ssl_targets: List = None, augment: bool = False, subsample: List = None) -> Tuple:
 92    def transform(
 93        self,
 94        main_input: Dict,
 95        ssl_inputs: List = None,
 96        ssl_targets: List = None,
 97        augment: bool = False,
 98        subsample: List = None,
 99    ) -> Tuple:
100        """
101        Apply augmentations and generate tensors from feature dictionaries.
102
103        The same augmentations are applied to all the inputs (if they have the features known to the transformer).
104
105        If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as
106        another augmentation of main_input.
107        Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`.
108        All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to
109        a model.
110
111        Parameters
112        ----------
113        main_input : dict
114            the feature dictionary of the main input
115        ssl_inputs : list, optional
116            a list of feature dictionaries of SSL inputs (some or all can be None)
117        ssl_targets : list, optional
118            a list of feature dictionaries of SSL targets (some or all can be None)
119        augment : bool, default True
120            if True, augmentations are applied
121        subsample : list, optional
122            a list of indices to subsample the input tensors (if not provided, no subsampling is applied)
123
124        Returns
125        -------
126        main_input : torch.Tensor
127            the augmented tensor of the main input
128        ssl_inputs : list, optional
129            a list of augmented tensors of SSL inputs (some or all can be None)
130        ssl_targets : list, optional
131            a list of augmented tensors of SSL targets (some or all can be None)
132
133        """
134        if subsample is not None:
135            original_len = list(main_input.values())[0].shape[-1]
136            for key in main_input:
137                main_input[key] = main_input[key][..., subsample]
138            # subsample_ssl = sorted(random.sample(range(original_len), len(subsample)))
139            for x in ssl_inputs + ssl_targets:
140                if x is not None:
141                    for key in x:
142                        if len(x[key].shape) == 3 and x[key].shape[-1] == original_len:
143                            x[key] = x[key][..., subsample]
144        main_input, ssl_inputs, ssl_targets = self._apply_augmentations(
145            main_input, ssl_inputs, ssl_targets, augment
146        )
147        meta = [None for _ in ssl_inputs]
148        for i, x in enumerate(ssl_inputs):
149            if type(x) is tuple:
150                x, meta_x = x
151                meta[i] = meta_x
152                ssl_inputs[i] = x
153        for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input):
154            if ssl_x is None and generate:
155                ssl_inputs[i] = self._apply_augmentations(
156                    deepcopy(main_input), None, None, augment=True, ssl=True
157                )[0]
158        output = []
159        num_ssl = len(ssl_inputs)
160        dicts = [main_input] + ssl_inputs + ssl_targets
161
162        for x in dicts:
163            if x is None:
164                output.append(None)
165            else:
166                output.append(self._make_tensor(x, self.model_name))
167                # output.append(
168                #     torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
169                # )
170        main_input, ssl_inputs, ssl_targets = (
171            output[0],
172            output[1 : num_ssl + 1],
173            output[num_ssl + 1 :],
174        )
175        for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none):
176            if not keep and ssl_x is None:
177                ssl_targets[i] = main_input
178        for i, meta_x in enumerate(meta):
179            if meta_x is not None:
180                ssl_inputs[i] = (ssl_inputs[i], meta_x)
181        return main_input, ssl_inputs, ssl_targets

Apply augmentations and generate tensors from feature dictionaries.

The same augmentations are applied to all the inputs (if they have the features known to the transformer).

If generate_ssl_input is set to True for some of the SSL pairs, those SSL inputs will be generated as another augmentation of main_input. Unless keep_target_none is set to True, None SSL targets will be replaced with augmented main_input. All features are stacked together to form a tensor of shape (#features, #frames) that can be passed to a model.

Parameters

main_input : dict the feature dictionary of the main input ssl_inputs : list, optional a list of feature dictionaries of SSL inputs (some or all can be None) ssl_targets : list, optional a list of feature dictionaries of SSL targets (some or all can be None) augment : bool, default True if True, augmentations are applied subsample : list, optional a list of indices to subsample the input tensors (if not provided, no subsampling is applied)

Returns

main_input : torch.Tensor the augmented tensor of the main input ssl_inputs : list, optional a list of augmented tensors of SSL inputs (some or all can be None) ssl_targets : list, optional a list of augmented tensors of SSL targets (some or all can be None)

class EmptyTransformer(Transformer):
454class EmptyTransformer(Transformer):
455    """
456    Empty transformer class that does not apply augmentations
457    """
458
459    def _augmentations_dict(self) -> Dict:
460        """
461        Return a dictionary of possible augmentations
462
463        The keys are augmentation names and the values are the corresponding functions
464
465        Returns
466        -------
467        augmentations_dict : dict
468            a dictionary of augmentation functions
469        """
470
471        return {}
472
473    def _default_augmentations(self) -> List:
474        """
475        Return a list of default augmentation names
476
477        In case an augmentation list is not provided to the class constructor and use_default_augmentations is True,
478        this function is called to set the augmentations parameter. The elements of the list have to be keys of the
479        dictionary returned by `self._augmentations_dict()`
480
481        Returns
482        -------
483        default_augmentations : list
484            a list of string names of the default augmentation functions
485        """
486
487        return []

Empty transformer class that does not apply augmentations