dlc2action.model.c2f_transformer

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7import math
  8from functools import partial
  9from typing import List, Optional, Union
 10
 11import torch
 12import torch.nn as nn
 13import torch.nn.functional as F
 14from dlc2action.model.base_model import Model
 15
 16nonlinearity = partial(F.relu, inplace=True)
 17
 18
 19class double_conv(nn.Module):
 20    def __init__(self, in_ch, out_ch, fc=False):
 21        super(double_conv, self).__init__()
 22        if fc:
 23            kernel_size = 1
 24            padding = 0
 25        else:
 26            kernel_size = 5
 27            padding = 2
 28        self.conv = nn.Sequential(
 29            nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size, padding=padding),
 30            nn.BatchNorm1d(out_ch),
 31            nn.ReLU(inplace=True),
 32            nn.Conv1d(out_ch, out_ch, kernel_size=kernel_size, padding=padding),
 33            nn.BatchNorm1d(out_ch),
 34            nn.ReLU(inplace=True),
 35        )
 36
 37    def forward(self, x):
 38        """Forward pass."""
 39        x = self.conv(x)
 40        return x
 41
 42
 43class inconv(nn.Module):
 44    def __init__(self, in_ch, out_ch):
 45        super(inconv, self).__init__()
 46        self.conv = double_conv(in_ch, out_ch)
 47
 48    def forward(self, x):
 49        """Forward pass."""
 50        x = self.conv(x)
 51        return x
 52
 53
 54class outconv(nn.Module):
 55    def __init__(self, in_ch, out_ch):
 56        super(outconv, self).__init__()
 57        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1)
 58
 59    def forward(self, x):
 60        """Forward pass."""
 61        x = self.conv(x)
 62        return x
 63
 64
 65class down(nn.Module):
 66    def __init__(self, in_ch, out_ch):
 67        super(down, self).__init__()
 68        self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch))
 69
 70    def forward(self, x):
 71        """Forward pass."""
 72        x = self.max_pool_conv(x)
 73        return x
 74
 75
 76class up(nn.Module):
 77    """Upscaling then double conv"""
 78
 79    def __init__(
 80        self, in_channels, out_channels, heads, att_in, bilinear=True, fc=False
 81    ):
 82        super().__init__()
 83
 84        self.attn = MultiHeadAttention(heads=heads, d_model=att_in)
 85        if bilinear:
 86            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
 87        else:
 88            self.up = nn.ConvTranspose1d(
 89                in_channels // 2, in_channels // 2, kernel_size=2, stride=2
 90            )
 91
 92        self.out = double_conv(in_channels, out_channels, fc=fc)
 93
 94    def forward(self, x1, x2):
 95        """Forward pass."""
 96        x1 = self.attn(x1, x1, x1)
 97        x1 = self.up(x1)
 98        diff = torch.tensor([x2.size()[2] - x1.size()[2]])
 99
100        x1 = F.pad(x1, [diff // 2, diff - diff // 2])
101        x = torch.cat([x2, x1], dim=1)
102        return self.out(x)
103
104
105class TPPblock(nn.Module):
106    def __init__(self, in_channels):
107        super(TPPblock, self).__init__()
108        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
109        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
110        self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5)
111        self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6)
112
113        self.conv = nn.Conv1d(
114            in_channels=in_channels, out_channels=1, kernel_size=1, padding=0
115        )
116
117        self.out_conv = nn.Conv1d(
118            in_channels=in_channels + 4,
119            out_channels=in_channels,
120            kernel_size=1,
121            padding=0,
122        )
123
124    def forward(self, x):
125        """Forward pass."""
126        self.in_channels, t = x.size(1), x.size(2)
127        self.layer1 = F.interpolate(
128            self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True
129        )
130        self.layer2 = F.interpolate(
131            self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True
132        )
133        self.layer3 = F.interpolate(
134            self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True
135        )
136        self.layer4 = F.interpolate(
137            self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True
138        )
139
140        out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
141        out = self.out_conv(out)
142
143        return out
144
145
146class C2F_Transformer_Module(nn.Module):
147    """
148    Features are extracted at the last layer of decoder.
149    """
150
151    def __init__(
152        self, n_channels, output_dim, heads, num_f_maps, use_predictor=False, fc=False
153    ):
154        super().__init__()
155        self.use_predictor = use_predictor
156        self.inc = inconv(n_channels, num_f_maps * 2)
157        self.down1 = down(num_f_maps * 2, num_f_maps * 2)
158        self.down2 = down(num_f_maps * 2, num_f_maps * 2)
159        self.down3 = down(num_f_maps * 2, num_f_maps)
160        self.down4 = down(num_f_maps, num_f_maps)
161        self.down5 = down(num_f_maps, num_f_maps)
162        self.down6 = down(num_f_maps, num_f_maps)
163        self.pe = PositionalEncoder(num_f_maps)
164        self.up = up(num_f_maps * 2, num_f_maps, heads, att_in=num_f_maps, fc=fc)
165        self.outcc0 = outconv(num_f_maps, output_dim)
166        self.up0 = up(num_f_maps * 2, num_f_maps, heads, att_in=num_f_maps, fc=fc)
167        self.outcc1 = outconv(num_f_maps, output_dim)
168        self.up1 = up(num_f_maps * 2, num_f_maps, heads, att_in=num_f_maps, fc=fc)
169        self.outcc2 = outconv(num_f_maps, output_dim)
170        self.up2 = up(num_f_maps * 3, num_f_maps, heads, att_in=num_f_maps, fc=fc)
171        self.outcc3 = outconv(num_f_maps, output_dim)
172        self.up3 = up(num_f_maps * 3, num_f_maps, heads, att_in=num_f_maps, fc=fc)
173        self.outcc4 = outconv(num_f_maps, output_dim)
174        self.up4 = up(num_f_maps * 3, num_f_maps, heads, att_in=num_f_maps, fc=fc)
175        self.outcc = outconv(num_f_maps, output_dim)
176        self.tpp = TPPblock(num_f_maps)
177        self.weights = torch.nn.Parameter(torch.ones(6))
178
179    def forward(self, x):
180        """Forward pass."""
181        # print(f'{x.shape=}')
182        x1 = self.inc(x)
183        x2 = self.down1(x1)
184        x3 = self.down2(x2)
185        x4 = self.down3(x3)
186        x5 = self.down4(x4)
187        x6 = self.down5(x5)
188        x7 = self.down6(x6)
189        x7 = self.tpp(x7)
190        x7 = self.pe(x7)
191        # print(f'{x6.shape=}')
192        x = self.up(x7, x6)
193        y1 = self.outcc0(F.relu(x))
194        x = self.up0(x, x5)
195        y2 = self.outcc1(F.relu(x))
196        x = self.up1(x, x4)
197        y3 = self.outcc2(F.relu(x))
198        x = self.up2(x, x3)
199        y4 = self.outcc3(F.relu(x))
200        x = self.up3(x, x2)
201        y5 = self.outcc4(F.relu(x))
202        x = self.up4(x, x1)
203        y = self.outcc(x)
204        output = [y]
205        for outp_ele in [y5, y4, y3]:
206            output.append(
207                F.interpolate(
208                    outp_ele, size=y.shape[-1], mode="linear", align_corners=True
209                )
210            )
211        output = torch.stack(output, dim=0)
212        if self.use_predictor:
213            K, B, C, T = output.shape
214            output = output.reshape((-1, C, T))
215        return output
216
217
218class C2F_Transformer(Model):
219    """
220    A modification of C2F-TCN that replaces some convolutions with attention
221
222    Requires the `"general/len_segment"` parameter to be at least 512
223    """
224
225    def __init__(
226        self,
227        num_classes,
228        input_dims,
229        heads,
230        num_f_maps,
231        linear=False,
232        feature_dim=None,
233        state_dict_path=None,
234        ssl_constructors=None,
235        ssl_types=None,
236        ssl_modules=None,
237    ):
238        input_dims = int(sum([s[0] for s in input_dims.values()]))
239        if feature_dim is None:
240            feature_dim = num_classes
241            self.f_shape = None
242            self.params_predictor = None
243        else:
244            self.f_shape = torch.Size([feature_dim])
245            self.params_predictor = {
246                "dim": int(feature_dim),
247                "num_classes": num_classes,
248            }
249        self.params = {
250            "output_dim": int(float(feature_dim)),
251            "n_channels": int(float(input_dims)),
252            "num_f_maps": int(float(num_f_maps)),
253            "heads": int(float(heads)),
254            "use_predictor": self.f_shape is not None,
255            "fc": linear,
256        }
257        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
258
259    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
260        return C2F_Transformer_Module(**self.params)
261
262    def _predictor(self) -> torch.nn.Module:
263        if self.params_predictor is not None:
264            return Predictor(**self.params_predictor)
265        else:
266            return nn.Identity()
267
268    def features_shape(self) -> Optional[torch.Size]:
269        return self.f_shape
270
271
272class PositionalEncoder(nn.Module):
273    def __init__(self, d_model, max_seq_len=512):
274        super().__init__()
275        self.d_model = d_model
276
277        # create constant 'pe' matrix with values dependent on
278        # pos and i
279        pe = torch.zeros(d_model, max_seq_len)
280        for pos in range(max_seq_len):
281            for i in range(0, d_model, 2):
282                pe[i, pos] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
283                if i + 1 < d_model:
284                    pe[i + 1, pos] = math.cos(
285                        pos / (10000 ** ((2 * (i + 1)) / d_model))
286                    )
287
288        self.pe = pe.unsqueeze(0)
289
290    def forward(self, x):
291        """Forward pass."""
292        # make embeddings relatively larger
293        x = x * math.sqrt(self.d_model)
294        # add constant to embedding
295        seq_len = x.size(-1)
296        x = x + self.pe[:, :, :seq_len].to(x.device)
297        return x
298
299
300class Predictor(nn.Module):
301    def __init__(self, dim, num_classes):
302        super(Predictor, self).__init__()
303        self.num_classes = num_classes
304        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
305        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
306
307    def forward(self, x):
308        """Forward pass."""
309        x = self.conv_out_1(x)
310        x = F.relu(x)
311        x = self.conv_out_2(x)
312        x = x.reshape((4, -1, self.num_classes, x.shape[-1]))
313        return x
314
315
316class MultiHeadAttention(nn.Module):
317    def __init__(self, heads, d_model, dropout=0.1):
318        super().__init__()
319
320        self.d_model = d_model
321        self.d_k = d_model // heads
322        self.h = heads
323
324        # print(f'{d_model=}')
325        self.q_linear = nn.Linear(d_model, d_model)
326        self.v_linear = nn.Linear(d_model, d_model)
327        self.k_linear = nn.Linear(d_model, d_model)
328        self.dropout = nn.Dropout(dropout)
329        self.out = nn.Linear(d_model, d_model)
330
331    def forward(self, q, k, v, mask=None):
332        """Forward pass."""
333        bs = q.size(0)
334        q = q.transpose(1, 2)
335        v = v.transpose(1, 2)
336        k = k.transpose(1, 2)
337
338        # perform linear operation and split into h heads
339        # print(f'{self.h=}, {self.d_k=}, {k.shape=}')
340        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
341        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
342        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
343
344        # transpose to get dimensions bs * h * sl * d_model
345
346        k = k.transpose(1, 2)
347        q = q.transpose(1, 2)
348        v = v.transpose(1, 2)
349        # calculate attention using function we will define next
350        scores = attention(q, k, v, self.d_k, mask, self.dropout)
351
352        # concatenate heads and put through final linear layer
353        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
354
355        output = self.out(concat).transpose(1, 2)
356
357        return output
358
359
360def attention(q, k, v, d_k, mask=None, dropout=None):
361    """Attention."""
362    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
363    if mask is not None:
364        mask = mask.unsqueeze(1)
365        scores = scores.masked_fill(mask == 0, -1e9)
366    scores = F.softmax(scores, dim=-1)
367
368    if dropout is not None:
369        scores = dropout(scores)
370
371    output = torch.matmul(scores, v)
372    return output
nonlinearity = functools.partial(<function relu>, inplace=True)
class double_conv(torch.nn.modules.module.Module):
20class double_conv(nn.Module):
21    def __init__(self, in_ch, out_ch, fc=False):
22        super(double_conv, self).__init__()
23        if fc:
24            kernel_size = 1
25            padding = 0
26        else:
27            kernel_size = 5
28            padding = 2
29        self.conv = nn.Sequential(
30            nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size, padding=padding),
31            nn.BatchNorm1d(out_ch),
32            nn.ReLU(inplace=True),
33            nn.Conv1d(out_ch, out_ch, kernel_size=kernel_size, padding=padding),
34            nn.BatchNorm1d(out_ch),
35            nn.ReLU(inplace=True),
36        )
37
38    def forward(self, x):
39        """Forward pass."""
40        x = self.conv(x)
41        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

double_conv(in_ch, out_ch, fc=False)
21    def __init__(self, in_ch, out_ch, fc=False):
22        super(double_conv, self).__init__()
23        if fc:
24            kernel_size = 1
25            padding = 0
26        else:
27            kernel_size = 5
28            padding = 2
29        self.conv = nn.Sequential(
30            nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size, padding=padding),
31            nn.BatchNorm1d(out_ch),
32            nn.ReLU(inplace=True),
33            nn.Conv1d(out_ch, out_ch, kernel_size=kernel_size, padding=padding),
34            nn.BatchNorm1d(out_ch),
35            nn.ReLU(inplace=True),
36        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x):
38    def forward(self, x):
39        """Forward pass."""
40        x = self.conv(x)
41        return x

Forward pass.

class inconv(torch.nn.modules.module.Module):
44class inconv(nn.Module):
45    def __init__(self, in_ch, out_ch):
46        super(inconv, self).__init__()
47        self.conv = double_conv(in_ch, out_ch)
48
49    def forward(self, x):
50        """Forward pass."""
51        x = self.conv(x)
52        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

inconv(in_ch, out_ch)
45    def __init__(self, in_ch, out_ch):
46        super(inconv, self).__init__()
47        self.conv = double_conv(in_ch, out_ch)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x):
49    def forward(self, x):
50        """Forward pass."""
51        x = self.conv(x)
52        return x

Forward pass.

class outconv(torch.nn.modules.module.Module):
55class outconv(nn.Module):
56    def __init__(self, in_ch, out_ch):
57        super(outconv, self).__init__()
58        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1)
59
60    def forward(self, x):
61        """Forward pass."""
62        x = self.conv(x)
63        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

outconv(in_ch, out_ch)
56    def __init__(self, in_ch, out_ch):
57        super(outconv, self).__init__()
58        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv
def forward(self, x):
60    def forward(self, x):
61        """Forward pass."""
62        x = self.conv(x)
63        return x

Forward pass.

class down(torch.nn.modules.module.Module):
66class down(nn.Module):
67    def __init__(self, in_ch, out_ch):
68        super(down, self).__init__()
69        self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch))
70
71    def forward(self, x):
72        """Forward pass."""
73        x = self.max_pool_conv(x)
74        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

down(in_ch, out_ch)
67    def __init__(self, in_ch, out_ch):
68        super(down, self).__init__()
69        self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

max_pool_conv
def forward(self, x):
71    def forward(self, x):
72        """Forward pass."""
73        x = self.max_pool_conv(x)
74        return x

Forward pass.

class up(torch.nn.modules.module.Module):
 77class up(nn.Module):
 78    """Upscaling then double conv"""
 79
 80    def __init__(
 81        self, in_channels, out_channels, heads, att_in, bilinear=True, fc=False
 82    ):
 83        super().__init__()
 84
 85        self.attn = MultiHeadAttention(heads=heads, d_model=att_in)
 86        if bilinear:
 87            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
 88        else:
 89            self.up = nn.ConvTranspose1d(
 90                in_channels // 2, in_channels // 2, kernel_size=2, stride=2
 91            )
 92
 93        self.out = double_conv(in_channels, out_channels, fc=fc)
 94
 95    def forward(self, x1, x2):
 96        """Forward pass."""
 97        x1 = self.attn(x1, x1, x1)
 98        x1 = self.up(x1)
 99        diff = torch.tensor([x2.size()[2] - x1.size()[2]])
100
101        x1 = F.pad(x1, [diff // 2, diff - diff // 2])
102        x = torch.cat([x2, x1], dim=1)
103        return self.out(x)

Upscaling then double conv

up(in_channels, out_channels, heads, att_in, bilinear=True, fc=False)
80    def __init__(
81        self, in_channels, out_channels, heads, att_in, bilinear=True, fc=False
82    ):
83        super().__init__()
84
85        self.attn = MultiHeadAttention(heads=heads, d_model=att_in)
86        if bilinear:
87            self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True)
88        else:
89            self.up = nn.ConvTranspose1d(
90                in_channels // 2, in_channels // 2, kernel_size=2, stride=2
91            )
92
93        self.out = double_conv(in_channels, out_channels, fc=fc)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

attn
out
def forward(self, x1, x2):
 95    def forward(self, x1, x2):
 96        """Forward pass."""
 97        x1 = self.attn(x1, x1, x1)
 98        x1 = self.up(x1)
 99        diff = torch.tensor([x2.size()[2] - x1.size()[2]])
100
101        x1 = F.pad(x1, [diff // 2, diff - diff // 2])
102        x = torch.cat([x2, x1], dim=1)
103        return self.out(x)

Forward pass.

class TPPblock(torch.nn.modules.module.Module):
106class TPPblock(nn.Module):
107    def __init__(self, in_channels):
108        super(TPPblock, self).__init__()
109        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
110        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
111        self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5)
112        self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6)
113
114        self.conv = nn.Conv1d(
115            in_channels=in_channels, out_channels=1, kernel_size=1, padding=0
116        )
117
118        self.out_conv = nn.Conv1d(
119            in_channels=in_channels + 4,
120            out_channels=in_channels,
121            kernel_size=1,
122            padding=0,
123        )
124
125    def forward(self, x):
126        """Forward pass."""
127        self.in_channels, t = x.size(1), x.size(2)
128        self.layer1 = F.interpolate(
129            self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True
130        )
131        self.layer2 = F.interpolate(
132            self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True
133        )
134        self.layer3 = F.interpolate(
135            self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True
136        )
137        self.layer4 = F.interpolate(
138            self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True
139        )
140
141        out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
142        out = self.out_conv(out)
143
144        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

TPPblock(in_channels)
107    def __init__(self, in_channels):
108        super(TPPblock, self).__init__()
109        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)
110        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
111        self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5)
112        self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6)
113
114        self.conv = nn.Conv1d(
115            in_channels=in_channels, out_channels=1, kernel_size=1, padding=0
116        )
117
118        self.out_conv = nn.Conv1d(
119            in_channels=in_channels + 4,
120            out_channels=in_channels,
121            kernel_size=1,
122            padding=0,
123        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

pool1
pool2
pool3
pool4
conv
out_conv
def forward(self, x):
125    def forward(self, x):
126        """Forward pass."""
127        self.in_channels, t = x.size(1), x.size(2)
128        self.layer1 = F.interpolate(
129            self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True
130        )
131        self.layer2 = F.interpolate(
132            self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True
133        )
134        self.layer3 = F.interpolate(
135            self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True
136        )
137        self.layer4 = F.interpolate(
138            self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True
139        )
140
141        out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
142        out = self.out_conv(out)
143
144        return out

Forward pass.

class C2F_Transformer_Module(torch.nn.modules.module.Module):
147class C2F_Transformer_Module(nn.Module):
148    """
149    Features are extracted at the last layer of decoder.
150    """
151
152    def __init__(
153        self, n_channels, output_dim, heads, num_f_maps, use_predictor=False, fc=False
154    ):
155        super().__init__()
156        self.use_predictor = use_predictor
157        self.inc = inconv(n_channels, num_f_maps * 2)
158        self.down1 = down(num_f_maps * 2, num_f_maps * 2)
159        self.down2 = down(num_f_maps * 2, num_f_maps * 2)
160        self.down3 = down(num_f_maps * 2, num_f_maps)
161        self.down4 = down(num_f_maps, num_f_maps)
162        self.down5 = down(num_f_maps, num_f_maps)
163        self.down6 = down(num_f_maps, num_f_maps)
164        self.pe = PositionalEncoder(num_f_maps)
165        self.up = up(num_f_maps * 2, num_f_maps, heads, att_in=num_f_maps, fc=fc)
166        self.outcc0 = outconv(num_f_maps, output_dim)
167        self.up0 = up(num_f_maps * 2, num_f_maps, heads, att_in=num_f_maps, fc=fc)
168        self.outcc1 = outconv(num_f_maps, output_dim)
169        self.up1 = up(num_f_maps * 2, num_f_maps, heads, att_in=num_f_maps, fc=fc)
170        self.outcc2 = outconv(num_f_maps, output_dim)
171        self.up2 = up(num_f_maps * 3, num_f_maps, heads, att_in=num_f_maps, fc=fc)
172        self.outcc3 = outconv(num_f_maps, output_dim)
173        self.up3 = up(num_f_maps * 3, num_f_maps, heads, att_in=num_f_maps, fc=fc)
174        self.outcc4 = outconv(num_f_maps, output_dim)
175        self.up4 = up(num_f_maps * 3, num_f_maps, heads, att_in=num_f_maps, fc=fc)
176        self.outcc = outconv(num_f_maps, output_dim)
177        self.tpp = TPPblock(num_f_maps)
178        self.weights = torch.nn.Parameter(torch.ones(6))
179
180    def forward(self, x):
181        """Forward pass."""
182        # print(f'{x.shape=}')
183        x1 = self.inc(x)
184        x2 = self.down1(x1)
185        x3 = self.down2(x2)
186        x4 = self.down3(x3)
187        x5 = self.down4(x4)
188        x6 = self.down5(x5)
189        x7 = self.down6(x6)
190        x7 = self.tpp(x7)
191        x7 = self.pe(x7)
192        # print(f'{x6.shape=}')
193        x = self.up(x7, x6)
194        y1 = self.outcc0(F.relu(x))
195        x = self.up0(x, x5)
196        y2 = self.outcc1(F.relu(x))
197        x = self.up1(x, x4)
198        y3 = self.outcc2(F.relu(x))
199        x = self.up2(x, x3)
200        y4 = self.outcc3(F.relu(x))
201        x = self.up3(x, x2)
202        y5 = self.outcc4(F.relu(x))
203        x = self.up4(x, x1)
204        y = self.outcc(x)
205        output = [y]
206        for outp_ele in [y5, y4, y3]:
207            output.append(
208                F.interpolate(
209                    outp_ele, size=y.shape[-1], mode="linear", align_corners=True
210                )
211            )
212        output = torch.stack(output, dim=0)
213        if self.use_predictor:
214            K, B, C, T = output.shape
215            output = output.reshape((-1, C, T))
216        return output

Features are extracted at the last layer of decoder.

C2F_Transformer_Module( n_channels, output_dim, heads, num_f_maps, use_predictor=False, fc=False)
152    def __init__(
153        self, n_channels, output_dim, heads, num_f_maps, use_predictor=False, fc=False
154    ):
155        super().__init__()
156        self.use_predictor = use_predictor
157        self.inc = inconv(n_channels, num_f_maps * 2)
158        self.down1 = down(num_f_maps * 2, num_f_maps * 2)
159        self.down2 = down(num_f_maps * 2, num_f_maps * 2)
160        self.down3 = down(num_f_maps * 2, num_f_maps)
161        self.down4 = down(num_f_maps, num_f_maps)
162        self.down5 = down(num_f_maps, num_f_maps)
163        self.down6 = down(num_f_maps, num_f_maps)
164        self.pe = PositionalEncoder(num_f_maps)
165        self.up = up(num_f_maps * 2, num_f_maps, heads, att_in=num_f_maps, fc=fc)
166        self.outcc0 = outconv(num_f_maps, output_dim)
167        self.up0 = up(num_f_maps * 2, num_f_maps, heads, att_in=num_f_maps, fc=fc)
168        self.outcc1 = outconv(num_f_maps, output_dim)
169        self.up1 = up(num_f_maps * 2, num_f_maps, heads, att_in=num_f_maps, fc=fc)
170        self.outcc2 = outconv(num_f_maps, output_dim)
171        self.up2 = up(num_f_maps * 3, num_f_maps, heads, att_in=num_f_maps, fc=fc)
172        self.outcc3 = outconv(num_f_maps, output_dim)
173        self.up3 = up(num_f_maps * 3, num_f_maps, heads, att_in=num_f_maps, fc=fc)
174        self.outcc4 = outconv(num_f_maps, output_dim)
175        self.up4 = up(num_f_maps * 3, num_f_maps, heads, att_in=num_f_maps, fc=fc)
176        self.outcc = outconv(num_f_maps, output_dim)
177        self.tpp = TPPblock(num_f_maps)
178        self.weights = torch.nn.Parameter(torch.ones(6))

Initialize internal Module state, shared by both nn.Module and ScriptModule.

use_predictor
inc
down1
down2
down3
down4
down5
down6
pe
up
outcc0
up0
outcc1
up1
outcc2
up2
outcc3
up3
outcc4
up4
outcc
tpp
weights
def forward(self, x):
180    def forward(self, x):
181        """Forward pass."""
182        # print(f'{x.shape=}')
183        x1 = self.inc(x)
184        x2 = self.down1(x1)
185        x3 = self.down2(x2)
186        x4 = self.down3(x3)
187        x5 = self.down4(x4)
188        x6 = self.down5(x5)
189        x7 = self.down6(x6)
190        x7 = self.tpp(x7)
191        x7 = self.pe(x7)
192        # print(f'{x6.shape=}')
193        x = self.up(x7, x6)
194        y1 = self.outcc0(F.relu(x))
195        x = self.up0(x, x5)
196        y2 = self.outcc1(F.relu(x))
197        x = self.up1(x, x4)
198        y3 = self.outcc2(F.relu(x))
199        x = self.up2(x, x3)
200        y4 = self.outcc3(F.relu(x))
201        x = self.up3(x, x2)
202        y5 = self.outcc4(F.relu(x))
203        x = self.up4(x, x1)
204        y = self.outcc(x)
205        output = [y]
206        for outp_ele in [y5, y4, y3]:
207            output.append(
208                F.interpolate(
209                    outp_ele, size=y.shape[-1], mode="linear", align_corners=True
210                )
211            )
212        output = torch.stack(output, dim=0)
213        if self.use_predictor:
214            K, B, C, T = output.shape
215            output = output.reshape((-1, C, T))
216        return output

Forward pass.

class C2F_Transformer(dlc2action.model.base_model.Model):
219class C2F_Transformer(Model):
220    """
221    A modification of C2F-TCN that replaces some convolutions with attention
222
223    Requires the `"general/len_segment"` parameter to be at least 512
224    """
225
226    def __init__(
227        self,
228        num_classes,
229        input_dims,
230        heads,
231        num_f_maps,
232        linear=False,
233        feature_dim=None,
234        state_dict_path=None,
235        ssl_constructors=None,
236        ssl_types=None,
237        ssl_modules=None,
238    ):
239        input_dims = int(sum([s[0] for s in input_dims.values()]))
240        if feature_dim is None:
241            feature_dim = num_classes
242            self.f_shape = None
243            self.params_predictor = None
244        else:
245            self.f_shape = torch.Size([feature_dim])
246            self.params_predictor = {
247                "dim": int(feature_dim),
248                "num_classes": num_classes,
249            }
250        self.params = {
251            "output_dim": int(float(feature_dim)),
252            "n_channels": int(float(input_dims)),
253            "num_f_maps": int(float(num_f_maps)),
254            "heads": int(float(heads)),
255            "use_predictor": self.f_shape is not None,
256            "fc": linear,
257        }
258        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
259
260    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
261        return C2F_Transformer_Module(**self.params)
262
263    def _predictor(self) -> torch.nn.Module:
264        if self.params_predictor is not None:
265            return Predictor(**self.params_predictor)
266        else:
267            return nn.Identity()
268
269    def features_shape(self) -> Optional[torch.Size]:
270        return self.f_shape

A modification of C2F-TCN that replaces some convolutions with attention

Requires the "general/len_segment" parameter to be at least 512

C2F_Transformer( num_classes, input_dims, heads, num_f_maps, linear=False, feature_dim=None, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
226    def __init__(
227        self,
228        num_classes,
229        input_dims,
230        heads,
231        num_f_maps,
232        linear=False,
233        feature_dim=None,
234        state_dict_path=None,
235        ssl_constructors=None,
236        ssl_types=None,
237        ssl_modules=None,
238    ):
239        input_dims = int(sum([s[0] for s in input_dims.values()]))
240        if feature_dim is None:
241            feature_dim = num_classes
242            self.f_shape = None
243            self.params_predictor = None
244        else:
245            self.f_shape = torch.Size([feature_dim])
246            self.params_predictor = {
247                "dim": int(feature_dim),
248                "num_classes": num_classes,
249            }
250        self.params = {
251            "output_dim": int(float(feature_dim)),
252            "n_channels": int(float(input_dims)),
253            "num_f_maps": int(float(num_f_maps)),
254            "heads": int(float(heads)),
255            "use_predictor": self.f_shape is not None,
256            "fc": linear,
257        }
258        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Initialize the model.

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt

params
def features_shape(self) -> Optional[torch.Size]:
269    def features_shape(self) -> Optional[torch.Size]:
270        return self.f_shape

Get the shape of feature extractor output.

Returns

feature_shape : torch.Size shape of feature extractor output

class PositionalEncoder(torch.nn.modules.module.Module):
273class PositionalEncoder(nn.Module):
274    def __init__(self, d_model, max_seq_len=512):
275        super().__init__()
276        self.d_model = d_model
277
278        # create constant 'pe' matrix with values dependent on
279        # pos and i
280        pe = torch.zeros(d_model, max_seq_len)
281        for pos in range(max_seq_len):
282            for i in range(0, d_model, 2):
283                pe[i, pos] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
284                if i + 1 < d_model:
285                    pe[i + 1, pos] = math.cos(
286                        pos / (10000 ** ((2 * (i + 1)) / d_model))
287                    )
288
289        self.pe = pe.unsqueeze(0)
290
291    def forward(self, x):
292        """Forward pass."""
293        # make embeddings relatively larger
294        x = x * math.sqrt(self.d_model)
295        # add constant to embedding
296        seq_len = x.size(-1)
297        x = x + self.pe[:, :, :seq_len].to(x.device)
298        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

PositionalEncoder(d_model, max_seq_len=512)
274    def __init__(self, d_model, max_seq_len=512):
275        super().__init__()
276        self.d_model = d_model
277
278        # create constant 'pe' matrix with values dependent on
279        # pos and i
280        pe = torch.zeros(d_model, max_seq_len)
281        for pos in range(max_seq_len):
282            for i in range(0, d_model, 2):
283                pe[i, pos] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
284                if i + 1 < d_model:
285                    pe[i + 1, pos] = math.cos(
286                        pos / (10000 ** ((2 * (i + 1)) / d_model))
287                    )
288
289        self.pe = pe.unsqueeze(0)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

d_model
pe
def forward(self, x):
291    def forward(self, x):
292        """Forward pass."""
293        # make embeddings relatively larger
294        x = x * math.sqrt(self.d_model)
295        # add constant to embedding
296        seq_len = x.size(-1)
297        x = x + self.pe[:, :, :seq_len].to(x.device)
298        return x

Forward pass.

class Predictor(torch.nn.modules.module.Module):
301class Predictor(nn.Module):
302    def __init__(self, dim, num_classes):
303        super(Predictor, self).__init__()
304        self.num_classes = num_classes
305        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
306        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
307
308    def forward(self, x):
309        """Forward pass."""
310        x = self.conv_out_1(x)
311        x = F.relu(x)
312        x = self.conv_out_2(x)
313        x = x.reshape((4, -1, self.num_classes, x.shape[-1]))
314        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Predictor(dim, num_classes)
302    def __init__(self, dim, num_classes):
303        super(Predictor, self).__init__()
304        self.num_classes = num_classes
305        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
306        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_classes
conv_out_1
conv_out_2
def forward(self, x):
308    def forward(self, x):
309        """Forward pass."""
310        x = self.conv_out_1(x)
311        x = F.relu(x)
312        x = self.conv_out_2(x)
313        x = x.reshape((4, -1, self.num_classes, x.shape[-1]))
314        return x

Forward pass.

class MultiHeadAttention(torch.nn.modules.module.Module):
317class MultiHeadAttention(nn.Module):
318    def __init__(self, heads, d_model, dropout=0.1):
319        super().__init__()
320
321        self.d_model = d_model
322        self.d_k = d_model // heads
323        self.h = heads
324
325        # print(f'{d_model=}')
326        self.q_linear = nn.Linear(d_model, d_model)
327        self.v_linear = nn.Linear(d_model, d_model)
328        self.k_linear = nn.Linear(d_model, d_model)
329        self.dropout = nn.Dropout(dropout)
330        self.out = nn.Linear(d_model, d_model)
331
332    def forward(self, q, k, v, mask=None):
333        """Forward pass."""
334        bs = q.size(0)
335        q = q.transpose(1, 2)
336        v = v.transpose(1, 2)
337        k = k.transpose(1, 2)
338
339        # perform linear operation and split into h heads
340        # print(f'{self.h=}, {self.d_k=}, {k.shape=}')
341        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
342        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
343        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
344
345        # transpose to get dimensions bs * h * sl * d_model
346
347        k = k.transpose(1, 2)
348        q = q.transpose(1, 2)
349        v = v.transpose(1, 2)
350        # calculate attention using function we will define next
351        scores = attention(q, k, v, self.d_k, mask, self.dropout)
352
353        # concatenate heads and put through final linear layer
354        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
355
356        output = self.out(concat).transpose(1, 2)
357
358        return output

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

MultiHeadAttention(heads, d_model, dropout=0.1)
318    def __init__(self, heads, d_model, dropout=0.1):
319        super().__init__()
320
321        self.d_model = d_model
322        self.d_k = d_model // heads
323        self.h = heads
324
325        # print(f'{d_model=}')
326        self.q_linear = nn.Linear(d_model, d_model)
327        self.v_linear = nn.Linear(d_model, d_model)
328        self.k_linear = nn.Linear(d_model, d_model)
329        self.dropout = nn.Dropout(dropout)
330        self.out = nn.Linear(d_model, d_model)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

d_model
d_k
h
q_linear
v_linear
k_linear
dropout
out
def forward(self, q, k, v, mask=None):
332    def forward(self, q, k, v, mask=None):
333        """Forward pass."""
334        bs = q.size(0)
335        q = q.transpose(1, 2)
336        v = v.transpose(1, 2)
337        k = k.transpose(1, 2)
338
339        # perform linear operation and split into h heads
340        # print(f'{self.h=}, {self.d_k=}, {k.shape=}')
341        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
342        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
343        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
344
345        # transpose to get dimensions bs * h * sl * d_model
346
347        k = k.transpose(1, 2)
348        q = q.transpose(1, 2)
349        v = v.transpose(1, 2)
350        # calculate attention using function we will define next
351        scores = attention(q, k, v, self.d_k, mask, self.dropout)
352
353        # concatenate heads and put through final linear layer
354        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
355
356        output = self.out(concat).transpose(1, 2)
357
358        return output

Forward pass.

def attention(q, k, v, d_k, mask=None, dropout=None):
361def attention(q, k, v, d_k, mask=None, dropout=None):
362    """Attention."""
363    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
364    if mask is not None:
365        mask = mask.unsqueeze(1)
366        scores = scores.masked_fill(mask == 0, -1e9)
367    scores = F.softmax(scores, dim=-1)
368
369    if dropout is not None:
370        scores = dropout(scores)
371
372    output = torch.matmul(scores, v)
373    return output

Attention.