dlc2action.ssl.tcc
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6from dlc2action.ssl.base_ssl import SSLConstructor 7import torch 8from dlc2action.loss.tcc import _TCCLoss 9from typing import Dict, Union, Tuple 10from torch import nn 11from dlc2action.ssl.modules import _FC 12 13 14class TCCSSL(SSLConstructor): 15 """ 16 A contrastive SSL class with an NT-Xent loss 17 18 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 19 input sample at runtime). 20 """ 21 22 type = "ssl_target" 23 24 def __init__( 25 self, 26 num_f_maps: torch.Size, 27 len_segment: int, 28 projection_head_f_maps: int = None, 29 loss_type: str = "regression_mse_var", 30 variance_lambda: float = 0.001, 31 normalize_indices: bool = True, 32 normalize_embeddings: bool = False, 33 similarity_type: str = "l2", 34 num_cycles: int = 20, 35 cycle_length: int = 2, 36 temperature: float = 0.1, 37 label_smoothing: float = 0.1, 38 ) -> None: 39 super().__init__() 40 if len(num_f_maps) > 1: 41 raise RuntimeError( 42 "The TCC constructor expects the input data to be 2-dimensional; " 43 f"got {len(num_f_maps) + 1} dimensions" 44 ) 45 num_f_maps = int(num_f_maps[0]) 46 if projection_head_f_maps is None: 47 projection_head_f_maps = num_f_maps 48 self.len_segment = int(len_segment) 49 self.loss_function = _TCCLoss( 50 loss_type, 51 variance_lambda, 52 normalize_indices, 53 normalize_embeddings, 54 similarity_type, 55 int(num_cycles), 56 int(cycle_length), 57 temperature, 58 label_smoothing, 59 ) 60 self.pars = { 61 "dim": int(num_f_maps), 62 "num_f_maps": int(num_f_maps), 63 "num_ssl_layers": 1, 64 "num_ssl_f_maps": int(projection_head_f_maps), 65 } 66 67 def transformation(self, sample_data: Dict) -> Tuple: 68 """ 69 Empty transformation 70 """ 71 72 mask = torch.ones((1, self.len_segment)) 73 for key, value in sample_data.items(): 74 mask *= (torch.sum(value, 0) == 0).unsqueeze(0) 75 mask = 1 - mask 76 return torch.tensor(float("nan")), {"loaded": mask} 77 78 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 79 loss = self.loss_function(predicted, target) 80 return loss 81 82 def construct_module(self) -> Union[nn.Module, None]: 83 84 if self.pars["num_ssl_f_maps"] is None: 85 module = nn.Identity() 86 else: 87 module = _FC(**self.pars) 88 return module
15class TCCSSL(SSLConstructor): 16 """ 17 A contrastive SSL class with an NT-Xent loss 18 19 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 20 input sample at runtime). 21 """ 22 23 type = "ssl_target" 24 25 def __init__( 26 self, 27 num_f_maps: torch.Size, 28 len_segment: int, 29 projection_head_f_maps: int = None, 30 loss_type: str = "regression_mse_var", 31 variance_lambda: float = 0.001, 32 normalize_indices: bool = True, 33 normalize_embeddings: bool = False, 34 similarity_type: str = "l2", 35 num_cycles: int = 20, 36 cycle_length: int = 2, 37 temperature: float = 0.1, 38 label_smoothing: float = 0.1, 39 ) -> None: 40 super().__init__() 41 if len(num_f_maps) > 1: 42 raise RuntimeError( 43 "The TCC constructor expects the input data to be 2-dimensional; " 44 f"got {len(num_f_maps) + 1} dimensions" 45 ) 46 num_f_maps = int(num_f_maps[0]) 47 if projection_head_f_maps is None: 48 projection_head_f_maps = num_f_maps 49 self.len_segment = int(len_segment) 50 self.loss_function = _TCCLoss( 51 loss_type, 52 variance_lambda, 53 normalize_indices, 54 normalize_embeddings, 55 similarity_type, 56 int(num_cycles), 57 int(cycle_length), 58 temperature, 59 label_smoothing, 60 ) 61 self.pars = { 62 "dim": int(num_f_maps), 63 "num_f_maps": int(num_f_maps), 64 "num_ssl_layers": 1, 65 "num_ssl_f_maps": int(projection_head_f_maps), 66 } 67 68 def transformation(self, sample_data: Dict) -> Tuple: 69 """ 70 Empty transformation 71 """ 72 73 mask = torch.ones((1, self.len_segment)) 74 for key, value in sample_data.items(): 75 mask *= (torch.sum(value, 0) == 0).unsqueeze(0) 76 mask = 1 - mask 77 return torch.tensor(float("nan")), {"loaded": mask} 78 79 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 80 loss = self.loss_function(predicted, target) 81 return loss 82 83 def construct_module(self) -> Union[nn.Module, None]: 84 85 if self.pars["num_ssl_f_maps"] is None: 86 module = nn.Identity() 87 else: 88 module = _FC(**self.pars) 89 return module
A contrastive SSL class with an NT-Xent loss
The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).
TCCSSL( num_f_maps: torch.Size, len_segment: int, projection_head_f_maps: int = None, loss_type: str = 'regression_mse_var', variance_lambda: float = 0.001, normalize_indices: bool = True, normalize_embeddings: bool = False, similarity_type: str = 'l2', num_cycles: int = 20, cycle_length: int = 2, temperature: float = 0.1, label_smoothing: float = 0.1)
25 def __init__( 26 self, 27 num_f_maps: torch.Size, 28 len_segment: int, 29 projection_head_f_maps: int = None, 30 loss_type: str = "regression_mse_var", 31 variance_lambda: float = 0.001, 32 normalize_indices: bool = True, 33 normalize_embeddings: bool = False, 34 similarity_type: str = "l2", 35 num_cycles: int = 20, 36 cycle_length: int = 2, 37 temperature: float = 0.1, 38 label_smoothing: float = 0.1, 39 ) -> None: 40 super().__init__() 41 if len(num_f_maps) > 1: 42 raise RuntimeError( 43 "The TCC constructor expects the input data to be 2-dimensional; " 44 f"got {len(num_f_maps) + 1} dimensions" 45 ) 46 num_f_maps = int(num_f_maps[0]) 47 if projection_head_f_maps is None: 48 projection_head_f_maps = num_f_maps 49 self.len_segment = int(len_segment) 50 self.loss_function = _TCCLoss( 51 loss_type, 52 variance_lambda, 53 normalize_indices, 54 normalize_embeddings, 55 similarity_type, 56 int(num_cycles), 57 int(cycle_length), 58 temperature, 59 label_smoothing, 60 ) 61 self.pars = { 62 "dim": int(num_f_maps), 63 "num_f_maps": int(num_f_maps), 64 "num_ssl_layers": 1, 65 "num_ssl_f_maps": int(projection_head_f_maps), 66 }
type = 'ssl_target'
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def
transformation(self, sample_data: Dict) -> Tuple:
68 def transformation(self, sample_data: Dict) -> Tuple: 69 """ 70 Empty transformation 71 """ 72 73 mask = torch.ones((1, self.len_segment)) 74 for key, value in sample_data.items(): 75 mask *= (torch.sum(value, 0) == 0).unsqueeze(0) 76 mask = 1 - mask 77 return torch.tensor(float("nan")), {"loaded": mask}
Empty transformation
def
loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
79 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 80 loss = self.loss_function(predicted, target) 81 return loss
Calculate the SSL loss
Parameters
predicted : torch.Tensor output of the SSL module target : torch.Tensor augmented and stacked SSL_target
Returns
loss : float the loss value
def
construct_module(self) -> Optional[torch.nn.modules.module.Module]:
83 def construct_module(self) -> Union[nn.Module, None]: 84 85 if self.pars["num_ssl_f_maps"] is None: 86 module = nn.Identity() 87 else: 88 module = _FC(**self.pars) 89 return module
Construct the SSL module
Returns
ssl_module : torch.nn.Module a neural net module that takes features extracted by a model's feature extractor as input and returns SSL output