dlc2action.ssl.segment_order
Segment order SSL constructors.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Segment order SSL constructors.""" 8 9from typing import Dict, Tuple, Union, List 10import torch 11from dlc2action.ssl.base_ssl import SSLConstructor 12from abc import ABC, abstractmethod 13from dlc2action.ssl.modules import * 14from copy import deepcopy 15from itertools import permutations 16 17from torch.nn import CrossEntropyLoss, Linear, BCEWithLogitsLoss 18 19class ReverseSSL(SSLConstructor, ABC): 20 """A flip detection SSL. 21 22 Reverse some of the segments and predict the flip with a binary classifier. 23 """ 24 25 type = "ssl_input" 26 27 def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None: 28 """Initialize the constructor. 29 30 Parameters 31 ---------- 32 num_f_maps : torch.Size 33 the number of input feature maps 34 len_segment : int 35 the length of the input segments 36 37 """ 38 super().__init__() 39 self.ce = BCEWithLogitsLoss() 40 if len(num_f_maps) > 1: 41 raise RuntimeError( 42 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 43 f"got {len(num_f_maps) + 1} dimensions" 44 ) 45 num_f_maps = int(num_f_maps[0]) 46 self.pars = { 47 "num_f_maps": num_f_maps, 48 "len_segment": len_segment, 49 "output_dim": 1, 50 "kernel_1": 5, 51 "kernel_2": 5, 52 "stride": 2, 53 "decrease_f_maps": True, 54 } 55 56 def transformation(self, sample_data: Dict) -> Tuple: 57 """Do the flip.""" 58 ssl_target = torch.randint(2, (1,), dtype=torch.float) 59 ssl_input = deepcopy(sample_data) 60 if ssl_target == 1: 61 for key, value in sample_data.items(): 62 ssl_input[key] = value.flip(-1) 63 return ssl_input, {"order": ssl_target} 64 65 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 66 """Cross-entropy loss.""" 67 loss = self.ce(predicted, target.squeeze()) 68 return loss 69 70 def construct_module(self) -> nn.Module: 71 """Construct the SSL prediction module using the parameters specified at initialization.""" 72 module = FeatureExtractorTCN(**self.pars) 73 return module 74 75 76class OrderSSL(SSLConstructor, ABC): 77 """An order prediction SSL. 78 79 Cut out segments from the features, permute them and predict the order. 80 """ 81 82 type = "ssl_target" 83 84 def __init__( 85 self, 86 num_f_maps: torch.Size, 87 len_segment: int, 88 num_segments: int = 3, 89 ssl_features: int = 32, 90 skip_frames: int = 10, 91 ) -> None: 92 """Initialize the OrderSSL. 93 94 Parameters 95 ---------- 96 num_f_maps : torch.Size 97 the number of the input feature maps 98 len_segment : int 99 the length of the input segments 100 num_segments : int, default 3 101 the number of segments to permute 102 ssl_features : int, default 32 103 the number of features per permuted segment 104 skip_frames : int, default 10 105 the number of frames to cut from each permuted segment 106 107 """ 108 super().__init__() 109 self.ce = CrossEntropyLoss(ignore_index=-100) 110 if len(num_f_maps) > 1: 111 raise RuntimeError( 112 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 113 f"got {len(num_f_maps) + 1} dimensions" 114 ) 115 num_f_maps = int(num_f_maps[0]) 116 self.orders = [list(x) for x in permutations(range(num_segments), num_segments)] 117 self.len_segment = len_segment // num_segments 118 self.num_segments = num_segments 119 self.skip_frames = skip_frames 120 self.pars = { 121 "num_f_maps": num_f_maps, 122 "len_segment": len_segment // num_segments, 123 "output_dim": ssl_features, 124 "kernel_1": 5, 125 "kernel_2": 5, 126 "stride": 2, 127 "decrease_f_maps": True, 128 } 129 130 def transformation(self, sample_data: Dict) -> Tuple: 131 """Empty transformation.""" 132 return torch.tensor(float("nan")), torch.tensor(float("nan")) 133 134 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 135 """Cross-entropy loss.""" 136 predicted, target = predicted 137 loss = self.ce(predicted, target) 138 return loss 139 140 def construct_module(self) -> nn.Module: 141 """Construct the SSL prediction module using the parameters specified at initialization.""" 142 class Classifier(nn.Module): 143 def __init__(self, num_segments, num_classes, skip_frames, **pars): 144 super().__init__() 145 self.len_segment = pars["len_segment"] 146 pars["len_segment"] -= skip_frames 147 self.extractor = FeatureExtractorTCN(**pars) 148 self.num_segments = num_segments 149 self.skip_frames = skip_frames 150 self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes) 151 self.orders = torch.tensor( 152 [list(x) for x in permutations(range(num_segments), num_segments)] 153 ) 154 155 def forward(self, x): 156 target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device) 157 order = self.orders.to(x.device)[target] 158 x = x[:, :, : self.num_segments * self.len_segment] 159 B, F, L = x.shape 160 x = x.reshape((B, F, -1, self.len_segment)) 161 x = x[:, :, :, : -self.skip_frames] 162 x = x[ 163 torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), 164 torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0), 165 order.unsqueeze(1).unsqueeze(-1), 166 torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0), 167 ] 168 x = x.transpose(1, 2).reshape( 169 (-1, F, self.len_segment - self.skip_frames) 170 ) 171 x = self.extractor(x).reshape((B, -1)) 172 x = self.fc(x) 173 return (x, target) 174 175 module = Classifier( 176 self.num_segments, 177 len(self.orders), 178 skip_frames=self.skip_frames, 179 **self.pars, 180 ) 181 return module
20class ReverseSSL(SSLConstructor, ABC): 21 """A flip detection SSL. 22 23 Reverse some of the segments and predict the flip with a binary classifier. 24 """ 25 26 type = "ssl_input" 27 28 def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None: 29 """Initialize the constructor. 30 31 Parameters 32 ---------- 33 num_f_maps : torch.Size 34 the number of input feature maps 35 len_segment : int 36 the length of the input segments 37 38 """ 39 super().__init__() 40 self.ce = BCEWithLogitsLoss() 41 if len(num_f_maps) > 1: 42 raise RuntimeError( 43 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 44 f"got {len(num_f_maps) + 1} dimensions" 45 ) 46 num_f_maps = int(num_f_maps[0]) 47 self.pars = { 48 "num_f_maps": num_f_maps, 49 "len_segment": len_segment, 50 "output_dim": 1, 51 "kernel_1": 5, 52 "kernel_2": 5, 53 "stride": 2, 54 "decrease_f_maps": True, 55 } 56 57 def transformation(self, sample_data: Dict) -> Tuple: 58 """Do the flip.""" 59 ssl_target = torch.randint(2, (1,), dtype=torch.float) 60 ssl_input = deepcopy(sample_data) 61 if ssl_target == 1: 62 for key, value in sample_data.items(): 63 ssl_input[key] = value.flip(-1) 64 return ssl_input, {"order": ssl_target} 65 66 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 67 """Cross-entropy loss.""" 68 loss = self.ce(predicted, target.squeeze()) 69 return loss 70 71 def construct_module(self) -> nn.Module: 72 """Construct the SSL prediction module using the parameters specified at initialization.""" 73 module = FeatureExtractorTCN(**self.pars) 74 return module
A flip detection SSL.
Reverse some of the segments and predict the flip with a binary classifier.
28 def __init__(self, num_f_maps: torch.Size, len_segment: int) -> None: 29 """Initialize the constructor. 30 31 Parameters 32 ---------- 33 num_f_maps : torch.Size 34 the number of input feature maps 35 len_segment : int 36 the length of the input segments 37 38 """ 39 super().__init__() 40 self.ce = BCEWithLogitsLoss() 41 if len(num_f_maps) > 1: 42 raise RuntimeError( 43 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 44 f"got {len(num_f_maps) + 1} dimensions" 45 ) 46 num_f_maps = int(num_f_maps[0]) 47 self.pars = { 48 "num_f_maps": num_f_maps, 49 "len_segment": len_segment, 50 "output_dim": 1, 51 "kernel_1": 5, 52 "kernel_2": 5, 53 "stride": 2, 54 "decrease_f_maps": True, 55 }
Initialize the constructor.
Parameters
num_f_maps : torch.Size the number of input feature maps len_segment : int the length of the input segments
The type parameter defines interaction with the model:
'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
57 def transformation(self, sample_data: Dict) -> Tuple: 58 """Do the flip.""" 59 ssl_target = torch.randint(2, (1,), dtype=torch.float) 60 ssl_input = deepcopy(sample_data) 61 if ssl_target == 1: 62 for key, value in sample_data.items(): 63 ssl_input[key] = value.flip(-1) 64 return ssl_input, {"order": ssl_target}
Do the flip.
66 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 67 """Cross-entropy loss.""" 68 loss = self.ce(predicted, target.squeeze()) 69 return loss
Cross-entropy loss.
71 def construct_module(self) -> nn.Module: 72 """Construct the SSL prediction module using the parameters specified at initialization.""" 73 module = FeatureExtractorTCN(**self.pars) 74 return module
Construct the SSL prediction module using the parameters specified at initialization.
77class OrderSSL(SSLConstructor, ABC): 78 """An order prediction SSL. 79 80 Cut out segments from the features, permute them and predict the order. 81 """ 82 83 type = "ssl_target" 84 85 def __init__( 86 self, 87 num_f_maps: torch.Size, 88 len_segment: int, 89 num_segments: int = 3, 90 ssl_features: int = 32, 91 skip_frames: int = 10, 92 ) -> None: 93 """Initialize the OrderSSL. 94 95 Parameters 96 ---------- 97 num_f_maps : torch.Size 98 the number of the input feature maps 99 len_segment : int 100 the length of the input segments 101 num_segments : int, default 3 102 the number of segments to permute 103 ssl_features : int, default 32 104 the number of features per permuted segment 105 skip_frames : int, default 10 106 the number of frames to cut from each permuted segment 107 108 """ 109 super().__init__() 110 self.ce = CrossEntropyLoss(ignore_index=-100) 111 if len(num_f_maps) > 1: 112 raise RuntimeError( 113 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 114 f"got {len(num_f_maps) + 1} dimensions" 115 ) 116 num_f_maps = int(num_f_maps[0]) 117 self.orders = [list(x) for x in permutations(range(num_segments), num_segments)] 118 self.len_segment = len_segment // num_segments 119 self.num_segments = num_segments 120 self.skip_frames = skip_frames 121 self.pars = { 122 "num_f_maps": num_f_maps, 123 "len_segment": len_segment // num_segments, 124 "output_dim": ssl_features, 125 "kernel_1": 5, 126 "kernel_2": 5, 127 "stride": 2, 128 "decrease_f_maps": True, 129 } 130 131 def transformation(self, sample_data: Dict) -> Tuple: 132 """Empty transformation.""" 133 return torch.tensor(float("nan")), torch.tensor(float("nan")) 134 135 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 136 """Cross-entropy loss.""" 137 predicted, target = predicted 138 loss = self.ce(predicted, target) 139 return loss 140 141 def construct_module(self) -> nn.Module: 142 """Construct the SSL prediction module using the parameters specified at initialization.""" 143 class Classifier(nn.Module): 144 def __init__(self, num_segments, num_classes, skip_frames, **pars): 145 super().__init__() 146 self.len_segment = pars["len_segment"] 147 pars["len_segment"] -= skip_frames 148 self.extractor = FeatureExtractorTCN(**pars) 149 self.num_segments = num_segments 150 self.skip_frames = skip_frames 151 self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes) 152 self.orders = torch.tensor( 153 [list(x) for x in permutations(range(num_segments), num_segments)] 154 ) 155 156 def forward(self, x): 157 target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device) 158 order = self.orders.to(x.device)[target] 159 x = x[:, :, : self.num_segments * self.len_segment] 160 B, F, L = x.shape 161 x = x.reshape((B, F, -1, self.len_segment)) 162 x = x[:, :, :, : -self.skip_frames] 163 x = x[ 164 torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), 165 torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0), 166 order.unsqueeze(1).unsqueeze(-1), 167 torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0), 168 ] 169 x = x.transpose(1, 2).reshape( 170 (-1, F, self.len_segment - self.skip_frames) 171 ) 172 x = self.extractor(x).reshape((B, -1)) 173 x = self.fc(x) 174 return (x, target) 175 176 module = Classifier( 177 self.num_segments, 178 len(self.orders), 179 skip_frames=self.skip_frames, 180 **self.pars, 181 ) 182 return module
An order prediction SSL.
Cut out segments from the features, permute them and predict the order.
85 def __init__( 86 self, 87 num_f_maps: torch.Size, 88 len_segment: int, 89 num_segments: int = 3, 90 ssl_features: int = 32, 91 skip_frames: int = 10, 92 ) -> None: 93 """Initialize the OrderSSL. 94 95 Parameters 96 ---------- 97 num_f_maps : torch.Size 98 the number of the input feature maps 99 len_segment : int 100 the length of the input segments 101 num_segments : int, default 3 102 the number of segments to permute 103 ssl_features : int, default 32 104 the number of features per permuted segment 105 skip_frames : int, default 10 106 the number of frames to cut from each permuted segment 107 108 """ 109 super().__init__() 110 self.ce = CrossEntropyLoss(ignore_index=-100) 111 if len(num_f_maps) > 1: 112 raise RuntimeError( 113 "The ContrastiveSSL constructor expects the input data to be 2-dimensional; " 114 f"got {len(num_f_maps) + 1} dimensions" 115 ) 116 num_f_maps = int(num_f_maps[0]) 117 self.orders = [list(x) for x in permutations(range(num_segments), num_segments)] 118 self.len_segment = len_segment // num_segments 119 self.num_segments = num_segments 120 self.skip_frames = skip_frames 121 self.pars = { 122 "num_f_maps": num_f_maps, 123 "len_segment": len_segment // num_segments, 124 "output_dim": ssl_features, 125 "kernel_1": 5, 126 "kernel_2": 5, 127 "stride": 2, 128 "decrease_f_maps": True, 129 }
Initialize the OrderSSL.
Parameters
num_f_maps : torch.Size the number of the input feature maps len_segment : int the length of the input segments num_segments : int, default 3 the number of segments to permute ssl_features : int, default 32 the number of features per permuted segment skip_frames : int, default 10 the number of frames to cut from each permuted segment
The type parameter defines interaction with the model:
'ssl_input': a modification of the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'ssl_output': the input data passes through the base network feature extraction module and the SSL module; it is returned as SSL output and compared to SSL target (or, if it is None, to the input data),'contrastive': the input data and its modification pass through the base network feature extraction module and the SSL module; an (input results, modification results) tuple is returned as SSL output,'contrastive_2layers': the input data and its modification pass through the base network feature extraction module; the output of the second feature extraction layer for the modified data goes through an SSL module and then, optionally, that result and the first-level unmodified features pass another transformation; an (input results, modified results) tuple is returned as SSL output,
131 def transformation(self, sample_data: Dict) -> Tuple: 132 """Empty transformation.""" 133 return torch.tensor(float("nan")), torch.tensor(float("nan"))
Empty transformation.
135 def loss(self, predicted: torch.Tensor, target: torch.Tensor) -> float: 136 """Cross-entropy loss.""" 137 predicted, target = predicted 138 loss = self.ce(predicted, target) 139 return loss
Cross-entropy loss.
141 def construct_module(self) -> nn.Module: 142 """Construct the SSL prediction module using the parameters specified at initialization.""" 143 class Classifier(nn.Module): 144 def __init__(self, num_segments, num_classes, skip_frames, **pars): 145 super().__init__() 146 self.len_segment = pars["len_segment"] 147 pars["len_segment"] -= skip_frames 148 self.extractor = FeatureExtractorTCN(**pars) 149 self.num_segments = num_segments 150 self.skip_frames = skip_frames 151 self.fc = Linear(pars["output_dim"] * self.num_segments, num_classes) 152 self.orders = torch.tensor( 153 [list(x) for x in permutations(range(num_segments), num_segments)] 154 ) 155 156 def forward(self, x): 157 target = torch.randint(len(self.orders), (x.shape[0],)).to(x.device) 158 order = self.orders.to(x.device)[target] 159 x = x[:, :, : self.num_segments * self.len_segment] 160 B, F, L = x.shape 161 x = x.reshape((B, F, -1, self.len_segment)) 162 x = x[:, :, :, : -self.skip_frames] 163 x = x[ 164 torch.arange(x.shape[0]).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1), 165 torch.arange(x.shape[1]).unsqueeze(-1).unsqueeze(-1).unsqueeze(0), 166 order.unsqueeze(1).unsqueeze(-1), 167 torch.arange(x.shape[-1]).unsqueeze(0).unsqueeze(0).unsqueeze(0), 168 ] 169 x = x.transpose(1, 2).reshape( 170 (-1, F, self.len_segment - self.skip_frames) 171 ) 172 x = self.extractor(x).reshape((B, -1)) 173 x = self.fc(x) 174 return (x, target) 175 176 module = Classifier( 177 self.num_segments, 178 len(self.orders), 179 skip_frames=self.skip_frames, 180 **self.pars, 181 ) 182 return module
Construct the SSL prediction module using the parameters specified at initialization.