dlc2action.ssl.segment_order

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6from typing import Dict, Tuple, Union, List
  7import torch
  8from dlc2action.ssl.base_ssl import SSLConstructor
  9from abc import ABC, abstractmethod
 10from dlc2action.ssl.modules import _FeatureExtractorTCN
 11import torch
 12from torch import nn
 13from copy import deepcopy
 14from itertools import permutations
 15
 16from torch.nn import CrossEntropyLoss, Linear, BCEWithLogitsLoss
 17
 18
 19class ReverseSSL(SSLConstructor, ABC):
 20    """
 21    A flip detection SSL
 22
 23    Reverse some of the segments and predict the flip with a binary classifier.
 24    """
 25
 26    type = "ssl_input"
 27
 28    def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None:
 29        """
 30        Parameters
 31        ----------
 32        num_f_maps : torch.Size
 33            the number of input feature maps
 34        len_segment : int
 35            the length of the input segments
 36        """
 37
 38        super().__init__()
 39        self.ce = BCEWithLogitsLoss()
 40        if len(num_f_maps) > 1:
 41            raise RuntimeError(
 42                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
 43                f"got {len(num_f_maps) + 1} dimensions"
 44            )
 45        num_f_maps = int(num_f_maps[0])
 46        self.pars = {
 47            "num_f_maps": num_f_maps,
 48            "len_segment": len_segment,
 49            "output_dim": 1,
 50            "kernel_1": 5,
 51            "kernel_2": 5,
 52            "stride": 2,
 53            "decrease_f_maps": True,
 54        }
 55
 56    def transformation(self, sample_data: Dict) -> Tuple:
 57        """
 58        Do the flip
 59        """
 60
 61        ssl_target = torch.randint(2, (1,), dtype=torch.float)
 62        ssl_input = deepcopy(sample_data)
 63        if ssl_target == 1:
 64            for key, value in sample_data.items():
 65                ssl_input[key] = value.flip(-1)
 66        return ssl_input, {"order": ssl_target}
 67
 68    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
 69        """
 70        Cross-entropy loss
 71        """
 72
 73        loss = self.ce(predicted, target.squeeze())
 74        return loss
 75
 76    def construct_module(self) -> nn.Module:
 77        """
 78        Construct the SSL prediction module using the parameters specified at initialization
 79        """
 80
 81        module = _FeatureExtractorTCN(**self.pars)
 82        return module
 83
 84
 85class OrderSSL(SSLConstructor, ABC):
 86    """
 87    An order prediction SSL
 88
 89    Cut out segments from the features, permute them and predict the order.
 90    """
 91
 92    type = "ssl_target"
 93
 94    def __init__(
 95        self,
 96        num_f_maps: torch.Size,
 97        len_segment: int,
 98        num_segments: int = 3,
 99        ssl_features: int = 32,
100        skip_frames: int = 10,
101    ) -> None:
102        """
103        Parameters
104        ----------
105        num_f_maps : torch.Size
106            the number of the input feature maps
107        len_segment : int
108            the length of the input segments
109        num_segments : int, default 3
110            the number of segments to permute
111        ssl_features : int, default 32
112            the number of features per permuted segment
113        skip_frames : int, default 10
114            the number of frames to cut from each permuted segment
115        """
116
117        super().__init__()
118        self.ce = CrossEntropyLoss(ignore_index=-100)
119        if len(num_f_maps) > 1:
120            raise RuntimeError(
121                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
122                f"got {len(num_f_maps) + 1} dimensions"
123            )
124        num_f_maps = int(num_f_maps[0])
125        self.orders = [list(x) for x in permutations(range(num_segments), num_segments)]
126        self.len_segment = len_segment // num_segments
127        self.num_segments = num_segments
128        self.skip_frames = skip_frames
129        self.pars = {
130            "num_f_maps": num_f_maps,
131            "len_segment": len_segment // num_segments,
132            "output_dim": ssl_features,
133            "kernel_1": 5,
134            "kernel_2": 5,
135            "stride": 2,
136            "decrease_f_maps": True,
137        }
138
139    def transformation(self, sample_data: Dict) -> Tuple:
140        """
141        Empty transformation
142        """
143
144        return torch.tensor(float("nan")), torch.tensor(float("nan"))
145
146    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
147        """
148        Cross-entropy loss
149        """
150
151        predicted, target = predicted
152        loss = self.ce(predicted, target)
153        return loss
154
155    def construct_module(self) -> nn.Module:
156        """
157        Construct the SSL prediction module using the parameters specified at initialization
158        """
159
160        class Classifier(nn.Module):
161            def __init__(self, num_segments, num_classes, skip_frames, **pars):
162                super().__init__()
163                self.len_segment = pars["len_segment"]
164                pars["len_segment"] -= skip_frames
165                self.extractor = _FeatureExtractorTCN(**pars)
166                self.num_segments = num_segments
167                self.skip_frames = skip_frames
168                self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes)
169                self.orders = torch.tensor(
170                    [list(x) for x in permutations(range(num_segments), num_segments)]
171                )
172
173            def forward(self, x):
174                target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device)
175                order = self.orders[target]
176                x = x[:, :, : self.num_segments * self.len_segment]
177                B, F, L = x.shape
178                x = x.reshape((B, F, -1, self.len_segment))
179                x = x[:, :, :, : -self.skip_frames]
180                x = x[
181                    torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1),
182                    torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0),
183                    order.unsqueeze(1).unsqueeze(-1),
184                    torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0),
185                ]
186                x = x.transpose(1, 2).reshape(
187                    (-1, F, self.len_segment - self.skip_frames)
188                )
189                x = self.extractor(x).reshape((B, -1))
190                x = self.fc(x)
191                return (x, target)
192
193        module = Classifier(
194            self.num_segments,
195            len(self.orders),
196            skip_frames=self.skip_frames,
197            **self.pars,
198        )
199        return module
class ReverseSSL(dlc2action.ssl.base_ssl.SSLConstructor, abc.ABC):
20class ReverseSSL(SSLConstructor, ABC):
21    """
22    A flip detection SSL
23
24    Reverse some of the segments and predict the flip with a binary classifier.
25    """
26
27    type = "ssl_input"
28
29    def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None:
30        """
31        Parameters
32        ----------
33        num_f_maps : torch.Size
34            the number of input feature maps
35        len_segment : int
36            the length of the input segments
37        """
38
39        super().__init__()
40        self.ce = BCEWithLogitsLoss()
41        if len(num_f_maps) > 1:
42            raise RuntimeError(
43                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
44                f"got {len(num_f_maps) + 1} dimensions"
45            )
46        num_f_maps = int(num_f_maps[0])
47        self.pars = {
48            "num_f_maps": num_f_maps,
49            "len_segment": len_segment,
50            "output_dim": 1,
51            "kernel_1": 5,
52            "kernel_2": 5,
53            "stride": 2,
54            "decrease_f_maps": True,
55        }
56
57    def transformation(self, sample_data: Dict) -> Tuple:
58        """
59        Do the flip
60        """
61
62        ssl_target = torch.randint(2, (1,), dtype=torch.float)
63        ssl_input = deepcopy(sample_data)
64        if ssl_target == 1:
65            for key, value in sample_data.items():
66                ssl_input[key] = value.flip(-1)
67        return ssl_input, {"order": ssl_target}
68
69    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
70        """
71        Cross-entropy loss
72        """
73
74        loss = self.ce(predicted, target.squeeze())
75        return loss
76
77    def construct_module(self) -> nn.Module:
78        """
79        Construct the SSL prediction module using the parameters specified at initialization
80        """
81
82        module = _FeatureExtractorTCN(**self.pars)
83        return module

A flip detection SSL

Reverse some of the segments and predict the flip with a binary classifier.

ReverseSSL(num_f_maps: torch.Size, len_segment: int)
29    def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None:
30        """
31        Parameters
32        ----------
33        num_f_maps : torch.Size
34            the number of input feature maps
35        len_segment : int
36            the length of the input segments
37        """
38
39        super().__init__()
40        self.ce = BCEWithLogitsLoss()
41        if len(num_f_maps) > 1:
42            raise RuntimeError(
43                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
44                f"got {len(num_f_maps) + 1} dimensions"
45            )
46        num_f_maps = int(num_f_maps[0])
47        self.pars = {
48            "num_f_maps": num_f_maps,
49            "len_segment": len_segment,
50            "output_dim": 1,
51            "kernel_1": 5,
52            "kernel_2": 5,
53            "stride": 2,
54            "decrease_f_maps": True,
55        }

Parameters

num_f_maps : torch.Size the number of input feature maps len_segment : int the length of the input segments

type = 'ssl_input'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
57    def transformation(self, sample_data: Dict) -> Tuple:
58        """
59        Do the flip
60        """
61
62        ssl_target = torch.randint(2, (1,), dtype=torch.float)
63        ssl_input = deepcopy(sample_data)
64        if ssl_target == 1:
65            for key, value in sample_data.items():
66                ssl_input[key] = value.flip(-1)
67        return ssl_input, {"order": ssl_target}

Do the flip

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
69    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
70        """
71        Cross-entropy loss
72        """
73
74        loss = self.ce(predicted, target.squeeze())
75        return loss

Cross-entropy loss

def construct_module(self) -> torch.nn.modules.module.Module:
77    def construct_module(self) -> nn.Module:
78        """
79        Construct the SSL prediction module using the parameters specified at initialization
80        """
81
82        module = _FeatureExtractorTCN(**self.pars)
83        return module

Construct the SSL prediction module using the parameters specified at initialization

class OrderSSL(dlc2action.ssl.base_ssl.SSLConstructor, abc.ABC):
 86class OrderSSL(SSLConstructor, ABC):
 87    """
 88    An order prediction SSL
 89
 90    Cut out segments from the features, permute them and predict the order.
 91    """
 92
 93    type = "ssl_target"
 94
 95    def __init__(
 96        self,
 97        num_f_maps: torch.Size,
 98        len_segment: int,
 99        num_segments: int = 3,
100        ssl_features: int = 32,
101        skip_frames: int = 10,
102    ) -> None:
103        """
104        Parameters
105        ----------
106        num_f_maps : torch.Size
107            the number of the input feature maps
108        len_segment : int
109            the length of the input segments
110        num_segments : int, default 3
111            the number of segments to permute
112        ssl_features : int, default 32
113            the number of features per permuted segment
114        skip_frames : int, default 10
115            the number of frames to cut from each permuted segment
116        """
117
118        super().__init__()
119        self.ce = CrossEntropyLoss(ignore_index=-100)
120        if len(num_f_maps) > 1:
121            raise RuntimeError(
122                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
123                f"got {len(num_f_maps) + 1} dimensions"
124            )
125        num_f_maps = int(num_f_maps[0])
126        self.orders = [list(x) for x in permutations(range(num_segments), num_segments)]
127        self.len_segment = len_segment // num_segments
128        self.num_segments = num_segments
129        self.skip_frames = skip_frames
130        self.pars = {
131            "num_f_maps": num_f_maps,
132            "len_segment": len_segment // num_segments,
133            "output_dim": ssl_features,
134            "kernel_1": 5,
135            "kernel_2": 5,
136            "stride": 2,
137            "decrease_f_maps": True,
138        }
139
140    def transformation(self, sample_data: Dict) -> Tuple:
141        """
142        Empty transformation
143        """
144
145        return torch.tensor(float("nan")), torch.tensor(float("nan"))
146
147    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
148        """
149        Cross-entropy loss
150        """
151
152        predicted, target = predicted
153        loss = self.ce(predicted, target)
154        return loss
155
156    def construct_module(self) -> nn.Module:
157        """
158        Construct the SSL prediction module using the parameters specified at initialization
159        """
160
161        class Classifier(nn.Module):
162            def __init__(self, num_segments, num_classes, skip_frames, **pars):
163                super().__init__()
164                self.len_segment = pars["len_segment"]
165                pars["len_segment"] -= skip_frames
166                self.extractor = _FeatureExtractorTCN(**pars)
167                self.num_segments = num_segments
168                self.skip_frames = skip_frames
169                self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes)
170                self.orders = torch.tensor(
171                    [list(x) for x in permutations(range(num_segments), num_segments)]
172                )
173
174            def forward(self, x):
175                target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device)
176                order = self.orders[target]
177                x = x[:, :, : self.num_segments * self.len_segment]
178                B, F, L = x.shape
179                x = x.reshape((B, F, -1, self.len_segment))
180                x = x[:, :, :, : -self.skip_frames]
181                x = x[
182                    torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1),
183                    torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0),
184                    order.unsqueeze(1).unsqueeze(-1),
185                    torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0),
186                ]
187                x = x.transpose(1, 2).reshape(
188                    (-1, F, self.len_segment - self.skip_frames)
189                )
190                x = self.extractor(x).reshape((B, -1))
191                x = self.fc(x)
192                return (x, target)
193
194        module = Classifier(
195            self.num_segments,
196            len(self.orders),
197            skip_frames=self.skip_frames,
198            **self.pars,
199        )
200        return module

An order prediction SSL

Cut out segments from the features, permute them and predict the order.

OrderSSL( num_f_maps: torch.Size, len_segment: int, num_segments: int = 3, ssl_features: int = 32, skip_frames: int = 10)
 95    def __init__(
 96        self,
 97        num_f_maps: torch.Size,
 98        len_segment: int,
 99        num_segments: int = 3,
100        ssl_features: int = 32,
101        skip_frames: int = 10,
102    ) -> None:
103        """
104        Parameters
105        ----------
106        num_f_maps : torch.Size
107            the number of the input feature maps
108        len_segment : int
109            the length of the input segments
110        num_segments : int, default 3
111            the number of segments to permute
112        ssl_features : int, default 32
113            the number of features per permuted segment
114        skip_frames : int, default 10
115            the number of frames to cut from each permuted segment
116        """
117
118        super().__init__()
119        self.ce = CrossEntropyLoss(ignore_index=-100)
120        if len(num_f_maps) > 1:
121            raise RuntimeError(
122                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
123                f"got {len(num_f_maps) + 1} dimensions"
124            )
125        num_f_maps = int(num_f_maps[0])
126        self.orders = [list(x) for x in permutations(range(num_segments), num_segments)]
127        self.len_segment = len_segment // num_segments
128        self.num_segments = num_segments
129        self.skip_frames = skip_frames
130        self.pars = {
131            "num_f_maps": num_f_maps,
132            "len_segment": len_segment // num_segments,
133            "output_dim": ssl_features,
134            "kernel_1": 5,
135            "kernel_2": 5,
136            "stride": 2,
137            "decrease_f_maps": True,
138        }

Parameters

num_f_maps : torch.Size the number of the input feature maps len_segment : int the length of the input segments num_segments : int, default 3 the number of segments to permute ssl_features : int, default 32 the number of features per permuted segment skip_frames : int, default 10 the number of frames to cut from each permuted segment

type = 'ssl_target'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
def transformation(self, sample_data: Dict) -> Tuple:
140    def transformation(self, sample_data: Dict) -> Tuple:
141        """
142        Empty transformation
143        """
144
145        return torch.tensor(float("nan")), torch.tensor(float("nan"))

Empty transformation

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
147    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
148        """
149        Cross-entropy loss
150        """
151
152        predicted, target = predicted
153        loss = self.ce(predicted, target)
154        return loss

Cross-entropy loss

def construct_module(self) -> torch.nn.modules.module.Module:
156    def construct_module(self) -> nn.Module:
157        """
158        Construct the SSL prediction module using the parameters specified at initialization
159        """
160
161        class Classifier(nn.Module):
162            def __init__(self, num_segments, num_classes, skip_frames, **pars):
163                super().__init__()
164                self.len_segment = pars["len_segment"]
165                pars["len_segment"] -= skip_frames
166                self.extractor = _FeatureExtractorTCN(**pars)
167                self.num_segments = num_segments
168                self.skip_frames = skip_frames
169                self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes)
170                self.orders = torch.tensor(
171                    [list(x) for x in permutations(range(num_segments), num_segments)]
172                )
173
174            def forward(self, x):
175                target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device)
176                order = self.orders[target]
177                x = x[:, :, : self.num_segments * self.len_segment]
178                B, F, L = x.shape
179                x = x.reshape((B, F, -1, self.len_segment))
180                x = x[:, :, :, : -self.skip_frames]
181                x = x[
182                    torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1),
183                    torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0),
184                    order.unsqueeze(1).unsqueeze(-1),
185                    torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0),
186                ]
187                x = x.transpose(1, 2).reshape(
188                    (-1, F, self.len_segment - self.skip_frames)
189                )
190                x = self.extractor(x).reshape((B, -1))
191                x = self.fc(x)
192                return (x, target)
193
194        module = Classifier(
195            self.num_segments,
196            len(self.orders),
197            skip_frames=self.skip_frames,
198            **self.pars,
199        )
200        return module

Construct the SSL prediction module using the parameters specified at initialization