dlc2action.model.mlp
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6from dlc2action.model.base_model import Model 7import torch 8from torch import nn 9from typing import List, Union 10from torch.nn import functional as F 11 12 13class _MLPModule(nn.Module): 14 def __init__(self, f_maps_list, input_dims, num_classes, dropout_rates=None): 15 super(_MLPModule, self).__init__() 16 input_dims = int(sum([s[0] for s in input_dims.values()])) 17 if dropout_rates is None: 18 dropout_rates = 0.5 19 if not isinstance(dropout_rates, list): 20 dropout_rates = [dropout_rates for _ in range(input_dims)] 21 input_f_maps = [input_dims] + f_maps_list 22 output_f_maps = f_maps_list + [num_classes] 23 self.layers = nn.ModuleList( 24 [ 25 nn.Conv1d(in_f_maps, out_f_maps, 1) 26 for in_f_maps, out_f_maps in zip(input_f_maps, output_f_maps) 27 ] 28 ) 29 self.dropout = nn.ModuleList([nn.Dropout(r) for r in dropout_rates]) 30 31 def forward(self, x): 32 for i, layer in enumerate(self.layers): 33 x = layer(x) 34 if i < len(self.layers) - 1: 35 x = self.dropout[i](x) 36 x = F.relu(x) 37 return x 38 39 40class MLP(Model): 41 """ 42 A Multi-Layer Perceptron 43 """ 44 45 def __init__( 46 self, 47 f_maps_list, 48 input_dims, 49 num_classes, 50 dropout_rates=None, 51 state_dict_path=None, 52 ssl_constructors=None, 53 ssl_types=None, 54 ssl_modules=None, 55 ): 56 self.params = { 57 "f_maps_list": f_maps_list, 58 "input_dims": input_dims, 59 "num_classes": num_classes, 60 "dropout_rates": dropout_rates, 61 } 62 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 63 64 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 65 return _MLPModule(**self.params) 66 67 def _predictor(self) -> torch.nn.Module: 68 return nn.Identity() 69 70 def features_shape(self) -> torch.Size: 71 return None
41class MLP(Model): 42 """ 43 A Multi-Layer Perceptron 44 """ 45 46 def __init__( 47 self, 48 f_maps_list, 49 input_dims, 50 num_classes, 51 dropout_rates=None, 52 state_dict_path=None, 53 ssl_constructors=None, 54 ssl_types=None, 55 ssl_modules=None, 56 ): 57 self.params = { 58 "f_maps_list": f_maps_list, 59 "input_dims": input_dims, 60 "num_classes": num_classes, 61 "dropout_rates": dropout_rates, 62 } 63 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 64 65 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 66 return _MLPModule(**self.params) 67 68 def _predictor(self) -> torch.nn.Module: 69 return nn.Identity() 70 71 def features_shape(self) -> torch.Size: 72 return None
A Multi-Layer Perceptron
MLP( f_maps_list, input_dims, num_classes, dropout_rates=None, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
46 def __init__( 47 self, 48 f_maps_list, 49 input_dims, 50 num_classes, 51 dropout_rates=None, 52 state_dict_path=None, 53 ssl_constructors=None, 54 ssl_types=None, 55 ssl_modules=None, 56 ): 57 self.params = { 58 "f_maps_list": f_maps_list, 59 "input_dims": input_dims, 60 "num_classes": num_classes, 61 "dropout_rates": dropout_rates, 62 } 63 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
def
features_shape(self) -> torch.Size:
Get the shape of feature extractor output
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- forward
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- T_destination
- state_dict
- register_load_state_dict_post_hook
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr