
 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
 6from dlc2action.ssl.base_ssl import SSLConstructor
 7import torch
 8from dlc2action.loss.tcc import _TCCLoss
 9from typing import Dict, Union, Tuple
10from torch import nn
11from dlc2action.ssl.modules import _FC
14class TCCSSL(SSLConstructor):
15    """
16    A contrastive SSL class with an NT-Xent loss
18    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
19    input sample at runtime).
20    """
22    type = "ssl_target"
24    def __init__(
25        self,
26        num_f_maps: torch.Size,
27        len_segment: int,
28        projection_head_f_maps: int = None,
29        loss_type: str = "regression_mse_var",
30        variance_lambda: float = 0.001,
31        normalize_indices: bool = True,
32        normalize_embeddings: bool = False,
33        similarity_type: str = "l2",
34        num_cycles: int = 20,
35        cycle_length: int = 2,
36        temperature: float = 0.1,
37        label_smoothing: float = 0.1,
38    ) -> None:
39        super().__init__()
40        if len(num_f_maps) > 1:
41            raise RuntimeError(
42                "The TCC constructor expects the input data to be 2-dimensional; "
43                f"got {len(num_f_maps) + 1} dimensions"
44            )
45        num_f_maps = int(num_f_maps[0])
46        if projection_head_f_maps is None:
47            projection_head_f_maps = num_f_maps
48        self.len_segment = int(len_segment)
49        self.loss_function = _TCCLoss(
50            loss_type,
51            variance_lambda,
52            normalize_indices,
53            normalize_embeddings,
54            similarity_type,
55            int(num_cycles),
56            int(cycle_length),
57            temperature,
58            label_smoothing,
59        )
60        self.pars = {
61            "dim": int(num_f_maps),
62            "num_f_maps": int(num_f_maps),
63            "num_ssl_layers": 1,
64            "num_ssl_f_maps": int(projection_head_f_maps),
65        }
67    def transformation(self, sample_data: Dict) -> Tuple:
68        """
69        Empty transformation
70        """
72        mask = torch.ones((1, self.len_segment))
73        for key, value in sample_data.items():
74            mask *= (torch.sum(value, 0) == 0).unsqueeze(0)
75        mask = 1 - mask
76        return torch.tensor(float("nan")), {"loaded": mask}
78    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
79        loss = self.loss_function(predicted, target)
80        return loss
82    def construct_module(self) -> Union[nn.Module, None]:
84        if self.pars["num_ssl_f_maps"] is None:
85            module = nn.Identity()
86        else:
87            module = _FC(**self.pars)
88        return module
class TCCSSL(dlc2action.ssl.base_ssl.SSLConstructor):
15class TCCSSL(SSLConstructor):
16    """
17    A contrastive SSL class with an NT-Xent loss
19    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
20    input sample at runtime).
21    """
23    type = "ssl_target"
25    def __init__(
26        self,
27        num_f_maps: torch.Size,
28        len_segment: int,
29        projection_head_f_maps: int = None,
30        loss_type: str = "regression_mse_var",
31        variance_lambda: float = 0.001,
32        normalize_indices: bool = True,
33        normalize_embeddings: bool = False,
34        similarity_type: str = "l2",
35        num_cycles: int = 20,
36        cycle_length: int = 2,
37        temperature: float = 0.1,
38        label_smoothing: float = 0.1,
39    ) -> None:
40        super().__init__()
41        if len(num_f_maps) > 1:
42            raise RuntimeError(
43                "The TCC constructor expects the input data to be 2-dimensional; "
44                f"got {len(num_f_maps) + 1} dimensions"
45            )
46        num_f_maps = int(num_f_maps[0])
47        if projection_head_f_maps is None:
48            projection_head_f_maps = num_f_maps
49        self.len_segment = int(len_segment)
50        self.loss_function = _TCCLoss(
51            loss_type,
52            variance_lambda,
53            normalize_indices,
54            normalize_embeddings,
55            similarity_type,
56            int(num_cycles),
57            int(cycle_length),
58            temperature,
59            label_smoothing,
60        )
61        self.pars = {
62            "dim": int(num_f_maps),
63            "num_f_maps": int(num_f_maps),
64            "num_ssl_layers": 1,
65            "num_ssl_f_maps": int(projection_head_f_maps),
66        }
68    def transformation(self, sample_data: Dict) -> Tuple:
69        """
70        Empty transformation
71        """
73        mask = torch.ones((1, self.len_segment))
74        for key, value in sample_data.items():
75            mask *= (torch.sum(value, 0) == 0).unsqueeze(0)
76        mask = 1 - mask
77        return torch.tensor(float("nan")), {"loaded": mask}
79    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
80        loss = self.loss_function(predicted, target)
81        return loss
83    def construct_module(self) -> Union[nn.Module, None]:
85        if self.pars["num_ssl_f_maps"] is None:
86            module = nn.Identity()
87        else:
88            module = _FC(**self.pars)
89        return module

