dlc2action.model.edtcn
EDTCN
Adapted from https://github.com/yiskw713/asrf/blob/main/libs/models/tcn.py
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# 7# Adapted from ASRF by yiskw713 8# Adapted from https://github.com/yiskw713/asrf/blob/main/libs/models/tcn.py 9# Licensed under MIT License 10# 11""" EDTCN 12 13Adapted from https://github.com/yiskw713/asrf/blob/main/libs/models/tcn.py 14""" 15 16import torch 17from torch import nn 18from typing import Tuple, Any 19from torch.nn import functional as F 20from dlc2action.model.base_model import Model 21from typing import Union, List 22 23 24class _NormalizedReLU(nn.Module): 25 """ 26 Normalized ReLU Activation prposed in the original TCN paper. 27 the values are divided by the max computed per frame 28 """ 29 30 def __init__(self, eps: float = 1e-5) -> None: 31 super().__init__() 32 self.eps = eps 33 34 def forward(self, x: torch.Tensor) -> torch.Tensor: 35 x = F.relu(x) 36 x = x / (x.max(dim=1, keepdim=True)[0] + self.eps) 37 38 return x 39 40 41class _EDTCNModule(nn.Module): 42 """ 43 Encoder Decoder Temporal Convolutional Network 44 """ 45 46 def __init__( 47 self, 48 in_channel: int, 49 output_dim: int, 50 kernel_size: int = 25, 51 mid_channels: Tuple[int, int] = [128, 160], 52 **kwargs: Any 53 ) -> None: 54 """ 55 Args: 56 in_channel: int. the number of the channels of input feature 57 output_dim: int. output classes 58 kernel_size: int. 25 is proposed in the original paper 59 mid_channels: list. the list of the number of the channels of the middle layer. 60 [96 + 32*1, 96 + 32*2] is proposed in the original paper 61 Note that this implementation only supports n_layer=2 62 """ 63 super().__init__() 64 65 # encoder 66 self.enc1 = nn.Conv1d( 67 in_channel, 68 mid_channels[0], 69 kernel_size, 70 stride=1, 71 padding=(kernel_size - 1) // 2, 72 ) 73 self.dropout1 = nn.Dropout(0.3) 74 self.relu1 = _NormalizedReLU() 75 76 self.enc2 = nn.Conv1d( 77 mid_channels[0], 78 mid_channels[1], 79 kernel_size, 80 stride=1, 81 padding=(kernel_size - 1) // 2, 82 ) 83 self.dropout2 = nn.Dropout(0.3) 84 self.relu2 = _NormalizedReLU() 85 86 # decoder 87 self.dec1 = nn.Conv1d( 88 mid_channels[1], 89 mid_channels[1], 90 kernel_size, 91 stride=1, 92 padding=(kernel_size - 1) // 2, 93 ) 94 self.dropout3 = nn.Dropout(0.3) 95 self.relu3 = _NormalizedReLU() 96 97 self.dec2 = nn.Conv1d( 98 mid_channels[1], 99 mid_channels[0], 100 kernel_size, 101 stride=1, 102 padding=(kernel_size - 1) // 2, 103 ) 104 self.dropout4 = nn.Dropout(0.3) 105 self.relu4 = _NormalizedReLU() 106 107 self.conv_out = nn.Conv1d(mid_channels[0], output_dim, 1, bias=True) 108 109 self.init_weight() 110 111 def forward(self, x: torch.Tensor) -> torch.Tensor: 112 # encoder 1 113 x1 = self.relu1(self.dropout1(self.enc1(x))) 114 t1 = x1.shape[2] 115 x1 = F.max_pool1d(x1, 2) 116 117 # encoder 2 118 x2 = self.relu2(self.dropout2(self.enc2(x1))) 119 t2 = x2.shape[2] 120 x2 = F.max_pool1d(x2, 2) 121 122 # decoder 1 123 x3 = F.interpolate(x2, size=(t2,), mode="nearest") 124 x3 = self.relu3(self.dropout3(self.dec1(x3))) 125 126 # decoder 2 127 x4 = F.interpolate(x3, size=(t1,), mode="nearest") 128 x4 = self.relu4(self.dropout4(self.dec2(x4))) 129 130 out = self.conv_out(x4) 131 132 return out 133 134 def init_weight(self) -> None: 135 for m in self.modules(): 136 if isinstance(m, nn.Conv1d): 137 nn.init.xavier_normal_(m.weight) 138 if m.bias is not None: 139 torch.nn.init.zeros_(m.bias) 140 141 142class _Predictor(nn.Module): 143 def __init__(self, dim, num_classes): 144 super(_Predictor, self).__init__() 145 self.num_classes = num_classes 146 self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1) 147 self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1) 148 149 def forward(self, x): 150 x = self.conv_out_1(x) 151 x = F.relu(x) 152 x = self.conv_out_2(x) 153 return x 154 155 156class EDTCN(Model): 157 """ 158 An implementation of EDTCN (Endoder-Decoder TCN) 159 """ 160 161 def __init__( 162 self, 163 num_classes, 164 input_dims, 165 kernel_size, 166 mid_channels, 167 feature_dim=None, 168 state_dict_path=None, 169 ssl_constructors=None, 170 ssl_types=None, 171 ssl_modules=None, 172 ): 173 input_dims = int(sum([s[0] for s in input_dims.values()])) 174 if feature_dim is None: 175 feature_dim = num_classes 176 self.params_predictor = None 177 self.f_shape = None 178 else: 179 self.params_predictor = { 180 "dim": int(feature_dim), 181 "num_classes": int(num_classes), 182 } 183 self.f_shape = torch.Size([int(feature_dim)]) 184 self.params = { 185 "output_dim": int(feature_dim), 186 "in_channel": input_dims, 187 "kernel_size": int(kernel_size), 188 "mid_channels": mid_channels, 189 } 190 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 191 192 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 193 return _EDTCNModule(**self.params) 194 195 def _predictor(self) -> torch.nn.Module: 196 if self.params_predictor is None: 197 return nn.Identity() 198 else: 199 return _Predictor(**self.params_predictor) 200 201 def features_shape(self) -> torch.Size: 202 return self.f_shape
157class EDTCN(Model): 158 """ 159 An implementation of EDTCN (Endoder-Decoder TCN) 160 """ 161 162 def __init__( 163 self, 164 num_classes, 165 input_dims, 166 kernel_size, 167 mid_channels, 168 feature_dim=None, 169 state_dict_path=None, 170 ssl_constructors=None, 171 ssl_types=None, 172 ssl_modules=None, 173 ): 174 input_dims = int(sum([s[0] for s in input_dims.values()])) 175 if feature_dim is None: 176 feature_dim = num_classes 177 self.params_predictor = None 178 self.f_shape = None 179 else: 180 self.params_predictor = { 181 "dim": int(feature_dim), 182 "num_classes": int(num_classes), 183 } 184 self.f_shape = torch.Size([int(feature_dim)]) 185 self.params = { 186 "output_dim": int(feature_dim), 187 "in_channel": input_dims, 188 "kernel_size": int(kernel_size), 189 "mid_channels": mid_channels, 190 } 191 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 192 193 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 194 return _EDTCNModule(**self.params) 195 196 def _predictor(self) -> torch.nn.Module: 197 if self.params_predictor is None: 198 return nn.Identity() 199 else: 200 return _Predictor(**self.params_predictor) 201 202 def features_shape(self) -> torch.Size: 203 return self.f_shape
An implementation of EDTCN (Endoder-Decoder TCN)
162 def __init__( 163 self, 164 num_classes, 165 input_dims, 166 kernel_size, 167 mid_channels, 168 feature_dim=None, 169 state_dict_path=None, 170 ssl_constructors=None, 171 ssl_types=None, 172 ssl_modules=None, 173 ): 174 input_dims = int(sum([s[0] for s in input_dims.values()])) 175 if feature_dim is None: 176 feature_dim = num_classes 177 self.params_predictor = None 178 self.f_shape = None 179 else: 180 self.params_predictor = { 181 "dim": int(feature_dim), 182 "num_classes": int(num_classes), 183 } 184 self.f_shape = torch.Size([int(feature_dim)]) 185 self.params = { 186 "output_dim": int(feature_dim), 187 "in_channel": input_dims, 188 "kernel_size": int(kernel_size), 189 "mid_channels": mid_channels, 190 } 191 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
Get the shape of feature extractor output
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- forward
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- T_destination
- state_dict
- register_load_state_dict_post_hook
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr