dlc2action.model.base_model

Abstract parent class for models used in dlc2action.task.universal_task.Task.

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""Abstract parent class for models used in `dlc2action.task.universal_task.Task`."""
  8
  9import copy
 10import warnings
 11from abc import ABC, abstractmethod
 12from collections.abc import Iterable
 13from typing import Callable, Dict, List, Union
 14
 15import torch
 16from torch import nn
 17
 18available_ssl_types = [
 19    "ssl_input",
 20    "ssl_target",
 21    "contrastive",
 22    "none",
 23    "contrastive_2layers",
 24]
 25
 26
 27class Model(nn.Module, ABC):
 28    """Base class for all models.
 29
 30    Manages interaction of base model and SSL modules + ensures consistent input and output format.
 31    """
 32
 33    process_labels = False
 34
 35    def __init__(
 36        self,
 37        ssl_constructors: List = None,
 38        ssl_modules: List = None,
 39        ssl_types: List = None,
 40        state_dict_path: str = None,
 41        strict: bool = False,
 42        prompt_function: Callable = None,
 43    ) -> None:
 44        """Initialize the model.
 45
 46        Parameters
 47        ----------
 48        ssl_constructors : list, optional
 49            a list of SSL constructors that build the necessary SSL modules
 50        ssl_modules : list, optional
 51            a list of torch.nn.Module instances that will serve as SSL modules
 52        ssl_types : list, optional
 53            a list of string SSL types
 54        state_dict_path : str, optional
 55            path to the model state dictionary to load
 56        strict : bool, default False
 57            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
 58            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
 59        prompt_function : callable, optional
 60            a function that takes a list of strings and returns a string prompt
 61
 62        """
 63        super(Model, self).__init__()
 64        feature_extractors = self._feature_extractor()
 65        if not isinstance(feature_extractors, list):
 66            feature_extractors = [feature_extractors]
 67        self.feature_extractor = feature_extractors[0]
 68        self.feature_extractors = nn.ModuleList(feature_extractors[1:])
 69        self.predictor = self._predictor()
 70        self.set_ssl(ssl_constructors, ssl_types, ssl_modules)
 71        self.ssl_active = True
 72        self.main_task_active = True
 73        self.prompt_function = prompt_function
 74        self.class_tensors = None
 75        if state_dict_path is not None:
 76            self.load_state_dict(torch.load(state_dict_path), strict=strict)
 77        # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors])
 78        # self.predictor = nn.DataParallel(self.predictor)
 79        # if self.ssl != [None]:
 80        #     self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl])
 81
 82    # def to(self, device, *args, **kwargs):
 83    #     if self.class_tensors is not None:
 84    #         self.class_tensors = {
 85    #             k: v.to(device) for k, v in self.class_tensors.items()
 86    #         }
 87    #     return super().to(device, *args, **kwargs)
 88
 89    def freeze_feature_extractor(self) -> None:
 90        """Freeze the parameters of the feature extraction module."""
 91        for param in self.feature_extractor.parameters():
 92            param.requires_grad = False
 93
 94    def unfreeze_feature_extractor(self) -> None:
 95        """Unfreeze the parameters of the feature extraction module."""
 96        for param in self.feature_extractor.parameters():
 97            param.requires_grad = True
 98
 99    def load_state_dict(self, state_dict: str, strict: bool = True) -> None:
100        """Load a model state dictionary.
101
102        Parameters
103        ----------
104        state_dict : str
105            the path to the saved state dictionary
106        strict : bool, default True
107            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
108            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
109
110        """
111        try:
112            super().load_state_dict(state_dict, strict)
113        except RuntimeError as e:
114            if strict:
115                raise e
116            else:
117                warnings.warn(
118                    "Some of the layer shapes do not match the loaded state dictionary, skipping those"
119                )
120                own_state = self.state_dict()
121                for name, param in state_dict.items():
122                    if name not in own_state:
123                        continue
124                    if isinstance(param, nn.Parameter):
125                        # backwards compatibility for serialized parameters
126                        param = param.data
127                    try:
128                        own_state[name].copy_(param)
129                    except:
130                        pass
131
132    def ssl_off(self) -> None:
133        """Turn SSL off (SSL output will not be computed by the forward function)."""
134        self.ssl_active = False
135
136    def ssl_on(self) -> None:
137        """Turn SSL on (SSL output will be computed by the forward function)."""
138        self.ssl_active = True
139
140    def main_task_on(self) -> None:
141        """Turn main task training on."""
142        self.main_task_active = True
143
144    def main_task_off(self) -> None:
145        """Turn main task training on."""
146        self.main_task_active = False
147
148    def set_ssl(
149        self,
150        ssl_constructors: List = None,
151        ssl_types: List = None,
152        ssl_modules: List = None,
153    ) -> None:
154        """Set the SSL modules."""
155        if ssl_constructors is None and ssl_types is None:
156            self.ssl_type = ["none"]
157            self.ssl = [None]
158        else:
159            if ssl_constructors is not None:
160                ssl_types = [
161                    ssl_constructor.type for ssl_constructor in ssl_constructors
162                ]
163                ssl_modules = [
164                    ssl_constructor.construct_module()
165                    for ssl_constructor in ssl_constructors
166                ]
167            if not isinstance(ssl_types, Iterable):
168                ssl_types = [ssl_types]
169                ssl_modules = [ssl_modules]
170            for t in ssl_types:
171                if t not in available_ssl_types:
172                    raise ValueError(
173                        f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}"
174                    )
175            self.ssl_type = ssl_types
176            self.ssl = nn.ModuleList(ssl_modules)
177
178    @abstractmethod
179    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
180        """Construct the feature extractor module.
181
182        Returns
183        -------
184        feature_extractor : torch.nn.Module
185            an instance of torch.nn.Module that has a forward method receiving input and
186            returning features that can be passed to an SSL module and to a prediction module
187
188        """
189
190    @abstractmethod
191    def _predictor(self) -> torch.nn.Module:
192        """Construct the predictor module.
193
194        Returns
195        -------
196        predictor : torch.nn.Module
197            an instance of torch.nn.Module that has a forward method receiving features
198            extracted by self.feature_extractor and returning a prediction
199
200        """
201
202    @abstractmethod
203    def features_shape(self) -> torch.Size:
204        """Get the shape of feature extractor output.
205
206        Returns
207        -------
208        feature_shape : torch.Size
209            shape of feature extractor output
210
211        """
212
213    def extract_features(self, x, start=0):
214        """Apply the feature extraction modules consecutively.
215
216        Parameters
217        ----------
218        x : torch.Tensor
219            the input tensor
220        start : int, default 0
221            the index of the feature extraction module to start with
222
223        Returns
224        -------
225        output : torch.Tensor
226            the output tensor
227
228        """
229        if start == 0:
230            x = self.feature_extractor(x)
231        for extractor in self.feature_extractors[max(0, start - 1) :]:
232            x = extractor(x)
233        return x
234
235    def _extract_features_first(self, x):
236        """Extract features from the first feature extractor."""
237        return self.feature_extractor(x)
238
239    def transform_labels(self, device):
240        """Transform the labels into a tensor of shape (1, n_classes, n_features)."""
241        return {
242            k: self.transform_label(v.to(device)).mean(0).unsqueeze(0)
243            for k, v in self.class_tensors.items()
244        }
245
246    def forward(
247        self,
248        x: torch.Tensor,
249        ssl_xs: list,
250        tag: torch.Tensor = None,
251    ) -> tuple[torch.Tensor, list]:
252        """Generate a prediction for x.
253
254        Parameters
255        ----------
256        x : torch.Tensor
257            the main input
258        ssl_xs : list
259            a list of SSL input tensors
260        tag : any, optional
261            a meta information tag
262
263        Returns
264        -------
265        prediction : torch.Tensor
266            prediction for the main input
267        ssl_out : list
268            a list of SSL prediction tensors
269
270        """
271        ssl_out = None
272        features_0 = self._extract_features_first(x)
273        if len(self.feature_extractors) > 1:
274            features = copy.copy(features_0)
275            features = self.extract_features(features, start=1)
276        else:
277            features = features_0
278        if self.ssl_active:
279            ssl_out = []
280            for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type):
281                if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]:
282                    ssl_features = self.extract_features(ssl_x)
283                if ssl_type == "ssl_input":
284                    if ssl_features.shape[0] > ssl_x.shape[0]:
285                        if ssl_features.shape[0] % ssl_x.shape[0] != 0:
286                            raise ValueError(
287                                "The length of the SSL input tensor must be a multiple of the main input tensor"
288                            )
289                        ssl_features = ssl_features[:ssl_x.shape[0]] #This for C2F-TCN where you want to keep onlw the first (which is the last) feature set
290                    ssl_out.append(ssl(ssl_features))
291                elif ssl_type == "contrastive_2layers":
292                    ssl_out.append(
293                        (ssl(features_0, extract_features=False), ssl(ssl_features))
294                    )
295                elif ssl_type == "contrastive":
296                    ssl_out.append((ssl(features), ssl(ssl_features)))
297                elif ssl_type == "ssl_target":
298                    ssl_out.append(ssl(features))
299        args = [features]
300        if self.main_task_active:
301            x = self.predictor(*args)
302        else:
303            x = None
304        return x, ssl_out
305
306
307class LoadedModel(Model):
308    """A class to generate a Model instance from a torch.nn.Module."""
309
310    ssl_types = ["none"]
311
312    def __init__(self, model: nn.Module, **kwargs) -> None:
313        """Initialize the model.
314
315        Parameters
316        ----------
317        model : torch.nn.Module
318            a model with a forward function that takes a single tensor as input and returns a single tensor as output
319
320        """
321        super(LoadedModel, self).__init__()
322        self.ssl_active = False
323        self.feature_extractor = model
324
325    def _feature_extractor(self) -> None:
326        """Set feature extractor."""
327        pass
328
329    def _predictor(self) -> None:
330        """Set predictor."""
331        self.predictor = nn.Identity()
332
333    def ssl_on(self):
334        """Turn SSL on (SSL output will be computed by the forward function)."""
335        pass
available_ssl_types = ['ssl_input', 'ssl_target', 'contrastive', 'none', 'contrastive_2layers']
class Model(torch.nn.modules.module.Module, abc.ABC):
 28class Model(nn.Module, ABC):
 29    """Base class for all models.
 30
 31    Manages interaction of base model and SSL modules + ensures consistent input and output format.
 32    """
 33
 34    process_labels = False
 35
 36    def __init__(
 37        self,
 38        ssl_constructors: List = None,
 39        ssl_modules: List = None,
 40        ssl_types: List = None,
 41        state_dict_path: str = None,
 42        strict: bool = False,
 43        prompt_function: Callable = None,
 44    ) -> None:
 45        """Initialize the model.
 46
 47        Parameters
 48        ----------
 49        ssl_constructors : list, optional
 50            a list of SSL constructors that build the necessary SSL modules
 51        ssl_modules : list, optional
 52            a list of torch.nn.Module instances that will serve as SSL modules
 53        ssl_types : list, optional
 54            a list of string SSL types
 55        state_dict_path : str, optional
 56            path to the model state dictionary to load
 57        strict : bool, default False
 58            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
 59            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
 60        prompt_function : callable, optional
 61            a function that takes a list of strings and returns a string prompt
 62
 63        """
 64        super(Model, self).__init__()
 65        feature_extractors = self._feature_extractor()
 66        if not isinstance(feature_extractors, list):
 67            feature_extractors = [feature_extractors]
 68        self.feature_extractor = feature_extractors[0]
 69        self.feature_extractors = nn.ModuleList(feature_extractors[1:])
 70        self.predictor = self._predictor()
 71        self.set_ssl(ssl_constructors, ssl_types, ssl_modules)
 72        self.ssl_active = True
 73        self.main_task_active = True
 74        self.prompt_function = prompt_function
 75        self.class_tensors = None
 76        if state_dict_path is not None:
 77            self.load_state_dict(torch.load(state_dict_path), strict=strict)
 78        # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors])
 79        # self.predictor = nn.DataParallel(self.predictor)
 80        # if self.ssl != [None]:
 81        #     self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl])
 82
 83    # def to(self, device, *args, **kwargs):
 84    #     if self.class_tensors is not None:
 85    #         self.class_tensors = {
 86    #             k: v.to(device) for k, v in self.class_tensors.items()
 87    #         }
 88    #     return super().to(device, *args, **kwargs)
 89
 90    def freeze_feature_extractor(self) -> None:
 91        """Freeze the parameters of the feature extraction module."""
 92        for param in self.feature_extractor.parameters():
 93            param.requires_grad = False
 94
 95    def unfreeze_feature_extractor(self) -> None:
 96        """Unfreeze the parameters of the feature extraction module."""
 97        for param in self.feature_extractor.parameters():
 98            param.requires_grad = True
 99
100    def load_state_dict(self, state_dict: str, strict: bool = True) -> None:
101        """Load a model state dictionary.
102
103        Parameters
104        ----------
105        state_dict : str
106            the path to the saved state dictionary
107        strict : bool, default True
108            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
109            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
110
111        """
112        try:
113            super().load_state_dict(state_dict, strict)
114        except RuntimeError as e:
115            if strict:
116                raise e
117            else:
118                warnings.warn(
119                    "Some of the layer shapes do not match the loaded state dictionary, skipping those"
120                )
121                own_state = self.state_dict()
122                for name, param in state_dict.items():
123                    if name not in own_state:
124                        continue
125                    if isinstance(param, nn.Parameter):
126                        # backwards compatibility for serialized parameters
127                        param = param.data
128                    try:
129                        own_state[name].copy_(param)
130                    except:
131                        pass
132
133    def ssl_off(self) -> None:
134        """Turn SSL off (SSL output will not be computed by the forward function)."""
135        self.ssl_active = False
136
137    def ssl_on(self) -> None:
138        """Turn SSL on (SSL output will be computed by the forward function)."""
139        self.ssl_active = True
140
141    def main_task_on(self) -> None:
142        """Turn main task training on."""
143        self.main_task_active = True
144
145    def main_task_off(self) -> None:
146        """Turn main task training on."""
147        self.main_task_active = False
148
149    def set_ssl(
150        self,
151        ssl_constructors: List = None,
152        ssl_types: List = None,
153        ssl_modules: List = None,
154    ) -> None:
155        """Set the SSL modules."""
156        if ssl_constructors is None and ssl_types is None:
157            self.ssl_type = ["none"]
158            self.ssl = [None]
159        else:
160            if ssl_constructors is not None:
161                ssl_types = [
162                    ssl_constructor.type for ssl_constructor in ssl_constructors
163                ]
164                ssl_modules = [
165                    ssl_constructor.construct_module()
166                    for ssl_constructor in ssl_constructors
167                ]
168            if not isinstance(ssl_types, Iterable):
169                ssl_types = [ssl_types]
170                ssl_modules = [ssl_modules]
171            for t in ssl_types:
172                if t not in available_ssl_types:
173                    raise ValueError(
174                        f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}"
175                    )
176            self.ssl_type = ssl_types
177            self.ssl = nn.ModuleList(ssl_modules)
178
179    @abstractmethod
180    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
181        """Construct the feature extractor module.
182
183        Returns
184        -------
185        feature_extractor : torch.nn.Module
186            an instance of torch.nn.Module that has a forward method receiving input and
187            returning features that can be passed to an SSL module and to a prediction module
188
189        """
190
191    @abstractmethod
192    def _predictor(self) -> torch.nn.Module:
193        """Construct the predictor module.
194
195        Returns
196        -------
197        predictor : torch.nn.Module
198            an instance of torch.nn.Module that has a forward method receiving features
199            extracted by self.feature_extractor and returning a prediction
200
201        """
202
203    @abstractmethod
204    def features_shape(self) -> torch.Size:
205        """Get the shape of feature extractor output.
206
207        Returns
208        -------
209        feature_shape : torch.Size
210            shape of feature extractor output
211
212        """
213
214    def extract_features(self, x, start=0):
215        """Apply the feature extraction modules consecutively.
216
217        Parameters
218        ----------
219        x : torch.Tensor
220            the input tensor
221        start : int, default 0
222            the index of the feature extraction module to start with
223
224        Returns
225        -------
226        output : torch.Tensor
227            the output tensor
228
229        """
230        if start == 0:
231            x = self.feature_extractor(x)
232        for extractor in self.feature_extractors[max(0, start - 1) :]:
233            x = extractor(x)
234        return x
235
236    def _extract_features_first(self, x):
237        """Extract features from the first feature extractor."""
238        return self.feature_extractor(x)
239
240    def transform_labels(self, device):
241        """Transform the labels into a tensor of shape (1, n_classes, n_features)."""
242        return {
243            k: self.transform_label(v.to(device)).mean(0).unsqueeze(0)
244            for k, v in self.class_tensors.items()
245        }
246
247    def forward(
248        self,
249        x: torch.Tensor,
250        ssl_xs: list,
251        tag: torch.Tensor = None,
252    ) -> tuple[torch.Tensor, list]:
253        """Generate a prediction for x.
254
255        Parameters
256        ----------
257        x : torch.Tensor
258            the main input
259        ssl_xs : list
260            a list of SSL input tensors
261        tag : any, optional
262            a meta information tag
263
264        Returns
265        -------
266        prediction : torch.Tensor
267            prediction for the main input
268        ssl_out : list
269            a list of SSL prediction tensors
270
271        """
272        ssl_out = None
273        features_0 = self._extract_features_first(x)
274        if len(self.feature_extractors) > 1:
275            features = copy.copy(features_0)
276            features = self.extract_features(features, start=1)
277        else:
278            features = features_0
279        if self.ssl_active:
280            ssl_out = []
281            for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type):
282                if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]:
283                    ssl_features = self.extract_features(ssl_x)
284                if ssl_type == "ssl_input":
285                    if ssl_features.shape[0] > ssl_x.shape[0]:
286                        if ssl_features.shape[0] % ssl_x.shape[0] != 0:
287                            raise ValueError(
288                                "The length of the SSL input tensor must be a multiple of the main input tensor"
289                            )
290                        ssl_features = ssl_features[:ssl_x.shape[0]] #This for C2F-TCN where you want to keep onlw the first (which is the last) feature set
291                    ssl_out.append(ssl(ssl_features))
292                elif ssl_type == "contrastive_2layers":
293                    ssl_out.append(
294                        (ssl(features_0, extract_features=False), ssl(ssl_features))
295                    )
296                elif ssl_type == "contrastive":
297                    ssl_out.append((ssl(features), ssl(ssl_features)))
298                elif ssl_type == "ssl_target":
299                    ssl_out.append(ssl(features))
300        args = [features]
301        if self.main_task_active:
302            x = self.predictor(*args)
303        else:
304            x = None
305        return x, ssl_out

Base class for all models.

Manages interaction of base model and SSL modules + ensures consistent input and output format.

Model( ssl_constructors: List = None, ssl_modules: List = None, ssl_types: List = None, state_dict_path: str = None, strict: bool = False, prompt_function: Callable = None)
36    def __init__(
37        self,
38        ssl_constructors: List = None,
39        ssl_modules: List = None,
40        ssl_types: List = None,
41        state_dict_path: str = None,
42        strict: bool = False,
43        prompt_function: Callable = None,
44    ) -> None:
45        """Initialize the model.
46
47        Parameters
48        ----------
49        ssl_constructors : list, optional
50            a list of SSL constructors that build the necessary SSL modules
51        ssl_modules : list, optional
52            a list of torch.nn.Module instances that will serve as SSL modules
53        ssl_types : list, optional
54            a list of string SSL types
55        state_dict_path : str, optional
56            path to the model state dictionary to load
57        strict : bool, default False
58            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
59            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
60        prompt_function : callable, optional
61            a function that takes a list of strings and returns a string prompt
62
63        """
64        super(Model, self).__init__()
65        feature_extractors = self._feature_extractor()
66        if not isinstance(feature_extractors, list):
67            feature_extractors = [feature_extractors]
68        self.feature_extractor = feature_extractors[0]
69        self.feature_extractors = nn.ModuleList(feature_extractors[1:])
70        self.predictor = self._predictor()
71        self.set_ssl(ssl_constructors, ssl_types, ssl_modules)
72        self.ssl_active = True
73        self.main_task_active = True
74        self.prompt_function = prompt_function
75        self.class_tensors = None
76        if state_dict_path is not None:
77            self.load_state_dict(torch.load(state_dict_path), strict=strict)
78        # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors])
79        # self.predictor = nn.DataParallel(self.predictor)
80        # if self.ssl != [None]:
81        #     self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl])

Initialize the model.

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt

process_labels = False
feature_extractor
feature_extractors
predictor
ssl_active
main_task_active
prompt_function
class_tensors
def freeze_feature_extractor(self) -> None:
90    def freeze_feature_extractor(self) -> None:
91        """Freeze the parameters of the feature extraction module."""
92        for param in self.feature_extractor.parameters():
93            param.requires_grad = False

Freeze the parameters of the feature extraction module.

def unfreeze_feature_extractor(self) -> None:
95    def unfreeze_feature_extractor(self) -> None:
96        """Unfreeze the parameters of the feature extraction module."""
97        for param in self.feature_extractor.parameters():
98            param.requires_grad = True

Unfreeze the parameters of the feature extraction module.

def load_state_dict(self, state_dict: str, strict: bool = True) -> None:
100    def load_state_dict(self, state_dict: str, strict: bool = True) -> None:
101        """Load a model state dictionary.
102
103        Parameters
104        ----------
105        state_dict : str
106            the path to the saved state dictionary
107        strict : bool, default True
108            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
109            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
110
111        """
112        try:
113            super().load_state_dict(state_dict, strict)
114        except RuntimeError as e:
115            if strict:
116                raise e
117            else:
118                warnings.warn(
119                    "Some of the layer shapes do not match the loaded state dictionary, skipping those"
120                )
121                own_state = self.state_dict()
122                for name, param in state_dict.items():
123                    if name not in own_state:
124                        continue
125                    if isinstance(param, nn.Parameter):
126                        # backwards compatibility for serialized parameters
127                        param = param.data
128                    try:
129                        own_state[name].copy_(param)
130                    except:
131                        pass

Load a model state dictionary.

Parameters

state_dict : str the path to the saved state dictionary strict : bool, default True when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored

def ssl_off(self) -> None:
133    def ssl_off(self) -> None:
134        """Turn SSL off (SSL output will not be computed by the forward function)."""
135        self.ssl_active = False

Turn SSL off (SSL output will not be computed by the forward function).

def ssl_on(self) -> None:
137    def ssl_on(self) -> None:
138        """Turn SSL on (SSL output will be computed by the forward function)."""
139        self.ssl_active = True

Turn SSL on (SSL output will be computed by the forward function).

def main_task_on(self) -> None:
141    def main_task_on(self) -> None:
142        """Turn main task training on."""
143        self.main_task_active = True

Turn main task training on.

def main_task_off(self) -> None:
145    def main_task_off(self) -> None:
146        """Turn main task training on."""
147        self.main_task_active = False

Turn main task training on.

def set_ssl( self, ssl_constructors: List = None, ssl_types: List = None, ssl_modules: List = None) -> None:
149    def set_ssl(
150        self,
151        ssl_constructors: List = None,
152        ssl_types: List = None,
153        ssl_modules: List = None,
154    ) -> None:
155        """Set the SSL modules."""
156        if ssl_constructors is None and ssl_types is None:
157            self.ssl_type = ["none"]
158            self.ssl = [None]
159        else:
160            if ssl_constructors is not None:
161                ssl_types = [
162                    ssl_constructor.type for ssl_constructor in ssl_constructors
163                ]
164                ssl_modules = [
165                    ssl_constructor.construct_module()
166                    for ssl_constructor in ssl_constructors
167                ]
168            if not isinstance(ssl_types, Iterable):
169                ssl_types = [ssl_types]
170                ssl_modules = [ssl_modules]
171            for t in ssl_types:
172                if t not in available_ssl_types:
173                    raise ValueError(
174                        f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}"
175                    )
176            self.ssl_type = ssl_types
177            self.ssl = nn.ModuleList(ssl_modules)

Set the SSL modules.

@abstractmethod
def features_shape(self) -> torch.Size:
203    @abstractmethod
204    def features_shape(self) -> torch.Size:
205        """Get the shape of feature extractor output.
206
207        Returns
208        -------
209        feature_shape : torch.Size
210            shape of feature extractor output
211
212        """

Get the shape of feature extractor output.

Returns

feature_shape : torch.Size shape of feature extractor output

def extract_features(self, x, start=0):
214    def extract_features(self, x, start=0):
215        """Apply the feature extraction modules consecutively.
216
217        Parameters
218        ----------
219        x : torch.Tensor
220            the input tensor
221        start : int, default 0
222            the index of the feature extraction module to start with
223
224        Returns
225        -------
226        output : torch.Tensor
227            the output tensor
228
229        """
230        if start == 0:
231            x = self.feature_extractor(x)
232        for extractor in self.feature_extractors[max(0, start - 1) :]:
233            x = extractor(x)
234        return x

Apply the feature extraction modules consecutively.

Parameters

x : torch.Tensor the input tensor start : int, default 0 the index of the feature extraction module to start with

Returns

output : torch.Tensor the output tensor

def transform_labels(self, device):
240    def transform_labels(self, device):
241        """Transform the labels into a tensor of shape (1, n_classes, n_features)."""
242        return {
243            k: self.transform_label(v.to(device)).mean(0).unsqueeze(0)
244            for k, v in self.class_tensors.items()
245        }

Transform the labels into a tensor of shape (1, n_classes, n_features).

def forward( self, x: torch.Tensor, ssl_xs: list, tag: torch.Tensor = None) -> tuple[torch.Tensor, list]:
247    def forward(
248        self,
249        x: torch.Tensor,
250        ssl_xs: list,
251        tag: torch.Tensor = None,
252    ) -> tuple[torch.Tensor, list]:
253        """Generate a prediction for x.
254
255        Parameters
256        ----------
257        x : torch.Tensor
258            the main input
259        ssl_xs : list
260            a list of SSL input tensors
261        tag : any, optional
262            a meta information tag
263
264        Returns
265        -------
266        prediction : torch.Tensor
267            prediction for the main input
268        ssl_out : list
269            a list of SSL prediction tensors
270
271        """
272        ssl_out = None
273        features_0 = self._extract_features_first(x)
274        if len(self.feature_extractors) > 1:
275            features = copy.copy(features_0)
276            features = self.extract_features(features, start=1)
277        else:
278            features = features_0
279        if self.ssl_active:
280            ssl_out = []
281            for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type):
282                if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]:
283                    ssl_features = self.extract_features(ssl_x)
284                if ssl_type == "ssl_input":
285                    if ssl_features.shape[0] > ssl_x.shape[0]:
286                        if ssl_features.shape[0] % ssl_x.shape[0] != 0:
287                            raise ValueError(
288                                "The length of the SSL input tensor must be a multiple of the main input tensor"
289                            )
290                        ssl_features = ssl_features[:ssl_x.shape[0]] #This for C2F-TCN where you want to keep onlw the first (which is the last) feature set
291                    ssl_out.append(ssl(ssl_features))
292                elif ssl_type == "contrastive_2layers":
293                    ssl_out.append(
294                        (ssl(features_0, extract_features=False), ssl(ssl_features))
295                    )
296                elif ssl_type == "contrastive":
297                    ssl_out.append((ssl(features), ssl(ssl_features)))
298                elif ssl_type == "ssl_target":
299                    ssl_out.append(ssl(features))
300        args = [features]
301        if self.main_task_active:
302            x = self.predictor(*args)
303        else:
304            x = None
305        return x, ssl_out

Generate a prediction for x.

Parameters

x : torch.Tensor the main input ssl_xs : list a list of SSL input tensors tag : any, optional a meta information tag

Returns

prediction : torch.Tensor prediction for the main input ssl_out : list a list of SSL prediction tensors

class LoadedModel(Model):
308class LoadedModel(Model):
309    """A class to generate a Model instance from a torch.nn.Module."""
310
311    ssl_types = ["none"]
312
313    def __init__(self, model: nn.Module, **kwargs) -> None:
314        """Initialize the model.
315
316        Parameters
317        ----------
318        model : torch.nn.Module
319            a model with a forward function that takes a single tensor as input and returns a single tensor as output
320
321        """
322        super(LoadedModel, self).__init__()
323        self.ssl_active = False
324        self.feature_extractor = model
325
326    def _feature_extractor(self) -> None:
327        """Set feature extractor."""
328        pass
329
330    def _predictor(self) -> None:
331        """Set predictor."""
332        self.predictor = nn.Identity()
333
334    def ssl_on(self):
335        """Turn SSL on (SSL output will be computed by the forward function)."""
336        pass

A class to generate a Model instance from a torch.nn.Module.

LoadedModel(model: torch.nn.modules.module.Module, **kwargs)
313    def __init__(self, model: nn.Module, **kwargs) -> None:
314        """Initialize the model.
315
316        Parameters
317        ----------
318        model : torch.nn.Module
319            a model with a forward function that takes a single tensor as input and returns a single tensor as output
320
321        """
322        super(LoadedModel, self).__init__()
323        self.ssl_active = False
324        self.feature_extractor = model

Initialize the model.

Parameters

model : torch.nn.Module a model with a forward function that takes a single tensor as input and returns a single tensor as output

ssl_types = ['none']
ssl_active
feature_extractor
def ssl_on(self):
334    def ssl_on(self):
335        """Turn SSL on (SSL output will be computed by the forward function)."""
336        pass

Turn SSL on (SSL output will be computed by the forward function).