dlc2action.model.asformer

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6# Incorporates code adapted from ASFormer by ChinaYi
  7# Original work Copyright (c) 2021 ChinaYi
  8# Source: https://github.com/ChinaYi/ASFormer
  9# Originally licensed under MIT License
 10# Combined work licensed under GNU AGPLv3
 11#
 12import copy
 13import math
 14from typing import List, Union
 15
 16import numpy as np
 17import torch
 18import torch.nn as nn
 19import torch.nn.functional as F
 20from dlc2action.model.base_model import Model
 21from torch import optim
 22
 23device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 24
 25
 26def exponential_descrease(idx_decoder, p=3):
 27    """Exponential decrease function for the attention window."""
 28    return math.exp(-p * idx_decoder)
 29
 30
 31class AttentionHelper(nn.Module):
 32    def __init__(self):
 33        super(AttentionHelper, self).__init__()
 34        self.softmax = nn.Softmax(dim=-1)
 35
 36    def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask):
 37        """
 38        scalar dot attention.
 39        :param proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length)
 40        :param proj_key: shape of (B, C, L)
 41        :param proj_val: shape of (B, C, L)
 42        :param padding_mask: shape of (B, C, L)
 43        :return: attention value of shape (B, C, L)
 44        """
 45        m, c1, l1 = proj_query.shape
 46        m, c2, l2 = proj_key.shape
 47
 48        assert c1 == c2
 49
 50        energy = torch.bmm(
 51            proj_query.permute(0, 2, 1), proj_key
 52        )  # out of shape (B, L1, L2)
 53        attention = energy / np.sqrt(c1)
 54        attention = attention + torch.log(
 55            padding_mask + 1e-6
 56        )  # mask the zero paddings. log(1e-6) for zero paddings
 57        attention = self.softmax(attention)
 58        attention = attention * padding_mask
 59        attention = attention.permute(0, 2, 1)
 60        out = torch.bmm(proj_val, attention)
 61        return out, attention
 62
 63
 64class AttLayer(nn.Module):
 65    def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type):  # r1 = r2
 66        self._fix_types(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)
 67        super(AttLayer, self).__init__()
 68
 69        self.query_conv = nn.Conv1d(
 70            in_channels=self.q_dim, out_channels=self.q_dim // self.r1, kernel_size=1
 71        )
 72        self.key_conv = nn.Conv1d(
 73            in_channels=self.k_dim, out_channels=self.k_dim // self.r2, kernel_size=1
 74        )
 75        self.value_conv = nn.Conv1d(
 76            in_channels=self.v_dim, out_channels=self.v_dim // self.r3, kernel_size=1
 77        )
 78
 79        self.conv_out = nn.Conv1d(
 80            in_channels=self.v_dim // self.r3, out_channels=self.v_dim, kernel_size=1
 81        )
 82
 83        assert self.att_type in ["normal_att", "block_att", "sliding_att"]
 84        assert self.stage in ["encoder", "decoder"]
 85
 86        self.att_helper = AttentionHelper()
 87        self.window_mask = self.construct_window_mask()
 88
 89    def _fix_types(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type):
 90        self.q_dim = int(float(q_dim))
 91        self.k_dim = int(float(k_dim))
 92        self.v_dim = int(float(v_dim))
 93        self.r1 = int(float(r1))
 94        self.r2 = int(float(r2))
 95        self.r3 = int(float(r3))
 96        self.bl = int(float(bl))
 97        self.stage = stage
 98        self.att_type = att_type
 99
100    def construct_window_mask(self):
101        """
102        construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention
103        """
104        window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2)))
105        for i in range(self.bl):
106            window_mask[:, :, i : i + self.bl] = 1
107        return window_mask.to(device)
108
109    def forward(self, x1, x2, mask):
110        """Forward pass."""
111        # x1 from the encoder
112        # x2 from the decoder
113
114        query = self.query_conv(x1)
115        key = self.key_conv(x1)
116
117        if self.stage == "decoder":
118            assert x2 is not None
119            value = self.value_conv(x2)
120        else:
121            value = self.value_conv(x1)
122
123        if self.att_type == "normal_att":
124            return self._normal_self_att(query, key, value, mask)
125        elif self.att_type == "block_att":
126            return self._block_wise_self_att(query, key, value, mask)
127        elif self.att_type == "sliding_att":
128            return self._sliding_window_self_att(query, key, value, mask)
129
130    def _normal_self_att(self, q, k, v, mask):
131        m_batchsize, c1, L = q.size()
132        _, c2, L = k.size()
133        _, c3, L = v.size()
134        padding_mask = torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :]
135        output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask)
136        output = self.conv_out(F.relu(output))
137        output = output[:, :, 0:L]
138        return output * mask[:, 0:1, :]
139
140    def _block_wise_self_att(self, q, k, v, mask):
141        m_batchsize, c1, L = q.size()
142        _, c2, L = k.size()
143        _, c3, L = v.size()
144
145        nb = L // self.bl
146        if L % self.bl != 0:
147            q = torch.cat(
148                [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)],
149                dim=-1,
150            )
151            k = torch.cat(
152                [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)],
153                dim=-1,
154            )
155            v = torch.cat(
156                [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)],
157                dim=-1,
158            )
159            nb += 1
160
161        padding_mask = torch.cat(
162            [
163                torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :],
164                torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device),
165            ],
166            dim=-1,
167        )
168
169        q = (
170            q.reshape(m_batchsize, c1, nb, self.bl)
171            .permute(0, 2, 1, 3)
172            .reshape(m_batchsize * nb, c1, self.bl)
173        )
174        padding_mask = (
175            padding_mask.reshape(m_batchsize, 1, nb, self.bl)
176            .permute(0, 2, 1, 3)
177            .reshape(m_batchsize * nb, 1, self.bl)
178        )
179        k = (
180            k.reshape(m_batchsize, c2, nb, self.bl)
181            .permute(0, 2, 1, 3)
182            .reshape(m_batchsize * nb, c2, self.bl)
183        )
184        v = (
185            v.reshape(m_batchsize, c3, nb, self.bl)
186            .permute(0, 2, 1, 3)
187            .reshape(m_batchsize * nb, c3, self.bl)
188        )
189
190        output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask)
191        output = self.conv_out(F.relu(output))
192
193        output = (
194            output.reshape(m_batchsize, nb, c3, self.bl)
195            .permute(0, 2, 1, 3)
196            .reshape(m_batchsize, c3, nb * self.bl)
197        )
198        output = output[:, :, 0:L]
199        return output * mask[:, 0:1, :]
200
201    def _sliding_window_self_att(self, q, k, v, mask):
202        m_batchsize, c1, L = q.size()
203        _, c2, _ = k.size()
204        _, c3, _ = v.size()
205
206        # padding zeros for the last segment
207        nb = L // self.bl
208        if L % self.bl != 0:
209            q = torch.cat(
210                [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)],
211                dim=-1,
212            )
213            k = torch.cat(
214                [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)],
215                dim=-1,
216            )
217            v = torch.cat(
218                [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)],
219                dim=-1,
220            )
221            nb += 1
222        padding_mask = torch.cat(
223            [
224                torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :],
225                torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device),
226            ],
227            dim=-1,
228        )
229
230        # sliding window approach, by splitting query_proj and key_proj into shape (c1, l) x (c1, 2l)
231        # sliding window for query_proj: reshape
232        q = (
233            q.reshape(m_batchsize, c1, nb, self.bl)
234            .permute(0, 2, 1, 3)
235            .reshape(m_batchsize * nb, c1, self.bl)
236        )
237
238        # sliding window approach for key_proj
239        # 1. add paddings at the start and end
240        k = torch.cat(
241            [
242                torch.zeros(m_batchsize, c2, self.bl // 2).to(device),
243                k,
244                torch.zeros(m_batchsize, c2, self.bl // 2).to(device),
245            ],
246            dim=-1,
247        )
248        v = torch.cat(
249            [
250                torch.zeros(m_batchsize, c3, self.bl // 2).to(device),
251                v,
252                torch.zeros(m_batchsize, c3, self.bl // 2).to(device),
253            ],
254            dim=-1,
255        )
256        padding_mask = torch.cat(
257            [
258                torch.zeros(m_batchsize, 1, self.bl // 2).to(device),
259                padding_mask,
260                torch.zeros(m_batchsize, 1, self.bl // 2).to(device),
261            ],
262            dim=-1,
263        )
264
265        # 2. reshape key_proj of shape (m_batchsize*nb, c1, 2*self.bl)
266        k = torch.cat(
267            [
268                k[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2]
269                for i in range(nb)
270            ],
271            dim=0,
272        )  # special case when self.bl = 1
273        v = torch.cat(
274            [
275                v[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2]
276                for i in range(nb)
277            ],
278            dim=0,
279        )
280        # 3. construct window mask of shape (1, l, 2l), and use it to generate final mask
281        padding_mask = torch.cat(
282            [
283                padding_mask[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2]
284                for i in range(nb)
285            ],
286            dim=0,
287        )  # of shape (m*nb, 1, 2l)
288        final_mask = self.window_mask.repeat(m_batchsize * nb, 1, 1) * padding_mask
289
290        output, attention = self.att_helper.scalar_dot_att(q, k, v, final_mask)
291        output = self.conv_out(F.relu(output))
292
293        output = (
294            output.reshape(m_batchsize, nb, -1, self.bl)
295            .permute(0, 2, 1, 3)
296            .reshape(m_batchsize, -1, nb * self.bl)
297        )
298        output = output[:, :, 0:L]
299        return output * mask[:, 0:1, :]
300
301
302class MultiHeadAttLayer(nn.Module):
303    def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, num_head):
304        super(MultiHeadAttLayer, self).__init__()
305        #         assert v_dim % num_head == 0
306        self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1)
307        self.layers = nn.ModuleList(
308            [
309                copy.deepcopy(
310                    AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)
311                )
312                for i in range(num_head)
313            ]
314        )
315        self.dropout = nn.Dropout(p=0.5)
316
317    def forward(self, x1, x2, mask):
318        """Forward pass."""
319        out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1)
320        out = self.conv_out(self.dropout(out))
321        return out
322
323
324class ConvFeedForward(nn.Module):
325    def __init__(self, dilation, in_channels, out_channels):
326        super(ConvFeedForward, self).__init__()
327        self.layer = nn.Sequential(
328            nn.Conv1d(
329                in_channels, out_channels, 3, padding=dilation, dilation=dilation
330            ),
331            nn.ReLU(),
332        )
333
334    def forward(self, x):
335        """Forward pass."""
336        return self.layer(x)
337
338
339class FCFeedForward(nn.Module):
340    def __init__(self, in_channels, out_channels):
341        super(FCFeedForward, self).__init__()
342        self.layer = nn.Sequential(
343            nn.Conv1d(in_channels, out_channels, 1),  # conv1d equals fc
344            nn.ReLU(),
345            nn.Dropout(),
346            nn.Conv1d(out_channels, out_channels, 1),
347        )
348
349    def forward(self, x):
350        """Forward pass."""
351        return self.layer(x)
352
353
354class AttModule(nn.Module):
355    def __init__(
356        self, dilation, in_channels, out_channels, r1, r2, att_type, stage, alpha
357    ):
358        super(AttModule, self).__init__()
359        self.feed_forward = ConvFeedForward(dilation, in_channels, out_channels)
360        self.instance_norm = nn.InstanceNorm1d(in_channels, track_running_stats=False)
361        self.att_layer = AttLayer(
362            in_channels,
363            in_channels,
364            out_channels,
365            r1,
366            r1,
367            r2,
368            dilation,
369            att_type=att_type,
370            stage=stage,
371        )  # dilation
372        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
373        self.dropout = nn.Dropout()
374        self.alpha = alpha
375
376    def forward(self, x, f, mask):
377        """Forward pass."""
378        out = self.feed_forward(x)
379        out = self.alpha * self.att_layer(self.instance_norm(out), f, mask) + out
380        out = self.conv_1x1(out)
381        out = self.dropout(out)
382        return (x + out) * mask[:, 0:1, :]
383
384
385class PositionalEncoding(nn.Module):
386    "Implement the PE function."
387
388    def __init__(self, d_model, max_len=10000):
389        super(PositionalEncoding, self).__init__()
390        # Compute the positional encodings once in log space.
391        pe = torch.zeros(max_len, d_model)
392        position = torch.arange(0, max_len).unsqueeze(1)
393        div_term = torch.exp(
394            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
395        )
396        pe[:, 0::2] = torch.sin(position * div_term)
397        pe[:, 1::2] = torch.cos(position * div_term)
398        pe = pe.unsqueeze(0).permute(0, 2, 1)  # of shape (1, d_model, l)
399        self.pe = nn.Parameter(pe, requires_grad=True)
400
401    #         self.register_buffer('pe', pe)
402
403    def forward(self, x):
404        """Forward pass."""
405        return x + self.pe[:, :, 0 : x.shape[2]]
406
407
408class Encoder(nn.Module):
409    def __init__(
410        self,
411        num_layers,
412        r1,
413        r2,
414        num_f_maps,
415        input_dim,
416        channel_masking_rate,
417        att_type,
418        alpha,
419    ):
420        super(Encoder, self).__init__()
421        self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)  # fc layer
422        self.layers = nn.ModuleList(
423            [
424                AttModule(
425                    2**i, num_f_maps, num_f_maps, r1, r2, att_type, "encoder", alpha
426                )
427                for i in range(num_layers)  # 2**i
428            ]
429        )
430
431        self.dropout = nn.Dropout2d(p=channel_masking_rate)
432        self.channel_masking_rate = channel_masking_rate
433
434    def forward(self, x, mask):
435        """Forward pass."""
436        if self.channel_masking_rate > 0:
437            x = x.unsqueeze(2)
438            x = self.dropout(x)
439            x = x.squeeze(2)
440
441        feature = self.conv_1x1(x)
442        for layer in self.layers:
443            feature = layer(feature, None, mask)
444
445        return feature
446
447
448class Decoder(nn.Module):
449    def __init__(
450        self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, att_type, alpha
451    ):
452        super(
453            Decoder, self
454        ).__init__()  # self.position_en = PositionalEncoding(d_model=num_f_maps)
455        self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)
456        self.layers = nn.ModuleList(
457            [
458                AttModule(
459                    2**i, num_f_maps, num_f_maps, r1, r2, att_type, "decoder", alpha
460                )
461                for i in range(num_layers)  # 2 ** i
462            ]
463        )
464        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
465
466    def forward(self, x, fencoder, mask):
467        """Forward pass."""
468        feature = self.conv_1x1(x)
469        for layer in self.layers:
470            feature = layer(feature, fencoder, mask)
471
472        out = self.conv_out(feature) * mask[:, 0:1, :]
473
474        return out, feature
475
476
477class MyTransformer(nn.Module):
478    def __init__(
479        self,
480        num_layers,
481        r1,
482        r2,
483        num_f_maps,
484        input_dim,
485        channel_masking_rate,
486    ):
487        super(MyTransformer, self).__init__()
488        self.encoder = Encoder(
489            num_layers,
490            r1,
491            r2,
492            num_f_maps,
493            input_dim,
494            channel_masking_rate,
495            att_type="sliding_att",
496            alpha=1,
497        )
498
499    def forward(self, x):
500        """Forward pass."""
501        mask = (x.sum(1).unsqueeze(1) != 0).int()
502        feature = self.encoder(x, mask)
503        feature = feature * mask
504
505        return feature
506
507
508class Predictor(nn.Module):
509    def __init__(
510        self,
511        num_layers,
512        r1,
513        r2,
514        num_f_maps,
515        num_classes,
516        num_decoders,
517    ):
518        super(Predictor, self).__init__()
519        self.decoders = nn.ModuleList(
520            [
521                copy.deepcopy(
522                    Decoder(
523                        num_layers,
524                        r1,
525                        r2,
526                        num_f_maps,
527                        num_classes,
528                        num_classes,
529                        att_type="sliding_att",
530                        alpha=exponential_descrease(s),
531                    )
532                )
533                for s in range(num_decoders)
534            ]
535        )  # num_decoders
536
537        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
538
539    def forward(self, x):
540        """Forward pass."""
541        mask = (x.sum(1).unsqueeze(1) != 0).int()
542        out = self.conv_out(x) * mask[:, 0:1, :]
543        outputs = out.unsqueeze(0)
544
545        for decoder in self.decoders:
546            out, x = decoder(
547                F.softmax(out, dim=1) * mask[:, 0:1, :], x * mask[:, 0:1, :], mask
548            )
549            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
550        return outputs
551
552
553class ASFormer(Model):
554    """
555    An implementation of ASFormer
556    """
557
558    def __init__(
559        self,
560        num_decoders:int,
561        num_layers:int,
562        r1:float,
563        r2:float,
564        num_f_maps:int,
565        input_dim:dict,
566        num_classes:int,
567        channel_masking_rate:float,
568        state_dict_path:str=None,
569        ssl_constructors:List=None,
570        ssl_types:List=None,
571        ssl_modules:List=None,
572    ):
573        input_dim = sum([x[0] for x in input_dim.values()])
574        self.num_f_maps = int(float(num_f_maps))
575        self.params = {
576            "num_layers": int(float(num_layers)),
577            "r1": float(r1),
578            "r2": float(r2),
579            "num_f_maps": self.num_f_maps,
580            "input_dim": int(float(input_dim)),
581            "channel_masking_rate": float(channel_masking_rate),
582        }
583        self.params_predictor = {
584            "num_layers": int(float(num_layers)),
585            "r1": r1,
586            "r2": r2,
587            "num_f_maps": self.num_f_maps,
588            "num_classes": int(float(num_classes)),
589            "num_decoders": int(float(num_decoders)),
590        }
591        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
592
593    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
594        return MyTransformer(**self.params)
595
596    def _predictor(self) -> torch.nn.Module:
597        return Predictor(**self.params_predictor)
598
599    def features_shape(self) -> torch.Size:
600        return torch.Size([self.num_f_maps])
device = device(type='cpu')
def exponential_descrease(idx_decoder, p=3):
27def exponential_descrease(idx_decoder, p=3):
28    """Exponential decrease function for the attention window."""
29    return math.exp(-p * idx_decoder)

Exponential decrease function for the attention window.

class AttentionHelper(torch.nn.modules.module.Module):
32class AttentionHelper(nn.Module):
33    def __init__(self):
34        super(AttentionHelper, self).__init__()
35        self.softmax = nn.Softmax(dim=-1)
36
37    def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask):
38        """
39        scalar dot attention.
40        :param proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length)
41        :param proj_key: shape of (B, C, L)
42        :param proj_val: shape of (B, C, L)
43        :param padding_mask: shape of (B, C, L)
44        :return: attention value of shape (B, C, L)
45        """
46        m, c1, l1 = proj_query.shape
47        m, c2, l2 = proj_key.shape
48
49        assert c1 == c2
50
51        energy = torch.bmm(
52            proj_query.permute(0, 2, 1), proj_key
53        )  # out of shape (B, L1, L2)
54        attention = energy / np.sqrt(c1)
55        attention = attention + torch.log(
56            padding_mask + 1e-6
57        )  # mask the zero paddings. log(1e-6) for zero paddings
58        attention = self.softmax(attention)
59        attention = attention * padding_mask
60        attention = attention.permute(0, 2, 1)
61        out = torch.bmm(proj_val, attention)
62        return out, attention

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AttentionHelper()
33    def __init__(self):
34        super(AttentionHelper, self).__init__()
35        self.softmax = nn.Softmax(dim=-1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

softmax
def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask):
37    def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask):
38        """
39        scalar dot attention.
40        :param proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length)
41        :param proj_key: shape of (B, C, L)
42        :param proj_val: shape of (B, C, L)
43        :param padding_mask: shape of (B, C, L)
44        :return: attention value of shape (B, C, L)
45        """
46        m, c1, l1 = proj_query.shape
47        m, c2, l2 = proj_key.shape
48
49        assert c1 == c2
50
51        energy = torch.bmm(
52            proj_query.permute(0, 2, 1), proj_key
53        )  # out of shape (B, L1, L2)
54        attention = energy / np.sqrt(c1)
55        attention = attention + torch.log(
56            padding_mask + 1e-6
57        )  # mask the zero paddings. log(1e-6) for zero paddings
58        attention = self.softmax(attention)
59        attention = attention * padding_mask
60        attention = attention.permute(0, 2, 1)
61        out = torch.bmm(proj_val, attention)
62        return out, attention

scalar dot attention.

Parameters
  • proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length)
  • proj_key: shape of (B, C, L)
  • proj_val: shape of (B, C, L)
  • padding_mask: shape of (B, C, L)
Returns

attention value of shape (B, C, L)

class AttLayer(torch.nn.modules.module.Module):
 65class AttLayer(nn.Module):
 66    def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type):  # r1 = r2
 67        self._fix_types(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)
 68        super(AttLayer, self).__init__()
 69
 70        self.query_conv = nn.Conv1d(
 71            in_channels=self.q_dim, out_channels=self.q_dim // self.r1, kernel_size=1
 72        )
 73        self.key_conv = nn.Conv1d(
 74            in_channels=self.k_dim, out_channels=self.k_dim // self.r2, kernel_size=1
 75        )
 76        self.value_conv = nn.Conv1d(
 77            in_channels=self.v_dim, out_channels=self.v_dim // self.r3, kernel_size=1
 78        )
 79
 80        self.conv_out = nn.Conv1d(
 81            in_channels=self.v_dim // self.r3, out_channels=self.v_dim, kernel_size=1
 82        )
 83
 84        assert self.att_type in ["normal_att", "block_att", "sliding_att"]
 85        assert self.stage in ["encoder", "decoder"]
 86
 87        self.att_helper = AttentionHelper()
 88        self.window_mask = self.construct_window_mask()
 89
 90    def _fix_types(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type):
 91        self.q_dim = int(float(q_dim))
 92        self.k_dim = int(float(k_dim))
 93        self.v_dim = int(float(v_dim))
 94        self.r1 = int(float(r1))
 95        self.r2 = int(float(r2))
 96        self.r3 = int(float(r3))
 97        self.bl = int(float(bl))
 98        self.stage = stage
 99        self.att_type = att_type
100
101    def construct_window_mask(self):
102        """
103        construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention
104        """
105        window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2)))
106        for i in range(self.bl):
107            window_mask[:, :, i : i + self.bl] = 1
108        return window_mask.to(device)
109
110    def forward(self, x1, x2, mask):
111        """Forward pass."""
112        # x1 from the encoder
113        # x2 from the decoder
114
115        query = self.query_conv(x1)
116        key = self.key_conv(x1)
117
118        if self.stage == "decoder":
119            assert x2 is not None
120            value = self.value_conv(x2)
121        else:
122            value = self.value_conv(x1)
123
124        if self.att_type == "normal_att":
125            return self._normal_self_att(query, key, value, mask)
126        elif self.att_type == "block_att":
127            return self._block_wise_self_att(query, key, value, mask)
128        elif self.att_type == "sliding_att":
129            return self._sliding_window_self_att(query, key, value, mask)
130
131    def _normal_self_att(self, q, k, v, mask):
132        m_batchsize, c1, L = q.size()
133        _, c2, L = k.size()
134        _, c3, L = v.size()
135        padding_mask = torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :]
136        output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask)
137        output = self.conv_out(F.relu(output))
138        output = output[:, :, 0:L]
139        return output * mask[:, 0:1, :]
140
141    def _block_wise_self_att(self, q, k, v, mask):
142        m_batchsize, c1, L = q.size()
143        _, c2, L = k.size()
144        _, c3, L = v.size()
145
146        nb = L // self.bl
147        if L % self.bl != 0:
148            q = torch.cat(
149                [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)],
150                dim=-1,
151            )
152            k = torch.cat(
153                [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)],
154                dim=-1,
155            )
156            v = torch.cat(
157                [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)],
158                dim=-1,
159            )
160            nb += 1
161
162        padding_mask = torch.cat(
163            [
164                torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :],
165                torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device),
166            ],
167            dim=-1,
168        )
169
170        q = (
171            q.reshape(m_batchsize, c1, nb, self.bl)
172            .permute(0, 2, 1, 3)
173            .reshape(m_batchsize * nb, c1, self.bl)
174        )
175        padding_mask = (
176            padding_mask.reshape(m_batchsize, 1, nb, self.bl)
177            .permute(0, 2, 1, 3)
178            .reshape(m_batchsize * nb, 1, self.bl)
179        )
180        k = (
181            k.reshape(m_batchsize, c2, nb, self.bl)
182            .permute(0, 2, 1, 3)
183            .reshape(m_batchsize * nb, c2, self.bl)
184        )
185        v = (
186            v.reshape(m_batchsize, c3, nb, self.bl)
187            .permute(0, 2, 1, 3)
188            .reshape(m_batchsize * nb, c3, self.bl)
189        )
190
191        output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask)
192        output = self.conv_out(F.relu(output))
193
194        output = (
195            output.reshape(m_batchsize, nb, c3, self.bl)
196            .permute(0, 2, 1, 3)
197            .reshape(m_batchsize, c3, nb * self.bl)
198        )
199        output = output[:, :, 0:L]
200        return output * mask[:, 0:1, :]
201
202    def _sliding_window_self_att(self, q, k, v, mask):
203        m_batchsize, c1, L = q.size()
204        _, c2, _ = k.size()
205        _, c3, _ = v.size()
206
207        # padding zeros for the last segment
208        nb = L // self.bl
209        if L % self.bl != 0:
210            q = torch.cat(
211                [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)],
212                dim=-1,
213            )
214            k = torch.cat(
215                [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)],
216                dim=-1,
217            )
218            v = torch.cat(
219                [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)],
220                dim=-1,
221            )
222            nb += 1
223        padding_mask = torch.cat(
224            [
225                torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :],
226                torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device),
227            ],
228            dim=-1,
229        )
230
231        # sliding window approach, by splitting query_proj and key_proj into shape (c1, l) x (c1, 2l)
232        # sliding window for query_proj: reshape
233        q = (
234            q.reshape(m_batchsize, c1, nb, self.bl)
235            .permute(0, 2, 1, 3)
236            .reshape(m_batchsize * nb, c1, self.bl)
237        )
238
239        # sliding window approach for key_proj
240        # 1. add paddings at the start and end
241        k = torch.cat(
242            [
243                torch.zeros(m_batchsize, c2, self.bl // 2).to(device),
244                k,
245                torch.zeros(m_batchsize, c2, self.bl // 2).to(device),
246            ],
247            dim=-1,
248        )
249        v = torch.cat(
250            [
251                torch.zeros(m_batchsize, c3, self.bl // 2).to(device),
252                v,
253                torch.zeros(m_batchsize, c3, self.bl // 2).to(device),
254            ],
255            dim=-1,
256        )
257        padding_mask = torch.cat(
258            [
259                torch.zeros(m_batchsize, 1, self.bl // 2).to(device),
260                padding_mask,
261                torch.zeros(m_batchsize, 1, self.bl // 2).to(device),
262            ],
263            dim=-1,
264        )
265
266        # 2. reshape key_proj of shape (m_batchsize*nb, c1, 2*self.bl)
267        k = torch.cat(
268            [
269                k[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2]
270                for i in range(nb)
271            ],
272            dim=0,
273        )  # special case when self.bl = 1
274        v = torch.cat(
275            [
276                v[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2]
277                for i in range(nb)
278            ],
279            dim=0,
280        )
281        # 3. construct window mask of shape (1, l, 2l), and use it to generate final mask
282        padding_mask = torch.cat(
283            [
284                padding_mask[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2]
285                for i in range(nb)
286            ],
287            dim=0,
288        )  # of shape (m*nb, 1, 2l)
289        final_mask = self.window_mask.repeat(m_batchsize * nb, 1, 1) * padding_mask
290
291        output, attention = self.att_helper.scalar_dot_att(q, k, v, final_mask)
292        output = self.conv_out(F.relu(output))
293
294        output = (
295            output.reshape(m_batchsize, nb, -1, self.bl)
296            .permute(0, 2, 1, 3)
297            .reshape(m_batchsize, -1, nb * self.bl)
298        )
299        output = output[:, :, 0:L]
300        return output * mask[:, 0:1, :]

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)
66    def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type):  # r1 = r2
67        self._fix_types(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)
68        super(AttLayer, self).__init__()
69
70        self.query_conv = nn.Conv1d(
71            in_channels=self.q_dim, out_channels=self.q_dim // self.r1, kernel_size=1
72        )
73        self.key_conv = nn.Conv1d(
74            in_channels=self.k_dim, out_channels=self.k_dim // self.r2, kernel_size=1
75        )
76        self.value_conv = nn.Conv1d(
77            in_channels=self.v_dim, out_channels=self.v_dim // self.r3, kernel_size=1
78        )
79
80        self.conv_out = nn.Conv1d(
81            in_channels=self.v_dim // self.r3, out_channels=self.v_dim, kernel_size=1
82        )
83
84        assert self.att_type in ["normal_att", "block_att", "sliding_att"]
85        assert self.stage in ["encoder", "decoder"]
86
87        self.att_helper = AttentionHelper()
88        self.window_mask = self.construct_window_mask()

Initialize internal Module state, shared by both nn.Module and ScriptModule.

query_conv
key_conv
value_conv
conv_out
att_helper
window_mask
def construct_window_mask(self):
101    def construct_window_mask(self):
102        """
103        construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention
104        """
105        window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2)))
106        for i in range(self.bl):
107            window_mask[:, :, i : i + self.bl] = 1
108        return window_mask.to(device)

construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention

def forward(self, x1, x2, mask):
110    def forward(self, x1, x2, mask):
111        """Forward pass."""
112        # x1 from the encoder
113        # x2 from the decoder
114
115        query = self.query_conv(x1)
116        key = self.key_conv(x1)
117
118        if self.stage == "decoder":
119            assert x2 is not None
120            value = self.value_conv(x2)
121        else:
122            value = self.value_conv(x1)
123
124        if self.att_type == "normal_att":
125            return self._normal_self_att(query, key, value, mask)
126        elif self.att_type == "block_att":
127            return self._block_wise_self_att(query, key, value, mask)
128        elif self.att_type == "sliding_att":
129            return self._sliding_window_self_att(query, key, value, mask)

Forward pass.

class MultiHeadAttLayer(torch.nn.modules.module.Module):
303class MultiHeadAttLayer(nn.Module):
304    def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, num_head):
305        super(MultiHeadAttLayer, self).__init__()
306        #         assert v_dim % num_head == 0
307        self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1)
308        self.layers = nn.ModuleList(
309            [
310                copy.deepcopy(
311                    AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)
312                )
313                for i in range(num_head)
314            ]
315        )
316        self.dropout = nn.Dropout(p=0.5)
317
318    def forward(self, x1, x2, mask):
319        """Forward pass."""
320        out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1)
321        out = self.conv_out(self.dropout(out))
322        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

MultiHeadAttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, num_head)
304    def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, num_head):
305        super(MultiHeadAttLayer, self).__init__()
306        #         assert v_dim % num_head == 0
307        self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1)
308        self.layers = nn.ModuleList(
309            [
310                copy.deepcopy(
311                    AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)
312                )
313                for i in range(num_head)
314            ]
315        )
316        self.dropout = nn.Dropout(p=0.5)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv_out
layers
dropout
def forward(self, x1, x2, mask):
318    def forward(self, x1, x2, mask):
319        """Forward pass."""
320        out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1)
321        out = self.conv_out(self.dropout(out))
322        return out

Forward pass.

class ConvFeedForward(torch.nn.modules.module.Module):
325class ConvFeedForward(nn.Module):
326    def __init__(self, dilation, in_channels, out_channels):
327        super(ConvFeedForward, self).__init__()
328        self.layer = nn.Sequential(
329            nn.Conv1d(
330                in_channels, out_channels, 3, padding=dilation, dilation=dilation
331            ),
332            nn.ReLU(),
333        )
334
335    def forward(self, x):
336        """Forward pass."""
337        return self.layer(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ConvFeedForward(dilation, in_channels, out_channels)
326    def __init__(self, dilation, in_channels, out_channels):
327        super(ConvFeedForward, self).__init__()
328        self.layer = nn.Sequential(
329            nn.Conv1d(
330                in_channels, out_channels, 3, padding=dilation, dilation=dilation
331            ),
332            nn.ReLU(),
333        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

layer
def forward(self, x):
335    def forward(self, x):
336        """Forward pass."""
337        return self.layer(x)

Forward pass.

class FCFeedForward(torch.nn.modules.module.Module):
340class FCFeedForward(nn.Module):
341    def __init__(self, in_channels, out_channels):
342        super(FCFeedForward, self).__init__()
343        self.layer = nn.Sequential(
344            nn.Conv1d(in_channels, out_channels, 1),  # conv1d equals fc
345            nn.ReLU(),
346            nn.Dropout(),
347            nn.Conv1d(out_channels, out_channels, 1),
348        )
349
350    def forward(self, x):
351        """Forward pass."""
352        return self.layer(x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

FCFeedForward(in_channels, out_channels)
341    def __init__(self, in_channels, out_channels):
342        super(FCFeedForward, self).__init__()
343        self.layer = nn.Sequential(
344            nn.Conv1d(in_channels, out_channels, 1),  # conv1d equals fc
345            nn.ReLU(),
346            nn.Dropout(),
347            nn.Conv1d(out_channels, out_channels, 1),
348        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

layer
def forward(self, x):
350    def forward(self, x):
351        """Forward pass."""
352        return self.layer(x)

Forward pass.

class AttModule(torch.nn.modules.module.Module):
355class AttModule(nn.Module):
356    def __init__(
357        self, dilation, in_channels, out_channels, r1, r2, att_type, stage, alpha
358    ):
359        super(AttModule, self).__init__()
360        self.feed_forward = ConvFeedForward(dilation, in_channels, out_channels)
361        self.instance_norm = nn.InstanceNorm1d(in_channels, track_running_stats=False)
362        self.att_layer = AttLayer(
363            in_channels,
364            in_channels,
365            out_channels,
366            r1,
367            r1,
368            r2,
369            dilation,
370            att_type=att_type,
371            stage=stage,
372        )  # dilation
373        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
374        self.dropout = nn.Dropout()
375        self.alpha = alpha
376
377    def forward(self, x, f, mask):
378        """Forward pass."""
379        out = self.feed_forward(x)
380        out = self.alpha * self.att_layer(self.instance_norm(out), f, mask) + out
381        out = self.conv_1x1(out)
382        out = self.dropout(out)
383        return (x + out) * mask[:, 0:1, :]

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

AttModule(dilation, in_channels, out_channels, r1, r2, att_type, stage, alpha)
356    def __init__(
357        self, dilation, in_channels, out_channels, r1, r2, att_type, stage, alpha
358    ):
359        super(AttModule, self).__init__()
360        self.feed_forward = ConvFeedForward(dilation, in_channels, out_channels)
361        self.instance_norm = nn.InstanceNorm1d(in_channels, track_running_stats=False)
362        self.att_layer = AttLayer(
363            in_channels,
364            in_channels,
365            out_channels,
366            r1,
367            r1,
368            r2,
369            dilation,
370            att_type=att_type,
371            stage=stage,
372        )  # dilation
373        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
374        self.dropout = nn.Dropout()
375        self.alpha = alpha

Initialize internal Module state, shared by both nn.Module and ScriptModule.

feed_forward
instance_norm
att_layer
conv_1x1
dropout
alpha
def forward(self, x, f, mask):
377    def forward(self, x, f, mask):
378        """Forward pass."""
379        out = self.feed_forward(x)
380        out = self.alpha * self.att_layer(self.instance_norm(out), f, mask) + out
381        out = self.conv_1x1(out)
382        out = self.dropout(out)
383        return (x + out) * mask[:, 0:1, :]

Forward pass.

class PositionalEncoding(torch.nn.modules.module.Module):
386class PositionalEncoding(nn.Module):
387    "Implement the PE function."
388
389    def __init__(self, d_model, max_len=10000):
390        super(PositionalEncoding, self).__init__()
391        # Compute the positional encodings once in log space.
392        pe = torch.zeros(max_len, d_model)
393        position = torch.arange(0, max_len).unsqueeze(1)
394        div_term = torch.exp(
395            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
396        )
397        pe[:, 0::2] = torch.sin(position * div_term)
398        pe[:, 1::2] = torch.cos(position * div_term)
399        pe = pe.unsqueeze(0).permute(0, 2, 1)  # of shape (1, d_model, l)
400        self.pe = nn.Parameter(pe, requires_grad=True)
401
402    #         self.register_buffer('pe', pe)
403
404    def forward(self, x):
405        """Forward pass."""
406        return x + self.pe[:, :, 0 : x.shape[2]]

Implement the PE function.

PositionalEncoding(d_model, max_len=10000)
389    def __init__(self, d_model, max_len=10000):
390        super(PositionalEncoding, self).__init__()
391        # Compute the positional encodings once in log space.
392        pe = torch.zeros(max_len, d_model)
393        position = torch.arange(0, max_len).unsqueeze(1)
394        div_term = torch.exp(
395            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
396        )
397        pe[:, 0::2] = torch.sin(position * div_term)
398        pe[:, 1::2] = torch.cos(position * div_term)
399        pe = pe.unsqueeze(0).permute(0, 2, 1)  # of shape (1, d_model, l)
400        self.pe = nn.Parameter(pe, requires_grad=True)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

pe
def forward(self, x):
404    def forward(self, x):
405        """Forward pass."""
406        return x + self.pe[:, :, 0 : x.shape[2]]

Forward pass.

class Encoder(torch.nn.modules.module.Module):
409class Encoder(nn.Module):
410    def __init__(
411        self,
412        num_layers,
413        r1,
414        r2,
415        num_f_maps,
416        input_dim,
417        channel_masking_rate,
418        att_type,
419        alpha,
420    ):
421        super(Encoder, self).__init__()
422        self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)  # fc layer
423        self.layers = nn.ModuleList(
424            [
425                AttModule(
426                    2**i, num_f_maps, num_f_maps, r1, r2, att_type, "encoder", alpha
427                )
428                for i in range(num_layers)  # 2**i
429            ]
430        )
431
432        self.dropout = nn.Dropout2d(p=channel_masking_rate)
433        self.channel_masking_rate = channel_masking_rate
434
435    def forward(self, x, mask):
436        """Forward pass."""
437        if self.channel_masking_rate > 0:
438            x = x.unsqueeze(2)
439            x = self.dropout(x)
440            x = x.squeeze(2)
441
442        feature = self.conv_1x1(x)
443        for layer in self.layers:
444            feature = layer(feature, None, mask)
445
446        return feature

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Encoder( num_layers, r1, r2, num_f_maps, input_dim, channel_masking_rate, att_type, alpha)
410    def __init__(
411        self,
412        num_layers,
413        r1,
414        r2,
415        num_f_maps,
416        input_dim,
417        channel_masking_rate,
418        att_type,
419        alpha,
420    ):
421        super(Encoder, self).__init__()
422        self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)  # fc layer
423        self.layers = nn.ModuleList(
424            [
425                AttModule(
426                    2**i, num_f_maps, num_f_maps, r1, r2, att_type, "encoder", alpha
427                )
428                for i in range(num_layers)  # 2**i
429            ]
430        )
431
432        self.dropout = nn.Dropout2d(p=channel_masking_rate)
433        self.channel_masking_rate = channel_masking_rate

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv_1x1
layers
dropout
channel_masking_rate
def forward(self, x, mask):
435    def forward(self, x, mask):
436        """Forward pass."""
437        if self.channel_masking_rate > 0:
438            x = x.unsqueeze(2)
439            x = self.dropout(x)
440            x = x.squeeze(2)
441
442        feature = self.conv_1x1(x)
443        for layer in self.layers:
444            feature = layer(feature, None, mask)
445
446        return feature

Forward pass.

class Decoder(torch.nn.modules.module.Module):
449class Decoder(nn.Module):
450    def __init__(
451        self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, att_type, alpha
452    ):
453        super(
454            Decoder, self
455        ).__init__()  # self.position_en = PositionalEncoding(d_model=num_f_maps)
456        self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)
457        self.layers = nn.ModuleList(
458            [
459                AttModule(
460                    2**i, num_f_maps, num_f_maps, r1, r2, att_type, "decoder", alpha
461                )
462                for i in range(num_layers)  # 2 ** i
463            ]
464        )
465        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
466
467    def forward(self, x, fencoder, mask):
468        """Forward pass."""
469        feature = self.conv_1x1(x)
470        for layer in self.layers:
471            feature = layer(feature, fencoder, mask)
472
473        out = self.conv_out(feature) * mask[:, 0:1, :]
474
475        return out, feature

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Decoder( num_layers, r1, r2, num_f_maps, input_dim, num_classes, att_type, alpha)
450    def __init__(
451        self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, att_type, alpha
452    ):
453        super(
454            Decoder, self
455        ).__init__()  # self.position_en = PositionalEncoding(d_model=num_f_maps)
456        self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)
457        self.layers = nn.ModuleList(
458            [
459                AttModule(
460                    2**i, num_f_maps, num_f_maps, r1, r2, att_type, "decoder", alpha
461                )
462                for i in range(num_layers)  # 2 ** i
463            ]
464        )
465        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv_1x1
layers
conv_out
def forward(self, x, fencoder, mask):
467    def forward(self, x, fencoder, mask):
468        """Forward pass."""
469        feature = self.conv_1x1(x)
470        for layer in self.layers:
471            feature = layer(feature, fencoder, mask)
472
473        out = self.conv_out(feature) * mask[:, 0:1, :]
474
475        return out, feature

