dlc2action.model.edtcn

EDTCN

Adapted from https://github.com/yiskw713/asrf/blob/main/libs/models/tcn.py

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6#
  7# Adapted from ASRF by yiskw713
  8# Adapted from https://github.com/yiskw713/asrf/blob/main/libs/models/tcn.py
  9# Licensed under MIT License
 10#
 11""" EDTCN
 12
 13Adapted from https://github.com/yiskw713/asrf/blob/main/libs/models/tcn.py
 14"""
 15
 16import torch
 17from torch import nn
 18from typing import Tuple, Any
 19from torch.nn import functional as F
 20from dlc2action.model.base_model import Model
 21from typing import Union, List
 22
 23
 24class _NormalizedReLU(nn.Module):
 25    """
 26    Normalized ReLU Activation prposed in the original TCN paper.
 27    the values are divided by the max computed per frame
 28    """
 29
 30    def __init__(self, eps: float = 1e-5) -> None:
 31        super().__init__()
 32        self.eps = eps
 33
 34    def forward(self, x: torch.Tensor) -> torch.Tensor:
 35        x = F.relu(x)
 36        x = x / (x.max(dim=1, keepdim=True)[0] + self.eps)
 37
 38        return x
 39
 40
 41class _EDTCNModule(nn.Module):
 42    """
 43    Encoder Decoder Temporal Convolutional Network
 44    """
 45
 46    def __init__(
 47        self,
 48        in_channel: int,
 49        output_dim: int,
 50        kernel_size: int = 25,
 51        mid_channels: Tuple[int, int] = [128, 160],
 52        **kwargs: Any
 53    ) -> None:
 54        """
 55        Args:
 56            in_channel: int. the number of the channels of input feature
 57            output_dim: int. output classes
 58            kernel_size: int. 25 is proposed in the original paper
 59            mid_channels: list. the list of the number of the channels of the middle layer.
 60                        [96 + 32*1, 96 + 32*2] is proposed in the original paper
 61        Note that this implementation only supports n_layer=2
 62        """
 63        super().__init__()
 64
 65        # encoder
 66        self.enc1 = nn.Conv1d(
 67            in_channel,
 68            mid_channels[0],
 69            kernel_size,
 70            stride=1,
 71            padding=(kernel_size - 1) // 2,
 72        )
 73        self.dropout1 = nn.Dropout(0.3)
 74        self.relu1 = _NormalizedReLU()
 75
 76        self.enc2 = nn.Conv1d(
 77            mid_channels[0],
 78            mid_channels[1],
 79            kernel_size,
 80            stride=1,
 81            padding=(kernel_size - 1) // 2,
 82        )
 83        self.dropout2 = nn.Dropout(0.3)
 84        self.relu2 = _NormalizedReLU()
 85
 86        # decoder
 87        self.dec1 = nn.Conv1d(
 88            mid_channels[1],
 89            mid_channels[1],
 90            kernel_size,
 91            stride=1,
 92            padding=(kernel_size - 1) // 2,
 93        )
 94        self.dropout3 = nn.Dropout(0.3)
 95        self.relu3 = _NormalizedReLU()
 96
 97        self.dec2 = nn.Conv1d(
 98            mid_channels[1],
 99            mid_channels[0],
100            kernel_size,
101            stride=1,
102            padding=(kernel_size - 1) // 2,
103        )
104        self.dropout4 = nn.Dropout(0.3)
105        self.relu4 = _NormalizedReLU()
106
107        self.conv_out = nn.Conv1d(mid_channels[0], output_dim, 1, bias=True)
108
109        self.init_weight()
110
111    def forward(self, x: torch.Tensor) -> torch.Tensor:
112        # encoder 1
113        x1 = self.relu1(self.dropout1(self.enc1(x)))
114        t1 = x1.shape[2]
115        x1 = F.max_pool1d(x1, 2)
116
117        # encoder 2
118        x2 = self.relu2(self.dropout2(self.enc2(x1)))
119        t2 = x2.shape[2]
120        x2 = F.max_pool1d(x2, 2)
121
122        # decoder 1
123        x3 = F.interpolate(x2, size=(t2,), mode="nearest")
124        x3 = self.relu3(self.dropout3(self.dec1(x3)))
125
126        # decoder 2
127        x4 = F.interpolate(x3, size=(t1,), mode="nearest")
128        x4 = self.relu4(self.dropout4(self.dec2(x4)))
129
130        out = self.conv_out(x4)
131
132        return out
133
134    def init_weight(self) -> None:
135        for m in self.modules():
136            if isinstance(m, nn.Conv1d):
137                nn.init.xavier_normal_(m.weight)
138                if m.bias is not None:
139                    torch.nn.init.zeros_(m.bias)
140
141
142class _Predictor(nn.Module):
143    def __init__(self, dim, num_classes):
144        super(_Predictor, self).__init__()
145        self.num_classes = num_classes
146        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
147        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
148
149    def forward(self, x):
150        x = self.conv_out_1(x)
151        x = F.relu(x)
152        x = self.conv_out_2(x)
153        return x
154
155
156class EDTCN(Model):
157    """
158    An implementation of EDTCN (Endoder-Decoder TCN)
159    """
160
161    def __init__(
162        self,
163        num_classes,
164        input_dims,
165        kernel_size,
166        mid_channels,
167        feature_dim=None,
168        state_dict_path=None,
169        ssl_constructors=None,
170        ssl_types=None,
171        ssl_modules=None,
172    ):
173        input_dims = int(sum([s[0] for s in input_dims.values()]))
174        if feature_dim is None:
175            feature_dim = num_classes
176            self.params_predictor = None
177            self.f_shape = None
178        else:
179            self.params_predictor = {
180                "dim": int(feature_dim),
181                "num_classes": int(num_classes),
182            }
183            self.f_shape = torch.Size([int(feature_dim)])
184        self.params = {
185            "output_dim": int(feature_dim),
186            "in_channel": input_dims,
187            "kernel_size": int(kernel_size),
188            "mid_channels": mid_channels,
189        }
190        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
191
192    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
193        return _EDTCNModule(**self.params)
194
195    def _predictor(self) -> torch.nn.Module:
196        if self.params_predictor is None:
197            return nn.Identity()
198        else:
199            return _Predictor(**self.params_predictor)
200
201    def features_shape(self) -> torch.Size:
202        return self.f_shape
class EDTCN(dlc2action.model.base_model.Model):
157class EDTCN(Model):
158    """
159    An implementation of EDTCN (Endoder-Decoder TCN)
160    """
161
162    def __init__(
163        self,
164        num_classes,
165        input_dims,
166        kernel_size,
167        mid_channels,
168        feature_dim=None,
169        state_dict_path=None,
170        ssl_constructors=None,
171        ssl_types=None,
172        ssl_modules=None,
173    ):
174        input_dims = int(sum([s[0] for s in input_dims.values()]))
175        if feature_dim is None:
176            feature_dim = num_classes
177            self.params_predictor = None
178            self.f_shape = None
179        else:
180            self.params_predictor = {
181                "dim": int(feature_dim),
182                "num_classes": int(num_classes),
183            }
184            self.f_shape = torch.Size([int(feature_dim)])
185        self.params = {
186            "output_dim": int(feature_dim),
187            "in_channel": input_dims,
188            "kernel_size": int(kernel_size),
189            "mid_channels": mid_channels,
190        }
191        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
192
193    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
194        return _EDTCNModule(**self.params)
195
196    def _predictor(self) -> torch.nn.Module:
197        if self.params_predictor is None:
198            return nn.Identity()
199        else:
200            return _Predictor(**self.params_predictor)
201
202    def features_shape(self) -> torch.Size:
203        return self.f_shape

