dlc2action.model.ms_tcn
MS-TCN++ (multi-stage temporal convolutional network) variations
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# Incorporates code adapted from MS-TCN++ by yabufarha 7# Original work Copyright (c) 2019 June01 8# Source: https://github.com/sj-li/MS-TCN2 9# Originally licensed under MIT License 10# Combined work licensed under GNU AGPLv3 11# 12""" 13MS-TCN++ (multi-stage temporal convolutional network) variations 14""" 15 16from dlc2action.model.base_model import Model 17from dlc2action.model.ms_tcn_modules import * 18 19 20class Compiled(nn.Module): 21 def __init__(self, modules): 22 super(Compiled, self).__init__() 23 self.module_list = nn.ModuleList(modules) 24 25 def forward(self, x, tag=None): 26 """Forward pass.""" 27 for m in self.module_list: 28 x = m(x, tag) 29 return x 30 31 32class MS_TCN3(Model): 33 """ 34 A modification of MS-TCN++ model with additional options 35 """ 36 37 def __init__( 38 self, 39 num_f_maps, 40 num_classes, 41 exclusive, 42 dims, 43 num_layers_R, 44 num_R, 45 num_layers_PG, 46 num_f_maps_R=None, 47 num_layers_S=0, 48 dropout_rate=0.5, 49 shared_weights=False, 50 skip_connections_refinement=True, 51 block_size_prediction=0, 52 block_size_refinement=0, 53 kernel_size_prediction=3, 54 direction_PG=None, 55 direction_R=None, 56 PG_in_FE=False, 57 rare_dilations=False, 58 num_heads=1, 59 R_attention="none", 60 PG_attention="none", 61 state_dict_path=None, 62 ssl_constructors=None, 63 ssl_types=None, 64 ssl_modules=None, 65 multihead=False, 66 *args, 67 **kwargs, 68 ): 69 """ 70 Parameters 71 ---------- 72 num_f_maps : int 73 number of feature maps 74 num_classes : int 75 number of classes to predict 76 exclusive : bool 77 if `True`, single-label predictions are made; otherwise multi-label 78 dims : torch.Size 79 shape of features in the input data 80 num_layers_R : int 81 number of layers in the refinement stages 82 num_R : int 83 number of refinement stages 84 num_layers_PG : int 85 number of layers in the prediction generation stage 86 num_layers_S : int, default 0 87 number of layers in the spatial feature extraction stage 88 dropout_rate : float, default 0.5 89 dropout rate 90 shared_weights : bool, default False 91 if `True`, weights are shared across refinement stages 92 skip_connections_refinement : bool, default False 93 if `True`, skip connections are added to the refinement stages 94 block_size_prediction : int, default 0 95 if not 0, skip connections are added to the prediction generation stage with this interval 96 block_size_refinement : int, default 0 97 if not 0, skip connections are added to the refinement stage with this interval 98 direction_PG : [None, 'bidirectional', 'forward', 'backward'] 99 if not `None`, a combination of causal and anticausal convolutions are used in the 100 prediction generation stage 101 direction_R : [None, 'bidirectional', 'forward', 'backward'] 102 if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages 103 PG_in_FE : bool, default True 104 if `True`, the prediction generation stage is included in the feature extractor and otherwise in the 105 predictor (the output of the feature extractor is used in SSL tasks) 106 rare_dilations : bool, default False 107 if `False`, dilation increases every layer, otherwise every second layer in 108 the prediction generation stage 109 num_heads : int, default 1 110 the number of parallel refinement stages 111 PG_attention : bool, default False 112 if `True`, an attention layer is added to the prediction generation stage 113 R_attention : bool, default False 114 if `True`, an attention layer is added to the refinement stages 115 state_dict_path : str, optional 116 if not `None`, the model state dictionary will be loaded from this path 117 ssl_constructors : list, optional 118 a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate 119 ssl_types : list, optional 120 a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`) 121 ssl_modules : list, optional 122 a list of SSL modules to integrate (used alternatively to `ssl_constructors`) 123 """ 124 125 self.num_layers_R = int(float(num_layers_R)) 126 self.num_R = int(float(num_R)) 127 self.num_f_maps = int(float(num_f_maps)) 128 self.num_classes = int(float(num_classes)) 129 self.dropout_rate = float(dropout_rate) 130 self.exclusive = bool(exclusive) 131 self.num_layers_PG = int(float(num_layers_PG)) 132 self.num_layers_S = int(float(num_layers_S)) 133 self.dim = self._get_dims(dims) 134 self.shared_weights = bool(shared_weights) 135 self.skip_connections_ref = bool(skip_connections_refinement) 136 self.block_size_prediction = int(float(block_size_prediction)) 137 self.block_size_refinement = int(float(block_size_refinement)) 138 self.direction_R = direction_R 139 self.direction_PG = direction_PG 140 self.kernel_size_prediction = int(float(kernel_size_prediction)) 141 self.PG_in_FE = PG_in_FE 142 self.rare_dilations = rare_dilations 143 self.num_heads = int(float(num_heads)) 144 self.PG_attention = PG_attention 145 self.R_attention = R_attention 146 self.multihead = multihead 147 if num_f_maps_R is None: 148 num_f_maps_R = self.num_f_maps 149 self.num_f_maps_R = num_f_maps_R 150 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 151 152 def _get_dims(self, dims): 153 return int(sum([s[0] for s in dims.values()])) 154 155 def _PG(self): 156 if self.num_layers_S == 0: 157 dim = self.dim 158 else: 159 dim = self.num_f_maps 160 if self.direction_PG == "bidirectional": 161 PG = DilatedTCNB( 162 num_layers=self.num_layers_PG, 163 num_f_maps=self.num_f_maps, 164 dim=dim, 165 block_size=self.block_size_prediction, 166 kernel_size=self.kernel_size_prediction, 167 rare_dilations=self.rare_dilations, 168 ) 169 else: 170 PG = DilatedTCN( 171 num_layers=self.num_layers_PG, 172 num_f_maps=self.num_f_maps, 173 dim=dim, 174 direction=self.direction_PG, 175 block_size=self.block_size_prediction, 176 kernel_size=self.kernel_size_prediction, 177 rare_dilations=self.rare_dilations, 178 attention=self.PG_attention, 179 multihead=self.multihead, 180 ) 181 return PG 182 183 def _feature_extractor(self): 184 if self.num_layers_S == 0: 185 if self.PG_in_FE: 186 print("MS-TCN using the prediction generator as a feature extractor") 187 return self._PG() 188 else: 189 print("MS-TCN without a feature extractor -> no SSL possible!") 190 return nn.Identity() 191 192 print("MS-TCN using a spatial feature extractor") 193 feature_extractor = SpatialFeatures( 194 self.num_layers_S, 195 self.num_f_maps, 196 self.dim, 197 self.block_size_prediction, 198 ) 199 if self.PG_in_FE: 200 print(" -> also has the prediction generator as a feature extractor") 201 PG = self._PG() 202 feature_extractor = [feature_extractor, PG] 203 return feature_extractor 204 205 def _predictor(self): 206 if self.shared_weights: 207 prediction_module = MSRefinementShared 208 else: 209 prediction_module = MSRefinement 210 predictor = prediction_module( 211 num_layers_R=int(self.num_layers_R), 212 num_R=int(self.num_R), 213 num_f_maps_input=int(self.num_f_maps), 214 num_f_maps=int(self.num_f_maps_R), 215 num_classes=int(self.num_classes), 216 dropout_rate=self.dropout_rate, 217 exclusive=self.exclusive, 218 skip_connections=self.skip_connections_ref, 219 direction=self.direction_R, 220 block_size=self.block_size_refinement, 221 num_heads=self.num_heads, 222 attention=self.R_attention, 223 ) 224 if not self.PG_in_FE: 225 PG = self._PG() 226 predictor = Compiled([PG, predictor]) 227 return predictor 228 229 def features_shape(self) -> torch.Size: 230 """ 231 Get the shape of feature extractor output 232 233 Returns 234 ------- 235 feature_shape : torch.Size 236 shape of feature extractor output 237 """ 238 239 return torch.Size([self.num_f_maps]) 240 241 242class MS_TCN_P(MS_TCN3): 243 def _get_dims(self, dims): 244 keys = list(dims.keys()) 245 values = list(dims.values()) 246 groups = [key.split("---")[-1] for key in keys] 247 unique_groups = sorted(set(groups)) 248 res = [] 249 for group in unique_groups: 250 res.append(int(sum([x[0] for x, g in zip(values, groups) if g == group]))) 251 if "loaded" in dims: 252 res.append(int(dims["loaded"][0])) 253 return res 254 255 def _PG(self): 256 PG = MultiDilatedTCN( 257 self.num_layers_PG, 258 self.num_f_maps, 259 self.dim, 260 self.direction_PG, 261 self.block_size_prediction, 262 self.kernel_size_prediction, 263 self.rare_dilations, 264 ) 265 return PG 266 267 268# class MS_TCNC(Model): 269# """ 270# Basic MS-TCN++ model with options for shared weights and added skip connections 271# """ 272# 273# def __init__( 274# self, 275# num_layers_R, 276# num_R, 277# num_f_maps, 278# num_classes, 279# exclusive, 280# num_layers_PG, 281# num_layers_S, 282# dims, 283# len_segment, 284# dropout_rate=0.5, 285# shared_weights=False, 286# skip_connections_refinement=True, 287# block_size_prediction=5, 288# block_size_refinement=0, 289# kernel_size_prediction=3, 290# direction_PG=None, 291# direction_R=None, 292# PG_in_FE=False, 293# state_dict_path=None, 294# ssl_constructors=None, 295# ssl_types=None, 296# ssl_modules=None, 297# ): 298# """ 299# Parameters 300# ---------- 301# num_layers_R : int 302# number of layers in the refinement stages 303# num_R : int 304# number of refinement stages 305# num_f_maps : int 306# number of feature maps 307# num_classes : int 308# number of classes to predict 309# exclusive : bool 310# if `True`, single-label predictions are made; otherwise multi-label 311# num_layers_PG : int 312# number of layers in the prediction generation stage 313# dims : torch.Size 314# shape of features in the input data 315# dropout_rate : float, default 0.5 316# dropout rate 317# shared_weights : bool, default False 318# if `True`, weights are shared across refinement stages 319# skip_connections_refinement : bool, default False 320# if `True`, skip connections are added to the refinement stages 321# block_size_prediction : int, optional 322# if not 'None', skip connections are added to the prediction generation stage with this interval 323# direction_PG : bool, default True 324# if True, causal convolutions are used in the prediction generation stage 325# direction_R : bool, default False 326# if True, causal convolutions are used in the refinement stages 327# state_dict_path : str, optional 328# if not `None`, the model state dictionary will be loaded from this path 329# ssl_constructors : list, optional 330# a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate 331# ssl_types : list, optional 332# a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`) 333# ssl_modules : list, optional 334# a list of SSL modules to integrate (used alternatively to `ssl_constructors`) 335# """ 336# 337# if len(dims) > 1: 338# raise RuntimeError( 339# "The MS-TCN++ model expects the input data to be 2-dimensional; " 340# f"got {len(dims) + 1} dimensions" 341# ) 342# self.num_layers_R = int(num_layers_R) 343# self.num_R = int(num_R) 344# self.num_f_maps = int(num_f_maps) 345# self.num_classes = int(num_classes) 346# self.dropout_rate = dropout_rate 347# self.exclusive = exclusive 348# self.num_layers_PG = int(num_layers_PG) 349# self.num_layers_S = int(num_layers_S) 350# self.dim = int(dims[0]) 351# self.shared_weights = shared_weights 352# self.skip_connections_ref = skip_connections_refinement 353# self.block_size_prediction = int(block_size_prediction) 354# self.block_size_refinement = int(block_size_refinement) 355# self.direction_R = direction_R 356# self.direction_PG = direction_PG 357# self.kernel_size_prediction = int(kernel_size_prediction) 358# self.PG_in_FE = PG_in_FE 359# self.len_segment = len_segment 360# super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 361# 362# def _PG(self): 363# PG = DilatedTCNC( 364# num_f_maps=self.num_f_maps, 365# num_layers_PG=self.num_layers_PG, 366# len_segment=self.len_segment, 367# block_size_prediction=self.block_size_prediction, 368# kernel_size_prediction=self.kernel_size_prediction, 369# direction_PG=self.direction_PG, 370# ) 371# return PG 372# 373# def _feature_extractor(self): 374# feature_extractor = SpatialFeatures( 375# num_layers=self.num_layers_S, 376# num_f_maps=self.num_f_maps, 377# dim=self.dim, 378# block_size=self.block_size_prediction, 379# ) 380# if self.PG_in_FE: 381# PG = self._PG() 382# feature_extractor = [feature_extractor, PG] 383# return feature_extractor 384# 385# def _predictor(self): 386# 387# if self.shared_weights: 388# prediction_module = MSRefinementShared 389# else: 390# prediction_module = MSRefinement 391# predictor = prediction_module( 392# num_layers_R=int(self.num_layers_R), 393# num_R=int(self.num_R), 394# num_f_maps=int(self.num_f_maps), 395# num_classes=int(self.num_classes), 396# dropout_rate=self.dropout_rate, 397# exclusive=self.exclusive, 398# skip_connections=self.skip_connections_ref, 399# direction=self.direction_R, 400# block_size=self.block_size_refinement, 401# ) 402# if not self.PG_in_FE: 403# PG = self._PG() 404# predictor = Compiled([PG, predictor]) 405# return predictor 406# 407# def features_shape(self) -> torch.Size: 408# """ 409# Get the shape of feature extractor output 410# 411# Returns 412# ------- 413# feature_shape : torch.Size 414# shape of feature extractor output 415# """ 416# 417# return torch.Size([self.num_f_maps]) 418# 419# class MS_TCNA(Model): 420# """ 421# Basic MS-TCN++ model with additional options 422# """ 423# 424# def __init__( 425# self, 426# num_f_maps, 427# num_classes, 428# exclusive, 429# dims, 430# num_layers_R, 431# num_R, 432# num_layers_PG, 433# len_segment, 434# num_f_maps_R=None, 435# num_layers_S=0, 436# dropout_rate=0.5, 437# skip_connections_refinement=True, 438# block_size_prediction=0, 439# block_size_refinement=0, 440# kernel_size_prediction=3, 441# direction_PG=None, 442# direction_R=None, 443# PG_in_FE=False, 444# rare_dilations=False, 445# state_dict_path=None, 446# ssl_constructors=None, 447# ssl_types=None, 448# ssl_modules=None, 449# *args, **kwargs 450# ): 451# """ 452# Parameters 453# ---------- 454# num_f_maps : int 455# number of feature maps 456# num_classes : int 457# number of classes to predict 458# exclusive : bool 459# if `True`, single-label predictions are made; otherwise multi-label 460# dims : torch.Size 461# shape of features in the input data 462# num_layers_R : int 463# number of layers in the refinement stages 464# num_R : int 465# number of refinement stages 466# num_layers_PG : int 467# number of layers in the prediction generation stage 468# num_layers_S : int, default 0 469# number of layers in the spatial feature extraction stage 470# dropout_rate : float, default 0.5 471# dropout rate 472# shared_weights : bool, default False 473# if `True`, weights are shared across refinement stages 474# skip_connections_refinement : bool, default False 475# if `True`, skip connections are added to the refinement stages 476# block_size_prediction : int, default 0 477# if not 0, skip connections are added to the prediction generation stage with this interval 478# block_size_refinement : int, default 0 479# if not 0, skip connections are added to the refinement stage with this interval 480# direction_PG : [None, 'bidirectional', 'forward', 'backward'] 481# if not `None`, a combination of causal and anticausal convolutions are used in the 482# prediction generation stage 483# direction_R : [None, 'bidirectional', 'forward', 'backward'] 484# if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages 485# PG_in_FE : bool, default True 486# if `True`, the prediction generation stage is included in the feature extractor and otherwise in the 487# predictor (the output of the feature extractor is used in SSL tasks) 488# rare_dilations : bool, default False 489# if `False`, dilation increases every layer, otherwise every second layer in 490# the prediction generation stage 491# num_heads : int, default 1 492# the number of parallel refinement stages 493# state_dict_path : str, optional 494# if not `None`, the model state dictionary will be loaded from this path 495# ssl_constructors : list, optional 496# a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate 497# ssl_types : list, optional 498# a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`) 499# ssl_modules : list, optional 500# a list of SSL modules to integrate (used alternatively to `ssl_constructors`) 501# """ 502# 503# if len(dims) > 1: 504# raise RuntimeError( 505# "The MS-TCN++ model expects the input data to be 2-dimensional; " 506# f"got {len(dims) + 1} dimensions" 507# ) 508# self.num_layers_R = int(num_layers_R) 509# self.num_R = int(num_R) 510# self.num_f_maps = int(num_f_maps) 511# self.num_classes = int(num_classes) 512# self.dropout_rate = dropout_rate 513# self.exclusive = exclusive 514# self.num_layers_PG = int(num_layers_PG) 515# self.num_layers_S = int(num_layers_S) 516# self.dim = int(dims[0]) 517# self.skip_connections_ref = skip_connections_refinement 518# self.block_size_prediction = int(block_size_prediction) 519# self.block_size_refinement = int(block_size_refinement) 520# self.direction_R = direction_R 521# self.direction_PG = direction_PG 522# self.kernel_size_prediction = int(kernel_size_prediction) 523# self.PG_in_FE = PG_in_FE 524# self.rare_dilations = rare_dilations 525# self.len_segment = len_segment 526# if num_f_maps_R is None: 527# num_f_maps_R = num_f_maps 528# self.num_f_maps_R = num_f_maps_R 529# super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 530# 531# def _PG(self): 532# if self.num_layers_S == 0: 533# dim = self.dim 534# else: 535# dim = self.num_f_maps 536# if self.direction_PG == "bidirectional": 537# PG = DilatedTCNB( 538# num_layers=self.num_layers_PG, 539# num_f_maps=self.num_f_maps, 540# dim=dim, 541# block_size=self.block_size_prediction, 542# kernel_size=self.kernel_size_prediction, 543# rare_dilations=self.rare_dilations, 544# ) 545# else: 546# PG = DilatedTCN( 547# num_layers=self.num_layers_PG, 548# num_f_maps=self.num_f_maps, 549# dim=dim, 550# direction=self.direction_PG, 551# block_size=self.block_size_prediction, 552# kernel_size=self.kernel_size_prediction, 553# rare_dilations=self.rare_dilations, 554# ) 555# return PG 556# 557# def _feature_extractor(self): 558# if self.num_layers_S == 0: 559# if self.PG_in_FE: 560# return self._PG() 561# else: 562# return nn.Identity() 563# feature_extractor = SpatialFeatures( 564# self.num_layers_S, 565# self.num_f_maps, 566# self.dim, 567# self.block_size_prediction, 568# ) 569# if self.PG_in_FE: 570# PG = self._PG() 571# feature_extractor = [feature_extractor, PG] 572# return feature_extractor 573# 574# def _predictor(self): 575# predictor = MSRefinementAttention( 576# num_layers_R=int(self.num_layers_R), 577# num_R=int(self.num_R), 578# num_f_maps_input=int(self.num_f_maps), 579# num_f_maps=int(self.num_f_maps_R), 580# num_classes=int(self.num_classes), 581# dropout_rate=self.dropout_rate, 582# exclusive=self.exclusive, 583# skip_connections=self.skip_connections_ref, 584# block_size=self.block_size_refinement, 585# len_segment=self.len_segment, 586# ) 587# if not self.PG_in_FE: 588# PG = self._PG() 589# predictor = Compiled([PG, predictor]) 590# return predictor 591# 592# def features_shape(self) -> torch.Size: 593# """ 594# Get the shape of feature extractor output 595# 596# Returns 597# ------- 598# feature_shape : torch.Size 599# shape of feature extractor output 600# """ 601# 602# return torch.Size([self.num_f_maps])
21class Compiled(nn.Module): 22 def __init__(self, modules): 23 super(Compiled, self).__init__() 24 self.module_list = nn.ModuleList(modules) 25 26 def forward(self, x, tag=None): 27 """Forward pass.""" 28 for m in self.module_list: 29 x = m(x, tag) 30 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
33class MS_TCN3(Model): 34 """ 35 A modification of MS-TCN++ model with additional options 36 """ 37 38 def __init__( 39 self, 40 num_f_maps, 41 num_classes, 42 exclusive, 43 dims, 44 num_layers_R, 45 num_R, 46 num_layers_PG, 47 num_f_maps_R=None, 48 num_layers_S=0, 49 dropout_rate=0.5, 50 shared_weights=False, 51 skip_connections_refinement=True, 52 block_size_prediction=0, 53 block_size_refinement=0, 54 kernel_size_prediction=3, 55 direction_PG=None, 56 direction_R=None, 57 PG_in_FE=False, 58 rare_dilations=False, 59 num_heads=1, 60 R_attention="none", 61 PG_attention="none", 62 state_dict_path=None, 63 ssl_constructors=None, 64 ssl_types=None, 65 ssl_modules=None, 66 multihead=False, 67 *args, 68 **kwargs, 69 ): 70 """ 71 Parameters 72 ---------- 73 num_f_maps : int 74 number of feature maps 75 num_classes : int 76 number of classes to predict 77 exclusive : bool 78 if `True`, single-label predictions are made; otherwise multi-label 79 dims : torch.Size 80 shape of features in the input data 81 num_layers_R : int 82 number of layers in the refinement stages 83 num_R : int 84 number of refinement stages 85 num_layers_PG : int 86 number of layers in the prediction generation stage 87 num_layers_S : int, default 0 88 number of layers in the spatial feature extraction stage 89 dropout_rate : float, default 0.5 90 dropout rate 91 shared_weights : bool, default False 92 if `True`, weights are shared across refinement stages 93 skip_connections_refinement : bool, default False 94 if `True`, skip connections are added to the refinement stages 95 block_size_prediction : int, default 0 96 if not 0, skip connections are added to the prediction generation stage with this interval 97 block_size_refinement : int, default 0 98 if not 0, skip connections are added to the refinement stage with this interval 99 direction_PG : [None, 'bidirectional', 'forward', 'backward'] 100 if not `None`, a combination of causal and anticausal convolutions are used in the 101 prediction generation stage 102 direction_R : [None, 'bidirectional', 'forward', 'backward'] 103 if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages 104 PG_in_FE : bool, default True 105 if `True`, the prediction generation stage is included in the feature extractor and otherwise in the 106 predictor (the output of the feature extractor is used in SSL tasks) 107 rare_dilations : bool, default False 108 if `False`, dilation increases every layer, otherwise every second layer in 109 the prediction generation stage 110 num_heads : int, default 1 111 the number of parallel refinement stages 112 PG_attention : bool, default False 113 if `True`, an attention layer is added to the prediction generation stage 114 R_attention : bool, default False 115 if `True`, an attention layer is added to the refinement stages 116 state_dict_path : str, optional 117 if not `None`, the model state dictionary will be loaded from this path 118 ssl_constructors : list, optional 119 a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate 120 ssl_types : list, optional 121 a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`) 122 ssl_modules : list, optional 123 a list of SSL modules to integrate (used alternatively to `ssl_constructors`) 124 """ 125 126 self.num_layers_R = int(float(num_layers_R)) 127 self.num_R = int(float(num_R)) 128 self.num_f_maps = int(float(num_f_maps)) 129 self.num_classes = int(float(num_classes)) 130 self.dropout_rate = float(dropout_rate) 131 self.exclusive = bool(exclusive) 132 self.num_layers_PG = int(float(num_layers_PG)) 133 self.num_layers_S = int(float(num_layers_S)) 134 self.dim = self._get_dims(dims) 135 self.shared_weights = bool(shared_weights) 136 self.skip_connections_ref = bool(skip_connections_refinement) 137 self.block_size_prediction = int(float(block_size_prediction)) 138 self.block_size_refinement = int(float(block_size_refinement)) 139 self.direction_R = direction_R 140 self.direction_PG = direction_PG 141 self.kernel_size_prediction = int(float(kernel_size_prediction)) 142 self.PG_in_FE = PG_in_FE 143 self.rare_dilations = rare_dilations 144 self.num_heads = int(float(num_heads)) 145 self.PG_attention = PG_attention 146 self.R_attention = R_attention 147 self.multihead = multihead 148 if num_f_maps_R is None: 149 num_f_maps_R = self.num_f_maps 150 self.num_f_maps_R = num_f_maps_R 151 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 152 153 def _get_dims(self, dims): 154 return int(sum([s[0] for s in dims.values()])) 155 156 def _PG(self): 157 if self.num_layers_S == 0: 158 dim = self.dim 159 else: 160 dim = self.num_f_maps 161 if self.direction_PG == "bidirectional": 162 PG = DilatedTCNB( 163 num_layers=self.num_layers_PG, 164 num_f_maps=self.num_f_maps, 165 dim=dim, 166 block_size=self.block_size_prediction, 167 kernel_size=self.kernel_size_prediction, 168 rare_dilations=self.rare_dilations, 169 ) 170 else: 171 PG = DilatedTCN( 172 num_layers=self.num_layers_PG, 173 num_f_maps=self.num_f_maps, 174 dim=dim, 175 direction=self.direction_PG, 176 block_size=self.block_size_prediction, 177 kernel_size=self.kernel_size_prediction, 178 rare_dilations=self.rare_dilations, 179 attention=self.PG_attention, 180 multihead=self.multihead, 181 ) 182 return PG 183 184 def _feature_extractor(self): 185 if self.num_layers_S == 0: 186 if self.PG_in_FE: 187 print("MS-TCN using the prediction generator as a feature extractor") 188 return self._PG() 189 else: 190 print("MS-TCN without a feature extractor -> no SSL possible!") 191 return nn.Identity() 192 193 print("MS-TCN using a spatial feature extractor") 194 feature_extractor = SpatialFeatures( 195 self.num_layers_S, 196 self.num_f_maps, 197 self.dim, 198 self.block_size_prediction, 199 ) 200 if self.PG_in_FE: 201 print(" -> also has the prediction generator as a feature extractor") 202 PG = self._PG() 203 feature_extractor = [feature_extractor, PG] 204 return feature_extractor 205 206 def _predictor(self): 207 if self.shared_weights: 208 prediction_module = MSRefinementShared 209 else: 210 prediction_module = MSRefinement 211 predictor = prediction_module( 212 num_layers_R=int(self.num_layers_R), 213 num_R=int(self.num_R), 214 num_f_maps_input=int(self.num_f_maps), 215 num_f_maps=int(self.num_f_maps_R), 216 num_classes=int(self.num_classes), 217 dropout_rate=self.dropout_rate, 218 exclusive=self.exclusive, 219 skip_connections=self.skip_connections_ref, 220 direction=self.direction_R, 221 block_size=self.block_size_refinement, 222 num_heads=self.num_heads, 223 attention=self.R_attention, 224 ) 225 if not self.PG_in_FE: 226 PG = self._PG() 227 predictor = Compiled([PG, predictor]) 228 return predictor 229 230 def features_shape(self) -> torch.Size: 231 """ 232 Get the shape of feature extractor output 233 234 Returns 235 ------- 236 feature_shape : torch.Size 237 shape of feature extractor output 238 """ 239 240 return torch.Size([self.num_f_maps])
A modification of MS-TCN++ model with additional options
38 def __init__( 39 self, 40 num_f_maps, 41 num_classes, 42 exclusive, 43 dims, 44 num_layers_R, 45 num_R, 46 num_layers_PG, 47 num_f_maps_R=None, 48 num_layers_S=0, 49 dropout_rate=0.5, 50 shared_weights=False, 51 skip_connections_refinement=True, 52 block_size_prediction=0, 53 block_size_refinement=0, 54 kernel_size_prediction=3, 55 direction_PG=None, 56 direction_R=None, 57 PG_in_FE=False, 58 rare_dilations=False, 59 num_heads=1, 60 R_attention="none", 61 PG_attention="none", 62 state_dict_path=None, 63 ssl_constructors=None, 64 ssl_types=None, 65 ssl_modules=None, 66 multihead=False, 67 *args, 68 **kwargs, 69 ): 70 """ 71 Parameters 72 ---------- 73 num_f_maps : int 74 number of feature maps 75 num_classes : int 76 number of classes to predict 77 exclusive : bool 78 if `True`, single-label predictions are made; otherwise multi-label 79 dims : torch.Size 80 shape of features in the input data 81 num_layers_R : int 82 number of layers in the refinement stages 83 num_R : int 84 number of refinement stages 85 num_layers_PG : int 86 number of layers in the prediction generation stage 87 num_layers_S : int, default 0 88 number of layers in the spatial feature extraction stage 89 dropout_rate : float, default 0.5 90 dropout rate 91 shared_weights : bool, default False 92 if `True`, weights are shared across refinement stages 93 skip_connections_refinement : bool, default False 94 if `True`, skip connections are added to the refinement stages 95 block_size_prediction : int, default 0 96 if not 0, skip connections are added to the prediction generation stage with this interval 97 block_size_refinement : int, default 0 98 if not 0, skip connections are added to the refinement stage with this interval 99 direction_PG : [None, 'bidirectional', 'forward', 'backward'] 100 if not `None`, a combination of causal and anticausal convolutions are used in the 101 prediction generation stage 102 direction_R : [None, 'bidirectional', 'forward', 'backward'] 103 if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages 104 PG_in_FE : bool, default True 105 if `True`, the prediction generation stage is included in the feature extractor and otherwise in the 106 predictor (the output of the feature extractor is used in SSL tasks) 107 rare_dilations : bool, default False 108 if `False`, dilation increases every layer, otherwise every second layer in 109 the prediction generation stage 110 num_heads : int, default 1 111 the number of parallel refinement stages 112 PG_attention : bool, default False 113 if `True`, an attention layer is added to the prediction generation stage 114 R_attention : bool, default False 115 if `True`, an attention layer is added to the refinement stages 116 state_dict_path : str, optional 117 if not `None`, the model state dictionary will be loaded from this path 118 ssl_constructors : list, optional 119 a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate 120 ssl_types : list, optional 121 a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`) 122 ssl_modules : list, optional 123 a list of SSL modules to integrate (used alternatively to `ssl_constructors`) 124 """ 125 126 self.num_layers_R = int(float(num_layers_R)) 127 self.num_R = int(float(num_R)) 128 self.num_f_maps = int(float(num_f_maps)) 129 self.num_classes = int(float(num_classes)) 130 self.dropout_rate = float(dropout_rate) 131 self.exclusive = bool(exclusive) 132 self.num_layers_PG = int(float(num_layers_PG)) 133 self.num_layers_S = int(float(num_layers_S)) 134 self.dim = self._get_dims(dims) 135 self.shared_weights = bool(shared_weights) 136 self.skip_connections_ref = bool(skip_connections_refinement) 137 self.block_size_prediction = int(float(block_size_prediction)) 138 self.block_size_refinement = int(float(block_size_refinement)) 139 self.direction_R = direction_R 140 self.direction_PG = direction_PG 141 self.kernel_size_prediction = int(float(kernel_size_prediction)) 142 self.PG_in_FE = PG_in_FE 143 self.rare_dilations = rare_dilations 144 self.num_heads = int(float(num_heads)) 145 self.PG_attention = PG_attention 146 self.R_attention = R_attention 147 self.multihead = multihead 148 if num_f_maps_R is None: 149 num_f_maps_R = self.num_f_maps 150 self.num_f_maps_R = num_f_maps_R 151 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Parameters
num_f_maps : int
number of feature maps
num_classes : int
number of classes to predict
exclusive : bool
if True
, single-label predictions are made; otherwise multi-label
dims : torch.Size
shape of features in the input data
num_layers_R : int
number of layers in the refinement stages
num_R : int
number of refinement stages
num_layers_PG : int
number of layers in the prediction generation stage
num_layers_S : int, default 0
number of layers in the spatial feature extraction stage
dropout_rate : float, default 0.5
dropout rate
shared_weights : bool, default False
if True
, weights are shared across refinement stages
skip_connections_refinement : bool, default False
if True
, skip connections are added to the refinement stages
block_size_prediction : int, default 0
if not 0, skip connections are added to the prediction generation stage with this interval
block_size_refinement : int, default 0
if not 0, skip connections are added to the refinement stage with this interval
direction_PG : [None, 'bidirectional', 'forward', 'backward']
if not None
, a combination of causal and anticausal convolutions are used in the
prediction generation stage
direction_R : [None, 'bidirectional', 'forward', 'backward']
if not None
, a combination of causal and anticausal convolutions are used in the refinement stages
PG_in_FE : bool, default True
if True
, the prediction generation stage is included in the feature extractor and otherwise in the
predictor (the output of the feature extractor is used in SSL tasks)
rare_dilations : bool, default False
if False
, dilation increases every layer, otherwise every second layer in
the prediction generation stage
num_heads : int, default 1
the number of parallel refinement stages
PG_attention : bool, default False
if True
, an attention layer is added to the prediction generation stage
R_attention : bool, default False
if True
, an attention layer is added to the refinement stages
state_dict_path : str, optional
if not None
, the model state dictionary will be loaded from this path
ssl_constructors : list, optional
a list of dlc2action.ssl.base_ssl.SSLConstructor
instances to integrate
ssl_types : list, optional
a list of types of the SSL modules to integrate (used alternatively to ssl_constructors
)
ssl_modules : list, optional
a list of SSL modules to integrate (used alternatively to ssl_constructors
)
230 def features_shape(self) -> torch.Size: 231 """ 232 Get the shape of feature extractor output 233 234 Returns 235 ------- 236 feature_shape : torch.Size 237 shape of feature extractor output 238 """ 239 240 return torch.Size([self.num_f_maps])
Get the shape of feature extractor output
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- feature_extractor
- feature_extractors
- predictor
- ssl_active
- main_task_active
- prompt_function
- class_tensors
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- transform_labels
- forward
243class MS_TCN_P(MS_TCN3): 244 def _get_dims(self, dims): 245 keys = list(dims.keys()) 246 values = list(dims.values()) 247 groups = [key.split("---")[-1] for key in keys] 248 unique_groups = sorted(set(groups)) 249 res = [] 250 for group in unique_groups: 251 res.append(int(sum([x[0] for x, g in zip(values, groups) if g == group]))) 252 if "loaded" in dims: 253 res.append(int(dims["loaded"][0])) 254 return res 255 256 def _PG(self): 257 PG = MultiDilatedTCN( 258 self.num_layers_PG, 259 self.num_f_maps, 260 self.dim, 261 self.direction_PG, 262 self.block_size_prediction, 263 self.kernel_size_prediction, 264 self.rare_dilations, 265 ) 266 return PG
A modification of MS-TCN++ model with additional options
Inherited Members
- MS_TCN3
- MS_TCN3
- num_layers_R
- num_R
- num_f_maps
- num_classes
- dropout_rate
- exclusive
- num_layers_PG
- num_layers_S
- dim
- skip_connections_ref
- block_size_prediction
- block_size_refinement
- direction_R
- direction_PG
- kernel_size_prediction
- PG_in_FE
- rare_dilations
- num_heads
- PG_attention
- R_attention
- multihead
- num_f_maps_R
- features_shape
- dlc2action.model.base_model.Model
- process_labels
- feature_extractor
- feature_extractors
- predictor
- ssl_active
- main_task_active
- prompt_function
- class_tensors
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- transform_labels
- forward