dlc2action.ssl.modules

Network modules used by implementations of dlc2action.ssl.base_ssl.SSLConstructor.

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""Network modules used by implementations of `dlc2action.ssl.base_ssl.SSLConstructor`."""
  8
  9import torch
 10from torch import nn
 11import copy
 12from math import floor
 13import torch.nn.functional as F
 14from torch.nn import Linear
 15
 16class FeatureExtractorTCN(nn.Module):
 17    """A module that extracts clip-level features with a TCN."""
 18
 19    def __init__(
 20        self,
 21        num_f_maps: int,
 22        output_dim: int,
 23        len_segment: int,
 24        kernel_1: int,
 25        kernel_2: int,
 26        stride: int,
 27        decrease_f_maps: bool = False,
 28    ) -> None:
 29        """Initialize the module.
 30
 31        Parameters
 32        ----------
 33        num_f_maps : int
 34            number of features in input
 35        output_dim : int
 36            number of features in output
 37        len_segment : int
 38            length of segment in input
 39        kernel_1 : int
 40            kernel size of the first layer
 41        kernel_2 : int
 42            kernel size of the second layer
 43        stride : int
 44            stride
 45        decrease_f_maps : bool, default False
 46            if `True`, number of feature maps is halved at each new layer
 47
 48        """
 49        super().__init__()
 50        num_f_maps = int(num_f_maps)
 51        output_dim = int(output_dim)
 52        if decrease_f_maps:
 53            f_maps_2 = max(num_f_maps // 2, 1)
 54            f_maps_3 = max(num_f_maps // 4, 1)
 55        else:
 56            f_maps_2 = f_maps_3 = num_f_maps
 57        length = int(floor((len_segment - kernel_1) / stride + 1))
 58        length = floor((length - kernel_2) / stride + 1)
 59        features = length * f_maps_3
 60        self.conv = nn.ModuleList()
 61        self.conv.append(
 62            nn.Conv1d(num_f_maps, f_maps_2, kernel_1, padding=0, stride=stride)
 63        )
 64        self.conv.append(
 65            nn.Conv1d(f_maps_2, f_maps_3, kernel_2, padding=0, stride=stride)
 66        )
 67        self.conv_1x1_out = nn.Conv1d(features, output_dim, 1)
 68        self.dropout = nn.Dropout()
 69
 70    def forward(self, f):
 71        """Forward pass."""
 72        for conv in self.conv:
 73            f = conv(f)
 74            f = F.relu(f)
 75            f = self.dropout(f)
 76        f = f.reshape((f.shape[0], -1, 1))
 77        f = self.conv_1x1_out(f).squeeze()
 78        return f
 79
 80
 81class MFeatureExtractorTCN(nn.Module):
 82    """A module that extracts segment-level features with a TCN."""
 83
 84    def __init__(
 85        self,
 86        num_f_maps: int,
 87        output_dim: int,
 88        len_segment: int,
 89        kernel_1: int,
 90        kernel_2: int,
 91        stride: int,
 92        start: int,
 93        end: int,
 94        num_layers: int = 3,
 95    ):
 96        """Initialize the module.
 97
 98        Parameters
 99        ----------
100        num_f_maps : int
101            number of features in input
102        output_dim : int
103            number of features in output
104        len_segment : int
105            length of segment in input
106        kernel_1 : int
107            kernel size of the first layer
108        kernel_2 : int
109            kernel size of the second layer
110        stride : int
111            stride
112        start : int
113            the start index of the segment to extract
114        end : int
115            the end index of the segment to extract
116        num_layers : int
117            number of layers
118
119        """
120        super(MFeatureExtractorTCN, self).__init__()
121        self.main_module = DilatedTCN(num_layers, num_f_maps, num_f_maps)
122        self.extractor = FeatureExtractorTCN(
123            num_f_maps, output_dim, len_segment, kernel_1, kernel_2, stride
124        )
125        in_features = int(len_segment * num_f_maps)
126        out_features = int((end - start) * num_f_maps)
127        self.linear = Linear(in_features=in_features, out_features=out_features)
128        self.start = int(start)
129        self.end = int(end)
130
131    def forward(self, f, extract_features=True):
132        """Forward pass."""
133        if extract_features:
134            f = self.main_module(f)
135            f = F.relu(f)
136            f = f.reshape((f.shape[0], -1))
137            f = self.linear(f)
138        else:
139            f = f[:, :, self.start : self.end]
140            f = f.reshape((f.shape[0], -1))
141        return f
142
143
144class FC(nn.Module):
145    """Fully connected module that predicts input data given features."""
146
147    def __init__(
148        self, dim: int, num_f_maps: int, num_ssl_layers: int, num_ssl_f_maps: int, ssl_input:bool = False
149    ) -> None:
150        """Initialize the module.
151
152        Parameters
153        ----------
154        dim : int
155            output number of features
156        num_f_maps : int
157            number of features in input
158        num_ssl_layers : int
159            number of layers in the module
160        num_ssl_f_maps : int
161            number of feature maps in the module
162
163        """
164        super().__init__()
165        dim = int(dim)
166        in_features = dim if ssl_input else num_f_maps
167        num_f_maps = int(num_f_maps)
168        num_ssl_layers = int(num_ssl_layers)
169        num_ssl_f_maps = int(num_ssl_f_maps)
170        self.layers = nn.ModuleList(
171            [nn.Linear(in_features=in_features, out_features=num_ssl_f_maps)]
172        )
173        for _ in range(num_ssl_layers - 2):
174            self.layers.append(
175                nn.Linear(in_features=num_ssl_f_maps, out_features=num_ssl_f_maps)
176            )
177        self.layers.append(nn.Linear(in_features=num_ssl_f_maps, out_features=dim))
178
179    def forward(self, x: torch.Tensor) -> torch.Tensor:
180        """Forward pass."""
181        N, C, F = x.shape
182        x = x.transpose(1, 2).reshape(-1, C)
183        for layer in self.layers:
184            x = layer(x)
185        x = x.reshape(N, F, -1).transpose(1, 2)
186        return x
187
188
189class DilatedTCN(nn.Module):
190    """TCN module that predicts input data given features."""
191
192    def __init__(self, num_layers, input_dim, output_dim):
193        """Initialize the module.
194
195        Parameters
196        ----------
197        output_dim : int
198            output number of features
199        input_dim : int
200            number of features in input
201        num_layers : int
202            number of layers in the module
203
204        """
205        super().__init__()
206        num_layers = int(num_layers)
207        input_dim = int(input_dim)
208        output_dim = int(output_dim)
209        self.num_layers = num_layers
210        self.conv_dilated_1 = nn.ModuleList(
211            (
212                nn.Conv1d(
213                    input_dim,
214                    input_dim,
215                    3,
216                    padding=2 ** (num_layers - 1 - i),
217                    dilation=2 ** (num_layers - 1 - i),
218                )
219                for i in range(num_layers)
220            )
221        )
222
223        self.conv_dilated_2 = nn.ModuleList(
224            (
225                nn.Conv1d(input_dim, input_dim, 3, padding=2**i, dilation=2**i)
226                for i in range(num_layers)
227            )
228        )
229
230        self.conv_fusion = nn.ModuleList(
231            (nn.Conv1d(2 * input_dim, input_dim, 1) for i in range(num_layers))
232        )
233
234        self.conv_1x1_out = nn.Conv1d(input_dim, output_dim, 1)
235
236        self.dropout = nn.Dropout()
237
238    def forward(self, f):
239        """Forward pass."""
240        for i in range(self.num_layers):
241            f_in = copy.copy(f)
242            f = self.conv_fusion[i](
243                torch.cat([self.conv_dilated_1[i](f), self.conv_dilated_2[i](f)], 1)
244            )
245            f = F.relu(f)
246            f = self.dropout(f)
247            f = f + f_in
248        f = self.conv_1x1_out(f)
249        return f
class FeatureExtractorTCN(torch.nn.modules.module.Module):
17class FeatureExtractorTCN(nn.Module):
18    """A module that extracts clip-level features with a TCN."""
19
20    def __init__(
21        self,
22        num_f_maps: int,
23        output_dim: int,
24        len_segment: int,
25        kernel_1: int,
26        kernel_2: int,
27        stride: int,
28        decrease_f_maps: bool = False,
29    ) -> None:
30        """Initialize the module.
31
32        Parameters
33        ----------
34        num_f_maps : int
35            number of features in input
36        output_dim : int
37            number of features in output
38        len_segment : int
39            length of segment in input
40        kernel_1 : int
41            kernel size of the first layer
42        kernel_2 : int
43            kernel size of the second layer
44        stride : int
45            stride
46        decrease_f_maps : bool, default False
47            if `True`, number of feature maps is halved at each new layer
48
49        """
50        super().__init__()
51        num_f_maps = int(num_f_maps)
52        output_dim = int(output_dim)
53        if decrease_f_maps:
54            f_maps_2 = max(num_f_maps // 2, 1)
55            f_maps_3 = max(num_f_maps // 4, 1)
56        else:
57            f_maps_2 = f_maps_3 = num_f_maps
58        length = int(floor((len_segment - kernel_1) / stride + 1))
59        length = floor((length - kernel_2) / stride + 1)
60        features = length * f_maps_3
61        self.conv = nn.ModuleList()
62        self.conv.append(
63            nn.Conv1d(num_f_maps, f_maps_2, kernel_1, padding=0, stride=stride)
64        )
65        self.conv.append(
66            nn.Conv1d(f_maps_2, f_maps_3, kernel_2, padding=0, stride=stride)
67        )
68        self.conv_1x1_out = nn.Conv1d(features, output_dim, 1)
69        self.dropout = nn.Dropout()
70
71    def forward(self, f):
72        """Forward pass."""
73        for conv in self.conv:
74            f = conv(f)
75            f = F.relu(f)
76            f = self.dropout(f)
77        f = f.reshape((f.shape[0], -1, 1))
78        f = self.conv_1x1_out(f).squeeze()
79        return f

A module that extracts clip-level features with a TCN.

FeatureExtractorTCN( num_f_maps: int, output_dim: int, len_segment: int, kernel_1: int, kernel_2: int, stride: int, decrease_f_maps: bool = False)
20    def __init__(
21        self,
22        num_f_maps: int,
23        output_dim: int,
24        len_segment: int,
25        kernel_1: int,
26        kernel_2: int,
27        stride: int,
28        decrease_f_maps: bool = False,
29    ) -> None:
30        """Initialize the module.
31
32        Parameters
33        ----------
34        num_f_maps : int
35            number of features in input
36        output_dim : int
37            number of features in output
38        len_segment : int
39            length of segment in input
40        kernel_1 : int
41            kernel size of the first layer
42        kernel_2 : int
43            kernel size of the second layer
44        stride : int
45            stride
46        decrease_f_maps : bool, default False
47            if `True`, number of feature maps is halved at each new layer
48
49        """
50        super().__init__()
51        num_f_maps = int(num_f_maps)
52        output_dim = int(output_dim)
53        if decrease_f_maps:
54            f_maps_2 = max(num_f_maps // 2, 1)
55            f_maps_3 = max(num_f_maps // 4, 1)
56        else:
57            f_maps_2 = f_maps_3 = num_f_maps
58        length = int(floor((len_segment - kernel_1) / stride + 1))
59        length = floor((length - kernel_2) / stride + 1)
60        features = length * f_maps_3
61        self.conv = nn.ModuleList()
62        self.conv.append(
63            nn.Conv1d(num_f_maps, f_maps_2, kernel_1, padding=0, stride=stride)
64        )
65        self.conv.append(
66            nn.Conv1d(f_maps_2, f_maps_3, kernel_2, padding=0, stride=stride)
67        )
68        self.conv_1x1_out = nn.Conv1d(features, output_dim, 1)
69        self.dropout = nn.Dropout()

Initialize the module.

Parameters

num_f_maps : int number of features in input output_dim : int number of features in output len_segment : int length of segment in input kernel_1 : int kernel size of the first layer kernel_2 : int kernel size of the second layer stride : int stride decrease_f_maps : bool, default False if True, number of feature maps is halved at each new layer

conv
conv_1x1_out
dropout
def forward(self, f):
71    def forward(self, f):
72        """Forward pass."""
73        for conv in self.conv:
74            f = conv(f)
75            f = F.relu(f)
76            f = self.dropout(f)
77        f = f.reshape((f.shape[0], -1, 1))
78        f = self.conv_1x1_out(f).squeeze()
79        return f

Forward pass.

class MFeatureExtractorTCN(torch.nn.modules.module.Module):
 82class MFeatureExtractorTCN(nn.Module):
 83    """A module that extracts segment-level features with a TCN."""
 84
 85    def __init__(
 86        self,
 87        num_f_maps: int,
 88        output_dim: int,
 89        len_segment: int,
 90        kernel_1: int,
 91        kernel_2: int,
 92        stride: int,
 93        start: int,
 94        end: int,
 95        num_layers: int = 3,
 96    ):
 97        """Initialize the module.
 98
 99        Parameters
100        ----------
101        num_f_maps : int
102            number of features in input
103        output_dim : int
104            number of features in output
105        len_segment : int
106            length of segment in input
107        kernel_1 : int
108            kernel size of the first layer
109        kernel_2 : int
110            kernel size of the second layer
111        stride : int
112            stride
113        start : int
114            the start index of the segment to extract
115        end : int
116            the end index of the segment to extract
117        num_layers : int
118            number of layers
119
120        """
121        super(MFeatureExtractorTCN, self).__init__()
122        self.main_module = DilatedTCN(num_layers, num_f_maps, num_f_maps)
123        self.extractor = FeatureExtractorTCN(
124            num_f_maps, output_dim, len_segment, kernel_1, kernel_2, stride
125        )
126        in_features = int(len_segment * num_f_maps)
127        out_features = int((end - start) * num_f_maps)
128        self.linear = Linear(in_features=in_features, out_features=out_features)
129        self.start = int(start)
130        self.end = int(end)
131
132    def forward(self, f, extract_features=True):
133        """Forward pass."""
134        if extract_features:
135            f = self.main_module(f)
136            f = F.relu(f)
137            f = f.reshape((f.shape[0], -1))
138            f = self.linear(f)
139        else:
140            f = f[:, :, self.start : self.end]
141            f = f.reshape((f.shape[0], -1))
142        return f

A module that extracts segment-level features with a TCN.

MFeatureExtractorTCN( num_f_maps: int, output_dim: int, len_segment: int, kernel_1: int, kernel_2: int, stride: int, start: int, end: int, num_layers: int = 3)
 85    def __init__(
 86        self,
 87        num_f_maps: int,
 88        output_dim: int,
 89        len_segment: int,
 90        kernel_1: int,
 91        kernel_2: int,
 92        stride: int,
 93        start: int,
 94        end: int,
 95        num_layers: int = 3,
 96    ):
 97        """Initialize the module.
 98
 99        Parameters
100        ----------
101        num_f_maps : int
102            number of features in input
103        output_dim : int
104            number of features in output
105        len_segment : int
106            length of segment in input
107        kernel_1 : int
108            kernel size of the first layer
109        kernel_2 : int
110            kernel size of the second layer
111        stride : int
112            stride
113        start : int
114            the start index of the segment to extract
115        end : int
116            the end index of the segment to extract
117        num_layers : int
118            number of layers
119
120        """
121        super(MFeatureExtractorTCN, self).__init__()
122        self.main_module = DilatedTCN(num_layers, num_f_maps, num_f_maps)
123        self.extractor = FeatureExtractorTCN(
124            num_f_maps, output_dim, len_segment, kernel_1, kernel_2, stride
125        )
126        in_features = int(len_segment * num_f_maps)
127        out_features = int((end - start) * num_f_maps)
128        self.linear = Linear(in_features=in_features, out_features=out_features)
129        self.start = int(start)
130        self.end = int(end)

Initialize the module.

Parameters

num_f_maps : int number of features in input output_dim : int number of features in output len_segment : int length of segment in input kernel_1 : int kernel size of the first layer kernel_2 : int kernel size of the second layer stride : int stride start : int the start index of the segment to extract end : int the end index of the segment to extract num_layers : int number of layers

main_module
extractor
linear
start
end
def forward(self, f, extract_features=True):
132    def forward(self, f, extract_features=True):
133        """Forward pass."""
134        if extract_features:
135            f = self.main_module(f)
136            f = F.relu(f)
137            f = f.reshape((f.shape[0], -1))
138            f = self.linear(f)
139        else:
140            f = f[:, :, self.start : self.end]
141            f = f.reshape((f.shape[0], -1))
142        return f

Forward pass.

class FC(torch.nn.modules.module.Module):
145class FC(nn.Module):
146    """Fully connected module that predicts input data given features."""
147
148    def __init__(
149        self, dim: int, num_f_maps: int, num_ssl_layers: int, num_ssl_f_maps: int, ssl_input:bool = False
150    ) -> None:
151        """Initialize the module.
152
153        Parameters
154        ----------
155        dim : int
156            output number of features
157        num_f_maps : int
158            number of features in input
159        num_ssl_layers : int
160            number of layers in the module
161        num_ssl_f_maps : int
162            number of feature maps in the module
163
164        """
165        super().__init__()
166        dim = int(dim)
167        in_features = dim if ssl_input else num_f_maps
168        num_f_maps = int(num_f_maps)
169        num_ssl_layers = int(num_ssl_layers)
170        num_ssl_f_maps = int(num_ssl_f_maps)
171        self.layers = nn.ModuleList(
172            [nn.Linear(in_features=in_features, out_features=num_ssl_f_maps)]
173        )
174        for _ in range(num_ssl_layers - 2):
175            self.layers.append(
176                nn.Linear(in_features=num_ssl_f_maps, out_features=num_ssl_f_maps)
177            )
178        self.layers.append(nn.Linear(in_features=num_ssl_f_maps, out_features=dim))
179
180    def forward(self, x: torch.Tensor) -> torch.Tensor:
181        """Forward pass."""
182        N, C, F = x.shape
183        x = x.transpose(1, 2).reshape(-1, C)
184        for layer in self.layers:
185            x = layer(x)
186        x = x.reshape(N, F, -1).transpose(1, 2)
187        return x

Fully connected module that predicts input data given features.

FC( dim: int, num_f_maps: int, num_ssl_layers: int, num_ssl_f_maps: int, ssl_input: bool = False)
148    def __init__(
149        self, dim: int, num_f_maps: int, num_ssl_layers: int, num_ssl_f_maps: int, ssl_input:bool = False
150    ) -> None:
151        """Initialize the module.
152
153        Parameters
154        ----------
155        dim : int
156            output number of features
157        num_f_maps : int
158            number of features in input
159        num_ssl_layers : int
160            number of layers in the module
161        num_ssl_f_maps : int
162            number of feature maps in the module
163
164        """
165        super().__init__()
166        dim = int(dim)
167        in_features = dim if ssl_input else num_f_maps
168        num_f_maps = int(num_f_maps)
169        num_ssl_layers = int(num_ssl_layers)
170        num_ssl_f_maps = int(num_ssl_f_maps)
171        self.layers = nn.ModuleList(
172            [nn.Linear(in_features=in_features, out_features=num_ssl_f_maps)]
173        )
174        for _ in range(num_ssl_layers - 2):
175            self.layers.append(
176                nn.Linear(in_features=num_ssl_f_maps, out_features=num_ssl_f_maps)
177            )
178        self.layers.append(nn.Linear(in_features=num_ssl_f_maps, out_features=dim))

Initialize the module.

Parameters

dim : int output number of features num_f_maps : int number of features in input num_ssl_layers : int number of layers in the module num_ssl_f_maps : int number of feature maps in the module

layers
def forward(self, x: torch.Tensor) -> torch.Tensor:
180    def forward(self, x: torch.Tensor) -> torch.Tensor:
181        """Forward pass."""
182        N, C, F = x.shape
183        x = x.transpose(1, 2).reshape(-1, C)
184        for layer in self.layers:
185            x = layer(x)
186        x = x.reshape(N, F, -1).transpose(1, 2)
187        return x

Forward pass.

class DilatedTCN(torch.nn.modules.module.Module):
190class DilatedTCN(nn.Module):
191    """TCN module that predicts input data given features."""
192
193    def __init__(self, num_layers, input_dim, output_dim):
194        """Initialize the module.
195
196        Parameters
197        ----------
198        output_dim : int
199            output number of features
200        input_dim : int
201            number of features in input
202        num_layers : int
203            number of layers in the module
204
205        """
206        super().__init__()
207        num_layers = int(num_layers)
208        input_dim = int(input_dim)
209        output_dim = int(output_dim)
210        self.num_layers = num_layers
211        self.conv_dilated_1 = nn.ModuleList(
212            (
213                nn.Conv1d(
214                    input_dim,
215                    input_dim,
216                    3,
217                    padding=2 ** (num_layers - 1 - i),
218                    dilation=2 ** (num_layers - 1 - i),
219                )
220                for i in range(num_layers)
221            )
222        )
223
224        self.conv_dilated_2 = nn.ModuleList(
225            (
226                nn.Conv1d(input_dim, input_dim, 3, padding=2**i, dilation=2**i)
227                for i in range(num_layers)
228            )
229        )
230
231        self.conv_fusion = nn.ModuleList(
232            (nn.Conv1d(2 * input_dim, input_dim, 1) for i in range(num_layers))
233        )
234
235        self.conv_1x1_out = nn.Conv1d(input_dim, output_dim, 1)
236
237        self.dropout = nn.Dropout()
238
239    def forward(self, f):
240        """Forward pass."""
241        for i in range(self.num_layers):
242            f_in = copy.copy(f)
243            f = self.conv_fusion[i](
244                torch.cat([self.conv_dilated_1[i](f), self.conv_dilated_2[i](f)], 1)
245            )
246            f = F.relu(f)
247            f = self.dropout(f)
248            f = f + f_in
249        f = self.conv_1x1_out(f)
250        return f

TCN module that predicts input data given features.

DilatedTCN(num_layers, input_dim, output_dim)
193    def __init__(self, num_layers, input_dim, output_dim):
194        """Initialize the module.
195
196        Parameters
197        ----------
198        output_dim : int
199            output number of features
200        input_dim : int
201            number of features in input
202        num_layers : int
203            number of layers in the module
204
205        """
206        super().__init__()
207        num_layers = int(num_layers)
208        input_dim = int(input_dim)
209        output_dim = int(output_dim)
210        self.num_layers = num_layers
211        self.conv_dilated_1 = nn.ModuleList(
212            (
213                nn.Conv1d(
214                    input_dim,
215                    input_dim,
216                    3,
217                    padding=2 ** (num_layers - 1 - i),
218                    dilation=2 ** (num_layers - 1 - i),
219                )
220                for i in range(num_layers)
221            )
222        )
223
224        self.conv_dilated_2 = nn.ModuleList(
225            (
226                nn.Conv1d(input_dim, input_dim, 3, padding=2**i, dilation=2**i)
227                for i in range(num_layers)
228            )
229        )
230
231        self.conv_fusion = nn.ModuleList(
232            (nn.Conv1d(2 * input_dim, input_dim, 1) for i in range(num_layers))
233        )
234
235        self.conv_1x1_out = nn.Conv1d(input_dim, output_dim, 1)
236
237        self.dropout = nn.Dropout()

Initialize the module.

Parameters

output_dim : int output number of features input_dim : int number of features in input num_layers : int number of layers in the module

num_layers
conv_dilated_1
conv_dilated_2
conv_fusion
conv_1x1_out
dropout
def forward(self, f):
239    def forward(self, f):
240        """Forward pass."""
241        for i in range(self.num_layers):
242            f_in = copy.copy(f)
243            f = self.conv_fusion[i](
244                torch.cat([self.conv_dilated_1[i](f), self.conv_dilated_2[i](f)], 1)
245            )
246            f = F.relu(f)
247            f = self.dropout(f)
248            f = f + f_in
249        f = self.conv_1x1_out(f)
250        return f

Forward pass.