dlc2action.ssl.segment_order

Segment order SSL constructors.

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""Segment order SSL constructors."""
  8
  9from typing import Dict, Tuple, Union, List
 10import torch
 11from dlc2action.ssl.base_ssl import SSLConstructor
 12from abc import ABC, abstractmethod
 13from dlc2action.ssl.modules import *
 14from copy import deepcopy
 15from itertools import permutations
 16
 17from torch.nn import CrossEntropyLoss, Linear, BCEWithLogitsLoss
 18
 19class ReverseSSL(SSLConstructor, ABC):
 20    """A flip detection SSL.
 21
 22    Reverse some of the segments and predict the flip with a binary classifier.
 23    """
 24
 25    type = "ssl_input"
 26
 27    def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None:
 28        """Initialize the constructor.
 29
 30        Parameters
 31        ----------
 32        num_f_maps : torch.Size
 33            the number of input feature maps
 34        len_segment : int
 35            the length of the input segments
 36
 37        """
 38        super().__init__()
 39        self.ce = BCEWithLogitsLoss()
 40        if len(num_f_maps) > 1:
 41            raise RuntimeError(
 42                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
 43                f"got {len(num_f_maps) + 1} dimensions"
 44            )
 45        num_f_maps = int(num_f_maps[0])
 46        self.pars = {
 47            "num_f_maps": num_f_maps,
 48            "len_segment": len_segment,
 49            "output_dim": 1,
 50            "kernel_1": 5,
 51            "kernel_2": 5,
 52            "stride": 2,
 53            "decrease_f_maps": True,
 54        }
 55
 56    def transformation(self, sample_data: Dict) -> Tuple:
 57        """Do the flip."""
 58        ssl_target = torch.randint(2, (1,), dtype=torch.float)
 59        ssl_input = deepcopy(sample_data)
 60        if ssl_target == 1:
 61            for key, value in sample_data.items():
 62                ssl_input[key] = value.flip(-1)
 63        return ssl_input, {"order": ssl_target}
 64
 65    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
 66        """Cross-entropy loss."""
 67        loss = self.ce(predicted, target.squeeze())
 68        return loss
 69
 70    def construct_module(self) -> nn.Module:
 71        """Construct the SSL prediction module using the parameters specified at initialization."""
 72        module = FeatureExtractorTCN(**self.pars)
 73        return module
 74
 75
 76class OrderSSL(SSLConstructor, ABC):
 77    """An order prediction SSL.
 78
 79    Cut out segments from the features, permute them and predict the order.
 80    """
 81
 82    type = "ssl_target"
 83
 84    def __init__(
 85        self,
 86        num_f_maps: torch.Size,
 87        len_segment: int,
 88        num_segments: int = 3,
 89        ssl_features: int = 32,
 90        skip_frames: int = 10,
 91    ) -> None:
 92        """Initialize the OrderSSL.
 93
 94        Parameters
 95        ----------
 96        num_f_maps : torch.Size
 97            the number of the input feature maps
 98        len_segment : int
 99            the length of the input segments
100        num_segments : int, default 3
101            the number of segments to permute
102        ssl_features : int, default 32
103            the number of features per permuted segment
104        skip_frames : int, default 10
105            the number of frames to cut from each permuted segment
106
107        """
108        super().__init__()
109        self.ce = CrossEntropyLoss(ignore_index=-100)
110        if len(num_f_maps) > 1:
111            raise RuntimeError(
112                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
113                f"got {len(num_f_maps) + 1} dimensions"
114            )
115        num_f_maps = int(num_f_maps[0])
116        self.orders = [list(x) for x in permutations(range(num_segments), num_segments)]
117        self.len_segment = len_segment // num_segments
118        self.num_segments = num_segments
119        self.skip_frames = skip_frames
120        self.pars = {
121            "num_f_maps": num_f_maps,
122            "len_segment": len_segment // num_segments,
123            "output_dim": ssl_features,
124            "kernel_1": 5,
125            "kernel_2": 5,
126            "stride": 2,
127            "decrease_f_maps": True,
128        }
129
130    def transformation(self, sample_data: Dict) -> Tuple:
131        """Empty transformation."""
132        return torch.tensor(float("nan")), torch.tensor(float("nan"))
133
134    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
135        """Cross-entropy loss."""
136        predicted, target = predicted
137        loss = self.ce(predicted, target)
138        return loss
139
140    def construct_module(self) -> nn.Module:
141        """Construct the SSL prediction module using the parameters specified at initialization."""
142        class Classifier(nn.Module):
143            def __init__(self, num_segments, num_classes, skip_frames, **pars):
144                super().__init__()
145                self.len_segment = pars["len_segment"]
146                pars["len_segment"] -= skip_frames
147                self.extractor = FeatureExtractorTCN(**pars)
148                self.num_segments = num_segments
149                self.skip_frames = skip_frames
150                self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes)
151                self.orders = torch.tensor(
152                    [list(x) for x in permutations(range(num_segments), num_segments)]
153                )
154
155            def forward(self, x):
156                target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device)
157                order = self.orders.to(x.device)[target]
158                x = x[:, :, : self.num_segments * self.len_segment]
159                B, F, L = x.shape
160                x = x.reshape((B, F, -1, self.len_segment))
161                x = x[:, :, :, : -self.skip_frames]
162                x = x[
163                    torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1),
164                    torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0),
165                    order.unsqueeze(1).unsqueeze(-1),
166                    torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0),
167                ]
168                x = x.transpose(1, 2).reshape(
169                    (-1, F, self.len_segment - self.skip_frames)
170                )
171                x = self.extractor(x).reshape((B, -1))
172                x = self.fc(x)
173                return (x, target)
174
175        module = Classifier(
176            self.num_segments,
177            len(self.orders),
178            skip_frames=self.skip_frames,
179            **self.pars,
180        )
181        return module
class ReverseSSL(dlc2action.ssl.base_ssl.SSLConstructor, abc.ABC):
20class ReverseSSL(SSLConstructor, ABC):
21    """A flip detection SSL.
22
23    Reverse some of the segments and predict the flip with a binary classifier.
24    """
25
26    type = "ssl_input"
27
28    def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None:
29        """Initialize the constructor.
30
31        Parameters
32        ----------
33        num_f_maps : torch.Size
34            the number of input feature maps
35        len_segment : int
36            the length of the input segments
37
38        """
39        super().__init__()
40        self.ce = BCEWithLogitsLoss()
41        if len(num_f_maps) > 1:
42            raise RuntimeError(
43                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
44                f"got {len(num_f_maps) + 1} dimensions"
45            )
46        num_f_maps = int(num_f_maps[0])
47        self.pars = {
48            "num_f_maps": num_f_maps,
49            "len_segment": len_segment,
50            "output_dim": 1,
51            "kernel_1": 5,
52            "kernel_2": 5,
53            "stride": 2,
54            "decrease_f_maps": True,
55        }
56
57    def transformation(self, sample_data: Dict) -> Tuple:
58        """Do the flip."""
59        ssl_target = torch.randint(2, (1,), dtype=torch.float)
60        ssl_input = deepcopy(sample_data)
61        if ssl_target == 1:
62            for key, value in sample_data.items():
63                ssl_input[key] = value.flip(-1)
64        return ssl_input, {"order": ssl_target}
65
66    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
67        """Cross-entropy loss."""
68        loss = self.ce(predicted, target.squeeze())
69        return loss
70
71    def construct_module(self) -> nn.Module:
72        """Construct the SSL prediction module using the parameters specified at initialization."""
73        module = FeatureExtractorTCN(**self.pars)
74        return module

A flip detection SSL.

Reverse some of the segments and predict the flip with a binary classifier.

ReverseSSL(num_f_maps: torch.Size, len_segment: int)
28    def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None:
29        """Initialize the constructor.
30
31        Parameters
32        ----------
33        num_f_maps : torch.Size
34            the number of input feature maps
35        len_segment : int
36            the length of the input segments
37
38        """
39        super().__init__()
40        self.ce = BCEWithLogitsLoss()
41        if len(num_f_maps) > 1:
42            raise RuntimeError(
43                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
44                f"got {len(num_f_maps) + 1} dimensions"
45            )
46        num_f_maps = int(num_f_maps[0])
47        self.pars = {
48            "num_f_maps": num_f_maps,
49            "len_segment": len_segment,
50            "output_dim": 1,
51            "kernel_1": 5,
52            "kernel_2": 5,
53            "stride": 2,
54            "decrease_f_maps": True,
55        }

Initialize the constructor.

Parameters

num_f_maps : torch.Size the number of input feature maps len_segment : int the length of the input segments

type = 'ssl_input'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
ce
pars
def transformation(self, sample_data: Dict) -> Tuple:
57    def transformation(self, sample_data: Dict) -> Tuple:
58        """Do the flip."""
59        ssl_target = torch.randint(2, (1,), dtype=torch.float)
60        ssl_input = deepcopy(sample_data)
61        if ssl_target == 1:
62            for key, value in sample_data.items():
63                ssl_input[key] = value.flip(-1)
64        return ssl_input, {"order": ssl_target}

Do the flip.

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
66    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
67        """Cross-entropy loss."""
68        loss = self.ce(predicted, target.squeeze())
69        return loss

Cross-entropy loss.

def construct_module(self) -> torch.nn.modules.module.Module:
71    def construct_module(self) -> nn.Module:
72        """Construct the SSL prediction module using the parameters specified at initialization."""
73        module = FeatureExtractorTCN(**self.pars)
74        return module

Construct the SSL prediction module using the parameters specified at initialization.

class OrderSSL(dlc2action.ssl.base_ssl.SSLConstructor, abc.ABC):
 77class OrderSSL(SSLConstructor, ABC):
 78    """An order prediction SSL.
 79
 80    Cut out segments from the features, permute them and predict the order.
 81    """
 82
 83    type = "ssl_target"
 84
 85    def __init__(
 86        self,
 87        num_f_maps: torch.Size,
 88        len_segment: int,
 89        num_segments: int = 3,
 90        ssl_features: int = 32,
 91        skip_frames: int = 10,
 92    ) -> None:
 93        """Initialize the OrderSSL.
 94
 95        Parameters
 96        ----------
 97        num_f_maps : torch.Size
 98            the number of the input feature maps
 99        len_segment : int
100            the length of the input segments
101        num_segments : int, default 3
102            the number of segments to permute
103        ssl_features : int, default 32
104            the number of features per permuted segment
105        skip_frames : int, default 10
106            the number of frames to cut from each permuted segment
107
108        """
109        super().__init__()
110        self.ce = CrossEntropyLoss(ignore_index=-100)
111        if len(num_f_maps) > 1:
112            raise RuntimeError(
113                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
114                f"got {len(num_f_maps) + 1} dimensions"
115            )
116        num_f_maps = int(num_f_maps[0])
117        self.orders = [list(x) for x in permutations(range(num_segments), num_segments)]
118        self.len_segment = len_segment // num_segments
119        self.num_segments = num_segments
120        self.skip_frames = skip_frames
121        self.pars = {
122            "num_f_maps": num_f_maps,
123            "len_segment": len_segment // num_segments,
124            "output_dim": ssl_features,
125            "kernel_1": 5,
126            "kernel_2": 5,
127            "stride": 2,
128            "decrease_f_maps": True,
129        }
130
131    def transformation(self, sample_data: Dict) -> Tuple:
132        """Empty transformation."""
133        return torch.tensor(float("nan")), torch.tensor(float("nan"))
134
135    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
136        """Cross-entropy loss."""
137        predicted, target = predicted
138        loss = self.ce(predicted, target)
139        return loss
140
141    def construct_module(self) -> nn.Module:
142        """Construct the SSL prediction module using the parameters specified at initialization."""
143        class Classifier(nn.Module):
144            def __init__(self, num_segments, num_classes, skip_frames, **pars):
145                super().__init__()
146                self.len_segment = pars["len_segment"]
147                pars["len_segment"] -= skip_frames
148                self.extractor = FeatureExtractorTCN(**pars)
149                self.num_segments = num_segments
150                self.skip_frames = skip_frames
151                self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes)
152                self.orders = torch.tensor(
153                    [list(x) for x in permutations(range(num_segments), num_segments)]
154                )
155
156            def forward(self, x):
157                target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device)
158                order = self.orders.to(x.device)[target]
159                x = x[:, :, : self.num_segments * self.len_segment]
160                B, F, L = x.shape
161                x = x.reshape((B, F, -1, self.len_segment))
162                x = x[:, :, :, : -self.skip_frames]
163                x = x[
164                    torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1),
165                    torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0),
166                    order.unsqueeze(1).unsqueeze(-1),
167                    torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0),
168                ]
169                x = x.transpose(1, 2).reshape(
170                    (-1, F, self.len_segment - self.skip_frames)
171                )
172                x = self.extractor(x).reshape((B, -1))
173                x = self.fc(x)
174                return (x, target)
175
176        module = Classifier(
177            self.num_segments,
178            len(self.orders),
179            skip_frames=self.skip_frames,
180            **self.pars,
181        )
182        return module

An order prediction SSL.

Cut out segments from the features, permute them and predict the order.

OrderSSL( num_f_maps: torch.Size, len_segment: int, num_segments: int = 3, ssl_features: int = 32, skip_frames: int = 10)
 85    def __init__(
 86        self,
 87        num_f_maps: torch.Size,
 88        len_segment: int,
 89        num_segments: int = 3,
 90        ssl_features: int = 32,
 91        skip_frames: int = 10,
 92    ) -> None:
 93        """Initialize the OrderSSL.
 94
 95        Parameters
 96        ----------
 97        num_f_maps : torch.Size
 98            the number of the input feature maps
 99        len_segment : int
100            the length of the input segments
101        num_segments : int, default 3
102            the number of segments to permute
103        ssl_features : int, default 32
104            the number of features per permuted segment
105        skip_frames : int, default 10
106            the number of frames to cut from each permuted segment
107
108        """
109        super().__init__()
110        self.ce = CrossEntropyLoss(ignore_index=-100)
111        if len(num_f_maps) > 1:
112            raise RuntimeError(
113                "The ContrastiveSSL constructor expects the input data to be 2-dimensional; "
114                f"got {len(num_f_maps) + 1} dimensions"
115            )
116        num_f_maps = int(num_f_maps[0])
117        self.orders = [list(x) for x in permutations(range(num_segments), num_segments)]
118        self.len_segment = len_segment // num_segments
119        self.num_segments = num_segments
120        self.skip_frames = skip_frames
121        self.pars = {
122            "num_f_maps": num_f_maps,
123            "len_segment": len_segment // num_segments,
124            "output_dim": ssl_features,
125            "kernel_1": 5,
126            "kernel_2": 5,
127            "stride": 2,
128            "decrease_f_maps": True,
129        }

Initialize the OrderSSL.

Parameters

num_f_maps : torch.Size the number of the input feature maps len_segment : int the length of the input segments num_segments : int, default 3 the number of segments to permute ssl_features : int, default 32 the number of features per permuted segment skip_frames : int, default 10 the number of frames to cut from each permuted segment

type = 'ssl_target'

The type parameter defines interaction with the model:

  • 'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),
  • 'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,
  • 'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
ce
orders
len_segment
num_segments
skip_frames
pars
def transformation(self, sample_data: Dict) -> Tuple:
131    def transformation(self, sample_data: Dict) -> Tuple:
132        """Empty transformation."""
133        return torch.tensor(float("nan")), torch.tensor(float("nan"))

Empty transformation.

def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
135    def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float:
136        """Cross-entropy loss."""
137        predicted, target = predicted
138        loss = self.ce(predicted, target)
139        return loss

Cross-entropy loss.

def construct_module(self) -> torch.nn.modules.module.Module:
141    def construct_module(self) -> nn.Module:
142        """Construct the SSL prediction module using the parameters specified at initialization."""
143        class Classifier(nn.Module):
144            def __init__(self, num_segments, num_classes, skip_frames, **pars):
145                super().__init__()
146                self.len_segment = pars["len_segment"]
147                pars["len_segment"] -= skip_frames
148                self.extractor = FeatureExtractorTCN(**pars)
149                self.num_segments = num_segments
150                self.skip_frames = skip_frames
151                self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes)
152                self.orders = torch.tensor(
153                    [list(x) for x in permutations(range(num_segments), num_segments)]
154                )
155
156            def forward(self, x):
157                target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device)
158                order = self.orders.to(x.device)[target]
159                x = x[:, :, : self.num_segments * self.len_segment]
160                B, F, L = x.shape
161                x = x.reshape((B, F, -1, self.len_segment))
162                x = x[:, :, :, : -self.skip_frames]
163                x = x[
164                    torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1),
165                    torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0),
166                    order.unsqueeze(1).unsqueeze(-1),
167                    torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0),
168                ]
169                x = x.transpose(1, 2).reshape(
170                    (-1, F, self.len_segment - self.skip_frames)
171                )
172                x = self.extractor(x).reshape((B, -1))
173                x = self.fc(x)
174                return (x, target)
175
176        module = Classifier(
177            self.num_segments,
178            len(self.orders),
179            skip_frames=self.skip_frames,
180            **self.pars,
181        )
182        return module

Construct the SSL prediction module using the parameters specified at initialization.