dlc2action.ssl.contrastive
Implementations of dlc2action.ssl.base_ssl.SSLConstructor
of the 'contrastive'
type.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Implementations of `dlc2action.ssl.base_ssl.SSLConstructor` of the `'contrastive'` type.""" 8 9from typing import Dict, Tuple, Union 10from dlc2action.ssl.base_ssl import SSLConstructor 11from dlc2action.loss.contrastive import * 12from dlc2action.loss.contrastive_frame import * 13from dlc2action.ssl.modules import * 14from copy import deepcopy 15 16 17class ContrastiveSSL(SSLConstructor): 18 """A contrastive SSL class with an NT-Xent loss. 19 20 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 21 input sample at runtime). 22 """ 23 24 type = "contrastive" 25 26 def __init__( 27 self, 28 num_f_maps: torch.Size, 29 len_segment: int, 30 ssl_features: int = 128, 31 tau: float = 1, 32 ) -> None: 33 """Initialize the SSL constructor. 34 35 Parameters 36 ---------- 37 num_f_maps : torch.Size 38 shape of feature extractor output 39 len_segment : int 40 length of segment in the base feature extractor output 41 ssl_features : int, default 128 42 the final number of features per clip 43 tau : float, default 1 44 the tau parameter of NT-Xent loss 45 46 """ 47 super().__init__() 48 self.loss_function = NTXent(tau) 49 if len(num_f_maps) > 1: 50 raise RuntimeError( 51 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 52 f"got {len(num_f_maps) + 1} dimensions" 53 ) 54 num_f_maps = int(num_f_maps[0]) 55 self.pars = { 56 "num_f_maps": num_f_maps, 57 "len_segment": len_segment, 58 "output_dim": ssl_features, 59 "kernel_1": 5, 60 "kernel_2": 5, 61 "stride": 2, 62 "decrease_f_maps": True, 63 } 64 65 def transformation(self, sample_data: Dict) -> Tuple: 66 """Empty transformation.""" 67 return torch.tensor(float("nan")), torch.tensor(float("nan")) 68 69 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 70 """NT-Xent loss.""" 71 features1, features2 = predicted 72 loss = self.loss_function(features1, features2) 73 return loss 74 75 def construct_module(self) -> Union[nn.Module, None]: 76 """Clip-wise feature TCN extractor.""" 77 module = FeatureExtractorTCN(**self.pars) 78 return module 79 80 81class ContrastiveMaskedSSL(SSLConstructor): 82 """A contrastive masked SSL class with an NT-Xent loss. 83 84 A few frames in the middle of each segment are masked and then the output of the second layer of 85 feature extraction for the segment is used to predict the output of the first layer for the missing frames. 86 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 87 input sample at runtime). 88 """ 89 90 type = "contrastive_2layers" 91 92 def __init__( 93 self, 94 num_f_maps: torch.Size, 95 len_segment: int, 96 ssl_features: int = 128, 97 tau: float = 1, 98 num_masked: int = 10, 99 ) -> None: 100 """Initialize the ContrastiveMaskedSSL class. 101 102 Parameters 103 ---------- 104 num_f_maps : torch.Size 105 shape of feature extractor output 106 len_segment : int 107 length of segment in the base feature extractor output 108 ssl_features : int, default 128 109 the final number of features per clip 110 tau : float, default 1 111 the tau parameter of NT-Xent loss 112 num_masked : int, default 10 113 number of frames to be masked in the middle of each segment 114 115 """ 116 super().__init__() 117 self.start = int(len_segment // 2 - num_masked // 2) 118 self.end = int(len_segment // 2 + num_masked // 2) 119 self.loss_function = NTXent(tau) 120 if len(num_f_maps) > 1: 121 raise RuntimeError( 122 "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; " 123 f"got {len(num_f_maps) + 1} dimensions" 124 ) 125 num_f_maps = int(num_f_maps[0]) 126 self.pars = { 127 "num_f_maps": num_f_maps, 128 "len_segment": len_segment, 129 "output_dim": ssl_features, 130 "kernel_1": 3, 131 "kernel_2": 3, 132 "stride": 1, 133 "start": self.start, 134 "end": self.end, 135 } 136 137 def transformation(self, sample_data: Dict) -> Tuple: 138 """Mask the input data.""" 139 data = deepcopy(sample_data) 140 for key in data.keys(): 141 data[key][:, self.start : self.end] = 0 142 return data, torch.tensor(float("nan")) 143 # return torch.tensor(float("nan")), torch.tensor(float("nan")) 144 145 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 146 """NT-Xent loss.""" 147 features, ssl_features = predicted 148 loss = self.loss_function(features, ssl_features) 149 return loss 150 151 def construct_module(self) -> Union[nn.Module, None]: 152 """Clip-wise feature TCN extractor.""" 153 module = MFeatureExtractorTCN(**self.pars) 154 return module 155 156 157class PairwiseSSL(SSLConstructor): 158 """A pairwise SSL class with triplet or circle loss. 159 160 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 161 input sample at runtime). 162 """ 163 164 type = "contrastive" 165 166 def __init__( 167 self, 168 num_f_maps: torch.Size, 169 len_segment: int, 170 ssl_features: int = 128, 171 margin: float = 0, 172 distance: str = "cosine", 173 loss: str = "triplet", 174 gamma: float = 1, 175 ) -> None: 176 """Initialize the PairwiseSSL class. 177 178 Parameters 179 ---------- 180 num_f_maps : torch.Size 181 shape of feature extractor output 182 len_segment : int 183 length of segment in feature extractor output 184 ssl_features : int, default 128 185 final number of features per clip 186 margin : float, default 0 187 the margin parameter of triplet or circle loss 188 distance : {'cosine', 'euclidean'} 189 the distance calculation method for triplet or circle loss 190 loss : {'triplet', 'circle'} 191 the loss function name 192 gamma : float, default 1 193 the gamma parameter of circle loss 194 195 """ 196 super().__init__() 197 if loss == "triplet": 198 self.loss_function = TripletLoss(margin=margin, distance=distance) 199 elif loss == "circle": 200 self.loss_function = CircleLoss( 201 margin=margin, gamma=gamma, distance=distance 202 ) 203 else: 204 raise ValueError( 205 f'The {loss} loss is unavailable, please choose from "triplet" and "circle"' 206 ) 207 if len(num_f_maps) > 1: 208 raise RuntimeError( 209 "The PairwiseSSL constructor expects the input data to be 2-dimensional; " 210 f"got {len(num_f_maps) + 1} dimensions" 211 ) 212 num_f_maps = int(num_f_maps[0]) 213 self.pars = { 214 "num_f_maps": num_f_maps, 215 "len_segment": len_segment, 216 "output_dim": ssl_features, 217 "kernel_1": 5, 218 "kernel_2": 5, 219 "stride": 2, 220 "decrease_f_maps": True, 221 } 222 223 def transformation(self, sample_data: Dict) -> Tuple: 224 """Empty transformation.""" 225 return torch.tensor(float("nan")), torch.tensor(float("nan")) 226 227 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 228 """Triplet or circle loss.""" 229 features1, features2 = predicted 230 loss = self.loss_function(features1, features2) 231 return loss 232 233 def construct_module(self) -> Union[nn.Module, None]: 234 """Clip-wise feature TCN extractor.""" 235 module = FeatureExtractorTCN(**self.pars) 236 return module 237 238 239class PairwiseMaskedSSL(PairwiseSSL): 240 """A contrastive SSL class with triplet or circle loss and masked input. 241 242 A few frames in the middle of each segment are masked and then the output of the second layer of 243 feature extraction for the segment is used to predict the output of the first layer for the missing frames. 244 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 245 input sample at runtime). 246 """ 247 248 type = "contrastive_2layers" 249 250 def __init__( 251 self, 252 num_f_maps: torch.Size, 253 len_segment: int, 254 ssl_features: int = 128, 255 margin: float = 0, 256 distance: str = "cosine", 257 loss: str = "triplet", 258 gamma: float = 1, 259 num_masked: int = 10, 260 ) -> None: 261 """Initialize the PairwiseMaskedSSL class. 262 263 Parameters 264 ---------- 265 num_f_maps : torch.Size 266 shape of feature extractor output 267 len_segment : int 268 length of segment in feature extractor output 269 ssl_features : int, default 128 270 final number of features per clip 271 margin : float, default 0 272 the margin parameter of triplet or circle loss 273 distance : {'cosine', 'euclidean'} 274 the distance calculation method for triplet or circle loss 275 loss : {'triplet', 'circle'} 276 the loss function name 277 gamma : float, default 1 278 the gamma parameter of circle loss 279 num_masked : int, default 10 280 number of masked frames 281 282 """ 283 super().__init__( 284 num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma 285 ) 286 self.num_masked = num_masked 287 self.start = int(len_segment // 2 - num_masked // 2) 288 self.end = int(len_segment // 2 + num_masked // 2) 289 if len(num_f_maps) > 1: 290 raise RuntimeError( 291 "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; " 292 f"got {len(num_f_maps) + 1} dimensions" 293 ) 294 num_f_maps = int(num_f_maps[0]) 295 self.pars = { 296 "num_f_maps": num_f_maps, 297 "len_segment": len_segment, 298 "output_dim": ssl_features, 299 "kernel_1": 3, 300 "kernel_2": 3, 301 "stride": 1, 302 "start": self.start, 303 "end": self.end, 304 } 305 306 def transformation(self, sample_data: Dict) -> Tuple: 307 """Mask the input data.""" 308 data = deepcopy(sample_data) 309 for key in data.keys(): 310 data[key][:, self.start : self.end] = 0 311 return data, torch.tensor(float("nan")) 312 313 def construct_module(self) -> Union[nn.Module, None]: 314 """Clip-wise feature TCN extractor.""" 315 module = MFeatureExtractorTCN(**self.pars) 316 return module 317 318 319class ContrastiveRegressionSSL(SSLConstructor): 320 """Contrastive SSL class with regression loss.""" 321 322 type = "contrastive" 323 324 def __init__( 325 self, 326 num_f_maps: torch.Size, 327 num_features: int = 128, 328 num_ssl_layers: int = 1, 329 distance: str = "cosine", 330 temperature: float = 1, 331 break_factor: int = None, 332 ) -> None: 333 """Initialize the ContrastiveRegressionSSL class. 334 335 Parameters 336 ---------- 337 num_f_maps : torch.Size 338 shape of feature extractor output 339 num_features : int, default 128 340 final number of features per clip 341 num_ssl_layers : int, default 1 342 number of SSL layers 343 distance : {'cosine', 'euclidean'} 344 the distance calculation method for triplet or circle loss 345 temperature : float, default 1 346 the temperature parameter of contrastive loss 347 break_factor : int, default None 348 the break factor parameter of contrastive loss 349 350 """ 351 if len(num_f_maps) > 1: 352 raise RuntimeError( 353 "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; " 354 f"got {len(num_f_maps) + 1} dimensions" 355 ) 356 num_f_maps = int(num_f_maps[0]) 357 self.loss_function = ContrastiveRegressionLoss( 358 temperature, distance, break_factor 359 ) 360 self.pars = { 361 "num_f_maps": num_f_maps, 362 "num_ssl_layers": num_ssl_layers, 363 "num_ssl_f_maps": num_features, 364 "dim": num_features, 365 "ssl_input": False, 366 } 367 super().__init__() 368 369 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 370 """NT-Xent loss.""" 371 features1, features2 = predicted 372 loss = self.loss_function(features1, features2) 373 return loss 374 375 def transformation(self, sample_data: Dict) -> Tuple: 376 """Empty transformation.""" 377 return torch.tensor(float("nan")), torch.tensor(float("nan")) 378 379 def construct_module(self) -> Union[nn.Module, None]: 380 """Clip-wise feature TCN extractor.""" 381 return FC(**self.pars)
18class ContrastiveSSL(SSLConstructor): 19 """A contrastive SSL class with an NT-Xent loss. 20 21 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 22 input sample at runtime). 23 """ 24 25 type = "contrastive" 26 27 def __init__( 28 self, 29 num_f_maps: torch.Size, 30 len_segment: int, 31 ssl_features: int = 128, 32 tau: float = 1, 33 ) -> None: 34 """Initialize the SSL constructor. 35 36 Parameters 37 ---------- 38 num_f_maps : torch.Size 39 shape of feature extractor output 40 len_segment : int 41 length of segment in the base feature extractor output 42 ssl_features : int, default 128 43 the final number of features per clip 44 tau : float, default 1 45 the tau parameter of NT-Xent loss 46 47 """ 48 super().__init__() 49 self.loss_function = NTXent(tau) 50 if len(num_f_maps) > 1: 51 raise RuntimeError( 52 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 53 f"got {len(num_f_maps) + 1} dimensions" 54 ) 55 num_f_maps = int(num_f_maps[0]) 56 self.pars = { 57 "num_f_maps": num_f_maps, 58 "len_segment": len_segment, 59 "output_dim": ssl_features, 60 "kernel_1": 5, 61 "kernel_2": 5, 62 "stride": 2, 63 "decrease_f_maps": True, 64 } 65 66 def transformation(self, sample_data: Dict) -> Tuple: 67 """Empty transformation.""" 68 return torch.tensor(float("nan")), torch.tensor(float("nan")) 69 70 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 71 """NT-Xent loss.""" 72 features1, features2 = predicted 73 loss = self.loss_function(features1, features2) 74 return loss 75 76 def construct_module(self) -> Union[nn.Module, None]: 77 """Clip-wise feature TCN extractor.""" 78 module = FeatureExtractorTCN(**self.pars) 79 return module
A contrastive SSL class with an NT-Xent loss.
The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).
27 def __init__( 28 self, 29 num_f_maps: torch.Size, 30 len_segment: int, 31 ssl_features: int = 128, 32 tau: float = 1, 33 ) -> None: 34 """Initialize the SSL constructor. 35 36 Parameters 37 ---------- 38 num_f_maps : torch.Size 39 shape of feature extractor output 40 len_segment : int 41 length of segment in the base feature extractor output 42 ssl_features : int, default 128 43 the final number of features per clip 44 tau : float, default 1 45 the tau parameter of NT-Xent loss 46 47 """ 48 super().__init__() 49 self.loss_function = NTXent(tau) 50 if len(num_f_maps) > 1: 51 raise RuntimeError( 52 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 53 f"got {len(num_f_maps) + 1} dimensions" 54 ) 55 num_f_maps = int(num_f_maps[0]) 56 self.pars = { 57 "num_f_maps": num_f_maps, 58 "len_segment": len_segment, 59 "output_dim": ssl_features, 60 "kernel_1": 5, 61 "kernel_2": 5, 62 "stride": 2, 63 "decrease_f_maps": True, 64 }
Initialize the SSL constructor.
Parameters
num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in the base feature extractor output ssl_features : int, default 128 the final number of features per clip tau : float, default 1 the tau parameter of NT-Xent loss
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
66 def transformation(self, sample_data: Dict) -> Tuple: 67 """Empty transformation.""" 68 return torch.tensor(float("nan")), torch.tensor(float("nan"))
Empty transformation.
82class ContrastiveMaskedSSL(SSLConstructor): 83 """A contrastive masked SSL class with an NT-Xent loss. 84 85 A few frames in the middle of each segment are masked and then the output of the second layer of 86 feature extraction for the segment is used to predict the output of the first layer for the missing frames. 87 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 88 input sample at runtime). 89 """ 90 91 type = "contrastive_2layers" 92 93 def __init__( 94 self, 95 num_f_maps: torch.Size, 96 len_segment: int, 97 ssl_features: int = 128, 98 tau: float = 1, 99 num_masked: int = 10, 100 ) -> None: 101 """Initialize the ContrastiveMaskedSSL class. 102 103 Parameters 104 ---------- 105 num_f_maps : torch.Size 106 shape of feature extractor output 107 len_segment : int 108 length of segment in the base feature extractor output 109 ssl_features : int, default 128 110 the final number of features per clip 111 tau : float, default 1 112 the tau parameter of NT-Xent loss 113 num_masked : int, default 10 114 number of frames to be masked in the middle of each segment 115 116 """ 117 super().__init__() 118 self.start = int(len_segment // 2 - num_masked // 2) 119 self.end = int(len_segment // 2 + num_masked // 2) 120 self.loss_function = NTXent(tau) 121 if len(num_f_maps) > 1: 122 raise RuntimeError( 123 "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; " 124 f"got {len(num_f_maps) + 1} dimensions" 125 ) 126 num_f_maps = int(num_f_maps[0]) 127 self.pars = { 128 "num_f_maps": num_f_maps, 129 "len_segment": len_segment, 130 "output_dim": ssl_features, 131 "kernel_1": 3, 132 "kernel_2": 3, 133 "stride": 1, 134 "start": self.start, 135 "end": self.end, 136 } 137 138 def transformation(self, sample_data: Dict) -> Tuple: 139 """Mask the input data.""" 140 data = deepcopy(sample_data) 141 for key in data.keys(): 142 data[key][:, self.start : self.end] = 0 143 return data, torch.tensor(float("nan")) 144 # return torch.tensor(float("nan")), torch.tensor(float("nan")) 145 146 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 147 """NT-Xent loss.""" 148 features, ssl_features = predicted 149 loss = self.loss_function(features, ssl_features) 150 return loss 151 152 def construct_module(self) -> Union[nn.Module, None]: 153 """Clip-wise feature TCN extractor.""" 154 module = MFeatureExtractorTCN(**self.pars) 155 return module
A contrastive masked SSL class with an NT-Xent loss.
A few frames in the middle of each segment are masked and then the output of the second layer of feature extraction for the segment is used to predict the output of the first layer for the missing frames. The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).
93 def __init__( 94 self, 95 num_f_maps: torch.Size, 96 len_segment: int, 97 ssl_features: int = 128, 98 tau: float = 1, 99 num_masked: int = 10, 100 ) -> None: 101 """Initialize the ContrastiveMaskedSSL class. 102 103 Parameters 104 ---------- 105 num_f_maps : torch.Size 106 shape of feature extractor output 107 len_segment : int 108 length of segment in the base feature extractor output 109 ssl_features : int, default 128 110 the final number of features per clip 111 tau : float, default 1 112 the tau parameter of NT-Xent loss 113 num_masked : int, default 10 114 number of frames to be masked in the middle of each segment 115 116 """ 117 super().__init__() 118 self.start = int(len_segment // 2 - num_masked // 2) 119 self.end = int(len_segment // 2 + num_masked // 2) 120 self.loss_function = NTXent(tau) 121 if len(num_f_maps) > 1: 122 raise RuntimeError( 123 "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; " 124 f"got {len(num_f_maps) + 1} dimensions" 125 ) 126 num_f_maps = int(num_f_maps[0]) 127 self.pars = { 128 "num_f_maps": num_f_maps, 129 "len_segment": len_segment, 130 "output_dim": ssl_features, 131 "kernel_1": 3, 132 "kernel_2": 3, 133 "stride": 1, 134 "start": self.start, 135 "end": self.end, 136 }
Initialize the ContrastiveMaskedSSL class.
Parameters
num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in the base feature extractor output ssl_features : int, default 128 the final number of features per clip tau : float, default 1 the tau parameter of NT-Xent loss num_masked : int, default 10 number of frames to be masked in the middle of each segment
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
138 def transformation(self, sample_data: Dict) -> Tuple: 139 """Mask the input data.""" 140 data = deepcopy(sample_data) 141 for key in data.keys(): 142 data[key][:, self.start : self.end] = 0 143 return data, torch.tensor(float("nan")) 144 # return torch.tensor(float("nan")), torch.tensor(float("nan"))
Mask the input data.
158class PairwiseSSL(SSLConstructor): 159 """A pairwise SSL class with triplet or circle loss. 160 161 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 162 input sample at runtime). 163 """ 164 165 type = "contrastive" 166 167 def __init__( 168 self, 169 num_f_maps: torch.Size, 170 len_segment: int, 171 ssl_features: int = 128, 172 margin: float = 0, 173 distance: str = "cosine", 174 loss: str = "triplet", 175 gamma: float = 1, 176 ) -> None: 177 """Initialize the PairwiseSSL class. 178 179 Parameters 180 ---------- 181 num_f_maps : torch.Size 182 shape of feature extractor output 183 len_segment : int 184 length of segment in feature extractor output 185 ssl_features : int, default 128 186 final number of features per clip 187 margin : float, default 0 188 the margin parameter of triplet or circle loss 189 distance : {'cosine', 'euclidean'} 190 the distance calculation method for triplet or circle loss 191 loss : {'triplet', 'circle'} 192 the loss function name 193 gamma : float, default 1 194 the gamma parameter of circle loss 195 196 """ 197 super().__init__() 198 if loss == "triplet": 199 self.loss_function = TripletLoss(margin=margin, distance=distance) 200 elif loss == "circle": 201 self.loss_function = CircleLoss( 202 margin=margin, gamma=gamma, distance=distance 203 ) 204 else: 205 raise ValueError( 206 f'The {loss} loss is unavailable, please choose from "triplet" and "circle"' 207 ) 208 if len(num_f_maps) > 1: 209 raise RuntimeError( 210 "The PairwiseSSL constructor expects the input data to be 2-dimensional; " 211 f"got {len(num_f_maps) + 1} dimensions" 212 ) 213 num_f_maps = int(num_f_maps[0]) 214 self.pars = { 215 "num_f_maps": num_f_maps, 216 "len_segment": len_segment, 217 "output_dim": ssl_features, 218 "kernel_1": 5, 219 "kernel_2": 5, 220 "stride": 2, 221 "decrease_f_maps": True, 222 } 223 224 def transformation(self, sample_data: Dict) -> Tuple: 225 """Empty transformation.""" 226 return torch.tensor(float("nan")), torch.tensor(float("nan")) 227 228 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 229 """Triplet or circle loss.""" 230 features1, features2 = predicted 231 loss = self.loss_function(features1, features2) 232 return loss 233 234 def construct_module(self) -> Union[nn.Module, None]: 235 """Clip-wise feature TCN extractor.""" 236 module = FeatureExtractorTCN(**self.pars) 237 return module
A pairwise SSL class with triplet or circle loss.
The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).
167 def __init__( 168 self, 169 num_f_maps: torch.Size, 170 len_segment: int, 171 ssl_features: int = 128, 172 margin: float = 0, 173 distance: str = "cosine", 174 loss: str = "triplet", 175 gamma: float = 1, 176 ) -> None: 177 """Initialize the PairwiseSSL class. 178 179 Parameters 180 ---------- 181 num_f_maps : torch.Size 182 shape of feature extractor output 183 len_segment : int 184 length of segment in feature extractor output 185 ssl_features : int, default 128 186 final number of features per clip 187 margin : float, default 0 188 the margin parameter of triplet or circle loss 189 distance : {'cosine', 'euclidean'} 190 the distance calculation method for triplet or circle loss 191 loss : {'triplet', 'circle'} 192 the loss function name 193 gamma : float, default 1 194 the gamma parameter of circle loss 195 196 """ 197 super().__init__() 198 if loss == "triplet": 199 self.loss_function = TripletLoss(margin=margin, distance=distance) 200 elif loss == "circle": 201 self.loss_function = CircleLoss( 202 margin=margin, gamma=gamma, distance=distance 203 ) 204 else: 205 raise ValueError( 206 f'The {loss} loss is unavailable, please choose from "triplet" and "circle"' 207 ) 208 if len(num_f_maps) > 1: 209 raise RuntimeError( 210 "The PairwiseSSL constructor expects the input data to be 2-dimensional; " 211 f"got {len(num_f_maps) + 1} dimensions" 212 ) 213 num_f_maps = int(num_f_maps[0]) 214 self.pars = { 215 "num_f_maps": num_f_maps, 216 "len_segment": len_segment, 217 "output_dim": ssl_features, 218 "kernel_1": 5, 219 "kernel_2": 5, 220 "stride": 2, 221 "decrease_f_maps": True, 222 }
Initialize the PairwiseSSL class.
Parameters
num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in feature extractor output ssl_features : int, default 128 final number of features per clip margin : float, default 0 the margin parameter of triplet or circle loss distance : {'cosine', 'euclidean'} the distance calculation method for triplet or circle loss loss : {'triplet', 'circle'} the loss function name gamma : float, default 1 the gamma parameter of circle loss
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
224 def transformation(self, sample_data: Dict) -> Tuple: 225 """Empty transformation.""" 226 return torch.tensor(float("nan")), torch.tensor(float("nan"))
Empty transformation.
228 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 229 """Triplet or circle loss.""" 230 features1, features2 = predicted 231 loss = self.loss_function(features1, features2) 232 return loss
Triplet or circle loss.
240class PairwiseMaskedSSL(PairwiseSSL): 241 """A contrastive SSL class with triplet or circle loss and masked input. 242 243 A few frames in the middle of each segment are masked and then the output of the second layer of 244 feature extraction for the segment is used to predict the output of the first layer for the missing frames. 245 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 246 input sample at runtime). 247 """ 248 249 type = "contrastive_2layers" 250 251 def __init__( 252 self, 253 num_f_maps: torch.Size, 254 len_segment: int, 255 ssl_features: int = 128, 256 margin: float = 0, 257 distance: str = "cosine", 258 loss: str = "triplet", 259 gamma: float = 1, 260 num_masked: int = 10, 261 ) -> None: 262 """Initialize the PairwiseMaskedSSL class. 263 264 Parameters 265 ---------- 266 num_f_maps : torch.Size 267 shape of feature extractor output 268 len_segment : int 269 length of segment in feature extractor output 270 ssl_features : int, default 128 271 final number of features per clip 272 margin : float, default 0 273 the margin parameter of triplet or circle loss 274 distance : {'cosine', 'euclidean'} 275 the distance calculation method for triplet or circle loss 276 loss : {'triplet', 'circle'} 277 the loss function name 278 gamma : float, default 1 279 the gamma parameter of circle loss 280 num_masked : int, default 10 281 number of masked frames 282 283 """ 284 super().__init__( 285 num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma 286 ) 287 self.num_masked = num_masked 288 self.start = int(len_segment // 2 - num_masked // 2) 289 self.end = int(len_segment // 2 + num_masked // 2) 290 if len(num_f_maps) > 1: 291 raise RuntimeError( 292 "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; " 293 f"got {len(num_f_maps) + 1} dimensions" 294 ) 295 num_f_maps = int(num_f_maps[0]) 296 self.pars = { 297 "num_f_maps": num_f_maps, 298 "len_segment": len_segment, 299 "output_dim": ssl_features, 300 "kernel_1": 3, 301 "kernel_2": 3, 302 "stride": 1, 303 "start": self.start, 304 "end": self.end, 305 } 306 307 def transformation(self, sample_data: Dict) -> Tuple: 308 """Mask the input data.""" 309 data = deepcopy(sample_data) 310 for key in data.keys(): 311 data[key][:, self.start : self.end] = 0 312 return data, torch.tensor(float("nan")) 313 314 def construct_module(self) -> Union[nn.Module, None]: 315 """Clip-wise feature TCN extractor.""" 316 module = MFeatureExtractorTCN(**self.pars) 317 return module
A contrastive SSL class with triplet or circle loss and masked input.
A few frames in the middle of each segment are masked and then the output of the second layer of feature extraction for the segment is used to predict the output of the first layer for the missing frames. The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).
251 def __init__( 252 self, 253 num_f_maps: torch.Size, 254 len_segment: int, 255 ssl_features: int = 128, 256 margin: float = 0, 257 distance: str = "cosine", 258 loss: str = "triplet", 259 gamma: float = 1, 260 num_masked: int = 10, 261 ) -> None: 262 """Initialize the PairwiseMaskedSSL class. 263 264 Parameters 265 ---------- 266 num_f_maps : torch.Size 267 shape of feature extractor output 268 len_segment : int 269 length of segment in feature extractor output 270 ssl_features : int, default 128 271 final number of features per clip 272 margin : float, default 0 273 the margin parameter of triplet or circle loss 274 distance : {'cosine', 'euclidean'} 275 the distance calculation method for triplet or circle loss 276 loss : {'triplet', 'circle'} 277 the loss function name 278 gamma : float, default 1 279 the gamma parameter of circle loss 280 num_masked : int, default 10 281 number of masked frames 282 283 """ 284 super().__init__( 285 num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma 286 ) 287 self.num_masked = num_masked 288 self.start = int(len_segment // 2 - num_masked // 2) 289 self.end = int(len_segment // 2 + num_masked // 2) 290 if len(num_f_maps) > 1: 291 raise RuntimeError( 292 "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; " 293 f"got {len(num_f_maps) + 1} dimensions" 294 ) 295 num_f_maps = int(num_f_maps[0]) 296 self.pars = { 297 "num_f_maps": num_f_maps, 298 "len_segment": len_segment, 299 "output_dim": ssl_features, 300 "kernel_1": 3, 301 "kernel_2": 3, 302 "stride": 1, 303 "start": self.start, 304 "end": self.end, 305 }
Initialize the PairwiseMaskedSSL class.
Parameters
num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in feature extractor output ssl_features : int, default 128 final number of features per clip margin : float, default 0 the margin parameter of triplet or circle loss distance : {'cosine', 'euclidean'} the distance calculation method for triplet or circle loss loss : {'triplet', 'circle'} the loss function name gamma : float, default 1 the gamma parameter of circle loss num_masked : int, default 10 number of masked frames
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
307 def transformation(self, sample_data: Dict) -> Tuple: 308 """Mask the input data.""" 309 data = deepcopy(sample_data) 310 for key in data.keys(): 311 data[key][:, self.start : self.end] = 0 312 return data, torch.tensor(float("nan"))
Mask the input data.
314 def construct_module(self) -> Union[nn.Module, None]: 315 """Clip-wise feature TCN extractor.""" 316 module = MFeatureExtractorTCN(**self.pars) 317 return module
Clip-wise feature TCN extractor.
Inherited Members
320class ContrastiveRegressionSSL(SSLConstructor): 321 """Contrastive SSL class with regression loss.""" 322 323 type = "contrastive" 324 325 def __init__( 326 self, 327 num_f_maps: torch.Size, 328 num_features: int = 128, 329 num_ssl_layers: int = 1, 330 distance: str = "cosine", 331 temperature: float = 1, 332 break_factor: int = None, 333 ) -> None: 334 """Initialize the ContrastiveRegressionSSL class. 335 336 Parameters 337 ---------- 338 num_f_maps : torch.Size 339 shape of feature extractor output 340 num_features : int, default 128 341 final number of features per clip 342 num_ssl_layers : int, default 1 343 number of SSL layers 344 distance : {'cosine', 'euclidean'} 345 the distance calculation method for triplet or circle loss 346 temperature : float, default 1 347 the temperature parameter of contrastive loss 348 break_factor : int, default None 349 the break factor parameter of contrastive loss 350 351 """ 352 if len(num_f_maps) > 1: 353 raise RuntimeError( 354 "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; " 355 f"got {len(num_f_maps) + 1} dimensions" 356 ) 357 num_f_maps = int(num_f_maps[0]) 358 self.loss_function = ContrastiveRegressionLoss( 359 temperature, distance, break_factor 360 ) 361 self.pars = { 362 "num_f_maps": num_f_maps, 363 "num_ssl_layers": num_ssl_layers, 364 "num_ssl_f_maps": num_features, 365 "dim": num_features, 366 "ssl_input": False, 367 } 368 super().__init__() 369 370 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 371 """NT-Xent loss.""" 372 features1, features2 = predicted 373 loss = self.loss_function(features1, features2) 374 return loss 375 376 def transformation(self, sample_data: Dict) -> Tuple: 377 """Empty transformation.""" 378 return torch.tensor(float("nan")), torch.tensor(float("nan")) 379 380 def construct_module(self) -> Union[nn.Module, None]: 381 """Clip-wise feature TCN extractor.""" 382 return FC(**self.pars)
Contrastive SSL class with regression loss.
325 def __init__( 326 self, 327 num_f_maps: torch.Size, 328 num_features: int = 128, 329 num_ssl_layers: int = 1, 330 distance: str = "cosine", 331 temperature: float = 1, 332 break_factor: int = None, 333 ) -> None: 334 """Initialize the ContrastiveRegressionSSL class. 335 336 Parameters 337 ---------- 338 num_f_maps : torch.Size 339 shape of feature extractor output 340 num_features : int, default 128 341 final number of features per clip 342 num_ssl_layers : int, default 1 343 number of SSL layers 344 distance : {'cosine', 'euclidean'} 345 the distance calculation method for triplet or circle loss 346 temperature : float, default 1 347 the temperature parameter of contrastive loss 348 break_factor : int, default None 349 the break factor parameter of contrastive loss 350 351 """ 352 if len(num_f_maps) > 1: 353 raise RuntimeError( 354 "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; " 355 f"got {len(num_f_maps) + 1} dimensions" 356 ) 357 num_f_maps = int(num_f_maps[0]) 358 self.loss_function = ContrastiveRegressionLoss( 359 temperature, distance, break_factor 360 ) 361 self.pars = { 362 "num_f_maps": num_f_maps, 363 "num_ssl_layers": num_ssl_layers, 364 "num_ssl_f_maps": num_features, 365 "dim": num_features, 366 "ssl_input": False, 367 } 368 super().__init__()
Initialize the ContrastiveRegressionSSL class.
Parameters
num_f_maps : torch.Size shape of feature extractor output num_features : int, default 128 final number of features per clip num_ssl_layers : int, default 1 number of SSL layers distance : {'cosine', 'euclidean'} the distance calculation method for triplet or circle loss temperature : float, default 1 the temperature parameter of contrastive loss break_factor : int, default None the break factor parameter of contrastive loss
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
370 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 371 """NT-Xent loss.""" 372 features1, features2 = predicted 373 loss = self.loss_function(features1, features2) 374 return loss
NT-Xent loss.