dlc2action.ssl.contrastive

Implementations of dlc2action.ssl.base_ssl.SSLConstructor of the 'contrastive' type.

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""Implementations of `dlc2action.ssl.base_ssl.SSLConstructor` of the `'contrastive'` type."""
  8
  9from typing import Dict, Tuple, Union
 10from dlc2action.ssl.base_ssl import SSLConstructor
 11from dlc2action.loss.contrastive import *
 12from dlc2action.loss.contrastive_frame import *
 13from dlc2action.ssl.modules import *
 14from copy import deepcopy
 15
 16
 17class ContrastiveSSL(SSLConstructor):
 18    """A contrastive SSL class with an NT-Xent loss.
 19
 20    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
 21    input sample at runtime).
 22    """
 23
 24    type = "contrastive"
 25
 26    def __init__(
 27        self,
 28        num_f_maps: torch.Size,
 29        len_segment: int,
 30        ssl_features: int = 128,
 31        tau: float = 1,
 32    ) -> None:
 33        """Initialize the SSL constructor.
 34
 35        Parameters
 36        ----------
 37        num_f_maps : torch.Size
 38            shape of feature extractor output
 39        len_segment : int
 40            length of segment in the base feature extractor output
 41        ssl_features : int, default 128
 42            the final number of features per clip
 43        tau : float, default 1
 44            the tau parameter of NT-Xent loss
 45
 46        """
 47        super().__init__()
 48        self.loss_function = NTXent(tau)
 49        if len(num_f_maps) > 1:
 50            raise RuntimeError(
 51                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
 52                f"got {len(num_f_maps) + 1} dimensions"
 53            )
 54        num_f_maps = int(num_f_maps[0])
 55        self.pars = {
 56            "num_f_maps": num_f_maps,
 57            "len_segment": len_segment,
 58            "output_dim": ssl_features,
 59            "kernel_1": 5,
 60            "kernel_2": 5,
 61            "stride": 2,
 62            "decrease_f_maps": True,
 63        }
 64
 65    def transformation(self, sample_data: Dict) -> Tuple:
 66        """Empty transformation."""
 67        return torch.tensor(float("nan")), torch.tensor(float("nan"))
 68
 69    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
 70        """NT-Xent loss."""
 71        features1, features2 = predicted
 72        loss = self.loss_function(features1, features2)
 73        return loss
 74
 75    def construct_module(self) -> Union[nn.Module, None]:
 76        """Clip-wise feature TCN extractor."""
 77        module = FeatureExtractorTCN(**self.pars)
 78        return module
 79
 80
 81class ContrastiveMaskedSSL(SSLConstructor):
 82    """A contrastive masked SSL class with an NT-Xent loss.
 83
 84    A few frames in the middle of each segment are masked and then the output of the second layer of
 85    feature extraction for the segment is used to predict the output of the first layer for the missing frames.
 86    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
 87    input sample at runtime).
 88    """
 89
 90    type = "contrastive_2layers"
 91
 92    def __init__(
 93        self,
 94        num_f_maps: torch.Size,
 95        len_segment: int,
 96        ssl_features: int = 128,
 97        tau: float = 1,
 98        num_masked: int = 10,
 99    ) -> None:
100        """Initialize the ContrastiveMaskedSSL class.
101
102        Parameters
103        ----------
104        num_f_maps : torch.Size
105            shape of feature extractor output
106        len_segment : int
107            length of segment in the base feature extractor output
108        ssl_features : int, default 128
109            the final number of features per clip
110        tau : float, default 1
111            the tau parameter of NT-Xent loss
112        num_masked : int, default 10
113            number of frames to be masked in the middle of each segment
114
115        """
116        super().__init__()
117        self.start = int(len_segment // 2 - num_masked // 2)
118        self.end = int(len_segment // 2 + num_masked // 2)
119        self.loss_function = NTXent(tau)
120        if len(num_f_maps) > 1:
121            raise RuntimeError(
122                "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; "
123                f"got {len(num_f_maps) + 1} dimensions"
124            )
125        num_f_maps = int(num_f_maps[0])
126        self.pars = {
127            "num_f_maps": num_f_maps,
128            "len_segment": len_segment,
129            "output_dim": ssl_features,
130            "kernel_1": 3,
131            "kernel_2": 3,
132            "stride": 1,
133            "start": self.start,
134            "end": self.end,
135        }
136
137    def transformation(self, sample_data: Dict) -> Tuple:
138        """Mask the input data."""
139        data = deepcopy(sample_data)
140        for key in data.keys():
141            data[key][:, self.start : self.end] = 0
142        return data, torch.tensor(float("nan"))
143        # return torch.tensor(float("nan")), torch.tensor(float("nan"))
144
145    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
146        """NT-Xent loss."""
147        features, ssl_features = predicted
148        loss = self.loss_function(features, ssl_features)
149        return loss
150
151    def construct_module(self) -> Union[nn.Module, None]:
152        """Clip-wise feature TCN extractor."""
153        module = MFeatureExtractorTCN(**self.pars)
154        return module
155
156
157class PairwiseSSL(SSLConstructor):
158    """A pairwise SSL class with triplet or circle loss.
159
160    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
161    input sample at runtime).
162    """
163
164    type = "contrastive"
165
166    def __init__(
167        self,
168        num_f_maps: torch.Size,
169        len_segment: int,
170        ssl_features: int = 128,
171        margin: float = 0,
172        distance: str = "cosine",
173        loss: str = "triplet",
174        gamma: float = 1,
175    ) -> None:
176        """Initialize the PairwiseSSL class.
177
178        Parameters
179        ----------
180        num_f_maps : torch.Size
181            shape of feature extractor output
182        len_segment : int
183            length of segment in feature extractor output
184        ssl_features : int, default 128
185            final number of features per clip
186        margin : float, default 0
187            the margin parameter of triplet or circle loss
188        distance : {'cosine', 'euclidean'}
189            the distance calculation method for triplet or circle loss
190        loss : {'triplet', 'circle'}
191            the loss function name
192        gamma : float, default 1
193            the gamma parameter of circle loss
194
195        """
196        super().__init__()
197        if loss == "triplet":
198            self.loss_function = TripletLoss(margin=margin, distance=distance)
199        elif loss == "circle":
200            self.loss_function = CircleLoss(
201                margin=margin, gamma=gamma, distance=distance
202            )
203        else:
204            raise ValueError(
205                f'The {loss} loss is unavailable, please choose from "triplet" and "circle"'
206            )
207        if len(num_f_maps) > 1:
208            raise RuntimeError(
209                "The PairwiseSSL constructor expects the input data to be 2-dimensional; "
210                f"got {len(num_f_maps) + 1} dimensions"
211            )
212        num_f_maps = int(num_f_maps[0])
213        self.pars = {
214            "num_f_maps": num_f_maps,
215            "len_segment": len_segment,
216            "output_dim": ssl_features,
217            "kernel_1": 5,
218            "kernel_2": 5,
219            "stride": 2,
220            "decrease_f_maps": True,
221        }
222
223    def transformation(self, sample_data: Dict) -> Tuple:
224        """Empty transformation."""
225        return torch.tensor(float("nan")), torch.tensor(float("nan"))
226
227    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
228        """Triplet or circle loss."""
229        features1, features2 = predicted
230        loss = self.loss_function(features1, features2)
231        return loss
232
233    def construct_module(self) -> Union[nn.Module, None]:
234        """Clip-wise feature TCN extractor."""
235        module = FeatureExtractorTCN(**self.pars)
236        return module
237
238
239class PairwiseMaskedSSL(PairwiseSSL):
240    """A contrastive SSL class with triplet or circle loss and masked input.
241
242    A few frames in the middle of each segment are masked and then the output of the second layer of
243    feature extraction for the segment is used to predict the output of the first layer for the missing frames.
244    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
245    input sample at runtime).
246    """
247
248    type = "contrastive_2layers"
249
250    def __init__(
251        self,
252        num_f_maps: torch.Size,
253        len_segment: int,
254        ssl_features: int = 128,
255        margin: float = 0,
256        distance: str = "cosine",
257        loss: str = "triplet",
258        gamma: float = 1,
259        num_masked: int = 10,
260    ) -> None:
261        """Initialize the PairwiseMaskedSSL class.
262
263        Parameters
264        ----------
265        num_f_maps : torch.Size
266            shape of feature extractor output
267        len_segment : int
268            length of segment in feature extractor output
269        ssl_features : int, default 128
270            final number of features per clip
271        margin : float, default 0
272            the margin parameter of triplet or circle loss
273        distance : {'cosine', 'euclidean'}
274            the distance calculation method for triplet or circle loss
275        loss : {'triplet', 'circle'}
276            the loss function name
277        gamma : float, default 1
278            the gamma parameter of circle loss
279        num_masked : int, default 10
280            number of masked frames
281
282        """
283        super().__init__(
284            num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma
285        )
286        self.num_masked = num_masked
287        self.start = int(len_segment // 2 - num_masked // 2)
288        self.end = int(len_segment // 2 + num_masked // 2)
289        if len(num_f_maps) > 1:
290            raise RuntimeError(
291                "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; "
292                f"got {len(num_f_maps) + 1} dimensions"
293            )
294        num_f_maps = int(num_f_maps[0])
295        self.pars = {
296            "num_f_maps": num_f_maps,
297            "len_segment": len_segment,
298            "output_dim": ssl_features,
299            "kernel_1": 3,
300            "kernel_2": 3,
301            "stride": 1,
302            "start": self.start,
303            "end": self.end,
304        }
305
306    def transformation(self, sample_data: Dict) -> Tuple:
307        """Mask the input data."""
308        data = deepcopy(sample_data)
309        for key in data.keys():
310            data[key][:, self.start : self.end] = 0
311        return data, torch.tensor(float("nan"))
312
313    def construct_module(self) -> Union[nn.Module, None]:
314        """Clip-wise feature TCN extractor."""
315        module = MFeatureExtractorTCN(**self.pars)
316        return module
317
318
319class ContrastiveRegressionSSL(SSLConstructor):
320    """Contrastive SSL class with regression loss."""
321
322    type = "contrastive"
323
324    def __init__(
325        self,
326        num_f_maps: torch.Size,
327        num_features: int = 128,
328        num_ssl_layers: int = 1,
329        distance: str = "cosine",
330        temperature: float = 1,
331        break_factor: int = None,
332    ) -> None:
333        """Initialize the ContrastiveRegressionSSL class.
334
335        Parameters
336        ----------
337        num_f_maps : torch.Size
338            shape of feature extractor output
339        num_features : int, default 128
340            final number of features per clip
341        num_ssl_layers : int, default 1
342            number of SSL layers
343        distance : {'cosine', 'euclidean'}
344            the distance calculation method for triplet or circle loss
345        temperature : float, default 1
346            the temperature parameter of contrastive loss
347        break_factor : int, default None
348            the break factor parameter of contrastive loss
349
350        """
351        if len(num_f_maps) > 1:
352            raise RuntimeError(
353                "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; "
354                f"got {len(num_f_maps) + 1} dimensions"
355            )
356        num_f_maps = int(num_f_maps[0])
357        self.loss_function = ContrastiveRegressionLoss(
358            temperature, distance, break_factor
359        )
360        self.pars = {
361            "num_f_maps": num_f_maps,
362            "num_ssl_layers": num_ssl_layers,
363            "num_ssl_f_maps": num_features,
364            "dim": num_features,
365            "ssl_input": False,
366        }
367        super().__init__()
368
369    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
370        """NT-Xent loss."""
371        features1, features2 = predicted
372        loss = self.loss_function(features1, features2)
373        return loss
374
375    def transformation(self, sample_data: Dict) -> Tuple:
376        """Empty transformation."""
377        return torch.tensor(float("nan")), torch.tensor(float("nan"))
378
379    def construct_module(self) -> Union[nn.Module, None]:
380        """Clip-wise feature TCN extractor."""
381        return FC(**self.pars)
class ContrastiveSSL(dlc2action.ssl.base_ssl.SSLConstructor):
18class ContrastiveSSL(SSLConstructor):
19    """A contrastive SSL class with an NT-Xent loss.
20
21    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
22    input sample at runtime).
23    """
24
25    type = "contrastive"
26
27    def __init__(
28        self,
29        num_f_maps: torch.Size,
30        len_segment: int,
31        ssl_features: int = 128,
32        tau: float = 1,
33    ) -> None:
34        """Initialize the SSL constructor.
35
36        Parameters
37        ----------
38        num_f_maps : torch.Size
39            shape of feature extractor output
40        len_segment : int
41            length of segment in the base feature extractor output
42        ssl_features : int, default 128
43            the final number of features per clip
44        tau : float, default 1
45            the tau parameter of NT-Xent loss
46
47        """
48        super().__init__()
49        self.loss_function = NTXent(tau)
50        if len(num_f_maps) > 1:
51            raise RuntimeError(
52                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
53                f"got {len(num_f_maps) + 1} dimensions"
54            )
55        num_f_maps = int(num_f_maps[0])
56        self.pars = {
57            "num_f_maps": num_f_maps,
58            "len_segment": len_segment,
59            "output_dim": ssl_features,
60            "kernel_1": 5,
61            "kernel_2": 5,
62            "stride": 2,
63            "decrease_f_maps": True,
64        }
65
66    def transformation(self, sample_data: Dict) -> Tuple:
67        """Empty transformation."""
68        return torch.tensor(float("nan")), torch.tensor(float("nan"))
69
70    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
71        """NT-Xent loss."""
72        features1, features2 = predicted
73        loss = self.loss_function(features1, features2)
74        return loss
75
76    def construct_module(self) -> Union[nn.Module, None]:
77        """Clip-wise feature TCN extractor."""
78        module = FeatureExtractorTCN(**self.pars)
79        return module

A contrastive SSL class with an NT-Xent loss.

The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).

ContrastiveSSL( num_f_maps: torch.Size, len_segment: int, ssl_features: int = 128, tau: float = 1)
27    def __init__(
28        self,
29        num_f_maps: torch.Size,
30        len_segment: int,
31        ssl_features: int = 128,
32        tau: float = 1,
33    ) -> None:
34        """Initialize the SSL constructor.
35
36        Parameters
37        ----------
38        num_f_maps : torch.Size
39            shape of feature extractor output
40        len_segment : int
41            length of segment in the base feature extractor output
42        ssl_features : int, default 128
43            the final number of features per clip
44        tau : float, default 1
45            the tau parameter of NT-Xent loss
46
47        """
48        super().__init__()
49        self.loss_function = NTXent(tau)
50        if len(num_f_maps) > 1:
51            raise RuntimeError(
52                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
53                f"got {len(num_f_maps) + 1} dimensions"
54            )
55        num_f_maps = int(num_f_maps[0])
56        self.pars = {
57            "num_f_maps": num_f_maps,
58            "len_segment": len_segment,
59            "output_dim": ssl_features,
60            "kernel_1": 5,
61            "kernel_2": 5,
62            "stride": 2,
63            "decrease_f_maps": True,
64        }

Initialize the SSL constructor.

Parameters

num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in the base feature extractor output ssl_features : int, default 128 the final number of features per clip tau : float, default 1 the tau parameter of NT-Xent loss

type = 'contrastive'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
loss_function
pars
def transformation(self, sample_data: Dict) -> Tuple:
66    def transformation(self, sample_data: Dict) -> Tuple:
67        """Empty transformation."""
68        return torch.tensor(float("nan")), torch.tensor(float("nan"))

Empty transformation.

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
70    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
71        """NT-Xent loss."""
72        features1, features2 = predicted
73        loss = self.loss_function(features1, features2)
74        return loss

NT-Xent loss.

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
76    def construct_module(self) -> Union[nn.Module, None]:
77        """Clip-wise feature TCN extractor."""
78        module = FeatureExtractorTCN(**self.pars)
79        return module

Clip-wise feature TCN extractor.

class ContrastiveMaskedSSL(dlc2action.ssl.base_ssl.SSLConstructor):
 82class ContrastiveMaskedSSL(SSLConstructor):
 83    """A contrastive masked SSL class with an NT-Xent loss.
 84
 85    A few frames in the middle of each segment are masked and then the output of the second layer of
 86    feature extraction for the segment is used to predict the output of the first layer for the missing frames.
 87    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
 88    input sample at runtime).
 89    """
 90
 91    type = "contrastive_2layers"
 92
 93    def __init__(
 94        self,
 95        num_f_maps: torch.Size,
 96        len_segment: int,
 97        ssl_features: int = 128,
 98        tau: float = 1,
 99        num_masked: int = 10,
100    ) -> None:
101        """Initialize the ContrastiveMaskedSSL class.
102
103        Parameters
104        ----------
105        num_f_maps : torch.Size
106            shape of feature extractor output
107        len_segment : int
108            length of segment in the base feature extractor output
109        ssl_features : int, default 128
110            the final number of features per clip
111        tau : float, default 1
112            the tau parameter of NT-Xent loss
113        num_masked : int, default 10
114            number of frames to be masked in the middle of each segment
115
116        """
117        super().__init__()
118        self.start = int(len_segment // 2 - num_masked // 2)
119        self.end = int(len_segment // 2 + num_masked // 2)
120        self.loss_function = NTXent(tau)
121        if len(num_f_maps) > 1:
122            raise RuntimeError(
123                "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; "
124                f"got {len(num_f_maps) + 1} dimensions"
125            )
126        num_f_maps = int(num_f_maps[0])
127        self.pars = {
128            "num_f_maps": num_f_maps,
129            "len_segment": len_segment,
130            "output_dim": ssl_features,
131            "kernel_1": 3,
132            "kernel_2": 3,
133            "stride": 1,
134            "start": self.start,
135            "end": self.end,
136        }
137
138    def transformation(self, sample_data: Dict) -> Tuple:
139        """Mask the input data."""
140        data = deepcopy(sample_data)
141        for key in data.keys():
142            data[key][:, self.start : self.end] = 0
143        return data, torch.tensor(float("nan"))
144        # return torch.tensor(float("nan")), torch.tensor(float("nan"))
145
146    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
147        """NT-Xent loss."""
148        features, ssl_features = predicted
149        loss = self.loss_function(features, ssl_features)
150        return loss
151
152    def construct_module(self) -> Union[nn.Module, None]:
153        """Clip-wise feature TCN extractor."""
154        module = MFeatureExtractorTCN(**self.pars)
155        return module

A contrastive masked SSL class with an NT-Xent loss.

A few frames in the middle of each segment are masked and then the output of the second layer of feature extraction for the segment is used to predict the output of the first layer for the missing frames. The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).

ContrastiveMaskedSSL( num_f_maps: torch.Size, len_segment: int, ssl_features: int = 128, tau: float = 1, num_masked: int = 10)
 93    def __init__(
 94        self,
 95        num_f_maps: torch.Size,
 96        len_segment: int,
 97        ssl_features: int = 128,
 98        tau: float = 1,
 99        num_masked: int = 10,
100    ) -> None:
101        """Initialize the ContrastiveMaskedSSL class.
102
103        Parameters
104        ----------
105        num_f_maps : torch.Size
106            shape of feature extractor output
107        len_segment : int
108            length of segment in the base feature extractor output
109        ssl_features : int, default 128
110            the final number of features per clip
111        tau : float, default 1
112            the tau parameter of NT-Xent loss
113        num_masked : int, default 10
114            number of frames to be masked in the middle of each segment
115
116        """
117        super().__init__()
118        self.start = int(len_segment // 2 - num_masked // 2)
119        self.end = int(len_segment // 2 + num_masked // 2)
120        self.loss_function = NTXent(tau)
121        if len(num_f_maps) > 1:
122            raise RuntimeError(
123                "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; "
124                f"got {len(num_f_maps) + 1} dimensions"
125            )
126        num_f_maps = int(num_f_maps[0])
127        self.pars = {
128            "num_f_maps": num_f_maps,
129            "len_segment": len_segment,
130            "output_dim": ssl_features,
131            "kernel_1": 3,
132            "kernel_2": 3,
133            "stride": 1,
134            "start": self.start,
135            "end": self.end,
136        }

Initialize the ContrastiveMaskedSSL class.

Parameters

num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in the base feature extractor output ssl_features : int, default 128 the final number of features per clip tau : float, default 1 the tau parameter of NT-Xent loss num_masked : int, default 10 number of frames to be masked in the middle of each segment

type = 'contrastive_2layers'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
start
end
loss_function
pars
def transformation(self, sample_data: Dict) -> Tuple:
138    def transformation(self, sample_data: Dict) -> Tuple:
139        """Mask the input data."""
140        data = deepcopy(sample_data)
141        for key in data.keys():
142            data[key][:, self.start : self.end] = 0
143        return data, torch.tensor(float("nan"))
144        # return torch.tensor(float("nan")), torch.tensor(float("nan"))

Mask the input data.

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
146    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
147        """NT-Xent loss."""
148        features, ssl_features = predicted
149        loss = self.loss_function(features, ssl_features)
150        return loss

NT-Xent loss.

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
152    def construct_module(self) -> Union[nn.Module, None]:
153        """Clip-wise feature TCN extractor."""
154        module = MFeatureExtractorTCN(**self.pars)
155        return module

Clip-wise feature TCN extractor.

class PairwiseSSL(dlc2action.ssl.base_ssl.SSLConstructor):
158class PairwiseSSL(SSLConstructor):
159    """A pairwise SSL class with triplet or circle loss.
160
161    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
162    input sample at runtime).
163    """
164
165    type = "contrastive"
166
167    def __init__(
168        self,
169        num_f_maps: torch.Size,
170        len_segment: int,
171        ssl_features: int = 128,
172        margin: float = 0,
173        distance: str = "cosine",
174        loss: str = "triplet",
175        gamma: float = 1,
176    ) -> None:
177        """Initialize the PairwiseSSL class.
178
179        Parameters
180        ----------
181        num_f_maps : torch.Size
182            shape of feature extractor output
183        len_segment : int
184            length of segment in feature extractor output
185        ssl_features : int, default 128
186            final number of features per clip
187        margin : float, default 0
188            the margin parameter of triplet or circle loss
189        distance : {'cosine', 'euclidean'}
190            the distance calculation method for triplet or circle loss
191        loss : {'triplet', 'circle'}
192            the loss function name
193        gamma : float, default 1
194            the gamma parameter of circle loss
195
196        """
197        super().__init__()
198        if loss == "triplet":
199            self.loss_function = TripletLoss(margin=margin, distance=distance)
200        elif loss == "circle":
201            self.loss_function = CircleLoss(
202                margin=margin, gamma=gamma, distance=distance
203            )
204        else:
205            raise ValueError(
206                f'The {loss} loss is unavailable, please choose from "triplet" and "circle"'
207            )
208        if len(num_f_maps) > 1:
209            raise RuntimeError(
210                "The PairwiseSSL constructor expects the input data to be 2-dimensional; "
211                f"got {len(num_f_maps) + 1} dimensions"
212            )
213        num_f_maps = int(num_f_maps[0])
214        self.pars = {
215            "num_f_maps": num_f_maps,
216            "len_segment": len_segment,
217            "output_dim": ssl_features,
218            "kernel_1": 5,
219            "kernel_2": 5,
220            "stride": 2,
221            "decrease_f_maps": True,
222        }
223
224    def transformation(self, sample_data: Dict) -> Tuple:
225        """Empty transformation."""
226        return torch.tensor(float("nan")), torch.tensor(float("nan"))
227
228    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
229        """Triplet or circle loss."""
230        features1, features2 = predicted
231        loss = self.loss_function(features1, features2)
232        return loss
233
234    def construct_module(self) -> Union[nn.Module, None]:
235        """Clip-wise feature TCN extractor."""
236        module = FeatureExtractorTCN(**self.pars)
237        return module

A pairwise SSL class with triplet or circle loss.

The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).

PairwiseSSL( num_f_maps: torch.Size, len_segment: int, ssl_features: int = 128, margin: float = 0, distance: str = 'cosine', loss: str = 'triplet', gamma: float = 1)
167    def __init__(
168        self,
169        num_f_maps: torch.Size,
170        len_segment: int,
171        ssl_features: int = 128,
172        margin: float = 0,
173        distance: str = "cosine",
174        loss: str = "triplet",
175        gamma: float = 1,
176    ) -> None:
177        """Initialize the PairwiseSSL class.
178
179        Parameters
180        ----------
181        num_f_maps : torch.Size
182            shape of feature extractor output
183        len_segment : int
184            length of segment in feature extractor output
185        ssl_features : int, default 128
186            final number of features per clip
187        margin : float, default 0
188            the margin parameter of triplet or circle loss
189        distance : {'cosine', 'euclidean'}
190            the distance calculation method for triplet or circle loss
191        loss : {'triplet', 'circle'}
192            the loss function name
193        gamma : float, default 1
194            the gamma parameter of circle loss
195
196        """
197        super().__init__()
198        if loss == "triplet":
199            self.loss_function = TripletLoss(margin=margin, distance=distance)
200        elif loss == "circle":
201            self.loss_function = CircleLoss(
202                margin=margin, gamma=gamma, distance=distance
203            )
204        else:
205            raise ValueError(
206                f'The {loss} loss is unavailable, please choose from "triplet" and "circle"'
207            )
208        if len(num_f_maps) > 1:
209            raise RuntimeError(
210                "The PairwiseSSL constructor expects the input data to be 2-dimensional; "
211                f"got {len(num_f_maps) + 1} dimensions"
212            )
213        num_f_maps = int(num_f_maps[0])
214        self.pars = {
215            "num_f_maps": num_f_maps,
216            "len_segment": len_segment,
217            "output_dim": ssl_features,
218            "kernel_1": 5,
219            "kernel_2": 5,
220            "stride": 2,
221            "decrease_f_maps": True,
222        }

Initialize the PairwiseSSL class.

Parameters

num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in feature extractor output ssl_features : int, default 128 final number of features per clip margin : float, default 0 the margin parameter of triplet or circle loss distance : {'cosine', 'euclidean'} the distance calculation method for triplet or circle loss loss : {'triplet', 'circle'} the loss function name gamma : float, default 1 the gamma parameter of circle loss

type = 'contrastive'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
pars
def transformation(self, sample_data: Dict) -> Tuple:
224    def transformation(self, sample_data: Dict) -> Tuple:
225        """Empty transformation."""
226        return torch.tensor(float("nan")), torch.tensor(float("nan"))

Empty transformation.

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
228    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
229        """Triplet or circle loss."""
230        features1, features2 = predicted
231        loss = self.loss_function(features1, features2)
232        return loss

Triplet or circle loss.

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
234    def construct_module(self) -> Union[nn.Module, None]:
235        """Clip-wise feature TCN extractor."""
236        module = FeatureExtractorTCN(**self.pars)
237        return module

Clip-wise feature TCN extractor.

class PairwiseMaskedSSL(PairwiseSSL):
240class PairwiseMaskedSSL(PairwiseSSL):
241    """A contrastive SSL class with triplet or circle loss and masked input.
242
243    A few frames in the middle of each segment are masked and then the output of the second layer of
244    feature extraction for the segment is used to predict the output of the first layer for the missing frames.
245    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
246    input sample at runtime).
247    """
248
249    type = "contrastive_2layers"
250
251    def __init__(
252        self,
253        num_f_maps: torch.Size,
254        len_segment: int,
255        ssl_features: int = 128,
256        margin: float = 0,
257        distance: str = "cosine",
258        loss: str = "triplet",
259        gamma: float = 1,
260        num_masked: int = 10,
261    ) -> None:
262        """Initialize the PairwiseMaskedSSL class.
263
264        Parameters
265        ----------
266        num_f_maps : torch.Size
267            shape of feature extractor output
268        len_segment : int
269            length of segment in feature extractor output
270        ssl_features : int, default 128
271            final number of features per clip
272        margin : float, default 0
273            the margin parameter of triplet or circle loss
274        distance : {'cosine', 'euclidean'}
275            the distance calculation method for triplet or circle loss
276        loss : {'triplet', 'circle'}
277            the loss function name
278        gamma : float, default 1
279            the gamma parameter of circle loss
280        num_masked : int, default 10
281            number of masked frames
282
283        """
284        super().__init__(
285            num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma
286        )
287        self.num_masked = num_masked
288        self.start = int(len_segment // 2 - num_masked // 2)
289        self.end = int(len_segment // 2 + num_masked // 2)
290        if len(num_f_maps) > 1:
291            raise RuntimeError(
292                "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; "
293                f"got {len(num_f_maps) + 1} dimensions"
294            )
295        num_f_maps = int(num_f_maps[0])
296        self.pars = {
297            "num_f_maps": num_f_maps,
298            "len_segment": len_segment,
299            "output_dim": ssl_features,
300            "kernel_1": 3,
301            "kernel_2": 3,
302            "stride": 1,
303            "start": self.start,
304            "end": self.end,
305        }
306
307    def transformation(self, sample_data: Dict) -> Tuple:
308        """Mask the input data."""
309        data = deepcopy(sample_data)
310        for key in data.keys():
311            data[key][:, self.start : self.end] = 0
312        return data, torch.tensor(float("nan"))
313
314    def construct_module(self) -> Union[nn.Module, None]:
315        """Clip-wise feature TCN extractor."""
316        module = MFeatureExtractorTCN(**self.pars)
317        return module

A contrastive SSL class with triplet or circle loss and masked input.

A few frames in the middle of each segment are masked and then the output of the second layer of feature extraction for the segment is used to predict the output of the first layer for the missing frames. The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).

PairwiseMaskedSSL( num_f_maps: torch.Size, len_segment: int, ssl_features: int = 128, margin: float = 0, distance: str = 'cosine', loss: str = 'triplet', gamma: float = 1, num_masked: int = 10)
251    def __init__(
252        self,
253        num_f_maps: torch.Size,
254        len_segment: int,
255        ssl_features: int = 128,
256        margin: float = 0,
257        distance: str = "cosine",
258        loss: str = "triplet",
259        gamma: float = 1,
260        num_masked: int = 10,
261    ) -> None:
262        """Initialize the PairwiseMaskedSSL class.
263
264        Parameters
265        ----------
266        num_f_maps : torch.Size
267            shape of feature extractor output
268        len_segment : int
269            length of segment in feature extractor output
270        ssl_features : int, default 128
271            final number of features per clip
272        margin : float, default 0
273            the margin parameter of triplet or circle loss
274        distance : {'cosine', 'euclidean'}
275            the distance calculation method for triplet or circle loss
276        loss : {'triplet', 'circle'}
277            the loss function name
278        gamma : float, default 1
279            the gamma parameter of circle loss
280        num_masked : int, default 10
281            number of masked frames
282
283        """
284        super().__init__(
285            num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma
286        )
287        self.num_masked = num_masked
288        self.start = int(len_segment // 2 - num_masked // 2)
289        self.end = int(len_segment // 2 + num_masked // 2)
290        if len(num_f_maps) > 1:
291            raise RuntimeError(
292                "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; "
293                f"got {len(num_f_maps) + 1} dimensions"
294            )
295        num_f_maps = int(num_f_maps[0])
296        self.pars = {
297            "num_f_maps": num_f_maps,
298            "len_segment": len_segment,
299            "output_dim": ssl_features,
300            "kernel_1": 3,
301            "kernel_2": 3,
302            "stride": 1,
303            "start": self.start,
304            "end": self.end,
305        }

Initialize the PairwiseMaskedSSL class.

Parameters

num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in feature extractor output ssl_features : int, default 128 final number of features per clip margin : float, default 0 the margin parameter of triplet or circle loss distance : {'cosine', 'euclidean'} the distance calculation method for triplet or circle loss loss : {'triplet', 'circle'} the loss function name gamma : float, default 1 the gamma parameter of circle loss num_masked : int, default 10 number of masked frames

type = 'contrastive_2layers'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
num_masked
start
end
pars
def transformation(self, sample_data: Dict) -> Tuple:
307    def transformation(self, sample_data: Dict) -> Tuple:
308        """Mask the input data."""
309        data = deepcopy(sample_data)
310        for key in data.keys():
311            data[key][:, self.start : self.end] = 0
312        return data, torch.tensor(float("nan"))

Mask the input data.

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
314    def construct_module(self) -> Union[nn.Module, None]:
315        """Clip-wise feature TCN extractor."""
316        module = MFeatureExtractorTCN(**self.pars)
317        return module

Clip-wise feature TCN extractor.

Inherited Members
PairwiseSSL
loss
class ContrastiveRegressionSSL(dlc2action.ssl.base_ssl.SSLConstructor):
320class ContrastiveRegressionSSL(SSLConstructor):
321    """Contrastive SSL class with regression loss."""
322
323    type = "contrastive"
324
325    def __init__(
326        self,
327        num_f_maps: torch.Size,
328        num_features: int = 128,
329        num_ssl_layers: int = 1,
330        distance: str = "cosine",
331        temperature: float = 1,
332        break_factor: int = None,
333    ) -> None:
334        """Initialize the ContrastiveRegressionSSL class.
335
336        Parameters
337        ----------
338        num_f_maps : torch.Size
339            shape of feature extractor output
340        num_features : int, default 128
341            final number of features per clip
342        num_ssl_layers : int, default 1
343            number of SSL layers
344        distance : {'cosine', 'euclidean'}
345            the distance calculation method for triplet or circle loss
346        temperature : float, default 1
347            the temperature parameter of contrastive loss
348        break_factor : int, default None
349            the break factor parameter of contrastive loss
350
351        """
352        if len(num_f_maps) > 1:
353            raise RuntimeError(
354                "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; "
355                f"got {len(num_f_maps) + 1} dimensions"
356            )
357        num_f_maps = int(num_f_maps[0])
358        self.loss_function = ContrastiveRegressionLoss(
359            temperature, distance, break_factor
360        )
361        self.pars = {
362            "num_f_maps": num_f_maps,
363            "num_ssl_layers": num_ssl_layers,
364            "num_ssl_f_maps": num_features,
365            "dim": num_features,
366            "ssl_input": False,
367        }
368        super().__init__()
369
370    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
371        """NT-Xent loss."""
372        features1, features2 = predicted
373        loss = self.loss_function(features1, features2)
374        return loss
375
376    def transformation(self, sample_data: Dict) -> Tuple:
377        """Empty transformation."""
378        return torch.tensor(float("nan")), torch.tensor(float("nan"))
379
380    def construct_module(self) -> Union[nn.Module, None]:
381        """Clip-wise feature TCN extractor."""
382        return FC(**self.pars)

Contrastive SSL class with regression loss.

ContrastiveRegressionSSL( num_f_maps: torch.Size, num_features: int = 128, num_ssl_layers: int = 1, distance: str = 'cosine', temperature: float = 1, break_factor: int = None)
325    def __init__(
326        self,
327        num_f_maps: torch.Size,
328        num_features: int = 128,
329        num_ssl_layers: int = 1,
330        distance: str = "cosine",
331        temperature: float = 1,
332        break_factor: int = None,
333    ) -> None:
334        """Initialize the ContrastiveRegressionSSL class.
335
336        Parameters
337        ----------
338        num_f_maps : torch.Size
339            shape of feature extractor output
340        num_features : int, default 128
341            final number of features per clip
342        num_ssl_layers : int, default 1
343            number of SSL layers
344        distance : {'cosine', 'euclidean'}
345            the distance calculation method for triplet or circle loss
346        temperature : float, default 1
347            the temperature parameter of contrastive loss
348        break_factor : int, default None
349            the break factor parameter of contrastive loss
350
351        """
352        if len(num_f_maps) > 1:
353            raise RuntimeError(
354                "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; "
355                f"got {len(num_f_maps) + 1} dimensions"
356            )
357        num_f_maps = int(num_f_maps[0])
358        self.loss_function = ContrastiveRegressionLoss(
359            temperature, distance, break_factor
360        )
361        self.pars = {
362            "num_f_maps": num_f_maps,
363            "num_ssl_layers": num_ssl_layers,
364            "num_ssl_f_maps": num_features,
365            "dim": num_features,
366            "ssl_input": False,
367        }
368        super().__init__()

Initialize the ContrastiveRegressionSSL class.

Parameters

num_f_maps : torch.Size shape of feature extractor output num_features : int, default 128 final number of features per clip num_ssl_layers : int, default 1 number of SSL layers distance : {'cosine', 'euclidean'} the distance calculation method for triplet or circle loss temperature : float, default 1 the temperature parameter of contrastive loss break_factor : int, default None the break factor parameter of contrastive loss

type = 'contrastive'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
loss_function
pars
def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
370    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
371        """NT-Xent loss."""
372        features1, features2 = predicted
373        loss = self.loss_function(features1, features2)
374        return loss

NT-Xent loss.

def transformation(self, sample_data: Dict) -> Tuple:
376    def transformation(self, sample_data: Dict) -> Tuple:
377        """Empty transformation."""
378        return torch.tensor(float("nan")), torch.tensor(float("nan"))

Empty transformation.

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
380    def construct_module(self) -> Union[nn.Module, None]:
381        """Clip-wise feature TCN extractor."""
382        return FC(**self.pars)

Clip-wise feature TCN extractor.