dlc2action.ssl.base_ssl

Abstract class for defining SSL tasks

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6"""
  7Abstract class for defining SSL tasks
  8"""
  9
 10from typing import Callable, Tuple, Dict, Union
 11from abc import ABC, abstractmethod
 12import torch
 13from torch import nn
 14
 15
 16class SSLConstructor(ABC):
 17    """
 18    A base class for all SSL constructors
 19
 20    An SSL method is defined by three things: a *transformation* that maps a sample into SSL input and output,
 21    a neural net *module* that takes features as input and predicts SSL target, a *type* and a *loss function*.
 22    """
 23
 24    type = "none"
 25    """
 26    The `type` parameter defines interaction with the model:
 27    
 28    - `'ssl_input'`: a modification of the input data passes through the base network feature extraction module and the
 29    SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
 30    - `'ssl_output'`:  the input data passes through the base network feature extraction module and the SSL module; it
 31    is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
 32    - `'contrastive'`:  the input data and its modification pass through the base network feature extraction module and
 33    the SSL module; an (input results, modification results) tuple is returned as SSL output,
 34    - `'contrastive_2layers'`: the input data and its modification pass through the base network feature extraction module; 
 35    the output of the second feature extraction layer for the modified data goes through an SSL module and then, 
 36    optionally, that result and the first-level unmodified features pass another transformation; 
 37    an (input results, modified results) tuple is returned as SSL output,
 38    """
 39
 40    def __init__(self, *args, **kwargs) -> None:
 41        ...
 42
 43    @abstractmethod
 44    def transformation(self, sample_data: Dict) -> Tuple[Dict, Dict]:
 45        """
 46        Transform a sample feature dictionary into SSL input and target
 47
 48        Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets
 49        with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample.
 50        If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before
 51        all features are stacked together.
 52
 53        Parameters
 54        ----------
 55        sample_data : dict
 56            a feature dictionary
 57
 58        Returns
 59        -------
 60        ssl_input : dict | torch.float('nan')
 61            a feature dictionary of SSL inputs
 62        ssl_target : dict | torch.float('nan')
 63            a feature dictionary of SSL targets
 64        """
 65
 66    @abstractmethod
 67    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
 68        """
 69        Calculate the SSL loss
 70
 71        Parameters
 72        ----------
 73        predicted : torch.Tensor
 74            output of the SSL module
 75        target : torch.Tensor
 76            augmented and stacked SSL_target
 77
 78        Returns
 79        -------
 80        loss : float
 81            the loss value
 82        """
 83
 84    @abstractmethod
 85    def construct_module(self) -> nn.Module:
 86        """
 87        Construct the SSL module
 88
 89        Returns
 90        -------
 91        ssl_module : torch.nn.Module
 92            a neural net module that takes features extracted by a model's feature extractor as input and
 93            returns SSL output
 94        """
 95
 96
 97class EmptySSL(SSLConstructor):
 98    """
 99    Empty SSL class
100    """
101
102    def transformation(self, sample_data: Dict) -> Tuple:
103        """
104        Empty transformation
105        """
106
107        return (torch.tensor(float("nan")), torch.tensor(float("nan")))
108
109    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
110        """
111        Empty loss
112        """
113
114        return 0
115
116    def construct_module(self) -> None:
117        """
118        Empty module
119        """
120
121        return None
class SSLConstructor(abc.ABC):
17class SSLConstructor(ABC):
18    """
19    A base class for all SSL constructors
20
21    An SSL method is defined by three things: a *transformation* that maps a sample into SSL input and output,
22    a neural net *module* that takes features as input and predicts SSL target, a *type* and a *loss function*.
23    """
24
25    type = "none"
26    """
27    The `type` parameter defines interaction with the model:
28    
29    - `'ssl_input'`: a modification of the input data passes through the base network feature extraction module and the
30    SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
31    - `'ssl_output'`:  the input data passes through the base network feature extraction module and the SSL module; it
32    is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
33    - `'contrastive'`:  the input data and its modification pass through the base network feature extraction module and
34    the SSL module; an (input results, modification results) tuple is returned as SSL output,
35    - `'contrastive_2layers'`: the input data and its modification pass through the base network feature extraction module; 
36    the output of the second feature extraction layer for the modified data goes through an SSL module and then, 
37    optionally, that result and the first-level unmodified features pass another transformation; 
38    an (input results, modified results) tuple is returned as SSL output,
39    """
40
41    def __init__(self, *args, **kwargs) -> None:
42        ...
43
44    @abstractmethod
45    def transformation(self, sample_data: Dict) -> Tuple[Dict, Dict]:
46        """
47        Transform a sample feature dictionary into SSL input and target
48
49        Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets
50        with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample.
51        If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before
52        all features are stacked together.
53
54        Parameters
55        ----------
56        sample_data : dict
57            a feature dictionary
58
59        Returns
60        -------
61        ssl_input : dict | torch.float('nan')
62            a feature dictionary of SSL inputs
63        ssl_target : dict | torch.float('nan')
64            a feature dictionary of SSL targets
65        """
66
67    @abstractmethod
68    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
69        """
70        Calculate the SSL loss
71
72        Parameters
73        ----------
74        predicted : torch.Tensor
75            output of the SSL module
76        target : torch.Tensor
77            augmented and stacked SSL_target
78
79        Returns
80        -------
81        loss : float
82            the loss value
83        """
84
85    @abstractmethod
86    def construct_module(self) -> nn.Module:
87        """
88        Construct the SSL module
89
90        Returns
91        -------
92        ssl_module : torch.nn.Module
93            a neural net module that takes features extracted by a model's feature extractor as input and
94            returns SSL output
95        """

A base class for all SSL constructors

An SSL method is defined by three things: a transformation that maps a sample into SSL input and output, a neural net module that takes features as input and predicts SSL target, a type and a loss function.

type = 'none'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
@abstractmethod
def transformation(self, sample_data: Dict) -> Tuple[Dict, Dict]:
44    @abstractmethod
45    def transformation(self, sample_data: Dict) -> Tuple[Dict, Dict]:
46        """
47        Transform a sample feature dictionary into SSL input and target
48
49        Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets
50        with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample.
51        If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before
52        all features are stacked together.
53
54        Parameters
55        ----------
56        sample_data : dict
57            a feature dictionary
58
59        Returns
60        -------
61        ssl_input : dict | torch.float('nan')
62            a feature dictionary of SSL inputs
63        ssl_target : dict | torch.float('nan')
64            a feature dictionary of SSL targets
65        """

Transform a sample feature dictionary into SSL input and target

Either input, target or both can be left as None. Transformers can be configured to replace None SSL targets with the input sample at runtime and/or to replace None SSL inputs with a new augmentation of the input sample. If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before all features are stacked together.

Parameters

sample_data : dict a feature dictionary

Returns

ssl_input : dict | torch.float('nan') a feature dictionary of SSL inputs ssl_target : dict | torch.float('nan') a feature dictionary of SSL targets

@abstractmethod
def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
67    @abstractmethod
68    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
69        """
70        Calculate the SSL loss
71
72        Parameters
73        ----------
74        predicted : torch.Tensor
75            output of the SSL module
76        target : torch.Tensor
77            augmented and stacked SSL_target
78
79        Returns
80        -------
81        loss : float
82            the loss value
83        """

Calculate the SSL loss

Parameters

predicted : torch.Tensor output of the SSL module target : torch.Tensor augmented and stacked SSL_target

Returns

loss : float the loss value

@abstractmethod
def construct_module(self) -> torch.nn.modules.module.Module:
85    @abstractmethod
86    def construct_module(self) -> nn.Module:
87        """
88        Construct the SSL module
89
90        Returns
91        -------
92        ssl_module : torch.nn.Module
93            a neural net module that takes features extracted by a model's feature extractor as input and
94            returns SSL output
95        """

Construct the SSL module

Returns

ssl_module : torch.nn.Module a neural net module that takes features extracted by a model's feature extractor as input and returns SSL output

class EmptySSL(SSLConstructor):
 98class EmptySSL(SSLConstructor):
 99    """
100    Empty SSL class
101    """
102
103    def transformation(self, sample_data: Dict) -> Tuple:
104        """
105        Empty transformation
106        """
107
108        return (torch.tensor(float("nan")), torch.tensor(float("nan")))
109
110    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
111        """
112        Empty loss
113        """
114
115        return 0
116
117    def construct_module(self) -> None:
118        """
119        Empty module
120        """
121
122        return None

Empty SSL class

def transformation(self, sample_data: Dict) -> Tuple:
103    def transformation(self, sample_data: Dict) -> Tuple:
104        """
105        Empty transformation
106        """
107
108        return (torch.tensor(float("nan")), torch.tensor(float("nan")))

Empty transformation

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
110    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
111        """
112        Empty loss
113        """
114
115        return 0

Empty loss

def construct_module(self) -> None:
117    def construct_module(self) -> None:
118        """
119        Empty module
120        """
121
122        return None

Empty module

Inherited Members
SSLConstructor
SSLConstructor
type