dlc2action.model.ms_tcn_modules

   1#
   2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. 
   5# A copy is included in dlc2action/LICENSE.AGPL.
   6#
   7import copy
   8
   9import torch
  10import torch.nn as nn
  11import torch.nn.functional as F
  12from dlc2action.model.asformer import AttLayer
  13
  14
  15class Refinement(nn.Module):
  16    """
  17    Refinement module
  18    """
  19
  20    def __init__(
  21        self,
  22        num_layers,
  23        num_f_maps_input,
  24        num_f_maps,
  25        dim,
  26        num_classes,
  27        dropout_rate,
  28        direction,
  29        skip_connections,
  30        attention="none",
  31        block_size=0,
  32    ):
  33        """
  34        Parameters
  35        ----------
  36        num_layers : int
  37            the number of layers
  38        num_f_maps : int
  39            the number of feature maps
  40        dim : int
  41            the number of features in input
  42        num_classes : int
  43            the number of target classes
  44        dropout_rate : float
  45            dropout rate
  46        direction : [None, 'forward', 'backward']
  47            the direction of convolutions; if None, regular convolutions are used
  48        skip_connections : bool
  49            if `True`, skip connections are added
  50        block_size : int, default 0
  51            if not 0, skip connections are added to the prediction generation stage with this interval
  52        """
  53
  54        super(Refinement, self).__init__()
  55        self.block_size = block_size
  56        self.direction = direction
  57        if skip_connections:
  58            self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1)
  59        else:
  60            self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
  61        self.layers = nn.ModuleList(
  62            [
  63                copy.deepcopy(
  64                    DilatedResidualLayer(
  65                        dilation=2**i,
  66                        in_channels=num_f_maps,
  67                        out_channels=num_f_maps,
  68                        dropout_rate=dropout_rate,
  69                        causal=(direction is not None),
  70                    )
  71                )
  72                for i in range(num_layers)
  73            ]
  74        )
  75        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
  76        self.attention_layers = nn.ModuleList([])
  77        self.attention = attention
  78        if self.attention == "basic":
  79            self.attention_layers += nn.ModuleList(
  80                [
  81                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
  82                    nn.ReLU(),
  83                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
  84                    nn.Sigmoid(),
  85                ]
  86            )
  87
  88    def forward(self, x):
  89        """Forward pass."""
  90        x = self.conv_1x1(x)
  91        out = copy.copy(x)
  92        for l_i, layer in enumerate(self.layers):
  93            out = layer(out, self.direction)
  94            if self.block_size != 0 and (l_i + 1) % self.block_size == 0:
  95                out = out + x
  96                x = copy.copy(out)
  97        if self.attention != "none":
  98            x = copy.copy(out)
  99            for layer in self.attention_layers:
 100                x = layer(x)
 101            out = out * x
 102        out = self.conv_out(out)
 103        return out
 104
 105
 106class Refinement_SE(nn.Module):
 107    """
 108    Refinement module
 109    """
 110
 111    def __init__(
 112        self,
 113        num_layers,
 114        num_f_maps_input,
 115        num_f_maps,
 116        dim,
 117        num_classes,
 118        dropout_rate,
 119        direction,
 120        skip_connections,
 121        len_segment,
 122        block_size=0,
 123    ):
 124        """
 125        Parameters
 126        ----------
 127        num_layers : int
 128            the number of layers
 129        num_f_maps : int
 130            the number of feature maps
 131        dim : int
 132            the number of features in input
 133        num_classes : int
 134            the number of target classes
 135        dropout_rate : float
 136            dropout rate
 137        direction : [None, 'forward', 'backward']
 138            the direction of convolutions; if None, regular convolutions are used
 139        skip_connections : bool
 140            if `True`, skip connections are added
 141        block_size : int, default 0
 142            if not 0, skip connections are added to the prediction generation stage with this interval
 143        """
 144
 145        super().__init__()
 146        self.block_size = block_size
 147        self.direction = direction
 148        if skip_connections:
 149            self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1)
 150        else:
 151            self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
 152        self.layers = nn.ModuleList(
 153            [
 154                DilatedResidualLayer(
 155                    dilation=2**i,
 156                    in_channels=num_f_maps,
 157                    out_channels=num_f_maps,
 158                    dropout_rate=dropout_rate,
 159                    causal=(direction is not None),
 160                )
 161                for i in range(num_layers)
 162            ]
 163        )
 164        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
 165        self.fc1_f = nn.ModuleList(
 166            [nn.Linear(num_f_maps, num_f_maps // 2) for _ in range(6)]
 167        )
 168        self.fc2_f = nn.ModuleList(
 169            [nn.Linear(num_f_maps // 2, num_f_maps) for _ in range(6)]
 170        )
 171
 172    def _fc1_f(self, tag):
 173        if tag is None:
 174            for i in range(1, len(self.fc1_f)):
 175                self.fc1_f[i].load_state_dict(self.fc1_f[0].state_dict())
 176            return self.fc1_f[0]
 177        else:
 178            return self.fc1_f[tag]
 179
 180    def _fc2_f(self, tag):
 181        if tag is None:
 182            for i in range(1, len(self.fc2_f)):
 183                self.fc2_f[i].load_state_dict(self.fc2_f[0].state_dict())
 184            return self.fc2_f[0]
 185        else:
 186            return self.fc2_f[tag]
 187
 188    def forward(self, x, tag):
 189        """Forward pass."""
 190        x = self.conv_1x1(x)
 191        scale = torch.mean(x, -1)
 192        scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale)))
 193        scale = F.sigmoid(scale).unsqueeze(-1)
 194        x = scale * x
 195        out = copy.copy(x)
 196        for l_i, layer in enumerate(self.layers):
 197            out = layer(out, self.direction)
 198            if self.block_size != 0 and (l_i + 1) % self.block_size == 0:
 199                out = out + x
 200                x = copy.copy(out)
 201        out = self.conv_out(out)
 202        return out
 203
 204
 205class RefinementB(Refinement):
 206    """
 207    Bidirectional refinement module
 208    """
 209
 210    def forward(self, x):
 211        """Forward pass."""
 212        x_f = self.conv_1x1(x)
 213        x_b = copy.copy(x_f)
 214        forward = copy.copy(x_f)
 215        backward = copy.copy(x_f)
 216        for i, layer_f in enumerate(self.layers):
 217            forward = layer_f(forward, "forward")
 218            backward = layer_f(backward, "backward")
 219            if self.block_size != 0 and (i + 1) % self.block_size == 0:
 220                forward = forward + x_f
 221                backward = backward + x_b
 222                x_f = copy.copy(forward)
 223                x_b = copy.copy(backward)
 224        out = torch.cat([forward, backward], 1)
 225        out = self.conv_out(out)
 226        return out
 227
 228
 229class SimpleResidualLayer(nn.Module):
 230    """
 231    Basic residual layer
 232    """
 233
 234    def __init__(self, num_f_maps, dropout_rate):
 235        """
 236        Parameters
 237        ----------
 238        in_channels : int
 239            number of input channels
 240        out_channels : int
 241            number of output channels
 242        dropout_rate : float
 243            dropout rate
 244        """
 245
 246        super().__init__()
 247        self.conv_1x1_in = nn.Conv1d(num_f_maps, num_f_maps, 1)
 248        self.conv_1x1 = nn.Conv1d(num_f_maps, num_f_maps, 1)
 249        self.dropout = nn.Dropout(dropout_rate)
 250
 251    def forward(self, x):
 252        """Forward pass."""
 253        out = self.conv_1x1_in(x)
 254        out = F.relu(out)
 255        out = self.conv_1x1(out)
 256        out = self.dropout(out)
 257        return x + out
 258
 259
 260class DilatedResidualLayer(nn.Module):
 261    """
 262    Dilated residual layer
 263    """
 264
 265    def __init__(self, dilation, in_channels, out_channels, dropout_rate, causal):
 266        """
 267        Parameters
 268        ----------
 269        dilation : int
 270            dilation
 271        in_channels : int
 272            number of input channels
 273        out_channels : int
 274            number of output channels
 275        dropout_rate : float
 276            dropout rate
 277        causal : bool
 278            if `True`, causal convolutions are used
 279        """
 280
 281        super(DilatedResidualLayer, self).__init__()
 282        self.padding = dilation * 2
 283        self.causal = causal
 284        if self.causal:
 285            padding = 0
 286        else:
 287            padding = dilation
 288        self.conv_dilated = nn.Conv1d(
 289            in_channels, out_channels, 3, padding=padding, dilation=dilation
 290        )
 291        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
 292        self.dropout = nn.Dropout(dropout_rate)
 293
 294    def forward(self, x, direction=None):
 295        """Forward pass."""
 296        if direction is not None and not self.causal:
 297            raise ValueError("Cannot set direction in a non-causal layer!")
 298        elif direction is None and self.causal:
 299            direction = "forward"
 300        if direction == "forward":
 301            padding = (0, self.padding)
 302        elif direction == "backward":
 303            padding = (self.padding, 0)
 304        elif direction is not None:
 305            raise ValueError(
 306                f"Unrecognized direction: {direction}, please choose from"
 307                f'"backward", "forward" and None'
 308            )
 309        if direction is not None:
 310            out = self.conv_dilated(F.pad(x, padding))
 311        else:
 312            out = self.conv_dilated(x)
 313        out = F.relu(out)
 314        out = self.conv_1x1(out)
 315        out = self.dropout(out)
 316        return x + out
 317
 318
 319class DilatedResidualLayer_SE(nn.Module):
 320    """
 321    Dilated residual layer
 322    """
 323
 324    def __init__(
 325        self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment
 326    ):
 327        """
 328        Parameters
 329        ----------
 330        dilation : int
 331            dilation
 332        in_channels : int
 333            number of input channels
 334        out_channels : int
 335            number of output channels
 336        dropout_rate : float
 337            dropout rate
 338        causal : bool
 339            if `True`, causal convolutions are used
 340        """
 341
 342        super().__init__()
 343        self.padding = dilation * 2
 344        self.causal = causal
 345        if self.causal:
 346            padding = 0
 347        else:
 348            padding = dilation
 349        self.conv_dilated = nn.Conv1d(
 350            in_channels, out_channels, 3, padding=padding, dilation=dilation
 351        )
 352        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
 353        self.dropout = nn.Dropout(dropout_rate)
 354        self.fc1_f = nn.ModuleList(
 355            [nn.Linear(out_channels, out_channels // 2) for _ in range(6)]
 356        )
 357        self.fc2_f = nn.ModuleList(
 358            [nn.Linear(out_channels // 2, out_channels) for _ in range(6)]
 359        )
 360        # self.fc1_t = nn.ModuleList([nn.Linear(len_segment, len_segment // 2) for _ in range(6)])
 361        # self.fc2_t = nn.ModuleList([nn.Linear(len_segment // 2, len_segment) for _ in range(6)])
 362
 363    def _fc1_f(self, tag):
 364        if tag is None:
 365            for i in range(1, len(self.fc1_f)):
 366                self.fc1_f[i].load_state_dict(self.fc1_f[0].state_dict())
 367            return self.fc1_f[0]
 368        else:
 369            return self.fc1_f[tag]
 370
 371    def _fc2_f(self, tag):
 372        if tag is None:
 373            for i in range(1, len(self.fc2_f)):
 374                self.fc2_f[i].load_state_dict(self.fc2_f[0].state_dict())
 375            return self.fc2_f[0]
 376        else:
 377            return self.fc2_f[tag]
 378
 379    def _fc1_t(self, tag):
 380        if tag is None:
 381            for i in range(1, len(self.fc1_t)):
 382                self.fc1_t[i].load_state_dict(self.fc1_t[0].state_dict())
 383            return self.fc1_t[0]
 384        else:
 385            return self.fc1_t[tag]
 386
 387    def _fc2_t(self, tag):
 388        if tag is None:
 389            for i in range(1, len(self.fc2_t)):
 390                self.fc2_t[i].load_state_dict(self.fc2_t[0].state_dict())
 391            return self.fc2_t[0]
 392        else:
 393            return self.fc2_t[tag]
 394
 395    def forward(self, x, direction, tag):
 396        """Forward pass."""
 397        if direction is not None and not self.causal:
 398            raise ValueError("Cannot set direction in a non-causal layer!")
 399        elif direction is None and self.causal:
 400            direction = "forward"
 401        if direction == "forward":
 402            padding = (0, self.padding)
 403        elif direction == "backward":
 404            padding = (self.padding, 0)
 405        elif direction is not None:
 406            raise ValueError(
 407                f"Unrecognized direction: {direction}, please choose from"
 408                f'"backward", "forward" and None'
 409            )
 410        if direction is not None:
 411            out = self.conv_dilated(F.pad(x, padding))
 412        else:
 413            out = self.conv_dilated(x)
 414        out = F.relu(out)
 415        out = self.conv_1x1(out)
 416        out = self.dropout(out)
 417        scale = torch.mean(out, -1)
 418        scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale)))
 419        scale = F.sigmoid(scale).unsqueeze(-1)
 420        # time_scale = torch.mean(out, 1)
 421        # time_scale = self._fc2_t(tag)(F.relu(self._fc1_t(tag)(time_scale)))
 422        # time_scale = F.sigmoid(time_scale).unsqueeze(1)
 423        out = out * scale  # * time_scale
 424        return x + out
 425
 426
 427class DilatedResidualLayer_SEC(nn.Module):
 428    """
 429    Dilated residual layer
 430    """
 431
 432    def __init__(
 433        self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment
 434    ):
 435        """
 436        Parameters
 437        ----------
 438        dilation : int
 439            dilation
 440        in_channels : int
 441            number of input channels
 442        out_channels : int
 443            number of output channels
 444        dropout_rate : float
 445            dropout rate
 446        causal : bool
 447            if `True`, causal convolutions are used
 448        """
 449
 450        super().__init__()
 451        self.padding = dilation * 2
 452        self.causal = causal
 453        if self.causal:
 454            padding = 0
 455        else:
 456            padding = dilation
 457        self.conv_dilated = nn.Conv1d(
 458            in_channels, out_channels, 3, padding=padding, dilation=dilation
 459        )
 460        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
 461        self.dropout = nn.Dropout(dropout_rate)
 462        self.conv_sq = nn.ModuleList(
 463            [
 464                nn.Conv1d(out_channels, 1, 3, padding=padding, dilation=dilation)
 465                for _ in range(6)
 466            ]
 467        )
 468
 469    def _conv_sq(self, tag):
 470        if tag is None:
 471            for i in range(1, len(self.conv_sq)):
 472                self.conv_sq[i].load_state_dict(self.conv_sq[0].state_dict())
 473            return self.conv_sq[0]
 474        else:
 475            return self.conv_sq[tag]
 476
 477    def forward(self, x, direction, tag):
 478        """Forward pass."""
 479        if direction is not None and not self.causal:
 480            raise ValueError("Cannot set direction in a non-causal layer!")
 481        elif direction is None and self.causal:
 482            direction = "forward"
 483        if direction == "forward":
 484            padding = (0, self.padding)
 485        elif direction == "backward":
 486            padding = (self.padding, 0)
 487        elif direction is not None:
 488            raise ValueError(
 489                f"Unrecognized direction: {direction}, please choose from"
 490                f'"backward", "forward" and None'
 491            )
 492        if direction is not None:
 493            out = self.conv_dilated(F.pad(x, padding))
 494        else:
 495            out = self.conv_dilated(x)
 496        out = F.relu(out)
 497        out = self.conv_1x1(out)
 498        out = self.dropout(out)
 499        scale = torch.sigmoid(self._conv_sq(tag)(out))
 500        out = out * scale
 501        return x + out
 502
 503
 504class DualDilatedResidualLayer(nn.Module):
 505    """
 506    Dual dilated residual layer
 507    """
 508
 509    def __init__(
 510        self,
 511        dilation1,
 512        dilation2,
 513        in_channels,
 514        out_channels,
 515        dropout_rate,
 516        causal,
 517        kernel_size=3,
 518    ):
 519        """
 520        Parameters
 521        ----------
 522        dilation1, dilation2 : int
 523            dilation of one of the blocks
 524        in_channels : int
 525            number of input channels
 526        out_channels : int
 527            number of output channels
 528        dropout_rate : float
 529            dropout rate
 530        causal : bool
 531            if `True`, causal convolutions are used
 532        kernel_size : int, default 3
 533            kernel size
 534        """
 535
 536        super(DualDilatedResidualLayer, self).__init__()
 537        self.causal = causal
 538        self.padding1 = dilation1 * (kernel_size - 1) // 2
 539        self.padding2 = dilation2 * (kernel_size - 1) // 2
 540        if self.causal:
 541            self.padding1 *= 2
 542            self.padding2 *= 2
 543        self.conv_dilated_1 = nn.Conv1d(
 544            in_channels,
 545            out_channels,
 546            kernel_size=kernel_size,
 547            padding=self.padding1,
 548            dilation=dilation1,
 549        )
 550        self.conv_dilated_2 = nn.Conv1d(
 551            in_channels,
 552            out_channels,
 553            kernel_size=kernel_size,
 554            padding=self.padding2,
 555            dilation=dilation2,
 556        )
 557        self.conv_fusion = nn.Conv1d(2 * out_channels, out_channels, 1)
 558        self.dropout = nn.Dropout(dropout_rate)
 559
 560    def forward(self, x, direction=None):
 561        """Forward pass."""
 562        if direction is not None and not self.causal:
 563            raise ValueError("Cannot set direction in a non-causal layer!")
 564        elif direction is None and self.causal:
 565            direction = "forward"
 566        if direction not in ["backward", "forward", None]:
 567            raise ValueError(
 568                f"Unrecognized direction: {direction}, please choose from"
 569                f'"backward", "forward" and None'
 570            )
 571        out_1 = self.conv_dilated_1(x)
 572        out_2 = self.conv_dilated_2(x)
 573        if direction == "forward":
 574            out_1 = out_1[:, :, : -self.padding1]
 575            out_2 = out_2[:, :, : -self.padding2]
 576        elif direction == "backward":
 577            out_1 = out_1[:, :, self.padding1 :]
 578            out_2 = out_2[:, :, self.padding2 :]
 579
 580        out = self.conv_fusion(
 581            torch.cat(
 582                [
 583                    out_1,
 584                    out_2,
 585                ],
 586                1,
 587            )
 588        )
 589        out = F.relu(out)
 590        out = self.dropout(out)
 591        out = out + x
 592        return out
 593
 594
 595class MultiDilatedTCN(nn.Module):
 596    """
 597    Multiple prediction generation stages in parallel
 598    """
 599
 600    def __init__(
 601        self,
 602        num_layers,
 603        num_f_maps,
 604        dims,
 605        direction,
 606        block_size=5,
 607        kernel_size=3,
 608        rare_dilations=False,
 609    ):
 610        super(MultiDilatedTCN, self).__init__()
 611        self.PGs = nn.ModuleList(
 612            [
 613                DilatedTCN(
 614                    num_layers,
 615                    num_f_maps,
 616                    dim,
 617                    direction,
 618                    block_size,
 619                    kernel_size,
 620                    rare_dilations,
 621                )
 622                for dim in dims
 623            ]
 624        )
 625        self.conv_out = nn.Conv1d(num_f_maps * len(dims), num_f_maps, 1)
 626
 627    def forward(self, x, tag=None):
 628        """Forward pass."""
 629        out = []
 630        for arr, PG in zip(x, self.PGs):
 631            out.append(PG(arr))
 632        out = torch.cat(out, 1)
 633        out = self.conv_out(out)
 634        return out
 635
 636
 637class DilatedTCN(nn.Module):
 638    """
 639    Prediction generation stage
 640    """
 641
 642    def __init__(
 643        self,
 644        num_layers,
 645        num_f_maps,
 646        dim,
 647        direction,
 648        num_bp=None,
 649        block_size=5,
 650        kernel_size=3,
 651        rare_dilations=False,
 652        attention="none",
 653        multihead=False,
 654    ):
 655        """
 656        Parameters
 657        ----------
 658        num_layers : int
 659            number of layers
 660        num_f_maps : int
 661            number of feature maps
 662        dim : int
 663            number of features in input
 664        direction : [None, 'forward', 'backward']
 665            the direction of convolutions; if None, regular convolutions are used
 666        block_size : int, default 0
 667            if not 0, skip connections are added to the prediction generation stage with this interval
 668        kernel_size : int, default 3
 669            kernel size
 670        rare_dilations : bool, default False
 671            if `False`, dilation increases every layer, otherwise every second layer
 672        """
 673
 674        super().__init__()
 675        self.num_layers = num_layers
 676        self.block_size = block_size
 677        self.direction = direction
 678        self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1)
 679        pars = {
 680            "in_channels": num_f_maps,
 681            "out_channels": num_f_maps,
 682            "dropout_rate": 0.5,
 683            "causal": (direction is not None),
 684            "kernel_size": kernel_size,
 685        }
 686        module = DualDilatedResidualLayer
 687        if not rare_dilations:
 688            self.layers = nn.ModuleList(
 689                [
 690                    module(dilation1=2 ** (num_layers - 1 - i), dilation2=2**i, **pars)
 691                    for i in range(num_layers)
 692                ]
 693            )
 694        else:
 695            self.layers = nn.ModuleList(
 696                [
 697                    module(
 698                        dilation1=2 ** (num_layers // 2 - 1 - i // 2),
 699                        dilation2=2 ** (i // 2),
 700                        **pars,
 701                    )
 702                    for i in range(num_layers)
 703                ]
 704            )
 705        self.attention_layers = nn.ModuleList([])
 706        self.attention = attention
 707        if self.attention in ["basic"]:
 708            self.attention_layers += nn.ModuleList(
 709                [
 710                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
 711                    nn.ReLU(),
 712                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
 713                    nn.Sigmoid(),
 714                ]
 715            )
 716        elif isinstance(self.attention, str) and self.attention != "none":
 717            self.attention_layers = AttLayer(
 718                num_f_maps,
 719                num_f_maps,
 720                num_f_maps,
 721                4,
 722                4,
 723                2,
 724                64,
 725                att_type=self.attention,
 726                stage="encoder",
 727            )
 728        self.multihead = multihead
 729
 730    def forward(self, x, tag=None):
 731        """Forward pass."""
 732        x = self.conv_1x1_in(x)
 733        f = copy.copy(x)
 734        for i, layer in enumerate(self.layers):
 735            f = layer(f, self.direction)
 736            if self.block_size != 0 and (i + 1) % self.block_size == 0:
 737                f = f + x
 738                x = copy.copy(f)
 739        if isinstance(self.attention, str) and self.attention != "none":
 740            x = copy.copy(f)
 741            if self.attention != "basic":
 742                f = self.attention_layers(x)
 743            elif not self.multihead:
 744                for layer in self.attention_layers:
 745                    x = layer(x)
 746                f = f * x
 747            else:
 748                outputs = []
 749                for layers in self.attention_layers:
 750                    y = copy.copy(x)
 751                    for layer in layers:
 752                        y = layer(y)
 753                    outputs.append(copy.copy(f * y))
 754                outputs = torch.cat(outputs, dim=1)
 755                f = self.conv_att(self.dropout(outputs))
 756        return f
 757
 758
 759class DilatedTCNB(nn.Module):
 760    """
 761    Bidirectional prediction generation stage
 762    """
 763
 764    def __init__(
 765        self,
 766        num_layers,
 767        num_f_maps,
 768        dim,
 769        block_size=5,
 770        kernel_size=3,
 771        rare_dilations=False,
 772    ):
 773        """
 774        Parameters
 775        ----------
 776        num_layers : int
 777            number of layers
 778        num_f_maps : int
 779            number of feature maps
 780        dim : int
 781            number of features in input
 782        block_size : int, default 0
 783            if not 0, skip connections are added to the prediction generation stage with this interval
 784        kernel_size : int, default 3
 785            kernel size
 786        rare_dilations : bool, default False
 787            if `False`, dilation increases every layer, otherwise every second layer
 788        """
 789
 790        super().__init__()
 791        self.num_layers = num_layers
 792        self.block_size = block_size
 793        self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1)
 794        self.conv_1x1_out = nn.Conv1d(num_f_maps * 2, num_f_maps, 1)
 795        if not rare_dilations:
 796            self.layers = nn.ModuleList(
 797                [
 798                    DualDilatedResidualLayer(
 799                        dilation1=2 ** (num_layers - 1 - i),
 800                        dilation2=2**i,
 801                        in_channels=num_f_maps,
 802                        out_channels=num_f_maps,
 803                        dropout_rate=0.5,
 804                        causal=True,
 805                        kernel_size=kernel_size,
 806                    )
 807                    for i in range(num_layers)
 808                ]
 809            )
 810        else:
 811            self.layers = nn.ModuleList(
 812                [
 813                    DualDilatedResidualLayer(
 814                        dilation1=2 ** (num_layers // 2 - 1 - i // 2),
 815                        dilation2=2 ** (i // 2),
 816                        in_channels=num_f_maps,
 817                        out_channels=num_f_maps,
 818                        dropout_rate=0.5,
 819                        causal=True,
 820                        kernel_size=kernel_size,
 821                    )
 822                    for i in range(num_layers)
 823                ]
 824            )
 825
 826    def forward(self, x, tag=None):
 827        """Forward pass."""
 828        x_f = self.conv_1x1_in(x)
 829        x_b = copy.copy(x_f)
 830        forward = copy.copy(x_f)
 831        backward = copy.copy(x_f)
 832        for i, layer_f in enumerate(self.layers):
 833            forward = layer_f(forward, "forward")
 834            backward = layer_f(backward, "backward")
 835            if self.block_size != 0 and (i + 1) % self.block_size == 0:
 836                forward = forward + x_f
 837                backward = backward + x_b
 838                x_f = copy.copy(forward)
 839                x_b = copy.copy(backward)
 840        out = torch.cat([forward, backward], 1)
 841        out = self.conv_1x1_out(out)
 842        return out
 843
 844
 845class SpatialFeatures(nn.Module):
 846    """
 847    Spatial features extraction stage
 848    """
 849
 850    def __init__(
 851        self,
 852        num_layers,
 853        num_f_maps,
 854        dim,
 855        block_size=5,
 856        graph_edges=None,
 857        num_nodes=None,
 858        denom: int = 8,
 859    ):
 860        """
 861        Parameters
 862        ----------
 863        num_layers : int
 864            number of layers
 865        num_f_maps : int
 866            number of feature maps
 867        dim : int
 868            number of features in input
 869        block_size : int, default 5
 870            if not 0, skip connections are added to the prediction generation stage with this interval
 871        """
 872
 873        super().__init__()
 874        self.num_nodes = num_nodes
 875        if graph_edges is None:
 876            module = SimpleResidualLayer
 877            self.graph = False
 878            pars = {"num_f_maps": num_f_maps, "dropout_rate": 0.5}
 879            self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1)
 880        else:
 881            raise NotImplementedError("Graph not implemented")
 882        
 883        self.num_layers = num_layers
 884        self.block_size = block_size
 885        self.layers = nn.ModuleList([module(**pars) for _ in range(num_layers)])
 886
 887    def forward(self, x):
 888        """Forward pass."""
 889        if self.graph:
 890            B, _, L = x.shape
 891            x = x.transpose(-1, -2)
 892            x = x.reshape((-1, x.shape[-1]))
 893            x = x.reshape((x.shape[0], self.num_nodes, -1))
 894            x = x.transpose(-1, -2)
 895        x = self.conv_1x1_in(x)
 896        f = copy.copy(x)
 897        for i, layer in enumerate(self.layers):
 898            f = layer(f)
 899            if self.block_size != 0 and (i + 1) % self.block_size == 0:
 900                f = f + x
 901                x = copy.copy(f)
 902        if self.graph:
 903            f = f.reshape((B, L, -1))
 904            f = f.transpose(-1, -2)
 905            f = self.conv_1x1_out(f)
 906        return f
 907
 908
 909class MSRefinement(nn.Module):
 910    """
 911    Refinement stage
 912    """
 913
 914    def __init__(
 915        self,
 916        num_layers_R,
 917        num_R,
 918        num_f_maps_input,
 919        num_f_maps,
 920        num_classes,
 921        dropout_rate,
 922        exclusive,
 923        skip_connections,
 924        direction,
 925        block_size=0,
 926        num_heads=1,
 927        attention="none",
 928    ):
 929        """
 930        Parameters
 931        ----------
 932        num_layers_R : int
 933            number of layers in refinement modules
 934        num_R : int
 935            number of refinement modules
 936        num_f_maps : int
 937            number of feature maps
 938        num_classes : int
 939            number of target classes
 940        dropout_rate : float
 941            dropout rate
 942        exclusive : bool
 943            set `False` for multi-label classification
 944        skip_connections : bool
 945            if `True`, skip connections are added
 946        direction : [None, 'bidirectional', 'forward', 'backward']
 947            the direction of convolutions; if None, regular convolutions are used
 948        block_size : int, default 0
 949            if not 0, skip connections are added to the prediction generation stage with this interval
 950        num_heads : int, default 1
 951            number of parallel refinement stages
 952        """
 953
 954        super().__init__()
 955        self.skip_connections = skip_connections
 956        self.num_heads = num_heads
 957        if exclusive:
 958            self.nl = lambda x: F.softmax(x, dim=1)
 959        else:
 960            self.nl = lambda x: torch.sigmoid(x)
 961        if direction == "bidirectional":
 962            refinement_module = RefinementB
 963        else:
 964            refinement_module = Refinement
 965        self.Rs = nn.ModuleList(
 966            [
 967                nn.ModuleList(
 968                    [
 969                        refinement_module(
 970                            num_layers=num_layers_R,
 971                            num_f_maps=num_f_maps,
 972                            num_f_maps_input=num_f_maps_input,
 973                            dim=num_classes,
 974                            num_classes=num_classes,
 975                            dropout_rate=dropout_rate,
 976                            direction=direction,
 977                            skip_connections=skip_connections,
 978                            block_size=block_size,
 979                            attention=attention,
 980                        )
 981                        for s in range(num_R)
 982                    ]
 983                )
 984                for _ in range(self.num_heads)
 985            ]
 986        )
 987        self.conv_out = nn.ModuleList(
 988            [nn.Conv1d(num_f_maps_input, num_classes, 1) for _ in range(self.num_heads)]
 989        )
 990        if self.num_heads == 1:
 991            self.Rs = self.Rs[0]
 992            self.conv_out = self.conv_out[0]
 993
 994    def _Rs(self, tag):
 995        if self.num_heads == 1:
 996            return self.Rs
 997        if tag is None:
 998            tag = 0
 999            for i in range(1, self.num_heads):
1000                self.Rs[i].load_state_dict(self.Rs[0].state_dict())
1001        return self.Rs[tag]
1002
1003    def _conv_out(self, tag):
1004        if self.num_heads == 1:
1005            return self.conv_out
1006        if tag is None:
1007            tag = 0
1008            for i in range(1, self.num_heads):
1009                self.conv_out[i].load_state_dict(self.conv_out[0].state_dict())
1010        return self.conv_out[tag]
1011
1012    def forward(self, x, tag=None):
1013        """Forward pass."""
1014        if tag is not None:
1015            tag = tag[0]
1016        out = self._conv_out(tag)(x)
1017        outputs = out.unsqueeze(0)
1018        for R in self._Rs(tag):
1019            if self.skip_connections:
1020                out = R(torch.cat([self.nl(out), x], axis=1))
1021                # out = R(torch.cat([out, x], axis=1))
1022            else:
1023                out = R(self.nl(out))
1024                # out = R(out)
1025            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
1026
1027        return outputs
1028
1029
1030class MSRefinementShared(nn.Module):
1031    """
1032    Refinement stage with shared weights across modules
1033    """
1034
1035    def __init__(
1036        self,
1037        num_layers_R,
1038        num_R,
1039        num_f_maps_input,
1040        num_f_maps,
1041        num_classes,
1042        dropout_rate,
1043        exclusive,
1044        skip_connections,
1045        direction,
1046        block_size=0,
1047        num_heads=1,
1048        attention="none",
1049    ):
1050        """
1051        Parameters
1052        ----------
1053        num_layers_R : int
1054            number of layers in refinement modules
1055        num_R : int
1056            number of refinement modules
1057        num_f_maps : int
1058            number of feature maps
1059        num_classes : int
1060            number of target classes
1061        dropout_rate : float
1062            dropout rate
1063        exclusive : bool
1064            set `False` for multi-label classification
1065        skip_connections : bool
1066            if `True`, skip connections are added
1067        direction : [None, 'bidirectional', 'forward', 'backward']
1068            the direction of convolutions; if None, regular convolutions are used
1069        block_size : int, default 0
1070            if not 0, skip connections are added to the prediction generation stage with this interval
1071        num_heads : int, default 1
1072            number of parallel refinement stages
1073        """
1074
1075        super().__init__()
1076        if exclusive:
1077            self.nl = lambda x: F.softmax(x, dim=1)
1078        else:
1079            self.nl = lambda x: torch.sigmoid(x)
1080        if direction == "bidirectional":
1081            refinement_module = RefinementB
1082        else:
1083            refinement_module = Refinement
1084        self.num_heads = num_heads
1085        self.R = nn.ModuleList(
1086            [
1087                refinement_module(
1088                    num_layers=num_layers_R,
1089                    num_f_maps_input=num_f_maps_input,
1090                    num_f_maps=num_f_maps,
1091                    dim=num_classes,
1092                    num_classes=num_classes,
1093                    dropout_rate=dropout_rate,
1094                    direction=direction,
1095                    skip_connections=skip_connections,
1096                    block_size=block_size,
1097                    attention=attention,
1098                )
1099                for _ in range(self.num_heads)
1100            ]
1101        )
1102        self.num_R = num_R
1103        self.conv_out = nn.ModuleList(
1104            [nn.Conv1d(num_f_maps_input, num_classes, 1) for _ in range(self.num_heads)]
1105        )
1106        self.skip_connections = skip_connections
1107        if self.num_heads == 1:
1108            self.R = self.R[0]
1109            self.conv_out = self.conv_out[0]
1110
1111    def _R(self, tag):
1112        if self.num_heads == 1:
1113            return self.R
1114        if tag is None:
1115            tag = 0
1116            for i in range(1, self.num_heads):
1117                self.R[i].load_state_dict(self.R[0].state_dict())
1118        return self.R[tag]
1119
1120    def _conv_out(self, tag):
1121        if self.num_heads == 1:
1122            return self.conv_out
1123        if tag is None:
1124            tag = 0
1125            for i in range(1, self.num_heads):
1126                self.conv_out[i].load_state_dict(self.conv_out[0].state_dict())
1127        return self.conv_out[tag]
1128
1129    def forward(self, x, tag=None):
1130        """Forward pass."""
1131        if tag is not None:
1132            tag = tag[0]
1133        out = self._conv_out(tag)(x)
1134        outputs = out.unsqueeze(0)
1135        for _ in range(self.num_R):
1136            if self.skip_connections:
1137                # out = self._R(tag)(torch.cat([self.nl(out), x], axis=1))
1138                out = self._R(tag)(torch.cat([out, x], axis=1))
1139            else:
1140                # out = self._R(tag)(self.nl(out))
1141                out = self._R(tag)(out)
1142            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
1143
1144        return outputs
1145
1146
1147class MSRefinementAttention(nn.Module):
1148    """
1149    Refinement stage
1150    """
1151
1152    def __init__(
1153        self,
1154        num_layers_R,
1155        num_R,
1156        num_f_maps_input,
1157        num_f_maps,
1158        num_classes,
1159        dropout_rate,
1160        exclusive,
1161        skip_connections,
1162        len_segment,
1163        block_size=0,
1164    ):
1165        """
1166        Parameters
1167        ----------
1168        num_layers_R : int
1169            number of layers in refinement modules
1170        num_R : int
1171            number of refinement modules
1172        num_f_maps : int
1173            number of feature maps
1174        num_classes : int
1175            number of target classes
1176        dropout_rate : float
1177            dropout rate
1178        exclusive : bool
1179            set `False` for multi-label classification
1180        skip_connections : bool
1181            if `True`, skip connections are added
1182        direction : [None, 'bidirectional', 'forward', 'backward']
1183            the direction of convolutions; if None, regular convolutions are used
1184        block_size : int, default 0
1185            if not 0, skip connections are added to the prediction generation stage with this interval
1186        num_heads : int, default 1
1187            number of parallel refinement stages
1188        """
1189
1190        super().__init__()
1191        self.skip_connections = skip_connections
1192        if exclusive:
1193            self.nl = lambda x: F.softmax(x, dim=1)
1194        else:
1195            self.nl = lambda x: torch.sigmoid(x)
1196        refinement_module = Refinement_SE
1197        self.Rs = nn.ModuleList(
1198            [
1199                refinement_module(
1200                    num_layers=num_layers_R,
1201                    num_f_maps=num_f_maps,
1202                    num_f_maps_input=num_f_maps_input,
1203                    dim=num_classes,
1204                    num_classes=num_classes,
1205                    dropout_rate=dropout_rate,
1206                    direction=None,
1207                    skip_connections=skip_connections,
1208                    block_size=block_size,
1209                    len_segment=len_segment,
1210                )
1211                for s in range(num_R)
1212            ]
1213        )
1214        self.conv_out = nn.Conv1d(num_f_maps_input, num_classes, 1)
1215
1216    def forward(self, x, tag=None):
1217        """Forward pass."""
1218        if tag is not None:
1219            tag = tag[0]
1220        out = self.conv_out(x)
1221        outputs = out.unsqueeze(0)
1222        for R in self.Rs:
1223            if self.skip_connections:
1224                out = R(torch.cat([self.nl(out), x], axis=1), tag)
1225                # out = R(torch.cat([out, x], axis=1), tag)
1226            else:
1227                out = R(self.nl(out), tag)
1228                # out = R(out, tag)
1229            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
1230
1231        return outputs
1232
1233
1234class DilatedTCNC(nn.Module):
1235    def __init__(
1236        self,
1237        num_f_maps,
1238        num_layers_PG,
1239        len_segment,
1240        block_size_prediction=5,
1241        kernel_size_prediction=3,
1242        direction_PG=None,
1243    ):
1244        super(DilatedTCNC, self).__init__()
1245        if direction_PG == "bidirectional":
1246            self.PG_S = DilatedTCNB(
1247                num_layers=num_layers_PG,
1248                num_f_maps=num_f_maps,
1249                dim=num_f_maps,
1250                block_size=block_size_prediction,
1251            )
1252            self.PG_T = DilatedTCNB(
1253                num_layers=num_layers_PG,
1254                num_f_maps=len_segment,
1255                dim=len_segment,
1256                block_size=block_size_prediction,
1257            )
1258        else:
1259            self.PG_S = DilatedTCN(
1260                num_layers=num_layers_PG,
1261                num_f_maps=num_f_maps,
1262                dim=num_f_maps,
1263                direction=direction_PG,
1264                block_size=block_size_prediction,
1265                kernel_size=kernel_size_prediction,
1266            )
1267            self.PG_T = DilatedTCN(
1268                num_layers=num_layers_PG,
1269                num_f_maps=len_segment,
1270                dim=len_segment,
1271                direction=direction_PG,
1272                block_size=block_size_prediction,
1273                kernel_size=kernel_size_prediction,
1274            )
1275
1276    def forward(self, x):
1277        """Forward pass."""
1278        x = self.PG_S(x)
1279        x = torch.transpose(x, 1, 2)
1280        x = self.PG_T(x)
1281        x = torch.transpose(x, 1, 2)
1282        return x
class Refinement(torch.nn.modules.module.Module):
 16class Refinement(nn.Module):
 17    """
 18    Refinement module
 19    """
 20
 21    def __init__(
 22        self,
 23        num_layers,
 24        num_f_maps_input,
 25        num_f_maps,
 26        dim,
 27        num_classes,
 28        dropout_rate,
 29        direction,
 30        skip_connections,
 31        attention="none",
 32        block_size=0,
 33    ):
 34        """
 35        Parameters
 36        ----------
 37        num_layers : int
 38            the number of layers
 39        num_f_maps : int
 40            the number of feature maps
 41        dim : int
 42            the number of features in input
 43        num_classes : int
 44            the number of target classes
 45        dropout_rate : float
 46            dropout rate
 47        direction : [None, 'forward', 'backward']
 48            the direction of convolutions; if None, regular convolutions are used
 49        skip_connections : bool
 50            if `True`, skip connections are added
 51        block_size : int, default 0
 52            if not 0, skip connections are added to the prediction generation stage with this interval
 53        """
 54
 55        super(Refinement, self).__init__()
 56        self.block_size = block_size
 57        self.direction = direction
 58        if skip_connections:
 59            self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1)
 60        else:
 61            self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
 62        self.layers = nn.ModuleList(
 63            [
 64                copy.deepcopy(
 65                    DilatedResidualLayer(
 66                        dilation=2**i,
 67                        in_channels=num_f_maps,
 68                        out_channels=num_f_maps,
 69                        dropout_rate=dropout_rate,
 70                        causal=(direction is not None),
 71                    )
 72                )
 73                for i in range(num_layers)
 74            ]
 75        )
 76        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
 77        self.attention_layers = nn.ModuleList([])
 78        self.attention = attention
 79        if self.attention == "basic":
 80            self.attention_layers += nn.ModuleList(
 81                [
 82                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
 83                    nn.ReLU(),
 84                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
 85                    nn.Sigmoid(),
 86                ]
 87            )
 88
 89    def forward(self, x):
 90        """Forward pass."""
 91        x = self.conv_1x1(x)
 92        out = copy.copy(x)
 93        for l_i, layer in enumerate(self.layers):
 94            out = layer(out, self.direction)
 95            if self.block_size != 0 and (l_i + 1) % self.block_size == 0:
 96                out = out + x
 97                x = copy.copy(out)
 98        if self.attention != "none":
 99            x = copy.copy(out)
100            for layer in self.attention_layers:
101                x = layer(x)
102            out = out * x
103        out = self.conv_out(out)
104        return out

Refinement module

Refinement( num_layers, num_f_maps_input, num_f_maps, dim, num_classes, dropout_rate, direction, skip_connections, attention='none', block_size=0)
21    def __init__(
22        self,
23        num_layers,
24        num_f_maps_input,
25        num_f_maps,
26        dim,
27        num_classes,
28        dropout_rate,
29        direction,
30        skip_connections,
31        attention="none",
32        block_size=0,
33    ):
34        """
35        Parameters
36        ----------
37        num_layers : int
38            the number of layers
39        num_f_maps : int
40            the number of feature maps
41        dim : int
42            the number of features in input
43        num_classes : int
44            the number of target classes
45        dropout_rate : float
46            dropout rate
47        direction : [None, 'forward', 'backward']
48            the direction of convolutions; if None, regular convolutions are used
49        skip_connections : bool
50            if `True`, skip connections are added
51        block_size : int, default 0
52            if not 0, skip connections are added to the prediction generation stage with this interval
53        """
54
55        super(Refinement, self).__init__()
56        self.block_size = block_size
57        self.direction = direction
58        if skip_connections:
59            self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1)
60        else:
61            self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
62        self.layers = nn.ModuleList(
63            [
64                copy.deepcopy(
65                    DilatedResidualLayer(
66                        dilation=2**i,
67                        in_channels=num_f_maps,
68                        out_channels=num_f_maps,
69                        dropout_rate=dropout_rate,
70                        causal=(direction is not None),
71                    )
72                )
73                for i in range(num_layers)
74            ]
75        )
76        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
77        self.attention_layers = nn.ModuleList([])
78        self.attention = attention
79        if self.attention == "basic":
80            self.attention_layers += nn.ModuleList(
81                [
82                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
83                    nn.ReLU(),
84                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
85                    nn.Sigmoid(),
86                ]
87            )

Parameters

num_layers : int the number of layers num_f_maps : int the number of feature maps dim : int the number of features in input num_classes : int the number of target classes dropout_rate : float dropout rate direction : [None, 'forward', 'backward'] the direction of convolutions; if None, regular convolutions are used skip_connections : bool if True, skip connections are added block_size : int, default 0 if not 0, skip connections are added to the prediction generation stage with this interval

block_size
direction
layers
conv_out
attention_layers
attention
def forward(self, x):
 89    def forward(self, x):
 90        """Forward pass."""
 91        x = self.conv_1x1(x)
 92        out = copy.copy(x)
 93        for l_i, layer in enumerate(self.layers):
 94            out = layer(out, self.direction)
 95            if self.block_size != 0 and (l_i + 1) % self.block_size == 0:
 96                out = out + x
 97                x = copy.copy(out)
 98        if self.attention != "none":
 99            x = copy.copy(out)
100            for layer in self.attention_layers:
101                x = layer(x)
102            out = out * x
103        out = self.conv_out(out)
104        return out

Forward pass.

class Refinement_SE(torch.nn.modules.module.Module):
107class Refinement_SE(nn.Module):
108    """
109    Refinement module
110    """
111
112    def __init__(
113        self,
114        num_layers,
115        num_f_maps_input,
116        num_f_maps,
117        dim,
118        num_classes,
119        dropout_rate,
120        direction,
121        skip_connections,
122        len_segment,
123        block_size=0,
124    ):
125        """
126        Parameters
127        ----------
128        num_layers : int
129            the number of layers
130        num_f_maps : int
131            the number of feature maps
132        dim : int
133            the number of features in input
134        num_classes : int
135            the number of target classes
136        dropout_rate : float
137            dropout rate
138        direction : [None, 'forward', 'backward']
139            the direction of convolutions; if None, regular convolutions are used
140        skip_connections : bool
141            if `True`, skip connections are added
142        block_size : int, default 0
143            if not 0, skip connections are added to the prediction generation stage with this interval
144        """
145
146        super().__init__()
147        self.block_size = block_size
148        self.direction = direction
149        if skip_connections:
150            self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1)
151        else:
152            self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
153        self.layers = nn.ModuleList(
154            [
155                DilatedResidualLayer(
156                    dilation=2**i,
157                    in_channels=num_f_maps,
158                    out_channels=num_f_maps,
159                    dropout_rate=dropout_rate,
160                    causal=(direction is not None),
161                )
162                for i in range(num_layers)
163            ]
164        )
165        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
166        self.fc1_f = nn.ModuleList(
167            [nn.Linear(num_f_maps, num_f_maps // 2) for _ in range(6)]
168        )
169        self.fc2_f = nn.ModuleList(
170            [nn.Linear(num_f_maps // 2, num_f_maps) for _ in range(6)]
171        )
172
173    def _fc1_f(self, tag):
174        if tag is None:
175            for i in range(1, len(self.fc1_f)):
176                self.fc1_f[i].load_state_dict(self.fc1_f[0].state_dict())
177            return self.fc1_f[0]
178        else:
179            return self.fc1_f[tag]
180
181    def _fc2_f(self, tag):
182        if tag is None:
183            for i in range(1, len(self.fc2_f)):
184                self.fc2_f[i].load_state_dict(self.fc2_f[0].state_dict())
185            return self.fc2_f[0]
186        else:
187            return self.fc2_f[tag]
188
189    def forward(self, x, tag):
190        """Forward pass."""
191        x = self.conv_1x1(x)
192        scale = torch.mean(x, -1)
193        scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale)))
194        scale = F.sigmoid(scale).unsqueeze(-1)
195        x = scale * x
196        out = copy.copy(x)
197        for l_i, layer in enumerate(self.layers):
198            out = layer(out, self.direction)
199            if self.block_size != 0 and (l_i + 1) % self.block_size == 0:
200                out = out + x
201                x = copy.copy(out)
202        out = self.conv_out(out)
203        return out

Refinement module

Refinement_SE( num_layers, num_f_maps_input, num_f_maps, dim, num_classes, dropout_rate, direction, skip_connections, len_segment, block_size=0)
112    def __init__(
113        self,
114        num_layers,
115        num_f_maps_input,
116        num_f_maps,
117        dim,
118        num_classes,
119        dropout_rate,
120        direction,
121        skip_connections,
122        len_segment,
123        block_size=0,
124    ):
125        """
126        Parameters
127        ----------
128        num_layers : int
129            the number of layers
130        num_f_maps : int
131            the number of feature maps
132        dim : int
133            the number of features in input
134        num_classes : int
135            the number of target classes
136        dropout_rate : float
137            dropout rate
138        direction : [None, 'forward', 'backward']
139            the direction of convolutions; if None, regular convolutions are used
140        skip_connections : bool
141            if `True`, skip connections are added
142        block_size : int, default 0
143            if not 0, skip connections are added to the prediction generation stage with this interval
144        """
145
146        super().__init__()
147        self.block_size = block_size
148        self.direction = direction
149        if skip_connections:
150            self.conv_1x1 = nn.Conv1d(dim + num_f_maps_input, num_f_maps, 1)
151        else:
152            self.conv_1x1 = nn.Conv1d(dim, num_f_maps, 1)
153        self.layers = nn.ModuleList(
154            [
155                DilatedResidualLayer(
156                    dilation=2**i,
157                    in_channels=num_f_maps,
158                    out_channels=num_f_maps,
159                    dropout_rate=dropout_rate,
160                    causal=(direction is not None),
161                )
162                for i in range(num_layers)
163            ]
164        )
165        self.conv_out = nn.Conv1d(num_f_maps, num_classes, 1)
166        self.fc1_f = nn.ModuleList(
167            [nn.Linear(num_f_maps, num_f_maps // 2) for _ in range(6)]
168        )
169        self.fc2_f = nn.ModuleList(
170            [nn.Linear(num_f_maps // 2, num_f_maps) for _ in range(6)]
171        )

Parameters

num_layers : int the number of layers num_f_maps : int the number of feature maps dim : int the number of features in input num_classes : int the number of target classes dropout_rate : float dropout rate direction : [None, 'forward', 'backward'] the direction of convolutions; if None, regular convolutions are used skip_connections : bool if True, skip connections are added block_size : int, default 0 if not 0, skip connections are added to the prediction generation stage with this interval

block_size
direction
layers
conv_out
fc1_f
fc2_f
def forward(self, x, tag):
189    def forward(self, x, tag):
190        """Forward pass."""
191        x = self.conv_1x1(x)
192        scale = torch.mean(x, -1)
193        scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale)))
194        scale = F.sigmoid(scale).unsqueeze(-1)
195        x = scale * x
196        out = copy.copy(x)
197        for l_i, layer in enumerate(self.layers):
198            out = layer(out, self.direction)
199            if self.block_size != 0 and (l_i + 1) % self.block_size == 0:
200                out = out + x
201                x = copy.copy(out)
202        out = self.conv_out(out)
203        return out

Forward pass.

class RefinementB(Refinement):
206class RefinementB(Refinement):
207    """
208    Bidirectional refinement module
209    """
210
211    def forward(self, x):
212        """Forward pass."""
213        x_f = self.conv_1x1(x)
214        x_b = copy.copy(x_f)
215        forward = copy.copy(x_f)
216        backward = copy.copy(x_f)
217        for i, layer_f in enumerate(self.layers):
218            forward = layer_f(forward, "forward")
219            backward = layer_f(backward, "backward")
220            if self.block_size != 0 and (i + 1) % self.block_size == 0:
221                forward = forward + x_f
222                backward = backward + x_b
223                x_f = copy.copy(forward)
224                x_b = copy.copy(backward)
225        out = torch.cat([forward, backward], 1)
226        out = self.conv_out(out)
227        return out

Bidirectional refinement module

def forward(self, x):
211    def forward(self, x):
212        """Forward pass."""
213        x_f = self.conv_1x1(x)
214        x_b = copy.copy(x_f)
215        forward = copy.copy(x_f)
216        backward = copy.copy(x_f)
217        for i, layer_f in enumerate(self.layers):
218            forward = layer_f(forward, "forward")
219            backward = layer_f(backward, "backward")
220            if self.block_size != 0 and (i + 1) % self.block_size == 0:
221                forward = forward + x_f
222                backward = backward + x_b
223                x_f = copy.copy(forward)
224                x_b = copy.copy(backward)
225        out = torch.cat([forward, backward], 1)
226        out = self.conv_out(out)
227        return out

Forward pass.

class SimpleResidualLayer(torch.nn.modules.module.Module):
230class SimpleResidualLayer(nn.Module):
231    """
232    Basic residual layer
233    """
234
235    def __init__(self, num_f_maps, dropout_rate):
236        """
237        Parameters
238        ----------
239        in_channels : int
240            number of input channels
241        out_channels : int
242            number of output channels
243        dropout_rate : float
244            dropout rate
245        """
246
247        super().__init__()
248        self.conv_1x1_in = nn.Conv1d(num_f_maps, num_f_maps, 1)
249        self.conv_1x1 = nn.Conv1d(num_f_maps, num_f_maps, 1)
250        self.dropout = nn.Dropout(dropout_rate)
251
252    def forward(self, x):
253        """Forward pass."""
254        out = self.conv_1x1_in(x)
255        out = F.relu(out)
256        out = self.conv_1x1(out)
257        out = self.dropout(out)
258        return x + out

Basic residual layer

SimpleResidualLayer(num_f_maps, dropout_rate)
235    def __init__(self, num_f_maps, dropout_rate):
236        """
237        Parameters
238        ----------
239        in_channels : int
240            number of input channels
241        out_channels : int
242            number of output channels
243        dropout_rate : float
244            dropout rate
245        """
246
247        super().__init__()
248        self.conv_1x1_in = nn.Conv1d(num_f_maps, num_f_maps, 1)
249        self.conv_1x1 = nn.Conv1d(num_f_maps, num_f_maps, 1)
250        self.dropout = nn.Dropout(dropout_rate)

Parameters

in_channels : int number of input channels out_channels : int number of output channels dropout_rate : float dropout rate

conv_1x1_in
conv_1x1
dropout
def forward(self, x):
252    def forward(self, x):
253        """Forward pass."""
254        out = self.conv_1x1_in(x)
255        out = F.relu(out)
256        out = self.conv_1x1(out)
257        out = self.dropout(out)
258        return x + out

Forward pass.

class DilatedResidualLayer(torch.nn.modules.module.Module):
261class DilatedResidualLayer(nn.Module):
262    """
263    Dilated residual layer
264    """
265
266    def __init__(self, dilation, in_channels, out_channels, dropout_rate, causal):
267        """
268        Parameters
269        ----------
270        dilation : int
271            dilation
272        in_channels : int
273            number of input channels
274        out_channels : int
275            number of output channels
276        dropout_rate : float
277            dropout rate
278        causal : bool
279            if `True`, causal convolutions are used
280        """
281
282        super(DilatedResidualLayer, self).__init__()
283        self.padding = dilation * 2
284        self.causal = causal
285        if self.causal:
286            padding = 0
287        else:
288            padding = dilation
289        self.conv_dilated = nn.Conv1d(
290            in_channels, out_channels, 3, padding=padding, dilation=dilation
291        )
292        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
293        self.dropout = nn.Dropout(dropout_rate)
294
295    def forward(self, x, direction=None):
296        """Forward pass."""
297        if direction is not None and not self.causal:
298            raise ValueError("Cannot set direction in a non-causal layer!")
299        elif direction is None and self.causal:
300            direction = "forward"
301        if direction == "forward":
302            padding = (0, self.padding)
303        elif direction == "backward":
304            padding = (self.padding, 0)
305        elif direction is not None:
306            raise ValueError(
307                f"Unrecognized direction: {direction}, please choose from"
308                f'"backward", "forward" and None'
309            )
310        if direction is not None:
311            out = self.conv_dilated(F.pad(x, padding))
312        else:
313            out = self.conv_dilated(x)
314        out = F.relu(out)
315        out = self.conv_1x1(out)
316        out = self.dropout(out)
317        return x + out

Dilated residual layer

DilatedResidualLayer(dilation, in_channels, out_channels, dropout_rate, causal)
266    def __init__(self, dilation, in_channels, out_channels, dropout_rate, causal):
267        """
268        Parameters
269        ----------
270        dilation : int
271            dilation
272        in_channels : int
273            number of input channels
274        out_channels : int
275            number of output channels
276        dropout_rate : float
277            dropout rate
278        causal : bool
279            if `True`, causal convolutions are used
280        """
281
282        super(DilatedResidualLayer, self).__init__()
283        self.padding = dilation * 2
284        self.causal = causal
285        if self.causal:
286            padding = 0
287        else:
288            padding = dilation
289        self.conv_dilated = nn.Conv1d(
290            in_channels, out_channels, 3, padding=padding, dilation=dilation
291        )
292        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
293        self.dropout = nn.Dropout(dropout_rate)

Parameters

dilation : int dilation in_channels : int number of input channels out_channels : int number of output channels dropout_rate : float dropout rate causal : bool if True, causal convolutions are used

padding
causal
conv_dilated
conv_1x1
dropout
def forward(self, x, direction=None):
295    def forward(self, x, direction=None):
296        """Forward pass."""
297        if direction is not None and not self.causal:
298            raise ValueError("Cannot set direction in a non-causal layer!")
299        elif direction is None and self.causal:
300            direction = "forward"
301        if direction == "forward":
302            padding = (0, self.padding)
303        elif direction == "backward":
304            padding = (self.padding, 0)
305        elif direction is not None:
306            raise ValueError(
307                f"Unrecognized direction: {direction}, please choose from"
308                f'"backward", "forward" and None'
309            )
310        if direction is not None:
311            out = self.conv_dilated(F.pad(x, padding))
312        else:
313            out = self.conv_dilated(x)
314        out = F.relu(out)
315        out = self.conv_1x1(out)
316        out = self.dropout(out)
317        return x + out

Forward pass.

class DilatedResidualLayer_SE(torch.nn.modules.module.Module):
320class DilatedResidualLayer_SE(nn.Module):
321    """
322    Dilated residual layer
323    """
324
325    def __init__(
326        self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment
327    ):
328        """
329        Parameters
330        ----------
331        dilation : int
332            dilation
333        in_channels : int
334            number of input channels
335        out_channels : int
336            number of output channels
337        dropout_rate : float
338            dropout rate
339        causal : bool
340            if `True`, causal convolutions are used
341        """
342
343        super().__init__()
344        self.padding = dilation * 2
345        self.causal = causal
346        if self.causal:
347            padding = 0
348        else:
349            padding = dilation
350        self.conv_dilated = nn.Conv1d(
351            in_channels, out_channels, 3, padding=padding, dilation=dilation
352        )
353        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
354        self.dropout = nn.Dropout(dropout_rate)
355        self.fc1_f = nn.ModuleList(
356            [nn.Linear(out_channels, out_channels // 2) for _ in range(6)]
357        )
358        self.fc2_f = nn.ModuleList(
359            [nn.Linear(out_channels // 2, out_channels) for _ in range(6)]
360        )
361        # self.fc1_t = nn.ModuleList([nn.Linear(len_segment, len_segment // 2) for _ in range(6)])
362        # self.fc2_t = nn.ModuleList([nn.Linear(len_segment // 2, len_segment) for _ in range(6)])
363
364    def _fc1_f(self, tag):
365        if tag is None:
366            for i in range(1, len(self.fc1_f)):
367                self.fc1_f[i].load_state_dict(self.fc1_f[0].state_dict())
368            return self.fc1_f[0]
369        else:
370            return self.fc1_f[tag]
371
372    def _fc2_f(self, tag):
373        if tag is None:
374            for i in range(1, len(self.fc2_f)):
375                self.fc2_f[i].load_state_dict(self.fc2_f[0].state_dict())
376            return self.fc2_f[0]
377        else:
378            return self.fc2_f[tag]
379
380    def _fc1_t(self, tag):
381        if tag is None:
382            for i in range(1, len(self.fc1_t)):
383                self.fc1_t[i].load_state_dict(self.fc1_t[0].state_dict())
384            return self.fc1_t[0]
385        else:
386            return self.fc1_t[tag]
387
388    def _fc2_t(self, tag):
389        if tag is None:
390            for i in range(1, len(self.fc2_t)):
391                self.fc2_t[i].load_state_dict(self.fc2_t[0].state_dict())
392            return self.fc2_t[0]
393        else:
394            return self.fc2_t[tag]
395
396    def forward(self, x, direction, tag):
397        """Forward pass."""
398        if direction is not None and not self.causal:
399            raise ValueError("Cannot set direction in a non-causal layer!")
400        elif direction is None and self.causal:
401            direction = "forward"
402        if direction == "forward":
403            padding = (0, self.padding)
404        elif direction == "backward":
405            padding = (self.padding, 0)
406        elif direction is not None:
407            raise ValueError(
408                f"Unrecognized direction: {direction}, please choose from"
409                f'"backward", "forward" and None'
410            )
411        if direction is not None:
412            out = self.conv_dilated(F.pad(x, padding))
413        else:
414            out = self.conv_dilated(x)
415        out = F.relu(out)
416        out = self.conv_1x1(out)
417        out = self.dropout(out)
418        scale = torch.mean(out, -1)
419        scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale)))
420        scale = F.sigmoid(scale).unsqueeze(-1)
421        # time_scale = torch.mean(out, 1)
422        # time_scale = self._fc2_t(tag)(F.relu(self._fc1_t(tag)(time_scale)))
423        # time_scale = F.sigmoid(time_scale).unsqueeze(1)
424        out = out * scale  # * time_scale
425        return x + out

Dilated residual layer

DilatedResidualLayer_SE( dilation, in_channels, out_channels, dropout_rate, causal, len_segment)
325    def __init__(
326        self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment
327    ):
328        """
329        Parameters
330        ----------
331        dilation : int
332            dilation
333        in_channels : int
334            number of input channels
335        out_channels : int
336            number of output channels
337        dropout_rate : float
338            dropout rate
339        causal : bool
340            if `True`, causal convolutions are used
341        """
342
343        super().__init__()
344        self.padding = dilation * 2
345        self.causal = causal
346        if self.causal:
347            padding = 0
348        else:
349            padding = dilation
350        self.conv_dilated = nn.Conv1d(
351            in_channels, out_channels, 3, padding=padding, dilation=dilation
352        )
353        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
354        self.dropout = nn.Dropout(dropout_rate)
355        self.fc1_f = nn.ModuleList(
356            [nn.Linear(out_channels, out_channels // 2) for _ in range(6)]
357        )
358        self.fc2_f = nn.ModuleList(
359            [nn.Linear(out_channels // 2, out_channels) for _ in range(6)]
360        )
361        # self.fc1_t = nn.ModuleList([nn.Linear(len_segment, len_segment // 2) for _ in range(6)])
362        # self.fc2_t = nn.ModuleList([nn.Linear(len_segment // 2, len_segment) for _ in range(6)])

Parameters

dilation : int dilation in_channels : int number of input channels out_channels : int number of output channels dropout_rate : float dropout rate causal : bool if True, causal convolutions are used

padding
causal
conv_dilated
conv_1x1
dropout
fc1_f
fc2_f
def forward(self, x, direction, tag):
396    def forward(self, x, direction, tag):
397        """Forward pass."""
398        if direction is not None and not self.causal:
399            raise ValueError("Cannot set direction in a non-causal layer!")
400        elif direction is None and self.causal:
401            direction = "forward"
402        if direction == "forward":
403            padding = (0, self.padding)
404        elif direction == "backward":
405            padding = (self.padding, 0)
406        elif direction is not None:
407            raise ValueError(
408                f"Unrecognized direction: {direction}, please choose from"
409                f'"backward", "forward" and None'
410            )
411        if direction is not None:
412            out = self.conv_dilated(F.pad(x, padding))
413        else:
414            out = self.conv_dilated(x)
415        out = F.relu(out)
416        out = self.conv_1x1(out)
417        out = self.dropout(out)
418        scale = torch.mean(out, -1)
419        scale = self._fc2_f(tag)(F.relu(self._fc1_f(tag)(scale)))
420        scale = F.sigmoid(scale).unsqueeze(-1)
421        # time_scale = torch.mean(out, 1)
422        # time_scale = self._fc2_t(tag)(F.relu(self._fc1_t(tag)(time_scale)))
423        # time_scale = F.sigmoid(time_scale).unsqueeze(1)
424        out = out * scale  # * time_scale
425        return x + out

Forward pass.

class DilatedResidualLayer_SEC(torch.nn.modules.module.Module):
428class DilatedResidualLayer_SEC(nn.Module):
429    """
430    Dilated residual layer
431    """
432
433    def __init__(
434        self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment
435    ):
436        """
437        Parameters
438        ----------
439        dilation : int
440            dilation
441        in_channels : int
442            number of input channels
443        out_channels : int
444            number of output channels
445        dropout_rate : float
446            dropout rate
447        causal : bool
448            if `True`, causal convolutions are used
449        """
450
451        super().__init__()
452        self.padding = dilation * 2
453        self.causal = causal
454        if self.causal:
455            padding = 0
456        else:
457            padding = dilation
458        self.conv_dilated = nn.Conv1d(
459            in_channels, out_channels, 3, padding=padding, dilation=dilation
460        )
461        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
462        self.dropout = nn.Dropout(dropout_rate)
463        self.conv_sq = nn.ModuleList(
464            [
465                nn.Conv1d(out_channels, 1, 3, padding=padding, dilation=dilation)
466                for _ in range(6)
467            ]
468        )
469
470    def _conv_sq(self, tag):
471        if tag is None:
472            for i in range(1, len(self.conv_sq)):
473                self.conv_sq[i].load_state_dict(self.conv_sq[0].state_dict())
474            return self.conv_sq[0]
475        else:
476            return self.conv_sq[tag]
477
478    def forward(self, x, direction, tag):
479        """Forward pass."""
480        if direction is not None and not self.causal:
481            raise ValueError("Cannot set direction in a non-causal layer!")
482        elif direction is None and self.causal:
483            direction = "forward"
484        if direction == "forward":
485            padding = (0, self.padding)
486        elif direction == "backward":
487            padding = (self.padding, 0)
488        elif direction is not None:
489            raise ValueError(
490                f"Unrecognized direction: {direction}, please choose from"
491                f'"backward", "forward" and None'
492            )
493        if direction is not None:
494            out = self.conv_dilated(F.pad(x, padding))
495        else:
496            out = self.conv_dilated(x)
497        out = F.relu(out)
498        out = self.conv_1x1(out)
499        out = self.dropout(out)
500        scale = torch.sigmoid(self._conv_sq(tag)(out))
501        out = out * scale
502        return x + out

Dilated residual layer

DilatedResidualLayer_SEC( dilation, in_channels, out_channels, dropout_rate, causal, len_segment)
433    def __init__(
434        self, dilation, in_channels, out_channels, dropout_rate, causal, len_segment
435    ):
436        """
437        Parameters
438        ----------
439        dilation : int
440            dilation
441        in_channels : int
442            number of input channels
443        out_channels : int
444            number of output channels
445        dropout_rate : float
446            dropout rate
447        causal : bool
448            if `True`, causal convolutions are used
449        """
450
451        super().__init__()
452        self.padding = dilation * 2
453        self.causal = causal
454        if self.causal:
455            padding = 0
456        else:
457            padding = dilation
458        self.conv_dilated = nn.Conv1d(
459            in_channels, out_channels, 3, padding=padding, dilation=dilation
460        )
461        self.conv_1x1 = nn.Conv1d(out_channels, out_channels, 1)
462        self.dropout = nn.Dropout(dropout_rate)
463        self.conv_sq = nn.ModuleList(
464            [
465                nn.Conv1d(out_channels, 1, 3, padding=padding, dilation=dilation)
466                for _ in range(6)
467            ]
468        )

Parameters

dilation : int dilation in_channels : int number of input channels out_channels : int number of output channels dropout_rate : float dropout rate causal : bool if True, causal convolutions are used

padding
causal
conv_dilated
conv_1x1
dropout
conv_sq
def forward(self, x, direction, tag):
478    def forward(self, x, direction, tag):
479        """Forward pass."""
480        if direction is not None and not self.causal:
481            raise ValueError("Cannot set direction in a non-causal layer!")
482        elif direction is None and self.causal:
483            direction = "forward"
484        if direction == "forward":
485            padding = (0, self.padding)
486        elif direction == "backward":
487            padding = (self.padding, 0)
488        elif direction is not None:
489            raise ValueError(
490                f"Unrecognized direction: {direction}, please choose from"
491                f'"backward", "forward" and None'
492            )
493        if direction is not None:
494            out = self.conv_dilated(F.pad(x, padding))
495        else:
496            out = self.conv_dilated(x)
497        out = F.relu(out)
498        out = self.conv_1x1(out)
499        out = self.dropout(out)
500        scale = torch.sigmoid(self._conv_sq(tag)(out))
501        out = out * scale
502        return x + out

Forward pass.

class DualDilatedResidualLayer(torch.nn.modules.module.Module):
505class DualDilatedResidualLayer(nn.Module):
506    """
507    Dual dilated residual layer
508    """
509
510    def __init__(
511        self,
512        dilation1,
513        dilation2,
514        in_channels,
515        out_channels,
516        dropout_rate,
517        causal,
518        kernel_size=3,
519    ):
520        """
521        Parameters
522        ----------
523        dilation1, dilation2 : int
524            dilation of one of the blocks
525        in_channels : int
526            number of input channels
527        out_channels : int
528            number of output channels
529        dropout_rate : float
530            dropout rate
531        causal : bool
532            if `True`, causal convolutions are used
533        kernel_size : int, default 3
534            kernel size
535        """
536
537        super(DualDilatedResidualLayer, self).__init__()
538        self.causal = causal
539        self.padding1 = dilation1 * (kernel_size - 1) // 2
540        self.padding2 = dilation2 * (kernel_size - 1) // 2
541        if self.causal:
542            self.padding1 *= 2
543            self.padding2 *= 2
544        self.conv_dilated_1 = nn.Conv1d(
545            in_channels,
546            out_channels,
547            kernel_size=kernel_size,
548            padding=self.padding1,
549            dilation=dilation1,
550        )
551        self.conv_dilated_2 = nn.Conv1d(
552            in_channels,
553            out_channels,
554            kernel_size=kernel_size,
555            padding=self.padding2,
556            dilation=dilation2,
557        )
558        self.conv_fusion = nn.Conv1d(2 * out_channels, out_channels, 1)
559        self.dropout = nn.Dropout(dropout_rate)
560
561    def forward(self, x, direction=None):
562        """Forward pass."""
563        if direction is not None and not self.causal:
564            raise ValueError("Cannot set direction in a non-causal layer!")
565        elif direction is None and self.causal:
566            direction = "forward"
567        if direction not in ["backward", "forward", None]:
568            raise ValueError(
569                f"Unrecognized direction: {direction}, please choose from"
570                f'"backward", "forward" and None'
571            )
572        out_1 = self.conv_dilated_1(x)
573        out_2 = self.conv_dilated_2(x)
574        if direction == "forward":
575            out_1 = out_1[:, :, : -self.padding1]
576            out_2 = out_2[:, :, : -self.padding2]
577        elif direction == "backward":
578            out_1 = out_1[:, :, self.padding1 :]
579            out_2 = out_2[:, :, self.padding2 :]
580
581        out = self.conv_fusion(
582            torch.cat(
583                [
584                    out_1,
585                    out_2,
586                ],
587                1,
588            )
589        )
590        out = F.relu(out)
591        out = self.dropout(out)
592        out = out + x
593        return out

Dual dilated residual layer

DualDilatedResidualLayer( dilation1, dilation2, in_channels, out_channels, dropout_rate, causal, kernel_size=3)
510    def __init__(
511        self,
512        dilation1,
513        dilation2,
514        in_channels,
515        out_channels,
516        dropout_rate,
517        causal,
518        kernel_size=3,
519    ):
520        """
521        Parameters
522        ----------
523        dilation1, dilation2 : int
524            dilation of one of the blocks
525        in_channels : int
526            number of input channels
527        out_channels : int
528            number of output channels
529        dropout_rate : float
530            dropout rate
531        causal : bool
532            if `True`, causal convolutions are used
533        kernel_size : int, default 3
534            kernel size
535        """
536
537        super(DualDilatedResidualLayer, self).__init__()
538        self.causal = causal
539        self.padding1 = dilation1 * (kernel_size - 1) // 2
540        self.padding2 = dilation2 * (kernel_size - 1) // 2
541        if self.causal:
542            self.padding1 *= 2
543            self.padding2 *= 2
544        self.conv_dilated_1 = nn.Conv1d(
545            in_channels,
546            out_channels,
547            kernel_size=kernel_size,
548            padding=self.padding1,
549            dilation=dilation1,
550        )
551        self.conv_dilated_2 = nn.Conv1d(
552            in_channels,
553            out_channels,
554            kernel_size=kernel_size,
555            padding=self.padding2,
556            dilation=dilation2,
557        )
558        self.conv_fusion = nn.Conv1d(2 * out_channels, out_channels, 1)
559        self.dropout = nn.Dropout(dropout_rate)

Parameters

dilation1, dilation2 : int dilation of one of the blocks in_channels : int number of input channels out_channels : int number of output channels dropout_rate : float dropout rate causal : bool if True, causal convolutions are used kernel_size : int, default 3 kernel size

causal
padding1
padding2
conv_dilated_1
conv_dilated_2
conv_fusion
dropout
def forward(self, x, direction=None):
561    def forward(self, x, direction=None):
562        """Forward pass."""
563        if direction is not None and not self.causal:
564            raise ValueError("Cannot set direction in a non-causal layer!")
565        elif direction is None and self.causal:
566            direction = "forward"
567        if direction not in ["backward", "forward", None]:
568            raise ValueError(
569                f"Unrecognized direction: {direction}, please choose from"
570                f'"backward", "forward" and None'
571            )
572        out_1 = self.conv_dilated_1(x)
573        out_2 = self.conv_dilated_2(x)
574        if direction == "forward":
575            out_1 = out_1[:, :, : -self.padding1]
576            out_2 = out_2[:, :, : -self.padding2]
577        elif direction == "backward":
578            out_1 = out_1[:, :, self.padding1 :]
579            out_2 = out_2[:, :, self.padding2 :]
580
581        out = self.conv_fusion(
582            torch.cat(
583                [
584                    out_1,
585                    out_2,
586                ],
587                1,
588            )
589        )
590        out = F.relu(out)
591        out = self.dropout(out)
592        out = out + x
593        return out

Forward pass.

class MultiDilatedTCN(torch.nn.modules.module.Module):
596class MultiDilatedTCN(nn.Module):
597    """
598    Multiple prediction generation stages in parallel
599    """
600
601    def __init__(
602        self,
603        num_layers,
604        num_f_maps,
605        dims,
606        direction,
607        block_size=5,
608        kernel_size=3,
609        rare_dilations=False,
610    ):
611        super(MultiDilatedTCN, self).__init__()
612        self.PGs = nn.ModuleList(
613            [
614                DilatedTCN(
615                    num_layers,
616                    num_f_maps,
617                    dim,
618                    direction,
619                    block_size,
620                    kernel_size,
621                    rare_dilations,
622                )
623                for dim in dims
624            ]
625        )
626        self.conv_out = nn.Conv1d(num_f_maps * len(dims), num_f_maps, 1)
627
628    def forward(self, x, tag=None):
629        """Forward pass."""
630        out = []
631        for arr, PG in zip(x, self.PGs):
632            out.append(PG(arr))
633        out = torch.cat(out, 1)
634        out = self.conv_out(out)
635        return out

Multiple prediction generation stages in parallel

MultiDilatedTCN( num_layers, num_f_maps, dims, direction, block_size=5, kernel_size=3, rare_dilations=False)
601    def __init__(
602        self,
603        num_layers,
604        num_f_maps,
605        dims,
606        direction,
607        block_size=5,
608        kernel_size=3,
609        rare_dilations=False,
610    ):
611        super(MultiDilatedTCN, self).__init__()
612        self.PGs = nn.ModuleList(
613            [
614                DilatedTCN(
615                    num_layers,
616                    num_f_maps,
617                    dim,
618                    direction,
619                    block_size,
620                    kernel_size,
621                    rare_dilations,
622                )
623                for dim in dims
624            ]
625        )
626        self.conv_out = nn.Conv1d(num_f_maps * len(dims), num_f_maps, 1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

PGs
conv_out
def forward(self, x, tag=None):
628    def forward(self, x, tag=None):
629        """Forward pass."""
630        out = []
631        for arr, PG in zip(x, self.PGs):
632            out.append(PG(arr))
633        out = torch.cat(out, 1)
634        out = self.conv_out(out)
635        return out

Forward pass.

class DilatedTCN(torch.nn.modules.module.Module):
638class DilatedTCN(nn.Module):
639    """
640    Prediction generation stage
641    """
642
643    def __init__(
644        self,
645        num_layers,
646        num_f_maps,
647        dim,
648        direction,
649        num_bp=None,
650        block_size=5,
651        kernel_size=3,
652        rare_dilations=False,
653        attention="none",
654        multihead=False,
655    ):
656        """
657        Parameters
658        ----------
659        num_layers : int
660            number of layers
661        num_f_maps : int
662            number of feature maps
663        dim : int
664            number of features in input
665        direction : [None, 'forward', 'backward']
666            the direction of convolutions; if None, regular convolutions are used
667        block_size : int, default 0
668            if not 0, skip connections are added to the prediction generation stage with this interval
669        kernel_size : int, default 3
670            kernel size
671        rare_dilations : bool, default False
672            if `False`, dilation increases every layer, otherwise every second layer
673        """
674
675        super().__init__()
676        self.num_layers = num_layers
677        self.block_size = block_size
678        self.direction = direction
679        self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1)
680        pars = {
681            "in_channels": num_f_maps,
682            "out_channels": num_f_maps,
683            "dropout_rate": 0.5,
684            "causal": (direction is not None),
685            "kernel_size": kernel_size,
686        }
687        module = DualDilatedResidualLayer
688        if not rare_dilations:
689            self.layers = nn.ModuleList(
690                [
691                    module(dilation1=2 ** (num_layers - 1 - i), dilation2=2**i, **pars)
692                    for i in range(num_layers)
693                ]
694            )
695        else:
696            self.layers = nn.ModuleList(
697                [
698                    module(
699                        dilation1=2 ** (num_layers // 2 - 1 - i // 2),
700                        dilation2=2 ** (i // 2),
701                        **pars,
702                    )
703                    for i in range(num_layers)
704                ]
705            )
706        self.attention_layers = nn.ModuleList([])
707        self.attention = attention
708        if self.attention in ["basic"]:
709            self.attention_layers += nn.ModuleList(
710                [
711                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
712                    nn.ReLU(),
713                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
714                    nn.Sigmoid(),
715                ]
716            )
717        elif isinstance(self.attention, str) and self.attention != "none":
718            self.attention_layers = AttLayer(
719                num_f_maps,
720                num_f_maps,
721                num_f_maps,
722                4,
723                4,
724                2,
725                64,
726                att_type=self.attention,
727                stage="encoder",
728            )
729        self.multihead = multihead
730
731    def forward(self, x, tag=None):
732        """Forward pass."""
733        x = self.conv_1x1_in(x)
734        f = copy.copy(x)
735        for i, layer in enumerate(self.layers):
736            f = layer(f, self.direction)
737            if self.block_size != 0 and (i + 1) % self.block_size == 0:
738                f = f + x
739                x = copy.copy(f)
740        if isinstance(self.attention, str) and self.attention != "none":
741            x = copy.copy(f)
742            if self.attention != "basic":
743                f = self.attention_layers(x)
744            elif not self.multihead:
745                for layer in self.attention_layers:
746                    x = layer(x)
747                f = f * x
748            else:
749                outputs = []
750                for layers in self.attention_layers:
751                    y = copy.copy(x)
752                    for layer in layers:
753                        y = layer(y)
754                    outputs.append(copy.copy(f * y))
755                outputs = torch.cat(outputs, dim=1)
756                f = self.conv_att(self.dropout(outputs))
757        return f

Prediction generation stage

DilatedTCN( num_layers, num_f_maps, dim, direction, num_bp=None, block_size=5, kernel_size=3, rare_dilations=False, attention='none', multihead=False)
643    def __init__(
644        self,
645        num_layers,
646        num_f_maps,
647        dim,
648        direction,
649        num_bp=None,
650        block_size=5,
651        kernel_size=3,
652        rare_dilations=False,
653        attention="none",
654        multihead=False,
655    ):
656        """
657        Parameters
658        ----------
659        num_layers : int
660            number of layers
661        num_f_maps : int
662            number of feature maps
663        dim : int
664            number of features in input
665        direction : [None, 'forward', 'backward']
666            the direction of convolutions; if None, regular convolutions are used
667        block_size : int, default 0
668            if not 0, skip connections are added to the prediction generation stage with this interval
669        kernel_size : int, default 3
670            kernel size
671        rare_dilations : bool, default False
672            if `False`, dilation increases every layer, otherwise every second layer
673        """
674
675        super().__init__()
676        self.num_layers = num_layers
677        self.block_size = block_size
678        self.direction = direction
679        self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1)
680        pars = {
681            "in_channels": num_f_maps,
682            "out_channels": num_f_maps,
683            "dropout_rate": 0.5,
684            "causal": (direction is not None),
685            "kernel_size": kernel_size,
686        }
687        module = DualDilatedResidualLayer
688        if not rare_dilations:
689            self.layers = nn.ModuleList(
690                [
691                    module(dilation1=2 ** (num_layers - 1 - i), dilation2=2**i, **pars)
692                    for i in range(num_layers)
693                ]
694            )
695        else:
696            self.layers = nn.ModuleList(
697                [
698                    module(
699                        dilation1=2 ** (num_layers // 2 - 1 - i // 2),
700                        dilation2=2 ** (i // 2),
701                        **pars,
702                    )
703                    for i in range(num_layers)
704                ]
705            )
706        self.attention_layers = nn.ModuleList([])
707        self.attention = attention
708        if self.attention in ["basic"]:
709            self.attention_layers += nn.ModuleList(
710                [
711                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
712                    nn.ReLU(),
713                    nn.Conv1d(num_f_maps, num_f_maps, 3, padding=1),
714                    nn.Sigmoid(),
715                ]
716            )
717        elif isinstance(self.attention, str) and self.attention != "none":
718            self.attention_layers = AttLayer(
719                num_f_maps,
720                num_f_maps,
721                num_f_maps,
722                4,
723                4,
724                2,
725                64,
726                att_type=self.attention,
727                stage="encoder",
728            )
729        self.multihead = multihead

Parameters

num_layers : int number of layers num_f_maps : int number of feature maps dim : int number of features in input direction : [None, 'forward', 'backward'] the direction of convolutions; if None, regular convolutions are used block_size : int, default 0 if not 0, skip connections are added to the prediction generation stage with this interval kernel_size : int, default 3 kernel size rare_dilations : bool, default False if False, dilation increases every layer, otherwise every second layer

num_layers
block_size
direction
conv_1x1_in
attention_layers
attention
multihead
def forward(self, x, tag=None):
731    def forward(self, x, tag=None):
732        """Forward pass."""
733        x = self.conv_1x1_in(x)
734        f = copy.copy(x)
735        for i, layer in enumerate(self.layers):
736            f = layer(f, self.direction)
737            if self.block_size != 0 and (i + 1) % self.block_size == 0:
738                f = f + x
739                x = copy.copy(f)
740        if isinstance(self.attention, str) and self.attention != "none":
741            x = copy.copy(f)
742            if self.attention != "basic":
743                f = self.attention_layers(x)
744            elif not self.multihead:
745                for layer in self.attention_layers:
746                    x = layer(x)
747                f = f * x
748            else:
749                outputs = []
750                for layers in self.attention_layers:
751                    y = copy.copy(x)
752                    for layer in layers:
753                        y = layer(y)
754                    outputs.append(copy.copy(f * y))
755                outputs = torch.cat(outputs, dim=1)
756                f = self.conv_att(self.dropout(outputs))
757        return f

Forward pass.

class DilatedTCNB(torch.nn.modules.module.Module):
760class DilatedTCNB(nn.Module):
761    """
762    Bidirectional prediction generation stage
763    """
764
765    def __init__(
766        self,
767        num_layers,
768        num_f_maps,
769        dim,
770        block_size=5,
771        kernel_size=3,
772        rare_dilations=False,
773    ):
774        """
775        Parameters
776        ----------
777        num_layers : int
778            number of layers
779        num_f_maps : int
780            number of feature maps
781        dim : int
782            number of features in input
783        block_size : int, default 0
784            if not 0, skip connections are added to the prediction generation stage with this interval
785        kernel_size : int, default 3
786            kernel size
787        rare_dilations : bool, default False
788            if `False`, dilation increases every layer, otherwise every second layer
789        """
790
791        super().__init__()
792        self.num_layers = num_layers
793        self.block_size = block_size
794        self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1)
795        self.conv_1x1_out = nn.Conv1d(num_f_maps * 2, num_f_maps, 1)
796        if not rare_dilations:
797            self.layers = nn.ModuleList(
798                [
799                    DualDilatedResidualLayer(
800                        dilation1=2 ** (num_layers - 1 - i),
801                        dilation2=2**i,
802                        in_channels=num_f_maps,
803                        out_channels=num_f_maps,
804                        dropout_rate=0.5,
805                        causal=True,
806                        kernel_size=kernel_size,
807                    )
808                    for i in range(num_layers)
809                ]
810            )
811        else:
812            self.layers = nn.ModuleList(
813                [
814                    DualDilatedResidualLayer(
815                        dilation1=2 ** (num_layers // 2 - 1 - i // 2),
816                        dilation2=2 ** (i // 2),
817                        in_channels=num_f_maps,
818                        out_channels=num_f_maps,
819                        dropout_rate=0.5,
820                        causal=True,
821                        kernel_size=kernel_size,
822                    )
823                    for i in range(num_layers)
824                ]
825            )
826
827    def forward(self, x, tag=None):
828        """Forward pass."""
829        x_f = self.conv_1x1_in(x)
830        x_b = copy.copy(x_f)
831        forward = copy.copy(x_f)
832        backward = copy.copy(x_f)
833        for i, layer_f in enumerate(self.layers):
834            forward = layer_f(forward, "forward")
835            backward = layer_f(backward, "backward")
836            if self.block_size != 0 and (i + 1) % self.block_size == 0:
837                forward = forward + x_f
838                backward = backward + x_b
839                x_f = copy.copy(forward)
840                x_b = copy.copy(backward)
841        out = torch.cat([forward, backward], 1)
842        out = self.conv_1x1_out(out)
843        return out

Bidirectional prediction generation stage

DilatedTCNB( num_layers, num_f_maps, dim, block_size=5, kernel_size=3, rare_dilations=False)
765    def __init__(
766        self,
767        num_layers,
768        num_f_maps,
769        dim,
770        block_size=5,
771        kernel_size=3,
772        rare_dilations=False,
773    ):
774        """
775        Parameters
776        ----------
777        num_layers : int
778            number of layers
779        num_f_maps : int
780            number of feature maps
781        dim : int
782            number of features in input
783        block_size : int, default 0
784            if not 0, skip connections are added to the prediction generation stage with this interval
785        kernel_size : int, default 3
786            kernel size
787        rare_dilations : bool, default False
788            if `False`, dilation increases every layer, otherwise every second layer
789        """
790
791        super().__init__()
792        self.num_layers = num_layers
793        self.block_size = block_size
794        self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1)
795        self.conv_1x1_out = nn.Conv1d(num_f_maps * 2, num_f_maps, 1)
796        if not rare_dilations:
797            self.layers = nn.ModuleList(
798                [
799                    DualDilatedResidualLayer(
800                        dilation1=2 ** (num_layers - 1 - i),
801                        dilation2=2**i,
802                        in_channels=num_f_maps,
803                        out_channels=num_f_maps,
804                        dropout_rate=0.5,
805                        causal=True,
806                        kernel_size=kernel_size,
807                    )
808                    for i in range(num_layers)
809                ]
810            )
811        else:
812            self.layers = nn.ModuleList(
813                [
814                    DualDilatedResidualLayer(
815                        dilation1=2 ** (num_layers // 2 - 1 - i // 2),
816                        dilation2=2 ** (i // 2),
817                        in_channels=num_f_maps,
818                        out_channels=num_f_maps,
819                        dropout_rate=0.5,
820                        causal=True,
821                        kernel_size=kernel_size,
822                    )
823                    for i in range(num_layers)
824                ]
825            )

Parameters

num_layers : int number of layers num_f_maps : int number of feature maps dim : int number of features in input block_size : int, default 0 if not 0, skip connections are added to the prediction generation stage with this interval kernel_size : int, default 3 kernel size rare_dilations : bool, default False if False, dilation increases every layer, otherwise every second layer

num_layers
block_size
conv_1x1_in
conv_1x1_out
def forward(self, x, tag=None):
827    def forward(self, x, tag=None):
828        """Forward pass."""
829        x_f = self.conv_1x1_in(x)
830        x_b = copy.copy(x_f)
831        forward = copy.copy(x_f)
832        backward = copy.copy(x_f)
833        for i, layer_f in enumerate(self.layers):
834            forward = layer_f(forward, "forward")
835            backward = layer_f(backward, "backward")
836            if self.block_size != 0 and (i + 1) % self.block_size == 0:
837                forward = forward + x_f
838                backward = backward + x_b
839                x_f = copy.copy(forward)
840                x_b = copy.copy(backward)
841        out = torch.cat([forward, backward], 1)
842        out = self.conv_1x1_out(out)
843        return out

Forward pass.

class SpatialFeatures(torch.nn.modules.module.Module):
846class SpatialFeatures(nn.Module):
847    """
848    Spatial features extraction stage
849    """
850
851    def __init__(
852        self,
853        num_layers,
854        num_f_maps,
855        dim,
856        block_size=5,
857        graph_edges=None,
858        num_nodes=None,
859        denom: int = 8,
860    ):
861        """
862        Parameters
863        ----------
864        num_layers : int
865            number of layers
866        num_f_maps : int
867            number of feature maps
868        dim : int
869            number of features in input
870        block_size : int, default 5
871            if not 0, skip connections are added to the prediction generation stage with this interval
872        """
873
874        super().__init__()
875        self.num_nodes = num_nodes
876        if graph_edges is None:
877            module = SimpleResidualLayer
878            self.graph = False
879            pars = {"num_f_maps": num_f_maps, "dropout_rate": 0.5}
880            self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1)
881        else:
882            raise NotImplementedError("Graph not implemented")
883        
884        self.num_layers = num_layers
885        self.block_size = block_size
886        self.layers = nn.ModuleList([module(**pars) for _ in range(num_layers)])
887
888    def forward(self, x):
889        """Forward pass."""
890        if self.graph:
891            B, _, L = x.shape
892            x = x.transpose(-1, -2)
893            x = x.reshape((-1, x.shape[-1]))
894            x = x.reshape((x.shape[0], self.num_nodes, -1))
895            x = x.transpose(-1, -2)
896        x = self.conv_1x1_in(x)
897        f = copy.copy(x)
898        for i, layer in enumerate(self.layers):
899            f = layer(f)
900            if self.block_size != 0 and (i + 1) % self.block_size == 0:
901                f = f + x
902                x = copy.copy(f)
903        if self.graph:
904            f = f.reshape((B, L, -1))
905            f = f.transpose(-1, -2)
906            f = self.conv_1x1_out(f)
907        return f

Spatial features extraction stage

SpatialFeatures( num_layers, num_f_maps, dim, block_size=5, graph_edges=None, num_nodes=None, denom: int = 8)
851    def __init__(
852        self,
853        num_layers,
854        num_f_maps,
855        dim,
856        block_size=5,
857        graph_edges=None,
858        num_nodes=None,
859        denom: int = 8,
860    ):
861        """
862        Parameters
863        ----------
864        num_layers : int
865            number of layers
866        num_f_maps : int
867            number of feature maps
868        dim : int
869            number of features in input
870        block_size : int, default 5
871            if not 0, skip connections are added to the prediction generation stage with this interval
872        """
873
874        super().__init__()
875        self.num_nodes = num_nodes
876        if graph_edges is None:
877            module = SimpleResidualLayer
878            self.graph = False
879            pars = {"num_f_maps": num_f_maps, "dropout_rate": 0.5}
880            self.conv_1x1_in = nn.Conv1d(dim, num_f_maps, 1)
881        else:
882            raise NotImplementedError("Graph not implemented")
883        
884        self.num_layers = num_layers
885        self.block_size = block_size
886        self.layers = nn.ModuleList([module(**pars) for _ in range(num_layers)])

Parameters

num_layers : int number of layers num_f_maps : int number of feature maps dim : int number of features in input block_size : int, default 5 if not 0, skip connections are added to the prediction generation stage with this interval

num_nodes
num_layers
block_size
layers
def forward(self, x):
888    def forward(self, x):
889        """Forward pass."""
890        if self.graph:
891            B, _, L = x.shape
892            x = x.transpose(-1, -2)
893            x = x.reshape((-1, x.shape[-1]))
894            x = x.reshape((x.shape[0], self.num_nodes, -1))
895            x = x.transpose(-1, -2)
896        x = self.conv_1x1_in(x)
897        f = copy.copy(x)
898        for i, layer in enumerate(self.layers):
899            f = layer(f)
900            if self.block_size != 0 and (i + 1) % self.block_size == 0:
901                f = f + x
902                x = copy.copy(f)
903        if self.graph:
904            f = f.reshape((B, L, -1))
905            f = f.transpose(-1, -2)
906            f = self.conv_1x1_out(f)
907        return f

Forward pass.

class MSRefinement(torch.nn.modules.module.Module):
 910class MSRefinement(nn.Module):
 911    """
 912    Refinement stage
 913    """
 914
 915    def __init__(
 916        self,
 917        num_layers_R,
 918        num_R,
 919        num_f_maps_input,
 920        num_f_maps,
 921        num_classes,
 922        dropout_rate,
 923        exclusive,
 924        skip_connections,
 925        direction,
 926        block_size=0,
 927        num_heads=1,
 928        attention="none",
 929    ):
 930        """
 931        Parameters
 932        ----------
 933        num_layers_R : int
 934            number of layers in refinement modules
 935        num_R : int
 936            number of refinement modules
 937        num_f_maps : int
 938            number of feature maps
 939        num_classes : int
 940            number of target classes
 941        dropout_rate : float
 942            dropout rate
 943        exclusive : bool
 944            set `False` for multi-label classification
 945        skip_connections : bool
 946            if `True`, skip connections are added
 947        direction : [None, 'bidirectional', 'forward', 'backward']
 948            the direction of convolutions; if None, regular convolutions are used
 949        block_size : int, default 0
 950            if not 0, skip connections are added to the prediction generation stage with this interval
 951        num_heads : int, default 1
 952            number of parallel refinement stages
 953        """
 954
 955        super().__init__()
 956        self.skip_connections = skip_connections
 957        self.num_heads = num_heads
 958        if exclusive:
 959            self.nl = lambda x: F.softmax(x, dim=1)
 960        else:
 961            self.nl = lambda x: torch.sigmoid(x)
 962        if direction == "bidirectional":
 963            refinement_module = RefinementB
 964        else:
 965            refinement_module = Refinement
 966        self.Rs = nn.ModuleList(
 967            [
 968                nn.ModuleList(
 969                    [
 970                        refinement_module(
 971                            num_layers=num_layers_R,
 972                            num_f_maps=num_f_maps,
 973                            num_f_maps_input=num_f_maps_input,
 974                            dim=num_classes,
 975                            num_classes=num_classes,
 976                            dropout_rate=dropout_rate,
 977                            direction=direction,
 978                            skip_connections=skip_connections,
 979                            block_size=block_size,
 980                            attention=attention,
 981                        )
 982                        for s in range(num_R)
 983                    ]
 984                )
 985                for _ in range(self.num_heads)
 986            ]
 987        )
 988        self.conv_out = nn.ModuleList(
 989            [nn.Conv1d(num_f_maps_input, num_classes, 1) for _ in range(self.num_heads)]
 990        )
 991        if self.num_heads == 1:
 992            self.Rs = self.Rs[0]
 993            self.conv_out = self.conv_out[0]
 994
 995    def _Rs(self, tag):
 996        if self.num_heads == 1:
 997            return self.Rs
 998        if tag is None:
 999            tag = 0
1000            for i in range(1, self.num_heads):
1001                self.Rs[i].load_state_dict(self.Rs[0].state_dict())
1002        return self.Rs[tag]
1003
1004    def _conv_out(self, tag):
1005        if self.num_heads == 1:
1006            return self.conv_out
1007        if tag is None:
1008            tag = 0
1009            for i in range(1, self.num_heads):
1010                self.conv_out[i].load_state_dict(self.conv_out[0].state_dict())
1011        return self.conv_out[tag]
1012
1013    def forward(self, x, tag=None):
1014        """Forward pass."""
1015        if tag is not None:
1016            tag = tag[0]
1017        out = self._conv_out(tag)(x)
1018        outputs = out.unsqueeze(0)
1019        for R in self._Rs(tag):
1020            if self.skip_connections:
1021                out = R(torch.cat([self.nl(out), x], axis=1))
1022                # out = R(torch.cat([out, x], axis=1))
1023            else:
1024                out = R(self.nl(out))
1025                # out = R(out)
1026            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
1027
1028        return outputs

Refinement stage

MSRefinement( num_layers_R, num_R, num_f_maps_input, num_f_maps, num_classes, dropout_rate, exclusive, skip_connections, direction, block_size=0, num_heads=1, attention='none')
915    def __init__(
916        self,
917        num_layers_R,
918        num_R,
919        num_f_maps_input,
920        num_f_maps,
921        num_classes,
922        dropout_rate,
923        exclusive,
924        skip_connections,
925        direction,
926        block_size=0,
927        num_heads=1,
928        attention="none",
929    ):
930        """
931        Parameters
932        ----------
933        num_layers_R : int
934            number of layers in refinement modules
935        num_R : int
936            number of refinement modules
937        num_f_maps : int
938            number of feature maps
939        num_classes : int
940            number of target classes
941        dropout_rate : float
942            dropout rate
943        exclusive : bool
944            set `False` for multi-label classification
945        skip_connections : bool
946            if `True`, skip connections are added
947        direction : [None, 'bidirectional', 'forward', 'backward']
948            the direction of convolutions; if None, regular convolutions are used
949        block_size : int, default 0
950            if not 0, skip connections are added to the prediction generation stage with this interval
951        num_heads : int, default 1
952            number of parallel refinement stages
953        """
954
955        super().__init__()
956        self.skip_connections = skip_connections
957        self.num_heads = num_heads
958        if exclusive:
959            self.nl = lambda x: F.softmax(x, dim=1)
960        else:
961            self.nl = lambda x: torch.sigmoid(x)
962        if direction == "bidirectional":
963            refinement_module = RefinementB
964        else:
965            refinement_module = Refinement
966        self.Rs = nn.ModuleList(
967            [
968                nn.ModuleList(
969                    [
970                        refinement_module(
971                            num_layers=num_layers_R,
972                            num_f_maps=num_f_maps,
973                            num_f_maps_input=num_f_maps_input,
974                            dim=num_classes,
975                            num_classes=num_classes,
976                            dropout_rate=dropout_rate,
977                            direction=direction,
978                            skip_connections=skip_connections,
979                            block_size=block_size,
980                            attention=attention,
981                        )
982                        for s in range(num_R)
983                    ]
984                )
985                for _ in range(self.num_heads)
986            ]
987        )
988        self.conv_out = nn.ModuleList(
989            [nn.Conv1d(num_f_maps_input, num_classes, 1) for _ in range(self.num_heads)]
990        )
991        if self.num_heads == 1:
992            self.Rs = self.Rs[0]
993            self.conv_out = self.conv_out[0]

Parameters

num_layers_R : int number of layers in refinement modules num_R : int number of refinement modules num_f_maps : int number of feature maps num_classes : int number of target classes dropout_rate : float dropout rate exclusive : bool set False for multi-label classification skip_connections : bool if True, skip connections are added direction : [None, 'bidirectional', 'forward', 'backward'] the direction of convolutions; if None, regular convolutions are used block_size : int, default 0 if not 0, skip connections are added to the prediction generation stage with this interval num_heads : int, default 1 number of parallel refinement stages

skip_connections
num_heads
Rs
conv_out
def forward(self, x, tag=None):
1013    def forward(self, x, tag=None):
1014        """Forward pass."""
1015        if tag is not None:
1016            tag = tag[0]
1017        out = self._conv_out(tag)(x)
1018        outputs = out.unsqueeze(0)
1019        for R in self._Rs(tag):
1020            if self.skip_connections:
1021                out = R(torch.cat([self.nl(out), x], axis=1))
1022                # out = R(torch.cat([out, x], axis=1))
1023            else:
1024                out = R(self.nl(out))
1025                # out = R(out)
1026            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
1027
1028        return outputs

Forward pass.

class MSRefinementShared(torch.nn.modules.module.Module):
1031class MSRefinementShared(nn.Module):
1032    """
1033    Refinement stage with shared weights across modules
1034    """
1035
1036    def __init__(
1037        self,
1038        num_layers_R,
1039        num_R,
1040        num_f_maps_input,
1041        num_f_maps,
1042        num_classes,
1043        dropout_rate,
1044        exclusive,
1045        skip_connections,
1046        direction,
1047        block_size=0,
1048        num_heads=1,
1049        attention="none",
1050    ):
1051        """
1052        Parameters
1053        ----------
1054        num_layers_R : int
1055            number of layers in refinement modules
1056        num_R : int
1057            number of refinement modules
1058        num_f_maps : int
1059            number of feature maps
1060        num_classes : int
1061            number of target classes
1062        dropout_rate : float
1063            dropout rate
1064        exclusive : bool
1065            set `False` for multi-label classification
1066        skip_connections : bool
1067            if `True`, skip connections are added
1068        direction : [None, 'bidirectional', 'forward', 'backward']
1069            the direction of convolutions; if None, regular convolutions are used
1070        block_size : int, default 0
1071            if not 0, skip connections are added to the prediction generation stage with this interval
1072        num_heads : int, default 1
1073            number of parallel refinement stages
1074        """
1075
1076        super().__init__()
1077        if exclusive:
1078            self.nl = lambda x: F.softmax(x, dim=1)
1079        else:
1080            self.nl = lambda x: torch.sigmoid(x)
1081        if direction == "bidirectional":
1082            refinement_module = RefinementB
1083        else:
1084            refinement_module = Refinement
1085        self.num_heads = num_heads
1086        self.R = nn.ModuleList(
1087            [
1088                refinement_module(
1089                    num_layers=num_layers_R,
1090                    num_f_maps_input=num_f_maps_input,
1091                    num_f_maps=num_f_maps,
1092                    dim=num_classes,
1093                    num_classes=num_classes,
1094                    dropout_rate=dropout_rate,
1095                    direction=direction,
1096                    skip_connections=skip_connections,
1097                    block_size=block_size,
1098                    attention=attention,
1099                )
1100                for _ in range(self.num_heads)
1101            ]
1102        )
1103        self.num_R = num_R
1104        self.conv_out = nn.ModuleList(
1105            [nn.Conv1d(num_f_maps_input, num_classes, 1) for _ in range(self.num_heads)]
1106        )
1107        self.skip_connections = skip_connections
1108        if self.num_heads == 1:
1109            self.R = self.R[0]
1110            self.conv_out = self.conv_out[0]
1111
1112    def _R(self, tag):
1113        if self.num_heads == 1:
1114            return self.R
1115        if tag is None:
1116            tag = 0
1117            for i in range(1, self.num_heads):
1118                self.R[i].load_state_dict(self.R[0].state_dict())
1119        return self.R[tag]
1120
1121    def _conv_out(self, tag):
1122        if self.num_heads == 1:
1123            return self.conv_out
1124        if tag is None:
1125            tag = 0
1126            for i in range(1, self.num_heads):
1127                self.conv_out[i].load_state_dict(self.conv_out[0].state_dict())
1128        return self.conv_out[tag]
1129
1130    def forward(self, x, tag=None):
1131        """Forward pass."""
1132        if tag is not None:
1133            tag = tag[0]
1134        out = self._conv_out(tag)(x)
1135        outputs = out.unsqueeze(0)
1136        for _ in range(self.num_R):
1137            if self.skip_connections:
1138                # out = self._R(tag)(torch.cat([self.nl(out), x], axis=1))
1139                out = self._R(tag)(torch.cat([out, x], axis=1))
1140            else:
1141                # out = self._R(tag)(self.nl(out))
1142                out = self._R(tag)(out)
1143            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
1144
1145        return outputs

Refinement stage with shared weights across modules

MSRefinementShared( num_layers_R, num_R, num_f_maps_input, num_f_maps, num_classes, dropout_rate, exclusive, skip_connections, direction, block_size=0, num_heads=1, attention='none')
1036    def __init__(
1037        self,
1038        num_layers_R,
1039        num_R,
1040        num_f_maps_input,
1041        num_f_maps,
1042        num_classes,
1043        dropout_rate,
1044        exclusive,
1045        skip_connections,
1046        direction,
1047        block_size=0,
1048        num_heads=1,
1049        attention="none",
1050    ):
1051        """
1052        Parameters
1053        ----------
1054        num_layers_R : int
1055            number of layers in refinement modules
1056        num_R : int
1057            number of refinement modules
1058        num_f_maps : int
1059            number of feature maps
1060        num_classes : int
1061            number of target classes
1062        dropout_rate : float
1063            dropout rate
1064        exclusive : bool
1065            set `False` for multi-label classification
1066        skip_connections : bool
1067            if `True`, skip connections are added
1068        direction : [None, 'bidirectional', 'forward', 'backward']
1069            the direction of convolutions; if None, regular convolutions are used
1070        block_size : int, default 0
1071            if not 0, skip connections are added to the prediction generation stage with this interval
1072        num_heads : int, default 1
1073            number of parallel refinement stages
1074        """
1075
1076        super().__init__()
1077        if exclusive:
1078            self.nl = lambda x: F.softmax(x, dim=1)
1079        else:
1080            self.nl = lambda x: torch.sigmoid(x)
1081        if direction == "bidirectional":
1082            refinement_module = RefinementB
1083        else:
1084            refinement_module = Refinement
1085        self.num_heads = num_heads
1086        self.R = nn.ModuleList(
1087            [
1088                refinement_module(
1089                    num_layers=num_layers_R,
1090                    num_f_maps_input=num_f_maps_input,
1091                    num_f_maps=num_f_maps,
1092                    dim=num_classes,
1093                    num_classes=num_classes,
1094                    dropout_rate=dropout_rate,
1095                    direction=direction,
1096                    skip_connections=skip_connections,
1097                    block_size=block_size,
1098                    attention=attention,
1099                )
1100                for _ in range(self.num_heads)
1101            ]
1102        )
1103        self.num_R = num_R
1104        self.conv_out = nn.ModuleList(
1105            [nn.Conv1d(num_f_maps_input, num_classes, 1) for _ in range(self.num_heads)]
1106        )
1107        self.skip_connections = skip_connections
1108        if self.num_heads == 1:
1109            self.R = self.R[0]
1110            self.conv_out = self.conv_out[0]

Parameters

num_layers_R : int number of layers in refinement modules num_R : int number of refinement modules num_f_maps : int number of feature maps num_classes : int number of target classes dropout_rate : float dropout rate exclusive : bool set False for multi-label classification skip_connections : bool if True, skip connections are added direction : [None, 'bidirectional', 'forward', 'backward'] the direction of convolutions; if None, regular convolutions are used block_size : int, default 0 if not 0, skip connections are added to the prediction generation stage with this interval num_heads : int, default 1 number of parallel refinement stages

num_heads
R
num_R
conv_out
skip_connections
def forward(self, x, tag=None):
1130    def forward(self, x, tag=None):
1131        """Forward pass."""
1132        if tag is not None:
1133            tag = tag[0]
1134        out = self._conv_out(tag)(x)
1135        outputs = out.unsqueeze(0)
1136        for _ in range(self.num_R):
1137            if self.skip_connections:
1138                # out = self._R(tag)(torch.cat([self.nl(out), x], axis=1))
1139                out = self._R(tag)(torch.cat([out, x], axis=1))
1140            else:
1141                # out = self._R(tag)(self.nl(out))
1142                out = self._R(tag)(out)
1143            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
1144
1145        return outputs

Forward pass.

class MSRefinementAttention(torch.nn.modules.module.Module):
1148class MSRefinementAttention(nn.Module):
1149    """
1150    Refinement stage
1151    """
1152
1153    def __init__(
1154        self,
1155        num_layers_R,
1156        num_R,
1157        num_f_maps_input,
1158        num_f_maps,
1159        num_classes,
1160        dropout_rate,
1161        exclusive,
1162        skip_connections,
1163        len_segment,
1164        block_size=0,
1165    ):
1166        """
1167        Parameters
1168        ----------
1169        num_layers_R : int
1170            number of layers in refinement modules
1171        num_R : int
1172            number of refinement modules
1173        num_f_maps : int
1174            number of feature maps
1175        num_classes : int
1176            number of target classes
1177        dropout_rate : float
1178            dropout rate
1179        exclusive : bool
1180            set `False` for multi-label classification
1181        skip_connections : bool
1182            if `True`, skip connections are added
1183        direction : [None, 'bidirectional', 'forward', 'backward']
1184            the direction of convolutions; if None, regular convolutions are used
1185        block_size : int, default 0
1186            if not 0, skip connections are added to the prediction generation stage with this interval
1187        num_heads : int, default 1
1188            number of parallel refinement stages
1189        """
1190
1191        super().__init__()
1192        self.skip_connections = skip_connections
1193        if exclusive:
1194            self.nl = lambda x: F.softmax(x, dim=1)
1195        else:
1196            self.nl = lambda x: torch.sigmoid(x)
1197        refinement_module = Refinement_SE
1198        self.Rs = nn.ModuleList(
1199            [
1200                refinement_module(
1201                    num_layers=num_layers_R,
1202                    num_f_maps=num_f_maps,
1203                    num_f_maps_input=num_f_maps_input,
1204                    dim=num_classes,
1205                    num_classes=num_classes,
1206                    dropout_rate=dropout_rate,
1207                    direction=None,
1208                    skip_connections=skip_connections,
1209                    block_size=block_size,
1210                    len_segment=len_segment,
1211                )
1212                for s in range(num_R)
1213            ]
1214        )
1215        self.conv_out = nn.Conv1d(num_f_maps_input, num_classes, 1)
1216
1217    def forward(self, x, tag=None):
1218        """Forward pass."""
1219        if tag is not None:
1220            tag = tag[0]
1221        out = self.conv_out(x)
1222        outputs = out.unsqueeze(0)
1223        for R in self.Rs:
1224            if self.skip_connections:
1225                out = R(torch.cat([self.nl(out), x], axis=1), tag)
1226                # out = R(torch.cat([out, x], axis=1), tag)
1227            else:
1228                out = R(self.nl(out), tag)
1229                # out = R(out, tag)
1230            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
1231
1232        return outputs

Refinement stage

MSRefinementAttention( num_layers_R, num_R, num_f_maps_input, num_f_maps, num_classes, dropout_rate, exclusive, skip_connections, len_segment, block_size=0)
1153    def __init__(
1154        self,
1155        num_layers_R,
1156        num_R,
1157        num_f_maps_input,
1158        num_f_maps,
1159        num_classes,
1160        dropout_rate,
1161        exclusive,
1162        skip_connections,
1163        len_segment,
1164        block_size=0,
1165    ):
1166        """
1167        Parameters
1168        ----------
1169        num_layers_R : int
1170            number of layers in refinement modules
1171        num_R : int
1172            number of refinement modules
1173        num_f_maps : int
1174            number of feature maps
1175        num_classes : int
1176            number of target classes
1177        dropout_rate : float
1178            dropout rate
1179        exclusive : bool
1180            set `False` for multi-label classification
1181        skip_connections : bool
1182            if `True`, skip connections are added
1183        direction : [None, 'bidirectional', 'forward', 'backward']
1184            the direction of convolutions; if None, regular convolutions are used
1185        block_size : int, default 0
1186            if not 0, skip connections are added to the prediction generation stage with this interval
1187        num_heads : int, default 1
1188            number of parallel refinement stages
1189        """
1190
1191        super().__init__()
1192        self.skip_connections = skip_connections
1193        if exclusive:
1194            self.nl = lambda x: F.softmax(x, dim=1)
1195        else:
1196            self.nl = lambda x: torch.sigmoid(x)
1197        refinement_module = Refinement_SE
1198        self.Rs = nn.ModuleList(
1199            [
1200                refinement_module(
1201                    num_layers=num_layers_R,
1202                    num_f_maps=num_f_maps,
1203                    num_f_maps_input=num_f_maps_input,
1204                    dim=num_classes,
1205                    num_classes=num_classes,
1206                    dropout_rate=dropout_rate,
1207                    direction=None,
1208                    skip_connections=skip_connections,
1209                    block_size=block_size,
1210                    len_segment=len_segment,
1211                )
1212                for s in range(num_R)
1213            ]
1214        )
1215        self.conv_out = nn.Conv1d(num_f_maps_input, num_classes, 1)

Parameters

num_layers_R : int number of layers in refinement modules num_R : int number of refinement modules num_f_maps : int number of feature maps num_classes : int number of target classes dropout_rate : float dropout rate exclusive : bool set False for multi-label classification skip_connections : bool if True, skip connections are added direction : [None, 'bidirectional', 'forward', 'backward'] the direction of convolutions; if None, regular convolutions are used block_size : int, default 0 if not 0, skip connections are added to the prediction generation stage with this interval num_heads : int, default 1 number of parallel refinement stages

skip_connections
Rs
conv_out
def forward(self, x, tag=None):
1217    def forward(self, x, tag=None):
1218        """Forward pass."""
1219        if tag is not None:
1220            tag = tag[0]
1221        out = self.conv_out(x)
1222        outputs = out.unsqueeze(0)
1223        for R in self.Rs:
1224            if self.skip_connections:
1225                out = R(torch.cat([self.nl(out), x], axis=1), tag)
1226                # out = R(torch.cat([out, x], axis=1), tag)
1227            else:
1228                out = R(self.nl(out), tag)
1229                # out = R(out, tag)
1230            outputs = torch.cat((outputs, out.unsqueeze(0)), dim=0)
1231
1232        return outputs

Forward pass.

class DilatedTCNC(torch.nn.modules.module.Module):
1235class DilatedTCNC(nn.Module):
1236    def __init__(
1237        self,
1238        num_f_maps,
1239        num_layers_PG,
1240        len_segment,
1241        block_size_prediction=5,
1242        kernel_size_prediction=3,
1243        direction_PG=None,
1244    ):
1245        super(DilatedTCNC, self).__init__()
1246        if direction_PG == "bidirectional":
1247            self.PG_S = DilatedTCNB(
1248                num_layers=num_layers_PG,
1249                num_f_maps=num_f_maps,
1250                dim=num_f_maps,
1251                block_size=block_size_prediction,
1252            )
1253            self.PG_T = DilatedTCNB(
1254                num_layers=num_layers_PG,
1255                num_f_maps=len_segment,
1256                dim=len_segment,
1257                block_size=block_size_prediction,
1258            )
1259        else:
1260            self.PG_S = DilatedTCN(
1261                num_layers=num_layers_PG,
1262                num_f_maps=num_f_maps,
1263                dim=num_f_maps,
1264                direction=direction_PG,
1265                block_size=block_size_prediction,
1266                kernel_size=kernel_size_prediction,
1267            )
1268            self.PG_T = DilatedTCN(
1269                num_layers=num_layers_PG,
1270                num_f_maps=len_segment,
1271                dim=len_segment,
1272                direction=direction_PG,
1273                block_size=block_size_prediction,
1274                kernel_size=kernel_size_prediction,
1275            )
1276
1277    def forward(self, x):
1278        """Forward pass."""
1279        x = self.PG_S(x)
1280        x = torch.transpose(x, 1, 2)
1281        x = self.PG_T(x)
1282        x = torch.transpose(x, 1, 2)
1283        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

DilatedTCNC( num_f_maps, num_layers_PG, len_segment, block_size_prediction=5, kernel_size_prediction=3, direction_PG=None)
1236    def __init__(
1237        self,
1238        num_f_maps,
1239        num_layers_PG,
1240        len_segment,
1241        block_size_prediction=5,
1242        kernel_size_prediction=3,
1243        direction_PG=None,
1244    ):
1245        super(DilatedTCNC, self).__init__()
1246        if direction_PG == "bidirectional":
1247            self.PG_S = DilatedTCNB(
1248                num_layers=num_layers_PG,
1249                num_f_maps=num_f_maps,
1250                dim=num_f_maps,
1251                block_size=block_size_prediction,
1252            )
1253            self.PG_T = DilatedTCNB(
1254                num_layers=num_layers_PG,
1255                num_f_maps=len_segment,
1256                dim=len_segment,
1257                block_size=block_size_prediction,
1258            )
1259        else:
1260            self.PG_S = DilatedTCN(
1261                num_layers=num_layers_PG,
1262                num_f_maps=num_f_maps,
1263                dim=num_f_maps,
1264                direction=direction_PG,
1265                block_size=block_size_prediction,
1266                kernel_size=kernel_size_prediction,
1267            )
1268            self.PG_T = DilatedTCN(
1269                num_layers=num_layers_PG,
1270                num_f_maps=len_segment,
1271                dim=len_segment,
1272                direction=direction_PG,
1273                block_size=block_size_prediction,
1274                kernel_size=kernel_size_prediction,
1275            )

Initialize internal Module state, shared by both nn.Module and ScriptModule.

def forward(self, x):
1277    def forward(self, x):
1278        """Forward pass."""
1279        x = self.PG_S(x)
1280        x = torch.transpose(x, 1, 2)
1281        x = self.PG_T(x)
1282        x = torch.transpose(x, 1, 2)
1283        return x

Forward pass.