dlc2action.model.asformer
ASFormer
Adapted from https://github.com/ChinaYi/ASFormer
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# 7# Adapted from ASFormer by ChinaYi 8# Adapted from https://github.com/ChinaYi/ASFormer 9# Licensed under MIT License 10# 11""" ASFormer 12 13Adapted from https://github.com/ChinaYi/ASFormer 14""" 15 16import torch 17import torch.nn as nn 18import torch.nn.functional as F 19from torch import optim 20from dlc2action.model.base_model import Model 21from typing import Union, List 22 23import copy 24import numpy as np 25import math 26 27device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 28 29 30def _exponential_descrease(idx_decoder, p=3): 31 return math.exp(-p * idx_decoder) 32 33 34class _AttentionHelper(nn.Module): 35 def __init__(self): 36 super(_AttentionHelper, self).__init__() 37 self.softmax = nn.Softmax(dim=-1) 38 39 def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask): 40 """ 41 scalar dot attention. 42 :param proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length) 43 :param proj_key: shape of (B, C, L) 44 :param proj_val: shape of (B, C, L) 45 :param padding_mask: shape of (B, C, L) 46 :return: attention value of shape (B, C, L) 47 """ 48 m, c1, l1 = proj_query.shape 49 m, c2, l2 = proj_key.shape 50 51 assert c1 == c2 52 53 energy = torch.bmm( 54 proj_query.permute(0, 2, 1), proj_key 55 ) # out of shape (B, L1, L2) 56 attention = energy / np.sqrt(c1) 57 attention = attention + torch.log( 58 padding_mask + 1e-6 59 ) # mask the zero paddings. log(1e-6) for zero paddings 60 attention = self.softmax(attention) 61 attention = attention * padding_mask 62 attention = attention.permute(0, 2, 1) 63 out = torch.bmm(proj_val, attention) 64 return out, attention 65 66 67class _AttLayer(nn.Module): 68 def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type): # r1 = r2 69 super(_AttLayer, self).__init__() 70 71 self.query_conv = nn.Conv1d( 72 in_channels=q_dim, out_channels=q_dim // r1, kernel_size=1 73 ) 74 self.key_conv = nn.Conv1d( 75 in_channels=k_dim, out_channels=k_dim // r2, kernel_size=1 76 ) 77 self.value_conv = nn.Conv1d( 78 in_channels=v_dim, out_channels=v_dim // r3, kernel_size=1 79 ) 80 81 self.conv_out = nn.Conv1d( 82 in_channels=v_dim // r3, out_channels=v_dim, kernel_size=1 83 ) 84 85 self.bl = bl 86 self.stage = stage 87 self.att_type = att_type 88 assert self.att_type in ["normal_att", "block_att", "sliding_att"] 89 assert self.stage in ["encoder", "decoder"] 90 91 self.att_helper = _AttentionHelper() 92 self.window_mask = self.construct_window_mask() 93 94 def construct_window_mask(self): 95 """ 96 construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention 97 """ 98 window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2))) 99 for i in range(self.bl): 100 window_mask[:, :, i : i + self.bl] = 1 101 return window_mask.to(device) 102 103 def forward(self, x1, x2, mask): 104 # x1 from the encoder 105 # x2 from the decoder 106 107 query = self.query_conv(x1) 108 key = self.key_conv(x1) 109 110 if self.stage == "decoder": 111 assert x2 is not None 112 value = self.value_conv(x2) 113 else: 114 value = self.value_conv(x1) 115 116 if self.att_type == "normal_att": 117 return self._normal_self_att(query, key, value, mask) 118 elif self.att_type == "block_att": 119 return self._block_wise_self_att(query, key, value, mask) 120 elif self.att_type == "sliding_att": 121 return self._sliding_window_self_att(query, key, value, mask) 122 123 def _normal_self_att(self, q, k, v, mask): 124 m_batchsize, c1, L = q.size() 125 _, c2, L = k.size() 126 _, c3, L = v.size() 127 padding_mask = torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :] 128 output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask) 129 output = self.conv_out(F.relu(output)) 130 output = output[:, :, 0:L] 131 return output * mask[:, 0:1, :] 132 133 def _block_wise_self_att(self, q, k, v, mask): 134 m_batchsize, c1, L = q.size() 135 _, c2, L = k.size() 136 _, c3, L = v.size() 137 138 nb = L // self.bl 139 if L % self.bl != 0: 140 q = torch.cat( 141 [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)], 142 dim=-1, 143 ) 144 k = torch.cat( 145 [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)], 146 dim=-1, 147 ) 148 v = torch.cat( 149 [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)], 150 dim=-1, 151 ) 152 nb += 1 153 154 padding_mask = torch.cat( 155 [ 156 torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :], 157 torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device), 158 ], 159 dim=-1, 160 ) 161 162 q = ( 163 q.reshape(m_batchsize, c1, nb, self.bl) 164 .permute(0, 2, 1, 3) 165 .reshape(m_batchsize * nb, c1, self.bl) 166 ) 167 padding_mask = ( 168 padding_mask.reshape(m_batchsize, 1, nb, self.bl) 169 .permute(0, 2, 1, 3) 170 .reshape(m_batchsize * nb, 1, self.bl) 171 ) 172 k = ( 173 k.reshape(m_batchsize, c2, nb, self.bl) 174 .permute(0, 2, 1, 3) 175 .reshape(m_batchsize * nb, c2, self.bl) 176 ) 177 v = ( 178 v.reshape(m_batchsize, c3, nb, self.bl) 179 .permute(0, 2, 1, 3) 180 .reshape(m_batchsize * nb, c3, self.bl) 181 ) 182 183 output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask) 184 output = self.conv_out(F.relu(output)) 185 186 output = ( 187 output.reshape(m_batchsize, nb, c3, self.bl) 188 .permute(0, 2, 1, 3) 189 .reshape(m_batchsize, c3, nb * self.bl) 190 ) 191 output = output[:, :, 0:L] 192 return output * mask[:, 0:1, :] 193 194 def _sliding_window_self_att(self, q, k, v, mask): 195 m_batchsize, c1, L = q.size() 196 _, c2, _ = k.size() 197 _, c3, _ = v.size() 198 199 # padding zeros for the last segment 200 nb = L // self.bl 201 if L % self.bl != 0: 202 q = torch.cat( 203 [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)], 204 dim=-1, 205 ) 206 k = torch.cat( 207 [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)], 208 dim=-1, 209 ) 210 v = torch.cat( 211 [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)], 212 dim=-1, 213 ) 214 nb += 1 215 padding_mask = torch.cat( 216 [ 217 torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :], 218 torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device), 219 ], 220 dim=-1, 221 ) 222 223 # sliding window approach, by splitting query_proj and key_proj into shape (c1, l) x (c1, 2l) 224 # sliding window for query_proj: reshape 225 q = ( 226 q.reshape(m_batchsize, c1, nb, self.bl) 227 .permute(0, 2, 1, 3) 228 .reshape(m_batchsize * nb, c1, self.bl) 229 ) 230 231 # sliding window approach for key_proj 232 # 1. add paddings at the start and end 233 k = torch.cat( 234 [ 235 torch.zeros(m_batchsize, c2, self.bl // 2).to(device), 236 k, 237 torch.zeros(m_batchsize, c2, self.bl // 2).to(device), 238 ], 239 dim=-1, 240 ) 241 v = torch.cat( 242 [ 243 torch.zeros(m_batchsize, c3, self.bl // 2).to(device), 244 v, 245 torch.zeros(m_batchsize, c3, self.bl // 2).to(device), 246 ], 247 dim=-1, 248 ) 249 padding_mask = torch.cat( 250 [ 251 torch.zeros(m_batchsize, 1, self.bl // 2).to(device), 252 padding_mask, 253 torch.zeros(m_batchsize, 1, self.bl // 2).to(device), 254 ], 255 dim=-1, 256 ) 257 258 # 2. reshape key_proj of shape (m_batchsize*nb, c1, 2*self.bl) 259 k = torch.cat( 260 [ 261 k[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2] 262 for i in range(nb) 263 ], 264 dim=0, 265 ) # special case when self.bl = 1 266 v = torch.cat( 267 [ 268 v[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2] 269 for i in range(nb) 270 ], 271 dim=0, 272 ) 273 # 3. construct window mask of shape (1, l, 2l), and use it to generate final mask 274 padding_mask = torch.cat( 275 [ 276 padding_mask[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2] 277 for i in range(nb) 278 ], 279 dim=0, 280 ) # of shape (m*nb, 1, 2l) 281 final_mask = self.window_mask.repeat(m_batchsize * nb, 1, 1) * padding_mask 282 283 output, attention = self.att_helper.scalar_dot_att(q, k, v, final_mask) 284 output = self.conv_out(F.relu(output)) 285 286 output = ( 287 output.reshape(m_batchsize, nb, -1, self.bl) 288 .permute(0, 2, 1, 3) 289 .reshape(m_batchsize, -1, nb * self.bl) 290 ) 291 output = output[:, :, 0:L] 292 return output * mask[:, 0:1, :] 293 294 295class _MultiHeadAttLayer(nn.Module): 296 def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, num_head): 297 super(_MultiHeadAttLayer, self).__init__() 298 # assert v_dim % num_head == 0 299 self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1) 300 self.layers = nn.ModuleList( 301 [ 302 copy.deepcopy( 303 _AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type) 304 ) 305 for i in range(num_head) 306 ] 307 ) 308 self.dropout = nn.Dropout(p=0.5) 309 310 def forward(self, x1, x2, mask): 311 out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1) 312 out = self.conv_out(self.dropout(out)) 313 return out 314 315 316class _ConvFeedForward(nn.Module): 317 def __init__(self, dilation, in_channels, out_channels): 318 super(_ConvFeedForward, self).__init__() 319 self.layer = nn.Sequential( 320 nn.Conv1d( 321 in_channels, out_channels, 3, padding=dilation, dilation=dilation 322 ), 323 nn.ReLU(), 324 ) 325 326 def forward(self, x): 327 return self.layer(x) 328 329 330class _FCFeedForward(nn.Module): 331 def __init__(self, in_channels, out_channels): 332 super(_FCFeedForward, self).__init__() 333 self.layer = nn.Sequential( 334 nn.Conv1d(in_channels, out_channels, 1), # conv1d equals fc 335 nn.ReLU(), 336 nn.Dropout(), 337 nn.Conv1d(out_channels, out_channels, 1), 338 ) 339 340 def forward(self, x): 341 return self.layer(x) 342 343 344class _AttModule(nn.Module): 345 def __init__( 346 self, dilation, in_channels, out_channels, r1, r2, att_type, stage, alpha 347 ): 348 super(_AttModule, self).__init__() 349 self.feed_forward = _ConvFeedForward(dilation, in_channels, out_channels) 350 self.instance_norm = nn.InstanceNorm1d(in_channels, track_running_stats=False) 351 self.att_layer = _AttLayer( 352 in_channels, 353 in_channels, 354 out_channels, 355 r1, 356 r1, 357 r2, 358 dilation, 359 att_type=att_type, 360 stage=stage, 361 ) # dilation 362 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 363 self.dropout = nn.Dropout() 364 self.alpha = alpha 365 366 def forward(self, x, f, mask): 367 out = self.feed_forward(x) 368 out = self.alpha * self.att_layer(self.instance_norm(out), f, mask) + out 369 out = self.conv_1x1(out) 370 out = self.dropout(out) 371 return (x + out) * mask[:, 0:1, :] 372 373 374class _PositionalEncoding(nn.Module): 375 "Implement the PE function." 376 377 def __init__(self, d_model, max_len=10000): 378 super(_PositionalEncoding, self).__init__() 379 # Compute the positional encodings once in log space. 380 pe = torch.zeros(max_len, d_model) 381 position = torch.arange(0, max_len).unsqueeze(1) 382 div_term = torch.exp( 383 torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) 384 ) 385 pe[:, 0::2] = torch.sin(position * div_term) 386 pe[:, 1::2] = torch.cos(position * div_term) 387 pe = pe.unsqueeze(0).permute(0, 2, 1) # of shape (1, d_model, l) 388 self.pe = nn.Parameter(pe, requires_grad=True) 389 390 # self.register_buffer('pe', pe) 391 392 def forward(self, x): 393 return x + self.pe[:, :, 0 : x.shape[2]] 394 395 396class _Encoder(nn.Module): 397 def __init__( 398 self, 399 num_layers, 400 r1, 401 r2, 402 num_f_maps, 403 input_dim, 404 channel_masking_rate, 405 att_type, 406 alpha, 407 ): 408 super(_Encoder, self).__init__() 409 self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) # fc layer 410 self.layers = nn.ModuleList( 411 [ 412 _AttModule( 413 2**i, num_f_maps, num_f_maps, r1, r2, att_type, "encoder", alpha 414 ) 415 for i in range(num_layers) # 2**i 416 ] 417 ) 418 419 self.dropout = nn.Dropout2d(p=channel_masking_rate) 420 self.channel_masking_rate = channel_masking_rate 421 422 def forward(self, x, mask): 423 """ 424 :param x: (N, C, L) 425 :param mask: 426 :return: 427 """ 428 429 if self.channel_masking_rate > 0: 430 x = x.unsqueeze(2) 431 x = self.dropout(x) 432 x = x.squeeze(2) 433 434 feature = self.conv_1x1(x) 435 for layer in self.layers: 436 feature = layer(feature, None, mask) 437 438 return feature 439 440 441class _Decoder(nn.Module): 442 def __init__( 443 self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, att_type, alpha 444 ): 445 super( 446 _Decoder, self 447 ).__init__() # self.position_en = PositionalEncoding(d_model=num_f_maps) 448 self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) 449 self.layers = nn.ModuleList( 450 [ 451 _AttModule( 452 2**i, num_f_maps, num_f_maps, r1, r2, att_type, "decoder", alpha 453 ) 454 for i in range(num_layers) # 2 ** i 455 ] 456 ) 457 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 458 459 def forward(self, x, fencoder, mask): 460 feature = self.conv_1x1(x) 461 for layer in self.layers: 462 feature = layer(feature, fencoder, mask) 463 464 out = self.conv_out(feature) * mask[:, 0:1, :] 465 466 return out, feature 467 468 469class _MyTransformer(nn.Module): 470 def __init__( 471 self, 472 num_layers, 473 r1, 474 r2, 475 num_f_maps, 476 input_dim, 477 channel_masking_rate, 478 ): 479 super(_MyTransformer, self).__init__() 480 self.encoder = _Encoder( 481 num_layers, 482 r1, 483 r2, 484 num_f_maps, 485 input_dim, 486 channel_masking_rate, 487 att_type="sliding_att", 488 alpha=1, 489 ) 490 491 def forward(self, x): 492 mask = (x.sum(1).unsqueeze(1) != 0).int() 493 feature = self.encoder(x, mask) 494 feature = feature * mask 495 496 return feature 497 498 499class _Predictor(nn.Module): 500 def __init__( 501 self, 502 num_layers, 503 r1, 504 r2, 505 num_f_maps, 506 num_classes, 507 num_decoders, 508 ): 509 super(_Predictor, self).__init__() 510 self.decoders = nn.ModuleList( 511 [ 512 copy.deepcopy( 513 _Decoder( 514 num_layers, 515 r1, 516 r2, 517 num_f_maps, 518 num_classes, 519 num_classes, 520 att_type="sliding_att", 521 alpha=_exponential_descrease(s), 522 ) 523 ) 524 for s in range(num_decoders) 525 ] 526 ) # num_decoders 527 528 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 529 530 def forward(self, x): 531 mask = (x.sum(1).unsqueeze(1) != 0).int() 532 out = self.conv_out(x) * mask[:, 0:1, :] 533 outputs = out.unsqueeze(0) 534 535 for decoder in self.decoders: 536 out, x = decoder( 537 F.softmax(out, dim=1) * mask[:, 0:1, :], x * mask[:, 0:1, :], mask 538 ) 539 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 540 return outputs 541 542 543class ASFormer(Model): 544 """ 545 An implementation of ASFormer 546 """ 547 548 def __init__( 549 self, 550 num_decoders, 551 num_layers, 552 r1, 553 r2, 554 num_f_maps, 555 input_dim, 556 num_classes, 557 channel_masking_rate, 558 state_dict_path=None, 559 ssl_constructors=None, 560 ssl_types=None, 561 ssl_modules=None, 562 ): 563 input_dim = sum([x[0] for x in input_dim.values()]) 564 self.num_f_maps = num_f_maps 565 self.params = { 566 "num_layers": int(num_layers), 567 "r1": r1, 568 "r2": r2, 569 "num_f_maps": int(num_f_maps), 570 "input_dim": int(input_dim), 571 "channel_masking_rate": channel_masking_rate, 572 } 573 self.params_predictor = { 574 "num_layers": int(num_layers), 575 "r1": r1, 576 "r2": r2, 577 "num_f_maps": int(num_f_maps), 578 "num_classes": int(num_classes), 579 "num_decoders": int(num_decoders), 580 } 581 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 582 583 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 584 return _MyTransformer(**self.params) 585 586 def _predictor(self) -> torch.nn.Module: 587 return _Predictor(**self.params_predictor) 588 589 def features_shape(self) -> torch.Size: 590 return torch.Size([self.num_f_maps])
544class ASFormer(Model): 545 """ 546 An implementation of ASFormer 547 """ 548 549 def __init__( 550 self, 551 num_decoders, 552 num_layers, 553 r1, 554 r2, 555 num_f_maps, 556 input_dim, 557 num_classes, 558 channel_masking_rate, 559 state_dict_path=None, 560 ssl_constructors=None, 561 ssl_types=None, 562 ssl_modules=None, 563 ): 564 input_dim = sum([x[0] for x in input_dim.values()]) 565 self.num_f_maps = num_f_maps 566 self.params = { 567 "num_layers": int(num_layers), 568 "r1": r1, 569 "r2": r2, 570 "num_f_maps": int(num_f_maps), 571 "input_dim": int(input_dim), 572 "channel_masking_rate": channel_masking_rate, 573 } 574 self.params_predictor = { 575 "num_layers": int(num_layers), 576 "r1": r1, 577 "r2": r2, 578 "num_f_maps": int(num_f_maps), 579 "num_classes": int(num_classes), 580 "num_decoders": int(num_decoders), 581 } 582 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 583 584 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 585 return _MyTransformer(**self.params) 586 587 def _predictor(self) -> torch.nn.Module: 588 return _Predictor(**self.params_predictor) 589 590 def features_shape(self) -> torch.Size: 591 return torch.Size([self.num_f_maps])
An implementation of ASFormer
549 def __init__( 550 self, 551 num_decoders, 552 num_layers, 553 r1, 554 r2, 555 num_f_maps, 556 input_dim, 557 num_classes, 558 channel_masking_rate, 559 state_dict_path=None, 560 ssl_constructors=None, 561 ssl_types=None, 562 ssl_modules=None, 563 ): 564 input_dim = sum([x[0] for x in input_dim.values()]) 565 self.num_f_maps = num_f_maps 566 self.params = { 567 "num_layers": int(num_layers), 568 "r1": r1, 569 "r2": r2, 570 "num_f_maps": int(num_f_maps), 571 "input_dim": int(input_dim), 572 "channel_masking_rate": channel_masking_rate, 573 } 574 self.params_predictor = { 575 "num_layers": int(num_layers), 576 "r1": r1, 577 "r2": r2, 578 "num_f_maps": int(num_f_maps), 579 "num_classes": int(num_classes), 580 "num_decoders": int(num_decoders), 581 } 582 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
Get the shape of feature extractor output
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- forward
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- T_destination
- state_dict
- register_load_state_dict_post_hook
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr