dlc2action.ssl.tcc

Temporal Cycle Consistency SSL constructor.

 1#
 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. 
 5# A copy is included in dlc2action/LICENSE.AGPL.
 6#
 7"""Temporal Cycle Consistency SSL constructor."""
 8
 9from dlc2action.ssl.base_ssl import SSLConstructor
10import torch
11from dlc2action.loss.tcc import TCCLoss
12from typing import Dict, Union, Tuple
13from torch import nn
14from dlc2action.ssl.modules import FC
15
16class TCCSSL(SSLConstructor):
17    """Temporal Cycle Consistency SSL constructor."""
18
19    type = "ssl_target"
20
21    def __init__(
22        self,
23        num_f_maps: torch.Size,
24        len_segment: int,
25        projection_head_f_maps: int = None,
26        variance_lambda: float = 0.001,
27        normalize_indices: bool = True,
28        normalize_embeddings: bool = False,
29        similarity_type: str = "l2",
30        num_cycles: int = 20,
31        cycle_length: int = 2,
32        temperature: float = 0.1,
33        label_smoothing: float = 0.1,
34    ) -> None:
35        """Initialize the constructor."""
36        super().__init__()
37        if len(num_f_maps) > 1:
38            raise RuntimeError(
39                "The TCC constructor expects the input data to be 2-dimensional; "
40                f"got {len(num_f_maps) + 1} dimensions"
41            )
42        num_f_maps = int(num_f_maps[0])
43        if projection_head_f_maps is None:
44            projection_head_f_maps = num_f_maps
45        self.len_segment = int(len_segment)
46        self.loss_function = TCCLoss(
47            "regression_mse_var",
48            variance_lambda,
49            normalize_indices,
50            normalize_embeddings,
51            similarity_type,
52            int(num_cycles),
53            int(cycle_length),
54            temperature,
55            label_smoothing,
56        )
57        self.pars = {
58            "dim": int(num_f_maps),
59            "num_f_maps": int(num_f_maps),
60            "num_ssl_layers": 1,
61            "num_ssl_f_maps": int(projection_head_f_maps),
62        }
63
64    def transformation(self, sample_data: Dict) -> Tuple:
65        """Get the mask."""
66        mask = torch.ones((1, self.len_segment))
67        for key, value in sample_data.items():
68            mask *= (torch.sum(value, 0) == 0).unsqueeze(0)
69        mask = 1 - mask
70        return torch.tensor(float("nan")), {"loaded": mask}
71
72    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
73        """TCC loss."""
74        loss = self.loss_function(predicted, target)
75        return loss
76
77    def construct_module(self) -> Union[nn.Module, None]:
78        """Construct the SSL prediction module using the parameters specified at initialization."""
79        if self.pars["num_ssl_f_maps"] is None:
80            module = nn.Identity()
81        else:
82            module = FC(**self.pars)
83        return module
class TCCSSL(dlc2action.ssl.base_ssl.SSLConstructor):
17class TCCSSL(SSLConstructor):
18    """Temporal Cycle Consistency SSL constructor."""
19
20    type = "ssl_target"
21
22    def __init__(
23        self,
24        num_f_maps: torch.Size,
25        len_segment: int,
26        projection_head_f_maps: int = None,
27        variance_lambda: float = 0.001,
28        normalize_indices: bool = True,
29        normalize_embeddings: bool = False,
30        similarity_type: str = "l2",
31        num_cycles: int = 20,
32        cycle_length: int = 2,
33        temperature: float = 0.1,
34        label_smoothing: float = 0.1,
35    ) -> None:
36        """Initialize the constructor."""
37        super().__init__()
38        if len(num_f_maps) > 1:
39            raise RuntimeError(
40                "The TCC constructor expects the input data to be 2-dimensional; "
41                f"got {len(num_f_maps) + 1} dimensions"
42            )
43        num_f_maps = int(num_f_maps[0])
44        if projection_head_f_maps is None:
45            projection_head_f_maps = num_f_maps
46        self.len_segment = int(len_segment)
47        self.loss_function = TCCLoss(
48            "regression_mse_var",
49            variance_lambda,
50            normalize_indices,
51            normalize_embeddings,
52            similarity_type,
53            int(num_cycles),
54            int(cycle_length),
55            temperature,
56            label_smoothing,
57        )
58        self.pars = {
59            "dim": int(num_f_maps),
60            "num_f_maps": int(num_f_maps),
61            "num_ssl_layers": 1,
62            "num_ssl_f_maps": int(projection_head_f_maps),
63        }
64
65    def transformation(self, sample_data: Dict) -> Tuple:
66        """Get the mask."""
67        mask = torch.ones((1, self.len_segment))
68        for key, value in sample_data.items():
69            mask *= (torch.sum(value, 0) == 0).unsqueeze(0)
70        mask = 1 - mask
71        return torch.tensor(float("nan")), {"loaded": mask}
72
73    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
74        """TCC loss."""
75        loss = self.loss_function(predicted, target)
76        return loss
77
78    def construct_module(self) -> Union[nn.Module, None]:
79        """Construct the SSL prediction module using the parameters specified at initialization."""
80        if self.pars["num_ssl_f_maps"] is None:
81            module = nn.Identity()
82        else:
83            module = FC(**self.pars)
84        return module

