dlc2action.model.asformer

ASFormer

Adapted from https://github.com/ChinaYi/ASFormer

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6#
  7# Adapted from ASFormer by ChinaYi
  8# Adapted from https://github.com/ChinaYi/ASFormer
  9# Licensed under MIT License
 10#
 11""" ASFormer
 12
 13Adapted from https://github.com/ChinaYi/ASFormer
 14"""
 15
 16import torch
 17import torch.nn as nn
 18import torch.nn.functional as F
 19from torch import optim
 20from dlc2action.model.base_model import Model
 21from typing import Union, List
 22
 23import copy
 24import numpy as np
 25import math
 26
 27device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 28
 29
 30def _exponential_descrease(idx_decoder, p=3):
 31    return math.exp(-p * idx_decoder)
 32
 33
 34class _AttentionHelper(nn.Module):
 35    def __init__(self):
 36        super(_AttentionHelper, self).__init__()
 37        self.softmax = nn.Softmax(dim=-1)
 38
 39    def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask):
 40        """
 41        scalar dot attention.
 42        :param proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length)
 43        :param proj_key: shape of (B, C, L)
 44        :param proj_val: shape of (B, C, L)
 45        :param padding_mask: shape of (B, C, L)
 46        :return: attention value of shape (B, C, L)
 47        """
 48        m, c1, l1 = proj_query.shape
 49        m, c2, l2 = proj_key.shape
 50
 51        assert c1 == c2
 52
 53        energy = torch.bmm(
 54            proj_query.permute(0, 2, 1), proj_key
 55        )  # out of shape (B, L1, L2)
 56        attention = energy / np.sqrt(c1)
 57        attention = attention + torch.log(
 58            padding_mask + 1e-6
 59        )  # mask the zero paddings. log(1e-6) for zero paddings
 60        attention = self.softmax(attention)
 61        attention = attention * padding_mask
 62        attention = attention.permute(0, 2, 1)
 63        out = torch.bmm(proj_val, attention)
 64        return out, attention
 65
 66
 67class _AttLayer(nn.Module):
 68    def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type):  # r1 = r2
 69        super(_AttLayer, self).__init__()
 70
 71        self.query_conv = nn.Conv1d(
 72            in_channels=q_dim, out_channels=q_dim // r1, kernel_size=1
 73        )
 74        self.key_conv = nn.Conv1d(
 75            in_channels=k_dim, out_channels=k_dim // r2, kernel_size=1
 76        )
 77        self.value_conv = nn.Conv1d(
 78            in_channels=v_dim, out_channels=v_dim // r3, kernel_size=1
 79        )
 80
 81        self.conv_out = nn.Conv1d(
 82            in_channels=v_dim // r3, out_channels=v_dim, kernel_size=1
 83        )
 84
 85        self.bl = bl
 86        self.stage = stage
 87        self.att_type = att_type
 88        assert self.att_type in ["normal_att", "block_att", "sliding_att"]
 89        assert self.stage in ["encoder", "decoder"]
 90
 91        self.att_helper = _AttentionHelper()
 92        self.window_mask = self.construct_window_mask()
 93
 94    def construct_window_mask(self):
 95        """
 96        construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention
 97        """
 98        window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2)))
 99        for i in range(self.bl):
100            window_mask[:, :, i : i + self.bl] = 1
101        return window_mask.to(device)
102
103    def forward(self, x1, x2, mask):
104        # x1 from the encoder
105        # x2 from the decoder
106
107        query = self.query_conv(x1)
108        key = self.key_conv(x1)
109
110        if self.stage == "decoder":
111            assert x2 is not None
112            value = self.value_conv(x2)
113        else:
114            value = self.value_conv(x1)
115
116        if self.att_type == "normal_att":
117            return self._normal_self_att(query, key, value, mask)
118        elif self.att_type == "block_att":
119            return self._block_wise_self_att(query, key, value, mask)
120        elif self.att_type == "sliding_att":
121            return self._sliding_window_self_att(query, key, value, mask)
122
123    def _normal_self_att(self, q, k, v, mask):
124        m_batchsize, c1, L = q.size()
125        _, c2, L = k.size()
126        _, c3, L = v.size()
127        padding_mask = torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :]
128        output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask)
129        output = self.conv_out(F.relu(output))
130        output = output[:, :, 0:L]
131        return output * mask[:, 0:1, :]
132
133    def _block_wise_self_att(self, q, k, v, mask):
134        m_batchsize, c1, L = q.size()
135        _, c2, L = k.size()
136        _, c3, L = v.size()
137
138        nb = L // self.bl
139        if L % self.bl != 0:
140            q = torch.cat(
141                [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)],
142                dim=-1,
143            )
144            k = torch.cat(
145                [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)],
146                dim=-1,
147            )
148            v = torch.cat(
149                [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)],
150                dim=-1,
151            )
152            nb += 1
153
154        padding_mask = torch.cat(
155            [
156                torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :],
157                torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device),
158            ],
159            dim=-1,
160        )
161
162        q = (
163            q.reshape(m_batchsize, c1, nb, self.bl)
164            .permute(0, 2, 1, 3)
165            .reshape(m_batchsize * nb, c1, self.bl)
166        )
167        padding_mask = (
168            padding_mask.reshape(m_batchsize, 1, nb, self.bl)
169            .permute(0, 2, 1, 3)
170            .reshape(m_batchsize * nb, 1, self.bl)
171        )
172        k = (
173            k.reshape(m_batchsize, c2, nb, self.bl)
174            .permute(0, 2, 1, 3)
175            .reshape(m_batchsize * nb, c2, self.bl)
176        )
177        v = (
178            v.reshape(m_batchsize, c3, nb, self.bl)
179            .permute(0, 2, 1, 3)
180            .reshape(m_batchsize * nb, c3, self.bl)
181        )
182
183        output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask)
184        output = self.conv_out(F.relu(output))
185
186        output = (
187            output.reshape(m_batchsize, nb, c3, self.bl)
188            .permute(0, 2, 1, 3)
189            .reshape(m_batchsize, c3, nb * self.bl)
190        )
191        output = output[:, :, 0:L]
192        return output * mask[:, 0:1, :]
193
194    def _sliding_window_self_att(self, q, k, v, mask):
195        m_batchsize, c1, L = q.size()
196        _, c2, _ = k.size()
197        _, c3, _ = v.size()
198
199        # padding zeros for the last segment
200        nb = L // self.bl
201        if L % self.bl != 0:
202            q = torch.cat(
203                [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)],
204                dim=-1,
205            )
206            k = torch.cat(
207                [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)],
208                dim=-1,
209            )
210            v = torch.cat(
211                [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)],
212                dim=-1,
213            )
214            nb += 1
215        padding_mask = torch.cat(
216            [
217                torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :],
218                torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device),
219            ],
220            dim=-1,
221        )
222
223        # sliding window approach, by splitting query_proj and key_proj into shape (c1, l) x (c1, 2l)
224        # sliding window for query_proj: reshape
225        q = (
226            q.reshape(m_batchsize, c1, nb, self.bl)
227            .permute(0, 2, 1, 3)
228            .reshape(m_batchsize * nb, c1, self.bl)
229        )
230
231        # sliding window approach for key_proj
232        # 1. add paddings at the start and end
233        k = torch.cat(
234            [
235                torch.zeros(m_batchsize, c2, self.bl // 2).to(device),
236                k,
237                torch.zeros(m_batchsize, c2, self.bl // 2).to(device),
238            ],
239            dim=-1,
240        )
241        v = torch.cat(
242            [
243                torch.zeros(m_batchsize, c3, self.bl // 2).to(device),
244                v,
245                torch.zeros(m_batchsize, c3, self.bl // 2).to(device),
246            ],
247            dim=-1,
248        )
249        padding_mask = torch.cat(
250            [
251                torch.zeros(m_batchsize, 1, self.bl // 2).to(device),
252                padding_mask,
253                torch.zeros(m_batchsize, 1, self.bl // 2).to(device),
254            ],
255            dim=-1,
256        )
257
258        # 2. reshape key_proj of shape (m_batchsize*nb, c1, 2*self.bl)
259        k = torch.cat(
260            [
261                k[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2]
262                for i in range(nb)
263            ],
264            dim=0,
265        )  # special case when self.bl = 1
266        v = torch.cat(
267            [
268                v[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2]
269                for i in range(nb)
270            ],
271            dim=0,
272        )
273        # 3. construct window mask of shape (1, l, 2l), and use it to generate final mask
274        padding_mask = torch.cat(
275            [
276                padding_mask[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2]
277                for i in range(nb)
278            ],
279            dim=0,
280        )  # of shape (m*nb, 1, 2l)
281        final_mask = self.window_mask.repeat(m_batchsize * nb, 1, 1) * padding_mask
282
283        output, attention = self.att_helper.scalar_dot_att(q, k, v, final_mask)
284        output = self.conv_out(F.relu(output))
285
286        output = (
287            output.reshape(m_batchsize, nb, -1, self.bl)
288            .permute(0, 2, 1, 3)
289            .reshape(m_batchsize, -1, nb * self.bl)
290        )
291        output = output[:, :, 0:L]
292        return output * mask[:, 0:1, :]
293
294
295class _MultiHeadAttLayer(nn.Module):
296    def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, num_head):
297        super(_MultiHeadAttLayer, self).__init__()
298        #         assert v_dim % num_head == 0
299        self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1)
300        self.layers = nn.ModuleList(
301            [
302                copy.deepcopy(
303                    _AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type)
304                )
305                for i in range(num_head)
306            ]
307        )
308        self.dropout = nn.Dropout(p=0.5)
309
310    def forward(self, x1, x2, mask):
311        out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1)
312        out = self.conv_out(self.dropout(out))
313        return out
314
315
316class _ConvFeedForward(nn.Module):
317    def __init__(self, dilation, in_channels, out_channels):
318        super(_ConvFeedForward, self).__init__()
319        self.layer = nn.Sequential(
320            nn.Conv1d(
321                in_channels, out_channels, 3, padding=dilation, dilation=dilation
322            ),
323            nn.ReLU(),
324        )
325
326    def forward(self, x):
327        return self.layer(x)
328
329
330class _FCFeedForward(nn.Module):
331    def __init__(self, in_channels, out_channels):
332        super(_FCFeedForward, self).__init__()
333        self.layer = nn.Sequential(
334            nn.Conv1d(in_channels, out_channels, 1),  # conv1d equals fc
335            nn.ReLU(),
336            nn.Dropout(),
337            nn.Conv1d(out_channels, out_channels, 1),
338        )
339
340    def forward(self, x):
341        return self.layer(x)
342
343
344class _AttModule(nn.Module):
345    def __init__(
346        self, dilation, in_channels, out_channels, r1, r2, att_type, stage, alpha
347    ):
348        super(_AttModule, self).__init__()
349        self.feed_forward = _ConvFeedForward(dilation, in_channels, out_channels)
350        self.instance_norm = nn.InstanceNorm1d(in_channels, track_running_stats=False)
351        self.att_layer = _AttLayer(
352            in_channels,
353            in_channels,
354            out_channels,
355            r1,
356            r1,
357            r2,
358            dilation,
359            att_type=att_type,
360            stage=stage,
361        )  # dilation
362        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
363        self.dropout = nn.Dropout()
364        self.alpha = alpha
365
366    def forward(self, x, f, mask):
367        out = self.feed_forward(x)
368        out = self.alpha * self.att_layer(self.instance_norm(out), f, mask) + out
369        out = self.conv_1x1(out)
370        out = self.dropout(out)
371        return (x + out) * mask[:, 0:1, :]
372
373
374class _PositionalEncoding(nn.Module):
375    "Implement the PE function."
376
377    def __init__(self, d_model, max_len=10000):
378        super(_PositionalEncoding, self).__init__()
379        # Compute the positional encodings once in log space.
380        pe = torch.zeros(max_len, d_model)
381        position = torch.arange(0, max_len).unsqueeze(1)
382        div_term = torch.exp(
383            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
384        )
385        pe[:, 0::2] = torch.sin(position * div_term)
386        pe[:, 1::2] = torch.cos(position * div_term)
387        pe = pe.unsqueeze(0).permute(0, 2, 1)  # of shape (1, d_model, l)
388        self.pe = nn.Parameter(pe, requires_grad=True)
389
390    #         self.register_buffer('pe', pe)
391
392    def forward(self, x):
393        return x + self.pe[:, :, 0 : x.shape[2]]
394
395
396class _Encoder(nn.Module):
397    def __init__(
398        self,
399        num_layers,
400        r1,
401        r2,
402        num_f_maps,
403        input_dim,
404        channel_masking_rate,
405        att_type,
406        alpha,
407    ):
408        super(_Encoder, self).__init__()
409        self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)  # fc layer
410        self.layers = nn.ModuleList(
411            [
412                _AttModule(
413                    2**i, num_f_maps, num_f_maps, r1, r2, att_type, "encoder", alpha
414                )
415                for i in range(num_layers)  # 2**i
416            ]
417        )
418
419        self.dropout = nn.Dropout2d(p=channel_masking_rate)
420        self.channel_masking_rate = channel_masking_rate
421
422    def forward(self, x, mask):
423        """
424        :param x: (N, C, L)
425        :param mask:
426        :return:
427        """
428
429        if self.channel_masking_rate > 0:
430            x = x.unsqueeze(2)
431            x = self.dropout(x)
432            x = x.squeeze(2)
433
434        feature = self.conv_1x1(x)
435        for layer in self.layers:
436            feature = layer(feature, None, mask)
437
438        return feature
439
440
441class _Decoder(nn.Module):
442    def __init__(
443        self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, att_type, alpha
444    ):
445        super(
446            _Decoder, self
447        ).__init__()  # self.position_en = PositionalEncoding(d_model=num_f_maps)
448        self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1)
449        self.layers = nn.ModuleList(
450            [
451                _AttModule(
452                    2**i, num_f_maps, num_f_maps, r1, r2, att_type, "decoder", alpha
453                )
454                for i in range(num_layers)  # 2 ** i
455            ]
456        )
457        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
458
459    def forward(self, x, fencoder, mask):
460        feature = self.conv_1x1(x)
461        for layer in self.layers:
462            feature = layer(feature, fencoder, mask)
463
464        out = self.conv_out(feature) * mask[:, 0:1, :]
465
466        return out, feature
467
468
469class _MyTransformer(nn.Module):
470    def __init__(
471        self,
472        num_layers,
473        r1,
474        r2,
475        num_f_maps,
476        input_dim,
477        channel_masking_rate,
478    ):
479        super(_MyTransformer, self).__init__()
480        self.encoder = _Encoder(
481            num_layers,
482            r1,
483            r2,
484            num_f_maps,
485            input_dim,
486            channel_masking_rate,
487            att_type="sliding_att",
488            alpha=1,
489        )
490
491    def forward(self, x):
492        mask = (x.sum(1).unsqueeze(1) != 0).int()
493        feature = self.encoder(x, mask)
494        feature = feature * mask
495
496        return feature
497
498
499class _Predictor(nn.Module):
500    def __init__(
501        self,
502        num_layers,
503        r1,
504        r2,
505        num_f_maps,
506        num_classes,
507        num_decoders,
508    ):
509        super(_Predictor, self).__init__()
510        self.decoders = nn.ModuleList(
511            [
512                copy.deepcopy(
513                    _Decoder(
514                        num_layers,
515                        r1,
516                        r2,
517                        num_f_maps,
518                        num_classes,
519                        num_classes,
520                        att_type="sliding_att",
521                        alpha=_exponential_descrease(s),
522                    )
523                )
524                for s in range(num_decoders)
525            ]
526        )  # num_decoders
527
528        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
529
530    def forward(self, x):
531        mask = (x.sum(1).unsqueeze(1) != 0).int()
532        out = self.conv_out(x) * mask[:, 0:1, :]
533        outputs = out.unsqueeze(0)
534
535        for decoder in self.decoders:
536            out, x = decoder(
537                F.softmax(out, dim=1) * mask[:, 0:1, :], x * mask[:, 0:1, :], mask
538            )
539            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
540        return outputs
541
542
543class ASFormer(Model):
544    """
545    An implementation of ASFormer
546    """
547
548    def __init__(
549        self,
550        num_decoders,
551        num_layers,
552        r1,
553        r2,
554        num_f_maps,
555        input_dim,
556        num_classes,
557        channel_masking_rate,
558        state_dict_path=None,
559        ssl_constructors=None,
560        ssl_types=None,
561        ssl_modules=None,
562    ):
563        input_dim = sum([x[0] for x in input_dim.values()])
564        self.num_f_maps = num_f_maps
565        self.params = {
566            "num_layers": int(num_layers),
567            "r1": r1,
568            "r2": r2,
569            "num_f_maps": int(num_f_maps),
570            "input_dim": int(input_dim),
571            "channel_masking_rate": channel_masking_rate,
572        }
573        self.params_predictor = {
574            "num_layers": int(num_layers),
575            "r1": r1,
576            "r2": r2,
577            "num_f_maps": int(num_f_maps),
578            "num_classes": int(num_classes),
579            "num_decoders": int(num_decoders),
580        }
581        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
582
583    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
584        return _MyTransformer(**self.params)
585
586    def _predictor(self) -> torch.nn.Module:
587        return _Predictor(**self.params_predictor)
588
589    def features_shape(self) -> torch.Size:
590        return torch.Size([self.num_f_maps])
class ASFormer(dlc2action.model.base_model.Model):
544class ASFormer(Model):
545    """
546    An implementation of ASFormer
547    """
548
549    def __init__(
550        self,
551        num_decoders,
552        num_layers,
553        r1,
554        r2,
555        num_f_maps,
556        input_dim,
557        num_classes,
558        channel_masking_rate,
559        state_dict_path=None,
560        ssl_constructors=None,
561        ssl_types=None,
562        ssl_modules=None,
563    ):
564        input_dim = sum([x[0] for x in input_dim.values()])
565        self.num_f_maps = num_f_maps
566        self.params = {
567            "num_layers": int(num_layers),
568            "r1": r1,
569            "r2": r2,
570            "num_f_maps": int(num_f_maps),
571            "input_dim": int(input_dim),
572            "channel_masking_rate": channel_masking_rate,
573        }
574        self.params_predictor = {
575            "num_layers": int(num_layers),
576            "r1": r1,
577            "r2": r2,
578            "num_f_maps": int(num_f_maps),
579            "num_classes": int(num_classes),
580            "num_decoders": int(num_decoders),
581        }
582        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
583
584    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
585        return _MyTransformer(**self.params)
586
587    def _predictor(self) -> torch.nn.Module:
588        return _Predictor(**self.params_predictor)
589
590    def features_shape(self) -> torch.Size:
591        return torch.Size([self.num_f_maps])

An implementation of ASFormer

ASFormer( num_decoders, num_layers, r1, r2, num_f_maps, input_dim, num_classes, channel_masking_rate, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
549    def __init__(
550        self,
551        num_decoders,
552        num_layers,
553        r1,
554        r2,
555        num_f_maps,
556        input_dim,
557        num_classes,
558        channel_masking_rate,
559        state_dict_path=None,
560        ssl_constructors=None,
561        ssl_types=None,
562        ssl_modules=None,
563    ):
564        input_dim = sum([x[0] for x in input_dim.values()])
565        self.num_f_maps = num_f_maps
566        self.params = {
567            "num_layers": int(num_layers),
568            "r1": r1,
569            "r2": r2,
570            "num_f_maps": int(num_f_maps),
571            "input_dim": int(input_dim),
572            "channel_masking_rate": channel_masking_rate,
573        }
574        self.params_predictor = {
575            "num_layers": int(num_layers),
576            "r1": r1,
577            "r2": r2,
578            "num_f_maps": int(num_f_maps),
579            "num_classes": int(num_classes),
580            "num_decoders": int(num_decoders),
581        }
582        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored

def features_shape(self) -> torch.Size:
590    def features_shape(self) -> torch.Size:
591        return torch.Size([self.num_f_maps])

Get the shape of feature extractor output

Returns

feature_shape : torch.Size shape of feature extractor output

Inherited Members
dlc2action.model.base_model.Model
process_labels
freeze_feature_extractor
unfreeze_feature_extractor
load_state_dict
ssl_off
ssl_on
main_task_on
main_task_off
set_ssl
extract_features
forward
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
T_destination
state_dict
register_load_state_dict_post_hook
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr