dlc2action.ssl.masked

Implementations of dlc2action.ssl.base_ssl.SSLConstructor that predict masked input features.

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""Implementations of `dlc2action.ssl.base_ssl.SSLConstructor` that predict masked input features."""
  8
  9from typing import Dict, Tuple, Union, List
 10
 11import torch
 12
 13from dlc2action.ssl.base_ssl import SSLConstructor
 14from abc import ABC, abstractmethod
 15from dlc2action.loss.mse import MSE
 16from dlc2action.ssl.modules import *
 17
 18
 19class MaskedFeaturesSSL(SSLConstructor, ABC):
 20    """A base masked features SSL class.
 21
 22    Mask some of the input features randomly and predict the initial data.
 23    """
 24
 25    type = "ssl_input"
 26
 27    def __init__(self, frac_masked: float = 0.2) -> None:
 28        """Initialize the constructor.
 29
 30        Parameters
 31        ----------
 32        frac_masked : float
 33            fraction of features to real_lens
 34
 35        """
 36        super().__init__()
 37        self.mse = MSE()
 38        self.frac_masked = frac_masked
 39
 40    def transformation(self, sample_data: Dict) -> Tuple:
 41        """Mask some of the features randomly."""
 42        for key in sample_data:
 43            mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked
 44            sample_data[key] = sample_data[key] * mask
 45        ssl_target = torch.cat(list(sample_data.values()))
 46        return (sample_data, ssl_target)
 47
 48    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
 49        """MSE loss."""
 50        loss = self.mse(predicted, target)
 51        return loss
 52
 53    @abstractmethod
 54    def construct_module(self) -> nn.Module:
 55        """Construct the SSL prediction module using the parameters specified at initialization."""
 56
 57
 58class MaskedFeaturesSSL_FC(MaskedFeaturesSSL):
 59    """A fully connected masked features SSL class.
 60
 61    Mask some of the input features randomly and predict the initial data.
 62    """
 63
 64    type = "ssl_input"
 65
 66    def __init__(
 67        self,
 68        dims: torch.Size,
 69        num_f_maps: torch.Size,
 70        frac_masked: float = 0.2,
 71        num_ssl_layers: int = 5,
 72        num_ssl_f_maps: int = 16,
 73    ) -> None:
 74        """Initialize the constructor.
 75
 76        Parameters
 77        ----------
 78        dims : torch.Size
 79            the shape of features in model input
 80        num_f_maps : torch.Size
 81            shape of feature extraction output
 82        frac_masked : float, default 0.1
 83            fraction of features to real_lens
 84        num_ssl_layers : int, default 5
 85            number of layers in the SSL module
 86        num_ssl_f_maps : int, default 16
 87            number of feature maps in the SSL module
 88
 89        """
 90        super().__init__(frac_masked)
 91        dim = int(sum([s[0] for s in dims.values()]))
 92        num_f_maps = int(num_f_maps[0])
 93        self.pars = {
 94            "dim": dim,
 95            "num_f_maps": num_f_maps,
 96            "num_ssl_layers": num_ssl_layers,
 97            "num_ssl_f_maps": num_ssl_f_maps,
 98        }
 99
100    def construct_module(self) -> Union[nn.Module, None]:
101        """Construct a fully connected module."""
102        module = FC(**self.pars)
103        return module
104
105
106class MaskedFeaturesSSL_TCN(MaskedFeaturesSSL):
107    """A TCN masked features SSL class.
108
109    Mask some of the input features randomly and predict the initial data.
110    """
111
112    def __init__(
113        self,
114        dims: Dict,
115        num_f_maps: torch.Size,
116        frac_masked: float = 0.2,
117        num_ssl_layers: int = 5,
118        num_ssl_f_maps:int = None,
119    ) -> None:
120        """Initialize the class.
121
122        Parameters
123        ----------
124        dims : torch.Size
125            the shape of features in model input
126        num_f_maps : torch.Size
127            shape of feature extraction output
128        frac_masked : float, default 0.1
129            fraction of features to real_lens
130        num_ssl_layers : int, default 5
131            number of layers in the SSL module
132
133        """
134        super().__init__(frac_masked)
135
136        if not num_ssl_f_maps is None:
137            print(f"num_ssl_f_maps is set to {num_ssl_f_maps} but is ignored for TCN")
138
139        dim = int(sum([s[0] for s in dims.values()]))
140        num_f_maps = int(num_f_maps[0])
141        self.pars = {
142            "input_dim": num_f_maps,
143            "num_layers": num_ssl_layers,
144            "output_dim": dim,
145        }
146
147    def construct_module(self) -> Union[nn.Module, None]:
148        """Construct a TCN module."""
149        module = DilatedTCN(**self.pars)
150        return module
151
152
153class MaskedKinematicSSL(SSLConstructor, ABC):
154    """A base masked joints SSL class.
155
156    Mask some of the joints randomly and predict the initial data.
157    """
158
159    type = "ssl_input"
160
161    def __init__(self, frac_masked: float = 0.2) -> None:
162        """Initialize the class.
163
164        Parameters
165        ----------
166        frac_masked : float, default 0.1
167            fraction of features to real_lens
168
169        """
170        super().__init__()
171        self.mse = MSE()
172        self.frac_masked = frac_masked
173
174    def _get_keys(self, key_bases, x):
175        """Get keys of x that start with one of the strings in key_bases."""
176        keys = []
177        for key in x:
178            if key_bases.count(key) > 0:
179                keys.append(key)
180        return keys
181
182    def transformation(self, sample_data: Dict) -> Tuple:
183        """Mask joints randomly."""
184        assert (
185            "coords" in sample_data.keys() or "coord_diff" in sample_data.keys()
186        ), "'coords' or 'coord_diff' features are required when using MaskedKinematicSSL"
187
188        multi_dim_features = self._get_keys(
189            (
190                "coords",
191                "coord_diff",
192                "speed_direction",
193            ),
194            sample_data,
195        )
196
197        single_dim_features = self._get_keys(
198            (
199                "speed_joints",
200                "acc_joints",
201                "angle_joints_radian",
202                "angle_speeds",
203                "speed_value",
204            ),
205            sample_data,
206        )
207
208        assert (
209            len(multi_dim_features) > 0
210        ), "No multi-dimensional features found in sample_data"
211        assert (
212            len(single_dim_features) > 0
213        ), "No single-dimensional features found in sample_data"
214
215        features, frames = sample_data[multi_dim_features[0]].shape
216
217        n_bp = features // 2
218        masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked
219
220        keys = self._get_keys(("intra_distance", "inter_distance"), sample_data)
221        for key in keys:
222            mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2)
223            indices = torch.triu_indices(n_bp, n_bp, 1)
224
225            X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device)
226            X[indices[0], indices[1], :] = sample_data[key]
227            X[mask] = 0
228            X[mask.transpose(0, 1)] = 0
229            sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames)
230
231        for key in multi_dim_features + single_dim_features:
232            mask = (
233                masked_joints.repeat(sample_data[key].shape[0]//sample_data[single_dim_features[0]].shape[0], frames, 1)
234                .transpose(0, 2)
235                .reshape((-1, frames))
236            )
237            sample_data[key][mask] = 0
238
239        return sample_data, torch.cat(list(sample_data.values()))
240
241    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
242        """MSE loss."""
243        loss = self.mse(predicted, target)
244        return loss
245
246    @abstractmethod
247    def construct_module(self) -> Union[nn.Module, None]:
248        """Construct the SSL prediction module using the parameters specified at initialization."""
249
250
251class MaskedKinematicSSL_FC(MaskedKinematicSSL):
252    """Masked kinematic SSL class with fully connected module."""
253
254    def __init__(
255        self,
256        dims: torch.Size,
257        num_f_maps: torch.Size,
258        frac_masked: float = 0.2,
259        num_ssl_layers: int = 5,
260        num_ssl_f_maps: int = 16,
261    ) -> None:
262        """Initialize the constructor.
263
264        Parameters
265        ----------
266        dims : torch.Size
267            the number of features in model input
268        num_f_maps : torch.Size
269            shape of feature extraction output
270        frac_masked : float, default 0.1
271            fraction of joints to real_lens
272        num_ssl_layers : int, default 5
273            number of layers in the SSL module
274        num_ssl_f_maps : int, default 16
275            number of feature maps in the SSL module
276
277        """
278        super().__init__(frac_masked)
279        dim = int(sum([s[0] for s in dims.values()]))
280        num_f_maps = int(num_f_maps[0])
281        self.pars = {
282            "dim": dim,
283            "num_f_maps": num_f_maps,
284            "num_ssl_layers": num_ssl_layers,
285            "num_ssl_f_maps": num_ssl_f_maps,
286        }
287
288    def construct_module(self) -> Union[nn.Module, None]:
289        """Construct a fully connected module."""
290        module = FC(**self.pars)
291        return module
292
293
294class MaskedKinematicSSL_TCN(MaskedKinematicSSL):
295    """Masked kinematic SSL using a TCN module."""
296
297    def __init__(
298        self,
299        dims: torch.Size,
300        num_f_maps: torch.Size,
301        frac_masked: float = 0.2,
302        num_ssl_layers: int = 5,
303    ) -> None:
304        """Initialise the constructor.
305
306        Parameters
307        ----------
308        dims : torch.Size
309            the shape of features in model input
310        num_f_maps : torch.Size
311            shape of feature extraction output
312        frac_masked : float, default 0.1
313            fraction of joints to real_lens
314        num_ssl_layers : int, default 5
315            number of layers in the SSL module
316
317        """
318        super().__init__(frac_masked)
319        dim = int(sum([s[0] for s in dims.values()]))
320        num_f_maps = int(num_f_maps[0])
321        self.pars = {
322            "input_dim": num_f_maps,
323            "num_layers": num_ssl_layers,
324            "output_dim": dim,
325        }
326
327    def construct_module(self) -> Union[nn.Module, None]:
328        """Construct a TCN module."""
329        module = DilatedTCN(**self.pars)
330        return module
331
332
333class MaskedFramesSSL(SSLConstructor, ABC):
334    """Abstract class for masked frame SSL constructors.
335
336    Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly
337    and predict the initial data
338    """
339
340    type = "ssl_input"
341
342    def __init__(self, frac_masked: float = 0.1) -> None:
343        """Initialize the SSL constructor.
344
345        Parameters
346        ----------
347        frac_masked : float, default 0.1
348            fraction of frames to real_lens
349
350        """
351        super().__init__()
352        self.frac_masked = frac_masked
353        self.mse = MSE()
354
355    def transformation(self, sample_data: Dict) -> Tuple:
356        """Mask some of the frames randomly."""
357        key = list(sample_data.keys())[0]
358        num_frames = sample_data[key].shape[-1]
359        mask = torch.empty(num_frames).normal_() > self.frac_masked
360        mask = mask.unsqueeze(0)
361        for key in sample_data:
362            sample_data[key] = sample_data[key] * mask
363        ssl_target = torch.cat(list(sample_data.values()))
364        return (sample_data, ssl_target)
365
366    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
367        """MSE loss."""
368        loss = self.mse(predicted, target)
369        return loss
370
371    @abstractmethod
372    def construct_module(self) -> Union[nn.Module, None]:
373        """Construct the SSL prediction module using the parameters specified at initialization."""
374
375
376class MaskedFramesSSL_FC(MaskedFramesSSL):
377    """Masked frames SSL with a fully connected module."""
378
379    def __init__(
380        self,
381        dims: torch.Size,
382        num_f_maps: torch.Size,
383        frac_masked: float = 0.1,
384        num_ssl_layers: int = 3,
385        num_ssl_f_maps: int = 16,
386    ) -> None:
387        """Initialize the constructor.
388
389        Parameters
390        ----------
391        dims : torch.Size
392            the shape of features in model input
393        num_f_maps : torch.Size
394            shape of feature extraction output
395        frac_masked : float, default 0.1
396            fraction of frames to real_lens
397        num_ssl_layers : int, default 5
398            number of layers in the SSL module
399        num_ssl_f_maps : int, default 16
400            number of feature maps in the SSL module
401
402        """
403        super().__init__(frac_masked)
404        dim = int(sum([s[0] for s in dims.values()]))
405        num_f_maps = int(num_f_maps[0])
406        self.pars = {
407            "dim": dim,
408            "num_f_maps": num_f_maps,
409            "num_ssl_layers": num_ssl_layers,
410            "num_ssl_f_maps": num_ssl_f_maps,
411        }
412
413    def construct_module(self) -> Union[nn.Module, None]:
414        """Construct a fully connected module."""
415        module = FC(**self.pars)
416        return module
417
418
419class MaskedFramesSSL_TCN(MaskedFramesSSL):
420    """Masked frames SSL with a TCN module."""
421
422    def __init__(
423        self,
424        dims: torch.Size,
425        num_f_maps: torch.Size,
426        frac_masked: float = 0.2,
427        num_ssl_layers: int = 5,
428    ) -> None:
429        """Initialize the SSL constructor.
430
431        Parameters
432        ----------
433        dims : torch.Size
434            the number of features in model input
435        num_f_maps : torch.Size
436            shape of feature extraction output
437        frac_masked : float, default 0.1
438            fraction of frames to real_lens
439        num_ssl_layers : int, default 5
440            number of layers in the SSL module
441
442        """
443        super().__init__(frac_masked)
444        dim = int(sum([s[0] for s in dims.values()]))
445        num_f_maps = int(num_f_maps[0])
446        self.pars = {
447            "input_dim": num_f_maps,
448            "num_layers": num_ssl_layers,
449            "output_dim": dim,
450        }
451
452    def construct_module(self) -> Union[nn.Module, None]:
453        """Construct a TCN module."""
454        module = DilatedTCN(**self.pars)
455        return module
class MaskedFeaturesSSL(dlc2action.ssl.base_ssl.SSLConstructor, abc.ABC):
20class MaskedFeaturesSSL(SSLConstructor, ABC):
21    """A base masked features SSL class.
22
23    Mask some of the input features randomly and predict the initial data.
24    """
25
26    type = "ssl_input"
27
28    def __init__(self, frac_masked: float = 0.2) -> None:
29        """Initialize the constructor.
30
31        Parameters
32        ----------
33        frac_masked : float
34            fraction of features to real_lens
35
36        """
37        super().__init__()
38        self.mse = MSE()
39        self.frac_masked = frac_masked
40
41    def transformation(self, sample_data: Dict) -> Tuple:
42        """Mask some of the features randomly."""
43        for key in sample_data:
44            mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked
45            sample_data[key] = sample_data[key] * mask
46        ssl_target = torch.cat(list(sample_data.values()))
47        return (sample_data, ssl_target)
48
49    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
50        """MSE loss."""
51        loss = self.mse(predicted, target)
52        return loss
53
54    @abstractmethod
55    def construct_module(self) -> nn.Module:
56        """Construct the SSL prediction module using the parameters specified at initialization."""

A base masked features SSL class.

Mask some of the input features randomly and predict the initial data.

MaskedFeaturesSSL(frac_masked: float = 0.2)
28    def __init__(self, frac_masked: float = 0.2) -> None:
29        """Initialize the constructor.
30
31        Parameters
32        ----------
33        frac_masked : float
34            fraction of features to real_lens
35
36        """
37        super().__init__()
38        self.mse = MSE()
39        self.frac_masked = frac_masked

Initialize the constructor.

Parameters

frac_masked : float fraction of features to real_lens

type = 'ssl_input'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
mse
frac_masked
def transformation(self, sample_data: Dict) -> Tuple:
41    def transformation(self, sample_data: Dict) -> Tuple:
42        """Mask some of the features randomly."""
43        for key in sample_data:
44            mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked
45            sample_data[key] = sample_data[key] * mask
46        ssl_target = torch.cat(list(sample_data.values()))
47        return (sample_data, ssl_target)

Mask some of the features randomly.

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
49    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
50        """MSE loss."""
51        loss = self.mse(predicted, target)
52        return loss

MSE loss.

@abstractmethod
def construct_module(self) -> torch.nn.modules.module.Module:
54    @abstractmethod
55    def construct_module(self) -> nn.Module:
56        """Construct the SSL prediction module using the parameters specified at initialization."""

Construct the SSL prediction module using the parameters specified at initialization.

class MaskedFeaturesSSL_FC(MaskedFeaturesSSL):
 59class MaskedFeaturesSSL_FC(MaskedFeaturesSSL):
 60    """A fully connected masked features SSL class.
 61
 62    Mask some of the input features randomly and predict the initial data.
 63    """
 64
 65    type = "ssl_input"
 66
 67    def __init__(
 68        self,
 69        dims: torch.Size,
 70        num_f_maps: torch.Size,
 71        frac_masked: float = 0.2,
 72        num_ssl_layers: int = 5,
 73        num_ssl_f_maps: int = 16,
 74    ) -> None:
 75        """Initialize the constructor.
 76
 77        Parameters
 78        ----------
 79        dims : torch.Size
 80            the shape of features in model input
 81        num_f_maps : torch.Size
 82            shape of feature extraction output
 83        frac_masked : float, default 0.1
 84            fraction of features to real_lens
 85        num_ssl_layers : int, default 5
 86            number of layers in the SSL module
 87        num_ssl_f_maps : int, default 16
 88            number of feature maps in the SSL module
 89
 90        """
 91        super().__init__(frac_masked)
 92        dim = int(sum([s[0] for s in dims.values()]))
 93        num_f_maps = int(num_f_maps[0])
 94        self.pars = {
 95            "dim": dim,
 96            "num_f_maps": num_f_maps,
 97            "num_ssl_layers": num_ssl_layers,
 98            "num_ssl_f_maps": num_ssl_f_maps,
 99        }
100
101    def construct_module(self) -> Union[nn.Module, None]:
102        """Construct a fully connected module."""
103        module = FC(**self.pars)
104        return module

A fully connected masked features SSL class.

Mask some of the input features randomly and predict the initial data.

MaskedFeaturesSSL_FC( dims: torch.Size, num_f_maps: torch.Size, frac_masked: float = 0.2, num_ssl_layers: int = 5, num_ssl_f_maps: int = 16)
67    def __init__(
68        self,
69        dims: torch.Size,
70        num_f_maps: torch.Size,
71        frac_masked: float = 0.2,
72        num_ssl_layers: int = 5,
73        num_ssl_f_maps: int = 16,
74    ) -> None:
75        """Initialize the constructor.
76
77        Parameters
78        ----------
79        dims : torch.Size
80            the shape of features in model input
81        num_f_maps : torch.Size
82            shape of feature extraction output
83        frac_masked : float, default 0.1
84            fraction of features to real_lens
85        num_ssl_layers : int, default 5
86            number of layers in the SSL module
87        num_ssl_f_maps : int, default 16
88            number of feature maps in the SSL module
89
90        """
91        super().__init__(frac_masked)
92        dim = int(sum([s[0] for s in dims.values()]))
93        num_f_maps = int(num_f_maps[0])
94        self.pars = {
95            "dim": dim,
96            "num_f_maps": num_f_maps,
97            "num_ssl_layers": num_ssl_layers,
98            "num_ssl_f_maps": num_ssl_f_maps,
99        }

Initialize the constructor.

Parameters

dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of features to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module

type = 'ssl_input'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
pars
def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
101    def construct_module(self) -> Union[nn.Module, None]:
102        """Construct a fully connected module."""
103        module = FC(**self.pars)
104        return module

Construct a fully connected module.

class MaskedFeaturesSSL_TCN(MaskedFeaturesSSL):
107class MaskedFeaturesSSL_TCN(MaskedFeaturesSSL):
108    """A TCN masked features SSL class.
109
110    Mask some of the input features randomly and predict the initial data.
111    """
112
113    def __init__(
114        self,
115        dims: Dict,
116        num_f_maps: torch.Size,
117        frac_masked: float = 0.2,
118        num_ssl_layers: int = 5,
119        num_ssl_f_maps:int = None,
120    ) -> None:
121        """Initialize the class.
122
123        Parameters
124        ----------
125        dims : torch.Size
126            the shape of features in model input
127        num_f_maps : torch.Size
128            shape of feature extraction output
129        frac_masked : float, default 0.1
130            fraction of features to real_lens
131        num_ssl_layers : int, default 5
132            number of layers in the SSL module
133
134        """
135        super().__init__(frac_masked)
136
137        if not num_ssl_f_maps is None:
138            print(f"num_ssl_f_maps is set to {num_ssl_f_maps} but is ignored for TCN")
139
140        dim = int(sum([s[0] for s in dims.values()]))
141        num_f_maps = int(num_f_maps[0])
142        self.pars = {
143            "input_dim": num_f_maps,
144            "num_layers": num_ssl_layers,
145            "output_dim": dim,
146        }
147
148    def construct_module(self) -> Union[nn.Module, None]:
149        """Construct a TCN module."""
150        module = DilatedTCN(**self.pars)
151        return module

A TCN masked features SSL class.

Mask some of the input features randomly and predict the initial data.

MaskedFeaturesSSL_TCN( dims: Dict, num_f_maps: torch.Size, frac_masked: float = 0.2, num_ssl_layers: int = 5, num_ssl_f_maps: int = None)
113    def __init__(
114        self,
115        dims: Dict,
116        num_f_maps: torch.Size,
117        frac_masked: float = 0.2,
118        num_ssl_layers: int = 5,
119        num_ssl_f_maps:int = None,
120    ) -> None:
121        """Initialize the class.
122
123        Parameters
124        ----------
125        dims : torch.Size
126            the shape of features in model input
127        num_f_maps : torch.Size
128            shape of feature extraction output
129        frac_masked : float, default 0.1
130            fraction of features to real_lens
131        num_ssl_layers : int, default 5
132            number of layers in the SSL module
133
134        """
135        super().__init__(frac_masked)
136
137        if not num_ssl_f_maps is None:
138            print(f"num_ssl_f_maps is set to {num_ssl_f_maps} but is ignored for TCN")
139
140        dim = int(sum([s[0] for s in dims.values()]))
141        num_f_maps = int(num_f_maps[0])
142        self.pars = {
143            "input_dim": num_f_maps,
144            "num_layers": num_ssl_layers,
145            "output_dim": dim,
146        }

Initialize the class.

Parameters

dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of features to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module

pars
def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
148    def construct_module(self) -> Union[nn.Module, None]:
149        """Construct a TCN module."""
150        module = DilatedTCN(**self.pars)
151        return module

Construct a TCN module.

class MaskedKinematicSSL(dlc2action.ssl.base_ssl.SSLConstructor, abc.ABC):
154class MaskedKinematicSSL(SSLConstructor, ABC):
155    """A base masked joints SSL class.
156
157    Mask some of the joints randomly and predict the initial data.
158    """
159
160    type = "ssl_input"
161
162    def __init__(self, frac_masked: float = 0.2) -> None:
163        """Initialize the class.
164
165        Parameters
166        ----------
167        frac_masked : float, default 0.1
168            fraction of features to real_lens
169
170        """
171        super().__init__()
172        self.mse = MSE()
173        self.frac_masked = frac_masked
174
175    def _get_keys(self, key_bases, x):
176        """Get keys of x that start with one of the strings in key_bases."""
177        keys = []
178        for key in x:
179            if key_bases.count(key) > 0:
180                keys.append(key)
181        return keys
182
183    def transformation(self, sample_data: Dict) -> Tuple:
184        """Mask joints randomly."""
185        assert (
186            "coords" in sample_data.keys() or "coord_diff" in sample_data.keys()
187        ), "'coords' or 'coord_diff' features are required when using MaskedKinematicSSL"
188
189        multi_dim_features = self._get_keys(
190            (
191                "coords",
192                "coord_diff",
193                "speed_direction",
194            ),
195            sample_data,
196        )
197
198        single_dim_features = self._get_keys(
199            (
200                "speed_joints",
201                "acc_joints",
202                "angle_joints_radian",
203                "angle_speeds",
204                "speed_value",
205            ),
206            sample_data,
207        )
208
209        assert (
210            len(multi_dim_features) > 0
211        ), "No multi-dimensional features found in sample_data"
212        assert (
213            len(single_dim_features) > 0
214        ), "No single-dimensional features found in sample_data"
215
216        features, frames = sample_data[multi_dim_features[0]].shape
217
218        n_bp = features // 2
219        masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked
220
221        keys = self._get_keys(("intra_distance", "inter_distance"), sample_data)
222        for key in keys:
223            mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2)
224            indices = torch.triu_indices(n_bp, n_bp, 1)
225
226            X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device)
227            X[indices[0], indices[1], :] = sample_data[key]
228            X[mask] = 0
229            X[mask.transpose(0, 1)] = 0
230            sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames)
231
232        for key in multi_dim_features + single_dim_features:
233            mask = (
234                masked_joints.repeat(sample_data[key].shape[0]//sample_data[single_dim_features[0]].shape[0], frames, 1)
235                .transpose(0, 2)
236                .reshape((-1, frames))
237            )
238            sample_data[key][mask] = 0
239
240        return sample_data, torch.cat(list(sample_data.values()))
241
242    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
243        """MSE loss."""
244        loss = self.mse(predicted, target)
245        return loss
246
247    @abstractmethod
248    def construct_module(self) -> Union[nn.Module, None]:
249        """Construct the SSL prediction module using the parameters specified at initialization."""

A base masked joints SSL class.

Mask some of the joints randomly and predict the initial data.

MaskedKinematicSSL(frac_masked: float = 0.2)
162    def __init__(self, frac_masked: float = 0.2) -> None:
163        """Initialize the class.
164
165        Parameters
166        ----------
167        frac_masked : float, default 0.1
168            fraction of features to real_lens
169
170        """
171        super().__init__()
172        self.mse = MSE()
173        self.frac_masked = frac_masked

Initialize the class.

Parameters

frac_masked : float, default 0.1 fraction of features to real_lens

type = 'ssl_input'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
mse
frac_masked
def transformation(self, sample_data: Dict) -> Tuple:
183    def transformation(self, sample_data: Dict) -> Tuple:
184        """Mask joints randomly."""
185        assert (
186            "coords" in sample_data.keys() or "coord_diff" in sample_data.keys()
187        ), "'coords' or 'coord_diff' features are required when using MaskedKinematicSSL"
188
189        multi_dim_features = self._get_keys(
190            (
191                "coords",
192                "coord_diff",
193                "speed_direction",
194            ),
195            sample_data,
196        )
197
198        single_dim_features = self._get_keys(
199            (
200                "speed_joints",
201                "acc_joints",
202                "angle_joints_radian",
203                "angle_speeds",
204                "speed_value",
205            ),
206            sample_data,
207        )
208
209        assert (
210            len(multi_dim_features) > 0
211        ), "No multi-dimensional features found in sample_data"
212        assert (
213            len(single_dim_features) > 0
214        ), "No single-dimensional features found in sample_data"
215
216        features, frames = sample_data[multi_dim_features[0]].shape
217
218        n_bp = features // 2
219        masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked
220
221        keys = self._get_keys(("intra_distance", "inter_distance"), sample_data)
222        for key in keys:
223            mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2)
224            indices = torch.triu_indices(n_bp, n_bp, 1)
225
226            X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device)
227            X[indices[0], indices[1], :] = sample_data[key]
228            X[mask] = 0
229            X[mask.transpose(0, 1)] = 0
230            sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames)
231
232        for key in multi_dim_features + single_dim_features:
233            mask = (
234                masked_joints.repeat(sample_data[key].shape[0]//sample_data[single_dim_features[0]].shape[0], frames, 1)
235                .transpose(0, 2)
236                .reshape((-1, frames))
237            )
238            sample_data[key][mask] = 0
239
240        return sample_data, torch.cat(list(sample_data.values()))

Mask joints randomly.

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
242    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
243        """MSE loss."""
244        loss = self.mse(predicted, target)
245        return loss

MSE loss.

@abstractmethod
def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
247    @abstractmethod
248    def construct_module(self) -> Union[nn.Module, None]:
249        """Construct the SSL prediction module using the parameters specified at initialization."""

Construct the SSL prediction module using the parameters specified at initialization.

class MaskedKinematicSSL_FC(MaskedKinematicSSL):
252class MaskedKinematicSSL_FC(MaskedKinematicSSL):
253    """Masked kinematic SSL class with fully connected module."""
254
255    def __init__(
256        self,
257        dims: torch.Size,
258        num_f_maps: torch.Size,
259        frac_masked: float = 0.2,
260        num_ssl_layers: int = 5,
261        num_ssl_f_maps: int = 16,
262    ) -> None:
263        """Initialize the constructor.
264
265        Parameters
266        ----------
267        dims : torch.Size
268            the number of features in model input
269        num_f_maps : torch.Size
270            shape of feature extraction output
271        frac_masked : float, default 0.1
272            fraction of joints to real_lens
273        num_ssl_layers : int, default 5
274            number of layers in the SSL module
275        num_ssl_f_maps : int, default 16
276            number of feature maps in the SSL module
277
278        """
279        super().__init__(frac_masked)
280        dim = int(sum([s[0] for s in dims.values()]))
281        num_f_maps = int(num_f_maps[0])
282        self.pars = {
283            "dim": dim,
284            "num_f_maps": num_f_maps,
285            "num_ssl_layers": num_ssl_layers,
286            "num_ssl_f_maps": num_ssl_f_maps,
287        }
288
289    def construct_module(self) -> Union[nn.Module, None]:
290        """Construct a fully connected module."""
291        module = FC(**self.pars)
292        return module

Masked kinematic SSL class with fully connected module.

MaskedKinematicSSL_FC( dims: torch.Size, num_f_maps: torch.Size, frac_masked: float = 0.2, num_ssl_layers: int = 5, num_ssl_f_maps: int = 16)
255    def __init__(
256        self,
257        dims: torch.Size,
258        num_f_maps: torch.Size,
259        frac_masked: float = 0.2,
260        num_ssl_layers: int = 5,
261        num_ssl_f_maps: int = 16,
262    ) -> None:
263        """Initialize the constructor.
264
265        Parameters
266        ----------
267        dims : torch.Size
268            the number of features in model input
269        num_f_maps : torch.Size
270            shape of feature extraction output
271        frac_masked : float, default 0.1
272            fraction of joints to real_lens
273        num_ssl_layers : int, default 5
274            number of layers in the SSL module
275        num_ssl_f_maps : int, default 16
276            number of feature maps in the SSL module
277
278        """
279        super().__init__(frac_masked)
280        dim = int(sum([s[0] for s in dims.values()]))
281        num_f_maps = int(num_f_maps[0])
282        self.pars = {
283            "dim": dim,
284            "num_f_maps": num_f_maps,
285            "num_ssl_layers": num_ssl_layers,
286            "num_ssl_f_maps": num_ssl_f_maps,
287        }

Initialize the constructor.

Parameters

dims : torch.Size the number of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of joints to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module

pars
def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
289    def construct_module(self) -> Union[nn.Module, None]:
290        """Construct a fully connected module."""
291        module = FC(**self.pars)
292        return module

Construct a fully connected module.

class MaskedKinematicSSL_TCN(MaskedKinematicSSL):
295class MaskedKinematicSSL_TCN(MaskedKinematicSSL):
296    """Masked kinematic SSL using a TCN module."""
297
298    def __init__(
299        self,
300        dims: torch.Size,
301        num_f_maps: torch.Size,
302        frac_masked: float = 0.2,
303        num_ssl_layers: int = 5,
304    ) -> None:
305        """Initialise the constructor.
306
307        Parameters
308        ----------
309        dims : torch.Size
310            the shape of features in model input
311        num_f_maps : torch.Size
312            shape of feature extraction output
313        frac_masked : float, default 0.1
314            fraction of joints to real_lens
315        num_ssl_layers : int, default 5
316            number of layers in the SSL module
317
318        """
319        super().__init__(frac_masked)
320        dim = int(sum([s[0] for s in dims.values()]))
321        num_f_maps = int(num_f_maps[0])
322        self.pars = {
323            "input_dim": num_f_maps,
324            "num_layers": num_ssl_layers,
325            "output_dim": dim,
326        }
327
328    def construct_module(self) -> Union[nn.Module, None]:
329        """Construct a TCN module."""
330        module = DilatedTCN(**self.pars)
331        return module

Masked kinematic SSL using a TCN module.

MaskedKinematicSSL_TCN( dims: torch.Size, num_f_maps: torch.Size, frac_masked: float = 0.2, num_ssl_layers: int = 5)
298    def __init__(
299        self,
300        dims: torch.Size,
301        num_f_maps: torch.Size,
302        frac_masked: float = 0.2,
303        num_ssl_layers: int = 5,
304    ) -> None:
305        """Initialise the constructor.
306
307        Parameters
308        ----------
309        dims : torch.Size
310            the shape of features in model input
311        num_f_maps : torch.Size
312            shape of feature extraction output
313        frac_masked : float, default 0.1
314            fraction of joints to real_lens
315        num_ssl_layers : int, default 5
316            number of layers in the SSL module
317
318        """
319        super().__init__(frac_masked)
320        dim = int(sum([s[0] for s in dims.values()]))
321        num_f_maps = int(num_f_maps[0])
322        self.pars = {
323            "input_dim": num_f_maps,
324            "num_layers": num_ssl_layers,
325            "output_dim": dim,
326        }

Initialise the constructor.

Parameters

dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of joints to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module

pars
def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
328    def construct_module(self) -> Union[nn.Module, None]:
329        """Construct a TCN module."""
330        module = DilatedTCN(**self.pars)
331        return module

Construct a TCN module.

class MaskedFramesSSL(dlc2action.ssl.base_ssl.SSLConstructor, abc.ABC):
334class MaskedFramesSSL(SSLConstructor, ABC):
335    """Abstract class for masked frame SSL constructors.
336
337    Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly
338    and predict the initial data
339    """
340
341    type = "ssl_input"
342
343    def __init__(self, frac_masked: float = 0.1) -> None:
344        """Initialize the SSL constructor.
345
346        Parameters
347        ----------
348        frac_masked : float, default 0.1
349            fraction of frames to real_lens
350
351        """
352        super().__init__()
353        self.frac_masked = frac_masked
354        self.mse = MSE()
355
356    def transformation(self, sample_data: Dict) -> Tuple:
357        """Mask some of the frames randomly."""
358        key = list(sample_data.keys())[0]
359        num_frames = sample_data[key].shape[-1]
360        mask = torch.empty(num_frames).normal_() > self.frac_masked
361        mask = mask.unsqueeze(0)
362        for key in sample_data:
363            sample_data[key] = sample_data[key] * mask
364        ssl_target = torch.cat(list(sample_data.values()))
365        return (sample_data, ssl_target)
366
367    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
368        """MSE loss."""
369        loss = self.mse(predicted, target)
370        return loss
371
372    @abstractmethod
373    def construct_module(self) -> Union[nn.Module, None]:
374        """Construct the SSL prediction module using the parameters specified at initialization."""

Abstract class for masked frame SSL constructors.

Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly and predict the initial data

MaskedFramesSSL(frac_masked: float = 0.1)
343    def __init__(self, frac_masked: float = 0.1) -> None:
344        """Initialize the SSL constructor.
345
346        Parameters
347        ----------
348        frac_masked : float, default 0.1
349            fraction of frames to real_lens
350
351        """
352        super().__init__()
353        self.frac_masked = frac_masked
354        self.mse = MSE()

Initialize the SSL constructor.

Parameters

frac_masked : float, default 0.1 fraction of frames to real_lens

type = 'ssl_input'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
frac_masked
mse
def transformation(self, sample_data: Dict) -> Tuple:
356    def transformation(self, sample_data: Dict) -> Tuple:
357        """Mask some of the frames randomly."""
358        key = list(sample_data.keys())[0]
359        num_frames = sample_data[key].shape[-1]
360        mask = torch.empty(num_frames).normal_() > self.frac_masked
361        mask = mask.unsqueeze(0)
362        for key in sample_data:
363            sample_data[key] = sample_data[key] * mask
364        ssl_target = torch.cat(list(sample_data.values()))
365        return (sample_data, ssl_target)

Mask some of the frames randomly.

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
367    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
368        """MSE loss."""
369        loss = self.mse(predicted, target)
370        return loss

MSE loss.

@abstractmethod
def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
372    @abstractmethod
373    def construct_module(self) -> Union[nn.Module, None]:
374        """Construct the SSL prediction module using the parameters specified at initialization."""

Construct the SSL prediction module using the parameters specified at initialization.

class MaskedFramesSSL_FC(MaskedFramesSSL):
377class MaskedFramesSSL_FC(MaskedFramesSSL):
378    """Masked frames SSL with a fully connected module."""
379
380    def __init__(
381        self,
382        dims: torch.Size,
383        num_f_maps: torch.Size,
384        frac_masked: float = 0.1,
385        num_ssl_layers: int = 3,
386        num_ssl_f_maps: int = 16,
387    ) -> None:
388        """Initialize the constructor.
389
390        Parameters
391        ----------
392        dims : torch.Size
393            the shape of features in model input
394        num_f_maps : torch.Size
395            shape of feature extraction output
396        frac_masked : float, default 0.1
397            fraction of frames to real_lens
398        num_ssl_layers : int, default 5
399            number of layers in the SSL module
400        num_ssl_f_maps : int, default 16
401            number of feature maps in the SSL module
402
403        """
404        super().__init__(frac_masked)
405        dim = int(sum([s[0] for s in dims.values()]))
406        num_f_maps = int(num_f_maps[0])
407        self.pars = {
408            "dim": dim,
409            "num_f_maps": num_f_maps,
410            "num_ssl_layers": num_ssl_layers,
411            "num_ssl_f_maps": num_ssl_f_maps,
412        }
413
414    def construct_module(self) -> Union[nn.Module, None]:
415        """Construct a fully connected module."""
416        module = FC(**self.pars)
417        return module

Masked frames SSL with a fully connected module.

MaskedFramesSSL_FC( dims: torch.Size, num_f_maps: torch.Size, frac_masked: float = 0.1, num_ssl_layers: int = 3, num_ssl_f_maps: int = 16)
380    def __init__(
381        self,
382        dims: torch.Size,
383        num_f_maps: torch.Size,
384        frac_masked: float = 0.1,
385        num_ssl_layers: int = 3,
386        num_ssl_f_maps: int = 16,
387    ) -> None:
388        """Initialize the constructor.
389
390        Parameters
391        ----------
392        dims : torch.Size
393            the shape of features in model input
394        num_f_maps : torch.Size
395            shape of feature extraction output
396        frac_masked : float, default 0.1
397            fraction of frames to real_lens
398        num_ssl_layers : int, default 5
399            number of layers in the SSL module
400        num_ssl_f_maps : int, default 16
401            number of feature maps in the SSL module
402
403        """
404        super().__init__(frac_masked)
405        dim = int(sum([s[0] for s in dims.values()]))
406        num_f_maps = int(num_f_maps[0])
407        self.pars = {
408            "dim": dim,
409            "num_f_maps": num_f_maps,
410            "num_ssl_layers": num_ssl_layers,
411            "num_ssl_f_maps": num_ssl_f_maps,
412        }

Initialize the constructor.

Parameters

dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of frames to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module

pars
def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
414    def construct_module(self) -> Union[nn.Module, None]:
415        """Construct a fully connected module."""
416        module = FC(**self.pars)
417        return module

Construct a fully connected module.

class MaskedFramesSSL_TCN(MaskedFramesSSL):
420class MaskedFramesSSL_TCN(MaskedFramesSSL):
421    """Masked frames SSL with a TCN module."""
422
423    def __init__(
424        self,
425        dims: torch.Size,
426        num_f_maps: torch.Size,
427        frac_masked: float = 0.2,
428        num_ssl_layers: int = 5,
429    ) -> None:
430        """Initialize the SSL constructor.
431
432        Parameters
433        ----------
434        dims : torch.Size
435            the number of features in model input
436        num_f_maps : torch.Size
437            shape of feature extraction output
438        frac_masked : float, default 0.1
439            fraction of frames to real_lens
440        num_ssl_layers : int, default 5
441            number of layers in the SSL module
442
443        """
444        super().__init__(frac_masked)
445        dim = int(sum([s[0] for s in dims.values()]))
446        num_f_maps = int(num_f_maps[0])
447        self.pars = {
448            "input_dim": num_f_maps,
449            "num_layers": num_ssl_layers,
450            "output_dim": dim,
451        }
452
453    def construct_module(self) -> Union[nn.Module, None]:
454        """Construct a TCN module."""
455        module = DilatedTCN(**self.pars)
456        return module

Masked frames SSL with a TCN module.

MaskedFramesSSL_TCN( dims: torch.Size, num_f_maps: torch.Size, frac_masked: float = 0.2, num_ssl_layers: int = 5)
423    def __init__(
424        self,
425        dims: torch.Size,
426        num_f_maps: torch.Size,
427        frac_masked: float = 0.2,
428        num_ssl_layers: int = 5,
429    ) -> None:
430        """Initialize the SSL constructor.
431
432        Parameters
433        ----------
434        dims : torch.Size
435            the number of features in model input
436        num_f_maps : torch.Size
437            shape of feature extraction output
438        frac_masked : float, default 0.1
439            fraction of frames to real_lens
440        num_ssl_layers : int, default 5
441            number of layers in the SSL module
442
443        """
444        super().__init__(frac_masked)
445        dim = int(sum([s[0] for s in dims.values()]))
446        num_f_maps = int(num_f_maps[0])
447        self.pars = {
448            "input_dim": num_f_maps,
449            "num_layers": num_ssl_layers,
450            "output_dim": dim,
451        }

Initialize the SSL constructor.

Parameters

dims : torch.Size the number of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of frames to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module

pars
def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
453    def construct_module(self) -> Union[nn.Module, None]:
454        """Construct a TCN module."""
455        module = DilatedTCN(**self.pars)
456        return module

Construct a TCN module.