dlc2action.ssl.modules
Network modules used by implementations of dlc2action.ssl.base_ssl.SSLConstructor.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Network modules used by implementations of `dlc2action.ssl.base_ssl.SSLConstructor`.""" 8 9import torch 10from torch import nn 11import copy 12from math import floor 13import torch.nn.functional as F 14from torch.nn import Linear 15 16class FeatureExtractorTCN(nn.Module): 17 """A module that extracts clip-level features with a TCN.""" 18 19 def __init__( 20 self, 21 num_f_maps: int, 22 output_dim: int, 23 len_segment: int, 24 kernel_1: int, 25 kernel_2: int, 26 stride: int, 27 decrease_f_maps: bool = False, 28 ) -> None: 29 """Initialize the module. 30 31 Parameters 32 ---------- 33 num_f_maps : int 34 number of features in input 35 output_dim : int 36 number of features in output 37 len_segment : int 38 length of segment in input 39 kernel_1 : int 40 kernel size of the first layer 41 kernel_2 : int 42 kernel size of the second layer 43 stride : int 44 stride 45 decrease_f_maps : bool, default False 46 if `True`, number of feature maps is halved at each new layer 47 48 """ 49 super().__init__() 50 num_f_maps = int(num_f_maps) 51 output_dim = int(output_dim) 52 if decrease_f_maps: 53 f_maps_2 = max(num_f_maps // 2, 1) 54 f_maps_3 = max(num_f_maps // 4, 1) 55 else: 56 f_maps_2 = f_maps_3 = num_f_maps 57 length = int(floor((len_segment - kernel_1) / stride + 1)) 58 length = floor((length - kernel_2) / stride + 1) 59 features = length * f_maps_3 60 self.conv = nn.ModuleList() 61 self.conv.append( 62 nn.Conv1d(num_f_maps, f_maps_2, kernel_1, padding=0, stride=stride) 63 ) 64 self.conv.append( 65 nn.Conv1d(f_maps_2, f_maps_3, kernel_2, padding=0, stride=stride) 66 ) 67 self.conv_1x1_out = nn.Conv1d(features, output_dim, 1) 68 self.dropout = nn.Dropout() 69 70 def forward(self, f): 71 """Forward pass.""" 72 for conv in self.conv: 73 f = conv(f) 74 f = F.relu(f) 75 f = self.dropout(f) 76 f = f.reshape((f.shape[0], -1, 1)) 77 f = self.conv_1x1_out(f).squeeze() 78 return f 79 80 81class MFeatureExtractorTCN(nn.Module): 82 """A module that extracts segment-level features with a TCN.""" 83 84 def __init__( 85 self, 86 num_f_maps: int, 87 output_dim: int, 88 len_segment: int, 89 kernel_1: int, 90 kernel_2: int, 91 stride: int, 92 start: int, 93 end: int, 94 num_layers: int = 3, 95 ): 96 """Initialize the module. 97 98 Parameters 99 ---------- 100 num_f_maps : int 101 number of features in input 102 output_dim : int 103 number of features in output 104 len_segment : int 105 length of segment in input 106 kernel_1 : int 107 kernel size of the first layer 108 kernel_2 : int 109 kernel size of the second layer 110 stride : int 111 stride 112 start : int 113 the start index of the segment to extract 114 end : int 115 the end index of the segment to extract 116 num_layers : int 117 number of layers 118 119 """ 120 super(MFeatureExtractorTCN, self).__init__() 121 self.main_module = DilatedTCN(num_layers, num_f_maps, num_f_maps) 122 self.extractor = FeatureExtractorTCN( 123 num_f_maps, output_dim, len_segment, kernel_1, kernel_2, stride 124 ) 125 in_features = int(len_segment * num_f_maps) 126 out_features = int((end - start) * num_f_maps) 127 self.linear = Linear(in_features=in_features, out_features=out_features) 128 self.start = int(start) 129 self.end = int(end) 130 131 def forward(self, f, extract_features=True): 132 """Forward pass.""" 133 if extract_features: 134 f = self.main_module(f) 135 f = F.relu(f) 136 f = f.reshape((f.shape[0], -1)) 137 f = self.linear(f) 138 else: 139 f = f[:, :, self.start : self.end] 140 f = f.reshape((f.shape[0], -1)) 141 return f 142 143 144class FC(nn.Module): 145 """Fully connected module that predicts input data given features.""" 146 147 def __init__( 148 self, dim: int, num_f_maps: int, num_ssl_layers: int, num_ssl_f_maps: int, ssl_input:bool = False 149 ) -> None: 150 """Initialize the module. 151 152 Parameters 153 ---------- 154 dim : int 155 output number of features 156 num_f_maps : int 157 number of features in input 158 num_ssl_layers : int 159 number of layers in the module 160 num_ssl_f_maps : int 161 number of feature maps in the module 162 163 """ 164 super().__init__() 165 dim = int(dim) 166 in_features = dim if ssl_input else num_f_maps 167 num_f_maps = int(num_f_maps) 168 num_ssl_layers = int(num_ssl_layers) 169 num_ssl_f_maps = int(num_ssl_f_maps) 170 self.layers = nn.ModuleList( 171 [nn.Linear(in_features=in_features, out_features=num_ssl_f_maps)] 172 ) 173 for _ in range(num_ssl_layers - 2): 174 self.layers.append( 175 nn.Linear(in_features=num_ssl_f_maps, out_features=num_ssl_f_maps) 176 ) 177 self.layers.append(nn.Linear(in_features=num_ssl_f_maps, out_features=dim)) 178 179 def forward(self, x: torch.Tensor) -> torch.Tensor: 180 """Forward pass.""" 181 N, C, F = x.shape 182 x = x.transpose(1, 2).reshape(-1, C) 183 for layer in self.layers: 184 x = layer(x) 185 x = x.reshape(N, F, -1).transpose(1, 2) 186 return x 187 188 189class DilatedTCN(nn.Module): 190 """TCN module that predicts input data given features.""" 191 192 def __init__(self, num_layers, input_dim, output_dim): 193 """Initialize the module. 194 195 Parameters 196 ---------- 197 output_dim : int 198 output number of features 199 input_dim : int 200 number of features in input 201 num_layers : int 202 number of layers in the module 203 204 """ 205 super().__init__() 206 num_layers = int(num_layers) 207 input_dim = int(input_dim) 208 output_dim = int(output_dim) 209 self.num_layers = num_layers 210 self.conv_dilated_1 = nn.ModuleList( 211 ( 212 nn.Conv1d( 213 input_dim, 214 input_dim, 215 3, 216 padding=2 ** (num_layers - 1 - i), 217 dilation=2 ** (num_layers - 1 - i), 218 ) 219 for i in range(num_layers) 220 ) 221 ) 222 223 self.conv_dilated_2 = nn.ModuleList( 224 ( 225 nn.Conv1d(input_dim, input_dim, 3, padding=2**i, dilation=2**i) 226 for i in range(num_layers) 227 ) 228 ) 229 230 self.conv_fusion = nn.ModuleList( 231 (nn.Conv1d(2 * input_dim, input_dim, 1) for i in range(num_layers)) 232 ) 233 234 self.conv_1x1_out = nn.Conv1d(input_dim, output_dim, 1) 235 236 self.dropout = nn.Dropout() 237 238 def forward(self, f): 239 """Forward pass.""" 240 for i in range(self.num_layers): 241 f_in = copy.copy(f) 242 f = self.conv_fusion[i]( 243 torch.cat([self.conv_dilated_1[i](f), self.conv_dilated_2[i](f)], 1) 244 ) 245 f = F.relu(f) 246 f = self.dropout(f) 247 f = f + f_in 248 f = self.conv_1x1_out(f) 249 return f
17class FeatureExtractorTCN(nn.Module): 18 """A module that extracts clip-level features with a TCN.""" 19 20 def __init__( 21 self, 22 num_f_maps: int, 23 output_dim: int, 24 len_segment: int, 25 kernel_1: int, 26 kernel_2: int, 27 stride: int, 28 decrease_f_maps: bool = False, 29 ) -> None: 30 """Initialize the module. 31 32 Parameters 33 ---------- 34 num_f_maps : int 35 number of features in input 36 output_dim : int 37 number of features in output 38 len_segment : int 39 length of segment in input 40 kernel_1 : int 41 kernel size of the first layer 42 kernel_2 : int 43 kernel size of the second layer 44 stride : int 45 stride 46 decrease_f_maps : bool, default False 47 if `True`, number of feature maps is halved at each new layer 48 49 """ 50 super().__init__() 51 num_f_maps = int(num_f_maps) 52 output_dim = int(output_dim) 53 if decrease_f_maps: 54 f_maps_2 = max(num_f_maps // 2, 1) 55 f_maps_3 = max(num_f_maps // 4, 1) 56 else: 57 f_maps_2 = f_maps_3 = num_f_maps 58 length = int(floor((len_segment - kernel_1) / stride + 1)) 59 length = floor((length - kernel_2) / stride + 1) 60 features = length * f_maps_3 61 self.conv = nn.ModuleList() 62 self.conv.append( 63 nn.Conv1d(num_f_maps, f_maps_2, kernel_1, padding=0, stride=stride) 64 ) 65 self.conv.append( 66 nn.Conv1d(f_maps_2, f_maps_3, kernel_2, padding=0, stride=stride) 67 ) 68 self.conv_1x1_out = nn.Conv1d(features, output_dim, 1) 69 self.dropout = nn.Dropout() 70 71 def forward(self, f): 72 """Forward pass.""" 73 for conv in self.conv: 74 f = conv(f) 75 f = F.relu(f) 76 f = self.dropout(f) 77 f = f.reshape((f.shape[0], -1, 1)) 78 f = self.conv_1x1_out(f).squeeze() 79 return f
A module that extracts clip-level features with a TCN.
20 def __init__( 21 self, 22 num_f_maps: int, 23 output_dim: int, 24 len_segment: int, 25 kernel_1: int, 26 kernel_2: int, 27 stride: int, 28 decrease_f_maps: bool = False, 29 ) -> None: 30 """Initialize the module. 31 32 Parameters 33 ---------- 34 num_f_maps : int 35 number of features in input 36 output_dim : int 37 number of features in output 38 len_segment : int 39 length of segment in input 40 kernel_1 : int 41 kernel size of the first layer 42 kernel_2 : int 43 kernel size of the second layer 44 stride : int 45 stride 46 decrease_f_maps : bool, default False 47 if `True`, number of feature maps is halved at each new layer 48 49 """ 50 super().__init__() 51 num_f_maps = int(num_f_maps) 52 output_dim = int(output_dim) 53 if decrease_f_maps: 54 f_maps_2 = max(num_f_maps // 2, 1) 55 f_maps_3 = max(num_f_maps // 4, 1) 56 else: 57 f_maps_2 = f_maps_3 = num_f_maps 58 length = int(floor((len_segment - kernel_1) / stride + 1)) 59 length = floor((length - kernel_2) / stride + 1) 60 features = length * f_maps_3 61 self.conv = nn.ModuleList() 62 self.conv.append( 63 nn.Conv1d(num_f_maps, f_maps_2, kernel_1, padding=0, stride=stride) 64 ) 65 self.conv.append( 66 nn.Conv1d(f_maps_2, f_maps_3, kernel_2, padding=0, stride=stride) 67 ) 68 self.conv_1x1_out = nn.Conv1d(features, output_dim, 1) 69 self.dropout = nn.Dropout()
Initialize the module.
Parameters
num_f_maps : int
number of features in input
output_dim : int
number of features in output
len_segment : int
length of segment in input
kernel_1 : int
kernel size of the first layer
kernel_2 : int
kernel size of the second layer
stride : int
stride
decrease_f_maps : bool, default False
if True, number of feature maps is halved at each new layer
82class MFeatureExtractorTCN(nn.Module): 83 """A module that extracts segment-level features with a TCN.""" 84 85 def __init__( 86 self, 87 num_f_maps: int, 88 output_dim: int, 89 len_segment: int, 90 kernel_1: int, 91 kernel_2: int, 92 stride: int, 93 start: int, 94 end: int, 95 num_layers: int = 3, 96 ): 97 """Initialize the module. 98 99 Parameters 100 ---------- 101 num_f_maps : int 102 number of features in input 103 output_dim : int 104 number of features in output 105 len_segment : int 106 length of segment in input 107 kernel_1 : int 108 kernel size of the first layer 109 kernel_2 : int 110 kernel size of the second layer 111 stride : int 112 stride 113 start : int 114 the start index of the segment to extract 115 end : int 116 the end index of the segment to extract 117 num_layers : int 118 number of layers 119 120 """ 121 super(MFeatureExtractorTCN, self).__init__() 122 self.main_module = DilatedTCN(num_layers, num_f_maps, num_f_maps) 123 self.extractor = FeatureExtractorTCN( 124 num_f_maps, output_dim, len_segment, kernel_1, kernel_2, stride 125 ) 126 in_features = int(len_segment * num_f_maps) 127 out_features = int((end - start) * num_f_maps) 128 self.linear = Linear(in_features=in_features, out_features=out_features) 129 self.start = int(start) 130 self.end = int(end) 131 132 def forward(self, f, extract_features=True): 133 """Forward pass.""" 134 if extract_features: 135 f = self.main_module(f) 136 f = F.relu(f) 137 f = f.reshape((f.shape[0], -1)) 138 f = self.linear(f) 139 else: 140 f = f[:, :, self.start : self.end] 141 f = f.reshape((f.shape[0], -1)) 142 return f
A module that extracts segment-level features with a TCN.
85 def __init__( 86 self, 87 num_f_maps: int, 88 output_dim: int, 89 len_segment: int, 90 kernel_1: int, 91 kernel_2: int, 92 stride: int, 93 start: int, 94 end: int, 95 num_layers: int = 3, 96 ): 97 """Initialize the module. 98 99 Parameters 100 ---------- 101 num_f_maps : int 102 number of features in input 103 output_dim : int 104 number of features in output 105 len_segment : int 106 length of segment in input 107 kernel_1 : int 108 kernel size of the first layer 109 kernel_2 : int 110 kernel size of the second layer 111 stride : int 112 stride 113 start : int 114 the start index of the segment to extract 115 end : int 116 the end index of the segment to extract 117 num_layers : int 118 number of layers 119 120 """ 121 super(MFeatureExtractorTCN, self).__init__() 122 self.main_module = DilatedTCN(num_layers, num_f_maps, num_f_maps) 123 self.extractor = FeatureExtractorTCN( 124 num_f_maps, output_dim, len_segment, kernel_1, kernel_2, stride 125 ) 126 in_features = int(len_segment * num_f_maps) 127 out_features = int((end - start) * num_f_maps) 128 self.linear = Linear(in_features=in_features, out_features=out_features) 129 self.start = int(start) 130 self.end = int(end)
Initialize the module.
Parameters
num_f_maps : int number of features in input output_dim : int number of features in output len_segment : int length of segment in input kernel_1 : int kernel size of the first layer kernel_2 : int kernel size of the second layer stride : int stride start : int the start index of the segment to extract end : int the end index of the segment to extract num_layers : int number of layers
132 def forward(self, f, extract_features=True): 133 """Forward pass.""" 134 if extract_features: 135 f = self.main_module(f) 136 f = F.relu(f) 137 f = f.reshape((f.shape[0], -1)) 138 f = self.linear(f) 139 else: 140 f = f[:, :, self.start : self.end] 141 f = f.reshape((f.shape[0], -1)) 142 return f
Forward pass.
145class FC(nn.Module): 146 """Fully connected module that predicts input data given features.""" 147 148 def __init__( 149 self, dim: int, num_f_maps: int, num_ssl_layers: int, num_ssl_f_maps: int, ssl_input:bool = False 150 ) -> None: 151 """Initialize the module. 152 153 Parameters 154 ---------- 155 dim : int 156 output number of features 157 num_f_maps : int 158 number of features in input 159 num_ssl_layers : int 160 number of layers in the module 161 num_ssl_f_maps : int 162 number of feature maps in the module 163 164 """ 165 super().__init__() 166 dim = int(dim) 167 in_features = dim if ssl_input else num_f_maps 168 num_f_maps = int(num_f_maps) 169 num_ssl_layers = int(num_ssl_layers) 170 num_ssl_f_maps = int(num_ssl_f_maps) 171 self.layers = nn.ModuleList( 172 [nn.Linear(in_features=in_features, out_features=num_ssl_f_maps)] 173 ) 174 for _ in range(num_ssl_layers - 2): 175 self.layers.append( 176 nn.Linear(in_features=num_ssl_f_maps, out_features=num_ssl_f_maps) 177 ) 178 self.layers.append(nn.Linear(in_features=num_ssl_f_maps, out_features=dim)) 179 180 def forward(self, x: torch.Tensor) -> torch.Tensor: 181 """Forward pass.""" 182 N, C, F = x.shape 183 x = x.transpose(1, 2).reshape(-1, C) 184 for layer in self.layers: 185 x = layer(x) 186 x = x.reshape(N, F, -1).transpose(1, 2) 187 return x
Fully connected module that predicts input data given features.
148 def __init__( 149 self, dim: int, num_f_maps: int, num_ssl_layers: int, num_ssl_f_maps: int, ssl_input:bool = False 150 ) -> None: 151 """Initialize the module. 152 153 Parameters 154 ---------- 155 dim : int 156 output number of features 157 num_f_maps : int 158 number of features in input 159 num_ssl_layers : int 160 number of layers in the module 161 num_ssl_f_maps : int 162 number of feature maps in the module 163 164 """ 165 super().__init__() 166 dim = int(dim) 167 in_features = dim if ssl_input else num_f_maps 168 num_f_maps = int(num_f_maps) 169 num_ssl_layers = int(num_ssl_layers) 170 num_ssl_f_maps = int(num_ssl_f_maps) 171 self.layers = nn.ModuleList( 172 [nn.Linear(in_features=in_features, out_features=num_ssl_f_maps)] 173 ) 174 for _ in range(num_ssl_layers - 2): 175 self.layers.append( 176 nn.Linear(in_features=num_ssl_f_maps, out_features=num_ssl_f_maps) 177 ) 178 self.layers.append(nn.Linear(in_features=num_ssl_f_maps, out_features=dim))
Initialize the module.
Parameters
dim : int output number of features num_f_maps : int number of features in input num_ssl_layers : int number of layers in the module num_ssl_f_maps : int number of feature maps in the module
180 def forward(self, x: torch.Tensor) -> torch.Tensor: 181 """Forward pass.""" 182 N, C, F = x.shape 183 x = x.transpose(1, 2).reshape(-1, C) 184 for layer in self.layers: 185 x = layer(x) 186 x = x.reshape(N, F, -1).transpose(1, 2) 187 return x
Forward pass.
190class DilatedTCN(nn.Module): 191 """TCN module that predicts input data given features.""" 192 193 def __init__(self, num_layers, input_dim, output_dim): 194 """Initialize the module. 195 196 Parameters 197 ---------- 198 output_dim : int 199 output number of features 200 input_dim : int 201 number of features in input 202 num_layers : int 203 number of layers in the module 204 205 """ 206 super().__init__() 207 num_layers = int(num_layers) 208 input_dim = int(input_dim) 209 output_dim = int(output_dim) 210 self.num_layers = num_layers 211 self.conv_dilated_1 = nn.ModuleList( 212 ( 213 nn.Conv1d( 214 input_dim, 215 input_dim, 216 3, 217 padding=2 ** (num_layers - 1 - i), 218 dilation=2 ** (num_layers - 1 - i), 219 ) 220 for i in range(num_layers) 221 ) 222 ) 223 224 self.conv_dilated_2 = nn.ModuleList( 225 ( 226 nn.Conv1d(input_dim, input_dim, 3, padding=2**i, dilation=2**i) 227 for i in range(num_layers) 228 ) 229 ) 230 231 self.conv_fusion = nn.ModuleList( 232 (nn.Conv1d(2 * input_dim, input_dim, 1) for i in range(num_layers)) 233 ) 234 235 self.conv_1x1_out = nn.Conv1d(input_dim, output_dim, 1) 236 237 self.dropout = nn.Dropout() 238 239 def forward(self, f): 240 """Forward pass.""" 241 for i in range(self.num_layers): 242 f_in = copy.copy(f) 243 f = self.conv_fusion[i]( 244 torch.cat([self.conv_dilated_1[i](f), self.conv_dilated_2[i](f)], 1) 245 ) 246 f = F.relu(f) 247 f = self.dropout(f) 248 f = f + f_in 249 f = self.conv_1x1_out(f) 250 return f
TCN module that predicts input data given features.
193 def __init__(self, num_layers, input_dim, output_dim): 194 """Initialize the module. 195 196 Parameters 197 ---------- 198 output_dim : int 199 output number of features 200 input_dim : int 201 number of features in input 202 num_layers : int 203 number of layers in the module 204 205 """ 206 super().__init__() 207 num_layers = int(num_layers) 208 input_dim = int(input_dim) 209 output_dim = int(output_dim) 210 self.num_layers = num_layers 211 self.conv_dilated_1 = nn.ModuleList( 212 ( 213 nn.Conv1d( 214 input_dim, 215 input_dim, 216 3, 217 padding=2 ** (num_layers - 1 - i), 218 dilation=2 ** (num_layers - 1 - i), 219 ) 220 for i in range(num_layers) 221 ) 222 ) 223 224 self.conv_dilated_2 = nn.ModuleList( 225 ( 226 nn.Conv1d(input_dim, input_dim, 3, padding=2**i, dilation=2**i) 227 for i in range(num_layers) 228 ) 229 ) 230 231 self.conv_fusion = nn.ModuleList( 232 (nn.Conv1d(2 * input_dim, input_dim, 1) for i in range(num_layers)) 233 ) 234 235 self.conv_1x1_out = nn.Conv1d(input_dim, output_dim, 1) 236 237 self.dropout = nn.Dropout()
Initialize the module.
Parameters
output_dim : int output number of features input_dim : int number of features in input num_layers : int number of layers in the module
239 def forward(self, f): 240 """Forward pass.""" 241 for i in range(self.num_layers): 242 f_in = copy.copy(f) 243 f = self.conv_fusion[i]( 244 torch.cat([self.conv_dilated_1[i](f), self.conv_dilated_2[i](f)], 1) 245 ) 246 f = F.relu(f) 247 f = self.dropout(f) 248 f = f + f_in 249 f = self.conv_1x1_out(f) 250 return f
Forward pass.