dlc2action.model.base_model

Abstract parent class for models used in dlc2action.task.universal_task.Task

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6"""
  7Abstract parent class for models used in `dlc2action.task.universal_task.Task`
  8"""
  9
 10from typing import Dict, Union, List, Callable, Tuple
 11from torch import nn
 12from abc import ABC, abstractmethod
 13import torch
 14import warnings
 15from collections.abc import Iterable
 16import copy
 17
 18available_ssl_types = [
 19    "ssl_input",
 20    "ssl_target",
 21    "contrastive",
 22    "none",
 23    "contrastive_2layers",
 24]
 25
 26
 27class Model(nn.Module, ABC):
 28    """
 29    Base class for all models
 30
 31    Manages interaction of base model and SSL modules + ensures consistent input and output format
 32    """
 33
 34    process_labels = False
 35
 36    def __init__(
 37        self,
 38        ssl_constructors: List = None,
 39        ssl_modules: List = None,
 40        ssl_types: List = None,
 41        state_dict_path: str = None,
 42        strict: bool = False,
 43        prompt_function: Callable = None,
 44    ) -> None:
 45        """
 46        Parameters
 47        ----------
 48        ssl_constructors : list, optional
 49            a list of SSL constructors that build the necessary SSL modules
 50        ssl_modules : list, optional
 51            a list of torch.nn.Module instances that will serve as SSL modules
 52        ssl_types : list, optional
 53            a list of string SSL types
 54        state_dict_path : str, optional
 55            path to the model state dictionary to load
 56        strict : bool, default False
 57            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
 58            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
 59        """
 60
 61        super(Model, self).__init__()
 62        feature_extractors = self._feature_extractor()
 63        if not isinstance(feature_extractors, list):
 64            feature_extractors = [feature_extractors]
 65        self.feature_extractor = feature_extractors[0]
 66        self.feature_extractors = nn.ModuleList(feature_extractors[1:])
 67        self.predictor = self._predictor()
 68        self.set_ssl(ssl_constructors, ssl_types, ssl_modules)
 69        self.ssl_active = True
 70        self.main_task_active = True
 71        self.prompt_function = prompt_function
 72        self.class_tensors = None
 73        if state_dict_path is not None:
 74            self.load_state_dict(torch.load(state_dict_path), strict=strict)
 75        # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors])
 76        # self.predictor = nn.DataParallel(self.predictor)
 77        # if self.ssl != [None]:
 78        #     self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl])
 79
 80    # def to(self, device, *args, **kwargs):
 81    #     if self.class_tensors is not None:
 82    #         self.class_tensors = {
 83    #             k: v.to(device) for k, v in self.class_tensors.items()
 84    #         }
 85    #     return super().to(device, *args, **kwargs)
 86
 87    def freeze_feature_extractor(self) -> None:
 88        """
 89        Freeze the parameters of the feature extraction module
 90        """
 91
 92        for param in self.feature_extractor.parameters():
 93            param.requires_grad = False
 94
 95    def unfreeze_feature_extractor(self) -> None:
 96        """
 97        Unfreeze the parameters of the feature extraction module
 98        """
 99
100        for param in self.feature_extractor.parameters():
101            param.requires_grad = True
102
103    def load_state_dict(self, state_dict: str, strict: bool = True) -> None:
104        """
105        Load a model state dictionary
106
107        Parameters
108        ----------
109        state_dict : str
110            the path to the saved state dictionary
111        strict : bool, default True
112            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
113            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
114        """
115
116        try:
117            super().load_state_dict(state_dict, strict)
118        except RuntimeError as e:
119            if strict:
120                raise e
121            else:
122                warnings.warn(
123                    "Some of the layer shapes do not match the loaded state dictionary, skipping those"
124                )
125                own_state = self.state_dict()
126                for name, param in state_dict.items():
127                    if name not in own_state:
128                        continue
129                    if isinstance(param, nn.Parameter):
130                        # backwards compatibility for serialized parameters
131                        param = param.data
132                    try:
133                        own_state[name].copy_(param)
134                    except:
135                        pass
136
137    def ssl_off(self) -> None:
138        """
139        Turn SSL off (SSL output will not be computed by the forward function)
140        """
141
142        self.ssl_active = False
143
144    def ssl_on(self) -> None:
145        """
146        Turn SSL on (SSL output will be computed by the forward function)
147        """
148
149        self.ssl_active = True
150
151    def main_task_on(self) -> None:
152        """
153        Turn main task training on
154        """
155
156        self.main_task_active = True
157
158    def main_task_off(self) -> None:
159        """
160        Turn main task training on
161        """
162
163        self.main_task_active = False
164
165    def set_ssl(
166        self,
167        ssl_constructors: List = None,
168        ssl_types: List = None,
169        ssl_modules: List = None,
170    ) -> None:
171        """
172        Set the SSL modules
173        """
174
175        if ssl_constructors is None and ssl_types is None:
176            self.ssl_type = ["none"]
177            self.ssl = [None]
178        else:
179            if ssl_constructors is not None:
180                ssl_types = [
181                    ssl_constructor.type for ssl_constructor in ssl_constructors
182                ]
183                ssl_modules = [
184                    ssl_constructor.construct_module()
185                    for ssl_constructor in ssl_constructors
186                ]
187            if not isinstance(ssl_types, Iterable):
188                ssl_types = [ssl_types]
189                ssl_modules = [ssl_modules]
190            for t in ssl_types:
191                if t not in available_ssl_types:
192                    raise ValueError(
193                        f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}"
194                    )
195            self.ssl_type = ssl_types
196            self.ssl = nn.ModuleList(ssl_modules)
197
198    @abstractmethod
199    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
200        """
201        Construct the feature extractor module
202
203        Returns
204        -------
205        feature_extractor : torch.nn.Module
206            an instance of torch.nn.Module that has a forward method receiving input and
207            returning features that can be passed to an SSL module and to a prediction module
208        """
209
210    @abstractmethod
211    def _predictor(self) -> torch.nn.Module:
212        """
213        Construct the predictor module
214
215        Returns
216        -------
217        predictor : torch.nn.Module
218            an instance of torch.nn.Module that has a forward method receiving features
219            exctracted by self.feature_extractor and returning a prediction
220        """
221
222    @abstractmethod
223    def features_shape(self) -> torch.Size:
224        """
225        Get the shape of feature extractor output
226
227        Returns
228        -------
229        feature_shape : torch.Size
230            shape of feature extractor output
231        """
232
233    def extract_features(self, x, start=0):
234        """
235        Apply the feature extraction modules consecutively
236
237        Parameters
238        ----------
239        x : torch.Tensor
240            the input tensor
241        start : int, default 0
242            the index of the feature extraction module to start with
243
244        Returns
245        -------
246        output : torch.Tensor
247            the output tensor
248        """
249
250        if start == 0:
251            x = self.feature_extractor(x)
252        for extractor in self.feature_extractors[max(0, start - 1) :]:
253            x = extractor(x)
254        return x
255
256    def _extract_features_first(self, x):
257        return self.feature_extractor(x)
258
259    def forward(
260        self,
261        x: torch.Tensor,
262        ssl_xs: list,
263        tag: torch.Tensor = None,
264    ) -> Tuple[torch.Tensor, list]:
265        """
266        Generate a prediction for x
267
268        Parameters
269        ----------
270        x : torch.Tensor
271            the main input
272        ssl_xs : list
273            a list of SSL input tensors
274        tag : any, optional
275            a meta information tag
276
277        Returns
278        -------
279        prediction : torch.Tensor
280            prediction for the main input
281        ssl_out : list
282            a list of SSL prediction tensors
283        """
284
285        ssl_out = None
286        features_0 = self._extract_features_first(x)
287        if len(self.feature_extractors) > 1:
288            features = copy.copy(features_0)
289            features = self.extract_features(features, start=1)
290        else:
291            features = features_0
292        if self.ssl_active:
293            ssl_out = []
294            for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type):
295                if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]:
296                    ssl_features = self.extract_features(ssl_x)
297                if ssl_type == "ssl_input":
298                    ssl_out.append(ssl(ssl_features))
299                elif ssl_type == "contrastive_2layers":
300                    ssl_out.append(
301                        (ssl(features_0, extract_features=False), ssl(ssl_features))
302                    )
303                elif ssl_type == "contrastive":
304                    ssl_out.append((ssl(features), ssl(ssl_features)))
305                elif ssl_type == "ssl_target":
306                    ssl_out.append(ssl(features))
307        args = [features]
308        if tag is not None:
309            args.append(tag)
310        if self.main_task_active:
311            x = self.predictor(*args)
312        else:
313            x = None
314        return x, ssl_out
315
316
317class LoadedModel(Model):
318    """
319    A class to generate a Model instance from a torch.nn.Module
320    """
321
322    ssl_types = ["none"]
323
324    def __init__(self, model: nn.Module, **kwargs) -> None:
325        """
326        Parameters
327        ----------
328        model : torch.nn.Module
329            a model with a forward function that takes a single tensor as input and returns a single tensor as output
330        """
331
332        super(LoadedModel, self).__init__()
333        self.ssl_active = False
334        self.feature_extractor = model
335
336    def _feature_extractor(self) -> None:
337        """
338        Set feature extractor
339        """
340
341        pass
342
343    def _predictor(self) -> None:
344        """
345        Set predictor
346        """
347
348        self.predictor = nn.Identity()
349
350    def ssl_on(self):
351        """
352        Turn SSL on (SSL output will be computed by the forward function)
353        """
354
355        pass
class Model(torch.nn.modules.module.Module, abc.ABC):
 28class Model(nn.Module, ABC):
 29    """
 30    Base class for all models
 31
 32    Manages interaction of base model and SSL modules + ensures consistent input and output format
 33    """
 34
 35    process_labels = False
 36
 37    def __init__(
 38        self,
 39        ssl_constructors: List = None,
 40        ssl_modules: List = None,
 41        ssl_types: List = None,
 42        state_dict_path: str = None,
 43        strict: bool = False,
 44        prompt_function: Callable = None,
 45    ) -> None:
 46        """
 47        Parameters
 48        ----------
 49        ssl_constructors : list, optional
 50            a list of SSL constructors that build the necessary SSL modules
 51        ssl_modules : list, optional
 52            a list of torch.nn.Module instances that will serve as SSL modules
 53        ssl_types : list, optional
 54            a list of string SSL types
 55        state_dict_path : str, optional
 56            path to the model state dictionary to load
 57        strict : bool, default False
 58            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
 59            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
 60        """
 61
 62        super(Model, self).__init__()
 63        feature_extractors = self._feature_extractor()
 64        if not isinstance(feature_extractors, list):
 65            feature_extractors = [feature_extractors]
 66        self.feature_extractor = feature_extractors[0]
 67        self.feature_extractors = nn.ModuleList(feature_extractors[1:])
 68        self.predictor = self._predictor()
 69        self.set_ssl(ssl_constructors, ssl_types, ssl_modules)
 70        self.ssl_active = True
 71        self.main_task_active = True
 72        self.prompt_function = prompt_function
 73        self.class_tensors = None
 74        if state_dict_path is not None:
 75            self.load_state_dict(torch.load(state_dict_path), strict=strict)
 76        # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors])
 77        # self.predictor = nn.DataParallel(self.predictor)
 78        # if self.ssl != [None]:
 79        #     self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl])
 80
 81    # def to(self, device, *args, **kwargs):
 82    #     if self.class_tensors is not None:
 83    #         self.class_tensors = {
 84    #             k: v.to(device) for k, v in self.class_tensors.items()
 85    #         }
 86    #     return super().to(device, *args, **kwargs)
 87
 88    def freeze_feature_extractor(self) -> None:
 89        """
 90        Freeze the parameters of the feature extraction module
 91        """
 92
 93        for param in self.feature_extractor.parameters():
 94            param.requires_grad = False
 95
 96    def unfreeze_feature_extractor(self) -> None:
 97        """
 98        Unfreeze the parameters of the feature extraction module
 99        """
100
101        for param in self.feature_extractor.parameters():
102            param.requires_grad = True
103
104    def load_state_dict(self, state_dict: str, strict: bool = True) -> None:
105        """
106        Load a model state dictionary
107
108        Parameters
109        ----------
110        state_dict : str
111            the path to the saved state dictionary
112        strict : bool, default True
113            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
114            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
115        """
116
117        try:
118            super().load_state_dict(state_dict, strict)
119        except RuntimeError as e:
120            if strict:
121                raise e
122            else:
123                warnings.warn(
124                    "Some of the layer shapes do not match the loaded state dictionary, skipping those"
125                )
126                own_state = self.state_dict()
127                for name, param in state_dict.items():
128                    if name not in own_state:
129                        continue
130                    if isinstance(param, nn.Parameter):
131                        # backwards compatibility for serialized parameters
132                        param = param.data
133                    try:
134                        own_state[name].copy_(param)
135                    except:
136                        pass
137
138    def ssl_off(self) -> None:
139        """
140        Turn SSL off (SSL output will not be computed by the forward function)
141        """
142
143        self.ssl_active = False
144
145    def ssl_on(self) -> None:
146        """
147        Turn SSL on (SSL output will be computed by the forward function)
148        """
149
150        self.ssl_active = True
151
152    def main_task_on(self) -> None:
153        """
154        Turn main task training on
155        """
156
157        self.main_task_active = True
158
159    def main_task_off(self) -> None:
160        """
161        Turn main task training on
162        """
163
164        self.main_task_active = False
165
166    def set_ssl(
167        self,
168        ssl_constructors: List = None,
169        ssl_types: List = None,
170        ssl_modules: List = None,
171    ) -> None:
172        """
173        Set the SSL modules
174        """
175
176        if ssl_constructors is None and ssl_types is None:
177            self.ssl_type = ["none"]
178            self.ssl = [None]
179        else:
180            if ssl_constructors is not None:
181                ssl_types = [
182                    ssl_constructor.type for ssl_constructor in ssl_constructors
183                ]
184                ssl_modules = [
185                    ssl_constructor.construct_module()
186                    for ssl_constructor in ssl_constructors
187                ]
188            if not isinstance(ssl_types, Iterable):
189                ssl_types = [ssl_types]
190                ssl_modules = [ssl_modules]
191            for t in ssl_types:
192                if t not in available_ssl_types:
193                    raise ValueError(
194                        f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}"
195                    )
196            self.ssl_type = ssl_types
197            self.ssl = nn.ModuleList(ssl_modules)
198
199    @abstractmethod
200    def _feature_extractor(self) -> Union[torch.nn.Module, List]:
201        """
202        Construct the feature extractor module
203
204        Returns
205        -------
206        feature_extractor : torch.nn.Module
207            an instance of torch.nn.Module that has a forward method receiving input and
208            returning features that can be passed to an SSL module and to a prediction module
209        """
210
211    @abstractmethod
212    def _predictor(self) -> torch.nn.Module:
213        """
214        Construct the predictor module
215
216        Returns
217        -------
218        predictor : torch.nn.Module
219            an instance of torch.nn.Module that has a forward method receiving features
220            exctracted by self.feature_extractor and returning a prediction
221        """
222
223    @abstractmethod
224    def features_shape(self) -> torch.Size:
225        """
226        Get the shape of feature extractor output
227
228        Returns
229        -------
230        feature_shape : torch.Size
231            shape of feature extractor output
232        """
233
234    def extract_features(self, x, start=0):
235        """
236        Apply the feature extraction modules consecutively
237
238        Parameters
239        ----------
240        x : torch.Tensor
241            the input tensor
242        start : int, default 0
243            the index of the feature extraction module to start with
244
245        Returns
246        -------
247        output : torch.Tensor
248            the output tensor
249        """
250
251        if start == 0:
252            x = self.feature_extractor(x)
253        for extractor in self.feature_extractors[max(0, start - 1) :]:
254            x = extractor(x)
255        return x
256
257    def _extract_features_first(self, x):
258        return self.feature_extractor(x)
259
260    def forward(
261        self,
262        x: torch.Tensor,
263        ssl_xs: list,
264        tag: torch.Tensor = None,
265    ) -> Tuple[torch.Tensor, list]:
266        """
267        Generate a prediction for x
268
269        Parameters
270        ----------
271        x : torch.Tensor
272            the main input
273        ssl_xs : list
274            a list of SSL input tensors
275        tag : any, optional
276            a meta information tag
277
278        Returns
279        -------
280        prediction : torch.Tensor
281            prediction for the main input
282        ssl_out : list
283            a list of SSL prediction tensors
284        """
285
286        ssl_out = None
287        features_0 = self._extract_features_first(x)
288        if len(self.feature_extractors) > 1:
289            features = copy.copy(features_0)
290            features = self.extract_features(features, start=1)
291        else:
292            features = features_0
293        if self.ssl_active:
294            ssl_out = []
295            for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type):
296                if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]:
297                    ssl_features = self.extract_features(ssl_x)
298                if ssl_type == "ssl_input":
299                    ssl_out.append(ssl(ssl_features))
300                elif ssl_type == "contrastive_2layers":
301                    ssl_out.append(
302                        (ssl(features_0, extract_features=False), ssl(ssl_features))
303                    )
304                elif ssl_type == "contrastive":
305                    ssl_out.append((ssl(features), ssl(ssl_features)))
306                elif ssl_type == "ssl_target":
307                    ssl_out.append(ssl(features))
308        args = [features]
309        if tag is not None:
310            args.append(tag)
311        if self.main_task_active:
312            x = self.predictor(*args)
313        else:
314            x = None
315        return x, ssl_out

Base class for all models

Manages interaction of base model and SSL modules + ensures consistent input and output format

Model( ssl_constructors: List = None, ssl_modules: List = None, ssl_types: List = None, state_dict_path: str = None, strict: bool = False, prompt_function: Callable = None)
37    def __init__(
38        self,
39        ssl_constructors: List = None,
40        ssl_modules: List = None,
41        ssl_types: List = None,
42        state_dict_path: str = None,
43        strict: bool = False,
44        prompt_function: Callable = None,
45    ) -> None:
46        """
47        Parameters
48        ----------
49        ssl_constructors : list, optional
50            a list of SSL constructors that build the necessary SSL modules
51        ssl_modules : list, optional
52            a list of torch.nn.Module instances that will serve as SSL modules
53        ssl_types : list, optional
54            a list of string SSL types
55        state_dict_path : str, optional
56            path to the model state dictionary to load
57        strict : bool, default False
58            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
59            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
60        """
61
62        super(Model, self).__init__()
63        feature_extractors = self._feature_extractor()
64        if not isinstance(feature_extractors, list):
65            feature_extractors = [feature_extractors]
66        self.feature_extractor = feature_extractors[0]
67        self.feature_extractors = nn.ModuleList(feature_extractors[1:])
68        self.predictor = self._predictor()
69        self.set_ssl(ssl_constructors, ssl_types, ssl_modules)
70        self.ssl_active = True
71        self.main_task_active = True
72        self.prompt_function = prompt_function
73        self.class_tensors = None
74        if state_dict_path is not None:
75            self.load_state_dict(torch.load(state_dict_path), strict=strict)
76        # self.feature_extractors = nn.ModuleList([nn.DataParallel(x) for x in self.feature_extractors])
77        # self.predictor = nn.DataParallel(self.predictor)
78        # if self.ssl != [None]:
79        #     self.ssl = nn.ModuleList([nn.DataParallel(x) for x in self.ssl])

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored

process_labels = False
def freeze_feature_extractor(self) -> None:
88    def freeze_feature_extractor(self) -> None:
89        """
90        Freeze the parameters of the feature extraction module
91        """
92
93        for param in self.feature_extractor.parameters():
94            param.requires_grad = False

Freeze the parameters of the feature extraction module

def unfreeze_feature_extractor(self) -> None:
 96    def unfreeze_feature_extractor(self) -> None:
 97        """
 98        Unfreeze the parameters of the feature extraction module
 99        """
100
101        for param in self.feature_extractor.parameters():
102            param.requires_grad = True

Unfreeze the parameters of the feature extraction module

def load_state_dict(self, state_dict: str, strict: bool = True) -> None:
104    def load_state_dict(self, state_dict: str, strict: bool = True) -> None:
105        """
106        Load a model state dictionary
107
108        Parameters
109        ----------
110        state_dict : str
111            the path to the saved state dictionary
112        strict : bool, default True
113            when True, the state dictionary will only be loaded if the current and the loaded architecture are the same;
114            otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored
115        """
116
117        try:
118            super().load_state_dict(state_dict, strict)
119        except RuntimeError as e:
120            if strict:
121                raise e
122            else:
123                warnings.warn(
124                    "Some of the layer shapes do not match the loaded state dictionary, skipping those"
125                )
126                own_state = self.state_dict()
127                for name, param in state_dict.items():
128                    if name not in own_state:
129                        continue
130                    if isinstance(param, nn.Parameter):
131                        # backwards compatibility for serialized parameters
132                        param = param.data
133                    try:
134                        own_state[name].copy_(param)
135                    except:
136                        pass

Load a model state dictionary

Parameters

state_dict : str the path to the saved state dictionary strict : bool, default True when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored

def ssl_off(self) -> None:
138    def ssl_off(self) -> None:
139        """
140        Turn SSL off (SSL output will not be computed by the forward function)
141        """
142
143        self.ssl_active = False

Turn SSL off (SSL output will not be computed by the forward function)

def ssl_on(self) -> None:
145    def ssl_on(self) -> None:
146        """
147        Turn SSL on (SSL output will be computed by the forward function)
148        """
149
150        self.ssl_active = True

Turn SSL on (SSL output will be computed by the forward function)

def main_task_on(self) -> None:
152    def main_task_on(self) -> None:
153        """
154        Turn main task training on
155        """
156
157        self.main_task_active = True

Turn main task training on

def main_task_off(self) -> None:
159    def main_task_off(self) -> None:
160        """
161        Turn main task training on
162        """
163
164        self.main_task_active = False

Turn main task training on

def set_ssl( self, ssl_constructors: List = None, ssl_types: List = None, ssl_modules: List = None) -> None:
166    def set_ssl(
167        self,
168        ssl_constructors: List = None,
169        ssl_types: List = None,
170        ssl_modules: List = None,
171    ) -> None:
172        """
173        Set the SSL modules
174        """
175
176        if ssl_constructors is None and ssl_types is None:
177            self.ssl_type = ["none"]
178            self.ssl = [None]
179        else:
180            if ssl_constructors is not None:
181                ssl_types = [
182                    ssl_constructor.type for ssl_constructor in ssl_constructors
183                ]
184                ssl_modules = [
185                    ssl_constructor.construct_module()
186                    for ssl_constructor in ssl_constructors
187                ]
188            if not isinstance(ssl_types, Iterable):
189                ssl_types = [ssl_types]
190                ssl_modules = [ssl_modules]
191            for t in ssl_types:
192                if t not in available_ssl_types:
193                    raise ValueError(
194                        f"SSL type {t} is not implemented yet, please choose from {available_ssl_types}"
195                    )
196            self.ssl_type = ssl_types
197            self.ssl = nn.ModuleList(ssl_modules)

Set the SSL modules

@abstractmethod
def features_shape(self) -> torch.Size:
223    @abstractmethod
224    def features_shape(self) -> torch.Size:
225        """
226        Get the shape of feature extractor output
227
228        Returns
229        -------
230        feature_shape : torch.Size
231            shape of feature extractor output
232        """

Get the shape of feature extractor output

Returns

feature_shape : torch.Size shape of feature extractor output

def extract_features(self, x, start=0)
234    def extract_features(self, x, start=0):
235        """
236        Apply the feature extraction modules consecutively
237
238        Parameters
239        ----------
240        x : torch.Tensor
241            the input tensor
242        start : int, default 0
243            the index of the feature extraction module to start with
244
245        Returns
246        -------
247        output : torch.Tensor
248            the output tensor
249        """
250
251        if start == 0:
252            x = self.feature_extractor(x)
253        for extractor in self.feature_extractors[max(0, start - 1) :]:
254            x = extractor(x)
255        return x

Apply the feature extraction modules consecutively

Parameters

x : torch.Tensor the input tensor start : int, default 0 the index of the feature extraction module to start with

Returns

output : torch.Tensor the output tensor

def forward( self, x: torch.Tensor, ssl_xs: list, tag: torch.Tensor = None) -> Tuple[torch.Tensor, list]:
260    def forward(
261        self,
262        x: torch.Tensor,
263        ssl_xs: list,
264        tag: torch.Tensor = None,
265    ) -> Tuple[torch.Tensor, list]:
266        """
267        Generate a prediction for x
268
269        Parameters
270        ----------
271        x : torch.Tensor
272            the main input
273        ssl_xs : list
274            a list of SSL input tensors
275        tag : any, optional
276            a meta information tag
277
278        Returns
279        -------
280        prediction : torch.Tensor
281            prediction for the main input
282        ssl_out : list
283            a list of SSL prediction tensors
284        """
285
286        ssl_out = None
287        features_0 = self._extract_features_first(x)
288        if len(self.feature_extractors) > 1:
289            features = copy.copy(features_0)
290            features = self.extract_features(features, start=1)
291        else:
292            features = features_0
293        if self.ssl_active:
294            ssl_out = []
295            for ssl, ssl_x, ssl_type in zip(self.ssl, ssl_xs, self.ssl_type):
296                if ssl_type in ["contrastive", "ssl_input", "contrastive_2layers"]:
297                    ssl_features = self.extract_features(ssl_x)
298                if ssl_type == "ssl_input":
299                    ssl_out.append(ssl(ssl_features))
300                elif ssl_type == "contrastive_2layers":
301                    ssl_out.append(
302                        (ssl(features_0, extract_features=False), ssl(ssl_features))
303                    )
304                elif ssl_type == "contrastive":
305                    ssl_out.append((ssl(features), ssl(ssl_features)))
306                elif ssl_type == "ssl_target":
307                    ssl_out.append(ssl(features))
308        args = [features]
309        if tag is not None:
310            args.append(tag)
311        if self.main_task_active:
312            x = self.predictor(*args)
313        else:
314            x = None
315        return x, ssl_out

Generate a prediction for x

Parameters

x : torch.Tensor the main input ssl_xs : list a list of SSL input tensors tag : any, optional a meta information tag

Returns

prediction : torch.Tensor prediction for the main input ssl_out : list a list of SSL prediction tensors

Inherited Members
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
T_destination
state_dict
register_load_state_dict_post_hook
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
class LoadedModel(Model):
318class LoadedModel(Model):
319    """
320    A class to generate a Model instance from a torch.nn.Module
321    """
322
323    ssl_types = ["none"]
324
325    def __init__(self, model: nn.Module, **kwargs) -> None:
326        """
327        Parameters
328        ----------
329        model : torch.nn.Module
330            a model with a forward function that takes a single tensor as input and returns a single tensor as output
331        """
332
333        super(LoadedModel, self).__init__()
334        self.ssl_active = False
335        self.feature_extractor = model
336
337    def _feature_extractor(self) -> None:
338        """
339        Set feature extractor
340        """
341
342        pass
343
344    def _predictor(self) -> None:
345        """
346        Set predictor
347        """
348
349        self.predictor = nn.Identity()
350
351    def ssl_on(self):
352        """
353        Turn SSL on (SSL output will be computed by the forward function)
354        """
355
356        pass

A class to generate a Model instance from a torch.nn.Module

LoadedModel(model: torch.nn.modules.module.Module, **kwargs)
325    def __init__(self, model: nn.Module, **kwargs) -> None:
326        """
327        Parameters
328        ----------
329        model : torch.nn.Module
330            a model with a forward function that takes a single tensor as input and returns a single tensor as output
331        """
332
333        super(LoadedModel, self).__init__()
334        self.ssl_active = False
335        self.feature_extractor = model

Parameters

model : torch.nn.Module a model with a forward function that takes a single tensor as input and returns a single tensor as output

ssl_types = ['none']
def ssl_on(self)
351    def ssl_on(self):
352        """
353        Turn SSL on (SSL output will be computed by the forward function)
354        """
355
356        pass

Turn SSL on (SSL output will be computed by the forward function)

Inherited Members
Model
process_labels
freeze_feature_extractor
unfreeze_feature_extractor
load_state_dict
ssl_off
main_task_on
main_task_off
set_ssl
features_shape
extract_features
forward
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
T_destination
state_dict
register_load_state_dict_post_hook
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr