dlc2action.ssl.masked
Implementations of dlc2action.ssl.base_ssl.SSLConstructor that predict masked input features.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Implementations of `dlc2action.ssl.base_ssl.SSLConstructor` that predict masked input features.""" 8 9from typing import Dict, Tuple, Union, List 10 11import torch 12 13from dlc2action.ssl.base_ssl import SSLConstructor 14from abc import ABC, abstractmethod 15from dlc2action.loss.mse import MSE 16from dlc2action.ssl.modules import * 17 18 19class MaskedFeaturesSSL(SSLConstructor, ABC): 20 """A base masked features SSL class. 21 22 Mask some of the input features randomly and predict the initial data. 23 """ 24 25 type = "ssl_input" 26 27 def __init__(self, frac_masked: float = 0.2) -> None: 28 """Initialize the constructor. 29 30 Parameters 31 ---------- 32 frac_masked : float 33 fraction of features to real_lens 34 35 """ 36 super().__init__() 37 self.mse = MSE() 38 self.frac_masked = frac_masked 39 40 def transformation(self, sample_data: Dict) -> Tuple: 41 """Mask some of the features randomly.""" 42 for key in sample_data: 43 mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked 44 sample_data[key] = sample_data[key] * mask 45 ssl_target = torch.cat(list(sample_data.values())) 46 return (sample_data, ssl_target) 47 48 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 49 """MSE loss.""" 50 loss = self.mse(predicted, target) 51 return loss 52 53 @abstractmethod 54 def construct_module(self) -> nn.Module: 55 """Construct the SSL prediction module using the parameters specified at initialization.""" 56 57 58class MaskedFeaturesSSL_FC(MaskedFeaturesSSL): 59 """A fully connected masked features SSL class. 60 61 Mask some of the input features randomly and predict the initial data. 62 """ 63 64 type = "ssl_input" 65 66 def __init__( 67 self, 68 dims: torch.Size, 69 num_f_maps: torch.Size, 70 frac_masked: float = 0.2, 71 num_ssl_layers: int = 5, 72 num_ssl_f_maps: int = 16, 73 ) -> None: 74 """Initialize the constructor. 75 76 Parameters 77 ---------- 78 dims : torch.Size 79 the shape of features in model input 80 num_f_maps : torch.Size 81 shape of feature extraction output 82 frac_masked : float, default 0.1 83 fraction of features to real_lens 84 num_ssl_layers : int, default 5 85 number of layers in the SSL module 86 num_ssl_f_maps : int, default 16 87 number of feature maps in the SSL module 88 89 """ 90 super().__init__(frac_masked) 91 dim = int(sum([s[0] for s in dims.values()])) 92 num_f_maps = int(num_f_maps[0]) 93 self.pars = { 94 "dim": dim, 95 "num_f_maps": num_f_maps, 96 "num_ssl_layers": num_ssl_layers, 97 "num_ssl_f_maps": num_ssl_f_maps, 98 } 99 100 def construct_module(self) -> Union[nn.Module, None]: 101 """Construct a fully connected module.""" 102 module = FC(**self.pars) 103 return module 104 105 106class MaskedFeaturesSSL_TCN(MaskedFeaturesSSL): 107 """A TCN masked features SSL class. 108 109 Mask some of the input features randomly and predict the initial data. 110 """ 111 112 def __init__( 113 self, 114 dims: Dict, 115 num_f_maps: torch.Size, 116 frac_masked: float = 0.2, 117 num_ssl_layers: int = 5, 118 num_ssl_f_maps:int = None, 119 ) -> None: 120 """Initialize the class. 121 122 Parameters 123 ---------- 124 dims : torch.Size 125 the shape of features in model input 126 num_f_maps : torch.Size 127 shape of feature extraction output 128 frac_masked : float, default 0.1 129 fraction of features to real_lens 130 num_ssl_layers : int, default 5 131 number of layers in the SSL module 132 133 """ 134 super().__init__(frac_masked) 135 136 if not num_ssl_f_maps is None: 137 print(f"num_ssl_f_maps is set to {num_ssl_f_maps} but is ignored for TCN") 138 139 dim = int(sum([s[0] for s in dims.values()])) 140 num_f_maps = int(num_f_maps[0]) 141 self.pars = { 142 "input_dim": num_f_maps, 143 "num_layers": num_ssl_layers, 144 "output_dim": dim, 145 } 146 147 def construct_module(self) -> Union[nn.Module, None]: 148 """Construct a TCN module.""" 149 module = DilatedTCN(**self.pars) 150 return module 151 152 153class MaskedKinematicSSL(SSLConstructor, ABC): 154 """A base masked joints SSL class. 155 156 Mask some of the joints randomly and predict the initial data. 157 """ 158 159 type = "ssl_input" 160 161 def __init__(self, frac_masked: float = 0.2) -> None: 162 """Initialize the class. 163 164 Parameters 165 ---------- 166 frac_masked : float, default 0.1 167 fraction of features to real_lens 168 169 """ 170 super().__init__() 171 self.mse = MSE() 172 self.frac_masked = frac_masked 173 174 def _get_keys(self, key_bases, x): 175 """Get keys of x that start with one of the strings in key_bases.""" 176 keys = [] 177 for key in x: 178 if key_bases.count(key) > 0: 179 keys.append(key) 180 return keys 181 182 def transformation(self, sample_data: Dict) -> Tuple: 183 """Mask joints randomly.""" 184 assert ( 185 "coords" in sample_data.keys() or "coord_diff" in sample_data.keys() 186 ), "'coords' or 'coord_diff' features are required when using MaskedKinematicSSL" 187 188 multi_dim_features = self._get_keys( 189 ( 190 "coords", 191 "coord_diff", 192 "speed_direction", 193 ), 194 sample_data, 195 ) 196 197 single_dim_features = self._get_keys( 198 ( 199 "speed_joints", 200 "acc_joints", 201 "angle_joints_radian", 202 "angle_speeds", 203 "speed_value", 204 ), 205 sample_data, 206 ) 207 208 assert ( 209 len(multi_dim_features) > 0 210 ), "No multi-dimensional features found in sample_data" 211 assert ( 212 len(single_dim_features) > 0 213 ), "No single-dimensional features found in sample_data" 214 215 features, frames = sample_data[multi_dim_features[0]].shape 216 217 n_bp = features // 2 218 masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked 219 220 keys = self._get_keys(("intra_distance", "inter_distance"), sample_data) 221 for key in keys: 222 mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2) 223 indices = torch.triu_indices(n_bp, n_bp, 1) 224 225 X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device) 226 X[indices[0], indices[1], :] = sample_data[key] 227 X[mask] = 0 228 X[mask.transpose(0, 1)] = 0 229 sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames) 230 231 for key in multi_dim_features + single_dim_features: 232 mask = ( 233 masked_joints.repeat(sample_data[key].shape[0]//sample_data[single_dim_features[0]].shape[0], frames, 1) 234 .transpose(0, 2) 235 .reshape((-1, frames)) 236 ) 237 sample_data[key][mask] = 0 238 239 return sample_data, torch.cat(list(sample_data.values())) 240 241 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 242 """MSE loss.""" 243 loss = self.mse(predicted, target) 244 return loss 245 246 @abstractmethod 247 def construct_module(self) -> Union[nn.Module, None]: 248 """Construct the SSL prediction module using the parameters specified at initialization.""" 249 250 251class MaskedKinematicSSL_FC(MaskedKinematicSSL): 252 """Masked kinematic SSL class with fully connected module.""" 253 254 def __init__( 255 self, 256 dims: torch.Size, 257 num_f_maps: torch.Size, 258 frac_masked: float = 0.2, 259 num_ssl_layers: int = 5, 260 num_ssl_f_maps: int = 16, 261 ) -> None: 262 """Initialize the constructor. 263 264 Parameters 265 ---------- 266 dims : torch.Size 267 the number of features in model input 268 num_f_maps : torch.Size 269 shape of feature extraction output 270 frac_masked : float, default 0.1 271 fraction of joints to real_lens 272 num_ssl_layers : int, default 5 273 number of layers in the SSL module 274 num_ssl_f_maps : int, default 16 275 number of feature maps in the SSL module 276 277 """ 278 super().__init__(frac_masked) 279 dim = int(sum([s[0] for s in dims.values()])) 280 num_f_maps = int(num_f_maps[0]) 281 self.pars = { 282 "dim": dim, 283 "num_f_maps": num_f_maps, 284 "num_ssl_layers": num_ssl_layers, 285 "num_ssl_f_maps": num_ssl_f_maps, 286 } 287 288 def construct_module(self) -> Union[nn.Module, None]: 289 """Construct a fully connected module.""" 290 module = FC(**self.pars) 291 return module 292 293 294class MaskedKinematicSSL_TCN(MaskedKinematicSSL): 295 """Masked kinematic SSL using a TCN module.""" 296 297 def __init__( 298 self, 299 dims: torch.Size, 300 num_f_maps: torch.Size, 301 frac_masked: float = 0.2, 302 num_ssl_layers: int = 5, 303 ) -> None: 304 """Initialise the constructor. 305 306 Parameters 307 ---------- 308 dims : torch.Size 309 the shape of features in model input 310 num_f_maps : torch.Size 311 shape of feature extraction output 312 frac_masked : float, default 0.1 313 fraction of joints to real_lens 314 num_ssl_layers : int, default 5 315 number of layers in the SSL module 316 317 """ 318 super().__init__(frac_masked) 319 dim = int(sum([s[0] for s in dims.values()])) 320 num_f_maps = int(num_f_maps[0]) 321 self.pars = { 322 "input_dim": num_f_maps, 323 "num_layers": num_ssl_layers, 324 "output_dim": dim, 325 } 326 327 def construct_module(self) -> Union[nn.Module, None]: 328 """Construct a TCN module.""" 329 module = DilatedTCN(**self.pars) 330 return module 331 332 333class MaskedFramesSSL(SSLConstructor, ABC): 334 """Abstract class for masked frame SSL constructors. 335 336 Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly 337 and predict the initial data 338 """ 339 340 type = "ssl_input" 341 342 def __init__(self, frac_masked: float = 0.1) -> None: 343 """Initialize the SSL constructor. 344 345 Parameters 346 ---------- 347 frac_masked : float, default 0.1 348 fraction of frames to real_lens 349 350 """ 351 super().__init__() 352 self.frac_masked = frac_masked 353 self.mse = MSE() 354 355 def transformation(self, sample_data: Dict) -> Tuple: 356 """Mask some of the frames randomly.""" 357 key = list(sample_data.keys())[0] 358 num_frames = sample_data[key].shape[-1] 359 mask = torch.empty(num_frames).normal_() > self.frac_masked 360 mask = mask.unsqueeze(0) 361 for key in sample_data: 362 sample_data[key] = sample_data[key] * mask 363 ssl_target = torch.cat(list(sample_data.values())) 364 return (sample_data, ssl_target) 365 366 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 367 """MSE loss.""" 368 loss = self.mse(predicted, target) 369 return loss 370 371 @abstractmethod 372 def construct_module(self) -> Union[nn.Module, None]: 373 """Construct the SSL prediction module using the parameters specified at initialization.""" 374 375 376class MaskedFramesSSL_FC(MaskedFramesSSL): 377 """Masked frames SSL with a fully connected module.""" 378 379 def __init__( 380 self, 381 dims: torch.Size, 382 num_f_maps: torch.Size, 383 frac_masked: float = 0.1, 384 num_ssl_layers: int = 3, 385 num_ssl_f_maps: int = 16, 386 ) -> None: 387 """Initialize the constructor. 388 389 Parameters 390 ---------- 391 dims : torch.Size 392 the shape of features in model input 393 num_f_maps : torch.Size 394 shape of feature extraction output 395 frac_masked : float, default 0.1 396 fraction of frames to real_lens 397 num_ssl_layers : int, default 5 398 number of layers in the SSL module 399 num_ssl_f_maps : int, default 16 400 number of feature maps in the SSL module 401 402 """ 403 super().__init__(frac_masked) 404 dim = int(sum([s[0] for s in dims.values()])) 405 num_f_maps = int(num_f_maps[0]) 406 self.pars = { 407 "dim": dim, 408 "num_f_maps": num_f_maps, 409 "num_ssl_layers": num_ssl_layers, 410 "num_ssl_f_maps": num_ssl_f_maps, 411 } 412 413 def construct_module(self) -> Union[nn.Module, None]: 414 """Construct a fully connected module.""" 415 module = FC(**self.pars) 416 return module 417 418 419class MaskedFramesSSL_TCN(MaskedFramesSSL): 420 """Masked frames SSL with a TCN module.""" 421 422 def __init__( 423 self, 424 dims: torch.Size, 425 num_f_maps: torch.Size, 426 frac_masked: float = 0.2, 427 num_ssl_layers: int = 5, 428 ) -> None: 429 """Initialize the SSL constructor. 430 431 Parameters 432 ---------- 433 dims : torch.Size 434 the number of features in model input 435 num_f_maps : torch.Size 436 shape of feature extraction output 437 frac_masked : float, default 0.1 438 fraction of frames to real_lens 439 num_ssl_layers : int, default 5 440 number of layers in the SSL module 441 442 """ 443 super().__init__(frac_masked) 444 dim = int(sum([s[0] for s in dims.values()])) 445 num_f_maps = int(num_f_maps[0]) 446 self.pars = { 447 "input_dim": num_f_maps, 448 "num_layers": num_ssl_layers, 449 "output_dim": dim, 450 } 451 452 def construct_module(self) -> Union[nn.Module, None]: 453 """Construct a TCN module.""" 454 module = DilatedTCN(**self.pars) 455 return module
20class MaskedFeaturesSSL(SSLConstructor, ABC): 21 """A base masked features SSL class. 22 23 Mask some of the input features randomly and predict the initial data. 24 """ 25 26 type = "ssl_input" 27 28 def __init__(self, frac_masked: float = 0.2) -> None: 29 """Initialize the constructor. 30 31 Parameters 32 ---------- 33 frac_masked : float 34 fraction of features to real_lens 35 36 """ 37 super().__init__() 38 self.mse = MSE() 39 self.frac_masked = frac_masked 40 41 def transformation(self, sample_data: Dict) -> Tuple: 42 """Mask some of the features randomly.""" 43 for key in sample_data: 44 mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked 45 sample_data[key] = sample_data[key] * mask 46 ssl_target = torch.cat(list(sample_data.values())) 47 return (sample_data, ssl_target) 48 49 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 50 """MSE loss.""" 51 loss = self.mse(predicted, target) 52 return loss 53 54 @abstractmethod 55 def construct_module(self) -> nn.Module: 56 """Construct the SSL prediction module using the parameters specified at initialization."""
A base masked features SSL class.
Mask some of the input features randomly and predict the initial data.
28 def __init__(self, frac_masked: float = 0.2) -> None: 29 """Initialize the constructor. 30 31 Parameters 32 ---------- 33 frac_masked : float 34 fraction of features to real_lens 35 36 """ 37 super().__init__() 38 self.mse = MSE() 39 self.frac_masked = frac_masked
Initialize the constructor.
Parameters
frac_masked : float fraction of features to real_lens
The type parameter defines interaction with the model:
'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
41 def transformation(self, sample_data: Dict) -> Tuple: 42 """Mask some of the features randomly.""" 43 for key in sample_data: 44 mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked 45 sample_data[key] = sample_data[key] * mask 46 ssl_target = torch.cat(list(sample_data.values())) 47 return (sample_data, ssl_target)
Mask some of the features randomly.
49 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 50 """MSE loss.""" 51 loss = self.mse(predicted, target) 52 return loss
MSE loss.
54 @abstractmethod 55 def construct_module(self) -> nn.Module: 56 """Construct the SSL prediction module using the parameters specified at initialization."""
Construct the SSL prediction module using the parameters specified at initialization.
59class MaskedFeaturesSSL_FC(MaskedFeaturesSSL): 60 """A fully connected masked features SSL class. 61 62 Mask some of the input features randomly and predict the initial data. 63 """ 64 65 type = "ssl_input" 66 67 def __init__( 68 self, 69 dims: torch.Size, 70 num_f_maps: torch.Size, 71 frac_masked: float = 0.2, 72 num_ssl_layers: int = 5, 73 num_ssl_f_maps: int = 16, 74 ) -> None: 75 """Initialize the constructor. 76 77 Parameters 78 ---------- 79 dims : torch.Size 80 the shape of features in model input 81 num_f_maps : torch.Size 82 shape of feature extraction output 83 frac_masked : float, default 0.1 84 fraction of features to real_lens 85 num_ssl_layers : int, default 5 86 number of layers in the SSL module 87 num_ssl_f_maps : int, default 16 88 number of feature maps in the SSL module 89 90 """ 91 super().__init__(frac_masked) 92 dim = int(sum([s[0] for s in dims.values()])) 93 num_f_maps = int(num_f_maps[0]) 94 self.pars = { 95 "dim": dim, 96 "num_f_maps": num_f_maps, 97 "num_ssl_layers": num_ssl_layers, 98 "num_ssl_f_maps": num_ssl_f_maps, 99 } 100 101 def construct_module(self) -> Union[nn.Module, None]: 102 """Construct a fully connected module.""" 103 module = FC(**self.pars) 104 return module
A fully connected masked features SSL class.
Mask some of the input features randomly and predict the initial data.
67 def __init__( 68 self, 69 dims: torch.Size, 70 num_f_maps: torch.Size, 71 frac_masked: float = 0.2, 72 num_ssl_layers: int = 5, 73 num_ssl_f_maps: int = 16, 74 ) -> None: 75 """Initialize the constructor. 76 77 Parameters 78 ---------- 79 dims : torch.Size 80 the shape of features in model input 81 num_f_maps : torch.Size 82 shape of feature extraction output 83 frac_masked : float, default 0.1 84 fraction of features to real_lens 85 num_ssl_layers : int, default 5 86 number of layers in the SSL module 87 num_ssl_f_maps : int, default 16 88 number of feature maps in the SSL module 89 90 """ 91 super().__init__(frac_masked) 92 dim = int(sum([s[0] for s in dims.values()])) 93 num_f_maps = int(num_f_maps[0]) 94 self.pars = { 95 "dim": dim, 96 "num_f_maps": num_f_maps, 97 "num_ssl_layers": num_ssl_layers, 98 "num_ssl_f_maps": num_ssl_f_maps, 99 }
Initialize the constructor.
Parameters
dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of features to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module
The type parameter defines interaction with the model:
'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
101 def construct_module(self) -> Union[nn.Module, None]: 102 """Construct a fully connected module.""" 103 module = FC(**self.pars) 104 return module
Construct a fully connected module.
Inherited Members
107class MaskedFeaturesSSL_TCN(MaskedFeaturesSSL): 108 """A TCN masked features SSL class. 109 110 Mask some of the input features randomly and predict the initial data. 111 """ 112 113 def __init__( 114 self, 115 dims: Dict, 116 num_f_maps: torch.Size, 117 frac_masked: float = 0.2, 118 num_ssl_layers: int = 5, 119 num_ssl_f_maps:int = None, 120 ) -> None: 121 """Initialize the class. 122 123 Parameters 124 ---------- 125 dims : torch.Size 126 the shape of features in model input 127 num_f_maps : torch.Size 128 shape of feature extraction output 129 frac_masked : float, default 0.1 130 fraction of features to real_lens 131 num_ssl_layers : int, default 5 132 number of layers in the SSL module 133 134 """ 135 super().__init__(frac_masked) 136 137 if not num_ssl_f_maps is None: 138 print(f"num_ssl_f_maps is set to {num_ssl_f_maps} but is ignored for TCN") 139 140 dim = int(sum([s[0] for s in dims.values()])) 141 num_f_maps = int(num_f_maps[0]) 142 self.pars = { 143 "input_dim": num_f_maps, 144 "num_layers": num_ssl_layers, 145 "output_dim": dim, 146 } 147 148 def construct_module(self) -> Union[nn.Module, None]: 149 """Construct a TCN module.""" 150 module = DilatedTCN(**self.pars) 151 return module
A TCN masked features SSL class.
Mask some of the input features randomly and predict the initial data.
113 def __init__( 114 self, 115 dims: Dict, 116 num_f_maps: torch.Size, 117 frac_masked: float = 0.2, 118 num_ssl_layers: int = 5, 119 num_ssl_f_maps:int = None, 120 ) -> None: 121 """Initialize the class. 122 123 Parameters 124 ---------- 125 dims : torch.Size 126 the shape of features in model input 127 num_f_maps : torch.Size 128 shape of feature extraction output 129 frac_masked : float, default 0.1 130 fraction of features to real_lens 131 num_ssl_layers : int, default 5 132 number of layers in the SSL module 133 134 """ 135 super().__init__(frac_masked) 136 137 if not num_ssl_f_maps is None: 138 print(f"num_ssl_f_maps is set to {num_ssl_f_maps} but is ignored for TCN") 139 140 dim = int(sum([s[0] for s in dims.values()])) 141 num_f_maps = int(num_f_maps[0]) 142 self.pars = { 143 "input_dim": num_f_maps, 144 "num_layers": num_ssl_layers, 145 "output_dim": dim, 146 }
Initialize the class.
Parameters
dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of features to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module
148 def construct_module(self) -> Union[nn.Module, None]: 149 """Construct a TCN module.""" 150 module = DilatedTCN(**self.pars) 151 return module
Construct a TCN module.
Inherited Members
154class MaskedKinematicSSL(SSLConstructor, ABC): 155 """A base masked joints SSL class. 156 157 Mask some of the joints randomly and predict the initial data. 158 """ 159 160 type = "ssl_input" 161 162 def __init__(self, frac_masked: float = 0.2) -> None: 163 """Initialize the class. 164 165 Parameters 166 ---------- 167 frac_masked : float, default 0.1 168 fraction of features to real_lens 169 170 """ 171 super().__init__() 172 self.mse = MSE() 173 self.frac_masked = frac_masked 174 175 def _get_keys(self, key_bases, x): 176 """Get keys of x that start with one of the strings in key_bases.""" 177 keys = [] 178 for key in x: 179 if key_bases.count(key) > 0: 180 keys.append(key) 181 return keys 182 183 def transformation(self, sample_data: Dict) -> Tuple: 184 """Mask joints randomly.""" 185 assert ( 186 "coords" in sample_data.keys() or "coord_diff" in sample_data.keys() 187 ), "'coords' or 'coord_diff' features are required when using MaskedKinematicSSL" 188 189 multi_dim_features = self._get_keys( 190 ( 191 "coords", 192 "coord_diff", 193 "speed_direction", 194 ), 195 sample_data, 196 ) 197 198 single_dim_features = self._get_keys( 199 ( 200 "speed_joints", 201 "acc_joints", 202 "angle_joints_radian", 203 "angle_speeds", 204 "speed_value", 205 ), 206 sample_data, 207 ) 208 209 assert ( 210 len(multi_dim_features) > 0 211 ), "No multi-dimensional features found in sample_data" 212 assert ( 213 len(single_dim_features) > 0 214 ), "No single-dimensional features found in sample_data" 215 216 features, frames = sample_data[multi_dim_features[0]].shape 217 218 n_bp = features // 2 219 masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked 220 221 keys = self._get_keys(("intra_distance", "inter_distance"), sample_data) 222 for key in keys: 223 mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2) 224 indices = torch.triu_indices(n_bp, n_bp, 1) 225 226 X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device) 227 X[indices[0], indices[1], :] = sample_data[key] 228 X[mask] = 0 229 X[mask.transpose(0, 1)] = 0 230 sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames) 231 232 for key in multi_dim_features + single_dim_features: 233 mask = ( 234 masked_joints.repeat(sample_data[key].shape[0]//sample_data[single_dim_features[0]].shape[0], frames, 1) 235 .transpose(0, 2) 236 .reshape((-1, frames)) 237 ) 238 sample_data[key][mask] = 0 239 240 return sample_data, torch.cat(list(sample_data.values())) 241 242 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 243 """MSE loss.""" 244 loss = self.mse(predicted, target) 245 return loss 246 247 @abstractmethod 248 def construct_module(self) -> Union[nn.Module, None]: 249 """Construct the SSL prediction module using the parameters specified at initialization."""
A base masked joints SSL class.
Mask some of the joints randomly and predict the initial data.
162 def __init__(self, frac_masked: float = 0.2) -> None: 163 """Initialize the class. 164 165 Parameters 166 ---------- 167 frac_masked : float, default 0.1 168 fraction of features to real_lens 169 170 """ 171 super().__init__() 172 self.mse = MSE() 173 self.frac_masked = frac_masked
Initialize the class.
Parameters
frac_masked : float, default 0.1 fraction of features to real_lens
The type parameter defines interaction with the model:
'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
183 def transformation(self, sample_data: Dict) -> Tuple: 184 """Mask joints randomly.""" 185 assert ( 186 "coords" in sample_data.keys() or "coord_diff" in sample_data.keys() 187 ), "'coords' or 'coord_diff' features are required when using MaskedKinematicSSL" 188 189 multi_dim_features = self._get_keys( 190 ( 191 "coords", 192 "coord_diff", 193 "speed_direction", 194 ), 195 sample_data, 196 ) 197 198 single_dim_features = self._get_keys( 199 ( 200 "speed_joints", 201 "acc_joints", 202 "angle_joints_radian", 203 "angle_speeds", 204 "speed_value", 205 ), 206 sample_data, 207 ) 208 209 assert ( 210 len(multi_dim_features) > 0 211 ), "No multi-dimensional features found in sample_data" 212 assert ( 213 len(single_dim_features) > 0 214 ), "No single-dimensional features found in sample_data" 215 216 features, frames = sample_data[multi_dim_features[0]].shape 217 218 n_bp = features // 2 219 masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked 220 221 keys = self._get_keys(("intra_distance", "inter_distance"), sample_data) 222 for key in keys: 223 mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2) 224 indices = torch.triu_indices(n_bp, n_bp, 1) 225 226 X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device) 227 X[indices[0], indices[1], :] = sample_data[key] 228 X[mask] = 0 229 X[mask.transpose(0, 1)] = 0 230 sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames) 231 232 for key in multi_dim_features + single_dim_features: 233 mask = ( 234 masked_joints.repeat(sample_data[key].shape[0]//sample_data[single_dim_features[0]].shape[0], frames, 1) 235 .transpose(0, 2) 236 .reshape((-1, frames)) 237 ) 238 sample_data[key][mask] = 0 239 240 return sample_data, torch.cat(list(sample_data.values()))
Mask joints randomly.
242 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 243 """MSE loss.""" 244 loss = self.mse(predicted, target) 245 return loss
MSE loss.
247 @abstractmethod 248 def construct_module(self) -> Union[nn.Module, None]: 249 """Construct the SSL prediction module using the parameters specified at initialization."""
Construct the SSL prediction module using the parameters specified at initialization.
252class MaskedKinematicSSL_FC(MaskedKinematicSSL): 253 """Masked kinematic SSL class with fully connected module.""" 254 255 def __init__( 256 self, 257 dims: torch.Size, 258 num_f_maps: torch.Size, 259 frac_masked: float = 0.2, 260 num_ssl_layers: int = 5, 261 num_ssl_f_maps: int = 16, 262 ) -> None: 263 """Initialize the constructor. 264 265 Parameters 266 ---------- 267 dims : torch.Size 268 the number of features in model input 269 num_f_maps : torch.Size 270 shape of feature extraction output 271 frac_masked : float, default 0.1 272 fraction of joints to real_lens 273 num_ssl_layers : int, default 5 274 number of layers in the SSL module 275 num_ssl_f_maps : int, default 16 276 number of feature maps in the SSL module 277 278 """ 279 super().__init__(frac_masked) 280 dim = int(sum([s[0] for s in dims.values()])) 281 num_f_maps = int(num_f_maps[0]) 282 self.pars = { 283 "dim": dim, 284 "num_f_maps": num_f_maps, 285 "num_ssl_layers": num_ssl_layers, 286 "num_ssl_f_maps": num_ssl_f_maps, 287 } 288 289 def construct_module(self) -> Union[nn.Module, None]: 290 """Construct a fully connected module.""" 291 module = FC(**self.pars) 292 return module
Masked kinematic SSL class with fully connected module.
255 def __init__( 256 self, 257 dims: torch.Size, 258 num_f_maps: torch.Size, 259 frac_masked: float = 0.2, 260 num_ssl_layers: int = 5, 261 num_ssl_f_maps: int = 16, 262 ) -> None: 263 """Initialize the constructor. 264 265 Parameters 266 ---------- 267 dims : torch.Size 268 the number of features in model input 269 num_f_maps : torch.Size 270 shape of feature extraction output 271 frac_masked : float, default 0.1 272 fraction of joints to real_lens 273 num_ssl_layers : int, default 5 274 number of layers in the SSL module 275 num_ssl_f_maps : int, default 16 276 number of feature maps in the SSL module 277 278 """ 279 super().__init__(frac_masked) 280 dim = int(sum([s[0] for s in dims.values()])) 281 num_f_maps = int(num_f_maps[0]) 282 self.pars = { 283 "dim": dim, 284 "num_f_maps": num_f_maps, 285 "num_ssl_layers": num_ssl_layers, 286 "num_ssl_f_maps": num_ssl_f_maps, 287 }
Initialize the constructor.
Parameters
dims : torch.Size the number of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of joints to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module
289 def construct_module(self) -> Union[nn.Module, None]: 290 """Construct a fully connected module.""" 291 module = FC(**self.pars) 292 return module
Construct a fully connected module.
Inherited Members
295class MaskedKinematicSSL_TCN(MaskedKinematicSSL): 296 """Masked kinematic SSL using a TCN module.""" 297 298 def __init__( 299 self, 300 dims: torch.Size, 301 num_f_maps: torch.Size, 302 frac_masked: float = 0.2, 303 num_ssl_layers: int = 5, 304 ) -> None: 305 """Initialise the constructor. 306 307 Parameters 308 ---------- 309 dims : torch.Size 310 the shape of features in model input 311 num_f_maps : torch.Size 312 shape of feature extraction output 313 frac_masked : float, default 0.1 314 fraction of joints to real_lens 315 num_ssl_layers : int, default 5 316 number of layers in the SSL module 317 318 """ 319 super().__init__(frac_masked) 320 dim = int(sum([s[0] for s in dims.values()])) 321 num_f_maps = int(num_f_maps[0]) 322 self.pars = { 323 "input_dim": num_f_maps, 324 "num_layers": num_ssl_layers, 325 "output_dim": dim, 326 } 327 328 def construct_module(self) -> Union[nn.Module, None]: 329 """Construct a TCN module.""" 330 module = DilatedTCN(**self.pars) 331 return module
Masked kinematic SSL using a TCN module.
298 def __init__( 299 self, 300 dims: torch.Size, 301 num_f_maps: torch.Size, 302 frac_masked: float = 0.2, 303 num_ssl_layers: int = 5, 304 ) -> None: 305 """Initialise the constructor. 306 307 Parameters 308 ---------- 309 dims : torch.Size 310 the shape of features in model input 311 num_f_maps : torch.Size 312 shape of feature extraction output 313 frac_masked : float, default 0.1 314 fraction of joints to real_lens 315 num_ssl_layers : int, default 5 316 number of layers in the SSL module 317 318 """ 319 super().__init__(frac_masked) 320 dim = int(sum([s[0] for s in dims.values()])) 321 num_f_maps = int(num_f_maps[0]) 322 self.pars = { 323 "input_dim": num_f_maps, 324 "num_layers": num_ssl_layers, 325 "output_dim": dim, 326 }
Initialise the constructor.
Parameters
dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of joints to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module
328 def construct_module(self) -> Union[nn.Module, None]: 329 """Construct a TCN module.""" 330 module = DilatedTCN(**self.pars) 331 return module
Construct a TCN module.
Inherited Members
334class MaskedFramesSSL(SSLConstructor, ABC): 335 """Abstract class for masked frame SSL constructors. 336 337 Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly 338 and predict the initial data 339 """ 340 341 type = "ssl_input" 342 343 def __init__(self, frac_masked: float = 0.1) -> None: 344 """Initialize the SSL constructor. 345 346 Parameters 347 ---------- 348 frac_masked : float, default 0.1 349 fraction of frames to real_lens 350 351 """ 352 super().__init__() 353 self.frac_masked = frac_masked 354 self.mse = MSE() 355 356 def transformation(self, sample_data: Dict) -> Tuple: 357 """Mask some of the frames randomly.""" 358 key = list(sample_data.keys())[0] 359 num_frames = sample_data[key].shape[-1] 360 mask = torch.empty(num_frames).normal_() > self.frac_masked 361 mask = mask.unsqueeze(0) 362 for key in sample_data: 363 sample_data[key] = sample_data[key] * mask 364 ssl_target = torch.cat(list(sample_data.values())) 365 return (sample_data, ssl_target) 366 367 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 368 """MSE loss.""" 369 loss = self.mse(predicted, target) 370 return loss 371 372 @abstractmethod 373 def construct_module(self) -> Union[nn.Module, None]: 374 """Construct the SSL prediction module using the parameters specified at initialization."""
Abstract class for masked frame SSL constructors.
Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly and predict the initial data
343 def __init__(self, frac_masked: float = 0.1) -> None: 344 """Initialize the SSL constructor. 345 346 Parameters 347 ---------- 348 frac_masked : float, default 0.1 349 fraction of frames to real_lens 350 351 """ 352 super().__init__() 353 self.frac_masked = frac_masked 354 self.mse = MSE()
Initialize the SSL constructor.
Parameters
frac_masked : float, default 0.1 fraction of frames to real_lens
The type parameter defines interaction with the model:
'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
356 def transformation(self, sample_data: Dict) -> Tuple: 357 """Mask some of the frames randomly.""" 358 key = list(sample_data.keys())[0] 359 num_frames = sample_data[key].shape[-1] 360 mask = torch.empty(num_frames).normal_() > self.frac_masked 361 mask = mask.unsqueeze(0) 362 for key in sample_data: 363 sample_data[key] = sample_data[key] * mask 364 ssl_target = torch.cat(list(sample_data.values())) 365 return (sample_data, ssl_target)
Mask some of the frames randomly.
367 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 368 """MSE loss.""" 369 loss = self.mse(predicted, target) 370 return loss
MSE loss.
372 @abstractmethod 373 def construct_module(self) -> Union[nn.Module, None]: 374 """Construct the SSL prediction module using the parameters specified at initialization."""
Construct the SSL prediction module using the parameters specified at initialization.
377class MaskedFramesSSL_FC(MaskedFramesSSL): 378 """Masked frames SSL with a fully connected module.""" 379 380 def __init__( 381 self, 382 dims: torch.Size, 383 num_f_maps: torch.Size, 384 frac_masked: float = 0.1, 385 num_ssl_layers: int = 3, 386 num_ssl_f_maps: int = 16, 387 ) -> None: 388 """Initialize the constructor. 389 390 Parameters 391 ---------- 392 dims : torch.Size 393 the shape of features in model input 394 num_f_maps : torch.Size 395 shape of feature extraction output 396 frac_masked : float, default 0.1 397 fraction of frames to real_lens 398 num_ssl_layers : int, default 5 399 number of layers in the SSL module 400 num_ssl_f_maps : int, default 16 401 number of feature maps in the SSL module 402 403 """ 404 super().__init__(frac_masked) 405 dim = int(sum([s[0] for s in dims.values()])) 406 num_f_maps = int(num_f_maps[0]) 407 self.pars = { 408 "dim": dim, 409 "num_f_maps": num_f_maps, 410 "num_ssl_layers": num_ssl_layers, 411 "num_ssl_f_maps": num_ssl_f_maps, 412 } 413 414 def construct_module(self) -> Union[nn.Module, None]: 415 """Construct a fully connected module.""" 416 module = FC(**self.pars) 417 return module
Masked frames SSL with a fully connected module.
380 def __init__( 381 self, 382 dims: torch.Size, 383 num_f_maps: torch.Size, 384 frac_masked: float = 0.1, 385 num_ssl_layers: int = 3, 386 num_ssl_f_maps: int = 16, 387 ) -> None: 388 """Initialize the constructor. 389 390 Parameters 391 ---------- 392 dims : torch.Size 393 the shape of features in model input 394 num_f_maps : torch.Size 395 shape of feature extraction output 396 frac_masked : float, default 0.1 397 fraction of frames to real_lens 398 num_ssl_layers : int, default 5 399 number of layers in the SSL module 400 num_ssl_f_maps : int, default 16 401 number of feature maps in the SSL module 402 403 """ 404 super().__init__(frac_masked) 405 dim = int(sum([s[0] for s in dims.values()])) 406 num_f_maps = int(num_f_maps[0]) 407 self.pars = { 408 "dim": dim, 409 "num_f_maps": num_f_maps, 410 "num_ssl_layers": num_ssl_layers, 411 "num_ssl_f_maps": num_ssl_f_maps, 412 }
Initialize the constructor.
Parameters
dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of frames to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module
414 def construct_module(self) -> Union[nn.Module, None]: 415 """Construct a fully connected module.""" 416 module = FC(**self.pars) 417 return module
Construct a fully connected module.
Inherited Members
420class MaskedFramesSSL_TCN(MaskedFramesSSL): 421 """Masked frames SSL with a TCN module.""" 422 423 def __init__( 424 self, 425 dims: torch.Size, 426 num_f_maps: torch.Size, 427 frac_masked: float = 0.2, 428 num_ssl_layers: int = 5, 429 ) -> None: 430 """Initialize the SSL constructor. 431 432 Parameters 433 ---------- 434 dims : torch.Size 435 the number of features in model input 436 num_f_maps : torch.Size 437 shape of feature extraction output 438 frac_masked : float, default 0.1 439 fraction of frames to real_lens 440 num_ssl_layers : int, default 5 441 number of layers in the SSL module 442 443 """ 444 super().__init__(frac_masked) 445 dim = int(sum([s[0] for s in dims.values()])) 446 num_f_maps = int(num_f_maps[0]) 447 self.pars = { 448 "input_dim": num_f_maps, 449 "num_layers": num_ssl_layers, 450 "output_dim": dim, 451 } 452 453 def construct_module(self) -> Union[nn.Module, None]: 454 """Construct a TCN module.""" 455 module = DilatedTCN(**self.pars) 456 return module
Masked frames SSL with a TCN module.
423 def __init__( 424 self, 425 dims: torch.Size, 426 num_f_maps: torch.Size, 427 frac_masked: float = 0.2, 428 num_ssl_layers: int = 5, 429 ) -> None: 430 """Initialize the SSL constructor. 431 432 Parameters 433 ---------- 434 dims : torch.Size 435 the number of features in model input 436 num_f_maps : torch.Size 437 shape of feature extraction output 438 frac_masked : float, default 0.1 439 fraction of frames to real_lens 440 num_ssl_layers : int, default 5 441 number of layers in the SSL module 442 443 """ 444 super().__init__(frac_masked) 445 dim = int(sum([s[0] for s in dims.values()])) 446 num_f_maps = int(num_f_maps[0]) 447 self.pars = { 448 "input_dim": num_f_maps, 449 "num_layers": num_ssl_layers, 450 "output_dim": dim, 451 }
Initialize the SSL constructor.
Parameters
dims : torch.Size the number of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of frames to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module
453 def construct_module(self) -> Union[nn.Module, None]: 454 """Construct a TCN module.""" 455 module = DilatedTCN(**self.pars) 456 return module
Construct a TCN module.