dlc2action.ssl.contrastive
Implementations of dlc2action.ssl.base_ssl.SSLConstructor
of the 'contrastive'
type
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Implementations of `dlc2action.ssl.base_ssl.SSLConstructor` of the `'contrastive'` type 8""" 9 10from typing import Dict, Tuple, Union 11from dlc2action.ssl.base_ssl import SSLConstructor 12from dlc2action.loss.contrastive import _NTXent, _CircleLoss, _TripletLoss 13from dlc2action.loss.contrastive_frame import _ContrastiveRegressionLoss 14from dlc2action.ssl.modules import _FeatureExtractorTCN, _MFeatureExtractorTCN, _FC 15import torch 16from torch import nn 17from copy import deepcopy 18 19 20class ContrastiveSSL(SSLConstructor): 21 """ 22 A contrastive SSL class with an NT-Xent loss 23 24 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 25 input sample at runtime). 26 """ 27 28 type = "contrastive" 29 30 def __init__( 31 self, 32 num_f_maps: torch.Size, 33 len_segment: int, 34 ssl_features: int = 128, 35 tau: float = 1, 36 ) -> None: 37 """ 38 Parameters 39 ---------- 40 num_f_maps : torch.Size 41 shape of feature extractor output 42 len_segment : int 43 length of segment in the base feature extractor output 44 ssl_features : int, default 128 45 the final number of features per clip 46 tau : float, default 1 47 the tau parameter of NT-Xent loss 48 """ 49 50 super().__init__() 51 self.loss_function = _NTXent(tau) 52 if len(num_f_maps) > 1: 53 raise RuntimeError( 54 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 55 f"got {len(num_f_maps) + 1} dimensions" 56 ) 57 num_f_maps = int(num_f_maps[0]) 58 self.pars = { 59 "num_f_maps": num_f_maps, 60 "len_segment": len_segment, 61 "output_dim": ssl_features, 62 "kernel_1": 5, 63 "kernel_2": 5, 64 "stride": 2, 65 "decrease_f_maps": True, 66 } 67 68 def transformation(self, sample_data: Dict) -> Tuple: 69 """ 70 Empty transformation 71 """ 72 73 return torch.tensor(float("nan")), torch.tensor(float("nan")) 74 75 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 76 """ 77 NT-Xent loss 78 """ 79 80 features1, features2 = predicted 81 loss = self.loss_function(features1, features2) 82 return loss 83 84 def construct_module(self) -> Union[nn.Module, None]: 85 """ 86 Clip-wise feature TCN extractor 87 """ 88 89 module = _FeatureExtractorTCN(**self.pars) 90 return module 91 92 93class ContrastiveMaskedSSL(SSLConstructor): 94 """ 95 A contrastive masked SSL class with an NT-Xent loss 96 97 A few frames in the middle of each segment are masked and then the output of the second layer of 98 feature extraction for the segment is used to predict the output of the first layer for the missing frames. 99 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 100 input sample at runtime). 101 """ 102 103 type = "contrastive_2layers" 104 105 def __init__( 106 self, 107 num_f_maps: torch.Size, 108 len_segment: int, 109 ssl_features: int = 128, 110 tau: float = 1, 111 num_masked: int = 10, 112 ) -> None: 113 """ 114 Parameters 115 ---------- 116 num_f_maps : torch.Size 117 shape of feature extractor output 118 len_segment : int 119 length of segment in the base feature extractor output 120 ssl_features : int, default 128 121 the final number of features per clip 122 tau : float, default 1 123 the tau parameter of NT-Xent loss 124 """ 125 126 super().__init__() 127 self.start = int(len_segment // 2 - num_masked // 2) 128 self.end = int(len_segment // 2 + num_masked // 2) 129 self.loss_function = _NTXent(tau) 130 if len(num_f_maps) > 1: 131 raise RuntimeError( 132 "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; " 133 f"got {len(num_f_maps) + 1} dimensions" 134 ) 135 num_f_maps = int(num_f_maps[0]) 136 self.pars = { 137 "num_f_maps": num_f_maps, 138 "len_segment": len_segment, 139 "output_dim": ssl_features, 140 "kernel_1": 3, 141 "kernel_2": 3, 142 "stride": 1, 143 "start": self.start, 144 "end": self.end, 145 } 146 147 def transformation(self, sample_data: Dict) -> Tuple: 148 """ 149 Empty transformation 150 """ 151 152 data = deepcopy(sample_data) 153 for key in data.keys(): 154 data[key][:, self.start : self.end] = 0 155 return data, torch.tensor(float("nan")) 156 # return torch.tensor(float("nan")), torch.tensor(float("nan")) 157 158 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 159 """ 160 NT-Xent loss 161 """ 162 163 features, ssl_features = predicted 164 loss = self.loss_function(features, ssl_features) 165 return loss 166 167 def construct_module(self) -> Union[nn.Module, None]: 168 """ 169 Clip-wise feature TCN extractor 170 """ 171 172 module = _MFeatureExtractorTCN(**self.pars) 173 return module 174 175 176class PairwiseSSL(SSLConstructor): 177 """ 178 A pairwise SSL class with triplet or circle loss 179 180 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 181 input sample at runtime). 182 """ 183 184 type = "contrastive" 185 186 def __init__( 187 self, 188 num_f_maps: torch.Size, 189 len_segment: int, 190 ssl_features: int = 128, 191 margin: float = 0, 192 distance: str = "cosine", 193 loss: str = "triplet", 194 gamma: float = 1, 195 ) -> None: 196 """ 197 Parameters 198 ---------- 199 num_f_maps : torch.Size 200 shape of feature extractor output 201 len_segment : int 202 length of segment in feature extractor output 203 ssl_features : int, default 128 204 final number of features per clip 205 margin : float, default 0 206 the margin parameter of triplet or circle loss 207 distance : {'cosine', 'euclidean'} 208 the distance calculation method for triplet or circle loss 209 loss : {'triplet', 'circle'} 210 the loss function name 211 gamma : float, default 1 212 the gamma parameter of circle loss 213 """ 214 215 super().__init__() 216 if loss == "triplet": 217 self.loss_function = _TripletLoss(margin=margin, distance=distance) 218 elif loss == "circle": 219 self.loss_function = _CircleLoss( 220 margin=margin, gamma=gamma, distance=distance 221 ) 222 else: 223 raise ValueError( 224 f'The {loss} loss is unavailable, please choose from "triplet" and "circle"' 225 ) 226 if len(num_f_maps) > 1: 227 raise RuntimeError( 228 "The PairwiseSSL constructor expects the input data to be 2-dimensional; " 229 f"got {len(num_f_maps) + 1} dimensions" 230 ) 231 num_f_maps = int(num_f_maps[0]) 232 self.pars = { 233 "num_f_maps": num_f_maps, 234 "len_segment": len_segment, 235 "output_dim": ssl_features, 236 "kernel_1": 5, 237 "kernel_2": 5, 238 "stride": 2, 239 "decrease_f_maps": True, 240 } 241 242 def transformation(self, sample_data: Dict) -> Tuple: 243 """ 244 Empty transformation 245 """ 246 247 return torch.tensor(float("nan")), torch.tensor(float("nan")) 248 249 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 250 """ 251 Triplet or circle loss 252 """ 253 254 features1, features2 = predicted 255 loss = self.loss_function(features1, features2) 256 return loss 257 258 def construct_module(self) -> Union[nn.Module, None]: 259 """ 260 Clip-wise feature TCN extractor 261 """ 262 263 module = _FeatureExtractorTCN(**self.pars) 264 return module 265 266 267class PairwiseMaskedSSL(PairwiseSSL): 268 269 type = "contrastive_2layers" 270 271 def __init__( 272 self, 273 num_f_maps: torch.Size, 274 len_segment: int, 275 ssl_features: int = 128, 276 margin: float = 0, 277 distance: str = "cosine", 278 loss: str = "triplet", 279 gamma: float = 1, 280 num_masked: int = 10, 281 ) -> None: 282 super().__init__( 283 num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma 284 ) 285 self.num_masked = num_masked 286 self.start = int(len_segment // 2 - num_masked // 2) 287 self.end = int(len_segment // 2 + num_masked // 2) 288 if len(num_f_maps) > 1: 289 raise RuntimeError( 290 "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; " 291 f"got {len(num_f_maps) + 1} dimensions" 292 ) 293 num_f_maps = int(num_f_maps[0]) 294 self.pars = { 295 "num_f_maps": num_f_maps, 296 "len_segment": len_segment, 297 "output_dim": ssl_features, 298 "kernel_1": 3, 299 "kernel_2": 3, 300 "stride": 1, 301 "start": self.start, 302 "end": self.end, 303 } 304 305 def transformation(self, sample_data: Dict) -> Tuple: 306 """ 307 Empty transformation 308 """ 309 310 data = deepcopy(sample_data) 311 for key in data.keys(): 312 data[key][:, self.start : self.end] = 0 313 return data, torch.tensor(float("nan")) 314 315 def construct_module(self) -> Union[nn.Module, None]: 316 """ 317 Clip-wise feature TCN extractor 318 """ 319 320 module = _MFeatureExtractorTCN(**self.pars) 321 return module 322 323 324class ContrastiveRegressionSSL(SSLConstructor): 325 326 type = "contrastive" 327 328 def __init__( 329 self, 330 num_f_maps: torch.Size, 331 num_features: int = 128, 332 num_ssl_layers: int = 1, 333 distance: str = "cosine", 334 temperature: float = 1, 335 break_factor: int = None, 336 ) -> None: 337 if len(num_f_maps) > 1: 338 raise RuntimeError( 339 "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; " 340 f"got {len(num_f_maps) + 1} dimensions" 341 ) 342 num_f_maps = int(num_f_maps[0]) 343 self.loss_function = _ContrastiveRegressionLoss( 344 temperature, distance, break_factor 345 ) 346 self.pars = { 347 "num_f_maps": num_f_maps, 348 "num_ssl_layers": num_ssl_layers, 349 "num_ssl_f_maps": num_features, 350 "dim": num_features, 351 } 352 super().__init__() 353 354 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 355 """ 356 NT-Xent loss 357 """ 358 359 features1, features2 = predicted 360 loss = self.loss_function(features1, features2) 361 return loss 362 363 def transformation(self, sample_data: Dict) -> Tuple: 364 """ 365 Empty transformation 366 """ 367 368 return torch.tensor(float("nan")), torch.tensor(float("nan")) 369 370 def construct_module(self) -> Union[nn.Module, None]: 371 """ 372 Clip-wise feature TCN extractor 373 """ 374 375 return _FC(**self.pars)
21class ContrastiveSSL(SSLConstructor): 22 """ 23 A contrastive SSL class with an NT-Xent loss 24 25 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 26 input sample at runtime). 27 """ 28 29 type = "contrastive" 30 31 def __init__( 32 self, 33 num_f_maps: torch.Size, 34 len_segment: int, 35 ssl_features: int = 128, 36 tau: float = 1, 37 ) -> None: 38 """ 39 Parameters 40 ---------- 41 num_f_maps : torch.Size 42 shape of feature extractor output 43 len_segment : int 44 length of segment in the base feature extractor output 45 ssl_features : int, default 128 46 the final number of features per clip 47 tau : float, default 1 48 the tau parameter of NT-Xent loss 49 """ 50 51 super().__init__() 52 self.loss_function = _NTXent(tau) 53 if len(num_f_maps) > 1: 54 raise RuntimeError( 55 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 56 f"got {len(num_f_maps) + 1} dimensions" 57 ) 58 num_f_maps = int(num_f_maps[0]) 59 self.pars = { 60 "num_f_maps": num_f_maps, 61 "len_segment": len_segment, 62 "output_dim": ssl_features, 63 "kernel_1": 5, 64 "kernel_2": 5, 65 "stride": 2, 66 "decrease_f_maps": True, 67 } 68 69 def transformation(self, sample_data: Dict) -> Tuple: 70 """ 71 Empty transformation 72 """ 73 74 return torch.tensor(float("nan")), torch.tensor(float("nan")) 75 76 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 77 """ 78 NT-Xent loss 79 """ 80 81 features1, features2 = predicted 82 loss = self.loss_function(features1, features2) 83 return loss 84 85 def construct_module(self) -> Union[nn.Module, None]: 86 """ 87 Clip-wise feature TCN extractor 88 """ 89 90 module = _FeatureExtractorTCN(**self.pars) 91 return module
A contrastive SSL class with an NT-Xent loss
The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).
31 def __init__( 32 self, 33 num_f_maps: torch.Size, 34 len_segment: int, 35 ssl_features: int = 128, 36 tau: float = 1, 37 ) -> None: 38 """ 39 Parameters 40 ---------- 41 num_f_maps : torch.Size 42 shape of feature extractor output 43 len_segment : int 44 length of segment in the base feature extractor output 45 ssl_features : int, default 128 46 the final number of features per clip 47 tau : float, default 1 48 the tau parameter of NT-Xent loss 49 """ 50 51 super().__init__() 52 self.loss_function = _NTXent(tau) 53 if len(num_f_maps) > 1: 54 raise RuntimeError( 55 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 56 f"got {len(num_f_maps) + 1} dimensions" 57 ) 58 num_f_maps = int(num_f_maps[0]) 59 self.pars = { 60 "num_f_maps": num_f_maps, 61 "len_segment": len_segment, 62 "output_dim": ssl_features, 63 "kernel_1": 5, 64 "kernel_2": 5, 65 "stride": 2, 66 "decrease_f_maps": True, 67 }
Parameters
num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in the base feature extractor output ssl_features : int, default 128 the final number of features per clip tau : float, default 1 the tau parameter of NT-Xent loss
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
69 def transformation(self, sample_data: Dict) -> Tuple: 70 """ 71 Empty transformation 72 """ 73 74 return torch.tensor(float("nan")), torch.tensor(float("nan"))
Empty transformation
94class ContrastiveMaskedSSL(SSLConstructor): 95 """ 96 A contrastive masked SSL class with an NT-Xent loss 97 98 A few frames in the middle of each segment are masked and then the output of the second layer of 99 feature extraction for the segment is used to predict the output of the first layer for the missing frames. 100 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 101 input sample at runtime). 102 """ 103 104 type = "contrastive_2layers" 105 106 def __init__( 107 self, 108 num_f_maps: torch.Size, 109 len_segment: int, 110 ssl_features: int = 128, 111 tau: float = 1, 112 num_masked: int = 10, 113 ) -> None: 114 """ 115 Parameters 116 ---------- 117 num_f_maps : torch.Size 118 shape of feature extractor output 119 len_segment : int 120 length of segment in the base feature extractor output 121 ssl_features : int, default 128 122 the final number of features per clip 123 tau : float, default 1 124 the tau parameter of NT-Xent loss 125 """ 126 127 super().__init__() 128 self.start = int(len_segment // 2 - num_masked // 2) 129 self.end = int(len_segment // 2 + num_masked // 2) 130 self.loss_function = _NTXent(tau) 131 if len(num_f_maps) > 1: 132 raise RuntimeError( 133 "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; " 134 f"got {len(num_f_maps) + 1} dimensions" 135 ) 136 num_f_maps = int(num_f_maps[0]) 137 self.pars = { 138 "num_f_maps": num_f_maps, 139 "len_segment": len_segment, 140 "output_dim": ssl_features, 141 "kernel_1": 3, 142 "kernel_2": 3, 143 "stride": 1, 144 "start": self.start, 145 "end": self.end, 146 } 147 148 def transformation(self, sample_data: Dict) -> Tuple: 149 """ 150 Empty transformation 151 """ 152 153 data = deepcopy(sample_data) 154 for key in data.keys(): 155 data[key][:, self.start : self.end] = 0 156 return data, torch.tensor(float("nan")) 157 # return torch.tensor(float("nan")), torch.tensor(float("nan")) 158 159 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 160 """ 161 NT-Xent loss 162 """ 163 164 features, ssl_features = predicted 165 loss = self.loss_function(features, ssl_features) 166 return loss 167 168 def construct_module(self) -> Union[nn.Module, None]: 169 """ 170 Clip-wise feature TCN extractor 171 """ 172 173 module = _MFeatureExtractorTCN(**self.pars) 174 return module
A contrastive masked SSL class with an NT-Xent loss
A few frames in the middle of each segment are masked and then the output of the second layer of feature extraction for the segment is used to predict the output of the first layer for the missing frames. The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).
106 def __init__( 107 self, 108 num_f_maps: torch.Size, 109 len_segment: int, 110 ssl_features: int = 128, 111 tau: float = 1, 112 num_masked: int = 10, 113 ) -> None: 114 """ 115 Parameters 116 ---------- 117 num_f_maps : torch.Size 118 shape of feature extractor output 119 len_segment : int 120 length of segment in the base feature extractor output 121 ssl_features : int, default 128 122 the final number of features per clip 123 tau : float, default 1 124 the tau parameter of NT-Xent loss 125 """ 126 127 super().__init__() 128 self.start = int(len_segment // 2 - num_masked // 2) 129 self.end = int(len_segment // 2 + num_masked // 2) 130 self.loss_function = _NTXent(tau) 131 if len(num_f_maps) > 1: 132 raise RuntimeError( 133 "The ContrastiveMaskedSSL constructor expects the input data to be 2-dimensional; " 134 f"got {len(num_f_maps) + 1} dimensions" 135 ) 136 num_f_maps = int(num_f_maps[0]) 137 self.pars = { 138 "num_f_maps": num_f_maps, 139 "len_segment": len_segment, 140 "output_dim": ssl_features, 141 "kernel_1": 3, 142 "kernel_2": 3, 143 "stride": 1, 144 "start": self.start, 145 "end": self.end, 146 }
Parameters
num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in the base feature extractor output ssl_features : int, default 128 the final number of features per clip tau : float, default 1 the tau parameter of NT-Xent loss
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
148 def transformation(self, sample_data: Dict) -> Tuple: 149 """ 150 Empty transformation 151 """ 152 153 data = deepcopy(sample_data) 154 for key in data.keys(): 155 data[key][:, self.start : self.end] = 0 156 return data, torch.tensor(float("nan")) 157 # return torch.tensor(float("nan")), torch.tensor(float("nan"))
Empty transformation
159 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 160 """ 161 NT-Xent loss 162 """ 163 164 features, ssl_features = predicted 165 loss = self.loss_function(features, ssl_features) 166 return loss
NT-Xent loss
177class PairwiseSSL(SSLConstructor): 178 """ 179 A pairwise SSL class with triplet or circle loss 180 181 The SSL input and target are left empty (the SSL input is generated as an augmentation of the 182 input sample at runtime). 183 """ 184 185 type = "contrastive" 186 187 def __init__( 188 self, 189 num_f_maps: torch.Size, 190 len_segment: int, 191 ssl_features: int = 128, 192 margin: float = 0, 193 distance: str = "cosine", 194 loss: str = "triplet", 195 gamma: float = 1, 196 ) -> None: 197 """ 198 Parameters 199 ---------- 200 num_f_maps : torch.Size 201 shape of feature extractor output 202 len_segment : int 203 length of segment in feature extractor output 204 ssl_features : int, default 128 205 final number of features per clip 206 margin : float, default 0 207 the margin parameter of triplet or circle loss 208 distance : {'cosine', 'euclidean'} 209 the distance calculation method for triplet or circle loss 210 loss : {'triplet', 'circle'} 211 the loss function name 212 gamma : float, default 1 213 the gamma parameter of circle loss 214 """ 215 216 super().__init__() 217 if loss == "triplet": 218 self.loss_function = _TripletLoss(margin=margin, distance=distance) 219 elif loss == "circle": 220 self.loss_function = _CircleLoss( 221 margin=margin, gamma=gamma, distance=distance 222 ) 223 else: 224 raise ValueError( 225 f'The {loss} loss is unavailable, please choose from "triplet" and "circle"' 226 ) 227 if len(num_f_maps) > 1: 228 raise RuntimeError( 229 "The PairwiseSSL constructor expects the input data to be 2-dimensional; " 230 f"got {len(num_f_maps) + 1} dimensions" 231 ) 232 num_f_maps = int(num_f_maps[0]) 233 self.pars = { 234 "num_f_maps": num_f_maps, 235 "len_segment": len_segment, 236 "output_dim": ssl_features, 237 "kernel_1": 5, 238 "kernel_2": 5, 239 "stride": 2, 240 "decrease_f_maps": True, 241 } 242 243 def transformation(self, sample_data: Dict) -> Tuple: 244 """ 245 Empty transformation 246 """ 247 248 return torch.tensor(float("nan")), torch.tensor(float("nan")) 249 250 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 251 """ 252 Triplet or circle loss 253 """ 254 255 features1, features2 = predicted 256 loss = self.loss_function(features1, features2) 257 return loss 258 259 def construct_module(self) -> Union[nn.Module, None]: 260 """ 261 Clip-wise feature TCN extractor 262 """ 263 264 module = _FeatureExtractorTCN(**self.pars) 265 return module
A pairwise SSL class with triplet or circle loss
The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).
187 def __init__( 188 self, 189 num_f_maps: torch.Size, 190 len_segment: int, 191 ssl_features: int = 128, 192 margin: float = 0, 193 distance: str = "cosine", 194 loss: str = "triplet", 195 gamma: float = 1, 196 ) -> None: 197 """ 198 Parameters 199 ---------- 200 num_f_maps : torch.Size 201 shape of feature extractor output 202 len_segment : int 203 length of segment in feature extractor output 204 ssl_features : int, default 128 205 final number of features per clip 206 margin : float, default 0 207 the margin parameter of triplet or circle loss 208 distance : {'cosine', 'euclidean'} 209 the distance calculation method for triplet or circle loss 210 loss : {'triplet', 'circle'} 211 the loss function name 212 gamma : float, default 1 213 the gamma parameter of circle loss 214 """ 215 216 super().__init__() 217 if loss == "triplet": 218 self.loss_function = _TripletLoss(margin=margin, distance=distance) 219 elif loss == "circle": 220 self.loss_function = _CircleLoss( 221 margin=margin, gamma=gamma, distance=distance 222 ) 223 else: 224 raise ValueError( 225 f'The {loss} loss is unavailable, please choose from "triplet" and "circle"' 226 ) 227 if len(num_f_maps) > 1: 228 raise RuntimeError( 229 "The PairwiseSSL constructor expects the input data to be 2-dimensional; " 230 f"got {len(num_f_maps) + 1} dimensions" 231 ) 232 num_f_maps = int(num_f_maps[0]) 233 self.pars = { 234 "num_f_maps": num_f_maps, 235 "len_segment": len_segment, 236 "output_dim": ssl_features, 237 "kernel_1": 5, 238 "kernel_2": 5, 239 "stride": 2, 240 "decrease_f_maps": True, 241 }
Parameters
num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in feature extractor output ssl_features : int, default 128 final number of features per clip margin : float, default 0 the margin parameter of triplet or circle loss distance : {'cosine', 'euclidean'} the distance calculation method for triplet or circle loss loss : {'triplet', 'circle'} the loss function name gamma : float, default 1 the gamma parameter of circle loss
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
243 def transformation(self, sample_data: Dict) -> Tuple: 244 """ 245 Empty transformation 246 """ 247 248 return torch.tensor(float("nan")), torch.tensor(float("nan"))
Empty transformation
250 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 251 """ 252 Triplet or circle loss 253 """ 254 255 features1, features2 = predicted 256 loss = self.loss_function(features1, features2) 257 return loss
Triplet or circle loss
268class PairwiseMaskedSSL(PairwiseSSL): 269 270 type = "contrastive_2layers" 271 272 def __init__( 273 self, 274 num_f_maps: torch.Size, 275 len_segment: int, 276 ssl_features: int = 128, 277 margin: float = 0, 278 distance: str = "cosine", 279 loss: str = "triplet", 280 gamma: float = 1, 281 num_masked: int = 10, 282 ) -> None: 283 super().__init__( 284 num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma 285 ) 286 self.num_masked = num_masked 287 self.start = int(len_segment // 2 - num_masked // 2) 288 self.end = int(len_segment // 2 + num_masked // 2) 289 if len(num_f_maps) > 1: 290 raise RuntimeError( 291 "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; " 292 f"got {len(num_f_maps) + 1} dimensions" 293 ) 294 num_f_maps = int(num_f_maps[0]) 295 self.pars = { 296 "num_f_maps": num_f_maps, 297 "len_segment": len_segment, 298 "output_dim": ssl_features, 299 "kernel_1": 3, 300 "kernel_2": 3, 301 "stride": 1, 302 "start": self.start, 303 "end": self.end, 304 } 305 306 def transformation(self, sample_data: Dict) -> Tuple: 307 """ 308 Empty transformation 309 """ 310 311 data = deepcopy(sample_data) 312 for key in data.keys(): 313 data[key][:, self.start : self.end] = 0 314 return data, torch.tensor(float("nan")) 315 316 def construct_module(self) -> Union[nn.Module, None]: 317 """ 318 Clip-wise feature TCN extractor 319 """ 320 321 module = _MFeatureExtractorTCN(**self.pars) 322 return module
A pairwise SSL class with triplet or circle loss
The SSL input and target are left empty (the SSL input is generated as an augmentation of the input sample at runtime).
272 def __init__( 273 self, 274 num_f_maps: torch.Size, 275 len_segment: int, 276 ssl_features: int = 128, 277 margin: float = 0, 278 distance: str = "cosine", 279 loss: str = "triplet", 280 gamma: float = 1, 281 num_masked: int = 10, 282 ) -> None: 283 super().__init__( 284 num_f_maps, len_segment, ssl_features, margin, distance, loss, gamma 285 ) 286 self.num_masked = num_masked 287 self.start = int(len_segment // 2 - num_masked // 2) 288 self.end = int(len_segment // 2 + num_masked // 2) 289 if len(num_f_maps) > 1: 290 raise RuntimeError( 291 "The PairwiseMaskedSSL constructor expects the input data to be 2-dimensional; " 292 f"got {len(num_f_maps) + 1} dimensions" 293 ) 294 num_f_maps = int(num_f_maps[0]) 295 self.pars = { 296 "num_f_maps": num_f_maps, 297 "len_segment": len_segment, 298 "output_dim": ssl_features, 299 "kernel_1": 3, 300 "kernel_2": 3, 301 "stride": 1, 302 "start": self.start, 303 "end": self.end, 304 }
Parameters
num_f_maps : torch.Size shape of feature extractor output len_segment : int length of segment in feature extractor output ssl_features : int, default 128 final number of features per clip margin : float, default 0 the margin parameter of triplet or circle loss distance : {'cosine', 'euclidean'} the distance calculation method for triplet or circle loss loss : {'triplet', 'circle'} the loss function name gamma : float, default 1 the gamma parameter of circle loss
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
306 def transformation(self, sample_data: Dict) -> Tuple: 307 """ 308 Empty transformation 309 """ 310 311 data = deepcopy(sample_data) 312 for key in data.keys(): 313 data[key][:, self.start : self.end] = 0 314 return data, torch.tensor(float("nan"))
Empty transformation
316 def construct_module(self) -> Union[nn.Module, None]: 317 """ 318 Clip-wise feature TCN extractor 319 """ 320 321 module = _MFeatureExtractorTCN(**self.pars) 322 return module
Clip-wise feature TCN extractor
Inherited Members
325class ContrastiveRegressionSSL(SSLConstructor): 326 327 type = "contrastive" 328 329 def __init__( 330 self, 331 num_f_maps: torch.Size, 332 num_features: int = 128, 333 num_ssl_layers: int = 1, 334 distance: str = "cosine", 335 temperature: float = 1, 336 break_factor: int = None, 337 ) -> None: 338 if len(num_f_maps) > 1: 339 raise RuntimeError( 340 "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; " 341 f"got {len(num_f_maps) + 1} dimensions" 342 ) 343 num_f_maps = int(num_f_maps[0]) 344 self.loss_function = _ContrastiveRegressionLoss( 345 temperature, distance, break_factor 346 ) 347 self.pars = { 348 "num_f_maps": num_f_maps, 349 "num_ssl_layers": num_ssl_layers, 350 "num_ssl_f_maps": num_features, 351 "dim": num_features, 352 } 353 super().__init__() 354 355 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 356 """ 357 NT-Xent loss 358 """ 359 360 features1, features2 = predicted 361 loss = self.loss_function(features1, features2) 362 return loss 363 364 def transformation(self, sample_data: Dict) -> Tuple: 365 """ 366 Empty transformation 367 """ 368 369 return torch.tensor(float("nan")), torch.tensor(float("nan")) 370 371 def construct_module(self) -> Union[nn.Module, None]: 372 """ 373 Clip-wise feature TCN extractor 374 """ 375 376 return _FC(**self.pars)
A base class for all SSL constructors
An SSL method is defined by three things: a transformation that maps a sample into SSL input and output, a neural net module that takes features as input and predicts SSL target, a type and a loss function.
329 def __init__( 330 self, 331 num_f_maps: torch.Size, 332 num_features: int = 128, 333 num_ssl_layers: int = 1, 334 distance: str = "cosine", 335 temperature: float = 1, 336 break_factor: int = None, 337 ) -> None: 338 if len(num_f_maps) > 1: 339 raise RuntimeError( 340 "The ContrastiveRegressionSSL constructor expects the input data to be 2-dimensional; " 341 f"got {len(num_f_maps) + 1} dimensions" 342 ) 343 num_f_maps = int(num_f_maps[0]) 344 self.loss_function = _ContrastiveRegressionLoss( 345 temperature, distance, break_factor 346 ) 347 self.pars = { 348 "num_f_maps": num_f_maps, 349 "num_ssl_layers": num_ssl_layers, 350 "num_ssl_f_maps": num_features, 351 "dim": num_features, 352 } 353 super().__init__()
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
355 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 356 """ 357 NT-Xent loss 358 """ 359 360 features1, features2 = predicted 361 loss = self.loss_function(features1, features2) 362 return loss
NT-Xent loss