dlc2action.model.mlp
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7from typing import List, Union 8 9import torch 10from dlc2action.model.base_model import Model 11from torch import nn 12from torch.nn import functional as F 13 14 15class MLPModule(nn.Module): 16 def __init__(self, f_maps_list, input_dims, num_classes, dropout_rates=None): 17 super(MLPModule, self).__init__() 18 input_dims = int(sum([s[0] for s in input_dims.values()])) 19 if dropout_rates is None: 20 dropout_rates = 0.5 21 if not isinstance(dropout_rates, list): 22 dropout_rates = [dropout_rates for _ in range(input_dims)] 23 input_f_maps = [input_dims] + f_maps_list 24 output_f_maps = f_maps_list + [num_classes] 25 self.layers = nn.ModuleList( 26 [ 27 nn.Conv1d(in_f_maps, out_f_maps, 1) 28 for in_f_maps, out_f_maps in zip(input_f_maps, output_f_maps) 29 ] 30 ) 31 self.dropout = nn.ModuleList([nn.Dropout(r) for r in dropout_rates]) 32 33 def forward(self, x): 34 """Forward pass.""" 35 for i, layer in enumerate(self.layers): 36 x = layer(x) 37 if i < len(self.layers) - 1: 38 x = self.dropout[i](x) 39 x = F.relu(x) 40 return x 41 42 43class MLP(Model): 44 """ 45 A Multi-Layer Perceptron 46 """ 47 48 def __init__( 49 self, 50 f_maps_list, 51 input_dims, 52 num_classes, 53 dropout_rates=None, 54 state_dict_path=None, 55 ssl_constructors=None, 56 ssl_types=None, 57 ssl_modules=None, 58 ): 59 self.params = { 60 "f_maps_list": f_maps_list, 61 "input_dims": input_dims, 62 "num_classes": num_classes, 63 "dropout_rates": dropout_rates, 64 } 65 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 66 67 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 68 return MLPModule(**self.params) 69 70 def _predictor(self) -> torch.nn.Module: 71 return nn.Identity() 72 73 def features_shape(self) -> torch.Size: 74 return None
16class MLPModule(nn.Module): 17 def __init__(self, f_maps_list, input_dims, num_classes, dropout_rates=None): 18 super(MLPModule, self).__init__() 19 input_dims = int(sum([s[0] for s in input_dims.values()])) 20 if dropout_rates is None: 21 dropout_rates = 0.5 22 if not isinstance(dropout_rates, list): 23 dropout_rates = [dropout_rates for _ in range(input_dims)] 24 input_f_maps = [input_dims] + f_maps_list 25 output_f_maps = f_maps_list + [num_classes] 26 self.layers = nn.ModuleList( 27 [ 28 nn.Conv1d(in_f_maps, out_f_maps, 1) 29 for in_f_maps, out_f_maps in zip(input_f_maps, output_f_maps) 30 ] 31 ) 32 self.dropout = nn.ModuleList([nn.Dropout(r) for r in dropout_rates]) 33 34 def forward(self, x): 35 """Forward pass.""" 36 for i, layer in enumerate(self.layers): 37 x = layer(x) 38 if i < len(self.layers) - 1: 39 x = self.dropout[i](x) 40 x = F.relu(x) 41 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
17 def __init__(self, f_maps_list, input_dims, num_classes, dropout_rates=None): 18 super(MLPModule, self).__init__() 19 input_dims = int(sum([s[0] for s in input_dims.values()])) 20 if dropout_rates is None: 21 dropout_rates = 0.5 22 if not isinstance(dropout_rates, list): 23 dropout_rates = [dropout_rates for _ in range(input_dims)] 24 input_f_maps = [input_dims] + f_maps_list 25 output_f_maps = f_maps_list + [num_classes] 26 self.layers = nn.ModuleList( 27 [ 28 nn.Conv1d(in_f_maps, out_f_maps, 1) 29 for in_f_maps, out_f_maps in zip(input_f_maps, output_f_maps) 30 ] 31 ) 32 self.dropout = nn.ModuleList([nn.Dropout(r) for r in dropout_rates])
Initialize internal Module state, shared by both nn.Module and ScriptModule.
44class MLP(Model): 45 """ 46 A Multi-Layer Perceptron 47 """ 48 49 def __init__( 50 self, 51 f_maps_list, 52 input_dims, 53 num_classes, 54 dropout_rates=None, 55 state_dict_path=None, 56 ssl_constructors=None, 57 ssl_types=None, 58 ssl_modules=None, 59 ): 60 self.params = { 61 "f_maps_list": f_maps_list, 62 "input_dims": input_dims, 63 "num_classes": num_classes, 64 "dropout_rates": dropout_rates, 65 } 66 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 67 68 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 69 return MLPModule(**self.params) 70 71 def _predictor(self) -> torch.nn.Module: 72 return nn.Identity() 73 74 def features_shape(self) -> torch.Size: 75 return None
A Multi-Layer Perceptron
49 def __init__( 50 self, 51 f_maps_list, 52 input_dims, 53 num_classes, 54 dropout_rates=None, 55 state_dict_path=None, 56 ssl_constructors=None, 57 ssl_types=None, 58 ssl_modules=None, 59 ): 60 self.params = { 61 "f_maps_list": f_maps_list, 62 "input_dims": input_dims, 63 "num_classes": num_classes, 64 "dropout_rates": dropout_rates, 65 } 66 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Initialize the model.
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt
Get the shape of feature extractor output.
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- feature_extractor
- feature_extractors
- predictor
- ssl_active
- main_task_active
- prompt_function
- class_tensors
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- transform_labels
- forward