dlc2action.model.ms_tcn

MS-TCN++ (multi-stage temporal convolutional network) variations

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6# Incorporates code adapted from MS-TCN++ by yabufarha
  7# Original work Copyright (c) 2019 June01
  8# Source: https://github.com/sj-li/MS-TCN2
  9# Originally licensed under MIT License
 10# Combined work licensed under GNU AGPLv3
 11#
 12"""
 13MS-TCN++ (multi-stage temporal convolutional network) variations
 14"""
 15
 16from dlc2action.model.base_model import Model
 17from dlc2action.model.ms_tcn_modules import *
 18
 19
 20class Compiled(nn.Module):
 21    def __init__(self, modules):
 22        super(Compiled, self).__init__()
 23        self.module_list = nn.ModuleList(modules)
 24
 25    def forward(self, x, tag=None):
 26        """Forward pass."""
 27        for m in self.module_list:
 28            x = m(x, tag)
 29        return x
 30
 31
 32class MS_TCN3(Model):
 33    """
 34    A modification of MS-TCN++ model with additional options
 35    """
 36
 37    def __init__(
 38        self,
 39        num_f_maps,
 40        num_classes,
 41        exclusive,
 42        dims,
 43        num_layers_R,
 44        num_R,
 45        num_layers_PG,
 46        num_f_maps_R=None,
 47        num_layers_S=0,
 48        dropout_rate=0.5,
 49        shared_weights=False,
 50        skip_connections_refinement=True,
 51        block_size_prediction=0,
 52        block_size_refinement=0,
 53        kernel_size_prediction=3,
 54        direction_PG=None,
 55        direction_R=None,
 56        PG_in_FE=False,
 57        rare_dilations=False,
 58        num_heads=1,
 59        R_attention="none",
 60        PG_attention="none",
 61        state_dict_path=None,
 62        ssl_constructors=None,
 63        ssl_types=None,
 64        ssl_modules=None,
 65        multihead=False,
 66        *args,
 67        **kwargs,
 68    ):
 69        """
 70        Parameters
 71        ----------
 72        num_f_maps : int
 73            number of feature maps
 74        num_classes : int
 75            number of classes to predict
 76        exclusive : bool
 77            if `True`, single-label predictions are made; otherwise multi-label
 78        dims : torch.Size
 79            shape of features in the input data
 80        num_layers_R : int
 81            number of layers in the refinement stages
 82        num_R : int
 83            number of refinement stages
 84        num_layers_PG : int
 85            number of layers in the prediction generation stage
 86        num_layers_S : int, default 0
 87            number of layers in the spatial feature extraction stage
 88        dropout_rate : float, default 0.5
 89            dropout rate
 90        shared_weights : bool, default False
 91            if `True`, weights are shared across refinement stages
 92        skip_connections_refinement : bool, default False
 93            if `True`, skip connections are added to the refinement stages
 94        block_size_prediction : int, default 0
 95            if not 0, skip connections are added to the prediction generation stage with this interval
 96        block_size_refinement : int, default 0
 97            if not 0, skip connections are added to the refinement stage with this interval
 98        direction_PG : [None, 'bidirectional', 'forward', 'backward']
 99            if not `None`, a combination of causal and anticausal convolutions are used in the
100            prediction generation stage
101        direction_R : [None, 'bidirectional', 'forward', 'backward']
102            if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages
103        PG_in_FE : bool, default True
104            if `True`, the prediction generation stage is included in the feature extractor and otherwise in the
105            predictor (the output of the feature extractor is used in SSL tasks)
106        rare_dilations : bool, default False
107            if `False`, dilation increases every layer, otherwise every second layer in
108            the prediction generation stage
109        num_heads : int, default 1
110            the number of parallel refinement stages
111        PG_attention : bool, default False
112            if `True`, an attention layer is added to the prediction generation stage
113        R_attention : bool, default False
114            if `True`, an attention layer is added to the refinement stages
115        state_dict_path : str, optional
116            if not `None`, the model state dictionary will be loaded from this path
117        ssl_constructors : list, optional
118            a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate
119        ssl_types : list, optional
120            a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`)
121        ssl_modules : list, optional
122            a list of SSL modules to integrate (used alternatively to `ssl_constructors`)
123        """
124
125        self.num_layers_R = int(float(num_layers_R))
126        self.num_R = int(float(num_R))
127        self.num_f_maps = int(float(num_f_maps))
128        self.num_classes = int(float(num_classes))
129        self.dropout_rate = float(dropout_rate)
130        self.exclusive = bool(exclusive)
131        self.num_layers_PG = int(float(num_layers_PG))
132        self.num_layers_S = int(float(num_layers_S))
133        self.dim = self._get_dims(dims)
134        self.shared_weights = bool(shared_weights)
135        self.skip_connections_ref = bool(skip_connections_refinement)
136        self.block_size_prediction = int(float(block_size_prediction))
137        self.block_size_refinement = int(float(block_size_refinement))
138        self.direction_R = direction_R
139        self.direction_PG = direction_PG
140        self.kernel_size_prediction = int(float(kernel_size_prediction))
141        self.PG_in_FE = PG_in_FE
142        self.rare_dilations = rare_dilations
143        self.num_heads = int(float(num_heads))
144        self.PG_attention = PG_attention
145        self.R_attention = R_attention
146        self.multihead = multihead
147        if num_f_maps_R is None:
148            num_f_maps_R = self.num_f_maps
149        self.num_f_maps_R = num_f_maps_R
150        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
151
152    def _get_dims(self, dims):
153        return int(sum([s[0] for s in dims.values()]))
154
155    def _PG(self):
156        if self.num_layers_S == 0:
157            dim = self.dim
158        else:
159            dim = self.num_f_maps
160        if self.direction_PG == "bidirectional":
161            PG = DilatedTCNB(
162                num_layers=self.num_layers_PG,
163                num_f_maps=self.num_f_maps,
164                dim=dim,
165                block_size=self.block_size_prediction,
166                kernel_size=self.kernel_size_prediction,
167                rare_dilations=self.rare_dilations,
168            )
169        else:
170            PG = DilatedTCN(
171                num_layers=self.num_layers_PG,
172                num_f_maps=self.num_f_maps,
173                dim=dim,
174                direction=self.direction_PG,
175                block_size=self.block_size_prediction,
176                kernel_size=self.kernel_size_prediction,
177                rare_dilations=self.rare_dilations,
178                attention=self.PG_attention,
179                multihead=self.multihead,
180            )
181        return PG
182
183    def _feature_extractor(self):
184        if self.num_layers_S == 0:
185            if self.PG_in_FE:
186                print("MS-TCN using the prediction generator as a feature extractor")
187                return self._PG()
188            else:
189                print("MS-TCN without a feature extractor -> no SSL possible!")
190                return nn.Identity()
191
192        print("MS-TCN using a spatial feature extractor")
193        feature_extractor = SpatialFeatures(
194            self.num_layers_S,
195            self.num_f_maps,
196            self.dim,
197            self.block_size_prediction,
198        )
199        if self.PG_in_FE:
200            print("  -> also has the prediction generator as a feature extractor")
201            PG = self._PG()
202            feature_extractor = [feature_extractor, PG]
203        return feature_extractor
204
205    def _predictor(self):
206        if self.shared_weights:
207            prediction_module = MSRefinementShared
208        else:
209            prediction_module = MSRefinement
210        predictor = prediction_module(
211            num_layers_R=int(self.num_layers_R),
212            num_R=int(self.num_R),
213            num_f_maps_input=int(self.num_f_maps),
214            num_f_maps=int(self.num_f_maps_R),
215            num_classes=int(self.num_classes),
216            dropout_rate=self.dropout_rate,
217            exclusive=self.exclusive,
218            skip_connections=self.skip_connections_ref,
219            direction=self.direction_R,
220            block_size=self.block_size_refinement,
221            num_heads=self.num_heads,
222            attention=self.R_attention,
223        )
224        if not self.PG_in_FE:
225            PG = self._PG()
226            predictor = Compiled([PG, predictor])
227        return predictor
228
229    def features_shape(self) -> torch.Size:
230        """
231        Get the shape of feature extractor output
232
233        Returns
234        -------
235        feature_shape : torch.Size
236            shape of feature extractor output
237        """
238
239        return torch.Size([self.num_f_maps])
240
241
242class MS_TCN_P(MS_TCN3):
243    def _get_dims(self, dims):
244        keys = list(dims.keys())
245        values = list(dims.values())
246        groups = [key.split("---")[-1] for key in keys]
247        unique_groups = sorted(set(groups))
248        res = []
249        for group in unique_groups:
250            res.append(int(sum([x[0] for x, g in zip(values, groups) if g == group])))
251        if "loaded" in dims:
252            res.append(int(dims["loaded"][0]))
253        return res
254
255    def _PG(self):
256        PG = MultiDilatedTCN(
257            self.num_layers_PG,
258            self.num_f_maps,
259            self.dim,
260            self.direction_PG,
261            self.block_size_prediction,
262            self.kernel_size_prediction,
263            self.rare_dilations,
264        )
265        return PG
266
267
268# class MS_TCNC(Model):
269#     """
270#     Basic MS-TCN++ model with options for shared weights and added skip connections
271#     """
272#
273#     def __init__(
274#         self,
275#         num_layers_R,
276#         num_R,
277#         num_f_maps,
278#         num_classes,
279#         exclusive,
280#         num_layers_PG,
281#         num_layers_S,
282#         dims,
283#         len_segment,
284#         dropout_rate=0.5,
285#         shared_weights=False,
286#         skip_connections_refinement=True,
287#         block_size_prediction=5,
288#         block_size_refinement=0,
289#         kernel_size_prediction=3,
290#         direction_PG=None,
291#         direction_R=None,
292#         PG_in_FE=False,
293#         state_dict_path=None,
294#         ssl_constructors=None,
295#         ssl_types=None,
296#         ssl_modules=None,
297#     ):
298#         """
299#         Parameters
300#         ----------
301#         num_layers_R : int
302#             number of layers in the refinement stages
303#         num_R : int
304#             number of refinement stages
305#         num_f_maps : int
306#             number of feature maps
307#         num_classes : int
308#             number of classes to predict
309#         exclusive : bool
310#             if `True`, single-label predictions are made; otherwise multi-label
311#         num_layers_PG : int
312#             number of layers in the prediction generation stage
313#         dims : torch.Size
314#             shape of features in the input data
315#         dropout_rate : float, default 0.5
316#             dropout rate
317#         shared_weights : bool, default False
318#             if `True`, weights are shared across refinement stages
319#         skip_connections_refinement : bool, default False
320#             if `True`, skip connections are added to the refinement stages
321#         block_size_prediction : int, optional
322#             if not 'None', skip connections are added to the prediction generation stage with this interval
323#         direction_PG : bool, default True
324#             if True, causal convolutions are used in the prediction generation stage
325#         direction_R : bool, default False
326#             if True, causal convolutions are used in the refinement stages
327#         state_dict_path : str, optional
328#             if not `None`, the model state dictionary will be loaded from this path
329#         ssl_constructors : list, optional
330#             a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate
331#         ssl_types : list, optional
332#             a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`)
333#         ssl_modules : list, optional
334#             a list of SSL modules to integrate (used alternatively to `ssl_constructors`)
335#         """
336#
337#         if len(dims) > 1:
338#             raise RuntimeError(
339#                 "The MS-TCN++ model expects the input data to be 2-dimensional; "
340#                 f"got {len(dims) + 1} dimensions"
341#             )
342#         self.num_layers_R = int(num_layers_R)
343#         self.num_R = int(num_R)
344#         self.num_f_maps = int(num_f_maps)
345#         self.num_classes = int(num_classes)
346#         self.dropout_rate = dropout_rate
347#         self.exclusive = exclusive
348#         self.num_layers_PG = int(num_layers_PG)
349#         self.num_layers_S = int(num_layers_S)
350#         self.dim = int(dims[0])
351#         self.shared_weights = shared_weights
352#         self.skip_connections_ref = skip_connections_refinement
353#         self.block_size_prediction = int(block_size_prediction)
354#         self.block_size_refinement = int(block_size_refinement)
355#         self.direction_R = direction_R
356#         self.direction_PG = direction_PG
357#         self.kernel_size_prediction = int(kernel_size_prediction)
358#         self.PG_in_FE = PG_in_FE
359#         self.len_segment = len_segment
360#         super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
361#
362#     def _PG(self):
363#         PG = DilatedTCNC(
364#             num_f_maps=self.num_f_maps,
365#             num_layers_PG=self.num_layers_PG,
366#             len_segment=self.len_segment,
367#             block_size_prediction=self.block_size_prediction,
368#             kernel_size_prediction=self.kernel_size_prediction,
369#             direction_PG=self.direction_PG,
370#         )
371#         return PG
372#
373#     def _feature_extractor(self):
374#         feature_extractor = SpatialFeatures(
375#             num_layers=self.num_layers_S,
376#             num_f_maps=self.num_f_maps,
377#             dim=self.dim,
378#             block_size=self.block_size_prediction,
379#         )
380#         if self.PG_in_FE:
381#             PG = self._PG()
382#             feature_extractor = [feature_extractor, PG]
383#         return feature_extractor
384#
385#     def _predictor(self):
386#
387#         if self.shared_weights:
388#             prediction_module = MSRefinementShared
389#         else:
390#             prediction_module = MSRefinement
391#         predictor = prediction_module(
392#             num_layers_R=int(self.num_layers_R),
393#             num_R=int(self.num_R),
394#             num_f_maps=int(self.num_f_maps),
395#             num_classes=int(self.num_classes),
396#             dropout_rate=self.dropout_rate,
397#             exclusive=self.exclusive,
398#             skip_connections=self.skip_connections_ref,
399#             direction=self.direction_R,
400#             block_size=self.block_size_refinement,
401#         )
402#         if not self.PG_in_FE:
403#             PG = self._PG()
404#             predictor = Compiled([PG, predictor])
405#         return predictor
406#
407#     def features_shape(self) -> torch.Size:
408#         """
409#         Get the shape of feature extractor output
410#
411#         Returns
412#         -------
413#         feature_shape : torch.Size
414#             shape of feature extractor output
415#         """
416#
417#         return torch.Size([self.num_f_maps])
418#
419# class MS_TCNA(Model):
420#     """
421#     Basic MS-TCN++ model with additional options
422#     """
423#
424#     def __init__(
425#         self,
426#         num_f_maps,
427#         num_classes,
428#         exclusive,
429#         dims,
430#         num_layers_R,
431#         num_R,
432#         num_layers_PG,
433#         len_segment,
434#         num_f_maps_R=None,
435#         num_layers_S=0,
436#         dropout_rate=0.5,
437#         skip_connections_refinement=True,
438#         block_size_prediction=0,
439#         block_size_refinement=0,
440#         kernel_size_prediction=3,
441#         direction_PG=None,
442#         direction_R=None,
443#         PG_in_FE=False,
444#         rare_dilations=False,
445#         state_dict_path=None,
446#         ssl_constructors=None,
447#         ssl_types=None,
448#         ssl_modules=None,
449#         *args, **kwargs
450#     ):
451#         """
452#         Parameters
453#         ----------
454#         num_f_maps : int
455#             number of feature maps
456#         num_classes : int
457#             number of classes to predict
458#         exclusive : bool
459#             if `True`, single-label predictions are made; otherwise multi-label
460#         dims : torch.Size
461#             shape of features in the input data
462#         num_layers_R : int
463#             number of layers in the refinement stages
464#         num_R : int
465#             number of refinement stages
466#         num_layers_PG : int
467#             number of layers in the prediction generation stage
468#         num_layers_S : int, default 0
469#             number of layers in the spatial feature extraction stage
470#         dropout_rate : float, default 0.5
471#             dropout rate
472#         shared_weights : bool, default False
473#             if `True`, weights are shared across refinement stages
474#         skip_connections_refinement : bool, default False
475#             if `True`, skip connections are added to the refinement stages
476#         block_size_prediction : int, default 0
477#             if not 0, skip connections are added to the prediction generation stage with this interval
478#         block_size_refinement : int, default 0
479#             if not 0, skip connections are added to the refinement stage with this interval
480#         direction_PG : [None, 'bidirectional', 'forward', 'backward']
481#             if not `None`, a combination of causal and anticausal convolutions are used in the
482#             prediction generation stage
483#         direction_R : [None, 'bidirectional', 'forward', 'backward']
484#             if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages
485#         PG_in_FE : bool, default True
486#             if `True`, the prediction generation stage is included in the feature extractor and otherwise in the
487#             predictor (the output of the feature extractor is used in SSL tasks)
488#         rare_dilations : bool, default False
489#             if `False`, dilation increases every layer, otherwise every second layer in
490#             the prediction generation stage
491#         num_heads : int, default 1
492#             the number of parallel refinement stages
493#         state_dict_path : str, optional
494#             if not `None`, the model state dictionary will be loaded from this path
495#         ssl_constructors : list, optional
496#             a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate
497#         ssl_types : list, optional
498#             a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`)
499#         ssl_modules : list, optional
500#             a list of SSL modules to integrate (used alternatively to `ssl_constructors`)
501#         """
502#
503#         if len(dims) > 1:
504#             raise RuntimeError(
505#                 "The MS-TCN++ model expects the input data to be 2-dimensional; "
506#                 f"got {len(dims) + 1} dimensions"
507#             )
508#         self.num_layers_R = int(num_layers_R)
509#         self.num_R = int(num_R)
510#         self.num_f_maps = int(num_f_maps)
511#         self.num_classes = int(num_classes)
512#         self.dropout_rate = dropout_rate
513#         self.exclusive = exclusive
514#         self.num_layers_PG = int(num_layers_PG)
515#         self.num_layers_S = int(num_layers_S)
516#         self.dim = int(dims[0])
517#         self.skip_connections_ref = skip_connections_refinement
518#         self.block_size_prediction = int(block_size_prediction)
519#         self.block_size_refinement = int(block_size_refinement)
520#         self.direction_R = direction_R
521#         self.direction_PG = direction_PG
522#         self.kernel_size_prediction = int(kernel_size_prediction)
523#         self.PG_in_FE = PG_in_FE
524#         self.rare_dilations = rare_dilations
525#         self.len_segment = len_segment
526#         if num_f_maps_R is None:
527#             num_f_maps_R = num_f_maps
528#         self.num_f_maps_R = num_f_maps_R
529#         super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
530#
531#     def _PG(self):
532#         if self.num_layers_S == 0:
533#             dim = self.dim
534#         else:
535#             dim = self.num_f_maps
536#         if self.direction_PG == "bidirectional":
537#             PG = DilatedTCNB(
538#                 num_layers=self.num_layers_PG,
539#                 num_f_maps=self.num_f_maps,
540#                 dim=dim,
541#                 block_size=self.block_size_prediction,
542#                 kernel_size=self.kernel_size_prediction,
543#                 rare_dilations=self.rare_dilations,
544#             )
545#         else:
546#             PG = DilatedTCN(
547#                 num_layers=self.num_layers_PG,
548#                 num_f_maps=self.num_f_maps,
549#                 dim=dim,
550#                 direction=self.direction_PG,
551#                 block_size=self.block_size_prediction,
552#                 kernel_size=self.kernel_size_prediction,
553#                 rare_dilations=self.rare_dilations,
554#             )
555#         return PG
556#
557#     def _feature_extractor(self):
558#         if self.num_layers_S == 0:
559#             if self.PG_in_FE:
560#                 return self._PG()
561#             else:
562#                 return nn.Identity()
563#         feature_extractor = SpatialFeatures(
564#             self.num_layers_S,
565#             self.num_f_maps,
566#             self.dim,
567#             self.block_size_prediction,
568#         )
569#         if self.PG_in_FE:
570#             PG = self._PG()
571#             feature_extractor = [feature_extractor, PG]
572#         return feature_extractor
573#
574#     def _predictor(self):
575#         predictor = MSRefinementAttention(
576#             num_layers_R=int(self.num_layers_R),
577#             num_R=int(self.num_R),
578#             num_f_maps_input=int(self.num_f_maps),
579#             num_f_maps=int(self.num_f_maps_R),
580#             num_classes=int(self.num_classes),
581#             dropout_rate=self.dropout_rate,
582#             exclusive=self.exclusive,
583#             skip_connections=self.skip_connections_ref,
584#             block_size=self.block_size_refinement,
585#             len_segment=self.len_segment,
586#         )
587#         if not self.PG_in_FE:
588#             PG = self._PG()
589#             predictor = Compiled([PG, predictor])
590#         return predictor
591#
592#     def features_shape(self) -> torch.Size:
593#         """
594#         Get the shape of feature extractor output
595#
596#         Returns
597#         -------
598#         feature_shape : torch.Size
599#             shape of feature extractor output
600#         """
601#
602#         return torch.Size([self.num_f_maps])
class Compiled(torch.nn.modules.module.Module):
21class Compiled(nn.Module):
22    def __init__(self, modules):
23        super(Compiled, self).__init__()
24        self.module_list = nn.ModuleList(modules)
25
26    def forward(self, x, tag=None):
27        """Forward pass."""
28        for m in self.module_list:
29            x = m(x, tag)
30        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Compiled(modules)
22    def __init__(self, modules):
23        super(Compiled, self).__init__()
24        self.module_list = nn.ModuleList(modules)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

