dlc2action.model.c2f_tcn
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# Incorporates code adapted from C2F-TCN by dipika-singhania 7# Original work Copyright (c) 2021 dipika-singhania 8# Source: https://github.com/dipika-singhania/C2F-TCN 9# Originally licensed under MIT License 10# Combined work licensed under GNU AGPLv3 11# 12import torch.nn.functional as F 13import torch.nn as nn 14import torch 15from functools import partial 16from dlc2action.model.base_model import Model 17from typing import Union, List, Optional 18 19nonlinearity = partial(F.relu, inplace=True) 20 21 22class double_conv(nn.Module): 23 def __init__(self, in_ch, out_ch): 24 super(double_conv, self).__init__() 25 self.conv = nn.Sequential( 26 nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2), 27 nn.BatchNorm1d(out_ch), 28 nn.ReLU(inplace=True), 29 nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2), 30 nn.BatchNorm1d(out_ch), 31 nn.ReLU(inplace=True), 32 ) 33 34 def forward(self, x): 35 """Forward pass.""" 36 x = self.conv(x) 37 return x 38 39 40class inconv(nn.Module): 41 def __init__(self, in_ch, out_ch): 42 super(inconv, self).__init__() 43 self.conv = double_conv(in_ch, out_ch) 44 45 def forward(self, x): 46 """Forward pass.""" 47 x = self.conv(x) 48 return x 49 50 51class outconv(nn.Module): 52 def __init__(self, in_ch, out_ch): 53 super(outconv, self).__init__() 54 self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1) 55 56 def forward(self, x): 57 """Forward pass.""" 58 x = self.conv(x) 59 return x 60 61 62class down(nn.Module): 63 def __init__(self, in_ch, out_ch): 64 super(down, self).__init__() 65 self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch)) 66 67 def forward(self, x): 68 """Forward pass.""" 69 x = self.max_pool_conv(x) 70 return x 71 72 73class up(nn.Module): 74 """Upscaling then double conv""" 75 76 def __init__(self, in_channels, out_channels, bilinear=True): 77 super().__init__() 78 79 if bilinear: 80 self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True) 81 else: 82 self.up = nn.ConvTranspose1d( 83 in_channels // 2, in_channels // 2, kernel_size=2, stride=2 84 ) 85 86 self.conv = double_conv(in_channels, out_channels) 87 88 def forward(self, x1, x2): 89 """Forward pass.""" 90 x1 = self.up(x1) 91 # input is CHW 92 diff = torch.tensor([x2.size()[2] - x1.size()[2]]) 93 94 x1 = F.pad(x1, [diff // 2, diff - diff // 2]) 95 x = torch.cat([x2, x1], dim=1) 96 return self.conv(x) 97 98 99class TPPblock(nn.Module): 100 def __init__(self, in_channels): 101 super(TPPblock, self).__init__() 102 self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2) 103 self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3) 104 self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5) 105 self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6) 106 107 self.conv = nn.Conv1d( 108 in_channels=in_channels, out_channels=1, kernel_size=1, padding=0 109 ) 110 111 def forward(self, x): 112 """Forward pass.""" 113 self.in_channels, t = x.size(1), x.size(2) 114 self.layer1 = F.interpolate( 115 self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True 116 ) 117 self.layer2 = F.interpolate( 118 self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True 119 ) 120 self.layer3 = F.interpolate( 121 self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True 122 ) 123 self.layer4 = F.interpolate( 124 self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True 125 ) 126 127 out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1) 128 129 return out 130 131 132class Predictor(nn.Module): 133 def __init__(self, dim, num_classes): 134 super(Predictor, self).__init__() 135 self.num_classes = num_classes 136 self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1) 137 self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1) 138 139 def forward(self, x): 140 """Forward pass.""" 141 x = self.conv_out_1(x) 142 x = F.relu(x) 143 x = self.conv_out_2(x) 144 x = x.reshape((4, -1, self.num_classes, x.shape[-1])) 145 return x 146 147 148class C2F_TCN_Module(nn.Module): 149 """ 150 Features are extracted at the last layer of decoder. 151 """ 152 153 def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False): 154 super().__init__() 155 self.use_predictor = use_predictor 156 self.inc = inconv(n_channels, num_f_maps * 2) 157 self.down1 = down(num_f_maps * 2, num_f_maps * 2) 158 self.down2 = down(num_f_maps * 2, num_f_maps * 2) 159 self.down3 = down(num_f_maps * 2, num_f_maps) 160 self.down4 = down(num_f_maps, num_f_maps) 161 self.down5 = down(num_f_maps, num_f_maps) 162 self.down6 = down(num_f_maps, num_f_maps) 163 self.up = up(num_f_maps * 2 + 4, num_f_maps) 164 self.outcc0 = outconv(num_f_maps, output_dim) 165 self.up0 = up(num_f_maps * 2, num_f_maps) 166 self.outcc1 = outconv(num_f_maps, output_dim) 167 self.up1 = up(num_f_maps * 2, num_f_maps) 168 self.outcc2 = outconv(num_f_maps, output_dim) 169 self.up2 = up(num_f_maps * 3, num_f_maps) 170 self.outcc3 = outconv(num_f_maps, output_dim) 171 self.up3 = up(num_f_maps * 3, num_f_maps) 172 self.outcc4 = outconv(num_f_maps, output_dim) 173 self.up4 = up(num_f_maps * 3, num_f_maps) 174 self.outcc = outconv(num_f_maps, output_dim) 175 self.tpp = TPPblock(num_f_maps) 176 self.weights = torch.nn.Parameter(torch.ones(6)) 177 178 def forward(self, x): 179 """Forward pass.""" 180 x1 = self.inc(x) 181 x2 = self.down1(x1) 182 x3 = self.down2(x2) 183 x4 = self.down3(x3) 184 x5 = self.down4(x4) 185 x6 = self.down5(x5) 186 x7 = self.down6(x6) 187 # x7 = self.dac(x7) 188 x7 = self.tpp(x7) 189 x = self.up(x7, x6) 190 y1 = self.outcc0(F.relu(x)) 191 # print("y1.shape=", y1.shape) 192 x = self.up0(x, x5) 193 y2 = self.outcc1(F.relu(x)) 194 # print("y2.shape=", y2.shape) 195 x = self.up1(x, x4) 196 y3 = self.outcc2(F.relu(x)) 197 # print("y3.shape=", y3.shape) 198 x = self.up2(x, x3) 199 y4 = self.outcc3(F.relu(x)) 200 # print("y4.shape=", y4.shape) 201 x = self.up3(x, x2) 202 y5 = self.outcc4(F.relu(x)) 203 # print("y5.shape=", y5.shape) 204 x = self.up4(x, x1) 205 y = self.outcc(x) 206 # print("y.shape=", y.shape) 207 output = [y] 208 for outp_ele in [y5, y4, y3]: 209 output.append( 210 F.interpolate( 211 outp_ele, size=y.shape[-1], mode="linear", align_corners=True 212 ) 213 ) 214 output = torch.stack(output, dim=0) 215 if self.use_predictor: 216 K, B, C, T = output.shape 217 output = output.reshape((-1, C, T)) 218 return output 219 220 221class C2F_TCN(Model): 222 """ 223 An implementation of C2F-TCN 224 225 Requires the `"general/len_segment"` parameter to be at least 512 226 """ 227 228 def __init__( 229 self, 230 num_classes:int, 231 input_dims:dict, 232 num_f_maps:int=128, 233 feature_dim:int=None, 234 state_dict_path:str=None, 235 ssl_constructors:List=None, 236 ssl_types:List=None, 237 ssl_modules:List=None, 238 ): 239 input_dims = int(sum([s[0] for s in input_dims.values()])) 240 if feature_dim is None: 241 feature_dim = num_classes 242 self.f_shape = None 243 self.params_predictor = None 244 else: 245 self.f_shape = torch.Size([int(feature_dim)]) 246 self.params_predictor = { 247 "dim": int(feature_dim), 248 "num_classes": num_classes, 249 } 250 self.params = { 251 "output_dim": int(feature_dim), 252 "n_channels": int(input_dims), 253 "num_f_maps": int(float(num_f_maps)), 254 "use_predictor": self.f_shape is not None, 255 } 256 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 257 258 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 259 return C2F_TCN_Module(**self.params) 260 261 def _predictor(self) -> torch.nn.Module: 262 if self.params_predictor is not None: 263 return Predictor(**self.params_predictor) 264 else: 265 return nn.Identity() 266 267 def features_shape(self) -> Optional[torch.Size]: 268 return self.f_shape
23class double_conv(nn.Module): 24 def __init__(self, in_ch, out_ch): 25 super(double_conv, self).__init__() 26 self.conv = nn.Sequential( 27 nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2), 28 nn.BatchNorm1d(out_ch), 29 nn.ReLU(inplace=True), 30 nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2), 31 nn.BatchNorm1d(out_ch), 32 nn.ReLU(inplace=True), 33 ) 34 35 def forward(self, x): 36 """Forward pass.""" 37 x = self.conv(x) 38 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
24 def __init__(self, in_ch, out_ch): 25 super(double_conv, self).__init__() 26 self.conv = nn.Sequential( 27 nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2), 28 nn.BatchNorm1d(out_ch), 29 nn.ReLU(inplace=True), 30 nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2), 31 nn.BatchNorm1d(out_ch), 32 nn.ReLU(inplace=True), 33 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
41class inconv(nn.Module): 42 def __init__(self, in_ch, out_ch): 43 super(inconv, self).__init__() 44 self.conv = double_conv(in_ch, out_ch) 45 46 def forward(self, x): 47 """Forward pass.""" 48 x = self.conv(x) 49 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
52class outconv(nn.Module): 53 def __init__(self, in_ch, out_ch): 54 super(outconv, self).__init__() 55 self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1) 56 57 def forward(self, x): 58 """Forward pass.""" 59 x = self.conv(x) 60 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
63class down(nn.Module): 64 def __init__(self, in_ch, out_ch): 65 super(down, self).__init__() 66 self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch)) 67 68 def forward(self, x): 69 """Forward pass.""" 70 x = self.max_pool_conv(x) 71 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
74class up(nn.Module): 75 """Upscaling then double conv""" 76 77 def __init__(self, in_channels, out_channels, bilinear=True): 78 super().__init__() 79 80 if bilinear: 81 self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True) 82 else: 83 self.up = nn.ConvTranspose1d( 84 in_channels // 2, in_channels // 2, kernel_size=2, stride=2 85 ) 86 87 self.conv = double_conv(in_channels, out_channels) 88 89 def forward(self, x1, x2): 90 """Forward pass.""" 91 x1 = self.up(x1) 92 # input is CHW 93 diff = torch.tensor([x2.size()[2] - x1.size()[2]]) 94 95 x1 = F.pad(x1, [diff // 2, diff - diff // 2]) 96 x = torch.cat([x2, x1], dim=1) 97 return self.conv(x)
Upscaling then double conv
77 def __init__(self, in_channels, out_channels, bilinear=True): 78 super().__init__() 79 80 if bilinear: 81 self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True) 82 else: 83 self.up = nn.ConvTranspose1d( 84 in_channels // 2, in_channels // 2, kernel_size=2, stride=2 85 ) 86 87 self.conv = double_conv(in_channels, out_channels)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
100class TPPblock(nn.Module): 101 def __init__(self, in_channels): 102 super(TPPblock, self).__init__() 103 self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2) 104 self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3) 105 self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5) 106 self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6) 107 108 self.conv = nn.Conv1d( 109 in_channels=in_channels, out_channels=1, kernel_size=1, padding=0 110 ) 111 112 def forward(self, x): 113 """Forward pass.""" 114 self.in_channels, t = x.size(1), x.size(2) 115 self.layer1 = F.interpolate( 116 self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True 117 ) 118 self.layer2 = F.interpolate( 119 self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True 120 ) 121 self.layer3 = F.interpolate( 122 self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True 123 ) 124 self.layer4 = F.interpolate( 125 self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True 126 ) 127 128 out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1) 129 130 return out
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
101 def __init__(self, in_channels): 102 super(TPPblock, self).__init__() 103 self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2) 104 self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3) 105 self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5) 106 self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6) 107 108 self.conv = nn.Conv1d( 109 in_channels=in_channels, out_channels=1, kernel_size=1, padding=0 110 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
112 def forward(self, x): 113 """Forward pass.""" 114 self.in_channels, t = x.size(1), x.size(2) 115 self.layer1 = F.interpolate( 116 self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True 117 ) 118 self.layer2 = F.interpolate( 119 self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True 120 ) 121 self.layer3 = F.interpolate( 122 self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True 123 ) 124 self.layer4 = F.interpolate( 125 self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True 126 ) 127 128 out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1) 129 130 return out
Forward pass.
133class Predictor(nn.Module): 134 def __init__(self, dim, num_classes): 135 super(Predictor, self).__init__() 136 self.num_classes = num_classes 137 self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1) 138 self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1) 139 140 def forward(self, x): 141 """Forward pass.""" 142 x = self.conv_out_1(x) 143 x = F.relu(x) 144 x = self.conv_out_2(x) 145 x = x.reshape((4, -1, self.num_classes, x.shape[-1])) 146 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to(), etc.
As per the example above, an __init__() call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
134 def __init__(self, dim, num_classes): 135 super(Predictor, self).__init__() 136 self.num_classes = num_classes 137 self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1) 138 self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
149class C2F_TCN_Module(nn.Module): 150 """ 151 Features are extracted at the last layer of decoder. 152 """ 153 154 def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False): 155 super().__init__() 156 self.use_predictor = use_predictor 157 self.inc = inconv(n_channels, num_f_maps * 2) 158 self.down1 = down(num_f_maps * 2, num_f_maps * 2) 159 self.down2 = down(num_f_maps * 2, num_f_maps * 2) 160 self.down3 = down(num_f_maps * 2, num_f_maps) 161 self.down4 = down(num_f_maps, num_f_maps) 162 self.down5 = down(num_f_maps, num_f_maps) 163 self.down6 = down(num_f_maps, num_f_maps) 164 self.up = up(num_f_maps * 2 + 4, num_f_maps) 165 self.outcc0 = outconv(num_f_maps, output_dim) 166 self.up0 = up(num_f_maps * 2, num_f_maps) 167 self.outcc1 = outconv(num_f_maps, output_dim) 168 self.up1 = up(num_f_maps * 2, num_f_maps) 169 self.outcc2 = outconv(num_f_maps, output_dim) 170 self.up2 = up(num_f_maps * 3, num_f_maps) 171 self.outcc3 = outconv(num_f_maps, output_dim) 172 self.up3 = up(num_f_maps * 3, num_f_maps) 173 self.outcc4 = outconv(num_f_maps, output_dim) 174 self.up4 = up(num_f_maps * 3, num_f_maps) 175 self.outcc = outconv(num_f_maps, output_dim) 176 self.tpp = TPPblock(num_f_maps) 177 self.weights = torch.nn.Parameter(torch.ones(6)) 178 179 def forward(self, x): 180 """Forward pass.""" 181 x1 = self.inc(x) 182 x2 = self.down1(x1) 183 x3 = self.down2(x2) 184 x4 = self.down3(x3) 185 x5 = self.down4(x4) 186 x6 = self.down5(x5) 187 x7 = self.down6(x6) 188 # x7 = self.dac(x7) 189 x7 = self.tpp(x7) 190 x = self.up(x7, x6) 191 y1 = self.outcc0(F.relu(x)) 192 # print("y1.shape=", y1.shape) 193 x = self.up0(x, x5) 194 y2 = self.outcc1(F.relu(x)) 195 # print("y2.shape=", y2.shape) 196 x = self.up1(x, x4) 197 y3 = self.outcc2(F.relu(x)) 198 # print("y3.shape=", y3.shape) 199 x = self.up2(x, x3) 200 y4 = self.outcc3(F.relu(x)) 201 # print("y4.shape=", y4.shape) 202 x = self.up3(x, x2) 203 y5 = self.outcc4(F.relu(x)) 204 # print("y5.shape=", y5.shape) 205 x = self.up4(x, x1) 206 y = self.outcc(x) 207 # print("y.shape=", y.shape) 208 output = [y] 209 for outp_ele in [y5, y4, y3]: 210 output.append( 211 F.interpolate( 212 outp_ele, size=y.shape[-1], mode="linear", align_corners=True 213 ) 214 ) 215 output = torch.stack(output, dim=0) 216 if self.use_predictor: 217 K, B, C, T = output.shape 218 output = output.reshape((-1, C, T)) 219 return output
Features are extracted at the last layer of decoder.
154 def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False): 155 super().__init__() 156 self.use_predictor = use_predictor 157 self.inc = inconv(n_channels, num_f_maps * 2) 158 self.down1 = down(num_f_maps * 2, num_f_maps * 2) 159 self.down2 = down(num_f_maps * 2, num_f_maps * 2) 160 self.down3 = down(num_f_maps * 2, num_f_maps) 161 self.down4 = down(num_f_maps, num_f_maps) 162 self.down5 = down(num_f_maps, num_f_maps) 163 self.down6 = down(num_f_maps, num_f_maps) 164 self.up = up(num_f_maps * 2 + 4, num_f_maps) 165 self.outcc0 = outconv(num_f_maps, output_dim) 166 self.up0 = up(num_f_maps * 2, num_f_maps) 167 self.outcc1 = outconv(num_f_maps, output_dim) 168 self.up1 = up(num_f_maps * 2, num_f_maps) 169 self.outcc2 = outconv(num_f_maps, output_dim) 170 self.up2 = up(num_f_maps * 3, num_f_maps) 171 self.outcc3 = outconv(num_f_maps, output_dim) 172 self.up3 = up(num_f_maps * 3, num_f_maps) 173 self.outcc4 = outconv(num_f_maps, output_dim) 174 self.up4 = up(num_f_maps * 3, num_f_maps) 175 self.outcc = outconv(num_f_maps, output_dim) 176 self.tpp = TPPblock(num_f_maps) 177 self.weights = torch.nn.Parameter(torch.ones(6))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
179 def forward(self, x): 180 """Forward pass.""" 181 x1 = self.inc(x) 182 x2 = self.down1(x1) 183 x3 = self.down2(x2) 184 x4 = self.down3(x3) 185 x5 = self.down4(x4) 186 x6 = self.down5(x5) 187 x7 = self.down6(x6) 188 # x7 = self.dac(x7) 189 x7 = self.tpp(x7) 190 x = self.up(x7, x6) 191 y1 = self.outcc0(F.relu(x)) 192 # print("y1.shape=", y1.shape) 193 x = self.up0(x, x5) 194 y2 = self.outcc1(F.relu(x)) 195 # print("y2.shape=", y2.shape) 196 x = self.up1(x, x4) 197 y3 = self.outcc2(F.relu(x)) 198 # print("y3.shape=", y3.shape) 199 x = self.up2(x, x3) 200 y4 = self.outcc3(F.relu(x)) 201 # print("y4.shape=", y4.shape) 202 x = self.up3(x, x2) 203 y5 = self.outcc4(F.relu(x)) 204 # print("y5.shape=", y5.shape) 205 x = self.up4(x, x1) 206 y = self.outcc(x) 207 # print("y.shape=", y.shape) 208 output = [y] 209 for outp_ele in [y5, y4, y3]: 210 output.append( 211 F.interpolate( 212 outp_ele, size=y.shape[-1], mode="linear", align_corners=True 213 ) 214 ) 215 output = torch.stack(output, dim=0) 216 if self.use_predictor: 217 K, B, C, T = output.shape 218 output = output.reshape((-1, C, T)) 219 return output
Forward pass.
222class C2F_TCN(Model): 223 """ 224 An implementation of C2F-TCN 225 226 Requires the `"general/len_segment"` parameter to be at least 512 227 """ 228 229 def __init__( 230 self, 231 num_classes:int, 232 input_dims:dict, 233 num_f_maps:int=128, 234 feature_dim:int=None, 235 state_dict_path:str=None, 236 ssl_constructors:List=None, 237 ssl_types:List=None, 238 ssl_modules:List=None, 239 ): 240 input_dims = int(sum([s[0] for s in input_dims.values()])) 241 if feature_dim is None: 242 feature_dim = num_classes 243 self.f_shape = None 244 self.params_predictor = None 245 else: 246 self.f_shape = torch.Size([int(feature_dim)]) 247 self.params_predictor = { 248 "dim": int(feature_dim), 249 "num_classes": num_classes, 250 } 251 self.params = { 252 "output_dim": int(feature_dim), 253 "n_channels": int(input_dims), 254 "num_f_maps": int(float(num_f_maps)), 255 "use_predictor": self.f_shape is not None, 256 } 257 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 258 259 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 260 return C2F_TCN_Module(**self.params) 261 262 def _predictor(self) -> torch.nn.Module: 263 if self.params_predictor is not None: 264 return Predictor(**self.params_predictor) 265 else: 266 return nn.Identity() 267 268 def features_shape(self) -> Optional[torch.Size]: 269 return self.f_shape
An implementation of C2F-TCN
Requires the "general/len_segment" parameter to be at least 512
229 def __init__( 230 self, 231 num_classes:int, 232 input_dims:dict, 233 num_f_maps:int=128, 234 feature_dim:int=None, 235 state_dict_path:str=None, 236 ssl_constructors:List=None, 237 ssl_types:List=None, 238 ssl_modules:List=None, 239 ): 240 input_dims = int(sum([s[0] for s in input_dims.values()])) 241 if feature_dim is None: 242 feature_dim = num_classes 243 self.f_shape = None 244 self.params_predictor = None 245 else: 246 self.f_shape = torch.Size([int(feature_dim)]) 247 self.params_predictor = { 248 "dim": int(feature_dim), 249 "num_classes": num_classes, 250 } 251 self.params = { 252 "output_dim": int(feature_dim), 253 "n_channels": int(input_dims), 254 "num_f_maps": int(float(num_f_maps)), 255 "use_predictor": self.f_shape is not None, 256 } 257 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Initialize the model.
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt
Get the shape of feature extractor output.
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- feature_extractor
- feature_extractors
- predictor
- ssl_active
- main_task_active
- prompt_function
- class_tensors
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- transform_labels
- forward