dlc2action.model.transformer

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6import torch
  7from torch import nn
  8from torch.nn import functional as F
  9import math
 10import copy
 11from dlc2action.model.base_model import Model
 12
 13
 14class _Interpolate(nn.Module):
 15    def __init__(self, size, mode):
 16        super(_Interpolate, self).__init__()
 17        self.interp = F.interpolate
 18        self.size = size
 19        self.mode = mode
 20
 21    def forward(self, x):
 22        x = self.interp(x, size=self.size, mode=self.mode, align_corners=False)
 23        return x
 24
 25
 26class _FeedForward(nn.Module):
 27    def __init__(self, d_model, dropout=0.1):
 28        super().__init__()
 29        # We set d_ff as a default to 2048
 30        self.linear_1 = nn.Conv1d(d_model, d_model, 1)
 31        self.dropout = nn.Dropout(dropout)
 32        self.linear_2 = nn.Conv1d(d_model, d_model, 1)
 33
 34    def forward(self, x):
 35        x = self.dropout(F.relu(self.linear_1(x)))
 36        x = self.linear_2(x)
 37        return x
 38
 39
 40class _MultiHeadAttention(nn.Module):
 41    def __init__(self, heads, d_model, dropout=0.1):
 42        super().__init__()
 43
 44        self.d_model = d_model
 45        self.d_k = d_model // heads
 46        self.h = heads
 47
 48        self.q_linear = nn.Linear(d_model, d_model)
 49        self.v_linear = nn.Linear(d_model, d_model)
 50        self.k_linear = nn.Linear(d_model, d_model)
 51        self.dropout = nn.Dropout(dropout)
 52        self.out = nn.Linear(d_model, d_model)
 53
 54    def forward(self, q, k, v, mask=None):
 55        bs = q.size(0)
 56        q = q.transpose(1, 2)
 57        v = v.transpose(1, 2)
 58        k = k.transpose(1, 2)
 59
 60        # perform linear operation and split into h heads
 61        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
 62        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
 63        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
 64
 65        # transpose to get dimensions bs * h * sl * d_model
 66
 67        k = k.transpose(1, 2)
 68        q = q.transpose(1, 2)
 69        v = v.transpose(1, 2)
 70        # calculate attention using function we will define next
 71        scores = _attention(q, k, v, self.d_k, mask, self.dropout)
 72
 73        # concatenate heads and put through final linear layer
 74        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
 75
 76        output = self.out(concat).transpose(1, 2)
 77
 78        return output
 79
 80
 81def _attention(q, k, v, d_k, mask=None, dropout=None):
 82    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
 83    if mask is not None:
 84        mask = mask.unsqueeze(1)
 85        scores = scores.masked_fill(mask == 0, -1e9)
 86    scores = F.softmax(scores, dim=-1)
 87
 88    if dropout is not None:
 89        scores = dropout(scores)
 90
 91    output = torch.matmul(scores, v)
 92    return output
 93
 94
 95class _EncoderLayer(nn.Module):
 96    def __init__(self, d_model, heads, dropout=0.1):
 97        super().__init__()
 98        self.norm_1 = nn.BatchNorm1d(d_model)
 99        self.norm_2 = nn.BatchNorm1d(d_model)
100        self.attn = _MultiHeadAttention(heads, d_model)
101        self.ff = _FeedForward(d_model)
102        self.dropout_1 = nn.Dropout(dropout)
103        self.dropout_2 = nn.Dropout(dropout)
104
105    def forward(self, x, mask):
106        x2 = self.norm_1(x)
107        x = x + self.dropout_1(self.attn(x2, x2, x2, mask))
108        x2 = self.norm_2(x)
109        x = x + self.dropout_2(self.ff(x2))
110        return x
111
112
113def _get_clones(module, N):
114    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
115
116
117class _Encoder(nn.Module):
118    def __init__(self, d_model, N, heads):
119        super().__init__()
120        self.pe = _PositionalEncoder(d_model)
121        self.layers = _get_clones(_EncoderLayer(d_model, heads), N)
122        self.norm = nn.BatchNorm1d(d_model)
123
124    def forward(self, src, mask):
125        x = self.pe(src)
126        for layer in self.layers:
127            x = layer(x, mask)
128        return self.norm(x)
129
130
131class _PositionalEncoder(nn.Module):
132    def __init__(self, d_model, max_seq_len=512):
133        super().__init__()
134        self.d_model = d_model
135
136        # create constant 'pe' matrix with values dependant on
137        # pos and i
138        pe = torch.zeros(d_model, max_seq_len)
139        for pos in range(max_seq_len):
140            for i in range(0, d_model, 2):
141                pe[i, pos] = math.sin(pos / (10000 ** ((2 * i) / d_model)))
142                if i + 1 < d_model:
143                    pe[i + 1, pos] = math.cos(
144                        pos / (10000 ** ((2 * (i + 1)) / d_model))
145                    )
146
147        self.pe = pe.unsqueeze(0)
148
149    def forward(self, x):
150        # make embeddings relatively larger
151        x = x * math.sqrt(self.d_model)
152        # add constant to embedding
153        seq_len = x.size(-1)
154        x = x + self.pe[:, :, :seq_len].to(x.device)
155        return x
156
157
158class _TransformerModule(nn.Module):
159    def __init__(
160        self, heads, N, d_model, input_dim, output_dim, num_pool=3, add_batchnorm=True
161    ):
162        super(_TransformerModule, self).__init__()
163        self.encoder = _Encoder(d_model, N, heads)
164        self.in_layers = nn.ModuleList()
165        self.out_layers = nn.ModuleList()
166        layer = nn.ModuleList()
167        layer.append(nn.Conv1d(input_dim, d_model, 3, padding=1))
168        layer.append(nn.ReLU())
169        if num_pool > 0:
170            layer.append(nn.MaxPool1d(2, 2))
171        self.in_layers.append(layer)
172        for _ in range(num_pool - 1):
173            layer = nn.ModuleList()
174            layer.append(nn.Conv1d(d_model, d_model, 3, padding=1))
175            layer.append(nn.ReLU())
176            if add_batchnorm:
177                layer.append(nn.BatchNorm1d(d_model))
178            layer.append(nn.MaxPool1d(2, 2))
179            self.in_layers.append(layer)
180        for _ in range(num_pool):
181            layer = nn.ModuleList()
182            layer.append(nn.Conv1d(d_model, d_model, 3, padding=1))
183            layer.append(nn.ReLU())
184            if add_batchnorm:
185                layer.append(nn.BatchNorm1d(d_model))
186            self.out_layers.append(layer)
187        self.conv_out = nn.Conv1d(d_model, output_dim, 3, padding=1)
188
189    def forward(self, x):
190        sizes = []
191        for layer_list in self.in_layers:
192            sizes.append(x.shape[-1])
193            for layer in layer_list:
194                x = layer(x)
195        mask = (x.sum(1).unsqueeze(1) != 0).int()
196        x = self.encoder(x, mask)
197        sizes = sizes[::-1]
198        for i, (layer_list, size) in enumerate(zip(self.out_layers, sizes)):
199            for layer in layer_list:
200                x = layer(x)
201            x = F.interpolate(x, size)
202        x = self.conv_out(x)
203        return x
204
205
206class _Predictor(nn.Module):
207    def __init__(self, dim, num_classes):
208        super(_Predictor, self).__init__()
209        self.conv_out_1 = nn.Conv1d(dim, 64, kernel_size=1)
210        self.conv_out_2 = nn.Conv1d(64, num_classes, kernel_size=1)
211
212    def forward(self, x):
213        x = self.conv_out_1(x)
214        x = F.relu(x)
215        x = self.conv_out_2(x)
216        return x
217
218
219class Transformer(Model):
220    """
221    A modification of Transformer-Encoder with additional max-pooling and upsampling
222
223    Set `num_pool` to 0 to get a standart transformer-encoder.
224    """
225
226    def __init__(
227        self,
228        N,
229        heads,
230        num_f_maps,
231        input_dim,
232        num_classes,
233        num_pool,
234        add_batchnorm=False,
235        feature_dim=None,
236        state_dict_path=None,
237        ssl_constructors=None,
238        ssl_types=None,
239        ssl_modules=None,
240    ):
241        input_dim = sum([x[0] for x in input_dim.values()])
242        if feature_dim is None:
243            feature_dim = num_classes
244            self.f_shape = None
245            self.params_predictor = None
246        else:
247            self.f_shape = torch.Size([int(feature_dim)])
248            self.params_predictor = {
249                "dim": int(feature_dim),
250                "num_classes": int(num_classes),
251            }
252        self.params = {
253            "d_model": int(num_f_maps),
254            "input_dim": int(input_dim),
255            "N": int(N),
256            "heads": int(heads),
257            "add_batchnorm": add_batchnorm,
258            "num_pool": int(num_pool),
259            "output_dim": int(feature_dim),
260        }
261        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
262
263    def _feature_extractor(self):
264        return _TransformerModule(**self.params)
265
266    def _predictor(self) -> torch.nn.Module:
267        if self.params_predictor is None:
268            return nn.Identity()
269        else:
270            return _Predictor(**self.params_predictor)
271
272    def features_shape(self) -> torch.Size:
273        return self.f_shape
class Transformer(dlc2action.model.base_model.Model):
220class Transformer(Model):
221    """
222    A modification of Transformer-Encoder with additional max-pooling and upsampling
223
224    Set `num_pool` to 0 to get a standart transformer-encoder.
225    """
226
227    def __init__(
228        self,
229        N,
230        heads,
231        num_f_maps,
232        input_dim,
233        num_classes,
234        num_pool,
235        add_batchnorm=False,
236        feature_dim=None,
237        state_dict_path=None,
238        ssl_constructors=None,
239        ssl_types=None,
240        ssl_modules=None,
241    ):
242        input_dim = sum([x[0] for x in input_dim.values()])
243        if feature_dim is None:
244            feature_dim = num_classes
245            self.f_shape = None
246            self.params_predictor = None
247        else:
248            self.f_shape = torch.Size([int(feature_dim)])
249            self.params_predictor = {
250                "dim": int(feature_dim),
251                "num_classes": int(num_classes),
252            }
253        self.params = {
254            "d_model": int(num_f_maps),
255            "input_dim": int(input_dim),
256            "N": int(N),
257            "heads": int(heads),
258            "add_batchnorm": add_batchnorm,
259            "num_pool": int(num_pool),
260            "output_dim": int(feature_dim),
261        }
262        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
263
264    def _feature_extractor(self):
265        return _TransformerModule(**self.params)
266
267    def _predictor(self) -> torch.nn.Module:
268        if self.params_predictor is None:
269            return nn.Identity()
270        else:
271            return _Predictor(**self.params_predictor)
272
273    def features_shape(self) -> torch.Size:
274        return self.f_shape

A modification of Transformer-Encoder with additional max-pooling and upsampling

Set num_pool to 0 to get a standart transformer-encoder.

Transformer( N, heads, num_f_maps, input_dim, num_classes, num_pool, add_batchnorm=False, feature_dim=None, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
227    def __init__(
228        self,
229        N,
230        heads,
231        num_f_maps,
232        input_dim,
233        num_classes,
234        num_pool,
235        add_batchnorm=False,
236        feature_dim=None,
237        state_dict_path=None,
238        ssl_constructors=None,
239        ssl_types=None,
240        ssl_modules=None,
241    ):
242        input_dim = sum([x[0] for x in input_dim.values()])
243        if feature_dim is None:
244            feature_dim = num_classes
245            self.f_shape = None
246            self.params_predictor = None
247        else:
248            self.f_shape = torch.Size([int(feature_dim)])
249            self.params_predictor = {
250                "dim": int(feature_dim),
251                "num_classes": int(num_classes),
252            }
253        self.params = {
254            "d_model": int(num_f_maps),
255            "input_dim": int(input_dim),
256            "N": int(N),
257            "heads": int(heads),
258            "add_batchnorm": add_batchnorm,
259            "num_pool": int(num_pool),
260            "output_dim": int(feature_dim),
261        }
262        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored

def features_shape(self) -> torch.Size:
273    def features_shape(self) -> torch.Size:
274        return self.f_shape

Get the shape of feature extractor output

Returns

feature_shape : torch.Size shape of feature extractor output

Inherited Members
dlc2action.model.base_model.Model
process_labels
freeze_feature_extractor
unfreeze_feature_extractor
load_state_dict
ssl_off
ssl_on
main_task_on
main_task_off
set_ssl
extract_features
forward
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
T_destination
state_dict
register_load_state_dict_post_hook
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr