dlc2action.utils
Utility functions
TensorDict
is a convenient data structure for keeping (and indexing) lists of feature dictionaries,apply_threshold
,apply_threshold_hysteresis
andapply_threshold_max
are utility functions fordlc2action.data.dataset.BehaviorDataset.find_valleys
,strip_suffix
is used to get rid of suffices if a string (usually filename) ends with one of them,strip_prefix
is used to get rid of prefixes if a string (usually filename) starts with one of them,rotation_matrix_2d
androtation_matrix_3d
are used to generate rotation matrices bydlc2action.transformer.base_transformer.Transformer
instances
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7## Utility functions 8 9- `TensorDict` is a convenient data structure for keeping (and indexing) lists of feature dictionaries, 10- `apply_threshold`, `apply_threshold_hysteresis` and `apply_threshold_max` are utility functions for 11`dlc2action.data.dataset.BehaviorDataset.find_valleys`, 12- `strip_suffix` is used to get rid of suffices if a string (usually filename) ends with one of them, 13- `strip_prefix` is used to get rid of prefixes if a string (usually filename) starts with one of them, 14- `rotation_matrix_2d` and `rotation_matrix_3d` are used to generate rotation matrices by 15`dlc2action.transformer.base_transformer.Transformer` instances 16""" 17 18import torch 19from typing import List, Dict, Union 20from collections.abc import Iterable 21import warnings 22import os 23import numpy as np 24from torch import nn 25from torch.nn import functional as F 26import math 27 28 29class TensorDict: 30 """ 31 A class that handles indexing in a dictionary of tensors of the same length 32 """ 33 34 def __init__(self, obj: Union[Dict, Iterable] = None) -> None: 35 """ 36 Parameters 37 ---------- 38 obj : dict | iterable, optional 39 either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with 40 the same keys (if not passed, a blank TensorDict is initialized) 41 """ 42 43 if obj is None: 44 obj = {} 45 if isinstance(obj, list): 46 self.keys = list(obj[0].keys()) 47 self.dict = {key: [] for key in self.keys} 48 with warnings.catch_warnings(): 49 warnings.filterwarnings("ignore", category=UserWarning) 50 for element in obj: 51 for key in self.keys: 52 self.dict[key].append(torch.tensor(element[key])) 53 # self.dict = {k: torch.stack(v) for k, v in self.dict.items()} 54 old_dict = self.dict 55 self.dict = {} 56 for k, v in old_dict.items(): 57 self.dict[k] = torch.stack(v) 58 elif isinstance(obj, dict): 59 if not all([isinstance(obj[key], torch.Tensor) for key in obj]): 60 raise TypeError( 61 f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;" 62 f"got {[type(obj[key]) for key in obj]}" 63 ) 64 lengths = [len(obj[key]) for key in obj] 65 if not all([x == lengths[0] for x in lengths]): 66 raise ValueError( 67 f"The value tensors in the dictionary passed to TensorDict need to have the same length;" 68 f"got {lengths}" 69 ) 70 self.dict = obj 71 self.keys = list(self.dict.keys()) 72 else: 73 raise TypeError( 74 f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary " 75 f"of tensors (got {type(obj)})" 76 ) 77 if len(self.keys) > 0: 78 self.type = type(self.dict[self.keys[0]]) 79 else: 80 self.type = None 81 82 def __len__(self) -> int: 83 return len(self.dict[self.keys[0]]) 84 85 def __getitem__(self, ind: Union[int, List]): 86 """ 87 Index the TensorDict 88 89 Parameters 90 ---------- 91 ind : int | list 92 the index/indices 93 94 Returns 95 ------- 96 dict | TensorDict 97 the indexed elements of all value lists combined in a dictionary (if length of ind is 1) or a TensorDict 98 """ 99 100 x = {key: self.dict[key][ind] for key in self.dict} 101 if type(ind) is not int: 102 d = TensorDict(x) 103 return d 104 else: 105 return x 106 107 def append(self, element: Dict) -> None: 108 """ 109 Append an element 110 111 Parameters 112 ---------- 113 element : dict 114 a dictionary 115 """ 116 117 type_element = type(element[list(element.keys())[0]]) 118 if self.type is None: 119 self.dict = {k: v.unsqueeze(0) for k, v in element.items()} 120 self.keys = list(element.keys()) 121 self.type = type_element 122 else: 123 for key in self.keys: 124 if key not in element: 125 raise ValueError( 126 f"The dictionary appended to TensorDict needs to have the same keys as the " 127 f"TensorDict; got {element.keys()} and {self.keys}" 128 ) 129 self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)]) 130 131 def remove(self, indices: List) -> None: 132 """ 133 Remove indexed elements 134 135 Parameters 136 ---------- 137 indices : list 138 the indices to remove 139 """ 140 141 mask = torch.ones(len(self)) 142 mask[indices] = 0 143 mask = mask.bool() 144 for key, value in self.dict.items(): 145 self.dict[key] = value[mask] 146 147 148def apply_threshold( 149 tensor: torch.Tensor, 150 threshold: float, 151 low: bool = True, 152 error_mask: torch.Tensor = None, 153 min_frames: int = 0, 154 smooth_interval: int = 0, 155 masked_intervals: List = None, 156): 157 """ 158 Apply a hard threshold to a tensor and return indices of the intervals that passed 159 160 If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold. 161 If `min_frames` is not 0, the intervals are additionally filtered by length. 162 163 Parameters 164 ---------- 165 tensor : torch.Tensor 166 the tensor to apply the threshold to 167 threshold : float 168 the threshold 169 error_mask : torch.Tensor, optional 170 a boolean real_lens to apply to the results 171 min_frames : int, default 0 172 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 173 174 Returns 175 ------- 176 indices_start : list 177 a list of indices of the first frames of the chosen intervals 178 indices_end : list 179 a list of indices of the last frames of the chosen intervals 180 """ 181 182 if masked_intervals is None: 183 masked_intervals = [] 184 if low: 185 p = tensor <= threshold 186 else: 187 p = tensor >= threshold 188 p = smooth(p, smooth_interval) 189 if error_mask is not None: 190 p = p * error_mask 191 for start, end in masked_intervals: 192 p[start:end] = False 193 output, indices, counts = torch.unique_consecutive( 194 p, return_inverse=True, return_counts=True 195 ) 196 long_indices = torch.where(output * (counts > min_frames))[0] 197 res_indices_start = [ 198 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 199 ] 200 res_indices_end = [ 201 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 202 ] 203 return res_indices_start, res_indices_end 204 205 206def apply_threshold_hysteresis( 207 tensor: torch.Tensor, 208 soft_threshold: float, 209 hard_threshold: float, 210 low: bool = True, 211 error_mask: torch.Tensor = None, 212 min_frames: int = 0, 213 smooth_interval: int = 0, 214 masked_intervals: List = None, 215): 216 """ 217 Apply a hysteresis threshold to a tensor and return indices of the intervals that passed 218 219 In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold. 220 If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold. 221 If `min_frames` is not 0, the intervals are additionally filtered by length. 222 223 Parameters 224 ---------- 225 tensor : torch.Tensor 226 the tensor to apply the threshold to 227 soft_threshold : float 228 the soft threshold 229 hard_threshold : float 230 the hard threshold 231 error_mask : torch.Tensor, optional 232 a boolean real_lens to apply to the results 233 min_frames : int, default 0 234 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 235 236 Returns 237 ------- 238 indices_start : list 239 a list of indices of the first frames of the chosen intervals 240 indices_end : list 241 a list of indices of the last frames of the chosen intervals 242 """ 243 244 if masked_intervals is None: 245 masked_intervals = [] 246 if low: 247 p = tensor <= soft_threshold 248 hard = tensor <= hard_threshold 249 else: 250 p = tensor >= soft_threshold 251 hard = tensor >= hard_threshold 252 p = smooth(p, smooth_interval) 253 if error_mask is not None: 254 p = p * error_mask 255 for start, end in masked_intervals: 256 p[start:end] = False 257 output, indices, counts = torch.unique_consecutive( 258 p, return_inverse=True, return_counts=True 259 ) 260 long_indices = torch.where(output * (counts > min_frames))[0] 261 indices_start = [ 262 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 263 ] 264 indices_end = [ 265 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 266 ] 267 res_indices_start = [] 268 res_indices_end = [] 269 for start, end in zip(indices_start, indices_end): 270 if torch.sum(hard[start:end]) > 0: 271 res_indices_start.append(start) 272 res_indices_end.append(end) 273 return res_indices_start, res_indices_end 274 275 276def apply_threshold_max( 277 tensor: torch.Tensor, 278 threshold: float, 279 main_class: int, 280 error_mask: torch.Tensor = None, 281 min_frames: int = 0, 282 smooth_interval: int = 0, 283 masked_intervals: List = None, 284): 285 """ 286 Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed 287 288 In the chosen intervals the values at the `main_class` index are larger than the others everywhere 289 and at least one value at the `main_class` index passes the threshold. 290 If `error_mask` is not `None`, the elements marked `False`are treated as if they did not pass the threshold. 291 If min_frames is not 0, the intervals are additionally filtered by length. 292 293 Parameters 294 ---------- 295 tensor : torch.Tensor 296 the tensor to apply the threshold to (of shape `(#classes, #frames)`) 297 threshold : float 298 the threshold 299 main_class : int 300 the class that conditions the soft threshold 301 error_mask : torch.Tensor, optional 302 a boolean real_lens to apply to the results 303 min_frames : int, default 0 304 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 305 306 Returns 307 ------- 308 indices_start : list 309 a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor) 310 indices_end : list 311 a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor) 312 """ 313 314 if masked_intervals is None: 315 masked_intervals = [] 316 _, indices = torch.max(tensor, dim=0) 317 p = indices == main_class 318 p = smooth(p, smooth_interval) 319 if error_mask is not None: 320 p = p * error_mask 321 for start, end in masked_intervals: 322 p[start:end] = False 323 output, indices, counts = torch.unique_consecutive( 324 p, return_inverse=True, return_counts=True 325 ) 326 long_indices = torch.where(output * (counts > min_frames))[0] 327 indices_start = [ 328 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 329 ] 330 indices_end = [ 331 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 332 ] 333 res_indices_start = [] 334 res_indices_end = [] 335 if threshold is not None: 336 hard = tensor[main_class, :] > threshold 337 for start, end in zip(indices_start, indices_end): 338 if torch.sum(hard[start:end]) > 0: 339 res_indices_start.append(start) 340 res_indices_end.append(end) 341 return res_indices_start, res_indices_end 342 else: 343 return indices_start, indices_end 344 345 346def strip_suffix(text: str, suffix: Iterable): 347 """ 348 Strip a suffix from a string if it is contained in a list 349 350 Parameters 351 ---------- 352 text : str 353 the main string 354 suffix : iterable 355 the list of suffices to be stripped 356 357 Returns 358 ------- 359 result : str 360 the stripped string 361 """ 362 363 for s in suffix: 364 if text.endswith(s): 365 return text[: -len(s)] 366 return text 367 368 369def strip_prefix(text: str, prefix: Iterable): 370 """ 371 Strip a prefix from a string if it is contained in a list 372 373 Parameters 374 ---------- 375 text : str 376 the main string 377 prefix : iterable 378 the list of prefixes to be stripped 379 380 Returns 381 ------- 382 result : str 383 the stripped string 384 """ 385 386 if prefix is None: 387 prefix = [] 388 for s in prefix: 389 if text.startswith(s): 390 return text[len(s) :] 391 return text 392 393 394def rotation_matrix_2d(angles: torch.Tensor) -> torch.Tensor: 395 """ 396 Create a tensor of 2D rotation matrices from a tensor of angles 397 398 Parameters 399 ---------- 400 angles : torch.Tensor 401 a tensor of angles of arbitrary shape `(...)` 402 403 Returns 404 ------- 405 rotation_matrices : torch.Tensor 406 a tensor of 2D rotation matrices of shape `(..., 2, 2)` 407 """ 408 409 cos = torch.cos(angles) 410 sin = torch.sin(angles) 411 R = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*angles.shape, 2, 2) 412 return R 413 414 415def rotation_matrix_3d(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor): 416 """ 417 Create a tensor of 3D rotation matrices from a tensor of angles 418 419 Parameters 420 ---------- 421 alpha : torch.Tensor 422 a tensor of rotation angles around the x axis of arbitrary shape `(...)` 423 beta : torch.Tensor 424 a tensor of rotation angles around the y axis of arbitrary shape `(...)` 425 gamma : torch.Tensor 426 a tensor of rotation angles around the z axis of arbitrary shape `(...)` 427 428 Returns 429 ------- 430 rotation_matrices : torch.Tensor 431 a tensor of 3D rotation matrices of shape `(..., 3, 3)` 432 """ 433 434 cos = torch.cos(alpha) 435 sin = torch.sin(alpha) 436 Rx = torch.stack( 437 [ 438 torch.ones(cos.shape), 439 torch.zeros(cos.shape), 440 torch.zeros(cos.shape), 441 torch.zeros(cos.shape), 442 cos, 443 -sin, 444 torch.zeros(cos.shape), 445 sin, 446 cos, 447 ], 448 dim=-1, 449 ).reshape(*alpha.shape, 3, 3) 450 cos = torch.cos(beta) 451 sin = torch.sin(beta) 452 Ry = torch.stack( 453 [ 454 cos, 455 torch.zeros(cos.shape), 456 sin, 457 torch.zeros(cos.shape), 458 torch.ones(cos.shape), 459 torch.zeros(cos.shape), 460 -sin, 461 torch.zeros(cos.shape), 462 cos, 463 ], 464 dim=-1, 465 ).reshape(*beta.shape, 3, 3) 466 cos = torch.cos(gamma) 467 sin = torch.sin(gamma) 468 Rz = torch.stack( 469 [ 470 cos, 471 -sin, 472 torch.zeros(cos.shape), 473 sin, 474 cos, 475 torch.zeros(cos.shape), 476 torch.zeros(cos.shape), 477 torch.zeros(cos.shape), 478 torch.ones(cos.shape), 479 ], 480 dim=-1, 481 ).reshape(*gamma.shape, 3, 3) 482 R = torch.einsum("...ij,...jk,...kl->...il", Rx, Ry, Rz) 483 return R 484 485 486def correct_path(path, project_path): 487 if not isinstance(path, str): 488 return path 489 path = os.path.normpath(path).split(os.path.sep) 490 if "results" in path: 491 name = "results" 492 else: 493 name = "saved_datasets" 494 ind = path.index(name) + 1 495 return os.path.join(project_path, "results", *path[ind:]) 496 497 498class TensorList(list): 499 """ 500 A list of tensors that can send each element to a `torch` device 501 """ 502 503 def to_device(self, device: torch.device): 504 for i, x in enumerate(self): 505 self[i] = x.to_device(device) 506 507 508def get_intervals(tensor: torch.Tensor) -> torch.Tensor: 509 """ 510 Get a list of True group beginning and end indices from a boolean tensor 511 """ 512 513 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 514 true_indices = torch.where(output)[0] 515 starts = torch.tensor( 516 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 517 ) 518 ends = torch.tensor( 519 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 520 ) 521 return torch.stack([starts, ends]).T 522 523 524def smooth(tensor: torch.Tensor, smooth_interval: int = 0) -> torch.Tensor: 525 """ 526 Get rid of jittering in a non-exclusive classification tensor 527 528 First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than 529 `smooth_interval`. 530 """ 531 532 if smooth_interval == 0: 533 return tensor 534 intervals = get_intervals(tensor == 0) 535 interval_lengths = torch.tensor( 536 [interval[1] - interval[0] for interval in intervals] 537 ) 538 short_intervals = intervals[interval_lengths <= smooth_interval] 539 for start, end in short_intervals: 540 tensor[start:end] = 1 541 intervals = get_intervals(tensor == 1) 542 interval_lengths = torch.tensor( 543 [interval[1] - interval[0] for interval in intervals] 544 ) 545 short_intervals = intervals[interval_lengths <= smooth_interval] 546 for start, end in short_intervals: 547 tensor[start:end] = 0 548 return tensor 549 550 551class GaussianSmoothing(nn.Module): 552 """ 553 Apply gaussian smoothing on a 1d tensor. 554 Filtering is performed seperately for each channel 555 in the input using a depthwise convolution. 556 Arguments: 557 channels (int, sequence): Number of channels of the input tensors. Output will 558 have this number of channels as well. 559 kernel_size (int, sequence): Size of the gaussian kernel. 560 sigma (float, sequence): Standard deviation of the gaussian kernel. 561 """ 562 563 def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None: 564 super().__init__() 565 self.kernel_size = kernel_size 566 567 # The gaussian kernel is the product of the 568 # gaussian function of each dimension. 569 kernel = 1 570 meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float() 571 572 mean = (kernel_size - 1) / 2 573 kernel = kernel / (sigma * math.sqrt(2 * math.pi)) 574 kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2) 575 576 # Make sure sum of values in gaussian kernel equals 1. 577 # kernel = kernel / torch.max(kernel) 578 579 self.kernel = kernel.view(1, 1, *kernel.size()) 580 581 def forward(self, inputs: torch.Tensor) -> torch.Tensor: 582 """ 583 Apply gaussian filter to input. 584 Arguments: 585 input (torch.Tensor): Input to apply gaussian filter on. 586 Returns: 587 filtered (torch.Tensor): Filtered output. 588 """ 589 _, c, _ = inputs.shape 590 inputs = F.pad( 591 inputs, 592 pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2), 593 mode="reflect", 594 ) 595 kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device) 596 return F.conv1d(inputs, weight=kernel, groups=c) 597 598 599def argrelmax(prob: np.ndarray, threshold: float = 0.7) -> List[int]: 600 """ 601 Calculate arguments of relative maxima. 602 prob: np.array. boundary probability maps distributerd in [0, 1] 603 prob shape is (T) 604 ignore the peak whose value is under threshold 605 Return: 606 Index of peaks for each batch 607 """ 608 # ignore the values under threshold 609 prob[prob < threshold] = 0.0 610 611 # calculate the relative maxima of boundary maps 612 # treat the first frame as boundary 613 peak = np.concatenate( 614 [ 615 np.ones((1), dtype=np.bool), 616 (prob[:-2] < prob[1:-1]) & (prob[2:] < prob[1:-1]), 617 np.zeros((1), dtype=np.bool), 618 ], 619 axis=0, 620 ) 621 622 peak_idx = np.where(peak)[0].tolist() 623 624 return peak_idx 625 626 627def decide_boundary_prob_with_similarity(x: torch.Tensor) -> torch.Tensor: 628 """ 629 Decide action boundary probabilities based on adjacent frame similarities. 630 Args: 631 x: frame-wise video features (N, C, T) 632 Return: 633 boundary: action boundary probability (N, 1, T) 634 """ 635 device = x.device 636 637 # gaussian kernel. 638 diff = x[0, :, 1:] - x[0, :, :-1] 639 similarity = torch.exp(-torch.norm(diff, dim=0) / (2 * 1.0)) 640 641 # define action starting point as action boundary. 642 start = torch.ones(1).float().to(device) 643 boundary = torch.cat([start, similarity]) 644 boundary = boundary.view(1, 1, -1) 645 return boundary 646 647 648class PostProcessor(object): 649 def __init__( 650 self, 651 name: str, 652 boundary_th: int = 0.7, 653 theta_t: int = 15, 654 kernel_size: int = 15, 655 ) -> None: 656 self.func = { 657 "refinement_with_boundary": self._refinement_with_boundary, 658 "relabeling": self._relabeling, 659 "smoothing": self._smoothing, 660 } 661 assert name in self.func 662 663 self.name = name 664 self.boundary_th = boundary_th 665 self.theta_t = theta_t 666 self.kernel_size = kernel_size 667 668 if name == "smoothing": 669 self.filter = GaussianSmoothing(self.kernel_size) 670 671 def _is_probability(self, x: np.ndarray) -> bool: 672 assert x.ndim == 3 673 674 if x.shape[1] == 1: 675 # sigmoid 676 if x.min() >= 0 and x.max() <= 1: 677 return True 678 else: 679 return False 680 else: 681 # softmax 682 _sum = np.sum(x, axis=1).astype(np.float32) 683 _ones = np.ones_like(_sum, dtype=np.float32) 684 return np.allclose(_sum, _ones) 685 686 def _convert2probability(self, x: np.ndarray) -> np.ndarray: 687 """ 688 Args: x (N, C, T) 689 """ 690 assert x.ndim == 3 691 692 if self._is_probability(x): 693 return x 694 else: 695 if x.shape[1] == 1: 696 # sigmoid 697 prob = 1 / (1 + np.exp(-x)) 698 else: 699 # softmax 700 prob = np.exp(x) / np.expand_dims(np.sum(np.exp(x), axis=1), 1) 701 return prob.astype(np.float32) 702 703 def _convert2label(self, x: np.ndarray) -> np.ndarray: 704 assert x.ndim == 2 or x.ndim == 3 705 706 if x.ndim == 2: 707 return x.astype(np.int64) 708 else: 709 if not self._is_probability(x): 710 x = self._convert2probability(x) 711 712 label = np.argmax(x, axis=1) 713 return label.astype(np.int64) 714 715 def _refinement_with_boundary( 716 self, 717 outputs: np.array, 718 boundaries: np.ndarray, 719 ) -> np.ndarray: 720 """ 721 Get segments which is defined as the span b/w two boundaries, 722 and decide their classes by majority vote. 723 Args: 724 outputs: numpy array. shape (N, C, T) 725 the model output for frame-level class prediction. 726 boundaries: numpy array. shape (N, 1, T) 727 boundary prediction. 728 masks: np.array. np.bool. shape (N, 1, T) 729 valid length for each video 730 Return: 731 preds: np.array. shape (N, T) 732 final class prediction considering boundaries. 733 """ 734 735 preds = self._convert2label(outputs) 736 boundaries = self._convert2probability(boundaries) 737 738 for i, (output, pred, boundary) in enumerate(zip(outputs, preds, boundaries)): 739 idx = argrelmax(boundary.squeeze(), threshold=self.boundary_th) 740 741 # add the index of the last action ending 742 T = pred.shape[0] 743 idx.append(T) 744 745 # majority vote 746 for j in range(len(idx) - 1): 747 count = np.bincount(pred[idx[j] : idx[j + 1]]) 748 modes = np.where(count == count.max())[0] 749 if len(modes) == 1: 750 mode = modes 751 else: 752 if outputs.ndim == 3: 753 # if more than one majority class exist 754 prob_sum_max = 0 755 for m in modes: 756 prob_sum = output[m, idx[j] : idx[j + 1]].sum() 757 if prob_sum_max < prob_sum: 758 mode = m 759 prob_sum_max = prob_sum 760 else: 761 # decide first mode when more than one majority class 762 # have the same number during oracle experiment 763 mode = modes[0] 764 765 preds[i, idx[j] : idx[j + 1]] = mode 766 767 return preds 768 769 def _relabeling(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: 770 """ 771 Relabeling small action segments with their previous action segment 772 Args: 773 output: the results of action segmentation. (N, T) or (N, C, T) 774 theta_t: the threshold of the size of action segments. 775 Return: 776 relabeled output. (N, T) 777 """ 778 779 preds = self._convert2label(outputs) 780 781 for i in range(preds.shape[0]): 782 # shape (T,) 783 last = preds[i][0] 784 cnt = 1 785 for j in range(1, preds.shape[1]): 786 if last == preds[i][j]: 787 cnt += 1 788 else: 789 if cnt > self.theta_t: 790 cnt = 1 791 last = preds[i][j] 792 else: 793 preds[i][j - cnt : j] = preds[i][j - cnt - 1] 794 cnt = 1 795 last = preds[i][j] 796 797 if cnt <= self.theta_t: 798 preds[i][j - cnt : j] = preds[i][j - cnt - 1] 799 800 return preds 801 802 def _smoothing(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: 803 """ 804 Smoothing action probabilities with gaussian filter. 805 Args: 806 outputs: frame-wise action probabilities. (N, C, T) 807 Return: 808 predictions: final prediction. (N, T) 809 """ 810 811 outputs = self._convert2probability(outputs) 812 outputs = self.filter(torch.Tensor(outputs)).numpy() 813 814 preds = self._convert2label(outputs) 815 return preds 816 817 def __call__(self, outputs, **kwargs: np.ndarray) -> np.ndarray: 818 819 preds = self.func[self.name](outputs, **kwargs) 820 return preds
30class TensorDict: 31 """ 32 A class that handles indexing in a dictionary of tensors of the same length 33 """ 34 35 def __init__(self, obj: Union[Dict, Iterable] = None) -> None: 36 """ 37 Parameters 38 ---------- 39 obj : dict | iterable, optional 40 either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with 41 the same keys (if not passed, a blank TensorDict is initialized) 42 """ 43 44 if obj is None: 45 obj = {} 46 if isinstance(obj, list): 47 self.keys = list(obj[0].keys()) 48 self.dict = {key: [] for key in self.keys} 49 with warnings.catch_warnings(): 50 warnings.filterwarnings("ignore", category=UserWarning) 51 for element in obj: 52 for key in self.keys: 53 self.dict[key].append(torch.tensor(element[key])) 54 # self.dict = {k: torch.stack(v) for k, v in self.dict.items()} 55 old_dict = self.dict 56 self.dict = {} 57 for k, v in old_dict.items(): 58 self.dict[k] = torch.stack(v) 59 elif isinstance(obj, dict): 60 if not all([isinstance(obj[key], torch.Tensor) for key in obj]): 61 raise TypeError( 62 f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;" 63 f"got {[type(obj[key]) for key in obj]}" 64 ) 65 lengths = [len(obj[key]) for key in obj] 66 if not all([x == lengths[0] for x in lengths]): 67 raise ValueError( 68 f"The value tensors in the dictionary passed to TensorDict need to have the same length;" 69 f"got {lengths}" 70 ) 71 self.dict = obj 72 self.keys = list(self.dict.keys()) 73 else: 74 raise TypeError( 75 f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary " 76 f"of tensors (got {type(obj)})" 77 ) 78 if len(self.keys) > 0: 79 self.type = type(self.dict[self.keys[0]]) 80 else: 81 self.type = None 82 83 def __len__(self) -> int: 84 return len(self.dict[self.keys[0]]) 85 86 def __getitem__(self, ind: Union[int, List]): 87 """ 88 Index the TensorDict 89 90 Parameters 91 ---------- 92 ind : int | list 93 the index/indices 94 95 Returns 96 ------- 97 dict | TensorDict 98 the indexed elements of all value lists combined in a dictionary (if length of ind is 1) or a TensorDict 99 """ 100 101 x = {key: self.dict[key][ind] for key in self.dict} 102 if type(ind) is not int: 103 d = TensorDict(x) 104 return d 105 else: 106 return x 107 108 def append(self, element: Dict) -> None: 109 """ 110 Append an element 111 112 Parameters 113 ---------- 114 element : dict 115 a dictionary 116 """ 117 118 type_element = type(element[list(element.keys())[0]]) 119 if self.type is None: 120 self.dict = {k: v.unsqueeze(0) for k, v in element.items()} 121 self.keys = list(element.keys()) 122 self.type = type_element 123 else: 124 for key in self.keys: 125 if key not in element: 126 raise ValueError( 127 f"The dictionary appended to TensorDict needs to have the same keys as the " 128 f"TensorDict; got {element.keys()} and {self.keys}" 129 ) 130 self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)]) 131 132 def remove(self, indices: List) -> None: 133 """ 134 Remove indexed elements 135 136 Parameters 137 ---------- 138 indices : list 139 the indices to remove 140 """ 141 142 mask = torch.ones(len(self)) 143 mask[indices] = 0 144 mask = mask.bool() 145 for key, value in self.dict.items(): 146 self.dict[key] = value[mask]
A class that handles indexing in a dictionary of tensors of the same length
35 def __init__(self, obj: Union[Dict, Iterable] = None) -> None: 36 """ 37 Parameters 38 ---------- 39 obj : dict | iterable, optional 40 either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with 41 the same keys (if not passed, a blank TensorDict is initialized) 42 """ 43 44 if obj is None: 45 obj = {} 46 if isinstance(obj, list): 47 self.keys = list(obj[0].keys()) 48 self.dict = {key: [] for key in self.keys} 49 with warnings.catch_warnings(): 50 warnings.filterwarnings("ignore", category=UserWarning) 51 for element in obj: 52 for key in self.keys: 53 self.dict[key].append(torch.tensor(element[key])) 54 # self.dict = {k: torch.stack(v) for k, v in self.dict.items()} 55 old_dict = self.dict 56 self.dict = {} 57 for k, v in old_dict.items(): 58 self.dict[k] = torch.stack(v) 59 elif isinstance(obj, dict): 60 if not all([isinstance(obj[key], torch.Tensor) for key in obj]): 61 raise TypeError( 62 f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;" 63 f"got {[type(obj[key]) for key in obj]}" 64 ) 65 lengths = [len(obj[key]) for key in obj] 66 if not all([x == lengths[0] for x in lengths]): 67 raise ValueError( 68 f"The value tensors in the dictionary passed to TensorDict need to have the same length;" 69 f"got {lengths}" 70 ) 71 self.dict = obj 72 self.keys = list(self.dict.keys()) 73 else: 74 raise TypeError( 75 f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary " 76 f"of tensors (got {type(obj)})" 77 ) 78 if len(self.keys) > 0: 79 self.type = type(self.dict[self.keys[0]]) 80 else: 81 self.type = None
Parameters
obj : dict | iterable, optional either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with the same keys (if not passed, a blank TensorDict is initialized)
108 def append(self, element: Dict) -> None: 109 """ 110 Append an element 111 112 Parameters 113 ---------- 114 element : dict 115 a dictionary 116 """ 117 118 type_element = type(element[list(element.keys())[0]]) 119 if self.type is None: 120 self.dict = {k: v.unsqueeze(0) for k, v in element.items()} 121 self.keys = list(element.keys()) 122 self.type = type_element 123 else: 124 for key in self.keys: 125 if key not in element: 126 raise ValueError( 127 f"The dictionary appended to TensorDict needs to have the same keys as the " 128 f"TensorDict; got {element.keys()} and {self.keys}" 129 ) 130 self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)])
Append an element
Parameters
element : dict a dictionary
132 def remove(self, indices: List) -> None: 133 """ 134 Remove indexed elements 135 136 Parameters 137 ---------- 138 indices : list 139 the indices to remove 140 """ 141 142 mask = torch.ones(len(self)) 143 mask[indices] = 0 144 mask = mask.bool() 145 for key, value in self.dict.items(): 146 self.dict[key] = value[mask]
Remove indexed elements
Parameters
indices : list the indices to remove
149def apply_threshold( 150 tensor: torch.Tensor, 151 threshold: float, 152 low: bool = True, 153 error_mask: torch.Tensor = None, 154 min_frames: int = 0, 155 smooth_interval: int = 0, 156 masked_intervals: List = None, 157): 158 """ 159 Apply a hard threshold to a tensor and return indices of the intervals that passed 160 161 If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold. 162 If `min_frames` is not 0, the intervals are additionally filtered by length. 163 164 Parameters 165 ---------- 166 tensor : torch.Tensor 167 the tensor to apply the threshold to 168 threshold : float 169 the threshold 170 error_mask : torch.Tensor, optional 171 a boolean real_lens to apply to the results 172 min_frames : int, default 0 173 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 174 175 Returns 176 ------- 177 indices_start : list 178 a list of indices of the first frames of the chosen intervals 179 indices_end : list 180 a list of indices of the last frames of the chosen intervals 181 """ 182 183 if masked_intervals is None: 184 masked_intervals = [] 185 if low: 186 p = tensor <= threshold 187 else: 188 p = tensor >= threshold 189 p = smooth(p, smooth_interval) 190 if error_mask is not None: 191 p = p * error_mask 192 for start, end in masked_intervals: 193 p[start:end] = False 194 output, indices, counts = torch.unique_consecutive( 195 p, return_inverse=True, return_counts=True 196 ) 197 long_indices = torch.where(output * (counts > min_frames))[0] 198 res_indices_start = [ 199 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 200 ] 201 res_indices_end = [ 202 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 203 ] 204 return res_indices_start, res_indices_end
Apply a hard threshold to a tensor and return indices of the intervals that passed
If error_mask
is not None
, the elements marked False
are treated as if they did not pass the threshold.
If min_frames
is not 0, the intervals are additionally filtered by length.
Parameters
tensor : torch.Tensor the tensor to apply the threshold to threshold : float the threshold error_mask : torch.Tensor, optional a boolean real_lens to apply to the results min_frames : int, default 0 the minimum number of frames in the resulting intervals (shorter intervals are discarded)
Returns
indices_start : list a list of indices of the first frames of the chosen intervals indices_end : list a list of indices of the last frames of the chosen intervals
207def apply_threshold_hysteresis( 208 tensor: torch.Tensor, 209 soft_threshold: float, 210 hard_threshold: float, 211 low: bool = True, 212 error_mask: torch.Tensor = None, 213 min_frames: int = 0, 214 smooth_interval: int = 0, 215 masked_intervals: List = None, 216): 217 """ 218 Apply a hysteresis threshold to a tensor and return indices of the intervals that passed 219 220 In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold. 221 If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold. 222 If `min_frames` is not 0, the intervals are additionally filtered by length. 223 224 Parameters 225 ---------- 226 tensor : torch.Tensor 227 the tensor to apply the threshold to 228 soft_threshold : float 229 the soft threshold 230 hard_threshold : float 231 the hard threshold 232 error_mask : torch.Tensor, optional 233 a boolean real_lens to apply to the results 234 min_frames : int, default 0 235 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 236 237 Returns 238 ------- 239 indices_start : list 240 a list of indices of the first frames of the chosen intervals 241 indices_end : list 242 a list of indices of the last frames of the chosen intervals 243 """ 244 245 if masked_intervals is None: 246 masked_intervals = [] 247 if low: 248 p = tensor <= soft_threshold 249 hard = tensor <= hard_threshold 250 else: 251 p = tensor >= soft_threshold 252 hard = tensor >= hard_threshold 253 p = smooth(p, smooth_interval) 254 if error_mask is not None: 255 p = p * error_mask 256 for start, end in masked_intervals: 257 p[start:end] = False 258 output, indices, counts = torch.unique_consecutive( 259 p, return_inverse=True, return_counts=True 260 ) 261 long_indices = torch.where(output * (counts > min_frames))[0] 262 indices_start = [ 263 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 264 ] 265 indices_end = [ 266 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 267 ] 268 res_indices_start = [] 269 res_indices_end = [] 270 for start, end in zip(indices_start, indices_end): 271 if torch.sum(hard[start:end]) > 0: 272 res_indices_start.append(start) 273 res_indices_end.append(end) 274 return res_indices_start, res_indices_end
Apply a hysteresis threshold to a tensor and return indices of the intervals that passed
In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold.
If error_mask
is not None
, the elements marked False
are treated as if they did not pass the threshold.
If min_frames
is not 0, the intervals are additionally filtered by length.
Parameters
tensor : torch.Tensor the tensor to apply the threshold to soft_threshold : float the soft threshold hard_threshold : float the hard threshold error_mask : torch.Tensor, optional a boolean real_lens to apply to the results min_frames : int, default 0 the minimum number of frames in the resulting intervals (shorter intervals are discarded)
Returns
indices_start : list a list of indices of the first frames of the chosen intervals indices_end : list a list of indices of the last frames of the chosen intervals
277def apply_threshold_max( 278 tensor: torch.Tensor, 279 threshold: float, 280 main_class: int, 281 error_mask: torch.Tensor = None, 282 min_frames: int = 0, 283 smooth_interval: int = 0, 284 masked_intervals: List = None, 285): 286 """ 287 Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed 288 289 In the chosen intervals the values at the `main_class` index are larger than the others everywhere 290 and at least one value at the `main_class` index passes the threshold. 291 If `error_mask` is not `None`, the elements marked `False`are treated as if they did not pass the threshold. 292 If min_frames is not 0, the intervals are additionally filtered by length. 293 294 Parameters 295 ---------- 296 tensor : torch.Tensor 297 the tensor to apply the threshold to (of shape `(#classes, #frames)`) 298 threshold : float 299 the threshold 300 main_class : int 301 the class that conditions the soft threshold 302 error_mask : torch.Tensor, optional 303 a boolean real_lens to apply to the results 304 min_frames : int, default 0 305 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 306 307 Returns 308 ------- 309 indices_start : list 310 a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor) 311 indices_end : list 312 a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor) 313 """ 314 315 if masked_intervals is None: 316 masked_intervals = [] 317 _, indices = torch.max(tensor, dim=0) 318 p = indices == main_class 319 p = smooth(p, smooth_interval) 320 if error_mask is not None: 321 p = p * error_mask 322 for start, end in masked_intervals: 323 p[start:end] = False 324 output, indices, counts = torch.unique_consecutive( 325 p, return_inverse=True, return_counts=True 326 ) 327 long_indices = torch.where(output * (counts > min_frames))[0] 328 indices_start = [ 329 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 330 ] 331 indices_end = [ 332 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 333 ] 334 res_indices_start = [] 335 res_indices_end = [] 336 if threshold is not None: 337 hard = tensor[main_class, :] > threshold 338 for start, end in zip(indices_start, indices_end): 339 if torch.sum(hard[start:end]) > 0: 340 res_indices_start.append(start) 341 res_indices_end.append(end) 342 return res_indices_start, res_indices_end 343 else: 344 return indices_start, indices_end
Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed
In the chosen intervals the values at the main_class
index are larger than the others everywhere
and at least one value at the main_class
index passes the threshold.
If error_mask
is not None
, the elements marked False
are treated as if they did not pass the threshold.
If min_frames is not 0, the intervals are additionally filtered by length.
Parameters
tensor : torch.Tensor
the tensor to apply the threshold to (of shape (#classes, #frames)
)
threshold : float
the threshold
main_class : int
the class that conditions the soft threshold
error_mask : torch.Tensor, optional
a boolean real_lens to apply to the results
min_frames : int, default 0
the minimum number of frames in the resulting intervals (shorter intervals are discarded)
Returns
indices_start : list a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor) indices_end : list a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor)
347def strip_suffix(text: str, suffix: Iterable): 348 """ 349 Strip a suffix from a string if it is contained in a list 350 351 Parameters 352 ---------- 353 text : str 354 the main string 355 suffix : iterable 356 the list of suffices to be stripped 357 358 Returns 359 ------- 360 result : str 361 the stripped string 362 """ 363 364 for s in suffix: 365 if text.endswith(s): 366 return text[: -len(s)] 367 return text
Strip a suffix from a string if it is contained in a list
Parameters
text : str the main string suffix : iterable the list of suffices to be stripped
Returns
result : str the stripped string
370def strip_prefix(text: str, prefix: Iterable): 371 """ 372 Strip a prefix from a string if it is contained in a list 373 374 Parameters 375 ---------- 376 text : str 377 the main string 378 prefix : iterable 379 the list of prefixes to be stripped 380 381 Returns 382 ------- 383 result : str 384 the stripped string 385 """ 386 387 if prefix is None: 388 prefix = [] 389 for s in prefix: 390 if text.startswith(s): 391 return text[len(s) :] 392 return text
Strip a prefix from a string if it is contained in a list
Parameters
text : str the main string prefix : iterable the list of prefixes to be stripped
Returns
result : str the stripped string
395def rotation_matrix_2d(angles: torch.Tensor) -> torch.Tensor: 396 """ 397 Create a tensor of 2D rotation matrices from a tensor of angles 398 399 Parameters 400 ---------- 401 angles : torch.Tensor 402 a tensor of angles of arbitrary shape `(...)` 403 404 Returns 405 ------- 406 rotation_matrices : torch.Tensor 407 a tensor of 2D rotation matrices of shape `(..., 2, 2)` 408 """ 409 410 cos = torch.cos(angles) 411 sin = torch.sin(angles) 412 R = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*angles.shape, 2, 2) 413 return R
Create a tensor of 2D rotation matrices from a tensor of angles
Parameters
angles : torch.Tensor
a tensor of angles of arbitrary shape (...)
Returns
rotation_matrices : torch.Tensor
a tensor of 2D rotation matrices of shape (..., 2, 2)
416def rotation_matrix_3d(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor): 417 """ 418 Create a tensor of 3D rotation matrices from a tensor of angles 419 420 Parameters 421 ---------- 422 alpha : torch.Tensor 423 a tensor of rotation angles around the x axis of arbitrary shape `(...)` 424 beta : torch.Tensor 425 a tensor of rotation angles around the y axis of arbitrary shape `(...)` 426 gamma : torch.Tensor 427 a tensor of rotation angles around the z axis of arbitrary shape `(...)` 428 429 Returns 430 ------- 431 rotation_matrices : torch.Tensor 432 a tensor of 3D rotation matrices of shape `(..., 3, 3)` 433 """ 434 435 cos = torch.cos(alpha) 436 sin = torch.sin(alpha) 437 Rx = torch.stack( 438 [ 439 torch.ones(cos.shape), 440 torch.zeros(cos.shape), 441 torch.zeros(cos.shape), 442 torch.zeros(cos.shape), 443 cos, 444 -sin, 445 torch.zeros(cos.shape), 446 sin, 447 cos, 448 ], 449 dim=-1, 450 ).reshape(*alpha.shape, 3, 3) 451 cos = torch.cos(beta) 452 sin = torch.sin(beta) 453 Ry = torch.stack( 454 [ 455 cos, 456 torch.zeros(cos.shape), 457 sin, 458 torch.zeros(cos.shape), 459 torch.ones(cos.shape), 460 torch.zeros(cos.shape), 461 -sin, 462 torch.zeros(cos.shape), 463 cos, 464 ], 465 dim=-1, 466 ).reshape(*beta.shape, 3, 3) 467 cos = torch.cos(gamma) 468 sin = torch.sin(gamma) 469 Rz = torch.stack( 470 [ 471 cos, 472 -sin, 473 torch.zeros(cos.shape), 474 sin, 475 cos, 476 torch.zeros(cos.shape), 477 torch.zeros(cos.shape), 478 torch.zeros(cos.shape), 479 torch.ones(cos.shape), 480 ], 481 dim=-1, 482 ).reshape(*gamma.shape, 3, 3) 483 R = torch.einsum("...ij,...jk,...kl->...il", Rx, Ry, Rz) 484 return R
Create a tensor of 3D rotation matrices from a tensor of angles
Parameters
alpha : torch.Tensor
a tensor of rotation angles around the x axis of arbitrary shape (...)
beta : torch.Tensor
a tensor of rotation angles around the y axis of arbitrary shape (...)
gamma : torch.Tensor
a tensor of rotation angles around the z axis of arbitrary shape (...)
Returns
rotation_matrices : torch.Tensor
a tensor of 3D rotation matrices of shape (..., 3, 3)
487def correct_path(path, project_path): 488 if not isinstance(path, str): 489 return path 490 path = os.path.normpath(path).split(os.path.sep) 491 if "results" in path: 492 name = "results" 493 else: 494 name = "saved_datasets" 495 ind = path.index(name) + 1 496 return os.path.join(project_path, "results", *path[ind:])
499class TensorList(list): 500 """ 501 A list of tensors that can send each element to a `torch` device 502 """ 503 504 def to_device(self, device: torch.device): 505 for i, x in enumerate(self): 506 self[i] = x.to_device(device)
A list of tensors that can send each element to a torch
device
Inherited Members
- builtins.list
- list
- clear
- copy
- append
- insert
- extend
- pop
- remove
- index
- count
- reverse
- sort
509def get_intervals(tensor: torch.Tensor) -> torch.Tensor: 510 """ 511 Get a list of True group beginning and end indices from a boolean tensor 512 """ 513 514 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 515 true_indices = torch.where(output)[0] 516 starts = torch.tensor( 517 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 518 ) 519 ends = torch.tensor( 520 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 521 ) 522 return torch.stack([starts, ends]).T
Get a list of True group beginning and end indices from a boolean tensor
525def smooth(tensor: torch.Tensor, smooth_interval: int = 0) -> torch.Tensor: 526 """ 527 Get rid of jittering in a non-exclusive classification tensor 528 529 First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than 530 `smooth_interval`. 531 """ 532 533 if smooth_interval == 0: 534 return tensor 535 intervals = get_intervals(tensor == 0) 536 interval_lengths = torch.tensor( 537 [interval[1] - interval[0] for interval in intervals] 538 ) 539 short_intervals = intervals[interval_lengths <= smooth_interval] 540 for start, end in short_intervals: 541 tensor[start:end] = 1 542 intervals = get_intervals(tensor == 1) 543 interval_lengths = torch.tensor( 544 [interval[1] - interval[0] for interval in intervals] 545 ) 546 short_intervals = intervals[interval_lengths <= smooth_interval] 547 for start, end in short_intervals: 548 tensor[start:end] = 0 549 return tensor
Get rid of jittering in a non-exclusive classification tensor
First, remove intervals of 0 shorter than smooth_interval
. Then, remove intervals of 1 shorter than
smooth_interval
.
552class GaussianSmoothing(nn.Module): 553 """ 554 Apply gaussian smoothing on a 1d tensor. 555 Filtering is performed seperately for each channel 556 in the input using a depthwise convolution. 557 Arguments: 558 channels (int, sequence): Number of channels of the input tensors. Output will 559 have this number of channels as well. 560 kernel_size (int, sequence): Size of the gaussian kernel. 561 sigma (float, sequence): Standard deviation of the gaussian kernel. 562 """ 563 564 def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None: 565 super().__init__() 566 self.kernel_size = kernel_size 567 568 # The gaussian kernel is the product of the 569 # gaussian function of each dimension. 570 kernel = 1 571 meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float() 572 573 mean = (kernel_size - 1) / 2 574 kernel = kernel / (sigma * math.sqrt(2 * math.pi)) 575 kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2) 576 577 # Make sure sum of values in gaussian kernel equals 1. 578 # kernel = kernel / torch.max(kernel) 579 580 self.kernel = kernel.view(1, 1, *kernel.size()) 581 582 def forward(self, inputs: torch.Tensor) -> torch.Tensor: 583 """ 584 Apply gaussian filter to input. 585 Arguments: 586 input (torch.Tensor): Input to apply gaussian filter on. 587 Returns: 588 filtered (torch.Tensor): Filtered output. 589 """ 590 _, c, _ = inputs.shape 591 inputs = F.pad( 592 inputs, 593 pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2), 594 mode="reflect", 595 ) 596 kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device) 597 return F.conv1d(inputs, weight=kernel, groups=c)
Apply gaussian smoothing on a 1d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. Arguments: channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel.
564 def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None: 565 super().__init__() 566 self.kernel_size = kernel_size 567 568 # The gaussian kernel is the product of the 569 # gaussian function of each dimension. 570 kernel = 1 571 meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float() 572 573 mean = (kernel_size - 1) / 2 574 kernel = kernel / (sigma * math.sqrt(2 * math.pi)) 575 kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2) 576 577 # Make sure sum of values in gaussian kernel equals 1. 578 # kernel = kernel / torch.max(kernel) 579 580 self.kernel = kernel.view(1, 1, *kernel.size())
Initializes internal Module state, shared by both nn.Module and ScriptModule.
582 def forward(self, inputs: torch.Tensor) -> torch.Tensor: 583 """ 584 Apply gaussian filter to input. 585 Arguments: 586 input (torch.Tensor): Input to apply gaussian filter on. 587 Returns: 588 filtered (torch.Tensor): Filtered output. 589 """ 590 _, c, _ = inputs.shape 591 inputs = F.pad( 592 inputs, 593 pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2), 594 mode="reflect", 595 ) 596 kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device) 597 return F.conv1d(inputs, weight=kernel, groups=c)
Apply gaussian filter to input. Arguments: input (torch.Tensor): Input to apply gaussian filter on. Returns: filtered (torch.Tensor): Filtered output.
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- T_destination
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
600def argrelmax(prob: np.ndarray, threshold: float = 0.7) -> List[int]: 601 """ 602 Calculate arguments of relative maxima. 603 prob: np.array. boundary probability maps distributerd in [0, 1] 604 prob shape is (T) 605 ignore the peak whose value is under threshold 606 Return: 607 Index of peaks for each batch 608 """ 609 # ignore the values under threshold 610 prob[prob < threshold] = 0.0 611 612 # calculate the relative maxima of boundary maps 613 # treat the first frame as boundary 614 peak = np.concatenate( 615 [ 616 np.ones((1), dtype=np.bool), 617 (prob[:-2] < prob[1:-1]) & (prob[2:] < prob[1:-1]), 618 np.zeros((1), dtype=np.bool), 619 ], 620 axis=0, 621 ) 622 623 peak_idx = np.where(peak)[0].tolist() 624 625 return peak_idx
Calculate arguments of relative maxima. prob: np.array. boundary probability maps distributerd in [0, 1] prob shape is (T) ignore the peak whose value is under threshold Return: Index of peaks for each batch
628def decide_boundary_prob_with_similarity(x: torch.Tensor) -> torch.Tensor: 629 """ 630 Decide action boundary probabilities based on adjacent frame similarities. 631 Args: 632 x: frame-wise video features (N, C, T) 633 Return: 634 boundary: action boundary probability (N, 1, T) 635 """ 636 device = x.device 637 638 # gaussian kernel. 639 diff = x[0, :, 1:] - x[0, :, :-1] 640 similarity = torch.exp(-torch.norm(diff, dim=0) / (2 * 1.0)) 641 642 # define action starting point as action boundary. 643 start = torch.ones(1).float().to(device) 644 boundary = torch.cat([start, similarity]) 645 boundary = boundary.view(1, 1, -1) 646 return boundary
Decide action boundary probabilities based on adjacent frame similarities. Args: x: frame-wise video features (N, C, T) Return: boundary: action boundary probability (N, 1, T)
649class PostProcessor(object): 650 def __init__( 651 self, 652 name: str, 653 boundary_th: int = 0.7, 654 theta_t: int = 15, 655 kernel_size: int = 15, 656 ) -> None: 657 self.func = { 658 "refinement_with_boundary": self._refinement_with_boundary, 659 "relabeling": self._relabeling, 660 "smoothing": self._smoothing, 661 } 662 assert name in self.func 663 664 self.name = name 665 self.boundary_th = boundary_th 666 self.theta_t = theta_t 667 self.kernel_size = kernel_size 668 669 if name == "smoothing": 670 self.filter = GaussianSmoothing(self.kernel_size) 671 672 def _is_probability(self, x: np.ndarray) -> bool: 673 assert x.ndim == 3 674 675 if x.shape[1] == 1: 676 # sigmoid 677 if x.min() >= 0 and x.max() <= 1: 678 return True 679 else: 680 return False 681 else: 682 # softmax 683 _sum = np.sum(x, axis=1).astype(np.float32) 684 _ones = np.ones_like(_sum, dtype=np.float32) 685 return np.allclose(_sum, _ones) 686 687 def _convert2probability(self, x: np.ndarray) -> np.ndarray: 688 """ 689 Args: x (N, C, T) 690 """ 691 assert x.ndim == 3 692 693 if self._is_probability(x): 694 return x 695 else: 696 if x.shape[1] == 1: 697 # sigmoid 698 prob = 1 / (1 + np.exp(-x)) 699 else: 700 # softmax 701 prob = np.exp(x) / np.expand_dims(np.sum(np.exp(x), axis=1), 1) 702 return prob.astype(np.float32) 703 704 def _convert2label(self, x: np.ndarray) -> np.ndarray: 705 assert x.ndim == 2 or x.ndim == 3 706 707 if x.ndim == 2: 708 return x.astype(np.int64) 709 else: 710 if not self._is_probability(x): 711 x = self._convert2probability(x) 712 713 label = np.argmax(x, axis=1) 714 return label.astype(np.int64) 715 716 def _refinement_with_boundary( 717 self, 718 outputs: np.array, 719 boundaries: np.ndarray, 720 ) -> np.ndarray: 721 """ 722 Get segments which is defined as the span b/w two boundaries, 723 and decide their classes by majority vote. 724 Args: 725 outputs: numpy array. shape (N, C, T) 726 the model output for frame-level class prediction. 727 boundaries: numpy array. shape (N, 1, T) 728 boundary prediction. 729 masks: np.array. np.bool. shape (N, 1, T) 730 valid length for each video 731 Return: 732 preds: np.array. shape (N, T) 733 final class prediction considering boundaries. 734 """ 735 736 preds = self._convert2label(outputs) 737 boundaries = self._convert2probability(boundaries) 738 739 for i, (output, pred, boundary) in enumerate(zip(outputs, preds, boundaries)): 740 idx = argrelmax(boundary.squeeze(), threshold=self.boundary_th) 741 742 # add the index of the last action ending 743 T = pred.shape[0] 744 idx.append(T) 745 746 # majority vote 747 for j in range(len(idx) - 1): 748 count = np.bincount(pred[idx[j] : idx[j + 1]]) 749 modes = np.where(count == count.max())[0] 750 if len(modes) == 1: 751 mode = modes 752 else: 753 if outputs.ndim == 3: 754 # if more than one majority class exist 755 prob_sum_max = 0 756 for m in modes: 757 prob_sum = output[m, idx[j] : idx[j + 1]].sum() 758 if prob_sum_max < prob_sum: 759 mode = m 760 prob_sum_max = prob_sum 761 else: 762 # decide first mode when more than one majority class 763 # have the same number during oracle experiment 764 mode = modes[0] 765 766 preds[i, idx[j] : idx[j + 1]] = mode 767 768 return preds 769 770 def _relabeling(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: 771 """ 772 Relabeling small action segments with their previous action segment 773 Args: 774 output: the results of action segmentation. (N, T) or (N, C, T) 775 theta_t: the threshold of the size of action segments. 776 Return: 777 relabeled output. (N, T) 778 """ 779 780 preds = self._convert2label(outputs) 781 782 for i in range(preds.shape[0]): 783 # shape (T,) 784 last = preds[i][0] 785 cnt = 1 786 for j in range(1, preds.shape[1]): 787 if last == preds[i][j]: 788 cnt += 1 789 else: 790 if cnt > self.theta_t: 791 cnt = 1 792 last = preds[i][j] 793 else: 794 preds[i][j - cnt : j] = preds[i][j - cnt - 1] 795 cnt = 1 796 last = preds[i][j] 797 798 if cnt <= self.theta_t: 799 preds[i][j - cnt : j] = preds[i][j - cnt - 1] 800 801 return preds 802 803 def _smoothing(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: 804 """ 805 Smoothing action probabilities with gaussian filter. 806 Args: 807 outputs: frame-wise action probabilities. (N, C, T) 808 Return: 809 predictions: final prediction. (N, T) 810 """ 811 812 outputs = self._convert2probability(outputs) 813 outputs = self.filter(torch.Tensor(outputs)).numpy() 814 815 preds = self._convert2label(outputs) 816 return preds 817 818 def __call__(self, outputs, **kwargs: np.ndarray) -> np.ndarray: 819 820 preds = self.func[self.name](outputs, **kwargs) 821 return preds
650 def __init__( 651 self, 652 name: str, 653 boundary_th: int = 0.7, 654 theta_t: int = 15, 655 kernel_size: int = 15, 656 ) -> None: 657 self.func = { 658 "refinement_with_boundary": self._refinement_with_boundary, 659 "relabeling": self._relabeling, 660 "smoothing": self._smoothing, 661 } 662 assert name in self.func 663 664 self.name = name 665 self.boundary_th = boundary_th 666 self.theta_t = theta_t 667 self.kernel_size = kernel_size 668 669 if name == "smoothing": 670 self.filter = GaussianSmoothing(self.kernel_size)