An implementation of EDTCN (Endoder-Decoder TCN)

EDTCN( num_classes, input_dims, kernel_size, mid_channels, feature_dim=None, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
162    def __init__(
163        self,
164        num_classes,
165        input_dims,
166        kernel_size,
167        mid_channels,
168        feature_dim=None,
169        state_dict_path=None,
170        ssl_constructors=None,
171        ssl_types=None,
172        ssl_modules=None,
173    ):
174        input_dims = int(sum([s[0] for s in input_dims.values()]))
175        if feature_dim is None:
176            feature_dim = num_classes
177            self.params_predictor = None
178            self.f_shape = None
179        else:
180            self.params_predictor = {
181                "dim": int(feature_dim),
182                "num_classes": int(num_classes),
183            }
184            self.f_shape = torch.Size([int(feature_dim)])
185        self.params = {
186            "output_dim": int(feature_dim),
187            "in_channel": input_dims,
188            "kernel_size": int(kernel_size),
189            "mid_channels": mid_channels,
190        }
191        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored

def features_shape(self) -> torch.Size:
202    def features_shape(self) -> torch.Size:
203        return self.f_shape

Get the shape of feature extractor output

Returns

feature_shape : torch.Size shape of feature extractor output

Inherited Members
dlc2action.model.base_model.Model
process_labels
freeze_feature_extractor
unfreeze_feature_extractor
load_state_dict
ssl_off
ssl_on
main_task_on
main_task_off
set_ssl
extract_features
forward
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
T_destination
state_dict
register_load_state_dict_post_hook
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr