dlc2action.model.c2f_tcn_par

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6# Incorporates code adapted from C2F-TCN by dipika-singhania
  7# Original work Copyright (c) 2021 dipika-singhania
  8# Source: https://github.com/dipika-singhania/C2F-TCN
  9# Originally licensed under MIT License
 10# Combined work licensed under GNU AGPLv3
 11#
 12from functools import partial
 13from typing import List, Optional, Union
 14
 15import torch
 16import torch.nn as nn
 17import torch.nn.functional as F
 18from dlc2action.model.base_model import Model
 19
 20nonlinearity = partial(F.relu, inplace=True)
 21
 22
 23class double_conv(nn.Module):
 24    def __init__(self, in_ch, out_ch):
 25        super(double_conv, self).__init__()
 26        self.conv = nn.Sequential(
 27            nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2),
 28            nn.BatchNorm1d(out_ch),
 29            nn.ReLU(inplace=True),
 30            nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2),
 31            nn.BatchNorm1d(out_ch),
 32            nn.ReLU(inplace=True),
 33        )
 34
 35    def forward(self, x):
 36        """Forward pass."""
 37        x = self.conv(x)
 38        return x
 39
 40
 41class inconv(nn.Module):
 42    def __init__(self, in_ch, out_ch):
 43        super(inconv, self).__init__()
 44        self.conv = double_conv(in_ch, out_ch)
 45
 46    def forward(self, x):
 47        """Forward pass."""
 48        x = self.conv(x)
 49        return x
 50
 51
 52class outconv(nn.Module):
 53    def __init__(self, in_ch, out_ch):
 54        super(outconv, self).__init__()
 55        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1)
 56
 57    def forward(self, x):
 58        """Forward pass."""
 59        x = self.conv(x)
 60        return x
 61
 62
 63class down(nn.Module):
 64    def __init__(self, in_ch, out_ch):
 65        super(down, self).__init__()
 66        self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch))
 67
 68    def forward(self, x):
 69        """Forward pass."""
 70        x = self.max_pool_conv(x)
 71        return x
 72
 73
 74class up(nn.Module):
 75    """Upscaling then double conv"""
 76
 77    def __init__(self, in_channels, out_channels, bilinear=True):
 78        super().__init__()
 79
 80        if bilinear:
 81            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
 82        else:
 83            self.up = nn.ConvTranspose1d(
 84                in_channels // 2, in_channels // 2, kernel_size=2, stride=2
 85            )
 86
 87        self.conv = double_conv(in_channels, out_channels)
 88
 89    def forward(self, x1, x2):
 90        """Forward pass."""
 91        x1 = self.up(x1)
 92        # input is CHW
 93        diff = torch.tensor([x2.size()[2] - x1.size()[2]])
 94
 95        x1 = F.pad(x1, [diff // 2, diff - diff // 2])
 96        x = torch.cat([x2, x1], dim=1)
 97        return self.conv(x)
 98
 99
100class TPPblock(nn.Module):
101    def __init__(self, in_channels):
102        super(TPPblock, self).__init__()
103        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
104        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
105        self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5)
106        self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6)
107
108        self.conv = nn.Conv1d(
109            in_channels=in_channels, out_channels=1, kernel_size=1, padding=0
110        )
111
112    def forward(self, x):
113        """Forward pass."""
114        self.in_channels, t = x.size(1), x.size(2)
115        self.layer1 = F.upsample(
116            self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True
117        )
118        self.layer2 = F.upsample(
119            self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True
120        )
121        self.layer3 = F.upsample(
122            self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True
123        )
124        self.layer4 = F.upsample(
125            self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True
126        )
127
128        out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
129
130        return out
131
132
133class Predictor(nn.Module):
134    def __init__(self, dim, num_classes):
135        super(Predictor, self).__init__()
136        self.num_classes = num_classes
137        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
138        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
139
140    def forward(self, x):
141        """Forward pass."""
142        x = self.conv_out_1(x)
143        x = F.relu(x)
144        x = self.conv_out_2(x)
145        x = x.reshape((4, -1, self.num_classes, x.shape[-1]))
146        return x
147
148
149class C2F_TCN_P_Module(nn.Module):
150    def __init__(self, n_channels, output_dim, num_f_maps):
151        super().__init__()
152        self.c2f_tcn = C2F_TCN_Module(
153            n_channels, output_dim, num_f_maps, use_predictor=True
154        )
155
156    def forward(self, x):
157        """Forward pass."""
158        output = []
159        for ind_x in x:
160            output.append(self.c2f_tcn(ind_x))
161        return torch.cat(output, dim=1)
162
163
164class C2F_TCN_Module(nn.Module):
165    """
166    Features are extracted at the last layer of decoder.
167    """
168
169    def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False):
170        super().__init__()
171        self.use_predictor = use_predictor
172        self.inc = inconv(n_channels, num_f_maps * 2)
173        self.down1 = down(num_f_maps * 2, num_f_maps * 2)
174        self.down2 = down(num_f_maps * 2, num_f_maps * 2)
175        self.down3 = down(num_f_maps * 2, num_f_maps)
176        self.down4 = down(num_f_maps, num_f_maps)
177        self.down5 = down(num_f_maps, num_f_maps)
178        self.down6 = down(num_f_maps, num_f_maps)
179        self.up = up(num_f_maps * 2 + 4, num_f_maps)
180        self.outcc0 = outconv(num_f_maps, output_dim)
181        self.up0 = up(num_f_maps * 2, num_f_maps)
182        self.outcc1 = outconv(num_f_maps, output_dim)
183        self.up1 = up(num_f_maps * 2, num_f_maps)
184        self.outcc2 = outconv(num_f_maps, output_dim)
185        self.up2 = up(num_f_maps * 3, num_f_maps)
186        self.outcc3 = outconv(num_f_maps, output_dim)
187        self.up3 = up(num_f_maps * 3, num_f_maps)
188        self.outcc4 = outconv(num_f_maps, output_dim)
189        self.up4 = up(num_f_maps * 3, num_f_maps)
190        self.outcc = outconv(num_f_maps, output_dim)
191        self.tpp = TPPblock(num_f_maps)
192        self.weights = torch.nn.Parameter(torch.ones(6))
193
194    def forward(self, x):
195        """Forward pass."""
196        x1 = self.inc(x)
197        x2 = self.down1(x1)
198        x3 = self.down2(x2)
199        x4 = self.down3(x3)
200        x5 = self.down4(x4)
201        x6 = self.down5(x5)
202        x7 = self.down6(x6)
203        # x7 = self.dac(x7)
204        x7 = self.tpp(x7)
205        x = self.up(x7, x6)
206        y1 = self.outcc0(F.relu(x))
207        # print("y1.shape=", y1.shape)
208        x = self.up0(x, x5)
209        y2 = self.outcc1(F.relu(x))
210        # print("y2.shape=", y2.shape)
211        x = self.up1(x, x4)
212        y3 = self.outcc2(F.relu(x))
213        # print("y3.shape=", y3.shape)
214        x = self.up2(x, x3)
215        y4 = self.outcc3(F.relu(x))
216        # print("y4.shape=", y4.shape)
217        x = self.up3(x, x2)
218        y5 = self.outcc4(F.relu(x))
219        # print("y5.shape=", y5.shape)
220        x = self.up4(x, x1)
221        y = self.outcc(x)
222        # print("y.shape=", y.shape)
223        output = [y]
224        for outp_ele in [y5, y4, y3]:
225            output.append(
226                F.upsample(
227                    outp_ele, size=y.shape[-1], mode="linear", align_corners=True
228                )
229            )
230        output = torch.stack(output, dim=0)
231        if self.use_predictor:
232            K, B, C, T = output.shape
233            output = output.reshape((-1, C, T))
234        return output
235
236
237class C2F_TCN_P(Model):
238    def __init__(
239        self,
240        num_classes,
241        input_dims,
242        num_f_maps=128,
243        feature_dim=None,
244        state_dict_path=None,
245        ssl_constructors=None,
246        ssl_types=None,
247        ssl_modules=None,
248    ):
249        if feature_dim is None:
250            feature_dim = num_f_maps
251        keys = [
252            key
253            for key in input_dims.keys()
254            if len(key.split("---")) != 1 and len(key.split("---")[-1].split("+")) != 2
255        ]
256        num_ind = len(set([key.split("---")[-1] for key in keys]))
257        key = keys[0]
258        ind = key.split("---")[-1]
259        input_dims = int(
260            sum([v[0] for k, v in input_dims.items() if k.split("---")[-1] == ind])
261        )
262        self.f_shape = torch.Size([feature_dim * num_ind])
263        self.params_predictor = {
264            "dim": int(feature_dim * num_ind),
265            "num_classes": num_classes,
266        }
267        self.params = {
268            "output_dim": int(feature_dim),
269            "n_channels": int(input_dims),
270            "num_f_maps": int(num_f_maps),
271        }
272        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
273
274    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
275        return C2F_TCN_P_Module(**self.params)
276
277    def _predictor(self) -> torch.nn.Module:
278        return Predictor(**self.params_predictor)
279
280    def features_shape(self) -> Optional[torch.Size]:
281        return self.f_shape
nonlinearity = functools.partial(<function relu>, inplace=True)
class double_conv(torch.nn.modules.module.Module):
24class double_conv(nn.Module):
25    def __init__(self, in_ch, out_ch):
26        super(double_conv, self).__init__()
27        self.conv = nn.Sequential(
28            nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2),
29            nn.BatchNorm1d(out_ch),
30            nn.ReLU(inplace=True),
31            nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2),
32            nn.BatchNorm1d(out_ch),
33            nn.ReLU(inplace=True),
34        )
35
36    def forward(self, x):
37        """Forward pass."""
38        x = self.conv(x)
39        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

double_conv(in_ch, out_ch)
25    def __init__(self, in_ch, out_ch):
26        super(double_conv, self).__init__()
27        self.conv = nn.Sequential(
28            nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2),
29            nn.BatchNorm1d(out_ch),
30            nn.ReLU(inplace=True),
31            nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2),
32            nn.BatchNorm1d(out_ch),
33            nn.ReLU(inplace=True),
34        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x):
36    def forward(self, x):
37        """Forward pass."""
38        x = self.conv(x)
39        return x

Forward pass.

class inconv(torch.nn.modules.module.Module):
42class inconv(nn.Module):
43    def __init__(self, in_ch, out_ch):
44        super(inconv, self).__init__()
45        self.conv = double_conv(in_ch, out_ch)
46
47    def forward(self, x):
48        """Forward pass."""
49        x = self.conv(x)
50        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

inconv(in_ch, out_ch)
43    def __init__(self, in_ch, out_ch):
44        super(inconv, self).__init__()
45        self.conv = double_conv(in_ch, out_ch)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x):
47    def forward(self, x):
48        """Forward pass."""
49        x = self.conv(x)
50        return x

Forward pass.

class outconv(torch.nn.modules.module.Module):
53class outconv(nn.Module):
54    def __init__(self, in_ch, out_ch):
55        super(outconv, self).__init__()
56        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1)
57
58    def forward(self, x):
59        """Forward pass."""
60        x = self.conv(x)
61        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

outconv(in_ch, out_ch)
54    def __init__(self, in_ch, out_ch):
55        super(outconv, self).__init__()
56        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x):
58    def forward(self, x):
59        """Forward pass."""
60        x = self.conv(x)
61        return x

Forward pass.

class down(torch.nn.modules.module.Module):
64class down(nn.Module):
65    def __init__(self, in_ch, out_ch):
66        super(down, self).__init__()
67        self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch))
68
69    def forward(self, x):
70        """Forward pass."""
71        x = self.max_pool_conv(x)
72        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

down(in_ch, out_ch)
65    def __init__(self, in_ch, out_ch):
66        super(down, self).__init__()
67        self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

max_pool_conv
def forward(self, x):
69    def forward(self, x):
70        """Forward pass."""
71        x = self.max_pool_conv(x)
72        return x

Forward pass.

class up(torch.nn.modules.module.Module):
75class up(nn.Module):
76    """Upscaling then double conv"""
77
78    def __init__(self, in_channels, out_channels, bilinear=True):
79        super().__init__()
80
81        if bilinear:
82            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
83        else:
84            self.up = nn.ConvTranspose1d(
85                in_channels // 2, in_channels // 2, kernel_size=2, stride=2
86            )
87
88        self.conv = double_conv(in_channels, out_channels)
89
90    def forward(self, x1, x2):
91        """Forward pass."""
92        x1 = self.up(x1)
93        # input is CHW
94        diff = torch.tensor([x2.size()[2] - x1.size()[2]])
95
96        x1 = F.pad(x1, [diff // 2, diff - diff // 2])
97        x = torch.cat([x2, x1], dim=1)
98        return self.conv(x)

Upscaling then double conv

up(in_channels, out_channels, bilinear=True)
78    def __init__(self, in_channels, out_channels, bilinear=True):
79        super().__init__()
80
81        if bilinear:
82            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
83        else:
84            self.up = nn.ConvTranspose1d(
85                in_channels // 2, in_channels // 2, kernel_size=2, stride=2
86            )
87
88        self.conv = double_conv(in_channels, out_channels)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x1, x2):
90    def forward(self, x1, x2):
91        """Forward pass."""
92        x1 = self.up(x1)
93        # input is CHW
94        diff = torch.tensor([x2.size()[2] - x1.size()[2]])
95
96        x1 = F.pad(x1, [diff // 2, diff - diff // 2])
97        x = torch.cat([x2, x1], dim=1)
98        return self.conv(x)

Forward pass.

class TPPblock(torch.nn.modules.module.Module):
101class TPPblock(nn.Module):
102    def __init__(self, in_channels):
103        super(TPPblock, self).__init__()
104        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
105        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
106        self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5)
107        self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6)
108
109        self.conv = nn.Conv1d(
110            in_channels=in_channels, out_channels=1, kernel_size=1, padding=0
111        )
112
113    def forward(self, x):
114        """Forward pass."""
115        self.in_channels, t = x.size(1), x.size(2)
116        self.layer1 = F.upsample(
117            self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True
118        )
119        self.layer2 = F.upsample(
120            self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True
121        )
122        self.layer3 = F.upsample(
123            self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True
124        )
125        self.layer4 = F.upsample(
126            self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True
127        )
128
129        out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
130
131        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

TPPblock(in_channels)
102    def __init__(self, in_channels):
103        super(TPPblock, self).__init__()
104        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
105        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
106        self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5)
107        self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6)
108
109        self.conv = nn.Conv1d(
110            in_channels=in_channels, out_channels=1, kernel_size=1, padding=0
111        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

pool1
pool2
pool3
pool4
conv
def forward(self, x):
113    def forward(self, x):
114        """Forward pass."""
115        self.in_channels, t = x.size(1), x.size(2)
116        self.layer1 = F.upsample(
117            self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True
118        )
119        self.layer2 = F.upsample(
120            self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True
121        )
122        self.layer3 = F.upsample(
123            self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True
124        )
125        self.layer4 = F.upsample(
126            self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True
127        )
128
129        out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
130
131        return out

Forward pass.

class Predictor(torch.nn.modules.module.Module):
134class Predictor(nn.Module):
135    def __init__(self, dim, num_classes):
136        super(Predictor, self).__init__()
137        self.num_classes = num_classes
138        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
139        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
140
141    def forward(self, x):
142        """Forward pass."""
143        x = self.conv_out_1(x)
144        x = F.relu(x)
145        x = self.conv_out_2(x)
146        x = x.reshape((4, -1, self.num_classes, x.shape[-1]))
147        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Predictor(dim, num_classes)
135    def __init__(self, dim, num_classes):
136        super(Predictor, self).__init__()
137        self.num_classes = num_classes
138        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
139        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_classes
conv_out_1
conv_out_2
def forward(self, x):
141    def forward(self, x):
142        """Forward pass."""
143        x = self.conv_out_1(x)
144        x = F.relu(x)
145        x = self.conv_out_2(x)
146        x = x.reshape((4, -1, self.num_classes, x.shape[-1]))
147        return x

Forward pass.

class C2F_TCN_P_Module(torch.nn.modules.module.Module):
150class C2F_TCN_P_Module(nn.Module):
151    def __init__(self, n_channels, output_dim, num_f_maps):
152        super().__init__()
153        self.c2f_tcn = C2F_TCN_Module(
154            n_channels, output_dim, num_f_maps, use_predictor=True
155        )
156
157    def forward(self, x):
158        """Forward pass."""
159        output = []
160        for ind_x in x:
161            output.append(self.c2f_tcn(ind_x))
162        return torch.cat(output, dim=1)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

C2F_TCN_P_Module(n_channels, output_dim, num_f_maps)
151    def __init__(self, n_channels, output_dim, num_f_maps):
152        super().__init__()
153        self.c2f_tcn = C2F_TCN_Module(
154            n_channels, output_dim, num_f_maps, use_predictor=True
155        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

c2f_tcn
def forward(self, x):
157    def forward(self, x):
158        """Forward pass."""
159        output = []
160        for ind_x in x:
161            output.append(self.c2f_tcn(ind_x))
162        return torch.cat(output, dim=1)

Forward pass.

class C2F_TCN_Module(torch.nn.modules.module.Module):
165class C2F_TCN_Module(nn.Module):
166    """
167    Features are extracted at the last layer of decoder.
168    """
169
170    def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False):
171        super().__init__()
172        self.use_predictor = use_predictor
173        self.inc = inconv(n_channels, num_f_maps * 2)
174        self.down1 = down(num_f_maps * 2, num_f_maps * 2)
175        self.down2 = down(num_f_maps * 2, num_f_maps * 2)
176        self.down3 = down(num_f_maps * 2, num_f_maps)
177        self.down4 = down(num_f_maps, num_f_maps)
178        self.down5 = down(num_f_maps, num_f_maps)
179        self.down6 = down(num_f_maps, num_f_maps)
180        self.up = up(num_f_maps * 2 + 4, num_f_maps)
181        self.outcc0 = outconv(num_f_maps, output_dim)
182        self.up0 = up(num_f_maps * 2, num_f_maps)
183        self.outcc1 = outconv(num_f_maps, output_dim)
184        self.up1 = up(num_f_maps * 2, num_f_maps)
185        self.outcc2 = outconv(num_f_maps, output_dim)
186        self.up2 = up(num_f_maps * 3, num_f_maps)
187        self.outcc3 = outconv(num_f_maps, output_dim)
188        self.up3 = up(num_f_maps * 3, num_f_maps)
189        self.outcc4 = outconv(num_f_maps, output_dim)
190        self.up4 = up(num_f_maps * 3, num_f_maps)
191        self.outcc = outconv(num_f_maps, output_dim)
192        self.tpp = TPPblock(num_f_maps)
193        self.weights = torch.nn.Parameter(torch.ones(6))
194
195    def forward(self, x):
196        """Forward pass."""
197        x1 = self.inc(x)
198        x2 = self.down1(x1)
199        x3 = self.down2(x2)
200        x4 = self.down3(x3)
201        x5 = self.down4(x4)
202        x6 = self.down5(x5)
203        x7 = self.down6(x6)
204        # x7 = self.dac(x7)
205        x7 = self.tpp(x7)
206        x = self.up(x7, x6)
207        y1 = self.outcc0(F.relu(x))
208        # print("y1.shape=", y1.shape)
209        x = self.up0(x, x5)
210        y2 = self.outcc1(F.relu(x))
211        # print("y2.shape=", y2.shape)
212        x = self.up1(x, x4)
213        y3 = self.outcc2(F.relu(x))
214        # print("y3.shape=", y3.shape)
215        x = self.up2(x, x3)
216        y4 = self.outcc3(F.relu(x))
217        # print("y4.shape=", y4.shape)
218        x = self.up3(x, x2)
219        y5 = self.outcc4(F.relu(x))
220        # print("y5.shape=", y5.shape)
221        x = self.up4(x, x1)
222        y = self.outcc(x)
223        # print("y.shape=", y.shape)
224        output = [y]
225        for outp_ele in [y5, y4, y3]:
226            output.append(
227                F.upsample(
228                    outp_ele, size=y.shape[-1], mode="linear", align_corners=True
229                )
230            )
231        output = torch.stack(output, dim=0)
232        if self.use_predictor:
233            K, B, C, T = output.shape
234            output = output.reshape((-1, C, T))
235        return output

Features are extracted at the last layer of decoder.

C2F_TCN_Module(n_channels, output_dim, num_f_maps, use_predictor=False)
170    def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False):
171        super().__init__()
172        self.use_predictor = use_predictor
173        self.inc = inconv(n_channels, num_f_maps * 2)
174        self.down1 = down(num_f_maps * 2, num_f_maps * 2)
175        self.down2 = down(num_f_maps * 2, num_f_maps * 2)
176        self.down3 = down(num_f_maps * 2, num_f_maps)
177        self.down4 = down(num_f_maps, num_f_maps)
178        self.down5 = down(num_f_maps, num_f_maps)
179        self.down6 = down(num_f_maps, num_f_maps)
180        self.up = up(num_f_maps * 2 + 4, num_f_maps)
181        self.outcc0 = outconv(num_f_maps, output_dim)
182        self.up0 = up(num_f_maps * 2, num_f_maps)
183        self.outcc1 = outconv(num_f_maps, output_dim)
184        self.up1 = up(num_f_maps * 2, num_f_maps)
185        self.outcc2 = outconv(num_f_maps, output_dim)
186        self.up2 = up(num_f_maps * 3, num_f_maps)
187        self.outcc3 = outconv(num_f_maps, output_dim)
188        self.up3 = up(num_f_maps * 3, num_f_maps)
189        self.outcc4 = outconv(num_f_maps, output_dim)
190        self.up4 = up(num_f_maps * 3, num_f_maps)
191        self.outcc = outconv(num_f_maps, output_dim)
192        self.tpp = TPPblock(num_f_maps)
193        self.weights = torch.nn.Parameter(torch.ones(6))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

use_predictor
inc
down1
down2
down3
down4
down5
down6
up
outcc0
up0
outcc1
up1
outcc2
up2
outcc3
up3
outcc4
up4
outcc
tpp
weights
def forward(self, x):
195    def forward(self, x):
196        """Forward pass."""
197        x1 = self.inc(x)
198        x2 = self.down1(x1)
199        x3 = self.down2(x2)
200        x4 = self.down3(x3)
201        x5 = self.down4(x4)
202        x6 = self.down5(x5)
203        x7 = self.down6(x6)
204        # x7 = self.dac(x7)
205        x7 = self.tpp(x7)
206        x = self.up(x7, x6)
207        y1 = self.outcc0(F.relu(x))
208        # print("y1.shape=", y1.shape)
209        x = self.up0(x, x5)
210        y2 = self.outcc1(F.relu(x))
211        # print("y2.shape=", y2.shape)
212        x = self.up1(x, x4)
213        y3 = self.outcc2(F.relu(x))
214        # print("y3.shape=", y3.shape)
215        x = self.up2(x, x3)
216        y4 = self.outcc3(F.relu(x))
217        # print("y4.shape=", y4.shape)
218        x = self.up3(x, x2)
219        y5 = self.outcc4(F.relu(x))
220        # print("y5.shape=", y5.shape)
221        x = self.up4(x, x1)
222        y = self.outcc(x)
223        # print("y.shape=", y.shape)
224        output = [y]
225        for outp_ele in [y5, y4, y3]:
226            output.append(
227                F.upsample(
228                    outp_ele, size=y.shape[-1], mode="linear", align_corners=True
229                )
230            )
231        output = torch.stack(output, dim=0)
232        if self.use_predictor:
233            K, B, C, T = output.shape
234            output = output.reshape((-1, C, T))
235        return output

