dlc2action.model.base_model
Abstract parent class for models used in dlc2action.task.universal_task.Task
.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Abstract parent class for models used in `dlc2action.task.universal_task.Task`.""" 8 9import copy 10import warnings 11from abc import ABC, abstractmethod 12from collections.abc import Iterable 13from typing import Callable, Dict, List, Union 14 15import torch 16from torch import nn 17 18available_ssl_types = [ 19 "ssl_input", 20 "ssl_target", 21 "contrastive", 22 "none", 23 "contrastive_2layers", 24] 25 26 27class Model(nn.Module, ABC): 28 """Base class for all models. 29 30 Manages interaction of base model and SSL modules + ensures consistent input and output format. 31 """ 32 33 process_labels = False 34 35 def __init__( 36 self, 37 ssl_constructors: List = None, 38 ssl_modules: List = None, 39 ssl_types: List = None, 40 state_dict_path: str = None, 41 strict: bool = False, 42 prompt_function: Callable = None, 43 ) -> None: 44 """Initialize the model. 45 46 Parameters 47 ---------- 48 ssl_constructors : list, optional 49 a list of SSL constructors that build the necessary SSL modules 50 ssl_modules : list, optional 51 a list of torch.nn.Module instances that will serve as SSL modules 52 ssl_types : list, optional 53 a list of string SSL types 54 state_dict_path : str, optional 55 path to the model state dictionary to load 56 strict : bool, default False 57 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 58 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 59 prompt_function : callable, optional 60 a function that takes a list of strings and returns a string prompt 61 62 """ 63 super(Model, self).__init__() 64 feature_extractors = self._feature_extractor() 65 if not isinstance(feature_extractors, list): 66 feature_extractors = [feature_extractors] 67 self.feature_extractor = feature_extractors[0] 68 self.feature_extractors = nn.ModuleList(feature_extractors[1:]) 69 self.predictor = self._predictor() 70 self.set_ssl(ssl_constructors, ssl_types, ssl_modules) 71 self.ssl_active = True 72 self.main_task_active = True 73 self.prompt_function = prompt_function 74 self.class_tensors = None 75 if state_dict_path is not None: 76 self.load_state_dict(torch.load(state_dict_path), strict=strict) 77 # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors]) 78 # self.predictor = nn.DataParallel(self.predictor) 79 # if self.ssl != [None]: 80 # self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl]) 81 82 # def to(self, device, *args, **kwargs): 83 # if self.class_tensors is not None: 84 # self.class_tensors = { 85 # k: v.to(device) for k, v in self.class_tensors.items() 86 # } 87 # return super().to(device, *args, **kwargs) 88 89 def freeze_feature_extractor(self) -> None: 90 """Freeze the parameters of the feature extraction module.""" 91 for param in self.feature_extractor.parameters(): 92 param.requires_grad = False 93 94 def unfreeze_feature_extractor(self) -> None: 95 """Unfreeze the parameters of the feature extraction module.""" 96 for param in self.feature_extractor.parameters(): 97 param.requires_grad = True 98 99 def load_state_dict(self, state_dict: str, strict: bool = True) -> None: 100 """Load a model state dictionary. 101 102 Parameters 103 ---------- 104 state_dict : str 105 the path to the saved state dictionary 106 strict : bool, default True 107 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 108 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 109 110 """ 111 try: 112 super().load_state_dict(state_dict, strict) 113 except RuntimeError as e: 114 if strict: 115 raise e 116 else: 117 warnings.warn( 118 "Some of the layer shapes do not match the loaded state dictionary, skipping those" 119 ) 120 own_state = self.state_dict() 121 for name, param in state_dict.items(): 122 if name not in own_state: 123 continue 124 if isinstance(param, nn.Parameter): 125 # backwards compatibility for serialized parameters 126 param = param.data 127 try: 128 own_state[name].copy_(param) 129 except: 130 pass 131 132 def ssl_off(self) -> None: 133 """Turn SSL off (SSL output will not be computed by the forward function).""" 134 self.ssl_active = False 135 136 def ssl_on(self) -> None: 137 """Turn SSL on (SSL output will be computed by the forward function).""" 138 self.ssl_active = True 139 140 def main_task_on(self) -> None: 141 """Turn main task training on.""" 142 self.main_task_active = True 143 144 def main_task_off(self) -> None: 145 """Turn main task training on.""" 146 self.main_task_active = False 147 148 def set_ssl( 149 self, 150 ssl_constructors: List = None, 151 ssl_types: List = None, 152 ssl_modules: List = None, 153 ) -> None: 154 """Set the SSL modules.""" 155 if ssl_constructors is None and ssl_types is None: 156 self.ssl_type = ["none"] 157 self.ssl = [None] 158 else: 159 if ssl_constructors is not None: 160 ssl_types = [ 161 ssl_constructor.type for ssl_constructor in ssl_constructors 162 ] 163 ssl_modules = [ 164 ssl_constructor.construct_module() 165 for ssl_constructor in ssl_constructors 166 ] 167 if not isinstance(ssl_types, Iterable): 168 ssl_types = [ssl_types] 169 ssl_modules = [ssl_modules] 170 for t in ssl_types: 171 if t not in available_ssl_types: 172 raise ValueError( 173 f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}" 174 ) 175 self.ssl_type = ssl_types 176 self.ssl = nn.ModuleList(ssl_modules) 177 178 @abstractmethod 179 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 180 """Construct the feature extractor module. 181 182 Returns 183 ------- 184 feature_extractor : torch.nn.Module 185 an instance of torch.nn.Module that has a forward method receiving input and 186 returning features that can be passed to an SSL module and to a prediction module 187 188 """ 189 190 @abstractmethod 191 def _predictor(self) -> torch.nn.Module: 192 """Construct the predictor module. 193 194 Returns 195 ------- 196 predictor : torch.nn.Module 197 an instance of torch.nn.Module that has a forward method receiving features 198 extracted by self.feature_extractor and returning a prediction 199 200 """ 201 202 @abstractmethod 203 def features_shape(self) -> torch.Size: 204 """Get the shape of feature extractor output. 205 206 Returns 207 ------- 208 feature_shape : torch.Size 209 shape of feature extractor output 210 211 """ 212 213 def extract_features(self, x, start=0): 214 """Apply the feature extraction modules consecutively. 215 216 Parameters 217 ---------- 218 x : torch.Tensor 219 the input tensor 220 start : int, default 0 221 the index of the feature extraction module to start with 222 223 Returns 224 ------- 225 output : torch.Tensor 226 the output tensor 227 228 """ 229 if start == 0: 230 x = self.feature_extractor(x) 231 for extractor in self.feature_extractors[max(0, start - 1) :]: 232 x = extractor(x) 233 return x 234 235 def _extract_features_first(self, x): 236 """Extract features from the first feature extractor.""" 237 return self.feature_extractor(x) 238 239 def transform_labels(self, device): 240 """Transform the labels into a tensor of shape (1, n_classes, n_features).""" 241 return { 242 k: self.transform_label(v.to(device)).mean(0).unsqueeze(0) 243 for k, v in self.class_tensors.items() 244 } 245 246 def forward( 247 self, 248 x: torch.Tensor, 249 ssl_xs: list, 250 tag: torch.Tensor = None, 251 ) -> tuple[torch.Tensor, list]: 252 """Generate a prediction for x. 253 254 Parameters 255 ---------- 256 x : torch.Tensor 257 the main input 258 ssl_xs : list 259 a list of SSL input tensors 260 tag : any, optional 261 a meta information tag 262 263 Returns 264 ------- 265 prediction : torch.Tensor 266 prediction for the main input 267 ssl_out : list 268 a list of SSL prediction tensors 269 270 """ 271 ssl_out = None 272 features_0 = self._extract_features_first(x) 273 if len(self.feature_extractors) > 1: 274 features = copy.copy(features_0) 275 features = self.extract_features(features, start=1) 276 else: 277 features = features_0 278 if self.ssl_active: 279 ssl_out = [] 280 for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type): 281 if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]: 282 ssl_features = self.extract_features(ssl_x) 283 if ssl_type == "ssl_input": 284 if ssl_features.shape[0] > ssl_x.shape[0]: 285 if ssl_features.shape[0] % ssl_x.shape[0] != 0: 286 raise ValueError( 287 "The length of the SSL input tensor must be a multiple of the main input tensor" 288 ) 289 ssl_features = ssl_features[:ssl_x.shape[0]] #This for C2F-TCN where you want to keep onlw the first (which is the last) feature set 290 ssl_out.append(ssl(ssl_features)) 291 elif ssl_type == "contrastive_2layers": 292 ssl_out.append( 293 (ssl(features_0, extract_features=False), ssl(ssl_features)) 294 ) 295 elif ssl_type == "contrastive": 296 ssl_out.append((ssl(features), ssl(ssl_features))) 297 elif ssl_type == "ssl_target": 298 ssl_out.append(ssl(features)) 299 args = [features] 300 if self.main_task_active: 301 x = self.predictor(*args) 302 else: 303 x = None 304 return x, ssl_out 305 306 307class LoadedModel(Model): 308 """A class to generate a Model instance from a torch.nn.Module.""" 309 310 ssl_types = ["none"] 311 312 def __init__(self, model: nn.Module, **kwargs) -> None: 313 """Initialize the model. 314 315 Parameters 316 ---------- 317 model : torch.nn.Module 318 a model with a forward function that takes a single tensor as input and returns a single tensor as output 319 320 """ 321 super(LoadedModel, self).__init__() 322 self.ssl_active = False 323 self.feature_extractor = model 324 325 def _feature_extractor(self) -> None: 326 """Set feature extractor.""" 327 pass 328 329 def _predictor(self) -> None: 330 """Set predictor.""" 331 self.predictor = nn.Identity() 332 333 def ssl_on(self): 334 """Turn SSL on (SSL output will be computed by the forward function).""" 335 pass
28class Model(nn.Module, ABC): 29 """Base class for all models. 30 31 Manages interaction of base model and SSL modules + ensures consistent input and output format. 32 """ 33 34 process_labels = False 35 36 def __init__( 37 self, 38 ssl_constructors: List = None, 39 ssl_modules: List = None, 40 ssl_types: List = None, 41 state_dict_path: str = None, 42 strict: bool = False, 43 prompt_function: Callable = None, 44 ) -> None: 45 """Initialize the model. 46 47 Parameters 48 ---------- 49 ssl_constructors : list, optional 50 a list of SSL constructors that build the necessary SSL modules 51 ssl_modules : list, optional 52 a list of torch.nn.Module instances that will serve as SSL modules 53 ssl_types : list, optional 54 a list of string SSL types 55 state_dict_path : str, optional 56 path to the model state dictionary to load 57 strict : bool, default False 58 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 59 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 60 prompt_function : callable, optional 61 a function that takes a list of strings and returns a string prompt 62 63 """ 64 super(Model, self).__init__() 65 feature_extractors = self._feature_extractor() 66 if not isinstance(feature_extractors, list): 67 feature_extractors = [feature_extractors] 68 self.feature_extractor = feature_extractors[0] 69 self.feature_extractors = nn.ModuleList(feature_extractors[1:]) 70 self.predictor = self._predictor() 71 self.set_ssl(ssl_constructors, ssl_types, ssl_modules) 72 self.ssl_active = True 73 self.main_task_active = True 74 self.prompt_function = prompt_function 75 self.class_tensors = None 76 if state_dict_path is not None: 77 self.load_state_dict(torch.load(state_dict_path), strict=strict) 78 # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors]) 79 # self.predictor = nn.DataParallel(self.predictor) 80 # if self.ssl != [None]: 81 # self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl]) 82 83 # def to(self, device, *args, **kwargs): 84 # if self.class_tensors is not None: 85 # self.class_tensors = { 86 # k: v.to(device) for k, v in self.class_tensors.items() 87 # } 88 # return super().to(device, *args, **kwargs) 89 90 def freeze_feature_extractor(self) -> None: 91 """Freeze the parameters of the feature extraction module.""" 92 for param in self.feature_extractor.parameters(): 93 param.requires_grad = False 94 95 def unfreeze_feature_extractor(self) -> None: 96 """Unfreeze the parameters of the feature extraction module.""" 97 for param in self.feature_extractor.parameters(): 98 param.requires_grad = True 99 100 def load_state_dict(self, state_dict: str, strict: bool = True) -> None: 101 """Load a model state dictionary. 102 103 Parameters 104 ---------- 105 state_dict : str 106 the path to the saved state dictionary 107 strict : bool, default True 108 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 109 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 110 111 """ 112 try: 113 super().load_state_dict(state_dict, strict) 114 except RuntimeError as e: 115 if strict: 116 raise e 117 else: 118 warnings.warn( 119 "Some of the layer shapes do not match the loaded state dictionary, skipping those" 120 ) 121 own_state = self.state_dict() 122 for name, param in state_dict.items(): 123 if name not in own_state: 124 continue 125 if isinstance(param, nn.Parameter): 126 # backwards compatibility for serialized parameters 127 param = param.data 128 try: 129 own_state[name].copy_(param) 130 except: 131 pass 132 133 def ssl_off(self) -> None: 134 """Turn SSL off (SSL output will not be computed by the forward function).""" 135 self.ssl_active = False 136 137 def ssl_on(self) -> None: 138 """Turn SSL on (SSL output will be computed by the forward function).""" 139 self.ssl_active = True 140 141 def main_task_on(self) -> None: 142 """Turn main task training on.""" 143 self.main_task_active = True 144 145 def main_task_off(self) -> None: 146 """Turn main task training on.""" 147 self.main_task_active = False 148 149 def set_ssl( 150 self, 151 ssl_constructors: List = None, 152 ssl_types: List = None, 153 ssl_modules: List = None, 154 ) -> None: 155 """Set the SSL modules.""" 156 if ssl_constructors is None and ssl_types is None: 157 self.ssl_type = ["none"] 158 self.ssl = [None] 159 else: 160 if ssl_constructors is not None: 161 ssl_types = [ 162 ssl_constructor.type for ssl_constructor in ssl_constructors 163 ] 164 ssl_modules = [ 165 ssl_constructor.construct_module() 166 for ssl_constructor in ssl_constructors 167 ] 168 if not isinstance(ssl_types, Iterable): 169 ssl_types = [ssl_types] 170 ssl_modules = [ssl_modules] 171 for t in ssl_types: 172 if t not in available_ssl_types: 173 raise ValueError( 174 f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}" 175 ) 176 self.ssl_type = ssl_types 177 self.ssl = nn.ModuleList(ssl_modules) 178 179 @abstractmethod 180 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 181 """Construct the feature extractor module. 182 183 Returns 184 ------- 185 feature_extractor : torch.nn.Module 186 an instance of torch.nn.Module that has a forward method receiving input and 187 returning features that can be passed to an SSL module and to a prediction module 188 189 """ 190 191 @abstractmethod 192 def _predictor(self) -> torch.nn.Module: 193 """Construct the predictor module. 194 195 Returns 196 ------- 197 predictor : torch.nn.Module 198 an instance of torch.nn.Module that has a forward method receiving features 199 extracted by self.feature_extractor and returning a prediction 200 201 """ 202 203 @abstractmethod 204 def features_shape(self) -> torch.Size: 205 """Get the shape of feature extractor output. 206 207 Returns 208 ------- 209 feature_shape : torch.Size 210 shape of feature extractor output 211 212 """ 213 214 def extract_features(self, x, start=0): 215 """Apply the feature extraction modules consecutively. 216 217 Parameters 218 ---------- 219 x : torch.Tensor 220 the input tensor 221 start : int, default 0 222 the index of the feature extraction module to start with 223 224 Returns 225 ------- 226 output : torch.Tensor 227 the output tensor 228 229 """ 230 if start == 0: 231 x = self.feature_extractor(x) 232 for extractor in self.feature_extractors[max(0, start - 1) :]: 233 x = extractor(x) 234 return x 235 236 def _extract_features_first(self, x): 237 """Extract features from the first feature extractor.""" 238 return self.feature_extractor(x) 239 240 def transform_labels(self, device): 241 """Transform the labels into a tensor of shape (1, n_classes, n_features).""" 242 return { 243 k: self.transform_label(v.to(device)).mean(0).unsqueeze(0) 244 for k, v in self.class_tensors.items() 245 } 246 247 def forward( 248 self, 249 x: torch.Tensor, 250 ssl_xs: list, 251 tag: torch.Tensor = None, 252 ) -> tuple[torch.Tensor, list]: 253 """Generate a prediction for x. 254 255 Parameters 256 ---------- 257 x : torch.Tensor 258 the main input 259 ssl_xs : list 260 a list of SSL input tensors 261 tag : any, optional 262 a meta information tag 263 264 Returns 265 ------- 266 prediction : torch.Tensor 267 prediction for the main input 268 ssl_out : list 269 a list of SSL prediction tensors 270 271 """ 272 ssl_out = None 273 features_0 = self._extract_features_first(x) 274 if len(self.feature_extractors) > 1: 275 features = copy.copy(features_0) 276 features = self.extract_features(features, start=1) 277 else: 278 features = features_0 279 if self.ssl_active: 280 ssl_out = [] 281 for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type): 282 if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]: 283 ssl_features = self.extract_features(ssl_x) 284 if ssl_type == "ssl_input": 285 if ssl_features.shape[0] > ssl_x.shape[0]: 286 if ssl_features.shape[0] % ssl_x.shape[0] != 0: 287 raise ValueError( 288 "The length of the SSL input tensor must be a multiple of the main input tensor" 289 ) 290 ssl_features = ssl_features[:ssl_x.shape[0]] #This for C2F-TCN where you want to keep onlw the first (which is the last) feature set 291 ssl_out.append(ssl(ssl_features)) 292 elif ssl_type == "contrastive_2layers": 293 ssl_out.append( 294 (ssl(features_0, extract_features=False), ssl(ssl_features)) 295 ) 296 elif ssl_type == "contrastive": 297 ssl_out.append((ssl(features), ssl(ssl_features))) 298 elif ssl_type == "ssl_target": 299 ssl_out.append(ssl(features)) 300 args = [features] 301 if self.main_task_active: 302 x = self.predictor(*args) 303 else: 304 x = None 305 return x, ssl_out
Base class for all models.
Manages interaction of base model and SSL modules + ensures consistent input and output format.
36 def __init__( 37 self, 38 ssl_constructors: List = None, 39 ssl_modules: List = None, 40 ssl_types: List = None, 41 state_dict_path: str = None, 42 strict: bool = False, 43 prompt_function: Callable = None, 44 ) -> None: 45 """Initialize the model. 46 47 Parameters 48 ---------- 49 ssl_constructors : list, optional 50 a list of SSL constructors that build the necessary SSL modules 51 ssl_modules : list, optional 52 a list of torch.nn.Module instances that will serve as SSL modules 53 ssl_types : list, optional 54 a list of string SSL types 55 state_dict_path : str, optional 56 path to the model state dictionary to load 57 strict : bool, default False 58 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 59 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 60 prompt_function : callable, optional 61 a function that takes a list of strings and returns a string prompt 62 63 """ 64 super(Model, self).__init__() 65 feature_extractors = self._feature_extractor() 66 if not isinstance(feature_extractors, list): 67 feature_extractors = [feature_extractors] 68 self.feature_extractor = feature_extractors[0] 69 self.feature_extractors = nn.ModuleList(feature_extractors[1:]) 70 self.predictor = self._predictor() 71 self.set_ssl(ssl_constructors, ssl_types, ssl_modules) 72 self.ssl_active = True 73 self.main_task_active = True 74 self.prompt_function = prompt_function 75 self.class_tensors = None 76 if state_dict_path is not None: 77 self.load_state_dict(torch.load(state_dict_path), strict=strict) 78 # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors]) 79 # self.predictor = nn.DataParallel(self.predictor) 80 # if self.ssl != [None]: 81 # self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl])
Initialize the model.
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt
90 def freeze_feature_extractor(self) -> None: 91 """Freeze the parameters of the feature extraction module.""" 92 for param in self.feature_extractor.parameters(): 93 param.requires_grad = False
Freeze the parameters of the feature extraction module.
95 def unfreeze_feature_extractor(self) -> None: 96 """Unfreeze the parameters of the feature extraction module.""" 97 for param in self.feature_extractor.parameters(): 98 param.requires_grad = True
Unfreeze the parameters of the feature extraction module.
100 def load_state_dict(self, state_dict: str, strict: bool = True) -> None: 101 """Load a model state dictionary. 102 103 Parameters 104 ---------- 105 state_dict : str 106 the path to the saved state dictionary 107 strict : bool, default True 108 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 109 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 110 111 """ 112 try: 113 super().load_state_dict(state_dict, strict) 114 except RuntimeError as e: 115 if strict: 116 raise e 117 else: 118 warnings.warn( 119 "Some of the layer shapes do not match the loaded state dictionary, skipping those" 120 ) 121 own_state = self.state_dict() 122 for name, param in state_dict.items(): 123 if name not in own_state: 124 continue 125 if isinstance(param, nn.Parameter): 126 # backwards compatibility for serialized parameters 127 param = param.data 128 try: 129 own_state[name].copy_(param) 130 except: 131 pass
Load a model state dictionary.
Parameters
state_dict : str the path to the saved state dictionary strict : bool, default True when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
133 def ssl_off(self) -> None: 134 """Turn SSL off (SSL output will not be computed by the forward function).""" 135 self.ssl_active = False
Turn SSL off (SSL output will not be computed by the forward function).
137 def ssl_on(self) -> None: 138 """Turn SSL on (SSL output will be computed by the forward function).""" 139 self.ssl_active = True
Turn SSL on (SSL output will be computed by the forward function).
141 def main_task_on(self) -> None: 142 """Turn main task training on.""" 143 self.main_task_active = True
Turn main task training on.
145 def main_task_off(self) -> None: 146 """Turn main task training on.""" 147 self.main_task_active = False
Turn main task training on.
149 def set_ssl( 150 self, 151 ssl_constructors: List = None, 152 ssl_types: List = None, 153 ssl_modules: List = None, 154 ) -> None: 155 """Set the SSL modules.""" 156 if ssl_constructors is None and ssl_types is None: 157 self.ssl_type = ["none"] 158 self.ssl = [None] 159 else: 160 if ssl_constructors is not None: 161 ssl_types = [ 162 ssl_constructor.type for ssl_constructor in ssl_constructors 163 ] 164 ssl_modules = [ 165 ssl_constructor.construct_module() 166 for ssl_constructor in ssl_constructors 167 ] 168 if not isinstance(ssl_types, Iterable): 169 ssl_types = [ssl_types] 170 ssl_modules = [ssl_modules] 171 for t in ssl_types: 172 if t not in available_ssl_types: 173 raise ValueError( 174 f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}" 175 ) 176 self.ssl_type = ssl_types 177 self.ssl = nn.ModuleList(ssl_modules)
Set the SSL modules.
203 @abstractmethod 204 def features_shape(self) -> torch.Size: 205 """Get the shape of feature extractor output. 206 207 Returns 208 ------- 209 feature_shape : torch.Size 210 shape of feature extractor output 211 212 """
Get the shape of feature extractor output.
Returns
feature_shape : torch.Size shape of feature extractor output
214 def extract_features(self, x, start=0): 215 """Apply the feature extraction modules consecutively. 216 217 Parameters 218 ---------- 219 x : torch.Tensor 220 the input tensor 221 start : int, default 0 222 the index of the feature extraction module to start with 223 224 Returns 225 ------- 226 output : torch.Tensor 227 the output tensor 228 229 """ 230 if start == 0: 231 x = self.feature_extractor(x) 232 for extractor in self.feature_extractors[max(0, start - 1) :]: 233 x = extractor(x) 234 return x
Apply the feature extraction modules consecutively.
Parameters
x : torch.Tensor the input tensor start : int, default 0 the index of the feature extraction module to start with
Returns
output : torch.Tensor the output tensor
240 def transform_labels(self, device): 241 """Transform the labels into a tensor of shape (1, n_classes, n_features).""" 242 return { 243 k: self.transform_label(v.to(device)).mean(0).unsqueeze(0) 244 for k, v in self.class_tensors.items() 245 }
Transform the labels into a tensor of shape (1, n_classes, n_features).
247 def forward( 248 self, 249 x: torch.Tensor, 250 ssl_xs: list, 251 tag: torch.Tensor = None, 252 ) -> tuple[torch.Tensor, list]: 253 """Generate a prediction for x. 254 255 Parameters 256 ---------- 257 x : torch.Tensor 258 the main input 259 ssl_xs : list 260 a list of SSL input tensors 261 tag : any, optional 262 a meta information tag 263 264 Returns 265 ------- 266 prediction : torch.Tensor 267 prediction for the main input 268 ssl_out : list 269 a list of SSL prediction tensors 270 271 """ 272 ssl_out = None 273 features_0 = self._extract_features_first(x) 274 if len(self.feature_extractors) > 1: 275 features = copy.copy(features_0) 276 features = self.extract_features(features, start=1) 277 else: 278 features = features_0 279 if self.ssl_active: 280 ssl_out = [] 281 for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type): 282 if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]: 283 ssl_features = self.extract_features(ssl_x) 284 if ssl_type == "ssl_input": 285 if ssl_features.shape[0] > ssl_x.shape[0]: 286 if ssl_features.shape[0] % ssl_x.shape[0] != 0: 287 raise ValueError( 288 "The length of the SSL input tensor must be a multiple of the main input tensor" 289 ) 290 ssl_features = ssl_features[:ssl_x.shape[0]] #This for C2F-TCN where you want to keep onlw the first (which is the last) feature set 291 ssl_out.append(ssl(ssl_features)) 292 elif ssl_type == "contrastive_2layers": 293 ssl_out.append( 294 (ssl(features_0, extract_features=False), ssl(ssl_features)) 295 ) 296 elif ssl_type == "contrastive": 297 ssl_out.append((ssl(features), ssl(ssl_features))) 298 elif ssl_type == "ssl_target": 299 ssl_out.append(ssl(features)) 300 args = [features] 301 if self.main_task_active: 302 x = self.predictor(*args) 303 else: 304 x = None 305 return x, ssl_out
Generate a prediction for x.
Parameters
x : torch.Tensor the main input ssl_xs : list a list of SSL input tensors tag : any, optional a meta information tag
Returns
prediction : torch.Tensor prediction for the main input ssl_out : list a list of SSL prediction tensors
308class LoadedModel(Model): 309 """A class to generate a Model instance from a torch.nn.Module.""" 310 311 ssl_types = ["none"] 312 313 def __init__(self, model: nn.Module, **kwargs) -> None: 314 """Initialize the model. 315 316 Parameters 317 ---------- 318 model : torch.nn.Module 319 a model with a forward function that takes a single tensor as input and returns a single tensor as output 320 321 """ 322 super(LoadedModel, self).__init__() 323 self.ssl_active = False 324 self.feature_extractor = model 325 326 def _feature_extractor(self) -> None: 327 """Set feature extractor.""" 328 pass 329 330 def _predictor(self) -> None: 331 """Set predictor.""" 332 self.predictor = nn.Identity() 333 334 def ssl_on(self): 335 """Turn SSL on (SSL output will be computed by the forward function).""" 336 pass
A class to generate a Model instance from a torch.nn.Module.
313 def __init__(self, model: nn.Module, **kwargs) -> None: 314 """Initialize the model. 315 316 Parameters 317 ---------- 318 model : torch.nn.Module 319 a model with a forward function that takes a single tensor as input and returns a single tensor as output 320 321 """ 322 super(LoadedModel, self).__init__() 323 self.ssl_active = False 324 self.feature_extractor = model
Initialize the model.
Parameters
model : torch.nn.Module a model with a forward function that takes a single tensor as input and returns a single tensor as output