dlc2action.ssl.contrastive

Implementations of dlc2action.ssl.base_ssl.SSLConstructor of the 'contrastive' type

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6"""
  7Implementations of `dlc2action.ssl.base_ssl.SSLConstructor` of the `'contrastive'` type
  8"""
  9
 10from typing import Dict, Tuple, Union
 11from dlc2action.ssl.base_ssl import SSLConstructor
 12from dlc2action.loss.contrastive import _NTXent, _CircleLoss, _TripletLoss
 13from dlc2action.loss.contrastive_frame import _ContrastiveRegressionLoss
 14from dlc2action.ssl.modules import _FeatureExtractorTCN, _MFeatureExtractorTCN, _FC
 15import torch
 16from torch import nn
 17from copy import deepcopy
 18
 19
 20class ContrastiveSSL(SSLConstructor):
 21    """
 22    A contrastive SSL class with an NT-Xent loss
 23
 24    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
 25    input sample at runtime).
 26    """
 27
 28    type = "contrastive"
 29
 30    def __init__(
 31        self,
 32        num_f_maps: torch.Size,
 33        len_segment: int,
 34        ssl_features: int = 128,
 35        tau: float = 1,
 36    ) -> None:
 37        """
 38        Parameters
 39        ----------
 40        num_f_maps : torch.Size
 41            shape of feature extractor output
 42        len_segment : int
 43            length of segment in the base feature extractor output
 44        ssl_features : int, default 128
 45            the final number of features per clip
 46        tau : float, default 1
 47            the tau parameter of NT-Xent loss
 48        """
 49
 50        super().__init__()
 51        self.loss_function = _NTXent(tau)
 52        if len(num_f_maps) > 1:
 53            raise RuntimeError(
 54                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
 55                f"got {len(num_f_maps) + 1} dimensions"
 56            )
 57        num_f_maps = int(num_f_maps[0])
 58        self.pars = {
 59            "num_f_maps": num_f_maps,
 60            "len_segment": len_segment,
 61            "output_dim": ssl_features,
 62            "kernel_1": 5,
 63            "kernel_2": 5,
 64            "stride": 2,
 65            "decrease_f_maps": True,
 66        }
 67
 68    def transformation(self, sample_data: Dict) -> Tuple:
 69        """
 70        Empty transformation
 71        """
 72
 73        return torch.tensor(float("nan")), torch.tensor(float("nan"))
 74
 75    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
 76        """
 77        NT-Xent loss
 78        """
 79
 80        features1, features2 = predicted
 81        loss = self.loss_function(features1, features2)
 82        return loss
 83
 84    def construct_module(self) -> Union[nn.Module, None]:
 85        """
 86        Clip-wise feature TCN extractor
 87        """
 88
 89        module = _FeatureExtractorTCN(**self.pars)
 90        return module
 91
 92
 93class ContrastiveMaskedSSL(SSLConstructor):
 94    """
 95    A contrastive masked SSL class with an NT-Xent loss
 96
 97    A few frames in the middle of each segment are masked and then the output of the second layer of
 98    feature extraction for the segment is used to predict the output of the first layer for the missing frames.
 99    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
100    input sample at runtime).
101    """
102
103    type = "contrastive_2layers"
104
105    def __init__(
106        self,
107        num_f_maps: torch.Size,
108        len_segment: int,
109        ssl_features: int = 128,
110        tau: float = 1,
111        num_masked: int = 10,
112    ) -> None:
113        """
114        Parameters
115        ----------
116        num_f_maps : torch.Size
117            shape of feature extractor output
118        len_segment : int
119            length of segment in the base feature extractor output
120        ssl_features : int, default 128
121            the final number of features per clip
122        tau : float, default 1
123            the tau parameter of NT-Xent loss
124        """
125
126        super().__init__()
127        self.start = int(len_segment // 2 - num_masked // 2)
128        self.end = int(len_segment // 2 + num_masked // 2)
129        self.loss_function = _NTXent(tau)
130        if len(num_f_maps) > 1:
131            raise RuntimeError(
132                "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; "
133                f"got {len(num_f_maps) + 1} dimensions"
134            )
135        num_f_maps = int(num_f_maps[0])
136        self.pars = {
137            "num_f_maps": num_f_maps,
138            "len_segment": len_segment,
139            "output_dim": ssl_features,
140            "kernel_1": 3,
141            "kernel_2": 3,
142            "stride": 1,
143            "start": self.start,
144            "end": self.end,
145        }
146
147    def transformation(self, sample_data: Dict) -> Tuple:
148        """
149        Empty transformation
150        """
151
152        data = deepcopy(sample_data)
153        for key in data.keys():
154            data[key][:, self.start : self.end] = 0
155        return data, torch.tensor(float("nan"))
156        # return torch.tensor(float("nan")), torch.tensor(float("nan"))
157
158    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
159        """
160        NT-Xent loss
161        """
162
163        features, ssl_features = predicted
164        loss = self.loss_function(features, ssl_features)
165        return loss
166
167    def construct_module(self) -> Union[nn.Module, None]:
168        """
169        Clip-wise feature TCN extractor
170        """
171
172        module = _MFeatureExtractorTCN(**self.pars)
173        return module
174
175
176class PairwiseSSL(SSLConstructor):
177    """
178    A pairwise SSL class with triplet or circle loss
179
180    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
181    input sample at runtime).
182    """
183
184    type = "contrastive"
185
186    def __init__(
187        self,
188        num_f_maps: torch.Size,
189        len_segment: int,
190        ssl_features: int = 128,
191        margin: float = 0,
192        distance: str = "cosine",
193        loss: str = "triplet",
194        gamma: float = 1,
195    ) -> None:
196        """
197        Parameters
198        ----------
199        num_f_maps : torch.Size
200            shape of feature extractor output
201        len_segment : int
202            length of segment in feature extractor output
203        ssl_features : int, default 128
204            final number of features per clip
205        margin : float, default 0
206            the margin parameter of triplet or circle loss
207        distance : {'cosine', 'euclidean'}
208            the distance calculation method for triplet or circle loss
209        loss : {'triplet', 'circle'}
210            the loss function name
211        gamma : float, default 1
212            the gamma parameter of circle loss
213        """
214
215        super().__init__()
216        if loss == "triplet":
217            self.loss_function = _TripletLoss(margin=margin, distance=distance)
218        elif loss == "circle":
219            self.loss_function = _CircleLoss(
220                margin=margin, gamma=gamma, distance=distance
221            )
222        else:
223            raise ValueError(
224                f'The {loss} loss is unavailable, please choose from "triplet" and "circle"'
225            )
226        if len(num_f_maps) > 1:
227            raise RuntimeError(
228                "The PairwiseSSL constructor expects the input data to be 2-dimensional; "
229                f"got {len(num_f_maps) + 1} dimensions"
230            )
231        num_f_maps = int(num_f_maps[0])
232        self.pars = {
233            "num_f_maps": num_f_maps,
234            "len_segment": len_segment,
235            "output_dim": ssl_features,
236            "kernel_1": 5,
237            "kernel_2": 5,
238            "stride": 2,
239            "decrease_f_maps": True,
240        }
241
242    def transformation(self, sample_data: Dict) -> Tuple:
243        """
244        Empty transformation
245        """
246
247        return torch.tensor(float("nan")), torch.tensor(float("nan"))
248
249    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
250        """
251        Triplet or circle loss
252        """
253
254        features1, features2 = predicted
255        loss = self.loss_function(features1, features2)
256        return loss
257
258    def construct_module(self) -> Union[nn.Module, None]:
259        """
260        Clip-wise feature TCN extractor
261        """
262
263        module = _FeatureExtractorTCN(**self.pars)
264        return module
265
266
267class PairwiseMaskedSSL(PairwiseSSL):
268
269    type = "contrastive_2layers"
270
271    def __init__(
272        self,
273        num_f_maps: torch.Size,
274        len_segment: int,
275        ssl_features: int = 128,
276        margin: float = 0,
277        distance: str = "cosine",
278        loss: str = "triplet",
279        gamma: float = 1,
280        num_masked: int = 10,
281    ) -> None:
282        super().__init__(
283            num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma
284        )
285        self.num_masked = num_masked
286        self.start = int(len_segment // 2 - num_masked // 2)
287        self.end = int(len_segment // 2 + num_masked // 2)
288        if len(num_f_maps) > 1:
289            raise RuntimeError(
290                "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; "
291                f"got {len(num_f_maps) + 1} dimensions"
292            )
293        num_f_maps = int(num_f_maps[0])
294        self.pars = {
295            "num_f_maps": num_f_maps,
296            "len_segment": len_segment,
297            "output_dim": ssl_features,
298            "kernel_1": 3,
299            "kernel_2": 3,
300            "stride": 1,
301            "start": self.start,
302            "end": self.end,
303        }
304
305    def transformation(self, sample_data: Dict) -> Tuple:
306        """
307        Empty transformation
308        """
309
310        data = deepcopy(sample_data)
311        for key in data.keys():
312            data[key][:, self.start : self.end] = 0
313        return data, torch.tensor(float("nan"))
314
315    def construct_module(self) -> Union[nn.Module, None]:
316        """
317        Clip-wise feature TCN extractor
318        """
319
320        module = _MFeatureExtractorTCN(**self.pars)
321        return module
322
323
324class ContrastiveRegressionSSL(SSLConstructor):
325
326    type = "contrastive"
327
328    def __init__(
329        self,
330        num_f_maps: torch.Size,
331        num_features: int = 128,
332        num_ssl_layers: int = 1,
333        distance: str = "cosine",
334        temperature: float = 1,
335        break_factor: int = None,
336    ) -> None:
337        if len(num_f_maps) > 1:
338            raise RuntimeError(
339                "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; "
340                f"got {len(num_f_maps) + 1} dimensions"
341            )
342        num_f_maps = int(num_f_maps[0])
343        self.loss_function = _ContrastiveRegressionLoss(
344            temperature, distance, break_factor
345        )
346        self.pars = {
347            "num_f_maps": num_f_maps,
348            "num_ssl_layers": num_ssl_layers,
349            "num_ssl_f_maps": num_features,
350            "dim": num_features,
351        }
352        super().__init__()
353
354    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
355        """
356        NT-Xent loss
357        """
358
359        features1, features2 = predicted
360        loss = self.loss_function(features1, features2)
361        return loss
362
363    def transformation(self, sample_data: Dict) -> Tuple:
364        """
365        Empty transformation
366        """
367
368        return torch.tensor(float("nan")), torch.tensor(float("nan"))
369
370    def construct_module(self) -> Union[nn.Module, None]:
371        """
372        Clip-wise feature TCN extractor
373        """
374
375        return _FC(**self.pars)
class ContrastiveSSL(dlc2action.ssl.base_ssl.SSLConstructor):
21class ContrastiveSSL(SSLConstructor):
22    """
23    A contrastive SSL class with an NT-Xent loss
24
25    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
26    input sample at runtime).
27    """
28
29    type = "contrastive"
30
31    def __init__(
32        self,
33        num_f_maps: torch.Size,
34        len_segment: int,
35        ssl_features: int = 128,
36        tau: float = 1,
37    ) -> None:
38        """
39        Parameters
40        ----------
41        num_f_maps : torch.Size
42            shape of feature extractor output
43        len_segment : int
44            length of segment in the base feature extractor output
45        ssl_features : int, default 128
46            the final number of features per clip
47        tau : float, default 1
48            the tau parameter of NT-Xent loss
49        """
50
51        super().__init__()
52        self.loss_function = _NTXent(tau)
53        if len(num_f_maps) > 1:
54            raise RuntimeError(
55                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
56                f"got {len(num_f_maps) + 1} dimensions"
57            )
58        num_f_maps = int(num_f_maps[0])
59        self.pars = {
60            "num_f_maps": num_f_maps,
61            "len_segment": len_segment,
62            "output_dim": ssl_features,
63            "kernel_1": 5,
64            "kernel_2": 5,
65            "stride": 2,
66            "decrease_f_maps": True,
67        }
68
69    def transformation(self, sample_data: Dict) -> Tuple:
70        """
71        Empty transformation
72        """
73
74        return torch.tensor(float("nan")), torch.tensor(float("nan"))
75
76    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
77        """
78        NT-Xent loss
79        """
80
81        features1, features2 = predicted
82        loss = self.loss_function(features1, features2)
83        return loss
84
85    def construct_module(self) -> Union[nn.Module, None]:
86        """
87        Clip-wise feature TCN extractor
88        """
89
90        module = _FeatureExtractorTCN(**self.pars)
91        return module

A contrastive SSL class with an NT-Xent loss

The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).

ContrastiveSSL( num_f_maps: torch.Size, len_segment: int, ssl_features: int = 128, tau: float = 1)
31    def __init__(
32        self,
33        num_f_maps: torch.Size,
34        len_segment: int,
35        ssl_features: int = 128,
36        tau: float = 1,
37    ) -> None:
38        """
39        Parameters
40        ----------
41        num_f_maps : torch.Size
42            shape of feature extractor output
43        len_segment : int
44            length of segment in the base feature extractor output
45        ssl_features : int, default 128
46            the final number of features per clip
47        tau : float, default 1
48            the tau parameter of NT-Xent loss
49        """
50
51        super().__init__()
52        self.loss_function = _NTXent(tau)
53        if len(num_f_maps) > 1:
54            raise RuntimeError(
55                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
56                f"got {len(num_f_maps) + 1} dimensions"
57            )
58        num_f_maps = int(num_f_maps[0])
59        self.pars = {
60            "num_f_maps": num_f_maps,
61            "len_segment": len_segment,
62            "output_dim": ssl_features,
63            "kernel_1": 5,
64            "kernel_2": 5,
65            "stride": 2,
66            "decrease_f_maps": True,
67        }

Parameters

num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in the base feature extractor output ssl_features : int, default 128 the final number of features per clip tau : float, default 1 the tau parameter of NT-Xent loss

type = 'contrastive'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
69    def transformation(self, sample_data: Dict) -> Tuple:
70        """
71        Empty transformation
72        """
73
74        return torch.tensor(float("nan")), torch.tensor(float("nan"))

Empty transformation

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
76    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
77        """
78        NT-Xent loss
79        """
80
81        features1, features2 = predicted
82        loss = self.loss_function(features1, features2)
83        return loss

NT-Xent loss

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
85    def construct_module(self) -> Union[nn.Module, None]:
86        """
87        Clip-wise feature TCN extractor
88        """
89
90        module = _FeatureExtractorTCN(**self.pars)
91        return module

Clip-wise feature TCN extractor

class ContrastiveMaskedSSL(dlc2action.ssl.base_ssl.SSLConstructor):
 94class ContrastiveMaskedSSL(SSLConstructor):
 95    """
 96    A contrastive masked SSL class with an NT-Xent loss
 97
 98    A few frames in the middle of each segment are masked and then the output of the second layer of
 99    feature extraction for the segment is used to predict the output of the first layer for the missing frames.
100    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
101    input sample at runtime).
102    """
103
104    type = "contrastive_2layers"
105
106    def __init__(
107        self,
108        num_f_maps: torch.Size,
109        len_segment: int,
110        ssl_features: int = 128,
111        tau: float = 1,
112        num_masked: int = 10,
113    ) -> None:
114        """
115        Parameters
116        ----------
117        num_f_maps : torch.Size
118            shape of feature extractor output
119        len_segment : int
120            length of segment in the base feature extractor output
121        ssl_features : int, default 128
122            the final number of features per clip
123        tau : float, default 1
124            the tau parameter of NT-Xent loss
125        """
126
127        super().__init__()
128        self.start = int(len_segment // 2 - num_masked // 2)
129        self.end = int(len_segment // 2 + num_masked // 2)
130        self.loss_function = _NTXent(tau)
131        if len(num_f_maps) > 1:
132            raise RuntimeError(
133                "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; "
134                f"got {len(num_f_maps) + 1} dimensions"
135            )
136        num_f_maps = int(num_f_maps[0])
137        self.pars = {
138            "num_f_maps": num_f_maps,
139            "len_segment": len_segment,
140            "output_dim": ssl_features,
141            "kernel_1": 3,
142            "kernel_2": 3,
143            "stride": 1,
144            "start": self.start,
145            "end": self.end,
146        }
147
148    def transformation(self, sample_data: Dict) -> Tuple:
149        """
150        Empty transformation
151        """
152
153        data = deepcopy(sample_data)
154        for key in data.keys():
155            data[key][:, self.start : self.end] = 0
156        return data, torch.tensor(float("nan"))
157        # return torch.tensor(float("nan")), torch.tensor(float("nan"))
158
159    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
160        """
161        NT-Xent loss
162        """
163
164        features, ssl_features = predicted
165        loss = self.loss_function(features, ssl_features)
166        return loss
167
168    def construct_module(self) -> Union[nn.Module, None]:
169        """
170        Clip-wise feature TCN extractor
171        """
172
173        module = _MFeatureExtractorTCN(**self.pars)
174        return module

A contrastive masked SSL class with an NT-Xent loss

A few frames in the middle of each segment are masked and then the output of the second layer of feature extraction for the segment is used to predict the output of the first layer for the missing frames. The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).

ContrastiveMaskedSSL( num_f_maps: torch.Size, len_segment: int, ssl_features: int = 128, tau: float = 1, num_masked: int = 10)
106    def __init__(
107        self,
108        num_f_maps: torch.Size,
109        len_segment: int,
110        ssl_features: int = 128,
111        tau: float = 1,
112        num_masked: int = 10,
113    ) -> None:
114        """
115        Parameters
116        ----------
117        num_f_maps : torch.Size
118            shape of feature extractor output
119        len_segment : int
120            length of segment in the base feature extractor output
121        ssl_features : int, default 128
122            the final number of features per clip
123        tau : float, default 1
124            the tau parameter of NT-Xent loss
125        """
126
127        super().__init__()
128        self.start = int(len_segment // 2 - num_masked // 2)
129        self.end = int(len_segment // 2 + num_masked // 2)
130        self.loss_function = _NTXent(tau)
131        if len(num_f_maps) > 1:
132            raise RuntimeError(
133                "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; "
134                f"got {len(num_f_maps) + 1} dimensions"
135            )
136        num_f_maps = int(num_f_maps[0])
137        self.pars = {
138            "num_f_maps": num_f_maps,
139            "len_segment": len_segment,
140            "output_dim": ssl_features,
141            "kernel_1": 3,
142            "kernel_2": 3,
143            "stride": 1,
144            "start": self.start,
145            "end": self.end,
146        }

Parameters

num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in the base feature extractor output ssl_features : int, default 128 the final number of features per clip tau : float, default 1 the tau parameter of NT-Xent loss

type = 'contrastive_2layers'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
148    def transformation(self, sample_data: Dict) -> Tuple:
149        """
150        Empty transformation
151        """
152
153        data = deepcopy(sample_data)
154        for key in data.keys():
155            data[key][:, self.start : self.end] = 0
156        return data, torch.tensor(float("nan"))
157        # return torch.tensor(float("nan")), torch.tensor(float("nan"))

Empty transformation

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
159    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
160        """
161        NT-Xent loss
162        """
163
164        features, ssl_features = predicted
165        loss = self.loss_function(features, ssl_features)
166        return loss

NT-Xent loss

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
168    def construct_module(self) -> Union[nn.Module, None]:
169        """
170        Clip-wise feature TCN extractor
171        """
172
173        module = _MFeatureExtractorTCN(**self.pars)
174        return module

Clip-wise feature TCN extractor

class PairwiseSSL(dlc2action.ssl.base_ssl.SSLConstructor):
177class PairwiseSSL(SSLConstructor):
178    """
179    A pairwise SSL class with triplet or circle loss
180
181    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
182    input sample at runtime).
183    """
184
185    type = "contrastive"
186
187    def __init__(
188        self,
189        num_f_maps: torch.Size,
190        len_segment: int,
191        ssl_features: int = 128,
192        margin: float = 0,
193        distance: str = "cosine",
194        loss: str = "triplet",
195        gamma: float = 1,
196    ) -> None:
197        """
198        Parameters
199        ----------
200        num_f_maps : torch.Size
201            shape of feature extractor output
202        len_segment : int
203            length of segment in feature extractor output
204        ssl_features : int, default 128
205            final number of features per clip
206        margin : float, default 0
207            the margin parameter of triplet or circle loss
208        distance : {'cosine', 'euclidean'}
209            the distance calculation method for triplet or circle loss
210        loss : {'triplet', 'circle'}
211            the loss function name
212        gamma : float, default 1
213            the gamma parameter of circle loss
214        """
215
216        super().__init__()
217        if loss == "triplet":
218            self.loss_function = _TripletLoss(margin=margin, distance=distance)
219        elif loss == "circle":
220            self.loss_function = _CircleLoss(
221                margin=margin, gamma=gamma, distance=distance
222            )
223        else:
224            raise ValueError(
225                f'The {loss} loss is unavailable, please choose from "triplet" and "circle"'
226            )
227        if len(num_f_maps) > 1:
228            raise RuntimeError(
229                "The PairwiseSSL constructor expects the input data to be 2-dimensional; "
230                f"got {len(num_f_maps) + 1} dimensions"
231            )
232        num_f_maps = int(num_f_maps[0])
233        self.pars = {
234            "num_f_maps": num_f_maps,
235            "len_segment": len_segment,
236            "output_dim": ssl_features,
237            "kernel_1": 5,
238            "kernel_2": 5,
239            "stride": 2,
240            "decrease_f_maps": True,
241        }
242
243    def transformation(self, sample_data: Dict) -> Tuple:
244        """
245        Empty transformation
246        """
247
248        return torch.tensor(float("nan")), torch.tensor(float("nan"))
249
250    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
251        """
252        Triplet or circle loss
253        """
254
255        features1, features2 = predicted
256        loss = self.loss_function(features1, features2)
257        return loss
258
259    def construct_module(self) -> Union[nn.Module, None]:
260        """
261        Clip-wise feature TCN extractor
262        """
263
264        module = _FeatureExtractorTCN(**self.pars)
265        return module

A pairwise SSL class with triplet or circle loss

The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).

PairwiseSSL( num_f_maps: torch.Size, len_segment: int, ssl_features: int = 128, margin: float = 0, distance: str = 'cosine', loss: str = 'triplet', gamma: float = 1)
187    def __init__(
188        self,
189        num_f_maps: torch.Size,
190        len_segment: int,
191        ssl_features: int = 128,
192        margin: float = 0,
193        distance: str = "cosine",
194        loss: str = "triplet",
195        gamma: float = 1,
196    ) -> None:
197        """
198        Parameters
199        ----------
200        num_f_maps : torch.Size
201            shape of feature extractor output
202        len_segment : int
203            length of segment in feature extractor output
204        ssl_features : int, default 128
205            final number of features per clip
206        margin : float, default 0
207            the margin parameter of triplet or circle loss
208        distance : {'cosine', 'euclidean'}
209            the distance calculation method for triplet or circle loss
210        loss : {'triplet', 'circle'}
211            the loss function name
212        gamma : float, default 1
213            the gamma parameter of circle loss
214        """
215
216        super().__init__()
217        if loss == "triplet":
218            self.loss_function = _TripletLoss(margin=margin, distance=distance)
219        elif loss == "circle":
220            self.loss_function = _CircleLoss(
221                margin=margin, gamma=gamma, distance=distance
222            )
223        else:
224            raise ValueError(
225                f'The {loss} loss is unavailable, please choose from "triplet" and "circle"'
226            )
227        if len(num_f_maps) > 1:
228            raise RuntimeError(
229                "The PairwiseSSL constructor expects the input data to be 2-dimensional; "
230                f"got {len(num_f_maps) + 1} dimensions"
231            )
232        num_f_maps = int(num_f_maps[0])
233        self.pars = {
234            "num_f_maps": num_f_maps,
235            "len_segment": len_segment,
236            "output_dim": ssl_features,
237            "kernel_1": 5,
238            "kernel_2": 5,
239            "stride": 2,
240            "decrease_f_maps": True,
241        }

Parameters

num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in feature extractor output ssl_features : int, default 128 final number of features per clip margin : float, default 0 the margin parameter of triplet or circle loss distance : {'cosine', 'euclidean'} the distance calculation method for triplet or circle loss loss : {'triplet', 'circle'} the loss function name gamma : float, default 1 the gamma parameter of circle loss

type = 'contrastive'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
243    def transformation(self, sample_data: Dict) -> Tuple:
244        """
245        Empty transformation
246        """
247
248        return torch.tensor(float("nan")), torch.tensor(float("nan"))

Empty transformation

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
250    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
251        """
252        Triplet or circle loss
253        """
254
255        features1, features2 = predicted
256        loss = self.loss_function(features1, features2)
257        return loss

Triplet or circle loss

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
259    def construct_module(self) -> Union[nn.Module, None]:
260        """
261        Clip-wise feature TCN extractor
262        """
263
264        module = _FeatureExtractorTCN(**self.pars)
265        return module

Clip-wise feature TCN extractor

class PairwiseMaskedSSL(PairwiseSSL):
268class PairwiseMaskedSSL(PairwiseSSL):
269
270    type = "contrastive_2layers"
271
272    def __init__(
273        self,
274        num_f_maps: torch.Size,
275        len_segment: int,
276        ssl_features: int = 128,
277        margin: float = 0,
278        distance: str = "cosine",
279        loss: str = "triplet",
280        gamma: float = 1,
281        num_masked: int = 10,
282    ) -> None:
283        super().__init__(
284            num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma
285        )
286        self.num_masked = num_masked
287        self.start = int(len_segment // 2 - num_masked // 2)
288        self.end = int(len_segment // 2 + num_masked // 2)
289        if len(num_f_maps) > 1:
290            raise RuntimeError(
291                "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; "
292                f"got {len(num_f_maps) + 1} dimensions"
293            )
294        num_f_maps = int(num_f_maps[0])
295        self.pars = {
296            "num_f_maps": num_f_maps,
297            "len_segment": len_segment,
298            "output_dim": ssl_features,
299            "kernel_1": 3,
300            "kernel_2": 3,
301            "stride": 1,
302            "start": self.start,
303            "end": self.end,
304        }
305
306    def transformation(self, sample_data: Dict) -> Tuple:
307        """
308        Empty transformation
309        """
310
311        data = deepcopy(sample_data)
312        for key in data.keys():
313            data[key][:, self.start : self.end] = 0
314        return data, torch.tensor(float("nan"))
315
316    def construct_module(self) -> Union[nn.Module, None]:
317        """
318        Clip-wise feature TCN extractor
319        """
320
321        module = _MFeatureExtractorTCN(**self.pars)
322        return module

A pairwise SSL class with triplet or circle loss

The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).

PairwiseMaskedSSL( num_f_maps: torch.Size, len_segment: int, ssl_features: int = 128, margin: float = 0, distance: str = 'cosine', loss: str = 'triplet', gamma: float = 1, num_masked: int = 10)
272    def __init__(
273        self,
274        num_f_maps: torch.Size,
275        len_segment: int,
276        ssl_features: int = 128,
277        margin: float = 0,
278        distance: str = "cosine",
279        loss: str = "triplet",
280        gamma: float = 1,
281        num_masked: int = 10,
282    ) -> None:
283        super().__init__(
284            num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma
285        )
286        self.num_masked = num_masked
287        self.start = int(len_segment // 2 - num_masked // 2)
288        self.end = int(len_segment // 2 + num_masked // 2)
289        if len(num_f_maps) > 1:
290            raise RuntimeError(
291                "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; "
292                f"got {len(num_f_maps) + 1} dimensions"
293            )
294        num_f_maps = int(num_f_maps[0])
295        self.pars = {
296            "num_f_maps": num_f_maps,
297            "len_segment": len_segment,
298            "output_dim": ssl_features,
299            "kernel_1": 3,
300            "kernel_2": 3,
301            "stride": 1,
302            "start": self.start,
303            "end": self.end,
304        }

Parameters

num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in feature extractor output ssl_features : int, default 128 final number of features per clip margin : float, default 0 the margin parameter of triplet or circle loss distance : {'cosine', 'euclidean'} the distance calculation method for triplet or circle loss loss : {'triplet', 'circle'} the loss function name gamma : float, default 1 the gamma parameter of circle loss

type = 'contrastive_2layers'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
306    def transformation(self, sample_data: Dict) -> Tuple:
307        """
308        Empty transformation
309        """
310
311        data = deepcopy(sample_data)
312        for key in data.keys():
313            data[key][:, self.start : self.end] = 0
314        return data, torch.tensor(float("nan"))

Empty transformation

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
316    def construct_module(self) -> Union[nn.Module, None]:
317        """
318        Clip-wise feature TCN extractor
319        """
320
321        module = _MFeatureExtractorTCN(**self.pars)
322        return module

Clip-wise feature TCN extractor

Inherited Members
PairwiseSSL
loss
class ContrastiveRegressionSSL(dlc2action.ssl.base_ssl.SSLConstructor):
325class ContrastiveRegressionSSL(SSLConstructor):
326
327    type = "contrastive"
328
329    def __init__(
330        self,
331        num_f_maps: torch.Size,
332        num_features: int = 128,
333        num_ssl_layers: int = 1,
334        distance: str = "cosine",
335        temperature: float = 1,
336        break_factor: int = None,
337    ) -> None:
338        if len(num_f_maps) > 1:
339            raise RuntimeError(
340                "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; "
341                f"got {len(num_f_maps) + 1} dimensions"
342            )
343        num_f_maps = int(num_f_maps[0])
344        self.loss_function = _ContrastiveRegressionLoss(
345            temperature, distance, break_factor
346        )
347        self.pars = {
348            "num_f_maps": num_f_maps,
349            "num_ssl_layers": num_ssl_layers,
350            "num_ssl_f_maps": num_features,
351            "dim": num_features,
352        }
353        super().__init__()
354
355    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
356        """
357        NT-Xent loss
358        """
359
360        features1, features2 = predicted
361        loss = self.loss_function(features1, features2)
362        return loss
363
364    def transformation(self, sample_data: Dict) -> Tuple:
365        """
366        Empty transformation
367        """
368
369        return torch.tensor(float("nan")), torch.tensor(float("nan"))
370
371    def construct_module(self) -> Union[nn.Module, None]:
372        """
373        Clip-wise feature TCN extractor
374        """
375
376        return _FC(**self.pars)

A base class for all SSL constructors

An SSL method is defined by three things: a transformation that maps a sample into SSL input and output, a neural net module that takes features as input and predicts SSL target, a type and a loss function.

ContrastiveRegressionSSL( num_f_maps: torch.Size, num_features: int = 128, num_ssl_layers: int = 1, distance: str = 'cosine', temperature: float = 1, break_factor: int = None)
329    def __init__(
330        self,
331        num_f_maps: torch.Size,
332        num_features: int = 128,
333        num_ssl_layers: int = 1,
334        distance: str = "cosine",
335        temperature: float = 1,
336        break_factor: int = None,
337    ) -> None:
338        if len(num_f_maps) > 1:
339            raise RuntimeError(
340                "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; "
341                f"got {len(num_f_maps) + 1} dimensions"
342            )
343        num_f_maps = int(num_f_maps[0])
344        self.loss_function = _ContrastiveRegressionLoss(
345            temperature, distance, break_factor
346        )
347        self.pars = {
348            "num_f_maps": num_f_maps,
349            "num_ssl_layers": num_ssl_layers,
350            "num_ssl_f_maps": num_features,
351            "dim": num_features,
352        }
353        super().__init__()
type = 'contrastive'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
355    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
356        """
357        NT-Xent loss
358        """
359
360        features1, features2 = predicted
361        loss = self.loss_function(features1, features2)
362        return loss

NT-Xent loss

def transformation(self, sample_data: Dict) -> Tuple:
364    def transformation(self, sample_data: Dict) -> Tuple:
365        """
366        Empty transformation
367        """
368
369        return torch.tensor(float("nan")), torch.tensor(float("nan"))

Empty transformation

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
371    def construct_module(self) -> Union[nn.Module, None]:
372        """
373        Clip-wise feature TCN extractor
374        """
375
376        return _FC(**self.pars)

Clip-wise feature TCN extractor