dlc2action.model.motionbert_modules
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# Incorporates code adapted from MotionBERT by Walter0807 7# Original work Copyright (c) 2023 Walter0807 8# Source: https://github.com/Walter0807/MotionBERT 9# Originally licensed under Apache License Version 2.0, 2023 10# Combined work licensed under GNU AGPLv3 11# 12import torch 13import torch.nn as nn 14import math 15import warnings 16import numpy as np 17from collections import OrderedDict 18 19 20def drop_path(x, drop_prob: float = 0., training: bool = False): 21 """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 22 This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 23 the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 24 See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 25 changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 26 'survival rate' as the argument. 27 """ 28 if drop_prob == 0. or not training: 29 return x 30 keep_prob = 1 - drop_prob 31 shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 32 random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 33 random_tensor.floor_() # binarize 34 output = x.div(keep_prob) * random_tensor 35 return output 36 37 38class DropPath(nn.Module): 39 """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 40 """ 41 def __init__(self, drop_prob=None): 42 super(DropPath, self).__init__() 43 self.drop_prob = drop_prob 44 45 def forward(self, x): 46 return drop_path(x, self.drop_prob, self.training) 47 48 49def _no_grad_trunc_normal_(tensor, mean, std, a, b): 50 # Cut & paste from PyTorch official master until it's in a few official releases - RW 51 # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 52 def norm_cdf(x): 53 # Computes standard normal cumulative distribution function 54 return (1. + math.erf(x / math.sqrt(2.))) / 2. 55 56 if (mean < a - 2 * std) or (mean > b + 2 * std): 57 warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 58 "The distribution of values may be incorrect.", 59 stacklevel=2) 60 61 with torch.no_grad(): 62 # Values are generated by using a truncated uniform distribution and 63 # then using the inverse CDF for the normal distribution. 64 # Get upper and lower cdf values 65 l = norm_cdf((a - mean) / std) 66 u = norm_cdf((b - mean) / std) 67 68 # Uniformly fill tensor with values from [l, u], then translate to 69 # [2l-1, 2u-1]. 70 tensor.uniform_(2 * l - 1, 2 * u - 1) 71 72 # Use inverse cdf transform for normal distribution to get truncated 73 # standard normal 74 tensor.erfinv_() 75 76 # Transform to proper mean, std 77 tensor.mul_(std * math.sqrt(2.)) 78 tensor.add_(mean) 79 80 # Clamp to ensure it's in the proper range 81 tensor.clamp_(min=a, max=b) 82 return tensor 83 84 85def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.) -> torch.Tensor: 86 r"""Fills the input Tensor with values drawn from a truncated 87 normal distribution. The values are effectively drawn from the 88 normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 89 with values outside :math:`[a, b]` redrawn until they are within 90 the bounds. The method used for generating the random values works 91 best when :math:`a \leq \text{mean} \leq b`. 92 Args: 93 tensor: an n-dimensional `torch.Tensor` 94 mean: the mean of the normal distribution 95 std: the standard deviation of the normal distribution 96 a: the minimum cutoff value 97 b: the maximum cutoff value 98 Examples: 99 >>> w = torch.empty(3, 5) 100 >>> nn.init.trunc_normal_(w) 101 """ 102 return _no_grad_trunc_normal_(tensor, mean, std, a, b) 103 104 105class MLP(nn.Module): 106 def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 107 super().__init__() 108 out_features = out_features or in_features 109 hidden_features = hidden_features or in_features 110 self.fc1 = nn.Linear(in_features, hidden_features) 111 self.act = act_layer() 112 self.fc2 = nn.Linear(hidden_features, out_features) 113 self.drop = nn.Dropout(drop) 114 115 def forward(self, x): 116 x = self.fc1(x) 117 x = self.act(x) 118 x = self.drop(x) 119 x = self.fc2(x) 120 x = self.drop(x) 121 return x 122 123 124class Attention(nn.Module): 125 def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., st_mode='vanilla'): 126 super().__init__() 127 self.num_heads = num_heads 128 head_dim = dim // num_heads 129 # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 130 self.scale = qk_scale or head_dim ** -0.5 131 132 self.attn_drop = nn.Dropout(attn_drop) 133 self.proj = nn.Linear(dim, dim) 134 self.mode = st_mode 135 if self.mode == 'parallel': 136 self.ts_attn = nn.Linear(dim*2, dim*2) 137 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 138 else: 139 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 140 self.proj_drop = nn.Dropout(proj_drop) 141 142 self.attn_count_s = None 143 self.attn_count_t = None 144 145 def forward(self, x, seqlen=1): 146 B, N, C = x.shape 147 148 if self.mode == 'series': 149 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 150 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 151 x = self.forward_spatial(q, k, v) 152 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 153 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 154 x = self.forward_temporal(q, k, v, seqlen=seqlen) 155 elif self.mode == 'parallel': 156 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 157 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 158 x_t = self.forward_temporal(q, k, v, seqlen=seqlen) 159 x_s = self.forward_spatial(q, k, v) 160 161 alpha = torch.cat([x_s, x_t], dim=-1) 162 alpha = alpha.mean(dim=1, keepdim=True) 163 alpha = self.ts_attn(alpha).reshape(B, 1, C, 2) 164 alpha = alpha.softmax(dim=-1) 165 x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0] 166 elif self.mode == 'coupling': 167 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 168 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 169 x = self.forward_coupling(q, k, v, seqlen=seqlen) 170 elif self.mode == 'vanilla': 171 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 172 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 173 x = self.forward_spatial(q, k, v) 174 elif self.mode == 'temporal': 175 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 176 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 177 x = self.forward_temporal(q, k, v, seqlen=seqlen) 178 elif self.mode == 'spatial': 179 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 180 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 181 x = self.forward_spatial(q, k, v) 182 else: 183 raise NotImplementedError(self.mode) 184 x = self.proj(x) 185 x = self.proj_drop(x) 186 return x 187 188 def reshape_T(self, x, seqlen=1, inverse=False): 189 if not inverse: 190 N, C = x.shape[-2:] 191 x = x.reshape(-1, seqlen, self.num_heads, N, C).transpose(1,2) 192 x = x.reshape(-1, self.num_heads, seqlen*N, C) #(B, H, TN, c) 193 else: 194 TN, C = x.shape[-2:] 195 x = x.reshape(-1, self.num_heads, seqlen, TN // seqlen, C).transpose(1,2) 196 x = x.reshape(-1, self.num_heads, TN // seqlen, C) #(BT, H, N, C) 197 return x 198 199 def forward_coupling(self, q, k, v, seqlen=8): 200 BT, _, N, C = q.shape 201 q = self.reshape_T(q, seqlen) 202 k = self.reshape_T(k, seqlen) 203 v = self.reshape_T(v, seqlen) 204 205 attn = (q @ k.transpose(-2, -1)) * self.scale 206 attn = attn.softmax(dim=-1) 207 attn = self.attn_drop(attn) 208 209 x = attn @ v 210 x = self.reshape_T(x, seqlen, inverse=True) 211 x = x.transpose(1,2).reshape(BT, N, C*self.num_heads) 212 return x 213 214 def forward_spatial(self, q, k, v): 215 B, _, N, C = q.shape 216 attn = (q @ k.transpose(-2, -1)) * self.scale 217 attn = attn.softmax(dim=-1) 218 attn = self.attn_drop(attn) 219 220 x = attn @ v 221 x = x.transpose(1,2).reshape(B, N, C*self.num_heads) 222 return x 223 224 def forward_temporal(self, q, k, v, seqlen=8): 225 B, _, N, C = q.shape 226 qt = q.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) 227 kt = k.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) 228 vt = v.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) 229 230 attn = (qt @ kt.transpose(-2, -1)) * self.scale 231 attn = attn.softmax(dim=-1) 232 attn = self.attn_drop(attn) 233 234 x = attn @ vt #(B, H, N, T, C) 235 x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C*self.num_heads) 236 return x 237 238 def count_attn(self, attn): 239 attn = attn.detach().cpu().numpy() 240 attn = attn.mean(axis=1) 241 attn_t = attn[:, :, 1].mean(axis=1) 242 attn_s = attn[:, :, 0].mean(axis=1) 243 if self.attn_count_s is None: 244 self.attn_count_s = attn_s 245 self.attn_count_t = attn_t 246 else: 247 self.attn_count_s = np.concatenate([self.attn_count_s, attn_s], axis=0) 248 self.attn_count_t = np.concatenate([self.attn_count_t, attn_t], axis=0) 249 250class Block(nn.Module): 251 252 def __init__(self, dim, num_heads, mlp_ratio=4., mlp_out_ratio=1., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 253 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, st_mode='stage_st', att_fuse=False): 254 super().__init__() 255 # assert 'stage' in st_mode 256 self.st_mode = st_mode 257 self.norm1_s = norm_layer(dim) 258 self.norm1_t = norm_layer(dim) 259 self.attn_s = Attention( 260 dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="spatial") 261 self.attn_t = Attention( 262 dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="temporal") 263 264 # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 265 self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 266 self.norm2_s = norm_layer(dim) 267 self.norm2_t = norm_layer(dim) 268 mlp_hidden_dim = int(dim * mlp_ratio) 269 mlp_out_dim = int(dim * mlp_out_ratio) 270 self.mlp_s = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop) 271 self.mlp_t = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop) 272 self.att_fuse = att_fuse 273 if self.att_fuse: 274 self.ts_attn = nn.Linear(dim*2, dim*2) 275 def forward(self, x, seqlen=1): 276 if self.st_mode=='stage_st': 277 x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) 278 x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) 279 x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) 280 x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) 281 elif self.st_mode=='stage_ts': 282 x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) 283 x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) 284 x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) 285 x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) 286 elif self.st_mode=='stage_para': 287 x_t = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) 288 x_t = x_t + self.drop_path(self.mlp_t(self.norm2_t(x_t))) 289 x_s = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) 290 x_s = x_s + self.drop_path(self.mlp_s(self.norm2_s(x_s))) 291 if self.att_fuse: 292 # x_s, x_t: [BF, J, dim] 293 alpha = torch.cat([x_s, x_t], dim=-1) 294 BF, J = alpha.shape[:2] 295 # alpha = alpha.mean(dim=1, keepdim=True) 296 alpha = self.ts_attn(alpha).reshape(BF, J, -1, 2) 297 alpha = alpha.softmax(dim=-1) 298 x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0] 299 else: 300 x = (x_s + x_t)*0.5 301 else: 302 raise NotImplementedError(self.st_mode) 303 return x 304 305class DSTformer(nn.Module): 306 def __init__(self, dim_in=3, dim_out=3, dim_feat=256, dim_rep=512, 307 depth=5, num_heads=8, mlp_ratio=4, 308 num_joints=17, maxlen=243, 309 qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, att_fuse=True): 310 super().__init__() 311 self.dim_out = dim_out 312 self.dim_feat = dim_feat 313 self.joints_embed = nn.Linear(dim_in, dim_feat) 314 self.pos_drop = nn.Dropout(p=drop_rate) 315 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 316 self.blocks_st = nn.ModuleList([ 317 Block( 318 dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 319 drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 320 st_mode="stage_st") 321 for i in range(depth)]) 322 self.blocks_ts = nn.ModuleList([ 323 Block( 324 dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 325 drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 326 st_mode="stage_ts") 327 for i in range(depth)]) 328 self.norm = norm_layer(dim_feat) 329 if dim_rep: 330 self.pre_logits = nn.Sequential(OrderedDict([ 331 ('fc', nn.Linear(dim_feat, dim_rep)), 332 ('act', nn.Tanh()) 333 ])) 334 else: 335 self.pre_logits = nn.Identity() 336 self.head = nn.Linear(dim_rep, dim_out) if dim_out > 0 else nn.Identity() 337 self.temp_embed = nn.Parameter(torch.zeros(1, maxlen, 1, dim_feat)) 338 self.pos_embed = nn.Parameter(torch.zeros(1, num_joints, dim_feat)) 339 trunc_normal_(self.temp_embed, std=.02) 340 trunc_normal_(self.pos_embed, std=.02) 341 self.apply(self._init_weights) 342 self.att_fuse = att_fuse 343 if self.att_fuse: 344 self.ts_attn = nn.ModuleList([nn.Linear(dim_feat*2, 2) for i in range(depth)]) 345 for i in range(depth): 346 self.ts_attn[i].weight.data.fill_(0) 347 self.ts_attn[i].bias.data.fill_(0.5) 348 349 def _init_weights(self, m): 350 if isinstance(m, nn.Linear): 351 trunc_normal_(m.weight, std=.02) 352 if isinstance(m, nn.Linear) and m.bias is not None: 353 nn.init.constant_(m.bias, 0) 354 elif isinstance(m, nn.LayerNorm): 355 nn.init.constant_(m.bias, 0) 356 nn.init.constant_(m.weight, 1.0) 357 358 def get_classifier(self): 359 return self.head 360 361 def reset_classifier(self, dim_out, global_pool=''): 362 self.dim_out = dim_out 363 self.head = nn.Linear(self.dim_feat, dim_out) if dim_out > 0 else nn.Identity() 364 365 def forward(self, x, return_rep=False): 366 B, F, J, C = x.shape 367 x = x.reshape(-1, J, C) 368 BF = x.shape[0] 369 x = self.joints_embed(x) 370 x = x + self.pos_embed 371 _, J, C = x.shape 372 x = x.reshape(-1, F, J, C) + self.temp_embed[:,:F,:,:] 373 x = x.reshape(BF, J, C) 374 x = self.pos_drop(x) 375 alphas = [] 376 for idx, (blk_st, blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)): 377 x_st = blk_st(x, F) 378 x_ts = blk_ts(x, F) 379 if self.att_fuse: 380 att = self.ts_attn[idx] 381 alpha = torch.cat([x_st, x_ts], dim=-1) 382 BF, J = alpha.shape[:2] 383 alpha = att(alpha) 384 alpha = alpha.softmax(dim=-1) 385 x = x_st * alpha[:,:,0:1] + x_ts * alpha[:,:,1:2] 386 else: 387 x = (x_st + x_ts)*0.5 388 x = self.norm(x) 389 x = x.reshape(B, F, J, -1) 390 x = self.pre_logits(x) # [B, F, J, dim_feat] 391 if return_rep: 392 return x 393 x = self.head(x) 394 return x 395 396 def get_representation(self, x): 397 return self.forward(x, return_rep=True)
21def drop_path(x, drop_prob: float = 0., training: bool = False): 22 """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 23 This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 24 the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 25 See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 26 changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 27 'survival rate' as the argument. 28 """ 29 if drop_prob == 0. or not training: 30 return x 31 keep_prob = 1 - drop_prob 32 shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 33 random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 34 random_tensor.floor_() # binarize 35 output = x.div(keep_prob) * random_tensor 36 return output
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument.
39class DropPath(nn.Module): 40 """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 41 """ 42 def __init__(self, drop_prob=None): 43 super(DropPath, self).__init__() 44 self.drop_prob = drop_prob 45 46 def forward(self, x): 47 return drop_path(x, self.drop_prob, self.training)
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
42 def __init__(self, drop_prob=None): 43 super(DropPath, self).__init__() 44 self.drop_prob = drop_prob
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
86def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.) -> torch.Tensor: 87 r"""Fills the input Tensor with values drawn from a truncated 88 normal distribution. The values are effectively drawn from the 89 normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 90 with values outside :math:`[a, b]` redrawn until they are within 91 the bounds. The method used for generating the random values works 92 best when :math:`a \leq \text{mean} \leq b`. 93 Args: 94 tensor: an n-dimensional `torch.Tensor` 95 mean: the mean of the normal distribution 96 std: the standard deviation of the normal distribution 97 a: the minimum cutoff value 98 b: the maximum cutoff value 99 Examples: 100 >>> w = torch.empty(3, 5) 101 >>> nn.init.trunc_normal_(w) 102 """ 103 return _no_grad_trunc_normal_(tensor, mean, std, a, b)
Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution \( \mathcal{N}(\text{mean}, \text{std}^2) \)
with values outside \( [a, b] \) redrawn until they are within
the bounds. The method used for generating the random values works
best when \( a \leq \text{mean} \leq b \).
Args:
tensor: an n-dimensional torch.Tensor
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
w = torch.empty(3, 5) nn.init.trunc_normal_(w)
106class MLP(nn.Module): 107 def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 108 super().__init__() 109 out_features = out_features or in_features 110 hidden_features = hidden_features or in_features 111 self.fc1 = nn.Linear(in_features, hidden_features) 112 self.act = act_layer() 113 self.fc2 = nn.Linear(hidden_features, out_features) 114 self.drop = nn.Dropout(drop) 115 116 def forward(self, x): 117 x = self.fc1(x) 118 x = self.act(x) 119 x = self.drop(x) 120 x = self.fc2(x) 121 x = self.drop(x) 122 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
107 def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 108 super().__init__() 109 out_features = out_features or in_features 110 hidden_features = hidden_features or in_features 111 self.fc1 = nn.Linear(in_features, hidden_features) 112 self.act = act_layer() 113 self.fc2 = nn.Linear(hidden_features, out_features) 114 self.drop = nn.Dropout(drop)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
116 def forward(self, x): 117 x = self.fc1(x) 118 x = self.act(x) 119 x = self.drop(x) 120 x = self.fc2(x) 121 x = self.drop(x) 122 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
125class Attention(nn.Module): 126 def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., st_mode='vanilla'): 127 super().__init__() 128 self.num_heads = num_heads 129 head_dim = dim // num_heads 130 # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 131 self.scale = qk_scale or head_dim ** -0.5 132 133 self.attn_drop = nn.Dropout(attn_drop) 134 self.proj = nn.Linear(dim, dim) 135 self.mode = st_mode 136 if self.mode == 'parallel': 137 self.ts_attn = nn.Linear(dim*2, dim*2) 138 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 139 else: 140 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 141 self.proj_drop = nn.Dropout(proj_drop) 142 143 self.attn_count_s = None 144 self.attn_count_t = None 145 146 def forward(self, x, seqlen=1): 147 B, N, C = x.shape 148 149 if self.mode == 'series': 150 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 151 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 152 x = self.forward_spatial(q, k, v) 153 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 154 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 155 x = self.forward_temporal(q, k, v, seqlen=seqlen) 156 elif self.mode == 'parallel': 157 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 158 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 159 x_t = self.forward_temporal(q, k, v, seqlen=seqlen) 160 x_s = self.forward_spatial(q, k, v) 161 162 alpha = torch.cat([x_s, x_t], dim=-1) 163 alpha = alpha.mean(dim=1, keepdim=True) 164 alpha = self.ts_attn(alpha).reshape(B, 1, C, 2) 165 alpha = alpha.softmax(dim=-1) 166 x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0] 167 elif self.mode == 'coupling': 168 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 169 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 170 x = self.forward_coupling(q, k, v, seqlen=seqlen) 171 elif self.mode == 'vanilla': 172 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 173 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 174 x = self.forward_spatial(q, k, v) 175 elif self.mode == 'temporal': 176 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 177 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 178 x = self.forward_temporal(q, k, v, seqlen=seqlen) 179 elif self.mode == 'spatial': 180 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 181 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 182 x = self.forward_spatial(q, k, v) 183 else: 184 raise NotImplementedError(self.mode) 185 x = self.proj(x) 186 x = self.proj_drop(x) 187 return x 188 189 def reshape_T(self, x, seqlen=1, inverse=False): 190 if not inverse: 191 N, C = x.shape[-2:] 192 x = x.reshape(-1, seqlen, self.num_heads, N, C).transpose(1,2) 193 x = x.reshape(-1, self.num_heads, seqlen*N, C) #(B, H, TN, c) 194 else: 195 TN, C = x.shape[-2:] 196 x = x.reshape(-1, self.num_heads, seqlen, TN // seqlen, C).transpose(1,2) 197 x = x.reshape(-1, self.num_heads, TN // seqlen, C) #(BT, H, N, C) 198 return x 199 200 def forward_coupling(self, q, k, v, seqlen=8): 201 BT, _, N, C = q.shape 202 q = self.reshape_T(q, seqlen) 203 k = self.reshape_T(k, seqlen) 204 v = self.reshape_T(v, seqlen) 205 206 attn = (q @ k.transpose(-2, -1)) * self.scale 207 attn = attn.softmax(dim=-1) 208 attn = self.attn_drop(attn) 209 210 x = attn @ v 211 x = self.reshape_T(x, seqlen, inverse=True) 212 x = x.transpose(1,2).reshape(BT, N, C*self.num_heads) 213 return x 214 215 def forward_spatial(self, q, k, v): 216 B, _, N, C = q.shape 217 attn = (q @ k.transpose(-2, -1)) * self.scale 218 attn = attn.softmax(dim=-1) 219 attn = self.attn_drop(attn) 220 221 x = attn @ v 222 x = x.transpose(1,2).reshape(B, N, C*self.num_heads) 223 return x 224 225 def forward_temporal(self, q, k, v, seqlen=8): 226 B, _, N, C = q.shape 227 qt = q.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) 228 kt = k.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) 229 vt = v.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) 230 231 attn = (qt @ kt.transpose(-2, -1)) * self.scale 232 attn = attn.softmax(dim=-1) 233 attn = self.attn_drop(attn) 234 235 x = attn @ vt #(B, H, N, T, C) 236 x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C*self.num_heads) 237 return x 238 239 def count_attn(self, attn): 240 attn = attn.detach().cpu().numpy() 241 attn = attn.mean(axis=1) 242 attn_t = attn[:, :, 1].mean(axis=1) 243 attn_s = attn[:, :, 0].mean(axis=1) 244 if self.attn_count_s is None: 245 self.attn_count_s = attn_s 246 self.attn_count_t = attn_t 247 else: 248 self.attn_count_s = np.concatenate([self.attn_count_s, attn_s], axis=0) 249 self.attn_count_t = np.concatenate([self.attn_count_t, attn_t], axis=0)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
126 def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., st_mode='vanilla'): 127 super().__init__() 128 self.num_heads = num_heads 129 head_dim = dim // num_heads 130 # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 131 self.scale = qk_scale or head_dim ** -0.5 132 133 self.attn_drop = nn.Dropout(attn_drop) 134 self.proj = nn.Linear(dim, dim) 135 self.mode = st_mode 136 if self.mode == 'parallel': 137 self.ts_attn = nn.Linear(dim*2, dim*2) 138 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 139 else: 140 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 141 self.proj_drop = nn.Dropout(proj_drop) 142 143 self.attn_count_s = None 144 self.attn_count_t = None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
146 def forward(self, x, seqlen=1): 147 B, N, C = x.shape 148 149 if self.mode == 'series': 150 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 151 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 152 x = self.forward_spatial(q, k, v) 153 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 154 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 155 x = self.forward_temporal(q, k, v, seqlen=seqlen) 156 elif self.mode == 'parallel': 157 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 158 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 159 x_t = self.forward_temporal(q, k, v, seqlen=seqlen) 160 x_s = self.forward_spatial(q, k, v) 161 162 alpha = torch.cat([x_s, x_t], dim=-1) 163 alpha = alpha.mean(dim=1, keepdim=True) 164 alpha = self.ts_attn(alpha).reshape(B, 1, C, 2) 165 alpha = alpha.softmax(dim=-1) 166 x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0] 167 elif self.mode == 'coupling': 168 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 169 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 170 x = self.forward_coupling(q, k, v, seqlen=seqlen) 171 elif self.mode == 'vanilla': 172 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 173 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 174 x = self.forward_spatial(q, k, v) 175 elif self.mode == 'temporal': 176 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 177 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 178 x = self.forward_temporal(q, k, v, seqlen=seqlen) 179 elif self.mode == 'spatial': 180 qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 181 q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 182 x = self.forward_spatial(q, k, v) 183 else: 184 raise NotImplementedError(self.mode) 185 x = self.proj(x) 186 x = self.proj_drop(x) 187 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
189 def reshape_T(self, x, seqlen=1, inverse=False): 190 if not inverse: 191 N, C = x.shape[-2:] 192 x = x.reshape(-1, seqlen, self.num_heads, N, C).transpose(1,2) 193 x = x.reshape(-1, self.num_heads, seqlen*N, C) #(B, H, TN, c) 194 else: 195 TN, C = x.shape[-2:] 196 x = x.reshape(-1, self.num_heads, seqlen, TN // seqlen, C).transpose(1,2) 197 x = x.reshape(-1, self.num_heads, TN // seqlen, C) #(BT, H, N, C) 198 return x
200 def forward_coupling(self, q, k, v, seqlen=8): 201 BT, _, N, C = q.shape 202 q = self.reshape_T(q, seqlen) 203 k = self.reshape_T(k, seqlen) 204 v = self.reshape_T(v, seqlen) 205 206 attn = (q @ k.transpose(-2, -1)) * self.scale 207 attn = attn.softmax(dim=-1) 208 attn = self.attn_drop(attn) 209 210 x = attn @ v 211 x = self.reshape_T(x, seqlen, inverse=True) 212 x = x.transpose(1,2).reshape(BT, N, C*self.num_heads) 213 return x
225 def forward_temporal(self, q, k, v, seqlen=8): 226 B, _, N, C = q.shape 227 qt = q.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) 228 kt = k.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) 229 vt = v.reshape(-1, seqlen, self.num_heads, N, C).permute(0, 2, 3, 1, 4) #(B, H, N, T, C) 230 231 attn = (qt @ kt.transpose(-2, -1)) * self.scale 232 attn = attn.softmax(dim=-1) 233 attn = self.attn_drop(attn) 234 235 x = attn @ vt #(B, H, N, T, C) 236 x = x.permute(0, 3, 2, 1, 4).reshape(B, N, C*self.num_heads) 237 return x
239 def count_attn(self, attn): 240 attn = attn.detach().cpu().numpy() 241 attn = attn.mean(axis=1) 242 attn_t = attn[:, :, 1].mean(axis=1) 243 attn_s = attn[:, :, 0].mean(axis=1) 244 if self.attn_count_s is None: 245 self.attn_count_s = attn_s 246 self.attn_count_t = attn_t 247 else: 248 self.attn_count_s = np.concatenate([self.attn_count_s, attn_s], axis=0) 249 self.attn_count_t = np.concatenate([self.attn_count_t, attn_t], axis=0)
251class Block(nn.Module): 252 253 def __init__(self, dim, num_heads, mlp_ratio=4., mlp_out_ratio=1., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 254 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, st_mode='stage_st', att_fuse=False): 255 super().__init__() 256 # assert 'stage' in st_mode 257 self.st_mode = st_mode 258 self.norm1_s = norm_layer(dim) 259 self.norm1_t = norm_layer(dim) 260 self.attn_s = Attention( 261 dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="spatial") 262 self.attn_t = Attention( 263 dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="temporal") 264 265 # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 266 self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 267 self.norm2_s = norm_layer(dim) 268 self.norm2_t = norm_layer(dim) 269 mlp_hidden_dim = int(dim * mlp_ratio) 270 mlp_out_dim = int(dim * mlp_out_ratio) 271 self.mlp_s = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop) 272 self.mlp_t = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop) 273 self.att_fuse = att_fuse 274 if self.att_fuse: 275 self.ts_attn = nn.Linear(dim*2, dim*2) 276 def forward(self, x, seqlen=1): 277 if self.st_mode=='stage_st': 278 x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) 279 x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) 280 x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) 281 x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) 282 elif self.st_mode=='stage_ts': 283 x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) 284 x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) 285 x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) 286 x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) 287 elif self.st_mode=='stage_para': 288 x_t = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) 289 x_t = x_t + self.drop_path(self.mlp_t(self.norm2_t(x_t))) 290 x_s = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) 291 x_s = x_s + self.drop_path(self.mlp_s(self.norm2_s(x_s))) 292 if self.att_fuse: 293 # x_s, x_t: [BF, J, dim] 294 alpha = torch.cat([x_s, x_t], dim=-1) 295 BF, J = alpha.shape[:2] 296 # alpha = alpha.mean(dim=1, keepdim=True) 297 alpha = self.ts_attn(alpha).reshape(BF, J, -1, 2) 298 alpha = alpha.softmax(dim=-1) 299 x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0] 300 else: 301 x = (x_s + x_t)*0.5 302 else: 303 raise NotImplementedError(self.st_mode) 304 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
253 def __init__(self, dim, num_heads, mlp_ratio=4., mlp_out_ratio=1., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 254 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, st_mode='stage_st', att_fuse=False): 255 super().__init__() 256 # assert 'stage' in st_mode 257 self.st_mode = st_mode 258 self.norm1_s = norm_layer(dim) 259 self.norm1_t = norm_layer(dim) 260 self.attn_s = Attention( 261 dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="spatial") 262 self.attn_t = Attention( 263 dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, st_mode="temporal") 264 265 # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 266 self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 267 self.norm2_s = norm_layer(dim) 268 self.norm2_t = norm_layer(dim) 269 mlp_hidden_dim = int(dim * mlp_ratio) 270 mlp_out_dim = int(dim * mlp_out_ratio) 271 self.mlp_s = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop) 272 self.mlp_t = MLP(in_features=dim, hidden_features=mlp_hidden_dim, out_features=mlp_out_dim, act_layer=act_layer, drop=drop) 273 self.att_fuse = att_fuse 274 if self.att_fuse: 275 self.ts_attn = nn.Linear(dim*2, dim*2)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
276 def forward(self, x, seqlen=1): 277 if self.st_mode=='stage_st': 278 x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) 279 x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) 280 x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) 281 x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) 282 elif self.st_mode=='stage_ts': 283 x = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) 284 x = x + self.drop_path(self.mlp_t(self.norm2_t(x))) 285 x = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) 286 x = x + self.drop_path(self.mlp_s(self.norm2_s(x))) 287 elif self.st_mode=='stage_para': 288 x_t = x + self.drop_path(self.attn_t(self.norm1_t(x), seqlen)) 289 x_t = x_t + self.drop_path(self.mlp_t(self.norm2_t(x_t))) 290 x_s = x + self.drop_path(self.attn_s(self.norm1_s(x), seqlen)) 291 x_s = x_s + self.drop_path(self.mlp_s(self.norm2_s(x_s))) 292 if self.att_fuse: 293 # x_s, x_t: [BF, J, dim] 294 alpha = torch.cat([x_s, x_t], dim=-1) 295 BF, J = alpha.shape[:2] 296 # alpha = alpha.mean(dim=1, keepdim=True) 297 alpha = self.ts_attn(alpha).reshape(BF, J, -1, 2) 298 alpha = alpha.softmax(dim=-1) 299 x = x_t * alpha[:,:,:,1] + x_s * alpha[:,:,:,0] 300 else: 301 x = (x_s + x_t)*0.5 302 else: 303 raise NotImplementedError(self.st_mode) 304 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
306class DSTformer(nn.Module): 307 def __init__(self, dim_in=3, dim_out=3, dim_feat=256, dim_rep=512, 308 depth=5, num_heads=8, mlp_ratio=4, 309 num_joints=17, maxlen=243, 310 qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, att_fuse=True): 311 super().__init__() 312 self.dim_out = dim_out 313 self.dim_feat = dim_feat 314 self.joints_embed = nn.Linear(dim_in, dim_feat) 315 self.pos_drop = nn.Dropout(p=drop_rate) 316 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 317 self.blocks_st = nn.ModuleList([ 318 Block( 319 dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 320 drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 321 st_mode="stage_st") 322 for i in range(depth)]) 323 self.blocks_ts = nn.ModuleList([ 324 Block( 325 dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 326 drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 327 st_mode="stage_ts") 328 for i in range(depth)]) 329 self.norm = norm_layer(dim_feat) 330 if dim_rep: 331 self.pre_logits = nn.Sequential(OrderedDict([ 332 ('fc', nn.Linear(dim_feat, dim_rep)), 333 ('act', nn.Tanh()) 334 ])) 335 else: 336 self.pre_logits = nn.Identity() 337 self.head = nn.Linear(dim_rep, dim_out) if dim_out > 0 else nn.Identity() 338 self.temp_embed = nn.Parameter(torch.zeros(1, maxlen, 1, dim_feat)) 339 self.pos_embed = nn.Parameter(torch.zeros(1, num_joints, dim_feat)) 340 trunc_normal_(self.temp_embed, std=.02) 341 trunc_normal_(self.pos_embed, std=.02) 342 self.apply(self._init_weights) 343 self.att_fuse = att_fuse 344 if self.att_fuse: 345 self.ts_attn = nn.ModuleList([nn.Linear(dim_feat*2, 2) for i in range(depth)]) 346 for i in range(depth): 347 self.ts_attn[i].weight.data.fill_(0) 348 self.ts_attn[i].bias.data.fill_(0.5) 349 350 def _init_weights(self, m): 351 if isinstance(m, nn.Linear): 352 trunc_normal_(m.weight, std=.02) 353 if isinstance(m, nn.Linear) and m.bias is not None: 354 nn.init.constant_(m.bias, 0) 355 elif isinstance(m, nn.LayerNorm): 356 nn.init.constant_(m.bias, 0) 357 nn.init.constant_(m.weight, 1.0) 358 359 def get_classifier(self): 360 return self.head 361 362 def reset_classifier(self, dim_out, global_pool=''): 363 self.dim_out = dim_out 364 self.head = nn.Linear(self.dim_feat, dim_out) if dim_out > 0 else nn.Identity() 365 366 def forward(self, x, return_rep=False): 367 B, F, J, C = x.shape 368 x = x.reshape(-1, J, C) 369 BF = x.shape[0] 370 x = self.joints_embed(x) 371 x = x + self.pos_embed 372 _, J, C = x.shape 373 x = x.reshape(-1, F, J, C) + self.temp_embed[:,:F,:,:] 374 x = x.reshape(BF, J, C) 375 x = self.pos_drop(x) 376 alphas = [] 377 for idx, (blk_st, blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)): 378 x_st = blk_st(x, F) 379 x_ts = blk_ts(x, F) 380 if self.att_fuse: 381 att = self.ts_attn[idx] 382 alpha = torch.cat([x_st, x_ts], dim=-1) 383 BF, J = alpha.shape[:2] 384 alpha = att(alpha) 385 alpha = alpha.softmax(dim=-1) 386 x = x_st * alpha[:,:,0:1] + x_ts * alpha[:,:,1:2] 387 else: 388 x = (x_st + x_ts)*0.5 389 x = self.norm(x) 390 x = x.reshape(B, F, J, -1) 391 x = self.pre_logits(x) # [B, F, J, dim_feat] 392 if return_rep: 393 return x 394 x = self.head(x) 395 return x 396 397 def get_representation(self, x): 398 return self.forward(x, return_rep=True)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
307 def __init__(self, dim_in=3, dim_out=3, dim_feat=256, dim_rep=512, 308 depth=5, num_heads=8, mlp_ratio=4, 309 num_joints=17, maxlen=243, 310 qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, att_fuse=True): 311 super().__init__() 312 self.dim_out = dim_out 313 self.dim_feat = dim_feat 314 self.joints_embed = nn.Linear(dim_in, dim_feat) 315 self.pos_drop = nn.Dropout(p=drop_rate) 316 dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 317 self.blocks_st = nn.ModuleList([ 318 Block( 319 dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 320 drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 321 st_mode="stage_st") 322 for i in range(depth)]) 323 self.blocks_ts = nn.ModuleList([ 324 Block( 325 dim=dim_feat, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 326 drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 327 st_mode="stage_ts") 328 for i in range(depth)]) 329 self.norm = norm_layer(dim_feat) 330 if dim_rep: 331 self.pre_logits = nn.Sequential(OrderedDict([ 332 ('fc', nn.Linear(dim_feat, dim_rep)), 333 ('act', nn.Tanh()) 334 ])) 335 else: 336 self.pre_logits = nn.Identity() 337 self.head = nn.Linear(dim_rep, dim_out) if dim_out > 0 else nn.Identity() 338 self.temp_embed = nn.Parameter(torch.zeros(1, maxlen, 1, dim_feat)) 339 self.pos_embed = nn.Parameter(torch.zeros(1, num_joints, dim_feat)) 340 trunc_normal_(self.temp_embed, std=.02) 341 trunc_normal_(self.pos_embed, std=.02) 342 self.apply(self._init_weights) 343 self.att_fuse = att_fuse 344 if self.att_fuse: 345 self.ts_attn = nn.ModuleList([nn.Linear(dim_feat*2, 2) for i in range(depth)]) 346 for i in range(depth): 347 self.ts_attn[i].weight.data.fill_(0) 348 self.ts_attn[i].bias.data.fill_(0.5)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
366 def forward(self, x, return_rep=False): 367 B, F, J, C = x.shape 368 x = x.reshape(-1, J, C) 369 BF = x.shape[0] 370 x = self.joints_embed(x) 371 x = x + self.pos_embed 372 _, J, C = x.shape 373 x = x.reshape(-1, F, J, C) + self.temp_embed[:,:F,:,:] 374 x = x.reshape(BF, J, C) 375 x = self.pos_drop(x) 376 alphas = [] 377 for idx, (blk_st, blk_ts) in enumerate(zip(self.blocks_st, self.blocks_ts)): 378 x_st = blk_st(x, F) 379 x_ts = blk_ts(x, F) 380 if self.att_fuse: 381 att = self.ts_attn[idx] 382 alpha = torch.cat([x_st, x_ts], dim=-1) 383 BF, J = alpha.shape[:2] 384 alpha = att(alpha) 385 alpha = alpha.softmax(dim=-1) 386 x = x_st * alpha[:,:,0:1] + x_ts * alpha[:,:,1:2] 387 else: 388 x = (x_st + x_ts)*0.5 389 x = self.norm(x) 390 x = x.reshape(B, F, J, -1) 391 x = self.pre_logits(x) # [B, F, J, dim_feat] 392 if return_rep: 393 return x 394 x = self.head(x) 395 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.