dlc2action.model.c2f_tcn

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6# Incorporates code adapted from C2F-TCN by dipika-singhania
  7# Original work Copyright (c) 2021 dipika-singhania
  8# Source: https://github.com/dipika-singhania/C2F-TCN
  9# Originally licensed under MIT License
 10# Combined work licensed under GNU AGPLv3
 11#
 12import torch.nn.functional as F
 13import torch.nn as nn
 14import torch
 15from functools import partial
 16from dlc2action.model.base_model import Model
 17from typing import Union, List, Optional
 18
 19nonlinearity = partial(F.relu, inplace=True)
 20
 21
 22class double_conv(nn.Module):
 23    def __init__(self, in_ch, out_ch):
 24        super(double_conv, self).__init__()
 25        self.conv = nn.Sequential(
 26            nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2),
 27            nn.BatchNorm1d(out_ch),
 28            nn.ReLU(inplace=True),
 29            nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2),
 30            nn.BatchNorm1d(out_ch),
 31            nn.ReLU(inplace=True),
 32        )
 33
 34    def forward(self, x):
 35        """Forward pass."""
 36        x = self.conv(x)
 37        return x
 38
 39
 40class inconv(nn.Module):
 41    def __init__(self, in_ch, out_ch):
 42        super(inconv, self).__init__()
 43        self.conv = double_conv(in_ch, out_ch)
 44
 45    def forward(self, x):
 46        """Forward pass."""
 47        x = self.conv(x)
 48        return x
 49
 50
 51class outconv(nn.Module):
 52    def __init__(self, in_ch, out_ch):
 53        super(outconv, self).__init__()
 54        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1)
 55
 56    def forward(self, x):
 57        """Forward pass."""
 58        x = self.conv(x)
 59        return x
 60
 61
 62class down(nn.Module):
 63    def __init__(self, in_ch, out_ch):
 64        super(down, self).__init__()
 65        self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch))
 66
 67    def forward(self, x):
 68        """Forward pass."""
 69        x = self.max_pool_conv(x)
 70        return x
 71
 72
 73class up(nn.Module):
 74    """Upscaling then double conv"""
 75
 76    def __init__(self, in_channels, out_channels, bilinear=True):
 77        super().__init__()
 78
 79        if bilinear:
 80            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
 81        else:
 82            self.up = nn.ConvTranspose1d(
 83                in_channels // 2, in_channels // 2, kernel_size=2, stride=2
 84            )
 85
 86        self.conv = double_conv(in_channels, out_channels)
 87
 88    def forward(self, x1, x2):
 89        """Forward pass."""
 90        x1 = self.up(x1)
 91        # input is CHW
 92        diff = torch.tensor([x2.size()[2] - x1.size()[2]])
 93
 94        x1 = F.pad(x1, [diff // 2, diff - diff // 2])
 95        x = torch.cat([x2, x1], dim=1)
 96        return self.conv(x)
 97
 98
 99class TPPblock(nn.Module):
100    def __init__(self, in_channels):
101        super(TPPblock, self).__init__()
102        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
103        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
104        self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5)
105        self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6)
106
107        self.conv = nn.Conv1d(
108            in_channels=in_channels, out_channels=1, kernel_size=1, padding=0
109        )
110
111    def forward(self, x):
112        """Forward pass."""
113        self.in_channels, t = x.size(1), x.size(2)
114        self.layer1 = F.interpolate(
115            self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True
116        )
117        self.layer2 = F.interpolate(
118            self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True
119        )
120        self.layer3 = F.interpolate(
121            self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True
122        )
123        self.layer4 = F.interpolate(
124            self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True
125        )
126
127        out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
128
129        return out
130
131
132class Predictor(nn.Module):
133    def __init__(self, dim, num_classes):
134        super(Predictor, self).__init__()
135        self.num_classes = num_classes
136        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
137        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
138
139    def forward(self, x):
140        """Forward pass."""
141        x = self.conv_out_1(x)
142        x = F.relu(x)
143        x = self.conv_out_2(x)
144        x = x.reshape((4, -1, self.num_classes, x.shape[-1]))
145        return x
146
147
148class C2F_TCN_Module(nn.Module):
149    """
150    Features are extracted at the last layer of decoder.
151    """
152
153    def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False):
154        super().__init__()
155        self.use_predictor = use_predictor
156        self.inc = inconv(n_channels, num_f_maps * 2)
157        self.down1 = down(num_f_maps * 2, num_f_maps * 2)
158        self.down2 = down(num_f_maps * 2, num_f_maps * 2)
159        self.down3 = down(num_f_maps * 2, num_f_maps)
160        self.down4 = down(num_f_maps, num_f_maps)
161        self.down5 = down(num_f_maps, num_f_maps)
162        self.down6 = down(num_f_maps, num_f_maps)
163        self.up = up(num_f_maps * 2 + 4, num_f_maps)
164        self.outcc0 = outconv(num_f_maps, output_dim)
165        self.up0 = up(num_f_maps * 2, num_f_maps)
166        self.outcc1 = outconv(num_f_maps, output_dim)
167        self.up1 = up(num_f_maps * 2, num_f_maps)
168        self.outcc2 = outconv(num_f_maps, output_dim)
169        self.up2 = up(num_f_maps * 3, num_f_maps)
170        self.outcc3 = outconv(num_f_maps, output_dim)
171        self.up3 = up(num_f_maps * 3, num_f_maps)
172        self.outcc4 = outconv(num_f_maps, output_dim)
173        self.up4 = up(num_f_maps * 3, num_f_maps)
174        self.outcc = outconv(num_f_maps, output_dim)
175        self.tpp = TPPblock(num_f_maps)
176        self.weights = torch.nn.Parameter(torch.ones(6))
177
178    def forward(self, x):
179        """Forward pass."""
180        x1 = self.inc(x)
181        x2 = self.down1(x1)
182        x3 = self.down2(x2)
183        x4 = self.down3(x3)
184        x5 = self.down4(x4)
185        x6 = self.down5(x5)
186        x7 = self.down6(x6)
187        # x7 = self.dac(x7)
188        x7 = self.tpp(x7)
189        x = self.up(x7, x6)
190        y1 = self.outcc0(F.relu(x))
191        # print("y1.shape=", y1.shape)
192        x = self.up0(x, x5)
193        y2 = self.outcc1(F.relu(x))
194        # print("y2.shape=", y2.shape)
195        x = self.up1(x, x4)
196        y3 = self.outcc2(F.relu(x))
197        # print("y3.shape=", y3.shape)
198        x = self.up2(x, x3)
199        y4 = self.outcc3(F.relu(x))
200        # print("y4.shape=", y4.shape)
201        x = self.up3(x, x2)
202        y5 = self.outcc4(F.relu(x))
203        # print("y5.shape=", y5.shape)
204        x = self.up4(x, x1)
205        y = self.outcc(x)
206        # print("y.shape=", y.shape)
207        output = [y]
208        for outp_ele in [y5, y4, y3]:
209            output.append(
210                F.interpolate(
211                    outp_ele, size=y.shape[-1], mode="linear", align_corners=True
212                )
213            )
214        output = torch.stack(output, dim=0)
215        if self.use_predictor:
216            K, B, C, T = output.shape
217            output = output.reshape((-1, C, T))
218        return output
219
220
221class C2F_TCN(Model):
222    """
223    An implementation of C2F-TCN
224
225    Requires the `"general/len_segment"` parameter to be at least 512
226    """
227
228    def __init__(
229        self,
230        num_classes:int,
231        input_dims:dict,
232        num_f_maps:int=128,
233        feature_dim:int=None,
234        state_dict_path:str=None,
235        ssl_constructors:List=None,
236        ssl_types:List=None,
237        ssl_modules:List=None,
238    ):
239        input_dims = int(sum([s[0] for s in input_dims.values()]))
240        if feature_dim is None:
241            feature_dim = num_classes
242            self.f_shape = None
243            self.params_predictor = None
244        else:
245            self.f_shape = torch.Size([int(feature_dim)])
246            self.params_predictor = {
247                "dim": int(feature_dim),
248                "num_classes": num_classes,
249            }
250        self.params = {
251            "output_dim": int(feature_dim),
252            "n_channels": int(input_dims),
253            "num_f_maps": int(float(num_f_maps)),
254            "use_predictor": self.f_shape is not None,
255        }
256        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
257
258    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
259        return C2F_TCN_Module(**self.params)
260
261    def _predictor(self) -> torch.nn.Module:
262        if self.params_predictor is not None:
263            return Predictor(**self.params_predictor)
264        else:
265            return nn.Identity()
266
267    def features_shape(self) -> Optional[torch.Size]:
268        return self.f_shape
nonlinearity = functools.partial(<function relu>, inplace=True)
class double_conv(torch.nn.modules.module.Module):
23class double_conv(nn.Module):
24    def __init__(self, in_ch, out_ch):
25        super(double_conv, self).__init__()
26        self.conv = nn.Sequential(
27            nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2),
28            nn.BatchNorm1d(out_ch),
29            nn.ReLU(inplace=True),
30            nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2),
31            nn.BatchNorm1d(out_ch),
32            nn.ReLU(inplace=True),
33        )
34
35    def forward(self, x):
36        """Forward pass."""
37        x = self.conv(x)
38        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

double_conv(in_ch, out_ch)
24    def __init__(self, in_ch, out_ch):
25        super(double_conv, self).__init__()
26        self.conv = nn.Sequential(
27            nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2),
28            nn.BatchNorm1d(out_ch),
29            nn.ReLU(inplace=True),
30            nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2),
31            nn.BatchNorm1d(out_ch),
32            nn.ReLU(inplace=True),
33        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x):
35    def forward(self, x):
36        """Forward pass."""
37        x = self.conv(x)
38        return x

Forward pass.

class inconv(torch.nn.modules.module.Module):
41class inconv(nn.Module):
42    def __init__(self, in_ch, out_ch):
43        super(inconv, self).__init__()
44        self.conv = double_conv(in_ch, out_ch)
45
46    def forward(self, x):
47        """Forward pass."""
48        x = self.conv(x)
49        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

inconv(in_ch, out_ch)
42    def __init__(self, in_ch, out_ch):
43        super(inconv, self).__init__()
44        self.conv = double_conv(in_ch, out_ch)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x):
46    def forward(self, x):
47        """Forward pass."""
48        x = self.conv(x)
49        return x

Forward pass.

class outconv(torch.nn.modules.module.Module):
52class outconv(nn.Module):
53    def __init__(self, in_ch, out_ch):
54        super(outconv, self).__init__()
55        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1)
56
57    def forward(self, x):
58        """Forward pass."""
59        x = self.conv(x)
60        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

outconv(in_ch, out_ch)
53    def __init__(self, in_ch, out_ch):
54        super(outconv, self).__init__()
55        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x):
57    def forward(self, x):
58        """Forward pass."""
59        x = self.conv(x)
60        return x

Forward pass.

class down(torch.nn.modules.module.Module):
63class down(nn.Module):
64    def __init__(self, in_ch, out_ch):
65        super(down, self).__init__()
66        self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch))
67
68    def forward(self, x):
69        """Forward pass."""
70        x = self.max_pool_conv(x)
71        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

down(in_ch, out_ch)
64    def __init__(self, in_ch, out_ch):
65        super(down, self).__init__()
66        self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

max_pool_conv
def forward(self, x):
68    def forward(self, x):
69        """Forward pass."""
70        x = self.max_pool_conv(x)
71        return x

Forward pass.

class up(torch.nn.modules.module.Module):
74class up(nn.Module):
75    """Upscaling then double conv"""
76
77    def __init__(self, in_channels, out_channels, bilinear=True):
78        super().__init__()
79
80        if bilinear:
81            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
82        else:
83            self.up = nn.ConvTranspose1d(
84                in_channels // 2, in_channels // 2, kernel_size=2, stride=2
85            )
86
87        self.conv = double_conv(in_channels, out_channels)
88
89    def forward(self, x1, x2):
90        """Forward pass."""
91        x1 = self.up(x1)
92        # input is CHW
93        diff = torch.tensor([x2.size()[2] - x1.size()[2]])
94
95        x1 = F.pad(x1, [diff // 2, diff - diff // 2])
96        x = torch.cat([x2, x1], dim=1)
97        return self.conv(x)

Upscaling then double conv

up(in_channels, out_channels, bilinear=True)
77    def __init__(self, in_channels, out_channels, bilinear=True):
78        super().__init__()
79
80        if bilinear:
81            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
82        else:
83            self.up = nn.ConvTranspose1d(
84                in_channels // 2, in_channels // 2, kernel_size=2, stride=2
85            )
86
87        self.conv = double_conv(in_channels, out_channels)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x1, x2):
89    def forward(self, x1, x2):
90        """Forward pass."""
91        x1 = self.up(x1)
92        # input is CHW
93        diff = torch.tensor([x2.size()[2] - x1.size()[2]])
94
95        x1 = F.pad(x1, [diff // 2, diff - diff // 2])
96        x = torch.cat([x2, x1], dim=1)
97        return self.conv(x)

Forward pass.

class TPPblock(torch.nn.modules.module.Module):
100class TPPblock(nn.Module):
101    def __init__(self, in_channels):
102        super(TPPblock, self).__init__()
103        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
104        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
105        self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5)
106        self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6)
107
108        self.conv = nn.Conv1d(
109            in_channels=in_channels, out_channels=1, kernel_size=1, padding=0
110        )
111
112    def forward(self, x):
113        """Forward pass."""
114        self.in_channels, t = x.size(1), x.size(2)
115        self.layer1 = F.interpolate(
116            self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True
117        )
118        self.layer2 = F.interpolate(
119            self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True
120        )
121        self.layer3 = F.interpolate(
122            self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True
123        )
124        self.layer4 = F.interpolate(
125            self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True
126        )
127
128        out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
129
130        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

TPPblock(in_channels)
101    def __init__(self, in_channels):
102        super(TPPblock, self).__init__()
103        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
104        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
105        self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5)
106        self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6)
107
108        self.conv = nn.Conv1d(
109            in_channels=in_channels, out_channels=1, kernel_size=1, padding=0
110        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

pool1
pool2
pool3
pool4
conv
def forward(self, x):
112    def forward(self, x):
113        """Forward pass."""
114        self.in_channels, t = x.size(1), x.size(2)
115        self.layer1 = F.interpolate(
116            self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True
117        )
118        self.layer2 = F.interpolate(
119            self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True
120        )
121        self.layer3 = F.interpolate(
122            self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True
123        )
124        self.layer4 = F.interpolate(
125            self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True
126        )
127
128        out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
129
130        return out

Forward pass.

class Predictor(torch.nn.modules.module.Module):
133class Predictor(nn.Module):
134    def __init__(self, dim, num_classes):
135        super(Predictor, self).__init__()
136        self.num_classes = num_classes
137        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
138        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
139
140    def forward(self, x):
141        """Forward pass."""
142        x = self.conv_out_1(x)
143        x = F.relu(x)
144        x = self.conv_out_2(x)
145        x = x.reshape((4, -1, self.num_classes, x.shape[-1]))
146        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Predictor(dim, num_classes)
134    def __init__(self, dim, num_classes):
135        super(Predictor, self).__init__()
136        self.num_classes = num_classes
137        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
138        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_classes
conv_out_1
conv_out_2
def forward(self, x):
140    def forward(self, x):
141        """Forward pass."""
142        x = self.conv_out_1(x)
143        x = F.relu(x)
144        x = self.conv_out_2(x)
145        x = x.reshape((4, -1, self.num_classes, x.shape[-1]))
146        return x

Forward pass.

class C2F_TCN_Module(torch.nn.modules.module.Module):
149class C2F_TCN_Module(nn.Module):
150    """
151    Features are extracted at the last layer of decoder.
152    """
153
154    def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False):
155        super().__init__()
156        self.use_predictor = use_predictor
157        self.inc = inconv(n_channels, num_f_maps * 2)
158        self.down1 = down(num_f_maps * 2, num_f_maps * 2)
159        self.down2 = down(num_f_maps * 2, num_f_maps * 2)
160        self.down3 = down(num_f_maps * 2, num_f_maps)
161        self.down4 = down(num_f_maps, num_f_maps)
162        self.down5 = down(num_f_maps, num_f_maps)
163        self.down6 = down(num_f_maps, num_f_maps)
164        self.up = up(num_f_maps * 2 + 4, num_f_maps)
165        self.outcc0 = outconv(num_f_maps, output_dim)
166        self.up0 = up(num_f_maps * 2, num_f_maps)
167        self.outcc1 = outconv(num_f_maps, output_dim)
168        self.up1 = up(num_f_maps * 2, num_f_maps)
169        self.outcc2 = outconv(num_f_maps, output_dim)
170        self.up2 = up(num_f_maps * 3, num_f_maps)
171        self.outcc3 = outconv(num_f_maps, output_dim)
172        self.up3 = up(num_f_maps * 3, num_f_maps)
173        self.outcc4 = outconv(num_f_maps, output_dim)
174        self.up4 = up(num_f_maps * 3, num_f_maps)
175        self.outcc = outconv(num_f_maps, output_dim)
176        self.tpp = TPPblock(num_f_maps)
177        self.weights = torch.nn.Parameter(torch.ones(6))
178
179    def forward(self, x):
180        """Forward pass."""
181        x1 = self.inc(x)
182        x2 = self.down1(x1)
183        x3 = self.down2(x2)
184        x4 = self.down3(x3)
185        x5 = self.down4(x4)
186        x6 = self.down5(x5)
187        x7 = self.down6(x6)
188        # x7 = self.dac(x7)
189        x7 = self.tpp(x7)
190        x = self.up(x7, x6)
191        y1 = self.outcc0(F.relu(x))
192        # print("y1.shape=", y1.shape)
193        x = self.up0(x, x5)
194        y2 = self.outcc1(F.relu(x))
195        # print("y2.shape=", y2.shape)
196        x = self.up1(x, x4)
197        y3 = self.outcc2(F.relu(x))
198        # print("y3.shape=", y3.shape)
199        x = self.up2(x, x3)
200        y4 = self.outcc3(F.relu(x))
201        # print("y4.shape=", y4.shape)
202        x = self.up3(x, x2)
203        y5 = self.outcc4(F.relu(x))
204        # print("y5.shape=", y5.shape)
205        x = self.up4(x, x1)
206        y = self.outcc(x)
207        # print("y.shape=", y.shape)
208        output = [y]
209        for outp_ele in [y5, y4, y3]:
210            output.append(
211                F.interpolate(
212                    outp_ele, size=y.shape[-1], mode="linear", align_corners=True
213                )
214            )
215        output = torch.stack(output, dim=0)
216        if self.use_predictor:
217            K, B, C, T = output.shape
218            output = output.reshape((-1, C, T))
219        return output

Features are extracted at the last layer of decoder.

C2F_TCN_Module(n_channels, output_dim, num_f_maps, use_predictor=False)
154    def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False):
155        super().__init__()
156        self.use_predictor = use_predictor
157        self.inc = inconv(n_channels, num_f_maps * 2)
158        self.down1 = down(num_f_maps * 2, num_f_maps * 2)
159        self.down2 = down(num_f_maps * 2, num_f_maps * 2)
160        self.down3 = down(num_f_maps * 2, num_f_maps)
161        self.down4 = down(num_f_maps, num_f_maps)
162        self.down5 = down(num_f_maps, num_f_maps)
163        self.down6 = down(num_f_maps, num_f_maps)
164        self.up = up(num_f_maps * 2 + 4, num_f_maps)
165        self.outcc0 = outconv(num_f_maps, output_dim)
166        self.up0 = up(num_f_maps * 2, num_f_maps)
167        self.outcc1 = outconv(num_f_maps, output_dim)
168        self.up1 = up(num_f_maps * 2, num_f_maps)
169        self.outcc2 = outconv(num_f_maps, output_dim)
170        self.up2 = up(num_f_maps * 3, num_f_maps)
171        self.outcc3 = outconv(num_f_maps, output_dim)
172        self.up3 = up(num_f_maps * 3, num_f_maps)
173        self.outcc4 = outconv(num_f_maps, output_dim)
174        self.up4 = up(num_f_maps * 3, num_f_maps)
175        self.outcc = outconv(num_f_maps, output_dim)
176        self.tpp = TPPblock(num_f_maps)
177        self.weights = torch.nn.Parameter(torch.ones(6))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

