dlc2action.ssl.base_ssl
Abstract class for defining SSL tasks.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Abstract class for defining SSL tasks.""" 8 9from typing import Callable, Tuple, Dict, Union 10from abc import ABC, abstractmethod 11import torch 12from torch import nn 13 14torch.manual_seed(0) 15 16class SSLConstructor(ABC): 17 """A base class for all SSL constructors. 18 19 An SSL method is defined by three things: a *transformation* that maps a sample into SSL input and output, 20 a neural net *module* that takes features as input and predicts SSL target, a *type* and a *loss function*. 21 """ 22 23 type = "none" 24 """ 25 The `type` parameter defines interaction with the model: 26 27 - `'ssl_input'`: a modification of the input data passes through the base network feature extraction module and the 28 SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data), 29 - `'ssl_output'`: the input data passes through the base network feature extraction module and the SSL module; it 30 is returned as SSL output and compared to SSL target (or, if it is None, to the input data), 31 - `'contrastive'`: the input data and its modification pass through the base network feature extraction module and 32 the SSL module; an (input results, modification results) tuple is returned as SSL output, 33 - `'contrastive_2layers'`: the input data and its modification pass through the base network feature extraction module; 34 the output of the second feature extraction layer for the modified data goes through an SSL module and then, 35 optionally, that result and the first-level unmodified features pass another transformation; 36 an (input results, modified results) tuple is returned as SSL output, 37 """ 38 39 def __init__(self, *args, **kwargs) -> None: 40 """Initialize the SSL constructor.""" 41 ... 42 43 @abstractmethod 44 def transformation(self, sample_data: Dict) -> (Dict, Dict): 45 """Transform a sample feature dictionary into SSL input and target. 46 47 Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets 48 with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample. 49 If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before 50 all features are stacked together. 51 52 Parameters 53 ---------- 54 sample_data : dict 55 a feature dictionary 56 57 Returns 58 ------- 59 ssl_input : dict | torch.float('nan') 60 a feature dictionary of SSL inputs 61 ssl_target : dict | torch.float('nan') 62 a feature dictionary of SSL targets 63 64 """ 65 66 @abstractmethod 67 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 68 """Calculate the SSL loss. 69 70 Parameters 71 ---------- 72 predicted : torch.Tensor 73 output of the SSL module 74 target : torch.Tensor 75 augmented and stacked SSL_target 76 77 Returns 78 ------- 79 loss : float 80 the loss value 81 82 """ 83 84 @abstractmethod 85 def construct_module(self) -> nn.Module: 86 """Construct the SSL module. 87 88 Returns 89 ------- 90 ssl_module : torch.nn.Module 91 a neural net module that takes features extracted by a model's feature extractor as input and 92 returns SSL output 93 94 """ 95 96 97class EmptySSL(SSLConstructor): 98 """Empty SSL class.""" 99 100 def transformation(self, sample_data: Dict) -> Tuple: 101 """Empty transformation.""" 102 return (torch.tensor(float("nan")), torch.tensor(float("nan"))) 103 104 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 105 """Empty loss.""" 106 return 0 107 108 def construct_module(self) -> None: 109 """Empty module.""" 110 return None
17class SSLConstructor(ABC): 18 """A base class for all SSL constructors. 19 20 An SSL method is defined by three things: a *transformation* that maps a sample into SSL input and output, 21 a neural net *module* that takes features as input and predicts SSL target, a *type* and a *loss function*. 22 """ 23 24 type = "none" 25 """ 26 The `type` parameter defines interaction with the model: 27 28 - `'ssl_input'`: a modification of the input data passes through the base network feature extraction module and the 29 SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data), 30 - `'ssl_output'`: the input data passes through the base network feature extraction module and the SSL module; it 31 is returned as SSL output and compared to SSL target (or, if it is None, to the input data), 32 - `'contrastive'`: the input data and its modification pass through the base network feature extraction module and 33 the SSL module; an (input results, modification results) tuple is returned as SSL output, 34 - `'contrastive_2layers'`: the input data and its modification pass through the base network feature extraction module; 35 the output of the second feature extraction layer for the modified data goes through an SSL module and then, 36 optionally, that result and the first-level unmodified features pass another transformation; 37 an (input results, modified results) tuple is returned as SSL output, 38 """ 39 40 def __init__(self, *args, **kwargs) -> None: 41 """Initialize the SSL constructor.""" 42 ... 43 44 @abstractmethod 45 def transformation(self, sample_data: Dict) -> (Dict, Dict): 46 """Transform a sample feature dictionary into SSL input and target. 47 48 Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets 49 with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample. 50 If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before 51 all features are stacked together. 52 53 Parameters 54 ---------- 55 sample_data : dict 56 a feature dictionary 57 58 Returns 59 ------- 60 ssl_input : dict | torch.float('nan') 61 a feature dictionary of SSL inputs 62 ssl_target : dict | torch.float('nan') 63 a feature dictionary of SSL targets 64 65 """ 66 67 @abstractmethod 68 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 69 """Calculate the SSL loss. 70 71 Parameters 72 ---------- 73 predicted : torch.Tensor 74 output of the SSL module 75 target : torch.Tensor 76 augmented and stacked SSL_target 77 78 Returns 79 ------- 80 loss : float 81 the loss value 82 83 """ 84 85 @abstractmethod 86 def construct_module(self) -> nn.Module: 87 """Construct the SSL module. 88 89 Returns 90 ------- 91 ssl_module : torch.nn.Module 92 a neural net module that takes features extracted by a model's feature extractor as input and 93 returns SSL output 94 95 """
A base class for all SSL constructors.
An SSL method is defined by three things: a transformation that maps a sample into SSL input and output, a neural net module that takes features as input and predicts SSL target, a type and a loss function.
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
44 @abstractmethod 45 def transformation(self, sample_data: Dict) -> (Dict, Dict): 46 """Transform a sample feature dictionary into SSL input and target. 47 48 Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets 49 with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample. 50 If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before 51 all features are stacked together. 52 53 Parameters 54 ---------- 55 sample_data : dict 56 a feature dictionary 57 58 Returns 59 ------- 60 ssl_input : dict | torch.float('nan') 61 a feature dictionary of SSL inputs 62 ssl_target : dict | torch.float('nan') 63 a feature dictionary of SSL targets 64 65 """
Transform a sample feature dictionary into SSL input and target.
Either input, target or both can be left as None
. Transformers can be configured to replace None
SSL targets
with the input sample at runtime and/or to replace None SSL
inputs with a new augmentation of the input sample.
If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before
all features are stacked together.
Parameters
sample_data : dict a feature dictionary
Returns
ssl_input : dict | torch.float('nan') a feature dictionary of SSL inputs ssl_target : dict | torch.float('nan') a feature dictionary of SSL targets
67 @abstractmethod 68 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 69 """Calculate the SSL loss. 70 71 Parameters 72 ---------- 73 predicted : torch.Tensor 74 output of the SSL module 75 target : torch.Tensor 76 augmented and stacked SSL_target 77 78 Returns 79 ------- 80 loss : float 81 the loss value 82 83 """
Calculate the SSL loss.
Parameters
predicted : torch.Tensor output of the SSL module target : torch.Tensor augmented and stacked SSL_target
Returns
loss : float the loss value
85 @abstractmethod 86 def construct_module(self) -> nn.Module: 87 """Construct the SSL module. 88 89 Returns 90 ------- 91 ssl_module : torch.nn.Module 92 a neural net module that takes features extracted by a model's feature extractor as input and 93 returns SSL output 94 95 """
Construct the SSL module.
Returns
ssl_module : torch.nn.Module a neural net module that takes features extracted by a model's feature extractor as input and returns SSL output
98class EmptySSL(SSLConstructor): 99 """Empty SSL class.""" 100 101 def transformation(self, sample_data: Dict) -> Tuple: 102 """Empty transformation.""" 103 return (torch.tensor(float("nan")), torch.tensor(float("nan"))) 104 105 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 106 """Empty loss.""" 107 return 0 108 109 def construct_module(self) -> None: 110 """Empty module.""" 111 return None
Empty SSL class.
101 def transformation(self, sample_data: Dict) -> Tuple: 102 """Empty transformation.""" 103 return (torch.tensor(float("nan")), torch.tensor(float("nan")))
Empty transformation.
105 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 106 """Empty loss.""" 107 return 0
Empty loss.