dlc2action.utils

Utility functions

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. 
  5# A copy is included in dlc2action/LICENSE.AGPL.
  6#
  7"""
  8## Utility functions
  9
 10- `TensorDict` is a convenient data structure for keeping (and indexing) lists of feature dictionaries,
 11- `apply_threshold`, `apply_threshold_hysteresis` and `apply_threshold_max` are utility functions for
 12`dlc2action.data.dataset.BehaviorDataset.find_valleys`,
 13- `strip_suffix` is used to get rid of suffices if a string (usually filename) ends with one of them,
 14- `strip_prefix` is used to get rid of prefixes if a string (usually filename) starts with one of them,
 15- `rotation_matrix_2d` and `rotation_matrix_3d` are used to generate rotation matrices by
 16`dlc2action.transformer.base_transformer.Transformer` instances
 17"""
 18
 19import torch
 20from typing import List, Dict, Union
 21from collections.abc import Iterable
 22import warnings
 23import os
 24import numpy as np
 25from torch import nn
 26from torch.nn import functional as F
 27import pickle
 28import re
 29import math
 30
 31
 32class TensorDict:
 33    """
 34    A class that handles indexing in a dictionary of tensors of the same length
 35    """
 36
 37    def __init__(self, obj: Union[Dict, Iterable] = None) -> None:
 38        """
 39        Parameters
 40        ----------
 41        obj : dict | iterable, optional
 42            either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with
 43            the same keys (if not passed, a blank TensorDict is initialized)
 44        """
 45
 46        if obj is None:
 47            obj = {}
 48        if isinstance(obj, list):
 49            self.keys = list(obj[0].keys())
 50            self.dict = {key: [] for key in self.keys}
 51            with warnings.catch_warnings():
 52                warnings.filterwarnings("ignore", category=UserWarning)
 53                for element in obj:
 54                    for key in self.keys:
 55                        self.dict[key].append(torch.tensor(element[key]))
 56            # self.dict = {k: torch.stack(v) for k, v in self.dict.items()}
 57            old_dict = self.dict
 58            self.dict = {}
 59            for k, v in old_dict.items():
 60                self.dict[k] = torch.stack(v)
 61        elif isinstance(obj, dict):
 62            if not all([isinstance(obj[key], torch.Tensor) for key in obj]):
 63                raise TypeError(
 64                    f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;"
 65                    f"got {[type(obj[key]) for key in obj]}"
 66                )
 67            lengths = [len(obj[key]) for key in obj]
 68            if not all([x == lengths[0] for x in lengths]):
 69                raise ValueError(
 70                    f"The value tensors in the dictionary passed to TensorDict need to have the same length;"
 71                    f"got {lengths}"
 72                )
 73            self.dict = obj
 74            self.keys = list(self.dict.keys())
 75        else:
 76            raise TypeError(
 77                f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary "
 78                f"of tensors (got {type(obj)})"
 79            )
 80        if len(self.keys) > 0:
 81            self.type = type(self.dict[self.keys[0]])
 82        else:
 83            self.type = None
 84
 85    def __len__(self) -> int:
 86        return len(self.dict[self.keys[0]])
 87
 88    def __getitem__(self, ind: Union[int, List]):
 89        """
 90        Index the TensorDict
 91
 92        Parameters
 93        ----------
 94        ind : int | list
 95            the index/indices
 96
 97        Returns
 98        -------
 99        dict | TensorDict
100            the indexed elements of all value lists combined in a dictionary (if length of ind is 1) or a TensorDict
101        """
102
103        x = {key: self.dict[key][ind] for key in self.dict}
104        if type(ind) is not int:
105            d = TensorDict(x)
106            return d
107        else:
108            return x
109
110    def append(self, element: Dict) -> None:
111        """
112        Append an element
113
114        Parameters
115        ----------
116        element : dict
117            a dictionary
118        """
119
120        type_element = type(element[list(element.keys())[0]])
121        if self.type is None:
122            self.dict = {k: v.unsqueeze(0) for k, v in element.items()}
123            self.keys = list(element.keys())
124            self.type = type_element
125        else:
126            for key in self.keys:
127                if key not in element:
128                    raise ValueError(
129                        f"The dictionary appended to TensorDict needs to have the same keys as the "
130                        f"TensorDict; got {element.keys()} and {self.keys}"
131                    )
132                self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)])
133
134    def remove(self, indices: List) -> None:
135        """
136        Remove indexed elements
137
138        Parameters
139        ----------
140        indices : list
141            the indices to remove
142        """
143
144        mask = torch.ones(len(self))
145        mask[indices] = 0
146        mask = mask.bool()
147        for key, value in self.dict.items():
148            self.dict[key] = value[mask]
149
150
151def apply_threshold(
152    tensor: torch.Tensor,
153    threshold: float,
154    low: bool = True,
155    error_mask: torch.Tensor = None,
156    min_frames: int = 0,
157    smooth_interval: int = 0,
158    masked_intervals: List = None,
159):
160    """
161    Apply a hard threshold to a tensor and return indices of the intervals that passed.
162
163    If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold.
164    If `min_frames` is not 0, the intervals are additionally filtered by length.
165
166    Parameters
167    ----------
168    tensor : torch.Tensor
169        the tensor to apply the threshold to
170    threshold : float
171        the threshold
172    low : bool, default True
173        whether to select the intervals below the threshold (if `True`) or above (if `False`)
174    error_mask : torch.Tensor, optional
175        a boolean real_lens to apply to the results
176    min_frames : int, default 0
177        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
178    smooth_interval : int, default 0
179        the number of frames to smooth the tensor with before applying the threshold
180    masked_intervals : list, optional
181        a list of intervals to mask out
182
183    Returns
184    -------
185    indices_start : list
186        a list of indices of the first frames of the chosen intervals
187    indices_end : list
188        a list of indices of the last frames of the chosen intervals
189
190    """
191    if masked_intervals is None:
192        masked_intervals = []
193    if low:
194        p = tensor <= threshold
195    else:
196        p = tensor >= threshold
197    p = smooth(p, smooth_interval)
198    if error_mask is not None:
199        p = p * error_mask
200    for start, end in masked_intervals:
201        p[start:end] = False
202    output, indices, counts = torch.unique_consecutive(
203        p, return_inverse=True, return_counts=True
204    )
205    long_indices = torch.where(output * (counts > min_frames))[0]
206    res_indices_start = [
207        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
208    ]
209    res_indices_end = [
210        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
211    ]
212    return res_indices_start, res_indices_end
213
214
215def apply_threshold_hysteresis(
216    tensor: torch.Tensor,
217    soft_threshold: float,
218    hard_threshold: float,
219    low: bool = True,
220    error_mask: torch.Tensor = None,
221    min_frames: int = 0,
222    smooth_interval: int = 0,
223    masked_intervals: List = None,
224):
225    """
226    Apply a hysteresis threshold to a tensor and return indices of the intervals that passed.
227
228    In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold.
229    If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold.
230    If `min_frames` is not 0, the intervals are additionally filtered by length.
231
232    Parameters
233    ----------
234    tensor : torch.Tensor
235        the tensor to apply the threshold to
236    soft_threshold : float
237        the soft threshold
238    hard_threshold : float
239        the hard threshold
240    low : bool, default True
241        whether to select values below the threshold (`True`) or above (`False`)
242    error_mask : torch.Tensor, optional
243        a boolean real_lens to apply to the results
244    min_frames : int, default 0
245        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
246    smooth_interval : int, default 0
247        the number of frames to smooth the tensor with
248    masked_intervals : list, default None
249        a list of intervals to mask out
250
251    Returns
252    -------
253    indices_start : list
254        a list of indices of the first frames of the chosen intervals
255    indices_end : list
256        a list of indices of the last frames of the chosen intervals
257
258    """
259    if masked_intervals is None:
260        masked_intervals = []
261    if low:
262        p = tensor <= soft_threshold
263        hard = tensor <= hard_threshold
264    else:
265        p = tensor >= soft_threshold
266        hard = tensor >= hard_threshold
267    p = smooth(p, smooth_interval)
268    if error_mask is not None:
269        p = p * error_mask
270    for start, end in masked_intervals:
271        p[start:end] = False
272    output, indices, counts = torch.unique_consecutive(
273        p, return_inverse=True, return_counts=True
274    )
275    long_indices = torch.where(output * (counts > min_frames))[0]
276    indices_start = [
277        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
278    ]
279    indices_end = [
280        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
281    ]
282    res_indices_start = []
283    res_indices_end = []
284    for start, end in zip(indices_start, indices_end):
285        if torch.sum(hard[start:end]) > 0:
286            res_indices_start.append(start)
287            res_indices_end.append(end)
288    return res_indices_start, res_indices_end
289
290
291def apply_threshold_max(
292    tensor: torch.Tensor,
293    threshold: float,
294    main_class: int,
295    error_mask: torch.Tensor = None,
296    min_frames: int = 0,
297    smooth_interval: int = 0,
298    masked_intervals: List = None,
299):
300    """
301    Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed.
302
303    In the chosen intervals the values at the `main_class` index are larger than the others everywhere
304    and at least one value at the `main_class` index passes the threshold.
305    If `error_mask` is not `None`, the elements marked `False`are treated as if they did not pass the threshold.
306    If min_frames is not 0, the intervals are additionally filtered by length.
307
308    Parameters
309    ----------
310    tensor : torch.Tensor
311        the tensor to apply the threshold to (of shape `(#classes, #frames)`)
312    threshold : float
313        the threshold
314    main_class : int
315        the class that conditions the soft threshold
316    error_mask : torch.Tensor, optional
317        a boolean real_lens to apply to the results
318    min_frames : int, default 0
319        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
320    smooth_interval : int, default 0
321        the number of frames to smooth the tensor with
322    masked_intervals : list, default None
323        a list of intervals to mask out
324
325    Returns
326    -------
327    indices_start : list
328        a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor)
329    indices_end : list
330        a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor)
331
332    """
333    if masked_intervals is None:
334        masked_intervals = []
335    _, indices = torch.max(tensor, dim=0)
336    p = indices == main_class
337    p = smooth(p, smooth_interval)
338    if error_mask is not None:
339        p = p * error_mask
340    for start, end in masked_intervals:
341        p[start:end] = False
342    output, indices, counts = torch.unique_consecutive(
343        p, return_inverse=True, return_counts=True
344    )
345    long_indices = torch.where(output * (counts > min_frames))[0]
346    indices_start = [
347        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
348    ]
349    indices_end = [
350        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
351    ]
352    res_indices_start = []
353    res_indices_end = []
354    if threshold is not None:
355        hard = tensor[main_class, :] > threshold
356        for start, end in zip(indices_start, indices_end):
357            if torch.sum(hard[start:end]) > 0:
358                res_indices_start.append(start)
359                res_indices_end.append(end)
360        return res_indices_start, res_indices_end
361    else:
362        return indices_start, indices_end
363
364
365def strip_suffix(text: str, suffix: Iterable):
366    """
367    Strip a suffix from a string if it is contained in a list
368
369    Parameters
370    ----------
371    text : str
372        the main string
373    suffix : iterable
374        the list of suffices to be stripped
375
376    Returns
377    -------
378    result : str
379        the stripped string
380    """
381
382    for s in suffix:
383        if text.endswith(s):
384            return text[: -len(s)]
385    return text
386
387
388def strip_prefix(text: str, prefix: Iterable):
389    """
390    Strip a prefix from a string if it is contained in a list
391
392    Parameters
393    ----------
394    text : str
395        the main string
396    prefix : iterable
397        the list of prefixes to be stripped
398
399    Returns
400    -------
401    result : str
402        the stripped string
403    """
404
405    if prefix is None:
406        prefix = []
407    for s in prefix:
408        if text.startswith(s):
409            return text[len(s) :]
410    return text
411
412
413def rotation_matrix_2d(angles: torch.Tensor) -> torch.Tensor:
414    """
415    Create a tensor of 2D rotation matrices from a tensor of angles
416
417    Parameters
418    ----------
419    angles : torch.Tensor
420        a tensor of angles of arbitrary shape `(...)`
421
422    Returns
423    -------
424    rotation_matrices : torch.Tensor
425        a tensor of 2D rotation matrices of shape `(..., 2, 2)`
426    """
427
428    cos = torch.cos(angles)
429    sin = torch.sin(angles)
430    R = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*angles.shape, 2, 2)
431    return R
432
433
434def rotation_matrix_3d(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor):
435    """
436    Create a tensor of 3D rotation matrices from a tensor of angles
437
438    Parameters
439    ----------
440    alpha : torch.Tensor
441        a tensor of rotation angles around the x axis of arbitrary shape `(...)`
442    beta : torch.Tensor
443        a tensor of rotation angles around the y axis of arbitrary shape `(...)`
444    gamma : torch.Tensor
445        a tensor of rotation angles around the z axis of arbitrary shape `(...)`
446
447    Returns
448    -------
449    rotation_matrices : torch.Tensor
450        a tensor of 3D rotation matrices of shape `(..., 3, 3)`
451    """
452
453    cos = torch.cos(alpha)
454    sin = torch.sin(alpha)
455    Rx = torch.stack(
456        [
457            torch.ones(cos.shape),
458            torch.zeros(cos.shape),
459            torch.zeros(cos.shape),
460            torch.zeros(cos.shape),
461            cos,
462            -sin,
463            torch.zeros(cos.shape),
464            sin,
465            cos,
466        ],
467        dim=-1,
468    ).reshape(*alpha.shape, 3, 3)
469    cos = torch.cos(beta)
470    sin = torch.sin(beta)
471    Ry = torch.stack(
472        [
473            cos,
474            torch.zeros(cos.shape),
475            sin,
476            torch.zeros(cos.shape),
477            torch.ones(cos.shape),
478            torch.zeros(cos.shape),
479            -sin,
480            torch.zeros(cos.shape),
481            cos,
482        ],
483        dim=-1,
484    ).reshape(*beta.shape, 3, 3)
485    cos = torch.cos(gamma)
486    sin = torch.sin(gamma)
487    Rz = torch.stack(
488        [
489            cos,
490            -sin,
491            torch.zeros(cos.shape),
492            sin,
493            cos,
494            torch.zeros(cos.shape),
495            torch.zeros(cos.shape),
496            torch.zeros(cos.shape),
497            torch.ones(cos.shape),
498        ],
499        dim=-1,
500    ).reshape(*gamma.shape, 3, 3)
501    R = torch.einsum("...ij,...jk,...kl->...il", Rx, Ry, Rz)
502    return R
503
504
505def correct_path(path, project_path):
506    """
507    Correct a path to a file or directory to be relative to the project path.
508
509    Parameters
510    ----------
511    path : str
512        the path to be corrected
513    project_path : str
514        the path to the project root directory
515
516    Returns
517    -------
518    corrected_path : str
519        the corrected path
520
521    """
522    if not isinstance(path, str):
523        return path
524    path_split = os.path.normpath(path).split(os.path.sep)
525    if "results" in path_split:
526        name = "results"
527    elif "saved_datasets" in path:
528        name = "saved_datasets"
529    else:
530        return os.path.sep.join(path_split)  # No need for correction
531    ind = path_split.index(name) + 1
532    return os.path.join(project_path, "results", *path_split[ind:])
533
534
535class TensorList(list):
536    """
537    A list of tensors that can send each element to a `torch` device
538    """
539
540    def to_device(self, device: torch.device):
541        """Send each element of the list to a `torch` device."""
542        for i, x in enumerate(self):
543            self[i] = x.to_device(device)
544
545
546def get_intervals(tensor: torch.Tensor) -> torch.Tensor:
547    """
548    Get a list of True group beginning and end indices from a boolean tensor
549    """
550
551    output, indices = torch.unique_consecutive(tensor, return_inverse=True)
552    true_indices = torch.where(output)[0]
553    starts = torch.tensor(
554        [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
555    )
556    ends = torch.tensor(
557        [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
558    )
559    return torch.stack([starts, ends]).T
560
561
562def smooth(tensor: torch.Tensor, smooth_interval: int = 0) -> torch.Tensor:
563    """
564    Get rid of jittering in a non-exclusive classification tensor
565
566    First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than
567    `smooth_interval`.
568    """
569
570    if smooth_interval == 0:
571        return tensor
572    intervals = get_intervals(tensor == 0)
573    interval_lengths = torch.tensor(
574        [interval[1] - interval[0] for interval in intervals]
575    )
576    short_intervals = intervals[interval_lengths <= smooth_interval]
577    for start, end in short_intervals:
578        tensor[start:end] = 1
579    intervals = get_intervals(tensor == 1)
580    interval_lengths = torch.tensor(
581        [interval[1] - interval[0] for interval in intervals]
582    )
583    short_intervals = intervals[interval_lengths <= smooth_interval]
584    for start, end in short_intervals:
585        tensor[start:end] = 0
586    return tensor
587
588
589class GaussianSmoothing(nn.Module):
590    """Apply gaussian smoothing on a 1d tensor.
591
592    Filtering is performed separately for each channel
593    in the input using a depthwise convolution.
594    Arguments:
595        channels (int, sequence): Number of channels of the input tensors. Output will
596            have this number of channels as well.
597        kernel_size (int, sequence): Size of the gaussian kernel.
598        sigma (float, sequence): Standard deviation of the gaussian kernel.
599    """
600
601    def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None:
602        super().__init__()
603        self.kernel_size = kernel_size
604
605        # The gaussian kernel is the product of the
606        # gaussian function of each dimension.
607        kernel = 1
608        meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float()
609
610        mean = (kernel_size - 1) / 2
611        kernel = kernel / (sigma * math.sqrt(2 * math.pi))
612        kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2)
613
614        # Make sure sum of values in gaussian kernel equals 1.
615        # kernel = kernel / torch.max(kernel)
616
617        self.kernel = kernel.view(1, 1, *kernel.size())
618
619    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
620        """
621        Apply gaussian filter to input.
622
623        Parameters
624        ----------
625        inputs : torch.Tensor
626            Input to apply gaussian filter on.
627
628        Returns
629        -------
630        filtered : torch.Tensor)
631            Filtered output.
632
633        """
634        _, c, _ = inputs.shape
635        inputs = F.pad(
636            inputs,
637            pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2),
638            mode="reflect",
639        )
640        kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device)
641        return F.conv1d(inputs, weight=kernel, groups=c)
642
643
644def argrelmax(prob: np.ndarray, threshold: float = 0.7) -> List[int]:
645    """
646    Calculate arguments of relative maxima.
647
648    Parameters
649    ----------
650    prob : np.ndarray
651        boundary probability maps distributerd in [0, 1], shape is (T)
652    threshold : float, default 0.7
653        ignore the peak whose value is under threshold
654
655    Returns
656    -------
657    peak_idx : List[int]
658        Index of peaks for each batch
659
660    """
661    # ignore the values under threshold
662    prob[prob < threshold] = 0.0
663
664    # calculate the relative maxima of boundary maps
665    # treat the first frame as boundary
666    peak = np.concatenate(
667        [
668            np.ones((1), dtype=np.bool),
669            (prob[:-2] < prob[1:-1]) & (prob[2:] < prob[1:-1]),
670            np.zeros((1), dtype=np.bool),
671        ],
672        axis=0,
673    )
674
675    peak_idx = np.where(peak)[0].tolist()
676
677    return peak_idx
678
679
680def decide_boundary_prob_with_similarity(x: torch.Tensor) -> torch.Tensor:
681    """
682    Decide action boundary probabilities based on adjacent frame similarities.
683    Args:
684        x: frame-wise video features (N, C, T)
685    Return:
686        boundary: action boundary probability (N, 1, T)
687    """
688    device = x.device
689
690    # gaussian kernel.
691    diff = x[0, :, 1:] - x[0, :, :-1]
692    similarity = torch.exp(-torch.norm(diff, dim=0) / (2 * 1.0))
693
694    # define action starting point as action boundary.
695    start = torch.ones(1).float().to(device)
696    boundary = torch.cat([start, similarity])
697    boundary = boundary.view(1, 1, -1)
698    return boundary
699
700
701class PostProcessor(object):
702    def __init__(
703        self,
704        name: str,
705        boundary_th: int = 0.7,
706        theta_t: int = 15,
707        kernel_size: int = 15,
708    ) -> None:
709        self.func = {
710            "refinement_with_boundary": self._refinement_with_boundary,
711            "relabeling": self._relabeling,
712            "smoothing": self._smoothing,
713        }
714        assert name in self.func
715
716        self.name = name
717        self.boundary_th = boundary_th
718        self.theta_t = theta_t
719        self.kernel_size = kernel_size
720
721        if name == "smoothing":
722            self.filter = GaussianSmoothing(self.kernel_size)
723
724    def _is_probability(self, x: np.ndarray) -> bool:
725        assert x.ndim == 3
726
727        if x.shape[1] == 1:
728            # sigmoid
729            if x.min() >= 0 and x.max() <= 1:
730                return True
731            else:
732                return False
733        else:
734            # softmax
735            _sum = np.sum(x, axis=1).astype(np.float32)
736            _ones = np.ones_like(_sum, dtype=np.float32)
737            return np.allclose(_sum, _ones)
738
739    def _convert2probability(self, x: np.ndarray) -> np.ndarray:
740        """
741        Args: x (N, C, T)
742        """
743        assert x.ndim == 3
744
745        if self._is_probability(x):
746            return x
747        else:
748            if x.shape[1] == 1:
749                # sigmoid
750                prob = 1 / (1 + np.exp(-x))
751            else:
752                # softmax
753                prob = np.exp(x) / np.expand_dims(np.sum(np.exp(x), axis=1), 1)
754            return prob.astype(np.float32)
755
756    def _convert2label(self, x: np.ndarray) -> np.ndarray:
757        assert x.ndim == 2 or x.ndim == 3
758
759        if x.ndim == 2:
760            return x.astype(np.int64)
761        else:
762            if not self._is_probability(x):
763                x = self._convert2probability(x)
764
765            label = np.argmax(x, axis=1)
766            return label.astype(np.int64)
767
768    def _refinement_with_boundary(
769        self,
770        outputs: np.array,
771        boundaries: np.ndarray,
772    ) -> np.ndarray:
773        """
774        Get segments which is defined as the span b/w two boundaries,
775        and decide their classes by majority vote.
776        Args:
777            outputs: numpy array. shape (N, C, T)
778                the model output for frame-level class prediction.
779            boundaries: numpy array.  shape (N, 1, T)
780                boundary prediction.
781            masks: np.array. np.bool. shape (N, 1, T)
782                valid length for each video
783        Return:
784            preds: np.array. shape (N, T)
785                final class prediction considering boundaries.
786        """
787
788        preds = self._convert2label(outputs)
789        boundaries = self._convert2probability(boundaries)
790
791        for i, (output, pred, boundary) in enumerate(zip(outputs, preds, boundaries)):
792            idx = argrelmax(boundary.squeeze(), threshold=self.boundary_th)
793
794            # add the index of the last action ending
795            T = pred.shape[0]
796            idx.append(T)
797
798            # majority vote
799            for j in range(len(idx) - 1):
800                count = np.bincount(pred[idx[j] : idx[j + 1]])
801                modes = np.where(count == count.max())[0]
802                if len(modes) == 1:
803                    mode = modes
804                else:
805                    if outputs.ndim == 3:
806                        # if more than one majority class exist
807                        prob_sum_max = 0
808                        for m in modes:
809                            prob_sum = output[m, idx[j] : idx[j + 1]].sum()
810                            if prob_sum_max < prob_sum:
811                                mode = m
812                                prob_sum_max = prob_sum
813                    else:
814                        # decide first mode when more than one majority class
815                        # have the same number during oracle experiment
816                        mode = modes[0]
817
818                preds[i, idx[j] : idx[j + 1]] = mode
819
820        return preds
821
822    def _relabeling(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
823        """
824        Relabeling small action segments with their previous action segment
825        Args:
826            output: the results of action segmentation. (N, T) or (N, C, T)
827            theta_t: the threshold of the size of action segments.
828        Return:
829            relabeled output. (N, T)
830        """
831
832        preds = self._convert2label(outputs)
833
834        for i in range(preds.shape[0]):
835            # shape (T,)
836            last = preds[i][0]
837            cnt = 1
838            for j in range(1, preds.shape[1]):
839                if last == preds[i][j]:
840                    cnt += 1
841                else:
842                    if cnt > self.theta_t:
843                        cnt = 1
844                        last = preds[i][j]
845                    else:
846                        preds[i][j - cnt : j] = preds[i][j - cnt - 1]
847                        cnt = 1
848                        last = preds[i][j]
849
850            if cnt <= self.theta_t:
851                preds[i][j - cnt : j] = preds[i][j - cnt - 1]
852
853        return preds
854
855    def _smoothing(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
856        """
857        Smoothing action probabilities with gaussian filter.
858        Args:
859            outputs: frame-wise action probabilities. (N, C, T)
860        Return:
861            predictions: final prediction. (N, T)
862        """
863
864        outputs = self._convert2probability(outputs)
865        outputs = self.filter(torch.Tensor(outputs)).numpy()
866
867        preds = self._convert2label(outputs)
868        return preds
869
870    def __call__(self, outputs, **kwargs: np.ndarray) -> np.ndarray:
871
872        preds = self.func[self.name](outputs, **kwargs)
873        return preds
874
875
876def get_sentences_from_article(
877    pdf_filepath: str, max_symbols=77, min_symbols=15
878) -> List:
879    """
880    Read a pdf file, clean the text from references and split into sentences
881
882    Parameters
883    ----------
884    pdf_filepath : str
885        the path to the pdf file
886    max_symbols : int, default 77
887        sentences longer than this number of symbols will be omitted
888    min_symbols : int, default 15
889        sentences shorter than this number of symbols will be omitted
890
891    Returns
892    -------
893    sentences : list
894        a list of sentences
895    """
896
897    from tika import parser
898
899    rawText = parser.from_file(pdf_filepath)
900    text = rawText["content"].replace("-\n", "")
901    text = text.replace("\n", " ")
902    text = re.sub(r"(?<=\d),(?=\d)", "", text)
903    text = text.split("References")[0]
904    text = re.sub(r" \([\w&, ;]+\d{4}\)", "", text).strip()
905    text = re.sub(r" \[\d+\]", "", text).strip()
906    sentences = [x.strip() for x in text.split(".")]
907    sentences = [
908        x
909        for x in sentences
910        if not x.endswith(" al")
911        and len(x) <= max_symbols
912        and len(x) >= min_symbols
913        and sum(c.isdigit() for c in x) < len(x) * 0.2
914        and sum(len(w) < 2 for w in x.split()) < len(x) * 0.2
915    ]
916    return list(set(sentences))
917
918
919
920def load_pickle(file_path: str):
921    """
922    Load a pickle file
923
924    Parameters
925    ----------
926    file_path : str
927        the path to the pickle file
928
929    Returns
930    -------
931    obj : any
932        the loaded object
933    """
934
935    with open(file_path, "rb") as f:
936        obj = pickle.load(f)
937    return obj
938
939
940def binarize_data(data: list, max_frame: int = -1) -> np.array:
941    """Convert file in start, stop, confusion in to binary matrix of size (N_classes, N_timepoints)"""
942    data = data[3][0]
943    all_annot = [arr for arr in data if len(arr)]
944    all_annot = np.concatenate(all_annot, axis=0)
945    max_frame = np.max(all_annot) if max_frame == -1 else max_frame
946    label_array = np.zeros((len(data), max_frame))
947    for k, arr in enumerate(data):
948        for start, stop, confusion in arr:
949            if confusion < 0.5:
950                label_array[k, start:stop] = 1
951    return label_array
class TensorDict:
 33class TensorDict:
 34    """
 35    A class that handles indexing in a dictionary of tensors of the same length
 36    """
 37
 38    def __init__(self, obj: Union[Dict, Iterable] = None) -> None:
 39        """
 40        Parameters
 41        ----------
 42        obj : dict | iterable, optional
 43            either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with
 44            the same keys (if not passed, a blank TensorDict is initialized)
 45        """
 46
 47        if obj is None:
 48            obj = {}
 49        if isinstance(obj, list):
 50            self.keys = list(obj[0].keys())
 51            self.dict = {key: [] for key in self.keys}
 52            with warnings.catch_warnings():
 53                warnings.filterwarnings("ignore", category=UserWarning)
 54                for element in obj:
 55                    for key in self.keys:
 56                        self.dict[key].append(torch.tensor(element[key]))
 57            # self.dict = {k: torch.stack(v) for k, v in self.dict.items()}
 58            old_dict = self.dict
 59            self.dict = {}
 60            for k, v in old_dict.items():
 61                self.dict[k] = torch.stack(v)
 62        elif isinstance(obj, dict):
 63            if not all([isinstance(obj[key], torch.Tensor) for key in obj]):
 64                raise TypeError(
 65                    f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;"
 66                    f"got {[type(obj[key]) for key in obj]}"
 67                )
 68            lengths = [len(obj[key]) for key in obj]
 69            if not all([x == lengths[0] for x in lengths]):
 70                raise ValueError(
 71                    f"The value tensors in the dictionary passed to TensorDict need to have the same length;"
 72                    f"got {lengths}"
 73                )
 74            self.dict = obj
 75            self.keys = list(self.dict.keys())
 76        else:
 77            raise TypeError(
 78                f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary "
 79                f"of tensors (got {type(obj)})"
 80            )
 81        if len(self.keys) > 0:
 82            self.type = type(self.dict[self.keys[0]])
 83        else:
 84            self.type = None
 85
 86    def __len__(self) -> int:
 87        return len(self.dict[self.keys[0]])
 88
 89    def __getitem__(self, ind: Union[int, List]):
 90        """
 91        Index the TensorDict
 92
 93        Parameters
 94        ----------
 95        ind : int | list
 96            the index/indices
 97
 98        Returns
 99        -------
100        dict | TensorDict
101            the indexed elements of all value lists combined in a dictionary (if length of ind is 1) or a TensorDict
102        """
103
104        x = {key: self.dict[key][ind] for key in self.dict}
105        if type(ind) is not int:
106            d = TensorDict(x)
107            return d
108        else:
109            return x
110
111    def append(self, element: Dict) -> None:
112        """
113        Append an element
114
115        Parameters
116        ----------
117        element : dict
118            a dictionary
119        """
120
121        type_element = type(element[list(element.keys())[0]])
122        if self.type is None:
123            self.dict = {k: v.unsqueeze(0) for k, v in element.items()}
124            self.keys = list(element.keys())
125            self.type = type_element
126        else:
127            for key in self.keys:
128                if key not in element:
129                    raise ValueError(
130                        f"The dictionary appended to TensorDict needs to have the same keys as the "
131                        f"TensorDict; got {element.keys()} and {self.keys}"
132                    )
133                self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)])
134
135    def remove(self, indices: List) -> None:
136        """
137        Remove indexed elements
138
139        Parameters
140        ----------
141        indices : list
142            the indices to remove
143        """
144
145        mask = torch.ones(len(self))
146        mask[indices] = 0
147        mask = mask.bool()
148        for key, value in self.dict.items():
149            self.dict[key] = value[mask]

A class that handles indexing in a dictionary of tensors of the same length

TensorDict(obj: Union[Dict, Iterable] = None)
38    def __init__(self, obj: Union[Dict, Iterable] = None) -> None:
39        """
40        Parameters
41        ----------
42        obj : dict | iterable, optional
43            either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with
44            the same keys (if not passed, a blank TensorDict is initialized)
45        """
46
47        if obj is None:
48            obj = {}
49        if isinstance(obj, list):
50            self.keys = list(obj[0].keys())
51            self.dict = {key: [] for key in self.keys}
52            with warnings.catch_warnings():
53                warnings.filterwarnings("ignore", category=UserWarning)
54                for element in obj:
55                    for key in self.keys:
56                        self.dict[key].append(torch.tensor(element[key]))
57            # self.dict = {k: torch.stack(v) for k, v in self.dict.items()}
58            old_dict = self.dict
59            self.dict = {}
60            for k, v in old_dict.items():
61                self.dict[k] = torch.stack(v)
62        elif isinstance(obj, dict):
63            if not all([isinstance(obj[key], torch.Tensor) for key in obj]):
64                raise TypeError(
65                    f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;"
66                    f"got {[type(obj[key]) for key in obj]}"
67                )
68            lengths = [len(obj[key]) for key in obj]
69            if not all([x == lengths[0] for x in lengths]):
70                raise ValueError(
71                    f"The value tensors in the dictionary passed to TensorDict need to have the same length;"
72                    f"got {lengths}"
73                )
74            self.dict = obj
75            self.keys = list(self.dict.keys())
76        else:
77            raise TypeError(
78                f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary "
79                f"of tensors (got {type(obj)})"
80            )
81        if len(self.keys) > 0:
82            self.type = type(self.dict[self.keys[0]])
83        else:
84            self.type = None

Parameters

obj : dict | iterable, optional either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with the same keys (if not passed, a blank TensorDict is initialized)

def append(self, element: Dict) -> None:
111    def append(self, element: Dict) -> None:
112        """
113        Append an element
114
115        Parameters
116        ----------
117        element : dict
118            a dictionary
119        """
120
121        type_element = type(element[list(element.keys())[0]])
122        if self.type is None:
123            self.dict = {k: v.unsqueeze(0) for k, v in element.items()}
124            self.keys = list(element.keys())
125            self.type = type_element
126        else:
127            for key in self.keys:
128                if key not in element:
129                    raise ValueError(
130                        f"The dictionary appended to TensorDict needs to have the same keys as the "
131                        f"TensorDict; got {element.keys()} and {self.keys}"
132                    )
133                self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)])

Append an element

Parameters

element : dict a dictionary

def remove(self, indices: List) -> None:
135    def remove(self, indices: List) -> None:
136        """
137        Remove indexed elements
138
139        Parameters
140        ----------
141        indices : list
142            the indices to remove
143        """
144
145        mask = torch.ones(len(self))
146        mask[indices] = 0
147        mask = mask.bool()
148        for key, value in self.dict.items():
149            self.dict[key] = value[mask]

Remove indexed elements

Parameters

indices : list the indices to remove

def apply_threshold( tensor: torch.Tensor, threshold: float, low: bool = True, error_mask: torch.Tensor = None, min_frames: int = 0, smooth_interval: int = 0, masked_intervals: List = None):
152def apply_threshold(
153    tensor: torch.Tensor,
154    threshold: float,
155    low: bool = True,
156    error_mask: torch.Tensor = None,
157    min_frames: int = 0,
158    smooth_interval: int = 0,
159    masked_intervals: List = None,
160):
161    """
162    Apply a hard threshold to a tensor and return indices of the intervals that passed.
163
164    If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold.
165    If `min_frames` is not 0, the intervals are additionally filtered by length.
166
167    Parameters
168    ----------
169    tensor : torch.Tensor
170        the tensor to apply the threshold to
171    threshold : float
172        the threshold
173    low : bool, default True
174        whether to select the intervals below the threshold (if `True`) or above (if `False`)
175    error_mask : torch.Tensor, optional
176        a boolean real_lens to apply to the results
177    min_frames : int, default 0
178        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
179    smooth_interval : int, default 0
180        the number of frames to smooth the tensor with before applying the threshold
181    masked_intervals : list, optional
182        a list of intervals to mask out
183
184    Returns
185    -------
186    indices_start : list
187        a list of indices of the first frames of the chosen intervals
188    indices_end : list
189        a list of indices of the last frames of the chosen intervals
190
191    """
192    if masked_intervals is None:
193        masked_intervals = []
194    if low:
195        p = tensor <= threshold
196    else:
197        p = tensor >= threshold
198    p = smooth(p, smooth_interval)
199    if error_mask is not None:
200        p = p * error_mask
201    for start, end in masked_intervals:
202        p[start:end] = False
203    output, indices, counts = torch.unique_consecutive(
204        p, return_inverse=True, return_counts=True
205    )
206    long_indices = torch.where(output * (counts > min_frames))[0]
207    res_indices_start = [
208        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
209    ]
210    res_indices_end = [
211        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
212    ]
213    return res_indices_start, res_indices_end

Apply a hard threshold to a tensor and return indices of the intervals that passed.

If error_mask is not None, the elements marked False are treated as if they did not pass the threshold. If min_frames is not 0, the intervals are additionally filtered by length.

Parameters

tensor : torch.Tensor the tensor to apply the threshold to threshold : float the threshold low : bool, default True whether to select the intervals below the threshold (if True) or above (if False) error_mask : torch.Tensor, optional a boolean real_lens to apply to the results min_frames : int, default 0 the minimum number of frames in the resulting intervals (shorter intervals are discarded) smooth_interval : int, default 0 the number of frames to smooth the tensor with before applying the threshold masked_intervals : list, optional a list of intervals to mask out

Returns

indices_start : list a list of indices of the first frames of the chosen intervals indices_end : list a list of indices of the last frames of the chosen intervals

def apply_threshold_hysteresis( tensor: torch.Tensor, soft_threshold: float, hard_threshold: float, low: bool = True, error_mask: torch.Tensor = None, min_frames: int = 0, smooth_interval: int = 0, masked_intervals: List = None):
216def apply_threshold_hysteresis(
217    tensor: torch.Tensor,
218    soft_threshold: float,
219    hard_threshold: float,
220    low: bool = True,
221    error_mask: torch.Tensor = None,
222    min_frames: int = 0,
223    smooth_interval: int = 0,
224    masked_intervals: List = None,
225):
226    """
227    Apply a hysteresis threshold to a tensor and return indices of the intervals that passed.
228
229    In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold.
230    If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold.
231    If `min_frames` is not 0, the intervals are additionally filtered by length.
232
233    Parameters
234    ----------
235    tensor : torch.Tensor
236        the tensor to apply the threshold to
237    soft_threshold : float
238        the soft threshold
239    hard_threshold : float
240        the hard threshold
241    low : bool, default True
242        whether to select values below the threshold (`True`) or above (`False`)
243    error_mask : torch.Tensor, optional
244        a boolean real_lens to apply to the results
245    min_frames : int, default 0
246        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
247    smooth_interval : int, default 0
248        the number of frames to smooth the tensor with
249    masked_intervals : list, default None
250        a list of intervals to mask out
251
252    Returns
253    -------
254    indices_start : list
255        a list of indices of the first frames of the chosen intervals
256    indices_end : list
257        a list of indices of the last frames of the chosen intervals
258
259    """
260    if masked_intervals is None:
261        masked_intervals = []
262    if low:
263        p = tensor <= soft_threshold
264        hard = tensor <= hard_threshold
265    else:
266        p = tensor >= soft_threshold
267        hard = tensor >= hard_threshold
268    p = smooth(p, smooth_interval)
269    if error_mask is not None:
270        p = p * error_mask
271    for start, end in masked_intervals:
272        p[start:end] = False
273    output, indices, counts = torch.unique_consecutive(
274        p, return_inverse=True, return_counts=True
275    )
276    long_indices = torch.where(output * (counts > min_frames))[0]
277    indices_start = [
278        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
279    ]
280    indices_end = [
281        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
282    ]
283    res_indices_start = []
284    res_indices_end = []
285    for start, end in zip(indices_start, indices_end):
286        if torch.sum(hard[start:end]) > 0:
287            res_indices_start.append(start)
288            res_indices_end.append(end)
289    return res_indices_start, res_indices_end

Apply a hysteresis threshold to a tensor and return indices of the intervals that passed.

In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold. If error_mask is not None, the elements marked False are treated as if they did not pass the threshold. If min_frames is not 0, the intervals are additionally filtered by length.

Parameters

tensor : torch.Tensor the tensor to apply the threshold to soft_threshold : float the soft threshold hard_threshold : float the hard threshold low : bool, default True whether to select values below the threshold (True) or above (False) error_mask : torch.Tensor, optional a boolean real_lens to apply to the results min_frames : int, default 0 the minimum number of frames in the resulting intervals (shorter intervals are discarded) smooth_interval : int, default 0 the number of frames to smooth the tensor with masked_intervals : list, default None a list of intervals to mask out

Returns

indices_start : list a list of indices of the first frames of the chosen intervals indices_end : list a list of indices of the last frames of the chosen intervals

def apply_threshold_max( tensor: torch.Tensor, threshold: float, main_class: int, error_mask: torch.Tensor = None, min_frames: int = 0, smooth_interval: int = 0, masked_intervals: List = None):
292def apply_threshold_max(
293    tensor: torch.Tensor,
294    threshold: float,
295    main_class: int,
296    error_mask: torch.Tensor = None,
297    min_frames: int = 0,
298    smooth_interval: int = 0,
299    masked_intervals: List = None,
300):
301    """
302    Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed.
303
304    In the chosen intervals the values at the `main_class` index are larger than the others everywhere
305    and at least one value at the `main_class` index passes the threshold.
306    If `error_mask` is not `None`, the elements marked `False`are treated as if they did not pass the threshold.
307    If min_frames is not 0, the intervals are additionally filtered by length.
308
309    Parameters
310    ----------
311    tensor : torch.Tensor
312        the tensor to apply the threshold to (of shape `(#classes, #frames)`)
313    threshold : float
314        the threshold
315    main_class : int
316        the class that conditions the soft threshold
317    error_mask : torch.Tensor, optional
318        a boolean real_lens to apply to the results
319    min_frames : int, default 0
320        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
321    smooth_interval : int, default 0
322        the number of frames to smooth the tensor with
323    masked_intervals : list, default None
324        a list of intervals to mask out
325
326    Returns
327    -------
328    indices_start : list
329        a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor)
330    indices_end : list
331        a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor)
332
333    """
334    if masked_intervals is None:
335        masked_intervals = []
336    _, indices = torch.max(tensor, dim=0)
337    p = indices == main_class
338    p = smooth(p, smooth_interval)
339    if error_mask is not None:
340        p = p * error_mask
341    for start, end in masked_intervals:
342        p[start:end] = False
343    output, indices, counts = torch.unique_consecutive(
344        p, return_inverse=True, return_counts=True
345    )
346    long_indices = torch.where(output * (counts > min_frames))[0]
347    indices_start = [
348        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
349    ]
350    indices_end = [
351        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
352    ]
353    res_indices_start = []
354    res_indices_end = []
355    if threshold is not None:
356        hard = tensor[main_class, :] > threshold
357        for start, end in zip(indices_start, indices_end):
358            if torch.sum(hard[start:end]) > 0:
359                res_indices_start.append(start)
360                res_indices_end.append(end)
361        return res_indices_start, res_indices_end
362    else:
363        return indices_start, indices_end

Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed.

In the chosen intervals the values at the main_class index are larger than the others everywhere and at least one value at the main_class index passes the threshold. If error_mask is not None, the elements marked Falseare treated as if they did not pass the threshold. If min_frames is not 0, the intervals are additionally filtered by length.

Parameters

tensor : torch.Tensor the tensor to apply the threshold to (of shape (#classes, #frames)) threshold : float the threshold main_class : int the class that conditions the soft threshold error_mask : torch.Tensor, optional a boolean real_lens to apply to the results min_frames : int, default 0 the minimum number of frames in the resulting intervals (shorter intervals are discarded) smooth_interval : int, default 0 the number of frames to smooth the tensor with masked_intervals : list, default None a list of intervals to mask out

Returns

indices_start : list a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor) indices_end : list a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor)

def strip_suffix(text: str, suffix: Iterable):
366def strip_suffix(text: str, suffix: Iterable):
367    """
368    Strip a suffix from a string if it is contained in a list
369
370    Parameters
371    ----------
372    text : str
373        the main string
374    suffix : iterable
375        the list of suffices to be stripped
376
377    Returns
378    -------
379    result : str
380        the stripped string
381    """
382
383    for s in suffix:
384        if text.endswith(s):
385            return text[: -len(s)]
386    return text

Strip a suffix from a string if it is contained in a list

Parameters

text : str the main string suffix : iterable the list of suffices to be stripped

Returns

result : str the stripped string

def strip_prefix(text: str, prefix: Iterable):
389def strip_prefix(text: str, prefix: Iterable):
390    """
391    Strip a prefix from a string if it is contained in a list
392
393    Parameters
394    ----------
395    text : str
396        the main string
397    prefix : iterable
398        the list of prefixes to be stripped
399
400    Returns
401    -------
402    result : str
403        the stripped string
404    """
405
406    if prefix is None:
407        prefix = []
408    for s in prefix:
409        if text.startswith(s):
410            return text[len(s) :]
411    return text

Strip a prefix from a string if it is contained in a list

Parameters

text : str the main string prefix : iterable the list of prefixes to be stripped

Returns

result : str the stripped string

def rotation_matrix_2d(angles: torch.Tensor) -> torch.Tensor:
414def rotation_matrix_2d(angles: torch.Tensor) -> torch.Tensor:
415    """
416    Create a tensor of 2D rotation matrices from a tensor of angles
417
418    Parameters
419    ----------
420    angles : torch.Tensor
421        a tensor of angles of arbitrary shape `(...)`
422
423    Returns
424    -------
425    rotation_matrices : torch.Tensor
426        a tensor of 2D rotation matrices of shape `(..., 2, 2)`
427    """
428
429    cos = torch.cos(angles)
430    sin = torch.sin(angles)
431    R = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*angles.shape, 2, 2)
432    return R

Create a tensor of 2D rotation matrices from a tensor of angles

Parameters

angles : torch.Tensor a tensor of angles of arbitrary shape (...)

Returns

rotation_matrices : torch.Tensor a tensor of 2D rotation matrices of shape (..., 2, 2)

def rotation_matrix_3d(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor):
435def rotation_matrix_3d(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor):
436    """
437    Create a tensor of 3D rotation matrices from a tensor of angles
438
439    Parameters
440    ----------
441    alpha : torch.Tensor
442        a tensor of rotation angles around the x axis of arbitrary shape `(...)`
443    beta : torch.Tensor
444        a tensor of rotation angles around the y axis of arbitrary shape `(...)`
445    gamma : torch.Tensor
446        a tensor of rotation angles around the z axis of arbitrary shape `(...)`
447
448    Returns
449    -------
450    rotation_matrices : torch.Tensor
451        a tensor of 3D rotation matrices of shape `(..., 3, 3)`
452    """
453
454    cos = torch.cos(alpha)
455    sin = torch.sin(alpha)
456    Rx = torch.stack(
457        [
458            torch.ones(cos.shape),
459            torch.zeros(cos.shape),
460            torch.zeros(cos.shape),
461            torch.zeros(cos.shape),
462            cos,
463            -sin,
464            torch.zeros(cos.shape),
465            sin,
466            cos,
467        ],
468        dim=-1,
469    ).reshape(*alpha.shape, 3, 3)
470    cos = torch.cos(beta)
471    sin = torch.sin(beta)
472    Ry = torch.stack(
473        [
474            cos,
475            torch.zeros(cos.shape),
476            sin,
477            torch.zeros(cos.shape),
478            torch.ones(cos.shape),
479            torch.zeros(cos.shape),
480            -sin,
481            torch.zeros(cos.shape),
482            cos,
483        ],
484        dim=-1,
485    ).reshape(*beta.shape, 3, 3)
486    cos = torch.cos(gamma)
487    sin = torch.sin(gamma)
488    Rz = torch.stack(
489        [
490            cos,
491            -sin,
492            torch.zeros(cos.shape),
493            sin,
494            cos,
495            torch.zeros(cos.shape),
496            torch.zeros(cos.shape),
497            torch.zeros(cos.shape),
498            torch.ones(cos.shape),
499        ],
500        dim=-1,
501    ).reshape(*gamma.shape, 3, 3)
502    R = torch.einsum("...ij,...jk,...kl->...il", Rx, Ry, Rz)
503    return R

Create a tensor of 3D rotation matrices from a tensor of angles

Parameters

alpha : torch.Tensor a tensor of rotation angles around the x axis of arbitrary shape (...) beta : torch.Tensor a tensor of rotation angles around the y axis of arbitrary shape (...) gamma : torch.Tensor a tensor of rotation angles around the z axis of arbitrary shape (...)

Returns

rotation_matrices : torch.Tensor a tensor of 3D rotation matrices of shape (..., 3, 3)

def correct_path(path, project_path):
506def correct_path(path, project_path):
507    """
508    Correct a path to a file or directory to be relative to the project path.
509
510    Parameters
511    ----------
512    path : str
513        the path to be corrected
514    project_path : str
515        the path to the project root directory
516
517    Returns
518    -------
519    corrected_path : str
520        the corrected path
521
522    """
523    if not isinstance(path, str):
524        return path
525    path_split = os.path.normpath(path).split(os.path.sep)
526    if "results" in path_split:
527        name = "results"
528    elif "saved_datasets" in path:
529        name = "saved_datasets"
530    else:
531        return os.path.sep.join(path_split)  # No need for correction
532    ind = path_split.index(name) + 1
533    return os.path.join(project_path, "results", *path_split[ind:])

Correct a path to a file or directory to be relative to the project path.

Parameters

path : str the path to be corrected project_path : str the path to the project root directory

Returns

corrected_path : str the corrected path

class TensorList(builtins.list):
536class TensorList(list):
537    """
538    A list of tensors that can send each element to a `torch` device
539    """
540
541    def to_device(self, device: torch.device):
542        """Send each element of the list to a `torch` device."""
543        for i, x in enumerate(self):
544            self[i] = x.to_device(device)

A list of tensors that can send each element to a torch device

def to_device(self, device: torch.device):
541    def to_device(self, device: torch.device):
542        """Send each element of the list to a `torch` device."""
543        for i, x in enumerate(self):
544            self[i] = x.to_device(device)

Send each element of the list to a torch device.

def get_intervals(tensor: torch.Tensor) -> torch.Tensor:
547def get_intervals(tensor: torch.Tensor) -> torch.Tensor:
548    """
549    Get a list of True group beginning and end indices from a boolean tensor
550    """
551
552    output, indices = torch.unique_consecutive(tensor, return_inverse=True)
553    true_indices = torch.where(output)[0]
554    starts = torch.tensor(
555        [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
556    )
557    ends = torch.tensor(
558        [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
559    )
560    return torch.stack([starts, ends]).T

Get a list of True group beginning and end indices from a boolean tensor

def smooth(tensor: torch.Tensor, smooth_interval: int = 0) -> torch.Tensor:
563def smooth(tensor: torch.Tensor, smooth_interval: int = 0) -> torch.Tensor:
564    """
565    Get rid of jittering in a non-exclusive classification tensor
566
567    First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than
568    `smooth_interval`.
569    """
570
571    if smooth_interval == 0:
572        return tensor
573    intervals = get_intervals(tensor == 0)
574    interval_lengths = torch.tensor(
575        [interval[1] - interval[0] for interval in intervals]
576    )
577    short_intervals = intervals[interval_lengths <= smooth_interval]
578    for start, end in short_intervals:
579        tensor[start:end] = 1
580    intervals = get_intervals(tensor == 1)
581    interval_lengths = torch.tensor(
582        [interval[1] - interval[0] for interval in intervals]
583    )
584    short_intervals = intervals[interval_lengths <= smooth_interval]
585    for start, end in short_intervals:
586        tensor[start:end] = 0
587    return tensor

Get rid of jittering in a non-exclusive classification tensor

First, remove intervals of 0 shorter than smooth_interval. Then, remove intervals of 1 shorter than smooth_interval.

class GaussianSmoothing(torch.nn.modules.module.Module):
590class GaussianSmoothing(nn.Module):
591    """Apply gaussian smoothing on a 1d tensor.
592
593    Filtering is performed separately for each channel
594    in the input using a depthwise convolution.
595    Arguments:
596        channels (int, sequence): Number of channels of the input tensors. Output will
597            have this number of channels as well.
598        kernel_size (int, sequence): Size of the gaussian kernel.
599        sigma (float, sequence): Standard deviation of the gaussian kernel.
600    """
601
602    def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None:
603        super().__init__()
604        self.kernel_size = kernel_size
605
606        # The gaussian kernel is the product of the
607        # gaussian function of each dimension.
608        kernel = 1
609        meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float()
610
611        mean = (kernel_size - 1) / 2
612        kernel = kernel / (sigma * math.sqrt(2 * math.pi))
613        kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2)
614
615        # Make sure sum of values in gaussian kernel equals 1.
616        # kernel = kernel / torch.max(kernel)
617
618        self.kernel = kernel.view(1, 1, *kernel.size())
619
620    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
621        """
622        Apply gaussian filter to input.
623
624        Parameters
625        ----------
626        inputs : torch.Tensor
627            Input to apply gaussian filter on.
628
629        Returns
630        -------
631        filtered : torch.Tensor)
632            Filtered output.
633
634        """
635        _, c, _ = inputs.shape
636        inputs = F.pad(
637            inputs,
638            pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2),
639            mode="reflect",
640        )
641        kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device)
642        return F.conv1d(inputs, weight=kernel, groups=c)

Apply gaussian smoothing on a 1d tensor.

Filtering is performed separately for each channel in the input using a depthwise convolution. Arguments: channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel.

GaussianSmoothing(kernel_size: int = 15, sigma: float = 1.0)
602    def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None:
603        super().__init__()
604        self.kernel_size = kernel_size
605
606        # The gaussian kernel is the product of the
607        # gaussian function of each dimension.
608        kernel = 1
609        meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float()
610
611        mean = (kernel_size - 1) / 2
612        kernel = kernel / (sigma * math.sqrt(2 * math.pi))
613        kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2)
614
615        # Make sure sum of values in gaussian kernel equals 1.
616        # kernel = kernel / torch.max(kernel)
617
618        self.kernel = kernel.view(1, 1, *kernel.size())

Initialize internal Module state, shared by both nn.Module and ScriptModule.

kernel_size
kernel
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
620    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
621        """
622        Apply gaussian filter to input.
623
624        Parameters
625        ----------
626        inputs : torch.Tensor
627            Input to apply gaussian filter on.
628
629        Returns
630        -------
631        filtered : torch.Tensor)
632            Filtered output.
633
634        """
635        _, c, _ = inputs.shape
636        inputs = F.pad(
637            inputs,
638            pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2),
639            mode="reflect",
640        )
641        kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device)
642        return F.conv1d(inputs, weight=kernel, groups=c)

Apply gaussian filter to input.

Parameters

inputs : torch.Tensor Input to apply gaussian filter on.

Returns

filtered : torch.Tensor) Filtered output.

def argrelmax(prob: numpy.ndarray, threshold: float = 0.7) -> List[int]:
645def argrelmax(prob: np.ndarray, threshold: float = 0.7) -> List[int]:
646    """
647    Calculate arguments of relative maxima.
648
649    Parameters
650    ----------
651    prob : np.ndarray
652        boundary probability maps distributerd in [0, 1], shape is (T)
653    threshold : float, default 0.7
654        ignore the peak whose value is under threshold
655
656    Returns
657    -------
658    peak_idx : List[int]
659        Index of peaks for each batch
660
661    """
662    # ignore the values under threshold
663    prob[prob < threshold] = 0.0
664
665    # calculate the relative maxima of boundary maps
666    # treat the first frame as boundary
667    peak = np.concatenate(
668        [
669            np.ones((1), dtype=np.bool),
670            (prob[:-2] < prob[1:-1]) & (prob[2:] < prob[1:-1]),
671            np.zeros((1), dtype=np.bool),
672        ],
673        axis=0,
674    )
675
676    peak_idx = np.where(peak)[0].tolist()
677
678    return peak_idx

Calculate arguments of relative maxima.

Parameters

prob : np.ndarray boundary probability maps distributerd in [0, 1], shape is (T) threshold : float, default 0.7 ignore the peak whose value is under threshold

Returns

peak_idx : List[int] Index of peaks for each batch

def decide_boundary_prob_with_similarity(x: torch.Tensor) -> torch.Tensor:
681def decide_boundary_prob_with_similarity(x: torch.Tensor) -> torch.Tensor:
682    """
683    Decide action boundary probabilities based on adjacent frame similarities.
684    Args:
685        x: frame-wise video features (N, C, T)
686    Return:
687        boundary: action boundary probability (N, 1, T)
688    """
689    device = x.device
690
691    # gaussian kernel.
692    diff = x[0, :, 1:] - x[0, :, :-1]
693    similarity = torch.exp(-torch.norm(diff, dim=0) / (2 * 1.0))
694
695    # define action starting point as action boundary.
696    start = torch.ones(1).float().to(device)
697    boundary = torch.cat([start, similarity])
698    boundary = boundary.view(1, 1, -1)
699    return boundary

Decide action boundary probabilities based on adjacent frame similarities. Args: x: frame-wise video features (N, C, T) Return: boundary: action boundary probability (N, 1, T)

class PostProcessor:
702class PostProcessor(object):
703    def __init__(
704        self,
705        name: str,
706        boundary_th: int = 0.7,
707        theta_t: int = 15,
708        kernel_size: int = 15,
709    ) -> None:
710        self.func = {
711            "refinement_with_boundary": self._refinement_with_boundary,
712            "relabeling": self._relabeling,
713            "smoothing": self._smoothing,
714        }
715        assert name in self.func
716
717        self.name = name
718        self.boundary_th = boundary_th
719        self.theta_t = theta_t
720        self.kernel_size = kernel_size
721
722        if name == "smoothing":
723            self.filter = GaussianSmoothing(self.kernel_size)
724
725    def _is_probability(self, x: np.ndarray) -> bool:
726        assert x.ndim == 3
727
728        if x.shape[1] == 1:
729            # sigmoid
730            if x.min() >= 0 and x.max() <= 1:
731                return True
732            else:
733                return False
734        else:
735            # softmax
736            _sum = np.sum(x, axis=1).astype(np.float32)
737            _ones = np.ones_like(_sum, dtype=np.float32)
738            return np.allclose(_sum, _ones)
739
740    def _convert2probability(self, x: np.ndarray) -> np.ndarray:
741        """
742        Args: x (N, C, T)
743        """
744        assert x.ndim == 3
745
746        if self._is_probability(x):
747            return x
748        else:
749            if x.shape[1] == 1:
750                # sigmoid
751                prob = 1 / (1 + np.exp(-x))
752            else:
753                # softmax
754                prob = np.exp(x) / np.expand_dims(np.sum(np.exp(x), axis=1), 1)
755            return prob.astype(np.float32)
756
757    def _convert2label(self, x: np.ndarray) -> np.ndarray:
758        assert x.ndim == 2 or x.ndim == 3
759
760        if x.ndim == 2:
761            return x.astype(np.int64)
762        else:
763            if not self._is_probability(x):
764                x = self._convert2probability(x)
765
766            label = np.argmax(x, axis=1)
767            return label.astype(np.int64)
768
769    def _refinement_with_boundary(
770        self,
771        outputs: np.array,
772        boundaries: np.ndarray,
773    ) -> np.ndarray:
774        """
775        Get segments which is defined as the span b/w two boundaries,
776        and decide their classes by majority vote.
777        Args:
778            outputs: numpy array. shape (N, C, T)
779                the model output for frame-level class prediction.
780            boundaries: numpy array.  shape (N, 1, T)
781                boundary prediction.
782            masks: np.array. np.bool. shape (N, 1, T)
783                valid length for each video
784        Return:
785            preds: np.array. shape (N, T)
786                final class prediction considering boundaries.
787        """
788
789        preds = self._convert2label(outputs)
790        boundaries = self._convert2probability(boundaries)
791
792        for i, (output, pred, boundary) in enumerate(zip(outputs, preds, boundaries)):
793            idx = argrelmax(boundary.squeeze(), threshold=self.boundary_th)
794
795            # add the index of the last action ending
796            T = pred.shape[0]
797            idx.append(T)
798
799            # majority vote
800            for j in range(len(idx) - 1):
801                count = np.bincount(pred[idx[j] : idx[j + 1]])
802                modes = np.where(count == count.max())[0]
803                if len(modes) == 1:
804                    mode = modes
805                else:
806                    if outputs.ndim == 3:
807                        # if more than one majority class exist
808                        prob_sum_max = 0
809                        for m in modes:
810                            prob_sum = output[m, idx[j] : idx[j + 1]].sum()
811                            if prob_sum_max < prob_sum:
812                                mode = m
813                                prob_sum_max = prob_sum
814                    else:
815                        # decide first mode when more than one majority class
816                        # have the same number during oracle experiment
817                        mode = modes[0]
818
819                preds[i, idx[j] : idx[j + 1]] = mode
820
821        return preds
822
823    def _relabeling(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
824        """
825        Relabeling small action segments with their previous action segment
826        Args:
827            output: the results of action segmentation. (N, T) or (N, C, T)
828            theta_t: the threshold of the size of action segments.
829        Return:
830            relabeled output. (N, T)
831        """
832
833        preds = self._convert2label(outputs)
834
835        for i in range(preds.shape[0]):
836            # shape (T,)
837            last = preds[i][0]
838            cnt = 1
839            for j in range(1, preds.shape[1]):
840                if last == preds[i][j]:
841                    cnt += 1
842                else:
843                    if cnt > self.theta_t:
844                        cnt = 1
845                        last = preds[i][j]
846                    else:
847                        preds[i][j - cnt : j] = preds[i][j - cnt - 1]
848                        cnt = 1
849                        last = preds[i][j]
850
851            if cnt <= self.theta_t:
852                preds[i][j - cnt : j] = preds[i][j - cnt - 1]
853
854        return preds
855
856    def _smoothing(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
857        """
858        Smoothing action probabilities with gaussian filter.
859        Args:
860            outputs: frame-wise action probabilities. (N, C, T)
861        Return:
862            predictions: final prediction. (N, T)
863        """
864
865        outputs = self._convert2probability(outputs)
866        outputs = self.filter(torch.Tensor(outputs)).numpy()
867
868        preds = self._convert2label(outputs)
869        return preds
870
871    def __call__(self, outputs, **kwargs: np.ndarray) -> np.ndarray:
872
873        preds = self.func[self.name](outputs, **kwargs)
874        return preds
PostProcessor( name: str, boundary_th: int = 0.7, theta_t: int = 15, kernel_size: int = 15)
703    def __init__(
704        self,
705        name: str,
706        boundary_th: int = 0.7,
707        theta_t: int = 15,
708        kernel_size: int = 15,
709    ) -> None:
710        self.func = {
711            "refinement_with_boundary": self._refinement_with_boundary,
712            "relabeling": self._relabeling,
713            "smoothing": self._smoothing,
714        }
715        assert name in self.func
716
717        self.name = name
718        self.boundary_th = boundary_th
719        self.theta_t = theta_t
720        self.kernel_size = kernel_size
721
722        if name == "smoothing":
723            self.filter = GaussianSmoothing(self.kernel_size)
func
name
boundary_th
theta_t
kernel_size
def get_sentences_from_article(pdf_filepath: str, max_symbols=77, min_symbols=15) -> List:
877def get_sentences_from_article(
878    pdf_filepath: str, max_symbols=77, min_symbols=15
879) -> List:
880    """
881    Read a pdf file, clean the text from references and split into sentences
882
883    Parameters
884    ----------
885    pdf_filepath : str
886        the path to the pdf file
887    max_symbols : int, default 77
888        sentences longer than this number of symbols will be omitted
889    min_symbols : int, default 15
890        sentences shorter than this number of symbols will be omitted
891
892    Returns
893    -------
894    sentences : list
895        a list of sentences
896    """
897
898    from tika import parser
899
900    rawText = parser.from_file(pdf_filepath)
901    text = rawText["content"].replace("-\n", "")
902    text = text.replace("\n", " ")
903    text = re.sub(r"(?<=\d),(?=\d)", "", text)
904    text = text.split("References")[0]
905    text = re.sub(r" \([\w&, ;]+\d{4}\)", "", text).strip()
906    text = re.sub(r" \[\d+\]", "", text).strip()
907    sentences = [x.strip() for x in text.split(".")]
908    sentences = [
909        x
910        for x in sentences
911        if not x.endswith(" al")
912        and len(x) <= max_symbols
913        and len(x) >= min_symbols
914        and sum(c.isdigit() for c in x) < len(x) * 0.2
915        and sum(len(w) < 2 for w in x.split()) < len(x) * 0.2
916    ]
917    return list(set(sentences))

Read a pdf file, clean the text from references and split into sentences

Parameters

pdf_filepath : str the path to the pdf file max_symbols : int, default 77 sentences longer than this number of symbols will be omitted min_symbols : int, default 15 sentences shorter than this number of symbols will be omitted

Returns

sentences : list a list of sentences

def load_pickle(file_path: str):
921def load_pickle(file_path: str):
922    """
923    Load a pickle file
924
925    Parameters
926    ----------
927    file_path : str
928        the path to the pickle file
929
930    Returns
931    -------
932    obj : any
933        the loaded object
934    """
935
936    with open(file_path, "rb") as f:
937        obj = pickle.load(f)
938    return obj

Load a pickle file

Parameters

file_path : str the path to the pickle file

Returns

obj : any the loaded object

def binarize_data(data: list, max_frame: int = -1) -> <built-in function array>:
941def binarize_data(data: list, max_frame: int = -1) -> np.array:
942    """Convert file in start, stop, confusion in to binary matrix of size (N_classes, N_timepoints)"""
943    data = data[3][0]
944    all_annot = [arr for arr in data if len(arr)]
945    all_annot = np.concatenate(all_annot, axis=0)
946    max_frame = np.max(all_annot) if max_frame == -1 else max_frame
947    label_array = np.zeros((len(data), max_frame))
948    for k, arr in enumerate(data):
949        for start, stop, confusion in arr:
950            if confusion < 0.5:
951                label_array[k, start:stop] = 1
952    return label_array

Convert file in start, stop, confusion in to binary matrix of size (N_classes, N_timepoints)