dlc2action.ssl.masked
Implementations of dlc2action.ssl.base_ssl.SSLConstructor
that predict masked input features
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Implementations of `dlc2action.ssl.base_ssl.SSLConstructor` that predict masked input features 8""" 9 10from typing import Dict, Tuple, Union, List 11 12import torch 13 14from dlc2action.ssl.base_ssl import SSLConstructor 15from abc import ABC, abstractmethod 16from dlc2action.loss.mse import _MSE 17from dlc2action.ssl.modules import _FC, _DilatedTCN 18import torch 19from torch import nn 20 21 22class MaskedFeaturesSSL(SSLConstructor, ABC): 23 """ 24 A base masked features SSL class 25 26 Mask some of the input features randomly and predict the initial data. 27 """ 28 29 type = "ssl_input" 30 31 def __init__(self, frac_masked: float = 0.2) -> None: 32 """ 33 Parameters 34 ---------- 35 frac_masked : float 36 fraction of features to real_lens 37 """ 38 39 super().__init__() 40 self.mse = _MSE() 41 self.frac_masked = frac_masked 42 43 def transformation(self, sample_data: Dict) -> Tuple: 44 """ 45 Mask some of the features randomly 46 """ 47 48 for key in sample_data: 49 mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked 50 sample_data[key] = sample_data[key] * mask 51 ssl_target = torch.cat(list(sample_data.values())) 52 return (sample_data, ssl_target) 53 54 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 55 """ 56 MSE loss 57 """ 58 59 loss = self.mse(predicted, target) 60 return loss 61 62 @abstractmethod 63 def construct_module(self) -> nn.Module: 64 """ 65 Construct the SSL prediction module using the parameters specified at initialization 66 """ 67 68 69class MaskedFeaturesSSL_FC(MaskedFeaturesSSL): 70 """ 71 A fully connected masked features SSL class 72 73 Mask some of the input features randomly and predict the initial data. 74 """ 75 76 def __init__( 77 self, 78 dims: torch.Size, 79 num_f_maps: torch.Size, 80 frac_masked: float = 0.2, 81 num_ssl_layers: int = 5, 82 num_ssl_f_maps: int = 16, 83 ) -> None: 84 """ 85 Parameters 86 ---------- 87 dims : torch.Size 88 the shape of features in model input 89 num_f_maps : torch.Size 90 shape of feature extraction output 91 frac_masked : float, default 0.1 92 fraction of features to real_lens 93 num_ssl_layers : int, default 5 94 number of layers in the SSL module 95 num_ssl_f_maps : int, default 16 96 number of feature maps in the SSL module 97 """ 98 99 super().__init__(frac_masked) 100 dim = int(sum([s[0] for s in dims.values()])) 101 num_f_maps = int(num_f_maps[0]) 102 self.pars = { 103 "dim": dim, 104 "num_f_maps": num_f_maps, 105 "num_ssl_layers": num_ssl_layers, 106 "num_ssl_f_maps": num_ssl_f_maps, 107 } 108 109 def construct_module(self) -> Union[nn.Module, None]: 110 """ 111 A fully connected module 112 """ 113 114 module = _FC(**self.pars) 115 return module 116 117 118class MaskedFeaturesSSL_TCN(MaskedFeaturesSSL): 119 """ 120 A TCN masked features SSL class 121 122 Mask some of the input features randomly and predict the initial data. 123 """ 124 125 def __init__( 126 self, 127 dims: Dict, 128 num_f_maps: torch.Size, 129 frac_masked: float = 0.2, 130 num_ssl_layers: int = 5, 131 ) -> None: 132 """ 133 Parameters 134 ---------- 135 dims : torch.Size 136 the shape of features in model input 137 num_f_maps : torch.Size 138 shape of feature extraction output 139 frac_masked : float, default 0.1 140 fraction of features to real_lens 141 num_ssl_layers : int, default 5 142 number of layers in the SSL module 143 """ 144 145 super().__init__(frac_masked) 146 dim = int(sum([s[0] for s in dims.values()])) 147 num_f_maps = int(num_f_maps[0]) 148 self.pars = { 149 "input_dim": num_f_maps, 150 "num_layers": num_ssl_layers, 151 "output_dim": dim, 152 } 153 154 def construct_module(self) -> Union[nn.Module, None]: 155 """ 156 A TCN module 157 """ 158 159 module = _DilatedTCN(**self.pars) 160 return module 161 162 163class MaskedKinematicSSL(SSLConstructor, ABC): 164 """ 165 A base masked joints SSL class 166 167 Mask some of the joints randomly and predict the initial data. 168 """ 169 170 type = "ssl_input" 171 172 def __init__(self, frac_masked: float = 0.2) -> None: 173 """ 174 Parameters 175 ---------- 176 frac_masked : float, default 0.1 177 fraction of features to real_lens 178 """ 179 180 super().__init__() 181 self.mse = _MSE() 182 self.frac_masked = frac_masked 183 184 def _get_keys(self, key_bases, x): 185 """ 186 Get keys of x that start with one of the strings in key_bases 187 """ 188 189 keys = [] 190 for key in x: 191 if key.startswith(key_bases): 192 keys.append(key) 193 return keys 194 195 def transformation(self, sample_data: Dict) -> Tuple: 196 """ 197 Mask joints randomly 198 """ 199 200 key = self._get_keys("coords", sample_data)[0] 201 features, frames = sample_data[key].shape 202 n_bp = features // 2 203 masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked 204 keys = self._get_keys(("intra_distance", "inter_distance"), sample_data) 205 for key in keys: 206 mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2) 207 indices = torch.triu_indices(n_bp, n_bp, 1) 208 209 X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device) 210 X[indices[0], indices[1], :] = sample_data[key] 211 X[mask] = 0 212 X[mask.transpose(0, 1)] = 0 213 sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames) 214 keys = self._get_keys(("speed_joints", "coords", "acc_joints"), sample_data) 215 for key in keys: 216 mask = ( 217 masked_joints.repeat(2, frames, 1).transpose(0, 2).reshape((-1, frames)) 218 ) 219 sample_data[key][mask] = 0 220 keys = self._get_keys("angle_joints_radian", sample_data) 221 for key in keys: 222 mask = masked_joints.repeat(frames, 1).transpose(0, 1) 223 sample_data[key][mask] = 0 224 225 return sample_data, torch.cat(list(sample_data.values())) 226 227 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 228 """ 229 MSE loss 230 """ 231 232 loss = self.mse(predicted, target) 233 return loss 234 235 @abstractmethod 236 def construct_module(self) -> Union[nn.Module, None]: 237 """ 238 Construct the SSL prediction module using the parameters specified at initialization 239 """ 240 241 242class MaskedKinematicSSL_FC(MaskedKinematicSSL): 243 def __init__( 244 self, 245 dims: torch.Size, 246 num_f_maps: torch.Size, 247 frac_masked: float = 0.2, 248 num_ssl_layers: int = 5, 249 num_ssl_f_maps: int = 16, 250 ) -> None: 251 """ 252 Parameters 253 ---------- 254 dims : torch.Size 255 the number of features in model input 256 num_f_maps : torch.Size 257 shape of feature extraction output 258 frac_masked : float, default 0.1 259 fraction of joints to real_lens 260 num_ssl_layers : int, default 5 261 number of layers in the SSL module 262 num_ssl_f_maps : int, default 16 263 number of feature maps in the SSL module 264 """ 265 266 super().__init__(frac_masked) 267 dim = int(sum([s[0] for s in dims.values()])) 268 num_f_maps = int(num_f_maps[0]) 269 self.pars = { 270 "dim": dim, 271 "num_f_maps": num_f_maps, 272 "num_ssl_layers": num_ssl_layers, 273 "num_ssl_f_maps": num_ssl_f_maps, 274 } 275 276 def construct_module(self) -> Union[nn.Module, None]: 277 """ 278 A fully connected module 279 """ 280 281 module = _FC(**self.pars) 282 return module 283 284 285class MaskedKinematicSSL_TCN(MaskedKinematicSSL): 286 def __init__( 287 self, 288 dims: torch.Size, 289 num_f_maps: torch.Size, 290 frac_masked: float = 0.2, 291 num_ssl_layers: int = 5, 292 ) -> None: 293 """ 294 Parameters 295 ---------- 296 dims : torch.Size 297 the shape of features in model input 298 num_f_maps : torch.Size 299 shape of feature extraction output 300 frac_masked : float, default 0.1 301 fraction of joints to real_lens 302 num_ssl_layers : int, default 5 303 number of layers in the SSL module 304 """ 305 306 super().__init__(frac_masked) 307 dim = int(sum([s[0] for s in dims.values()])) 308 num_f_maps = int(num_f_maps[0]) 309 self.pars = { 310 "input_dim": num_f_maps, 311 "num_layers": num_ssl_layers, 312 "output_dim": dim, 313 } 314 315 def construct_module(self) -> Union[nn.Module, None]: 316 """ 317 A TCN module 318 """ 319 320 module = _DilatedTCN(**self.pars) 321 return module 322 323 324class MaskedFramesSSL(SSLConstructor, ABC): 325 """Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly 326 and predict the initial data""" 327 328 type = "ssl_input" 329 330 def __init__(self, frac_masked: float = 0.1) -> None: 331 """ 332 Parameters 333 ---------- 334 frac_masked : float, default 0.1 335 fraction of frames to real_lens 336 """ 337 338 super().__init__() 339 self.frac_masked = frac_masked 340 self.mse = _MSE() 341 342 def transformation(self, sample_data: Dict) -> Tuple: 343 """ 344 Mask some of the frames randomly 345 """ 346 347 key = list(sample_data.keys())[0] 348 num_frames = sample_data[key].shape[-1] 349 mask = torch.empty(num_frames).normal_() > self.frac_masked 350 mask = mask.unsqueeze(0) 351 for key in sample_data: 352 sample_data[key] = sample_data[key] * mask 353 ssl_target = torch.cat(list(sample_data.values())) 354 return (sample_data, ssl_target) 355 356 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 357 """ 358 MSE loss 359 """ 360 361 loss = self.mse(predicted, target) 362 return loss 363 364 @abstractmethod 365 def construct_module(self) -> Union[nn.Module, None]: 366 """ 367 Construct the SSL prediction module using the parameters specified at initialization 368 """ 369 370 371class MaskedFramesSSL_FC(MaskedFramesSSL): 372 def __init__( 373 self, 374 dims: torch.Size, 375 num_f_maps: torch.Size, 376 frac_masked: float = 0.1, 377 num_ssl_layers: int = 3, 378 num_ssl_f_maps: int = 16, 379 ) -> None: 380 """ 381 Parameters 382 ---------- 383 dims : torch.Size 384 the shape of features in model input 385 num_f_maps : torch.Size 386 shape of feature extraction output 387 frac_masked : float, default 0.1 388 fraction of frames to real_lens 389 num_ssl_layers : int, default 5 390 number of layers in the SSL module 391 num_ssl_f_maps : int, default 16 392 number of feature maps in the SSL module 393 """ 394 super().__init__(frac_masked) 395 dim = int(sum([s[0] for s in dims.values()])) 396 num_f_maps = int(num_f_maps[0]) 397 self.pars = { 398 "dim": dim, 399 "num_f_maps": num_f_maps, 400 "num_ssl_layers": num_ssl_layers, 401 "num_ssl_f_maps": num_ssl_f_maps, 402 } 403 404 def construct_module(self) -> Union[nn.Module, None]: 405 """ 406 A fully connected module 407 """ 408 409 module = _FC(**self.pars) 410 return module 411 412 413class MaskedFramesSSL_TCN(MaskedFramesSSL): 414 def __init__( 415 self, 416 dims: torch.Size, 417 num_f_maps: torch.Size, 418 frac_masked: float = 0.2, 419 num_ssl_layers: int = 5, 420 ) -> None: 421 """ 422 Parameters 423 ---------- 424 dims : torch.Size 425 the number of features in model input 426 num_f_maps : torch.Size 427 shape of feature extraction output 428 frac_masked : float, default 0.1 429 fraction of frames to real_lens 430 num_ssl_layers : int, default 5 431 number of layers in the SSL module 432 """ 433 434 super().__init__(frac_masked) 435 dim = int(sum([s[0] for s in dims.values()])) 436 num_f_maps = int(num_f_maps[0]) 437 self.pars = { 438 "input_dim": num_f_maps, 439 "num_layers": num_ssl_layers, 440 "output_dim": dim, 441 } 442 443 def construct_module(self) -> Union[nn.Module, None]: 444 """ 445 A TCN module 446 """ 447 448 module = _DilatedTCN(**self.pars) 449 return module
23class MaskedFeaturesSSL(SSLConstructor, ABC): 24 """ 25 A base masked features SSL class 26 27 Mask some of the input features randomly and predict the initial data. 28 """ 29 30 type = "ssl_input" 31 32 def __init__(self, frac_masked: float = 0.2) -> None: 33 """ 34 Parameters 35 ---------- 36 frac_masked : float 37 fraction of features to real_lens 38 """ 39 40 super().__init__() 41 self.mse = _MSE() 42 self.frac_masked = frac_masked 43 44 def transformation(self, sample_data: Dict) -> Tuple: 45 """ 46 Mask some of the features randomly 47 """ 48 49 for key in sample_data: 50 mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked 51 sample_data[key] = sample_data[key] * mask 52 ssl_target = torch.cat(list(sample_data.values())) 53 return (sample_data, ssl_target) 54 55 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 56 """ 57 MSE loss 58 """ 59 60 loss = self.mse(predicted, target) 61 return loss 62 63 @abstractmethod 64 def construct_module(self) -> nn.Module: 65 """ 66 Construct the SSL prediction module using the parameters specified at initialization 67 """
A base masked features SSL class
Mask some of the input features randomly and predict the initial data.
32 def __init__(self, frac_masked: float = 0.2) -> None: 33 """ 34 Parameters 35 ---------- 36 frac_masked : float 37 fraction of features to real_lens 38 """ 39 40 super().__init__() 41 self.mse = _MSE() 42 self.frac_masked = frac_masked
Parameters
frac_masked : float fraction of features to real_lens
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
44 def transformation(self, sample_data: Dict) -> Tuple: 45 """ 46 Mask some of the features randomly 47 """ 48 49 for key in sample_data: 50 mask = torch.empty(sample_data[key].shape).normal_() > self.frac_masked 51 sample_data[key] = sample_data[key] * mask 52 ssl_target = torch.cat(list(sample_data.values())) 53 return (sample_data, ssl_target)
Mask some of the features randomly
55 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 56 """ 57 MSE loss 58 """ 59 60 loss = self.mse(predicted, target) 61 return loss
MSE loss
63 @abstractmethod 64 def construct_module(self) -> nn.Module: 65 """ 66 Construct the SSL prediction module using the parameters specified at initialization 67 """
Construct the SSL prediction module using the parameters specified at initialization
70class MaskedFeaturesSSL_FC(MaskedFeaturesSSL): 71 """ 72 A fully connected masked features SSL class 73 74 Mask some of the input features randomly and predict the initial data. 75 """ 76 77 def __init__( 78 self, 79 dims: torch.Size, 80 num_f_maps: torch.Size, 81 frac_masked: float = 0.2, 82 num_ssl_layers: int = 5, 83 num_ssl_f_maps: int = 16, 84 ) -> None: 85 """ 86 Parameters 87 ---------- 88 dims : torch.Size 89 the shape of features in model input 90 num_f_maps : torch.Size 91 shape of feature extraction output 92 frac_masked : float, default 0.1 93 fraction of features to real_lens 94 num_ssl_layers : int, default 5 95 number of layers in the SSL module 96 num_ssl_f_maps : int, default 16 97 number of feature maps in the SSL module 98 """ 99 100 super().__init__(frac_masked) 101 dim = int(sum([s[0] for s in dims.values()])) 102 num_f_maps = int(num_f_maps[0]) 103 self.pars = { 104 "dim": dim, 105 "num_f_maps": num_f_maps, 106 "num_ssl_layers": num_ssl_layers, 107 "num_ssl_f_maps": num_ssl_f_maps, 108 } 109 110 def construct_module(self) -> Union[nn.Module, None]: 111 """ 112 A fully connected module 113 """ 114 115 module = _FC(**self.pars) 116 return module
A fully connected masked features SSL class
Mask some of the input features randomly and predict the initial data.
77 def __init__( 78 self, 79 dims: torch.Size, 80 num_f_maps: torch.Size, 81 frac_masked: float = 0.2, 82 num_ssl_layers: int = 5, 83 num_ssl_f_maps: int = 16, 84 ) -> None: 85 """ 86 Parameters 87 ---------- 88 dims : torch.Size 89 the shape of features in model input 90 num_f_maps : torch.Size 91 shape of feature extraction output 92 frac_masked : float, default 0.1 93 fraction of features to real_lens 94 num_ssl_layers : int, default 5 95 number of layers in the SSL module 96 num_ssl_f_maps : int, default 16 97 number of feature maps in the SSL module 98 """ 99 100 super().__init__(frac_masked) 101 dim = int(sum([s[0] for s in dims.values()])) 102 num_f_maps = int(num_f_maps[0]) 103 self.pars = { 104 "dim": dim, 105 "num_f_maps": num_f_maps, 106 "num_ssl_layers": num_ssl_layers, 107 "num_ssl_f_maps": num_ssl_f_maps, 108 }
Parameters
dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of features to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module
110 def construct_module(self) -> Union[nn.Module, None]: 111 """ 112 A fully connected module 113 """ 114 115 module = _FC(**self.pars) 116 return module
A fully connected module
Inherited Members
119class MaskedFeaturesSSL_TCN(MaskedFeaturesSSL): 120 """ 121 A TCN masked features SSL class 122 123 Mask some of the input features randomly and predict the initial data. 124 """ 125 126 def __init__( 127 self, 128 dims: Dict, 129 num_f_maps: torch.Size, 130 frac_masked: float = 0.2, 131 num_ssl_layers: int = 5, 132 ) -> None: 133 """ 134 Parameters 135 ---------- 136 dims : torch.Size 137 the shape of features in model input 138 num_f_maps : torch.Size 139 shape of feature extraction output 140 frac_masked : float, default 0.1 141 fraction of features to real_lens 142 num_ssl_layers : int, default 5 143 number of layers in the SSL module 144 """ 145 146 super().__init__(frac_masked) 147 dim = int(sum([s[0] for s in dims.values()])) 148 num_f_maps = int(num_f_maps[0]) 149 self.pars = { 150 "input_dim": num_f_maps, 151 "num_layers": num_ssl_layers, 152 "output_dim": dim, 153 } 154 155 def construct_module(self) -> Union[nn.Module, None]: 156 """ 157 A TCN module 158 """ 159 160 module = _DilatedTCN(**self.pars) 161 return module
A TCN masked features SSL class
Mask some of the input features randomly and predict the initial data.
126 def __init__( 127 self, 128 dims: Dict, 129 num_f_maps: torch.Size, 130 frac_masked: float = 0.2, 131 num_ssl_layers: int = 5, 132 ) -> None: 133 """ 134 Parameters 135 ---------- 136 dims : torch.Size 137 the shape of features in model input 138 num_f_maps : torch.Size 139 shape of feature extraction output 140 frac_masked : float, default 0.1 141 fraction of features to real_lens 142 num_ssl_layers : int, default 5 143 number of layers in the SSL module 144 """ 145 146 super().__init__(frac_masked) 147 dim = int(sum([s[0] for s in dims.values()])) 148 num_f_maps = int(num_f_maps[0]) 149 self.pars = { 150 "input_dim": num_f_maps, 151 "num_layers": num_ssl_layers, 152 "output_dim": dim, 153 }
Parameters
dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of features to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module
155 def construct_module(self) -> Union[nn.Module, None]: 156 """ 157 A TCN module 158 """ 159 160 module = _DilatedTCN(**self.pars) 161 return module
A TCN module
Inherited Members
164class MaskedKinematicSSL(SSLConstructor, ABC): 165 """ 166 A base masked joints SSL class 167 168 Mask some of the joints randomly and predict the initial data. 169 """ 170 171 type = "ssl_input" 172 173 def __init__(self, frac_masked: float = 0.2) -> None: 174 """ 175 Parameters 176 ---------- 177 frac_masked : float, default 0.1 178 fraction of features to real_lens 179 """ 180 181 super().__init__() 182 self.mse = _MSE() 183 self.frac_masked = frac_masked 184 185 def _get_keys(self, key_bases, x): 186 """ 187 Get keys of x that start with one of the strings in key_bases 188 """ 189 190 keys = [] 191 for key in x: 192 if key.startswith(key_bases): 193 keys.append(key) 194 return keys 195 196 def transformation(self, sample_data: Dict) -> Tuple: 197 """ 198 Mask joints randomly 199 """ 200 201 key = self._get_keys("coords", sample_data)[0] 202 features, frames = sample_data[key].shape 203 n_bp = features // 2 204 masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked 205 keys = self._get_keys(("intra_distance", "inter_distance"), sample_data) 206 for key in keys: 207 mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2) 208 indices = torch.triu_indices(n_bp, n_bp, 1) 209 210 X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device) 211 X[indices[0], indices[1], :] = sample_data[key] 212 X[mask] = 0 213 X[mask.transpose(0, 1)] = 0 214 sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames) 215 keys = self._get_keys(("speed_joints", "coords", "acc_joints"), sample_data) 216 for key in keys: 217 mask = ( 218 masked_joints.repeat(2, frames, 1).transpose(0, 2).reshape((-1, frames)) 219 ) 220 sample_data[key][mask] = 0 221 keys = self._get_keys("angle_joints_radian", sample_data) 222 for key in keys: 223 mask = masked_joints.repeat(frames, 1).transpose(0, 1) 224 sample_data[key][mask] = 0 225 226 return sample_data, torch.cat(list(sample_data.values())) 227 228 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 229 """ 230 MSE loss 231 """ 232 233 loss = self.mse(predicted, target) 234 return loss 235 236 @abstractmethod 237 def construct_module(self) -> Union[nn.Module, None]: 238 """ 239 Construct the SSL prediction module using the parameters specified at initialization 240 """
A base masked joints SSL class
Mask some of the joints randomly and predict the initial data.
173 def __init__(self, frac_masked: float = 0.2) -> None: 174 """ 175 Parameters 176 ---------- 177 frac_masked : float, default 0.1 178 fraction of features to real_lens 179 """ 180 181 super().__init__() 182 self.mse = _MSE() 183 self.frac_masked = frac_masked
Parameters
frac_masked : float, default 0.1 fraction of features to real_lens
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
196 def transformation(self, sample_data: Dict) -> Tuple: 197 """ 198 Mask joints randomly 199 """ 200 201 key = self._get_keys("coords", sample_data)[0] 202 features, frames = sample_data[key].shape 203 n_bp = features // 2 204 masked_joints = torch.FloatTensor(n_bp).uniform_() > self.frac_masked 205 keys = self._get_keys(("intra_distance", "inter_distance"), sample_data) 206 for key in keys: 207 mask = masked_joints.repeat(n_bp, frames, 1).transpose(1, 2) 208 indices = torch.triu_indices(n_bp, n_bp, 1) 209 210 X = torch.zeros((n_bp, n_bp, frames)).to(sample_data[key].device) 211 X[indices[0], indices[1], :] = sample_data[key] 212 X[mask] = 0 213 X[mask.transpose(0, 1)] = 0 214 sample_data[key] = X[indices[0], indices[1], :].reshape(-1, frames) 215 keys = self._get_keys(("speed_joints", "coords", "acc_joints"), sample_data) 216 for key in keys: 217 mask = ( 218 masked_joints.repeat(2, frames, 1).transpose(0, 2).reshape((-1, frames)) 219 ) 220 sample_data[key][mask] = 0 221 keys = self._get_keys("angle_joints_radian", sample_data) 222 for key in keys: 223 mask = masked_joints.repeat(frames, 1).transpose(0, 1) 224 sample_data[key][mask] = 0 225 226 return sample_data, torch.cat(list(sample_data.values()))
Mask joints randomly
228 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 229 """ 230 MSE loss 231 """ 232 233 loss = self.mse(predicted, target) 234 return loss
MSE loss
236 @abstractmethod 237 def construct_module(self) -> Union[nn.Module, None]: 238 """ 239 Construct the SSL prediction module using the parameters specified at initialization 240 """
Construct the SSL prediction module using the parameters specified at initialization
243class MaskedKinematicSSL_FC(MaskedKinematicSSL): 244 def __init__( 245 self, 246 dims: torch.Size, 247 num_f_maps: torch.Size, 248 frac_masked: float = 0.2, 249 num_ssl_layers: int = 5, 250 num_ssl_f_maps: int = 16, 251 ) -> None: 252 """ 253 Parameters 254 ---------- 255 dims : torch.Size 256 the number of features in model input 257 num_f_maps : torch.Size 258 shape of feature extraction output 259 frac_masked : float, default 0.1 260 fraction of joints to real_lens 261 num_ssl_layers : int, default 5 262 number of layers in the SSL module 263 num_ssl_f_maps : int, default 16 264 number of feature maps in the SSL module 265 """ 266 267 super().__init__(frac_masked) 268 dim = int(sum([s[0] for s in dims.values()])) 269 num_f_maps = int(num_f_maps[0]) 270 self.pars = { 271 "dim": dim, 272 "num_f_maps": num_f_maps, 273 "num_ssl_layers": num_ssl_layers, 274 "num_ssl_f_maps": num_ssl_f_maps, 275 } 276 277 def construct_module(self) -> Union[nn.Module, None]: 278 """ 279 A fully connected module 280 """ 281 282 module = _FC(**self.pars) 283 return module
A base masked joints SSL class
Mask some of the joints randomly and predict the initial data.
244 def __init__( 245 self, 246 dims: torch.Size, 247 num_f_maps: torch.Size, 248 frac_masked: float = 0.2, 249 num_ssl_layers: int = 5, 250 num_ssl_f_maps: int = 16, 251 ) -> None: 252 """ 253 Parameters 254 ---------- 255 dims : torch.Size 256 the number of features in model input 257 num_f_maps : torch.Size 258 shape of feature extraction output 259 frac_masked : float, default 0.1 260 fraction of joints to real_lens 261 num_ssl_layers : int, default 5 262 number of layers in the SSL module 263 num_ssl_f_maps : int, default 16 264 number of feature maps in the SSL module 265 """ 266 267 super().__init__(frac_masked) 268 dim = int(sum([s[0] for s in dims.values()])) 269 num_f_maps = int(num_f_maps[0]) 270 self.pars = { 271 "dim": dim, 272 "num_f_maps": num_f_maps, 273 "num_ssl_layers": num_ssl_layers, 274 "num_ssl_f_maps": num_ssl_f_maps, 275 }
Parameters
dims : torch.Size the number of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of joints to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module
277 def construct_module(self) -> Union[nn.Module, None]: 278 """ 279 A fully connected module 280 """ 281 282 module = _FC(**self.pars) 283 return module
A fully connected module
Inherited Members
286class MaskedKinematicSSL_TCN(MaskedKinematicSSL): 287 def __init__( 288 self, 289 dims: torch.Size, 290 num_f_maps: torch.Size, 291 frac_masked: float = 0.2, 292 num_ssl_layers: int = 5, 293 ) -> None: 294 """ 295 Parameters 296 ---------- 297 dims : torch.Size 298 the shape of features in model input 299 num_f_maps : torch.Size 300 shape of feature extraction output 301 frac_masked : float, default 0.1 302 fraction of joints to real_lens 303 num_ssl_layers : int, default 5 304 number of layers in the SSL module 305 """ 306 307 super().__init__(frac_masked) 308 dim = int(sum([s[0] for s in dims.values()])) 309 num_f_maps = int(num_f_maps[0]) 310 self.pars = { 311 "input_dim": num_f_maps, 312 "num_layers": num_ssl_layers, 313 "output_dim": dim, 314 } 315 316 def construct_module(self) -> Union[nn.Module, None]: 317 """ 318 A TCN module 319 """ 320 321 module = _DilatedTCN(**self.pars) 322 return module
A base masked joints SSL class
Mask some of the joints randomly and predict the initial data.
287 def __init__( 288 self, 289 dims: torch.Size, 290 num_f_maps: torch.Size, 291 frac_masked: float = 0.2, 292 num_ssl_layers: int = 5, 293 ) -> None: 294 """ 295 Parameters 296 ---------- 297 dims : torch.Size 298 the shape of features in model input 299 num_f_maps : torch.Size 300 shape of feature extraction output 301 frac_masked : float, default 0.1 302 fraction of joints to real_lens 303 num_ssl_layers : int, default 5 304 number of layers in the SSL module 305 """ 306 307 super().__init__(frac_masked) 308 dim = int(sum([s[0] for s in dims.values()])) 309 num_f_maps = int(num_f_maps[0]) 310 self.pars = { 311 "input_dim": num_f_maps, 312 "num_layers": num_ssl_layers, 313 "output_dim": dim, 314 }
Parameters
dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of joints to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module
316 def construct_module(self) -> Union[nn.Module, None]: 317 """ 318 A TCN module 319 """ 320 321 module = _DilatedTCN(**self.pars) 322 return module
A TCN module
Inherited Members
325class MaskedFramesSSL(SSLConstructor, ABC): 326 """Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly 327 and predict the initial data""" 328 329 type = "ssl_input" 330 331 def __init__(self, frac_masked: float = 0.1) -> None: 332 """ 333 Parameters 334 ---------- 335 frac_masked : float, default 0.1 336 fraction of frames to real_lens 337 """ 338 339 super().__init__() 340 self.frac_masked = frac_masked 341 self.mse = _MSE() 342 343 def transformation(self, sample_data: Dict) -> Tuple: 344 """ 345 Mask some of the frames randomly 346 """ 347 348 key = list(sample_data.keys())[0] 349 num_frames = sample_data[key].shape[-1] 350 mask = torch.empty(num_frames).normal_() > self.frac_masked 351 mask = mask.unsqueeze(0) 352 for key in sample_data: 353 sample_data[key] = sample_data[key] * mask 354 ssl_target = torch.cat(list(sample_data.values())) 355 return (sample_data, ssl_target) 356 357 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 358 """ 359 MSE loss 360 """ 361 362 loss = self.mse(predicted, target) 363 return loss 364 365 @abstractmethod 366 def construct_module(self) -> Union[nn.Module, None]: 367 """ 368 Construct the SSL prediction module using the parameters specified at initialization 369 """
Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly and predict the initial data
331 def __init__(self, frac_masked: float = 0.1) -> None: 332 """ 333 Parameters 334 ---------- 335 frac_masked : float, default 0.1 336 fraction of frames to real_lens 337 """ 338 339 super().__init__() 340 self.frac_masked = frac_masked 341 self.mse = _MSE()
Parameters
frac_masked : float, default 0.1 fraction of frames to real_lens
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
343 def transformation(self, sample_data: Dict) -> Tuple: 344 """ 345 Mask some of the frames randomly 346 """ 347 348 key = list(sample_data.keys())[0] 349 num_frames = sample_data[key].shape[-1] 350 mask = torch.empty(num_frames).normal_() > self.frac_masked 351 mask = mask.unsqueeze(0) 352 for key in sample_data: 353 sample_data[key] = sample_data[key] * mask 354 ssl_target = torch.cat(list(sample_data.values())) 355 return (sample_data, ssl_target)
Mask some of the frames randomly
357 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 358 """ 359 MSE loss 360 """ 361 362 loss = self.mse(predicted, target) 363 return loss
MSE loss
365 @abstractmethod 366 def construct_module(self) -> Union[nn.Module, None]: 367 """ 368 Construct the SSL prediction module using the parameters specified at initialization 369 """
Construct the SSL prediction module using the parameters specified at initialization
372class MaskedFramesSSL_FC(MaskedFramesSSL): 373 def __init__( 374 self, 375 dims: torch.Size, 376 num_f_maps: torch.Size, 377 frac_masked: float = 0.1, 378 num_ssl_layers: int = 3, 379 num_ssl_f_maps: int = 16, 380 ) -> None: 381 """ 382 Parameters 383 ---------- 384 dims : torch.Size 385 the shape of features in model input 386 num_f_maps : torch.Size 387 shape of feature extraction output 388 frac_masked : float, default 0.1 389 fraction of frames to real_lens 390 num_ssl_layers : int, default 5 391 number of layers in the SSL module 392 num_ssl_f_maps : int, default 16 393 number of feature maps in the SSL module 394 """ 395 super().__init__(frac_masked) 396 dim = int(sum([s[0] for s in dims.values()])) 397 num_f_maps = int(num_f_maps[0]) 398 self.pars = { 399 "dim": dim, 400 "num_f_maps": num_f_maps, 401 "num_ssl_layers": num_ssl_layers, 402 "num_ssl_f_maps": num_ssl_f_maps, 403 } 404 405 def construct_module(self) -> Union[nn.Module, None]: 406 """ 407 A fully connected module 408 """ 409 410 module = _FC(**self.pars) 411 return module
Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly and predict the initial data
373 def __init__( 374 self, 375 dims: torch.Size, 376 num_f_maps: torch.Size, 377 frac_masked: float = 0.1, 378 num_ssl_layers: int = 3, 379 num_ssl_f_maps: int = 16, 380 ) -> None: 381 """ 382 Parameters 383 ---------- 384 dims : torch.Size 385 the shape of features in model input 386 num_f_maps : torch.Size 387 shape of feature extraction output 388 frac_masked : float, default 0.1 389 fraction of frames to real_lens 390 num_ssl_layers : int, default 5 391 number of layers in the SSL module 392 num_ssl_f_maps : int, default 16 393 number of feature maps in the SSL module 394 """ 395 super().__init__(frac_masked) 396 dim = int(sum([s[0] for s in dims.values()])) 397 num_f_maps = int(num_f_maps[0]) 398 self.pars = { 399 "dim": dim, 400 "num_f_maps": num_f_maps, 401 "num_ssl_layers": num_ssl_layers, 402 "num_ssl_f_maps": num_ssl_f_maps, 403 }
Parameters
dims : torch.Size the shape of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of frames to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module num_ssl_f_maps : int, default 16 number of feature maps in the SSL module
405 def construct_module(self) -> Union[nn.Module, None]: 406 """ 407 A fully connected module 408 """ 409 410 module = _FC(**self.pars) 411 return module
A fully connected module
Inherited Members
414class MaskedFramesSSL_TCN(MaskedFramesSSL): 415 def __init__( 416 self, 417 dims: torch.Size, 418 num_f_maps: torch.Size, 419 frac_masked: float = 0.2, 420 num_ssl_layers: int = 5, 421 ) -> None: 422 """ 423 Parameters 424 ---------- 425 dims : torch.Size 426 the number of features in model input 427 num_f_maps : torch.Size 428 shape of feature extraction output 429 frac_masked : float, default 0.1 430 fraction of frames to real_lens 431 num_ssl_layers : int, default 5 432 number of layers in the SSL module 433 """ 434 435 super().__init__(frac_masked) 436 dim = int(sum([s[0] for s in dims.values()])) 437 num_f_maps = int(num_f_maps[0]) 438 self.pars = { 439 "input_dim": num_f_maps, 440 "num_layers": num_ssl_layers, 441 "output_dim": dim, 442 } 443 444 def construct_module(self) -> Union[nn.Module, None]: 445 """ 446 A TCN module 447 """ 448 449 module = _DilatedTCN(**self.pars) 450 return module
Generates the functions necessary to build a masked features SSL: real_lens some of the input features randomly and predict the initial data
415 def __init__( 416 self, 417 dims: torch.Size, 418 num_f_maps: torch.Size, 419 frac_masked: float = 0.2, 420 num_ssl_layers: int = 5, 421 ) -> None: 422 """ 423 Parameters 424 ---------- 425 dims : torch.Size 426 the number of features in model input 427 num_f_maps : torch.Size 428 shape of feature extraction output 429 frac_masked : float, default 0.1 430 fraction of frames to real_lens 431 num_ssl_layers : int, default 5 432 number of layers in the SSL module 433 """ 434 435 super().__init__(frac_masked) 436 dim = int(sum([s[0] for s in dims.values()])) 437 num_f_maps = int(num_f_maps[0]) 438 self.pars = { 439 "input_dim": num_f_maps, 440 "num_layers": num_ssl_layers, 441 "output_dim": dim, 442 }
Parameters
dims : torch.Size the number of features in model input num_f_maps : torch.Size shape of feature extraction output frac_masked : float, default 0.1 fraction of frames to real_lens num_ssl_layers : int, default 5 number of layers in the SSL module
444 def construct_module(self) -> Union[nn.Module, None]: 445 """ 446 A TCN module 447 """ 448 449 module = _DilatedTCN(**self.pars) 450 return module
A TCN module