dlc2action.ssl.modules

Network modules used by implementations of dlc2action.ssl.base_ssl.SSLConstructor

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6"""
  7Network modules used by implementations of `dlc2action.ssl.base_ssl.SSLConstructor`
  8"""
  9
 10import torch
 11from torch import nn
 12import copy
 13from math import floor
 14import torch.nn.functional as F
 15from torch.nn import Linear
 16
 17
 18class _FeatureExtractorTCN(nn.Module):
 19    """
 20    A module that extracts clip-level features with a TCN
 21    """
 22
 23    def __init__(
 24        self,
 25        num_f_maps: int,
 26        output_dim: int,
 27        len_segment: int,
 28        kernel_1: int,
 29        kernel_2: int,
 30        stride: int,
 31        decrease_f_maps: bool = False,
 32    ) -> None:
 33        """
 34        Parameters
 35        ----------
 36        num_f_maps : int
 37            number of features in input
 38        output_dim : int
 39            number of features in output
 40        len_segment : int
 41            length of segment in input
 42        kernel_1 : int
 43            kernel size of the first layer
 44        kernel_2 : int
 45            kernel size of the second layer
 46        stride : int
 47            stride
 48        decrease_f_maps : bool, default False
 49            if `True`, number of feature maps is halved at each new layer
 50        """
 51
 52        super().__init__()
 53        num_f_maps = int(num_f_maps)
 54        output_dim = int(output_dim)
 55        if decrease_f_maps:
 56            f_maps_2 = max(num_f_maps // 2, 1)
 57            f_maps_3 = max(num_f_maps // 4, 1)
 58        else:
 59            f_maps_2 = f_maps_3 = num_f_maps
 60        length = int(floor((len_segment - kernel_1) / stride + 1))
 61        length = floor((length - kernel_2) / stride + 1)
 62        features = length * f_maps_3
 63        self.conv = nn.ModuleList()
 64        self.conv.append(
 65            nn.Conv1d(num_f_maps, f_maps_2, kernel_1, padding=0, stride=stride)
 66        )
 67        self.conv.append(
 68            nn.Conv1d(f_maps_2, f_maps_3, kernel_2, padding=0, stride=stride)
 69        )
 70        self.conv_1x1_out = nn.Conv1d(features, output_dim, 1)
 71        self.dropout = nn.Dropout()
 72
 73    def forward(self, f):
 74        for conv in self.conv:
 75            f = conv(f)
 76            f = F.relu(f)
 77            f = self.dropout(f)
 78        f = f.reshape((f.shape[0], -1, 1))
 79        f = self.conv_1x1_out(f).squeeze()
 80        return f
 81
 82
 83class _MFeatureExtractorTCN(nn.Module):
 84    """
 85    A module that extracts segment-level features with a TCN
 86    """
 87
 88    def __init__(
 89        self,
 90        num_f_maps: int,
 91        output_dim: int,
 92        len_segment: int,
 93        kernel_1: int,
 94        kernel_2: int,
 95        stride: int,
 96        start: int,
 97        end: int,
 98        num_layers: int = 3,
 99    ):
100        """
101        Parameters
102        ----------
103        num_f_maps : int
104            number of features in input
105        output_dim : int
106            number of features in output
107        len_segment : int
108            length of segment in input
109        kernel_1 : int
110            kernel size of the first layer
111        kernel_2 : int
112            kernel size of the second layer
113        stride : int
114            stride
115        start : int
116            the start index of the segment to extract
117        end : int
118            the end index of the segment to extract
119        num_layers : int
120            number of layers
121        """
122
123        super(_MFeatureExtractorTCN, self).__init__()
124        self.main_module = _DilatedTCN(num_layers, num_f_maps, num_f_maps)
125        self.extractor = _FeatureExtractorTCN(
126            num_f_maps, output_dim, len_segment, kernel_1, kernel_2, stride
127        )
128        in_features = int(len_segment * num_f_maps)
129        out_features = int((end - start) * num_f_maps)
130        self.linear = Linear(in_features=in_features, out_features=out_features)
131        self.start = int(start)
132        self.end = int(end)
133
134    def forward(self, f, extract_features=True):
135        if extract_features:
136            f = self.main_module(f)
137            f = F.relu(f)
138            f = f.reshape((f.shape[0], -1))
139            f = self.linear(f)
140        else:
141            f = f[:, :, self.start : self.end]
142            f = f.reshape((f.shape[0], -1))
143        return f
144
145
146class _FC(nn.Module):
147    """
148    Fully connected module that predicts input data given features
149    """
150
151    def __init__(
152        self, dim: int, num_f_maps: int, num_ssl_layers: int, num_ssl_f_maps: int
153    ) -> None:
154        """
155        Parameters
156        ----------
157        dim : int
158            output number of features
159        num_f_maps : int
160            number of features in input
161        num_ssl_layers : int
162            number of layers in the module
163        num_ssl_f_maps : int
164            number of feature maps in the module
165        """
166
167        super().__init__()
168        dim = int(dim)
169        num_f_maps = int(num_f_maps)
170        num_ssl_layers = int(num_ssl_layers)
171        num_ssl_f_maps = int(num_ssl_f_maps)
172        self.layers = nn.ModuleList(
173            [nn.Linear(in_features=num_f_maps, out_features=num_ssl_f_maps)]
174        )
175        for _ in range(num_ssl_layers - 2):
176            self.layers.append(
177                nn.Linear(in_features=num_ssl_f_maps, out_features=num_ssl_f_maps)
178            )
179        self.layers.append(nn.Linear(in_features=num_ssl_f_maps, out_features=dim))
180
181    def forward(self, x: torch.Tensor) -> torch.Tensor:
182        N, C, F = x.shape
183        x = x.transpose(1, 2).reshape(-1, C)
184        for layer in self.layers:
185            x = layer(x)
186        x = x.reshape(N, F, -1).transpose(1, 2)
187        return x
188
189
190class _DilatedTCN(nn.Module):
191    """
192    TCN module that predicts input data given features
193    """
194
195    def __init__(self, num_layers, input_dim, output_dim):
196        """
197        Parameters
198        ----------
199        output_dim : int
200            output number of features
201        input_dim : int
202            number of features in input
203        num_layers : int
204            number of layers in the module
205        """
206
207        super().__init__()
208        num_layers = int(num_layers)
209        input_dim = int(input_dim)
210        output_dim = int(output_dim)
211        self.num_layers = num_layers
212        self.conv_dilated_1 = nn.ModuleList(
213            (
214                nn.Conv1d(
215                    input_dim,
216                    input_dim,
217                    3,
218                    padding=2 ** (num_layers - 1 - i),
219                    dilation=2 ** (num_layers - 1 - i),
220                )
221                for i in range(num_layers)
222            )
223        )
224
225        self.conv_dilated_2 = nn.ModuleList(
226            (
227                nn.Conv1d(input_dim, input_dim, 3, padding=2**i, dilation=2**i)
228                for i in range(num_layers)
229            )
230        )
231
232        self.conv_fusion = nn.ModuleList(
233            (nn.Conv1d(2 * input_dim, input_dim, 1) for i in range(num_layers))
234        )
235
236        self.conv_1x1_out = nn.Conv1d(input_dim, output_dim, 1)
237
238        self.dropout = nn.Dropout()
239
240    def forward(self, f):
241        for i in range(self.num_layers):
242            f_in = copy.copy(f)
243            f = self.conv_fusion[i](
244                torch.cat([self.conv_dilated_1[i](f), self.conv_dilated_2[i](f)], 1)
245            )
246            f = F.relu(f)
247            f = self.dropout(f)
248            f = f + f_in
249        f = self.conv_1x1_out(f)
250        return f