dlc2action.model.ms_tcn_modules
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7import copy 8 9import torch 10import torch.nn as nn 11import torch.nn.functional as F 12from dlc2action.model.asformer import AttLayer 13 14 15class Refinement(nn.Module): 16 """ 17 Refinement module 18 """ 19 20 def __init__( 21 self, 22 num_layers, 23 num_f_maps_input, 24 num_f_maps, 25 dim, 26 num_classes, 27 dropout_rate, 28 direction, 29 skip_connections, 30 attention="none", 31 block_size=0, 32 ): 33 """ 34 Parameters 35 ---------- 36 num_layers : int 37 the number of layers 38 num_f_maps : int 39 the number of feature maps 40 dim : int 41 the number of features in input 42 num_classes : int 43 the number of target classes 44 dropout_rate : float 45 dropout rate 46 direction : [None, 'forward', 'backward'] 47 the direction of convolutions; if None, regular convolutions are used 48 skip_connections : bool 49 if `True`, skip connections are added 50 block_size : int, default 0 51 if not 0, skip connections are added to the prediction generation stage with this interval 52 """ 53 54 super(Refinement, self).__init__() 55 self.block_size = block_size 56 self.direction = direction 57 if skip_connections: 58 self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1) 59 else: 60 self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1) 61 self.layers = nn.ModuleList( 62 [ 63 copy.deepcopy( 64 DilatedResidualLayer( 65 dilation=2**i, 66 in_channels=num_f_maps, 67 out_channels=num_f_maps, 68 dropout_rate=dropout_rate, 69 causal=(direction is not None), 70 ) 71 ) 72 for i in range(num_layers) 73 ] 74 ) 75 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 76 self.attention_layers = nn.ModuleList([]) 77 self.attention = attention 78 if self.attention == "basic": 79 self.attention_layers += nn.ModuleList( 80 [ 81 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 82 nn.ReLU(), 83 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 84 nn.Sigmoid(), 85 ] 86 ) 87 88 def forward(self, x): 89 """Forward pass.""" 90 x = self.conv_1x1(x) 91 out = copy.copy(x) 92 for l_i, layer in enumerate(self.layers): 93 out = layer(out, self.direction) 94 if self.block_size != 0 and (l_i + 1) % self.block_size == 0: 95 out = out + x 96 x = copy.copy(out) 97 if self.attention != "none": 98 x = copy.copy(out) 99 for layer in self.attention_layers: 100 x = layer(x) 101 out = out * x 102 out = self.conv_out(out) 103 return out 104 105 106class Refinement_SE(nn.Module): 107 """ 108 Refinement module 109 """ 110 111 def __init__( 112 self, 113 num_layers, 114 num_f_maps_input, 115 num_f_maps, 116 dim, 117 num_classes, 118 dropout_rate, 119 direction, 120 skip_connections, 121 len_segment, 122 block_size=0, 123 ): 124 """ 125 Parameters 126 ---------- 127 num_layers : int 128 the number of layers 129 num_f_maps : int 130 the number of feature maps 131 dim : int 132 the number of features in input 133 num_classes : int 134 the number of target classes 135 dropout_rate : float 136 dropout rate 137 direction : [None, 'forward', 'backward'] 138 the direction of convolutions; if None, regular convolutions are used 139 skip_connections : bool 140 if `True`, skip connections are added 141 block_size : int, default 0 142 if not 0, skip connections are added to the prediction generation stage with this interval 143 """ 144 145 super().__init__() 146 self.block_size = block_size 147 self.direction = direction 148 if skip_connections: 149 self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1) 150 else: 151 self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1) 152 self.layers = nn.ModuleList( 153 [ 154 DilatedResidualLayer( 155 dilation=2**i, 156 in_channels=num_f_maps, 157 out_channels=num_f_maps, 158 dropout_rate=dropout_rate, 159 causal=(direction is not None), 160 ) 161 for i in range(num_layers) 162 ] 163 ) 164 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 165 self.fc1_f = nn.ModuleList( 166 [nn.Linear(num_f_maps, num_f_maps // 2) for _ in range(6)] 167 ) 168 self.fc2_f = nn.ModuleList( 169 [nn.Linear(num_f_maps // 2, num_f_maps) for _ in range(6)] 170 ) 171 172 def _fc1_f(self, tag): 173 if tag is None: 174 for i in range(1, len(self.fc1_f)): 175 self.fc1_f[i].load_state_dict(self.fc1_f[0].state_dict()) 176 return self.fc1_f[0] 177 else: 178 return self.fc1_f[tag] 179 180 def _fc2_f(self, tag): 181 if tag is None: 182 for i in range(1, len(self.fc2_f)): 183 self.fc2_f[i].load_state_dict(self.fc2_f[0].state_dict()) 184 return self.fc2_f[0] 185 else: 186 return self.fc2_f[tag] 187 188 def forward(self, x, tag): 189 """Forward pass.""" 190 x = self.conv_1x1(x) 191 scale = torch.mean(x, -1) 192 scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale))) 193 scale = F.sigmoid(scale).unsqueeze(-1) 194 x = scale * x 195 out = copy.copy(x) 196 for l_i, layer in enumerate(self.layers): 197 out = layer(out, self.direction) 198 if self.block_size != 0 and (l_i + 1) % self.block_size == 0: 199 out = out + x 200 x = copy.copy(out) 201 out = self.conv_out(out) 202 return out 203 204 205class RefinementB(Refinement): 206 """ 207 Bidirectional refinement module 208 """ 209 210 def forward(self, x): 211 """Forward pass.""" 212 x_f = self.conv_1x1(x) 213 x_b = copy.copy(x_f) 214 forward = copy.copy(x_f) 215 backward = copy.copy(x_f) 216 for i, layer_f in enumerate(self.layers): 217 forward = layer_f(forward, "forward") 218 backward = layer_f(backward, "backward") 219 if self.block_size != 0 and (i + 1) % self.block_size == 0: 220 forward = forward + x_f 221 backward = backward + x_b 222 x_f = copy.copy(forward) 223 x_b = copy.copy(backward) 224 out = torch.cat([forward, backward], 1) 225 out = self.conv_out(out) 226 return out 227 228 229class SimpleResidualLayer(nn.Module): 230 """ 231 Basic residual layer 232 """ 233 234 def __init__(self, num_f_maps, dropout_rate): 235 """ 236 Parameters 237 ---------- 238 in_channels : int 239 number of input channels 240 out_channels : int 241 number of output channels 242 dropout_rate : float 243 dropout rate 244 """ 245 246 super().__init__() 247 self.conv_1x1_in = nn.Conv1d(num_f_maps, num_f_maps, 1) 248 self.conv_1x1 = nn.Conv1d(num_f_maps, num_f_maps, 1) 249 self.dropout = nn.Dropout(dropout_rate) 250 251 def forward(self, x): 252 """Forward pass.""" 253 out = self.conv_1x1_in(x) 254 out = F.relu(out) 255 out = self.conv_1x1(out) 256 out = self.dropout(out) 257 return x + out 258 259 260class DilatedResidualLayer(nn.Module): 261 """ 262 Dilated residual layer 263 """ 264 265 def __init__(self, dilation, in_channels, out_channels, dropout_rate, causal): 266 """ 267 Parameters 268 ---------- 269 dilation : int 270 dilation 271 in_channels : int 272 number of input channels 273 out_channels : int 274 number of output channels 275 dropout_rate : float 276 dropout rate 277 causal : bool 278 if `True`, causal convolutions are used 279 """ 280 281 super(DilatedResidualLayer, self).__init__() 282 self.padding = dilation * 2 283 self.causal = causal 284 if self.causal: 285 padding = 0 286 else: 287 padding = dilation 288 self.conv_dilated = nn.Conv1d( 289 in_channels, out_channels, 3, padding=padding, dilation=dilation 290 ) 291 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 292 self.dropout = nn.Dropout(dropout_rate) 293 294 def forward(self, x, direction=None): 295 """Forward pass.""" 296 if direction is not None and not self.causal: 297 raise ValueError("Cannot set direction in a non-causal layer!") 298 elif direction is None and self.causal: 299 direction = "forward" 300 if direction == "forward": 301 padding = (0, self.padding) 302 elif direction == "backward": 303 padding = (self.padding, 0) 304 elif direction is not None: 305 raise ValueError( 306 f"Unrecognized direction: {direction}, please choose from" 307 f'"backward", "forward" and None' 308 ) 309 if direction is not None: 310 out = self.conv_dilated(F.pad(x, padding)) 311 else: 312 out = self.conv_dilated(x) 313 out = F.relu(out) 314 out = self.conv_1x1(out) 315 out = self.dropout(out) 316 return x + out 317 318 319class DilatedResidualLayer_SE(nn.Module): 320 """ 321 Dilated residual layer 322 """ 323 324 def __init__( 325 self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment 326 ): 327 """ 328 Parameters 329 ---------- 330 dilation : int 331 dilation 332 in_channels : int 333 number of input channels 334 out_channels : int 335 number of output channels 336 dropout_rate : float 337 dropout rate 338 causal : bool 339 if `True`, causal convolutions are used 340 """ 341 342 super().__init__() 343 self.padding = dilation * 2 344 self.causal = causal 345 if self.causal: 346 padding = 0 347 else: 348 padding = dilation 349 self.conv_dilated = nn.Conv1d( 350 in_channels, out_channels, 3, padding=padding, dilation=dilation 351 ) 352 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 353 self.dropout = nn.Dropout(dropout_rate) 354 self.fc1_f = nn.ModuleList( 355 [nn.Linear(out_channels, out_channels // 2) for _ in range(6)] 356 ) 357 self.fc2_f = nn.ModuleList( 358 [nn.Linear(out_channels // 2, out_channels) for _ in range(6)] 359 ) 360 # self.fc1_t = nn.ModuleList([nn.Linear(len_segment, len_segment // 2) for _ in range(6)]) 361 # self.fc2_t = nn.ModuleList([nn.Linear(len_segment // 2, len_segment) for _ in range(6)]) 362 363 def _fc1_f(self, tag): 364 if tag is None: 365 for i in range(1, len(self.fc1_f)): 366 self.fc1_f[i].load_state_dict(self.fc1_f[0].state_dict()) 367 return self.fc1_f[0] 368 else: 369 return self.fc1_f[tag] 370 371 def _fc2_f(self, tag): 372 if tag is None: 373 for i in range(1, len(self.fc2_f)): 374 self.fc2_f[i].load_state_dict(self.fc2_f[0].state_dict()) 375 return self.fc2_f[0] 376 else: 377 return self.fc2_f[tag] 378 379 def _fc1_t(self, tag): 380 if tag is None: 381 for i in range(1, len(self.fc1_t)): 382 self.fc1_t[i].load_state_dict(self.fc1_t[0].state_dict()) 383 return self.fc1_t[0] 384 else: 385 return self.fc1_t[tag] 386 387 def _fc2_t(self, tag): 388 if tag is None: 389 for i in range(1, len(self.fc2_t)): 390 self.fc2_t[i].load_state_dict(self.fc2_t[0].state_dict()) 391 return self.fc2_t[0] 392 else: 393 return self.fc2_t[tag] 394 395 def forward(self, x, direction, tag): 396 """Forward pass.""" 397 if direction is not None and not self.causal: 398 raise ValueError("Cannot set direction in a non-causal layer!") 399 elif direction is None and self.causal: 400 direction = "forward" 401 if direction == "forward": 402 padding = (0, self.padding) 403 elif direction == "backward": 404 padding = (self.padding, 0) 405 elif direction is not None: 406 raise ValueError( 407 f"Unrecognized direction: {direction}, please choose from" 408 f'"backward", "forward" and None' 409 ) 410 if direction is not None: 411 out = self.conv_dilated(F.pad(x, padding)) 412 else: 413 out = self.conv_dilated(x) 414 out = F.relu(out) 415 out = self.conv_1x1(out) 416 out = self.dropout(out) 417 scale = torch.mean(out, -1) 418 scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale))) 419 scale = F.sigmoid(scale).unsqueeze(-1) 420 # time_scale = torch.mean(out, 1) 421 # time_scale = self._fc2_t(tag)(F.relu(self._fc1_t(tag)(time_scale))) 422 # time_scale = F.sigmoid(time_scale).unsqueeze(1) 423 out = out * scale # * time_scale 424 return x + out 425 426 427class DilatedResidualLayer_SEC(nn.Module): 428 """ 429 Dilated residual layer 430 """ 431 432 def __init__( 433 self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment 434 ): 435 """ 436 Parameters 437 ---------- 438 dilation : int 439 dilation 440 in_channels : int 441 number of input channels 442 out_channels : int 443 number of output channels 444 dropout_rate : float 445 dropout rate 446 causal : bool 447 if `True`, causal convolutions are used 448 """ 449 450 super().__init__() 451 self.padding = dilation * 2 452 self.causal = causal 453 if self.causal: 454 padding = 0 455 else: 456 padding = dilation 457 self.conv_dilated = nn.Conv1d( 458 in_channels, out_channels, 3, padding=padding, dilation=dilation 459 ) 460 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 461 self.dropout = nn.Dropout(dropout_rate) 462 self.conv_sq = nn.ModuleList( 463 [ 464 nn.Conv1d(out_channels, 1, 3, padding=padding, dilation=dilation) 465 for _ in range(6) 466 ] 467 ) 468 469 def _conv_sq(self, tag): 470 if tag is None: 471 for i in range(1, len(self.conv_sq)): 472 self.conv_sq[i].load_state_dict(self.conv_sq[0].state_dict()) 473 return self.conv_sq[0] 474 else: 475 return self.conv_sq[tag] 476 477 def forward(self, x, direction, tag): 478 """Forward pass.""" 479 if direction is not None and not self.causal: 480 raise ValueError("Cannot set direction in a non-causal layer!") 481 elif direction is None and self.causal: 482 direction = "forward" 483 if direction == "forward": 484 padding = (0, self.padding) 485 elif direction == "backward": 486 padding = (self.padding, 0) 487 elif direction is not None: 488 raise ValueError( 489 f"Unrecognized direction: {direction}, please choose from" 490 f'"backward", "forward" and None' 491 ) 492 if direction is not None: 493 out = self.conv_dilated(F.pad(x, padding)) 494 else: 495 out = self.conv_dilated(x) 496 out = F.relu(out) 497 out = self.conv_1x1(out) 498 out = self.dropout(out) 499 scale = torch.sigmoid(self._conv_sq(tag)(out)) 500 out = out * scale 501 return x + out 502 503 504class DualDilatedResidualLayer(nn.Module): 505 """ 506 Dual dilated residual layer 507 """ 508 509 def __init__( 510 self, 511 dilation1, 512 dilation2, 513 in_channels, 514 out_channels, 515 dropout_rate, 516 causal, 517 kernel_size=3, 518 ): 519 """ 520 Parameters 521 ---------- 522 dilation1, dilation2 : int 523 dilation of one of the blocks 524 in_channels : int 525 number of input channels 526 out_channels : int 527 number of output channels 528 dropout_rate : float 529 dropout rate 530 causal : bool 531 if `True`, causal convolutions are used 532 kernel_size : int, default 3 533 kernel size 534 """ 535 536 super(DualDilatedResidualLayer, self).__init__() 537 self.causal = causal 538 self.padding1 = dilation1 * (kernel_size - 1) // 2 539 self.padding2 = dilation2 * (kernel_size - 1) // 2 540 if self.causal: 541 self.padding1 *= 2 542 self.padding2 *= 2 543 self.conv_dilated_1 = nn.Conv1d( 544 in_channels, 545 out_channels, 546 kernel_size=kernel_size, 547 padding=self.padding1, 548 dilation=dilation1, 549 ) 550 self.conv_dilated_2 = nn.Conv1d( 551 in_channels, 552 out_channels, 553 kernel_size=kernel_size, 554 padding=self.padding2, 555 dilation=dilation2, 556 ) 557 self.conv_fusion = nn.Conv1d(2 * out_channels, out_channels, 1) 558 self.dropout = nn.Dropout(dropout_rate) 559 560 def forward(self, x, direction=None): 561 """Forward pass.""" 562 if direction is not None and not self.causal: 563 raise ValueError("Cannot set direction in a non-causal layer!") 564 elif direction is None and self.causal: 565 direction = "forward" 566 if direction not in ["backward", "forward", None]: 567 raise ValueError( 568 f"Unrecognized direction: {direction}, please choose from" 569 f'"backward", "forward" and None' 570 ) 571 out_1 = self.conv_dilated_1(x) 572 out_2 = self.conv_dilated_2(x) 573 if direction == "forward": 574 out_1 = out_1[:, :, : -self.padding1] 575 out_2 = out_2[:, :, : -self.padding2] 576 elif direction == "backward": 577 out_1 = out_1[:, :, self.padding1 :] 578 out_2 = out_2[:, :, self.padding2 :] 579 580 out = self.conv_fusion( 581 torch.cat( 582 [ 583 out_1, 584 out_2, 585 ], 586 1, 587 ) 588 ) 589 out = F.relu(out) 590 out = self.dropout(out) 591 out = out + x 592 return out 593 594 595class MultiDilatedTCN(nn.Module): 596 """ 597 Multiple prediction generation stages in parallel 598 """ 599 600 def __init__( 601 self, 602 num_layers, 603 num_f_maps, 604 dims, 605 direction, 606 block_size=5, 607 kernel_size=3, 608 rare_dilations=False, 609 ): 610 super(MultiDilatedTCN, self).__init__() 611 self.PGs = nn.ModuleList( 612 [ 613 DilatedTCN( 614 num_layers, 615 num_f_maps, 616 dim, 617 direction, 618 block_size, 619 kernel_size, 620 rare_dilations, 621 ) 622 for dim in dims 623 ] 624 ) 625 self.conv_out = nn.Conv1d(num_f_maps * len(dims), num_f_maps, 1) 626 627 def forward(self, x, tag=None): 628 """Forward pass.""" 629 out = [] 630 for arr, PG in zip(x, self.PGs): 631 out.append(PG(arr)) 632 out = torch.cat(out, 1) 633 out = self.conv_out(out) 634 return out 635 636 637class DilatedTCN(nn.Module): 638 """ 639 Prediction generation stage 640 """ 641 642 def __init__( 643 self, 644 num_layers, 645 num_f_maps, 646 dim, 647 direction, 648 num_bp=None, 649 block_size=5, 650 kernel_size=3, 651 rare_dilations=False, 652 attention="none", 653 multihead=False, 654 ): 655 """ 656 Parameters 657 ---------- 658 num_layers : int 659 number of layers 660 num_f_maps : int 661 number of feature maps 662 dim : int 663 number of features in input 664 direction : [None, 'forward', 'backward'] 665 the direction of convolutions; if None, regular convolutions are used 666 block_size : int, default 0 667 if not 0, skip connections are added to the prediction generation stage with this interval 668 kernel_size : int, default 3 669 kernel size 670 rare_dilations : bool, default False 671 if `False`, dilation increases every layer, otherwise every second layer 672 """ 673 674 super().__init__() 675 self.num_layers = num_layers 676 self.block_size = block_size 677 self.direction = direction 678 self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1) 679 pars = { 680 "in_channels": num_f_maps, 681 "out_channels": num_f_maps, 682 "dropout_rate": 0.5, 683 "causal": (direction is not None), 684 "kernel_size": kernel_size, 685 } 686 module = DualDilatedResidualLayer 687 if not rare_dilations: 688 self.layers = nn.ModuleList( 689 [ 690 module(dilation1=2 ** (num_layers - 1 - i), dilation2=2**i, **pars) 691 for i in range(num_layers) 692 ] 693 ) 694 else: 695 self.layers = nn.ModuleList( 696 [ 697 module( 698 dilation1=2 ** (num_layers // 2 - 1 - i // 2), 699 dilation2=2 ** (i // 2), 700 **pars, 701 ) 702 for i in range(num_layers) 703 ] 704 ) 705 self.attention_layers = nn.ModuleList([]) 706 self.attention = attention 707 if self.attention in ["basic"]: 708 self.attention_layers += nn.ModuleList( 709 [ 710 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 711 nn.ReLU(), 712 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 713 nn.Sigmoid(), 714 ] 715 ) 716 elif isinstance(self.attention, str) and self.attention != "none": 717 self.attention_layers = AttLayer( 718 num_f_maps, 719 num_f_maps, 720 num_f_maps, 721 4, 722 4, 723 2, 724 64, 725 att_type=self.attention, 726 stage="encoder", 727 ) 728 self.multihead = multihead 729 730 def forward(self, x, tag=None): 731 """Forward pass.""" 732 x = self.conv_1x1_in(x) 733 f = copy.copy(x) 734 for i, layer in enumerate(self.layers): 735 f = layer(f, self.direction) 736 if self.block_size != 0 and (i + 1) % self.block_size == 0: 737 f = f + x 738 x = copy.copy(f) 739 if isinstance(self.attention, str) and self.attention != "none": 740 x = copy.copy(f) 741 if self.attention != "basic": 742 f = self.attention_layers(x) 743 elif not self.multihead: 744 for layer in self.attention_layers: 745 x = layer(x) 746 f = f * x 747 else: 748 outputs = [] 749 for layers in self.attention_layers: 750 y = copy.copy(x) 751 for layer in layers: 752 y = layer(y) 753 outputs.append(copy.copy(f * y)) 754 outputs = torch.cat(outputs, dim=1) 755 f = self.conv_att(self.dropout(outputs)) 756 return f 757 758 759class DilatedTCNB(nn.Module): 760 """ 761 Bidirectional prediction generation stage 762 """ 763 764 def __init__( 765 self, 766 num_layers, 767 num_f_maps, 768 dim, 769 block_size=5, 770 kernel_size=3, 771 rare_dilations=False, 772 ): 773 """ 774 Parameters 775 ---------- 776 num_layers : int 777 number of layers 778 num_f_maps : int 779 number of feature maps 780 dim : int 781 number of features in input 782 block_size : int, default 0 783 if not 0, skip connections are added to the prediction generation stage with this interval 784 kernel_size : int, default 3 785 kernel size 786 rare_dilations : bool, default False 787 if `False`, dilation increases every layer, otherwise every second layer 788 """ 789 790 super().__init__() 791 self.num_layers = num_layers 792 self.block_size = block_size 793 self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1) 794 self.conv_1x1_out = nn.Conv1d(num_f_maps * 2, num_f_maps, 1) 795 if not rare_dilations: 796 self.layers = nn.ModuleList( 797 [ 798 DualDilatedResidualLayer( 799 dilation1=2 ** (num_layers - 1 - i), 800 dilation2=2**i, 801 in_channels=num_f_maps, 802 out_channels=num_f_maps, 803 dropout_rate=0.5, 804 causal=True, 805 kernel_size=kernel_size, 806 ) 807 for i in range(num_layers) 808 ] 809 ) 810 else: 811 self.layers = nn.ModuleList( 812 [ 813 DualDilatedResidualLayer( 814 dilation1=2 ** (num_layers // 2 - 1 - i // 2), 815 dilation2=2 ** (i // 2), 816 in_channels=num_f_maps, 817 out_channels=num_f_maps, 818 dropout_rate=0.5, 819 causal=True, 820 kernel_size=kernel_size, 821 ) 822 for i in range(num_layers) 823 ] 824 ) 825 826 def forward(self, x, tag=None): 827 """Forward pass.""" 828 x_f = self.conv_1x1_in(x) 829 x_b = copy.copy(x_f) 830 forward = copy.copy(x_f) 831 backward = copy.copy(x_f) 832 for i, layer_f in enumerate(self.layers): 833 forward = layer_f(forward, "forward") 834 backward = layer_f(backward, "backward") 835 if self.block_size != 0 and (i + 1) % self.block_size == 0: 836 forward = forward + x_f 837 backward = backward + x_b 838 x_f = copy.copy(forward) 839 x_b = copy.copy(backward) 840 out = torch.cat([forward, backward], 1) 841 out = self.conv_1x1_out(out) 842 return out 843 844 845class SpatialFeatures(nn.Module): 846 """ 847 Spatial features extraction stage 848 """ 849 850 def __init__( 851 self, 852 num_layers, 853 num_f_maps, 854 dim, 855 block_size=5, 856 graph_edges=None, 857 num_nodes=None, 858 denom: int = 8, 859 ): 860 """ 861 Parameters 862 ---------- 863 num_layers : int 864 number of layers 865 num_f_maps : int 866 number of feature maps 867 dim : int 868 number of features in input 869 block_size : int, default 5 870 if not 0, skip connections are added to the prediction generation stage with this interval 871 """ 872 873 super().__init__() 874 self.num_nodes = num_nodes 875 if graph_edges is None: 876 module = SimpleResidualLayer 877 self.graph = False 878 pars = {"num_f_maps": num_f_maps, "dropout_rate": 0.5} 879 self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1) 880 else: 881 raise NotImplementedError("Graph not implemented") 882 883 self.num_layers = num_layers 884 self.block_size = block_size 885 self.layers = nn.ModuleList([module(**pars) for _ in range(num_layers)]) 886 887 def forward(self, x): 888 """Forward pass.""" 889 if self.graph: 890 B, _, L = x.shape 891 x = x.transpose(-1, -2) 892 x = x.reshape((-1, x.shape[-1])) 893 x = x.reshape((x.shape[0], self.num_nodes, -1)) 894 x = x.transpose(-1, -2) 895 x = self.conv_1x1_in(x) 896 f = copy.copy(x) 897 for i, layer in enumerate(self.layers): 898 f = layer(f) 899 if self.block_size != 0 and (i + 1) % self.block_size == 0: 900 f = f + x 901 x = copy.copy(f) 902 if self.graph: 903 f = f.reshape((B, L, -1)) 904 f = f.transpose(-1, -2) 905 f = self.conv_1x1_out(f) 906 return f 907 908 909class MSRefinement(nn.Module): 910 """ 911 Refinement stage 912 """ 913 914 def __init__( 915 self, 916 num_layers_R, 917 num_R, 918 num_f_maps_input, 919 num_f_maps, 920 num_classes, 921 dropout_rate, 922 exclusive, 923 skip_connections, 924 direction, 925 block_size=0, 926 num_heads=1, 927 attention="none", 928 ): 929 """ 930 Parameters 931 ---------- 932 num_layers_R : int 933 number of layers in refinement modules 934 num_R : int 935 number of refinement modules 936 num_f_maps : int 937 number of feature maps 938 num_classes : int 939 number of target classes 940 dropout_rate : float 941 dropout rate 942 exclusive : bool 943 set `False` for multi-label classification 944 skip_connections : bool 945 if `True`, skip connections are added 946 direction : [None, 'bidirectional', 'forward', 'backward'] 947 the direction of convolutions; if None, regular convolutions are used 948 block_size : int, default 0 949 if not 0, skip connections are added to the prediction generation stage with this interval 950 num_heads : int, default 1 951 number of parallel refinement stages 952 """ 953 954 super().__init__() 955 self.skip_connections = skip_connections 956 self.num_heads = num_heads 957 if exclusive: 958 self.nl = lambda x: F.softmax(x, dim=1) 959 else: 960 self.nl = lambda x: torch.sigmoid(x) 961 if direction == "bidirectional": 962 refinement_module = RefinementB 963 else: 964 refinement_module = Refinement 965 self.Rs = nn.ModuleList( 966 [ 967 nn.ModuleList( 968 [ 969 refinement_module( 970 num_layers=num_layers_R, 971 num_f_maps=num_f_maps, 972 num_f_maps_input=num_f_maps_input, 973 dim=num_classes, 974 num_classes=num_classes, 975 dropout_rate=dropout_rate, 976 direction=direction, 977 skip_connections=skip_connections, 978 block_size=block_size, 979 attention=attention, 980 ) 981 for s in range(num_R) 982 ] 983 ) 984 for _ in range(self.num_heads) 985 ] 986 ) 987 self.conv_out = nn.ModuleList( 988 [nn.Conv1d(num_f_maps_input, num_classes, 1) for _ in range(self.num_heads)] 989 ) 990 if self.num_heads == 1: 991 self.Rs = self.Rs[0] 992 self.conv_out = self.conv_out[0] 993 994 def _Rs(self, tag): 995 if self.num_heads == 1: 996 return self.Rs 997 if tag is None: 998 tag = 0 999 for i in range(1, self.num_heads): 1000 self.Rs[i].load_state_dict(self.Rs[0].state_dict()) 1001 return self.Rs[tag] 1002 1003 def _conv_out(self, tag): 1004 if self.num_heads == 1: 1005 return self.conv_out 1006 if tag is None: 1007 tag = 0 1008 for i in range(1, self.num_heads): 1009 self.conv_out[i].load_state_dict(self.conv_out[0].state_dict()) 1010 return self.conv_out[tag] 1011 1012 def forward(self, x, tag=None): 1013 """Forward pass.""" 1014 if tag is not None: 1015 tag = tag[0] 1016 out = self._conv_out(tag)(x) 1017 outputs = out.unsqueeze(0) 1018 for R in self._Rs(tag): 1019 if self.skip_connections: 1020 out = R(torch.cat([self.nl(out), x], axis=1)) 1021 # out = R(torch.cat([out, x], axis=1)) 1022 else: 1023 out = R(self.nl(out)) 1024 # out = R(out) 1025 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 1026 1027 return outputs 1028 1029 1030class MSRefinementShared(nn.Module): 1031 """ 1032 Refinement stage with shared weights across modules 1033 """ 1034 1035 def __init__( 1036 self, 1037 num_layers_R, 1038 num_R, 1039 num_f_maps_input, 1040 num_f_maps, 1041 num_classes, 1042 dropout_rate, 1043 exclusive, 1044 skip_connections, 1045 direction, 1046 block_size=0, 1047 num_heads=1, 1048 attention="none", 1049 ): 1050 """ 1051 Parameters 1052 ---------- 1053 num_layers_R : int 1054 number of layers in refinement modules 1055 num_R : int 1056 number of refinement modules 1057 num_f_maps : int 1058 number of feature maps 1059 num_classes : int 1060 number of target classes 1061 dropout_rate : float 1062 dropout rate 1063 exclusive : bool 1064 set `False` for multi-label classification 1065 skip_connections : bool 1066 if `True`, skip connections are added 1067 direction : [None, 'bidirectional', 'forward', 'backward'] 1068 the direction of convolutions; if None, regular convolutions are used 1069 block_size : int, default 0 1070 if not 0, skip connections are added to the prediction generation stage with this interval 1071 num_heads : int, default 1 1072 number of parallel refinement stages 1073 """ 1074 1075 super().__init__() 1076 if exclusive: 1077 self.nl = lambda x: F.softmax(x, dim=1) 1078 else: 1079 self.nl = lambda x: torch.sigmoid(x) 1080 if direction == "bidirectional": 1081 refinement_module = RefinementB 1082 else: 1083 refinement_module = Refinement 1084 self.num_heads = num_heads 1085 self.R = nn.ModuleList( 1086 [ 1087 refinement_module( 1088 num_layers=num_layers_R, 1089 num_f_maps_input=num_f_maps_input, 1090 num_f_maps=num_f_maps, 1091 dim=num_classes, 1092 num_classes=num_classes, 1093 dropout_rate=dropout_rate, 1094 direction=direction, 1095 skip_connections=skip_connections, 1096 block_size=block_size, 1097 attention=attention, 1098 ) 1099 for _ in range(self.num_heads) 1100 ] 1101 ) 1102 self.num_R = num_R 1103 self.conv_out = nn.ModuleList( 1104 [nn.Conv1d(num_f_maps_input, num_classes, 1) for _ in range(self.num_heads)] 1105 ) 1106 self.skip_connections = skip_connections 1107 if self.num_heads == 1: 1108 self.R = self.R[0] 1109 self.conv_out = self.conv_out[0] 1110 1111 def _R(self, tag): 1112 if self.num_heads == 1: 1113 return self.R 1114 if tag is None: 1115 tag = 0 1116 for i in range(1, self.num_heads): 1117 self.R[i].load_state_dict(self.R[0].state_dict()) 1118 return self.R[tag] 1119 1120 def _conv_out(self, tag): 1121 if self.num_heads == 1: 1122 return self.conv_out 1123 if tag is None: 1124 tag = 0 1125 for i in range(1, self.num_heads): 1126 self.conv_out[i].load_state_dict(self.conv_out[0].state_dict()) 1127 return self.conv_out[tag] 1128 1129 def forward(self, x, tag=None): 1130 """Forward pass.""" 1131 if tag is not None: 1132 tag = tag[0] 1133 out = self._conv_out(tag)(x) 1134 outputs = out.unsqueeze(0) 1135 for _ in range(self.num_R): 1136 if self.skip_connections: 1137 # out = self._R(tag)(torch.cat([self.nl(out), x], axis=1)) 1138 out = self._R(tag)(torch.cat([out, x], axis=1)) 1139 else: 1140 # out = self._R(tag)(self.nl(out)) 1141 out = self._R(tag)(out) 1142 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 1143 1144 return outputs 1145 1146 1147class MSRefinementAttention(nn.Module): 1148 """ 1149 Refinement stage 1150 """ 1151 1152 def __init__( 1153 self, 1154 num_layers_R, 1155 num_R, 1156 num_f_maps_input, 1157 num_f_maps, 1158 num_classes, 1159 dropout_rate, 1160 exclusive, 1161 skip_connections, 1162 len_segment, 1163 block_size=0, 1164 ): 1165 """ 1166 Parameters 1167 ---------- 1168 num_layers_R : int 1169 number of layers in refinement modules 1170 num_R : int 1171 number of refinement modules 1172 num_f_maps : int 1173 number of feature maps 1174 num_classes : int 1175 number of target classes 1176 dropout_rate : float 1177 dropout rate 1178 exclusive : bool 1179 set `False` for multi-label classification 1180 skip_connections : bool 1181 if `True`, skip connections are added 1182 direction : [None, 'bidirectional', 'forward', 'backward'] 1183 the direction of convolutions; if None, regular convolutions are used 1184 block_size : int, default 0 1185 if not 0, skip connections are added to the prediction generation stage with this interval 1186 num_heads : int, default 1 1187 number of parallel refinement stages 1188 """ 1189 1190 super().__init__() 1191 self.skip_connections = skip_connections 1192 if exclusive: 1193 self.nl = lambda x: F.softmax(x, dim=1) 1194 else: 1195 self.nl = lambda x: torch.sigmoid(x) 1196 refinement_module = Refinement_SE 1197 self.Rs = nn.ModuleList( 1198 [ 1199 refinement_module( 1200 num_layers=num_layers_R, 1201 num_f_maps=num_f_maps, 1202 num_f_maps_input=num_f_maps_input, 1203 dim=num_classes, 1204 num_classes=num_classes, 1205 dropout_rate=dropout_rate, 1206 direction=None, 1207 skip_connections=skip_connections, 1208 block_size=block_size, 1209 len_segment=len_segment, 1210 ) 1211 for s in range(num_R) 1212 ] 1213 ) 1214 self.conv_out = nn.Conv1d(num_f_maps_input, num_classes, 1) 1215 1216 def forward(self, x, tag=None): 1217 """Forward pass.""" 1218 if tag is not None: 1219 tag = tag[0] 1220 out = self.conv_out(x) 1221 outputs = out.unsqueeze(0) 1222 for R in self.Rs: 1223 if self.skip_connections: 1224 out = R(torch.cat([self.nl(out), x], axis=1), tag) 1225 # out = R(torch.cat([out, x], axis=1), tag) 1226 else: 1227 out = R(self.nl(out), tag) 1228 # out = R(out, tag) 1229 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 1230 1231 return outputs 1232 1233 1234class DilatedTCNC(nn.Module): 1235 def __init__( 1236 self, 1237 num_f_maps, 1238 num_layers_PG, 1239 len_segment, 1240 block_size_prediction=5, 1241 kernel_size_prediction=3, 1242 direction_PG=None, 1243 ): 1244 super(DilatedTCNC, self).__init__() 1245 if direction_PG == "bidirectional": 1246 self.PG_S = DilatedTCNB( 1247 num_layers=num_layers_PG, 1248 num_f_maps=num_f_maps, 1249 dim=num_f_maps, 1250 block_size=block_size_prediction, 1251 ) 1252 self.PG_T = DilatedTCNB( 1253 num_layers=num_layers_PG, 1254 num_f_maps=len_segment, 1255 dim=len_segment, 1256 block_size=block_size_prediction, 1257 ) 1258 else: 1259 self.PG_S = DilatedTCN( 1260 num_layers=num_layers_PG, 1261 num_f_maps=num_f_maps, 1262 dim=num_f_maps, 1263 direction=direction_PG, 1264 block_size=block_size_prediction, 1265 kernel_size=kernel_size_prediction, 1266 ) 1267 self.PG_T = DilatedTCN( 1268 num_layers=num_layers_PG, 1269 num_f_maps=len_segment, 1270 dim=len_segment, 1271 direction=direction_PG, 1272 block_size=block_size_prediction, 1273 kernel_size=kernel_size_prediction, 1274 ) 1275 1276 def forward(self, x): 1277 """Forward pass.""" 1278 x = self.PG_S(x) 1279 x = torch.transpose(x, 1, 2) 1280 x = self.PG_T(x) 1281 x = torch.transpose(x, 1, 2) 1282 return x
16class Refinement(nn.Module): 17 """ 18 Refinement module 19 """ 20 21 def __init__( 22 self, 23 num_layers, 24 num_f_maps_input, 25 num_f_maps, 26 dim, 27 num_classes, 28 dropout_rate, 29 direction, 30 skip_connections, 31 attention="none", 32 block_size=0, 33 ): 34 """ 35 Parameters 36 ---------- 37 num_layers : int 38 the number of layers 39 num_f_maps : int 40 the number of feature maps 41 dim : int 42 the number of features in input 43 num_classes : int 44 the number of target classes 45 dropout_rate : float 46 dropout rate 47 direction : [None, 'forward', 'backward'] 48 the direction of convolutions; if None, regular convolutions are used 49 skip_connections : bool 50 if `True`, skip connections are added 51 block_size : int, default 0 52 if not 0, skip connections are added to the prediction generation stage with this interval 53 """ 54 55 super(Refinement, self).__init__() 56 self.block_size = block_size 57 self.direction = direction 58 if skip_connections: 59 self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1) 60 else: 61 self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1) 62 self.layers = nn.ModuleList( 63 [ 64 copy.deepcopy( 65 DilatedResidualLayer( 66 dilation=2**i, 67 in_channels=num_f_maps, 68 out_channels=num_f_maps, 69 dropout_rate=dropout_rate, 70 causal=(direction is not None), 71 ) 72 ) 73 for i in range(num_layers) 74 ] 75 ) 76 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 77 self.attention_layers = nn.ModuleList([]) 78 self.attention = attention 79 if self.attention == "basic": 80 self.attention_layers += nn.ModuleList( 81 [ 82 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 83 nn.ReLU(), 84 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 85 nn.Sigmoid(), 86 ] 87 ) 88 89 def forward(self, x): 90 """Forward pass.""" 91 x = self.conv_1x1(x) 92 out = copy.copy(x) 93 for l_i, layer in enumerate(self.layers): 94 out = layer(out, self.direction) 95 if self.block_size != 0 and (l_i + 1) % self.block_size == 0: 96 out = out + x 97 x = copy.copy(out) 98 if self.attention != "none": 99 x = copy.copy(out) 100 for layer in self.attention_layers: 101 x = layer(x) 102 out = out * x 103 out = self.conv_out(out) 104 return out
Refinement module
21 def __init__( 22 self, 23 num_layers, 24 num_f_maps_input, 25 num_f_maps, 26 dim, 27 num_classes, 28 dropout_rate, 29 direction, 30 skip_connections, 31 attention="none", 32 block_size=0, 33 ): 34 """ 35 Parameters 36 ---------- 37 num_layers : int 38 the number of layers 39 num_f_maps : int 40 the number of feature maps 41 dim : int 42 the number of features in input 43 num_classes : int 44 the number of target classes 45 dropout_rate : float 46 dropout rate 47 direction : [None, 'forward', 'backward'] 48 the direction of convolutions; if None, regular convolutions are used 49 skip_connections : bool 50 if `True`, skip connections are added 51 block_size : int, default 0 52 if not 0, skip connections are added to the prediction generation stage with this interval 53 """ 54 55 super(Refinement, self).__init__() 56 self.block_size = block_size 57 self.direction = direction 58 if skip_connections: 59 self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1) 60 else: 61 self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1) 62 self.layers = nn.ModuleList( 63 [ 64 copy.deepcopy( 65 DilatedResidualLayer( 66 dilation=2**i, 67 in_channels=num_f_maps, 68 out_channels=num_f_maps, 69 dropout_rate=dropout_rate, 70 causal=(direction is not None), 71 ) 72 ) 73 for i in range(num_layers) 74 ] 75 ) 76 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 77 self.attention_layers = nn.ModuleList([]) 78 self.attention = attention 79 if self.attention == "basic": 80 self.attention_layers += nn.ModuleList( 81 [ 82 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 83 nn.ReLU(), 84 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 85 nn.Sigmoid(), 86 ] 87 )
Parameters
num_layers : int
the number of layers
num_f_maps : int
the number of feature maps
dim : int
the number of features in input
num_classes : int
the number of target classes
dropout_rate : float
dropout rate
direction : [None, 'forward', 'backward']
the direction of convolutions; if None, regular convolutions are used
skip_connections : bool
if True
, skip connections are added
block_size : int, default 0
if not 0, skip connections are added to the prediction generation stage with this interval
89 def forward(self, x): 90 """Forward pass.""" 91 x = self.conv_1x1(x) 92 out = copy.copy(x) 93 for l_i, layer in enumerate(self.layers): 94 out = layer(out, self.direction) 95 if self.block_size != 0 and (l_i + 1) % self.block_size == 0: 96 out = out + x 97 x = copy.copy(out) 98 if self.attention != "none": 99 x = copy.copy(out) 100 for layer in self.attention_layers: 101 x = layer(x) 102 out = out * x 103 out = self.conv_out(out) 104 return out
Forward pass.
107class Refinement_SE(nn.Module): 108 """ 109 Refinement module 110 """ 111 112 def __init__( 113 self, 114 num_layers, 115 num_f_maps_input, 116 num_f_maps, 117 dim, 118 num_classes, 119 dropout_rate, 120 direction, 121 skip_connections, 122 len_segment, 123 block_size=0, 124 ): 125 """ 126 Parameters 127 ---------- 128 num_layers : int 129 the number of layers 130 num_f_maps : int 131 the number of feature maps 132 dim : int 133 the number of features in input 134 num_classes : int 135 the number of target classes 136 dropout_rate : float 137 dropout rate 138 direction : [None, 'forward', 'backward'] 139 the direction of convolutions; if None, regular convolutions are used 140 skip_connections : bool 141 if `True`, skip connections are added 142 block_size : int, default 0 143 if not 0, skip connections are added to the prediction generation stage with this interval 144 """ 145 146 super().__init__() 147 self.block_size = block_size 148 self.direction = direction 149 if skip_connections: 150 self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1) 151 else: 152 self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1) 153 self.layers = nn.ModuleList( 154 [ 155 DilatedResidualLayer( 156 dilation=2**i, 157 in_channels=num_f_maps, 158 out_channels=num_f_maps, 159 dropout_rate=dropout_rate, 160 causal=(direction is not None), 161 ) 162 for i in range(num_layers) 163 ] 164 ) 165 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 166 self.fc1_f = nn.ModuleList( 167 [nn.Linear(num_f_maps, num_f_maps // 2) for _ in range(6)] 168 ) 169 self.fc2_f = nn.ModuleList( 170 [nn.Linear(num_f_maps // 2, num_f_maps) for _ in range(6)] 171 ) 172 173 def _fc1_f(self, tag): 174 if tag is None: 175 for i in range(1, len(self.fc1_f)): 176 self.fc1_f[i].load_state_dict(self.fc1_f[0].state_dict()) 177 return self.fc1_f[0] 178 else: 179 return self.fc1_f[tag] 180 181 def _fc2_f(self, tag): 182 if tag is None: 183 for i in range(1, len(self.fc2_f)): 184 self.fc2_f[i].load_state_dict(self.fc2_f[0].state_dict()) 185 return self.fc2_f[0] 186 else: 187 return self.fc2_f[tag] 188 189 def forward(self, x, tag): 190 """Forward pass.""" 191 x = self.conv_1x1(x) 192 scale = torch.mean(x, -1) 193 scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale))) 194 scale = F.sigmoid(scale).unsqueeze(-1) 195 x = scale * x 196 out = copy.copy(x) 197 for l_i, layer in enumerate(self.layers): 198 out = layer(out, self.direction) 199 if self.block_size != 0 and (l_i + 1) % self.block_size == 0: 200 out = out + x 201 x = copy.copy(out) 202 out = self.conv_out(out) 203 return out
Refinement module
112 def __init__( 113 self, 114 num_layers, 115 num_f_maps_input, 116 num_f_maps, 117 dim, 118 num_classes, 119 dropout_rate, 120 direction, 121 skip_connections, 122 len_segment, 123 block_size=0, 124 ): 125 """ 126 Parameters 127 ---------- 128 num_layers : int 129 the number of layers 130 num_f_maps : int 131 the number of feature maps 132 dim : int 133 the number of features in input 134 num_classes : int 135 the number of target classes 136 dropout_rate : float 137 dropout rate 138 direction : [None, 'forward', 'backward'] 139 the direction of convolutions; if None, regular convolutions are used 140 skip_connections : bool 141 if `True`, skip connections are added 142 block_size : int, default 0 143 if not 0, skip connections are added to the prediction generation stage with this interval 144 """ 145 146 super().__init__() 147 self.block_size = block_size 148 self.direction = direction 149 if skip_connections: 150 self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1) 151 else: 152 self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1) 153 self.layers = nn.ModuleList( 154 [ 155 DilatedResidualLayer( 156 dilation=2**i, 157 in_channels=num_f_maps, 158 out_channels=num_f_maps, 159 dropout_rate=dropout_rate, 160 causal=(direction is not None), 161 ) 162 for i in range(num_layers) 163 ] 164 ) 165 self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1) 166 self.fc1_f = nn.ModuleList( 167 [nn.Linear(num_f_maps, num_f_maps // 2) for _ in range(6)] 168 ) 169 self.fc2_f = nn.ModuleList( 170 [nn.Linear(num_f_maps // 2, num_f_maps) for _ in range(6)] 171 )
Parameters
num_layers : int
the number of layers
num_f_maps : int
the number of feature maps
dim : int
the number of features in input
num_classes : int
the number of target classes
dropout_rate : float
dropout rate
direction : [None, 'forward', 'backward']
the direction of convolutions; if None, regular convolutions are used
skip_connections : bool
if True
, skip connections are added
block_size : int, default 0
if not 0, skip connections are added to the prediction generation stage with this interval
189 def forward(self, x, tag): 190 """Forward pass.""" 191 x = self.conv_1x1(x) 192 scale = torch.mean(x, -1) 193 scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale))) 194 scale = F.sigmoid(scale).unsqueeze(-1) 195 x = scale * x 196 out = copy.copy(x) 197 for l_i, layer in enumerate(self.layers): 198 out = layer(out, self.direction) 199 if self.block_size != 0 and (l_i + 1) % self.block_size == 0: 200 out = out + x 201 x = copy.copy(out) 202 out = self.conv_out(out) 203 return out
Forward pass.
206class RefinementB(Refinement): 207 """ 208 Bidirectional refinement module 209 """ 210 211 def forward(self, x): 212 """Forward pass.""" 213 x_f = self.conv_1x1(x) 214 x_b = copy.copy(x_f) 215 forward = copy.copy(x_f) 216 backward = copy.copy(x_f) 217 for i, layer_f in enumerate(self.layers): 218 forward = layer_f(forward, "forward") 219 backward = layer_f(backward, "backward") 220 if self.block_size != 0 and (i + 1) % self.block_size == 0: 221 forward = forward + x_f 222 backward = backward + x_b 223 x_f = copy.copy(forward) 224 x_b = copy.copy(backward) 225 out = torch.cat([forward, backward], 1) 226 out = self.conv_out(out) 227 return out
Bidirectional refinement module
211 def forward(self, x): 212 """Forward pass.""" 213 x_f = self.conv_1x1(x) 214 x_b = copy.copy(x_f) 215 forward = copy.copy(x_f) 216 backward = copy.copy(x_f) 217 for i, layer_f in enumerate(self.layers): 218 forward = layer_f(forward, "forward") 219 backward = layer_f(backward, "backward") 220 if self.block_size != 0 and (i + 1) % self.block_size == 0: 221 forward = forward + x_f 222 backward = backward + x_b 223 x_f = copy.copy(forward) 224 x_b = copy.copy(backward) 225 out = torch.cat([forward, backward], 1) 226 out = self.conv_out(out) 227 return out
Forward pass.
Inherited Members
230class SimpleResidualLayer(nn.Module): 231 """ 232 Basic residual layer 233 """ 234 235 def __init__(self, num_f_maps, dropout_rate): 236 """ 237 Parameters 238 ---------- 239 in_channels : int 240 number of input channels 241 out_channels : int 242 number of output channels 243 dropout_rate : float 244 dropout rate 245 """ 246 247 super().__init__() 248 self.conv_1x1_in = nn.Conv1d(num_f_maps, num_f_maps, 1) 249 self.conv_1x1 = nn.Conv1d(num_f_maps, num_f_maps, 1) 250 self.dropout = nn.Dropout(dropout_rate) 251 252 def forward(self, x): 253 """Forward pass.""" 254 out = self.conv_1x1_in(x) 255 out = F.relu(out) 256 out = self.conv_1x1(out) 257 out = self.dropout(out) 258 return x + out
Basic residual layer
235 def __init__(self, num_f_maps, dropout_rate): 236 """ 237 Parameters 238 ---------- 239 in_channels : int 240 number of input channels 241 out_channels : int 242 number of output channels 243 dropout_rate : float 244 dropout rate 245 """ 246 247 super().__init__() 248 self.conv_1x1_in = nn.Conv1d(num_f_maps, num_f_maps, 1) 249 self.conv_1x1 = nn.Conv1d(num_f_maps, num_f_maps, 1) 250 self.dropout = nn.Dropout(dropout_rate)
Parameters
in_channels : int number of input channels out_channels : int number of output channels dropout_rate : float dropout rate
261class DilatedResidualLayer(nn.Module): 262 """ 263 Dilated residual layer 264 """ 265 266 def __init__(self, dilation, in_channels, out_channels, dropout_rate, causal): 267 """ 268 Parameters 269 ---------- 270 dilation : int 271 dilation 272 in_channels : int 273 number of input channels 274 out_channels : int 275 number of output channels 276 dropout_rate : float 277 dropout rate 278 causal : bool 279 if `True`, causal convolutions are used 280 """ 281 282 super(DilatedResidualLayer, self).__init__() 283 self.padding = dilation * 2 284 self.causal = causal 285 if self.causal: 286 padding = 0 287 else: 288 padding = dilation 289 self.conv_dilated = nn.Conv1d( 290 in_channels, out_channels, 3, padding=padding, dilation=dilation 291 ) 292 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 293 self.dropout = nn.Dropout(dropout_rate) 294 295 def forward(self, x, direction=None): 296 """Forward pass.""" 297 if direction is not None and not self.causal: 298 raise ValueError("Cannot set direction in a non-causal layer!") 299 elif direction is None and self.causal: 300 direction = "forward" 301 if direction == "forward": 302 padding = (0, self.padding) 303 elif direction == "backward": 304 padding = (self.padding, 0) 305 elif direction is not None: 306 raise ValueError( 307 f"Unrecognized direction: {direction}, please choose from" 308 f'"backward", "forward" and None' 309 ) 310 if direction is not None: 311 out = self.conv_dilated(F.pad(x, padding)) 312 else: 313 out = self.conv_dilated(x) 314 out = F.relu(out) 315 out = self.conv_1x1(out) 316 out = self.dropout(out) 317 return x + out
Dilated residual layer
266 def __init__(self, dilation, in_channels, out_channels, dropout_rate, causal): 267 """ 268 Parameters 269 ---------- 270 dilation : int 271 dilation 272 in_channels : int 273 number of input channels 274 out_channels : int 275 number of output channels 276 dropout_rate : float 277 dropout rate 278 causal : bool 279 if `True`, causal convolutions are used 280 """ 281 282 super(DilatedResidualLayer, self).__init__() 283 self.padding = dilation * 2 284 self.causal = causal 285 if self.causal: 286 padding = 0 287 else: 288 padding = dilation 289 self.conv_dilated = nn.Conv1d( 290 in_channels, out_channels, 3, padding=padding, dilation=dilation 291 ) 292 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 293 self.dropout = nn.Dropout(dropout_rate)
Parameters
dilation : int
dilation
in_channels : int
number of input channels
out_channels : int
number of output channels
dropout_rate : float
dropout rate
causal : bool
if True
, causal convolutions are used
295 def forward(self, x, direction=None): 296 """Forward pass.""" 297 if direction is not None and not self.causal: 298 raise ValueError("Cannot set direction in a non-causal layer!") 299 elif direction is None and self.causal: 300 direction = "forward" 301 if direction == "forward": 302 padding = (0, self.padding) 303 elif direction == "backward": 304 padding = (self.padding, 0) 305 elif direction is not None: 306 raise ValueError( 307 f"Unrecognized direction: {direction}, please choose from" 308 f'"backward", "forward" and None' 309 ) 310 if direction is not None: 311 out = self.conv_dilated(F.pad(x, padding)) 312 else: 313 out = self.conv_dilated(x) 314 out = F.relu(out) 315 out = self.conv_1x1(out) 316 out = self.dropout(out) 317 return x + out
Forward pass.
320class DilatedResidualLayer_SE(nn.Module): 321 """ 322 Dilated residual layer 323 """ 324 325 def __init__( 326 self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment 327 ): 328 """ 329 Parameters 330 ---------- 331 dilation : int 332 dilation 333 in_channels : int 334 number of input channels 335 out_channels : int 336 number of output channels 337 dropout_rate : float 338 dropout rate 339 causal : bool 340 if `True`, causal convolutions are used 341 """ 342 343 super().__init__() 344 self.padding = dilation * 2 345 self.causal = causal 346 if self.causal: 347 padding = 0 348 else: 349 padding = dilation 350 self.conv_dilated = nn.Conv1d( 351 in_channels, out_channels, 3, padding=padding, dilation=dilation 352 ) 353 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 354 self.dropout = nn.Dropout(dropout_rate) 355 self.fc1_f = nn.ModuleList( 356 [nn.Linear(out_channels, out_channels // 2) for _ in range(6)] 357 ) 358 self.fc2_f = nn.ModuleList( 359 [nn.Linear(out_channels // 2, out_channels) for _ in range(6)] 360 ) 361 # self.fc1_t = nn.ModuleList([nn.Linear(len_segment, len_segment // 2) for _ in range(6)]) 362 # self.fc2_t = nn.ModuleList([nn.Linear(len_segment // 2, len_segment) for _ in range(6)]) 363 364 def _fc1_f(self, tag): 365 if tag is None: 366 for i in range(1, len(self.fc1_f)): 367 self.fc1_f[i].load_state_dict(self.fc1_f[0].state_dict()) 368 return self.fc1_f[0] 369 else: 370 return self.fc1_f[tag] 371 372 def _fc2_f(self, tag): 373 if tag is None: 374 for i in range(1, len(self.fc2_f)): 375 self.fc2_f[i].load_state_dict(self.fc2_f[0].state_dict()) 376 return self.fc2_f[0] 377 else: 378 return self.fc2_f[tag] 379 380 def _fc1_t(self, tag): 381 if tag is None: 382 for i in range(1, len(self.fc1_t)): 383 self.fc1_t[i].load_state_dict(self.fc1_t[0].state_dict()) 384 return self.fc1_t[0] 385 else: 386 return self.fc1_t[tag] 387 388 def _fc2_t(self, tag): 389 if tag is None: 390 for i in range(1, len(self.fc2_t)): 391 self.fc2_t[i].load_state_dict(self.fc2_t[0].state_dict()) 392 return self.fc2_t[0] 393 else: 394 return self.fc2_t[tag] 395 396 def forward(self, x, direction, tag): 397 """Forward pass.""" 398 if direction is not None and not self.causal: 399 raise ValueError("Cannot set direction in a non-causal layer!") 400 elif direction is None and self.causal: 401 direction = "forward" 402 if direction == "forward": 403 padding = (0, self.padding) 404 elif direction == "backward": 405 padding = (self.padding, 0) 406 elif direction is not None: 407 raise ValueError( 408 f"Unrecognized direction: {direction}, please choose from" 409 f'"backward", "forward" and None' 410 ) 411 if direction is not None: 412 out = self.conv_dilated(F.pad(x, padding)) 413 else: 414 out = self.conv_dilated(x) 415 out = F.relu(out) 416 out = self.conv_1x1(out) 417 out = self.dropout(out) 418 scale = torch.mean(out, -1) 419 scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale))) 420 scale = F.sigmoid(scale).unsqueeze(-1) 421 # time_scale = torch.mean(out, 1) 422 # time_scale = self._fc2_t(tag)(F.relu(self._fc1_t(tag)(time_scale))) 423 # time_scale = F.sigmoid(time_scale).unsqueeze(1) 424 out = out * scale # * time_scale 425 return x + out
Dilated residual layer
325 def __init__( 326 self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment 327 ): 328 """ 329 Parameters 330 ---------- 331 dilation : int 332 dilation 333 in_channels : int 334 number of input channels 335 out_channels : int 336 number of output channels 337 dropout_rate : float 338 dropout rate 339 causal : bool 340 if `True`, causal convolutions are used 341 """ 342 343 super().__init__() 344 self.padding = dilation * 2 345 self.causal = causal 346 if self.causal: 347 padding = 0 348 else: 349 padding = dilation 350 self.conv_dilated = nn.Conv1d( 351 in_channels, out_channels, 3, padding=padding, dilation=dilation 352 ) 353 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 354 self.dropout = nn.Dropout(dropout_rate) 355 self.fc1_f = nn.ModuleList( 356 [nn.Linear(out_channels, out_channels // 2) for _ in range(6)] 357 ) 358 self.fc2_f = nn.ModuleList( 359 [nn.Linear(out_channels // 2, out_channels) for _ in range(6)] 360 ) 361 # self.fc1_t = nn.ModuleList([nn.Linear(len_segment, len_segment // 2) for _ in range(6)]) 362 # self.fc2_t = nn.ModuleList([nn.Linear(len_segment // 2, len_segment) for _ in range(6)])
Parameters
dilation : int
dilation
in_channels : int
number of input channels
out_channels : int
number of output channels
dropout_rate : float
dropout rate
causal : bool
if True
, causal convolutions are used
396 def forward(self, x, direction, tag): 397 """Forward pass.""" 398 if direction is not None and not self.causal: 399 raise ValueError("Cannot set direction in a non-causal layer!") 400 elif direction is None and self.causal: 401 direction = "forward" 402 if direction == "forward": 403 padding = (0, self.padding) 404 elif direction == "backward": 405 padding = (self.padding, 0) 406 elif direction is not None: 407 raise ValueError( 408 f"Unrecognized direction: {direction}, please choose from" 409 f'"backward", "forward" and None' 410 ) 411 if direction is not None: 412 out = self.conv_dilated(F.pad(x, padding)) 413 else: 414 out = self.conv_dilated(x) 415 out = F.relu(out) 416 out = self.conv_1x1(out) 417 out = self.dropout(out) 418 scale = torch.mean(out, -1) 419 scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale))) 420 scale = F.sigmoid(scale).unsqueeze(-1) 421 # time_scale = torch.mean(out, 1) 422 # time_scale = self._fc2_t(tag)(F.relu(self._fc1_t(tag)(time_scale))) 423 # time_scale = F.sigmoid(time_scale).unsqueeze(1) 424 out = out * scale # * time_scale 425 return x + out
Forward pass.
428class DilatedResidualLayer_SEC(nn.Module): 429 """ 430 Dilated residual layer 431 """ 432 433 def __init__( 434 self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment 435 ): 436 """ 437 Parameters 438 ---------- 439 dilation : int 440 dilation 441 in_channels : int 442 number of input channels 443 out_channels : int 444 number of output channels 445 dropout_rate : float 446 dropout rate 447 causal : bool 448 if `True`, causal convolutions are used 449 """ 450 451 super().__init__() 452 self.padding = dilation * 2 453 self.causal = causal 454 if self.causal: 455 padding = 0 456 else: 457 padding = dilation 458 self.conv_dilated = nn.Conv1d( 459 in_channels, out_channels, 3, padding=padding, dilation=dilation 460 ) 461 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 462 self.dropout = nn.Dropout(dropout_rate) 463 self.conv_sq = nn.ModuleList( 464 [ 465 nn.Conv1d(out_channels, 1, 3, padding=padding, dilation=dilation) 466 for _ in range(6) 467 ] 468 ) 469 470 def _conv_sq(self, tag): 471 if tag is None: 472 for i in range(1, len(self.conv_sq)): 473 self.conv_sq[i].load_state_dict(self.conv_sq[0].state_dict()) 474 return self.conv_sq[0] 475 else: 476 return self.conv_sq[tag] 477 478 def forward(self, x, direction, tag): 479 """Forward pass.""" 480 if direction is not None and not self.causal: 481 raise ValueError("Cannot set direction in a non-causal layer!") 482 elif direction is None and self.causal: 483 direction = "forward" 484 if direction == "forward": 485 padding = (0, self.padding) 486 elif direction == "backward": 487 padding = (self.padding, 0) 488 elif direction is not None: 489 raise ValueError( 490 f"Unrecognized direction: {direction}, please choose from" 491 f'"backward", "forward" and None' 492 ) 493 if direction is not None: 494 out = self.conv_dilated(F.pad(x, padding)) 495 else: 496 out = self.conv_dilated(x) 497 out = F.relu(out) 498 out = self.conv_1x1(out) 499 out = self.dropout(out) 500 scale = torch.sigmoid(self._conv_sq(tag)(out)) 501 out = out * scale 502 return x + out
Dilated residual layer
433 def __init__( 434 self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment 435 ): 436 """ 437 Parameters 438 ---------- 439 dilation : int 440 dilation 441 in_channels : int 442 number of input channels 443 out_channels : int 444 number of output channels 445 dropout_rate : float 446 dropout rate 447 causal : bool 448 if `True`, causal convolutions are used 449 """ 450 451 super().__init__() 452 self.padding = dilation * 2 453 self.causal = causal 454 if self.causal: 455 padding = 0 456 else: 457 padding = dilation 458 self.conv_dilated = nn.Conv1d( 459 in_channels, out_channels, 3, padding=padding, dilation=dilation 460 ) 461 self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1) 462 self.dropout = nn.Dropout(dropout_rate) 463 self.conv_sq = nn.ModuleList( 464 [ 465 nn.Conv1d(out_channels, 1, 3, padding=padding, dilation=dilation) 466 for _ in range(6) 467 ] 468 )
Parameters
dilation : int
dilation
in_channels : int
number of input channels
out_channels : int
number of output channels
dropout_rate : float
dropout rate
causal : bool
if True
, causal convolutions are used
478 def forward(self, x, direction, tag): 479 """Forward pass.""" 480 if direction is not None and not self.causal: 481 raise ValueError("Cannot set direction in a non-causal layer!") 482 elif direction is None and self.causal: 483 direction = "forward" 484 if direction == "forward": 485 padding = (0, self.padding) 486 elif direction == "backward": 487 padding = (self.padding, 0) 488 elif direction is not None: 489 raise ValueError( 490 f"Unrecognized direction: {direction}, please choose from" 491 f'"backward", "forward" and None' 492 ) 493 if direction is not None: 494 out = self.conv_dilated(F.pad(x, padding)) 495 else: 496 out = self.conv_dilated(x) 497 out = F.relu(out) 498 out = self.conv_1x1(out) 499 out = self.dropout(out) 500 scale = torch.sigmoid(self._conv_sq(tag)(out)) 501 out = out * scale 502 return x + out
Forward pass.
505class DualDilatedResidualLayer(nn.Module): 506 """ 507 Dual dilated residual layer 508 """ 509 510 def __init__( 511 self, 512 dilation1, 513 dilation2, 514 in_channels, 515 out_channels, 516 dropout_rate, 517 causal, 518 kernel_size=3, 519 ): 520 """ 521 Parameters 522 ---------- 523 dilation1, dilation2 : int 524 dilation of one of the blocks 525 in_channels : int 526 number of input channels 527 out_channels : int 528 number of output channels 529 dropout_rate : float 530 dropout rate 531 causal : bool 532 if `True`, causal convolutions are used 533 kernel_size : int, default 3 534 kernel size 535 """ 536 537 super(DualDilatedResidualLayer, self).__init__() 538 self.causal = causal 539 self.padding1 = dilation1 * (kernel_size - 1) // 2 540 self.padding2 = dilation2 * (kernel_size - 1) // 2 541 if self.causal: 542 self.padding1 *= 2 543 self.padding2 *= 2 544 self.conv_dilated_1 = nn.Conv1d( 545 in_channels, 546 out_channels, 547 kernel_size=kernel_size, 548 padding=self.padding1, 549 dilation=dilation1, 550 ) 551 self.conv_dilated_2 = nn.Conv1d( 552 in_channels, 553 out_channels, 554 kernel_size=kernel_size, 555 padding=self.padding2, 556 dilation=dilation2, 557 ) 558 self.conv_fusion = nn.Conv1d(2 * out_channels, out_channels, 1) 559 self.dropout = nn.Dropout(dropout_rate) 560 561 def forward(self, x, direction=None): 562 """Forward pass.""" 563 if direction is not None and not self.causal: 564 raise ValueError("Cannot set direction in a non-causal layer!") 565 elif direction is None and self.causal: 566 direction = "forward" 567 if direction not in ["backward", "forward", None]: 568 raise ValueError( 569 f"Unrecognized direction: {direction}, please choose from" 570 f'"backward", "forward" and None' 571 ) 572 out_1 = self.conv_dilated_1(x) 573 out_2 = self.conv_dilated_2(x) 574 if direction == "forward": 575 out_1 = out_1[:, :, : -self.padding1] 576 out_2 = out_2[:, :, : -self.padding2] 577 elif direction == "backward": 578 out_1 = out_1[:, :, self.padding1 :] 579 out_2 = out_2[:, :, self.padding2 :] 580 581 out = self.conv_fusion( 582 torch.cat( 583 [ 584 out_1, 585 out_2, 586 ], 587 1, 588 ) 589 ) 590 out = F.relu(out) 591 out = self.dropout(out) 592 out = out + x 593 return out
Dual dilated residual layer
510 def __init__( 511 self, 512 dilation1, 513 dilation2, 514 in_channels, 515 out_channels, 516 dropout_rate, 517 causal, 518 kernel_size=3, 519 ): 520 """ 521 Parameters 522 ---------- 523 dilation1, dilation2 : int 524 dilation of one of the blocks 525 in_channels : int 526 number of input channels 527 out_channels : int 528 number of output channels 529 dropout_rate : float 530 dropout rate 531 causal : bool 532 if `True`, causal convolutions are used 533 kernel_size : int, default 3 534 kernel size 535 """ 536 537 super(DualDilatedResidualLayer, self).__init__() 538 self.causal = causal 539 self.padding1 = dilation1 * (kernel_size - 1) // 2 540 self.padding2 = dilation2 * (kernel_size - 1) // 2 541 if self.causal: 542 self.padding1 *= 2 543 self.padding2 *= 2 544 self.conv_dilated_1 = nn.Conv1d( 545 in_channels, 546 out_channels, 547 kernel_size=kernel_size, 548 padding=self.padding1, 549 dilation=dilation1, 550 ) 551 self.conv_dilated_2 = nn.Conv1d( 552 in_channels, 553 out_channels, 554 kernel_size=kernel_size, 555 padding=self.padding2, 556 dilation=dilation2, 557 ) 558 self.conv_fusion = nn.Conv1d(2 * out_channels, out_channels, 1) 559 self.dropout = nn.Dropout(dropout_rate)
Parameters
dilation1, dilation2 : int
dilation of one of the blocks
in_channels : int
number of input channels
out_channels : int
number of output channels
dropout_rate : float
dropout rate
causal : bool
if True
, causal convolutions are used
kernel_size : int, default 3
kernel size
561 def forward(self, x, direction=None): 562 """Forward pass.""" 563 if direction is not None and not self.causal: 564 raise ValueError("Cannot set direction in a non-causal layer!") 565 elif direction is None and self.causal: 566 direction = "forward" 567 if direction not in ["backward", "forward", None]: 568 raise ValueError( 569 f"Unrecognized direction: {direction}, please choose from" 570 f'"backward", "forward" and None' 571 ) 572 out_1 = self.conv_dilated_1(x) 573 out_2 = self.conv_dilated_2(x) 574 if direction == "forward": 575 out_1 = out_1[:, :, : -self.padding1] 576 out_2 = out_2[:, :, : -self.padding2] 577 elif direction == "backward": 578 out_1 = out_1[:, :, self.padding1 :] 579 out_2 = out_2[:, :, self.padding2 :] 580 581 out = self.conv_fusion( 582 torch.cat( 583 [ 584 out_1, 585 out_2, 586 ], 587 1, 588 ) 589 ) 590 out = F.relu(out) 591 out = self.dropout(out) 592 out = out + x 593 return out
Forward pass.
596class MultiDilatedTCN(nn.Module): 597 """ 598 Multiple prediction generation stages in parallel 599 """ 600 601 def __init__( 602 self, 603 num_layers, 604 num_f_maps, 605 dims, 606 direction, 607 block_size=5, 608 kernel_size=3, 609 rare_dilations=False, 610 ): 611 super(MultiDilatedTCN, self).__init__() 612 self.PGs = nn.ModuleList( 613 [ 614 DilatedTCN( 615 num_layers, 616 num_f_maps, 617 dim, 618 direction, 619 block_size, 620 kernel_size, 621 rare_dilations, 622 ) 623 for dim in dims 624 ] 625 ) 626 self.conv_out = nn.Conv1d(num_f_maps * len(dims), num_f_maps, 1) 627 628 def forward(self, x, tag=None): 629 """Forward pass.""" 630 out = [] 631 for arr, PG in zip(x, self.PGs): 632 out.append(PG(arr)) 633 out = torch.cat(out, 1) 634 out = self.conv_out(out) 635 return out
Multiple prediction generation stages in parallel
601 def __init__( 602 self, 603 num_layers, 604 num_f_maps, 605 dims, 606 direction, 607 block_size=5, 608 kernel_size=3, 609 rare_dilations=False, 610 ): 611 super(MultiDilatedTCN, self).__init__() 612 self.PGs = nn.ModuleList( 613 [ 614 DilatedTCN( 615 num_layers, 616 num_f_maps, 617 dim, 618 direction, 619 block_size, 620 kernel_size, 621 rare_dilations, 622 ) 623 for dim in dims 624 ] 625 ) 626 self.conv_out = nn.Conv1d(num_f_maps * len(dims), num_f_maps, 1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
638class DilatedTCN(nn.Module): 639 """ 640 Prediction generation stage 641 """ 642 643 def __init__( 644 self, 645 num_layers, 646 num_f_maps, 647 dim, 648 direction, 649 num_bp=None, 650 block_size=5, 651 kernel_size=3, 652 rare_dilations=False, 653 attention="none", 654 multihead=False, 655 ): 656 """ 657 Parameters 658 ---------- 659 num_layers : int 660 number of layers 661 num_f_maps : int 662 number of feature maps 663 dim : int 664 number of features in input 665 direction : [None, 'forward', 'backward'] 666 the direction of convolutions; if None, regular convolutions are used 667 block_size : int, default 0 668 if not 0, skip connections are added to the prediction generation stage with this interval 669 kernel_size : int, default 3 670 kernel size 671 rare_dilations : bool, default False 672 if `False`, dilation increases every layer, otherwise every second layer 673 """ 674 675 super().__init__() 676 self.num_layers = num_layers 677 self.block_size = block_size 678 self.direction = direction 679 self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1) 680 pars = { 681 "in_channels": num_f_maps, 682 "out_channels": num_f_maps, 683 "dropout_rate": 0.5, 684 "causal": (direction is not None), 685 "kernel_size": kernel_size, 686 } 687 module = DualDilatedResidualLayer 688 if not rare_dilations: 689 self.layers = nn.ModuleList( 690 [ 691 module(dilation1=2 ** (num_layers - 1 - i), dilation2=2**i, **pars) 692 for i in range(num_layers) 693 ] 694 ) 695 else: 696 self.layers = nn.ModuleList( 697 [ 698 module( 699 dilation1=2 ** (num_layers // 2 - 1 - i // 2), 700 dilation2=2 ** (i // 2), 701 **pars, 702 ) 703 for i in range(num_layers) 704 ] 705 ) 706 self.attention_layers = nn.ModuleList([]) 707 self.attention = attention 708 if self.attention in ["basic"]: 709 self.attention_layers += nn.ModuleList( 710 [ 711 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 712 nn.ReLU(), 713 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 714 nn.Sigmoid(), 715 ] 716 ) 717 elif isinstance(self.attention, str) and self.attention != "none": 718 self.attention_layers = AttLayer( 719 num_f_maps, 720 num_f_maps, 721 num_f_maps, 722 4, 723 4, 724 2, 725 64, 726 att_type=self.attention, 727 stage="encoder", 728 ) 729 self.multihead = multihead 730 731 def forward(self, x, tag=None): 732 """Forward pass.""" 733 x = self.conv_1x1_in(x) 734 f = copy.copy(x) 735 for i, layer in enumerate(self.layers): 736 f = layer(f, self.direction) 737 if self.block_size != 0 and (i + 1) % self.block_size == 0: 738 f = f + x 739 x = copy.copy(f) 740 if isinstance(self.attention, str) and self.attention != "none": 741 x = copy.copy(f) 742 if self.attention != "basic": 743 f = self.attention_layers(x) 744 elif not self.multihead: 745 for layer in self.attention_layers: 746 x = layer(x) 747 f = f * x 748 else: 749 outputs = [] 750 for layers in self.attention_layers: 751 y = copy.copy(x) 752 for layer in layers: 753 y = layer(y) 754 outputs.append(copy.copy(f * y)) 755 outputs = torch.cat(outputs, dim=1) 756 f = self.conv_att(self.dropout(outputs)) 757 return f
Prediction generation stage
643 def __init__( 644 self, 645 num_layers, 646 num_f_maps, 647 dim, 648 direction, 649 num_bp=None, 650 block_size=5, 651 kernel_size=3, 652 rare_dilations=False, 653 attention="none", 654 multihead=False, 655 ): 656 """ 657 Parameters 658 ---------- 659 num_layers : int 660 number of layers 661 num_f_maps : int 662 number of feature maps 663 dim : int 664 number of features in input 665 direction : [None, 'forward', 'backward'] 666 the direction of convolutions; if None, regular convolutions are used 667 block_size : int, default 0 668 if not 0, skip connections are added to the prediction generation stage with this interval 669 kernel_size : int, default 3 670 kernel size 671 rare_dilations : bool, default False 672 if `False`, dilation increases every layer, otherwise every second layer 673 """ 674 675 super().__init__() 676 self.num_layers = num_layers 677 self.block_size = block_size 678 self.direction = direction 679 self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1) 680 pars = { 681 "in_channels": num_f_maps, 682 "out_channels": num_f_maps, 683 "dropout_rate": 0.5, 684 "causal": (direction is not None), 685 "kernel_size": kernel_size, 686 } 687 module = DualDilatedResidualLayer 688 if not rare_dilations: 689 self.layers = nn.ModuleList( 690 [ 691 module(dilation1=2 ** (num_layers - 1 - i), dilation2=2**i, **pars) 692 for i in range(num_layers) 693 ] 694 ) 695 else: 696 self.layers = nn.ModuleList( 697 [ 698 module( 699 dilation1=2 ** (num_layers // 2 - 1 - i // 2), 700 dilation2=2 ** (i // 2), 701 **pars, 702 ) 703 for i in range(num_layers) 704 ] 705 ) 706 self.attention_layers = nn.ModuleList([]) 707 self.attention = attention 708 if self.attention in ["basic"]: 709 self.attention_layers += nn.ModuleList( 710 [ 711 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 712 nn.ReLU(), 713 nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1), 714 nn.Sigmoid(), 715 ] 716 ) 717 elif isinstance(self.attention, str) and self.attention != "none": 718 self.attention_layers = AttLayer( 719 num_f_maps, 720 num_f_maps, 721 num_f_maps, 722 4, 723 4, 724 2, 725 64, 726 att_type=self.attention, 727 stage="encoder", 728 ) 729 self.multihead = multihead
Parameters
num_layers : int
number of layers
num_f_maps : int
number of feature maps
dim : int
number of features in input
direction : [None, 'forward', 'backward']
the direction of convolutions; if None, regular convolutions are used
block_size : int, default 0
if not 0, skip connections are added to the prediction generation stage with this interval
kernel_size : int, default 3
kernel size
rare_dilations : bool, default False
if False
, dilation increases every layer, otherwise every second layer
731 def forward(self, x, tag=None): 732 """Forward pass.""" 733 x = self.conv_1x1_in(x) 734 f = copy.copy(x) 735 for i, layer in enumerate(self.layers): 736 f = layer(f, self.direction) 737 if self.block_size != 0 and (i + 1) % self.block_size == 0: 738 f = f + x 739 x = copy.copy(f) 740 if isinstance(self.attention, str) and self.attention != "none": 741 x = copy.copy(f) 742 if self.attention != "basic": 743 f = self.attention_layers(x) 744 elif not self.multihead: 745 for layer in self.attention_layers: 746 x = layer(x) 747 f = f * x 748 else: 749 outputs = [] 750 for layers in self.attention_layers: 751 y = copy.copy(x) 752 for layer in layers: 753 y = layer(y) 754 outputs.append(copy.copy(f * y)) 755 outputs = torch.cat(outputs, dim=1) 756 f = self.conv_att(self.dropout(outputs)) 757 return f
Forward pass.
760class DilatedTCNB(nn.Module): 761 """ 762 Bidirectional prediction generation stage 763 """ 764 765 def __init__( 766 self, 767 num_layers, 768 num_f_maps, 769 dim, 770 block_size=5, 771 kernel_size=3, 772 rare_dilations=False, 773 ): 774 """ 775 Parameters 776 ---------- 777 num_layers : int 778 number of layers 779 num_f_maps : int 780 number of feature maps 781 dim : int 782 number of features in input 783 block_size : int, default 0 784 if not 0, skip connections are added to the prediction generation stage with this interval 785 kernel_size : int, default 3 786 kernel size 787 rare_dilations : bool, default False 788 if `False`, dilation increases every layer, otherwise every second layer 789 """ 790 791 super().__init__() 792 self.num_layers = num_layers 793 self.block_size = block_size 794 self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1) 795 self.conv_1x1_out = nn.Conv1d(num_f_maps * 2, num_f_maps, 1) 796 if not rare_dilations: 797 self.layers = nn.ModuleList( 798 [ 799 DualDilatedResidualLayer( 800 dilation1=2 ** (num_layers - 1 - i), 801 dilation2=2**i, 802 in_channels=num_f_maps, 803 out_channels=num_f_maps, 804 dropout_rate=0.5, 805 causal=True, 806 kernel_size=kernel_size, 807 ) 808 for i in range(num_layers) 809 ] 810 ) 811 else: 812 self.layers = nn.ModuleList( 813 [ 814 DualDilatedResidualLayer( 815 dilation1=2 ** (num_layers // 2 - 1 - i // 2), 816 dilation2=2 ** (i // 2), 817 in_channels=num_f_maps, 818 out_channels=num_f_maps, 819 dropout_rate=0.5, 820 causal=True, 821 kernel_size=kernel_size, 822 ) 823 for i in range(num_layers) 824 ] 825 ) 826 827 def forward(self, x, tag=None): 828 """Forward pass.""" 829 x_f = self.conv_1x1_in(x) 830 x_b = copy.copy(x_f) 831 forward = copy.copy(x_f) 832 backward = copy.copy(x_f) 833 for i, layer_f in enumerate(self.layers): 834 forward = layer_f(forward, "forward") 835 backward = layer_f(backward, "backward") 836 if self.block_size != 0 and (i + 1) % self.block_size == 0: 837 forward = forward + x_f 838 backward = backward + x_b 839 x_f = copy.copy(forward) 840 x_b = copy.copy(backward) 841 out = torch.cat([forward, backward], 1) 842 out = self.conv_1x1_out(out) 843 return out
Bidirectional prediction generation stage
765 def __init__( 766 self, 767 num_layers, 768 num_f_maps, 769 dim, 770 block_size=5, 771 kernel_size=3, 772 rare_dilations=False, 773 ): 774 """ 775 Parameters 776 ---------- 777 num_layers : int 778 number of layers 779 num_f_maps : int 780 number of feature maps 781 dim : int 782 number of features in input 783 block_size : int, default 0 784 if not 0, skip connections are added to the prediction generation stage with this interval 785 kernel_size : int, default 3 786 kernel size 787 rare_dilations : bool, default False 788 if `False`, dilation increases every layer, otherwise every second layer 789 """ 790 791 super().__init__() 792 self.num_layers = num_layers 793 self.block_size = block_size 794 self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1) 795 self.conv_1x1_out = nn.Conv1d(num_f_maps * 2, num_f_maps, 1) 796 if not rare_dilations: 797 self.layers = nn.ModuleList( 798 [ 799 DualDilatedResidualLayer( 800 dilation1=2 ** (num_layers - 1 - i), 801 dilation2=2**i, 802 in_channels=num_f_maps, 803 out_channels=num_f_maps, 804 dropout_rate=0.5, 805 causal=True, 806 kernel_size=kernel_size, 807 ) 808 for i in range(num_layers) 809 ] 810 ) 811 else: 812 self.layers = nn.ModuleList( 813 [ 814 DualDilatedResidualLayer( 815 dilation1=2 ** (num_layers // 2 - 1 - i // 2), 816 dilation2=2 ** (i // 2), 817 in_channels=num_f_maps, 818 out_channels=num_f_maps, 819 dropout_rate=0.5, 820 causal=True, 821 kernel_size=kernel_size, 822 ) 823 for i in range(num_layers) 824 ] 825 )
Parameters
num_layers : int
number of layers
num_f_maps : int
number of feature maps
dim : int
number of features in input
block_size : int, default 0
if not 0, skip connections are added to the prediction generation stage with this interval
kernel_size : int, default 3
kernel size
rare_dilations : bool, default False
if False
, dilation increases every layer, otherwise every second layer
827 def forward(self, x, tag=None): 828 """Forward pass.""" 829 x_f = self.conv_1x1_in(x) 830 x_b = copy.copy(x_f) 831 forward = copy.copy(x_f) 832 backward = copy.copy(x_f) 833 for i, layer_f in enumerate(self.layers): 834 forward = layer_f(forward, "forward") 835 backward = layer_f(backward, "backward") 836 if self.block_size != 0 and (i + 1) % self.block_size == 0: 837 forward = forward + x_f 838 backward = backward + x_b 839 x_f = copy.copy(forward) 840 x_b = copy.copy(backward) 841 out = torch.cat([forward, backward], 1) 842 out = self.conv_1x1_out(out) 843 return out
Forward pass.
846class SpatialFeatures(nn.Module): 847 """ 848 Spatial features extraction stage 849 """ 850 851 def __init__( 852 self, 853 num_layers, 854 num_f_maps, 855 dim, 856 block_size=5, 857 graph_edges=None, 858 num_nodes=None, 859 denom: int = 8, 860 ): 861 """ 862 Parameters 863 ---------- 864 num_layers : int 865 number of layers 866 num_f_maps : int 867 number of feature maps 868 dim : int 869 number of features in input 870 block_size : int, default 5 871 if not 0, skip connections are added to the prediction generation stage with this interval 872 """ 873 874 super().__init__() 875 self.num_nodes = num_nodes 876 if graph_edges is None: 877 module = SimpleResidualLayer 878 self.graph = False 879 pars = {"num_f_maps": num_f_maps, "dropout_rate": 0.5} 880 self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1) 881 else: 882 raise NotImplementedError("Graph not implemented") 883 884 self.num_layers = num_layers 885 self.block_size = block_size 886 self.layers = nn.ModuleList([module(**pars) for _ in range(num_layers)]) 887 888 def forward(self, x): 889 """Forward pass.""" 890 if self.graph: 891 B, _, L = x.shape 892 x = x.transpose(-1, -2) 893 x = x.reshape((-1, x.shape[-1])) 894 x = x.reshape((x.shape[0], self.num_nodes, -1)) 895 x = x.transpose(-1, -2) 896 x = self.conv_1x1_in(x) 897 f = copy.copy(x) 898 for i, layer in enumerate(self.layers): 899 f = layer(f) 900 if self.block_size != 0 and (i + 1) % self.block_size == 0: 901 f = f + x 902 x = copy.copy(f) 903 if self.graph: 904 f = f.reshape((B, L, -1)) 905 f = f.transpose(-1, -2) 906 f = self.conv_1x1_out(f) 907 return f
Spatial features extraction stage
851 def __init__( 852 self, 853 num_layers, 854 num_f_maps, 855 dim, 856 block_size=5, 857 graph_edges=None, 858 num_nodes=None, 859 denom: int = 8, 860 ): 861 """ 862 Parameters 863 ---------- 864 num_layers : int 865 number of layers 866 num_f_maps : int 867 number of feature maps 868 dim : int 869 number of features in input 870 block_size : int, default 5 871 if not 0, skip connections are added to the prediction generation stage with this interval 872 """ 873 874 super().__init__() 875 self.num_nodes = num_nodes 876 if graph_edges is None: 877 module = SimpleResidualLayer 878 self.graph = False 879 pars = {"num_f_maps": num_f_maps, "dropout_rate": 0.5} 880 self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1) 881 else: 882 raise NotImplementedError("Graph not implemented") 883 884 self.num_layers = num_layers 885 self.block_size = block_size 886 self.layers = nn.ModuleList([module(**pars) for _ in range(num_layers)])
Parameters
num_layers : int number of layers num_f_maps : int number of feature maps dim : int number of features in input block_size : int, default 5 if not 0, skip connections are added to the prediction generation stage with this interval
888 def forward(self, x): 889 """Forward pass.""" 890 if self.graph: 891 B, _, L = x.shape 892 x = x.transpose(-1, -2) 893 x = x.reshape((-1, x.shape[-1])) 894 x = x.reshape((x.shape[0], self.num_nodes, -1)) 895 x = x.transpose(-1, -2) 896 x = self.conv_1x1_in(x) 897 f = copy.copy(x) 898 for i, layer in enumerate(self.layers): 899 f = layer(f) 900 if self.block_size != 0 and (i + 1) % self.block_size == 0: 901 f = f + x 902 x = copy.copy(f) 903 if self.graph: 904 f = f.reshape((B, L, -1)) 905 f = f.transpose(-1, -2) 906 f = self.conv_1x1_out(f) 907 return f
Forward pass.
910class MSRefinement(nn.Module): 911 """ 912 Refinement stage 913 """ 914 915 def __init__( 916 self, 917 num_layers_R, 918 num_R, 919 num_f_maps_input, 920 num_f_maps, 921 num_classes, 922 dropout_rate, 923 exclusive, 924 skip_connections, 925 direction, 926 block_size=0, 927 num_heads=1, 928 attention="none", 929 ): 930 """ 931 Parameters 932 ---------- 933 num_layers_R : int 934 number of layers in refinement modules 935 num_R : int 936 number of refinement modules 937 num_f_maps : int 938 number of feature maps 939 num_classes : int 940 number of target classes 941 dropout_rate : float 942 dropout rate 943 exclusive : bool 944 set `False` for multi-label classification 945 skip_connections : bool 946 if `True`, skip connections are added 947 direction : [None, 'bidirectional', 'forward', 'backward'] 948 the direction of convolutions; if None, regular convolutions are used 949 block_size : int, default 0 950 if not 0, skip connections are added to the prediction generation stage with this interval 951 num_heads : int, default 1 952 number of parallel refinement stages 953 """ 954 955 super().__init__() 956 self.skip_connections = skip_connections 957 self.num_heads = num_heads 958 if exclusive: 959 self.nl = lambda x: F.softmax(x, dim=1) 960 else: 961 self.nl = lambda x: torch.sigmoid(x) 962 if direction == "bidirectional": 963 refinement_module = RefinementB 964 else: 965 refinement_module = Refinement 966 self.Rs = nn.ModuleList( 967 [ 968 nn.ModuleList( 969 [ 970 refinement_module( 971 num_layers=num_layers_R, 972 num_f_maps=num_f_maps, 973 num_f_maps_input=num_f_maps_input, 974 dim=num_classes, 975 num_classes=num_classes, 976 dropout_rate=dropout_rate, 977 direction=direction, 978 skip_connections=skip_connections, 979 block_size=block_size, 980 attention=attention, 981 ) 982 for s in range(num_R) 983 ] 984 ) 985 for _ in range(self.num_heads) 986 ] 987 ) 988 self.conv_out = nn.ModuleList( 989 [nn.Conv1d(num_f_maps_input, num_classes, 1) for _ in range(self.num_heads)] 990 ) 991 if self.num_heads == 1: 992 self.Rs = self.Rs[0] 993 self.conv_out = self.conv_out[0] 994 995 def _Rs(self, tag): 996 if self.num_heads == 1: 997 return self.Rs 998 if tag is None: 999 tag = 0 1000 for i in range(1, self.num_heads): 1001 self.Rs[i].load_state_dict(self.Rs[0].state_dict()) 1002 return self.Rs[tag] 1003 1004 def _conv_out(self, tag): 1005 if self.num_heads == 1: 1006 return self.conv_out 1007 if tag is None: 1008 tag = 0 1009 for i in range(1, self.num_heads): 1010 self.conv_out[i].load_state_dict(self.conv_out[0].state_dict()) 1011 return self.conv_out[tag] 1012 1013 def forward(self, x, tag=None): 1014 """Forward pass.""" 1015 if tag is not None: 1016 tag = tag[0] 1017 out = self._conv_out(tag)(x) 1018 outputs = out.unsqueeze(0) 1019 for R in self._Rs(tag): 1020 if self.skip_connections: 1021 out = R(torch.cat([self.nl(out), x], axis=1)) 1022 # out = R(torch.cat([out, x], axis=1)) 1023 else: 1024 out = R(self.nl(out)) 1025 # out = R(out) 1026 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 1027 1028 return outputs
Refinement stage
915 def __init__( 916 self, 917 num_layers_R, 918 num_R, 919 num_f_maps_input, 920 num_f_maps, 921 num_classes, 922 dropout_rate, 923 exclusive, 924 skip_connections, 925 direction, 926 block_size=0, 927 num_heads=1, 928 attention="none", 929 ): 930 """ 931 Parameters 932 ---------- 933 num_layers_R : int 934 number of layers in refinement modules 935 num_R : int 936 number of refinement modules 937 num_f_maps : int 938 number of feature maps 939 num_classes : int 940 number of target classes 941 dropout_rate : float 942 dropout rate 943 exclusive : bool 944 set `False` for multi-label classification 945 skip_connections : bool 946 if `True`, skip connections are added 947 direction : [None, 'bidirectional', 'forward', 'backward'] 948 the direction of convolutions; if None, regular convolutions are used 949 block_size : int, default 0 950 if not 0, skip connections are added to the prediction generation stage with this interval 951 num_heads : int, default 1 952 number of parallel refinement stages 953 """ 954 955 super().__init__() 956 self.skip_connections = skip_connections 957 self.num_heads = num_heads 958 if exclusive: 959 self.nl = lambda x: F.softmax(x, dim=1) 960 else: 961 self.nl = lambda x: torch.sigmoid(x) 962 if direction == "bidirectional": 963 refinement_module = RefinementB 964 else: 965 refinement_module = Refinement 966 self.Rs = nn.ModuleList( 967 [ 968 nn.ModuleList( 969 [ 970 refinement_module( 971 num_layers=num_layers_R, 972 num_f_maps=num_f_maps, 973 num_f_maps_input=num_f_maps_input, 974 dim=num_classes, 975 num_classes=num_classes, 976 dropout_rate=dropout_rate, 977 direction=direction, 978 skip_connections=skip_connections, 979 block_size=block_size, 980 attention=attention, 981 ) 982 for s in range(num_R) 983 ] 984 ) 985 for _ in range(self.num_heads) 986 ] 987 ) 988 self.conv_out = nn.ModuleList( 989 [nn.Conv1d(num_f_maps_input, num_classes, 1) for _ in range(self.num_heads)] 990 ) 991 if self.num_heads == 1: 992 self.Rs = self.Rs[0] 993 self.conv_out = self.conv_out[0]
Parameters
num_layers_R : int
number of layers in refinement modules
num_R : int
number of refinement modules
num_f_maps : int
number of feature maps
num_classes : int
number of target classes
dropout_rate : float
dropout rate
exclusive : bool
set False
for multi-label classification
skip_connections : bool
if True
, skip connections are added
direction : [None, 'bidirectional', 'forward', 'backward']
the direction of convolutions; if None, regular convolutions are used
block_size : int, default 0
if not 0, skip connections are added to the prediction generation stage with this interval
num_heads : int, default 1
number of parallel refinement stages
1013 def forward(self, x, tag=None): 1014 """Forward pass.""" 1015 if tag is not None: 1016 tag = tag[0] 1017 out = self._conv_out(tag)(x) 1018 outputs = out.unsqueeze(0) 1019 for R in self._Rs(tag): 1020 if self.skip_connections: 1021 out = R(torch.cat([self.nl(out), x], axis=1)) 1022 # out = R(torch.cat([out, x], axis=1)) 1023 else: 1024 out = R(self.nl(out)) 1025 # out = R(out) 1026 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 1027 1028 return outputs
Forward pass.
1148class MSRefinementAttention(nn.Module): 1149 """ 1150 Refinement stage 1151 """ 1152 1153 def __init__( 1154 self, 1155 num_layers_R, 1156 num_R, 1157 num_f_maps_input, 1158 num_f_maps, 1159 num_classes, 1160 dropout_rate, 1161 exclusive, 1162 skip_connections, 1163 len_segment, 1164 block_size=0, 1165 ): 1166 """ 1167 Parameters 1168 ---------- 1169 num_layers_R : int 1170 number of layers in refinement modules 1171 num_R : int 1172 number of refinement modules 1173 num_f_maps : int 1174 number of feature maps 1175 num_classes : int 1176 number of target classes 1177 dropout_rate : float 1178 dropout rate 1179 exclusive : bool 1180 set `False` for multi-label classification 1181 skip_connections : bool 1182 if `True`, skip connections are added 1183 direction : [None, 'bidirectional', 'forward', 'backward'] 1184 the direction of convolutions; if None, regular convolutions are used 1185 block_size : int, default 0 1186 if not 0, skip connections are added to the prediction generation stage with this interval 1187 num_heads : int, default 1 1188 number of parallel refinement stages 1189 """ 1190 1191 super().__init__() 1192 self.skip_connections = skip_connections 1193 if exclusive: 1194 self.nl = lambda x: F.softmax(x, dim=1) 1195 else: 1196 self.nl = lambda x: torch.sigmoid(x) 1197 refinement_module = Refinement_SE 1198 self.Rs = nn.ModuleList( 1199 [ 1200 refinement_module( 1201 num_layers=num_layers_R, 1202 num_f_maps=num_f_maps, 1203 num_f_maps_input=num_f_maps_input, 1204 dim=num_classes, 1205 num_classes=num_classes, 1206 dropout_rate=dropout_rate, 1207 direction=None, 1208 skip_connections=skip_connections, 1209 block_size=block_size, 1210 len_segment=len_segment, 1211 ) 1212 for s in range(num_R) 1213 ] 1214 ) 1215 self.conv_out = nn.Conv1d(num_f_maps_input, num_classes, 1) 1216 1217 def forward(self, x, tag=None): 1218 """Forward pass.""" 1219 if tag is not None: 1220 tag = tag[0] 1221 out = self.conv_out(x) 1222 outputs = out.unsqueeze(0) 1223 for R in self.Rs: 1224 if self.skip_connections: 1225 out = R(torch.cat([self.nl(out), x], axis=1), tag) 1226 # out = R(torch.cat([out, x], axis=1), tag) 1227 else: 1228 out = R(self.nl(out), tag) 1229 # out = R(out, tag) 1230 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 1231 1232 return outputs
Refinement stage
1153 def __init__( 1154 self, 1155 num_layers_R, 1156 num_R, 1157 num_f_maps_input, 1158 num_f_maps, 1159 num_classes, 1160 dropout_rate, 1161 exclusive, 1162 skip_connections, 1163 len_segment, 1164 block_size=0, 1165 ): 1166 """ 1167 Parameters 1168 ---------- 1169 num_layers_R : int 1170 number of layers in refinement modules 1171 num_R : int 1172 number of refinement modules 1173 num_f_maps : int 1174 number of feature maps 1175 num_classes : int 1176 number of target classes 1177 dropout_rate : float 1178 dropout rate 1179 exclusive : bool 1180 set `False` for multi-label classification 1181 skip_connections : bool 1182 if `True`, skip connections are added 1183 direction : [None, 'bidirectional', 'forward', 'backward'] 1184 the direction of convolutions; if None, regular convolutions are used 1185 block_size : int, default 0 1186 if not 0, skip connections are added to the prediction generation stage with this interval 1187 num_heads : int, default 1 1188 number of parallel refinement stages 1189 """ 1190 1191 super().__init__() 1192 self.skip_connections = skip_connections 1193 if exclusive: 1194 self.nl = lambda x: F.softmax(x, dim=1) 1195 else: 1196 self.nl = lambda x: torch.sigmoid(x) 1197 refinement_module = Refinement_SE 1198 self.Rs = nn.ModuleList( 1199 [ 1200 refinement_module( 1201 num_layers=num_layers_R, 1202 num_f_maps=num_f_maps, 1203 num_f_maps_input=num_f_maps_input, 1204 dim=num_classes, 1205 num_classes=num_classes, 1206 dropout_rate=dropout_rate, 1207 direction=None, 1208 skip_connections=skip_connections, 1209 block_size=block_size, 1210 len_segment=len_segment, 1211 ) 1212 for s in range(num_R) 1213 ] 1214 ) 1215 self.conv_out = nn.Conv1d(num_f_maps_input, num_classes, 1)
Parameters
num_layers_R : int
number of layers in refinement modules
num_R : int
number of refinement modules
num_f_maps : int
number of feature maps
num_classes : int
number of target classes
dropout_rate : float
dropout rate
exclusive : bool
set False
for multi-label classification
skip_connections : bool
if True
, skip connections are added
direction : [None, 'bidirectional', 'forward', 'backward']
the direction of convolutions; if None, regular convolutions are used
block_size : int, default 0
if not 0, skip connections are added to the prediction generation stage with this interval
num_heads : int, default 1
number of parallel refinement stages
1217 def forward(self, x, tag=None): 1218 """Forward pass.""" 1219 if tag is not None: 1220 tag = tag[0] 1221 out = self.conv_out(x) 1222 outputs = out.unsqueeze(0) 1223 for R in self.Rs: 1224 if self.skip_connections: 1225 out = R(torch.cat([self.nl(out), x], axis=1), tag) 1226 # out = R(torch.cat([out, x], axis=1), tag) 1227 else: 1228 out = R(self.nl(out), tag) 1229 # out = R(out, tag) 1230 outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0) 1231 1232 return outputs
Forward pass.
1235class DilatedTCNC(nn.Module): 1236 def __init__( 1237 self, 1238 num_f_maps, 1239 num_layers_PG, 1240 len_segment, 1241 block_size_prediction=5, 1242 kernel_size_prediction=3, 1243 direction_PG=None, 1244 ): 1245 super(DilatedTCNC, self).__init__() 1246 if direction_PG == "bidirectional": 1247 self.PG_S = DilatedTCNB( 1248 num_layers=num_layers_PG, 1249 num_f_maps=num_f_maps, 1250 dim=num_f_maps, 1251 block_size=block_size_prediction, 1252 ) 1253 self.PG_T = DilatedTCNB( 1254 num_layers=num_layers_PG, 1255 num_f_maps=len_segment, 1256 dim=len_segment, 1257 block_size=block_size_prediction, 1258 ) 1259 else: 1260 self.PG_S = DilatedTCN( 1261 num_layers=num_layers_PG, 1262 num_f_maps=num_f_maps, 1263 dim=num_f_maps, 1264 direction=direction_PG, 1265 block_size=block_size_prediction, 1266 kernel_size=kernel_size_prediction, 1267 ) 1268 self.PG_T = DilatedTCN( 1269 num_layers=num_layers_PG, 1270 num_f_maps=len_segment, 1271 dim=len_segment, 1272 direction=direction_PG, 1273 block_size=block_size_prediction, 1274 kernel_size=kernel_size_prediction, 1275 ) 1276 1277 def forward(self, x): 1278 """Forward pass.""" 1279 x = self.PG_S(x) 1280 x = torch.transpose(x, 1, 2) 1281 x = self.PG_T(x) 1282 x = torch.transpose(x, 1, 2) 1283 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
1236 def __init__( 1237 self, 1238 num_f_maps, 1239 num_layers_PG, 1240 len_segment, 1241 block_size_prediction=5, 1242 kernel_size_prediction=3, 1243 direction_PG=None, 1244 ): 1245 super(DilatedTCNC, self).__init__() 1246 if direction_PG == "bidirectional": 1247 self.PG_S = DilatedTCNB( 1248 num_layers=num_layers_PG, 1249 num_f_maps=num_f_maps, 1250 dim=num_f_maps, 1251 block_size=block_size_prediction, 1252 ) 1253 self.PG_T = DilatedTCNB( 1254 num_layers=num_layers_PG, 1255 num_f_maps=len_segment, 1256 dim=len_segment, 1257 block_size=block_size_prediction, 1258 ) 1259 else: 1260 self.PG_S = DilatedTCN( 1261 num_layers=num_layers_PG, 1262 num_f_maps=num_f_maps, 1263 dim=num_f_maps, 1264 direction=direction_PG, 1265 block_size=block_size_prediction, 1266 kernel_size=kernel_size_prediction, 1267 ) 1268 self.PG_T = DilatedTCN( 1269 num_layers=num_layers_PG, 1270 num_f_maps=len_segment, 1271 dim=len_segment, 1272 direction=direction_PG, 1273 block_size=block_size_prediction, 1274 kernel_size=kernel_size_prediction, 1275 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.