dlc2action.model.mlp

 1#
 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
 5#
 6from dlc2action.model.base_model import Model
 7import torch
 8from torch import nn
 9from typing import List, Union
10from torch.nn import functional as F
11
12
13class _MLPModule(nn.Module):
14    def __init__(self, f_maps_list, input_dims, num_classes, dropout_rates=None):
15        super(_MLPModule, self).__init__()
16        input_dims = int(sum([s[0] for s in input_dims.values()]))
17        if dropout_rates is None:
18            dropout_rates = 0.5
19        if not isinstance(dropout_rates, list):
20            dropout_rates = [dropout_rates for _ in range(input_dims)]
21        input_f_maps = [input_dims] + f_maps_list
22        output_f_maps = f_maps_list + [num_classes]
23        self.layers = nn.ModuleList(
24            [
25                nn.Conv1d(in_f_maps, out_f_maps, 1)
26                for in_f_maps, out_f_maps in zip(input_f_maps, output_f_maps)
27            ]
28        )
29        self.dropout = nn.ModuleList([nn.Dropout(r) for r in dropout_rates])
30
31    def forward(self, x):
32        for i, layer in enumerate(self.layers):
33            x = layer(x)
34            if i < len(self.layers) - 1:
35                x = self.dropout[i](x)
36                x = F.relu(x)
37        return x
38
39
40class MLP(Model):
41    """
42    A Multi-Layer Perceptron
43    """
44
45    def __init__(
46        self,
47        f_maps_list,
48        input_dims,
49        num_classes,
50        dropout_rates=None,
51        state_dict_path=None,
52        ssl_constructors=None,
53        ssl_types=None,
54        ssl_modules=None,
55    ):
56        self.params = {
57            "f_maps_list": f_maps_list,
58            "input_dims": input_dims,
59            "num_classes": num_classes,
60            "dropout_rates": dropout_rates,
61        }
62        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
63
64    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
65        return _MLPModule(**self.params)
66
67    def _predictor(self) -> torch.nn.Module:
68        return nn.Identity()
69
70    def features_shape(self) -> torch.Size:
71        return None
class MLP(dlc2action.model.base_model.Model):
41class MLP(Model):
42    """
43    A Multi-Layer Perceptron
44    """
45
46    def __init__(
47        self,
48        f_maps_list,
49        input_dims,
50        num_classes,
51        dropout_rates=None,
52        state_dict_path=None,
53        ssl_constructors=None,
54        ssl_types=None,
55        ssl_modules=None,
56    ):
57        self.params = {
58            "f_maps_list": f_maps_list,
59            "input_dims": input_dims,
60            "num_classes": num_classes,
61            "dropout_rates": dropout_rates,
62        }
63        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
64
65    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
66        return _MLPModule(**self.params)
67
68    def _predictor(self) -> torch.nn.Module:
69        return nn.Identity()
70
71    def features_shape(self) -> torch.Size:
72        return None

A Multi-Layer Perceptron

MLP( f_maps_list, input_dims, num_classes, dropout_rates=None, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
46    def __init__(
47        self,
48        f_maps_list,
49        input_dims,
50        num_classes,
51        dropout_rates=None,
52        state_dict_path=None,
53        ssl_constructors=None,
54        ssl_types=None,
55        ssl_modules=None,
56    ):
57        self.params = {
58            "f_maps_list": f_maps_list,
59            "input_dims": input_dims,
60            "num_classes": num_classes,
61            "dropout_rates": dropout_rates,
62        }
63        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored

def features_shape(self) -> torch.Size:
71    def features_shape(self) -> torch.Size:
72        return None

Get the shape of feature extractor output

Returns

feature_shape : torch.Size shape of feature extractor output

Inherited Members
dlc2action.model.base_model.Model
process_labels
freeze_feature_extractor
unfreeze_feature_extractor
load_state_dict
ssl_off
ssl_on
main_task_on
main_task_off
set_ssl
extract_features
forward
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
T_destination
state_dict
register_load_state_dict_post_hook
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr