dlc2action.ssl.masked

Implementations of dlc2action.ssl.base_ssl.SSLConstructor that predict masked input features

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6"""
  7Implementations of `dlc2action.ssl.base_ssl.SSLConstructor` that predict masked input features
  8"""
  9
 10from typing import Dict, Tuple, Union, List
 11
 12import torch
 13
 14from dlc2action.ssl.base_ssl import SSLConstructor
 15from abc import ABC, abstractmethod
 16from dlc2action.loss.mse import _MSE
 17from dlc2action.ssl.modules import _FC, _DilatedTCN
 18import torch
 19from torch import nn
 20
 21
 22class MaskedFeaturesSSL(SSLConstructor, ABC):
 23    """
 24    A base masked features SSL class
 25
 26    Mask some of the input features randomly and predict the initial data.
 27    """
 28
 29    type = "ssl_input"
 30
 31    def __init__(self, frac_masked: float = 0.2) -> None:
 32        """
 33        Parameters
 34        ----------
 35        frac_masked : float
 36            fraction of features to real_lens
 37        """
 38
 39        super().__init__()
 40        self.mse = _MSE()
 41        self.frac_masked = frac_masked
 42
 43    def transformation(self, sample_data: Dict) -> Tuple:
 44        """
 45        Mask some of the features randomly
 46        """
 47
 48        for key in sample_data:
 49            mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked
 50            sample_data[key] = sample_data[key] * mask
 51        ssl_target = torch.cat(list(sample_data.values()))
 52        return (sample_data, ssl_target)
 53
 54    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
 55        """
 56        MSE loss
 57        """
 58
 59        loss = self.mse(predicted, target)
 60        return loss
 61
 62    @abstractmethod
 63    def construct_module(self) -> nn.Module:
 64        """
 65        Construct the SSL prediction module using the parameters specified at initialization
 66        """
 67
 68
 69class MaskedFeaturesSSL_FC(MaskedFeaturesSSL):
 70    """
 71    A fully connected masked features SSL class
 72
 73    Mask some of the input features randomly and predict the initial data.
 74    """
 75
 76    def __init__(
 77        self,
 78        dims: torch.Size,
 79        num_f_maps: torch.Size,
 80        frac_masked: float = 0.2,
 81        num_ssl_layers: int = 5,
 82        num_ssl_f_maps: int = 16,
 83    ) -> None:
 84        """
 85        Parameters
 86        ----------
 87        dims : torch.Size
 88            the shape of features in model input
 89        num_f_maps : torch.Size
 90            shape of feature extraction output
 91        frac_masked : float, default 0.1
 92            fraction of features to real_lens
 93        num_ssl_layers : int, default 5
 94            number of layers in the SSL module
 95        num_ssl_f_maps : int, default 16
 96            number of feature maps in the SSL module
 97        """
 98
 99        super().__init__(frac_masked)
100        dim = int(sum([s[0] for s in dims.values()]))
101        num_f_maps = int(num_f_maps[0])
102        self.pars = {
103            "dim": dim,
104            "num_f_maps": num_f_maps,
105            "num_ssl_layers": num_ssl_layers,
106            "num_ssl_f_maps": num_ssl_f_maps,
107        }
108
109    def construct_module(self) -> Union[nn.Module, None]:
110        """
111        A fully connected module
112        """
113
114        module = _FC(**self.pars)
115        return module
116
117
118class MaskedFeaturesSSL_TCN(MaskedFeaturesSSL):
119    """
120    A TCN masked features SSL class
121
122    Mask some of the input features randomly and predict the initial data.
123    """
124
125    def __init__(
126        self,
127        dims: Dict,
128        num_f_maps: torch.Size,
129        frac_masked: float = 0.2,
130        num_ssl_layers: int = 5,
131    ) -> None:
132        """
133        Parameters
134        ----------
135        dims : torch.Size
136            the shape of features in model input
137        num_f_maps : torch.Size
138            shape of feature extraction output
139        frac_masked : float, default 0.1
140            fraction of features to real_lens
141        num_ssl_layers : int, default 5
142            number of layers in the SSL module
143        """
144
145        super().__init__(frac_masked)
146        dim = int(sum([s[0] for s in dims.values()]))
147        num_f_maps = int(num_f_maps[0])
148        self.pars = {
149            "input_dim": num_f_maps,
150            "num_layers": num_ssl_layers,
151            "output_dim": dim,
152        }
153
154    def construct_module(self) -> Union[nn.Module, None]:
155        """
156        A TCN module
157        """
158
159        module = _DilatedTCN(**self.pars)
160        return module
161
162
163class MaskedKinematicSSL(SSLConstructor, ABC):
164    """
165    A base masked joints SSL class
166
167    Mask some of the joints randomly and predict the initial data.
168    """
169
170    type = "ssl_input"
171
172    def __init__(self, frac_masked: float = 0.2) -> None:
173        """
174        Parameters
175        ----------
176        frac_masked : float, default 0.1
177            fraction of features to real_lens
178        """
179
180        super().__init__()
181        self.mse = _MSE()
182        self.frac_masked = frac_masked
183
184    def _get_keys(self, key_bases, x):
185        """
186        Get keys of x that start with one of the strings in key_bases
187        """
188
189        keys = []
190        for key in x:
191            if key.startswith(key_bases):
192                keys.append(key)
193        return keys
194
195    def transformation(self, sample_data: Dict) -> Tuple:
196        """
197        Mask joints randomly
198        """
199
200        key = self._get_keys("coords", sample_data)[0]
201        features, frames = sample_data[key].shape
202        n_bp = features // 2
203        masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked
204        keys = self._get_keys(("intra_distance", "inter_distance"), sample_data)
205        for key in keys:
206            mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2)
207            indices = torch.triu_indices(n_bp, n_bp, 1)
208
209            X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device)
210            X[indices[0], indices[1], :] = sample_data[key]
211            X[mask] = 0
212            X[mask.transpose(0, 1)] = 0
213            sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames)
214        keys = self._get_keys(("speed_joints", "coords", "acc_joints"), sample_data)
215        for key in keys:
216            mask = (
217                masked_joints.repeat(2, frames, 1).transpose(0, 2).reshape((-1, frames))
218            )
219            sample_data[key][mask] = 0
220        keys = self._get_keys("angle_joints_radian", sample_data)
221        for key in keys:
222            mask = masked_joints.repeat(frames, 1).transpose(0, 1)
223            sample_data[key][mask] = 0
224
225        return sample_data, torch.cat(list(sample_data.values()))
226
227    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
228        """
229        MSE loss
230        """
231
232        loss = self.mse(predicted, target)
233        return loss
234
235    @abstractmethod
236    def construct_module(self) -> Union[nn.Module, None]:
237        """
238        Construct the SSL prediction module using the parameters specified at initialization
239        """
240
241
242class MaskedKinematicSSL_FC(MaskedKinematicSSL):
243    def __init__(
244        self,
245        dims: torch.Size,
246        num_f_maps: torch.Size,
247        frac_masked: float = 0.2,
248        num_ssl_layers: int = 5,
249        num_ssl_f_maps: int = 16,
250    ) -> None:
251        """
252        Parameters
253        ----------
254        dims : torch.Size
255            the number of features in model input
256        num_f_maps : torch.Size
257            shape of feature extraction output
258        frac_masked : float, default 0.1
259            fraction of joints to real_lens
260        num_ssl_layers : int, default 5
261            number of layers in the SSL module
262        num_ssl_f_maps : int, default 16
263            number of feature maps in the SSL module
264        """
265
266        super().__init__(frac_masked)
267        dim = int(sum([s[0] for s in dims.values()]))
268        num_f_maps = int(num_f_maps[0])
269        self.pars = {
270            "dim": dim,
271            "num_f_maps": num_f_maps,
272            "num_ssl_layers": num_ssl_layers,
273            "num_ssl_f_maps": num_ssl_f_maps,
274        }
275
276    def construct_module(self) -> Union[nn.Module, None]:
277        """
278        A fully connected module
279        """
280
281        module = _FC(**self.pars)
282        return module
283
284
285class MaskedKinematicSSL_TCN(MaskedKinematicSSL):
286    def __init__(
287        self,
288        dims: torch.Size,
289        num_f_maps: torch.Size,
290        frac_masked: float = 0.2,
291        num_ssl_layers: int = 5,
292    ) -> None:
293        """
294        Parameters
295        ----------
296        dims : torch.Size
297            the shape of features in model input
298        num_f_maps : torch.Size
299            shape of feature extraction output
300        frac_masked : float, default 0.1
301            fraction of joints to real_lens
302        num_ssl_layers : int, default 5
303            number of layers in the SSL module
304        """
305
306        super().__init__(frac_masked)
307        dim = int(sum([s[0] for s in dims.values()]))
308        num_f_maps = int(num_f_maps[0])
309        self.pars = {
310            "input_dim": num_f_maps,
311            "num_layers": num_ssl_layers,
312            "output_dim": dim,
313        }
314
315    def construct_module(self) -> Union[nn.Module, None]:
316        """
317        A TCN module
318        """
319
320        module = _DilatedTCN(**self.pars)
321        return module
322
323
324class MaskedFramesSSL(SSLConstructor, ABC):
325    """Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly
326    and predict the initial data"""
327
328    type = "ssl_input"
329
330    def __init__(self, frac_masked: float = 0.1) -> None:
331        """
332        Parameters
333        ----------
334        frac_masked : float, default 0.1
335            fraction of frames to real_lens
336        """
337
338        super().__init__()
339        self.frac_masked = frac_masked
340        self.mse = _MSE()
341
342    def transformation(self, sample_data: Dict) -> Tuple:
343        """
344        Mask some of the frames randomly
345        """
346
347        key = list(sample_data.keys())[0]
348        num_frames = sample_data[key].shape[-1]
349        mask = torch.empty(num_frames).normal_() > self.frac_masked
350        mask = mask.unsqueeze(0)
351        for key in sample_data:
352            sample_data[key] = sample_data[key] * mask
353        ssl_target = torch.cat(list(sample_data.values()))
354        return (sample_data, ssl_target)
355
356    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
357        """
358        MSE loss
359        """
360
361        loss = self.mse(predicted, target)
362        return loss
363
364    @abstractmethod
365    def construct_module(self) -> Union[nn.Module, None]:
366        """
367        Construct the SSL prediction module using the parameters specified at initialization
368        """
369
370
371class MaskedFramesSSL_FC(MaskedFramesSSL):
372    def __init__(
373        self,
374        dims: torch.Size,
375        num_f_maps: torch.Size,
376        frac_masked: float = 0.1,
377        num_ssl_layers: int = 3,
378        num_ssl_f_maps: int = 16,
379    ) -> None:
380        """
381        Parameters
382        ----------
383        dims : torch.Size
384            the shape of features in model input
385        num_f_maps : torch.Size
386            shape of feature extraction output
387        frac_masked : float, default 0.1
388            fraction of frames to real_lens
389        num_ssl_layers : int, default 5
390            number of layers in the SSL module
391        num_ssl_f_maps : int, default 16
392            number of feature maps in the SSL module
393        """
394        super().__init__(frac_masked)
395        dim = int(sum([s[0] for s in dims.values()]))
396        num_f_maps = int(num_f_maps[0])
397        self.pars = {
398            "dim": dim,
399            "num_f_maps": num_f_maps,
400            "num_ssl_layers": num_ssl_layers,
401            "num_ssl_f_maps": num_ssl_f_maps,
402        }
403
404    def construct_module(self) -> Union[nn.Module, None]:
405        """
406        A fully connected module
407        """
408
409        module = _FC(**self.pars)
410        return module
411
412
413class MaskedFramesSSL_TCN(MaskedFramesSSL):
414    def __init__(
415        self,
416        dims: torch.Size,
417        num_f_maps: torch.Size,
418        frac_masked: float = 0.2,
419        num_ssl_layers: int = 5,
420    ) -> None:
421        """
422        Parameters
423        ----------
424        dims : torch.Size
425            the number of features in model input
426        num_f_maps : torch.Size
427            shape of feature extraction output
428        frac_masked : float, default 0.1
429            fraction of frames to real_lens
430        num_ssl_layers : int, default 5
431            number of layers in the SSL module
432        """
433
434        super().__init__(frac_masked)
435        dim = int(sum([s[0] for s in dims.values()]))
436        num_f_maps = int(num_f_maps[0])
437        self.pars = {
438            "input_dim": num_f_maps,
439            "num_layers": num_ssl_layers,
440            "output_dim": dim,
441        }
442
443    def construct_module(self) -> Union[nn.Module, None]:
444        """
445        A TCN module
446        """
447
448        module = _DilatedTCN(**self.pars)
449        return module
class MaskedFeaturesSSL(dlc2action.ssl.base_ssl.SSLConstructor, abc.ABC):
23class MaskedFeaturesSSL(SSLConstructor, ABC):
24    """
25    A base masked features SSL class
26
27    Mask some of the input features randomly and predict the initial data.
28    """
29
30    type = "ssl_input"
31
32    def __init__(self, frac_masked: float = 0.2) -> None:
33        """
34        Parameters
35        ----------
36        frac_masked : float
37            fraction of features to real_lens
38        """
39
40        super().__init__()
41        self.mse = _MSE()
42        self.frac_masked = frac_masked
43
44    def transformation(self, sample_data: Dict) -> Tuple:
45        """
46        Mask some of the features randomly
47        """
48
49        for key in sample_data:
50            mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked
51            sample_data[key] = sample_data[key] * mask
52        ssl_target = torch.cat(list(sample_data.values()))
53        return (sample_data, ssl_target)
54
55    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
56        """
57        MSE loss
58        """
59
60        loss = self.mse(predicted, target)
61        return loss
62
63    @abstractmethod
64    def construct_module(self) -> nn.Module:
65        """
66        Construct the SSL prediction module using the parameters specified at initialization
67        """

A base masked features SSL class

Mask some of the input features randomly and predict the initial data.

MaskedFeaturesSSL(frac_masked: float = 0.2)
32    def __init__(self, frac_masked: float = 0.2) -> None:
33        """
34        Parameters
35        ----------
36        frac_masked : float
37            fraction of features to real_lens
38        """
39
40        super().__init__()
41        self.mse = _MSE()
42        self.frac_masked = frac_masked

Parameters

frac_masked : float fraction of features to real_lens

type = 'ssl_input'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
44    def transformation(self, sample_data: Dict) -> Tuple:
45        """
46        Mask some of the features randomly
47        """
48
49        for key in sample_data:
50            mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked
51            sample_data[key] = sample_data[key] * mask
52        ssl_target = torch.cat(list(sample_data.values()))
53        return (sample_data, ssl_target)

Mask some of the features randomly

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
55    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
56        """
57        MSE loss
58        """
59
60        loss = self.mse(predicted, target)
61        return loss

MSE loss

@abstractmethod
def construct_module(self) -> torch.nn.modules.module.Module:
63    @abstractmethod
64    def construct_module(self) -> nn.Module:
65        """
66        Construct the SSL prediction module using the parameters specified at initialization
67        """

Construct the SSL prediction module using the parameters specified at initialization

class MaskedFeaturesSSL_FC(MaskedFeaturesSSL):
 70class MaskedFeaturesSSL_FC(MaskedFeaturesSSL):
 71    """
 72    A fully connected masked features SSL class
 73
 74    Mask some of the input features randomly and predict the initial data.
 75    """
 76
 77    def __init__(
 78        self,
 79        dims: torch.Size,
 80        num_f_maps: torch.Size,
 81        frac_masked: float = 0.2,
 82        num_ssl_layers: int = 5,
 83        num_ssl_f_maps: int = 16,
 84    ) -> None:
 85        """
 86        Parameters
 87        ----------
 88        dims : torch.Size
 89            the shape of features in model input
 90        num_f_maps : torch.Size
 91            shape of feature extraction output
 92        frac_masked : float, default 0.1
 93            fraction of features to real_lens
 94        num_ssl_layers : int, default 5
 95            number of layers in the SSL module
 96        num_ssl_f_maps : int, default 16
 97            number of feature maps in the SSL module
 98        """
 99
100        super().__init__(frac_masked)
101        dim = int(sum([s[0] for s in dims.values()]))
102        num_f_maps = int(num_f_maps[0])
103        self.pars = {
104            "dim": dim,
105            "num_f_maps": num_f_maps,
106            "num_ssl_layers": num_ssl_layers,
107            "num_ssl_f_maps": num_ssl_f_maps,
108        }
109
110    def construct_module(self) -> Union[nn.Module, None]:
111        """
112        A fully connected module
113        """
114
115        module = _FC(**self.pars)
116        return module

A fully connected masked features SSL class

Mask some of the input features randomly and predict the initial data.

MaskedFeaturesSSL_FC( dims: torch.Size, num_f_maps: torch.Size, frac_masked: float = 0.2, num_ssl_layers: int = 5, num_ssl_f_maps: int = 16)
 77    def __init__(
 78        self,
 79        dims: torch.Size,
 80        num_f_maps: torch.Size,
 81        frac_masked: float = 0.2,
 82        num_ssl_layers: int = 5,
 83        num_ssl_f_maps: int = 16,
 84    ) -> None:
 85        """
 86        Parameters
 87        ----------
 88        dims : torch.Size
 89            the shape of features in model input
 90        num_f_maps : torch.Size
 91            shape of feature extraction output
 92        frac_masked : float, default 0.1
 93            fraction of features to real_lens
 94        num_ssl_layers : int, default 5
 95            number of layers in the SSL module
 96        num_ssl_f_maps : int, default 16
 97            number of feature maps in the SSL module
 98        """
 99
100        super().__init__(frac_masked)
101        dim = int(sum([s[0] for s in dims.values()]))
102        num_f_maps = int(num_f_maps[0])
103        self.pars = {
104            "dim": dim,
105            "num_f_maps": num_f_maps,
106            "num_ssl_layers": num_ssl_layers,
107            "num_ssl_f_maps": num_ssl_f_maps,
108        }

Parameters

dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of features to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
110    def construct_module(self) -> Union[nn.Module, None]:
111        """
112        A fully connected module
113        """
114
115        module = _FC(**self.pars)
116        return module

A fully connected module

class MaskedFeaturesSSL_TCN(MaskedFeaturesSSL):
119class MaskedFeaturesSSL_TCN(MaskedFeaturesSSL):
120    """
121    A TCN masked features SSL class
122
123    Mask some of the input features randomly and predict the initial data.
124    """
125
126    def __init__(
127        self,
128        dims: Dict,
129        num_f_maps: torch.Size,
130        frac_masked: float = 0.2,
131        num_ssl_layers: int = 5,
132    ) -> None:
133        """
134        Parameters
135        ----------
136        dims : torch.Size
137            the shape of features in model input
138        num_f_maps : torch.Size
139            shape of feature extraction output
140        frac_masked : float, default 0.1
141            fraction of features to real_lens
142        num_ssl_layers : int, default 5
143            number of layers in the SSL module
144        """
145
146        super().__init__(frac_masked)
147        dim = int(sum([s[0] for s in dims.values()]))
148        num_f_maps = int(num_f_maps[0])
149        self.pars = {
150            "input_dim": num_f_maps,
151            "num_layers": num_ssl_layers,
152            "output_dim": dim,
153        }
154
155    def construct_module(self) -> Union[nn.Module, None]:
156        """
157        A TCN module
158        """
159
160        module = _DilatedTCN(**self.pars)
161        return module

A TCN masked features SSL class

Mask some of the input features randomly and predict the initial data.

MaskedFeaturesSSL_TCN( dims: Dict, num_f_maps: torch.Size, frac_masked: float = 0.2, num_ssl_layers: int = 5)
126    def __init__(
127        self,
128        dims: Dict,
129        num_f_maps: torch.Size,
130        frac_masked: float = 0.2,
131        num_ssl_layers: int = 5,
132    ) -> None:
133        """
134        Parameters
135        ----------
136        dims : torch.Size
137            the shape of features in model input
138        num_f_maps : torch.Size
139            shape of feature extraction output
140        frac_masked : float, default 0.1
141            fraction of features to real_lens
142        num_ssl_layers : int, default 5
143            number of layers in the SSL module
144        """
145
146        super().__init__(frac_masked)
147        dim = int(sum([s[0] for s in dims.values()]))
148        num_f_maps = int(num_f_maps[0])
149        self.pars = {
150            "input_dim": num_f_maps,
151            "num_layers": num_ssl_layers,
152            "output_dim": dim,
153        }

Parameters

dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of features to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
155    def construct_module(self) -> Union[nn.Module, None]:
156        """
157        A TCN module
158        """
159
160        module = _DilatedTCN(**self.pars)
161        return module

A TCN module

class MaskedKinematicSSL(dlc2action.ssl.base_ssl.SSLConstructor, abc.ABC):
164class MaskedKinematicSSL(SSLConstructor, ABC):
165    """
166    A base masked joints SSL class
167
168    Mask some of the joints randomly and predict the initial data.
169    """
170
171    type = "ssl_input"
172
173    def __init__(self, frac_masked: float = 0.2) -> None:
174        """
175        Parameters
176        ----------
177        frac_masked : float, default 0.1
178            fraction of features to real_lens
179        """
180
181        super().__init__()
182        self.mse = _MSE()
183        self.frac_masked = frac_masked
184
185    def _get_keys(self, key_bases, x):
186        """
187        Get keys of x that start with one of the strings in key_bases
188        """
189
190        keys = []
191        for key in x:
192            if key.startswith(key_bases):
193                keys.append(key)
194        return keys
195
196    def transformation(self, sample_data: Dict) -> Tuple:
197        """
198        Mask joints randomly
199        """
200
201        key = self._get_keys("coords", sample_data)[0]
202        features, frames = sample_data[key].shape
203        n_bp = features // 2
204        masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked
205        keys = self._get_keys(("intra_distance", "inter_distance"), sample_data)
206        for key in keys:
207            mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2)
208            indices = torch.triu_indices(n_bp, n_bp, 1)
209
210            X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device)
211            X[indices[0], indices[1], :] = sample_data[key]
212            X[mask] = 0
213            X[mask.transpose(0, 1)] = 0
214            sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames)
215        keys = self._get_keys(("speed_joints", "coords", "acc_joints"), sample_data)
216        for key in keys:
217            mask = (
218                masked_joints.repeat(2, frames, 1).transpose(0, 2).reshape((-1, frames))
219            )
220            sample_data[key][mask] = 0
221        keys = self._get_keys("angle_joints_radian", sample_data)
222        for key in keys:
223            mask = masked_joints.repeat(frames, 1).transpose(0, 1)
224            sample_data[key][mask] = 0
225
226        return sample_data, torch.cat(list(sample_data.values()))
227
228    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
229        """
230        MSE loss
231        """
232
233        loss = self.mse(predicted, target)
234        return loss
235
236    @abstractmethod
237    def construct_module(self) -> Union[nn.Module, None]:
238        """
239        Construct the SSL prediction module using the parameters specified at initialization
240        """

A base masked joints SSL class

Mask some of the joints randomly and predict the initial data.

MaskedKinematicSSL(frac_masked: float = 0.2)
173    def __init__(self, frac_masked: float = 0.2) -> None:
174        """
175        Parameters
176        ----------
177        frac_masked : float, default 0.1
178            fraction of features to real_lens
179        """
180
181        super().__init__()
182        self.mse = _MSE()
183        self.frac_masked = frac_masked

Parameters

frac_masked : float, default 0.1 fraction of features to real_lens

type = 'ssl_input'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
196    def transformation(self, sample_data: Dict) -> Tuple:
197        """
198        Mask joints randomly
199        """
200
201        key = self._get_keys("coords", sample_data)[0]
202        features, frames = sample_data[key].shape
203        n_bp = features // 2
204        masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked
205        keys = self._get_keys(("intra_distance", "inter_distance"), sample_data)
206        for key in keys:
207            mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2)
208            indices = torch.triu_indices(n_bp, n_bp, 1)
209
210            X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device)
211            X[indices[0], indices[1], :] = sample_data[key]
212            X[mask] = 0
213            X[mask.transpose(0, 1)] = 0
214            sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames)
215        keys = self._get_keys(("speed_joints", "coords", "acc_joints"), sample_data)
216        for key in keys:
217            mask = (
218                masked_joints.repeat(2, frames, 1).transpose(0, 2).reshape((-1, frames))
219            )
220            sample_data[key][mask] = 0
221        keys = self._get_keys("angle_joints_radian", sample_data)
222        for key in keys:
223            mask = masked_joints.repeat(frames, 1).transpose(0, 1)
224            sample_data[key][mask] = 0
225
226        return sample_data, torch.cat(list(sample_data.values()))

Mask joints randomly

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
228    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
229        """
230        MSE loss
231        """
232
233        loss = self.mse(predicted, target)
234        return loss

MSE loss

@abstractmethod
def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
236    @abstractmethod
237    def construct_module(self) -> Union[nn.Module, None]:
238        """
239        Construct the SSL prediction module using the parameters specified at initialization
240        """

Construct the SSL prediction module using the parameters specified at initialization

class MaskedKinematicSSL_FC(MaskedKinematicSSL):
243class MaskedKinematicSSL_FC(MaskedKinematicSSL):
244    def __init__(
245        self,
246        dims: torch.Size,
247        num_f_maps: torch.Size,
248        frac_masked: float = 0.2,
249        num_ssl_layers: int = 5,
250        num_ssl_f_maps: int = 16,
251    ) -> None:
252        """
253        Parameters
254        ----------
255        dims : torch.Size
256            the number of features in model input
257        num_f_maps : torch.Size
258            shape of feature extraction output
259        frac_masked : float, default 0.1
260            fraction of joints to real_lens
261        num_ssl_layers : int, default 5
262            number of layers in the SSL module
263        num_ssl_f_maps : int, default 16
264            number of feature maps in the SSL module
265        """
266
267        super().__init__(frac_masked)
268        dim = int(sum([s[0] for s in dims.values()]))
269        num_f_maps = int(num_f_maps[0])
270        self.pars = {
271            "dim": dim,
272            "num_f_maps": num_f_maps,
273            "num_ssl_layers": num_ssl_layers,
274            "num_ssl_f_maps": num_ssl_f_maps,
275        }
276
277    def construct_module(self) -> Union[nn.Module, None]:
278        """
279        A fully connected module
280        """
281
282        module = _FC(**self.pars)
283        return module

A base masked joints SSL class

Mask some of the joints randomly and predict the initial data.

MaskedKinematicSSL_FC( dims: torch.Size, num_f_maps: torch.Size, frac_masked: float = 0.2, num_ssl_layers: int = 5, num_ssl_f_maps: int = 16)
244    def __init__(
245        self,
246        dims: torch.Size,
247        num_f_maps: torch.Size,
248        frac_masked: float = 0.2,
249        num_ssl_layers: int = 5,
250        num_ssl_f_maps: int = 16,
251    ) -> None:
252        """
253        Parameters
254        ----------
255        dims : torch.Size
256            the number of features in model input
257        num_f_maps : torch.Size
258            shape of feature extraction output
259        frac_masked : float, default 0.1
260            fraction of joints to real_lens
261        num_ssl_layers : int, default 5
262            number of layers in the SSL module
263        num_ssl_f_maps : int, default 16
264            number of feature maps in the SSL module
265        """
266
267        super().__init__(frac_masked)
268        dim = int(sum([s[0] for s in dims.values()]))
269        num_f_maps = int(num_f_maps[0])
270        self.pars = {
271            "dim": dim,
272            "num_f_maps": num_f_maps,
273            "num_ssl_layers": num_ssl_layers,
274            "num_ssl_f_maps": num_ssl_f_maps,
275        }

Parameters

dims : torch.Size the number of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of joints to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
277    def construct_module(self) -> Union[nn.Module, None]:
278        """
279        A fully connected module
280        """
281
282        module = _FC(**self.pars)
283        return module

A fully connected module

class MaskedKinematicSSL_TCN(MaskedKinematicSSL):
286class MaskedKinematicSSL_TCN(MaskedKinematicSSL):
287    def __init__(
288        self,
289        dims: torch.Size,
290        num_f_maps: torch.Size,
291        frac_masked: float = 0.2,
292        num_ssl_layers: int = 5,
293    ) -> None:
294        """
295        Parameters
296        ----------
297        dims : torch.Size
298            the shape of features in model input
299        num_f_maps : torch.Size
300            shape of feature extraction output
301        frac_masked : float, default 0.1
302            fraction of joints to real_lens
303        num_ssl_layers : int, default 5
304            number of layers in the SSL module
305        """
306
307        super().__init__(frac_masked)
308        dim = int(sum([s[0] for s in dims.values()]))
309        num_f_maps = int(num_f_maps[0])
310        self.pars = {
311            "input_dim": num_f_maps,
312            "num_layers": num_ssl_layers,
313            "output_dim": dim,
314        }
315
316    def construct_module(self) -> Union[nn.Module, None]:
317        """
318        A TCN module
319        """
320
321        module = _DilatedTCN(**self.pars)
322        return module

A base masked joints SSL class

Mask some of the joints randomly and predict the initial data.

MaskedKinematicSSL_TCN( dims: torch.Size, num_f_maps: torch.Size, frac_masked: float = 0.2, num_ssl_layers: int = 5)
287    def __init__(
288        self,
289        dims: torch.Size,
290        num_f_maps: torch.Size,
291        frac_masked: float = 0.2,
292        num_ssl_layers: int = 5,
293    ) -> None:
294        """
295        Parameters
296        ----------
297        dims : torch.Size
298            the shape of features in model input
299        num_f_maps : torch.Size
300            shape of feature extraction output
301        frac_masked : float, default 0.1
302            fraction of joints to real_lens
303        num_ssl_layers : int, default 5
304            number of layers in the SSL module
305        """
306
307        super().__init__(frac_masked)
308        dim = int(sum([s[0] for s in dims.values()]))
309        num_f_maps = int(num_f_maps[0])
310        self.pars = {
311            "input_dim": num_f_maps,
312            "num_layers": num_ssl_layers,
313            "output_dim": dim,
314        }

Parameters

dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of joints to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
316    def construct_module(self) -> Union[nn.Module, None]:
317        """
318        A TCN module
319        """
320
321        module = _DilatedTCN(**self.pars)
322        return module

A TCN module

class MaskedFramesSSL(dlc2action.ssl.base_ssl.SSLConstructor, abc.ABC):
325class MaskedFramesSSL(SSLConstructor, ABC):
326    """Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly
327    and predict the initial data"""
328
329    type = "ssl_input"
330
331    def __init__(self, frac_masked: float = 0.1) -> None:
332        """
333        Parameters
334        ----------
335        frac_masked : float, default 0.1
336            fraction of frames to real_lens
337        """
338
339        super().__init__()
340        self.frac_masked = frac_masked
341        self.mse = _MSE()
342
343    def transformation(self, sample_data: Dict) -> Tuple:
344        """
345        Mask some of the frames randomly
346        """
347
348        key = list(sample_data.keys())[0]
349        num_frames = sample_data[key].shape[-1]
350        mask = torch.empty(num_frames).normal_() > self.frac_masked
351        mask = mask.unsqueeze(0)
352        for key in sample_data:
353            sample_data[key] = sample_data[key] * mask
354        ssl_target = torch.cat(list(sample_data.values()))
355        return (sample_data, ssl_target)
356
357    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
358        """
359        MSE loss
360        """
361
362        loss = self.mse(predicted, target)
363        return loss
364
365    @abstractmethod
366    def construct_module(self) -> Union[nn.Module, None]:
367        """
368        Construct the SSL prediction module using the parameters specified at initialization
369        """

Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly and predict the initial data

MaskedFramesSSL(frac_masked: float = 0.1)
331    def __init__(self, frac_masked: float = 0.1) -> None:
332        """
333        Parameters
334        ----------
335        frac_masked : float, default 0.1
336            fraction of frames to real_lens
337        """
338
339        super().__init__()
340        self.frac_masked = frac_masked
341        self.mse = _MSE()

Parameters

frac_masked : float, default 0.1 fraction of frames to real_lens

type = 'ssl_input'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
343    def transformation(self, sample_data: Dict) -> Tuple:
344        """
345        Mask some of the frames randomly
346        """
347
348        key = list(sample_data.keys())[0]
349        num_frames = sample_data[key].shape[-1]
350        mask = torch.empty(num_frames).normal_() > self.frac_masked
351        mask = mask.unsqueeze(0)
352        for key in sample_data:
353            sample_data[key] = sample_data[key] * mask
354        ssl_target = torch.cat(list(sample_data.values()))
355        return (sample_data, ssl_target)

Mask some of the frames randomly

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
357    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
358        """
359        MSE loss
360        """
361
362        loss = self.mse(predicted, target)
363        return loss

MSE loss

@abstractmethod
def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
365    @abstractmethod
366    def construct_module(self) -> Union[nn.Module, None]:
367        """
368        Construct the SSL prediction module using the parameters specified at initialization
369        """

Construct the SSL prediction module using the parameters specified at initialization