module_list
def forward(self, x, tag=None):
26    def forward(self, x, tag=None):
27        """Forward pass."""
28        for m in self.module_list:
29            x = m(x, tag)
30        return x

Forward pass.

class MS_TCN3(dlc2action.model.base_model.Model):
 33class MS_TCN3(Model):
 34    """
 35    A modification of MS-TCN++ model with additional options
 36    """
 37
 38    def __init__(
 39        self,
 40        num_f_maps,
 41        num_classes,
 42        exclusive,
 43        dims,
 44        num_layers_R,
 45        num_R,
 46        num_layers_PG,
 47        num_f_maps_R=None,
 48        num_layers_S=0,
 49        dropout_rate=0.5,
 50        shared_weights=False,
 51        skip_connections_refinement=True,
 52        block_size_prediction=0,
 53        block_size_refinement=0,
 54        kernel_size_prediction=3,
 55        direction_PG=None,
 56        direction_R=None,
 57        PG_in_FE=False,
 58        rare_dilations=False,
 59        num_heads=1,
 60        R_attention="none",
 61        PG_attention="none",
 62        state_dict_path=None,
 63        ssl_constructors=None,
 64        ssl_types=None,
 65        ssl_modules=None,
 66        multihead=False,
 67        *args,
 68        **kwargs,
 69    ):
 70        """
 71        Parameters
 72        ----------
 73        num_f_maps : int
 74            number of feature maps
 75        num_classes : int
 76            number of classes to predict
 77        exclusive : bool
 78            if `True`, single-label predictions are made; otherwise multi-label
 79        dims : torch.Size
 80            shape of features in the input data
 81        num_layers_R : int
 82            number of layers in the refinement stages
 83        num_R : int
 84            number of refinement stages
 85        num_layers_PG : int
 86            number of layers in the prediction generation stage
 87        num_layers_S : int, default 0
 88            number of layers in the spatial feature extraction stage
 89        dropout_rate : float, default 0.5
 90            dropout rate
 91        shared_weights : bool, default False
 92            if `True`, weights are shared across refinement stages
 93        skip_connections_refinement : bool, default False
 94            if `True`, skip connections are added to the refinement stages
 95        block_size_prediction : int, default 0
 96            if not 0, skip connections are added to the prediction generation stage with this interval
 97        block_size_refinement : int, default 0
 98            if not 0, skip connections are added to the refinement stage with this interval
 99        direction_PG : [None, 'bidirectional', 'forward', 'backward']
100            if not `None`, a combination of causal and anticausal convolutions are used in the
101            prediction generation stage
102        direction_R : [None, 'bidirectional', 'forward', 'backward']
103            if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages
104        PG_in_FE : bool, default True
105            if `True`, the prediction generation stage is included in the feature extractor and otherwise in the
106            predictor (the output of the feature extractor is used in SSL tasks)
107        rare_dilations : bool, default False
108            if `False`, dilation increases every layer, otherwise every second layer in
109            the prediction generation stage
110        num_heads : int, default 1
111            the number of parallel refinement stages
112        PG_attention : bool, default False
113            if `True`, an attention layer is added to the prediction generation stage
114        R_attention : bool, default False
115            if `True`, an attention layer is added to the refinement stages
116        state_dict_path : str, optional
117            if not `None`, the model state dictionary will be loaded from this path
118        ssl_constructors : list, optional
119            a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate
120        ssl_types : list, optional
121            a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`)
122        ssl_modules : list, optional
123            a list of SSL modules to integrate (used alternatively to `ssl_constructors`)
124        """
125
126        self.num_layers_R = int(float(num_layers_R))
127        self.num_R = int(float(num_R))
128        self.num_f_maps = int(float(num_f_maps))
129        self.num_classes = int(float(num_classes))
130        self.dropout_rate = float(dropout_rate)
131        self.exclusive = bool(exclusive)
132        self.num_layers_PG = int(float(num_layers_PG))
133        self.num_layers_S = int(float(num_layers_S))
134        self.dim = self._get_dims(dims)
135        self.shared_weights = bool(shared_weights)
136        self.skip_connections_ref = bool(skip_connections_refinement)
137        self.block_size_prediction = int(float(block_size_prediction))
138        self.block_size_refinement = int(float(block_size_refinement))
139        self.direction_R = direction_R
140        self.direction_PG = direction_PG
141        self.kernel_size_prediction = int(float(kernel_size_prediction))
142        self.PG_in_FE = PG_in_FE
143        self.rare_dilations = rare_dilations
144        self.num_heads = int(float(num_heads))
145        self.PG_attention = PG_attention
146        self.R_attention = R_attention
147        self.multihead = multihead
148        if num_f_maps_R is None:
149            num_f_maps_R = self.num_f_maps
150        self.num_f_maps_R = num_f_maps_R
151        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
152
153    def _get_dims(self, dims):
154        return int(sum([s[0] for s in dims.values()]))
155
156    def _PG(self):
157        if self.num_layers_S == 0:
158            dim = self.dim
159        else:
160            dim = self.num_f_maps
161        if self.direction_PG == "bidirectional":
162            PG = DilatedTCNB(
163                num_layers=self.num_layers_PG,
164                num_f_maps=self.num_f_maps,
165                dim=dim,
166                block_size=self.block_size_prediction,
167                kernel_size=self.kernel_size_prediction,
168                rare_dilations=self.rare_dilations,
169            )
170        else:
171            PG = DilatedTCN(
172                num_layers=self.num_layers_PG,
173                num_f_maps=self.num_f_maps,
174                dim=dim,
175                direction=self.direction_PG,
176                block_size=self.block_size_prediction,
177                kernel_size=self.kernel_size_prediction,
178                rare_dilations=self.rare_dilations,
179                attention=self.PG_attention,
180                multihead=self.multihead,
181            )
182        return PG
183
184    def _feature_extractor(self):
185        if self.num_layers_S == 0:
186            if self.PG_in_FE:
187                print("MS-TCN using the prediction generator as a feature extractor")
188                return self._PG()
189            else:
190                print("MS-TCN without a feature extractor -> no SSL possible!")
191                return nn.Identity()
192
193        print("MS-TCN using a spatial feature extractor")
194        feature_extractor = SpatialFeatures(
195            self.num_layers_S,
196            self.num_f_maps,
197            self.dim,
198            self.block_size_prediction,
199        )
200        if self.PG_in_FE:
201            print("  -> also has the prediction generator as a feature extractor")
202            PG = self._PG()
203            feature_extractor = [feature_extractor, PG]
204        return feature_extractor
205
206    def _predictor(self):
207        if self.shared_weights:
208            prediction_module = MSRefinementShared
209        else:
210            prediction_module = MSRefinement
211        predictor = prediction_module(
212            num_layers_R=int(self.num_layers_R),
213            num_R=int(self.num_R),
214            num_f_maps_input=int(self.num_f_maps),
215            num_f_maps=int(self.num_f_maps_R),
216            num_classes=int(self.num_classes),
217            dropout_rate=self.dropout_rate,
218            exclusive=self.exclusive,
219            skip_connections=self.skip_connections_ref,
220            direction=self.direction_R,
221            block_size=self.block_size_refinement,
222            num_heads=self.num_heads,
223            attention=self.R_attention,
224        )
225        if not self.PG_in_FE:
226            PG = self._PG()
227            predictor = Compiled([PG, predictor])
228        return predictor
229
230    def features_shape(self) -> torch.Size:
231        """
232        Get the shape of feature extractor output
233
234        Returns
235        -------
236        feature_shape : torch.Size
237            shape of feature extractor output
238        """
239
240        return torch.Size([self.num_f_maps])

A modification of MS-TCN++ model with additional options

MS_TCN3( num_f_maps, num_classes, exclusive, dims, num_layers_R, num_R, num_layers_PG, num_f_maps_R=None, num_layers_S=0, dropout_rate=0.5, shared_weights=False, skip_connections_refinement=True, block_size_prediction=0, block_size_refinement=0, kernel_size_prediction=3, direction_PG=None, direction_R=None, PG_in_FE=False, rare_dilations=False, num_heads=1, R_attention='none', PG_attention='none', state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None, multihead=False, *args, **kwargs)
 38    def __init__(
 39        self,
 40        num_f_maps,
 41        num_classes,
 42        exclusive,
 43        dims,
 44        num_layers_R,
 45        num_R,
 46        num_layers_PG,
 47        num_f_maps_R=None,
 48        num_layers_S=0,
 49        dropout_rate=0.5,
 50        shared_weights=False,
 51        skip_connections_refinement=True,
 52        block_size_prediction=0,
 53        block_size_refinement=0,
 54        kernel_size_prediction=3,
 55        direction_PG=None,
 56        direction_R=None,
 57        PG_in_FE=False,
 58        rare_dilations=False,
 59        num_heads=1,
 60        R_attention="none",
 61        PG_attention="none",
 62        state_dict_path=None,
 63        ssl_constructors=None,
 64        ssl_types=None,
 65        ssl_modules=None,
 66        multihead=False,
 67        *args,
 68        **kwargs,
 69    ):
 70        """
 71        Parameters
 72        ----------
 73        num_f_maps : int
 74            number of feature maps
 75        num_classes : int
 76            number of classes to predict
 77        exclusive : bool
 78            if `True`, single-label predictions are made; otherwise multi-label
 79        dims : torch.Size
 80            shape of features in the input data
 81        num_layers_R : int
 82            number of layers in the refinement stages
 83        num_R : int
 84            number of refinement stages
 85        num_layers_PG : int
 86            number of layers in the prediction generation stage
 87        num_layers_S : int, default 0
 88            number of layers in the spatial feature extraction stage
 89        dropout_rate : float, default 0.5
 90            dropout rate
 91        shared_weights : bool, default False
 92            if `True`, weights are shared across refinement stages
 93        skip_connections_refinement : bool, default False
 94            if `True`, skip connections are added to the refinement stages
 95        block_size_prediction : int, default 0
 96            if not 0, skip connections are added to the prediction generation stage with this interval
 97        block_size_refinement : int, default 0
 98            if not 0, skip connections are added to the refinement stage with this interval
 99        direction_PG : [None, 'bidirectional', 'forward', 'backward']
100            if not `None`, a combination of causal and anticausal convolutions are used in the
101            prediction generation stage
102        direction_R : [None, 'bidirectional', 'forward', 'backward']
103            if not `None`, a combination of causal and anticausal convolutions are used in the refinement stages
104        PG_in_FE : bool, default True
105            if `True`, the prediction generation stage is included in the feature extractor and otherwise in the
106            predictor (the output of the feature extractor is used in SSL tasks)
107        rare_dilations : bool, default False
108            if `False`, dilation increases every layer, otherwise every second layer in
109            the prediction generation stage
110        num_heads : int, default 1
111            the number of parallel refinement stages
112        PG_attention : bool, default False
113            if `True`, an attention layer is added to the prediction generation stage
114        R_attention : bool, default False
115            if `True`, an attention layer is added to the refinement stages
116        state_dict_path : str, optional
117            if not `None`, the model state dictionary will be loaded from this path
118        ssl_constructors : list, optional
119            a list of `dlc2action.ssl.base_ssl.SSLConstructor` instances to integrate
120        ssl_types : list, optional
121            a list of types of the SSL modules to integrate (used alternatively to `ssl_constructors`)
122        ssl_modules : list, optional
123            a list of SSL modules to integrate (used alternatively to `ssl_constructors`)
124        """
125
126        self.num_layers_R = int(float(num_layers_R))
127        self.num_R = int(float(num_R))
128        self.num_f_maps = int(float(num_f_maps))
129        self.num_classes = int(float(num_classes))
130        self.dropout_rate = float(dropout_rate)
131        self.exclusive = bool(exclusive)
132        self.num_layers_PG = int(float(num_layers_PG))
133        self.num_layers_S = int(float(num_layers_S))
134        self.dim = self._get_dims(dims)
135        self.shared_weights = bool(shared_weights)
136        self.skip_connections_ref = bool(skip_connections_refinement)
137        self.block_size_prediction = int(float(block_size_prediction))
138        self.block_size_refinement = int(float(block_size_refinement))
139        self.direction_R = direction_R
140        self.direction_PG = direction_PG
141        self.kernel_size_prediction = int(float(kernel_size_prediction))
142        self.PG_in_FE = PG_in_FE
143        self.rare_dilations = rare_dilations
144        self.num_heads = int(float(num_heads))
145        self.PG_attention = PG_attention
146        self.R_attention = R_attention
147        self.multihead = multihead
148        if num_f_maps_R is None:
149            num_f_maps_R = self.num_f_maps
150        self.num_f_maps_R = num_f_maps_R
151        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Parameters

num_f_maps : int number of feature maps num_classes : int number of classes to predict exclusive : bool if True, single-label predictions are made; otherwise multi-label dims : torch.Size shape of features in the input data num_layers_R : int number of layers in the refinement stages num_R : int number of refinement stages num_layers_PG : int number of layers in the prediction generation stage num_layers_S : int, default 0 number of layers in the spatial feature extraction stage dropout_rate : float, default 0.5 dropout rate shared_weights : bool, default False if True, weights are shared across refinement stages skip_connections_refinement : bool, default False if True, skip connections are added to the refinement stages block_size_prediction : int, default 0 if not 0, skip connections are added to the prediction generation stage with this interval block_size_refinement : int, default 0 if not 0, skip connections are added to the refinement stage with this interval direction_PG : [None, 'bidirectional', 'forward', 'backward'] if not None, a combination of causal and anticausal convolutions are used in the prediction generation stage direction_R : [None, 'bidirectional', 'forward', 'backward'] if not None, a combination of causal and anticausal convolutions are used in the refinement stages PG_in_FE : bool, default True if True, the prediction generation stage is included in the feature extractor and otherwise in the predictor (the output of the feature extractor is used in SSL tasks) rare_dilations : bool, default False if False, dilation increases every layer, otherwise every second layer in the prediction generation stage num_heads : int, default 1 the number of parallel refinement stages PG_attention : bool, default False if True, an attention layer is added to the prediction generation stage R_attention : bool, default False if True, an attention layer is added to the refinement stages state_dict_path : str, optional if not None, the model state dictionary will be loaded from this path ssl_constructors : list, optional a list of dlc2action.ssl.base_ssl.SSLConstructor instances to integrate ssl_types : list, optional a list of types of the SSL modules to integrate (used alternatively to ssl_constructors) ssl_modules : list, optional a list of SSL modules to integrate (used alternatively to ssl_constructors)

num_layers_R
num_R
num_f_maps
num_classes
dropout_rate
exclusive
num_layers_PG
num_layers_S
dim
shared_weights
skip_connections_ref
block_size_prediction
block_size_refinement
direction_R
direction_PG
kernel_size_prediction
PG_in_FE
rare_dilations
num_heads
PG_attention
R_attention
multihead
num_f_maps_R
def features_shape(self) -> torch.Size:
230    def features_shape(self) -> torch.Size:
231        """
232        Get the shape of feature extractor output
233
234        Returns
235        -------
236        feature_shape : torch.Size
237            shape of feature extractor output
238        """
239
240        return torch.Size([self.num_f_maps])

Get the shape of feature extractor output

Returns

feature_shape : torch.Size shape of feature extractor output

class MS_TCN_P(MS_TCN3):
243class MS_TCN_P(MS_TCN3):
244    def _get_dims(self, dims):
245        keys = list(dims.keys())
246        values = list(dims.values())
247        groups = [key.split("---")[-1] for key in keys]
248        unique_groups = sorted(set(groups))
249        res = []
250        for group in unique_groups:
251            res.append(int(sum([x[0] for x, g in zip(values, groups) if g == group])))
252        if "loaded" in dims:
253            res.append(int(dims["loaded"][0]))
254        return res
255
256    def _PG(self):
257        PG = MultiDilatedTCN(
258            self.num_layers_PG,
259            self.num_f_maps,
260            self.dim,
261            self.direction_PG,
262            self.block_size_prediction,
263            self.kernel_size_prediction,
264            self.rare_dilations,
265        )
266        return PG

A modification of MS-TCN++ model with additional options