dlc2action.model.c2f_tcn_par
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6# Incorporates code adapted from C2F-TCN by dipika-singhania 7# Original work Copyright (c) 2021 dipika-singhania 8# Source: https://github.com/dipika-singhania/C2F-TCN 9# Originally licensed under MIT License 10# Combined work licensed under GNU AGPLv3 11# 12from functools import partial 13from typing import List, Optional, Union 14 15import torch 16import torch.nn as nn 17import torch.nn.functional as F 18from dlc2action.model.base_model import Model 19 20nonlinearity = partial(F.relu, inplace=True) 21 22 23class double_conv(nn.Module): 24 def __init__(self, in_ch, out_ch): 25 super(double_conv, self).__init__() 26 self.conv = nn.Sequential( 27 nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2), 28 nn.BatchNorm1d(out_ch), 29 nn.ReLU(inplace=True), 30 nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2), 31 nn.BatchNorm1d(out_ch), 32 nn.ReLU(inplace=True), 33 ) 34 35 def forward(self, x): 36 """Forward pass.""" 37 x = self.conv(x) 38 return x 39 40 41class inconv(nn.Module): 42 def __init__(self, in_ch, out_ch): 43 super(inconv, self).__init__() 44 self.conv = double_conv(in_ch, out_ch) 45 46 def forward(self, x): 47 """Forward pass.""" 48 x = self.conv(x) 49 return x 50 51 52class outconv(nn.Module): 53 def __init__(self, in_ch, out_ch): 54 super(outconv, self).__init__() 55 self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1) 56 57 def forward(self, x): 58 """Forward pass.""" 59 x = self.conv(x) 60 return x 61 62 63class down(nn.Module): 64 def __init__(self, in_ch, out_ch): 65 super(down, self).__init__() 66 self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch)) 67 68 def forward(self, x): 69 """Forward pass.""" 70 x = self.max_pool_conv(x) 71 return x 72 73 74class up(nn.Module): 75 """Upscaling then double conv""" 76 77 def __init__(self, in_channels, out_channels, bilinear=True): 78 super().__init__() 79 80 if bilinear: 81 self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True) 82 else: 83 self.up = nn.ConvTranspose1d( 84 in_channels // 2, in_channels // 2, kernel_size=2, stride=2 85 ) 86 87 self.conv = double_conv(in_channels, out_channels) 88 89 def forward(self, x1, x2): 90 """Forward pass.""" 91 x1 = self.up(x1) 92 # input is CHW 93 diff = torch.tensor([x2.size()[2] - x1.size()[2]]) 94 95 x1 = F.pad(x1, [diff // 2, diff - diff // 2]) 96 x = torch.cat([x2, x1], dim=1) 97 return self.conv(x) 98 99 100class TPPblock(nn.Module): 101 def __init__(self, in_channels): 102 super(TPPblock, self).__init__() 103 self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2) 104 self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3) 105 self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5) 106 self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6) 107 108 self.conv = nn.Conv1d( 109 in_channels=in_channels, out_channels=1, kernel_size=1, padding=0 110 ) 111 112 def forward(self, x): 113 """Forward pass.""" 114 self.in_channels, t = x.size(1), x.size(2) 115 self.layer1 = F.upsample( 116 self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True 117 ) 118 self.layer2 = F.upsample( 119 self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True 120 ) 121 self.layer3 = F.upsample( 122 self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True 123 ) 124 self.layer4 = F.upsample( 125 self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True 126 ) 127 128 out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1) 129 130 return out 131 132 133class Predictor(nn.Module): 134 def __init__(self, dim, num_classes): 135 super(Predictor, self).__init__() 136 self.num_classes = num_classes 137 self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1) 138 self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1) 139 140 def forward(self, x): 141 """Forward pass.""" 142 x = self.conv_out_1(x) 143 x = F.relu(x) 144 x = self.conv_out_2(x) 145 x = x.reshape((4, -1, self.num_classes, x.shape[-1])) 146 return x 147 148 149class C2F_TCN_P_Module(nn.Module): 150 def __init__(self, n_channels, output_dim, num_f_maps): 151 super().__init__() 152 self.c2f_tcn = C2F_TCN_Module( 153 n_channels, output_dim, num_f_maps, use_predictor=True 154 ) 155 156 def forward(self, x): 157 """Forward pass.""" 158 output = [] 159 for ind_x in x: 160 output.append(self.c2f_tcn(ind_x)) 161 return torch.cat(output, dim=1) 162 163 164class C2F_TCN_Module(nn.Module): 165 """ 166 Features are extracted at the last layer of decoder. 167 """ 168 169 def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False): 170 super().__init__() 171 self.use_predictor = use_predictor 172 self.inc = inconv(n_channels, num_f_maps * 2) 173 self.down1 = down(num_f_maps * 2, num_f_maps * 2) 174 self.down2 = down(num_f_maps * 2, num_f_maps * 2) 175 self.down3 = down(num_f_maps * 2, num_f_maps) 176 self.down4 = down(num_f_maps, num_f_maps) 177 self.down5 = down(num_f_maps, num_f_maps) 178 self.down6 = down(num_f_maps, num_f_maps) 179 self.up = up(num_f_maps * 2 + 4, num_f_maps) 180 self.outcc0 = outconv(num_f_maps, output_dim) 181 self.up0 = up(num_f_maps * 2, num_f_maps) 182 self.outcc1 = outconv(num_f_maps, output_dim) 183 self.up1 = up(num_f_maps * 2, num_f_maps) 184 self.outcc2 = outconv(num_f_maps, output_dim) 185 self.up2 = up(num_f_maps * 3, num_f_maps) 186 self.outcc3 = outconv(num_f_maps, output_dim) 187 self.up3 = up(num_f_maps * 3, num_f_maps) 188 self.outcc4 = outconv(num_f_maps, output_dim) 189 self.up4 = up(num_f_maps * 3, num_f_maps) 190 self.outcc = outconv(num_f_maps, output_dim) 191 self.tpp = TPPblock(num_f_maps) 192 self.weights = torch.nn.Parameter(torch.ones(6)) 193 194 def forward(self, x): 195 """Forward pass.""" 196 x1 = self.inc(x) 197 x2 = self.down1(x1) 198 x3 = self.down2(x2) 199 x4 = self.down3(x3) 200 x5 = self.down4(x4) 201 x6 = self.down5(x5) 202 x7 = self.down6(x6) 203 # x7 = self.dac(x7) 204 x7 = self.tpp(x7) 205 x = self.up(x7, x6) 206 y1 = self.outcc0(F.relu(x)) 207 # print("y1.shape=", y1.shape) 208 x = self.up0(x, x5) 209 y2 = self.outcc1(F.relu(x)) 210 # print("y2.shape=", y2.shape) 211 x = self.up1(x, x4) 212 y3 = self.outcc2(F.relu(x)) 213 # print("y3.shape=", y3.shape) 214 x = self.up2(x, x3) 215 y4 = self.outcc3(F.relu(x)) 216 # print("y4.shape=", y4.shape) 217 x = self.up3(x, x2) 218 y5 = self.outcc4(F.relu(x)) 219 # print("y5.shape=", y5.shape) 220 x = self.up4(x, x1) 221 y = self.outcc(x) 222 # print("y.shape=", y.shape) 223 output = [y] 224 for outp_ele in [y5, y4, y3]: 225 output.append( 226 F.upsample( 227 outp_ele, size=y.shape[-1], mode="linear", align_corners=True 228 ) 229 ) 230 output = torch.stack(output, dim=0) 231 if self.use_predictor: 232 K, B, C, T = output.shape 233 output = output.reshape((-1, C, T)) 234 return output 235 236 237class C2F_TCN_P(Model): 238 def __init__( 239 self, 240 num_classes, 241 input_dims, 242 num_f_maps=128, 243 feature_dim=None, 244 state_dict_path=None, 245 ssl_constructors=None, 246 ssl_types=None, 247 ssl_modules=None, 248 ): 249 if feature_dim is None: 250 feature_dim = num_f_maps 251 keys = [ 252 key 253 for key in input_dims.keys() 254 if len(key.split("---")) != 1 and len(key.split("---")[-1].split("+")) != 2 255 ] 256 num_ind = len(set([key.split("---")[-1] for key in keys])) 257 key = keys[0] 258 ind = key.split("---")[-1] 259 input_dims = int( 260 sum([v[0] for k, v in input_dims.items() if k.split("---")[-1] == ind]) 261 ) 262 self.f_shape = torch.Size([feature_dim * num_ind]) 263 self.params_predictor = { 264 "dim": int(feature_dim * num_ind), 265 "num_classes": num_classes, 266 } 267 self.params = { 268 "output_dim": int(feature_dim), 269 "n_channels": int(input_dims), 270 "num_f_maps": int(num_f_maps), 271 } 272 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 273 274 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 275 return C2F_TCN_P_Module(**self.params) 276 277 def _predictor(self) -> torch.nn.Module: 278 return Predictor(**self.params_predictor) 279 280 def features_shape(self) -> Optional[torch.Size]: 281 return self.f_shape
24class double_conv(nn.Module): 25 def __init__(self, in_ch, out_ch): 26 super(double_conv, self).__init__() 27 self.conv = nn.Sequential( 28 nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2), 29 nn.BatchNorm1d(out_ch), 30 nn.ReLU(inplace=True), 31 nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2), 32 nn.BatchNorm1d(out_ch), 33 nn.ReLU(inplace=True), 34 ) 35 36 def forward(self, x): 37 """Forward pass.""" 38 x = self.conv(x) 39 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
25 def __init__(self, in_ch, out_ch): 26 super(double_conv, self).__init__() 27 self.conv = nn.Sequential( 28 nn.Conv1d(in_ch, out_ch, kernel_size=5, padding=2), 29 nn.BatchNorm1d(out_ch), 30 nn.ReLU(inplace=True), 31 nn.Conv1d(out_ch, out_ch, kernel_size=5, padding=2), 32 nn.BatchNorm1d(out_ch), 33 nn.ReLU(inplace=True), 34 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
42class inconv(nn.Module): 43 def __init__(self, in_ch, out_ch): 44 super(inconv, self).__init__() 45 self.conv = double_conv(in_ch, out_ch) 46 47 def forward(self, x): 48 """Forward pass.""" 49 x = self.conv(x) 50 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
53class outconv(nn.Module): 54 def __init__(self, in_ch, out_ch): 55 super(outconv, self).__init__() 56 self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1) 57 58 def forward(self, x): 59 """Forward pass.""" 60 x = self.conv(x) 61 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
64class down(nn.Module): 65 def __init__(self, in_ch, out_ch): 66 super(down, self).__init__() 67 self.max_pool_conv = nn.Sequential(nn.MaxPool1d(2), double_conv(in_ch, out_ch)) 68 69 def forward(self, x): 70 """Forward pass.""" 71 x = self.max_pool_conv(x) 72 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
75class up(nn.Module): 76 """Upscaling then double conv""" 77 78 def __init__(self, in_channels, out_channels, bilinear=True): 79 super().__init__() 80 81 if bilinear: 82 self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True) 83 else: 84 self.up = nn.ConvTranspose1d( 85 in_channels // 2, in_channels // 2, kernel_size=2, stride=2 86 ) 87 88 self.conv = double_conv(in_channels, out_channels) 89 90 def forward(self, x1, x2): 91 """Forward pass.""" 92 x1 = self.up(x1) 93 # input is CHW 94 diff = torch.tensor([x2.size()[2] - x1.size()[2]]) 95 96 x1 = F.pad(x1, [diff // 2, diff - diff // 2]) 97 x = torch.cat([x2, x1], dim=1) 98 return self.conv(x)
Upscaling then double conv
78 def __init__(self, in_channels, out_channels, bilinear=True): 79 super().__init__() 80 81 if bilinear: 82 self.up = nn.Upsample(scale_factor=2, mode="linear", align_corners=True) 83 else: 84 self.up = nn.ConvTranspose1d( 85 in_channels // 2, in_channels // 2, kernel_size=2, stride=2 86 ) 87 88 self.conv = double_conv(in_channels, out_channels)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
101class TPPblock(nn.Module): 102 def __init__(self, in_channels): 103 super(TPPblock, self).__init__() 104 self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2) 105 self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3) 106 self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5) 107 self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6) 108 109 self.conv = nn.Conv1d( 110 in_channels=in_channels, out_channels=1, kernel_size=1, padding=0 111 ) 112 113 def forward(self, x): 114 """Forward pass.""" 115 self.in_channels, t = x.size(1), x.size(2) 116 self.layer1 = F.upsample( 117 self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True 118 ) 119 self.layer2 = F.upsample( 120 self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True 121 ) 122 self.layer3 = F.upsample( 123 self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True 124 ) 125 self.layer4 = F.upsample( 126 self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True 127 ) 128 129 out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1) 130 131 return out
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
102 def __init__(self, in_channels): 103 super(TPPblock, self).__init__() 104 self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2) 105 self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3) 106 self.pool3 = nn.MaxPool1d(kernel_size=5, stride=5) 107 self.pool4 = nn.MaxPool1d(kernel_size=6, stride=6) 108 109 self.conv = nn.Conv1d( 110 in_channels=in_channels, out_channels=1, kernel_size=1, padding=0 111 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
113 def forward(self, x): 114 """Forward pass.""" 115 self.in_channels, t = x.size(1), x.size(2) 116 self.layer1 = F.upsample( 117 self.conv(self.pool1(x)), size=t, mode="linear", align_corners=True 118 ) 119 self.layer2 = F.upsample( 120 self.conv(self.pool2(x)), size=t, mode="linear", align_corners=True 121 ) 122 self.layer3 = F.upsample( 123 self.conv(self.pool3(x)), size=t, mode="linear", align_corners=True 124 ) 125 self.layer4 = F.upsample( 126 self.conv(self.pool4(x)), size=t, mode="linear", align_corners=True 127 ) 128 129 out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1) 130 131 return out
Forward pass.
134class Predictor(nn.Module): 135 def __init__(self, dim, num_classes): 136 super(Predictor, self).__init__() 137 self.num_classes = num_classes 138 self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1) 139 self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1) 140 141 def forward(self, x): 142 """Forward pass.""" 143 x = self.conv_out_1(x) 144 x = F.relu(x) 145 x = self.conv_out_2(x) 146 x = x.reshape((4, -1, self.num_classes, x.shape[-1])) 147 return x
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
135 def __init__(self, dim, num_classes): 136 super(Predictor, self).__init__() 137 self.num_classes = num_classes 138 self.conv_out_1 = nn.Conv1d(dim, dim, kernel_size=1) 139 self.conv_out_2 = nn.Conv1d(dim, num_classes, kernel_size=1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
150class C2F_TCN_P_Module(nn.Module): 151 def __init__(self, n_channels, output_dim, num_f_maps): 152 super().__init__() 153 self.c2f_tcn = C2F_TCN_Module( 154 n_channels, output_dim, num_f_maps, use_predictor=True 155 ) 156 157 def forward(self, x): 158 """Forward pass.""" 159 output = [] 160 for ind_x in x: 161 output.append(self.c2f_tcn(ind_x)) 162 return torch.cat(output, dim=1)
Base class for all neural network modules.
Your models should also subclass this class.
Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Submodules assigned in this way will be registered, and will also have their
parameters converted when you call to()
, etc.
As per the example above, an __init__()
call to the parent class
must be made before assignment on the child.
:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool
151 def __init__(self, n_channels, output_dim, num_f_maps): 152 super().__init__() 153 self.c2f_tcn = C2F_TCN_Module( 154 n_channels, output_dim, num_f_maps, use_predictor=True 155 )
Initialize internal Module state, shared by both nn.Module and ScriptModule.
165class C2F_TCN_Module(nn.Module): 166 """ 167 Features are extracted at the last layer of decoder. 168 """ 169 170 def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False): 171 super().__init__() 172 self.use_predictor = use_predictor 173 self.inc = inconv(n_channels, num_f_maps * 2) 174 self.down1 = down(num_f_maps * 2, num_f_maps * 2) 175 self.down2 = down(num_f_maps * 2, num_f_maps * 2) 176 self.down3 = down(num_f_maps * 2, num_f_maps) 177 self.down4 = down(num_f_maps, num_f_maps) 178 self.down5 = down(num_f_maps, num_f_maps) 179 self.down6 = down(num_f_maps, num_f_maps) 180 self.up = up(num_f_maps * 2 + 4, num_f_maps) 181 self.outcc0 = outconv(num_f_maps, output_dim) 182 self.up0 = up(num_f_maps * 2, num_f_maps) 183 self.outcc1 = outconv(num_f_maps, output_dim) 184 self.up1 = up(num_f_maps * 2, num_f_maps) 185 self.outcc2 = outconv(num_f_maps, output_dim) 186 self.up2 = up(num_f_maps * 3, num_f_maps) 187 self.outcc3 = outconv(num_f_maps, output_dim) 188 self.up3 = up(num_f_maps * 3, num_f_maps) 189 self.outcc4 = outconv(num_f_maps, output_dim) 190 self.up4 = up(num_f_maps * 3, num_f_maps) 191 self.outcc = outconv(num_f_maps, output_dim) 192 self.tpp = TPPblock(num_f_maps) 193 self.weights = torch.nn.Parameter(torch.ones(6)) 194 195 def forward(self, x): 196 """Forward pass.""" 197 x1 = self.inc(x) 198 x2 = self.down1(x1) 199 x3 = self.down2(x2) 200 x4 = self.down3(x3) 201 x5 = self.down4(x4) 202 x6 = self.down5(x5) 203 x7 = self.down6(x6) 204 # x7 = self.dac(x7) 205 x7 = self.tpp(x7) 206 x = self.up(x7, x6) 207 y1 = self.outcc0(F.relu(x)) 208 # print("y1.shape=", y1.shape) 209 x = self.up0(x, x5) 210 y2 = self.outcc1(F.relu(x)) 211 # print("y2.shape=", y2.shape) 212 x = self.up1(x, x4) 213 y3 = self.outcc2(F.relu(x)) 214 # print("y3.shape=", y3.shape) 215 x = self.up2(x, x3) 216 y4 = self.outcc3(F.relu(x)) 217 # print("y4.shape=", y4.shape) 218 x = self.up3(x, x2) 219 y5 = self.outcc4(F.relu(x)) 220 # print("y5.shape=", y5.shape) 221 x = self.up4(x, x1) 222 y = self.outcc(x) 223 # print("y.shape=", y.shape) 224 output = [y] 225 for outp_ele in [y5, y4, y3]: 226 output.append( 227 F.upsample( 228 outp_ele, size=y.shape[-1], mode="linear", align_corners=True 229 ) 230 ) 231 output = torch.stack(output, dim=0) 232 if self.use_predictor: 233 K, B, C, T = output.shape 234 output = output.reshape((-1, C, T)) 235 return output
Features are extracted at the last layer of decoder.
170 def __init__(self, n_channels, output_dim, num_f_maps, use_predictor=False): 171 super().__init__() 172 self.use_predictor = use_predictor 173 self.inc = inconv(n_channels, num_f_maps * 2) 174 self.down1 = down(num_f_maps * 2, num_f_maps * 2) 175 self.down2 = down(num_f_maps * 2, num_f_maps * 2) 176 self.down3 = down(num_f_maps * 2, num_f_maps) 177 self.down4 = down(num_f_maps, num_f_maps) 178 self.down5 = down(num_f_maps, num_f_maps) 179 self.down6 = down(num_f_maps, num_f_maps) 180 self.up = up(num_f_maps * 2 + 4, num_f_maps) 181 self.outcc0 = outconv(num_f_maps, output_dim) 182 self.up0 = up(num_f_maps * 2, num_f_maps) 183 self.outcc1 = outconv(num_f_maps, output_dim) 184 self.up1 = up(num_f_maps * 2, num_f_maps) 185 self.outcc2 = outconv(num_f_maps, output_dim) 186 self.up2 = up(num_f_maps * 3, num_f_maps) 187 self.outcc3 = outconv(num_f_maps, output_dim) 188 self.up3 = up(num_f_maps * 3, num_f_maps) 189 self.outcc4 = outconv(num_f_maps, output_dim) 190 self.up4 = up(num_f_maps * 3, num_f_maps) 191 self.outcc = outconv(num_f_maps, output_dim) 192 self.tpp = TPPblock(num_f_maps) 193 self.weights = torch.nn.Parameter(torch.ones(6))
Initialize internal Module state, shared by both nn.Module and ScriptModule.
195 def forward(self, x): 196 """Forward pass.""" 197 x1 = self.inc(x) 198 x2 = self.down1(x1) 199 x3 = self.down2(x2) 200 x4 = self.down3(x3) 201 x5 = self.down4(x4) 202 x6 = self.down5(x5) 203 x7 = self.down6(x6) 204 # x7 = self.dac(x7) 205 x7 = self.tpp(x7) 206 x = self.up(x7, x6) 207 y1 = self.outcc0(F.relu(x)) 208 # print("y1.shape=", y1.shape) 209 x = self.up0(x, x5) 210 y2 = self.outcc1(F.relu(x)) 211 # print("y2.shape=", y2.shape) 212 x = self.up1(x, x4) 213 y3 = self.outcc2(F.relu(x)) 214 # print("y3.shape=", y3.shape) 215 x = self.up2(x, x3) 216 y4 = self.outcc3(F.relu(x)) 217 # print("y4.shape=", y4.shape) 218 x = self.up3(x, x2) 219 y5 = self.outcc4(F.relu(x)) 220 # print("y5.shape=", y5.shape) 221 x = self.up4(x, x1) 222 y = self.outcc(x) 223 # print("y.shape=", y.shape) 224 output = [y] 225 for outp_ele in [y5, y4, y3]: 226 output.append( 227 F.upsample( 228 outp_ele, size=y.shape[-1], mode="linear", align_corners=True 229 ) 230 ) 231 output = torch.stack(output, dim=0) 232 if self.use_predictor: 233 K, B, C, T = output.shape 234 output = output.reshape((-1, C, T)) 235 return output
Forward pass.
238class C2F_TCN_P(Model): 239 def __init__( 240 self, 241 num_classes, 242 input_dims, 243 num_f_maps=128, 244 feature_dim=None, 245 state_dict_path=None, 246 ssl_constructors=None, 247 ssl_types=None, 248 ssl_modules=None, 249 ): 250 if feature_dim is None: 251 feature_dim = num_f_maps 252 keys = [ 253 key 254 for key in input_dims.keys() 255 if len(key.split("---")) != 1 and len(key.split("---")[-1].split("+")) != 2 256 ] 257 num_ind = len(set([key.split("---")[-1] for key in keys])) 258 key = keys[0] 259 ind = key.split("---")[-1] 260 input_dims = int( 261 sum([v[0] for k, v in input_dims.items() if k.split("---")[-1] == ind]) 262 ) 263 self.f_shape = torch.Size([feature_dim * num_ind]) 264 self.params_predictor = { 265 "dim": int(feature_dim * num_ind), 266 "num_classes": num_classes, 267 } 268 self.params = { 269 "output_dim": int(feature_dim), 270 "n_channels": int(input_dims), 271 "num_f_maps": int(num_f_maps), 272 } 273 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path) 274 275 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 276 return C2F_TCN_P_Module(**self.params) 277 278 def _predictor(self) -> torch.nn.Module: 279 return Predictor(**self.params_predictor) 280 281 def features_shape(self) -> Optional[torch.Size]: 282 return self.f_shape
Base class for all models.
Manages interaction of base model and SSL modules + ensures consistent input and output format.
239 def __init__( 240 self, 241 num_classes, 242 input_dims, 243 num_f_maps=128, 244 feature_dim=None, 245 state_dict_path=None, 246 ssl_constructors=None, 247 ssl_types=None, 248 ssl_modules=None, 249 ): 250 if feature_dim is None: 251 feature_dim = num_f_maps 252 keys = [ 253 key 254 for key in input_dims.keys() 255 if len(key.split("---")) != 1 and len(key.split("---")[-1].split("+")) != 2 256 ] 257 num_ind = len(set([key.split("---")[-1] for key in keys])) 258 key = keys[0] 259 ind = key.split("---")[-1] 260 input_dims = int( 261 sum([v[0] for k, v in input_dims.items() if k.split("---")[-1] == ind]) 262 ) 263 self.f_shape = torch.Size([feature_dim * num_ind]) 264 self.params_predictor = { 265 "dim": int(feature_dim * num_ind), 266 "num_classes": num_classes, 267 } 268 self.params = { 269 "output_dim": int(feature_dim), 270 "n_channels": int(input_dims), 271 "num_f_maps": int(num_f_maps), 272 } 273 super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
Initialize the model.
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt
Get the shape of feature extractor output.
Returns
feature_shape : torch.Size shape of feature extractor output
Inherited Members
- dlc2action.model.base_model.Model
- process_labels
- feature_extractor
- feature_extractors
- predictor
- ssl_active
- main_task_active
- prompt_function
- class_tensors
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- ssl_on
- main_task_on
- main_task_off
- set_ssl
- extract_features
- transform_labels
- forward