A contrastive SSL class with an NT-Xent loss

The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).

TCCSSL( num_f_maps: torch.Size, len_segment: int, projection_head_f_maps: int = None, loss_type: str = 'regression_mse_var', variance_lambda: float = 0.001, normalize_indices: bool = True, normalize_embeddings: bool = False, similarity_type: str = 'l2', num_cycles: int = 20, cycle_length: int = 2, temperature: float = 0.1, label_smoothing: float = 0.1)
25    def __init__(
26        self,
27        num_f_maps: torch.Size,
28        len_segment: int,
29        projection_head_f_maps: int = None,
30        loss_type: str = "regression_mse_var",
31        variance_lambda: float = 0.001,
32        normalize_indices: bool = True,
33        normalize_embeddings: bool = False,
34        similarity_type: str = "l2",
35        num_cycles: int = 20,
36        cycle_length: int = 2,
37        temperature: float = 0.1,
38        label_smoothing: float = 0.1,
39    ) -> None:
40        super().__init__()
41        if len(num_f_maps) > 1:
42            raise RuntimeError(
43                "The TCC constructor expects the input data to be 2-dimensional; "
44                f"got {len(num_f_maps) + 1} dimensions"
45            )
46        num_f_maps = int(num_f_maps[0])
47        if projection_head_f_maps is None:
48            projection_head_f_maps = num_f_maps
49        self.len_segment = int(len_segment)
50        self.loss_function = _TCCLoss(
51            loss_type,
52            variance_lambda,
53            normalize_indices,
54            normalize_embeddings,
55            similarity_type,
56            int(num_cycles),
57            int(cycle_length),
58            temperature,
59            label_smoothing,
60        )
61        self.pars = {
62            "dim": int(num_f_maps),
63            "num_f_maps": int(num_f_maps),
64            "num_ssl_layers": 1,
65            "num_ssl_f_maps": int(projection_head_f_maps),
66        }
type = 'ssl_target'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
68    def transformation(self, sample_data: Dict) -> Tuple:
69        """
70        Empty transformation
71        """
73        mask = torch.ones((1, self.len_segment))
74        for key, value in sample_data.items():
75            mask *= (torch.sum(value, 0) == 0).unsqueeze(0)
76        mask = 1 - mask
77        return torch.tensor(float("nan")), {"loaded": mask}

Empty transformation

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
79    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
80        loss = self.loss_function(predicted, target)
81        return loss

Calculate the SSL loss


predicted : torch.Tensor output of the SSL module target : torch.Tensor augmented and stacked SSL_target


loss : float the loss value

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
83    def construct_module(self) -> Union[nn.Module, None]:
85        if self.pars["num_ssl_f_maps"] is None:
86            module = nn.Identity()
87        else:
88            module = _FC(**self.pars)
89        return module

Construct the SSL module


ssl_module : torch.nn.Module a neural net module that takes features extracted by a model's feature extractor as input and returns SSL output