dlc2action.model.mlp

 1#
 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. 
 5# A copy is included in dlc2action/LICENSE.AGPL.
 6#
 7from typing import List, Union
 8
 9import torch
10from dlc2action.model.base_model import Model
11from torch import nn
12from torch.nn import functional as F
13
14
15class MLPModule(nn.Module):
16    def __init__(self, f_maps_list, input_dims, num_classes, dropout_rates=None):
17        super(MLPModule, self).__init__()
18        input_dims = int(sum([s[0] for s in input_dims.values()]))
19        if dropout_rates is None:
20            dropout_rates = 0.5
21        if not isinstance(dropout_rates, list):
22            dropout_rates = [dropout_rates for _ in range(input_dims)]
23        input_f_maps = [input_dims] + f_maps_list
24        output_f_maps = f_maps_list + [num_classes]
25        self.layers = nn.ModuleList(
26            [
27                nn.Conv1d(in_f_maps, out_f_maps, 1)
28                for in_f_maps, out_f_maps in zip(input_f_maps, output_f_maps)
29            ]
30        )
31        self.dropout = nn.ModuleList([nn.Dropout(r) for r in dropout_rates])
32
33    def forward(self, x):
34        """Forward pass."""
35        for i, layer in enumerate(self.layers):
36            x = layer(x)
37            if i < len(self.layers) - 1:
38                x = self.dropout[i](x)
39                x = F.relu(x)
40        return x
41
42
43class MLP(Model):
44    """
45    A Multi-Layer Perceptron
46    """
47
48    def __init__(
49        self,
50        f_maps_list,
51        input_dims,
52        num_classes,
53        dropout_rates=None,
54        state_dict_path=None,
55        ssl_constructors=None,
56        ssl_types=None,
57        ssl_modules=None,
58    ):
59        self.params = {
60            "f_maps_list": f_maps_list,
61            "input_dims": input_dims,
62            "num_classes": num_classes,
63            "dropout_rates": dropout_rates,
64        }
65        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
66
67    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
68        return MLPModule(**self.params)
69
70    def _predictor(self) -> torch.nn.Module:
71        return nn.Identity()
72
73    def features_shape(self) -> torch.Size:
74        return None
class MLPModule(torch.nn.modules.module.Module):
16class MLPModule(nn.Module):
17    def __init__(self, f_maps_list, input_dims, num_classes, dropout_rates=None):
18        super(MLPModule, self).__init__()
19        input_dims = int(sum([s[0] for s in input_dims.values()]))
20        if dropout_rates is None:
21            dropout_rates = 0.5
22        if not isinstance(dropout_rates, list):
23            dropout_rates = [dropout_rates for _ in range(input_dims)]
24        input_f_maps = [input_dims] + f_maps_list
25        output_f_maps = f_maps_list + [num_classes]
26        self.layers = nn.ModuleList(
27            [
28                nn.Conv1d(in_f_maps, out_f_maps, 1)
29                for in_f_maps, out_f_maps in zip(input_f_maps, output_f_maps)
30            ]
31        )
32        self.dropout = nn.ModuleList([nn.Dropout(r) for r in dropout_rates])
33
34    def forward(self, x):
35        """Forward pass."""
36        for i, layer in enumerate(self.layers):
37            x = layer(x)
38            if i < len(self.layers) - 1:
39                x = self.dropout[i](x)
40                x = F.relu(x)
41        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

MLPModule(f_maps_list, input_dims, num_classes, dropout_rates=None)
17    def __init__(self, f_maps_list, input_dims, num_classes, dropout_rates=None):
18        super(MLPModule, self).__init__()
19        input_dims = int(sum([s[0] for s in input_dims.values()]))
20        if dropout_rates is None:
21            dropout_rates = 0.5
22        if not isinstance(dropout_rates, list):
23            dropout_rates = [dropout_rates for _ in range(input_dims)]
24        input_f_maps = [input_dims] + f_maps_list
25        output_f_maps = f_maps_list + [num_classes]
26        self.layers = nn.ModuleList(
27            [
28                nn.Conv1d(in_f_maps, out_f_maps, 1)
29                for in_f_maps, out_f_maps in zip(input_f_maps, output_f_maps)
30            ]
31        )
32        self.dropout = nn.ModuleList([nn.Dropout(r) for r in dropout_rates])

Initialize internal Module state, shared by both nn.Module and ScriptModule.

layers
dropout
def forward(self, x):
34    def forward(self, x):
35        """Forward pass."""
36        for i, layer in enumerate(self.layers):
37            x = layer(x)
38            if i < len(self.layers) - 1:
39                x = self.dropout[i](x)
40                x = F.relu(x)
41        return x

Forward pass.

class MLP(dlc2action.model.base_model.Model):
44class MLP(Model):
45    """
46    A Multi-Layer Perceptron
47    """
48
49    def __init__(
50        self,
51        f_maps_list,
52        input_dims,
53        num_classes,
54        dropout_rates=None,
55        state_dict_path=None,
56        ssl_constructors=None,
57        ssl_types=None,
58        ssl_modules=None,
59    ):
60        self.params = {
61            "f_maps_list": f_maps_list,
62            "input_dims": input_dims,
63            "num_classes": num_classes,
64            "dropout_rates": dropout_rates,
65        }
66        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
67
68    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
69        return MLPModule(**self.params)
70
71    def _predictor(self) -> torch.nn.Module:
72        return nn.Identity()
73
74    def features_shape(self) -> torch.Size:
75        return None

A Multi-Layer Perceptron

MLP( f_maps_list, input_dims, num_classes, dropout_rates=None, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
49    def __init__(
50        self,
51        f_maps_list,
52        input_dims,
53        num_classes,
54        dropout_rates=None,
55        state_dict_path=None,
56        ssl_constructors=None,
57        ssl_types=None,
58        ssl_modules=None,
59    ):
60        self.params = {
61            "f_maps_list": f_maps_list,
62            "input_dims": input_dims,
63            "num_classes": num_classes,
64            "dropout_rates": dropout_rates,
65        }
66        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Initialize the model.

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt

params
def features_shape(self) -> torch.Size:
74    def features_shape(self) -> torch.Size:
75        return None

Get the shape of feature extractor output.

Returns

feature_shape : torch.Size shape of feature extractor output