dlc2action.transformer.base_transformer

Abstract parent class for transformers

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6"""
  7Abstract parent class for transformers
  8"""
  9
 10from typing import Dict, List, Callable, Union, Tuple
 11import torch
 12from abc import ABC, abstractmethod
 13from dlc2action.utils import TensorList
 14from copy import deepcopy
 15from matplotlib import pyplot as plt
 16
 17
 18class Transformer(ABC):
 19    """
 20    A base class for all transformers
 21
 22    A transformer should apply augmentations and generate model input and training target tensors.
 23
 24    All augmentation functions need to take `(main_input: dict, ssl_inputs: list, ssl_targets: list)`
 25    as input and return an output of the same format. Here `main_input` is a feature dictionary of the sample
 26    data, `ssl_inputs` is a list of SSL input feature dictionaries and `ssl_targets` is a list of SSL target
 27    feature dictionaries. The same augmentations are applied to all inputs and then `None` values are replaced
 28    according to the rules set by `keep_target_none` and `generate_ssl_input` parameters and the feature
 29    dictionaries are compiled into tensors.
 30    """
 31
 32    def __init__(
 33        self,
 34        model_name: str,
 35        augmentations: List = None,
 36        use_default_augmentations: bool = False,
 37        generate_ssl_input: List = None,
 38        keep_target_none: List = None,
 39        ssl_augmentations: List = None,
 40        graph_features: bool = False,
 41        bodyparts_order: List = None,
 42    ) -> None:
 43        """
 44        Parameters
 45        ----------
 46        augmentations : list, optional
 47            a list of string names of augmentations to use (if not provided, either no augmentations are applied or
 48            (if use_default_augmentations is True) a default list is used
 49        use_default_augmentations : bool, default False
 50            if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 51        generate_ssl_input : list, optional
 52            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
 53            is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to
 54            `False` for each module)
 55        keep_target_none : list, optional
 56            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
 57            is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults
 58            to `True` for each module)
 59        ssl_augmentations : list, optional
 60            a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True)
 61            (if not provided, defaults to the main augmentations list)
 62        graph_features : bool, default False
 63            if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)`
 64        bodyparts_order : list, optional
 65            a list of bodypart names, optional
 66        """
 67
 68        if augmentations is None:
 69            augmentations = []
 70        if generate_ssl_input is None:
 71            generate_ssl_input = [None]
 72        if keep_target_none is None:
 73            keep_target_none = [None]
 74        self.model_name = model_name
 75        self.augmentations = augmentations
 76        self.generate_ssl_input = generate_ssl_input
 77        self.keep_target_none = keep_target_none
 78        if len(self.augmentations) == 0 and use_default_augmentations:
 79            self.augmentations = self._default_augmentations()
 80        if ssl_augmentations is None:
 81            ssl_augmentations = self.augmentations
 82        self.ssl_augmentations = ssl_augmentations
 83        self.graph_features = graph_features
 84        self.num_graph_nodes = (
 85            len(bodyparts_order) if bodyparts_order is not None else None
 86        )
 87        self._check_augmentations(self.augmentations)
 88        self._check_augmentations(self.ssl_augmentations)
 89
 90    def transform(
 91        self,
 92        main_input: Dict,
 93        ssl_inputs: List = None,
 94        ssl_targets: List = None,
 95        augment: bool = False,
 96        subsample: List = None,
 97    ) -> Tuple:
 98        """
 99        Apply augmentations and generate tensors from feature dictionaries
100
101        The same augmentations are applied to all the inputs (if they have the features known to the transformer).
102
103        If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as
104        another augmentation of main_input.
105        Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`.
106        All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to
107        a model.
108
109        Parameters
110        ----------
111        main_input : dict
112            the feature dictionary of the main input
113        ssl_inputs : list, optional
114            a list of feature dictionaries of SSL inputs (some or all can be None)
115        ssl_targets : list, optional
116            a list of feature dictionaries of SSL targets (some or all can be None)
117        augment : bool, default True
118
119        Returns
120        -------
121        main_input : torch.Tensor
122            the augmented tensor of the main input
123        ssl_inputs : list, optional
124            a list of augmented tensors of SSL inputs (some or all can be None)
125        ssl_targets : list, optional
126            a list of augmented tensors of SSL targets (some or all can be None)
127        """
128
129        if subsample is not None:
130            original_len = list(main_input.values())[0].shape[-1]
131            for key in main_input:
132                main_input[key] = main_input[key][..., subsample]
133            # subsample_ssl = sorted(random.sample(range(original_len), len(subsample)))
134            for x in ssl_inputs + ssl_targets:
135                if x is not None:
136                    for key in x:
137                        if len(x[key].shape) == 3 and x[key].shape[-1] == original_len:
138                            x[key] = x[key][..., subsample]
139        main_input, ssl_inputs, ssl_targets = self._apply_augmentations(
140            main_input, ssl_inputs, ssl_targets, augment
141        )
142        meta = [None for _ in ssl_inputs]
143        for i, x in enumerate(ssl_inputs):
144            if type(x) is tuple:
145                x, meta_x = x
146                meta[i] = meta_x
147                ssl_inputs[i] = x
148        for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input):
149            if ssl_x is None and generate:
150                ssl_inputs[i] = self._apply_augmentations(
151                    deepcopy(main_input), None, None, augment=True, ssl=True
152                )[0]
153        output = []
154        num_ssl = len(ssl_inputs)
155        dicts = [main_input] + ssl_inputs + ssl_targets
156
157        for x in dicts:
158            if x is None:
159                output.append(None)
160            else:
161                output.append(self._make_tensor(x, self.model_name))
162                # output.append(
163                #     torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
164                # )
165        main_input, ssl_inputs, ssl_targets = (
166            output[0],
167            output[1 : num_ssl + 1],
168            output[num_ssl + 1 :],
169        )
170        for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none):
171            if not keep and ssl_x is None:
172                ssl_targets[i] = main_input
173        for i, meta_x in enumerate(meta):
174            if meta_x is not None:
175                ssl_inputs[i] = (ssl_inputs[i], meta_x)
176        return main_input, ssl_inputs, ssl_targets
177
178    @abstractmethod
179    def _augmentations_dict(self) -> Dict:
180        """
181        Return a dictionary of possible augmentations
182
183        The keys are augmentation names and the values are the corresponding functions
184
185        Returns
186        -------
187        augmentations_dict : dict
188            a dictionary of augmentation functions (each function needs to take
189            `(main_input: dict, ssl_inputs: list, ssl_targets: list)` as input and return an output
190            of the same format)
191        """
192
193    @abstractmethod
194    def _default_augmentations(self) -> List:
195        """
196        Return a list of default augmentation names
197
198        In case an augmentation list is not provided to the class constructor and use_default_augmentations is `True`,
199        this function is called to set the augmentations parameter. The elements of the list have to be keys of the
200        dictionary returned by `self._augmentations_dict()`
201
202        Returns
203        -------
204        default_augmentations : list
205            a list of string names of the default augmentation functions
206        """
207
208    def _check_augmentations(self, augmentations: List) -> None:
209        """
210        Check the validity of an augmentations list
211        """
212
213        for aug_name in augmentations:
214            if aug_name not in self._augmentations_dict():
215                raise ValueError(
216                    f"The {aug_name} augmentation is not possible in this augmentor! Please choose from {list(self._augmentations_dict().keys())}"
217                )
218
219    def _get_augmentation(self, aug_name: str) -> Callable:
220        """
221        Return the augmentation specified by `aug_name`
222        """
223
224        return self._augmentations_dict()[aug_name]
225
226    def _visualize(self, main_input, title):
227        coord_keys = [x for x in main_input.keys() if x.split("---")[0] == "coords"]
228        if len(coord_keys) > 0:
229            coords = [main_input[key][0, :, :, 0].detach().cpu() for key in coord_keys]
230            centers = [None]
231        else:
232            coord_keys = [
233                x for x in main_input.keys() if x.split("---")[0] == "coord_diff"
234            ]
235            coords = []
236            centers = []
237            for coord_diff_key in coord_keys:
238                if len(coord_diff_key.split("---")) == 2:
239                    ind = coord_diff_key.split("---")[1]
240                    center_key = f"center---{ind}"
241                else:
242                    center_key = "center"
243                    ind = ""
244                if center_key in main_input.keys():
245                    coords.append(
246                        main_input[center_key][0, :, :, 0].detach().cpu()
247                        + main_input[coord_diff_key][0, :, :, 0].detach().cpu()
248                    )
249                    title += f", {ind}: {main_input[center_key][0, 0, :, 0].data}"
250                    center = main_input[center_key][0, :, :, 0].detach().cpu()
251                else:
252                    coords.append(main_input[coord_diff_key][0, :, :, 0].detach().cpu())
253                    center = None
254                centers.append(center)
255        colors = ["blue", "orange", "green", "purple", "pink"]
256        if coords[0].shape[1] == 2:
257            plt.figure(figsize=(15, 15))
258        else:
259            fig = plt.figure(figsize=(15, 15))
260            ax = fig.add_subplot(projection="3d")
261        for i, coord in enumerate(coords):
262            if coord.shape[1] == 2:
263                plt.scatter(coord[:, 0], coord[:, 1], color=colors[i])
264                plt.xlim((-0.5, 0.5))
265                plt.ylim((-0.5, 0.5))
266            else:
267                ax.scatter(
268                    coord[:, 0].detach().cpu(),
269                    coord[:, 1].detach().cpu(),
270                    coord[:, 2].detach().cpu(),
271                    color=colors[i],
272                )
273                if centers[i] is not None:
274                    ax.scatter(
275                        centers[i][:, 0].detach().cpu(),
276                        centers[i][:, 1].detach().cpu(),
277                        0,
278                        color="red",
279                        s=30,
280                    )
281                    center = centers[i][0].detach().cpu()
282                    ax.text(
283                        center[0], center[1], 0, f"({center[0]:.2f}, {center[1]:.2f})"
284                    )
285                for i in [1, 8]:
286                    ax.scatter(
287                        coord[i : i + 1, 0].detach().cpu(),
288                        coord[i : i + 1, 1].detach().cpu(),
289                        coord[i : i + 1, 2].detach().cpu(),
290                        color="purple",
291                    )
292                    ax.text(
293                        coord[i, 0],
294                        coord[i, 1],
295                        coord[i, 2],
296                        f"({coord[i, 0]:.2f}, {coord[i, 1]:.2f}, {coord[i, 2]:.2f})",
297                    )
298
299        plt.xlim((-3, 3))
300        plt.ylim((-3, 3))
301        ax.set_zlim(-2, 4)
302        ax.set_xlabel("x")
303        ax.set_ylabel("y")
304        ax.set_zlabel("z")
305        plt.title(title)
306        plt.show()
307
308    def _apply_augmentations(
309        self,
310        main_input: Dict,
311        ssl_inputs: List = None,
312        ssl_targets: List = None,
313        augment: bool = False,
314        ssl: bool = False,
315    ) -> Tuple:
316        """
317        Apply the augmentations
318
319        The same augmentations are applied to all inputs
320        """
321
322        visualize = False
323        if ssl:
324            augmentations = self.ssl_augmentations
325        else:
326            augmentations = self.augmentations
327        if visualize:
328            self._visualize(main_input, "before")
329        if ssl_inputs is None:
330            ssl_inputs = [None]
331        if ssl_targets is None:
332            ssl_targets = [None]
333        if augment:
334            for aug_name in augmentations:
335                augment_func = self._get_augmentation(aug_name)
336                main_input, ssl_inputs, ssl_targets = augment_func(
337                    main_input, ssl_inputs, ssl_targets
338                )
339                if visualize:
340                    self._visualize(main_input, aug_name)
341        return main_input, ssl_inputs, ssl_targets
342
343    def _make_tensor(self, x: Dict, model_name: str) -> Union[torch.Tensor, TensorList]:
344        """
345        Turn a feature dictionary into a tensor or a `dlc2action.utils.TensorList` object
346        """
347
348        if model_name == "ms_tcn_p":
349            keys = sorted(list(x.keys()))
350            groups = [key.split("---")[-1] for key in keys]
351            unique_groups = sorted(set(groups))
352            tensor = TensorList()
353            for group in unique_groups:
354                if not self.graph_features:
355                    tensor.append(
356                        torch.cat(
357                            [x[key] for key, g in zip(keys, groups) if g == group],
358                            dim=2,
359                        )
360                    )
361                else:
362                    tensor.append(
363                        torch.cat(
364                            [
365                                x[key].reshape(
366                                    (
367                                        x[key].shape[0],
368                                        self.num_graph_nodes,
369                                        -1,
370                                        x[key].shape[-1],
371                                    )
372                                )
373                                for key, g in zip(keys, groups)
374                                if g == group
375                            ],
376                            dim=2,
377                        )
378                    )
379                    tensor[-1] = tensor[-1].reshape(
380                        (tensor[-1].shape[0], -1, tensor[-1].shape[-1])
381                    )
382            if "loaded" in x:
383                tensor.append(x["loaded"])
384        elif model_name == "c2f_tcn_p":
385            keys = sorted(
386                [
387                    key
388                    for key in x.keys()
389                    if len(key.split("---")) != 1
390                    and len(key.split("---")[-1].split("+")) != 2
391                ]
392            )
393            inds = [key.split("---")[-1] for key in keys]
394            unique_inds = sorted(set(inds))
395            tensor = TensorList()
396            for ind in unique_inds:
397                if not self.graph_features:
398                    tensor.append(
399                        torch.cat(
400                            [x[key] for key, g in zip(keys, inds) if g == ind],
401                            dim=1,
402                        )
403                    )
404                else:
405                    tensor.append(
406                        torch.cat(
407                            [
408                                x[key].reshape(
409                                    (
410                                        x[key].shape[0],
411                                        self.num_graph_nodes,
412                                        -1,
413                                        x[key].shape[-1],
414                                    )
415                                )
416                                for key, g in zip(keys, inds)
417                                if g == ind
418                            ],
419                            dim=1,
420                        )
421                    )
422                    tensor[-1] = tensor[-1].reshape(
423                        (tensor[-1].shape[0], -1, tensor[-1].shape[-1])
424                    )
425        elif model_name == "c3d_a":
426            tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
427        else:
428            if not self.graph_features:
429                tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
430            else:
431                tensor = torch.cat(
432                    [
433                        x[key].reshape(
434                            (
435                                x[key].shape[0],
436                                self.num_graph_nodes,
437                                -1,
438                                x[key].shape[-1],
439                            )
440                        )
441                        for key in sorted(list(x.keys()))
442                    ],
443                    dim=2,
444                )
445                tensor = tensor.reshape((tensor.shape[0], -1, tensor.shape[-1]))
446        return tensor
447
448
449class EmptyTransformer(Transformer):
450    """
451    Empty transformer class that does not apply augmentations
452    """
453
454    def _augmentations_dict(self) -> Dict:
455        """
456        Return a dictionary of possible augmentations
457
458        The keys are augmentation names and the values are the corresponding functions
459
460        Returns
461        -------
462        augmentations_dict : dict
463            a dictionary of augmentation functions
464        """
465
466        return {}
467
468    def _default_augmentations(self) -> List:
469        """
470        Return a list of default augmentation names
471
472        In case an augmentation list is not provided to the class constructor and use_default_augmentations is True,
473        this function is called to set the augmentations parameter. The elements of the list have to be keys of the
474        dictionary returned by `self._augmentations_dict()`
475
476        Returns
477        -------
478        default_augmentations : list
479            a list of string names of the default augmentation functions
480        """
481
482        return []
class Transformer(abc.ABC):
 19class Transformer(ABC):
 20    """
 21    A base class for all transformers
 22
 23    A transformer should apply augmentations and generate model input and training target tensors.
 24
 25    All augmentation functions need to take `(main_input: dict, ssl_inputs: list, ssl_targets: list)`
 26    as input and return an output of the same format. Here `main_input` is a feature dictionary of the sample
 27    data, `ssl_inputs` is a list of SSL input feature dictionaries and `ssl_targets` is a list of SSL target
 28    feature dictionaries. The same augmentations are applied to all inputs and then `None` values are replaced
 29    according to the rules set by `keep_target_none` and `generate_ssl_input` parameters and the feature
 30    dictionaries are compiled into tensors.
 31    """
 32
 33    def __init__(
 34        self,
 35        model_name: str,
 36        augmentations: List = None,
 37        use_default_augmentations: bool = False,
 38        generate_ssl_input: List = None,
 39        keep_target_none: List = None,
 40        ssl_augmentations: List = None,
 41        graph_features: bool = False,
 42        bodyparts_order: List = None,
 43    ) -> None:
 44        """
 45        Parameters
 46        ----------
 47        augmentations : list, optional
 48            a list of string names of augmentations to use (if not provided, either no augmentations are applied or
 49            (if use_default_augmentations is True) a default list is used
 50        use_default_augmentations : bool, default False
 51            if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
 52        generate_ssl_input : list, optional
 53            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
 54            is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to
 55            `False` for each module)
 56        keep_target_none : list, optional
 57            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
 58            is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults
 59            to `True` for each module)
 60        ssl_augmentations : list, optional
 61            a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True)
 62            (if not provided, defaults to the main augmentations list)
 63        graph_features : bool, default False
 64            if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)`
 65        bodyparts_order : list, optional
 66            a list of bodypart names, optional
 67        """
 68
 69        if augmentations is None:
 70            augmentations = []
 71        if generate_ssl_input is None:
 72            generate_ssl_input = [None]
 73        if keep_target_none is None:
 74            keep_target_none = [None]
 75        self.model_name = model_name
 76        self.augmentations = augmentations
 77        self.generate_ssl_input = generate_ssl_input
 78        self.keep_target_none = keep_target_none
 79        if len(self.augmentations) == 0 and use_default_augmentations:
 80            self.augmentations = self._default_augmentations()
 81        if ssl_augmentations is None:
 82            ssl_augmentations = self.augmentations
 83        self.ssl_augmentations = ssl_augmentations
 84        self.graph_features = graph_features
 85        self.num_graph_nodes = (
 86            len(bodyparts_order) if bodyparts_order is not None else None
 87        )
 88        self._check_augmentations(self.augmentations)
 89        self._check_augmentations(self.ssl_augmentations)
 90
 91    def transform(
 92        self,
 93        main_input: Dict,
 94        ssl_inputs: List = None,
 95        ssl_targets: List = None,
 96        augment: bool = False,
 97        subsample: List = None,
 98    ) -> Tuple:
 99        """
100        Apply augmentations and generate tensors from feature dictionaries
101
102        The same augmentations are applied to all the inputs (if they have the features known to the transformer).
103
104        If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as
105        another augmentation of main_input.
106        Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`.
107        All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to
108        a model.
109
110        Parameters
111        ----------
112        main_input : dict
113            the feature dictionary of the main input
114        ssl_inputs : list, optional
115            a list of feature dictionaries of SSL inputs (some or all can be None)
116        ssl_targets : list, optional
117            a list of feature dictionaries of SSL targets (some or all can be None)
118        augment : bool, default True
119
120        Returns
121        -------
122        main_input : torch.Tensor
123            the augmented tensor of the main input
124        ssl_inputs : list, optional
125            a list of augmented tensors of SSL inputs (some or all can be None)
126        ssl_targets : list, optional
127            a list of augmented tensors of SSL targets (some or all can be None)
128        """
129
130        if subsample is not None:
131            original_len = list(main_input.values())[0].shape[-1]
132            for key in main_input:
133                main_input[key] = main_input[key][..., subsample]
134            # subsample_ssl = sorted(random.sample(range(original_len), len(subsample)))
135            for x in ssl_inputs + ssl_targets:
136                if x is not None:
137                    for key in x:
138                        if len(x[key].shape) == 3 and x[key].shape[-1] == original_len:
139                            x[key] = x[key][..., subsample]
140        main_input, ssl_inputs, ssl_targets = self._apply_augmentations(
141            main_input, ssl_inputs, ssl_targets, augment
142        )
143        meta = [None for _ in ssl_inputs]
144        for i, x in enumerate(ssl_inputs):
145            if type(x) is tuple:
146                x, meta_x = x
147                meta[i] = meta_x
148                ssl_inputs[i] = x
149        for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input):
150            if ssl_x is None and generate:
151                ssl_inputs[i] = self._apply_augmentations(
152                    deepcopy(main_input), None, None, augment=True, ssl=True
153                )[0]
154        output = []
155        num_ssl = len(ssl_inputs)
156        dicts = [main_input] + ssl_inputs + ssl_targets
157
158        for x in dicts:
159            if x is None:
160                output.append(None)
161            else:
162                output.append(self._make_tensor(x, self.model_name))
163                # output.append(
164                #     torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
165                # )
166        main_input, ssl_inputs, ssl_targets = (
167            output[0],
168            output[1 : num_ssl + 1],
169            output[num_ssl + 1 :],
170        )
171        for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none):
172            if not keep and ssl_x is None:
173                ssl_targets[i] = main_input
174        for i, meta_x in enumerate(meta):
175            if meta_x is not None:
176                ssl_inputs[i] = (ssl_inputs[i], meta_x)
177        return main_input, ssl_inputs, ssl_targets
178
179    @abstractmethod
180    def _augmentations_dict(self) -> Dict:
181        """
182        Return a dictionary of possible augmentations
183
184        The keys are augmentation names and the values are the corresponding functions
185
186        Returns
187        -------
188        augmentations_dict : dict
189            a dictionary of augmentation functions (each function needs to take
190            `(main_input: dict, ssl_inputs: list, ssl_targets: list)` as input and return an output
191            of the same format)
192        """
193
194    @abstractmethod
195    def _default_augmentations(self) -> List:
196        """
197        Return a list of default augmentation names
198
199        In case an augmentation list is not provided to the class constructor and use_default_augmentations is `True`,
200        this function is called to set the augmentations parameter. The elements of the list have to be keys of the
201        dictionary returned by `self._augmentations_dict()`
202
203        Returns
204        -------
205        default_augmentations : list
206            a list of string names of the default augmentation functions
207        """
208
209    def _check_augmentations(self, augmentations: List) -> None:
210        """
211        Check the validity of an augmentations list
212        """
213
214        for aug_name in augmentations:
215            if aug_name not in self._augmentations_dict():
216                raise ValueError(
217                    f"The {aug_name} augmentation is not possible in this augmentor! Please choose from {list(self._augmentations_dict().keys())}"
218                )
219
220    def _get_augmentation(self, aug_name: str) -> Callable:
221        """
222        Return the augmentation specified by `aug_name`
223        """
224
225        return self._augmentations_dict()[aug_name]
226
227    def _visualize(self, main_input, title):
228        coord_keys = [x for x in main_input.keys() if x.split("---")[0] == "coords"]
229        if len(coord_keys) > 0:
230            coords = [main_input[key][0, :, :, 0].detach().cpu() for key in coord_keys]
231            centers = [None]
232        else:
233            coord_keys = [
234                x for x in main_input.keys() if x.split("---")[0] == "coord_diff"
235            ]
236            coords = []
237            centers = []
238            for coord_diff_key in coord_keys:
239                if len(coord_diff_key.split("---")) == 2:
240                    ind = coord_diff_key.split("---")[1]
241                    center_key = f"center---{ind}"
242                else:
243                    center_key = "center"
244                    ind = ""
245                if center_key in main_input.keys():
246                    coords.append(
247                        main_input[center_key][0, :, :, 0].detach().cpu()
248                        + main_input[coord_diff_key][0, :, :, 0].detach().cpu()
249                    )
250                    title += f", {ind}: {main_input[center_key][0, 0, :, 0].data}"
251                    center = main_input[center_key][0, :, :, 0].detach().cpu()
252                else:
253                    coords.append(main_input[coord_diff_key][0, :, :, 0].detach().cpu())
254                    center = None
255                centers.append(center)
256        colors = ["blue", "orange", "green", "purple", "pink"]
257        if coords[0].shape[1] == 2:
258            plt.figure(figsize=(15, 15))
259        else:
260            fig = plt.figure(figsize=(15, 15))
261            ax = fig.add_subplot(projection="3d")
262        for i, coord in enumerate(coords):
263            if coord.shape[1] == 2:
264                plt.scatter(coord[:, 0], coord[:, 1], color=colors[i])
265                plt.xlim((-0.5, 0.5))
266                plt.ylim((-0.5, 0.5))
267            else:
268                ax.scatter(
269                    coord[:, 0].detach().cpu(),
270                    coord[:, 1].detach().cpu(),
271                    coord[:, 2].detach().cpu(),
272                    color=colors[i],
273                )
274                if centers[i] is not None:
275                    ax.scatter(
276                        centers[i][:, 0].detach().cpu(),
277                        centers[i][:, 1].detach().cpu(),
278                        0,
279                        color="red",
280                        s=30,
281                    )
282                    center = centers[i][0].detach().cpu()
283                    ax.text(
284                        center[0], center[1], 0, f"({center[0]:.2f}, {center[1]:.2f})"
285                    )
286                for i in [1, 8]:
287                    ax.scatter(
288                        coord[i : i + 1, 0].detach().cpu(),
289                        coord[i : i + 1, 1].detach().cpu(),
290                        coord[i : i + 1, 2].detach().cpu(),
291                        color="purple",
292                    )
293                    ax.text(
294                        coord[i, 0],
295                        coord[i, 1],
296                        coord[i, 2],
297                        f"({coord[i, 0]:.2f}, {coord[i, 1]:.2f}, {coord[i, 2]:.2f})",
298                    )
299
300        plt.xlim((-3, 3))
301        plt.ylim((-3, 3))
302        ax.set_zlim(-2, 4)
303        ax.set_xlabel("x")
304        ax.set_ylabel("y")
305        ax.set_zlabel("z")
306        plt.title(title)
307        plt.show()
308
309    def _apply_augmentations(
310        self,
311        main_input: Dict,
312        ssl_inputs: List = None,
313        ssl_targets: List = None,
314        augment: bool = False,
315        ssl: bool = False,
316    ) -> Tuple:
317        """
318        Apply the augmentations
319
320        The same augmentations are applied to all inputs
321        """
322
323        visualize = False
324        if ssl:
325            augmentations = self.ssl_augmentations
326        else:
327            augmentations = self.augmentations
328        if visualize:
329            self._visualize(main_input, "before")
330        if ssl_inputs is None:
331            ssl_inputs = [None]
332        if ssl_targets is None:
333            ssl_targets = [None]
334        if augment:
335            for aug_name in augmentations:
336                augment_func = self._get_augmentation(aug_name)
337                main_input, ssl_inputs, ssl_targets = augment_func(
338                    main_input, ssl_inputs, ssl_targets
339                )
340                if visualize:
341                    self._visualize(main_input, aug_name)
342        return main_input, ssl_inputs, ssl_targets
343
344    def _make_tensor(self, x: Dict, model_name: str) -> Union[torch.Tensor, TensorList]:
345        """
346        Turn a feature dictionary into a tensor or a `dlc2action.utils.TensorList` object
347        """
348
349        if model_name == "ms_tcn_p":
350            keys = sorted(list(x.keys()))
351            groups = [key.split("---")[-1] for key in keys]
352            unique_groups = sorted(set(groups))
353            tensor = TensorList()
354            for group in unique_groups:
355                if not self.graph_features:
356                    tensor.append(
357                        torch.cat(
358                            [x[key] for key, g in zip(keys, groups) if g == group],
359                            dim=2,
360                        )
361                    )
362                else:
363                    tensor.append(
364                        torch.cat(
365                            [
366                                x[key].reshape(
367                                    (
368                                        x[key].shape[0],
369                                        self.num_graph_nodes,
370                                        -1,
371                                        x[key].shape[-1],
372                                    )
373                                )
374                                for key, g in zip(keys, groups)
375                                if g == group
376                            ],
377                            dim=2,
378                        )
379                    )
380                    tensor[-1] = tensor[-1].reshape(
381                        (tensor[-1].shape[0], -1, tensor[-1].shape[-1])
382                    )
383            if "loaded" in x:
384                tensor.append(x["loaded"])
385        elif model_name == "c2f_tcn_p":
386            keys = sorted(
387                [
388                    key
389                    for key in x.keys()
390                    if len(key.split("---")) != 1
391                    and len(key.split("---")[-1].split("+")) != 2
392                ]
393            )
394            inds = [key.split("---")[-1] for key in keys]
395            unique_inds = sorted(set(inds))
396            tensor = TensorList()
397            for ind in unique_inds:
398                if not self.graph_features:
399                    tensor.append(
400                        torch.cat(
401                            [x[key] for key, g in zip(keys, inds) if g == ind],
402                            dim=1,
403                        )
404                    )
405                else:
406                    tensor.append(
407                        torch.cat(
408                            [
409                                x[key].reshape(
410                                    (
411                                        x[key].shape[0],
412                                        self.num_graph_nodes,
413                                        -1,
414                                        x[key].shape[-1],
415                                    )
416                                )
417                                for key, g in zip(keys, inds)
418                                if g == ind
419                            ],
420                            dim=1,
421                        )
422                    )
423                    tensor[-1] = tensor[-1].reshape(
424                        (tensor[-1].shape[0], -1, tensor[-1].shape[-1])
425                    )
426        elif model_name == "c3d_a":
427            tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
428        else:
429            if not self.graph_features:
430                tensor = torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
431            else:
432                tensor = torch.cat(
433                    [
434                        x[key].reshape(
435                            (
436                                x[key].shape[0],
437                                self.num_graph_nodes,
438                                -1,
439                                x[key].shape[-1],
440                            )
441                        )
442                        for key in sorted(list(x.keys()))
443                    ],
444                    dim=2,
445                )
446                tensor = tensor.reshape((tensor.shape[0], -1, tensor.shape[-1]))
447        return tensor

A base class for all transformers

A transformer should apply augmentations and generate model input and training target tensors.

All augmentation functions need to take (main_input: dict, ssl_inputs: list, ssl_targets: list) as input and return an output of the same format. Here main_input is a feature dictionary of the sample data, ssl_inputs is a list of SSL input feature dictionaries and ssl_targets is a list of SSL target feature dictionaries. The same augmentations are applied to all inputs and then None values are replaced according to the rules set by keep_target_none and generate_ssl_input parameters and the feature dictionaries are compiled into tensors.

Transformer( model_name: str, augmentations: List = None, use_default_augmentations: bool = False, generate_ssl_input: List = None, keep_target_none: List = None, ssl_augmentations: List = None, graph_features: bool = False, bodyparts_order: List = None)
33    def __init__(
34        self,
35        model_name: str,
36        augmentations: List = None,
37        use_default_augmentations: bool = False,
38        generate_ssl_input: List = None,
39        keep_target_none: List = None,
40        ssl_augmentations: List = None,
41        graph_features: bool = False,
42        bodyparts_order: List = None,
43    ) -> None:
44        """
45        Parameters
46        ----------
47        augmentations : list, optional
48            a list of string names of augmentations to use (if not provided, either no augmentations are applied or
49            (if use_default_augmentations is True) a default list is used
50        use_default_augmentations : bool, default False
51            if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations
52        generate_ssl_input : list, optional
53            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
54            is `True`, the ssl input will be generated as a new augmentation of main input (if not provided defaults to
55            `False` for each module)
56        keep_target_none : list, optional
57            a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value
58            is `False` and the SSL target is `None`, the target is set to augmented main input (if not provided defaults
59            to `True` for each module)
60        ssl_augmentations : list, optional
61            a list of augmentation names to be applied with generating SSL input (when `generate_ssl_input` is True)
62            (if not provided, defaults to the main augmentations list)
63        graph_features : bool, default False
64            if `True`, all features in each frame can be meaningfully reshaped to `(#bodyparts, #features)`
65        bodyparts_order : list, optional
66            a list of bodypart names, optional
67        """
68
69        if augmentations is None:
70            augmentations = []
71        if generate_ssl_input is None:
72            generate_ssl_input = [None]
73        if keep_target_none is None:
74            keep_target_none = [None]
75        self.model_name = model_name
76        self.augmentations = augmentations
77        self.generate_ssl_input = generate_ssl_input
78        self.keep_target_none = keep_target_none
79        if len(self.augmentations) == 0 and use_default_augmentations:
80            self.augmentations = self._default_augmentations()
81        if ssl_augmentations is None:
82            ssl_augmentations = self.augmentations
83        self.ssl_augmentations = ssl_augmentations
84        self.graph_features = graph_features
85        self.num_graph_nodes = (
86            len(bodyparts_order) if bodyparts_order is not None else None
87        )
88        self._check_augmentations(self.augmentations)
89        self._check_augmentations(self.ssl_augmentations)

Parameters

augmentations : list, optional a list of string names of augmentations to use (if not provided, either no augmentations are applied or (if use_default_augmentations is True) a default list is used use_default_augmentations : bool, default False if True and augmentations are not passed, default augmentations will be applied; otherwise no augmentations generate_ssl_input : list, optional a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value is True, the ssl input will be generated as a new augmentation of main input (if not provided defaults to False for each module) keep_target_none : list, optional a list of bool values of the length of the number of SSL modules being used; if the corresponding bool value is False and the SSL target is None, the target is set to augmented main input (if not provided defaults to True for each module) ssl_augmentations : list, optional a list of augmentation names to be applied with generating SSL input (when generate_ssl_input is True) (if not provided, defaults to the main augmentations list) graph_features : bool, default False if True, all features in each frame can be meaningfully reshaped to (#bodyparts, #features) bodyparts_order : list, optional a list of bodypart names, optional

def transform( self, main_input: Dict, ssl_inputs: List = None, ssl_targets: List = None, augment: bool = False, subsample: List = None) -> Tuple:
 91    def transform(
 92        self,
 93        main_input: Dict,
 94        ssl_inputs: List = None,
 95        ssl_targets: List = None,
 96        augment: bool = False,
 97        subsample: List = None,
 98    ) -> Tuple:
 99        """
100        Apply augmentations and generate tensors from feature dictionaries
101
102        The same augmentations are applied to all the inputs (if they have the features known to the transformer).
103
104        If `generate_ssl_input` is set to True for some of the SSL pairs, those SSL inputs will be generated as
105        another augmentation of main_input.
106        Unless `keep_target_none` is set to True, `None` SSL targets will be replaced with augmented `main_input`.
107        All features are stacked together to form a tensor of shape `(#features, #frames)` that can be passed to
108        a model.
109
110        Parameters
111        ----------
112        main_input : dict
113            the feature dictionary of the main input
114        ssl_inputs : list, optional
115            a list of feature dictionaries of SSL inputs (some or all can be None)
116        ssl_targets : list, optional
117            a list of feature dictionaries of SSL targets (some or all can be None)
118        augment : bool, default True
119
120        Returns
121        -------
122        main_input : torch.Tensor
123            the augmented tensor of the main input
124        ssl_inputs : list, optional
125            a list of augmented tensors of SSL inputs (some or all can be None)
126        ssl_targets : list, optional
127            a list of augmented tensors of SSL targets (some or all can be None)
128        """
129
130        if subsample is not None:
131            original_len = list(main_input.values())[0].shape[-1]
132            for key in main_input:
133                main_input[key] = main_input[key][..., subsample]
134            # subsample_ssl = sorted(random.sample(range(original_len), len(subsample)))
135            for x in ssl_inputs + ssl_targets:
136                if x is not None:
137                    for key in x:
138                        if len(x[key].shape) == 3 and x[key].shape[-1] == original_len:
139                            x[key] = x[key][..., subsample]
140        main_input, ssl_inputs, ssl_targets = self._apply_augmentations(
141            main_input, ssl_inputs, ssl_targets, augment
142        )
143        meta = [None for _ in ssl_inputs]
144        for i, x in enumerate(ssl_inputs):
145            if type(x) is tuple:
146                x, meta_x = x
147                meta[i] = meta_x
148                ssl_inputs[i] = x
149        for (i, ssl_x), generate in zip(enumerate(ssl_inputs), self.generate_ssl_input):
150            if ssl_x is None and generate:
151                ssl_inputs[i] = self._apply_augmentations(
152                    deepcopy(main_input), None, None, augment=True, ssl=True
153                )[0]
154        output = []
155        num_ssl = len(ssl_inputs)
156        dicts = [main_input] + ssl_inputs + ssl_targets
157
158        for x in dicts:
159            if x is None:
160                output.append(None)
161            else:
162                output.append(self._make_tensor(x, self.model_name))
163                # output.append(
164                #     torch.cat([x[key] for key in sorted(list(x.keys()))], dim=1)
165                # )
166        main_input, ssl_inputs, ssl_targets = (
167            output[0],
168            output[1 : num_ssl + 1],
169            output[num_ssl + 1 :],
170        )
171        for (i, ssl_x), keep in zip(enumerate(ssl_targets), self.keep_target_none):
172            if not keep and ssl_x is None:
173                ssl_targets[i] = main_input
174        for i, meta_x in enumerate(meta):
175            if meta_x is not None:
176                ssl_inputs[i] = (ssl_inputs[i], meta_x)
177        return main_input, ssl_inputs, ssl_targets

Apply augmentations and generate tensors from feature dictionaries

The same augmentations are applied to all the inputs (if they have the features known to the transformer).

If generate_ssl_input is set to True for some of the SSL pairs, those SSL inputs will be generated as another augmentation of main_input. Unless keep_target_none is set to True, None SSL targets will be replaced with augmented main_input. All features are stacked together to form a tensor of shape (#features, #frames) that can be passed to a model.

Parameters

main_input : dict the feature dictionary of the main input ssl_inputs : list, optional a list of feature dictionaries of SSL inputs (some or all can be None) ssl_targets : list, optional a list of feature dictionaries of SSL targets (some or all can be None) augment : bool, default True

Returns

main_input : torch.Tensor the augmented tensor of the main input ssl_inputs : list, optional a list of augmented tensors of SSL inputs (some or all can be None) ssl_targets : list, optional a list of augmented tensors of SSL targets (some or all can be None)

class EmptyTransformer(Transformer):
450class EmptyTransformer(Transformer):
451    """
452    Empty transformer class that does not apply augmentations
453    """
454
455    def _augmentations_dict(self) -> Dict:
456        """
457        Return a dictionary of possible augmentations
458
459        The keys are augmentation names and the values are the corresponding functions
460
461        Returns
462        -------
463        augmentations_dict : dict
464            a dictionary of augmentation functions
465        """
466
467        return {}
468
469    def _default_augmentations(self) -> List:
470        """
471        Return a list of default augmentation names
472
473        In case an augmentation list is not provided to the class constructor and use_default_augmentations is True,
474        this function is called to set the augmentations parameter. The elements of the list have to be keys of the
475        dictionary returned by `self._augmentations_dict()`
476
477        Returns
478        -------
479        default_augmentations : list
480            a list of string names of the default augmentation functions
481        """
482
483        return []

Empty transformer class that does not apply augmentations

Inherited Members
Transformer
Transformer
transform