dlc2action.ssl.tcc

 1#
 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
 5#
 6from dlc2action.ssl.base_ssl import SSLConstructor
 7import torch
 8from dlc2action.loss.tcc import _TCCLoss
 9from typing import Dict, Union, Tuple
10from torch import nn
11from dlc2action.ssl.modules import _FC
12
13
14class TCCSSL(SSLConstructor):
15    """
16    A contrastive SSL class with an NT-Xent loss
17
18    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
19    input sample at runtime).
20    """
21
22    type = "ssl_target"
23
24    def __init__(
25        self,
26        num_f_maps: torch.Size,
27        len_segment: int,
28        projection_head_f_maps: int = None,
29        loss_type: str = "regression_mse_var",
30        variance_lambda: float = 0.001,
31        normalize_indices: bool = True,
32        normalize_embeddings: bool = False,
33        similarity_type: str = "l2",
34        num_cycles: int = 20,
35        cycle_length: int = 2,
36        temperature: float = 0.1,
37        label_smoothing: float = 0.1,
38    ) -> None:
39        super().__init__()
40        if len(num_f_maps) > 1:
41            raise RuntimeError(
42                "The TCC constructor expects the input data to be 2-dimensional; "
43                f"got {len(num_f_maps) + 1} dimensions"
44            )
45        num_f_maps = int(num_f_maps[0])
46        if projection_head_f_maps is None:
47            projection_head_f_maps = num_f_maps
48        self.len_segment = int(len_segment)
49        self.loss_function = _TCCLoss(
50            loss_type,
51            variance_lambda,
52            normalize_indices,
53            normalize_embeddings,
54            similarity_type,
55            int(num_cycles),
56            int(cycle_length),
57            temperature,
58            label_smoothing,
59        )
60        self.pars = {
61            "dim": int(num_f_maps),
62            "num_f_maps": int(num_f_maps),
63            "num_ssl_layers": 1,
64            "num_ssl_f_maps": int(projection_head_f_maps),
65        }
66
67    def transformation(self, sample_data: Dict) -> Tuple:
68        """
69        Empty transformation
70        """
71
72        mask = torch.ones((1, self.len_segment))
73        for key, value in sample_data.items():
74            mask *= (torch.sum(value, 0) == 0).unsqueeze(0)
75        mask = 1 - mask
76        return torch.tensor(float("nan")), {"loaded": mask}
77
78    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
79        loss = self.loss_function(predicted, target)
80        return loss
81
82    def construct_module(self) -> Union[nn.Module, None]:
83
84        if self.pars["num_ssl_f_maps"] is None:
85            module = nn.Identity()
86        else:
87            module = _FC(**self.pars)
88        return module
class TCCSSL(dlc2action.ssl.base_ssl.SSLConstructor):
15class TCCSSL(SSLConstructor):
16    """
17    A contrastive SSL class with an NT-Xent loss
18
19    The SSL input and target are left empty (the SSL input is generated as an augmentation of the
20    input sample at runtime).
21    """
22
23    type = "ssl_target"
24
25    def __init__(
26        self,
27        num_f_maps: torch.Size,
28        len_segment: int,
29        projection_head_f_maps: int = None,
30        loss_type: str = "regression_mse_var",
31        variance_lambda: float = 0.001,
32        normalize_indices: bool = True,
33        normalize_embeddings: bool = False,
34        similarity_type: str = "l2",
35        num_cycles: int = 20,
36        cycle_length: int = 2,
37        temperature: float = 0.1,
38        label_smoothing: float = 0.1,
39    ) -> None:
40        super().__init__()
41        if len(num_f_maps) > 1:
42            raise RuntimeError(
43                "The TCC constructor expects the input data to be 2-dimensional; "
44                f"got {len(num_f_maps) + 1} dimensions"
45            )
46        num_f_maps = int(num_f_maps[0])
47        if projection_head_f_maps is None:
48            projection_head_f_maps = num_f_maps
49        self.len_segment = int(len_segment)
50        self.loss_function = _TCCLoss(
51            loss_type,
52            variance_lambda,
53            normalize_indices,
54            normalize_embeddings,
55            similarity_type,
56            int(num_cycles),
57            int(cycle_length),
58            temperature,
59            label_smoothing,
60        )
61        self.pars = {
62            "dim": int(num_f_maps),
63            "num_f_maps": int(num_f_maps),
64            "num_ssl_layers": 1,
65            "num_ssl_f_maps": int(projection_head_f_maps),
66        }
67
68    def transformation(self, sample_data: Dict) -> Tuple:
69        """
70        Empty transformation
71        """
72
73        mask = torch.ones((1, self.len_segment))
74        for key, value in sample_data.items():
75            mask *= (torch.sum(value, 0) == 0).unsqueeze(0)
76        mask = 1 - mask
77        return torch.tensor(float("nan")), {"loaded": mask}
78
79    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
80        loss = self.loss_function(predicted, target)
81        return loss
82
83    def construct_module(self) -> Union[nn.Module, None]:
84
85        if self.pars["num_ssl_f_maps"] is None:
86            module = nn.Identity()
87        else:
88            module = _FC(**self.pars)
89        return module

A contrastive SSL class with an NT-Xent loss

The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).

TCCSSL( num_f_maps: torch.Size, len_segment: int, projection_head_f_maps: int = None, loss_type: str = 'regression_mse_var', variance_lambda: float = 0.001, normalize_indices: bool = True, normalize_embeddings: bool = False, similarity_type: str = 'l2', num_cycles: int = 20, cycle_length: int = 2, temperature: float = 0.1, label_smoothing: float = 0.1)
25    def __init__(
26        self,
27        num_f_maps: torch.Size,
28        len_segment: int,
29        projection_head_f_maps: int = None,
30        loss_type: str = "regression_mse_var",
31        variance_lambda: float = 0.001,
32        normalize_indices: bool = True,
33        normalize_embeddings: bool = False,
34        similarity_type: str = "l2",
35        num_cycles: int = 20,
36        cycle_length: int = 2,
37        temperature: float = 0.1,
38        label_smoothing: float = 0.1,
39    ) -> None:
40        super().__init__()
41        if len(num_f_maps) > 1:
42            raise RuntimeError(
43                "The TCC constructor expects the input data to be 2-dimensional; "
44                f"got {len(num_f_maps) + 1} dimensions"
45            )
46        num_f_maps = int(num_f_maps[0])
47        if projection_head_f_maps is None:
48            projection_head_f_maps = num_f_maps
49        self.len_segment = int(len_segment)
50        self.loss_function = _TCCLoss(
51            loss_type,
52            variance_lambda,
53            normalize_indices,
54            normalize_embeddings,
55            similarity_type,
56            int(num_cycles),
57            int(cycle_length),
58            temperature,
59            label_smoothing,
60        )
61        self.pars = {
62            "dim": int(num_f_maps),
63            "num_f_maps": int(num_f_maps),
64            "num_ssl_layers": 1,
65            "num_ssl_f_maps": int(projection_head_f_maps),
66        }
type = 'ssl_target'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
68    def transformation(self, sample_data: Dict) -> Tuple:
69        """
70        Empty transformation
71        """
72
73        mask = torch.ones((1, self.len_segment))
74        for key, value in sample_data.items():
75            mask *= (torch.sum(value, 0) == 0).unsqueeze(0)
76        mask = 1 - mask
77        return torch.tensor(float("nan")), {"loaded": mask}

Empty transformation

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
79    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
80        loss = self.loss_function(predicted, target)
81        return loss

Calculate the SSL loss

Parameters

predicted : torch.Tensor output of the SSL module target : torch.Tensor augmented and stacked SSL_target

Returns

loss : float the loss value

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
83    def construct_module(self) -> Union[nn.Module, None]:
84
85        if self.pars["num_ssl_f_maps"] is None:
86            module = nn.Identity()
87        else:
88            module = _FC(**self.pars)
89        return module

Construct the SSL module

Returns

ssl_module : torch.nn.Module a neural net module that takes features extracted by a model's feature extractor as input and returns SSL output