dlc2action.ssl.segment_order
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6from typing import Dict, Tuple, Union, List 7import torch 8from dlc2action.ssl.base_ssl import SSLConstructor 9from abc import ABC, abstractmethod 10from dlc2action.ssl.modules import _FeatureExtractorTCN 11import torch 12from torch import nn 13from copy import deepcopy 14from itertools import permutations 15 16from torch.nn import CrossEntropyLoss, Linear, BCEWithLogitsLoss 17 18 19class ReverseSSL(SSLConstructor, ABC): 20 """ 21 A flip detection SSL 22 23 Reverse some of the segments and predict the flip with a binary classifier. 24 """ 25 26 type = "ssl_input" 27 28 def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None: 29 """ 30 Parameters 31 ---------- 32 num_f_maps : torch.Size 33 the number of input feature maps 34 len_segment : int 35 the length of the input segments 36 """ 37 38 super().__init__() 39 self.ce = BCEWithLogitsLoss() 40 if len(num_f_maps) > 1: 41 raise RuntimeError( 42 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 43 f"got {len(num_f_maps) + 1} dimensions" 44 ) 45 num_f_maps = int(num_f_maps[0]) 46 self.pars = { 47 "num_f_maps": num_f_maps, 48 "len_segment": len_segment, 49 "output_dim": 1, 50 "kernel_1": 5, 51 "kernel_2": 5, 52 "stride": 2, 53 "decrease_f_maps": True, 54 } 55 56 def transformation(self, sample_data: Dict) -> Tuple: 57 """ 58 Do the flip 59 """ 60 61 ssl_target = torch.randint(2, (1,), dtype=torch.float) 62 ssl_input = deepcopy(sample_data) 63 if ssl_target == 1: 64 for key, value in sample_data.items(): 65 ssl_input[key] = value.flip(-1) 66 return ssl_input, {"order": ssl_target} 67 68 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 69 """ 70 Cross-entropy loss 71 """ 72 73 loss = self.ce(predicted, target.squeeze()) 74 return loss 75 76 def construct_module(self) -> nn.Module: 77 """ 78 Construct the SSL prediction module using the parameters specified at initialization 79 """ 80 81 module = _FeatureExtractorTCN(**self.pars) 82 return module 83 84 85class OrderSSL(SSLConstructor, ABC): 86 """ 87 An order prediction SSL 88 89 Cut out segments from the features, permute them and predict the order. 90 """ 91 92 type = "ssl_target" 93 94 def __init__( 95 self, 96 num_f_maps: torch.Size, 97 len_segment: int, 98 num_segments: int = 3, 99 ssl_features: int = 32, 100 skip_frames: int = 10, 101 ) -> None: 102 """ 103 Parameters 104 ---------- 105 num_f_maps : torch.Size 106 the number of the input feature maps 107 len_segment : int 108 the length of the input segments 109 num_segments : int, default 3 110 the number of segments to permute 111 ssl_features : int, default 32 112 the number of features per permuted segment 113 skip_frames : int, default 10 114 the number of frames to cut from each permuted segment 115 """ 116 117 super().__init__() 118 self.ce = CrossEntropyLoss(ignore_index=-100) 119 if len(num_f_maps) > 1: 120 raise RuntimeError( 121 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 122 f"got {len(num_f_maps) + 1} dimensions" 123 ) 124 num_f_maps = int(num_f_maps[0]) 125 self.orders = [list(x) for x in permutations(range(num_segments), num_segments)] 126 self.len_segment = len_segment // num_segments 127 self.num_segments = num_segments 128 self.skip_frames = skip_frames 129 self.pars = { 130 "num_f_maps": num_f_maps, 131 "len_segment": len_segment // num_segments, 132 "output_dim": ssl_features, 133 "kernel_1": 5, 134 "kernel_2": 5, 135 "stride": 2, 136 "decrease_f_maps": True, 137 } 138 139 def transformation(self, sample_data: Dict) -> Tuple: 140 """ 141 Empty transformation 142 """ 143 144 return torch.tensor(float("nan")), torch.tensor(float("nan")) 145 146 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 147 """ 148 Cross-entropy loss 149 """ 150 151 predicted, target = predicted 152 loss = self.ce(predicted, target) 153 return loss 154 155 def construct_module(self) -> nn.Module: 156 """ 157 Construct the SSL prediction module using the parameters specified at initialization 158 """ 159 160 class Classifier(nn.Module): 161 def __init__(self, num_segments, num_classes, skip_frames, **pars): 162 super().__init__() 163 self.len_segment = pars["len_segment"] 164 pars["len_segment"] -= skip_frames 165 self.extractor = _FeatureExtractorTCN(**pars) 166 self.num_segments = num_segments 167 self.skip_frames = skip_frames 168 self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes) 169 self.orders = torch.tensor( 170 [list(x) for x in permutations(range(num_segments), num_segments)] 171 ) 172 173 def forward(self, x): 174 target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device) 175 order = self.orders[target] 176 x = x[:, :, : self.num_segments * self.len_segment] 177 B, F, L = x.shape 178 x = x.reshape((B, F, -1, self.len_segment)) 179 x = x[:, :, :, : -self.skip_frames] 180 x = x[ 181 torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), 182 torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0), 183 order.unsqueeze(1).unsqueeze(-1), 184 torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0), 185 ] 186 x = x.transpose(1, 2).reshape( 187 (-1, F, self.len_segment - self.skip_frames) 188 ) 189 x = self.extractor(x).reshape((B, -1)) 190 x = self.fc(x) 191 return (x, target) 192 193 module = Classifier( 194 self.num_segments, 195 len(self.orders), 196 skip_frames=self.skip_frames, 197 **self.pars, 198 ) 199 return module
20class ReverseSSL(SSLConstructor, ABC): 21 """ 22 A flip detection SSL 23 24 Reverse some of the segments and predict the flip with a binary classifier. 25 """ 26 27 type = "ssl_input" 28 29 def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None: 30 """ 31 Parameters 32 ---------- 33 num_f_maps : torch.Size 34 the number of input feature maps 35 len_segment : int 36 the length of the input segments 37 """ 38 39 super().__init__() 40 self.ce = BCEWithLogitsLoss() 41 if len(num_f_maps) > 1: 42 raise RuntimeError( 43 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 44 f"got {len(num_f_maps) + 1} dimensions" 45 ) 46 num_f_maps = int(num_f_maps[0]) 47 self.pars = { 48 "num_f_maps": num_f_maps, 49 "len_segment": len_segment, 50 "output_dim": 1, 51 "kernel_1": 5, 52 "kernel_2": 5, 53 "stride": 2, 54 "decrease_f_maps": True, 55 } 56 57 def transformation(self, sample_data: Dict) -> Tuple: 58 """ 59 Do the flip 60 """ 61 62 ssl_target = torch.randint(2, (1,), dtype=torch.float) 63 ssl_input = deepcopy(sample_data) 64 if ssl_target == 1: 65 for key, value in sample_data.items(): 66 ssl_input[key] = value.flip(-1) 67 return ssl_input, {"order": ssl_target} 68 69 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 70 """ 71 Cross-entropy loss 72 """ 73 74 loss = self.ce(predicted, target.squeeze()) 75 return loss 76 77 def construct_module(self) -> nn.Module: 78 """ 79 Construct the SSL prediction module using the parameters specified at initialization 80 """ 81 82 module = _FeatureExtractorTCN(**self.pars) 83 return module
A flip detection SSL
Reverse some of the segments and predict the flip with a binary classifier.
29 def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None: 30 """ 31 Parameters 32 ---------- 33 num_f_maps : torch.Size 34 the number of input feature maps 35 len_segment : int 36 the length of the input segments 37 """ 38 39 super().__init__() 40 self.ce = BCEWithLogitsLoss() 41 if len(num_f_maps) > 1: 42 raise RuntimeError( 43 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 44 f"got {len(num_f_maps) + 1} dimensions" 45 ) 46 num_f_maps = int(num_f_maps[0]) 47 self.pars = { 48 "num_f_maps": num_f_maps, 49 "len_segment": len_segment, 50 "output_dim": 1, 51 "kernel_1": 5, 52 "kernel_2": 5, 53 "stride": 2, 54 "decrease_f_maps": True, 55 }
Parameters
num_f_maps : torch.Size the number of input feature maps len_segment : int the length of the input segments
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
57 def transformation(self, sample_data: Dict) -> Tuple: 58 """ 59 Do the flip 60 """ 61 62 ssl_target = torch.randint(2, (1,), dtype=torch.float) 63 ssl_input = deepcopy(sample_data) 64 if ssl_target == 1: 65 for key, value in sample_data.items(): 66 ssl_input[key] = value.flip(-1) 67 return ssl_input, {"order": ssl_target}
Do the flip
69 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 70 """ 71 Cross-entropy loss 72 """ 73 74 loss = self.ce(predicted, target.squeeze()) 75 return loss
Cross-entropy loss
77 def construct_module(self) -> nn.Module: 78 """ 79 Construct the SSL prediction module using the parameters specified at initialization 80 """ 81 82 module = _FeatureExtractorTCN(**self.pars) 83 return module
Construct the SSL prediction module using the parameters specified at initialization
86class OrderSSL(SSLConstructor, ABC): 87 """ 88 An order prediction SSL 89 90 Cut out segments from the features, permute them and predict the order. 91 """ 92 93 type = "ssl_target" 94 95 def __init__( 96 self, 97 num_f_maps: torch.Size, 98 len_segment: int, 99 num_segments: int = 3, 100 ssl_features: int = 32, 101 skip_frames: int = 10, 102 ) -> None: 103 """ 104 Parameters 105 ---------- 106 num_f_maps : torch.Size 107 the number of the input feature maps 108 len_segment : int 109 the length of the input segments 110 num_segments : int, default 3 111 the number of segments to permute 112 ssl_features : int, default 32 113 the number of features per permuted segment 114 skip_frames : int, default 10 115 the number of frames to cut from each permuted segment 116 """ 117 118 super().__init__() 119 self.ce = CrossEntropyLoss(ignore_index=-100) 120 if len(num_f_maps) > 1: 121 raise RuntimeError( 122 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 123 f"got {len(num_f_maps) + 1} dimensions" 124 ) 125 num_f_maps = int(num_f_maps[0]) 126 self.orders = [list(x) for x in permutations(range(num_segments), num_segments)] 127 self.len_segment = len_segment // num_segments 128 self.num_segments = num_segments 129 self.skip_frames = skip_frames 130 self.pars = { 131 "num_f_maps": num_f_maps, 132 "len_segment": len_segment // num_segments, 133 "output_dim": ssl_features, 134 "kernel_1": 5, 135 "kernel_2": 5, 136 "stride": 2, 137 "decrease_f_maps": True, 138 } 139 140 def transformation(self, sample_data: Dict) -> Tuple: 141 """ 142 Empty transformation 143 """ 144 145 return torch.tensor(float("nan")), torch.tensor(float("nan")) 146 147 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 148 """ 149 Cross-entropy loss 150 """ 151 152 predicted, target = predicted 153 loss = self.ce(predicted, target) 154 return loss 155 156 def construct_module(self) -> nn.Module: 157 """ 158 Construct the SSL prediction module using the parameters specified at initialization 159 """ 160 161 class Classifier(nn.Module): 162 def __init__(self, num_segments, num_classes, skip_frames, **pars): 163 super().__init__() 164 self.len_segment = pars["len_segment"] 165 pars["len_segment"] -= skip_frames 166 self.extractor = _FeatureExtractorTCN(**pars) 167 self.num_segments = num_segments 168 self.skip_frames = skip_frames 169 self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes) 170 self.orders = torch.tensor( 171 [list(x) for x in permutations(range(num_segments), num_segments)] 172 ) 173 174 def forward(self, x): 175 target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device) 176 order = self.orders[target] 177 x = x[:, :, : self.num_segments * self.len_segment] 178 B, F, L = x.shape 179 x = x.reshape((B, F, -1, self.len_segment)) 180 x = x[:, :, :, : -self.skip_frames] 181 x = x[ 182 torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), 183 torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0), 184 order.unsqueeze(1).unsqueeze(-1), 185 torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0), 186 ] 187 x = x.transpose(1, 2).reshape( 188 (-1, F, self.len_segment - self.skip_frames) 189 ) 190 x = self.extractor(x).reshape((B, -1)) 191 x = self.fc(x) 192 return (x, target) 193 194 module = Classifier( 195 self.num_segments, 196 len(self.orders), 197 skip_frames=self.skip_frames, 198 **self.pars, 199 ) 200 return module
An order prediction SSL
Cut out segments from the features, permute them and predict the order.
95 def __init__( 96 self, 97 num_f_maps: torch.Size, 98 len_segment: int, 99 num_segments: int = 3, 100 ssl_features: int = 32, 101 skip_frames: int = 10, 102 ) -> None: 103 """ 104 Parameters 105 ---------- 106 num_f_maps : torch.Size 107 the number of the input feature maps 108 len_segment : int 109 the length of the input segments 110 num_segments : int, default 3 111 the number of segments to permute 112 ssl_features : int, default 32 113 the number of features per permuted segment 114 skip_frames : int, default 10 115 the number of frames to cut from each permuted segment 116 """ 117 118 super().__init__() 119 self.ce = CrossEntropyLoss(ignore_index=-100) 120 if len(num_f_maps) > 1: 121 raise RuntimeError( 122 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 123 f"got {len(num_f_maps) + 1} dimensions" 124 ) 125 num_f_maps = int(num_f_maps[0]) 126 self.orders = [list(x) for x in permutations(range(num_segments), num_segments)] 127 self.len_segment = len_segment // num_segments 128 self.num_segments = num_segments 129 self.skip_frames = skip_frames 130 self.pars = { 131 "num_f_maps": num_f_maps, 132 "len_segment": len_segment // num_segments, 133 "output_dim": ssl_features, 134 "kernel_1": 5, 135 "kernel_2": 5, 136 "stride": 2, 137 "decrease_f_maps": True, 138 }
Parameters
num_f_maps : torch.Size the number of the input feature maps len_segment : int the length of the input segments num_segments : int, default 3 the number of segments to permute ssl_features : int, default 32 the number of features per permuted segment skip_frames : int, default 10 the number of frames to cut from each permuted segment
The type
parameter defines interaction with the model:
'ssl_input'
: a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output'
: the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive'
: the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers'
: the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
140 def transformation(self, sample_data: Dict) -> Tuple: 141 """ 142 Empty transformation 143 """ 144 145 return torch.tensor(float("nan")), torch.tensor(float("nan"))
Empty transformation
147 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 148 """ 149 Cross-entropy loss 150 """ 151 152 predicted, target = predicted 153 loss = self.ce(predicted, target) 154 return loss
Cross-entropy loss
156 def construct_module(self) -> nn.Module: 157 """ 158 Construct the SSL prediction module using the parameters specified at initialization 159 """ 160 161 class Classifier(nn.Module): 162 def __init__(self, num_segments, num_classes, skip_frames, **pars): 163 super().__init__() 164 self.len_segment = pars["len_segment"] 165 pars["len_segment"] -= skip_frames 166 self.extractor = _FeatureExtractorTCN(**pars) 167 self.num_segments = num_segments 168 self.skip_frames = skip_frames 169 self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes) 170 self.orders = torch.tensor( 171 [list(x) for x in permutations(range(num_segments), num_segments)] 172 ) 173 174 def forward(self, x): 175 target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device) 176 order = self.orders[target] 177 x = x[:, :, : self.num_segments * self.len_segment] 178 B, F, L = x.shape 179 x = x.reshape((B, F, -1, self.len_segment)) 180 x = x[:, :, :, : -self.skip_frames] 181 x = x[ 182 torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), 183 torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0), 184 order.unsqueeze(1).unsqueeze(-1), 185 torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0), 186 ] 187 x = x.transpose(1, 2).reshape( 188 (-1, F, self.len_segment - self.skip_frames) 189 ) 190 x = self.extractor(x).reshape((B, -1)) 191 x = self.fc(x) 192 return (x, target) 193 194 module = Classifier( 195 self.num_segments, 196 len(self.orders), 197 skip_frames=self.skip_frames, 198 **self.pars, 199 ) 200 return module
Construct the SSL prediction module using the parameters specified at initialization