dlc2action.model.ms_tcn
MS-TCN++ (multi-stage temporal convolutional network) variations
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# Incorporates code adapted from MS-TCN++ by yabufarha 7# Original work Copyright (c) 2019 June01 8# Source: https://github.com/sj-li/MS-TCN2 9# Originally licensed under MIT License 10# Combined work licensed under GNU AGPLv3 11# 12""" 13MS-TCN++ (multi-stage temporal convolutional network) variations 14""" 15 16from dlc2action.model.base_model import Model 17from dlc2action.model.ms_tcn_modules import * 18 19 20class Compiled(nn.Module): 21 def __init__(self, modules): 22 super(Compiled, self).__init__() 23 self.module_list = nn.ModuleList(modules) 24 25 def forward(self, x, tag=None): 26 """Forward pass.""" 27 for m in self.module_list: 28 x = m(x, tag) 29 return x 30 31 32class MS_TCN3(Model): 33 """ 34 A modification of MS-TCN++ model with additional options 35 """ 36 37 def __init__( 38 self, 39 num_f_maps, 40 num_classes, 41 exclusive, 42 dims, 43 num_layers_R, 44 num_R, 45 num_layers_PG, 46 num_f_maps_R=None, 47 num_layers_S=0, 48 dropout_rate=0.5, 49 shared_weights=False, 50 skip_connections_refinement=True, 51 block_size_prediction=0, 52 block_size_refinement=0, 53 kernel_size_prediction=3, 54 direction_PG=None, 55 direction_R=None, 56 PG_in_FE=False, 57 rare_dilations=False, 58 num_heads=1, 59 R_attention="none", 60 PG_attention="none", 61 state_dict_path=None, 62 ssl_constructors=None, 63 ssl_types=None, 64 ssl_modules=None, 65 multihead=False, 66 *args, 67 **kwargs, 68 ): 69 """ 70 Parameters 71 ---------- 72 num_f_maps : int 73 number of feature maps 74 num_classes : int 75 number of classes to predict 76 exclusive : bool 77 if `True`, single-label predictions are made; otherwise multi-label 78 dims : torch.Size 79 shape of features in the input data 80 num_layers_R : int 81 number of layers in the refinement stages 82 num_R : int 83 number of refinement stages 84 num_layers_PG : int 85 number of layers in the prediction generation stage 86 num_layers_S : int, default 0 87 number of layers in the spatial feature extraction stage 88 dropout_rate : float, default 0.5 89 dropout rate 90 shared_weights : bool, default False 91 if `True`, weights are shared across refinement stages 92 skip_connections_refinement : bool, default False 93 if `True`, skip connections are added to the refinement stages 94 block_size_prediction : int, default 0 95 if not 0, skip connections are added to the prediction generation stage with this interval 96 block_size_refinement : int, default 0 97 if not 0, skip connections are added to the refinement stage with this interval 98 direction_PG : [None, 'bidirectional', 'forward', 'backward'] 99 if not `None`, a combination of causal and anticausal convolutions are used in the 100 prediction generation stage 101 direction_R : [None, 'bidirectional', 'forward', 'backward'] 102 if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages 103 PG_in_FE : bool, default True 104 if `True`, the prediction generation stage is included in the feature extractor and otherwise in the 105 predictor (the output of the feature extractor is used in SSL tasks) 106 rare_dilations : bool, default False 107 if `False`, dilation increases every layer, otherwise every second layer in 108 the prediction generation stage 109 num_heads : int, default 1 110 the number of parallel refinement stages 111 PG_attention : bool, default False 112 if `True`, an attention layer is added to the prediction generation stage 113 R_attention : bool, default False 114 if `True`, an attention layer is added to the refinement stages 115 state_dict_path : str, optional 116 if not `None`, the model state dictionary will be loaded from this path 117 ssl_constructors : list, optional 118 a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate 119 ssl_types : list, optional 120 a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`) 121 ssl_modules : list, optional 122 a list of SSL modules to integrate (used alternatively to `ssl_constructors`) 123 """ 124 125 self.num_layers_R = int(float(num_layers_R)) 126 self.num_R = int(float(num_R)) 127 self.num_f_maps = int(float(num_f_maps)) 128 self.num_classes = int(float(num_classes)) 129 self.dropout_rate = float(dropout_rate) 130 self.exclusive = bool(exclusive) 131 self.num_layers_PG = int(float(num_layers_PG)) 132 self.num_layers_S = int(float(num_layers_S)) 133 self.dim = self._get_dims(dims) 134 self.shared_weights = bool(shared_weights) 135 self.skip_connections_ref = bool(skip_connections_refinement) 136 self.block_size_prediction = int(float(block_size_prediction)) 137 self.block_size_refinement = int(float(block_size_refinement)) 138 self.direction_R = direction_R 139 self.direction_PG = direction_PG 140 self.kernel_size_prediction = int(float(kernel_size_prediction)) 141 self.PG_in_FE = PG_in_FE 142 self.rare_dilations = rare_dilations 143 self.num_heads = int(float(num_heads)) 144 self.PG_attention = PG_attention 145 self.R_attention = R_attention 146 self.multihead = multihead 147 if num_f_maps_R is None: 148 num_f_maps_R = self.num_f_maps 149 self.num_f_maps_R = num_f_maps_R 150 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 151 152 def _get_dims(self, dims): 153 return int(sum([s[0] for s in dims.values()])) 154 155 def _PG(self): 156 if self.num_layers_S == 0: 157 dim = self.dim 158 else: 159 dim = self.num_f_maps 160 if self.direction_PG == "bidirectional": 161 PG = DilatedTCNB( 162 num_layers=self.num_layers_PG, 163 num_f_maps=self.num_f_maps, 164 dim=dim, 165 block_size=self.block_size_prediction, 166 kernel_size=self.kernel_size_prediction, 167 rare_dilations=self.rare_dilations, 168 ) 169 else: 170 PG = DilatedTCN( 171 num_layers=self.num_layers_PG, 172 num_f_maps=self.num_f_maps, 173 dim=dim, 174 direction=self.direction_PG, 175 block_size=self.block_size_prediction, 176 kernel_size=self.kernel_size_prediction, 177 rare_dilations=self.rare_dilations, 178 attention=self.PG_attention, 179 multihead=self.multihead, 180 ) 181 return PG 182 183 def _feature_extractor(self): 184 if self.num_layers_S == 0: 185 if self.PG_in_FE: 186 print("MS-TCN using the prediction generator as a feature extractor") 187 return self._PG() 188 else: 189 print("MS-TCN without a feature extractor -> no SSL possible!") 190 return nn.Identity() 191 192 print("MS-TCN using a spatial feature extractor") 193 feature_extractor = SpatialFeatures( 194 self.num_layers_S, 195 self.num_f_maps, 196 self.dim, 197 self.block_size_prediction, 198 ) 199 if self.PG_in_FE: 200 print(" -> also has the prediction generator as a feature extractor") 201 PG = self._PG() 202 feature_extractor = [feature_extractor, PG] 203 return feature_extractor 204 205 def _predictor(self): 206 if self.shared_weights: 207 prediction_module = MSRefinementShared 208 else: 209 prediction_module = MSRefinement 210 predictor = prediction_module( 211 num_layers_R=int(self.num_layers_R), 212 num_R=int(self.num_R), 213 num_f_maps_input=int(self.num_f_maps), 214 num_f_maps=int(self.num_f_maps_R), 215 num_classes=int(self.num_classes), 216 dropout_rate=self.dropout_rate, 217 exclusive=self.exclusive, 218 skip_connections=self.skip_connections_ref, 219 direction=self.direction_R, 220 block_size=self.block_size_refinement, 221 num_heads=self.num_heads, 222 attention=self.R_attention, 223 ) 224 if not self.PG_in_FE: 225 PG = self._PG() 226 predictor = Compiled([PG, predictor]) 227 return predictor 228 229 def features_shape(self) -> torch.Size: 230 """ 231 Get the shape of feature extractor output 232 233 Returns 234 ------- 235 feature_shape : torch.Size 236 shape of feature extractor output 237 """ 238 239 return torch.Size([self.num_f_maps]) 240 241 242class MS_TCN_P(MS_TCN3): 243 def _get_dims(self, dims): 244 keys = list(dims.keys()) 245 values = list(dims.values()) 246 groups = [key.split("---")[-1] for key in keys] 247 unique_groups = sorted(set(groups)) 248 res = [] 249 for group in unique_groups: 250 res.append(int(sum([x[0] for x, g in zip(values, groups) if g == group]))) 251 if "loaded" in dims: 252 res.append(int(dims["loaded"][0])) 253 return res 254 255 def _PG(self): 256 PG = MultiDilatedTCN( 257 self.num_layers_PG, 258 self.num_f_maps, 259 self.dim, 260 self.direction_PG, 261 self.block_size_prediction, 262 self.kernel_size_prediction, 263 self.rare_dilations, 264 ) 265 return PG 266 267 268# class MS_TCNC(Model): 269# """ 270# Basic MS-TCN++ model with options for shared weights and added skip connections 271# """ 272# 273# def __init__( 274# self, 275# num_layers_R, 276# num_R, 277# num_f_maps, 278# num_classes, 279# exclusive, 280# num_layers_PG, 281# num_layers_S, 282# dims, 283# len_segment, 284# dropout_rate=0.5, 285# shared_weights=False, 286# skip_connections_refinement=True, 287# block_size_prediction=5, 288# block_size_refinement=0, 289# kernel_size_prediction=3, 290# direction_PG=None, 291# direction_R=None, 292# PG_in_FE=False, 293# state_dict_path=None, 294# ssl_constructors=None, 295# ssl_types=None, 296# ssl_modules=None, 297# ): 298# """ 299# Parameters 300# ---------- 301# num_layers_R : int 302# number of layers in the refinement stages 303# num_R : int 304# number of refinement stages 305# num_f_maps : int 306# number of feature maps 307# num_classes : int 308# number of classes to predict 309# exclusive : bool 310# if `True`, single-label predictions are made; otherwise multi-label 311# num_layers_PG : int 312# number of layers in the prediction generation stage 313# dims : torch.Size 314# shape of features in the input data 315# dropout_rate : float, default 0.5 316# dropout rate 317# shared_weights : bool, default False 318# if `True`, weights are shared across refinement stages 319# skip_connections_refinement : bool, default False 320# if `True`, skip connections are added to the refinement stages 321# block_size_prediction : int, optional 322# if not 'None', skip connections are added to the prediction generation stage with this interval 323# direction_PG : bool, default True 324# if True, causal convolutions are used in the prediction generation stage 325# direction_R : bool, default False 326# if True, causal convolutions are used in the refinement stages 327# state_dict_path : str, optional 328# if not `None`, the model state dictionary will be loaded from this path 329# ssl_constructors : list, optional 330# a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate 331# ssl_types : list, optional 332# a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`) 333# ssl_modules : list, optional 334# a list of SSL modules to integrate (used alternatively to `ssl_constructors`) 335# """ 336# 337# if len(dims) > 1: 338# raise RuntimeError( 339# "The MS-TCN++ model expects the input data to be 2-dimensional; " 340# f"got {len(dims) + 1} dimensions" 341# ) 342# self.num_layers_R = int(num_layers_R) 343# self.num_R = int(num_R) 344# self.num_f_maps = int(num_f_maps) 345# self.num_classes = int(num_classes) 346# self.dropout_rate = dropout_rate 347# self.exclusive = exclusive 348# self.num_layers_PG = int(num_layers_PG) 349# self.num_layers_S = int(num_layers_S) 350# self.dim = int(dims[0]) 351# self.shared_weights = shared_weights 352# self.skip_connections_ref = skip_connections_refinement 353# self.block_size_prediction = int(block_size_prediction) 354# self.block_size_refinement = int(block_size_refinement) 355# self.direction_R = direction_R 356# self.direction_PG = direction_PG 357# self.kernel_size_prediction = int(kernel_size_prediction) 358# self.PG_in_FE = PG_in_FE 359# self.len_segment = len_segment 360# super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 361# 362# def _PG(self): 363# PG = DilatedTCNC( 364# num_f_maps=self.num_f_maps, 365# num_layers_PG=self.num_layers_PG, 366# len_segment=self.len_segment, 367# block_size_prediction=self.block_size_prediction, 368# kernel_size_prediction=self.kernel_size_prediction, 369# direction_PG=self.direction_PG, 370# ) 371# return PG 372# 373# def _feature_extractor(self): 374# feature_extractor = SpatialFeatures( 375# num_layers=self.num_layers_S, 376# num_f_maps=self.num_f_maps, 377# dim=self.dim, 378# block_size=self.block_size_prediction, 379# ) 380# if self.PG_in_FE: 381# PG = self._PG() 382# feature_extractor = [feature_extractor, PG] 383# return feature_extractor 384# 385# def _predictor(self): 386# 387# if self.shared_weights: 388# prediction_module = MSRefinementShared 389# else: 390# prediction_module = MSRefinement 391# predictor = prediction_module( 392# num_layers_R=int(self.num_layers_R), 393# num_R=int(self.num_R), 394# num_f_maps=int(self.num_f_maps), 395# num_classes=int(self.num_classes), 396# dropout_rate=self.dropout_rate, 397# exclusive=self.exclusive, 398# skip_connections=self.skip_connections_ref, 399# direction=self.direction_R, 400# block_size=self.block_size_refinement, 401# ) 402# if not self.PG_in_FE: 403# PG = self._PG() 404# predictor = Compiled([PG, predictor]) 405# return predictor 406# 407# def features_shape(self) -> torch.Size: 408# """ 409# Get the shape of feature extractor output 410# 411# Returns 412# ------- 413# feature_shape : torch.Size 414# shape of feature extractor output 415# """ 416# 417# return torch.Size([self.num_f_maps]) 418# 419# class MS_TCNA(Model): 420# """ 421# Basic MS-TCN++ model with additional options 422# """ 423# 424# def __init__( 425# self, 426# num_f_maps, 427# num_classes, 428# exclusive, 429# dims, 430# num_layers_R, 431# num_R, 432# num_layers_PG, 433# len_segment, 434# num_f_maps_R=None, 435# num_layers_S=0, 436# dropout_rate=0.5, 437# skip_connections_refinement=True, 438# block_size_prediction=0, 439# block_size_refinement=0, 440# kernel_size_prediction=3, 441# direction_PG=None, 442# direction_R=None, 443# PG_in_FE=False, 444# rare_dilations=False, 445# state_dict_path=None, 446# ssl_constructors=None, 447# ssl_types=None, 448# ssl_modules=None, 449# *args, **kwargs 450# ): 451# """ 452# Parameters 453# ---------- 454# num_f_maps : int 455# number of feature maps 456# num_classes : int 457# number of classes to predict 458# exclusive : bool 459# if `True`, single-label predictions are made; otherwise multi-label 460# dims : torch.Size 461# shape of features in the input data 462# num_layers_R : int 463# number of layers in the refinement stages 464# num_R : int 465# number of refinement stages 466# num_layers_PG : int 467# number of layers in the prediction generation stage 468# num_layers_S : int, default 0 469# number of layers in the spatial feature extraction stage 470# dropout_rate : float, default 0.5 471# dropout rate 472# shared_weights : bool, default False 473# if `True`, weights are shared across refinement stages 474# skip_connections_refinement : bool, default False 475# if `True`, skip connections are added to the refinement stages 476# block_size_prediction : int, default 0 477# if not 0, skip connections are added to the prediction generation stage with this interval 478# block_size_refinement : int, default 0 479# if not 0, skip connections are added to the refinement stage with this interval 480# direction_PG : [None, 'bidirectional', 'forward', 'backward'] 481# if not `None`, a combination of causal and anticausal convolutions are used in the 482# prediction generation stage 483# direction_R : [None, 'bidirectional', 'forward', 'backward'] 484# if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages 485# PG_in_FE : bool, default True 486# if `True`, the prediction generation stage is included in the feature extractor and otherwise in the 487# predictor (the output of the feature extractor is used in SSL tasks) 488# rare_dilations : bool, default False 489# if `False`, dilation increases every layer, otherwise every second layer in 490# the prediction generation stage 491# num_heads : int, default 1 492# the number of parallel refinement stages 493# state_dict_path : str, optional 494# if not `None`, the model state dictionary will be loaded from this path 495# ssl_constructors : list, optional 496# a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate 497# ssl_types : list, optional 498# a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`) 499# ssl_modules : list, optional 500# a list of SSL modules to integrate (used alternatively to `ssl_constructors`) 501# """ 502# 503# if len(dims) > 1: 504# raise RuntimeError( 505# "The MS-TCN++ model expects the input data to be 2-dimensional; " 506# f"got {len(dims) + 1} dimensions" 507# ) 508# self.num_layers_R = int(num_layers_R) 509# self.num_R = int(num_R) 510# self.num_f_maps = int(num_f_maps) 511# self.num_classes = int(num_classes) 512# self.dropout_rate = dropout_rate 513# self.exclusive = exclusive 514# self.num_layers_PG = int(num_layers_PG) 515# self.num_layers_S = int(num_layers_S) 516# self.dim = int(dims[0]) 517# self.skip_connections_ref = skip_connections_refinement 518# self.block_size_prediction = int(block_size_prediction) 519# self.block_size_refinement = int(block_size_refinement) 520# self.direction_R = direction_R 521# self.direction_PG = direction_PG 522# self.kernel_size_prediction = int(kernel_size_prediction) 523# self.PG_in_FE = PG_in_FE 524# self.rare_dilations = rare_dilations 525# self.len_segment = len_segment 526# if num_f_maps_R is None: 527# num_f_maps_R = num_f_maps 528# self.num_f_maps_R = num_f_maps_R 529# super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 530# 531# def _PG(self): 532# if self.num_layers_S == 0: 533# dim = self.dim 534# else: 535# dim = self.num_f_maps 536# if self.direction_PG == "bidirectional": 537# PG = DilatedTCNB( 538# num_layers=self.num_layers_PG, 539# num_f_maps=self.num_f_maps, 540# dim=dim, 541# block_size=self.block_size_prediction, 542# kernel_size=self.kernel_size_prediction, 543# rare_dilations=self.rare_dilations, 544# ) 545# else: 546# PG = DilatedTCN( 547# num_layers=self.num_layers_PG, 548# num_f_maps=self.num_f_maps, 549# dim=dim, 550# direction=self.direction_PG, 551# block_size=self.block_size_prediction, 552# kernel_size=self.kernel_size_prediction, 553# rare_dilations=self.rare_dilations, 554# ) 555# return PG 556# 557# def _feature_extractor(self): 558# if self.num_layers_S == 0: 559# if self.PG_in_FE: 560# return self._PG() 561# else: 562# return nn.Identity() 563# feature_extractor = SpatialFeatures( 564# self.num_layers_S, 565# self.num_f_maps, 566# self.dim, 567# self.block_size_prediction, 568# ) 569# if self.PG_in_FE: 570# PG = self._PG() 571# feature_extractor = [feature_extractor, PG] 572# return feature_extractor 573# 574# def _predictor(self): 575# predictor = MSRefinementAttention( 576# num_layers_R=int(self.num_layers_R), 577# num_R=int(self.num_R), 578# num_f_maps_input=int(self.num_f_maps), 579# num_f_maps=int(self.num_f_maps_R), 580# num_classes=int(self.num_classes), 581# dropout_rate=self.dropout_rate, 582# exclusive=self.exclusive, 583# skip_connections=self.skip_connections_ref, 584# block_size=self.block_size_refinement, 585# len_segment=self.len_segment, 586# ) 587# if not self.PG_in_FE: 588# PG = self._PG() 589# predictor = Compiled([PG, predictor]) 590# return predictor 591# 592# def features_shape(self) -> torch.Size: 593# """ 594# Get the shape of feature extractor output 595# 596# Returns 597# ------- 598# feature_shape : torch.Size 599# shape of feature extractor output 600# """ 601# 602# return torch.Size([self.num_f_maps])
21class Compiled(nn.Module): 22 def __init__(self, modules): 23 super(Compiled, self).__init__() 24 self.module_list = nn.ModuleList(modules) 25 26 def forward(self, x, tag=None): 27 """Forward pass.""" 28 for m in self.module_list: 29 x = m(x, tag) 30 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
33class MS_TCN3(Model): 34 """ 35 A modification of MS-TCN++ model with additional options 36 """ 37 38 def __init__( 39 self, 40 num_f_maps, 41 num_classes, 42 exclusive, 43 dims, 44 num_layers_R, 45 num_R, 46 num_layers_PG, 47 num_f_maps_R=None, 48 num_layers_S=0, 49 dropout_rate=0.5, 50 shared_weights=False, 51 skip_connections_refinement=True, 52 block_size_prediction=0, 53 block_size_refinement=0, 54 kernel_size_prediction=3, 55 direction_PG=None, 56 direction_R=None, 57 PG_in_FE=False, 58 rare_dilations=False, 59 num_heads=1, 60 R_attention="none", 61 PG_attention="none", 62 state_dict_path=None, 63 ssl_constructors=None, 64 ssl_types=None, 65 ssl_modules=None, 66 multihead=False, 67 *args, 68 **kwargs, 69 ): 70 """ 71 Parameters 72 ---------- 73 num_f_maps : int 74 number of feature maps 75 num_classes : int 76 number of classes to predict 77 exclusive : bool 78 if `True`, single-label predictions are made; otherwise multi-label 79 dims : torch.Size 80 shape of features in the input data 81 num_layers_R : int 82 number of layers in the refinement stages 83 num_R : int 84 number of refinement stages 85 num_layers_PG : int 86 number of layers in the prediction generation stage 87 num_layers_S : int, default 0 88 number of layers in the spatial feature extraction stage 89 dropout_rate : float, default 0.5 90 dropout rate 91 shared_weights : bool, default False 92 if `True`, weights are shared across refinement stages 93 skip_connections_refinement : bool, default False 94 if `True`, skip connections are added to the refinement stages 95 block_size_prediction : int, default 0 96 if not 0, skip connections are added to the prediction generation stage with this interval 97 block_size_refinement : int, default 0 98 if not 0, skip connections are added to the refinement stage with this interval 99 direction_PG : [None, 'bidirectional', 'forward', 'backward'] 100 if not `None`, a combination of causal and anticausal convolutions are used in the 101 prediction generation stage 102 direction_R : [None, 'bidirectional', 'forward', 'backward'] 103 if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages 104 PG_in_FE : bool, default True 105 if `True`, the prediction generation stage is included in the feature extractor and otherwise in the 106 predictor (the output of the feature extractor is used in SSL tasks) 107 rare_dilations : bool, default False 108 if `False`, dilation increases every layer, otherwise every second layer in 109 the prediction generation stage 110 num_heads : int, default 1 111 the number of parallel refinement stages 112 PG_attention : bool, default False 113 if `True`, an attention layer is added to the prediction generation stage 114 R_attention : bool, default False 115 if `True`, an attention layer is added to the refinement stages 116 state_dict_path : str, optional 117 if not `None`, the model state dictionary will be loaded from this path 118 ssl_constructors : list, optional 119 a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate 120 ssl_types : list, optional 121 a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`) 122 ssl_modules : list, optional 123 a list of SSL modules to integrate (used alternatively to `ssl_constructors`) 124 """ 125 126 self.num_layers_R = int(float(num_layers_R)) 127 self.num_R = int(float(num_R)) 128 self.num_f_maps = int(float(num_f_maps)) 129 self.num_classes = int(float(num_classes)) 130 self.dropout_rate = float(dropout_rate) 131 self.exclusive = bool(exclusive) 132 self.num_layers_PG = int(float(num_layers_PG)) 133 self.num_layers_S = int(float(num_layers_S)) 134 self.dim = self._get_dims(dims) 135 self.shared_weights = bool(shared_weights) 136 self.skip_connections_ref = bool(skip_connections_refinement) 137 self.block_size_prediction = int(float(block_size_prediction)) 138 self.block_size_refinement = int(float(block_size_refinement)) 139 self.direction_R = direction_R 140 self.direction_PG = direction_PG 141 self.kernel_size_prediction = int(float(kernel_size_prediction)) 142 self.PG_in_FE = PG_in_FE 143 self.rare_dilations = rare_dilations 144 self.num_heads = int(float(num_heads)) 145 self.PG_attention = PG_attention 146 self.R_attention = R_attention 147 self.multihead = multihead 148 if num_f_maps_R is None: 149 num_f_maps_R = self.num_f_maps 150 self.num_f_maps_R = num_f_maps_R 151 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 152 153 def _get_dims(self, dims): 154 return int(sum([s[0] for s in dims.values()])) 155 156 def _PG(self): 157 if self.num_layers_S == 0: 158 dim = self.dim 159 else: 160 dim = self.num_f_maps 161 if self.direction_PG == "bidirectional": 162 PG = DilatedTCNB( 163 num_layers=self.num_layers_PG, 164 num_f_maps=self.num_f_maps, 165 dim=dim, 166 block_size=self.block_size_prediction, 167 kernel_size=self.kernel_size_prediction, 168 rare_dilations=self.rare_dilations, 169 ) 170 else: 171 PG = DilatedTCN( 172 num_layers=self.num_layers_PG, 173 num_f_maps=self.num_f_maps, 174 dim=dim, 175 direction=self.direction_PG, 176 block_size=self.block_size_prediction, 177 kernel_size=self.kernel_size_prediction, 178 rare_dilations=self.rare_dilations, 179 attention=self.PG_attention, 180 multihead=self.multihead, 181 ) 182 return PG 183 184 def _feature_extractor(self): 185 if self.num_layers_S == 0: 186 if self.PG_in_FE: 187 print("MS-TCN using the prediction generator as a feature extractor") 188 return self._PG() 189 else: 190 print("MS-TCN without a feature extractor -> no SSL possible!") 191 return nn.Identity() 192 193 print("MS-TCN using a spatial feature extractor") 194 feature_extractor = SpatialFeatures( 195 self.num_layers_S, 196 self.num_f_maps, 197 self.dim, 198 self.block_size_prediction, 199 ) 200 if self.PG_in_FE: 201 print(" -> also has the prediction generator as a feature extractor") 202 PG = self._PG() 203 feature_extractor = [feature_extractor, PG] 204 return feature_extractor 205 206 def _predictor(self): 207 if self.shared_weights: 208 prediction_module = MSRefinementShared 209 else: 210 prediction_module = MSRefinement 211 predictor = prediction_module( 212 num_layers_R=int(self.num_layers_R), 213 num_R=int(self.num_R), 214 num_f_maps_input=int(self.num_f_maps), 215 num_f_maps=int(self.num_f_maps_R), 216 num_classes=int(self.num_classes), 217 dropout_rate=self.dropout_rate, 218 exclusive=self.exclusive, 219 skip_connections=self.skip_connections_ref, 220 direction=self.direction_R, 221 block_size=self.block_size_refinement, 222 num_heads=self.num_heads, 223 attention=self.R_attention, 224 ) 225 if not self.PG_in_FE: 226 PG = self._PG() 227 predictor = Compiled([PG, predictor]) 228 return predictor 229 230 def features_shape(self) -> torch.Size: 231 """ 232 Get the shape of feature extractor output 233 234 Returns 235 ------- 236 feature_shape : torch.Size 237 shape of feature extractor output 238 """ 239 240 return torch.Size([self.num_f_maps])
A modification of MS-TCN++ model with additional options
38 def __init__( 39 self, 40 num_f_maps, 41 num_classes, 42 exclusive, 43 dims, 44 num_layers_R, 45 num_R, 46 num_layers_PG, 47 num_f_maps_R=None, 48 num_layers_S=0, 49 dropout_rate=0.5, 50 shared_weights=False, 51 skip_connections_refinement=True, 52 block_size_prediction=0, 53 block_size_refinement=0, 54 kernel_size_prediction=3, 55 direction_PG=None, 56 direction_R=None, 57 PG_in_FE=False, 58 rare_dilations=False, 59 num_heads=1, 60 R_attention="none", 61 PG_attention="none", 62 state_dict_path=None, 63 ssl_constructors=None, 64 ssl_types=None, 65 ssl_modules=None, 66 multihead=False, 67 *args, 68 **kwargs, 69 ): 70 """ 71 Parameters 72 ---------- 73 num_f_maps : int 74 number of feature maps 75 num_classes : int 76 number of classes to predict 77 exclusive : bool 78 if `True`, single-label predictions are made; otherwise multi-label 79 dims : torch.Size 80 shape of features in the input data 81 num_layers_R : int 82 number of layers in the refinement stages 83 num_R : int 84 number of refinement stages 85 num_layers_PG : int 86 number of layers in the prediction generation stage 87 num_layers_S : int, default 0 88 number of layers in the spatial feature extraction stage 89 dropout_rate : float, default 0.5 90 dropout rate 91 shared_weights : bool, default False 92 if `True`, weights are shared across refinement stages 93 skip_connections_refinement : bool, default False 94 if `True`, skip connections are added to the refinement stages 95 block_size_prediction : int, default 0 96 if not 0, skip connections are added to the prediction generation stage with this interval 97 block_size_refinement : int, default 0 98 if not 0, skip connections are added to the refinement stage with this interval 99 direction_PG : [None, 'bidirectional', 'forward', 'backward'] 100 if not `None`, a combination of causal and anticausal convolutions are used in the 101 prediction generation stage 102 direction_R : [None, 'bidirectional', 'forward', 'backward'] 103 if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages 104 PG_in_FE : bool, default True 105 if `True`, the prediction generation stage is included in the feature extractor and otherwise in the 106 predictor (the output of the feature extractor is used in SSL tasks) 107 rare_dilations : bool, default False 108 if `False`, dilation increases every layer, otherwise every second layer in 109 the prediction generation stage 110 num_heads : int, default 1 111 the number of parallel refinement stages 112 PG_attention : bool, default False 113 if `True`, an attention layer is added to the prediction generation stage 114 R_attention : bool, default False 115 if `True`, an attention layer is added to the refinement stages 116 state_dict_path : str, optional 117 if not `None`, the model state dictionary will be loaded from this path 118 ssl_constructors : list, optional 119 a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate 120 ssl_types : list, optional 121 a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`) 122 ssl_modules : list, optional 123 a list of SSL modules to integrate (used alternatively to `ssl_constructors`) 124 """ 125 126 self.num_layers_R = int(float(num_layers_R)) 127 self.num_R = int(float(num_R)) 128 self.num_f_maps = int(float(num_f_maps)) 129 self.num_classes = int(float(num_classes)) 130 self.dropout_rate = float(dropout_rate) 131 self.exclusive = bool(exclusive) 132 self.num_layers_PG = int(float(num_layers_PG)) 133 self.num_layers_S = int(float(num_layers_S)) 134 self.dim = self._get_dims(dims) 135 self.shared_weights = bool(shared_weights) 136 self.skip_connections_ref = bool(skip_connections_refinement) 137 self.block_size_prediction = int(float(block_size_prediction)) 138 self.block_size_refinement = int(float(block_size_refinement)) 139 self.direction_R = direction_R 140 self.direction_PG = direction_PG 141 self.kernel_size_prediction = int(float(kernel_size_prediction)) 142 self.PG_in_FE = PG_in_FE 143 self.rare_dilations = rare_dilations 144 self.num_heads = int(float(num_heads)) 145 self.PG_attention = PG_attention 146 self.R_attention = R_attention 147 self.multihead = multihead 148 if num_f_maps_R is None: 149 num_f_maps_R = self.num_f_maps 150 self.num_f_maps_R = num_f_maps_R 151 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Parameters
num_f_maps : int
number of feature maps
num_classes : int
number of classes to predict
exclusive : bool
if True, single-label predictions are made; otherwise multi-label
dims : torch.Size
shape of features in the input data
num_layers_R : int
number of layers in the refinement stages
num_R : int
number of refinement stages
num_layers_PG : int
number of layers in the prediction generation stage
num_layers_S : int, default 0
number of layers in the spatial feature extraction stage
dropout_rate : float, default 0.5
dropout rate
shared_weights : bool, default False
if True, weights are shared across refinement stages
skip_connections_refinement : bool, default False
if True, skip connections are added to the refinement stages
block_size_prediction : int, default 0
if not 0, skip connections are added to the prediction generation stage with this interval
block_size_refinement : int, default 0
if not 0, skip connections are added to the refinement stage with this interval
direction_PG : [None, 'bidirectional', 'forward', 'backward']
if not None, a combination of causal and anticausal convolutions are used in the
prediction generation stage
direction_R : [None, 'bidirectional', 'forward', 'backward']
if not None, a combination of causal and anticausal convolutions are used in the refinement stages
PG_in_FE : bool, default True
if True, the prediction generation stage is included in the feature extractor and otherwise in the
predictor (the output of the feature extractor is used in SSL tasks)
rare_dilations : bool, default False
if False, dilation increases every layer, otherwise every second layer in
the prediction generation stage
num_heads : int, default 1
the number of parallel refinement stages
PG_attention : bool, default False
if True, an attention layer is added to the prediction generation stage
R_attention : bool, default False
if True, an attention layer is added to the refinement stages
state_dict_path : str, optional
if not None, the model state dictionary will be loaded from this path
ssl_constructors : list, optional
a list of dlc2action.ssl.base_ssl.SSLConstructor instances to integrate
ssl_types : list, optional
a list of types of the SSL modules to integrate (used alternatively to ssl_constructors)
ssl_modules : list, optional
a list of SSL modules to integrate (used alternatively to ssl_constructors)
230 def features_shape(self) -> torch.Size: 231 """ 232 Get the shape of feature extractor output 233 234 Returns 235 ------- 236 feature_shape : torch.Size 237 shape of feature extractor output 238 """ 239 240 return torch.Size([self.num_f_maps])
Get the shape of feature extractor output
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- feature_extractor
- feature_extractors
- predictor
- ssl_active
- main_task_active
- prompt_function
- class_tensors
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- transform_labels
- forward
243class MS_TCN_P(MS_TCN3): 244 def _get_dims(self, dims): 245 keys = list(dims.keys()) 246 values = list(dims.values()) 247 groups = [key.split("---")[-1] for key in keys] 248 unique_groups = sorted(set(groups)) 249 res = [] 250 for group in unique_groups: 251 res.append(int(sum([x[0] for x, g in zip(values, groups) if g == group]))) 252 if "loaded" in dims: 253 res.append(int(dims["loaded"][0])) 254 return res 255 256 def _PG(self): 257 PG = MultiDilatedTCN( 258 self.num_layers_PG, 259 self.num_f_maps, 260 self.dim, 261 self.direction_PG, 262 self.block_size_prediction, 263 self.kernel_size_prediction, 264 self.rare_dilations, 265 ) 266 return PG
A modification of MS-TCN++ model with additional options
Inherited Members
- MS_TCN3
- MS_TCN3
- num_layers_R
- num_R
- num_f_maps
- num_classes
- dropout_rate
- exclusive
- num_layers_PG
- num_layers_S
- dim
- skip_connections_ref
- block_size_prediction
- block_size_refinement
- direction_R
- direction_PG
- kernel_size_prediction
- PG_in_FE
- rare_dilations
- num_heads
- PG_attention
- R_attention
- multihead
- num_f_maps_R
- features_shape
- dlc2action.model.base_model.Model
- process_labels
- feature_extractor
- feature_extractors
- predictor
- ssl_active
- main_task_active
- prompt_function
- class_tensors
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- transform_labels
- forward