dlc2action.ssl.modules
Network modules used by implementations of dlc2action.ssl.base_ssl.SSLConstructor
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Network modules used by implementations of `dlc2action.ssl.base_ssl.SSLConstructor` 8""" 9 10import torch 11from torch import nn 12import copy 13from math import floor 14import torch.nn.functional as F 15from torch.nn import Linear 16 17 18class _FeatureExtractorTCN(nn.Module): 19 """ 20 A module that extracts clip-level features with a TCN 21 """ 22 23 def __init__( 24 self, 25 num_f_maps: int, 26 output_dim: int, 27 len_segment: int, 28 kernel_1: int, 29 kernel_2: int, 30 stride: int, 31 decrease_f_maps: bool = False, 32 ) -> None: 33 """ 34 Parameters 35 ---------- 36 num_f_maps : int 37 number of features in input 38 output_dim : int 39 number of features in output 40 len_segment : int 41 length of segment in input 42 kernel_1 : int 43 kernel size of the first layer 44 kernel_2 : int 45 kernel size of the second layer 46 stride : int 47 stride 48 decrease_f_maps : bool, default False 49 if `True`, number of feature maps is halved at each new layer 50 """ 51 52 super().__init__() 53 num_f_maps = int(num_f_maps) 54 output_dim = int(output_dim) 55 if decrease_f_maps: 56 f_maps_2 = max(num_f_maps // 2, 1) 57 f_maps_3 = max(num_f_maps // 4, 1) 58 else: 59 f_maps_2 = f_maps_3 = num_f_maps 60 length = int(floor((len_segment - kernel_1) / stride + 1)) 61 length = floor((length - kernel_2) / stride + 1) 62 features = length * f_maps_3 63 self.conv = nn.ModuleList() 64 self.conv.append( 65 nn.Conv1d(num_f_maps, f_maps_2, kernel_1, padding=0, stride=stride) 66 ) 67 self.conv.append( 68 nn.Conv1d(f_maps_2, f_maps_3, kernel_2, padding=0, stride=stride) 69 ) 70 self.conv_1x1_out = nn.Conv1d(features, output_dim, 1) 71 self.dropout = nn.Dropout() 72 73 def forward(self, f): 74 for conv in self.conv: 75 f = conv(f) 76 f = F.relu(f) 77 f = self.dropout(f) 78 f = f.reshape((f.shape[0], -1, 1)) 79 f = self.conv_1x1_out(f).squeeze() 80 return f 81 82 83class _MFeatureExtractorTCN(nn.Module): 84 """ 85 A module that extracts segment-level features with a TCN 86 """ 87 88 def __init__( 89 self, 90 num_f_maps: int, 91 output_dim: int, 92 len_segment: int, 93 kernel_1: int, 94 kernel_2: int, 95 stride: int, 96 start: int, 97 end: int, 98 num_layers: int = 3, 99 ): 100 """ 101 Parameters 102 ---------- 103 num_f_maps : int 104 number of features in input 105 output_dim : int 106 number of features in output 107 len_segment : int 108 length of segment in input 109 kernel_1 : int 110 kernel size of the first layer 111 kernel_2 : int 112 kernel size of the second layer 113 stride : int 114 stride 115 start : int 116 the start index of the segment to extract 117 end : int 118 the end index of the segment to extract 119 num_layers : int 120 number of layers 121 """ 122 123 super(_MFeatureExtractorTCN, self).__init__() 124 self.main_module = _DilatedTCN(num_layers, num_f_maps, num_f_maps) 125 self.extractor = _FeatureExtractorTCN( 126 num_f_maps, output_dim, len_segment, kernel_1, kernel_2, stride 127 ) 128 in_features = int(len_segment * num_f_maps) 129 out_features = int((end - start) * num_f_maps) 130 self.linear = Linear(in_features=in_features, out_features=out_features) 131 self.start = int(start) 132 self.end = int(end) 133 134 def forward(self, f, extract_features=True): 135 if extract_features: 136 f = self.main_module(f) 137 f = F.relu(f) 138 f = f.reshape((f.shape[0], -1)) 139 f = self.linear(f) 140 else: 141 f = f[:, :, self.start : self.end] 142 f = f.reshape((f.shape[0], -1)) 143 return f 144 145 146class _FC(nn.Module): 147 """ 148 Fully connected module that predicts input data given features 149 """ 150 151 def __init__( 152 self, dim: int, num_f_maps: int, num_ssl_layers: int, num_ssl_f_maps: int 153 ) -> None: 154 """ 155 Parameters 156 ---------- 157 dim : int 158 output number of features 159 num_f_maps : int 160 number of features in input 161 num_ssl_layers : int 162 number of layers in the module 163 num_ssl_f_maps : int 164 number of feature maps in the module 165 """ 166 167 super().__init__() 168 dim = int(dim) 169 num_f_maps = int(num_f_maps) 170 num_ssl_layers = int(num_ssl_layers) 171 num_ssl_f_maps = int(num_ssl_f_maps) 172 self.layers = nn.ModuleList( 173 [nn.Linear(in_features=num_f_maps, out_features=num_ssl_f_maps)] 174 ) 175 for _ in range(num_ssl_layers - 2): 176 self.layers.append( 177 nn.Linear(in_features=num_ssl_f_maps, out_features=num_ssl_f_maps) 178 ) 179 self.layers.append(nn.Linear(in_features=num_ssl_f_maps, out_features=dim)) 180 181 def forward(self, x: torch.Tensor) -> torch.Tensor: 182 N, C, F = x.shape 183 x = x.transpose(1, 2).reshape(-1, C) 184 for layer in self.layers: 185 x = layer(x) 186 x = x.reshape(N, F, -1).transpose(1, 2) 187 return x 188 189 190class _DilatedTCN(nn.Module): 191 """ 192 TCN module that predicts input data given features 193 """ 194 195 def __init__(self, num_layers, input_dim, output_dim): 196 """ 197 Parameters 198 ---------- 199 output_dim : int 200 output number of features 201 input_dim : int 202 number of features in input 203 num_layers : int 204 number of layers in the module 205 """ 206 207 super().__init__() 208 num_layers = int(num_layers) 209 input_dim = int(input_dim) 210 output_dim = int(output_dim) 211 self.num_layers = num_layers 212 self.conv_dilated_1 = nn.ModuleList( 213 ( 214 nn.Conv1d( 215 input_dim, 216 input_dim, 217 3, 218 padding=2 ** (num_layers - 1 - i), 219 dilation=2 ** (num_layers - 1 - i), 220 ) 221 for i in range(num_layers) 222 ) 223 ) 224 225 self.conv_dilated_2 = nn.ModuleList( 226 ( 227 nn.Conv1d(input_dim, input_dim, 3, padding=2**i, dilation=2**i) 228 for i in range(num_layers) 229 ) 230 ) 231 232 self.conv_fusion = nn.ModuleList( 233 (nn.Conv1d(2 * input_dim, input_dim, 1) for i in range(num_layers)) 234 ) 235 236 self.conv_1x1_out = nn.Conv1d(input_dim, output_dim, 1) 237 238 self.dropout = nn.Dropout() 239 240 def forward(self, f): 241 for i in range(self.num_layers): 242 f_in = copy.copy(f) 243 f = self.conv_fusion[i]( 244 torch.cat([self.conv_dilated_1[i](f), self.conv_dilated_2[i](f)], 1) 245 ) 246 f = F.relu(f) 247 f = self.dropout(f) 248 f = f + f_in 249 f = self.conv_1x1_out(f) 250 return f