dlc2action.model.c2f_tcn

C2F-TCN

Adapted from https://github.com/dipika-singhania/C2F-TCN

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6#
  7# Adapted from C2F-TCN by dipika singhania
  8# Adapted from https://github.com/dipika-singhania/C2F-TCN
  9# Licensed under MIT License
 10#
 11""" C2F-TCN
 12
 13Adapted from https://github.com/dipika-singhania/C2F-TCN
 14"""
 15
 16import torch.nn.functional as F
 17import torch.nn as nn
 18import torch
 19from functools import partial
 20from dlc2action.model.base_model import Model
 21from typing import Union, List, Optional
 22
 23nonlinearity = partial(F.relu, inplace=True)
 24
 25
 26class _double_conv(nn.Module):
 27    def __init__(self, in_ch, out_ch):
 28        super(_double_conv, self).__init__()
 29        self.conv = nn.Sequential(
 30            nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2),
 31            nn.BatchNorm1d(out_ch),
 32            nn.ReLU(inplace=True),
 33            nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2),
 34            nn.BatchNorm1d(out_ch),
 35            nn.ReLU(inplace=True),
 36        )
 37
 38    def forward(self, x):
 39        x = self.conv(x)
 40        return x
 41
 42
 43class _inconv(nn.Module):
 44    def __init__(self, in_ch, out_ch):
 45        super(_inconv, self).__init__()
 46        self.conv = _double_conv(in_ch, out_ch)
 47
 48    def forward(self, x):
 49        x = self.conv(x)
 50        return x
 51
 52
 53class _outconv(nn.Module):
 54    def __init__(self, in_ch, out_ch):
 55        super(_outconv, self).__init__()
 56        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1)
 57
 58    def forward(self, x):
 59        x = self.conv(x)
 60        return x
 61
 62
 63class _down(nn.Module):
 64    def __init__(self, in_ch, out_ch):
 65        super(_down, self).__init__()
 66        self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), _double_conv(in_ch, out_ch))
 67
 68    def forward(self, x):
 69        x = self.max_pool_conv(x)
 70        return x
 71
 72
 73class _up(nn.Module):
 74    """Upscaling then double conv"""
 75
 76    def __init__(self, in_channels, out_channels, bilinear=True):
 77        super().__init__()
 78
 79        if bilinear:
 80            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
 81        else:
 82            self.up = nn.ConvTranspose1d(
 83                in_channels // 2, in_channels // 2, kernel_size=2, stride=2
 84            )
 85
 86        self.conv = _double_conv(in_channels, out_channels)
 87
 88    def forward(self, x1, x2):
 89        x1 = self.up(x1)
 90        # input is CHW
 91        diff = torch.tensor([x2.size()[2] - x1.size()[2]])
 92
 93        x1 = F.pad(x1, [diff // 2, diff - diff // 2])
 94        x = torch.cat([x2, x1], dim=1)
 95        return self.conv(x)
 96
 97
 98class _TPPblock(nn.Module):
 99    def __init__(self, in_channels):
100        super(_TPPblock, self).__init__()
101        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
102        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
103        self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5)
104        self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6)
105
106        self.conv = nn.Conv1d(
107            in_channels=in_channels, out_channels=1, kernel_size=1, padding=0
108        )
109
110    def forward(self, x):
111        self.in_channels, t = x.size(1), x.size(2)
112        self.layer1 = F.interpolate(
113            self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True
114        )
115        self.layer2 = F.interpolate(
116            self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True
117        )
118        self.layer3 = F.interpolate(
119            self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True
120        )
121        self.layer4 = F.interpolate(
122            self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True
123        )
124
125        out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
126
127        return out
128
129
130class _Predictor(nn.Module):
131    def __init__(self, dim, num_classes):
132        super(_Predictor, self).__init__()
133        self.num_classes = num_classes
134        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
135        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
136
137    def forward(self, x):
138        x = self.conv_out_1(x)
139        x = F.relu(x)
140        x = self.conv_out_2(x)
141        x = x.reshape((4, -1, self.num_classes, x.shape[-1]))
142        return x
143
144
145class _C2F_TCN_Module(nn.Module):
146    """
147    Features are extracted at the last layer of decoder.
148    """
149
150    def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False):
151        super().__init__()
152        self.use_predictor = use_predictor
153        self.inc = _inconv(n_channels, num_f_maps * 2)
154        self.down1 = _down(num_f_maps * 2, num_f_maps * 2)
155        self.down2 = _down(num_f_maps * 2, num_f_maps * 2)
156        self.down3 = _down(num_f_maps * 2, num_f_maps)
157        self.down4 = _down(num_f_maps, num_f_maps)
158        self.down5 = _down(num_f_maps, num_f_maps)
159        self.down6 = _down(num_f_maps, num_f_maps)
160        self.up = _up(num_f_maps * 2 + 4, num_f_maps)
161        self.outcc0 = _outconv(num_f_maps, output_dim)
162        self.up0 = _up(num_f_maps * 2, num_f_maps)
163        self.outcc1 = _outconv(num_f_maps, output_dim)
164        self.up1 = _up(num_f_maps * 2, num_f_maps)
165        self.outcc2 = _outconv(num_f_maps, output_dim)
166        self.up2 = _up(num_f_maps * 3, num_f_maps)
167        self.outcc3 = _outconv(num_f_maps, output_dim)
168        self.up3 = _up(num_f_maps * 3, num_f_maps)
169        self.outcc4 = _outconv(num_f_maps, output_dim)
170        self.up4 = _up(num_f_maps * 3, num_f_maps)
171        self.outcc = _outconv(num_f_maps, output_dim)
172        self.tpp = _TPPblock(num_f_maps)
173        self.weights = torch.nn.Parameter(torch.ones(6))
174
175    def forward(self, x):
176        x1 = self.inc(x)
177        x2 = self.down1(x1)
178        x3 = self.down2(x2)
179        x4 = self.down3(x3)
180        x5 = self.down4(x4)
181        x6 = self.down5(x5)
182        x7 = self.down6(x6)
183        # x7 = self.dac(x7)
184        x7 = self.tpp(x7)
185        x = self.up(x7, x6)
186        y1 = self.outcc0(F.relu(x))
187        # print("y1.shape=", y1.shape)
188        x = self.up0(x, x5)
189        y2 = self.outcc1(F.relu(x))
190        # print("y2.shape=", y2.shape)
191        x = self.up1(x, x4)
192        y3 = self.outcc2(F.relu(x))
193        # print("y3.shape=", y3.shape)
194        x = self.up2(x, x3)
195        y4 = self.outcc3(F.relu(x))
196        # print("y4.shape=", y4.shape)
197        x = self.up3(x, x2)
198        y5 = self.outcc4(F.relu(x))
199        # print("y5.shape=", y5.shape)
200        x = self.up4(x, x1)
201        y = self.outcc(x)
202        # print("y.shape=", y.shape)
203        output = [y]
204        for outp_ele in [y5, y4, y3]:
205            output.append(
206                F.interpolate(
207                    outp_ele, size=y.shape[-1], mode="linear", align_corners=True
208                )
209            )
210        output = torch.stack(output, dim=0)
211        if self.use_predictor:
212            K, B, C, T = output.shape
213            output = output.reshape((-1, C, T))
214        return output
215
216
217class C2F_TCN(Model):
218    """
219    An implementation of C2F-TCN
220
221    Requires the `"general/len_segment"` parameter to be at least 512
222    """
223
224    def __init__(
225        self,
226        num_classes,
227        input_dims,
228        num_f_maps=128,
229        feature_dim=None,
230        state_dict_path=None,
231        ssl_constructors=None,
232        ssl_types=None,
233        ssl_modules=None,
234    ):
235        input_dims = int(sum([s[0] for s in input_dims.values()]))
236        if feature_dim is None:
237            feature_dim = num_classes
238            self.f_shape = None
239            self.params_predictor = None
240        else:
241            self.f_shape = torch.Size([int(feature_dim)])
242            self.params_predictor = {
243                "dim": int(feature_dim),
244                "num_classes": num_classes,
245            }
246        self.params = {
247            "output_dim": int(feature_dim),
248            "n_channels": int(input_dims),
249            "num_f_maps": int(num_f_maps),
250            "use_predictor": self.f_shape is not None,
251        }
252        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
253
254    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
255        return _C2F_TCN_Module(**self.params)
256
257    def _predictor(self) -> torch.nn.Module:
258        if self.params_predictor is not None:
259            return _Predictor(**self.params_predictor)
260        else:
261            return nn.Identity()
262
263    def features_shape(self) -> Optional[torch.Size]:
264        return self.f_shape
class C2F_TCN(dlc2action.model.base_model.Model):
218class C2F_TCN(Model):
219    """
220    An implementation of C2F-TCN
221
222    Requires the `"general/len_segment"` parameter to be at least 512
223    """
224
225    def __init__(
226        self,
227        num_classes,
228        input_dims,
229        num_f_maps=128,
230        feature_dim=None,
231        state_dict_path=None,
232        ssl_constructors=None,
233        ssl_types=None,
234        ssl_modules=None,
235    ):
236        input_dims = int(sum([s[0] for s in input_dims.values()]))
237        if feature_dim is None:
238            feature_dim = num_classes
239            self.f_shape = None
240            self.params_predictor = None
241        else:
242            self.f_shape = torch.Size([int(feature_dim)])
243            self.params_predictor = {
244                "dim": int(feature_dim),
245                "num_classes": num_classes,
246            }
247        self.params = {
248            "output_dim": int(feature_dim),
249            "n_channels": int(input_dims),
250            "num_f_maps": int(num_f_maps),
251            "use_predictor": self.f_shape is not None,
252        }
253        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
254
255    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
256        return _C2F_TCN_Module(**self.params)
257
258    def _predictor(self) -> torch.nn.Module:
259        if self.params_predictor is not None:
260            return _Predictor(**self.params_predictor)
261        else:
262            return nn.Identity()
263
264    def features_shape(self) -> Optional[torch.Size]:
265        return self.f_shape

An implementation of C2F-TCN

Requires the "general/len_segment" parameter to be at least 512

C2F_TCN( num_classes, input_dims, num_f_maps=128, feature_dim=None, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
225    def __init__(
226        self,
227        num_classes,
228        input_dims,
229        num_f_maps=128,
230        feature_dim=None,
231        state_dict_path=None,
232        ssl_constructors=None,
233        ssl_types=None,
234        ssl_modules=None,
235    ):
236        input_dims = int(sum([s[0] for s in input_dims.values()]))
237        if feature_dim is None:
238            feature_dim = num_classes
239            self.f_shape = None
240            self.params_predictor = None
241        else:
242            self.f_shape = torch.Size([int(feature_dim)])
243            self.params_predictor = {
244                "dim": int(feature_dim),
245                "num_classes": num_classes,
246            }
247        self.params = {
248            "output_dim": int(feature_dim),
249            "n_channels": int(input_dims),
250            "num_f_maps": int(num_f_maps),
251            "use_predictor": self.f_shape is not None,
252        }
253        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored

def features_shape(self) -> Optional[torch.Size]:
264    def features_shape(self) -> Optional[torch.Size]:
265        return self.f_shape

Get the shape of feature extractor output

Returns

feature_shape : torch.Size shape of feature extractor output

Inherited Members
dlc2action.model.base_model.Model
process_labels
freeze_feature_extractor
unfreeze_feature_extractor
load_state_dict
ssl_off
ssl_on
main_task_on
main_task_off
set_ssl
extract_features
forward
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
T_destination
state_dict
register_load_state_dict_post_hook
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr