dlc2action.ssl
Self-supervised tasks
Self-supervised learning tasks in dlc2action
are implemented in the form of the base_ssl.SSLConstructor
abstract class.
In order to create a new task you need to define four things: a data transformation, a network module, a loss
function and the type of the task that defines interaction with the base network.
The transformation is applied to model input data at runtime to generate the SSL input and target from input data.
It receives the target data as a feature dictionary, so complex generalised transformations can be defined easily.
The SSL input and target should also be returned in similarly formatted dictionaries so that augmentations can be
applied to them correctly. If you are using the dlc2action.task.task_dispatcher.TaskDispatcher
(or dlc2action.project.project.Project
) interface, they will be regarded
differently according to the type of the task:
'ssl_input'
: if the input generated by transformation is None an error will be raised; if the generated target is None it will be replaced by the unmodified input data,'ssl_target'
: the input generated by transformation will be disregarded; if the generated target is None it will be replaced by the unmodified input data,'contrastive'
,'contrastive_2layers'
: if the transformation returns None for the input, the input will be created by the transformer as an augmentation of the input data; if the target isNone
it will stayNone
,
You can also set these rules manually with the keep_target_none
and generate_ssl_input
parameters of the
transformers.
The SSL module is stacked with the base network feature extraction module as described above. The loss function
takes SSL output and SSL target as input and returns a loss value (all operations must be differentiable by torch
).
The available types are described at dlc2action.ssl.base_ssl.SSLConstructor
.
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7## Self-supervised tasks 8 9Self-supervised learning tasks in `dlc2action` are implemented in the form of the `base_ssl.SSLConstructor` 10abstract class. 11In order to create a new task you need to define four things: a data *transformation*, a network *module*, a *loss 12function* and the *type* of the task that defines interaction with the base network. 13 14The *transformation* is applied to model input data at runtime to generate the SSL input and target from input data. 15It receives the target data as a feature dictionary, so complex generalised transformations can be defined easily. 16The SSL input and target should also be returned in similarly formatted dictionaries so that augmentations can be 17applied to them correctly. If you are using the `dlc2action.task.task_dispatcher.TaskDispatcher` 18(or `dlc2action.project.project.Project`) interface, they will be regarded 19differently according to the type of the task: 20 21- `'ssl_input'`: if the input generated by transformation is None an error will be raised; if the generated target 22 is None it will be replaced by the unmodified input data, 23- `'ssl_target'`: the input generated by transformation will be disregarded; if the generated target is None it will 24 be replaced by the unmodified input data, 25- `'contrastive'`, `'contrastive_2layers'`: if the transformation returns None for the input, the input will be created by the transformer as 26 an augmentation of the input data; if the target is `None` it will stay `None`, 27 28You can also set these rules manually with the `keep_target_none` and `generate_ssl_input` parameters of the 29transformers. 30 31The SSL *module* is stacked with the base network feature extraction module as described above. The *loss function* 32takes SSL output and SSL target as input and returns a loss value (all operations must be differentiable by `torch`). 33The available types are described at `dlc2action.ssl.base_ssl.SSLConstructor`. 34""" 35 36from dlc2action.ssl.contrastive import * 37from dlc2action.ssl.masked import *