use_predictor
inc
down1
down2
down3
down4
down5
down6
up
outcc0
up0
outcc1
up1
outcc2
up2
outcc3
up3
outcc4
up4
outcc
tpp
weights
def forward(self, x):
179    def forward(self, x):
180        """Forward pass."""
181        x1 = self.inc(x)
182        x2 = self.down1(x1)
183        x3 = self.down2(x2)
184        x4 = self.down3(x3)
185        x5 = self.down4(x4)
186        x6 = self.down5(x5)
187        x7 = self.down6(x6)
188        # x7 = self.dac(x7)
189        x7 = self.tpp(x7)
190        x = self.up(x7, x6)
191        y1 = self.outcc0(F.relu(x))
192        # print("y1.shape=", y1.shape)
193        x = self.up0(x, x5)
194        y2 = self.outcc1(F.relu(x))
195        # print("y2.shape=", y2.shape)
196        x = self.up1(x, x4)
197        y3 = self.outcc2(F.relu(x))
198        # print("y3.shape=", y3.shape)
199        x = self.up2(x, x3)
200        y4 = self.outcc3(F.relu(x))
201        # print("y4.shape=", y4.shape)
202        x = self.up3(x, x2)
203        y5 = self.outcc4(F.relu(x))
204        # print("y5.shape=", y5.shape)
205        x = self.up4(x, x1)
206        y = self.outcc(x)
207        # print("y.shape=", y.shape)
208        output = [y]
209        for outp_ele in [y5, y4, y3]:
210            output.append(
211                F.interpolate(
212                    outp_ele, size=y.shape[-1], mode="linear", align_corners=True
213                )
214            )
215        output = torch.stack(output, dim=0)
216        if self.use_predictor:
217            K, B, C, T = output.shape
218            output = output.reshape((-1, C, T))
219        return output

Forward pass.

class C2F_TCN(dlc2action.model.base_model.Model):
222class C2F_TCN(Model):
223    """
224    An implementation of C2F-TCN
225
226    Requires the `"general/len_segment"` parameter to be at least 512
227    """
228
229    def __init__(
230        self,
231        num_classes:int,
232        input_dims:dict,
233        num_f_maps:int=128,
234        feature_dim:int=None,
235        state_dict_path:str=None,
236        ssl_constructors:List=None,
237        ssl_types:List=None,
238        ssl_modules:List=None,
239    ):
240        input_dims = int(sum([s[0] for s in input_dims.values()]))
241        if feature_dim is None:
242            feature_dim = num_classes
243            self.f_shape = None
244            self.params_predictor = None
245        else:
246            self.f_shape = torch.Size([int(feature_dim)])
247            self.params_predictor = {
248                "dim": int(feature_dim),
249                "num_classes": num_classes,
250            }
251        self.params = {
252            "output_dim": int(feature_dim),
253            "n_channels": int(input_dims),
254            "num_f_maps": int(float(num_f_maps)),
255            "use_predictor": self.f_shape is not None,
256        }
257        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
258
259    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
260        return C2F_TCN_Module(**self.params)
261
262    def _predictor(self) -> torch.nn.Module:
263        if self.params_predictor is not None:
264            return Predictor(**self.params_predictor)
265        else:
266            return nn.Identity()
267
268    def features_shape(self) -> Optional[torch.Size]:
269        return self.f_shape

An implementation of C2F-TCN

Requires the "general/len_segment" parameter to be at least 512

C2F_TCN( num_classes: int, input_dims: dict, num_f_maps: int = 128, feature_dim: int = None, state_dict_path: str = None, ssl_constructors: List = None, ssl_types: List = None, ssl_modules: List = None)
229    def __init__(
230        self,
231        num_classes:int,
232        input_dims:dict,
233        num_f_maps:int=128,
234        feature_dim:int=None,
235        state_dict_path:str=None,
236        ssl_constructors:List=None,
237        ssl_types:List=None,
238        ssl_modules:List=None,
239    ):
240        input_dims = int(sum([s[0] for s in input_dims.values()]))
241        if feature_dim is None:
242            feature_dim = num_classes
243            self.f_shape = None
244            self.params_predictor = None
245        else:
246            self.f_shape = torch.Size([int(feature_dim)])
247            self.params_predictor = {
248                "dim": int(feature_dim),
249                "num_classes": num_classes,
250            }
251        self.params = {
252            "output_dim": int(feature_dim),
253            "n_channels": int(input_dims),
254            "num_f_maps": int(float(num_f_maps)),
255            "use_predictor": self.f_shape is not None,
256        }
257        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Initialize the model.

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt

params
def features_shape(self) -> Optional[torch.Size]:
268    def features_shape(self) -> Optional[torch.Size]:
269        return self.f_shape

Get the shape of feature extractor output.

Returns

feature_shape : torch.Size shape of feature extractor output