dlc2action.ssl.base_ssl
Abstract class for defining SSL tasks
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Abstract class for defining SSL tasks 8""" 9 10from typing import Callable, Tuple, Dict, Union 11from abc import ABC, abstractmethod 12import torch 13from torch import nn 14 15 16class SSLConstructor(ABC): 17 """ 18 A base class for all SSL constructors 19 20 An SSL method is defined by three things: a *transformation* that maps a sample into SSL input and output, 21 a neural net *module* that takes features as input and predicts SSL target, a *type* and a *loss function*. 22 """ 23 24 type = "none" 25 """ 26 The `type` parameter defines interaction with the model: 27 28 - `'ssl_input'`: a modification of the input data passes through the base network feature extraction module and the 29 SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data), 30 - `'ssl_output'`: the input data passes through the base network feature extraction module and the SSL module; it 31 is returned as SSL output and compared to SSL target (or, if it is None, to the input data), 32 - `'contrastive'`: the input data and its modification pass through the base network feature extraction module and 33 the SSL module; an (input results, modification results) tuple is returned as SSL output, 34 - `'contrastive_2layers'`: the input data and its modification pass through the base network feature extraction module; 35 the output of the second feature extraction layer for the modified data goes through an SSL module and then, 36 optionally, that result and the first-level unmodified features pass another transformation; 37 an (input results, modified results) tuple is returned as SSL output, 38 """ 39 40 def __init__(self, *args, **kwargs) -> None: 41 ... 42 43 @abstractmethod 44 def transformation(self, sample_data: Dict) -> Tuple[Dict, Dict]: 45 """ 46 Transform a sample feature dictionary into SSL input and target 47 48 Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets 49 with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample. 50 If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before 51 all features are stacked together. 52 53 Parameters 54 ---------- 55 sample_data : dict 56 a feature dictionary 57 58 Returns 59 ------- 60 ssl_input : dict | torch.float('nan') 61 a feature dictionary of SSL inputs 62 ssl_target : dict | torch.float('nan') 63 a feature dictionary of SSL targets 64 """ 65 66 @abstractmethod 67 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 68 """ 69 Calculate the SSL loss 70 71 Parameters 72 ---------- 73 predicted : torch.Tensor 74 output of the SSL module 75 target : torch.Tensor 76 augmented and stacked SSL_target 77 78 Returns 79 ------- 80 loss : float 81 the loss value 82 """ 83 84 @abstractmethod 85 def construct_module(self) -> nn.Module: 86 """ 87 Construct the SSL module 88 89 Returns 90 ------- 91 ssl_module : torch.nn.Module 92 a neural net module that takes features extracted by a model's feature extractor as input and 93 returns SSL output 94 """ 95 96 97class EmptySSL(SSLConstructor): 98 """ 99 Empty SSL class 100 """ 101 102 def transformation(self, sample_data: Dict) -> Tuple: 103 """ 104 Empty transformation 105 """ 106 107 return (torch.tensor(float("nan")), torch.tensor(float("nan"))) 108 109 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 110 """ 111 Empty loss 112 """ 113 114 return 0 115 116 def construct_module(self) -> None: 117 """ 118 Empty module 119 """ 120 121 return None
17class SSLConstructor(ABC): 18 """ 19 A base class for all SSL constructors 20 21 An SSL method is defined by three things: a *transformation* that maps a sample into SSL input and output, 22 a neural net *module* that takes features as input and predicts SSL target, a *type* and a *loss function*. 23 """ 24 25 type = "none" 26 """ 27 The `type` parameter defines interaction with the model: 28 29 - `'ssl_input'`: a modification of the input data passes through the base network feature extraction module and the 30 SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data), 31 - `'ssl_output'`: the input data passes through the base network feature extraction module and the SSL module; it 32 is returned as SSL output and compared to SSL target (or, if it is None, to the input data), 33 - `'contrastive'`: the input data and its modification pass through the base network feature extraction module and 34 the SSL module; an (input results, modification results) tuple is returned as SSL output, 35 - `'contrastive_2layers'`: the input data and its modification pass through the base network feature extraction module; 36 the output of the second feature extraction layer for the modified data goes through an SSL module and then, 37 optionally, that result and the first-level unmodified features pass another transformation; 38 an (input results, modified results) tuple is returned as SSL output, 39 """ 40 41 def __init__(self, *args, **kwargs) -> None: 42 ... 43 44 @abstractmethod 45 def transformation(self, sample_data: Dict) -> Tuple[Dict, Dict]: 46 """ 47 Transform a sample feature dictionary into SSL input and target 48 49 Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets 50 with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample. 51 If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before 52 all features are stacked together. 53 54 Parameters 55 ---------- 56 sample_data : dict 57 a feature dictionary 58 59 Returns 60 ------- 61 ssl_input : dict | torch.float('nan') 62 a feature dictionary of SSL inputs 63 ssl_target : dict | torch.float('nan') 64 a feature dictionary of SSL targets 65 """ 66 67 @abstractmethod 68 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 69 """ 70 Calculate the SSL loss 71 72 Parameters 73 ---------- 74 predicted : torch.Tensor 75 output of the SSL module 76 target : torch.Tensor 77 augmented and stacked SSL_target 78 79 Returns 80 ------- 81 loss : float 82 the loss value 83 """ 84 85 @abstractmethod 86 def construct_module(self) -> nn.Module: 87 """ 88 Construct the SSL module 89 90 Returns 91 ------- 92 ssl_module : torch.nn.Module 93 a neural net module that takes features extracted by a model's feature extractor as input and 94 returns SSL output 95 """
A base class for all SSL constructors
An SSL method is defined by three things: a transformation that maps a sample into SSL input and output, a neural net module that takes features as input and predicts SSL target, a type and a loss function.
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
44 @abstractmethod 45 def transformation(self, sample_data: Dict) -> Tuple[Dict, Dict]: 46 """ 47 Transform a sample feature dictionary into SSL input and target 48 49 Either input, target or both can be left as `None`. Transformers can be configured to replace `None` SSL targets 50 with the input sample at runtime and/or to replace `None SSL` inputs with a new augmentation of the input sample. 51 If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before 52 all features are stacked together. 53 54 Parameters 55 ---------- 56 sample_data : dict 57 a feature dictionary 58 59 Returns 60 ------- 61 ssl_input : dict | torch.float('nan') 62 a feature dictionary of SSL inputs 63 ssl_target : dict | torch.float('nan') 64 a feature dictionary of SSL targets 65 """
Transform a sample feature dictionary into SSL input and target
Either input, target or both can be left as None
. Transformers can be configured to replace None
SSL targets
with the input sample at runtime and/or to replace None SSL
inputs with a new augmentation of the input sample.
If the keys of the feature dictionaries are recognized by the transformer, they will be augmented before
all features are stacked together.
Parameters
sample_data : dict a feature dictionary
Returns
ssl_input : dict | torch.float('nan') a feature dictionary of SSL inputs ssl_target : dict | torch.float('nan') a feature dictionary of SSL targets
67 @abstractmethod 68 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 69 """ 70 Calculate the SSL loss 71 72 Parameters 73 ---------- 74 predicted : torch.Tensor 75 output of the SSL module 76 target : torch.Tensor 77 augmented and stacked SSL_target 78 79 Returns 80 ------- 81 loss : float 82 the loss value 83 """
Calculate the SSL loss
Parameters
predicted : torch.Tensor output of the SSL module target : torch.Tensor augmented and stacked SSL_target
Returns
loss : float the loss value
85 @abstractmethod 86 def construct_module(self) -> nn.Module: 87 """ 88 Construct the SSL module 89 90 Returns 91 ------- 92 ssl_module : torch.nn.Module 93 a neural net module that takes features extracted by a model's feature extractor as input and 94 returns SSL output 95 """
Construct the SSL module
Returns
ssl_module : torch.nn.Module a neural net module that takes features extracted by a model's feature extractor as input and returns SSL output
98class EmptySSL(SSLConstructor): 99 """ 100 Empty SSL class 101 """ 102 103 def transformation(self, sample_data: Dict) -> Tuple: 104 """ 105 Empty transformation 106 """ 107 108 return (torch.tensor(float("nan")), torch.tensor(float("nan"))) 109 110 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 111 """ 112 Empty loss 113 """ 114 115 return 0 116 117 def construct_module(self) -> None: 118 """ 119 Empty module 120 """ 121 122 return None
Empty SSL class
103 def transformation(self, sample_data: Dict) -> Tuple: 104 """ 105 Empty transformation 106 """ 107 108 return (torch.tensor(float("nan")), torch.tensor(float("nan")))
Empty transformation
110 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 111 """ 112 Empty loss 113 """ 114 115 return 0
Empty loss