dlc2action.model.c2f_tcn
C2F-TCN
Adapted from https://github.com/dipika-singhania/C2F-TCN
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# 7# Adapted from C2F-TCN by dipika singhania 8# Adapted from https://github.com/dipika-singhania/C2F-TCN 9# Licensed under MIT License 10# 11""" C2F-TCN 12 13Adapted from https://github.com/dipika-singhania/C2F-TCN 14""" 15 16import torch.nn.functional as F 17import torch.nn as nn 18import torch 19from functools import partial 20from dlc2action.model.base_model import Model 21from typing import Union, List, Optional 22 23nonlinearity = partial(F.relu, inplace=True) 24 25 26class _double_conv(nn.Module): 27 def __init__(self, in_ch, out_ch): 28 super(_double_conv, self).__init__() 29 self.conv = nn.Sequential( 30 nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2), 31 nn.BatchNorm1d(out_ch), 32 nn.ReLU(inplace=True), 33 nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2), 34 nn.BatchNorm1d(out_ch), 35 nn.ReLU(inplace=True), 36 ) 37 38 def forward(self, x): 39 x = self.conv(x) 40 return x 41 42 43class _inconv(nn.Module): 44 def __init__(self, in_ch, out_ch): 45 super(_inconv, self).__init__() 46 self.conv = _double_conv(in_ch, out_ch) 47 48 def forward(self, x): 49 x = self.conv(x) 50 return x 51 52 53class _outconv(nn.Module): 54 def __init__(self, in_ch, out_ch): 55 super(_outconv, self).__init__() 56 self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1) 57 58 def forward(self, x): 59 x = self.conv(x) 60 return x 61 62 63class _down(nn.Module): 64 def __init__(self, in_ch, out_ch): 65 super(_down, self).__init__() 66 self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), _double_conv(in_ch, out_ch)) 67 68 def forward(self, x): 69 x = self.max_pool_conv(x) 70 return x 71 72 73class _up(nn.Module): 74 """Upscaling then double conv""" 75 76 def __init__(self, in_channels, out_channels, bilinear=True): 77 super().__init__() 78 79 if bilinear: 80 self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True) 81 else: 82 self.up = nn.ConvTranspose1d( 83 in_channels // 2, in_channels // 2, kernel_size=2, stride=2 84 ) 85 86 self.conv = _double_conv(in_channels, out_channels) 87 88 def forward(self, x1, x2): 89 x1 = self.up(x1) 90 # input is CHW 91 diff = torch.tensor([x2.size()[2] - x1.size()[2]]) 92 93 x1 = F.pad(x1, [diff // 2, diff - diff // 2]) 94 x = torch.cat([x2, x1], dim=1) 95 return self.conv(x) 96 97 98class _TPPblock(nn.Module): 99 def __init__(self, in_channels): 100 super(_TPPblock, self).__init__() 101 self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2) 102 self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3) 103 self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5) 104 self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6) 105 106 self.conv = nn.Conv1d( 107 in_channels=in_channels, out_channels=1, kernel_size=1, padding=0 108 ) 109 110 def forward(self, x): 111 self.in_channels, t = x.size(1), x.size(2) 112 self.layer1 = F.interpolate( 113 self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True 114 ) 115 self.layer2 = F.interpolate( 116 self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True 117 ) 118 self.layer3 = F.interpolate( 119 self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True 120 ) 121 self.layer4 = F.interpolate( 122 self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True 123 ) 124 125 out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1) 126 127 return out 128 129 130class _Predictor(nn.Module): 131 def __init__(self, dim, num_classes): 132 super(_Predictor, self).__init__() 133 self.num_classes = num_classes 134 self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1) 135 self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1) 136 137 def forward(self, x): 138 x = self.conv_out_1(x) 139 x = F.relu(x) 140 x = self.conv_out_2(x) 141 x = x.reshape((4, -1, self.num_classes, x.shape[-1])) 142 return x 143 144 145class _C2F_TCN_Module(nn.Module): 146 """ 147 Features are extracted at the last layer of decoder. 148 """ 149 150 def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False): 151 super().__init__() 152 self.use_predictor = use_predictor 153 self.inc = _inconv(n_channels, num_f_maps * 2) 154 self.down1 = _down(num_f_maps * 2, num_f_maps * 2) 155 self.down2 = _down(num_f_maps * 2, num_f_maps * 2) 156 self.down3 = _down(num_f_maps * 2, num_f_maps) 157 self.down4 = _down(num_f_maps, num_f_maps) 158 self.down5 = _down(num_f_maps, num_f_maps) 159 self.down6 = _down(num_f_maps, num_f_maps) 160 self.up = _up(num_f_maps * 2 + 4, num_f_maps) 161 self.outcc0 = _outconv(num_f_maps, output_dim) 162 self.up0 = _up(num_f_maps * 2, num_f_maps) 163 self.outcc1 = _outconv(num_f_maps, output_dim) 164 self.up1 = _up(num_f_maps * 2, num_f_maps) 165 self.outcc2 = _outconv(num_f_maps, output_dim) 166 self.up2 = _up(num_f_maps * 3, num_f_maps) 167 self.outcc3 = _outconv(num_f_maps, output_dim) 168 self.up3 = _up(num_f_maps * 3, num_f_maps) 169 self.outcc4 = _outconv(num_f_maps, output_dim) 170 self.up4 = _up(num_f_maps * 3, num_f_maps) 171 self.outcc = _outconv(num_f_maps, output_dim) 172 self.tpp = _TPPblock(num_f_maps) 173 self.weights = torch.nn.Parameter(torch.ones(6)) 174 175 def forward(self, x): 176 x1 = self.inc(x) 177 x2 = self.down1(x1) 178 x3 = self.down2(x2) 179 x4 = self.down3(x3) 180 x5 = self.down4(x4) 181 x6 = self.down5(x5) 182 x7 = self.down6(x6) 183 # x7 = self.dac(x7) 184 x7 = self.tpp(x7) 185 x = self.up(x7, x6) 186 y1 = self.outcc0(F.relu(x)) 187 # print("y1.shape=", y1.shape) 188 x = self.up0(x, x5) 189 y2 = self.outcc1(F.relu(x)) 190 # print("y2.shape=", y2.shape) 191 x = self.up1(x, x4) 192 y3 = self.outcc2(F.relu(x)) 193 # print("y3.shape=", y3.shape) 194 x = self.up2(x, x3) 195 y4 = self.outcc3(F.relu(x)) 196 # print("y4.shape=", y4.shape) 197 x = self.up3(x, x2) 198 y5 = self.outcc4(F.relu(x)) 199 # print("y5.shape=", y5.shape) 200 x = self.up4(x, x1) 201 y = self.outcc(x) 202 # print("y.shape=", y.shape) 203 output = [y] 204 for outp_ele in [y5, y4, y3]: 205 output.append( 206 F.interpolate( 207 outp_ele, size=y.shape[-1], mode="linear", align_corners=True 208 ) 209 ) 210 output = torch.stack(output, dim=0) 211 if self.use_predictor: 212 K, B, C, T = output.shape 213 output = output.reshape((-1, C, T)) 214 return output 215 216 217class C2F_TCN(Model): 218 """ 219 An implementation of C2F-TCN 220 221 Requires the `"general/len_segment"` parameter to be at least 512 222 """ 223 224 def __init__( 225 self, 226 num_classes, 227 input_dims, 228 num_f_maps=128, 229 feature_dim=None, 230 state_dict_path=None, 231 ssl_constructors=None, 232 ssl_types=None, 233 ssl_modules=None, 234 ): 235 input_dims = int(sum([s[0] for s in input_dims.values()])) 236 if feature_dim is None: 237 feature_dim = num_classes 238 self.f_shape = None 239 self.params_predictor = None 240 else: 241 self.f_shape = torch.Size([int(feature_dim)]) 242 self.params_predictor = { 243 "dim": int(feature_dim), 244 "num_classes": num_classes, 245 } 246 self.params = { 247 "output_dim": int(feature_dim), 248 "n_channels": int(input_dims), 249 "num_f_maps": int(num_f_maps), 250 "use_predictor": self.f_shape is not None, 251 } 252 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 253 254 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 255 return _C2F_TCN_Module(**self.params) 256 257 def _predictor(self) -> torch.nn.Module: 258 if self.params_predictor is not None: 259 return _Predictor(**self.params_predictor) 260 else: 261 return nn.Identity() 262 263 def features_shape(self) -> Optional[torch.Size]: 264 return self.f_shape
218class C2F_TCN(Model): 219 """ 220 An implementation of C2F-TCN 221 222 Requires the `"general/len_segment"` parameter to be at least 512 223 """ 224 225 def __init__( 226 self, 227 num_classes, 228 input_dims, 229 num_f_maps=128, 230 feature_dim=None, 231 state_dict_path=None, 232 ssl_constructors=None, 233 ssl_types=None, 234 ssl_modules=None, 235 ): 236 input_dims = int(sum([s[0] for s in input_dims.values()])) 237 if feature_dim is None: 238 feature_dim = num_classes 239 self.f_shape = None 240 self.params_predictor = None 241 else: 242 self.f_shape = torch.Size([int(feature_dim)]) 243 self.params_predictor = { 244 "dim": int(feature_dim), 245 "num_classes": num_classes, 246 } 247 self.params = { 248 "output_dim": int(feature_dim), 249 "n_channels": int(input_dims), 250 "num_f_maps": int(num_f_maps), 251 "use_predictor": self.f_shape is not None, 252 } 253 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 254 255 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 256 return _C2F_TCN_Module(**self.params) 257 258 def _predictor(self) -> torch.nn.Module: 259 if self.params_predictor is not None: 260 return _Predictor(**self.params_predictor) 261 else: 262 return nn.Identity() 263 264 def features_shape(self) -> Optional[torch.Size]: 265 return self.f_shape
An implementation of C2F-TCN
Requires the "general/len_segment"
parameter to be at least 512
225 def __init__( 226 self, 227 num_classes, 228 input_dims, 229 num_f_maps=128, 230 feature_dim=None, 231 state_dict_path=None, 232 ssl_constructors=None, 233 ssl_types=None, 234 ssl_modules=None, 235 ): 236 input_dims = int(sum([s[0] for s in input_dims.values()])) 237 if feature_dim is None: 238 feature_dim = num_classes 239 self.f_shape = None 240 self.params_predictor = None 241 else: 242 self.f_shape = torch.Size([int(feature_dim)]) 243 self.params_predictor = { 244 "dim": int(feature_dim), 245 "num_classes": num_classes, 246 } 247 self.params = { 248 "output_dim": int(feature_dim), 249 "n_channels": int(input_dims), 250 "num_f_maps": int(num_f_maps), 251 "use_predictor": self.f_shape is not None, 252 } 253 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
Get the shape of feature extractor output
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- forward
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- T_destination
- state_dict
- register_load_state_dict_post_hook
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr