dlc2action.utils

Utility functions

  1#
  2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6"""
  7## Utility functions
  8
  9- `TensorDict` is a convenient data structure for keeping (and indexing) lists of feature dictionaries,
 10- `apply_threshold`, `apply_threshold_hysteresis` and `apply_threshold_max` are utility functions for
 11`dlc2action.data.dataset.BehaviorDataset.find_valleys`,
 12- `strip_suffix` is used to get rid of suffices if a string (usually filename) ends with one of them,
 13- `strip_prefix` is used to get rid of prefixes if a string (usually filename) starts with one of them,
 14- `rotation_matrix_2d` and `rotation_matrix_3d` are used to generate rotation matrices by
 15`dlc2action.transformer.base_transformer.Transformer` instances
 16"""
 17
 18import torch
 19from typing import List, Dict, Union
 20from collections.abc import Iterable
 21import warnings
 22import os
 23import numpy as np
 24from torch import nn
 25from torch.nn import functional as F
 26import math
 27
 28
 29class TensorDict:
 30    """
 31    A class that handles indexing in a dictionary of tensors of the same length
 32    """
 33
 34    def __init__(self, obj: Union[Dict, Iterable] = None) -> None:
 35        """
 36        Parameters
 37        ----------
 38        obj : dict | iterable, optional
 39            either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with
 40            the same keys (if not passed, a blank TensorDict is initialized)
 41        """
 42
 43        if obj is None:
 44            obj = {}
 45        if isinstance(obj, list):
 46            self.keys = list(obj[0].keys())
 47            self.dict = {key: [] for key in self.keys}
 48            with warnings.catch_warnings():
 49                warnings.filterwarnings("ignore", category=UserWarning)
 50                for element in obj:
 51                    for key in self.keys:
 52                        self.dict[key].append(torch.tensor(element[key]))
 53            # self.dict = {k: torch.stack(v) for k, v in self.dict.items()}
 54            old_dict = self.dict
 55            self.dict = {}
 56            for k, v in old_dict.items():
 57                self.dict[k] = torch.stack(v)
 58        elif isinstance(obj, dict):
 59            if not all([isinstance(obj[key], torch.Tensor) for key in obj]):
 60                raise TypeError(
 61                    f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;"
 62                    f"got {[type(obj[key]) for key in obj]}"
 63                )
 64            lengths = [len(obj[key]) for key in obj]
 65            if not all([x == lengths[0] for x in lengths]):
 66                raise ValueError(
 67                    f"The value tensors in the dictionary passed to TensorDict need to have the same length;"
 68                    f"got {lengths}"
 69                )
 70            self.dict = obj
 71            self.keys = list(self.dict.keys())
 72        else:
 73            raise TypeError(
 74                f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary "
 75                f"of tensors (got {type(obj)})"
 76            )
 77        if len(self.keys) > 0:
 78            self.type = type(self.dict[self.keys[0]])
 79        else:
 80            self.type = None
 81
 82    def __len__(self) -> int:
 83        return len(self.dict[self.keys[0]])
 84
 85    def __getitem__(self, ind: Union[int, List]):
 86        """
 87        Index the TensorDict
 88
 89        Parameters
 90        ----------
 91        ind : int | list
 92            the index/indices
 93
 94        Returns
 95        -------
 96        dict | TensorDict
 97            the indexed elements of all value lists combined in a dictionary (if length of ind is 1) or a TensorDict
 98        """
 99
100        x = {key: self.dict[key][ind] for key in self.dict}
101        if type(ind) is not int:
102            d = TensorDict(x)
103            return d
104        else:
105            return x
106
107    def append(self, element: Dict) -> None:
108        """
109        Append an element
110
111        Parameters
112        ----------
113        element : dict
114            a dictionary
115        """
116
117        type_element = type(element[list(element.keys())[0]])
118        if self.type is None:
119            self.dict = {k: v.unsqueeze(0) for k, v in element.items()}
120            self.keys = list(element.keys())
121            self.type = type_element
122        else:
123            for key in self.keys:
124                if key not in element:
125                    raise ValueError(
126                        f"The dictionary appended to TensorDict needs to have the same keys as the "
127                        f"TensorDict; got {element.keys()} and {self.keys}"
128                    )
129                self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)])
130
131    def remove(self, indices: List) -> None:
132        """
133        Remove indexed elements
134
135        Parameters
136        ----------
137        indices : list
138            the indices to remove
139        """
140
141        mask = torch.ones(len(self))
142        mask[indices] = 0
143        mask = mask.bool()
144        for key, value in self.dict.items():
145            self.dict[key] = value[mask]
146
147
148def apply_threshold(
149    tensor: torch.Tensor,
150    threshold: float,
151    low: bool = True,
152    error_mask: torch.Tensor = None,
153    min_frames: int = 0,
154    smooth_interval: int = 0,
155    masked_intervals: List = None,
156):
157    """
158    Apply a hard threshold to a tensor and return indices of the intervals that passed
159
160    If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold.
161    If `min_frames` is not 0, the intervals are additionally filtered by length.
162
163    Parameters
164    ----------
165    tensor : torch.Tensor
166        the tensor to apply the threshold to
167    threshold : float
168        the threshold
169    error_mask : torch.Tensor, optional
170        a boolean real_lens to apply to the results
171    min_frames : int, default 0
172        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
173
174    Returns
175    -------
176    indices_start : list
177        a list of indices of the first frames of the chosen intervals
178    indices_end : list
179        a list of indices of the last frames of the chosen intervals
180    """
181
182    if masked_intervals is None:
183        masked_intervals = []
184    if low:
185        p = tensor <= threshold
186    else:
187        p = tensor >= threshold
188    p = smooth(p, smooth_interval)
189    if error_mask is not None:
190        p = p * error_mask
191    for start, end in masked_intervals:
192        p[start:end] = False
193    output, indices, counts = torch.unique_consecutive(
194        p, return_inverse=True, return_counts=True
195    )
196    long_indices = torch.where(output * (counts > min_frames))[0]
197    res_indices_start = [
198        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
199    ]
200    res_indices_end = [
201        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
202    ]
203    return res_indices_start, res_indices_end
204
205
206def apply_threshold_hysteresis(
207    tensor: torch.Tensor,
208    soft_threshold: float,
209    hard_threshold: float,
210    low: bool = True,
211    error_mask: torch.Tensor = None,
212    min_frames: int = 0,
213    smooth_interval: int = 0,
214    masked_intervals: List = None,
215):
216    """
217    Apply a hysteresis threshold to a tensor and return indices of the intervals that passed
218
219    In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold.
220    If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold.
221    If `min_frames` is not 0, the intervals are additionally filtered by length.
222
223    Parameters
224    ----------
225    tensor : torch.Tensor
226        the tensor to apply the threshold to
227    soft_threshold : float
228        the soft threshold
229    hard_threshold : float
230        the hard threshold
231    error_mask : torch.Tensor, optional
232        a boolean real_lens to apply to the results
233    min_frames : int, default 0
234        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
235
236    Returns
237    -------
238    indices_start : list
239        a list of indices of the first frames of the chosen intervals
240    indices_end : list
241        a list of indices of the last frames of the chosen intervals
242    """
243
244    if masked_intervals is None:
245        masked_intervals = []
246    if low:
247        p = tensor <= soft_threshold
248        hard = tensor <= hard_threshold
249    else:
250        p = tensor >= soft_threshold
251        hard = tensor >= hard_threshold
252    p = smooth(p, smooth_interval)
253    if error_mask is not None:
254        p = p * error_mask
255    for start, end in masked_intervals:
256        p[start:end] = False
257    output, indices, counts = torch.unique_consecutive(
258        p, return_inverse=True, return_counts=True
259    )
260    long_indices = torch.where(output * (counts > min_frames))[0]
261    indices_start = [
262        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
263    ]
264    indices_end = [
265        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
266    ]
267    res_indices_start = []
268    res_indices_end = []
269    for start, end in zip(indices_start, indices_end):
270        if torch.sum(hard[start:end]) > 0:
271            res_indices_start.append(start)
272            res_indices_end.append(end)
273    return res_indices_start, res_indices_end
274
275
276def apply_threshold_max(
277    tensor: torch.Tensor,
278    threshold: float,
279    main_class: int,
280    error_mask: torch.Tensor = None,
281    min_frames: int = 0,
282    smooth_interval: int = 0,
283    masked_intervals: List = None,
284):
285    """
286    Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed
287
288    In the chosen intervals the values at the `main_class` index are larger than the others everywhere
289    and at least one value at the `main_class` index passes the threshold.
290    If `error_mask` is not `None`, the elements marked `False`are treated as if they did not pass the threshold.
291    If min_frames is not 0, the intervals are additionally filtered by length.
292
293    Parameters
294    ----------
295    tensor : torch.Tensor
296        the tensor to apply the threshold to (of shape `(#classes, #frames)`)
297    threshold : float
298        the threshold
299    main_class : int
300        the class that conditions the soft threshold
301    error_mask : torch.Tensor, optional
302        a boolean real_lens to apply to the results
303    min_frames : int, default 0
304        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
305
306    Returns
307    -------
308    indices_start : list
309        a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor)
310    indices_end : list
311        a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor)
312    """
313
314    if masked_intervals is None:
315        masked_intervals = []
316    _, indices = torch.max(tensor, dim=0)
317    p = indices == main_class
318    p = smooth(p, smooth_interval)
319    if error_mask is not None:
320        p = p * error_mask
321    for start, end in masked_intervals:
322        p[start:end] = False
323    output, indices, counts = torch.unique_consecutive(
324        p, return_inverse=True, return_counts=True
325    )
326    long_indices = torch.where(output * (counts > min_frames))[0]
327    indices_start = [
328        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
329    ]
330    indices_end = [
331        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
332    ]
333    res_indices_start = []
334    res_indices_end = []
335    if threshold is not None:
336        hard = tensor[main_class, :] > threshold
337        for start, end in zip(indices_start, indices_end):
338            if torch.sum(hard[start:end]) > 0:
339                res_indices_start.append(start)
340                res_indices_end.append(end)
341        return res_indices_start, res_indices_end
342    else:
343        return indices_start, indices_end
344
345
346def strip_suffix(text: str, suffix: Iterable):
347    """
348    Strip a suffix from a string if it is contained in a list
349
350    Parameters
351    ----------
352    text : str
353        the main string
354    suffix : iterable
355        the list of suffices to be stripped
356
357    Returns
358    -------
359    result : str
360        the stripped string
361    """
362
363    for s in suffix:
364        if text.endswith(s):
365            return text[: -len(s)]
366    return text
367
368
369def strip_prefix(text: str, prefix: Iterable):
370    """
371    Strip a prefix from a string if it is contained in a list
372
373    Parameters
374    ----------
375    text : str
376        the main string
377    prefix : iterable
378        the list of prefixes to be stripped
379
380    Returns
381    -------
382    result : str
383        the stripped string
384    """
385
386    if prefix is None:
387        prefix = []
388    for s in prefix:
389        if text.startswith(s):
390            return text[len(s) :]
391    return text
392
393
394def rotation_matrix_2d(angles: torch.Tensor) -> torch.Tensor:
395    """
396    Create a tensor of 2D rotation matrices from a tensor of angles
397
398    Parameters
399    ----------
400    angles : torch.Tensor
401        a tensor of angles of arbitrary shape `(...)`
402
403    Returns
404    -------
405    rotation_matrices : torch.Tensor
406        a tensor of 2D rotation matrices of shape `(..., 2, 2)`
407    """
408
409    cos = torch.cos(angles)
410    sin = torch.sin(angles)
411    R = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*angles.shape, 2, 2)
412    return R
413
414
415def rotation_matrix_3d(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor):
416    """
417    Create a tensor of 3D rotation matrices from a tensor of angles
418
419    Parameters
420    ----------
421    alpha : torch.Tensor
422        a tensor of rotation angles around the x axis of arbitrary shape `(...)`
423    beta : torch.Tensor
424        a tensor of rotation angles around the y axis of arbitrary shape `(...)`
425    gamma : torch.Tensor
426        a tensor of rotation angles around the z axis of arbitrary shape `(...)`
427
428    Returns
429    -------
430    rotation_matrices : torch.Tensor
431        a tensor of 3D rotation matrices of shape `(..., 3, 3)`
432    """
433
434    cos = torch.cos(alpha)
435    sin = torch.sin(alpha)
436    Rx = torch.stack(
437        [
438            torch.ones(cos.shape),
439            torch.zeros(cos.shape),
440            torch.zeros(cos.shape),
441            torch.zeros(cos.shape),
442            cos,
443            -sin,
444            torch.zeros(cos.shape),
445            sin,
446            cos,
447        ],
448        dim=-1,
449    ).reshape(*alpha.shape, 3, 3)
450    cos = torch.cos(beta)
451    sin = torch.sin(beta)
452    Ry = torch.stack(
453        [
454            cos,
455            torch.zeros(cos.shape),
456            sin,
457            torch.zeros(cos.shape),
458            torch.ones(cos.shape),
459            torch.zeros(cos.shape),
460            -sin,
461            torch.zeros(cos.shape),
462            cos,
463        ],
464        dim=-1,
465    ).reshape(*beta.shape, 3, 3)
466    cos = torch.cos(gamma)
467    sin = torch.sin(gamma)
468    Rz = torch.stack(
469        [
470            cos,
471            -sin,
472            torch.zeros(cos.shape),
473            sin,
474            cos,
475            torch.zeros(cos.shape),
476            torch.zeros(cos.shape),
477            torch.zeros(cos.shape),
478            torch.ones(cos.shape),
479        ],
480        dim=-1,
481    ).reshape(*gamma.shape, 3, 3)
482    R = torch.einsum("...ij,...jk,...kl->...il", Rx, Ry, Rz)
483    return R
484
485
486def correct_path(path, project_path):
487    if not isinstance(path, str):
488        return path
489    path = os.path.normpath(path).split(os.path.sep)
490    if "results" in path:
491        name = "results"
492    else:
493        name = "saved_datasets"
494    ind = path.index(name) + 1
495    return os.path.join(project_path, "results", *path[ind:])
496
497
498class TensorList(list):
499    """
500    A list of tensors that can send each element to a `torch` device
501    """
502
503    def to_device(self, device: torch.device):
504        for i, x in enumerate(self):
505            self[i] = x.to_device(device)
506
507
508def get_intervals(tensor: torch.Tensor) -> torch.Tensor:
509    """
510    Get a list of True group beginning and end indices from a boolean tensor
511    """
512
513    output, indices = torch.unique_consecutive(tensor, return_inverse=True)
514    true_indices = torch.where(output)[0]
515    starts = torch.tensor(
516        [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
517    )
518    ends = torch.tensor(
519        [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
520    )
521    return torch.stack([starts, ends]).T
522
523
524def smooth(tensor: torch.Tensor, smooth_interval: int = 0) -> torch.Tensor:
525    """
526    Get rid of jittering in a non-exclusive classification tensor
527
528    First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than
529    `smooth_interval`.
530    """
531
532    if smooth_interval == 0:
533        return tensor
534    intervals = get_intervals(tensor == 0)
535    interval_lengths = torch.tensor(
536        [interval[1] - interval[0] for interval in intervals]
537    )
538    short_intervals = intervals[interval_lengths <= smooth_interval]
539    for start, end in short_intervals:
540        tensor[start:end] = 1
541    intervals = get_intervals(tensor == 1)
542    interval_lengths = torch.tensor(
543        [interval[1] - interval[0] for interval in intervals]
544    )
545    short_intervals = intervals[interval_lengths <= smooth_interval]
546    for start, end in short_intervals:
547        tensor[start:end] = 0
548    return tensor
549
550
551class GaussianSmoothing(nn.Module):
552    """
553    Apply gaussian smoothing on a 1d tensor.
554    Filtering is performed seperately for each channel
555    in the input using a depthwise convolution.
556    Arguments:
557        channels (int, sequence): Number of channels of the input tensors. Output will
558            have this number of channels as well.
559        kernel_size (int, sequence): Size of the gaussian kernel.
560        sigma (float, sequence): Standard deviation of the gaussian kernel.
561    """
562
563    def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None:
564        super().__init__()
565        self.kernel_size = kernel_size
566
567        # The gaussian kernel is the product of the
568        # gaussian function of each dimension.
569        kernel = 1
570        meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float()
571
572        mean = (kernel_size - 1) / 2
573        kernel = kernel / (sigma * math.sqrt(2 * math.pi))
574        kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2)
575
576        # Make sure sum of values in gaussian kernel equals 1.
577        # kernel = kernel / torch.max(kernel)
578
579        self.kernel = kernel.view(1, 1, *kernel.size())
580
581    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
582        """
583        Apply gaussian filter to input.
584        Arguments:
585            input (torch.Tensor): Input to apply gaussian filter on.
586        Returns:
587            filtered (torch.Tensor): Filtered output.
588        """
589        _, c, _ = inputs.shape
590        inputs = F.pad(
591            inputs,
592            pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2),
593            mode="reflect",
594        )
595        kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device)
596        return F.conv1d(inputs, weight=kernel, groups=c)
597
598
599def argrelmax(prob: np.ndarray, threshold: float = 0.7) -> List[int]:
600    """
601    Calculate arguments of relative maxima.
602    prob: np.array. boundary probability maps distributerd in [0, 1]
603    prob shape is (T)
604    ignore the peak whose value is under threshold
605    Return:
606        Index of peaks for each batch
607    """
608    # ignore the values under threshold
609    prob[prob < threshold] = 0.0
610
611    # calculate the relative maxima of boundary maps
612    # treat the first frame as boundary
613    peak = np.concatenate(
614        [
615            np.ones((1), dtype=np.bool),
616            (prob[:-2] < prob[1:-1]) & (prob[2:] < prob[1:-1]),
617            np.zeros((1), dtype=np.bool),
618        ],
619        axis=0,
620    )
621
622    peak_idx = np.where(peak)[0].tolist()
623
624    return peak_idx
625
626
627def decide_boundary_prob_with_similarity(x: torch.Tensor) -> torch.Tensor:
628    """
629    Decide action boundary probabilities based on adjacent frame similarities.
630    Args:
631        x: frame-wise video features (N, C, T)
632    Return:
633        boundary: action boundary probability (N, 1, T)
634    """
635    device = x.device
636
637    # gaussian kernel.
638    diff = x[0, :, 1:] - x[0, :, :-1]
639    similarity = torch.exp(-torch.norm(diff, dim=0) / (2 * 1.0))
640
641    # define action starting point as action boundary.
642    start = torch.ones(1).float().to(device)
643    boundary = torch.cat([start, similarity])
644    boundary = boundary.view(1, 1, -1)
645    return boundary
646
647
648class PostProcessor(object):
649    def __init__(
650        self,
651        name: str,
652        boundary_th: int = 0.7,
653        theta_t: int = 15,
654        kernel_size: int = 15,
655    ) -> None:
656        self.func = {
657            "refinement_with_boundary": self._refinement_with_boundary,
658            "relabeling": self._relabeling,
659            "smoothing": self._smoothing,
660        }
661        assert name in self.func
662
663        self.name = name
664        self.boundary_th = boundary_th
665        self.theta_t = theta_t
666        self.kernel_size = kernel_size
667
668        if name == "smoothing":
669            self.filter = GaussianSmoothing(self.kernel_size)
670
671    def _is_probability(self, x: np.ndarray) -> bool:
672        assert x.ndim == 3
673
674        if x.shape[1] == 1:
675            # sigmoid
676            if x.min() >= 0 and x.max() <= 1:
677                return True
678            else:
679                return False
680        else:
681            # softmax
682            _sum = np.sum(x, axis=1).astype(np.float32)
683            _ones = np.ones_like(_sum, dtype=np.float32)
684            return np.allclose(_sum, _ones)
685
686    def _convert2probability(self, x: np.ndarray) -> np.ndarray:
687        """
688        Args: x (N, C, T)
689        """
690        assert x.ndim == 3
691
692        if self._is_probability(x):
693            return x
694        else:
695            if x.shape[1] == 1:
696                # sigmoid
697                prob = 1 / (1 + np.exp(-x))
698            else:
699                # softmax
700                prob = np.exp(x) / np.expand_dims(np.sum(np.exp(x), axis=1), 1)
701            return prob.astype(np.float32)
702
703    def _convert2label(self, x: np.ndarray) -> np.ndarray:
704        assert x.ndim == 2 or x.ndim == 3
705
706        if x.ndim == 2:
707            return x.astype(np.int64)
708        else:
709            if not self._is_probability(x):
710                x = self._convert2probability(x)
711
712            label = np.argmax(x, axis=1)
713            return label.astype(np.int64)
714
715    def _refinement_with_boundary(
716        self,
717        outputs: np.array,
718        boundaries: np.ndarray,
719    ) -> np.ndarray:
720        """
721        Get segments which is defined as the span b/w two boundaries,
722        and decide their classes by majority vote.
723        Args:
724            outputs: numpy array. shape (N, C, T)
725                the model output for frame-level class prediction.
726            boundaries: numpy array.  shape (N, 1, T)
727                boundary prediction.
728            masks: np.array. np.bool. shape (N, 1, T)
729                valid length for each video
730        Return:
731            preds: np.array. shape (N, T)
732                final class prediction considering boundaries.
733        """
734
735        preds = self._convert2label(outputs)
736        boundaries = self._convert2probability(boundaries)
737
738        for i, (output, pred, boundary) in enumerate(zip(outputs, preds, boundaries)):
739            idx = argrelmax(boundary.squeeze(), threshold=self.boundary_th)
740
741            # add the index of the last action ending
742            T = pred.shape[0]
743            idx.append(T)
744
745            # majority vote
746            for j in range(len(idx) - 1):
747                count = np.bincount(pred[idx[j] : idx[j + 1]])
748                modes = np.where(count == count.max())[0]
749                if len(modes) == 1:
750                    mode = modes
751                else:
752                    if outputs.ndim == 3:
753                        # if more than one majority class exist
754                        prob_sum_max = 0
755                        for m in modes:
756                            prob_sum = output[m, idx[j] : idx[j + 1]].sum()
757                            if prob_sum_max < prob_sum:
758                                mode = m
759                                prob_sum_max = prob_sum
760                    else:
761                        # decide first mode when more than one majority class
762                        # have the same number during oracle experiment
763                        mode = modes[0]
764
765                preds[i, idx[j] : idx[j + 1]] = mode
766
767        return preds
768
769    def _relabeling(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
770        """
771        Relabeling small action segments with their previous action segment
772        Args:
773            output: the results of action segmentation. (N, T) or (N, C, T)
774            theta_t: the threshold of the size of action segments.
775        Return:
776            relabeled output. (N, T)
777        """
778
779        preds = self._convert2label(outputs)
780
781        for i in range(preds.shape[0]):
782            # shape (T,)
783            last = preds[i][0]
784            cnt = 1
785            for j in range(1, preds.shape[1]):
786                if last == preds[i][j]:
787                    cnt += 1
788                else:
789                    if cnt > self.theta_t:
790                        cnt = 1
791                        last = preds[i][j]
792                    else:
793                        preds[i][j - cnt : j] = preds[i][j - cnt - 1]
794                        cnt = 1
795                        last = preds[i][j]
796
797            if cnt <= self.theta_t:
798                preds[i][j - cnt : j] = preds[i][j - cnt - 1]
799
800        return preds
801
802    def _smoothing(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
803        """
804        Smoothing action probabilities with gaussian filter.
805        Args:
806            outputs: frame-wise action probabilities. (N, C, T)
807        Return:
808            predictions: final prediction. (N, T)
809        """
810
811        outputs = self._convert2probability(outputs)
812        outputs = self.filter(torch.Tensor(outputs)).numpy()
813
814        preds = self._convert2label(outputs)
815        return preds
816
817    def __call__(self, outputs, **kwargs: np.ndarray) -> np.ndarray:
818
819        preds = self.func[self.name](outputs, **kwargs)
820        return preds
class TensorDict:
 30class TensorDict:
 31    """
 32    A class that handles indexing in a dictionary of tensors of the same length
 33    """
 34
 35    def __init__(self, obj: Union[Dict, Iterable] = None) -> None:
 36        """
 37        Parameters
 38        ----------
 39        obj : dict | iterable, optional
 40            either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with
 41            the same keys (if not passed, a blank TensorDict is initialized)
 42        """
 43
 44        if obj is None:
 45            obj = {}
 46        if isinstance(obj, list):
 47            self.keys = list(obj[0].keys())
 48            self.dict = {key: [] for key in self.keys}
 49            with warnings.catch_warnings():
 50                warnings.filterwarnings("ignore", category=UserWarning)
 51                for element in obj:
 52                    for key in self.keys:
 53                        self.dict[key].append(torch.tensor(element[key]))
 54            # self.dict = {k: torch.stack(v) for k, v in self.dict.items()}
 55            old_dict = self.dict
 56            self.dict = {}
 57            for k, v in old_dict.items():
 58                self.dict[k] = torch.stack(v)
 59        elif isinstance(obj, dict):
 60            if not all([isinstance(obj[key], torch.Tensor) for key in obj]):
 61                raise TypeError(
 62                    f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;"
 63                    f"got {[type(obj[key]) for key in obj]}"
 64                )
 65            lengths = [len(obj[key]) for key in obj]
 66            if not all([x == lengths[0] for x in lengths]):
 67                raise ValueError(
 68                    f"The value tensors in the dictionary passed to TensorDict need to have the same length;"
 69                    f"got {lengths}"
 70                )
 71            self.dict = obj
 72            self.keys = list(self.dict.keys())
 73        else:
 74            raise TypeError(
 75                f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary "
 76                f"of tensors (got {type(obj)})"
 77            )
 78        if len(self.keys) > 0:
 79            self.type = type(self.dict[self.keys[0]])
 80        else:
 81            self.type = None
 82
 83    def __len__(self) -> int:
 84        return len(self.dict[self.keys[0]])
 85
 86    def __getitem__(self, ind: Union[int, List]):
 87        """
 88        Index the TensorDict
 89
 90        Parameters
 91        ----------
 92        ind : int | list
 93            the index/indices
 94
 95        Returns
 96        -------
 97        dict | TensorDict
 98            the indexed elements of all value lists combined in a dictionary (if length of ind is 1) or a TensorDict
 99        """
100
101        x = {key: self.dict[key][ind] for key in self.dict}
102        if type(ind) is not int:
103            d = TensorDict(x)
104            return d
105        else:
106            return x
107
108    def append(self, element: Dict) -> None:
109        """
110        Append an element
111
112        Parameters
113        ----------
114        element : dict
115            a dictionary
116        """
117
118        type_element = type(element[list(element.keys())[0]])
119        if self.type is None:
120            self.dict = {k: v.unsqueeze(0) for k, v in element.items()}
121            self.keys = list(element.keys())
122            self.type = type_element
123        else:
124            for key in self.keys:
125                if key not in element:
126                    raise ValueError(
127                        f"The dictionary appended to TensorDict needs to have the same keys as the "
128                        f"TensorDict; got {element.keys()} and {self.keys}"
129                    )
130                self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)])
131
132    def remove(self, indices: List) -> None:
133        """
134        Remove indexed elements
135
136        Parameters
137        ----------
138        indices : list
139            the indices to remove
140        """
141
142        mask = torch.ones(len(self))
143        mask[indices] = 0
144        mask = mask.bool()
145        for key, value in self.dict.items():
146            self.dict[key] = value[mask]

A class that handles indexing in a dictionary of tensors of the same length

TensorDict(obj: Union[Dict, collections.abc.Iterable] = None)
35    def __init__(self, obj: Union[Dict, Iterable] = None) -> None:
36        """
37        Parameters
38        ----------
39        obj : dict | iterable, optional
40            either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with
41            the same keys (if not passed, a blank TensorDict is initialized)
42        """
43
44        if obj is None:
45            obj = {}
46        if isinstance(obj, list):
47            self.keys = list(obj[0].keys())
48            self.dict = {key: [] for key in self.keys}
49            with warnings.catch_warnings():
50                warnings.filterwarnings("ignore", category=UserWarning)
51                for element in obj:
52                    for key in self.keys:
53                        self.dict[key].append(torch.tensor(element[key]))
54            # self.dict = {k: torch.stack(v) for k, v in self.dict.items()}
55            old_dict = self.dict
56            self.dict = {}
57            for k, v in old_dict.items():
58                self.dict[k] = torch.stack(v)
59        elif isinstance(obj, dict):
60            if not all([isinstance(obj[key], torch.Tensor) for key in obj]):
61                raise TypeError(
62                    f"The values in the dictionary passed to TensorDict need to be torch.Tensor instances;"
63                    f"got {[type(obj[key]) for key in obj]}"
64                )
65            lengths = [len(obj[key]) for key in obj]
66            if not all([x == lengths[0] for x in lengths]):
67                raise ValueError(
68                    f"The value tensors in the dictionary passed to TensorDict need to have the same length;"
69                    f"got {lengths}"
70                )
71            self.dict = obj
72            self.keys = list(self.dict.keys())
73        else:
74            raise TypeError(
75                f"TensorDict can only be constructed from an iterable of dictionaries of from a dictionary "
76                f"of tensors (got {type(obj)})"
77            )
78        if len(self.keys) > 0:
79            self.type = type(self.dict[self.keys[0]])
80        else:
81            self.type = None

Parameters

obj : dict | iterable, optional either a dictionary of torch.Tensor instances of the same length or an iterable of dictionaries with the same keys (if not passed, a blank TensorDict is initialized)

def append(self, element: Dict) -> None:
108    def append(self, element: Dict) -> None:
109        """
110        Append an element
111
112        Parameters
113        ----------
114        element : dict
115            a dictionary
116        """
117
118        type_element = type(element[list(element.keys())[0]])
119        if self.type is None:
120            self.dict = {k: v.unsqueeze(0) for k, v in element.items()}
121            self.keys = list(element.keys())
122            self.type = type_element
123        else:
124            for key in self.keys:
125                if key not in element:
126                    raise ValueError(
127                        f"The dictionary appended to TensorDict needs to have the same keys as the "
128                        f"TensorDict; got {element.keys()} and {self.keys}"
129                    )
130                self.dict[key] = torch.cat([self.dict[key], element[key].unsqueeze(0)])

Append an element

Parameters

element : dict a dictionary

def remove(self, indices: List) -> None:
132    def remove(self, indices: List) -> None:
133        """
134        Remove indexed elements
135
136        Parameters
137        ----------
138        indices : list
139            the indices to remove
140        """
141
142        mask = torch.ones(len(self))
143        mask[indices] = 0
144        mask = mask.bool()
145        for key, value in self.dict.items():
146            self.dict[key] = value[mask]

Remove indexed elements

Parameters

indices : list the indices to remove

def apply_threshold( tensor: torch.Tensor, threshold: float, low: bool = True, error_mask: torch.Tensor = None, min_frames: int = 0, smooth_interval: int = 0, masked_intervals: List = None)
149def apply_threshold(
150    tensor: torch.Tensor,
151    threshold: float,
152    low: bool = True,
153    error_mask: torch.Tensor = None,
154    min_frames: int = 0,
155    smooth_interval: int = 0,
156    masked_intervals: List = None,
157):
158    """
159    Apply a hard threshold to a tensor and return indices of the intervals that passed
160
161    If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold.
162    If `min_frames` is not 0, the intervals are additionally filtered by length.
163
164    Parameters
165    ----------
166    tensor : torch.Tensor
167        the tensor to apply the threshold to
168    threshold : float
169        the threshold
170    error_mask : torch.Tensor, optional
171        a boolean real_lens to apply to the results
172    min_frames : int, default 0
173        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
174
175    Returns
176    -------
177    indices_start : list
178        a list of indices of the first frames of the chosen intervals
179    indices_end : list
180        a list of indices of the last frames of the chosen intervals
181    """
182
183    if masked_intervals is None:
184        masked_intervals = []
185    if low:
186        p = tensor <= threshold
187    else:
188        p = tensor >= threshold
189    p = smooth(p, smooth_interval)
190    if error_mask is not None:
191        p = p * error_mask
192    for start, end in masked_intervals:
193        p[start:end] = False
194    output, indices, counts = torch.unique_consecutive(
195        p, return_inverse=True, return_counts=True
196    )
197    long_indices = torch.where(output * (counts > min_frames))[0]
198    res_indices_start = [
199        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
200    ]
201    res_indices_end = [
202        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
203    ]
204    return res_indices_start, res_indices_end

Apply a hard threshold to a tensor and return indices of the intervals that passed

If error_mask is not None, the elements marked False are treated as if they did not pass the threshold. If min_frames is not 0, the intervals are additionally filtered by length.

Parameters

tensor : torch.Tensor the tensor to apply the threshold to threshold : float the threshold error_mask : torch.Tensor, optional a boolean real_lens to apply to the results min_frames : int, default 0 the minimum number of frames in the resulting intervals (shorter intervals are discarded)

Returns

indices_start : list a list of indices of the first frames of the chosen intervals indices_end : list a list of indices of the last frames of the chosen intervals

def apply_threshold_hysteresis( tensor: torch.Tensor, soft_threshold: float, hard_threshold: float, low: bool = True, error_mask: torch.Tensor = None, min_frames: int = 0, smooth_interval: int = 0, masked_intervals: List = None)
207def apply_threshold_hysteresis(
208    tensor: torch.Tensor,
209    soft_threshold: float,
210    hard_threshold: float,
211    low: bool = True,
212    error_mask: torch.Tensor = None,
213    min_frames: int = 0,
214    smooth_interval: int = 0,
215    masked_intervals: List = None,
216):
217    """
218    Apply a hysteresis threshold to a tensor and return indices of the intervals that passed
219
220    In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold.
221    If `error_mask` is not `None`, the elements marked `False` are treated as if they did not pass the threshold.
222    If `min_frames` is not 0, the intervals are additionally filtered by length.
223
224    Parameters
225    ----------
226    tensor : torch.Tensor
227        the tensor to apply the threshold to
228    soft_threshold : float
229        the soft threshold
230    hard_threshold : float
231        the hard threshold
232    error_mask : torch.Tensor, optional
233        a boolean real_lens to apply to the results
234    min_frames : int, default 0
235        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
236
237    Returns
238    -------
239    indices_start : list
240        a list of indices of the first frames of the chosen intervals
241    indices_end : list
242        a list of indices of the last frames of the chosen intervals
243    """
244
245    if masked_intervals is None:
246        masked_intervals = []
247    if low:
248        p = tensor <= soft_threshold
249        hard = tensor <= hard_threshold
250    else:
251        p = tensor >= soft_threshold
252        hard = tensor >= hard_threshold
253    p = smooth(p, smooth_interval)
254    if error_mask is not None:
255        p = p * error_mask
256    for start, end in masked_intervals:
257        p[start:end] = False
258    output, indices, counts = torch.unique_consecutive(
259        p, return_inverse=True, return_counts=True
260    )
261    long_indices = torch.where(output * (counts > min_frames))[0]
262    indices_start = [
263        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
264    ]
265    indices_end = [
266        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
267    ]
268    res_indices_start = []
269    res_indices_end = []
270    for start, end in zip(indices_start, indices_end):
271        if torch.sum(hard[start:end]) > 0:
272            res_indices_start.append(start)
273            res_indices_end.append(end)
274    return res_indices_start, res_indices_end

Apply a hysteresis threshold to a tensor and return indices of the intervals that passed

In the chosen intervals all values pass the soft threshold and at least one value passes the hard threshold. If error_mask is not None, the elements marked False are treated as if they did not pass the threshold. If min_frames is not 0, the intervals are additionally filtered by length.

Parameters

tensor : torch.Tensor the tensor to apply the threshold to soft_threshold : float the soft threshold hard_threshold : float the hard threshold error_mask : torch.Tensor, optional a boolean real_lens to apply to the results min_frames : int, default 0 the minimum number of frames in the resulting intervals (shorter intervals are discarded)

Returns

indices_start : list a list of indices of the first frames of the chosen intervals indices_end : list a list of indices of the last frames of the chosen intervals

def apply_threshold_max( tensor: torch.Tensor, threshold: float, main_class: int, error_mask: torch.Tensor = None, min_frames: int = 0, smooth_interval: int = 0, masked_intervals: List = None)
277def apply_threshold_max(
278    tensor: torch.Tensor,
279    threshold: float,
280    main_class: int,
281    error_mask: torch.Tensor = None,
282    min_frames: int = 0,
283    smooth_interval: int = 0,
284    masked_intervals: List = None,
285):
286    """
287    Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed
288
289    In the chosen intervals the values at the `main_class` index are larger than the others everywhere
290    and at least one value at the `main_class` index passes the threshold.
291    If `error_mask` is not `None`, the elements marked `False`are treated as if they did not pass the threshold.
292    If min_frames is not 0, the intervals are additionally filtered by length.
293
294    Parameters
295    ----------
296    tensor : torch.Tensor
297        the tensor to apply the threshold to (of shape `(#classes, #frames)`)
298    threshold : float
299        the threshold
300    main_class : int
301        the class that conditions the soft threshold
302    error_mask : torch.Tensor, optional
303        a boolean real_lens to apply to the results
304    min_frames : int, default 0
305        the minimum number of frames in the resulting intervals (shorter intervals are discarded)
306
307    Returns
308    -------
309    indices_start : list
310        a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor)
311    indices_end : list
312        a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor)
313    """
314
315    if masked_intervals is None:
316        masked_intervals = []
317    _, indices = torch.max(tensor, dim=0)
318    p = indices == main_class
319    p = smooth(p, smooth_interval)
320    if error_mask is not None:
321        p = p * error_mask
322    for start, end in masked_intervals:
323        p[start:end] = False
324    output, indices, counts = torch.unique_consecutive(
325        p, return_inverse=True, return_counts=True
326    )
327    long_indices = torch.where(output * (counts > min_frames))[0]
328    indices_start = [
329        (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
330    ]
331    indices_end = [
332        (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 for i in long_indices
333    ]
334    res_indices_start = []
335    res_indices_end = []
336    if threshold is not None:
337        hard = tensor[main_class, :] > threshold
338        for start, end in zip(indices_start, indices_end):
339            if torch.sum(hard[start:end]) > 0:
340                res_indices_start.append(start)
341                res_indices_end.append(end)
342        return res_indices_start, res_indices_end
343    else:
344        return indices_start, indices_end

Apply a max hysteresis threshold to a tensor and return indices of the intervals that passed

In the chosen intervals the values at the main_class index are larger than the others everywhere and at least one value at the main_class index passes the threshold. If error_mask is not None, the elements marked Falseare treated as if they did not pass the threshold. If min_frames is not 0, the intervals are additionally filtered by length.

Parameters

tensor : torch.Tensor the tensor to apply the threshold to (of shape (#classes, #frames)) threshold : float the threshold main_class : int the class that conditions the soft threshold error_mask : torch.Tensor, optional a boolean real_lens to apply to the results min_frames : int, default 0 the minimum number of frames in the resulting intervals (shorter intervals are discarded)

Returns

indices_start : list a list of indices of the first frames of the chosen intervals (along dimension 1 of input tensor) indices_end : list a list of indices of the last frames of the chosen intervals (along dimension 1 of input tensor)

def strip_suffix(text: str, suffix: collections.abc.Iterable)
347def strip_suffix(text: str, suffix: Iterable):
348    """
349    Strip a suffix from a string if it is contained in a list
350
351    Parameters
352    ----------
353    text : str
354        the main string
355    suffix : iterable
356        the list of suffices to be stripped
357
358    Returns
359    -------
360    result : str
361        the stripped string
362    """
363
364    for s in suffix:
365        if text.endswith(s):
366            return text[: -len(s)]
367    return text

Strip a suffix from a string if it is contained in a list

Parameters

text : str the main string suffix : iterable the list of suffices to be stripped

Returns

result : str the stripped string

def strip_prefix(text: str, prefix: collections.abc.Iterable)
370def strip_prefix(text: str, prefix: Iterable):
371    """
372    Strip a prefix from a string if it is contained in a list
373
374    Parameters
375    ----------
376    text : str
377        the main string
378    prefix : iterable
379        the list of prefixes to be stripped
380
381    Returns
382    -------
383    result : str
384        the stripped string
385    """
386
387    if prefix is None:
388        prefix = []
389    for s in prefix:
390        if text.startswith(s):
391            return text[len(s) :]
392    return text

Strip a prefix from a string if it is contained in a list

Parameters

text : str the main string prefix : iterable the list of prefixes to be stripped

Returns

result : str the stripped string

def rotation_matrix_2d(angles: torch.Tensor) -> torch.Tensor:
395def rotation_matrix_2d(angles: torch.Tensor) -> torch.Tensor:
396    """
397    Create a tensor of 2D rotation matrices from a tensor of angles
398
399    Parameters
400    ----------
401    angles : torch.Tensor
402        a tensor of angles of arbitrary shape `(...)`
403
404    Returns
405    -------
406    rotation_matrices : torch.Tensor
407        a tensor of 2D rotation matrices of shape `(..., 2, 2)`
408    """
409
410    cos = torch.cos(angles)
411    sin = torch.sin(angles)
412    R = torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*angles.shape, 2, 2)
413    return R

Create a tensor of 2D rotation matrices from a tensor of angles

Parameters

angles : torch.Tensor a tensor of angles of arbitrary shape (...)

Returns

rotation_matrices : torch.Tensor a tensor of 2D rotation matrices of shape (..., 2, 2)

def rotation_matrix_3d(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor)
416def rotation_matrix_3d(alpha: torch.Tensor, beta: torch.Tensor, gamma: torch.Tensor):
417    """
418    Create a tensor of 3D rotation matrices from a tensor of angles
419
420    Parameters
421    ----------
422    alpha : torch.Tensor
423        a tensor of rotation angles around the x axis of arbitrary shape `(...)`
424    beta : torch.Tensor
425        a tensor of rotation angles around the y axis of arbitrary shape `(...)`
426    gamma : torch.Tensor
427        a tensor of rotation angles around the z axis of arbitrary shape `(...)`
428
429    Returns
430    -------
431    rotation_matrices : torch.Tensor
432        a tensor of 3D rotation matrices of shape `(..., 3, 3)`
433    """
434
435    cos = torch.cos(alpha)
436    sin = torch.sin(alpha)
437    Rx = torch.stack(
438        [
439            torch.ones(cos.shape),
440            torch.zeros(cos.shape),
441            torch.zeros(cos.shape),
442            torch.zeros(cos.shape),
443            cos,
444            -sin,
445            torch.zeros(cos.shape),
446            sin,
447            cos,
448        ],
449        dim=-1,
450    ).reshape(*alpha.shape, 3, 3)
451    cos = torch.cos(beta)
452    sin = torch.sin(beta)
453    Ry = torch.stack(
454        [
455            cos,
456            torch.zeros(cos.shape),
457            sin,
458            torch.zeros(cos.shape),
459            torch.ones(cos.shape),
460            torch.zeros(cos.shape),
461            -sin,
462            torch.zeros(cos.shape),
463            cos,
464        ],
465        dim=-1,
466    ).reshape(*beta.shape, 3, 3)
467    cos = torch.cos(gamma)
468    sin = torch.sin(gamma)
469    Rz = torch.stack(
470        [
471            cos,
472            -sin,
473            torch.zeros(cos.shape),
474            sin,
475            cos,
476            torch.zeros(cos.shape),
477            torch.zeros(cos.shape),
478            torch.zeros(cos.shape),
479            torch.ones(cos.shape),
480        ],
481        dim=-1,
482    ).reshape(*gamma.shape, 3, 3)
483    R = torch.einsum("...ij,...jk,...kl->...il", Rx, Ry, Rz)
484    return R

Create a tensor of 3D rotation matrices from a tensor of angles

Parameters

alpha : torch.Tensor a tensor of rotation angles around the x axis of arbitrary shape (...) beta : torch.Tensor a tensor of rotation angles around the y axis of arbitrary shape (...) gamma : torch.Tensor a tensor of rotation angles around the z axis of arbitrary shape (...)

Returns

rotation_matrices : torch.Tensor a tensor of 3D rotation matrices of shape (..., 3, 3)

def correct_path(path, project_path)
487def correct_path(path, project_path):
488    if not isinstance(path, str):
489        return path
490    path = os.path.normpath(path).split(os.path.sep)
491    if "results" in path:
492        name = "results"
493    else:
494        name = "saved_datasets"
495    ind = path.index(name) + 1
496    return os.path.join(project_path, "results", *path[ind:])
class TensorList(builtins.list):
499class TensorList(list):
500    """
501    A list of tensors that can send each element to a `torch` device
502    """
503
504    def to_device(self, device: torch.device):
505        for i, x in enumerate(self):
506            self[i] = x.to_device(device)

A list of tensors that can send each element to a torch device

def to_device(self, device: torch.device)
504    def to_device(self, device: torch.device):
505        for i, x in enumerate(self):
506            self[i] = x.to_device(device)
Inherited Members
builtins.list
list
clear
copy
append
insert
extend
pop
remove
index
count
reverse
sort
def get_intervals(tensor: torch.Tensor) -> torch.Tensor:
509def get_intervals(tensor: torch.Tensor) -> torch.Tensor:
510    """
511    Get a list of True group beginning and end indices from a boolean tensor
512    """
513
514    output, indices = torch.unique_consecutive(tensor, return_inverse=True)
515    true_indices = torch.where(output)[0]
516    starts = torch.tensor(
517        [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
518    )
519    ends = torch.tensor(
520        [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
521    )
522    return torch.stack([starts, ends]).T

Get a list of True group beginning and end indices from a boolean tensor

def smooth(tensor: torch.Tensor, smooth_interval: int = 0) -> torch.Tensor:
525def smooth(tensor: torch.Tensor, smooth_interval: int = 0) -> torch.Tensor:
526    """
527    Get rid of jittering in a non-exclusive classification tensor
528
529    First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than
530    `smooth_interval`.
531    """
532
533    if smooth_interval == 0:
534        return tensor
535    intervals = get_intervals(tensor == 0)
536    interval_lengths = torch.tensor(
537        [interval[1] - interval[0] for interval in intervals]
538    )
539    short_intervals = intervals[interval_lengths <= smooth_interval]
540    for start, end in short_intervals:
541        tensor[start:end] = 1
542    intervals = get_intervals(tensor == 1)
543    interval_lengths = torch.tensor(
544        [interval[1] - interval[0] for interval in intervals]
545    )
546    short_intervals = intervals[interval_lengths <= smooth_interval]
547    for start, end in short_intervals:
548        tensor[start:end] = 0
549    return tensor

Get rid of jittering in a non-exclusive classification tensor

First, remove intervals of 0 shorter than smooth_interval. Then, remove intervals of 1 shorter than smooth_interval.

class GaussianSmoothing(torch.nn.modules.module.Module):
552class GaussianSmoothing(nn.Module):
553    """
554    Apply gaussian smoothing on a 1d tensor.
555    Filtering is performed seperately for each channel
556    in the input using a depthwise convolution.
557    Arguments:
558        channels (int, sequence): Number of channels of the input tensors. Output will
559            have this number of channels as well.
560        kernel_size (int, sequence): Size of the gaussian kernel.
561        sigma (float, sequence): Standard deviation of the gaussian kernel.
562    """
563
564    def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None:
565        super().__init__()
566        self.kernel_size = kernel_size
567
568        # The gaussian kernel is the product of the
569        # gaussian function of each dimension.
570        kernel = 1
571        meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float()
572
573        mean = (kernel_size - 1) / 2
574        kernel = kernel / (sigma * math.sqrt(2 * math.pi))
575        kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2)
576
577        # Make sure sum of values in gaussian kernel equals 1.
578        # kernel = kernel / torch.max(kernel)
579
580        self.kernel = kernel.view(1, 1, *kernel.size())
581
582    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
583        """
584        Apply gaussian filter to input.
585        Arguments:
586            input (torch.Tensor): Input to apply gaussian filter on.
587        Returns:
588            filtered (torch.Tensor): Filtered output.
589        """
590        _, c, _ = inputs.shape
591        inputs = F.pad(
592            inputs,
593            pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2),
594            mode="reflect",
595        )
596        kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device)
597        return F.conv1d(inputs, weight=kernel, groups=c)

Apply gaussian smoothing on a 1d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. Arguments: channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel.

GaussianSmoothing(kernel_size: int = 15, sigma: float = 1.0)
564    def __init__(self, kernel_size: int = 15, sigma: float = 1.0) -> None:
565        super().__init__()
566        self.kernel_size = kernel_size
567
568        # The gaussian kernel is the product of the
569        # gaussian function of each dimension.
570        kernel = 1
571        meshgrid = torch.meshgrid(torch.arange(kernel_size))[0].float()
572
573        mean = (kernel_size - 1) / 2
574        kernel = kernel / (sigma * math.sqrt(2 * math.pi))
575        kernel = kernel * torch.exp(-(((meshgrid - mean) / sigma) ** 2) / 2)
576
577        # Make sure sum of values in gaussian kernel equals 1.
578        # kernel = kernel / torch.max(kernel)
579
580        self.kernel = kernel.view(1, 1, *kernel.size())

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
582    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
583        """
584        Apply gaussian filter to input.
585        Arguments:
586            input (torch.Tensor): Input to apply gaussian filter on.
587        Returns:
588            filtered (torch.Tensor): Filtered output.
589        """
590        _, c, _ = inputs.shape
591        inputs = F.pad(
592            inputs,
593            pad=((self.kernel_size - 1) // 2, (self.kernel_size - 1) // 2),
594            mode="reflect",
595        )
596        kernel = self.kernel.repeat(c, *[1] * (self.kernel.dim() - 1)).to(inputs.device)
597        return F.conv1d(inputs, weight=kernel, groups=c)

Apply gaussian filter to input. Arguments: input (torch.Tensor): Input to apply gaussian filter on. Returns: filtered (torch.Tensor): Filtered output.

Inherited Members
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
T_destination
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
def argrelmax(prob: numpy.ndarray, threshold: float = 0.7) -> List[int]:
600def argrelmax(prob: np.ndarray, threshold: float = 0.7) -> List[int]:
601    """
602    Calculate arguments of relative maxima.
603    prob: np.array. boundary probability maps distributerd in [0, 1]
604    prob shape is (T)
605    ignore the peak whose value is under threshold
606    Return:
607        Index of peaks for each batch
608    """
609    # ignore the values under threshold
610    prob[prob < threshold] = 0.0
611
612    # calculate the relative maxima of boundary maps
613    # treat the first frame as boundary
614    peak = np.concatenate(
615        [
616            np.ones((1), dtype=np.bool),
617            (prob[:-2] < prob[1:-1]) & (prob[2:] < prob[1:-1]),
618            np.zeros((1), dtype=np.bool),
619        ],
620        axis=0,
621    )
622
623    peak_idx = np.where(peak)[0].tolist()
624
625    return peak_idx

Calculate arguments of relative maxima. prob: np.array. boundary probability maps distributerd in [0, 1] prob shape is (T) ignore the peak whose value is under threshold Return: Index of peaks for each batch

def decide_boundary_prob_with_similarity(x: torch.Tensor) -> torch.Tensor:
628def decide_boundary_prob_with_similarity(x: torch.Tensor) -> torch.Tensor:
629    """
630    Decide action boundary probabilities based on adjacent frame similarities.
631    Args:
632        x: frame-wise video features (N, C, T)
633    Return:
634        boundary: action boundary probability (N, 1, T)
635    """
636    device = x.device
637
638    # gaussian kernel.
639    diff = x[0, :, 1:] - x[0, :, :-1]
640    similarity = torch.exp(-torch.norm(diff, dim=0) / (2 * 1.0))
641
642    # define action starting point as action boundary.
643    start = torch.ones(1).float().to(device)
644    boundary = torch.cat([start, similarity])
645    boundary = boundary.view(1, 1, -1)
646    return boundary

Decide action boundary probabilities based on adjacent frame similarities. Args: x: frame-wise video features (N, C, T) Return: boundary: action boundary probability (N, 1, T)

class PostProcessor:
649class PostProcessor(object):
650    def __init__(
651        self,
652        name: str,
653        boundary_th: int = 0.7,
654        theta_t: int = 15,
655        kernel_size: int = 15,
656    ) -> None:
657        self.func = {
658            "refinement_with_boundary": self._refinement_with_boundary,
659            "relabeling": self._relabeling,
660            "smoothing": self._smoothing,
661        }
662        assert name in self.func
663
664        self.name = name
665        self.boundary_th = boundary_th
666        self.theta_t = theta_t
667        self.kernel_size = kernel_size
668
669        if name == "smoothing":
670            self.filter = GaussianSmoothing(self.kernel_size)
671
672    def _is_probability(self, x: np.ndarray) -> bool:
673        assert x.ndim == 3
674
675        if x.shape[1] == 1:
676            # sigmoid
677            if x.min() >= 0 and x.max() <= 1:
678                return True
679            else:
680                return False
681        else:
682            # softmax
683            _sum = np.sum(x, axis=1).astype(np.float32)
684            _ones = np.ones_like(_sum, dtype=np.float32)
685            return np.allclose(_sum, _ones)
686
687    def _convert2probability(self, x: np.ndarray) -> np.ndarray:
688        """
689        Args: x (N, C, T)
690        """
691        assert x.ndim == 3
692
693        if self._is_probability(x):
694            return x
695        else:
696            if x.shape[1] == 1:
697                # sigmoid
698                prob = 1 / (1 + np.exp(-x))
699            else:
700                # softmax
701                prob = np.exp(x) / np.expand_dims(np.sum(np.exp(x), axis=1), 1)
702            return prob.astype(np.float32)
703
704    def _convert2label(self, x: np.ndarray) -> np.ndarray:
705        assert x.ndim == 2 or x.ndim == 3
706
707        if x.ndim == 2:
708            return x.astype(np.int64)
709        else:
710            if not self._is_probability(x):
711                x = self._convert2probability(x)
712
713            label = np.argmax(x, axis=1)
714            return label.astype(np.int64)
715
716    def _refinement_with_boundary(
717        self,
718        outputs: np.array,
719        boundaries: np.ndarray,
720    ) -> np.ndarray:
721        """
722        Get segments which is defined as the span b/w two boundaries,
723        and decide their classes by majority vote.
724        Args:
725            outputs: numpy array. shape (N, C, T)
726                the model output for frame-level class prediction.
727            boundaries: numpy array.  shape (N, 1, T)
728                boundary prediction.
729            masks: np.array. np.bool. shape (N, 1, T)
730                valid length for each video
731        Return:
732            preds: np.array. shape (N, T)
733                final class prediction considering boundaries.
734        """
735
736        preds = self._convert2label(outputs)
737        boundaries = self._convert2probability(boundaries)
738
739        for i, (output, pred, boundary) in enumerate(zip(outputs, preds, boundaries)):
740            idx = argrelmax(boundary.squeeze(), threshold=self.boundary_th)
741
742            # add the index of the last action ending
743            T = pred.shape[0]
744            idx.append(T)
745
746            # majority vote
747            for j in range(len(idx) - 1):
748                count = np.bincount(pred[idx[j] : idx[j + 1]])
749                modes = np.where(count == count.max())[0]
750                if len(modes) == 1:
751                    mode = modes
752                else:
753                    if outputs.ndim == 3:
754                        # if more than one majority class exist
755                        prob_sum_max = 0
756                        for m in modes:
757                            prob_sum = output[m, idx[j] : idx[j + 1]].sum()
758                            if prob_sum_max < prob_sum:
759                                mode = m
760                                prob_sum_max = prob_sum
761                    else:
762                        # decide first mode when more than one majority class
763                        # have the same number during oracle experiment
764                        mode = modes[0]
765
766                preds[i, idx[j] : idx[j + 1]] = mode
767
768        return preds
769
770    def _relabeling(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
771        """
772        Relabeling small action segments with their previous action segment
773        Args:
774            output: the results of action segmentation. (N, T) or (N, C, T)
775            theta_t: the threshold of the size of action segments.
776        Return:
777            relabeled output. (N, T)
778        """
779
780        preds = self._convert2label(outputs)
781
782        for i in range(preds.shape[0]):
783            # shape (T,)
784            last = preds[i][0]
785            cnt = 1
786            for j in range(1, preds.shape[1]):
787                if last == preds[i][j]:
788                    cnt += 1
789                else:
790                    if cnt > self.theta_t:
791                        cnt = 1
792                        last = preds[i][j]
793                    else:
794                        preds[i][j - cnt : j] = preds[i][j - cnt - 1]
795                        cnt = 1
796                        last = preds[i][j]
797
798            if cnt <= self.theta_t:
799                preds[i][j - cnt : j] = preds[i][j - cnt - 1]
800
801        return preds
802
803    def _smoothing(self, outputs: np.ndarray, **kwargs: np.ndarray) -> np.ndarray:
804        """
805        Smoothing action probabilities with gaussian filter.
806        Args:
807            outputs: frame-wise action probabilities. (N, C, T)
808        Return:
809            predictions: final prediction. (N, T)
810        """
811
812        outputs = self._convert2probability(outputs)
813        outputs = self.filter(torch.Tensor(outputs)).numpy()
814
815        preds = self._convert2label(outputs)
816        return preds
817
818    def __call__(self, outputs, **kwargs: np.ndarray) -> np.ndarray:
819
820        preds = self.func[self.name](outputs, **kwargs)
821        return preds
PostProcessor( name: str, boundary_th: int = 0.7, theta_t: int = 15, kernel_size: int = 15)
650    def __init__(
651        self,
652        name: str,
653        boundary_th: int = 0.7,
654        theta_t: int = 15,
655        kernel_size: int = 15,
656    ) -> None:
657        self.func = {
658            "refinement_with_boundary": self._refinement_with_boundary,
659            "relabeling": self._relabeling,
660            "smoothing": self._smoothing,
661        }
662        assert name in self.func
663
664        self.name = name
665        self.boundary_th = boundary_th
666        self.theta_t = theta_t
667        self.kernel_size = kernel_size
668
669        if name == "smoothing":
670            self.filter = GaussianSmoothing(self.kernel_size)