dlc2action.ssl.tcc
Temporal Cycle Consistency SSL constructor.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Temporal Cycle Consistency SSL constructor.""" 8 9from dlc2action.ssl.base_ssl import SSLConstructor 10import torch 11from dlc2action.loss.tcc import TCCLoss 12from typing import Dict, Union, Tuple 13from torch import nn 14from dlc2action.ssl.modules import FC 15 16class TCCSSL(SSLConstructor): 17 """Temporal Cycle Consistency SSL constructor.""" 18 19 type = "ssl_target" 20 21 def __init__( 22 self, 23 num_f_maps: torch.Size, 24 len_segment: int, 25 projection_head_f_maps: int = None, 26 variance_lambda: float = 0.001, 27 normalize_indices: bool = True, 28 normalize_embeddings: bool = False, 29 similarity_type: str = "l2", 30 num_cycles: int = 20, 31 cycle_length: int = 2, 32 temperature: float = 0.1, 33 label_smoothing: float = 0.1, 34 ) -> None: 35 """Initialize the constructor.""" 36 super().__init__() 37 if len(num_f_maps) > 1: 38 raise RuntimeError( 39 "The TCC constructor expects the input data to be 2-dimensional; " 40 f"got {len(num_f_maps) + 1} dimensions" 41 ) 42 num_f_maps = int(num_f_maps[0]) 43 if projection_head_f_maps is None: 44 projection_head_f_maps = num_f_maps 45 self.len_segment = int(len_segment) 46 self.loss_function = TCCLoss( 47 "regression_mse_var", 48 variance_lambda, 49 normalize_indices, 50 normalize_embeddings, 51 similarity_type, 52 int(num_cycles), 53 int(cycle_length), 54 temperature, 55 label_smoothing, 56 ) 57 self.pars = { 58 "dim": int(num_f_maps), 59 "num_f_maps": int(num_f_maps), 60 "num_ssl_layers": 1, 61 "num_ssl_f_maps": int(projection_head_f_maps), 62 } 63 64 def transformation(self, sample_data: Dict) -> Tuple: 65 """Get the mask.""" 66 mask = torch.ones((1, self.len_segment)) 67 for key, value in sample_data.items(): 68 mask *= (torch.sum(value, 0) == 0).unsqueeze(0) 69 mask = 1 - mask 70 return torch.tensor(float("nan")), {"loaded": mask} 71 72 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 73 """TCC loss.""" 74 loss = self.loss_function(predicted, target) 75 return loss 76 77 def construct_module(self) -> Union[nn.Module, None]: 78 """Construct the SSL prediction module using the parameters specified at initialization.""" 79 if self.pars["num_ssl_f_maps"] is None: 80 module = nn.Identity() 81 else: 82 module = FC(**self.pars) 83 return module
17class TCCSSL(SSLConstructor): 18 """Temporal Cycle Consistency SSL constructor.""" 19 20 type = "ssl_target" 21 22 def __init__( 23 self, 24 num_f_maps: torch.Size, 25 len_segment: int, 26 projection_head_f_maps: int = None, 27 variance_lambda: float = 0.001, 28 normalize_indices: bool = True, 29 normalize_embeddings: bool = False, 30 similarity_type: str = "l2", 31 num_cycles: int = 20, 32 cycle_length: int = 2, 33 temperature: float = 0.1, 34 label_smoothing: float = 0.1, 35 ) -> None: 36 """Initialize the constructor.""" 37 super().__init__() 38 if len(num_f_maps) > 1: 39 raise RuntimeError( 40 "The TCC constructor expects the input data to be 2-dimensional; " 41 f"got {len(num_f_maps) + 1} dimensions" 42 ) 43 num_f_maps = int(num_f_maps[0]) 44 if projection_head_f_maps is None: 45 projection_head_f_maps = num_f_maps 46 self.len_segment = int(len_segment) 47 self.loss_function = TCCLoss( 48 "regression_mse_var", 49 variance_lambda, 50 normalize_indices, 51 normalize_embeddings, 52 similarity_type, 53 int(num_cycles), 54 int(cycle_length), 55 temperature, 56 label_smoothing, 57 ) 58 self.pars = { 59 "dim": int(num_f_maps), 60 "num_f_maps": int(num_f_maps), 61 "num_ssl_layers": 1, 62 "num_ssl_f_maps": int(projection_head_f_maps), 63 } 64 65 def transformation(self, sample_data: Dict) -> Tuple: 66 """Get the mask.""" 67 mask = torch.ones((1, self.len_segment)) 68 for key, value in sample_data.items(): 69 mask *= (torch.sum(value, 0) == 0).unsqueeze(0) 70 mask = 1 - mask 71 return torch.tensor(float("nan")), {"loaded": mask} 72 73 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 74 """TCC loss.""" 75 loss = self.loss_function(predicted, target) 76 return loss 77 78 def construct_module(self) -> Union[nn.Module, None]: 79 """Construct the SSL prediction module using the parameters specified at initialization.""" 80 if self.pars["num_ssl_f_maps"] is None: 81 module = nn.Identity() 82 else: 83 module = FC(**self.pars) 84 return module
Temporal Cycle Consistency SSL constructor.
TCCSSL( num_f_maps: torch.Size, len_segment: int, projection_head_f_maps: int = None, variance_lambda: float = 0.001, normalize_indices: bool = True, normalize_embeddings: bool = False, similarity_type: str = 'l2', num_cycles: int = 20, cycle_length: int = 2, temperature: float = 0.1, label_smoothing: float = 0.1)
22 def __init__( 23 self, 24 num_f_maps: torch.Size, 25 len_segment: int, 26 projection_head_f_maps: int = None, 27 variance_lambda: float = 0.001, 28 normalize_indices: bool = True, 29 normalize_embeddings: bool = False, 30 similarity_type: str = "l2", 31 num_cycles: int = 20, 32 cycle_length: int = 2, 33 temperature: float = 0.1, 34 label_smoothing: float = 0.1, 35 ) -> None: 36 """Initialize the constructor.""" 37 super().__init__() 38 if len(num_f_maps) > 1: 39 raise RuntimeError( 40 "The TCC constructor expects the input data to be 2-dimensional; " 41 f"got {len(num_f_maps) + 1} dimensions" 42 ) 43 num_f_maps = int(num_f_maps[0]) 44 if projection_head_f_maps is None: 45 projection_head_f_maps = num_f_maps 46 self.len_segment = int(len_segment) 47 self.loss_function = TCCLoss( 48 "regression_mse_var", 49 variance_lambda, 50 normalize_indices, 51 normalize_embeddings, 52 similarity_type, 53 int(num_cycles), 54 int(cycle_length), 55 temperature, 56 label_smoothing, 57 ) 58 self.pars = { 59 "dim": int(num_f_maps), 60 "num_f_maps": int(num_f_maps), 61 "num_ssl_layers": 1, 62 "num_ssl_f_maps": int(projection_head_f_maps), 63 }
Initialize the constructor.
type =
'ssl_target'
The type parameter defines interaction with the model:
'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def
transformation(self, sample_data: Dict) -> Tuple:
65 def transformation(self, sample_data: Dict) -> Tuple: 66 """Get the mask.""" 67 mask = torch.ones((1, self.len_segment)) 68 for key, value in sample_data.items(): 69 mask *= (torch.sum(value, 0) == 0).unsqueeze(0) 70 mask = 1 - mask 71 return torch.tensor(float("nan")), {"loaded": mask}
Get the mask.
def
loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
73 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 74 """TCC loss.""" 75 loss = self.loss_function(predicted, target) 76 return loss
TCC loss.
def
construct_module(self) -> Optional[torch.nn.modules.module.Module]:
78 def construct_module(self) -> Union[nn.Module, None]: 79 """Construct the SSL prediction module using the parameters specified at initialization.""" 80 if self.pars["num_ssl_f_maps"] is None: 81 module = nn.Identity() 82 else: 83 module = FC(**self.pars) 84 return module
Construct the SSL prediction module using the parameters specified at initialization.