dlc2action.ssl.base_ssl

Abstract class for defining SSL tasks.

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""Abstract class for defining SSL tasks."""
  8
  9from typing import Callable, Tuple, Dict, Union
 10from abc import ABC, abstractmethod
 11import torch
 12from torch import nn
 13
 14torch.manual_seed(0)
 15
 16class SSLConstructor(ABC):
 17    """A base class for all SSL constructors.
 18
 19    An SSL method is defined by three things: a *transformation* that maps a sample into SSL input and output,
 20    a neural net *module* that takes features as input and predicts SSL target, a *type* and a *loss function*.
 21    """
 22
 23    type = "none"
 24    """
 25    The `type` parameter defines interaction with the model:
 26
 27    - `'ssl_input'`: a modification of the input data passes through the base network feature extraction module and the
 28    SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
 29    - `'ssl_output'`:  the input data passes through the base network feature extraction module and the SSL module; it
 30    is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
 31    - `'contrastive'`:  the input data and its modification pass through the base network feature extraction module and
 32    the SSL module; an (input results, modification results) tuple is returned as SSL output,
 33    - `'contrastive_2layers'`: the input data and its modification pass through the base network feature extraction module;
 34    the output of the second feature extraction layer for the modified data goes through an SSL module and then,
 35    optionally, that result and the first-level unmodified features pass another transformation;
 36    an (input results, modified results) tuple is returned as SSL output,
 37    """
 38
 39    def __init__(self, *args, **kwargs) -> None:
 40        """Initialize the SSL constructor."""
 41        ...
 42
 43    @abstractmethod
 44    def transformation(self, sample_data: Dict) -> (Dict, Dict):
 45        """Transform a sample feature dictionary into SSL input and target.
 46
 47        Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets
 48        with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample.
 49        If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before
 50        all features are stacked together.
 51
 52        Parameters
 53        ----------
 54        sample_data : dict
 55            a feature dictionary
 56
 57        Returns
 58        -------
 59        ssl_input : dict | torch.float('nan')
 60            a feature dictionary of SSL inputs
 61        ssl_target : dict | torch.float('nan')
 62            a feature dictionary of SSL targets
 63
 64        """
 65
 66    @abstractmethod
 67    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
 68        """Calculate the SSL loss.
 69
 70        Parameters
 71        ----------
 72        predicted : torch.Tensor
 73            output of the SSL module
 74        target : torch.Tensor
 75            augmented and stacked SSL_target
 76
 77        Returns
 78        -------
 79        loss : float
 80            the loss value
 81
 82        """
 83
 84    @abstractmethod
 85    def construct_module(self) -> nn.Module:
 86        """Construct the SSL module.
 87
 88        Returns
 89        -------
 90        ssl_module : torch.nn.Module
 91            a neural net module that takes features extracted by a model's feature extractor as input and
 92            returns SSL output
 93
 94        """
 95
 96
 97class EmptySSL(SSLConstructor):
 98    """Empty SSL class."""
 99
100    def transformation(self, sample_data: Dict) -> Tuple:
101        """Empty transformation."""
102        return (torch.tensor(float("nan")), torch.tensor(float("nan")))
103
104    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
105        """Empty loss."""
106        return 0
107
108    def construct_module(self) -> None:
109        """Empty module."""
110        return None
class SSLConstructor(abc.ABC):
17class SSLConstructor(ABC):
18    """A base class for all SSL constructors.
19
20    An SSL method is defined by three things: a *transformation* that maps a sample into SSL input and output,
21    a neural net *module* that takes features as input and predicts SSL target, a *type* and a *loss function*.
22    """
23
24    type = "none"
25    """
26    The `type` parameter defines interaction with the model:
27
28    - `'ssl_input'`: a modification of the input data passes through the base network feature extraction module and the
29    SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
30    - `'ssl_output'`:  the input data passes through the base network feature extraction module and the SSL module; it
31    is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
32    - `'contrastive'`:  the input data and its modification pass through the base network feature extraction module and
33    the SSL module; an (input results, modification results) tuple is returned as SSL output,
34    - `'contrastive_2layers'`: the input data and its modification pass through the base network feature extraction module;
35    the output of the second feature extraction layer for the modified data goes through an SSL module and then,
36    optionally, that result and the first-level unmodified features pass another transformation;
37    an (input results, modified results) tuple is returned as SSL output,
38    """
39
40    def __init__(self, *args, **kwargs) -> None:
41        """Initialize the SSL constructor."""
42        ...
43
44    @abstractmethod
45    def transformation(self, sample_data: Dict) -> (Dict, Dict):
46        """Transform a sample feature dictionary into SSL input and target.
47
48        Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets
49        with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample.
50        If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before
51        all features are stacked together.
52
53        Parameters
54        ----------
55        sample_data : dict
56            a feature dictionary
57
58        Returns
59        -------
60        ssl_input : dict | torch.float('nan')
61            a feature dictionary of SSL inputs
62        ssl_target : dict | torch.float('nan')
63            a feature dictionary of SSL targets
64
65        """
66
67    @abstractmethod
68    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
69        """Calculate the SSL loss.
70
71        Parameters
72        ----------
73        predicted : torch.Tensor
74            output of the SSL module
75        target : torch.Tensor
76            augmented and stacked SSL_target
77
78        Returns
79        -------
80        loss : float
81            the loss value
82
83        """
84
85    @abstractmethod
86    def construct_module(self) -> nn.Module:
87        """Construct the SSL module.
88
89        Returns
90        -------
91        ssl_module : torch.nn.Module
92            a neural net module that takes features extracted by a model's feature extractor as input and
93            returns SSL output
94
95        """

A base class for all SSL constructors.

An SSL method is defined by three things: a transformation that maps a sample into SSL input and output, a neural net module that takes features as input and predicts SSL target, a type and a loss function.

SSLConstructor(*args, **kwargs)
40    def __init__(self, *args, **kwargs) -> None:
41        """Initialize the SSL constructor."""
42        ...

Initialize the SSL constructor.

type = 'none'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
@abstractmethod
def transformation(self, sample_data: Dict) -> (typing.Dict, typing.Dict):
44    @abstractmethod
45    def transformation(self, sample_data: Dict) -> (Dict, Dict):
46        """Transform a sample feature dictionary into SSL input and target.
47
48        Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets
49        with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample.
50        If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before
51        all features are stacked together.
52
53        Parameters
54        ----------
55        sample_data : dict
56            a feature dictionary
57
58        Returns
59        -------
60        ssl_input : dict | torch.float('nan')
61            a feature dictionary of SSL inputs
62        ssl_target : dict | torch.float('nan')
63            a feature dictionary of SSL targets
64
65        """

Transform a sample feature dictionary into SSL input and target.

Either input, target or both can be left as None. Transformers can be configured to replace None SSL targets with the input sample at runtime and/or to replace None SSL inputs with a new augmentation of the input sample. If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before all features are stacked together.

Parameters

sample_data : dict a feature dictionary

Returns

ssl_input : dict | torch.float('nan') a feature dictionary of SSL inputs ssl_target : dict | torch.float('nan') a feature dictionary of SSL targets

@abstractmethod
def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
67    @abstractmethod
68    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
69        """Calculate the SSL loss.
70
71        Parameters
72        ----------
73        predicted : torch.Tensor
74            output of the SSL module
75        target : torch.Tensor
76            augmented and stacked SSL_target
77
78        Returns
79        -------
80        loss : float
81            the loss value
82
83        """

Calculate the SSL loss.

Parameters

predicted : torch.Tensor output of the SSL module target : torch.Tensor augmented and stacked SSL_target

Returns

loss : float the loss value

@abstractmethod
def construct_module(self) -> torch.nn.modules.module.Module:
85    @abstractmethod
86    def construct_module(self) -> nn.Module:
87        """Construct the SSL module.
88
89        Returns
90        -------
91        ssl_module : torch.nn.Module
92            a neural net module that takes features extracted by a model's feature extractor as input and
93            returns SSL output
94
95        """

Construct the SSL module.

Returns

ssl_module : torch.nn.Module a neural net module that takes features extracted by a model's feature extractor as input and returns SSL output

class EmptySSL(SSLConstructor):
 98class EmptySSL(SSLConstructor):
 99    """Empty SSL class."""
100
101    def transformation(self, sample_data: Dict) -> Tuple:
102        """Empty transformation."""
103        return (torch.tensor(float("nan")), torch.tensor(float("nan")))
104
105    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
106        """Empty loss."""
107        return 0
108
109    def construct_module(self) -> None:
110        """Empty module."""
111        return None

Empty SSL class.

def transformation(self, sample_data: Dict) -> Tuple:
101    def transformation(self, sample_data: Dict) -> Tuple:
102        """Empty transformation."""
103        return (torch.tensor(float("nan")), torch.tensor(float("nan")))

Empty transformation.

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
105    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
106        """Empty loss."""
107        return 0

Empty loss.

def construct_module(self) -> None:
109    def construct_module(self) -> None:
110        """Empty module."""
111        return None

Empty module.

Inherited Members
SSLConstructor
SSLConstructor
type