Forward pass.

class C2F_TCN_P(dlc2action.model.base_model.Model):
238class C2F_TCN_P(Model):
239    def __init__(
240        self,
241        num_classes,
242        input_dims,
243        num_f_maps=128,
244        feature_dim=None,
245        state_dict_path=None,
246        ssl_constructors=None,
247        ssl_types=None,
248        ssl_modules=None,
249    ):
250        if feature_dim is None:
251            feature_dim = num_f_maps
252        keys = [
253            key
254            for key in input_dims.keys()
255            if len(key.split("---")) != 1 and len(key.split("---")[-1].split("+")) != 2
256        ]
257        num_ind = len(set([key.split("---")[-1] for key in keys]))
258        key = keys[0]
259        ind = key.split("---")[-1]
260        input_dims = int(
261            sum([v[0] for k, v in input_dims.items() if k.split("---")[-1] == ind])
262        )
263        self.f_shape = torch.Size([feature_dim * num_ind])
264        self.params_predictor = {
265            "dim": int(feature_dim * num_ind),
266            "num_classes": num_classes,
267        }
268        self.params = {
269            "output_dim": int(feature_dim),
270            "n_channels": int(input_dims),
271            "num_f_maps": int(num_f_maps),
272        }
273        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
274
275    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
276        return C2F_TCN_P_Module(**self.params)
277
278    def _predictor(self) -> torch.nn.Module:
279        return Predictor(**self.params_predictor)
280
281    def features_shape(self) -> Optional[torch.Size]:
282        return self.f_shape

Base class for all models.

Manages interaction of base model and SSL modules + ensures consistent input and output format.

C2F_TCN_P( num_classes, input_dims, num_f_maps=128, feature_dim=None, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
239    def __init__(
240        self,
241        num_classes,
242        input_dims,
243        num_f_maps=128,
244        feature_dim=None,
245        state_dict_path=None,
246        ssl_constructors=None,
247        ssl_types=None,
248        ssl_modules=None,
249    ):
250        if feature_dim is None:
251            feature_dim = num_f_maps
252        keys = [
253            key
254            for key in input_dims.keys()
255            if len(key.split("---")) != 1 and len(key.split("---")[-1].split("+")) != 2
256        ]
257        num_ind = len(set([key.split("---")[-1] for key in keys]))
258        key = keys[0]
259        ind = key.split("---")[-1]
260        input_dims = int(
261            sum([v[0] for k, v in input_dims.items() if k.split("---")[-1] == ind])
262        )
263        self.f_shape = torch.Size([feature_dim * num_ind])
264        self.params_predictor = {
265            "dim": int(feature_dim * num_ind),
266            "num_classes": num_classes,
267        }
268        self.params = {
269            "output_dim": int(feature_dim),
270            "n_channels": int(input_dims),
271            "num_f_maps": int(num_f_maps),
272        }
273        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Initialize the model.

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt

f_shape
params_predictor
params
def features_shape(self) -> Optional[torch.Size]:
281    def features_shape(self) -> Optional[torch.Size]:
282        return self.f_shape

Get the shape of feature extractor output.

Returns

feature_shape : torch.Size shape of feature extractor output