dlc2action.model.asformer
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# Incorporates code adapted from ASFormer by ChinaYi 7# Original work Copyright (c) 2021 ChinaYi 8# Source: https://github.com/ChinaYi/ASFormer 9# Originally licensed under MIT License 10# Combined work licensed under GNU AGPLv3 11# 12import copy 13import math 14from typing import List, Union 15 16import numpy as np 17import torch 18import torch.nn as nn 19import torch.nn.functional as F 20from dlc2action.model.base_model import Model 21from torch import optim 22 23device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 25 26def exponential_descrease(idx_decoder, p=3): 27 """Exponential decrease function for the attention window.""" 28 return math.exp(-p * idx_decoder) 29 30 31class AttentionHelper(nn.Module): 32 def __init__(self): 33 super(AttentionHelper, self).__init__() 34 self.softmax = nn.Softmax(dim=-1) 35 36 def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask): 37 """ 38 scalar dot attention. 39 :param proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length) 40 :param proj_key: shape of (B, C, L) 41 :param proj_val: shape of (B, C, L) 42 :param padding_mask: shape of (B, C, L) 43 :return: attention value of shape (B, C, L) 44 """ 45 m, c1, l1 = proj_query.shape 46 m, c2, l2 = proj_key.shape 47 48 assert c1 == c2 49 50 energy = torch.bmm( 51 proj_query.permute(0, 2, 1), proj_key 52 ) # out of shape (B, L1, L2) 53 attention = energy / np.sqrt(c1) 54 attention = attention + torch.log( 55 padding_mask + 1e-6 56 ) # mask the zero paddings. log(1e-6) for zero paddings 57 attention = self.softmax(attention) 58 attention = attention * padding_mask 59 attention = attention.permute(0, 2, 1) 60 out = torch.bmm(proj_val, attention) 61 return out, attention 62 63 64class AttLayer(nn.Module): 65 def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type): # r1 = r2 66 self._fix_types(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type) 67 super(AttLayer, self).__init__() 68 69 self.query_conv = nn.Conv1d( 70 in_channels=self.q_dim, out_channels=self.q_dim // self.r1, kernel_size=1 71 ) 72 self.key_conv = nn.Conv1d( 73 in_channels=self.k_dim, out_channels=self.k_dim // self.r2, kernel_size=1 74 ) 75 self.value_conv = nn.Conv1d( 76 in_channels=self.v_dim, out_channels=self.v_dim // self.r3, kernel_size=1 77 ) 78 79 self.conv_out = nn.Conv1d( 80 in_channels=self.v_dim // self.r3, out_channels=self.v_dim, kernel_size=1 81 ) 82 83 assert self.att_type in ["normal_att", "block_att", "sliding_att"] 84 assert self.stage in ["encoder", "decoder"] 85 86 self.att_helper = AttentionHelper() 87 self.window_mask = self.construct_window_mask() 88 89 def _fix_types(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type): 90 self.q_dim = int(float(q_dim)) 91 self.k_dim = int(float(k_dim)) 92 self.v_dim = int(float(v_dim)) 93 self.r1 = int(float(r1)) 94 self.r2 = int(float(r2)) 95 self.r3 = int(float(r3)) 96 self.bl = int(float(bl)) 97 self.stage = stage 98 self.att_type = att_type 99 100 def construct_window_mask(self): 101 """ 102 construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention 103 """ 104 window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2))) 105 for i in range(self.bl): 106 window_mask[:, :, i : i + self.bl] = 1 107 return window_mask.to(device) 108 109 def forward(self, x1, x2, mask): 110 """Forward pass.""" 111 # x1 from the encoder 112 # x2 from the decoder 113 114 query = self.query_conv(x1) 115 key = self.key_conv(x1) 116 117 if self.stage == "decoder": 118 assert x2 is not None 119 value = self.value_conv(x2) 120 else: 121 value = self.value_conv(x1) 122 123 if self.att_type == "normal_att": 124 return self._normal_self_att(query, key, value, mask) 125 elif self.att_type == "block_att": 126 return self._block_wise_self_att(query, key, value, mask) 127 elif self.att_type == "sliding_att": 128 return self._sliding_window_self_att(query, key, value, mask) 129 130 def _normal_self_att(self, q, k, v, mask): 131 m_batchsize, c1, L = q.size() 132 _, c2, L = k.size() 133 _, c3, L = v.size() 134 padding_mask = torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :] 135 output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask) 136 output = self.conv_out(F.relu(output)) 137 output = output[:, :, 0:L] 138 return output * mask[:, 0:1, :] 139 140 def _block_wise_self_att(self, q, k, v, mask): 141 m_batchsize, c1, L = q.size() 142 _, c2, L = k.size() 143 _, c3, L = v.size() 144 145 nb = L // self.bl 146 if L % self.bl != 0: 147 q = torch.cat( 148 [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)], 149 dim=-1, 150 ) 151 k = torch.cat( 152 [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)], 153 dim=-1, 154 ) 155 v = torch.cat( 156 [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)], 157 dim=-1, 158 ) 159 nb += 1 160 161 padding_mask = torch.cat( 162 [ 163 torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :], 164 torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device), 165 ], 166 dim=-1, 167 ) 168 169 q = ( 170 q.reshape(m_batchsize, c1, nb, self.bl) 171 .permute(0, 2, 1, 3) 172 .reshape(m_batchsize * nb, c1, self.bl) 173 ) 174 padding_mask = ( 175 padding_mask.reshape(m_batchsize, 1, nb, self.bl) 176 .permute(0, 2, 1, 3) 177 .reshape(m_batchsize * nb, 1, self.bl) 178 ) 179 k = ( 180 k.reshape(m_batchsize, c2, nb, self.bl) 181 .permute(0, 2, 1, 3) 182 .reshape(m_batchsize * nb, c2, self.bl) 183 ) 184 v = ( 185 v.reshape(m_batchsize, c3, nb, self.bl) 186 .permute(0, 2, 1, 3) 187 .reshape(m_batchsize * nb, c3, self.bl) 188 ) 189 190 output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask) 191 output = self.conv_out(F.relu(output)) 192 193 output = ( 194 output.reshape(m_batchsize, nb, c3, self.bl) 195 .permute(0, 2, 1, 3) 196 .reshape(m_batchsize, c3, nb * self.bl) 197 ) 198 output = output[:, :, 0:L] 199 return output * mask[:, 0:1, :] 200 201 def _sliding_window_self_att(self, q, k, v, mask): 202 m_batchsize, c1, L = q.size() 203 _, c2, _ = k.size() 204 _, c3, _ = v.size() 205 206 # padding zeros for the last segment 207 nb = L // self.bl 208 if L % self.bl != 0: 209 q = torch.cat( 210 [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)], 211 dim=-1, 212 ) 213 k = torch.cat( 214 [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)], 215 dim=-1, 216 ) 217 v = torch.cat( 218 [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)], 219 dim=-1, 220 ) 221 nb += 1 222 padding_mask = torch.cat( 223 [ 224 torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :], 225 torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device), 226 ], 227 dim=-1, 228 ) 229 230 # sliding window approach, by splitting query_proj and key_proj into shape (c1, l) x (c1, 2l) 231 # sliding window for query_proj: reshape 232 q = ( 233 q.reshape(m_batchsize, c1, nb, self.bl) 234 .permute(0, 2, 1, 3) 235 .reshape(m_batchsize * nb, c1, self.bl) 236 ) 237 238 # sliding window approach for key_proj 239 # 1. add paddings at the start and end 240 k = torch.cat( 241 [ 242 torch.zeros(m_batchsize, c2, self.bl // 2).to(device), 243 k, 244 torch.zeros(m_batchsize, c2, self.bl // 2).to(device), 245 ], 246 dim=-1, 247 ) 248 v = torch.cat( 249 [ 250 torch.zeros(m_batchsize, c3, self.bl // 2).to(device), 251 v, 252 torch.zeros(m_batchsize, c3, self.bl // 2).to(device), 253 ], 254 dim=-1, 255 ) 256 padding_mask = torch.cat( 257 [ 258 torch.zeros(m_batchsize, 1, self.bl // 2).to(device), 259 padding_mask, 260 torch.zeros(m_batchsize, 1, self.bl // 2).to(device), 261 ], 262 dim=-1, 263 ) 264 265 # 2. reshape key_proj of shape (m_batchsize*nb, c1, 2*self.bl) 266 k = torch.cat( 267 [ 268 k[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2] 269 for i in range(nb) 270 ], 271 dim=0, 272 ) # special case when self.bl = 1 273 v = torch.cat( 274 [ 275 v[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2] 276 for i in range(nb) 277 ], 278 dim=0, 279 ) 280 # 3. construct window mask of shape (1, l, 2l), and use it to generate final mask 281 padding_mask = torch.cat( 282 [ 283 padding_mask[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2] 284 for i in range(nb) 285 ], 286 dim=0, 287 ) # of shape (m*nb, 1, 2l) 288 final_mask = self.window_mask.repeat(m_batchsize * nb, 1, 1) * padding_mask 289 290 output, attention = self.att_helper.scalar_dot_att(q, k, v, final_mask) 291 output = self.conv_out(F.relu(output)) 292 293 output = ( 294 output.reshape(m_batchsize, nb, -1, self.bl) 295 .permute(0, 2, 1, 3) 296 .reshape(m_batchsize, -1, nb * self.bl) 297 ) 298 output = output[:, :, 0:L] 299 return output * mask[:, 0:1, :] 300 301 302class MultiHeadAttLayer(nn.Module): 303 def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, num_head): 304 super(MultiHeadAttLayer, self).__init__() 305 # assert v_dim % num_head == 0 306 self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1) 307 self.layers = nn.ModuleList( 308 [ 309 copy.deepcopy( 310 AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type) 311 ) 312 for i in range(num_head) 313 ] 314 ) 315 self.dropout = nn.Dropout(p=0.5) 316 317 def forward(self, x1, x2, mask): 318 """Forward pass.""" 319 out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1) 320 out = self.conv_out(self.dropout(out)) 321 return out 322 323 324class ConvFeedForward(nn.Module): 325 def __init__(self, dilation, in_channels, out_channels): 326 super(ConvFeedForward, self).__init__() 327 self.layer = nn.Sequential( 328 nn.Conv1d( 329 in_channels, out_channels, 3, padding=dilation, dilation=dilation 330 ), 331 nn.ReLU(), 332 ) 333 334 def forward(self, x): 335 """Forward pass.""" 336 return self.layer(x) 337 338 339class FCFeedForward(nn.Module): 340 def __init__(self, in_channels, out_channels): 341 super(FCFeedForward, self).__init__() 342 self.layer = nn.Sequential( 343 nn.Conv1d(in_channels, out_channels, 1), # conv1d equals fc 344 nn.ReLU(), 345 nn.Dropout(), 346 nn.Conv1d(out_channels, out_channels, 1), 347 ) 348 349 def forward(self, x): 350 """Forward pass.""" 351 return self.layer(x) 352 353 354class AttModule(nn.Module): 355 def __init__( 356 self, dilation, in_channels, out_channels, r1, r2, att_type, stage, alpha 357 ): 358 super(AttModule, self).__init__() 359 self.feed_forward = ConvFeedForward(dilation, in_channels, out_channels) 360 self.instance_norm = nn.InstanceNorm1d(in_channels, track_running_stats=False) 361 self.att_layer = AttLayer( 362 in_channels, 363 in_channels, 364 out_channels, 365 r1, 366 r1, 367 r2, 368 dilation, 369 att_type=att_type, 370 stage=stage, 371 ) # dilation 372 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 373 self.dropout = nn.Dropout() 374 self.alpha = alpha 375 376 def forward(self, x, f, mask): 377 """Forward pass.""" 378 out = self.feed_forward(x) 379 out = self.alpha * self.att_layer(self.instance_norm(out), f, mask) + out 380 out = self.conv_1x1(out) 381 out = self.dropout(out) 382 return (x + out) * mask[:, 0:1, :] 383 384 385class PositionalEncoding(nn.Module): 386 "Implement the PE function." 387 388 def __init__(self, d_model, max_len=10000): 389 super(PositionalEncoding, self).__init__() 390 # Compute the positional encodings once in log space. 391 pe = torch.zeros(max_len, d_model) 392 position = torch.arange(0, max_len).unsqueeze(1) 393 div_term = torch.exp( 394 torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) 395 ) 396 pe[:, 0::2] = torch.sin(position * div_term) 397 pe[:, 1::2] = torch.cos(position * div_term) 398 pe = pe.unsqueeze(0).permute(0, 2, 1) # of shape (1, d_model, l) 399 self.pe = nn.Parameter(pe, requires_grad=True) 400 401 # self.register_buffer('pe', pe) 402 403 def forward(self, x): 404 """Forward pass.""" 405 return x + self.pe[:, :, 0 : x.shape[2]] 406 407 408class Encoder(nn.Module): 409 def __init__( 410 self, 411 num_layers, 412 r1, 413 r2, 414 num_f_maps, 415 input_dim, 416 channel_masking_rate, 417 att_type, 418 alpha, 419 ): 420 super(Encoder, self).__init__() 421 self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) # fc layer 422 self.layers = nn.ModuleList( 423 [ 424 AttModule( 425 2**i, num_f_maps, num_f_maps, r1, r2, att_type, "encoder", alpha 426 ) 427 for i in range(num_layers) # 2**i 428 ] 429 ) 430 431 self.dropout = nn.Dropout2d(p=channel_masking_rate) 432 self.channel_masking_rate = channel_masking_rate 433 434 def forward(self, x, mask): 435 """Forward pass.""" 436 if self.channel_masking_rate > 0: 437 x = x.unsqueeze(2) 438 x = self.dropout(x) 439 x = x.squeeze(2) 440 441 feature = self.conv_1x1(x) 442 for layer in self.layers: 443 feature = layer(feature, None, mask) 444 445 return feature 446 447 448class Decoder(nn.Module): 449 def __init__( 450 self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, att_type, alpha 451 ): 452 super( 453 Decoder, self 454 ).__init__() # self.position_en = PositionalEncoding(d_model=num_f_maps) 455 self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) 456 self.layers = nn.ModuleList( 457 [ 458 AttModule( 459 2**i, num_f_maps, num_f_maps, r1, r2, att_type, "decoder", alpha 460 ) 461 for i in range(num_layers) # 2 ** i 462 ] 463 ) 464 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 465 466 def forward(self, x, fencoder, mask): 467 """Forward pass.""" 468 feature = self.conv_1x1(x) 469 for layer in self.layers: 470 feature = layer(feature, fencoder, mask) 471 472 out = self.conv_out(feature) * mask[:, 0:1, :] 473 474 return out, feature 475 476 477class MyTransformer(nn.Module): 478 def __init__( 479 self, 480 num_layers, 481 r1, 482 r2, 483 num_f_maps, 484 input_dim, 485 channel_masking_rate, 486 ): 487 super(MyTransformer, self).__init__() 488 self.encoder = Encoder( 489 num_layers, 490 r1, 491 r2, 492 num_f_maps, 493 input_dim, 494 channel_masking_rate, 495 att_type="sliding_att", 496 alpha=1, 497 ) 498 499 def forward(self, x): 500 """Forward pass.""" 501 mask = (x.sum(1).unsqueeze(1) != 0).int() 502 feature = self.encoder(x, mask) 503 feature = feature * mask 504 505 return feature 506 507 508class Predictor(nn.Module): 509 def __init__( 510 self, 511 num_layers, 512 r1, 513 r2, 514 num_f_maps, 515 num_classes, 516 num_decoders, 517 ): 518 super(Predictor, self).__init__() 519 self.decoders = nn.ModuleList( 520 [ 521 copy.deepcopy( 522 Decoder( 523 num_layers, 524 r1, 525 r2, 526 num_f_maps, 527 num_classes, 528 num_classes, 529 att_type="sliding_att", 530 alpha=exponential_descrease(s), 531 ) 532 ) 533 for s in range(num_decoders) 534 ] 535 ) # num_decoders 536 537 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 538 539 def forward(self, x): 540 """Forward pass.""" 541 mask = (x.sum(1).unsqueeze(1) != 0).int() 542 out = self.conv_out(x) * mask[:, 0:1, :] 543 outputs = out.unsqueeze(0) 544 545 for decoder in self.decoders: 546 out, x = decoder( 547 F.softmax(out, dim=1) * mask[:, 0:1, :], x * mask[:, 0:1, :], mask 548 ) 549 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 550 return outputs 551 552 553class ASFormer(Model): 554 """ 555 An implementation of ASFormer 556 """ 557 558 def __init__( 559 self, 560 num_decoders:int, 561 num_layers:int, 562 r1:float, 563 r2:float, 564 num_f_maps:int, 565 input_dim:dict, 566 num_classes:int, 567 channel_masking_rate:float, 568 state_dict_path:str=None, 569 ssl_constructors:List=None, 570 ssl_types:List=None, 571 ssl_modules:List=None, 572 ): 573 input_dim = sum([x[0] for x in input_dim.values()]) 574 self.num_f_maps = int(float(num_f_maps)) 575 self.params = { 576 "num_layers": int(float(num_layers)), 577 "r1": float(r1), 578 "r2": float(r2), 579 "num_f_maps": self.num_f_maps, 580 "input_dim": int(float(input_dim)), 581 "channel_masking_rate": float(channel_masking_rate), 582 } 583 self.params_predictor = { 584 "num_layers": int(float(num_layers)), 585 "r1": r1, 586 "r2": r2, 587 "num_f_maps": self.num_f_maps, 588 "num_classes": int(float(num_classes)), 589 "num_decoders": int(float(num_decoders)), 590 } 591 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 592 593 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 594 return MyTransformer(**self.params) 595 596 def _predictor(self) -> torch.nn.Module: 597 return Predictor(**self.params_predictor) 598 599 def features_shape(self) -> torch.Size: 600 return torch.Size([self.num_f_maps])
27def exponential_descrease(idx_decoder, p=3): 28 """Exponential decrease function for the attention window.""" 29 return math.exp(-p * idx_decoder)
Exponential decrease function for the attention window.
32class AttentionHelper(nn.Module): 33 def __init__(self): 34 super(AttentionHelper, self).__init__() 35 self.softmax = nn.Softmax(dim=-1) 36 37 def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask): 38 """ 39 scalar dot attention. 40 :param proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length) 41 :param proj_key: shape of (B, C, L) 42 :param proj_val: shape of (B, C, L) 43 :param padding_mask: shape of (B, C, L) 44 :return: attention value of shape (B, C, L) 45 """ 46 m, c1, l1 = proj_query.shape 47 m, c2, l2 = proj_key.shape 48 49 assert c1 == c2 50 51 energy = torch.bmm( 52 proj_query.permute(0, 2, 1), proj_key 53 ) # out of shape (B, L1, L2) 54 attention = energy / np.sqrt(c1) 55 attention = attention + torch.log( 56 padding_mask + 1e-6 57 ) # mask the zero paddings. log(1e-6) for zero paddings 58 attention = self.softmax(attention) 59 attention = attention * padding_mask 60 attention = attention.permute(0, 2, 1) 61 out = torch.bmm(proj_val, attention) 62 return out, attention
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
33 def __init__(self): 34 super(AttentionHelper, self).__init__() 35 self.softmax = nn.Softmax(dim=-1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
37 def scalar_dot_att(self, proj_query, proj_key, proj_val, padding_mask): 38 """ 39 scalar dot attention. 40 :param proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length) 41 :param proj_key: shape of (B, C, L) 42 :param proj_val: shape of (B, C, L) 43 :param padding_mask: shape of (B, C, L) 44 :return: attention value of shape (B, C, L) 45 """ 46 m, c1, l1 = proj_query.shape 47 m, c2, l2 = proj_key.shape 48 49 assert c1 == c2 50 51 energy = torch.bmm( 52 proj_query.permute(0, 2, 1), proj_key 53 ) # out of shape (B, L1, L2) 54 attention = energy / np.sqrt(c1) 55 attention = attention + torch.log( 56 padding_mask + 1e-6 57 ) # mask the zero paddings. log(1e-6) for zero paddings 58 attention = self.softmax(attention) 59 attention = attention * padding_mask 60 attention = attention.permute(0, 2, 1) 61 out = torch.bmm(proj_val, attention) 62 return out, attention
scalar dot attention.
Parameters
- proj_query: shape of (B, C, L) => (Batch_Size, Feature_Dimension, Length)
- proj_key: shape of (B, C, L)
- proj_val: shape of (B, C, L)
- padding_mask: shape of (B, C, L)
Returns
attention value of shape (B, C, L)
65class AttLayer(nn.Module): 66 def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type): # r1 = r2 67 self._fix_types(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type) 68 super(AttLayer, self).__init__() 69 70 self.query_conv = nn.Conv1d( 71 in_channels=self.q_dim, out_channels=self.q_dim // self.r1, kernel_size=1 72 ) 73 self.key_conv = nn.Conv1d( 74 in_channels=self.k_dim, out_channels=self.k_dim // self.r2, kernel_size=1 75 ) 76 self.value_conv = nn.Conv1d( 77 in_channels=self.v_dim, out_channels=self.v_dim // self.r3, kernel_size=1 78 ) 79 80 self.conv_out = nn.Conv1d( 81 in_channels=self.v_dim // self.r3, out_channels=self.v_dim, kernel_size=1 82 ) 83 84 assert self.att_type in ["normal_att", "block_att", "sliding_att"] 85 assert self.stage in ["encoder", "decoder"] 86 87 self.att_helper = AttentionHelper() 88 self.window_mask = self.construct_window_mask() 89 90 def _fix_types(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type): 91 self.q_dim = int(float(q_dim)) 92 self.k_dim = int(float(k_dim)) 93 self.v_dim = int(float(v_dim)) 94 self.r1 = int(float(r1)) 95 self.r2 = int(float(r2)) 96 self.r3 = int(float(r3)) 97 self.bl = int(float(bl)) 98 self.stage = stage 99 self.att_type = att_type 100 101 def construct_window_mask(self): 102 """ 103 construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention 104 """ 105 window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2))) 106 for i in range(self.bl): 107 window_mask[:, :, i : i + self.bl] = 1 108 return window_mask.to(device) 109 110 def forward(self, x1, x2, mask): 111 """Forward pass.""" 112 # x1 from the encoder 113 # x2 from the decoder 114 115 query = self.query_conv(x1) 116 key = self.key_conv(x1) 117 118 if self.stage == "decoder": 119 assert x2 is not None 120 value = self.value_conv(x2) 121 else: 122 value = self.value_conv(x1) 123 124 if self.att_type == "normal_att": 125 return self._normal_self_att(query, key, value, mask) 126 elif self.att_type == "block_att": 127 return self._block_wise_self_att(query, key, value, mask) 128 elif self.att_type == "sliding_att": 129 return self._sliding_window_self_att(query, key, value, mask) 130 131 def _normal_self_att(self, q, k, v, mask): 132 m_batchsize, c1, L = q.size() 133 _, c2, L = k.size() 134 _, c3, L = v.size() 135 padding_mask = torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :] 136 output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask) 137 output = self.conv_out(F.relu(output)) 138 output = output[:, :, 0:L] 139 return output * mask[:, 0:1, :] 140 141 def _block_wise_self_att(self, q, k, v, mask): 142 m_batchsize, c1, L = q.size() 143 _, c2, L = k.size() 144 _, c3, L = v.size() 145 146 nb = L // self.bl 147 if L % self.bl != 0: 148 q = torch.cat( 149 [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)], 150 dim=-1, 151 ) 152 k = torch.cat( 153 [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)], 154 dim=-1, 155 ) 156 v = torch.cat( 157 [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)], 158 dim=-1, 159 ) 160 nb += 1 161 162 padding_mask = torch.cat( 163 [ 164 torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :], 165 torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device), 166 ], 167 dim=-1, 168 ) 169 170 q = ( 171 q.reshape(m_batchsize, c1, nb, self.bl) 172 .permute(0, 2, 1, 3) 173 .reshape(m_batchsize * nb, c1, self.bl) 174 ) 175 padding_mask = ( 176 padding_mask.reshape(m_batchsize, 1, nb, self.bl) 177 .permute(0, 2, 1, 3) 178 .reshape(m_batchsize * nb, 1, self.bl) 179 ) 180 k = ( 181 k.reshape(m_batchsize, c2, nb, self.bl) 182 .permute(0, 2, 1, 3) 183 .reshape(m_batchsize * nb, c2, self.bl) 184 ) 185 v = ( 186 v.reshape(m_batchsize, c3, nb, self.bl) 187 .permute(0, 2, 1, 3) 188 .reshape(m_batchsize * nb, c3, self.bl) 189 ) 190 191 output, attentions = self.att_helper.scalar_dot_att(q, k, v, padding_mask) 192 output = self.conv_out(F.relu(output)) 193 194 output = ( 195 output.reshape(m_batchsize, nb, c3, self.bl) 196 .permute(0, 2, 1, 3) 197 .reshape(m_batchsize, c3, nb * self.bl) 198 ) 199 output = output[:, :, 0:L] 200 return output * mask[:, 0:1, :] 201 202 def _sliding_window_self_att(self, q, k, v, mask): 203 m_batchsize, c1, L = q.size() 204 _, c2, _ = k.size() 205 _, c3, _ = v.size() 206 207 # padding zeros for the last segment 208 nb = L // self.bl 209 if L % self.bl != 0: 210 q = torch.cat( 211 [q, torch.zeros((m_batchsize, c1, self.bl - L % self.bl)).to(device)], 212 dim=-1, 213 ) 214 k = torch.cat( 215 [k, torch.zeros((m_batchsize, c2, self.bl - L % self.bl)).to(device)], 216 dim=-1, 217 ) 218 v = torch.cat( 219 [v, torch.zeros((m_batchsize, c3, self.bl - L % self.bl)).to(device)], 220 dim=-1, 221 ) 222 nb += 1 223 padding_mask = torch.cat( 224 [ 225 torch.ones((m_batchsize, 1, L)).to(device) * mask[:, 0:1, :], 226 torch.zeros((m_batchsize, 1, self.bl * nb - L)).to(device), 227 ], 228 dim=-1, 229 ) 230 231 # sliding window approach, by splitting query_proj and key_proj into shape (c1, l) x (c1, 2l) 232 # sliding window for query_proj: reshape 233 q = ( 234 q.reshape(m_batchsize, c1, nb, self.bl) 235 .permute(0, 2, 1, 3) 236 .reshape(m_batchsize * nb, c1, self.bl) 237 ) 238 239 # sliding window approach for key_proj 240 # 1. add paddings at the start and end 241 k = torch.cat( 242 [ 243 torch.zeros(m_batchsize, c2, self.bl // 2).to(device), 244 k, 245 torch.zeros(m_batchsize, c2, self.bl // 2).to(device), 246 ], 247 dim=-1, 248 ) 249 v = torch.cat( 250 [ 251 torch.zeros(m_batchsize, c3, self.bl // 2).to(device), 252 v, 253 torch.zeros(m_batchsize, c3, self.bl // 2).to(device), 254 ], 255 dim=-1, 256 ) 257 padding_mask = torch.cat( 258 [ 259 torch.zeros(m_batchsize, 1, self.bl // 2).to(device), 260 padding_mask, 261 torch.zeros(m_batchsize, 1, self.bl // 2).to(device), 262 ], 263 dim=-1, 264 ) 265 266 # 2. reshape key_proj of shape (m_batchsize*nb, c1, 2*self.bl) 267 k = torch.cat( 268 [ 269 k[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2] 270 for i in range(nb) 271 ], 272 dim=0, 273 ) # special case when self.bl = 1 274 v = torch.cat( 275 [ 276 v[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2] 277 for i in range(nb) 278 ], 279 dim=0, 280 ) 281 # 3. construct window mask of shape (1, l, 2l), and use it to generate final mask 282 padding_mask = torch.cat( 283 [ 284 padding_mask[:, :, i * self.bl : (i + 1) * self.bl + (self.bl // 2) * 2] 285 for i in range(nb) 286 ], 287 dim=0, 288 ) # of shape (m*nb, 1, 2l) 289 final_mask = self.window_mask.repeat(m_batchsize * nb, 1, 1) * padding_mask 290 291 output, attention = self.att_helper.scalar_dot_att(q, k, v, final_mask) 292 output = self.conv_out(F.relu(output)) 293 294 output = ( 295 output.reshape(m_batchsize, nb, -1, self.bl) 296 .permute(0, 2, 1, 3) 297 .reshape(m_batchsize, -1, nb * self.bl) 298 ) 299 output = output[:, :, 0:L] 300 return output * mask[:, 0:1, :]
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
66 def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type): # r1 = r2 67 self._fix_types(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type) 68 super(AttLayer, self).__init__() 69 70 self.query_conv = nn.Conv1d( 71 in_channels=self.q_dim, out_channels=self.q_dim // self.r1, kernel_size=1 72 ) 73 self.key_conv = nn.Conv1d( 74 in_channels=self.k_dim, out_channels=self.k_dim // self.r2, kernel_size=1 75 ) 76 self.value_conv = nn.Conv1d( 77 in_channels=self.v_dim, out_channels=self.v_dim // self.r3, kernel_size=1 78 ) 79 80 self.conv_out = nn.Conv1d( 81 in_channels=self.v_dim // self.r3, out_channels=self.v_dim, kernel_size=1 82 ) 83 84 assert self.att_type in ["normal_att", "block_att", "sliding_att"] 85 assert self.stage in ["encoder", "decoder"] 86 87 self.att_helper = AttentionHelper() 88 self.window_mask = self.construct_window_mask()
Initialize internal Module state, shared by both nn.Module and ScriptModule.
101 def construct_window_mask(self): 102 """ 103 construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention 104 """ 105 window_mask = torch.zeros((1, self.bl, self.bl + 2 * (self.bl // 2))) 106 for i in range(self.bl): 107 window_mask[:, :, i : i + self.bl] = 1 108 return window_mask.to(device)
construct window mask of shape (1, l, l + l//2 + l//2), used for sliding window self attention
110 def forward(self, x1, x2, mask): 111 """Forward pass.""" 112 # x1 from the encoder 113 # x2 from the decoder 114 115 query = self.query_conv(x1) 116 key = self.key_conv(x1) 117 118 if self.stage == "decoder": 119 assert x2 is not None 120 value = self.value_conv(x2) 121 else: 122 value = self.value_conv(x1) 123 124 if self.att_type == "normal_att": 125 return self._normal_self_att(query, key, value, mask) 126 elif self.att_type == "block_att": 127 return self._block_wise_self_att(query, key, value, mask) 128 elif self.att_type == "sliding_att": 129 return self._sliding_window_self_att(query, key, value, mask)
Forward pass.
303class MultiHeadAttLayer(nn.Module): 304 def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, num_head): 305 super(MultiHeadAttLayer, self).__init__() 306 # assert v_dim % num_head == 0 307 self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1) 308 self.layers = nn.ModuleList( 309 [ 310 copy.deepcopy( 311 AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type) 312 ) 313 for i in range(num_head) 314 ] 315 ) 316 self.dropout = nn.Dropout(p=0.5) 317 318 def forward(self, x1, x2, mask): 319 """Forward pass.""" 320 out = torch.cat([layer(x1, x2, mask) for layer in self.layers], dim=1) 321 out = self.conv_out(self.dropout(out)) 322 return out
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
304 def __init__(self, q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type, num_head): 305 super(MultiHeadAttLayer, self).__init__() 306 # assert v_dim % num_head == 0 307 self.conv_out = nn.Conv1d(v_dim * num_head, v_dim, 1) 308 self.layers = nn.ModuleList( 309 [ 310 copy.deepcopy( 311 AttLayer(q_dim, k_dim, v_dim, r1, r2, r3, bl, stage, att_type) 312 ) 313 for i in range(num_head) 314 ] 315 ) 316 self.dropout = nn.Dropout(p=0.5)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
325class ConvFeedForward(nn.Module): 326 def __init__(self, dilation, in_channels, out_channels): 327 super(ConvFeedForward, self).__init__() 328 self.layer = nn.Sequential( 329 nn.Conv1d( 330 in_channels, out_channels, 3, padding=dilation, dilation=dilation 331 ), 332 nn.ReLU(), 333 ) 334 335 def forward(self, x): 336 """Forward pass.""" 337 return self.layer(x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
326 def __init__(self, dilation, in_channels, out_channels): 327 super(ConvFeedForward, self).__init__() 328 self.layer = nn.Sequential( 329 nn.Conv1d( 330 in_channels, out_channels, 3, padding=dilation, dilation=dilation 331 ), 332 nn.ReLU(), 333 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
340class FCFeedForward(nn.Module): 341 def __init__(self, in_channels, out_channels): 342 super(FCFeedForward, self).__init__() 343 self.layer = nn.Sequential( 344 nn.Conv1d(in_channels, out_channels, 1), # conv1d equals fc 345 nn.ReLU(), 346 nn.Dropout(), 347 nn.Conv1d(out_channels, out_channels, 1), 348 ) 349 350 def forward(self, x): 351 """Forward pass.""" 352 return self.layer(x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
341 def __init__(self, in_channels, out_channels): 342 super(FCFeedForward, self).__init__() 343 self.layer = nn.Sequential( 344 nn.Conv1d(in_channels, out_channels, 1), # conv1d equals fc 345 nn.ReLU(), 346 nn.Dropout(), 347 nn.Conv1d(out_channels, out_channels, 1), 348 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
355class AttModule(nn.Module): 356 def __init__( 357 self, dilation, in_channels, out_channels, r1, r2, att_type, stage, alpha 358 ): 359 super(AttModule, self).__init__() 360 self.feed_forward = ConvFeedForward(dilation, in_channels, out_channels) 361 self.instance_norm = nn.InstanceNorm1d(in_channels, track_running_stats=False) 362 self.att_layer = AttLayer( 363 in_channels, 364 in_channels, 365 out_channels, 366 r1, 367 r1, 368 r2, 369 dilation, 370 att_type=att_type, 371 stage=stage, 372 ) # dilation 373 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 374 self.dropout = nn.Dropout() 375 self.alpha = alpha 376 377 def forward(self, x, f, mask): 378 """Forward pass.""" 379 out = self.feed_forward(x) 380 out = self.alpha * self.att_layer(self.instance_norm(out), f, mask) + out 381 out = self.conv_1x1(out) 382 out = self.dropout(out) 383 return (x + out) * mask[:, 0:1, :]
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
356 def __init__( 357 self, dilation, in_channels, out_channels, r1, r2, att_type, stage, alpha 358 ): 359 super(AttModule, self).__init__() 360 self.feed_forward = ConvFeedForward(dilation, in_channels, out_channels) 361 self.instance_norm = nn.InstanceNorm1d(in_channels, track_running_stats=False) 362 self.att_layer = AttLayer( 363 in_channels, 364 in_channels, 365 out_channels, 366 r1, 367 r1, 368 r2, 369 dilation, 370 att_type=att_type, 371 stage=stage, 372 ) # dilation 373 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 374 self.dropout = nn.Dropout() 375 self.alpha = alpha
Initialize internal Module state, shared by both nn.Module and ScriptModule.
377 def forward(self, x, f, mask): 378 """Forward pass.""" 379 out = self.feed_forward(x) 380 out = self.alpha * self.att_layer(self.instance_norm(out), f, mask) + out 381 out = self.conv_1x1(out) 382 out = self.dropout(out) 383 return (x + out) * mask[:, 0:1, :]
Forward pass.
386class PositionalEncoding(nn.Module): 387 "Implement the PE function." 388 389 def __init__(self, d_model, max_len=10000): 390 super(PositionalEncoding, self).__init__() 391 # Compute the positional encodings once in log space. 392 pe = torch.zeros(max_len, d_model) 393 position = torch.arange(0, max_len).unsqueeze(1) 394 div_term = torch.exp( 395 torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) 396 ) 397 pe[:, 0::2] = torch.sin(position * div_term) 398 pe[:, 1::2] = torch.cos(position * div_term) 399 pe = pe.unsqueeze(0).permute(0, 2, 1) # of shape (1, d_model, l) 400 self.pe = nn.Parameter(pe, requires_grad=True) 401 402 # self.register_buffer('pe', pe) 403 404 def forward(self, x): 405 """Forward pass.""" 406 return x + self.pe[:, :, 0 : x.shape[2]]
Implement the PE function.
389 def __init__(self, d_model, max_len=10000): 390 super(PositionalEncoding, self).__init__() 391 # Compute the positional encodings once in log space. 392 pe = torch.zeros(max_len, d_model) 393 position = torch.arange(0, max_len).unsqueeze(1) 394 div_term = torch.exp( 395 torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) 396 ) 397 pe[:, 0::2] = torch.sin(position * div_term) 398 pe[:, 1::2] = torch.cos(position * div_term) 399 pe = pe.unsqueeze(0).permute(0, 2, 1) # of shape (1, d_model, l) 400 self.pe = nn.Parameter(pe, requires_grad=True)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
409class Encoder(nn.Module): 410 def __init__( 411 self, 412 num_layers, 413 r1, 414 r2, 415 num_f_maps, 416 input_dim, 417 channel_masking_rate, 418 att_type, 419 alpha, 420 ): 421 super(Encoder, self).__init__() 422 self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) # fc layer 423 self.layers = nn.ModuleList( 424 [ 425 AttModule( 426 2**i, num_f_maps, num_f_maps, r1, r2, att_type, "encoder", alpha 427 ) 428 for i in range(num_layers) # 2**i 429 ] 430 ) 431 432 self.dropout = nn.Dropout2d(p=channel_masking_rate) 433 self.channel_masking_rate = channel_masking_rate 434 435 def forward(self, x, mask): 436 """Forward pass.""" 437 if self.channel_masking_rate > 0: 438 x = x.unsqueeze(2) 439 x = self.dropout(x) 440 x = x.squeeze(2) 441 442 feature = self.conv_1x1(x) 443 for layer in self.layers: 444 feature = layer(feature, None, mask) 445 446 return feature
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
410 def __init__( 411 self, 412 num_layers, 413 r1, 414 r2, 415 num_f_maps, 416 input_dim, 417 channel_masking_rate, 418 att_type, 419 alpha, 420 ): 421 super(Encoder, self).__init__() 422 self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) # fc layer 423 self.layers = nn.ModuleList( 424 [ 425 AttModule( 426 2**i, num_f_maps, num_f_maps, r1, r2, att_type, "encoder", alpha 427 ) 428 for i in range(num_layers) # 2**i 429 ] 430 ) 431 432 self.dropout = nn.Dropout2d(p=channel_masking_rate) 433 self.channel_masking_rate = channel_masking_rate
Initialize internal Module state, shared by both nn.Module and ScriptModule.
435 def forward(self, x, mask): 436 """Forward pass.""" 437 if self.channel_masking_rate > 0: 438 x = x.unsqueeze(2) 439 x = self.dropout(x) 440 x = x.squeeze(2) 441 442 feature = self.conv_1x1(x) 443 for layer in self.layers: 444 feature = layer(feature, None, mask) 445 446 return feature
Forward pass.
449class Decoder(nn.Module): 450 def __init__( 451 self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, att_type, alpha 452 ): 453 super( 454 Decoder, self 455 ).__init__() # self.position_en = PositionalEncoding(d_model=num_f_maps) 456 self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) 457 self.layers = nn.ModuleList( 458 [ 459 AttModule( 460 2**i, num_f_maps, num_f_maps, r1, r2, att_type, "decoder", alpha 461 ) 462 for i in range(num_layers) # 2 ** i 463 ] 464 ) 465 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 466 467 def forward(self, x, fencoder, mask): 468 """Forward pass.""" 469 feature = self.conv_1x1(x) 470 for layer in self.layers: 471 feature = layer(feature, fencoder, mask) 472 473 out = self.conv_out(feature) * mask[:, 0:1, :] 474 475 return out, feature
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
450 def __init__( 451 self, num_layers, r1, r2, num_f_maps, input_dim, num_classes, att_type, alpha 452 ): 453 super( 454 Decoder, self 455 ).__init__() # self.position_en = PositionalEncoding(d_model=num_f_maps) 456 self.conv_1x1 = nn.Conv1d(input_dim, num_f_maps, 1) 457 self.layers = nn.ModuleList( 458 [ 459 AttModule( 460 2**i, num_f_maps, num_f_maps, r1, r2, att_type, "decoder", alpha 461 ) 462 for i in range(num_layers) # 2 ** i 463 ] 464 ) 465 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
467 def forward(self, x, fencoder, mask): 468 """Forward pass.""" 469 feature = self.conv_1x1(x) 470 for layer in self.layers: 471 feature = layer(feature, fencoder, mask) 472 473 out = self.conv_out(feature) * mask[:, 0:1, :] 474 475 return out, feature
Forward pass.
478class MyTransformer(nn.Module): 479 def __init__( 480 self, 481 num_layers, 482 r1, 483 r2, 484 num_f_maps, 485 input_dim, 486 channel_masking_rate, 487 ): 488 super(MyTransformer, self).__init__() 489 self.encoder = Encoder( 490 num_layers, 491 r1, 492 r2, 493 num_f_maps, 494 input_dim, 495 channel_masking_rate, 496 att_type="sliding_att", 497 alpha=1, 498 ) 499 500 def forward(self, x): 501 """Forward pass.""" 502 mask = (x.sum(1).unsqueeze(1) != 0).int() 503 feature = self.encoder(x, mask) 504 feature = feature * mask 505 506 return feature
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
479 def __init__( 480 self, 481 num_layers, 482 r1, 483 r2, 484 num_f_maps, 485 input_dim, 486 channel_masking_rate, 487 ): 488 super(MyTransformer, self).__init__() 489 self.encoder = Encoder( 490 num_layers, 491 r1, 492 r2, 493 num_f_maps, 494 input_dim, 495 channel_masking_rate, 496 att_type="sliding_att", 497 alpha=1, 498 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
509class Predictor(nn.Module): 510 def __init__( 511 self, 512 num_layers, 513 r1, 514 r2, 515 num_f_maps, 516 num_classes, 517 num_decoders, 518 ): 519 super(Predictor, self).__init__() 520 self.decoders = nn.ModuleList( 521 [ 522 copy.deepcopy( 523 Decoder( 524 num_layers, 525 r1, 526 r2, 527 num_f_maps, 528 num_classes, 529 num_classes, 530 att_type="sliding_att", 531 alpha=exponential_descrease(s), 532 ) 533 ) 534 for s in range(num_decoders) 535 ] 536 ) # num_decoders 537 538 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 539 540 def forward(self, x): 541 """Forward pass.""" 542 mask = (x.sum(1).unsqueeze(1) != 0).int() 543 out = self.conv_out(x) * mask[:, 0:1, :] 544 outputs = out.unsqueeze(0) 545 546 for decoder in self.decoders: 547 out, x = decoder( 548 F.softmax(out, dim=1) * mask[:, 0:1, :], x * mask[:, 0:1, :], mask 549 ) 550 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 551 return outputs
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
510 def __init__( 511 self, 512 num_layers, 513 r1, 514 r2, 515 num_f_maps, 516 num_classes, 517 num_decoders, 518 ): 519 super(Predictor, self).__init__() 520 self.decoders = nn.ModuleList( 521 [ 522 copy.deepcopy( 523 Decoder( 524 num_layers, 525 r1, 526 r2, 527 num_f_maps, 528 num_classes, 529 num_classes, 530 att_type="sliding_att", 531 alpha=exponential_descrease(s), 532 ) 533 ) 534 for s in range(num_decoders) 535 ] 536 ) # num_decoders 537 538 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
540 def forward(self, x): 541 """Forward pass.""" 542 mask = (x.sum(1).unsqueeze(1) != 0).int() 543 out = self.conv_out(x) * mask[:, 0:1, :] 544 outputs = out.unsqueeze(0) 545 546 for decoder in self.decoders: 547 out, x = decoder( 548 F.softmax(out, dim=1) * mask[:, 0:1, :], x * mask[:, 0:1, :], mask 549 ) 550 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 551 return outputs
Forward pass.
554class ASFormer(Model): 555 """ 556 An implementation of ASFormer 557 """ 558 559 def __init__( 560 self, 561 num_decoders:int, 562 num_layers:int, 563 r1:float, 564 r2:float, 565 num_f_maps:int, 566 input_dim:dict, 567 num_classes:int, 568 channel_masking_rate:float, 569 state_dict_path:str=None, 570 ssl_constructors:List=None, 571 ssl_types:List=None, 572 ssl_modules:List=None, 573 ): 574 input_dim = sum([x[0] for x in input_dim.values()]) 575 self.num_f_maps = int(float(num_f_maps)) 576 self.params = { 577 "num_layers": int(float(num_layers)), 578 "r1": float(r1), 579 "r2": float(r2), 580 "num_f_maps": self.num_f_maps, 581 "input_dim": int(float(input_dim)), 582 "channel_masking_rate": float(channel_masking_rate), 583 } 584 self.params_predictor = { 585 "num_layers": int(float(num_layers)), 586 "r1": r1, 587 "r2": r2, 588 "num_f_maps": self.num_f_maps, 589 "num_classes": int(float(num_classes)), 590 "num_decoders": int(float(num_decoders)), 591 } 592 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 593 594 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 595 return MyTransformer(**self.params) 596 597 def _predictor(self) -> torch.nn.Module: 598 return Predictor(**self.params_predictor) 599 600 def features_shape(self) -> torch.Size: 601 return torch.Size([self.num_f_maps])
An implementation of ASFormer
559 def __init__( 560 self, 561 num_decoders:int, 562 num_layers:int, 563 r1:float, 564 r2:float, 565 num_f_maps:int, 566 input_dim:dict, 567 num_classes:int, 568 channel_masking_rate:float, 569 state_dict_path:str=None, 570 ssl_constructors:List=None, 571 ssl_types:List=None, 572 ssl_modules:List=None, 573 ): 574 input_dim = sum([x[0] for x in input_dim.values()]) 575 self.num_f_maps = int(float(num_f_maps)) 576 self.params = { 577 "num_layers": int(float(num_layers)), 578 "r1": float(r1), 579 "r2": float(r2), 580 "num_f_maps": self.num_f_maps, 581 "input_dim": int(float(input_dim)), 582 "channel_masking_rate": float(channel_masking_rate), 583 } 584 self.params_predictor = { 585 "num_layers": int(float(num_layers)), 586 "r1": r1, 587 "r2": r2, 588 "num_f_maps": self.num_f_maps, 589 "num_classes": int(float(num_classes)), 590 "num_decoders": int(float(num_decoders)), 591 } 592 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Initialize the model.
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt
Get the shape of feature extractor output.
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- feature_extractor
- feature_extractors
- predictor
- ssl_active
- main_task_active
- prompt_function
- class_tensors
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- transform_labels
- forward