dlc2action.model.motionbert_modules

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6# Incorporates code adapted from MotionBERT by Walter0807
  7# Original work Copyright (c) 2023 Walter0807
  8# Source: https://github.com/Walter0807/MotionBERT
  9# Originally licensed under Apache License Version 2.0, 2023
 10# Combined work licensed under GNU AGPLv3
 11#
 12import torch
 13import torch.nn as nn
 14import math
 15import warnings
 16import numpy as np
 17from collections import OrderedDict
 18
 19
 20def drop_path(x, drop_prob: float = 0., training: bool = False):
 21    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
 22    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
 23    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
 24    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
 25    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
 26    'survival rate' as the argument.
 27    """
 28    if drop_prob == 0. or not training:
 29        return x
 30    keep_prob = 1 - drop_prob
 31    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
 32    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
 33    random_tensor.floor_()  # binarize
 34    output = x.div(keep_prob) * random_tensor
 35    return output
 36
 37
 38class DropPath(nn.Module):
 39    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
 40    """
 41    def __init__(self, drop_prob=None):
 42        super(DropPath, self).__init__()
 43        self.drop_prob = drop_prob
 44
 45    def forward(self, x):
 46        return drop_path(x, self.drop_prob, self.training)
 47
 48
 49def _no_grad_trunc_normal_(tensor, mean, std, a, b):
 50    # Cut & paste from PyTorch official master until it's in a few official releases - RW
 51    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
 52    def norm_cdf(x):
 53        # Computes standard normal cumulative distribution function
 54        return (1. + math.erf(x / math.sqrt(2.))) / 2.
 55
 56    if (mean < a - 2 * std) or (mean > b + 2 * std):
 57        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
 58                      "The distribution of values may be incorrect.",
 59                      stacklevel=2)
 60
 61    with torch.no_grad():
 62        # Values are generated by using a truncated uniform distribution and
 63        # then using the inverse CDF for the normal distribution.
 64        # Get upper and lower cdf values
 65        l = norm_cdf((a - mean) / std)
 66        u = norm_cdf((b - mean) / std)
 67
 68        # Uniformly fill tensor with values from [l, u], then translate to
 69        # [2l-1, 2u-1].
 70        tensor.uniform_(2 * l - 1, 2 * u - 1)
 71
 72        # Use inverse cdf transform for normal distribution to get truncated
 73        # standard normal
 74        tensor.erfinv_()
 75
 76        # Transform to proper mean, std
 77        tensor.mul_(std * math.sqrt(2.))
 78        tensor.add_(mean)
 79
 80        # Clamp to ensure it's in the proper range
 81        tensor.clamp_(min=a, max=b)
 82        return tensor
 83
 84
 85def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.) -> torch.Tensor:
 86    r"""Fills the input Tensor with values drawn from a truncated
 87    normal distribution. The values are effectively drawn from the
 88    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
 89    with values outside :math:`[a, b]` redrawn until they are within
 90    the bounds. The method used for generating the random values works
 91    best when :math:`a \leq \text{mean} \leq b`.
 92    Args:
 93        tensor: an n-dimensional `torch.Tensor`
 94        mean: the mean of the normal distribution
 95        std: the standard deviation of the normal distribution
 96        a: the minimum cutoff value
 97        b: the maximum cutoff value
 98    Examples:
 99        >>> w = torch.empty(3, 5)
100        >>> nn.init.trunc_normal_(w)
101    """
102    return _no_grad_trunc_normal_(tensor, mean, std, a, b)
103
104
105class MLP(nn.Module):
106    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
107        super().__init__()
108        out_features = out_features or in_features
109        hidden_features = hidden_features or in_features
110        self.fc1 = nn.Linear(in_features, hidden_features)
111        self.act = act_layer()
112        self.fc2 = nn.Linear(hidden_features, out_features)
113        self.drop = nn.Dropout(drop)
114
115    def forward(self, x):
116        x = self.fc1(x)
117        x = self.act(x)
118        x = self.drop(x)
119        x = self.fc2(x)
120        x = self.drop(x)
121        return x
122
123
124class Attention(nn.Module):
125    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., st_mode='vanilla'):
126        super().__init__()
127        self.num_heads = num_heads
128        head_dim = dim // num_heads
129        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
130        self.scale = qk_scale or head_dim ** -0.5
131
132        self.attn_drop = nn.Dropout(attn_drop)
133        self.proj = nn.Linear(dim, dim)
134        self.mode = st_mode
135        if self.mode == 'parallel':
136            self.ts_attn = nn.Linear(dim*2, dim*2)
137            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
138        else:
139            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
140        self.proj_drop = nn.Dropout(proj_drop)
141
142        self.attn_count_s = None
143        self.attn_count_t = None
144
145    def forward(self, x, seqlen=1):
146        B, N, C = x.shape
147
148        if self.mode == 'series':
149            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
150            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
151            x = self.forward_spatial(q, k, v)
152            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
153            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
154            x = self.forward_temporal(q, k, v, seqlen=seqlen)
155        elif self.mode == 'parallel':
156            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
157            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
158            x_t = self.forward_temporal(q, k, v, seqlen=seqlen)
159            x_s = self.forward_spatial(q, k, v)
160
161            alpha = torch.cat([x_s, x_t], dim=-1)
162            alpha = alpha.mean(dim=1, keepdim=True)
163            alpha = self.ts_attn(alpha).reshape(B, 1, C, 2)
164            alpha = alpha.softmax(dim=-1)
165            x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0]
166        elif self.mode == 'coupling':
167            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
168            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
169            x = self.forward_coupling(q, k, v, seqlen=seqlen)
170        elif self.mode == 'vanilla':
171            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
172            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
173            x = self.forward_spatial(q, k, v)
174        elif self.mode == 'temporal':
175            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
176            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
177            x = self.forward_temporal(q, k, v, seqlen=seqlen)
178        elif self.mode == 'spatial':
179            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
180            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
181            x = self.forward_spatial(q, k, v)
182        else:
183            raise NotImplementedError(self.mode)
184        x = self.proj(x)
185        x = self.proj_drop(x)
186        return x
187
188    def reshape_T(self, x, seqlen=1, inverse=False):
189        if not inverse:
190            N, C = x.shape[-2:]
191            x = x.reshape(-1, seqlen, self.num_heads, N, C).transpose(1,2)
192            x = x.reshape(-1, self.num_heads, seqlen*N, C) #(B, H, TN, c)
193        else:
194            TN, C = x.shape[-2:]
195            x = x.reshape(-1, self.num_heads, seqlen, TN // seqlen, C).transpose(1,2)
196            x = x.reshape(-1, self.num_heads, TN // seqlen, C) #(BT, H, N, C)
197        return x
198
199    def forward_coupling(self, q, k, v, seqlen=8):
200        BT, _, N, C = q.shape
201        q = self.reshape_T(q, seqlen)
202        k = self.reshape_T(k, seqlen)
203        v = self.reshape_T(v, seqlen)
204
205        attn = (q @ k.transpose(-2, -1)) * self.scale
206        attn = attn.softmax(dim=-1)
207        attn = self.attn_drop(attn)
208
209        x = attn @ v
210        x = self.reshape_T(x, seqlen, inverse=True)
211        x = x.transpose(1,2).reshape(BT, N, C*self.num_heads)
212        return x
213
214    def forward_spatial(self, q, k, v):
215        B, _, N, C = q.shape
216        attn = (q @ k.transpose(-2, -1)) * self.scale
217        attn = attn.softmax(dim=-1)
218        attn = self.attn_drop(attn)
219
220        x = attn @ v
221        x = x.transpose(1,2).reshape(B, N, C*self.num_heads)
222        return x
223
224    def forward_temporal(self, q, k, v, seqlen=8):
225        B, _, N, C = q.shape
226        qt = q.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
227        kt = k.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
228        vt = v.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
229
230        attn = (qt @ kt.transpose(-2, -1)) * self.scale
231        attn = attn.softmax(dim=-1)
232        attn = self.attn_drop(attn)
233
234        x = attn @ vt #(B, H, N, T, C)
235        x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C*self.num_heads)
236        return x
237
238    def count_attn(self, attn):
239        attn = attn.detach().cpu().numpy()
240        attn = attn.mean(axis=1)
241        attn_t = attn[:, :, 1].mean(axis=1)
242        attn_s = attn[:, :, 0].mean(axis=1)
243        if self.attn_count_s is None:
244            self.attn_count_s = attn_s
245            self.attn_count_t = attn_t
246        else:
247            self.attn_count_s = np.concatenate([self.attn_count_s, attn_s], axis=0)
248            self.attn_count_t = np.concatenate([self.attn_count_t, attn_t], axis=0)
249
250class Block(nn.Module):
251
252    def __init__(self, dim, num_heads, mlp_ratio=4., mlp_out_ratio=1., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
253                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, st_mode='stage_st', att_fuse=False):
254        super().__init__()
255        # assert 'stage' in st_mode
256        self.st_mode = st_mode
257        self.norm1_s = norm_layer(dim)
258        self.norm1_t = norm_layer(dim)
259        self.attn_s = Attention(
260            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="spatial")
261        self.attn_t = Attention(
262            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="temporal")
263
264        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
265        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
266        self.norm2_s = norm_layer(dim)
267        self.norm2_t = norm_layer(dim)
268        mlp_hidden_dim = int(dim * mlp_ratio)
269        mlp_out_dim = int(dim * mlp_out_ratio)
270        self.mlp_s = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop)
271        self.mlp_t = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop)
272        self.att_fuse = att_fuse
273        if self.att_fuse:
274            self.ts_attn = nn.Linear(dim*2, dim*2)
275    def forward(self, x, seqlen=1):
276        if self.st_mode=='stage_st':
277            x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
278            x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
279            x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
280            x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
281        elif self.st_mode=='stage_ts':
282            x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
283            x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
284            x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
285            x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
286        elif self.st_mode=='stage_para':
287            x_t = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
288            x_t = x_t + self.drop_path(self.mlp_t(self.norm2_t(x_t)))
289            x_s = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
290            x_s = x_s + self.drop_path(self.mlp_s(self.norm2_s(x_s)))
291            if self.att_fuse:
292                #             x_s, x_t: [BF, J, dim]
293                alpha = torch.cat([x_s, x_t], dim=-1)
294                BF, J = alpha.shape[:2]
295                # alpha = alpha.mean(dim=1, keepdim=True)
296                alpha = self.ts_attn(alpha).reshape(BF, J, -1, 2)
297                alpha = alpha.softmax(dim=-1)
298                x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0]
299            else:
300                x = (x_s + x_t)*0.5
301        else:
302            raise NotImplementedError(self.st_mode)
303        return x
304
305class DSTformer(nn.Module):
306    def __init__(self, dim_in=3, dim_out=3, dim_feat=256, dim_rep=512,
307                 depth=5, num_heads=8, mlp_ratio=4,
308                 num_joints=17, maxlen=243,
309                 qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, att_fuse=True):
310        super().__init__()
311        self.dim_out = dim_out
312        self.dim_feat = dim_feat
313        self.joints_embed = nn.Linear(dim_in, dim_feat)
314        self.pos_drop = nn.Dropout(p=drop_rate)
315        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
316        self.blocks_st = nn.ModuleList([
317            Block(
318                dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
319                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
320                st_mode="stage_st")
321            for i in range(depth)])
322        self.blocks_ts = nn.ModuleList([
323            Block(
324                dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
325                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
326                st_mode="stage_ts")
327            for i in range(depth)])
328        self.norm = norm_layer(dim_feat)
329        if dim_rep:
330            self.pre_logits = nn.Sequential(OrderedDict([
331                ('fc', nn.Linear(dim_feat, dim_rep)),
332                ('act', nn.Tanh())
333            ]))
334        else:
335            self.pre_logits = nn.Identity()
336        self.head = nn.Linear(dim_rep, dim_out) if dim_out > 0 else nn.Identity()
337        self.temp_embed = nn.Parameter(torch.zeros(1, maxlen, 1, dim_feat))
338        self.pos_embed = nn.Parameter(torch.zeros(1, num_joints, dim_feat))
339        trunc_normal_(self.temp_embed, std=.02)
340        trunc_normal_(self.pos_embed, std=.02)
341        self.apply(self._init_weights)
342        self.att_fuse = att_fuse
343        if self.att_fuse:
344            self.ts_attn = nn.ModuleList([nn.Linear(dim_feat*2, 2) for i in range(depth)])
345            for i in range(depth):
346                self.ts_attn[i].weight.data.fill_(0)
347                self.ts_attn[i].bias.data.fill_(0.5)
348
349    def _init_weights(self, m):
350        if isinstance(m, nn.Linear):
351            trunc_normal_(m.weight, std=.02)
352            if isinstance(m, nn.Linear) and m.bias is not None:
353                nn.init.constant_(m.bias, 0)
354        elif isinstance(m, nn.LayerNorm):
355            nn.init.constant_(m.bias, 0)
356            nn.init.constant_(m.weight, 1.0)
357
358    def get_classifier(self):
359        return self.head
360
361    def reset_classifier(self, dim_out, global_pool=''):
362        self.dim_out = dim_out
363        self.head = nn.Linear(self.dim_feat, dim_out) if dim_out > 0 else nn.Identity()
364
365    def forward(self, x, return_rep=False):
366        B, F, J, C = x.shape
367        x = x.reshape(-1, J, C)
368        BF = x.shape[0]
369        x = self.joints_embed(x)
370        x = x + self.pos_embed
371        _, J, C = x.shape
372        x = x.reshape(-1, F, J, C) + self.temp_embed[:,:F,:,:]
373        x = x.reshape(BF, J, C)
374        x = self.pos_drop(x)
375        alphas = []
376        for idx, (blk_st, blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)):
377            x_st = blk_st(x, F)
378            x_ts = blk_ts(x, F)
379            if self.att_fuse:
380                att = self.ts_attn[idx]
381                alpha = torch.cat([x_st, x_ts], dim=-1)
382                BF, J = alpha.shape[:2]
383                alpha = att(alpha)
384                alpha = alpha.softmax(dim=-1)
385                x = x_st * alpha[:,:,0:1] + x_ts * alpha[:,:,1:2]
386            else:
387                x = (x_st + x_ts)*0.5
388        x = self.norm(x)
389        x = x.reshape(B, F, J, -1)
390        x = self.pre_logits(x)         # [B, F, J, dim_feat]
391        if return_rep:
392            return x
393        x = self.head(x)
394        return x
395
396    def get_representation(self, x):
397        return self.forward(x, return_rep=True)
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
21def drop_path(x, drop_prob: float = 0., training: bool = False):
22    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
23    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
24    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
25    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
26    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
27    'survival rate' as the argument.
28    """
29    if drop_prob == 0. or not training:
30        return x
31    keep_prob = 1 - drop_prob
32    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
33    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
34    random_tensor.floor_()  # binarize
35    output = x.div(keep_prob) * random_tensor
36    return output

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.

class DropPath(torch.nn.modules.module.Module):
39class DropPath(nn.Module):
40    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
41    """
42    def __init__(self, drop_prob=None):
43        super(DropPath, self).__init__()
44        self.drop_prob = drop_prob
45
46    def forward(self, x):
47        return drop_path(x, self.drop_prob, self.training)

Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

DropPath(drop_prob=None)
42    def __init__(self, drop_prob=None):
43        super(DropPath, self).__init__()
44        self.drop_prob = drop_prob

Initialize internal Module state, shared by both nn.Module and ScriptModule.

drop_prob
def forward(self, x):
46    def forward(self, x):
47        return drop_path(x, self.drop_prob, self.training)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0) -> torch.Tensor:
 86def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.) -> torch.Tensor:
 87    r"""Fills the input Tensor with values drawn from a truncated
 88    normal distribution. The values are effectively drawn from the
 89    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
 90    with values outside :math:`[a, b]` redrawn until they are within
 91    the bounds. The method used for generating the random values works
 92    best when :math:`a \leq \text{mean} \leq b`.
 93    Args:
 94        tensor: an n-dimensional `torch.Tensor`
 95        mean: the mean of the normal distribution
 96        std: the standard deviation of the normal distribution
 97        a: the minimum cutoff value
 98        b: the maximum cutoff value
 99    Examples:
100        >>> w = torch.empty(3, 5)
101        >>> nn.init.trunc_normal_(w)
102    """
103    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

Fills the input Tensor with values drawn from a truncated normal distribution. The values are effectively drawn from the normal distribution \( \mathcal{N}(\text{mean}, \text{std}^2) \) with values outside \( [a, b] \) redrawn until they are within the bounds. The method used for generating the random values works best when \( a \leq \text{mean} \leq b \). Args: tensor: an n-dimensional torch.Tensor mean: the mean of the normal distribution std: the standard deviation of the normal distribution a: the minimum cutoff value b: the maximum cutoff value Examples:

w = torch.empty(3, 5) nn.init.trunc_normal_(w)

class MLP(torch.nn.modules.module.Module):
106class MLP(nn.Module):
107    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
108        super().__init__()
109        out_features = out_features or in_features
110        hidden_features = hidden_features or in_features
111        self.fc1 = nn.Linear(in_features, hidden_features)
112        self.act = act_layer()
113        self.fc2 = nn.Linear(hidden_features, out_features)
114        self.drop = nn.Dropout(drop)
115
116    def forward(self, x):
117        x = self.fc1(x)
118        x = self.act(x)
119        x = self.drop(x)
120        x = self.fc2(x)
121        x = self.drop(x)
122        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

MLP( in_features, hidden_features=None, out_features=None, act_layer=<class 'torch.nn.modules.activation.GELU'>, drop=0.0)
107    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
108        super().__init__()
109        out_features = out_features or in_features
110        hidden_features = hidden_features or in_features
111        self.fc1 = nn.Linear(in_features, hidden_features)
112        self.act = act_layer()
113        self.fc2 = nn.Linear(hidden_features, out_features)
114        self.drop = nn.Dropout(drop)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

fc1
act
fc2
drop
def forward(self, x):
116    def forward(self, x):
117        x = self.fc1(x)
118        x = self.act(x)
119        x = self.drop(x)
120        x = self.fc2(x)
121        x = self.drop(x)
122        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Attention(torch.nn.modules.module.Module):
125class Attention(nn.Module):
126    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., st_mode='vanilla'):
127        super().__init__()
128        self.num_heads = num_heads
129        head_dim = dim // num_heads
130        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
131        self.scale = qk_scale or head_dim ** -0.5
132
133        self.attn_drop = nn.Dropout(attn_drop)
134        self.proj = nn.Linear(dim, dim)
135        self.mode = st_mode
136        if self.mode == 'parallel':
137            self.ts_attn = nn.Linear(dim*2, dim*2)
138            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
139        else:
140            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
141        self.proj_drop = nn.Dropout(proj_drop)
142
143        self.attn_count_s = None
144        self.attn_count_t = None
145
146    def forward(self, x, seqlen=1):
147        B, N, C = x.shape
148
149        if self.mode == 'series':
150            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
151            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
152            x = self.forward_spatial(q, k, v)
153            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
154            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
155            x = self.forward_temporal(q, k, v, seqlen=seqlen)
156        elif self.mode == 'parallel':
157            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
158            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
159            x_t = self.forward_temporal(q, k, v, seqlen=seqlen)
160            x_s = self.forward_spatial(q, k, v)
161
162            alpha = torch.cat([x_s, x_t], dim=-1)
163            alpha = alpha.mean(dim=1, keepdim=True)
164            alpha = self.ts_attn(alpha).reshape(B, 1, C, 2)
165            alpha = alpha.softmax(dim=-1)
166            x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0]
167        elif self.mode == 'coupling':
168            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
169            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
170            x = self.forward_coupling(q, k, v, seqlen=seqlen)
171        elif self.mode == 'vanilla':
172            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
173            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
174            x = self.forward_spatial(q, k, v)
175        elif self.mode == 'temporal':
176            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
177            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
178            x = self.forward_temporal(q, k, v, seqlen=seqlen)
179        elif self.mode == 'spatial':
180            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
181            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
182            x = self.forward_spatial(q, k, v)
183        else:
184            raise NotImplementedError(self.mode)
185        x = self.proj(x)
186        x = self.proj_drop(x)
187        return x
188
189    def reshape_T(self, x, seqlen=1, inverse=False):
190        if not inverse:
191            N, C = x.shape[-2:]
192            x = x.reshape(-1, seqlen, self.num_heads, N, C).transpose(1,2)
193            x = x.reshape(-1, self.num_heads, seqlen*N, C) #(B, H, TN, c)
194        else:
195            TN, C = x.shape[-2:]
196            x = x.reshape(-1, self.num_heads, seqlen, TN // seqlen, C).transpose(1,2)
197            x = x.reshape(-1, self.num_heads, TN // seqlen, C) #(BT, H, N, C)
198        return x
199
200    def forward_coupling(self, q, k, v, seqlen=8):
201        BT, _, N, C = q.shape
202        q = self.reshape_T(q, seqlen)
203        k = self.reshape_T(k, seqlen)
204        v = self.reshape_T(v, seqlen)
205
206        attn = (q @ k.transpose(-2, -1)) * self.scale
207        attn = attn.softmax(dim=-1)
208        attn = self.attn_drop(attn)
209
210        x = attn @ v
211        x = self.reshape_T(x, seqlen, inverse=True)
212        x = x.transpose(1,2).reshape(BT, N, C*self.num_heads)
213        return x
214
215    def forward_spatial(self, q, k, v):
216        B, _, N, C = q.shape
217        attn = (q @ k.transpose(-2, -1)) * self.scale
218        attn = attn.softmax(dim=-1)
219        attn = self.attn_drop(attn)
220
221        x = attn @ v
222        x = x.transpose(1,2).reshape(B, N, C*self.num_heads)
223        return x
224
225    def forward_temporal(self, q, k, v, seqlen=8):
226        B, _, N, C = q.shape
227        qt = q.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
228        kt = k.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
229        vt = v.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
230
231        attn = (qt @ kt.transpose(-2, -1)) * self.scale
232        attn = attn.softmax(dim=-1)
233        attn = self.attn_drop(attn)
234
235        x = attn @ vt #(B, H, N, T, C)
236        x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C*self.num_heads)
237        return x
238
239    def count_attn(self, attn):
240        attn = attn.detach().cpu().numpy()
241        attn = attn.mean(axis=1)
242        attn_t = attn[:, :, 1].mean(axis=1)
243        attn_s = attn[:, :, 0].mean(axis=1)
244        if self.attn_count_s is None:
245            self.attn_count_s = attn_s
246            self.attn_count_t = attn_t
247        else:
248            self.attn_count_s = np.concatenate([self.attn_count_s, attn_s], axis=0)
249            self.attn_count_t = np.concatenate([self.attn_count_t, attn_t], axis=0)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Attention( dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, st_mode='vanilla')
126    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., st_mode='vanilla'):
127        super().__init__()
128        self.num_heads = num_heads
129        head_dim = dim // num_heads
130        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
131        self.scale = qk_scale or head_dim ** -0.5
132
133        self.attn_drop = nn.Dropout(attn_drop)
134        self.proj = nn.Linear(dim, dim)
135        self.mode = st_mode
136        if self.mode == 'parallel':
137            self.ts_attn = nn.Linear(dim*2, dim*2)
138            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
139        else:
140            self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
141        self.proj_drop = nn.Dropout(proj_drop)
142
143        self.attn_count_s = None
144        self.attn_count_t = None

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_heads
scale
attn_drop
proj
mode
proj_drop
attn_count_s
attn_count_t
def forward(self, x, seqlen=1):
146    def forward(self, x, seqlen=1):
147        B, N, C = x.shape
148
149        if self.mode == 'series':
150            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
151            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
152            x = self.forward_spatial(q, k, v)
153            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
154            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
155            x = self.forward_temporal(q, k, v, seqlen=seqlen)
156        elif self.mode == 'parallel':
157            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
158            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
159            x_t = self.forward_temporal(q, k, v, seqlen=seqlen)
160            x_s = self.forward_spatial(q, k, v)
161
162            alpha = torch.cat([x_s, x_t], dim=-1)
163            alpha = alpha.mean(dim=1, keepdim=True)
164            alpha = self.ts_attn(alpha).reshape(B, 1, C, 2)
165            alpha = alpha.softmax(dim=-1)
166            x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0]
167        elif self.mode == 'coupling':
168            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
169            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
170            x = self.forward_coupling(q, k, v, seqlen=seqlen)
171        elif self.mode == 'vanilla':
172            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
173            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
174            x = self.forward_spatial(q, k, v)
175        elif self.mode == 'temporal':
176            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
177            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
178            x = self.forward_temporal(q, k, v, seqlen=seqlen)
179        elif self.mode == 'spatial':
180            qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
181            q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)
182            x = self.forward_spatial(q, k, v)
183        else:
184            raise NotImplementedError(self.mode)
185        x = self.proj(x)
186        x = self.proj_drop(x)
187        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def reshape_T(self, x, seqlen=1, inverse=False):
189    def reshape_T(self, x, seqlen=1, inverse=False):
190        if not inverse:
191            N, C = x.shape[-2:]
192            x = x.reshape(-1, seqlen, self.num_heads, N, C).transpose(1,2)
193            x = x.reshape(-1, self.num_heads, seqlen*N, C) #(B, H, TN, c)
194        else:
195            TN, C = x.shape[-2:]
196            x = x.reshape(-1, self.num_heads, seqlen, TN // seqlen, C).transpose(1,2)
197            x = x.reshape(-1, self.num_heads, TN // seqlen, C) #(BT, H, N, C)
198        return x
def forward_coupling(self, q, k, v, seqlen=8):
200    def forward_coupling(self, q, k, v, seqlen=8):
201        BT, _, N, C = q.shape
202        q = self.reshape_T(q, seqlen)
203        k = self.reshape_T(k, seqlen)
204        v = self.reshape_T(v, seqlen)
205
206        attn = (q @ k.transpose(-2, -1)) * self.scale
207        attn = attn.softmax(dim=-1)
208        attn = self.attn_drop(attn)
209
210        x = attn @ v
211        x = self.reshape_T(x, seqlen, inverse=True)
212        x = x.transpose(1,2).reshape(BT, N, C*self.num_heads)
213        return x
def forward_spatial(self, q, k, v):
215    def forward_spatial(self, q, k, v):
216        B, _, N, C = q.shape
217        attn = (q @ k.transpose(-2, -1)) * self.scale
218        attn = attn.softmax(dim=-1)
219        attn = self.attn_drop(attn)
220
221        x = attn @ v
222        x = x.transpose(1,2).reshape(B, N, C*self.num_heads)
223        return x
def forward_temporal(self, q, k, v, seqlen=8):
225    def forward_temporal(self, q, k, v, seqlen=8):
226        B, _, N, C = q.shape
227        qt = q.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
228        kt = k.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
229        vt = v.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C)
230
231        attn = (qt @ kt.transpose(-2, -1)) * self.scale
232        attn = attn.softmax(dim=-1)
233        attn = self.attn_drop(attn)
234
235        x = attn @ vt #(B, H, N, T, C)
236        x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C*self.num_heads)
237        return x
def count_attn(self, attn):
239    def count_attn(self, attn):
240        attn = attn.detach().cpu().numpy()
241        attn = attn.mean(axis=1)
242        attn_t = attn[:, :, 1].mean(axis=1)
243        attn_s = attn[:, :, 0].mean(axis=1)
244        if self.attn_count_s is None:
245            self.attn_count_s = attn_s
246            self.attn_count_t = attn_t
247        else:
248            self.attn_count_s = np.concatenate([self.attn_count_s, attn_s], axis=0)
249            self.attn_count_t = np.concatenate([self.attn_count_t, attn_t], axis=0)
class Block(torch.nn.modules.module.Module):
251class Block(nn.Module):
252
253    def __init__(self, dim, num_heads, mlp_ratio=4., mlp_out_ratio=1., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
254                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, st_mode='stage_st', att_fuse=False):
255        super().__init__()
256        # assert 'stage' in st_mode
257        self.st_mode = st_mode
258        self.norm1_s = norm_layer(dim)
259        self.norm1_t = norm_layer(dim)
260        self.attn_s = Attention(
261            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="spatial")
262        self.attn_t = Attention(
263            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="temporal")
264
265        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
266        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
267        self.norm2_s = norm_layer(dim)
268        self.norm2_t = norm_layer(dim)
269        mlp_hidden_dim = int(dim * mlp_ratio)
270        mlp_out_dim = int(dim * mlp_out_ratio)
271        self.mlp_s = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop)
272        self.mlp_t = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop)
273        self.att_fuse = att_fuse
274        if self.att_fuse:
275            self.ts_attn = nn.Linear(dim*2, dim*2)
276    def forward(self, x, seqlen=1):
277        if self.st_mode=='stage_st':
278            x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
279            x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
280            x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
281            x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
282        elif self.st_mode=='stage_ts':
283            x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
284            x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
285            x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
286            x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
287        elif self.st_mode=='stage_para':
288            x_t = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
289            x_t = x_t + self.drop_path(self.mlp_t(self.norm2_t(x_t)))
290            x_s = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
291            x_s = x_s + self.drop_path(self.mlp_s(self.norm2_s(x_s)))
292            if self.att_fuse:
293                #             x_s, x_t: [BF, J, dim]
294                alpha = torch.cat([x_s, x_t], dim=-1)
295                BF, J = alpha.shape[:2]
296                # alpha = alpha.mean(dim=1, keepdim=True)
297                alpha = self.ts_attn(alpha).reshape(BF, J, -1, 2)
298                alpha = alpha.softmax(dim=-1)
299                x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0]
300            else:
301                x = (x_s + x_t)*0.5
302        else:
303            raise NotImplementedError(self.st_mode)
304        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Block( dim, num_heads, mlp_ratio=4.0, mlp_out_ratio=1.0, qkv_bias=True, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=<class 'torch.nn.modules.activation.GELU'>, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, st_mode='stage_st', att_fuse=False)
253    def __init__(self, dim, num_heads, mlp_ratio=4., mlp_out_ratio=1., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
254                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, st_mode='stage_st', att_fuse=False):
255        super().__init__()
256        # assert 'stage' in st_mode
257        self.st_mode = st_mode
258        self.norm1_s = norm_layer(dim)
259        self.norm1_t = norm_layer(dim)
260        self.attn_s = Attention(
261            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="spatial")
262        self.attn_t = Attention(
263            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="temporal")
264
265        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
266        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
267        self.norm2_s = norm_layer(dim)
268        self.norm2_t = norm_layer(dim)
269        mlp_hidden_dim = int(dim * mlp_ratio)
270        mlp_out_dim = int(dim * mlp_out_ratio)
271        self.mlp_s = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop)
272        self.mlp_t = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop)
273        self.att_fuse = att_fuse
274        if self.att_fuse:
275            self.ts_attn = nn.Linear(dim*2, dim*2)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

st_mode
norm1_s
norm1_t
attn_s
attn_t
drop_path
norm2_s
norm2_t
mlp_s
mlp_t
att_fuse
def forward(self, x, seqlen=1):
276    def forward(self, x, seqlen=1):
277        if self.st_mode=='stage_st':
278            x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
279            x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
280            x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
281            x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
282        elif self.st_mode=='stage_ts':
283            x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
284            x = x + self.drop_path(self.mlp_t(self.norm2_t(x)))
285            x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
286            x = x + self.drop_path(self.mlp_s(self.norm2_s(x)))
287        elif self.st_mode=='stage_para':
288            x_t = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen))
289            x_t = x_t + self.drop_path(self.mlp_t(self.norm2_t(x_t)))
290            x_s = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen))
291            x_s = x_s + self.drop_path(self.mlp_s(self.norm2_s(x_s)))
292            if self.att_fuse:
293                #             x_s, x_t: [BF, J, dim]
294                alpha = torch.cat([x_s, x_t], dim=-1)
295                BF, J = alpha.shape[:2]
296                # alpha = alpha.mean(dim=1, keepdim=True)
297                alpha = self.ts_attn(alpha).reshape(BF, J, -1, 2)
298                alpha = alpha.softmax(dim=-1)
299                x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0]
300            else:
301                x = (x_s + x_t)*0.5
302        else:
303            raise NotImplementedError(self.st_mode)
304        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class DSTformer(torch.nn.modules.module.Module):
306class DSTformer(nn.Module):
307    def __init__(self, dim_in=3, dim_out=3, dim_feat=256, dim_rep=512,
308                 depth=5, num_heads=8, mlp_ratio=4,
309                 num_joints=17, maxlen=243,
310                 qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, att_fuse=True):
311        super().__init__()
312        self.dim_out = dim_out
313        self.dim_feat = dim_feat
314        self.joints_embed = nn.Linear(dim_in, dim_feat)
315        self.pos_drop = nn.Dropout(p=drop_rate)
316        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
317        self.blocks_st = nn.ModuleList([
318            Block(
319                dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
320                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
321                st_mode="stage_st")
322            for i in range(depth)])
323        self.blocks_ts = nn.ModuleList([
324            Block(
325                dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
326                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
327                st_mode="stage_ts")
328            for i in range(depth)])
329        self.norm = norm_layer(dim_feat)
330        if dim_rep:
331            self.pre_logits = nn.Sequential(OrderedDict([
332                ('fc', nn.Linear(dim_feat, dim_rep)),
333                ('act', nn.Tanh())
334            ]))
335        else:
336            self.pre_logits = nn.Identity()
337        self.head = nn.Linear(dim_rep, dim_out) if dim_out > 0 else nn.Identity()
338        self.temp_embed = nn.Parameter(torch.zeros(1, maxlen, 1, dim_feat))
339        self.pos_embed = nn.Parameter(torch.zeros(1, num_joints, dim_feat))
340        trunc_normal_(self.temp_embed, std=.02)
341        trunc_normal_(self.pos_embed, std=.02)
342        self.apply(self._init_weights)
343        self.att_fuse = att_fuse
344        if self.att_fuse:
345            self.ts_attn = nn.ModuleList([nn.Linear(dim_feat*2, 2) for i in range(depth)])
346            for i in range(depth):
347                self.ts_attn[i].weight.data.fill_(0)
348                self.ts_attn[i].bias.data.fill_(0.5)
349
350    def _init_weights(self, m):
351        if isinstance(m, nn.Linear):
352            trunc_normal_(m.weight, std=.02)
353            if isinstance(m, nn.Linear) and m.bias is not None:
354                nn.init.constant_(m.bias, 0)
355        elif isinstance(m, nn.LayerNorm):
356            nn.init.constant_(m.bias, 0)
357            nn.init.constant_(m.weight, 1.0)
358
359    def get_classifier(self):
360        return self.head
361
362    def reset_classifier(self, dim_out, global_pool=''):
363        self.dim_out = dim_out
364        self.head = nn.Linear(self.dim_feat, dim_out) if dim_out > 0 else nn.Identity()
365
366    def forward(self, x, return_rep=False):
367        B, F, J, C = x.shape
368        x = x.reshape(-1, J, C)
369        BF = x.shape[0]
370        x = self.joints_embed(x)
371        x = x + self.pos_embed
372        _, J, C = x.shape
373        x = x.reshape(-1, F, J, C) + self.temp_embed[:,:F,:,:]
374        x = x.reshape(BF, J, C)
375        x = self.pos_drop(x)
376        alphas = []
377        for idx, (blk_st, blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)):
378            x_st = blk_st(x, F)
379            x_ts = blk_ts(x, F)
380            if self.att_fuse:
381                att = self.ts_attn[idx]
382                alpha = torch.cat([x_st, x_ts], dim=-1)
383                BF, J = alpha.shape[:2]
384                alpha = att(alpha)
385                alpha = alpha.softmax(dim=-1)
386                x = x_st * alpha[:,:,0:1] + x_ts * alpha[:,:,1:2]
387            else:
388                x = (x_st + x_ts)*0.5
389        x = self.norm(x)
390        x = x.reshape(B, F, J, -1)
391        x = self.pre_logits(x)         # [B, F, J, dim_feat]
392        if return_rep:
393            return x
394        x = self.head(x)
395        return x
396
397    def get_representation(self, x):
398        return self.forward(x, return_rep=True)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

DSTformer( dim_in=3, dim_out=3, dim_feat=256, dim_rep=512, depth=5, num_heads=8, mlp_ratio=4, num_joints=17, maxlen=243, qkv_bias=True, qk_scale=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0, norm_layer=<class 'torch.nn.modules.normalization.LayerNorm'>, att_fuse=True)
307    def __init__(self, dim_in=3, dim_out=3, dim_feat=256, dim_rep=512,
308                 depth=5, num_heads=8, mlp_ratio=4,
309                 num_joints=17, maxlen=243,
310                 qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, att_fuse=True):
311        super().__init__()
312        self.dim_out = dim_out
313        self.dim_feat = dim_feat
314        self.joints_embed = nn.Linear(dim_in, dim_feat)
315        self.pos_drop = nn.Dropout(p=drop_rate)
316        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
317        self.blocks_st = nn.ModuleList([
318            Block(
319                dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
320                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
321                st_mode="stage_st")
322            for i in range(depth)])
323        self.blocks_ts = nn.ModuleList([
324            Block(
325                dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
326                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
327                st_mode="stage_ts")
328            for i in range(depth)])
329        self.norm = norm_layer(dim_feat)
330        if dim_rep:
331            self.pre_logits = nn.Sequential(OrderedDict([
332                ('fc', nn.Linear(dim_feat, dim_rep)),
333                ('act', nn.Tanh())
334            ]))
335        else:
336            self.pre_logits = nn.Identity()
337        self.head = nn.Linear(dim_rep, dim_out) if dim_out > 0 else nn.Identity()
338        self.temp_embed = nn.Parameter(torch.zeros(1, maxlen, 1, dim_feat))
339        self.pos_embed = nn.Parameter(torch.zeros(1, num_joints, dim_feat))
340        trunc_normal_(self.temp_embed, std=.02)
341        trunc_normal_(self.pos_embed, std=.02)
342        self.apply(self._init_weights)
343        self.att_fuse = att_fuse
344        if self.att_fuse:
345            self.ts_attn = nn.ModuleList([nn.Linear(dim_feat*2, 2) for i in range(depth)])
346            for i in range(depth):
347                self.ts_attn[i].weight.data.fill_(0)
348                self.ts_attn[i].bias.data.fill_(0.5)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dim_out
dim_feat
joints_embed
pos_drop
blocks_st
blocks_ts
norm
head
temp_embed
pos_embed
att_fuse
def get_classifier(self):
359    def get_classifier(self):
360        return self.head
def reset_classifier(self, dim_out, global_pool=''):
362    def reset_classifier(self, dim_out, global_pool=''):
363        self.dim_out = dim_out
364        self.head = nn.Linear(self.dim_feat, dim_out) if dim_out > 0 else nn.Identity()
def forward(self, x, return_rep=False):
366    def forward(self, x, return_rep=False):
367        B, F, J, C = x.shape
368        x = x.reshape(-1, J, C)
369        BF = x.shape[0]
370        x = self.joints_embed(x)
371        x = x + self.pos_embed
372        _, J, C = x.shape
373        x = x.reshape(-1, F, J, C) + self.temp_embed[:,:F,:,:]
374        x = x.reshape(BF, J, C)
375        x = self.pos_drop(x)
376        alphas = []
377        for idx, (blk_st, blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)):
378            x_st = blk_st(x, F)
379            x_ts = blk_ts(x, F)
380            if self.att_fuse:
381                att = self.ts_attn[idx]
382                alpha = torch.cat([x_st, x_ts], dim=-1)
383                BF, J = alpha.shape[:2]
384                alpha = att(alpha)
385                alpha = alpha.softmax(dim=-1)
386                x = x_st * alpha[:,:,0:1] + x_ts * alpha[:,:,1:2]
387            else:
388                x = (x_st + x_ts)*0.5
389        x = self.norm(x)
390        x = x.reshape(B, F, J, -1)
391        x = self.pre_logits(x)         # [B, F, J, dim_feat]
392        if return_rep:
393            return x
394        x = self.head(x)
395        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

def get_representation(self, x):
397    def get_representation(self, x):
398        return self.forward(x, return_rep=True)