dlc2action.model.c3d
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7from typing import List, Union 8 9import torch 10from dlc2action.model.base_model import Model 11from dlc2action.model.ms_tcn_modules import MSRefinement 12from torch import nn 13from torch.nn import functional as F 14 15 16class ResLayer3D(nn.Module): 17 def __init__(self, in_channels, out_channels): 18 super().__init__() 19 self.conv1 = nn.Conv3d( 20 in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1) 21 ) 22 self.conv2 = nn.Conv3d( 23 out_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1) 24 ) 25 self.bn1 = nn.BatchNorm3d(out_channels) 26 self.bn2 = nn.BatchNorm3d(out_channels) 27 self.bn3 = nn.BatchNorm3d(out_channels) 28 if in_channels != out_channels: 29 self.conv3 = nn.Conv3d(in_channels, out_channels, kernel_size=(1, 1, 1)) 30 else: 31 self.conv3 = None 32 33 def forward(self, x): 34 f = self.conv1(x) 35 f = F.relu(self.bn1(f)) 36 f = self.conv2(f) 37 f = self.bn2(f) 38 if self.conv3: 39 x = self.bn3(self.conv3(x)) 40 return F.relu(f + x) 41 42 43class C3D(nn.Module): 44 def __init__(self, dim, loaded_dim): 45 super(C3D, self).__init__() 46 self.layers = nn.ModuleList() 47 48 self.layers.append(ResLayer3D(dim, 32)) 49 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 50 51 self.layers.append(ResLayer3D(32, 64)) 52 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 53 54 self.layers.append(ResLayer3D(64, 128)) 55 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 56 self.layers.append(ResLayer3D(128, 128)) 57 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 58 59 self.layers.append(ResLayer3D(128, 128)) 60 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 61 62 if loaded_dim is not None: 63 self.conv1 = nn.Conv1d(loaded_dim, 16, 3) 64 self.conv2 = nn.Conv1d(16, 16, 3) 65 self.conv3 = nn.Conv1d(128 + 16, 128, 3) 66 67 def forward(self, x): 68 loaded = None 69 if isinstance(x, list): 70 x, loaded = x 71 for layer in self.layers: 72 x = layer(x) 73 x = torch.reshape(x, (*x.shape[:-2], -1)) 74 x = torch.mean(x, dim=-1) 75 if loaded is not None: 76 loaded = F.relu(self.conv1(loaded)) 77 loaded = F.relu(self.conv2(loaded)) 78 x = F.relu(self.conv3(torch.cat([x, loaded], dim=1))) 79 return x 80 81 82class Predictor(nn.Module): 83 def __init__(self, dim, num_classes): 84 super(Predictor, self).__init__() 85 self.conv_out_1 = nn.Conv1d(dim, 64, kernel_size=1) 86 self.conv_out_2 = nn.Conv1d(64, num_classes, kernel_size=1) 87 88 def forward(self, x): 89 x = self.conv_out_1(x) 90 x = F.relu(x) 91 x = self.conv_out_2(x) 92 return x 93 94 95class C3D_A(Model): 96 def __init__( 97 self, 98 dims, 99 num_classes, 100 state_dict_path=None, 101 ssl_constructors=None, 102 ssl_types=None, 103 ssl_modules=None, 104 ): 105 dim = sum([x[0] for k, x in dims.items() if k != "loaded"]) 106 if "loaded" in dims: 107 loaded_dim = dims["loaded"][0] 108 else: 109 loaded_dim = None 110 self.pars1 = {"dim": dim, "loaded_dim": loaded_dim} 111 output_dims = C3D(**self.pars1)( 112 torch.ones((1, dim, *list(dims.values())[0][1:])) 113 ).shape 114 self.num_f_maps = output_dims[1] 115 self.pars2 = {"dim": self.num_f_maps, "num_classes": num_classes} 116 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 117 118 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 119 return C3D(**self.pars1) 120 121 def _predictor(self) -> torch.nn.Module: 122 return Predictor(**self.pars2) 123 124 def features_shape(self) -> torch.Size: 125 return torch.Size([128]) 126 127 128class C3D_A_MS(Model): 129 def __init__( 130 self, 131 dims, 132 num_classes, 133 num_layers_R, 134 num_R, 135 num_f_maps_R, 136 dropout_rate, 137 direction, 138 skip_connections, 139 exclusive, 140 attention_R="none", 141 block_size_R=0, 142 num_heads=1, 143 state_dict_path=None, 144 ssl_constructors=None, 145 ssl_types=None, 146 ssl_modules=None, 147 ): 148 dim = sum([x[0] for k, x in dims.items() if k != "loaded"]) 149 if "loaded" in dims: 150 loaded_dim = dims["loaded"][0] 151 else: 152 loaded_dim = None 153 self.pars1 = {"dim": dim, "loaded_dim": loaded_dim} 154 output_dims = C3D(**self.pars1)( 155 torch.ones((1, dim, *list(dims.values())[0][1:])) 156 ).shape 157 self.num_f_maps = output_dims[1] 158 self.pars_R = { 159 "exclusive": exclusive, 160 "num_layers_R": int(num_layers_R), 161 "num_R": int(num_R), 162 "num_f_maps_input": 128, 163 "num_f_maps": int(num_f_maps_R), 164 "num_classes": int(num_classes), 165 "dropout_rate": dropout_rate, 166 "skip_connections": skip_connections, 167 "direction": direction, 168 "block_size": int(block_size_R), 169 "num_heads": int(num_heads), 170 "attention": attention_R, 171 } 172 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 173 174 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 175 return C3D(**self.pars1) 176 177 def _predictor(self) -> torch.nn.Module: 178 return MSRefinement(**self.pars_R) 179 180 def features_shape(self) -> torch.Size: 181 return torch.Size([128])
17class ResLayer3D(nn.Module): 18 def __init__(self, in_channels, out_channels): 19 super().__init__() 20 self.conv1 = nn.Conv3d( 21 in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1) 22 ) 23 self.conv2 = nn.Conv3d( 24 out_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1) 25 ) 26 self.bn1 = nn.BatchNorm3d(out_channels) 27 self.bn2 = nn.BatchNorm3d(out_channels) 28 self.bn3 = nn.BatchNorm3d(out_channels) 29 if in_channels != out_channels: 30 self.conv3 = nn.Conv3d(in_channels, out_channels, kernel_size=(1, 1, 1)) 31 else: 32 self.conv3 = None 33 34 def forward(self, x): 35 f = self.conv1(x) 36 f = F.relu(self.bn1(f)) 37 f = self.conv2(f) 38 f = self.bn2(f) 39 if self.conv3: 40 x = self.bn3(self.conv3(x)) 41 return F.relu(f + x)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
18 def __init__(self, in_channels, out_channels): 19 super().__init__() 20 self.conv1 = nn.Conv3d( 21 in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1) 22 ) 23 self.conv2 = nn.Conv3d( 24 out_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1) 25 ) 26 self.bn1 = nn.BatchNorm3d(out_channels) 27 self.bn2 = nn.BatchNorm3d(out_channels) 28 self.bn3 = nn.BatchNorm3d(out_channels) 29 if in_channels != out_channels: 30 self.conv3 = nn.Conv3d(in_channels, out_channels, kernel_size=(1, 1, 1)) 31 else: 32 self.conv3 = None
Initialize internal Module state, shared by both nn.Module and ScriptModule.
34 def forward(self, x): 35 f = self.conv1(x) 36 f = F.relu(self.bn1(f)) 37 f = self.conv2(f) 38 f = self.bn2(f) 39 if self.conv3: 40 x = self.bn3(self.conv3(x)) 41 return F.relu(f + x)
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
44class C3D(nn.Module): 45 def __init__(self, dim, loaded_dim): 46 super(C3D, self).__init__() 47 self.layers = nn.ModuleList() 48 49 self.layers.append(ResLayer3D(dim, 32)) 50 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 51 52 self.layers.append(ResLayer3D(32, 64)) 53 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 54 55 self.layers.append(ResLayer3D(64, 128)) 56 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 57 self.layers.append(ResLayer3D(128, 128)) 58 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 59 60 self.layers.append(ResLayer3D(128, 128)) 61 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 62 63 if loaded_dim is not None: 64 self.conv1 = nn.Conv1d(loaded_dim, 16, 3) 65 self.conv2 = nn.Conv1d(16, 16, 3) 66 self.conv3 = nn.Conv1d(128 + 16, 128, 3) 67 68 def forward(self, x): 69 loaded = None 70 if isinstance(x, list): 71 x, loaded = x 72 for layer in self.layers: 73 x = layer(x) 74 x = torch.reshape(x, (*x.shape[:-2], -1)) 75 x = torch.mean(x, dim=-1) 76 if loaded is not None: 77 loaded = F.relu(self.conv1(loaded)) 78 loaded = F.relu(self.conv2(loaded)) 79 x = F.relu(self.conv3(torch.cat([x, loaded], dim=1))) 80 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
45 def __init__(self, dim, loaded_dim): 46 super(C3D, self).__init__() 47 self.layers = nn.ModuleList() 48 49 self.layers.append(ResLayer3D(dim, 32)) 50 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 51 52 self.layers.append(ResLayer3D(32, 64)) 53 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 54 55 self.layers.append(ResLayer3D(64, 128)) 56 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 57 self.layers.append(ResLayer3D(128, 128)) 58 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 59 60 self.layers.append(ResLayer3D(128, 128)) 61 self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2))) 62 63 if loaded_dim is not None: 64 self.conv1 = nn.Conv1d(loaded_dim, 16, 3) 65 self.conv2 = nn.Conv1d(16, 16, 3) 66 self.conv3 = nn.Conv1d(128 + 16, 128, 3)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
68 def forward(self, x): 69 loaded = None 70 if isinstance(x, list): 71 x, loaded = x 72 for layer in self.layers: 73 x = layer(x) 74 x = torch.reshape(x, (*x.shape[:-2], -1)) 75 x = torch.mean(x, dim=-1) 76 if loaded is not None: 77 loaded = F.relu(self.conv1(loaded)) 78 loaded = F.relu(self.conv2(loaded)) 79 x = F.relu(self.conv3(torch.cat([x, loaded], dim=1))) 80 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
83class Predictor(nn.Module): 84 def __init__(self, dim, num_classes): 85 super(Predictor, self).__init__() 86 self.conv_out_1 = nn.Conv1d(dim, 64, kernel_size=1) 87 self.conv_out_2 = nn.Conv1d(64, num_classes, kernel_size=1) 88 89 def forward(self, x): 90 x = self.conv_out_1(x) 91 x = F.relu(x) 92 x = self.conv_out_2(x) 93 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
84 def __init__(self, dim, num_classes): 85 super(Predictor, self).__init__() 86 self.conv_out_1 = nn.Conv1d(dim, 64, kernel_size=1) 87 self.conv_out_2 = nn.Conv1d(64, num_classes, kernel_size=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
89 def forward(self, x): 90 x = self.conv_out_1(x) 91 x = F.relu(x) 92 x = self.conv_out_2(x) 93 return x
Define the computation performed at every call.
Should be overridden by all subclasses.
Although the recipe for forward pass needs to be defined within
this function, one should call the Module
instance afterwards
instead of this since the former takes care of running the
registered hooks while the latter silently ignores them.
96class C3D_A(Model): 97 def __init__( 98 self, 99 dims, 100 num_classes, 101 state_dict_path=None, 102 ssl_constructors=None, 103 ssl_types=None, 104 ssl_modules=None, 105 ): 106 dim = sum([x[0] for k, x in dims.items() if k != "loaded"]) 107 if "loaded" in dims: 108 loaded_dim = dims["loaded"][0] 109 else: 110 loaded_dim = None 111 self.pars1 = {"dim": dim, "loaded_dim": loaded_dim} 112 output_dims = C3D(**self.pars1)( 113 torch.ones((1, dim, *list(dims.values())[0][1:])) 114 ).shape 115 self.num_f_maps = output_dims[1] 116 self.pars2 = {"dim": self.num_f_maps, "num_classes": num_classes} 117 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 118 119 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 120 return C3D(**self.pars1) 121 122 def _predictor(self) -> torch.nn.Module: 123 return Predictor(**self.pars2) 124 125 def features_shape(self) -> torch.Size: 126 return torch.Size([128])
Base class for all models.
Manages interaction of base model and SSL modules + ensures consistent input and output format.
97 def __init__( 98 self, 99 dims, 100 num_classes, 101 state_dict_path=None, 102 ssl_constructors=None, 103 ssl_types=None, 104 ssl_modules=None, 105 ): 106 dim = sum([x[0] for k, x in dims.items() if k != "loaded"]) 107 if "loaded" in dims: 108 loaded_dim = dims["loaded"][0] 109 else: 110 loaded_dim = None 111 self.pars1 = {"dim": dim, "loaded_dim": loaded_dim} 112 output_dims = C3D(**self.pars1)( 113 torch.ones((1, dim, *list(dims.values())[0][1:])) 114 ).shape 115 self.num_f_maps = output_dims[1] 116 self.pars2 = {"dim": self.num_f_maps, "num_classes": num_classes} 117 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Initialize the model.
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt
Get the shape of feature extractor output.
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- feature_extractor
- feature_extractors
- predictor
- ssl_active
- main_task_active
- prompt_function
- class_tensors
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- transform_labels
- forward
129class C3D_A_MS(Model): 130 def __init__( 131 self, 132 dims, 133 num_classes, 134 num_layers_R, 135 num_R, 136 num_f_maps_R, 137 dropout_rate, 138 direction, 139 skip_connections, 140 exclusive, 141 attention_R="none", 142 block_size_R=0, 143 num_heads=1, 144 state_dict_path=None, 145 ssl_constructors=None, 146 ssl_types=None, 147 ssl_modules=None, 148 ): 149 dim = sum([x[0] for k, x in dims.items() if k != "loaded"]) 150 if "loaded" in dims: 151 loaded_dim = dims["loaded"][0] 152 else: 153 loaded_dim = None 154 self.pars1 = {"dim": dim, "loaded_dim": loaded_dim} 155 output_dims = C3D(**self.pars1)( 156 torch.ones((1, dim, *list(dims.values())[0][1:])) 157 ).shape 158 self.num_f_maps = output_dims[1] 159 self.pars_R = { 160 "exclusive": exclusive, 161 "num_layers_R": int(num_layers_R), 162 "num_R": int(num_R), 163 "num_f_maps_input": 128, 164 "num_f_maps": int(num_f_maps_R), 165 "num_classes": int(num_classes), 166 "dropout_rate": dropout_rate, 167 "skip_connections": skip_connections, 168 "direction": direction, 169 "block_size": int(block_size_R), 170 "num_heads": int(num_heads), 171 "attention": attention_R, 172 } 173 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 174 175 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 176 return C3D(**self.pars1) 177 178 def _predictor(self) -> torch.nn.Module: 179 return MSRefinement(**self.pars_R) 180 181 def features_shape(self) -> torch.Size: 182 return torch.Size([128])
Base class for all models.
Manages interaction of base model and SSL modules + ensures consistent input and output format.
130 def __init__( 131 self, 132 dims, 133 num_classes, 134 num_layers_R, 135 num_R, 136 num_f_maps_R, 137 dropout_rate, 138 direction, 139 skip_connections, 140 exclusive, 141 attention_R="none", 142 block_size_R=0, 143 num_heads=1, 144 state_dict_path=None, 145 ssl_constructors=None, 146 ssl_types=None, 147 ssl_modules=None, 148 ): 149 dim = sum([x[0] for k, x in dims.items() if k != "loaded"]) 150 if "loaded" in dims: 151 loaded_dim = dims["loaded"][0] 152 else: 153 loaded_dim = None 154 self.pars1 = {"dim": dim, "loaded_dim": loaded_dim} 155 output_dims = C3D(**self.pars1)( 156 torch.ones((1, dim, *list(dims.values())[0][1:])) 157 ).shape 158 self.num_f_maps = output_dims[1] 159 self.pars_R = { 160 "exclusive": exclusive, 161 "num_layers_R": int(num_layers_R), 162 "num_R": int(num_R), 163 "num_f_maps_input": 128, 164 "num_f_maps": int(num_f_maps_R), 165 "num_classes": int(num_classes), 166 "dropout_rate": dropout_rate, 167 "skip_connections": skip_connections, 168 "direction": direction, 169 "block_size": int(block_size_R), 170 "num_heads": int(num_heads), 171 "attention": attention_R, 172 } 173 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Initialize the model.
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt
Get the shape of feature extractor output.
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- feature_extractor
- feature_extractors
- predictor
- ssl_active
- main_task_active
- prompt_function
- class_tensors
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- transform_labels
- forward