dlc2action.utils
Utility functions
TensorDictis a convenient data structure for keeping (and indexing) lists of feature dictionaries,apply_threshold,apply_threshold_hysteresisandapply_threshold_maxare utility functions fordlc2action.data.dataset.BehaviorDataset.find_valleys,strip_suffixis used to get rid of suffices if a string (usually filename) ends with one of them,strip_prefixis used to get rid of prefixes if a string (usually filename) starts with one of them,rotation_matrix_2dandrotation_matrix_3dare used to generate rotation matrices bydlc2action.transformer.base_transformer.Transformerinstances
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7""" 8## Utility functions 9 10- `TensorDict` is a convenient data structure for keeping (and indexing) lists of feature dictionaries, 11- `apply_threshold`, `apply_threshold_hysteresis` and `apply_threshold_max` are utility functions for 12`dlc2action.data.dataset.BehaviorDataset.find_valleys`, 13- `strip_suffix` is used to get rid of suffices if a string (usually filename) ends with one of them, 14- `strip_prefix` is used to get rid of prefixes if a string (usually filename) starts with one of them, 15- `rotation_matrix_2d` and `rotation_matrix_3d` are used to generate rotation matrices by 16`dlc2action.transformer.base_transformer.Transformer` instances 17""" 18 19import torch 20from typing import List, Dict, Union 21from collections.abc import Iterable 22import warnings 23import os 24import numpy as np 25from torch import nn 26from torch.nn import functional as F 27import pickle 28import re 29import math 30 31 32class TensorDict: 33 """ 34 A class that handles indexing in a dictionary of tensors of the same length 35 """ 36 37 def __init__(self, obj: Union[Dict, Iterable] = None) -> None: 38 """ 39 Parameters 40 ---------- 41 obj : dict | iterable, optional 42 either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with 43 the same keys (if not passed, a blank TensorDict is initialized) 44 """ 45 46 if obj is None: 47 obj = {} 48 if isinstance(obj, list): 49 self.keys = list(obj[0].keys()) 50 self.dict = {key: [] for key in self.keys} 51 with warnings.catch_warnings(): 52 warnings.filterwarnings("ignore", category=UserWarning) 53 for element in obj: 54 for key in self.keys: 55 self.dict[key].append(torch.tensor(element[key])) 56 # self.dict = {k: torch.stack(v) for k, v in self.dict.items()} 57 old_dict = self.dict 58 self.dict = {} 59 for k, v in old_dict.items(): 60 self.dict[k] = torch.stack(v) 61 elif isinstance(obj, dict): 62 if not all([isinstance(obj[key], torch.Tensor) for key in obj]): 63 raise TypeError( 64 f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;" 65 f"got {[type(obj[key]) for key in obj]}" 66 ) 67 lengths = [len(obj[key]) for key in obj] 68 if not all([x == lengths[0] for x in lengths]): 69 raise ValueError( 70 f"The value tensors in the dictionary passed to TensorDict need to have the same length;" 71 f"got {lengths}" 72 ) 73 self.dict = obj 74 self.keys = list(self.dict.keys()) 75 else: 76 raise TypeError( 77 f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary " 78 f"of tensors (got {type(obj)})" 79 ) 80 if len(self.keys) > 0: 81 self.type = type(self.dict[self.keys[0]]) 82 else: 83 self.type = None 84 85 def __len__(self) -> int: 86 return len(self.dict[self.keys[0]]) 87 88 def __getitem__(self, ind: Union[int, List]): 89 """ 90 Index the TensorDict 91 92 Parameters 93 ---------- 94 ind : int | list 95 the index/indices 96 97 Returns 98 ------- 99 dict | TensorDict 100 the indexed elements of all value lists combined in a dictionary (if length of ind is 1) or a TensorDict 101 """ 102 103 x = {key: self.dict[key][ind] for key in self.dict} 104 if type(ind) is not int: 105 d = TensorDict(x) 106 return d 107 else: 108 return x 109 110 def append(self, element: Dict) -> None: 111 """ 112 Append an element 113 114 Parameters 115 ---------- 116 element : dict 117 a dictionary 118 """ 119 120 type_element = type(element[list(element.keys())[0]]) 121 if self.type is None: 122 self.dict = {k: v.unsqueeze(0) for k, v in element.items()} 123 self.keys = list(element.keys()) 124 self.type = type_element 125 else: 126 for key in self.keys: 127 if key not in element: 128 raise ValueError( 129 f"The dictionary appended to TensorDict needs to have the same keys as the " 130 f"TensorDict; got {element.keys()} and {self.keys}" 131 ) 132 self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)]) 133 134 def remove(self, indices: List) -> None: 135 """ 136 Remove indexed elements 137 138 Parameters 139 ---------- 140 indices : list 141 the indices to remove 142 """ 143 144 mask = torch.ones(len(self)) 145 mask[indices] = 0 146 mask = mask.bool() 147 for key, value in self.dict.items(): 148 self.dict[key] = value[mask] 149 150 151def apply_threshold( 152 tensor: torch.Tensor, 153 threshold: float, 154 low: bool = True, 155 error_mask: torch.Tensor = None, 156 min_frames: int = 0, 157 smooth_interval: int = 0, 158 masked_intervals: List = None, 159): 160 """ 161 Apply a hard threshold to a tensor and return indices of the intervals that passed. 162 163 If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold. 164 If `min_frames` is not 0, the intervals are additionally filtered by length. 165 166 Parameters 167 ---------- 168 tensor : torch.Tensor 169 the tensor to apply the threshold to 170 threshold : float 171 the threshold 172 low : bool, default True 173 whether to select the intervals below the threshold (if `True`) or above (if `False`) 174 error_mask : torch.Tensor, optional 175 a boolean real_lens to apply to the results 176 min_frames : int, default 0 177 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 178 smooth_interval : int, default 0 179 the number of frames to smooth the tensor with before applying the threshold 180 masked_intervals : list, optional 181 a list of intervals to mask out 182 183 Returns 184 ------- 185 indices_start : list 186 a list of indices of the first frames of the chosen intervals 187 indices_end : list 188 a list of indices of the last frames of the chosen intervals 189 190 """ 191 if masked_intervals is None: 192 masked_intervals = [] 193 if low: 194 p = tensor <= threshold 195 else: 196 p = tensor >= threshold 197 p = smooth(p, smooth_interval) 198 if error_mask is not None: 199 p = p * error_mask 200 for start, end in masked_intervals: 201 p[start:end] = False 202 output, indices, counts = torch.unique_consecutive( 203 p, return_inverse=True, return_counts=True 204 ) 205 long_indices = torch.where(output * (counts > min_frames))[0] 206 res_indices_start = [ 207 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 208 ] 209 res_indices_end = [ 210 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 211 ] 212 return res_indices_start, res_indices_end 213 214 215def apply_threshold_hysteresis( 216 tensor: torch.Tensor, 217 soft_threshold: float, 218 hard_threshold: float, 219 low: bool = True, 220 error_mask: torch.Tensor = None, 221 min_frames: int = 0, 222 smooth_interval: int = 0, 223 masked_intervals: List = None, 224): 225 """ 226 Apply a hysteresis threshold to a tensor and return indices of the intervals that passed. 227 228 In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold. 229 If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold. 230 If `min_frames` is not 0, the intervals are additionally filtered by length. 231 232 Parameters 233 ---------- 234 tensor : torch.Tensor 235 the tensor to apply the threshold to 236 soft_threshold : float 237 the soft threshold 238 hard_threshold : float 239 the hard threshold 240 low : bool, default True 241 whether to select values below the threshold (`True`) or above (`False`) 242 error_mask : torch.Tensor, optional 243 a boolean real_lens to apply to the results 244 min_frames : int, default 0 245 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 246 smooth_interval : int, default 0 247 the number of frames to smooth the tensor with 248 masked_intervals : list, default None 249 a list of intervals to mask out 250 251 Returns 252 ------- 253 indices_start : list 254 a list of indices of the first frames of the chosen intervals 255 indices_end : list 256 a list of indices of the last frames of the chosen intervals 257 258 """ 259 if masked_intervals is None: 260 masked_intervals = [] 261 if low: 262 p = tensor <= soft_threshold 263 hard = tensor <= hard_threshold 264 else: 265 p = tensor >= soft_threshold 266 hard = tensor >= hard_threshold 267 p = smooth(p, smooth_interval) 268 if error_mask is not None: 269 p = p * error_mask 270 for start, end in masked_intervals: 271 p[start:end] = False 272 output, indices, counts = torch.unique_consecutive( 273 p, return_inverse=True, return_counts=True 274 ) 275 long_indices = torch.where(output * (counts > min_frames))[0] 276 indices_start = [ 277 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 278 ] 279 indices_end = [ 280 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 281 ] 282 res_indices_start = [] 283 res_indices_end = [] 284 for start, end in zip(indices_start, indices_end): 285 if torch.sum(hard[start:end]) > 0: 286 res_indices_start.append(start) 287 res_indices_end.append(end) 288 return res_indices_start, res_indices_end 289 290 291def apply_threshold_max( 292 tensor: torch.Tensor, 293 threshold: float, 294 main_class: int, 295 error_mask: torch.Tensor = None, 296 min_frames: int = 0, 297 smooth_interval: int = 0, 298 masked_intervals: List = None, 299): 300 """ 301 Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed. 302 303 In the chosen intervals the values at the `main_class` index are larger than the others everywhere 304 and at least one value at the `main_class` index passes the threshold. 305 If `error_mask` is not `None`, the elements marked `False`are treated as if they did not pass the threshold. 306 If min_frames is not 0, the intervals are additionally filtered by length. 307 308 Parameters 309 ---------- 310 tensor : torch.Tensor 311 the tensor to apply the threshold to (of shape `(#classes, #frames)`) 312 threshold : float 313 the threshold 314 main_class : int 315 the class that conditions the soft threshold 316 error_mask : torch.Tensor, optional 317 a boolean real_lens to apply to the results 318 min_frames : int, default 0 319 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 320 smooth_interval : int, default 0 321 the number of frames to smooth the tensor with 322 masked_intervals : list, default None 323 a list of intervals to mask out 324 325 Returns 326 ------- 327 indices_start : list 328 a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor) 329 indices_end : list 330 a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor) 331 332 """ 333 if masked_intervals is None: 334 masked_intervals = [] 335 _, indices = torch.max(tensor, dim=0) 336 p = indices == main_class 337 p = smooth(p, smooth_interval) 338 if error_mask is not None: 339 p = p * error_mask 340 for start, end in masked_intervals: 341 p[start:end] = False 342 output, indices, counts = torch.unique_consecutive( 343 p, return_inverse=True, return_counts=True 344 ) 345 long_indices = torch.where(output * (counts > min_frames))[0] 346 indices_start = [ 347 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 348 ] 349 indices_end = [ 350 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 351 ] 352 res_indices_start = [] 353 res_indices_end = [] 354 if threshold is not None: 355 hard = tensor[main_class, :] > threshold 356 for start, end in zip(indices_start, indices_end): 357 if torch.sum(hard[start:end]) > 0: 358 res_indices_start.append(start) 359 res_indices_end.append(end) 360 return res_indices_start, res_indices_end 361 else: 362 return indices_start, indices_end 363 364 365def strip_suffix(text: str, suffix: Iterable): 366 """ 367 Strip a suffix from a string if it is contained in a list 368 369 Parameters 370 ---------- 371 text : str 372 the main string 373 suffix : iterable 374 the list of suffices to be stripped 375 376 Returns 377 ------- 378 result : str 379 the stripped string 380 """ 381 382 for s in suffix: 383 if text.endswith(s): 384 return text[: -len(s)] 385 return text 386 387 388def strip_prefix(text: str, prefix: Iterable): 389 """ 390 Strip a prefix from a string if it is contained in a list 391 392 Parameters 393 ---------- 394 text : str 395 the main string 396 prefix : iterable 397 the list of prefixes to be stripped 398 399 Returns 400 ------- 401 result : str 402 the stripped string 403 """ 404 405 if prefix is None: 406 prefix = [] 407 for s in prefix: 408 if text.startswith(s): 409 return text[len(s) :] 410 return text 411 412 413def rotation_matrix_2d(angles: torch.Tensor) -> torch.Tensor: 414 """ 415 Create a tensor of 2D rotation matrices from a tensor of angles 416 417 Parameters 418 ---------- 419 angles : torch.Tensor 420 a tensor of angles of arbitrary shape `(...)` 421 422 Returns 423 ------- 424 rotation_matrices : torch.Tensor 425 a tensor of 2D rotation matrices of shape `(..., 2, 2)` 426 """ 427 428 cos = torch.cos(angles) 429 sin = torch.sin(angles) 430 R = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*angles.shape, 2, 2) 431 return R 432 433 434def rotation_matrix_3d(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor): 435 """ 436 Create a tensor of 3D rotation matrices from a tensor of angles 437 438 Parameters 439 ---------- 440 alpha : torch.Tensor 441 a tensor of rotation angles around the x axis of arbitrary shape `(...)` 442 beta : torch.Tensor 443 a tensor of rotation angles around the y axis of arbitrary shape `(...)` 444 gamma : torch.Tensor 445 a tensor of rotation angles around the z axis of arbitrary shape `(...)` 446 447 Returns 448 ------- 449 rotation_matrices : torch.Tensor 450 a tensor of 3D rotation matrices of shape `(..., 3, 3)` 451 """ 452 453 cos = torch.cos(alpha) 454 sin = torch.sin(alpha) 455 Rx = torch.stack( 456 [ 457 torch.ones(cos.shape), 458 torch.zeros(cos.shape), 459 torch.zeros(cos.shape), 460 torch.zeros(cos.shape), 461 cos, 462 -sin, 463 torch.zeros(cos.shape), 464 sin, 465 cos, 466 ], 467 dim=-1, 468 ).reshape(*alpha.shape, 3, 3) 469 cos = torch.cos(beta) 470 sin = torch.sin(beta) 471 Ry = torch.stack( 472 [ 473 cos, 474 torch.zeros(cos.shape), 475 sin, 476 torch.zeros(cos.shape), 477 torch.ones(cos.shape), 478 torch.zeros(cos.shape), 479 -sin, 480 torch.zeros(cos.shape), 481 cos, 482 ], 483 dim=-1, 484 ).reshape(*beta.shape, 3, 3) 485 cos = torch.cos(gamma) 486 sin = torch.sin(gamma) 487 Rz = torch.stack( 488 [ 489 cos, 490 -sin, 491 torch.zeros(cos.shape), 492 sin, 493 cos, 494 torch.zeros(cos.shape), 495 torch.zeros(cos.shape), 496 torch.zeros(cos.shape), 497 torch.ones(cos.shape), 498 ], 499 dim=-1, 500 ).reshape(*gamma.shape, 3, 3) 501 R = torch.einsum("...ij,...jk,...kl->...il", Rx, Ry, Rz) 502 return R 503 504 505def correct_path(path, project_path): 506 """ 507 Correct a path to a file or directory to be relative to the project path. 508 509 Parameters 510 ---------- 511 path : str 512 the path to be corrected 513 project_path : str 514 the path to the project root directory 515 516 Returns 517 ------- 518 corrected_path : str 519 the corrected path 520 521 """ 522 if not isinstance(path, str): 523 return path 524 path_split = os.path.normpath(path).split(os.path.sep) 525 if "results" in path_split: 526 name = "results" 527 elif "saved_datasets" in path: 528 name = "saved_datasets" 529 else: 530 return os.path.sep.join(path_split) # No need for correction 531 ind = path_split.index(name) + 1 532 return os.path.join(project_path, "results", *path_split[ind:]) 533 534 535class TensorList(list): 536 """ 537 A list of tensors that can send each element to a `torch` device 538 """ 539 540 def to_device(self, device: torch.device): 541 """Send each element of the list to a `torch` device.""" 542 for i, x in enumerate(self): 543 self[i] = x.to_device(device) 544 545 546def get_intervals(tensor: torch.Tensor) -> torch.Tensor: 547 """ 548 Get a list of True group beginning and end indices from a boolean tensor 549 """ 550 551 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 552 true_indices = torch.where(output)[0] 553 starts = torch.tensor( 554 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 555 ) 556 ends = torch.tensor( 557 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 558 ) 559 return torch.stack([starts, ends]).T 560 561 562def smooth(tensor: torch.Tensor, smooth_interval: int = 0) -> torch.Tensor: 563 """ 564 Get rid of jittering in a non-exclusive classification tensor 565 566 First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than 567 `smooth_interval`. 568 """ 569 570 if smooth_interval == 0: 571 return tensor 572 intervals = get_intervals(tensor == 0) 573 interval_lengths = torch.tensor( 574 [interval[1] - interval[0] for interval in intervals] 575 ) 576 short_intervals = intervals[interval_lengths <= smooth_interval] 577 for start, end in short_intervals: 578 tensor[start:end] = 1 579 intervals = get_intervals(tensor == 1) 580 interval_lengths = torch.tensor( 581 [interval[1] - interval[0] for interval in intervals] 582 ) 583 short_intervals = intervals[interval_lengths <= smooth_interval] 584 for start, end in short_intervals: 585 tensor[start:end] = 0 586 return tensor 587 588 589class GaussianSmoothing(nn.Module): 590 """Apply gaussian smoothing on a 1d tensor. 591 592 Filtering is performed separately for each channel 593 in the input using a depthwise convolution. 594 Arguments: 595 channels (int, sequence): Number of channels of the input tensors. Output will 596 have this number of channels as well. 597 kernel_size (int, sequence): Size of the gaussian kernel. 598 sigma (float, sequence): Standard deviation of the gaussian kernel. 599 """ 600 601 def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None: 602 super().__init__() 603 self.kernel_size = kernel_size 604 605 # The gaussian kernel is the product of the 606 # gaussian function of each dimension. 607 kernel = 1 608 meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float() 609 610 mean = (kernel_size - 1) / 2 611 kernel = kernel / (sigma * math.sqrt(2 * math.pi)) 612 kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2) 613 614 # Make sure sum of values in gaussian kernel equals 1. 615 # kernel = kernel / torch.max(kernel) 616 617 self.kernel = kernel.view(1, 1, *kernel.size()) 618 619 def forward(self, inputs: torch.Tensor) -> torch.Tensor: 620 """ 621 Apply gaussian filter to input. 622 623 Parameters 624 ---------- 625 inputs : torch.Tensor 626 Input to apply gaussian filter on. 627 628 Returns 629 ------- 630 filtered : torch.Tensor) 631 Filtered output. 632 633 """ 634 _, c, _ = inputs.shape 635 inputs = F.pad( 636 inputs, 637 pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2), 638 mode="reflect", 639 ) 640 kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device) 641 return F.conv1d(inputs, weight=kernel, groups=c) 642 643 644def argrelmax(prob: np.ndarray, threshold: float = 0.7) -> List[int]: 645 """ 646 Calculate arguments of relative maxima. 647 648 Parameters 649 ---------- 650 prob : np.ndarray 651 boundary probability maps distributerd in [0, 1], shape is (T) 652 threshold : float, default 0.7 653 ignore the peak whose value is under threshold 654 655 Returns 656 ------- 657 peak_idx : List[int] 658 Index of peaks for each batch 659 660 """ 661 # ignore the values under threshold 662 prob[prob < threshold] = 0.0 663 664 # calculate the relative maxima of boundary maps 665 # treat the first frame as boundary 666 peak = np.concatenate( 667 [ 668 np.ones((1), dtype=np.bool), 669 (prob[:-2] < prob[1:-1]) & (prob[2:] < prob[1:-1]), 670 np.zeros((1), dtype=np.bool), 671 ], 672 axis=0, 673 ) 674 675 peak_idx = np.where(peak)[0].tolist() 676 677 return peak_idx 678 679 680def decide_boundary_prob_with_similarity(x: torch.Tensor) -> torch.Tensor: 681 """ 682 Decide action boundary probabilities based on adjacent frame similarities. 683 Args: 684 x: frame-wise video features (N, C, T) 685 Return: 686 boundary: action boundary probability (N, 1, T) 687 """ 688 device = x.device 689 690 # gaussian kernel. 691 diff = x[0, :, 1:] - x[0, :, :-1] 692 similarity = torch.exp(-torch.norm(diff, dim=0) / (2 * 1.0)) 693 694 # define action starting point as action boundary. 695 start = torch.ones(1).float().to(device) 696 boundary = torch.cat([start, similarity]) 697 boundary = boundary.view(1, 1, -1) 698 return boundary 699 700 701class PostProcessor(object): 702 def __init__( 703 self, 704 name: str, 705 boundary_th: int = 0.7, 706 theta_t: int = 15, 707 kernel_size: int = 15, 708 ) -> None: 709 self.func = { 710 "refinement_with_boundary": self._refinement_with_boundary, 711 "relabeling": self._relabeling, 712 "smoothing": self._smoothing, 713 } 714 assert name in self.func 715 716 self.name = name 717 self.boundary_th = boundary_th 718 self.theta_t = theta_t 719 self.kernel_size = kernel_size 720 721 if name == "smoothing": 722 self.filter = GaussianSmoothing(self.kernel_size) 723 724 def _is_probability(self, x: np.ndarray) -> bool: 725 assert x.ndim == 3 726 727 if x.shape[1] == 1: 728 # sigmoid 729 if x.min() >= 0 and x.max() <= 1: 730 return True 731 else: 732 return False 733 else: 734 # softmax 735 _sum = np.sum(x, axis=1).astype(np.float32) 736 _ones = np.ones_like(_sum, dtype=np.float32) 737 return np.allclose(_sum, _ones) 738 739 def _convert2probability(self, x: np.ndarray) -> np.ndarray: 740 """ 741 Args: x (N, C, T) 742 """ 743 assert x.ndim == 3 744 745 if self._is_probability(x): 746 return x 747 else: 748 if x.shape[1] == 1: 749 # sigmoid 750 prob = 1 / (1 + np.exp(-x)) 751 else: 752 # softmax 753 prob = np.exp(x) / np.expand_dims(np.sum(np.exp(x), axis=1), 1) 754 return prob.astype(np.float32) 755 756 def _convert2label(self, x: np.ndarray) -> np.ndarray: 757 assert x.ndim == 2 or x.ndim == 3 758 759 if x.ndim == 2: 760 return x.astype(np.int64) 761 else: 762 if not self._is_probability(x): 763 x = self._convert2probability(x) 764 765 label = np.argmax(x, axis=1) 766 return label.astype(np.int64) 767 768 def _refinement_with_boundary( 769 self, 770 outputs: np.array, 771 boundaries: np.ndarray, 772 ) -> np.ndarray: 773 """ 774 Get segments which is defined as the span b/w two boundaries, 775 and decide their classes by majority vote. 776 Args: 777 outputs: numpy array. shape (N, C, T) 778 the model output for frame-level class prediction. 779 boundaries: numpy array. shape (N, 1, T) 780 boundary prediction. 781 masks: np.array. np.bool. shape (N, 1, T) 782 valid length for each video 783 Return: 784 preds: np.array. shape (N, T) 785 final class prediction considering boundaries. 786 """ 787 788 preds = self._convert2label(outputs) 789 boundaries = self._convert2probability(boundaries) 790 791 for i, (output, pred, boundary) in enumerate(zip(outputs, preds, boundaries)): 792 idx = argrelmax(boundary.squeeze(), threshold=self.boundary_th) 793 794 # add the index of the last action ending 795 T = pred.shape[0] 796 idx.append(T) 797 798 # majority vote 799 for j in range(len(idx) - 1): 800 count = np.bincount(pred[idx[j] : idx[j + 1]]) 801 modes = np.where(count == count.max())[0] 802 if len(modes) == 1: 803 mode = modes 804 else: 805 if outputs.ndim == 3: 806 # if more than one majority class exist 807 prob_sum_max = 0 808 for m in modes: 809 prob_sum = output[m, idx[j] : idx[j + 1]].sum() 810 if prob_sum_max < prob_sum: 811 mode = m 812 prob_sum_max = prob_sum 813 else: 814 # decide first mode when more than one majority class 815 # have the same number during oracle experiment 816 mode = modes[0] 817 818 preds[i, idx[j] : idx[j + 1]] = mode 819 820 return preds 821 822 def _relabeling(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: 823 """ 824 Relabeling small action segments with their previous action segment 825 Args: 826 output: the results of action segmentation. (N, T) or (N, C, T) 827 theta_t: the threshold of the size of action segments. 828 Return: 829 relabeled output. (N, T) 830 """ 831 832 preds = self._convert2label(outputs) 833 834 for i in range(preds.shape[0]): 835 # shape (T,) 836 last = preds[i][0] 837 cnt = 1 838 for j in range(1, preds.shape[1]): 839 if last == preds[i][j]: 840 cnt += 1 841 else: 842 if cnt > self.theta_t: 843 cnt = 1 844 last = preds[i][j] 845 else: 846 preds[i][j - cnt : j] = preds[i][j - cnt - 1] 847 cnt = 1 848 last = preds[i][j] 849 850 if cnt <= self.theta_t: 851 preds[i][j - cnt : j] = preds[i][j - cnt - 1] 852 853 return preds 854 855 def _smoothing(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: 856 """ 857 Smoothing action probabilities with gaussian filter. 858 Args: 859 outputs: frame-wise action probabilities. (N, C, T) 860 Return: 861 predictions: final prediction. (N, T) 862 """ 863 864 outputs = self._convert2probability(outputs) 865 outputs = self.filter(torch.Tensor(outputs)).numpy() 866 867 preds = self._convert2label(outputs) 868 return preds 869 870 def __call__(self, outputs, **kwargs: np.ndarray) -> np.ndarray: 871 872 preds = self.func[self.name](outputs, **kwargs) 873 return preds 874 875 876def get_sentences_from_article( 877 pdf_filepath: str, max_symbols=77, min_symbols=15 878) -> List: 879 """ 880 Read a pdf file, clean the text from references and split into sentences 881 882 Parameters 883 ---------- 884 pdf_filepath : str 885 the path to the pdf file 886 max_symbols : int, default 77 887 sentences longer than this number of symbols will be omitted 888 min_symbols : int, default 15 889 sentences shorter than this number of symbols will be omitted 890 891 Returns 892 ------- 893 sentences : list 894 a list of sentences 895 """ 896 897 from tika import parser 898 899 rawText = parser.from_file(pdf_filepath) 900 text = rawText["content"].replace("-\n", "") 901 text = text.replace("\n", " ") 902 text = re.sub(r"(?<=\d),(?=\d)", "", text) 903 text = text.split("References")[0] 904 text = re.sub(r" \([\w&, ;]+\d{4}\)", "", text).strip() 905 text = re.sub(r" \[\d+\]", "", text).strip() 906 sentences = [x.strip() for x in text.split(".")] 907 sentences = [ 908 x 909 for x in sentences 910 if not x.endswith(" al") 911 and len(x) <= max_symbols 912 and len(x) >= min_symbols 913 and sum(c.isdigit() for c in x) < len(x) * 0.2 914 and sum(len(w) < 2 for w in x.split()) < len(x) * 0.2 915 ] 916 return list(set(sentences)) 917 918 919 920def load_pickle(file_path: str): 921 """ 922 Load a pickle file 923 924 Parameters 925 ---------- 926 file_path : str 927 the path to the pickle file 928 929 Returns 930 ------- 931 obj : any 932 the loaded object 933 """ 934 935 with open(file_path, "rb") as f: 936 obj = pickle.load(f) 937 return obj 938 939 940def binarize_data(data: list, max_frame: int = -1) -> np.array: 941 """Convert file in start, stop, confusion in to binary matrix of size (N_classes, N_timepoints)""" 942 data = data[3][0] 943 all_annot = [arr for arr in data if len(arr)] 944 all_annot = np.concatenate(all_annot, axis=0) 945 max_frame = np.max(all_annot) if max_frame == -1 else max_frame 946 label_array = np.zeros((len(data), max_frame)) 947 for k, arr in enumerate(data): 948 for start, stop, confusion in arr: 949 if confusion < 0.5: 950 label_array[k, start:stop] = 1 951 return label_array
33class TensorDict: 34 """ 35 A class that handles indexing in a dictionary of tensors of the same length 36 """ 37 38 def __init__(self, obj: Union[Dict, Iterable] = None) -> None: 39 """ 40 Parameters 41 ---------- 42 obj : dict | iterable, optional 43 either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with 44 the same keys (if not passed, a blank TensorDict is initialized) 45 """ 46 47 if obj is None: 48 obj = {} 49 if isinstance(obj, list): 50 self.keys = list(obj[0].keys()) 51 self.dict = {key: [] for key in self.keys} 52 with warnings.catch_warnings(): 53 warnings.filterwarnings("ignore", category=UserWarning) 54 for element in obj: 55 for key in self.keys: 56 self.dict[key].append(torch.tensor(element[key])) 57 # self.dict = {k: torch.stack(v) for k, v in self.dict.items()} 58 old_dict = self.dict 59 self.dict = {} 60 for k, v in old_dict.items(): 61 self.dict[k] = torch.stack(v) 62 elif isinstance(obj, dict): 63 if not all([isinstance(obj[key], torch.Tensor) for key in obj]): 64 raise TypeError( 65 f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;" 66 f"got {[type(obj[key]) for key in obj]}" 67 ) 68 lengths = [len(obj[key]) for key in obj] 69 if not all([x == lengths[0] for x in lengths]): 70 raise ValueError( 71 f"The value tensors in the dictionary passed to TensorDict need to have the same length;" 72 f"got {lengths}" 73 ) 74 self.dict = obj 75 self.keys = list(self.dict.keys()) 76 else: 77 raise TypeError( 78 f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary " 79 f"of tensors (got {type(obj)})" 80 ) 81 if len(self.keys) > 0: 82 self.type = type(self.dict[self.keys[0]]) 83 else: 84 self.type = None 85 86 def __len__(self) -> int: 87 return len(self.dict[self.keys[0]]) 88 89 def __getitem__(self, ind: Union[int, List]): 90 """ 91 Index the TensorDict 92 93 Parameters 94 ---------- 95 ind : int | list 96 the index/indices 97 98 Returns 99 ------- 100 dict | TensorDict 101 the indexed elements of all value lists combined in a dictionary (if length of ind is 1) or a TensorDict 102 """ 103 104 x = {key: self.dict[key][ind] for key in self.dict} 105 if type(ind) is not int: 106 d = TensorDict(x) 107 return d 108 else: 109 return x 110 111 def append(self, element: Dict) -> None: 112 """ 113 Append an element 114 115 Parameters 116 ---------- 117 element : dict 118 a dictionary 119 """ 120 121 type_element = type(element[list(element.keys())[0]]) 122 if self.type is None: 123 self.dict = {k: v.unsqueeze(0) for k, v in element.items()} 124 self.keys = list(element.keys()) 125 self.type = type_element 126 else: 127 for key in self.keys: 128 if key not in element: 129 raise ValueError( 130 f"The dictionary appended to TensorDict needs to have the same keys as the " 131 f"TensorDict; got {element.keys()} and {self.keys}" 132 ) 133 self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)]) 134 135 def remove(self, indices: List) -> None: 136 """ 137 Remove indexed elements 138 139 Parameters 140 ---------- 141 indices : list 142 the indices to remove 143 """ 144 145 mask = torch.ones(len(self)) 146 mask[indices] = 0 147 mask = mask.bool() 148 for key, value in self.dict.items(): 149 self.dict[key] = value[mask]
A class that handles indexing in a dictionary of tensors of the same length
38 def __init__(self, obj: Union[Dict, Iterable] = None) -> None: 39 """ 40 Parameters 41 ---------- 42 obj : dict | iterable, optional 43 either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with 44 the same keys (if not passed, a blank TensorDict is initialized) 45 """ 46 47 if obj is None: 48 obj = {} 49 if isinstance(obj, list): 50 self.keys = list(obj[0].keys()) 51 self.dict = {key: [] for key in self.keys} 52 with warnings.catch_warnings(): 53 warnings.filterwarnings("ignore", category=UserWarning) 54 for element in obj: 55 for key in self.keys: 56 self.dict[key].append(torch.tensor(element[key])) 57 # self.dict = {k: torch.stack(v) for k, v in self.dict.items()} 58 old_dict = self.dict 59 self.dict = {} 60 for k, v in old_dict.items(): 61 self.dict[k] = torch.stack(v) 62 elif isinstance(obj, dict): 63 if not all([isinstance(obj[key], torch.Tensor) for key in obj]): 64 raise TypeError( 65 f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;" 66 f"got {[type(obj[key]) for key in obj]}" 67 ) 68 lengths = [len(obj[key]) for key in obj] 69 if not all([x == lengths[0] for x in lengths]): 70 raise ValueError( 71 f"The value tensors in the dictionary passed to TensorDict need to have the same length;" 72 f"got {lengths}" 73 ) 74 self.dict = obj 75 self.keys = list(self.dict.keys()) 76 else: 77 raise TypeError( 78 f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary " 79 f"of tensors (got {type(obj)})" 80 ) 81 if len(self.keys) > 0: 82 self.type = type(self.dict[self.keys[0]]) 83 else: 84 self.type = None
Parameters
obj : dict | iterable, optional either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with the same keys (if not passed, a blank TensorDict is initialized)
111 def append(self, element: Dict) -> None: 112 """ 113 Append an element 114 115 Parameters 116 ---------- 117 element : dict 118 a dictionary 119 """ 120 121 type_element = type(element[list(element.keys())[0]]) 122 if self.type is None: 123 self.dict = {k: v.unsqueeze(0) for k, v in element.items()} 124 self.keys = list(element.keys()) 125 self.type = type_element 126 else: 127 for key in self.keys: 128 if key not in element: 129 raise ValueError( 130 f"The dictionary appended to TensorDict needs to have the same keys as the " 131 f"TensorDict; got {element.keys()} and {self.keys}" 132 ) 133 self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)])
Append an element
Parameters
element : dict a dictionary
135 def remove(self, indices: List) -> None: 136 """ 137 Remove indexed elements 138 139 Parameters 140 ---------- 141 indices : list 142 the indices to remove 143 """ 144 145 mask = torch.ones(len(self)) 146 mask[indices] = 0 147 mask = mask.bool() 148 for key, value in self.dict.items(): 149 self.dict[key] = value[mask]
Remove indexed elements
Parameters
indices : list the indices to remove
152def apply_threshold( 153 tensor: torch.Tensor, 154 threshold: float, 155 low: bool = True, 156 error_mask: torch.Tensor = None, 157 min_frames: int = 0, 158 smooth_interval: int = 0, 159 masked_intervals: List = None, 160): 161 """ 162 Apply a hard threshold to a tensor and return indices of the intervals that passed. 163 164 If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold. 165 If `min_frames` is not 0, the intervals are additionally filtered by length. 166 167 Parameters 168 ---------- 169 tensor : torch.Tensor 170 the tensor to apply the threshold to 171 threshold : float 172 the threshold 173 low : bool, default True 174 whether to select the intervals below the threshold (if `True`) or above (if `False`) 175 error_mask : torch.Tensor, optional 176 a boolean real_lens to apply to the results 177 min_frames : int, default 0 178 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 179 smooth_interval : int, default 0 180 the number of frames to smooth the tensor with before applying the threshold 181 masked_intervals : list, optional 182 a list of intervals to mask out 183 184 Returns 185 ------- 186 indices_start : list 187 a list of indices of the first frames of the chosen intervals 188 indices_end : list 189 a list of indices of the last frames of the chosen intervals 190 191 """ 192 if masked_intervals is None: 193 masked_intervals = [] 194 if low: 195 p = tensor <= threshold 196 else: 197 p = tensor >= threshold 198 p = smooth(p, smooth_interval) 199 if error_mask is not None: 200 p = p * error_mask 201 for start, end in masked_intervals: 202 p[start:end] = False 203 output, indices, counts = torch.unique_consecutive( 204 p, return_inverse=True, return_counts=True 205 ) 206 long_indices = torch.where(output * (counts > min_frames))[0] 207 res_indices_start = [ 208 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 209 ] 210 res_indices_end = [ 211 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 212 ] 213 return res_indices_start, res_indices_end
Apply a hard threshold to a tensor and return indices of the intervals that passed.
If error_mask is not None, the elements marked False are treated as if they did not pass the threshold.
If min_frames is not 0, the intervals are additionally filtered by length.
Parameters
tensor : torch.Tensor
the tensor to apply the threshold to
threshold : float
the threshold
low : bool, default True
whether to select the intervals below the threshold (if True) or above (if False)
error_mask : torch.Tensor, optional
a boolean real_lens to apply to the results
min_frames : int, default 0
the minimum number of frames in the resulting intervals (shorter intervals are discarded)
smooth_interval : int, default 0
the number of frames to smooth the tensor with before applying the threshold
masked_intervals : list, optional
a list of intervals to mask out
Returns
indices_start : list a list of indices of the first frames of the chosen intervals indices_end : list a list of indices of the last frames of the chosen intervals
216def apply_threshold_hysteresis( 217 tensor: torch.Tensor, 218 soft_threshold: float, 219 hard_threshold: float, 220 low: bool = True, 221 error_mask: torch.Tensor = None, 222 min_frames: int = 0, 223 smooth_interval: int = 0, 224 masked_intervals: List = None, 225): 226 """ 227 Apply a hysteresis threshold to a tensor and return indices of the intervals that passed. 228 229 In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold. 230 If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold. 231 If `min_frames` is not 0, the intervals are additionally filtered by length. 232 233 Parameters 234 ---------- 235 tensor : torch.Tensor 236 the tensor to apply the threshold to 237 soft_threshold : float 238 the soft threshold 239 hard_threshold : float 240 the hard threshold 241 low : bool, default True 242 whether to select values below the threshold (`True`) or above (`False`) 243 error_mask : torch.Tensor, optional 244 a boolean real_lens to apply to the results 245 min_frames : int, default 0 246 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 247 smooth_interval : int, default 0 248 the number of frames to smooth the tensor with 249 masked_intervals : list, default None 250 a list of intervals to mask out 251 252 Returns 253 ------- 254 indices_start : list 255 a list of indices of the first frames of the chosen intervals 256 indices_end : list 257 a list of indices of the last frames of the chosen intervals 258 259 """ 260 if masked_intervals is None: 261 masked_intervals = [] 262 if low: 263 p = tensor <= soft_threshold 264 hard = tensor <= hard_threshold 265 else: 266 p = tensor >= soft_threshold 267 hard = tensor >= hard_threshold 268 p = smooth(p, smooth_interval) 269 if error_mask is not None: 270 p = p * error_mask 271 for start, end in masked_intervals: 272 p[start:end] = False 273 output, indices, counts = torch.unique_consecutive( 274 p, return_inverse=True, return_counts=True 275 ) 276 long_indices = torch.where(output * (counts > min_frames))[0] 277 indices_start = [ 278 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 279 ] 280 indices_end = [ 281 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 282 ] 283 res_indices_start = [] 284 res_indices_end = [] 285 for start, end in zip(indices_start, indices_end): 286 if torch.sum(hard[start:end]) > 0: 287 res_indices_start.append(start) 288 res_indices_end.append(end) 289 return res_indices_start, res_indices_end
Apply a hysteresis threshold to a tensor and return indices of the intervals that passed.
In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold.
If error_mask is not None, the elements marked False are treated as if they did not pass the threshold.
If min_frames is not 0, the intervals are additionally filtered by length.
Parameters
tensor : torch.Tensor
the tensor to apply the threshold to
soft_threshold : float
the soft threshold
hard_threshold : float
the hard threshold
low : bool, default True
whether to select values below the threshold (True) or above (False)
error_mask : torch.Tensor, optional
a boolean real_lens to apply to the results
min_frames : int, default 0
the minimum number of frames in the resulting intervals (shorter intervals are discarded)
smooth_interval : int, default 0
the number of frames to smooth the tensor with
masked_intervals : list, default None
a list of intervals to mask out
Returns
indices_start : list a list of indices of the first frames of the chosen intervals indices_end : list a list of indices of the last frames of the chosen intervals
292def apply_threshold_max( 293 tensor: torch.Tensor, 294 threshold: float, 295 main_class: int, 296 error_mask: torch.Tensor = None, 297 min_frames: int = 0, 298 smooth_interval: int = 0, 299 masked_intervals: List = None, 300): 301 """ 302 Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed. 303 304 In the chosen intervals the values at the `main_class` index are larger than the others everywhere 305 and at least one value at the `main_class` index passes the threshold. 306 If `error_mask` is not `None`, the elements marked `False`are treated as if they did not pass the threshold. 307 If min_frames is not 0, the intervals are additionally filtered by length. 308 309 Parameters 310 ---------- 311 tensor : torch.Tensor 312 the tensor to apply the threshold to (of shape `(#classes, #frames)`) 313 threshold : float 314 the threshold 315 main_class : int 316 the class that conditions the soft threshold 317 error_mask : torch.Tensor, optional 318 a boolean real_lens to apply to the results 319 min_frames : int, default 0 320 the minimum number of frames in the resulting intervals (shorter intervals are discarded) 321 smooth_interval : int, default 0 322 the number of frames to smooth the tensor with 323 masked_intervals : list, default None 324 a list of intervals to mask out 325 326 Returns 327 ------- 328 indices_start : list 329 a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor) 330 indices_end : list 331 a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor) 332 333 """ 334 if masked_intervals is None: 335 masked_intervals = [] 336 _, indices = torch.max(tensor, dim=0) 337 p = indices == main_class 338 p = smooth(p, smooth_interval) 339 if error_mask is not None: 340 p = p * error_mask 341 for start, end in masked_intervals: 342 p[start:end] = False 343 output, indices, counts = torch.unique_consecutive( 344 p, return_inverse=True, return_counts=True 345 ) 346 long_indices = torch.where(output * (counts > min_frames))[0] 347 indices_start = [ 348 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 349 ] 350 indices_end = [ 351 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices 352 ] 353 res_indices_start = [] 354 res_indices_end = [] 355 if threshold is not None: 356 hard = tensor[main_class, :] > threshold 357 for start, end in zip(indices_start, indices_end): 358 if torch.sum(hard[start:end]) > 0: 359 res_indices_start.append(start) 360 res_indices_end.append(end) 361 return res_indices_start, res_indices_end 362 else: 363 return indices_start, indices_end
Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed.
In the chosen intervals the values at the main_class index are larger than the others everywhere
and at least one value at the main_class index passes the threshold.
If error_mask is not None, the elements marked Falseare treated as if they did not pass the threshold.
If min_frames is not 0, the intervals are additionally filtered by length.
Parameters
tensor : torch.Tensor
the tensor to apply the threshold to (of shape (#classes, #frames))
threshold : float
the threshold
main_class : int
the class that conditions the soft threshold
error_mask : torch.Tensor, optional
a boolean real_lens to apply to the results
min_frames : int, default 0
the minimum number of frames in the resulting intervals (shorter intervals are discarded)
smooth_interval : int, default 0
the number of frames to smooth the tensor with
masked_intervals : list, default None
a list of intervals to mask out
Returns
indices_start : list a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor) indices_end : list a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor)
366def strip_suffix(text: str, suffix: Iterable): 367 """ 368 Strip a suffix from a string if it is contained in a list 369 370 Parameters 371 ---------- 372 text : str 373 the main string 374 suffix : iterable 375 the list of suffices to be stripped 376 377 Returns 378 ------- 379 result : str 380 the stripped string 381 """ 382 383 for s in suffix: 384 if text.endswith(s): 385 return text[: -len(s)] 386 return text
Strip a suffix from a string if it is contained in a list
Parameters
text : str the main string suffix : iterable the list of suffices to be stripped
Returns
result : str the stripped string
389def strip_prefix(text: str, prefix: Iterable): 390 """ 391 Strip a prefix from a string if it is contained in a list 392 393 Parameters 394 ---------- 395 text : str 396 the main string 397 prefix : iterable 398 the list of prefixes to be stripped 399 400 Returns 401 ------- 402 result : str 403 the stripped string 404 """ 405 406 if prefix is None: 407 prefix = [] 408 for s in prefix: 409 if text.startswith(s): 410 return text[len(s) :] 411 return text
Strip a prefix from a string if it is contained in a list
Parameters
text : str the main string prefix : iterable the list of prefixes to be stripped
Returns
result : str the stripped string
414def rotation_matrix_2d(angles: torch.Tensor) -> torch.Tensor: 415 """ 416 Create a tensor of 2D rotation matrices from a tensor of angles 417 418 Parameters 419 ---------- 420 angles : torch.Tensor 421 a tensor of angles of arbitrary shape `(...)` 422 423 Returns 424 ------- 425 rotation_matrices : torch.Tensor 426 a tensor of 2D rotation matrices of shape `(..., 2, 2)` 427 """ 428 429 cos = torch.cos(angles) 430 sin = torch.sin(angles) 431 R = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*angles.shape, 2, 2) 432 return R
Create a tensor of 2D rotation matrices from a tensor of angles
Parameters
angles : torch.Tensor
a tensor of angles of arbitrary shape (...)
Returns
rotation_matrices : torch.Tensor
a tensor of 2D rotation matrices of shape (..., 2, 2)
435def rotation_matrix_3d(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor): 436 """ 437 Create a tensor of 3D rotation matrices from a tensor of angles 438 439 Parameters 440 ---------- 441 alpha : torch.Tensor 442 a tensor of rotation angles around the x axis of arbitrary shape `(...)` 443 beta : torch.Tensor 444 a tensor of rotation angles around the y axis of arbitrary shape `(...)` 445 gamma : torch.Tensor 446 a tensor of rotation angles around the z axis of arbitrary shape `(...)` 447 448 Returns 449 ------- 450 rotation_matrices : torch.Tensor 451 a tensor of 3D rotation matrices of shape `(..., 3, 3)` 452 """ 453 454 cos = torch.cos(alpha) 455 sin = torch.sin(alpha) 456 Rx = torch.stack( 457 [ 458 torch.ones(cos.shape), 459 torch.zeros(cos.shape), 460 torch.zeros(cos.shape), 461 torch.zeros(cos.shape), 462 cos, 463 -sin, 464 torch.zeros(cos.shape), 465 sin, 466 cos, 467 ], 468 dim=-1, 469 ).reshape(*alpha.shape, 3, 3) 470 cos = torch.cos(beta) 471 sin = torch.sin(beta) 472 Ry = torch.stack( 473 [ 474 cos, 475 torch.zeros(cos.shape), 476 sin, 477 torch.zeros(cos.shape), 478 torch.ones(cos.shape), 479 torch.zeros(cos.shape), 480 -sin, 481 torch.zeros(cos.shape), 482 cos, 483 ], 484 dim=-1, 485 ).reshape(*beta.shape, 3, 3) 486 cos = torch.cos(gamma) 487 sin = torch.sin(gamma) 488 Rz = torch.stack( 489 [ 490 cos, 491 -sin, 492 torch.zeros(cos.shape), 493 sin, 494 cos, 495 torch.zeros(cos.shape), 496 torch.zeros(cos.shape), 497 torch.zeros(cos.shape), 498 torch.ones(cos.shape), 499 ], 500 dim=-1, 501 ).reshape(*gamma.shape, 3, 3) 502 R = torch.einsum("...ij,...jk,...kl->...il", Rx, Ry, Rz) 503 return R
Create a tensor of 3D rotation matrices from a tensor of angles
Parameters
alpha : torch.Tensor
a tensor of rotation angles around the x axis of arbitrary shape (...)
beta : torch.Tensor
a tensor of rotation angles around the y axis of arbitrary shape (...)
gamma : torch.Tensor
a tensor of rotation angles around the z axis of arbitrary shape (...)
Returns
rotation_matrices : torch.Tensor
a tensor of 3D rotation matrices of shape (..., 3, 3)
506def correct_path(path, project_path): 507 """ 508 Correct a path to a file or directory to be relative to the project path. 509 510 Parameters 511 ---------- 512 path : str 513 the path to be corrected 514 project_path : str 515 the path to the project root directory 516 517 Returns 518 ------- 519 corrected_path : str 520 the corrected path 521 522 """ 523 if not isinstance(path, str): 524 return path 525 path_split = os.path.normpath(path).split(os.path.sep) 526 if "results" in path_split: 527 name = "results" 528 elif "saved_datasets" in path: 529 name = "saved_datasets" 530 else: 531 return os.path.sep.join(path_split) # No need for correction 532 ind = path_split.index(name) + 1 533 return os.path.join(project_path, "results", *path_split[ind:])
Correct a path to a file or directory to be relative to the project path.
Parameters
path : str the path to be corrected project_path : str the path to the project root directory
Returns
corrected_path : str the corrected path
536class TensorList(list): 537 """ 538 A list of tensors that can send each element to a `torch` device 539 """ 540 541 def to_device(self, device: torch.device): 542 """Send each element of the list to a `torch` device.""" 543 for i, x in enumerate(self): 544 self[i] = x.to_device(device)
A list of tensors that can send each element to a torch device
547def get_intervals(tensor: torch.Tensor) -> torch.Tensor: 548 """ 549 Get a list of True group beginning and end indices from a boolean tensor 550 """ 551 552 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 553 true_indices = torch.where(output)[0] 554 starts = torch.tensor( 555 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 556 ) 557 ends = torch.tensor( 558 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 559 ) 560 return torch.stack([starts, ends]).T
Get a list of True group beginning and end indices from a boolean tensor
563def smooth(tensor: torch.Tensor, smooth_interval: int = 0) -> torch.Tensor: 564 """ 565 Get rid of jittering in a non-exclusive classification tensor 566 567 First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than 568 `smooth_interval`. 569 """ 570 571 if smooth_interval == 0: 572 return tensor 573 intervals = get_intervals(tensor == 0) 574 interval_lengths = torch.tensor( 575 [interval[1] - interval[0] for interval in intervals] 576 ) 577 short_intervals = intervals[interval_lengths <= smooth_interval] 578 for start, end in short_intervals: 579 tensor[start:end] = 1 580 intervals = get_intervals(tensor == 1) 581 interval_lengths = torch.tensor( 582 [interval[1] - interval[0] for interval in intervals] 583 ) 584 short_intervals = intervals[interval_lengths <= smooth_interval] 585 for start, end in short_intervals: 586 tensor[start:end] = 0 587 return tensor
Get rid of jittering in a non-exclusive classification tensor
First, remove intervals of 0 shorter than smooth_interval. Then, remove intervals of 1 shorter than
smooth_interval.
590class GaussianSmoothing(nn.Module): 591 """Apply gaussian smoothing on a 1d tensor. 592 593 Filtering is performed separately for each channel 594 in the input using a depthwise convolution. 595 Arguments: 596 channels (int, sequence): Number of channels of the input tensors. Output will 597 have this number of channels as well. 598 kernel_size (int, sequence): Size of the gaussian kernel. 599 sigma (float, sequence): Standard deviation of the gaussian kernel. 600 """ 601 602 def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None: 603 super().__init__() 604 self.kernel_size = kernel_size 605 606 # The gaussian kernel is the product of the 607 # gaussian function of each dimension. 608 kernel = 1 609 meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float() 610 611 mean = (kernel_size - 1) / 2 612 kernel = kernel / (sigma * math.sqrt(2 * math.pi)) 613 kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2) 614 615 # Make sure sum of values in gaussian kernel equals 1. 616 # kernel = kernel / torch.max(kernel) 617 618 self.kernel = kernel.view(1, 1, *kernel.size()) 619 620 def forward(self, inputs: torch.Tensor) -> torch.Tensor: 621 """ 622 Apply gaussian filter to input. 623 624 Parameters 625 ---------- 626 inputs : torch.Tensor 627 Input to apply gaussian filter on. 628 629 Returns 630 ------- 631 filtered : torch.Tensor) 632 Filtered output. 633 634 """ 635 _, c, _ = inputs.shape 636 inputs = F.pad( 637 inputs, 638 pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2), 639 mode="reflect", 640 ) 641 kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device) 642 return F.conv1d(inputs, weight=kernel, groups=c)
Apply gaussian smoothing on a 1d tensor.
Filtering is performed separately for each channel in the input using a depthwise convolution. Arguments: channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel.
602 def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None: 603 super().__init__() 604 self.kernel_size = kernel_size 605 606 # The gaussian kernel is the product of the 607 # gaussian function of each dimension. 608 kernel = 1 609 meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float() 610 611 mean = (kernel_size - 1) / 2 612 kernel = kernel / (sigma * math.sqrt(2 * math.pi)) 613 kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2) 614 615 # Make sure sum of values in gaussian kernel equals 1. 616 # kernel = kernel / torch.max(kernel) 617 618 self.kernel = kernel.view(1, 1, *kernel.size())
Initialize internal Module state, shared by both nn.Module and ScriptModule.
620 def forward(self, inputs: torch.Tensor) -> torch.Tensor: 621 """ 622 Apply gaussian filter to input. 623 624 Parameters 625 ---------- 626 inputs : torch.Tensor 627 Input to apply gaussian filter on. 628 629 Returns 630 ------- 631 filtered : torch.Tensor) 632 Filtered output. 633 634 """ 635 _, c, _ = inputs.shape 636 inputs = F.pad( 637 inputs, 638 pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2), 639 mode="reflect", 640 ) 641 kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device) 642 return F.conv1d(inputs, weight=kernel, groups=c)
Apply gaussian filter to input.
Parameters
inputs : torch.Tensor Input to apply gaussian filter on.
Returns
filtered : torch.Tensor) Filtered output.
645def argrelmax(prob: np.ndarray, threshold: float = 0.7) -> List[int]: 646 """ 647 Calculate arguments of relative maxima. 648 649 Parameters 650 ---------- 651 prob : np.ndarray 652 boundary probability maps distributerd in [0, 1], shape is (T) 653 threshold : float, default 0.7 654 ignore the peak whose value is under threshold 655 656 Returns 657 ------- 658 peak_idx : List[int] 659 Index of peaks for each batch 660 661 """ 662 # ignore the values under threshold 663 prob[prob < threshold] = 0.0 664 665 # calculate the relative maxima of boundary maps 666 # treat the first frame as boundary 667 peak = np.concatenate( 668 [ 669 np.ones((1), dtype=np.bool), 670 (prob[:-2] < prob[1:-1]) & (prob[2:] < prob[1:-1]), 671 np.zeros((1), dtype=np.bool), 672 ], 673 axis=0, 674 ) 675 676 peak_idx = np.where(peak)[0].tolist() 677 678 return peak_idx
Calculate arguments of relative maxima.
Parameters
prob : np.ndarray boundary probability maps distributerd in [0, 1], shape is (T) threshold : float, default 0.7 ignore the peak whose value is under threshold
Returns
peak_idx : List[int] Index of peaks for each batch
681def decide_boundary_prob_with_similarity(x: torch.Tensor) -> torch.Tensor: 682 """ 683 Decide action boundary probabilities based on adjacent frame similarities. 684 Args: 685 x: frame-wise video features (N, C, T) 686 Return: 687 boundary: action boundary probability (N, 1, T) 688 """ 689 device = x.device 690 691 # gaussian kernel. 692 diff = x[0, :, 1:] - x[0, :, :-1] 693 similarity = torch.exp(-torch.norm(diff, dim=0) / (2 * 1.0)) 694 695 # define action starting point as action boundary. 696 start = torch.ones(1).float().to(device) 697 boundary = torch.cat([start, similarity]) 698 boundary = boundary.view(1, 1, -1) 699 return boundary
Decide action boundary probabilities based on adjacent frame similarities. Args: x: frame-wise video features (N, C, T) Return: boundary: action boundary probability (N, 1, T)
702class PostProcessor(object): 703 def __init__( 704 self, 705 name: str, 706 boundary_th: int = 0.7, 707 theta_t: int = 15, 708 kernel_size: int = 15, 709 ) -> None: 710 self.func = { 711 "refinement_with_boundary": self._refinement_with_boundary, 712 "relabeling": self._relabeling, 713 "smoothing": self._smoothing, 714 } 715 assert name in self.func 716 717 self.name = name 718 self.boundary_th = boundary_th 719 self.theta_t = theta_t 720 self.kernel_size = kernel_size 721 722 if name == "smoothing": 723 self.filter = GaussianSmoothing(self.kernel_size) 724 725 def _is_probability(self, x: np.ndarray) -> bool: 726 assert x.ndim == 3 727 728 if x.shape[1] == 1: 729 # sigmoid 730 if x.min() >= 0 and x.max() <= 1: 731 return True 732 else: 733 return False 734 else: 735 # softmax 736 _sum = np.sum(x, axis=1).astype(np.float32) 737 _ones = np.ones_like(_sum, dtype=np.float32) 738 return np.allclose(_sum, _ones) 739 740 def _convert2probability(self, x: np.ndarray) -> np.ndarray: 741 """ 742 Args: x (N, C, T) 743 """ 744 assert x.ndim == 3 745 746 if self._is_probability(x): 747 return x 748 else: 749 if x.shape[1] == 1: 750 # sigmoid 751 prob = 1 / (1 + np.exp(-x)) 752 else: 753 # softmax 754 prob = np.exp(x) / np.expand_dims(np.sum(np.exp(x), axis=1), 1) 755 return prob.astype(np.float32) 756 757 def _convert2label(self, x: np.ndarray) -> np.ndarray: 758 assert x.ndim == 2 or x.ndim == 3 759 760 if x.ndim == 2: 761 return x.astype(np.int64) 762 else: 763 if not self._is_probability(x): 764 x = self._convert2probability(x) 765 766 label = np.argmax(x, axis=1) 767 return label.astype(np.int64) 768 769 def _refinement_with_boundary( 770 self, 771 outputs: np.array, 772 boundaries: np.ndarray, 773 ) -> np.ndarray: 774 """ 775 Get segments which is defined as the span b/w two boundaries, 776 and decide their classes by majority vote. 777 Args: 778 outputs: numpy array. shape (N, C, T) 779 the model output for frame-level class prediction. 780 boundaries: numpy array. shape (N, 1, T) 781 boundary prediction. 782 masks: np.array. np.bool. shape (N, 1, T) 783 valid length for each video 784 Return: 785 preds: np.array. shape (N, T) 786 final class prediction considering boundaries. 787 """ 788 789 preds = self._convert2label(outputs) 790 boundaries = self._convert2probability(boundaries) 791 792 for i, (output, pred, boundary) in enumerate(zip(outputs, preds, boundaries)): 793 idx = argrelmax(boundary.squeeze(), threshold=self.boundary_th) 794 795 # add the index of the last action ending 796 T = pred.shape[0] 797 idx.append(T) 798 799 # majority vote 800 for j in range(len(idx) - 1): 801 count = np.bincount(pred[idx[j] : idx[j + 1]]) 802 modes = np.where(count == count.max())[0] 803 if len(modes) == 1: 804 mode = modes 805 else: 806 if outputs.ndim == 3: 807 # if more than one majority class exist 808 prob_sum_max = 0 809 for m in modes: 810 prob_sum = output[m, idx[j] : idx[j + 1]].sum() 811 if prob_sum_max < prob_sum: 812 mode = m 813 prob_sum_max = prob_sum 814 else: 815 # decide first mode when more than one majority class 816 # have the same number during oracle experiment 817 mode = modes[0] 818 819 preds[i, idx[j] : idx[j + 1]] = mode 820 821 return preds 822 823 def _relabeling(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: 824 """ 825 Relabeling small action segments with their previous action segment 826 Args: 827 output: the results of action segmentation. (N, T) or (N, C, T) 828 theta_t: the threshold of the size of action segments. 829 Return: 830 relabeled output. (N, T) 831 """ 832 833 preds = self._convert2label(outputs) 834 835 for i in range(preds.shape[0]): 836 # shape (T,) 837 last = preds[i][0] 838 cnt = 1 839 for j in range(1, preds.shape[1]): 840 if last == preds[i][j]: 841 cnt += 1 842 else: 843 if cnt > self.theta_t: 844 cnt = 1 845 last = preds[i][j] 846 else: 847 preds[i][j - cnt : j] = preds[i][j - cnt - 1] 848 cnt = 1 849 last = preds[i][j] 850 851 if cnt <= self.theta_t: 852 preds[i][j - cnt : j] = preds[i][j - cnt - 1] 853 854 return preds 855 856 def _smoothing(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray: 857 """ 858 Smoothing action probabilities with gaussian filter. 859 Args: 860 outputs: frame-wise action probabilities. (N, C, T) 861 Return: 862 predictions: final prediction. (N, T) 863 """ 864 865 outputs = self._convert2probability(outputs) 866 outputs = self.filter(torch.Tensor(outputs)).numpy() 867 868 preds = self._convert2label(outputs) 869 return preds 870 871 def __call__(self, outputs, **kwargs: np.ndarray) -> np.ndarray: 872 873 preds = self.func[self.name](outputs, **kwargs) 874 return preds
703 def __init__( 704 self, 705 name: str, 706 boundary_th: int = 0.7, 707 theta_t: int = 15, 708 kernel_size: int = 15, 709 ) -> None: 710 self.func = { 711 "refinement_with_boundary": self._refinement_with_boundary, 712 "relabeling": self._relabeling, 713 "smoothing": self._smoothing, 714 } 715 assert name in self.func 716 717 self.name = name 718 self.boundary_th = boundary_th 719 self.theta_t = theta_t 720 self.kernel_size = kernel_size 721 722 if name == "smoothing": 723 self.filter = GaussianSmoothing(self.kernel_size)
877def get_sentences_from_article( 878 pdf_filepath: str, max_symbols=77, min_symbols=15 879) -> List: 880 """ 881 Read a pdf file, clean the text from references and split into sentences 882 883 Parameters 884 ---------- 885 pdf_filepath : str 886 the path to the pdf file 887 max_symbols : int, default 77 888 sentences longer than this number of symbols will be omitted 889 min_symbols : int, default 15 890 sentences shorter than this number of symbols will be omitted 891 892 Returns 893 ------- 894 sentences : list 895 a list of sentences 896 """ 897 898 from tika import parser 899 900 rawText = parser.from_file(pdf_filepath) 901 text = rawText["content"].replace("-\n", "") 902 text = text.replace("\n", " ") 903 text = re.sub(r"(?<=\d),(?=\d)", "", text) 904 text = text.split("References")[0] 905 text = re.sub(r" \([\w&, ;]+\d{4}\)", "", text).strip() 906 text = re.sub(r" \[\d+\]", "", text).strip() 907 sentences = [x.strip() for x in text.split(".")] 908 sentences = [ 909 x 910 for x in sentences 911 if not x.endswith(" al") 912 and len(x) <= max_symbols 913 and len(x) >= min_symbols 914 and sum(c.isdigit() for c in x) < len(x) * 0.2 915 and sum(len(w) < 2 for w in x.split()) < len(x) * 0.2 916 ] 917 return list(set(sentences))
Read a pdf file, clean the text from references and split into sentences
Parameters
pdf_filepath : str the path to the pdf file max_symbols : int, default 77 sentences longer than this number of symbols will be omitted min_symbols : int, default 15 sentences shorter than this number of symbols will be omitted
Returns
sentences : list a list of sentences
921def load_pickle(file_path: str): 922 """ 923 Load a pickle file 924 925 Parameters 926 ---------- 927 file_path : str 928 the path to the pickle file 929 930 Returns 931 ------- 932 obj : any 933 the loaded object 934 """ 935 936 with open(file_path, "rb") as f: 937 obj = pickle.load(f) 938 return obj
Load a pickle file
Parameters
file_path : str the path to the pickle file
Returns
obj : any the loaded object
941def binarize_data(data: list, max_frame: int = -1) -> np.array: 942 """Convert file in start, stop, confusion in to binary matrix of size (N_classes, N_timepoints)""" 943 data = data[3][0] 944 all_annot = [arr for arr in data if len(arr)] 945 all_annot = np.concatenate(all_annot, axis=0) 946 max_frame = np.max(all_annot) if max_frame == -1 else max_frame 947 label_array = np.zeros((len(data), max_frame)) 948 for k, arr in enumerate(data): 949 for start, stop, confusion in arr: 950 if confusion < 0.5: 951 label_array[k, start:stop] = 1 952 return label_array
Convert file in start, stop, confusion in to binary matrix of size (N_classes, N_timepoints)