dlc2action.model.c3d

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7from typing import List, Union
  8
  9import torch
 10from dlc2action.model.base_model import Model
 11from dlc2action.model.ms_tcn_modules import MSRefinement
 12from torch import nn
 13from torch.nn import functional as F
 14
 15
 16class ResLayer3D(nn.Module):
 17    def __init__(self, in_channels, out_channels):
 18        super().__init__()
 19        self.conv1 = nn.Conv3d(
 20            in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1)
 21        )
 22        self.conv2 = nn.Conv3d(
 23            out_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1)
 24        )
 25        self.bn1 = nn.BatchNorm3d(out_channels)
 26        self.bn2 = nn.BatchNorm3d(out_channels)
 27        self.bn3 = nn.BatchNorm3d(out_channels)
 28        if in_channels != out_channels:
 29            self.conv3 = nn.Conv3d(in_channels, out_channels, kernel_size=(1, 1, 1))
 30        else:
 31            self.conv3 = None
 32
 33    def forward(self, x):
 34        f = self.conv1(x)
 35        f = F.relu(self.bn1(f))
 36        f = self.conv2(f)
 37        f = self.bn2(f)
 38        if self.conv3:
 39            x = self.bn3(self.conv3(x))
 40        return F.relu(f + x)
 41
 42
 43class C3D(nn.Module):
 44    def __init__(self, dim, loaded_dim):
 45        super(C3D, self).__init__()
 46        self.layers = nn.ModuleList()
 47
 48        self.layers.append(ResLayer3D(dim, 32))
 49        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
 50
 51        self.layers.append(ResLayer3D(32, 64))
 52        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
 53
 54        self.layers.append(ResLayer3D(64, 128))
 55        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
 56        self.layers.append(ResLayer3D(128, 128))
 57        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
 58
 59        self.layers.append(ResLayer3D(128, 128))
 60        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
 61
 62        if loaded_dim is not None:
 63            self.conv1 = nn.Conv1d(loaded_dim, 16, 3)
 64            self.conv2 = nn.Conv1d(16, 16, 3)
 65            self.conv3 = nn.Conv1d(128 + 16, 128, 3)
 66
 67    def forward(self, x):
 68        loaded = None
 69        if isinstance(x, list):
 70            x, loaded = x
 71        for layer in self.layers:
 72            x = layer(x)
 73        x = torch.reshape(x, (*x.shape[:-2], -1))
 74        x = torch.mean(x, dim=-1)
 75        if loaded is not None:
 76            loaded = F.relu(self.conv1(loaded))
 77            loaded = F.relu(self.conv2(loaded))
 78            x = F.relu(self.conv3(torch.cat([x, loaded], dim=1)))
 79        return x
 80
 81
 82class Predictor(nn.Module):
 83    def __init__(self, dim, num_classes):
 84        super(Predictor, self).__init__()
 85        self.conv_out_1 = nn.Conv1d(dim, 64, kernel_size=1)
 86        self.conv_out_2 = nn.Conv1d(64, num_classes, kernel_size=1)
 87
 88    def forward(self, x):
 89        x = self.conv_out_1(x)
 90        x = F.relu(x)
 91        x = self.conv_out_2(x)
 92        return x
 93
 94
 95class C3D_A(Model):
 96    def __init__(
 97        self,
 98        dims,
 99        num_classes,
100        state_dict_path=None,
101        ssl_constructors=None,
102        ssl_types=None,
103        ssl_modules=None,
104    ):
105        dim = sum([x[0] for k, x in dims.items() if k != "loaded"])
106        if "loaded" in dims:
107            loaded_dim = dims["loaded"][0]
108        else:
109            loaded_dim = None
110        self.pars1 = {"dim": dim, "loaded_dim": loaded_dim}
111        output_dims = C3D(**self.pars1)(
112            torch.ones((1, dim, *list(dims.values())[0][1:]))
113        ).shape
114        self.num_f_maps = output_dims[1]
115        self.pars2 = {"dim": self.num_f_maps, "num_classes": num_classes}
116        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
117
118    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
119        return C3D(**self.pars1)
120
121    def _predictor(self) -> torch.nn.Module:
122        return Predictor(**self.pars2)
123
124    def features_shape(self) -> torch.Size:
125        return torch.Size([128])
126
127
128class C3D_A_MS(Model):
129    def __init__(
130        self,
131        dims,
132        num_classes,
133        num_layers_R,
134        num_R,
135        num_f_maps_R,
136        dropout_rate,
137        direction,
138        skip_connections,
139        exclusive,
140        attention_R="none",
141        block_size_R=0,
142        num_heads=1,
143        state_dict_path=None,
144        ssl_constructors=None,
145        ssl_types=None,
146        ssl_modules=None,
147    ):
148        dim = sum([x[0] for k, x in dims.items() if k != "loaded"])
149        if "loaded" in dims:
150            loaded_dim = dims["loaded"][0]
151        else:
152            loaded_dim = None
153        self.pars1 = {"dim": dim, "loaded_dim": loaded_dim}
154        output_dims = C3D(**self.pars1)(
155            torch.ones((1, dim, *list(dims.values())[0][1:]))
156        ).shape
157        self.num_f_maps = output_dims[1]
158        self.pars_R = {
159            "exclusive": exclusive,
160            "num_layers_R": int(num_layers_R),
161            "num_R": int(num_R),
162            "num_f_maps_input": 128,
163            "num_f_maps": int(num_f_maps_R),
164            "num_classes": int(num_classes),
165            "dropout_rate": dropout_rate,
166            "skip_connections": skip_connections,
167            "direction": direction,
168            "block_size": int(block_size_R),
169            "num_heads": int(num_heads),
170            "attention": attention_R,
171        }
172        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
173
174    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
175        return C3D(**self.pars1)
176
177    def _predictor(self) -> torch.nn.Module:
178        return MSRefinement(**self.pars_R)
179
180    def features_shape(self) -> torch.Size:
181        return torch.Size([128])
class ResLayer3D(torch.nn.modules.module.Module):
17class ResLayer3D(nn.Module):
18    def __init__(self, in_channels, out_channels):
19        super().__init__()
20        self.conv1 = nn.Conv3d(
21            in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1)
22        )
23        self.conv2 = nn.Conv3d(
24            out_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1)
25        )
26        self.bn1 = nn.BatchNorm3d(out_channels)
27        self.bn2 = nn.BatchNorm3d(out_channels)
28        self.bn3 = nn.BatchNorm3d(out_channels)
29        if in_channels != out_channels:
30            self.conv3 = nn.Conv3d(in_channels, out_channels, kernel_size=(1, 1, 1))
31        else:
32            self.conv3 = None
33
34    def forward(self, x):
35        f = self.conv1(x)
36        f = F.relu(self.bn1(f))
37        f = self.conv2(f)
38        f = self.bn2(f)
39        if self.conv3:
40            x = self.bn3(self.conv3(x))
41        return F.relu(f + x)

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ResLayer3D(in_channels, out_channels)
18    def __init__(self, in_channels, out_channels):
19        super().__init__()
20        self.conv1 = nn.Conv3d(
21            in_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1)
22        )
23        self.conv2 = nn.Conv3d(
24            out_channels, out_channels, kernel_size=(3, 3, 3), padding=(1, 1, 1)
25        )
26        self.bn1 = nn.BatchNorm3d(out_channels)
27        self.bn2 = nn.BatchNorm3d(out_channels)
28        self.bn3 = nn.BatchNorm3d(out_channels)
29        if in_channels != out_channels:
30            self.conv3 = nn.Conv3d(in_channels, out_channels, kernel_size=(1, 1, 1))
31        else:
32            self.conv3 = None

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv1
conv2
bn1
bn2
bn3
def forward(self, x):
34    def forward(self, x):
35        f = self.conv1(x)
36        f = F.relu(self.bn1(f))
37        f = self.conv2(f)
38        f = self.bn2(f)
39        if self.conv3:
40            x = self.bn3(self.conv3(x))
41        return F.relu(f + x)

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class C3D(torch.nn.modules.module.Module):
44class C3D(nn.Module):
45    def __init__(self, dim, loaded_dim):
46        super(C3D, self).__init__()
47        self.layers = nn.ModuleList()
48
49        self.layers.append(ResLayer3D(dim, 32))
50        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
51
52        self.layers.append(ResLayer3D(32, 64))
53        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
54
55        self.layers.append(ResLayer3D(64, 128))
56        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
57        self.layers.append(ResLayer3D(128, 128))
58        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
59
60        self.layers.append(ResLayer3D(128, 128))
61        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
62
63        if loaded_dim is not None:
64            self.conv1 = nn.Conv1d(loaded_dim, 16, 3)
65            self.conv2 = nn.Conv1d(16, 16, 3)
66            self.conv3 = nn.Conv1d(128 + 16, 128, 3)
67
68    def forward(self, x):
69        loaded = None
70        if isinstance(x, list):
71            x, loaded = x
72        for layer in self.layers:
73            x = layer(x)
74        x = torch.reshape(x, (*x.shape[:-2], -1))
75        x = torch.mean(x, dim=-1)
76        if loaded is not None:
77            loaded = F.relu(self.conv1(loaded))
78            loaded = F.relu(self.conv2(loaded))
79            x = F.relu(self.conv3(torch.cat([x, loaded], dim=1)))
80        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

C3D(dim, loaded_dim)
45    def __init__(self, dim, loaded_dim):
46        super(C3D, self).__init__()
47        self.layers = nn.ModuleList()
48
49        self.layers.append(ResLayer3D(dim, 32))
50        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
51
52        self.layers.append(ResLayer3D(32, 64))
53        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
54
55        self.layers.append(ResLayer3D(64, 128))
56        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
57        self.layers.append(ResLayer3D(128, 128))
58        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
59
60        self.layers.append(ResLayer3D(128, 128))
61        self.layers.append(nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)))
62
63        if loaded_dim is not None:
64            self.conv1 = nn.Conv1d(loaded_dim, 16, 3)
65            self.conv2 = nn.Conv1d(16, 16, 3)
66            self.conv3 = nn.Conv1d(128 + 16, 128, 3)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

layers
def forward(self, x):
68    def forward(self, x):
69        loaded = None
70        if isinstance(x, list):
71            x, loaded = x
72        for layer in self.layers:
73            x = layer(x)
74        x = torch.reshape(x, (*x.shape[:-2], -1))
75        x = torch.mean(x, dim=-1)
76        if loaded is not None:
77            loaded = F.relu(self.conv1(loaded))
78            loaded = F.relu(self.conv2(loaded))
79            x = F.relu(self.conv3(torch.cat([x, loaded], dim=1)))
80        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class Predictor(torch.nn.modules.module.Module):
83class Predictor(nn.Module):
84    def __init__(self, dim, num_classes):
85        super(Predictor, self).__init__()
86        self.conv_out_1 = nn.Conv1d(dim, 64, kernel_size=1)
87        self.conv_out_2 = nn.Conv1d(64, num_classes, kernel_size=1)
88
89    def forward(self, x):
90        x = self.conv_out_1(x)
91        x = F.relu(x)
92        x = self.conv_out_2(x)
93        return x

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

Predictor(dim, num_classes)
84    def __init__(self, dim, num_classes):
85        super(Predictor, self).__init__()
86        self.conv_out_1 = nn.Conv1d(dim, 64, kernel_size=1)
87        self.conv_out_2 = nn.Conv1d(64, num_classes, kernel_size=1)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

conv_out_1
conv_out_2
def forward(self, x):
89    def forward(self, x):
90        x = self.conv_out_1(x)
91        x = F.relu(x)
92        x = self.conv_out_2(x)
93        return x

Define the computation performed at every call.

Should be overridden by all subclasses.

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class C3D_A(dlc2action.model.base_model.Model):
 96class C3D_A(Model):
 97    def __init__(
 98        self,
 99        dims,
100        num_classes,
101        state_dict_path=None,
102        ssl_constructors=None,
103        ssl_types=None,
104        ssl_modules=None,
105    ):
106        dim = sum([x[0] for k, x in dims.items() if k != "loaded"])
107        if "loaded" in dims:
108            loaded_dim = dims["loaded"][0]
109        else:
110            loaded_dim = None
111        self.pars1 = {"dim": dim, "loaded_dim": loaded_dim}
112        output_dims = C3D(**self.pars1)(
113            torch.ones((1, dim, *list(dims.values())[0][1:]))
114        ).shape
115        self.num_f_maps = output_dims[1]
116        self.pars2 = {"dim": self.num_f_maps, "num_classes": num_classes}
117        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
118
119    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
120        return C3D(**self.pars1)
121
122    def _predictor(self) -> torch.nn.Module:
123        return Predictor(**self.pars2)
124
125    def features_shape(self) -> torch.Size:
126        return torch.Size([128])

Base class for all models.

Manages interaction of base model and SSL modules + ensures consistent input and output format.

C3D_A( dims, num_classes, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
 97    def __init__(
 98        self,
 99        dims,
100        num_classes,
101        state_dict_path=None,
102        ssl_constructors=None,
103        ssl_types=None,
104        ssl_modules=None,
105    ):
106        dim = sum([x[0] for k, x in dims.items() if k != "loaded"])
107        if "loaded" in dims:
108            loaded_dim = dims["loaded"][0]
109        else:
110            loaded_dim = None
111        self.pars1 = {"dim": dim, "loaded_dim": loaded_dim}
112        output_dims = C3D(**self.pars1)(
113            torch.ones((1, dim, *list(dims.values())[0][1:]))
114        ).shape
115        self.num_f_maps = output_dims[1]
116        self.pars2 = {"dim": self.num_f_maps, "num_classes": num_classes}
117        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Initialize the model.

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt

pars1
num_f_maps
pars2
def features_shape(self) -> torch.Size:
125    def features_shape(self) -> torch.Size:
126        return torch.Size([128])

Get the shape of feature extractor output.

Returns

feature_shape : torch.Size shape of feature extractor output

class C3D_A_MS(dlc2action.model.base_model.Model):
129class C3D_A_MS(Model):
130    def __init__(
131        self,
132        dims,
133        num_classes,
134        num_layers_R,
135        num_R,
136        num_f_maps_R,
137        dropout_rate,
138        direction,
139        skip_connections,
140        exclusive,
141        attention_R="none",
142        block_size_R=0,
143        num_heads=1,
144        state_dict_path=None,
145        ssl_constructors=None,
146        ssl_types=None,
147        ssl_modules=None,
148    ):
149        dim = sum([x[0] for k, x in dims.items() if k != "loaded"])
150        if "loaded" in dims:
151            loaded_dim = dims["loaded"][0]
152        else:
153            loaded_dim = None
154        self.pars1 = {"dim": dim, "loaded_dim": loaded_dim}
155        output_dims = C3D(**self.pars1)(
156            torch.ones((1, dim, *list(dims.values())[0][1:]))
157        ).shape
158        self.num_f_maps = output_dims[1]
159        self.pars_R = {
160            "exclusive": exclusive,
161            "num_layers_R": int(num_layers_R),
162            "num_R": int(num_R),
163            "num_f_maps_input": 128,
164            "num_f_maps": int(num_f_maps_R),
165            "num_classes": int(num_classes),
166            "dropout_rate": dropout_rate,
167            "skip_connections": skip_connections,
168            "direction": direction,
169            "block_size": int(block_size_R),
170            "num_heads": int(num_heads),
171            "attention": attention_R,
172        }
173        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
174
175    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
176        return C3D(**self.pars1)
177
178    def _predictor(self) -> torch.nn.Module:
179        return MSRefinement(**self.pars_R)
180
181    def features_shape(self) -> torch.Size:
182        return torch.Size([128])

Base class for all models.

Manages interaction of base model and SSL modules + ensures consistent input and output format.

C3D_A_MS( dims, num_classes, num_layers_R, num_R, num_f_maps_R, dropout_rate, direction, skip_connections, exclusive, attention_R='none', block_size_R=0, num_heads=1, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
130    def __init__(
131        self,
132        dims,
133        num_classes,
134        num_layers_R,
135        num_R,
136        num_f_maps_R,
137        dropout_rate,
138        direction,
139        skip_connections,
140        exclusive,
141        attention_R="none",
142        block_size_R=0,
143        num_heads=1,
144        state_dict_path=None,
145        ssl_constructors=None,
146        ssl_types=None,
147        ssl_modules=None,
148    ):
149        dim = sum([x[0] for k, x in dims.items() if k != "loaded"])
150        if "loaded" in dims:
151            loaded_dim = dims["loaded"][0]
152        else:
153            loaded_dim = None
154        self.pars1 = {"dim": dim, "loaded_dim": loaded_dim}
155        output_dims = C3D(**self.pars1)(
156            torch.ones((1, dim, *list(dims.values())[0][1:]))
157        ).shape
158        self.num_f_maps = output_dims[1]
159        self.pars_R = {
160            "exclusive": exclusive,
161            "num_layers_R": int(num_layers_R),
162            "num_R": int(num_R),
163            "num_f_maps_input": 128,
164            "num_f_maps": int(num_f_maps_R),
165            "num_classes": int(num_classes),
166            "dropout_rate": dropout_rate,
167            "skip_connections": skip_connections,
168            "direction": direction,
169            "block_size": int(block_size_R),
170            "num_heads": int(num_heads),
171            "attention": attention_R,
172        }
173        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Initialize the model.

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt

pars1
num_f_maps
pars_R
def features_shape(self) -> torch.Size:
181    def features_shape(self) -> torch.Size:
182        return torch.Size([128])

Get the shape of feature extractor output.

Returns

feature_shape : torch.Size shape of feature extractor output