dlc2action.model.transformer
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6import torch 7from torch import nn 8from torch.nn import functional as F 9import math 10import copy 11from dlc2action.model.base_model import Model 12 13 14class _Interpolate(nn.Module): 15 def __init__(self, size, mode): 16 super(_Interpolate, self).__init__() 17 self.interp = F.interpolate 18 self.size = size 19 self.mode = mode 20 21 def forward(self, x): 22 x = self.interp(x, size=self.size, mode=self.mode, align_corners=False) 23 return x 24 25 26class _FeedForward(nn.Module): 27 def __init__(self, d_model, dropout=0.1): 28 super().__init__() 29 # We set d_ff as a default to 2048 30 self.linear_1 = nn.Conv1d(d_model, d_model, 1) 31 self.dropout = nn.Dropout(dropout) 32 self.linear_2 = nn.Conv1d(d_model, d_model, 1) 33 34 def forward(self, x): 35 x = self.dropout(F.relu(self.linear_1(x))) 36 x = self.linear_2(x) 37 return x 38 39 40class _MultiHeadAttention(nn.Module): 41 def __init__(self, heads, d_model, dropout=0.1): 42 super().__init__() 43 44 self.d_model = d_model 45 self.d_k = d_model // heads 46 self.h = heads 47 48 self.q_linear = nn.Linear(d_model, d_model) 49 self.v_linear = nn.Linear(d_model, d_model) 50 self.k_linear = nn.Linear(d_model, d_model) 51 self.dropout = nn.Dropout(dropout) 52 self.out = nn.Linear(d_model, d_model) 53 54 def forward(self, q, k, v, mask=None): 55 bs = q.size(0) 56 q = q.transpose(1, 2) 57 v = v.transpose(1, 2) 58 k = k.transpose(1, 2) 59 60 # perform linear operation and split into h heads 61 k = self.k_linear(k).view(bs, -1, self.h, self.d_k) 62 q = self.q_linear(q).view(bs, -1, self.h, self.d_k) 63 v = self.v_linear(v).view(bs, -1, self.h, self.d_k) 64 65 # transpose to get dimensions bs * h * sl * d_model 66 67 k = k.transpose(1, 2) 68 q = q.transpose(1, 2) 69 v = v.transpose(1, 2) 70 # calculate attention using function we will define next 71 scores = _attention(q, k, v, self.d_k, mask, self.dropout) 72 73 # concatenate heads and put through final linear layer 74 concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model) 75 76 output = self.out(concat).transpose(1, 2) 77 78 return output 79 80 81def _attention(q, k, v, d_k, mask=None, dropout=None): 82 scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k) 83 if mask is not None: 84 mask = mask.unsqueeze(1) 85 scores = scores.masked_fill(mask == 0, -1e9) 86 scores = F.softmax(scores, dim=-1) 87 88 if dropout is not None: 89 scores = dropout(scores) 90 91 output = torch.matmul(scores, v) 92 return output 93 94 95class _EncoderLayer(nn.Module): 96 def __init__(self, d_model, heads, dropout=0.1): 97 super().__init__() 98 self.norm_1 = nn.BatchNorm1d(d_model) 99 self.norm_2 = nn.BatchNorm1d(d_model) 100 self.attn = _MultiHeadAttention(heads, d_model) 101 self.ff = _FeedForward(d_model) 102 self.dropout_1 = nn.Dropout(dropout) 103 self.dropout_2 = nn.Dropout(dropout) 104 105 def forward(self, x, mask): 106 x2 = self.norm_1(x) 107 x = x + self.dropout_1(self.attn(x2, x2, x2, mask)) 108 x2 = self.norm_2(x) 109 x = x + self.dropout_2(self.ff(x2)) 110 return x 111 112 113def _get_clones(module, N): 114 return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 115 116 117class _Encoder(nn.Module): 118 def __init__(self, d_model, N, heads): 119 super().__init__() 120 self.pe = _PositionalEncoder(d_model) 121 self.layers = _get_clones(_EncoderLayer(d_model, heads), N) 122 self.norm = nn.BatchNorm1d(d_model) 123 124 def forward(self, src, mask): 125 x = self.pe(src) 126 for layer in self.layers: 127 x = layer(x, mask) 128 return self.norm(x) 129 130 131class _PositionalEncoder(nn.Module): 132 def __init__(self, d_model, max_seq_len=512): 133 super().__init__() 134 self.d_model = d_model 135 136 # create constant 'pe' matrix with values dependant on 137 # pos and i 138 pe = torch.zeros(d_model, max_seq_len) 139 for pos in range(max_seq_len): 140 for i in range(0, d_model, 2): 141 pe[i, pos] = math.sin(pos / (10000 ** ((2 * i) / d_model))) 142 if i + 1 < d_model: 143 pe[i + 1, pos] = math.cos( 144 pos / (10000 ** ((2 * (i + 1)) / d_model)) 145 ) 146 147 self.pe = pe.unsqueeze(0) 148 149 def forward(self, x): 150 # make embeddings relatively larger 151 x = x * math.sqrt(self.d_model) 152 # add constant to embedding 153 seq_len = x.size(-1) 154 x = x + self.pe[:, :, :seq_len].to(x.device) 155 return x 156 157 158class _TransformerModule(nn.Module): 159 def __init__( 160 self, heads, N, d_model, input_dim, output_dim, num_pool=3, add_batchnorm=True 161 ): 162 super(_TransformerModule, self).__init__() 163 self.encoder = _Encoder(d_model, N, heads) 164 self.in_layers = nn.ModuleList() 165 self.out_layers = nn.ModuleList() 166 layer = nn.ModuleList() 167 layer.append(nn.Conv1d(input_dim, d_model, 3, padding=1)) 168 layer.append(nn.ReLU()) 169 if num_pool > 0: 170 layer.append(nn.MaxPool1d(2, 2)) 171 self.in_layers.append(layer) 172 for _ in range(num_pool - 1): 173 layer = nn.ModuleList() 174 layer.append(nn.Conv1d(d_model, d_model, 3, padding=1)) 175 layer.append(nn.ReLU()) 176 if add_batchnorm: 177 layer.append(nn.BatchNorm1d(d_model)) 178 layer.append(nn.MaxPool1d(2, 2)) 179 self.in_layers.append(layer) 180 for _ in range(num_pool): 181 layer = nn.ModuleList() 182 layer.append(nn.Conv1d(d_model, d_model, 3, padding=1)) 183 layer.append(nn.ReLU()) 184 if add_batchnorm: 185 layer.append(nn.BatchNorm1d(d_model)) 186 self.out_layers.append(layer) 187 self.conv_out = nn.Conv1d(d_model, output_dim, 3, padding=1) 188 189 def forward(self, x): 190 sizes = [] 191 for layer_list in self.in_layers: 192 sizes.append(x.shape[-1]) 193 for layer in layer_list: 194 x = layer(x) 195 mask = (x.sum(1).unsqueeze(1) != 0).int() 196 x = self.encoder(x, mask) 197 sizes = sizes[::-1] 198 for i, (layer_list, size) in enumerate(zip(self.out_layers, sizes)): 199 for layer in layer_list: 200 x = layer(x) 201 x = F.interpolate(x, size) 202 x = self.conv_out(x) 203 return x 204 205 206class _Predictor(nn.Module): 207 def __init__(self, dim, num_classes): 208 super(_Predictor, self).__init__() 209 self.conv_out_1 = nn.Conv1d(dim, 64, kernel_size=1) 210 self.conv_out_2 = nn.Conv1d(64, num_classes, kernel_size=1) 211 212 def forward(self, x): 213 x = self.conv_out_1(x) 214 x = F.relu(x) 215 x = self.conv_out_2(x) 216 return x 217 218 219class Transformer(Model): 220 """ 221 A modification of Transformer-Encoder with additional max-pooling and upsampling 222 223 Set `num_pool` to 0 to get a standart transformer-encoder. 224 """ 225 226 def __init__( 227 self, 228 N, 229 heads, 230 num_f_maps, 231 input_dim, 232 num_classes, 233 num_pool, 234 add_batchnorm=False, 235 feature_dim=None, 236 state_dict_path=None, 237 ssl_constructors=None, 238 ssl_types=None, 239 ssl_modules=None, 240 ): 241 input_dim = sum([x[0] for x in input_dim.values()]) 242 if feature_dim is None: 243 feature_dim = num_classes 244 self.f_shape = None 245 self.params_predictor = None 246 else: 247 self.f_shape = torch.Size([int(feature_dim)]) 248 self.params_predictor = { 249 "dim": int(feature_dim), 250 "num_classes": int(num_classes), 251 } 252 self.params = { 253 "d_model": int(num_f_maps), 254 "input_dim": int(input_dim), 255 "N": int(N), 256 "heads": int(heads), 257 "add_batchnorm": add_batchnorm, 258 "num_pool": int(num_pool), 259 "output_dim": int(feature_dim), 260 } 261 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 262 263 def _feature_extractor(self): 264 return _TransformerModule(**self.params) 265 266 def _predictor(self) -> torch.nn.Module: 267 if self.params_predictor is None: 268 return nn.Identity() 269 else: 270 return _Predictor(**self.params_predictor) 271 272 def features_shape(self) -> torch.Size: 273 return self.f_shape
220class Transformer(Model): 221 """ 222 A modification of Transformer-Encoder with additional max-pooling and upsampling 223 224 Set `num_pool` to 0 to get a standart transformer-encoder. 225 """ 226 227 def __init__( 228 self, 229 N, 230 heads, 231 num_f_maps, 232 input_dim, 233 num_classes, 234 num_pool, 235 add_batchnorm=False, 236 feature_dim=None, 237 state_dict_path=None, 238 ssl_constructors=None, 239 ssl_types=None, 240 ssl_modules=None, 241 ): 242 input_dim = sum([x[0] for x in input_dim.values()]) 243 if feature_dim is None: 244 feature_dim = num_classes 245 self.f_shape = None 246 self.params_predictor = None 247 else: 248 self.f_shape = torch.Size([int(feature_dim)]) 249 self.params_predictor = { 250 "dim": int(feature_dim), 251 "num_classes": int(num_classes), 252 } 253 self.params = { 254 "d_model": int(num_f_maps), 255 "input_dim": int(input_dim), 256 "N": int(N), 257 "heads": int(heads), 258 "add_batchnorm": add_batchnorm, 259 "num_pool": int(num_pool), 260 "output_dim": int(feature_dim), 261 } 262 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 263 264 def _feature_extractor(self): 265 return _TransformerModule(**self.params) 266 267 def _predictor(self) -> torch.nn.Module: 268 if self.params_predictor is None: 269 return nn.Identity() 270 else: 271 return _Predictor(**self.params_predictor) 272 273 def features_shape(self) -> torch.Size: 274 return self.f_shape
A modification of Transformer-Encoder with additional max-pooling and upsampling
Set num_pool
to 0 to get a standart transformer-encoder.
227 def __init__( 228 self, 229 N, 230 heads, 231 num_f_maps, 232 input_dim, 233 num_classes, 234 num_pool, 235 add_batchnorm=False, 236 feature_dim=None, 237 state_dict_path=None, 238 ssl_constructors=None, 239 ssl_types=None, 240 ssl_modules=None, 241 ): 242 input_dim = sum([x[0] for x in input_dim.values()]) 243 if feature_dim is None: 244 feature_dim = num_classes 245 self.f_shape = None 246 self.params_predictor = None 247 else: 248 self.f_shape = torch.Size([int(feature_dim)]) 249 self.params_predictor = { 250 "dim": int(feature_dim), 251 "num_classes": int(num_classes), 252 } 253 self.params = { 254 "d_model": int(num_f_maps), 255 "input_dim": int(input_dim), 256 "N": int(N), 257 "heads": int(heads), 258 "add_batchnorm": add_batchnorm, 259 "num_pool": int(num_pool), 260 "output_dim": int(feature_dim), 261 } 262 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
Get the shape of feature extractor output
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- forward
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- T_destination
- state_dict
- register_load_state_dict_post_hook
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr