dlc2action.model.edtcn

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6# Incorporates code adapted from ASRF by yiskw713
  7# Original work Copyright (c) 2020 yiskw713
  8# Source: https://github.com/yiskw713/asrf/blob/main/libs/models/tcn.py
  9# Originally licensed under MIT License
 10# Combined work licensed under GNU AGPLv3
 11#
 12from typing import Any, List, Tuple, Union
 13
 14import torch
 15from dlc2action.model.base_model import Model
 16from torch import nn
 17from torch.nn import functional as F
 18
 19
 20class NormalizedReLU(nn.Module):
 21    """
 22    Normalized ReLU Activation proposed in the original TCN paper.
 23    the values are divided by the max computed per frame
 24    """
 25
 26    def __init__(self, eps: float = 1e-5) -> None:
 27        super().__init__()
 28        self.eps = eps
 29
 30    def forward(self, x: torch.Tensor) -> torch.Tensor:
 31        """Forward pass."""
 32        x = F.relu(x)
 33        x = x / (x.max(dim=1, keepdim=True)[0] + self.eps)
 34
 35        return x
 36
 37
 38class EDTCNModule(nn.Module):
 39    """
 40    Encoder Decoder Temporal Convolutional Network
 41    """
 42
 43    def __init__(
 44        self,
 45        in_channel: int,
 46        output_dim: int,
 47        kernel_size: int = 25,
 48        mid_channels: Tuple[int, int] = [128, 160],
 49        **kwargs: Any
 50    ) -> None:
 51        """
 52        Args:
 53            in_channel: int. the number of the channels of input feature
 54            output_dim: int. output classes
 55            kernel_size: int. 25 is proposed in the original paper
 56            mid_channels: list. the list of the number of the channels of the middle layer.
 57                        [96 + 32*1, 96 + 32*2] is proposed in the original paper
 58        Note that this implementation only supports n_layer=2
 59        """
 60        super().__init__()
 61
 62        # encoder
 63        self.enc1 = nn.Conv1d(
 64            in_channel,
 65            mid_channels[0],
 66            kernel_size,
 67            stride=1,
 68            padding=(kernel_size - 1) // 2,
 69        )
 70        self.dropout1 = nn.Dropout(0.3)
 71        self.relu1 = NormalizedReLU()
 72
 73        self.enc2 = nn.Conv1d(
 74            mid_channels[0],
 75            mid_channels[1],
 76            kernel_size,
 77            stride=1,
 78            padding=(kernel_size - 1) // 2,
 79        )
 80        self.dropout2 = nn.Dropout(0.3)
 81        self.relu2 = NormalizedReLU()
 82
 83        # decoder
 84        self.dec1 = nn.Conv1d(
 85            mid_channels[1],
 86            mid_channels[1],
 87            kernel_size,
 88            stride=1,
 89            padding=(kernel_size - 1) // 2,
 90        )
 91        self.dropout3 = nn.Dropout(0.3)
 92        self.relu3 = NormalizedReLU()
 93
 94        self.dec2 = nn.Conv1d(
 95            mid_channels[1],
 96            mid_channels[0],
 97            kernel_size,
 98            stride=1,
 99            padding=(kernel_size - 1) // 2,
100        )
101        self.dropout4 = nn.Dropout(0.3)
102        self.relu4 = NormalizedReLU()
103
104        self.conv_out = nn.Conv1d(mid_channels[0], output_dim, 1, bias=True)
105
106        self.init_weight()
107
108    def forward(self, x: torch.Tensor) -> torch.Tensor:
109        """Forward pass."""
110        # encoder 1
111        x1 = self.relu1(self.dropout1(self.enc1(x)))
112        t1 = x1.shape[2]
113        x1 = F.max_pool1d(x1, 2)
114
115        # encoder 2
116        x2 = self.relu2(self.dropout2(self.enc2(x1)))
117        t2 = x2.shape[2]
118        x2 = F.max_pool1d(x2, 2)
119
120        # decoder 1
121        x3 = F.interpolate(x2, size=(t2,), mode="nearest")
122        x3 = self.relu3(self.dropout3(self.dec1(x3)))
123
124        # decoder 2
125        x4 = F.interpolate(x3, size=(t1,), mode="nearest")
126        x4 = self.relu4(self.dropout4(self.dec2(x4)))
127
128        out = self.conv_out(x4)
129
130        return out
131
132    def init_weight(self) -> None:
133        """Initialize weights of the model."""
134        for m in self.modules():
135            if isinstance(m, nn.Conv1d):
136                nn.init.xavier_normal_(m.weight)
137                if m.bias is not None:
138                    torch.nn.init.zeros_(m.bias)
139
140
141class Predictor(nn.Module):
142    def __init__(self, dim, num_classes):
143        super(Predictor, self).__init__()
144        self.num_classes = num_classes
145        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
146        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
147
148    def forward(self, x):
149        """Forward pass."""
150        x = self.conv_out_1(x)
151        x = F.relu(x)
152        x = self.conv_out_2(x)
153        return x
154
155
156class EDTCN(Model):
157    """
158    An implementation of EDTCN (Endoder-Decoder TCN)
159    """
160
161    def __init__(
162        self,
163        num_classes,
164        input_dims,
165        kernel_size,
166        mid_channels,
167        feature_dim=None,
168        state_dict_path=None,
169        ssl_constructors=None,
170        ssl_types=None,
171        ssl_modules=None,
172    ):
173        input_dims = int(sum([s[0] for s in input_dims.values()]))
174        if feature_dim is None:
175            feature_dim = num_classes
176            self.params_predictor = None
177            self.f_shape = None
178        else:
179            self.params_predictor = {
180                "dim": int(feature_dim),
181                "num_classes": int(num_classes),
182            }
183            self.f_shape = torch.Size([int(feature_dim)])
184        self.params = {
185            "output_dim": int(feature_dim),
186            "in_channel": input_dims,
187            "kernel_size": int(kernel_size),
188            "mid_channels": mid_channels,
189        }
190        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
191
192    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
193        return EDTCNModule(**self.params)
194
195    def _predictor(self) -> torch.nn.Module:
196        if self.params_predictor is None:
197            return nn.Identity()
198        else:
199            return Predictor(**self.params_predictor)
200
201    def features_shape(self) -> torch.Size:
202        return self.f_shape
class NormalizedReLU(torch.nn.modules.module.Module):
21class NormalizedReLU(nn.Module):
22    """
23    Normalized ReLU Activation proposed in the original TCN paper.
24    the values are divided by the max computed per frame
25    """
26
27    def __init__(self, eps: float = 1e-5) -> None:
28        super().__init__()
29        self.eps = eps
30
31    def forward(self, x: torch.Tensor) -> torch.Tensor:
32        """Forward pass."""
33        x = F.relu(x)
34        x = x / (x.max(dim=1, keepdim=True)[0] + self.eps)
35
36        return x

Normalized ReLU Activation proposed in the original TCN paper. the values are divided by the max computed per frame

NormalizedReLU(eps: float = 1e-05)
27    def __init__(self, eps: float = 1e-5) -> None:
28        super().__init__()
29        self.eps = eps

Initialize internal Module state, shared by both nn.Module and ScriptModule.

eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
31    def forward(self, x: torch.Tensor) -> torch.Tensor:
32        """Forward pass."""
33        x = F.relu(x)
34        x = x / (x.max(dim=1, keepdim=True)[0] + self.eps)
35
36        return x

Forward pass.

class EDTCNModule(torch.nn.modules.module.Module):
 39class EDTCNModule(nn.Module):
 40    """
 41    Encoder Decoder Temporal Convolutional Network
 42    """
 43
 44    def __init__(
 45        self,
 46        in_channel: int,
 47        output_dim: int,
 48        kernel_size: int = 25,
 49        mid_channels: Tuple[int, int] = [128, 160],
 50        **kwargs: Any
 51    ) -> None:
 52        """
 53        Args:
 54            in_channel: int. the number of the channels of input feature
 55            output_dim: int. output classes
 56            kernel_size: int. 25 is proposed in the original paper
 57            mid_channels: list. the list of the number of the channels of the middle layer.
 58                        [96 + 32*1, 96 + 32*2] is proposed in the original paper
 59        Note that this implementation only supports n_layer=2
 60        """
 61        super().__init__()
 62
 63        # encoder
 64        self.enc1 = nn.Conv1d(
 65            in_channel,
 66            mid_channels[0],
 67            kernel_size,
 68            stride=1,
 69            padding=(kernel_size - 1) // 2,
 70        )
 71        self.dropout1 = nn.Dropout(0.3)
 72        self.relu1 = NormalizedReLU()
 73
 74        self.enc2 = nn.Conv1d(
 75            mid_channels[0],
 76            mid_channels[1],
 77            kernel_size,
 78            stride=1,
 79            padding=(kernel_size - 1) // 2,
 80        )
 81        self.dropout2 = nn.Dropout(0.3)
 82        self.relu2 = NormalizedReLU()
 83
 84        # decoder
 85        self.dec1 = nn.Conv1d(
 86            mid_channels[1],
 87            mid_channels[1],
 88            kernel_size,
 89            stride=1,
 90            padding=(kernel_size - 1) // 2,
 91        )
 92        self.dropout3 = nn.Dropout(0.3)
 93        self.relu3 = NormalizedReLU()
 94
 95        self.dec2 = nn.Conv1d(
 96            mid_channels[1],
 97            mid_channels[0],
 98            kernel_size,
 99            stride=1,
100            padding=(kernel_size - 1) // 2,
101        )
102        self.dropout4 = nn.Dropout(0.3)
103        self.relu4 = NormalizedReLU()
104
105        self.conv_out = nn.Conv1d(mid_channels[0], output_dim, 1, bias=True)
106
107        self.init_weight()
108
109    def forward(self, x: torch.Tensor) -> torch.Tensor:
110        """Forward pass."""
111        # encoder 1
112        x1 = self.relu1(self.dropout1(self.enc1(x)))
113        t1 = x1.shape[2]
114        x1 = F.max_pool1d(x1, 2)
115
116        # encoder 2
117        x2 = self.relu2(self.dropout2(self.enc2(x1)))
118        t2 = x2.shape[2]
119        x2 = F.max_pool1d(x2, 2)
120
121        # decoder 1
122        x3 = F.interpolate(x2, size=(t2,), mode="nearest")
123        x3 = self.relu3(self.dropout3(self.dec1(x3)))
124
125        # decoder 2
126        x4 = F.interpolate(x3, size=(t1,), mode="nearest")
127        x4 = self.relu4(self.dropout4(self.dec2(x4)))
128
129        out = self.conv_out(x4)
130
131        return out
132
133    def init_weight(self) -> None:
134        """Initialize weights of the model."""
135        for m in self.modules():
136            if isinstance(m, nn.Conv1d):
137                nn.init.xavier_normal_(m.weight)
138                if m.bias is not None:
139                    torch.nn.init.zeros_(m.bias)

Encoder Decoder Temporal Convolutional Network

EDTCNModule( in_channel: int, output_dim: int, kernel_size: int = 25, mid_channels: Tuple[int, int] = [128, 160], **kwargs: Any)
 44    def __init__(
 45        self,
 46        in_channel: int,
 47        output_dim: int,
 48        kernel_size: int = 25,
 49        mid_channels: Tuple[int, int] = [128, 160],
 50        **kwargs: Any
 51    ) -> None:
 52        """
 53        Args:
 54            in_channel: int. the number of the channels of input feature
 55            output_dim: int. output classes
 56            kernel_size: int. 25 is proposed in the original paper
 57            mid_channels: list. the list of the number of the channels of the middle layer.
 58                        [96 + 32*1, 96 + 32*2] is proposed in the original paper
 59        Note that this implementation only supports n_layer=2
 60        """
 61        super().__init__()
 62
 63        # encoder
 64        self.enc1 = nn.Conv1d(
 65            in_channel,
 66            mid_channels[0],
 67            kernel_size,
 68            stride=1,
 69            padding=(kernel_size - 1) // 2,
 70        )
 71        self.dropout1 = nn.Dropout(0.3)
 72        self.relu1 = NormalizedReLU()
 73
 74        self.enc2 = nn.Conv1d(
 75            mid_channels[0],
 76            mid_channels[1],
 77            kernel_size,
 78            stride=1,
 79            padding=(kernel_size - 1) // 2,
 80        )
 81        self.dropout2 = nn.Dropout(0.3)
 82        self.relu2 = NormalizedReLU()
 83
 84        # decoder
 85        self.dec1 = nn.Conv1d(
 86            mid_channels[1],
 87            mid_channels[1],
 88            kernel_size,
 89            stride=1,
 90            padding=(kernel_size - 1) // 2,
 91        )
 92        self.dropout3 = nn.Dropout(0.3)
 93        self.relu3 = NormalizedReLU()
 94
 95        self.dec2 = nn.Conv1d(
 96            mid_channels[1],
 97            mid_channels[0],
 98            kernel_size,
 99            stride=1,
100            padding=(kernel_size - 1) // 2,
101        )
102        self.dropout4 = nn.Dropout(0.3)
103        self.relu4 = NormalizedReLU()
104
105        self.conv_out = nn.Conv1d(mid_channels[0], output_dim, 1, bias=True)
106
107        self.init_weight()

Args: in_channel: int. the number of the channels of input feature output_dim: int. output classes kernel_size: int. 25 is proposed in the original paper mid_channels: list. the list of the number of the channels of the middle layer. [96 + 321, 96 + 322] is proposed in the original paper Note that this implementation only supports n_layer=2

enc1
dropout1
relu1
enc2
dropout2
relu2
dec1
dropout3
relu3
dec2
dropout4
relu4
conv_out
def forward(self, x: torch.Tensor) -> torch.Tensor:
109    def forward(self, x: torch.Tensor) -> torch.Tensor:
110        """Forward pass."""
111        # encoder 1
112        x1 = self.relu1(self.dropout1(self.enc1(x)))
113        t1 = x1.shape[2]
114        x1 = F.max_pool1d(x1, 2)
115
116        # encoder 2
117        x2 = self.relu2(self.dropout2(self.enc2(x1)))
118        t2 = x2.shape[2]
119        x2 = F.max_pool1d(x2, 2)
120
121        # decoder 1
122        x3 = F.interpolate(x2, size=(t2,), mode="nearest")
123        x3 = self.relu3(self.dropout3(self.dec1(x3)))
124
125        # decoder 2
126        x4 = F.interpolate(x3, size=(t1,), mode="nearest")
127        x4 = self.relu4(self.dropout4(self.dec2(x4)))
128
129        out = self.conv_out(x4)
130
131        return out

Forward pass.

def init_weight(self) -> None:
133    def init_weight(self) -> None:
134        """Initialize weights of the model."""
135        for m in self.modules():
136            if isinstance(m, nn.Conv1d):
137                nn.init.xavier_normal_(m.weight)
138                if m.bias is not None:
139                    torch.nn.init.zeros_(m.bias)

Initialize weights of the model.

class Predictor(torch.nn.modules.module.Module):
142class Predictor(nn.Module):
143    def __init__(self, dim, num_classes):
144        super(Predictor, self).__init__()
145        self.num_classes = num_classes
146        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
147        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
148
149    def forward(self, x):
150        """Forward pass."""
151        x = self.conv_out_1(x)
152        x = F.relu(x)
153        x = self.conv_out_2(x)
154        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Predictor(dim, num_classes)
143    def __init__(self, dim, num_classes):
144        super(Predictor, self).__init__()
145        self.num_classes = num_classes
146        self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1)
147        self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

num_classes
conv_out_1
conv_out_2
def forward(self, x):
149    def forward(self, x):
150        """Forward pass."""
151        x = self.conv_out_1(x)
152        x = F.relu(x)
153        x = self.conv_out_2(x)
154        return x

Forward pass.

class EDTCN(dlc2action.model.base_model.Model):
157class EDTCN(Model):
158    """
159    An implementation of EDTCN (Endoder-Decoder TCN)
160    """
161
162    def __init__(
163        self,
164        num_classes,
165        input_dims,
166        kernel_size,
167        mid_channels,
168        feature_dim=None,
169        state_dict_path=None,
170        ssl_constructors=None,
171        ssl_types=None,
172        ssl_modules=None,
173    ):
174        input_dims = int(sum([s[0] for s in input_dims.values()]))
175        if feature_dim is None:
176            feature_dim = num_classes
177            self.params_predictor = None
178            self.f_shape = None
179        else:
180            self.params_predictor = {
181                "dim": int(feature_dim),
182                "num_classes": int(num_classes),
183            }
184            self.f_shape = torch.Size([int(feature_dim)])
185        self.params = {
186            "output_dim": int(feature_dim),
187            "in_channel": input_dims,
188            "kernel_size": int(kernel_size),
189            "mid_channels": mid_channels,
190        }
191        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
192
193    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
194        return EDTCNModule(**self.params)
195
196    def _predictor(self) -> torch.nn.Module:
197        if self.params_predictor is None:
198            return nn.Identity()
199        else:
200            return Predictor(**self.params_predictor)
201
202    def features_shape(self) -> torch.Size:
203        return self.f_shape

An implementation of EDTCN (Endoder-Decoder TCN)

EDTCN( num_classes, input_dims, kernel_size, mid_channels, feature_dim=None, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
162    def __init__(
163        self,
164        num_classes,
165        input_dims,
166        kernel_size,
167        mid_channels,
168        feature_dim=None,
169        state_dict_path=None,
170        ssl_constructors=None,
171        ssl_types=None,
172        ssl_modules=None,
173    ):
174        input_dims = int(sum([s[0] for s in input_dims.values()]))
175        if feature_dim is None:
176            feature_dim = num_classes
177            self.params_predictor = None
178            self.f_shape = None
179        else:
180            self.params_predictor = {
181                "dim": int(feature_dim),
182                "num_classes": int(num_classes),
183            }
184            self.f_shape = torch.Size([int(feature_dim)])
185        self.params = {
186            "output_dim": int(feature_dim),
187            "in_channel": input_dims,
188            "kernel_size": int(kernel_size),
189            "mid_channels": mid_channels,
190        }
191        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Initialize the model.

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt

params
def features_shape(self) -> torch.Size:
202    def features_shape(self) -> torch.Size:
203        return self.f_shape

Get the shape of feature extractor output.

Returns

feature_shape : torch.Size shape of feature extractor output