dlc2action.model.base_model
Abstract parent class for models used in dlc2action.task.universal_task.Task
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Abstract parent class for models used in `dlc2action.task.universal_task.Task` 8""" 9 10from typing import Dict, Union, List, Callable, Tuple 11from torch import nn 12from abc import ABC, abstractmethod 13import torch 14import warnings 15from collections.abc import Iterable 16import copy 17 18available_ssl_types = [ 19 "ssl_input", 20 "ssl_target", 21 "contrastive", 22 "none", 23 "contrastive_2layers", 24] 25 26 27class Model(nn.Module, ABC): 28 """ 29 Base class for all models 30 31 Manages interaction of base model and SSL modules + ensures consistent input and output format 32 """ 33 34 process_labels = False 35 36 def __init__( 37 self, 38 ssl_constructors: List = None, 39 ssl_modules: List = None, 40 ssl_types: List = None, 41 state_dict_path: str = None, 42 strict: bool = False, 43 prompt_function: Callable = None, 44 ) -> None: 45 """ 46 Parameters 47 ---------- 48 ssl_constructors : list, optional 49 a list of SSL constructors that build the necessary SSL modules 50 ssl_modules : list, optional 51 a list of torch.nn.Module instances that will serve as SSL modules 52 ssl_types : list, optional 53 a list of string SSL types 54 state_dict_path : str, optional 55 path to the model state dictionary to load 56 strict : bool, default False 57 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 58 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 59 """ 60 61 super(Model, self).__init__() 62 feature_extractors = self._feature_extractor() 63 if not isinstance(feature_extractors, list): 64 feature_extractors = [feature_extractors] 65 self.feature_extractor = feature_extractors[0] 66 self.feature_extractors = nn.ModuleList(feature_extractors[1:]) 67 self.predictor = self._predictor() 68 self.set_ssl(ssl_constructors, ssl_types, ssl_modules) 69 self.ssl_active = True 70 self.main_task_active = True 71 self.prompt_function = prompt_function 72 self.class_tensors = None 73 if state_dict_path is not None: 74 self.load_state_dict(torch.load(state_dict_path), strict=strict) 75 # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors]) 76 # self.predictor = nn.DataParallel(self.predictor) 77 # if self.ssl != [None]: 78 # self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl]) 79 80 # def to(self, device, *args, **kwargs): 81 # if self.class_tensors is not None: 82 # self.class_tensors = { 83 # k: v.to(device) for k, v in self.class_tensors.items() 84 # } 85 # return super().to(device, *args, **kwargs) 86 87 def freeze_feature_extractor(self) -> None: 88 """ 89 Freeze the parameters of the feature extraction module 90 """ 91 92 for param in self.feature_extractor.parameters(): 93 param.requires_grad = False 94 95 def unfreeze_feature_extractor(self) -> None: 96 """ 97 Unfreeze the parameters of the feature extraction module 98 """ 99 100 for param in self.feature_extractor.parameters(): 101 param.requires_grad = True 102 103 def load_state_dict(self, state_dict: str, strict: bool = True) -> None: 104 """ 105 Load a model state dictionary 106 107 Parameters 108 ---------- 109 state_dict : str 110 the path to the saved state dictionary 111 strict : bool, default True 112 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 113 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 114 """ 115 116 try: 117 super().load_state_dict(state_dict, strict) 118 except RuntimeError as e: 119 if strict: 120 raise e 121 else: 122 warnings.warn( 123 "Some of the layer shapes do not match the loaded state dictionary, skipping those" 124 ) 125 own_state = self.state_dict() 126 for name, param in state_dict.items(): 127 if name not in own_state: 128 continue 129 if isinstance(param, nn.Parameter): 130 # backwards compatibility for serialized parameters 131 param = param.data 132 try: 133 own_state[name].copy_(param) 134 except: 135 pass 136 137 def ssl_off(self) -> None: 138 """ 139 Turn SSL off (SSL output will not be computed by the forward function) 140 """ 141 142 self.ssl_active = False 143 144 def ssl_on(self) -> None: 145 """ 146 Turn SSL on (SSL output will be computed by the forward function) 147 """ 148 149 self.ssl_active = True 150 151 def main_task_on(self) -> None: 152 """ 153 Turn main task training on 154 """ 155 156 self.main_task_active = True 157 158 def main_task_off(self) -> None: 159 """ 160 Turn main task training on 161 """ 162 163 self.main_task_active = False 164 165 def set_ssl( 166 self, 167 ssl_constructors: List = None, 168 ssl_types: List = None, 169 ssl_modules: List = None, 170 ) -> None: 171 """ 172 Set the SSL modules 173 """ 174 175 if ssl_constructors is None and ssl_types is None: 176 self.ssl_type = ["none"] 177 self.ssl = [None] 178 else: 179 if ssl_constructors is not None: 180 ssl_types = [ 181 ssl_constructor.type for ssl_constructor in ssl_constructors 182 ] 183 ssl_modules = [ 184 ssl_constructor.construct_module() 185 for ssl_constructor in ssl_constructors 186 ] 187 if not isinstance(ssl_types, Iterable): 188 ssl_types = [ssl_types] 189 ssl_modules = [ssl_modules] 190 for t in ssl_types: 191 if t not in available_ssl_types: 192 raise ValueError( 193 f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}" 194 ) 195 self.ssl_type = ssl_types 196 self.ssl = nn.ModuleList(ssl_modules) 197 198 @abstractmethod 199 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 200 """ 201 Construct the feature extractor module 202 203 Returns 204 ------- 205 feature_extractor : torch.nn.Module 206 an instance of torch.nn.Module that has a forward method receiving input and 207 returning features that can be passed to an SSL module and to a prediction module 208 """ 209 210 @abstractmethod 211 def _predictor(self) -> torch.nn.Module: 212 """ 213 Construct the predictor module 214 215 Returns 216 ------- 217 predictor : torch.nn.Module 218 an instance of torch.nn.Module that has a forward method receiving features 219 exctracted by self.feature_extractor and returning a prediction 220 """ 221 222 @abstractmethod 223 def features_shape(self) -> torch.Size: 224 """ 225 Get the shape of feature extractor output 226 227 Returns 228 ------- 229 feature_shape : torch.Size 230 shape of feature extractor output 231 """ 232 233 def extract_features(self, x, start=0): 234 """ 235 Apply the feature extraction modules consecutively 236 237 Parameters 238 ---------- 239 x : torch.Tensor 240 the input tensor 241 start : int, default 0 242 the index of the feature extraction module to start with 243 244 Returns 245 ------- 246 output : torch.Tensor 247 the output tensor 248 """ 249 250 if start == 0: 251 x = self.feature_extractor(x) 252 for extractor in self.feature_extractors[max(0, start - 1) :]: 253 x = extractor(x) 254 return x 255 256 def _extract_features_first(self, x): 257 return self.feature_extractor(x) 258 259 def forward( 260 self, 261 x: torch.Tensor, 262 ssl_xs: list, 263 tag: torch.Tensor = None, 264 ) -> Tuple[torch.Tensor, list]: 265 """ 266 Generate a prediction for x 267 268 Parameters 269 ---------- 270 x : torch.Tensor 271 the main input 272 ssl_xs : list 273 a list of SSL input tensors 274 tag : any, optional 275 a meta information tag 276 277 Returns 278 ------- 279 prediction : torch.Tensor 280 prediction for the main input 281 ssl_out : list 282 a list of SSL prediction tensors 283 """ 284 285 ssl_out = None 286 features_0 = self._extract_features_first(x) 287 if len(self.feature_extractors) > 1: 288 features = copy.copy(features_0) 289 features = self.extract_features(features, start=1) 290 else: 291 features = features_0 292 if self.ssl_active: 293 ssl_out = [] 294 for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type): 295 if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]: 296 ssl_features = self.extract_features(ssl_x) 297 if ssl_type == "ssl_input": 298 ssl_out.append(ssl(ssl_features)) 299 elif ssl_type == "contrastive_2layers": 300 ssl_out.append( 301 (ssl(features_0, extract_features=False), ssl(ssl_features)) 302 ) 303 elif ssl_type == "contrastive": 304 ssl_out.append((ssl(features), ssl(ssl_features))) 305 elif ssl_type == "ssl_target": 306 ssl_out.append(ssl(features)) 307 args = [features] 308 if tag is not None: 309 args.append(tag) 310 if self.main_task_active: 311 x = self.predictor(*args) 312 else: 313 x = None 314 return x, ssl_out 315 316 317class LoadedModel(Model): 318 """ 319 A class to generate a Model instance from a torch.nn.Module 320 """ 321 322 ssl_types = ["none"] 323 324 def __init__(self, model: nn.Module, **kwargs) -> None: 325 """ 326 Parameters 327 ---------- 328 model : torch.nn.Module 329 a model with a forward function that takes a single tensor as input and returns a single tensor as output 330 """ 331 332 super(LoadedModel, self).__init__() 333 self.ssl_active = False 334 self.feature_extractor = model 335 336 def _feature_extractor(self) -> None: 337 """ 338 Set feature extractor 339 """ 340 341 pass 342 343 def _predictor(self) -> None: 344 """ 345 Set predictor 346 """ 347 348 self.predictor = nn.Identity() 349 350 def ssl_on(self): 351 """ 352 Turn SSL on (SSL output will be computed by the forward function) 353 """ 354 355 pass
28class Model(nn.Module, ABC): 29 """ 30 Base class for all models 31 32 Manages interaction of base model and SSL modules + ensures consistent input and output format 33 """ 34 35 process_labels = False 36 37 def __init__( 38 self, 39 ssl_constructors: List = None, 40 ssl_modules: List = None, 41 ssl_types: List = None, 42 state_dict_path: str = None, 43 strict: bool = False, 44 prompt_function: Callable = None, 45 ) -> None: 46 """ 47 Parameters 48 ---------- 49 ssl_constructors : list, optional 50 a list of SSL constructors that build the necessary SSL modules 51 ssl_modules : list, optional 52 a list of torch.nn.Module instances that will serve as SSL modules 53 ssl_types : list, optional 54 a list of string SSL types 55 state_dict_path : str, optional 56 path to the model state dictionary to load 57 strict : bool, default False 58 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 59 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 60 """ 61 62 super(Model, self).__init__() 63 feature_extractors = self._feature_extractor() 64 if not isinstance(feature_extractors, list): 65 feature_extractors = [feature_extractors] 66 self.feature_extractor = feature_extractors[0] 67 self.feature_extractors = nn.ModuleList(feature_extractors[1:]) 68 self.predictor = self._predictor() 69 self.set_ssl(ssl_constructors, ssl_types, ssl_modules) 70 self.ssl_active = True 71 self.main_task_active = True 72 self.prompt_function = prompt_function 73 self.class_tensors = None 74 if state_dict_path is not None: 75 self.load_state_dict(torch.load(state_dict_path), strict=strict) 76 # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors]) 77 # self.predictor = nn.DataParallel(self.predictor) 78 # if self.ssl != [None]: 79 # self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl]) 80 81 # def to(self, device, *args, **kwargs): 82 # if self.class_tensors is not None: 83 # self.class_tensors = { 84 # k: v.to(device) for k, v in self.class_tensors.items() 85 # } 86 # return super().to(device, *args, **kwargs) 87 88 def freeze_feature_extractor(self) -> None: 89 """ 90 Freeze the parameters of the feature extraction module 91 """ 92 93 for param in self.feature_extractor.parameters(): 94 param.requires_grad = False 95 96 def unfreeze_feature_extractor(self) -> None: 97 """ 98 Unfreeze the parameters of the feature extraction module 99 """ 100 101 for param in self.feature_extractor.parameters(): 102 param.requires_grad = True 103 104 def load_state_dict(self, state_dict: str, strict: bool = True) -> None: 105 """ 106 Load a model state dictionary 107 108 Parameters 109 ---------- 110 state_dict : str 111 the path to the saved state dictionary 112 strict : bool, default True 113 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 114 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 115 """ 116 117 try: 118 super().load_state_dict(state_dict, strict) 119 except RuntimeError as e: 120 if strict: 121 raise e 122 else: 123 warnings.warn( 124 "Some of the layer shapes do not match the loaded state dictionary, skipping those" 125 ) 126 own_state = self.state_dict() 127 for name, param in state_dict.items(): 128 if name not in own_state: 129 continue 130 if isinstance(param, nn.Parameter): 131 # backwards compatibility for serialized parameters 132 param = param.data 133 try: 134 own_state[name].copy_(param) 135 except: 136 pass 137 138 def ssl_off(self) -> None: 139 """ 140 Turn SSL off (SSL output will not be computed by the forward function) 141 """ 142 143 self.ssl_active = False 144 145 def ssl_on(self) -> None: 146 """ 147 Turn SSL on (SSL output will be computed by the forward function) 148 """ 149 150 self.ssl_active = True 151 152 def main_task_on(self) -> None: 153 """ 154 Turn main task training on 155 """ 156 157 self.main_task_active = True 158 159 def main_task_off(self) -> None: 160 """ 161 Turn main task training on 162 """ 163 164 self.main_task_active = False 165 166 def set_ssl( 167 self, 168 ssl_constructors: List = None, 169 ssl_types: List = None, 170 ssl_modules: List = None, 171 ) -> None: 172 """ 173 Set the SSL modules 174 """ 175 176 if ssl_constructors is None and ssl_types is None: 177 self.ssl_type = ["none"] 178 self.ssl = [None] 179 else: 180 if ssl_constructors is not None: 181 ssl_types = [ 182 ssl_constructor.type for ssl_constructor in ssl_constructors 183 ] 184 ssl_modules = [ 185 ssl_constructor.construct_module() 186 for ssl_constructor in ssl_constructors 187 ] 188 if not isinstance(ssl_types, Iterable): 189 ssl_types = [ssl_types] 190 ssl_modules = [ssl_modules] 191 for t in ssl_types: 192 if t not in available_ssl_types: 193 raise ValueError( 194 f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}" 195 ) 196 self.ssl_type = ssl_types 197 self.ssl = nn.ModuleList(ssl_modules) 198 199 @abstractmethod 200 def _feature_extractor(self) -> Union[torch.nn.Module, List]: 201 """ 202 Construct the feature extractor module 203 204 Returns 205 ------- 206 feature_extractor : torch.nn.Module 207 an instance of torch.nn.Module that has a forward method receiving input and 208 returning features that can be passed to an SSL module and to a prediction module 209 """ 210 211 @abstractmethod 212 def _predictor(self) -> torch.nn.Module: 213 """ 214 Construct the predictor module 215 216 Returns 217 ------- 218 predictor : torch.nn.Module 219 an instance of torch.nn.Module that has a forward method receiving features 220 exctracted by self.feature_extractor and returning a prediction 221 """ 222 223 @abstractmethod 224 def features_shape(self) -> torch.Size: 225 """ 226 Get the shape of feature extractor output 227 228 Returns 229 ------- 230 feature_shape : torch.Size 231 shape of feature extractor output 232 """ 233 234 def extract_features(self, x, start=0): 235 """ 236 Apply the feature extraction modules consecutively 237 238 Parameters 239 ---------- 240 x : torch.Tensor 241 the input tensor 242 start : int, default 0 243 the index of the feature extraction module to start with 244 245 Returns 246 ------- 247 output : torch.Tensor 248 the output tensor 249 """ 250 251 if start == 0: 252 x = self.feature_extractor(x) 253 for extractor in self.feature_extractors[max(0, start - 1) :]: 254 x = extractor(x) 255 return x 256 257 def _extract_features_first(self, x): 258 return self.feature_extractor(x) 259 260 def forward( 261 self, 262 x: torch.Tensor, 263 ssl_xs: list, 264 tag: torch.Tensor = None, 265 ) -> Tuple[torch.Tensor, list]: 266 """ 267 Generate a prediction for x 268 269 Parameters 270 ---------- 271 x : torch.Tensor 272 the main input 273 ssl_xs : list 274 a list of SSL input tensors 275 tag : any, optional 276 a meta information tag 277 278 Returns 279 ------- 280 prediction : torch.Tensor 281 prediction for the main input 282 ssl_out : list 283 a list of SSL prediction tensors 284 """ 285 286 ssl_out = None 287 features_0 = self._extract_features_first(x) 288 if len(self.feature_extractors) > 1: 289 features = copy.copy(features_0) 290 features = self.extract_features(features, start=1) 291 else: 292 features = features_0 293 if self.ssl_active: 294 ssl_out = [] 295 for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type): 296 if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]: 297 ssl_features = self.extract_features(ssl_x) 298 if ssl_type == "ssl_input": 299 ssl_out.append(ssl(ssl_features)) 300 elif ssl_type == "contrastive_2layers": 301 ssl_out.append( 302 (ssl(features_0, extract_features=False), ssl(ssl_features)) 303 ) 304 elif ssl_type == "contrastive": 305 ssl_out.append((ssl(features), ssl(ssl_features))) 306 elif ssl_type == "ssl_target": 307 ssl_out.append(ssl(features)) 308 args = [features] 309 if tag is not None: 310 args.append(tag) 311 if self.main_task_active: 312 x = self.predictor(*args) 313 else: 314 x = None 315 return x, ssl_out
Base class for all models
Manages interaction of base model and SSL modules + ensures consistent input and output format
37 def __init__( 38 self, 39 ssl_constructors: List = None, 40 ssl_modules: List = None, 41 ssl_types: List = None, 42 state_dict_path: str = None, 43 strict: bool = False, 44 prompt_function: Callable = None, 45 ) -> None: 46 """ 47 Parameters 48 ---------- 49 ssl_constructors : list, optional 50 a list of SSL constructors that build the necessary SSL modules 51 ssl_modules : list, optional 52 a list of torch.nn.Module instances that will serve as SSL modules 53 ssl_types : list, optional 54 a list of string SSL types 55 state_dict_path : str, optional 56 path to the model state dictionary to load 57 strict : bool, default False 58 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 59 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 60 """ 61 62 super(Model, self).__init__() 63 feature_extractors = self._feature_extractor() 64 if not isinstance(feature_extractors, list): 65 feature_extractors = [feature_extractors] 66 self.feature_extractor = feature_extractors[0] 67 self.feature_extractors = nn.ModuleList(feature_extractors[1:]) 68 self.predictor = self._predictor() 69 self.set_ssl(ssl_constructors, ssl_types, ssl_modules) 70 self.ssl_active = True 71 self.main_task_active = True 72 self.prompt_function = prompt_function 73 self.class_tensors = None 74 if state_dict_path is not None: 75 self.load_state_dict(torch.load(state_dict_path), strict=strict) 76 # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors]) 77 # self.predictor = nn.DataParallel(self.predictor) 78 # if self.ssl != [None]: 79 # self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl])
Parameters
ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
88 def freeze_feature_extractor(self) -> None: 89 """ 90 Freeze the parameters of the feature extraction module 91 """ 92 93 for param in self.feature_extractor.parameters(): 94 param.requires_grad = False
Freeze the parameters of the feature extraction module
96 def unfreeze_feature_extractor(self) -> None: 97 """ 98 Unfreeze the parameters of the feature extraction module 99 """ 100 101 for param in self.feature_extractor.parameters(): 102 param.requires_grad = True
Unfreeze the parameters of the feature extraction module
104 def load_state_dict(self, state_dict: str, strict: bool = True) -> None: 105 """ 106 Load a model state dictionary 107 108 Parameters 109 ---------- 110 state_dict : str 111 the path to the saved state dictionary 112 strict : bool, default True 113 when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; 114 otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored 115 """ 116 117 try: 118 super().load_state_dict(state_dict, strict) 119 except RuntimeError as e: 120 if strict: 121 raise e 122 else: 123 warnings.warn( 124 "Some of the layer shapes do not match the loaded state dictionary, skipping those" 125 ) 126 own_state = self.state_dict() 127 for name, param in state_dict.items(): 128 if name not in own_state: 129 continue 130 if isinstance(param, nn.Parameter): 131 # backwards compatibility for serialized parameters 132 param = param.data 133 try: 134 own_state[name].copy_(param) 135 except: 136 pass
Load a model state dictionary
Parameters
state_dict : str the path to the saved state dictionary strict : bool, default True when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
138 def ssl_off(self) -> None: 139 """ 140 Turn SSL off (SSL output will not be computed by the forward function) 141 """ 142 143 self.ssl_active = False
Turn SSL off (SSL output will not be computed by the forward function)
145 def ssl_on(self) -> None: 146 """ 147 Turn SSL on (SSL output will be computed by the forward function) 148 """ 149 150 self.ssl_active = True
Turn SSL on (SSL output will be computed by the forward function)
152 def main_task_on(self) -> None: 153 """ 154 Turn main task training on 155 """ 156 157 self.main_task_active = True
Turn main task training on
159 def main_task_off(self) -> None: 160 """ 161 Turn main task training on 162 """ 163 164 self.main_task_active = False
Turn main task training on
166 def set_ssl( 167 self, 168 ssl_constructors: List = None, 169 ssl_types: List = None, 170 ssl_modules: List = None, 171 ) -> None: 172 """ 173 Set the SSL modules 174 """ 175 176 if ssl_constructors is None and ssl_types is None: 177 self.ssl_type = ["none"] 178 self.ssl = [None] 179 else: 180 if ssl_constructors is not None: 181 ssl_types = [ 182 ssl_constructor.type for ssl_constructor in ssl_constructors 183 ] 184 ssl_modules = [ 185 ssl_constructor.construct_module() 186 for ssl_constructor in ssl_constructors 187 ] 188 if not isinstance(ssl_types, Iterable): 189 ssl_types = [ssl_types] 190 ssl_modules = [ssl_modules] 191 for t in ssl_types: 192 if t not in available_ssl_types: 193 raise ValueError( 194 f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}" 195 ) 196 self.ssl_type = ssl_types 197 self.ssl = nn.ModuleList(ssl_modules)
Set the SSL modules
223 @abstractmethod 224 def features_shape(self) -> torch.Size: 225 """ 226 Get the shape of feature extractor output 227 228 Returns 229 ------- 230 feature_shape : torch.Size 231 shape of feature extractor output 232 """
Get the shape of feature extractor output
Returns
feature_shape : torch.Size shape of feature extractor output
234 def extract_features(self, x, start=0): 235 """ 236 Apply the feature extraction modules consecutively 237 238 Parameters 239 ---------- 240 x : torch.Tensor 241 the input tensor 242 start : int, default 0 243 the index of the feature extraction module to start with 244 245 Returns 246 ------- 247 output : torch.Tensor 248 the output tensor 249 """ 250 251 if start == 0: 252 x = self.feature_extractor(x) 253 for extractor in self.feature_extractors[max(0, start - 1) :]: 254 x = extractor(x) 255 return x
Apply the feature extraction modules consecutively
Parameters
x : torch.Tensor the input tensor start : int, default 0 the index of the feature extraction module to start with
Returns
output : torch.Tensor the output tensor
260 def forward( 261 self, 262 x: torch.Tensor, 263 ssl_xs: list, 264 tag: torch.Tensor = None, 265 ) -> Tuple[torch.Tensor, list]: 266 """ 267 Generate a prediction for x 268 269 Parameters 270 ---------- 271 x : torch.Tensor 272 the main input 273 ssl_xs : list 274 a list of SSL input tensors 275 tag : any, optional 276 a meta information tag 277 278 Returns 279 ------- 280 prediction : torch.Tensor 281 prediction for the main input 282 ssl_out : list 283 a list of SSL prediction tensors 284 """ 285 286 ssl_out = None 287 features_0 = self._extract_features_first(x) 288 if len(self.feature_extractors) > 1: 289 features = copy.copy(features_0) 290 features = self.extract_features(features, start=1) 291 else: 292 features = features_0 293 if self.ssl_active: 294 ssl_out = [] 295 for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type): 296 if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]: 297 ssl_features = self.extract_features(ssl_x) 298 if ssl_type == "ssl_input": 299 ssl_out.append(ssl(ssl_features)) 300 elif ssl_type == "contrastive_2layers": 301 ssl_out.append( 302 (ssl(features_0, extract_features=False), ssl(ssl_features)) 303 ) 304 elif ssl_type == "contrastive": 305 ssl_out.append((ssl(features), ssl(ssl_features))) 306 elif ssl_type == "ssl_target": 307 ssl_out.append(ssl(features)) 308 args = [features] 309 if tag is not None: 310 args.append(tag) 311 if self.main_task_active: 312 x = self.predictor(*args) 313 else: 314 x = None 315 return x, ssl_out
Generate a prediction for x
Parameters
x : torch.Tensor the main input ssl_xs : list a list of SSL input tensors tag : any, optional a meta information tag
Returns
prediction : torch.Tensor prediction for the main input ssl_out : list a list of SSL prediction tensors
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- T_destination
- state_dict
- register_load_state_dict_post_hook
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
318class LoadedModel(Model): 319 """ 320 A class to generate a Model instance from a torch.nn.Module 321 """ 322 323 ssl_types = ["none"] 324 325 def __init__(self, model: nn.Module, **kwargs) -> None: 326 """ 327 Parameters 328 ---------- 329 model : torch.nn.Module 330 a model with a forward function that takes a single tensor as input and returns a single tensor as output 331 """ 332 333 super(LoadedModel, self).__init__() 334 self.ssl_active = False 335 self.feature_extractor = model 336 337 def _feature_extractor(self) -> None: 338 """ 339 Set feature extractor 340 """ 341 342 pass 343 344 def _predictor(self) -> None: 345 """ 346 Set predictor 347 """ 348 349 self.predictor = nn.Identity() 350 351 def ssl_on(self): 352 """ 353 Turn SSL on (SSL output will be computed by the forward function) 354 """ 355 356 pass
A class to generate a Model instance from a torch.nn.Module
325 def __init__(self, model: nn.Module, **kwargs) -> None: 326 """ 327 Parameters 328 ---------- 329 model : torch.nn.Module 330 a model with a forward function that takes a single tensor as input and returns a single tensor as output 331 """ 332 333 super(LoadedModel, self).__init__() 334 self.ssl_active = False 335 self.feature_extractor = model
Parameters
model : torch.nn.Module a model with a forward function that takes a single tensor as input and returns a single tensor as output
351 def ssl_on(self): 352 """ 353 Turn SSL on (SSL output will be computed by the forward function) 354 """ 355 356 pass
Turn SSL on (SSL output will be computed by the forward function)
Inherited Members
- Model
- process_labels
- freeze_feature_extractor
- unfreeze_feature_extractor
- load_state_dict
- ssl_off
- main_task_on
- main_task_off
- set_ssl
- features_shape
- extract_features
- forward
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- T_destination
- state_dict
- register_load_state_dict_post_hook
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr