dlc2action.ssl

Self-supervised tasks

Self-supervised learning tasks in dlc2action are implemented in the form of the base_ssl.SSLConstructor abstract class. In order to create a new task you need to define four things: a data transformation, a network module, a loss function and the type of the task that defines interaction with the base network.

The transformation is applied to model input data at runtime to generate the SSL input and target from input data. It receives the target data as a feature dictionary, so complex generalised transformations can be defined easily. The SSL input and target should also be returned in similarly formatted dictionaries so that augmentations can be applied to them correctly. If you are using the dlc2action.task.task_dispatcher.TaskDispatcher (or dlc2action.project.project.Project) interface, they will be regarded differently according to the type of the task:

  • 'ssl_input': if the input generated by transformation is None an error will be raised; if the generated target is None it will be replaced by the unmodified input data,
  • 'ssl_target': the input generated by transformation will be disregarded; if the generated target is None it will be replaced by the unmodified input data,
  • 'contrastive', 'contrastive_2layers': if the transformation returns None for the input, the input will be created by the transformer as an augmentation of the input data; if the target is None it will stay None,

You can also set these rules manually with the keep_target_none and generate_ssl_input parameters of the transformers.

The SSL module is stacked with the base network feature extraction module as described above. The loss function takes SSL output and SSL target as input and returns a loss value (all operations must be differentiable by torch). The available types are described at dlc2action.ssl.base_ssl.SSLConstructor.

 1#
 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
 3#
 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
 5#
 6"""
 7## Self-supervised tasks
 8
 9Self-supervised learning tasks in `dlc2action` are implemented in the form of the `base_ssl.SSLConstructor`
10abstract class.
11In order to create a new task you need to define four things: a data *transformation*, a network *module*, a *loss
12function* and the *type* of the task that defines interaction with the base network.
13
14The *transformation* is applied to model input data at runtime to generate the SSL input and target from input data.
15It receives the target data as a feature dictionary, so complex generalised transformations can be defined easily.
16The SSL input and target should also be returned in similarly formatted dictionaries so that augmentations can be
17applied to them correctly. If you are using the `dlc2action.task.task_dispatcher.TaskDispatcher`
18(or `dlc2action.project.project.Project`) interface, they will be regarded
19differently according to the type of the task:
20
21- `'ssl_input'`: if the input generated by transformation is None an error will be raised; if the generated target
22    is None it will be replaced by the unmodified input data,
23- `'ssl_target'`: the input generated by transformation will be disregarded; if the generated target is None it will
24    be replaced by the unmodified input data,
25- `'contrastive'`, `'contrastive_2layers'`: if the transformation returns None for the input, the input will be created by the transformer as
26    an augmentation of the input data; if the target is `None` it will stay `None`,
27
28You can also set these rules manually with the `keep_target_none` and `generate_ssl_input` parameters of the
29transformers.
30
31The SSL *module* is stacked with the base network feature extraction module as described above. The *loss function*
32takes SSL output and SSL target as input and returns a loss value (all operations must be differentiable by `torch`).
33The available types are described at `dlc2action.ssl.base_ssl.SSLConstructor`.
34"""
35
36from dlc2action.ssl.contrastive import *
37from dlc2action.ssl.masked import *