Forward pass.

class MyTransformer(torch.nn.modules.module.Module):
478class MyTransformer(nn.Module):
479    def __init__(
480        self,
481        num_layers,
482        r1,
483        r2,
484        num_f_maps,
485        input_dim,
486        channel_masking_rate,
487    ):
488        super(MyTransformer, self).__init__()
489        self.encoder = Encoder(
490            num_layers,
491            r1,
492            r2,
493            num_f_maps,
494            input_dim,
495            channel_masking_rate,
496            att_type="sliding_att",
497            alpha=1,
498        )
499
500    def forward(self, x):
501        """Forward pass."""
502        mask = (x.sum(1).unsqueeze(1) != 0).int()
503        feature = self.encoder(x, mask)
504        feature = feature * mask
505
506        return feature

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

MyTransformer(num_layers, r1, r2, num_f_maps, input_dim, channel_masking_rate)
479    def __init__(
480        self,
481        num_layers,
482        r1,
483        r2,
484        num_f_maps,
485        input_dim,
486        channel_masking_rate,
487    ):
488        super(MyTransformer, self).__init__()
489        self.encoder = Encoder(
490            num_layers,
491            r1,
492            r2,
493            num_f_maps,
494            input_dim,
495            channel_masking_rate,
496            att_type="sliding_att",
497            alpha=1,
498        )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

encoder
def forward(self, x):
500    def forward(self, x):
501        """Forward pass."""
502        mask = (x.sum(1).unsqueeze(1) != 0).int()
503        feature = self.encoder(x, mask)
504        feature = feature * mask
505
506        return feature

Forward pass.

class Predictor(torch.nn.modules.module.Module):
509class Predictor(nn.Module):
510    def __init__(
511        self,
512        num_layers,
513        r1,
514        r2,
515        num_f_maps,
516        num_classes,
517        num_decoders,
518    ):
519        super(Predictor, self).__init__()
520        self.decoders = nn.ModuleList(
521            [
522                copy.deepcopy(
523                    Decoder(
524                        num_layers,
525                        r1,
526                        r2,
527                        num_f_maps,
528                        num_classes,
529                        num_classes,
530                        att_type="sliding_att",
531                        alpha=exponential_descrease(s),
532                    )
533                )
534                for s in range(num_decoders)
535            ]
536        )  # num_decoders
537
538        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
539
540    def forward(self, x):
541        """Forward pass."""
542        mask = (x.sum(1).unsqueeze(1) != 0).int()
543        out = self.conv_out(x) * mask[:, 0:1, :]
544        outputs = out.unsqueeze(0)
545
546        for decoder in self.decoders:
547            out, x = decoder(
548                F.softmax(out, dim=1) * mask[:, 0:1, :], x * mask[:, 0:1, :], mask
549            )
550            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
551        return outputs

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Predictor(num_layers, r1, r2, num_f_maps, num_classes, num_decoders)
510    def __init__(
511        self,
512        num_layers,
513        r1,
514        r2,
515        num_f_maps,
516        num_classes,
517        num_decoders,
518    ):
519        super(Predictor, self).__init__()
520        self.decoders = nn.ModuleList(
521            [
522                copy.deepcopy(
523                    Decoder(
524                        num_layers,
525                        r1,
526                        r2,
527                        num_f_maps,
528                        num_classes,
529                        num_classes,
530                        att_type="sliding_att",
531                        alpha=exponential_descrease(s),
532                    )
533                )
534                for s in range(num_decoders)
535            ]
536        )  # num_decoders
537
538        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

decoders
conv_out
def forward(self, x):
540    def forward(self, x):
541        """Forward pass."""
542        mask = (x.sum(1).unsqueeze(1) != 0).int()
543        out = self.conv_out(x) * mask[:, 0:1, :]
544        outputs = out.unsqueeze(0)
545
546        for decoder in self.decoders:
547            out, x = decoder(
548                F.softmax(out, dim=1) * mask[:, 0:1, :], x * mask[:, 0:1, :], mask
549            )
550            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
551        return outputs

Forward pass.

class ASFormer(dlc2action.model.base_model.Model):
554class ASFormer(Model):
555    """
556    An implementation of ASFormer
557    """
558
559    def __init__(
560        self,
561        num_decoders:int,
562        num_layers:int,
563        r1:float,
564        r2:float,
565        num_f_maps:int,
566        input_dim:dict,
567        num_classes:int,
568        channel_masking_rate:float,
569        state_dict_path:str=None,
570        ssl_constructors:List=None,
571        ssl_types:List=None,
572        ssl_modules:List=None,
573    ):
574        input_dim = sum([x[0] for x in input_dim.values()])
575        self.num_f_maps = int(float(num_f_maps))
576        self.params = {
577            "num_layers": int(float(num_layers)),
578            "r1": float(r1),
579            "r2": float(r2),
580            "num_f_maps": self.num_f_maps,
581            "input_dim": int(float(input_dim)),
582            "channel_masking_rate": float(channel_masking_rate),
583        }
584        self.params_predictor = {
585            "num_layers": int(float(num_layers)),
586            "r1": r1,
587            "r2": r2,
588            "num_f_maps": self.num_f_maps,
589            "num_classes": int(float(num_classes)),
590            "num_decoders": int(float(num_decoders)),
591        }
592        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
593
594    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
595        return MyTransformer(**self.params)
596
597    def _predictor(self) -> torch.nn.Module:
598        return Predictor(**self.params_predictor)
599
600    def features_shape(self) -> torch.Size:
601        return torch.Size([self.num_f_maps])

An implementation of ASFormer

ASFormer( num_decoders: int, num_layers: int, r1: float, r2: float, num_f_maps: int, input_dim: dict, num_classes: int, channel_masking_rate: float, state_dict_path: str = None, ssl_constructors: List = None, ssl_types: List = None, ssl_modules: List = None)
559    def __init__(
560        self,
561        num_decoders:int,
562        num_layers:int,
563        r1:float,
564        r2:float,
565        num_f_maps:int,
566        input_dim:dict,
567        num_classes:int,
568        channel_masking_rate:float,
569        state_dict_path:str=None,
570        ssl_constructors:List=None,
571        ssl_types:List=None,
572        ssl_modules:List=None,
573    ):
574        input_dim = sum([x[0] for x in input_dim.values()])
575        self.num_f_maps = int(float(num_f_maps))
576        self.params = {
577            "num_layers": int(float(num_layers)),
578            "r1": float(r1),
579            "r2": float(r2),
580            "num_f_maps": self.num_f_maps,
581            "input_dim": int(float(input_dim)),
582            "channel_masking_rate": float(channel_masking_rate),
583        }
584        self.params_predictor = {
585            "num_layers": int(float(num_layers)),
586            "r1": r1,
587            "r2": r2,
588            "num_f_maps": self.num_f_maps,
589            "num_classes": int(float(num_classes)),
590            "num_decoders": int(float(num_decoders)),
591        }
592        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Initialize the model.

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt

num_f_maps
params
params_predictor
def features_shape(self) -> torch.Size:
600    def features_shape(self) -> torch.Size:
601        return torch.Size([self.num_f_maps])

Get the shape of feature extractor output.

Returns

feature_shape : torch.Size shape of feature extractor output