Temporal Cycle Consistency SSL constructor.

TCCSSL( num_f_maps: torch.Size, len_segment: int, projection_head_f_maps: int = None, variance_lambda: float = 0.001, normalize_indices: bool = True, normalize_embeddings: bool = False, similarity_type: str = 'l2', num_cycles: int = 20, cycle_length: int = 2, temperature: float = 0.1, label_smoothing: float = 0.1)
22    def __init__(
23        self,
24        num_f_maps: torch.Size,
25        len_segment: int,
26        projection_head_f_maps: int = None,
27        variance_lambda: float = 0.001,
28        normalize_indices: bool = True,
29        normalize_embeddings: bool = False,
30        similarity_type: str = "l2",
31        num_cycles: int = 20,
32        cycle_length: int = 2,
33        temperature: float = 0.1,
34        label_smoothing: float = 0.1,
35    ) -> None:
36        """Initialize the constructor."""
37        super().__init__()
38        if len(num_f_maps) > 1:
39            raise RuntimeError(
40                "The TCC constructor expects the input data to be 2-dimensional; "
41                f"got {len(num_f_maps) + 1} dimensions"
42            )
43        num_f_maps = int(num_f_maps[0])
44        if projection_head_f_maps is None:
45            projection_head_f_maps = num_f_maps
46        self.len_segment = int(len_segment)
47        self.loss_function = TCCLoss(
48            "regression_mse_var",
49            variance_lambda,
50            normalize_indices,
51            normalize_embeddings,
52            similarity_type,
53            int(num_cycles),
54            int(cycle_length),
55            temperature,
56            label_smoothing,
57        )
58        self.pars = {
59            "dim": int(num_f_maps),
60            "num_f_maps": int(num_f_maps),
61            "num_ssl_layers": 1,
62            "num_ssl_f_maps": int(projection_head_f_maps),
63        }

Initialize the constructor.

type = 'ssl_target'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
len_segment
loss_function
pars
def transformation(self, sample_data: Dict) -> Tuple:
65    def transformation(self, sample_data: Dict) -> Tuple:
66        """Get the mask."""
67        mask = torch.ones((1, self.len_segment))
68        for key, value in sample_data.items():
69            mask *= (torch.sum(value, 0) == 0).unsqueeze(0)
70        mask = 1 - mask
71        return torch.tensor(float("nan")), {"loaded": mask}

Get the mask.

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
73    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
74        """TCC loss."""
75        loss = self.loss_function(predicted, target)
76        return loss

TCC loss.

def construct_module(self) -> Optional[torch.nn.modules.module.Module]:
78    def construct_module(self) -> Union[nn.Module, None]:
79        """Construct the SSL prediction module using the parameters specified at initialization."""
80        if self.pars["num_ssl_f_maps"] is None:
81            module = nn.Identity()
82        else:
83            module = FC(**self.pars)
84        return module

Construct the SSL prediction module using the parameters specified at initialization.