class MaskedFramesSSL_FC(MaskedFramesSSL):
372class MaskedFramesSSL_FC(MaskedFramesSSL):
373    def __init__(
374        self,
375        dims: torch.Size,
376        num_f_maps: torch.Size,
377        frac_masked: float = 0.1,
378        num_ssl_layers: int = 3,
379        num_ssl_f_maps: int = 16,
380    ) -> None:
381        """
382        Parameters
383        ----------
384        dims : torch.Size
385            the shape of features in model input
386        num_f_maps : torch.Size
387            shape of feature extraction output
388        frac_masked : float, default 0.1
389            fraction of frames to real_lens
390        num_ssl_layers : int, default 5
391            number of layers in the SSL module
392        num_ssl_f_maps : int, default 16
393            number of feature maps in the SSL module
394        """
395        super().__init__(frac_masked)
396        dim = int(sum([s[0] for s in dims.values()]))
397        num_f_maps = int(num_f_maps[0])
398        self.pars = {
399            "dim": dim,
400            "num_f_maps": num_f_maps,
401            "num_ssl_layers": num_ssl_layers,
402            "num_ssl_f_maps": num_ssl_f_maps,
403        }
404
405    def construct_module(self) -> Union[nn.Module, None]:
406        """
407        A fully connected module
408        """
409
410        module = _FC(**self.pars)
411        return module

Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly and predict the initial data

MaskedFramesSSL_FC( dims: torch.Size, num_f_maps: torch.Size, frac_masked: float = 0.1, num_ssl_layers: int = 3, num_ssl_f_maps: int = 16)
373    def __init__(
374        self,
375        dims: torch.Size,
376        num_f_maps: torch.Size,
377        frac_masked: float = 0.1,
378        num_ssl_layers: int = 3,
379        num_ssl_f_maps: int = 16,
380    ) -> None:
381        """
382        Parameters
383        ----------
384        dims : torch.Size
385            the shape of features in model input
386        num_f_maps : torch.Size
387            shape of feature extraction output
388        frac_masked : float, default 0.1
389            fraction of frames to real_lens
390        num_ssl_layers : int, default 5
391            number of layers in the SSL module
392        num_ssl_f_maps : int, default 16
393            number of feature maps in the SSL module
394        """
395        super().__init__(frac_masked)
396        dim = int(sum([s[0] for s in dims.values()]))
397        num_f_maps = int(num_f_maps[0])
398        self.pars = {
399            "dim": dim,
400            "num_f_maps": num_f_maps,
401            "num_ssl_layers": num_ssl_layers,
402            "num_ssl_f_maps": num_ssl_f_maps,
403        }

Parameters

dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of frames to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
405    def construct_module(self) -> Union[nn.Module, None]:
406        """
407        A fully connected module
408        """
409
410        module = _FC(**self.pars)
411        return module

A fully connected module

class MaskedFramesSSL_TCN(MaskedFramesSSL):
414class MaskedFramesSSL_TCN(MaskedFramesSSL):
415    def __init__(
416        self,
417        dims: torch.Size,
418        num_f_maps: torch.Size,
419        frac_masked: float = 0.2,
420        num_ssl_layers: int = 5,
421    ) -> None:
422        """
423        Parameters
424        ----------
425        dims : torch.Size
426            the number of features in model input
427        num_f_maps : torch.Size
428            shape of feature extraction output
429        frac_masked : float, default 0.1
430            fraction of frames to real_lens
431        num_ssl_layers : int, default 5
432            number of layers in the SSL module
433        """
434
435        super().__init__(frac_masked)
436        dim = int(sum([s[0] for s in dims.values()]))
437        num_f_maps = int(num_f_maps[0])
438        self.pars = {
439            "input_dim": num_f_maps,
440            "num_layers": num_ssl_layers,
441            "output_dim": dim,
442        }
443
444    def construct_module(self) -> Union[nn.Module, None]:
445        """
446        A TCN module
447        """
448
449        module = _DilatedTCN(**self.pars)
450        return module

Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly and predict the initial data

MaskedFramesSSL_TCN( dims: torch.Size, num_f_maps: torch.Size, frac_masked: float = 0.2, num_ssl_layers: int = 5)
415    def __init__(
416        self,
417        dims: torch.Size,
418        num_f_maps: torch.Size,
419        frac_masked: float = 0.2,
420        num_ssl_layers: int = 5,
421    ) -> None:
422        """
423        Parameters
424        ----------
425        dims : torch.Size
426            the number of features in model input
427        num_f_maps : torch.Size
428            shape of feature extraction output
429        frac_masked : float, default 0.1
430            fraction of frames to real_lens
431        num_ssl_layers : int, default 5
432            number of layers in the SSL module
433        """
434
435        super().__init__(frac_masked)
436        dim = int(sum([s[0] for s in dims.values()]))
437        num_f_maps = int(num_f_maps[0])
438        self.pars = {
439            "input_dim": num_f_maps,
440            "num_layers": num_ssl_layers,
441            "output_dim": dim,
442        }

Parameters

dims : torch.Size the number of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of frames to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
444    def construct_module(self) -> Union[nn.Module, None]:
445        """
446        A TCN module
447        """
448
449        module = _DilatedTCN(**self.pars)
450        return module

A TCN module