dlc2action.data.annotation_store

Specific implementations of dlc2action.data.base_store.BehaviorStore are defined here.

   1#
   2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. 
   5# A copy is included in dlc2action/LICENSE.AGPL.
   6#
   7"""Specific implementations of `dlc2action.data.base_store.BehaviorStore` are defined here."""
   8
   9import os
  10import pickle
  11from abc import abstractmethod
  12from collections import Counter, defaultdict
  13from collections.abc import Iterable
  14from copy import copy
  15from itertools import combinations
  16from typing import Dict, List, Set, Tuple, Union
  17
  18import numpy as np
  19import pandas as pd
  20import torch
  21from tqdm import tqdm
  22from dlc2action.data.base_store import BehaviorStore
  23from dlc2action.utils import strip_suffix
  24
  25
  26class EmptyBehaviorStore(BehaviorStore):
  27    """An empty annotation store that does not contain any data samples."""
  28
  29    def __init__(
  30        self, video_order: List = None, key_objects: Tuple = None, *args, **kwargs
  31    ):
  32        """Initialize the store.
  33
  34        Parameters
  35        ----------
  36        video_order : list, optional
  37            a list of video ids that should be processed in the same order (not passed if creating from key objects)
  38        key_objects : tuple, optional
  39            a tuple of key objects
  40
  41        """
  42        pass
  43
  44    def __len__(self) -> int:
  45        """Get the number of available samples.
  46
  47        Returns
  48        -------
  49        length : int
  50            the number of available samples
  51
  52        """
  53        return 0
  54
  55    def remove(self, indices: List) -> None:
  56        """Remove the samples corresponding to indices.
  57
  58        Parameters
  59        ----------
  60        indices : int
  61            a list of integer indices to remove
  62
  63        """
  64        pass
  65
  66    def key_objects(self) -> Tuple:
  67        """Return a tuple of the key objects necessary to re-create the Store.
  68
  69        Returns
  70        -------
  71        key_objects : tuple
  72            a tuple of key objects
  73
  74        """
  75        return ()
  76
  77    def load_from_key_objects(self, key_objects: Tuple) -> None:
  78        """Load the information from a tuple of key objects.
  79
  80        Parameters
  81        ----------
  82        key_objects : tuple
  83            a tuple of key objects
  84
  85        """
  86        pass
  87
  88    def to_ram(self) -> None:
  89        """Transfer the data samples to RAM if they were previously stored as file paths."""
  90        pass
  91
  92    def get_original_coordinates(self) -> np.ndarray:
  93        """Return the original coordinates array.
  94
  95        Returns
  96        -------
  97        np.ndarray
  98            an array that contains the coordinates of the data samples in original input data (video id, clip id,
  99            start frame)
 100
 101        """
 102        return None
 103
 104    def create_subsample(self, indices: List, ssl_indices: List = None):
 105        """Create a new store that contains a subsample of the data.
 106
 107        Parameters
 108        ----------
 109        indices : list
 110            the indices to be included in the subsample
 111        ssl_indices : list, optional
 112            the indices to be included in the subsample without the annotation data
 113
 114        """
 115        return self.new()
 116
 117    @classmethod
 118    def get_file_ids(cls, *args, **kwargs) -> List:
 119        """Get file ids.
 120
 121        Process data parameters and return a list of ids  of the videos that should
 122        be processed by the `__init__` function.
 123
 124        Returns
 125        -------
 126        video_ids : list
 127            a list of video file ids
 128
 129        """
 130        return None
 131
 132    def __getitem__(self, ind: int) -> torch.Tensor:
 133        """Return the annotation of the sample corresponding to an index.
 134
 135        Parameters
 136        ----------
 137        ind : int
 138            index of the sample
 139
 140        Returns
 141        -------
 142        sample : torch.Tensor
 143            the corresponding annotation tensor
 144
 145        """
 146        return torch.tensor(float("nan"))
 147
 148    def get_len(self, return_unlabeled: bool) -> int:
 149        """Get the length of the subsample of labeled/unlabeled data.
 150
 151        If `return_unlabele`d is `True`, the index is in the subsample of unlabeled data, if `False` in labeled
 152        and if `return_unlabeled` if `None` the index is already correct.
 153
 154        Parameters
 155        ----------
 156        return_unlabeled : bool
 157            the identifier for the subsample
 158
 159        Returns
 160        -------
 161        length : int
 162            the length of the subsample
 163
 164        """
 165        return None
 166
 167    def get_idx(self, index: int, return_unlabeled: bool) -> int:
 168        """Convert from an index in the subsample of labeled/unlabeled data to an index in the full array.
 169
 170        If `return_unlabeled` is `True`, the index is in the subsample of unlabeled data, if `False` in labeled
 171        and if `return_unlabeled` is `None` the index is already correct.
 172
 173        Parameters
 174        ----------
 175        index : int
 176            the index in the subsample
 177        return_unlabeled : bool
 178            the identifier for the subsample
 179
 180        Returns
 181        -------
 182        corrected_index : int
 183            the index in the full dataset
 184
 185        """
 186        return index
 187
 188    def count_classes(
 189        self, perc: bool = False, zeros: bool = False, bouts: bool = False
 190    ) -> Dict:
 191        """Get a dictionary with class-wise frame counts.
 192
 193        Parameters
 194        ----------
 195        perc : bool, default False
 196            if `True`, a fraction of the total frame count is returned
 197        zeros : bool, default False
 198            if `True` and annotation is not exclusive, zero counts are returned
 199        bouts : bool, default False
 200            if `True`, instead of frame counts segment counts are returned
 201
 202        Returns
 203        -------
 204        count_dictionary : dict
 205            a dictionary with class indices as keys and frame counts as values
 206
 207        """
 208        return {}
 209
 210    def behaviors_dict(self) -> Dict:
 211        """Get a dictionary of class names.
 212
 213        Returns
 214        -------
 215        behavior_dictionary: dict
 216            a dictionary with class indices as keys and class names as values
 217
 218        """
 219        return {}
 220
 221    def annotation_class(self) -> str:
 222        """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon).
 223
 224        Returns
 225        -------
 226        annotation_class : str
 227            the type of annotation
 228
 229        """
 230        return "none"
 231
 232    def size(self) -> int:
 233        """Get the total number of frames in the data.
 234
 235        Returns
 236        -------
 237        size : int
 238            the total number of frames
 239
 240        """
 241        return None
 242
 243    def filtered_indices(self) -> List:
 244        """Return the indices of the samples that should be removed.
 245
 246        Choosing the indices can be based on any kind of filering defined in the `__init__` function by the data
 247        parameters.
 248
 249        Returns
 250        -------
 251        indices_to_remove : list
 252            a list of integer indices that should be removed
 253
 254        """
 255        return []
 256
 257    def set_pseudo_labels(self, labels: torch.Tensor) -> None:
 258        """Set pseudo labels to the unlabeled data.
 259
 260        Parameters
 261        ----------
 262        labels : torch.Tensor
 263            a tensor of pseudo-labels for the unlabeled data
 264
 265        """
 266        pass
 267
 268
 269class ActionSegmentationStore(BehaviorStore):  # +
 270    """A general realization of an annotation store for action segmentation tasks.
 271
 272    Assumes the following file structure:
 273    ```
 274    annotation_path
 275    ├── video1_annotation.pickle
 276    └── video2_labels.pickle
 277    ```
 278    Here `annotation_suffix` is `{'_annotation.pickle', '_labels.pickle'}`.
 279    """
 280
 281    def __init__(
 282        self,
 283        video_order: List = None,
 284        min_frames: Dict = None,
 285        max_frames: Dict = None,
 286        visibility: Dict = None,
 287        exclusive: bool = True,
 288        len_segment: int = 128,
 289        overlap: int = 0,
 290        behaviors: Set = None,
 291        ignored_classes: Set = None,
 292        ignored_clips: Set = None,
 293        annotation_suffix: Union[Set, str] = None,
 294        annotation_path: Union[Set, str] = None,
 295        behavior_file: str = None,
 296        correction: Dict = None,
 297        frame_limit: int = 0,
 298        filter_annotated: bool = False,
 299        filter_background: bool = False,
 300        error_class: str = None,
 301        min_frames_action: int = None,
 302        key_objects: Tuple = None,
 303        visibility_min_score: float = 0.2,
 304        visibility_min_frac: float = 0.7,
 305        mask: Dict = None,
 306        use_hard_negatives: bool = False,
 307        interactive: bool = False,
 308        *args,
 309        **kwargs,
 310    ) -> None:
 311        """Initialize the store.
 312
 313        Parameters
 314        ----------
 315        video_order : list, optional
 316            a list of video ids that should be processed in the same order (not passed if creating from key objects)
 317        min_frames : dict, optional
 318            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
 319            clip start frames (not passed if creating from key objects)
 320        max_frames : dict, optional
 321            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
 322            clip end frames (not passed if creating from key objects)
 323        visibility : dict, optional
 324            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
 325            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
 326        exclusive : bool, default True
 327            if True, the annotation is single-label; if False, multi-label
 328        len_segment : int, default 128
 329            the length of the segments in which the data should be cut (in frames)
 330        overlap : int, default 0
 331            the length of the overlap between neighboring segments (in frames)
 332        behaviors : set, optional
 333            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
 334            loaded from a file)
 335        ignored_classes : set, optional
 336            the list of behaviors from the behaviors list or file to not annotate
 337        ignored_clips : set, optional
 338            clip ids to ignore
 339        annotation_suffix : str | set, optional
 340            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
 341            (not passed if creating from key objects or if irrelevant for the dataset)
 342        annotation_path : str | set, optional
 343            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
 344            from key objects)
 345        behavior_file : str, optional
 346            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
 347        correction : dict, optional
 348            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
 349            can be used to correct for variations in naming or to merge several labels in one
 350        frame_limit : int, default 0
 351            the smallest possible length of a clip (shorter clips are discarded)
 352        filter_annotated : bool, default False
 353            if True, the samples that do not have any labels will be filtered
 354        filter_background : bool, default False
 355            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
 356        error_class : str, optional
 357            the name of the error class (the annotations that intersect with this label will be discarded)
 358        min_frames_action : int, default 0
 359            the minimum length of an action (shorter actions are not annotated)
 360        key_objects : tuple, optional
 361            the key objects to load the BehaviorStore from
 362        visibility_min_score : float, default 5
 363            the minimum visibility score for visibility filtering
 364        visibility_min_frac : float, default 0.7
 365            the minimum fraction of visible frames for visibility filtering
 366        mask : dict, optional
 367            a masked value dictionary (for active learning simulation experiments)
 368        use_hard_negatives : bool, default False
 369            mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing
 370        interactive : bool, default False
 371            if `True`, annotation is assigned to pairs of individuals
 372
 373        """
 374        super().__init__()
 375
 376        if ignored_clips is None:
 377            ignored_clips = []
 378        self.len_segment = int(len_segment)
 379        self.exclusive = exclusive
 380        if isinstance(overlap, str):
 381            overlap = float(overlap)
 382        if overlap < 1:
 383            overlap = overlap * self.len_segment
 384        self.overlap = int(overlap)
 385        self.video_order = video_order
 386        self.min_frames = min_frames
 387        self.max_frames = max_frames
 388        self.visibility = visibility
 389        self.vis_min_score = visibility_min_score
 390        self.vis_min_frac = visibility_min_frac
 391        self.mask = mask
 392        self.use_negatives = use_hard_negatives
 393        self.interactive = interactive
 394        self.ignored_clips = ignored_clips
 395        self.file_paths = self._get_file_paths(annotation_path)
 396        self.ignored_classes = ignored_classes
 397        self.update_behaviors = False
 398
 399        self.ram = True
 400        self.original_coordinates = []
 401        self.filtered = []
 402
 403        self.step = self.len_segment - self.overlap
 404
 405        self.ann_suffix = annotation_suffix
 406        self.annotation_folder = annotation_path
 407        self.filter_annotated = filter_annotated
 408        self.filter_background = filter_background
 409        self.frame_limit = frame_limit
 410        self.min_frames_action = min_frames_action
 411        self.error_class = error_class
 412
 413        if correction is None:
 414            correction = {}
 415        self.correction = correction
 416
 417        if self.max_frames is None:
 418            self.max_frames = defaultdict(lambda: {})
 419        if self.min_frames is None:
 420            self.min_frames = defaultdict(lambda: {})
 421
 422        lists = [self.annotation_folder, self.ann_suffix]
 423        for i in range(len(lists)):
 424            iterable = isinstance(lists[i], Iterable) * (not isinstance(lists[i], str))
 425            if lists[i] is not None:
 426                if not iterable:
 427                    lists[i] = [lists[i]]
 428                lists[i] = [x for x in lists[i]]
 429        self.annotation_folder, self.ann_suffix = lists
 430
 431        if ignored_classes is None:
 432            ignored_classes = []
 433        self.ignored_classes = ignored_classes
 434        self._set_behaviors(behaviors, ignored_classes, behavior_file)
 435
 436        if key_objects is None and self.video_order is not None:
 437            self.data = self._load_data()
 438        elif key_objects is not None:
 439            self.load_from_key_objects(key_objects)
 440        else:
 441            self.data = None
 442        self.labeled_indices, self.unlabeled_indices = self._compute_labeled()
 443
 444    def __getitem__(self, ind):
 445        if self.data is None:
 446            raise RuntimeError("The annotation store data has not been initialized!")
 447        return self.data[ind]
 448
 449    def __len__(self) -> int:
 450        if self.data is None:
 451            raise RuntimeError("The annotation store data has not been initialized!")
 452        return len(self.data)
 453
 454    def remove(self, indices: List) -> None:
 455        """Remove the samples corresponding to indices.
 456
 457        Parameters
 458        ----------
 459        indices : list
 460            a list of integer indices to remove
 461
 462        """
 463        if len(indices) > 0:
 464            mask = np.ones(len(self.data))
 465            mask[indices] = 0
 466            mask = mask.astype(bool)
 467            self.data = self.data[mask]
 468            self.original_coordinates = self.original_coordinates[mask]
 469
 470    def key_objects(self) -> Tuple:
 471        """Return a tuple of the key objects necessary to re-create the Store.
 472
 473        Returns
 474        -------
 475        key_objects : tuple
 476            a tuple of key objects
 477
 478        """
 479        return (
 480            self.original_coordinates,
 481            self.data,
 482            self.behaviors,
 483            self.exclusive,
 484            self.len_segment,
 485            self.step,
 486            self.overlap,
 487        )
 488
 489    def load_from_key_objects(self, key_objects: Tuple) -> None:
 490        """Load the information from a tuple of key objects.
 491
 492        Parameters
 493        ----------
 494        key_objects : tuple
 495            a tuple of key objects
 496
 497        """
 498        (
 499            self.original_coordinates,
 500            self.data,
 501            self.behaviors,
 502            self.exclusive,
 503            self.len_segment,
 504            self.step,
 505            self.overlap,
 506        ) = key_objects
 507        self.labeled_indices, self.unlabeled_indices = self._compute_labeled()
 508
 509    def to_ram(self) -> None:
 510        """Transfer the data samples to RAM if they were previously stored as file paths."""
 511        pass
 512
 513    def get_original_coordinates(self) -> np.ndarray:
 514        """Return the `video_indices` array.
 515
 516        Returns
 517        -------
 518        original_coordinates : numpy.ndarray
 519            an array that contains the coordinates of the data samples in original input data
 520
 521        """
 522        return self.original_coordinates
 523
 524    def create_subsample(self, indices: List, ssl_indices: List = None):
 525        """Create a new store that contains a subsample of the data.
 526
 527        Parameters
 528        ----------
 529        indices : list
 530            the indices to be included in the subsample
 531        ssl_indices : list, optional
 532            the indices to be included in the subsample without the annotation data
 533
 534        """
 535        if ssl_indices is None:
 536            ssl_indices = []
 537        data = copy(self.data)
 538        data[ssl_indices, ...] = -100
 539        new = self.new()
 540        new.original_coordinates = self.original_coordinates[indices + ssl_indices]
 541        new.data = self.data[indices + ssl_indices]
 542        new.labeled_indices, new.unlabeled_indices = new._compute_labeled()
 543        new.behaviors = self.behaviors
 544        new.exclusive = self.exclusive
 545        new.len_segment = self.len_segment
 546        new.step = self.step
 547        new.overlap = self.overlap
 548        new.max_frames = self.max_frames
 549        new.min_frames = self.min_frames
 550        return new
 551
 552    def get_len(self, return_unlabeled: bool) -> int:
 553        """Get the length of the subsample of labeled/unlabeled data.
 554
 555        If return_unlabeled is True, the index is in the subsample of unlabeled data, if False in labeled
 556        and if return_unlabeled is None the index is already correct.
 557
 558        Parameters
 559        ----------
 560        return_unlabeled : bool
 561            the identifier for the subsample
 562
 563        Returns
 564        -------
 565        length : int
 566            the length of the subsample
 567
 568        """
 569        if self.data is None:
 570            raise RuntimeError("The annotation store data has not been initialized!")
 571        elif return_unlabeled is None:
 572            return len(self.data)
 573        elif return_unlabeled:
 574            return len(self.unlabeled_indices)
 575        else:
 576            return len(self.labeled_indices)
 577
 578    def get_indices(self, return_unlabeled: bool) -> List:
 579        """Get a list of indices of samples in the labeled/unlabeled subset.
 580
 581        Parameters
 582        ----------
 583        return_unlabeled : bool
 584            the identifier for the subsample (`True` for unlabeled, `False` for labeled, `None` for the
 585            whole dataset)
 586
 587        Returns
 588        -------
 589        indices : list
 590            a list of indices that meet the criteria
 591
 592        """
 593        return list(range(len(self.data)))
 594
 595    def count_classes(
 596        self, perc: bool = False, zeros: bool = False, bouts: bool = False
 597    ) -> Dict:
 598        """Get a dictionary with class-wise frame counts.
 599
 600        Parameters
 601        ----------
 602        perc : bool, default False
 603            if `True`, a fraction of the total frame count is returned
 604        zeros : bool, default False
 605            if `True` and annotation is not exclusive, zero counts are returned
 606        bouts : bool, default False
 607            if `True`, instead of frame counts segment counts are returned
 608
 609        Returns
 610        -------
 611        count_dictionary : dict
 612            a dictionary with class indices as keys and frame counts as values
 613
 614        """
 615        if bouts:
 616            if self.overlap != 0:
 617                data = {}
 618                for video, value in self.max_frames.items():
 619                    for clip, end in value.items():
 620                        length = end - self._get_min_frame(video, clip)
 621                        if self.exclusive:
 622                            data[f"{video}---{clip}"] = -100 * torch.ones(length)
 623                        else:
 624                            data[f"{video}---{clip}"] = -100 * torch.ones(
 625                                (len(self.behaviors_dict()), length)
 626                            )
 627                for x, coords in zip(self.data, self.original_coordinates):
 628                    split = coords[0].split("---")
 629                    l = self._get_max_frame(split[0], split[1]) - self._get_min_frame(
 630                        split[0], split[1]
 631                    )
 632                    i = coords[1]
 633                    start = int(i) * self.step
 634                    end = min(start + self.len_segment, l)
 635                    data[coords[0]][..., start:end] = x[..., : end - start]
 636                values = []
 637                for key, value in data.items():
 638                    values.append(value)
 639                    values.append(-100 * torch.ones((*value.shape[:-1], 1)))
 640                data = torch.cat(values, -1).T
 641            else:
 642                data = copy(self.data)
 643                if self.exclusive:
 644                    data = data.flatten()
 645                else:
 646                    data = data.transpose(1, 2).reshape(-1, len(self.behaviors))
 647            count_dictionary = {}
 648            for c in self.behaviors_dict():
 649                if self.exclusive:
 650                    arr = data == c
 651                else:
 652                    if zeros:
 653                        arr = data[:, c] == 0
 654                    else:
 655                        arr = data[:, c] == 1
 656                output, indices = torch.unique_consecutive(arr, return_inverse=True)
 657                true_indices = torch.where(output)[0]
 658                count_dictionary[c] = len(true_indices)
 659        else:
 660            ind = 1
 661            if zeros:
 662                ind = 0
 663            if self.exclusive:
 664                count_dictionary = dict(Counter(self.data.flatten().cpu().numpy()))
 665            else:
 666                d = {}
 667                for i in range(self.data.shape[1]):
 668                    cnt = Counter(self.data[:, i, :].flatten().cpu().numpy())
 669                    d[i] = cnt[ind]
 670                count_dictionary = d
 671            if perc:
 672                total = sum([v for k, v in count_dictionary.items()])
 673                count_dictionary = {k: v / total for k, v in count_dictionary.items()}
 674        for i in self.behaviors_dict():
 675            if i not in count_dictionary:
 676                count_dictionary[i] = 0
 677        return {int(k): v for k, v in count_dictionary.items()}
 678
 679    def behaviors_dict(self) -> Dict:
 680        """Get a dictionary of class names.
 681
 682        Returns
 683        -------
 684        behavior_dictionary: dict
 685            a dictionary with class indices as keys and class names as values
 686
 687        """
 688        if self.exclusive and "other" not in self.behaviors:
 689            d = {i + 1: b for i, b in enumerate(self.behaviors)}
 690            d[0] = "other"
 691        else:
 692            d = {i: b for i, b in enumerate(self.behaviors)}
 693        return d
 694
 695    def annotation_class(self) -> str:
 696        """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification').
 697
 698        Returns
 699        -------
 700        annotation_class : str
 701            the type of annotation
 702
 703        """
 704        if self.exclusive:
 705            return "exclusive_classification"
 706        else:
 707            return "nonexclusive_classification"
 708
 709    def size(self) -> int:
 710        """Get the total number of frames in the data.
 711
 712        Returns
 713        -------
 714        size : int
 715            the total number of frames
 716
 717        """
 718        return self.data.shape[0] * self.data.shape[-1]
 719
 720    def filtered_indices(self) -> List:
 721        """Return the indices of the samples that should be removed.
 722
 723        Choosing the indices can be based on any kind of filering defined in the `__init__` function by the data
 724        parameters.
 725
 726        Returns
 727        -------
 728        indices_to_remove : list
 729            a list of integer indices that should be removed
 730
 731        """
 732        return self.filtered
 733
 734    def set_pseudo_labels(self, labels: torch.Tensor) -> None:
 735        """Set pseudo labels to the unlabeled data.
 736
 737        Parameters
 738        ----------
 739        labels : torch.Tensor
 740            a tensor of pseudo-labels for the unlabeled data
 741
 742        """
 743        self.data[self.unlabeled_indices] = labels
 744
 745    @classmethod
 746    def get_file_ids(
 747        cls,
 748        annotation_path: Union[str, Set],
 749        annotation_suffix: Union[str, Set],
 750        *args,
 751        **kwargs,
 752    ) -> List:
 753        """Get file ids.
 754
 755        Process data parameters and return a list of ids  of the videos that should
 756        be processed by the `__init__` function.
 757
 758        Parameters
 759        ----------
 760        annotation_path : str | set
 761            the path or the set of paths to the folder where the annotation files are stored
 762        annotation_suffix : str | set, optional
 763            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
 764
 765        Returns
 766        -------
 767        video_ids : list
 768            a list of video file ids
 769
 770        """
 771        lists = [annotation_path, annotation_suffix]
 772        for i in range(len(lists)):
 773            iterable = isinstance(lists[i], Iterable) * (not isinstance(lists[i], str))
 774            if lists[i] is not None:
 775                if not iterable:
 776                    lists[i] = [lists[i]]
 777                lists[i] = [x for x in lists[i]]
 778        annotation_path, annotation_suffix = lists
 779        files = []
 780        for folder in annotation_path:
 781            files += [
 782                strip_suffix(os.path.basename(file), annotation_suffix)
 783                for file in os.listdir(folder)
 784                if file.endswith(tuple([x for x in annotation_suffix]))
 785            ]
 786        files = sorted(files, key=lambda x: os.path.basename(x))
 787        return files
 788
 789    def _set_behaviors(
 790        self, behaviors: List, ignored_classes: List, behavior_file: str
 791    ):
 792        """Get a list of behaviors that should be annotated from behavior parameters."""
 793        if behaviors is not None:
 794            for b in ignored_classes:
 795                if b in behaviors:
 796                    behaviors.remove(b)
 797        self.behaviors = behaviors
 798
 799    def _compute_labeled(self) -> Tuple[torch.Tensor, torch.Tensor]:
 800        """Get the indices of labeled (annotated) and unlabeled samples."""
 801        if self.data is not None and len(self.data) > 0:
 802            unlabeled = torch.sum(self.data != -100, dim=1) == 0
 803            labeled_indices = torch.where(~unlabeled)[0]
 804            unlabeled_indices = torch.where(unlabeled)[0]
 805        else:
 806            labeled_indices, unlabeled_indices = torch.tensor([]), torch.tensor([])
 807        return labeled_indices, unlabeled_indices
 808
 809    def _generate_annotation(self, times: Dict, name: str) -> Dict:
 810        """Process a loaded annotation file to generate a training labels dictionary."""
 811        annotation = {}
 812        if self.behaviors is None and times is not None:
 813            self.update_behaviors = True
 814            behaviors = set()
 815            for d in times.values():
 816                behaviors.update([k for k, v in d.items()])
 817            self.behaviors = [
 818                x
 819                for x in sorted(behaviors)
 820                if x not in self.ignored_classes
 821                and not x.startswith("negative")
 822                and not x.startswith("unknown")
 823            ]
 824        beh_inv = {v: k for k, v in self.behaviors_dict().items()}
 825        # if there is no annotation file, generate empty annotation
 826        if self.interactive:
 827            clips = [
 828                "+".join(sorted(x))
 829                for x in combinations(self.max_frames[name].keys(), 2)
 830            ]
 831        else:
 832            clips = list(self.max_frames[name].keys())
 833        if times is None:
 834            clips = [x for x in clips if x not in self.ignored_clips]
 835        # otherwise, apply filters and generate label arrays
 836        else:
 837            clips = [
 838                x
 839                for x in clips
 840                if x not in self.ignored_clips and x not in times.keys()
 841            ]
 842            for ind in times.keys():
 843                try:
 844                    min_frame = self._get_min_frame(name, ind)
 845                    max_frame = self._get_max_frame(name, ind)
 846                except KeyError:
 847                    continue
 848                go_on = max_frame - min_frame + 1 >= self.frame_limit
 849                if go_on:
 850                    v_len = max_frame - min_frame + 1
 851                    if self.exclusive:
 852                        if not self.filter_background:
 853                            value = beh_inv.get("other", 0)
 854                            labels = np.ones(v_len, dtype=int) * value
 855                        else:
 856                            labels = -100 * np.ones(v_len, dtype=int)
 857                    else:
 858                        labels = np.zeros(
 859                            (len(self.behaviors), v_len), dtype=np.float32
 860                        )
 861                    cat_new = []
 862                    for cat in times[ind].keys():
 863                        if cat.startswith("unknown"):
 864                            cat_new.append(cat)
 865                    for cat in times[ind].keys():
 866                        if cat.startswith("negative"):
 867                            cat_new.append(cat)
 868                    for cat in times[ind].keys():
 869                        if not cat.startswith("negative") and not cat.startswith(
 870                            "unknown"
 871                        ):
 872                            cat_new.append(cat)
 873                    for cat in cat_new:
 874                        neg = False
 875                        unknown = False
 876                        cat_times = times[ind][cat]
 877                        if self.use_negatives and cat.startswith("negative"):
 878                            cat = " ".join(cat.split()[1:])
 879                            neg = True
 880                        elif cat.startswith("unknown"):
 881                            cat = " ".join(cat.split()[1:])
 882                            unknown = True
 883                        if cat in self.correction:
 884                            cat = self.correction[cat]
 885                        for entry in cat_times:
 886                            if len(entry) == 3:
 887                                start, end, amb = entry
 888                            else:
 889                                start, end = entry
 890                                amb = 0
 891                            if end > self._get_max_frame(name, ind) + 1:
 892                                end = self._get_max_frame(name, ind) + 1
 893                            if amb != 0:
 894                                continue
 895                            start -= min_frame
 896                            end -= min_frame
 897                            if (
 898                                self.min_frames_action is not None
 899                                and end - start < self.min_frames_action
 900                            ):
 901                                continue
 902                            if (
 903                                self.vis_min_frac > 0
 904                                and self.vis_min_score > 0
 905                                and self.visibility is not None
 906                            ):
 907                                s = 0
 908                                for ind_k in ind.split("+"):
 909                                    s += np.sum(
 910                                        self.visibility[name][ind_k][start:end]
 911                                        > self.vis_min_score
 912                                    )
 913                                if s < self.vis_min_frac * (end - start) * len(
 914                                    ind.split("+")
 915                                ):
 916                                    continue
 917                            if cat in beh_inv:
 918                                cat_i_global = beh_inv[cat]
 919                                if self.exclusive:
 920                                    labels[start:end] = cat_i_global
 921                                else:
 922                                    if unknown:
 923                                        labels[cat_i_global, start:end] = -100
 924                                    elif neg:
 925                                        labels[cat_i_global, start:end] = 2
 926                                    else:
 927                                        labels[cat_i_global, start:end] = 1
 928                            else:
 929                                self.not_found.add(cat)
 930                                if self.filter_background:
 931                                    if not self.exclusive:
 932                                        labels[:, start:end][
 933                                            labels[:, start:end] == 0
 934                                        ] = 3
 935                                    else:
 936                                        labels[start:end][labels[start:end] == -100] = 0
 937
 938                    if self.error_class is not None and self.error_class in times[ind]:
 939                        for start, end, amb in times[ind][self.error_class]:
 940                            if self.exclusive:
 941                                labels[start:end] = -100
 942                            else:
 943                                labels[:, start:end] = -100
 944                    annotation[os.path.basename(name) + "---" + str(ind)] = labels
 945        for ind in clips:
 946            try:
 947                min_frame = self._get_min_frame(name, ind)
 948                max_frame = self._get_max_frame(name, ind)
 949            except KeyError:
 950                continue
 951            go_on = max_frame - min_frame + 1 >= self.frame_limit
 952            if go_on:
 953                v_len = max_frame - min_frame + 1
 954                if self.exclusive:
 955                    annotation[os.path.basename(name) + "---" + str(ind)] = (
 956                        -100 * np.ones(v_len, dtype=int)
 957                    )
 958                else:
 959                    annotation[os.path.basename(name) + "---" + str(ind)] = (
 960                        -100 * np.ones((len(self.behaviors), v_len), dtype=np.float32)
 961                    )
 962        return annotation
 963
 964    def _make_trimmed_annotations(self, annotations_dict: Dict) -> torch.Tensor:
 965        """Cut a label dictionary into overlapping pieces of equal length."""
 966        labels = []
 967        self.original_coordinates = []
 968        masked_all = []
 969        for v_id in sorted(annotations_dict.keys()):
 970            if v_id in annotations_dict:
 971                annotations = annotations_dict[v_id]
 972            else:
 973                raise ValueError(
 974                    f'The id list in {v_id.split("---")[0]} is not consistent across files'
 975                )
 976            split = v_id.split("---")
 977            if len(split) > 1:
 978                video_id, ind = split
 979            else:
 980                video_id = split[0]
 981                ind = ""
 982            min_frame = self._get_min_frame(video_id, ind)
 983            max_frame = self._get_max_frame(video_id, ind)
 984            v_len = max_frame - min_frame + 1
 985            sp = np.arange(0, v_len, self.step)
 986            pad = sp[-1] + self.len_segment - v_len
 987            if self.exclusive:
 988                annotations = np.pad(annotations, ((0, pad)), constant_values=-100)
 989            else:
 990                annotations = np.pad(
 991                    annotations, ((0, 0), (0, pad)), constant_values=-100
 992                )
 993            masked = np.zeros(annotations.shape)
 994            if (
 995                self.mask is not None
 996                and video_id in self.mask["masked"]
 997                and ind in self.mask["masked"][video_id]
 998            ):
 999                for start, end in self.mask["masked"][video_id][ind]:
1000                    masked[..., int(start) : int(end)] = 1
1001            for i, start in enumerate(sp):
1002                self.original_coordinates.append((v_id, i))
1003                if self.exclusive:
1004                    ann = annotations[start : start + self.len_segment]
1005                    m = masked[start : start + self.len_segment]
1006                else:
1007                    ann = annotations[:, start : start + self.len_segment]
1008                    m = masked[:, start : start + self.len_segment]
1009                labels.append(ann)
1010                masked_all.append(m)
1011        self.original_coordinates = np.array(self.original_coordinates)
1012        labels = torch.tensor(np.array(labels))
1013        masked_all = torch.tensor(np.array(masked_all)).int().bool()
1014        if self.filter_background and not self.exclusive:
1015            for i, label in enumerate(labels):
1016                label[:, torch.sum((label == 1) | (label == 3), 0) == 0] = -100
1017                label[label == 3] = 0
1018        labels[(labels != -100) & masked_all] = -200
1019        return labels
1020
1021    @classmethod
1022    def _get_file_paths(cls, annotation_path: Union[str, Set]) -> List:
1023        """Get a list of relevant files."""
1024        file_paths = []
1025        if annotation_path is not None:
1026            if isinstance(annotation_path, str):
1027                annotation_path = [annotation_path]
1028            for folder in annotation_path:
1029                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
1030        return file_paths
1031
1032    def _get_max_frame(self, video_id: str, clip_id: str):
1033        """Get the end frame of a clip in a video."""
1034        if clip_id in self.max_frames[video_id]:
1035            return self.max_frames[video_id][clip_id]
1036        else:
1037            return min(
1038                [self.max_frames[video_id][ind_k] for ind_k in clip_id.split("+")]
1039            )
1040
1041    def _get_min_frame(self, video_id, clip_id):
1042        """Get the start frame of a clip in a video."""
1043        if clip_id in self.min_frames[video_id]:
1044            if clip_id not in self.min_frames[video_id]:
1045                raise KeyError(
1046                    f"Check your individual names, video_id : {video_id}, clip_id : {clip_id}"
1047                )
1048            return self.min_frames[video_id][clip_id]
1049        else:
1050            return max(
1051                [self.min_frames[video_id][ind_k] for ind_k in clip_id.split("+")]
1052            )
1053
1054    @abstractmethod
1055    def _load_data(self) -> torch.Tensor:
1056        """Load behavior annotation and generate annotation prompts."""
1057
1058
1059class FileActionSegStore(ActionSegmentationStore):  # +
1060    """A generalized implementation of `ActionSegmentationStore` for datasets where one file corresponds to one video."""
1061
1062    def _generate_max_min_frames(self, times: Dict, video_id: str) -> None:
1063        """Generate `max_frames` and `min_frames` objects in case they were not passed from an `InputStore`."""
1064        if video_id in self.max_frames:
1065            return
1066        for ind, cat_dict in times.items():
1067            maxes = []
1068            mins = []
1069            for cat, cat_list in cat_dict.items():
1070                if len(cat_list) > 0:
1071                    maxes.append(max([x[1] for x in cat_list]))
1072                    mins.append(min([x[0] for x in cat_list]))
1073            self.max_frames[video_id][ind] = max(maxes)
1074            self.min_frames[video_id][ind] = min(mins)
1075
1076    def _load_data(self) -> torch.Tensor:
1077        """Load behavior annotation and generate annotation prompts."""
1078        if self.video_order is None:
1079            return None
1080
1081        files = []
1082        for x in self.video_order:
1083            ok = False
1084            for folder in self.annotation_folder:
1085                for s in self.ann_suffix:
1086                    file = os.path.join(folder, x + s)
1087                    if os.path.exists(file):
1088                        files.append(file)
1089                        ok = True
1090                        break
1091            if not ok:
1092                files.append(None)
1093        self.not_found = set()
1094        annotations_dict = {}
1095        print("Computing annotation arrays...")
1096        for name, filename in tqdm(list(zip(self.video_order, files))):
1097            if filename is not None:
1098                times = self._open_annotations(filename)
1099            else:
1100                times = None
1101            if times is not None:
1102                self._generate_max_min_frames(times, name)
1103            annotations_dict.update(self._generate_annotation(times, name))
1104            del times
1105        annotation = self._make_trimmed_annotations(annotations_dict)
1106        del annotations_dict
1107        if self.filter_annotated:
1108            if self.exclusive:
1109                s = torch.sum((annotation != -100), dim=1)
1110            else:
1111                s = torch.sum(
1112                    torch.sum((annotation != -100), dim=1) == annotation.shape[1], dim=1
1113                )
1114            self.filtered += torch.where(s == 0)[0].tolist()
1115        annotation[annotation == -200] = -100
1116        return annotation
1117
1118    @abstractmethod
1119    def _open_annotations(self, filename: str) -> Dict:
1120        """Load the annotation from filename.
1121
1122        Parameters
1123        ----------
1124        filename : str
1125            path to an annotation file
1126
1127        Returns
1128        -------
1129        times : dict
1130            a nested dictionary where first-level keys are clip ids, second-level keys are categories and values are
1131            lists of (start, end, ambiguity status) lists
1132
1133        """
1134
1135
1136class SequenceActionSegStore(ActionSegmentationStore):  # +
1137    """A generalized implementation of `ActionSegmentationStore` for datasets where one file corresponds to multiple videos."""
1138
1139    def _generate_max_min_frames(self, times: Dict) -> None:
1140        """Generate `max_frames` and `min_frames` objects in case they were not passed from an `InputStore`."""
1141        for video_id in times:
1142            if video_id in self.max_frames:
1143                continue
1144            self.max_frames[video_id] = {}
1145            for ind, cat_dict in times[video_id].items():
1146                maxes = []
1147                mins = []
1148                for cat, cat_list in cat_dict.items():
1149                    maxes.append(max([x[1] for x in cat_list]))
1150                    mins.append(min([x[0] for x in cat_list]))
1151                self.max_frames[video_id][ind] = max(maxes)
1152                self.min_frames[video_id][ind] = min(mins)
1153
1154    @classmethod
1155    def get_file_ids(
1156        cls,
1157        filenames: List = None,
1158        annotation_path: str = None,
1159        *args,
1160        **kwargs,
1161    ) -> List:
1162        """Get file ids.
1163
1164        Process data parameters and return a list of ids  of the videos that should
1165        be processed by the `__init__` function.
1166
1167        Parameters
1168        ----------
1169        filenames : list, optional
1170            a list of annotation file paths
1171        annotation_path : str, optional
1172            path to the annotation folder
1173
1174        Returns
1175        -------
1176        video_ids : list
1177            a list of video file ids
1178
1179        """
1180        file_paths = []
1181        if annotation_path is not None:
1182            if isinstance(annotation_path, str):
1183                annotation_path = [annotation_path]
1184            file_paths = []
1185            for folder in annotation_path:
1186                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
1187        ids = set()
1188        for f in file_paths:
1189            if os.path.basename(f) in filenames:
1190                ids.add(os.path.basename(f))
1191        ids = sorted(ids)
1192        return ids
1193
1194    def _load_data(self) -> torch.Tensor:
1195        """Load behavior annotation and generate annotation prompts."""
1196        if self.video_order is None:
1197            return None
1198
1199        files = []
1200        for f in self.file_paths:
1201            if os.path.basename(f) in self.video_order:
1202                files.append(f)
1203        self.not_found = set()
1204        annotations_dict = {}
1205        for name, filename in tqdm(zip(self.video_order, files)):
1206            if filename is not None:
1207                times = self._open_sequences(filename)
1208            else:
1209                times = None
1210            if times is not None:
1211                self._generate_max_min_frames(times)
1212                none_ids = []
1213                for video_id, sequence_dict in times.items():
1214                    if sequence_dict is None:
1215                        none_ids.append(sequence_dict)
1216                        continue
1217                    annotations_dict.update(
1218                        self._generate_annotation(sequence_dict, video_id)
1219                    )
1220                for video_id in none_ids:
1221                    annotations_dict.update(self._generate_annotation(None, video_id))
1222                del times
1223        annotation = self._make_trimmed_annotations(annotations_dict)
1224        del annotations_dict
1225        if self.filter_annotated:
1226            if self.exclusive:
1227                s = torch.sum((annotation != -100), dim=1)
1228            else:
1229                s = torch.sum(
1230                    torch.sum((annotation != -100), dim=1) == annotation.shape[1], dim=1
1231                )
1232            self.filtered += torch.where(s == 0)[0].tolist()
1233        annotation[annotation == -200] = -100
1234        return annotation
1235
1236    @abstractmethod
1237    def _open_sequences(self, filename: str) -> Dict:
1238        """Load the annotation from filename.
1239
1240        Parameters
1241        ----------
1242        filename : str
1243            path to an annotation file
1244
1245        Returns
1246        -------
1247        times : dict
1248            a nested dictionary where first-level keys are video ids, second-level keys are clip ids,
1249            third-level keys are categories and values are
1250            lists of (start, end, ambiguity status) lists
1251
1252        """
1253
1254
1255class DLC2ActionStore(FileActionSegStore):  # +
1256    """DLC type annotation data.
1257
1258    The files are either the DLC2Action GUI output or a pickled dictionary of the following structure:
1259        - nested dictionary,
1260        - first-level keys are individual IDs,
1261        - second-level keys are labels,
1262        - values are lists of intervals,
1263        - the lists of intervals is formatted as `[start_frame, end_frame, ambiguity]`,
1264        - ambiguity is 1 if the action is ambiguous (!!at the moment DLC2Action will IGNORE those intervals!!) or 0 if it isn't.
1265    A minimum working example of such a dictionary is:
1266    ```
1267    {
1268        "ind0": {},
1269        "ind1": {
1270            "running": [60, 70, 0]],
1271            "eating": []
1272        }
1273    }
1274    ```
1275    Here there are two animals: `"ind0"` and `"ind1"`, and two actions: running and eating.
1276    The only annotated action is eating for `"ind1"` between frames 60 and 70.
1277    If you generate those files manually, run this code for a sanity check:
1278    ```
1279    import pickle
1280    with open("/path/to/annotation.pickle", "rb") as f:
1281    data = pickle.load(f)
1282    for ind, ind_dict in data.items():
1283        print(f'individual {ind}:')
1284        for label, intervals in ind_dict.items():
1285            for start, end, ambiguity in intervals:
1286                if ambiguity == 0:
1287                    print(f'  from {start} to {end} frame: {label}')
1288    ```
1289    Assumes the following file structure:
1290    ```
1291    annotation_path
1292    ├── video1_annotation.pickle
1293    └── video2_labels.pickle
1294    ```
1295    Here `annotation_suffix` is `{'_annotation.pickle', '_labels.pickle'}`.
1296    """
1297
1298    def _open_annotations(self, filename: str) -> Dict:
1299        """Load the annotation from `filename`."""
1300        try:
1301            with open(filename, "rb") as f:
1302                data = pickle.load(f)
1303            if isinstance(data, dict):
1304                annotation = data
1305                for ind in annotation:
1306                    for cat, cat_list in annotation[ind].items():
1307                        annotation[ind][cat] = [
1308                            [start, end, 0] for start, end in cat_list
1309                        ]
1310            else:
1311                _, loaded_labels, animals, loaded_times = data
1312                annotation = {}
1313                for ind, ind_list in zip(animals, loaded_times):
1314                    annotation[ind] = {}
1315                    for cat, cat_list in zip(loaded_labels, ind_list):
1316                        annotation[ind][cat] = cat_list
1317            return annotation
1318        except:
1319            print(f"{filename} is invalid or does not exist")
1320            return None
1321
1322
1323class BorisStore(FileActionSegStore):  # +
1324    """BORIS type annotation data.
1325
1326    Assumes the following file structure:
1327    ```
1328    annotation_path
1329    ├── video1_annotation.pickle
1330    └── video2_labels.pickle
1331    ```
1332    Here `annotation_suffix` is `{'_annotation.pickle', '_labels.pickle'}`.
1333    """
1334
1335    def __init__(
1336        self,
1337        video_order: List = None,
1338        min_frames: Dict = None,
1339        max_frames: Dict = None,
1340        visibility: Dict = None,
1341        exclusive: bool = True,
1342        len_segment: int = 128,
1343        overlap: int = 0,
1344        behaviors: Set = None,
1345        ignored_classes: Set = None,
1346        annotation_suffix: Union[Set, str] = None,
1347        annotation_path: Union[Set, str] = None,
1348        behavior_file: str = None,
1349        correction: Dict = None,
1350        frame_limit: int = 0,
1351        filter_annotated: bool = False,
1352        filter_background: bool = False,
1353        error_class: str = None,
1354        min_frames_action: int = None,
1355        key_objects: Tuple = None,
1356        visibility_min_score: float = 0.2,
1357        visibility_min_frac: float = 0.7,
1358        mask: Dict = None,
1359        use_hard_negatives: bool = False,
1360        default_agent_name: str = "ind0",
1361        interactive: bool = False,
1362        ignored_clips: Set = None,
1363        *args,
1364        **kwargs,
1365    ) -> None:
1366        """Initialize a store.
1367
1368        Parameters
1369        ----------
1370        video_order : list, optional
1371            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1372        min_frames : dict, optional
1373            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1374            clip start frames (not passed if creating from key objects)
1375        max_frames : dict, optional
1376            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1377            clip end frames (not passed if creating from key objects)
1378        visibility : dict, optional
1379            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1380            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
1381        exclusive : bool, default True
1382            if True, the annotation is single-label; if False, multi-label
1383        len_segment : int, default 128
1384            the length of the segments in which the data should be cut (in frames)
1385        overlap : int, default 0
1386            the length of the overlap between neighboring segments (in frames)
1387        behaviors : set, optional
1388            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
1389            loaded from a file)
1390        ignored_classes : set, optional
1391            the list of behaviors from the behaviors list or file to not annotate
1392        annotation_suffix : str | set, optional
1393            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
1394            (not passed if creating from key objects or if irrelevant for the dataset)
1395        annotation_path : str | set, optional
1396            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
1397            from key objects)
1398        behavior_file : str, optional
1399            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
1400        correction : dict, optional
1401            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
1402            can be used to correct for variations in naming or to merge several labels in one
1403        frame_limit : int, default 0
1404            the smallest possible length of a clip (shorter clips are discarded)
1405        filter_annotated : bool, default False
1406            if True, the samples that do not have any labels will be filtered
1407        filter_background : bool, default False
1408            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
1409        error_class : str, optional
1410            the name of the error class (the annotations that intersect with this label will be discarded)
1411        min_frames_action : int, default 0
1412            the minimum length of an action (shorter actions are not annotated)
1413        key_objects : tuple, optional
1414            the key objects to load the BehaviorStore from
1415        visibility_min_score : float, default 5
1416            the minimum visibility score for visibility filtering
1417        visibility_min_frac : float, default 0.7
1418            the minimum fraction of visible frames for visibility filtering
1419        mask : dict, optional
1420            a masked value dictionary (for active learning simulation experiments)
1421        use_hard_negatives : bool, default False
1422            mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing
1423        default_agent_name : str, default 'ind0'
1424            the name of the default agent
1425        interactive : bool, default False
1426            if `True`, annotation is assigned to pairs of individuals
1427        ignored_clips : set, optional
1428            a set of clip ids to ignore
1429
1430        """
1431        self.default_agent_name = default_agent_name
1432        super().__init__(
1433            video_order=video_order,
1434            min_frames=min_frames,
1435            max_frames=max_frames,
1436            visibility=visibility,
1437            exclusive=exclusive,
1438            len_segment=len_segment,
1439            overlap=overlap,
1440            behaviors=behaviors,
1441            ignored_classes=ignored_classes,
1442            annotation_suffix=annotation_suffix,
1443            annotation_path=annotation_path,
1444            behavior_file=behavior_file,
1445            correction=correction,
1446            frame_limit=frame_limit,
1447            filter_annotated=filter_annotated,
1448            filter_background=filter_background,
1449            error_class=error_class,
1450            min_frames_action=min_frames_action,
1451            key_objects=key_objects,
1452            visibility_min_score=visibility_min_score,
1453            visibility_min_frac=visibility_min_frac,
1454            mask=mask,
1455            use_hard_negatives=use_hard_negatives,
1456            interactive=interactive,
1457            ignored_clips=ignored_clips,
1458        )
1459
1460    def _open_annotations(self, filename: str) -> Dict:
1461        """Load the annotation from filename."""
1462        try:
1463            df = pd.read_csv(filename, header=15)
1464            fps = df.iloc[0]["FPS"]
1465            df["Subject"] = df["Subject"].fillna(self.default_agent_name)
1466            loaded_labels = list(df["Behavior"].unique())
1467            animals = list(df["Subject"].unique())
1468            loaded_times = {}
1469            for ind in animals:
1470                loaded_times[ind] = {}
1471                agent_df = df[df["Subject"] == ind]
1472                for cat in loaded_labels:
1473                    filtered_df = agent_df[agent_df["Behavior"] == cat]
1474                    starts = (
1475                        filtered_df["Time"][filtered_df["Status"] == "START"] * fps
1476                    ).astype(int)
1477                    ends = (
1478                        filtered_df["Time"][filtered_df["Status"] == "STOP"] * fps
1479                    ).astype(int)
1480                    loaded_times[ind][cat] = [
1481                        [start, end, 0] for start, end in zip(starts, ends)
1482                    ]
1483            return loaded_times
1484        except:
1485            print(f"{filename} is invalid or does not exist")
1486            return None
1487
1488
1489class CalMS21Store(SequenceActionSegStore):  # +
1490    """CalMS21 annotation data.
1491
1492    Use the `'random:test_from_name:{name}'` and `'val-from-name:{val_name}:test-from-name:{test_name}'`
1493    partitioning methods with `'train'`, `'test'` and `'unlabeled'` names to separate into train, test and validation
1494    subsets according to the original files. For example, with `'val-from-name:test:test-from-name:unlabeled'`
1495    the data from the test file will go into validation and the unlabeled files will be the test.
1496
1497    Assumes the following file structure:
1498    ```
1499    annotation_path
1500    ├── calms21_task_train.npy
1501    ├── calms21_task_test.npy
1502    ├── calms21_unlabeled_videos_part1.npy
1503    ├── calms21_unlabeled_videos_part2.npy
1504    └── calms21_unlabeled_videos_part3.npy
1505    ```
1506    """
1507
1508    def __init__(
1509        self,
1510        task_n: int = 1,
1511        include_task1: bool = True,
1512        video_order: List = None,
1513        min_frames: Dict = None,
1514        max_frames: Dict = None,
1515        len_segment: int = 128,
1516        overlap: int = 0,
1517        ignored_classes: Set = None,
1518        annotation_path: Union[Set, str] = None,
1519        key_objects: Tuple = None,
1520        treba_files: bool = False,
1521        *args,
1522        **kwargs,
1523    ) -> None:
1524        """Initialize the store.
1525
1526        Parameters
1527        ----------
1528        task_n : [1, 2]
1529            the number of the task
1530        include_task1 : bool, default True
1531            include task 1 data to training set
1532        video_order : list, optional
1533            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1534        min_frames : dict, optional
1535            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1536            clip start frames (not passed if creating from key objects)
1537        max_frames : dict, optional
1538            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1539            clip end frames (not passed if creating from key objects)
1540        len_segment : int, default 128
1541            the length of the segments in which the data should be cut (in frames)
1542        overlap : int, default 0
1543            the length of the overlap between neighboring segments (in frames)
1544        ignored_classes : set, optional
1545            the list of behaviors from the behaviors list or file to not annotate
1546        annotation_path : str | set, optional
1547            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
1548            from key objects)
1549        key_objects : tuple, optional
1550            the key objects to load the BehaviorStore from
1551        treba_files : bool, default False
1552            if `True`, TREBA feature files will be loaded
1553
1554        """
1555        self.task_n = int(task_n)
1556        self.include_task1 = include_task1
1557        if self.task_n == 1:
1558            self.include_task1 = False
1559        self.treba_files = treba_files
1560        if "exclusive" in kwargs:
1561            exclusive = kwargs["exclusive"]
1562        else:
1563            exclusive = True
1564        if "behaviors" in kwargs and kwargs["behaviors"] is not None:
1565            behaviors = kwargs["behaviors"]
1566        else:
1567            behaviors = ["attack", "investigation", "mount", "other"]
1568            if task_n == 3:
1569                exclusive = False
1570                behaviors += [
1571                    "approach",
1572                    "disengaged",
1573                    "groom",
1574                    "intromission",
1575                    "mount_attempt",
1576                    "sniff_face",
1577                    "whiterearing",
1578                ]
1579        super().__init__(
1580            video_order=video_order,
1581            min_frames=min_frames,
1582            max_frames=max_frames,
1583            exclusive=exclusive,
1584            len_segment=len_segment,
1585            overlap=overlap,
1586            behaviors=behaviors,
1587            ignored_classes=ignored_classes,
1588            annotation_path=annotation_path,
1589            key_objects=key_objects,
1590            filter_annotated=False,
1591            interactive=True,
1592        )
1593
1594    @classmethod
1595    def get_file_ids(
1596        cls,
1597        task_n: int = 1,
1598        include_task1: bool = False,
1599        treba_files: bool = False,
1600        annotation_path: Union[str, Set] = None,
1601        file_paths=None,
1602        *args,
1603        **kwargs,
1604    ) -> Iterable:
1605        """Get file ids.
1606
1607        Process data parameters and return a list of ids  of the videos that should
1608        be processed by the `__init__` function.
1609
1610        Parameters
1611        ----------
1612        task_n : {1, 2, 3}
1613            the index of the CalMS21 challenge task
1614        include_task1 : bool, default False
1615            if `True`, the training file of the task 1 will be loaded
1616        treba_files : bool, default False
1617            if `True`, the TREBA feature files will be loaded
1618        filenames : set, optional
1619            a set of string filenames to search for (only basenames, not the whole paths)
1620        annotation_path : str | set, optional
1621            the path to the folder where the pose and feature files are stored or a set of such paths
1622            (not passed if creating from key objects or from `file_paths`)
1623        file_paths : set, optional
1624            a set of string paths to the pose and feature files
1625            (not passed if creating from key objects or from `data_path`)
1626
1627        Returns
1628        -------
1629        video_ids : list
1630            a list of video file ids
1631
1632        """
1633        task_n = int(task_n)
1634        if task_n == 1:
1635            include_task1 = False
1636        files = []
1637        if treba_files:
1638            postfix = "_features"
1639        else:
1640            postfix = ""
1641        files.append(f"calms21_task{task_n}_train{postfix}.npy")
1642        files.append(f"calms21_task{task_n}_test{postfix}.npy")
1643        if include_task1:
1644            files.append(f"calms21_task1_train{postfix}.npy")
1645        filenames = set(files)
1646        return SequenceActionSegStore.get_file_ids(
1647            filenames, annotation_path=annotation_path
1648        )
1649
1650    def _open_sequences(self, filename: str) -> Dict:
1651        """Load the annotation from filename.
1652
1653        Parameters
1654        ----------
1655        filename : str
1656            path to an annotation file
1657
1658        Returns
1659        -------
1660        times : dict
1661            a nested dictionary where first-level keys are video ids, second-level keys are clip ids,
1662            third-level keys are categories and values are
1663            lists of (start, end, ambiguity status) lists
1664
1665        """
1666        data_dict = np.load(filename, allow_pickle=True).item()
1667        data = {}
1668        result = {}
1669        keys = list(data_dict.keys())
1670        if "test" in os.path.basename(filename):
1671            mode = "test"
1672        elif "unlabeled" in os.path.basename(filename):
1673            mode = "unlabeled"
1674        else:
1675            mode = "train"
1676        if "approach" in keys:
1677            for behavior in keys:
1678                for key in data_dict[behavior].keys():
1679                    ann = data_dict[behavior][key]["annotations"]
1680                    result[f'{mode}--{key.split("/")[-1]}'] = {
1681                        "mouse1+mouse2": defaultdict(lambda: [])
1682                    }
1683                    starts = np.where(
1684                        np.diff(np.concatenate([np.array([0]), ann, np.array([0])]))
1685                        == 1
1686                    )[0]
1687                    ends = np.where(
1688                        np.diff(np.concatenate([np.array([0]), ann, np.array([0])]))
1689                        == -1
1690                    )[0]
1691                    for start, end in zip(starts, ends):
1692                        result[f'{mode}--{key.split("/")[-1]}']["mouse1+mouse2"][
1693                            behavior
1694                        ].append([start, end, 0])
1695                    for b in self.behaviors:
1696                        result[f'{mode}--{key.split("/")[-1]}---mouse1+mouse2'][
1697                            "mouse1+mouse2"
1698                        ][f"unknown {b}"].append([0, len(ann), 0])
1699        for key in keys:
1700            data.update(data_dict[key])
1701            data_dict.pop(key)
1702        if "approach" not in keys and self.task_n == 3:
1703            for key in data.keys():
1704                result[f'{mode}--{key.split("/")[-1]}'] = {"mouse1+mouse2": {}}
1705                ann = data[key]["annotations"]
1706                for i in range(4):
1707                    starts = np.where(
1708                        np.diff(
1709                            np.concatenate(
1710                                [np.array([0]), (ann == i).astype(int), np.array([0])]
1711                            )
1712                        )
1713                        == 1
1714                    )[0]
1715                    ends = np.where(
1716                        np.diff(
1717                            np.concatenate(
1718                                [np.array([0]), (ann == i).astype(int), np.array([0])]
1719                            )
1720                        )
1721                        == -1
1722                    )[0]
1723                    result[f'{mode}--{key.split("/")[-1]}']["mouse1+mouse2"][
1724                        self.behaviors_dict()[i]
1725                    ] = [[start, end, 0] for start, end in zip(starts, ends)]
1726        if self.task_n != 3:
1727            for seq_name, seq_dict in data.items():
1728                if "annotations" not in seq_dict:
1729                    return None
1730                behaviors = np.unique(seq_dict["annotations"])
1731                ann = seq_dict["annotations"]
1732                key = f'{mode}--{seq_name.split("/")[-1]}'
1733                result[key] = {"mouse1+mouse2": {}}
1734                for i in behaviors:
1735                    starts = np.where(
1736                        np.diff(
1737                            np.concatenate(
1738                                [np.array([0]), (ann == i).astype(int), np.array([0])]
1739                            )
1740                        )
1741                        == 1
1742                    )[0]
1743                    ends = np.where(
1744                        np.diff(
1745                            np.concatenate(
1746                                [np.array([0]), (ann == i).astype(int), np.array([0])]
1747                            )
1748                        )
1749                        == -1
1750                    )[0]
1751                    result[key]["mouse1+mouse2"][self.behaviors_dict()[i]] = [
1752                        [start, end, 0] for start, end in zip(starts, ends)
1753                    ]
1754        return result
1755
1756
1757class CSVActionSegStore(FileActionSegStore):  # +
1758    """CSV type annotation data.
1759
1760    Assumes that files are saved as .csv tables with at least the following columns:
1761    - from / start : start of action,
1762    - to / end : end of action,
1763    - class / behavior / behaviour / label / type : action label.
1764
1765    If the times are set in seconds instead of frames, don't forget to set the `fps` parameter to your frame rate.
1766
1767    Assumes the following file structure:
1768    ```
1769    annotation_path
1770    ├── video1_annotation.csv
1771    └── video2_labels.csv
1772    ```
1773    Here `annotation_suffix` is `{'_annotation.csv', '_labels.csv'}`.
1774    """
1775
1776    def __init__(
1777        self,
1778        video_order: List = None,
1779        min_frames: Dict = None,
1780        max_frames: Dict = None,
1781        visibility: Dict = None,
1782        exclusive: bool = True,
1783        len_segment: int = 128,
1784        overlap: int = 0,
1785        behaviors: Set = None,
1786        ignored_classes: Set = None,
1787        annotation_suffix: Union[Set, str] = None,
1788        annotation_path: Union[Set, str] = None,
1789        behavior_file: str = None,
1790        correction: Dict = None,
1791        frame_limit: int = 0,
1792        filter_annotated: bool = False,
1793        filter_background: bool = False,
1794        error_class: str = None,
1795        min_frames_action: int = None,
1796        key_objects: Tuple = None,
1797        visibility_min_score: float = 0.2,
1798        visibility_min_frac: float = 0.7,
1799        mask: Dict = None,
1800        default_agent_name: str = "ind0",
1801        separator: str = ",",
1802        fps: int = 30,
1803        *args,
1804        **kwargs,
1805    ) -> None:
1806        """Initialize the store.
1807
1808        Parameters
1809        ----------
1810        video_order : list, optional
1811            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1812        min_frames : dict, optional
1813            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1814            clip start frames (not passed if creating from key objects)
1815        max_frames : dict, optional
1816            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1817            clip end frames (not passed if creating from key objects)
1818        visibility : dict, optional
1819            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1820            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
1821        exclusive : bool, default True
1822            if True, the annotation is single-label; if False, multi-label
1823        len_segment : int, default 128
1824            the length of the segments in which the data should be cut (in frames)
1825        overlap : int, default 0
1826            the length of the overlap between neighboring segments (in frames)
1827        behaviors : set, optional
1828            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
1829            loaded from a file)
1830        ignored_classes : set, optional
1831            the list of behaviors from the behaviors list or file to not annotate
1832        annotation_suffix : str | set, optional
1833            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
1834            (not passed if creating from key objects or if irrelevant for the dataset)
1835        annotation_path : str | set, optional
1836            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
1837            from key objects)
1838        behavior_file : str, optional
1839            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
1840        correction : dict, optional
1841            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
1842            can be used to correct for variations in naming or to merge several labels in one
1843        frame_limit : int, default 0
1844            the smallest possible length of a clip (shorter clips are discarded)
1845        filter_annotated : bool, default False
1846            if True, the samples that do not have any labels will be filtered
1847        filter_background : bool, default False
1848            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
1849        error_class : str, optional
1850            the name of the error class (the annotations that intersect with this label will be discarded)
1851        min_frames_action : int, default 0
1852            the minimum length of an action (shorter actions are not annotated)
1853        key_objects : tuple, optional
1854            the key objects to load the BehaviorStore from
1855        visibility_min_score : float, default 5
1856            the minimum visibility score for visibility filtering
1857        visibility_min_frac : float, default 0.7
1858            the minimum fraction of visible frames for visibility filtering
1859        mask : dict, optional
1860            a masked value dictionary (for active learning simulation experiments)
1861        default_agent_name : str, default "ind0"
1862            the clip id to use when there is no given
1863        separator : str, default ","
1864            the separator in the csv files
1865        fps : int, default 30
1866            frames per second in the videos
1867
1868        """
1869        self.default_agent_name = default_agent_name
1870        self.separator = separator
1871        self.fps = fps
1872        super().__init__(
1873            video_order=video_order,
1874            min_frames=min_frames,
1875            max_frames=max_frames,
1876            visibility=visibility,
1877            exclusive=exclusive,
1878            len_segment=len_segment,
1879            overlap=overlap,
1880            behaviors=behaviors,
1881            ignored_classes=ignored_classes,
1882            ignored_clips=None,
1883            annotation_suffix=annotation_suffix,
1884            annotation_path=annotation_path,
1885            behavior_file=behavior_file,
1886            correction=correction,
1887            frame_limit=frame_limit,
1888            filter_annotated=filter_annotated,
1889            filter_background=filter_background,
1890            error_class=error_class,
1891            min_frames_action=min_frames_action,
1892            key_objects=key_objects,
1893            visibility_min_score=visibility_min_score,
1894            visibility_min_frac=visibility_min_frac,
1895            mask=mask,
1896        )
1897
1898    def _open_annotations(self, filename: str) -> Dict:
1899        """Load the annotation from `filename`."""
1900        data = pd.read_csv(filename, sep=self.separator)
1901        data.columns = list(map(lambda x: x.lower(), data.columns))
1902        starts, ends, actions = None, None, None
1903        start_names = ["from", "start"]
1904        for x in start_names:
1905            if x in data.columns:
1906                starts = data[x]
1907        end_names = ["to", "end"]
1908        for x in end_names:
1909            if x in data.columns:
1910                ends = data[x]
1911        class_names = ["class", "behavior", "behaviour", "type", "label"]
1912        for x in class_names:
1913            if x in data.columns:
1914                actions = data[x]
1915        if starts is None:
1916            raise ValueError("The file must have a column titled 'from' or 'start'!")
1917        if ends is None:
1918            raise ValueError("The file must have a column titled 'to' or 'end'!")
1919        if actions is None:
1920            raise ValueError(
1921                "The file must have a column titled 'class', 'behavior', 'behaviour', 'type' or 'label'!"
1922            )
1923        times = defaultdict(lambda: defaultdict(lambda: []))
1924        for start, end, action in zip(starts, ends, actions):
1925            if any([np.isnan(x) for x in [start, end]]):
1926                continue
1927            times[self.default_agent_name][action].append(
1928                [int(start * self.fps), int(end * self.fps), 0]
1929            )
1930        return times
1931
1932
1933class SIMBAStore(FileActionSegStore):  # +
1934    """SIMBA paper format data.
1935
1936    Assumes the following file structure:
1937    ```
1938    annotation_path
1939    ├── Video1.csv
1940    ...
1941    └── Video9.csv
1942    """
1943
1944    def __init__(
1945        self,
1946        video_order: List = None,
1947        min_frames: Dict = None,
1948        max_frames: Dict = None,
1949        visibility: Dict = None,
1950        exclusive: bool = True,
1951        len_segment: int = 128,
1952        overlap: int = 0,
1953        behaviors: Set = None,
1954        ignored_classes: Set = None,
1955        ignored_clips: Set = None,
1956        annotation_path: Union[Set, str] = None,
1957        correction: Dict = None,
1958        filter_annotated: bool = False,
1959        filter_background: bool = False,
1960        error_class: str = None,
1961        min_frames_action: int = None,
1962        key_objects: Tuple = None,
1963        visibility_min_score: float = 0.2,
1964        visibility_min_frac: float = 0.7,
1965        mask: Dict = None,
1966        use_hard_negatives: bool = False,
1967        annotation_suffix: str = None,
1968        *args,
1969        **kwargs,
1970    ) -> None:
1971        """Initialize the annotation store.
1972
1973        Parameters
1974        ----------
1975        video_order : list, optional
1976            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1977        min_frames : dict, optional
1978            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1979            clip start frames (not passed if creating from key objects)
1980        max_frames : dict, optional
1981            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1982            clip end frames (not passed if creating from key objects)
1983        visibility : dict, optional
1984            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1985            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
1986        exclusive : bool, default True
1987            if True, the annotation is single-label; if False, multi-label
1988        len_segment : int, default 128
1989            the length of the segments in which the data should be cut (in frames)
1990        overlap : int, default 0
1991            the length of the overlap between neighboring segments (in frames)
1992        behaviors : set, optional
1993            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
1994            loaded from a file)
1995        ignored_classes : set, optional
1996            the list of behaviors from the behaviors list or file to not annotate
1997        ignored_clips : set, optional
1998            clip ids to ignore
1999        annotation_path : str | set, optional
2000            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
2001            from key objects)
2002        behavior_file : str, optional
2003            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
2004        correction : dict, optional
2005            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
2006            can be used to correct for variations in naming or to merge several labels in one
2007        filter_annotated : bool, default False
2008            if True, the samples that do not have any labels will be filtered
2009        filter_background : bool, default False
2010            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
2011        error_class : str, optional
2012            the name of the error class (the annotations that intersect with this label will be discarded)
2013        min_frames_action : int, default 0
2014            the minimum length of an action (shorter actions are not annotated)
2015        key_objects : tuple, optional
2016            the key objects to load the BehaviorStore from
2017        visibility_min_score : float, default 5
2018            the minimum visibility score for visibility filtering
2019        visibility_min_frac : float, default 0.7
2020            the minimum fraction of visible frames for visibility filtering
2021        mask : dict, optional
2022            a masked value dictionary (for active learning simulation experiments)
2023        use_hard_negatives : bool, default False
2024            mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing
2025        annotation_suffix : str | set, optional
2026            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
2027            (not passed if creating from key objects or if irrelevant for the dataset)
2028
2029        """
2030        super().__init__(
2031            video_order=video_order,
2032            min_frames=min_frames,
2033            max_frames=max_frames,
2034            visibility=visibility,
2035            exclusive=exclusive,
2036            len_segment=len_segment,
2037            overlap=overlap,
2038            behaviors=behaviors,
2039            ignored_classes=ignored_classes,
2040            ignored_clips=ignored_clips,
2041            annotation_suffix=annotation_suffix,
2042            annotation_path=annotation_path,
2043            behavior_file=None,
2044            correction=correction,
2045            frame_limit=0,
2046            filter_annotated=filter_annotated,
2047            filter_background=filter_background,
2048            error_class=error_class,
2049            min_frames_action=min_frames_action,
2050            key_objects=key_objects,
2051            visibility_min_score=visibility_min_score,
2052            visibility_min_frac=visibility_min_frac,
2053            mask=mask,
2054            use_hard_negatives=use_hard_negatives,
2055            interactive=True,
2056        )
2057
2058    def _open_annotations(self, filename: str) -> Dict:
2059        """Load the annotation from filename.
2060
2061        Parameters
2062        ----------
2063        filename : str
2064            path to an annotation file
2065
2066        Returns
2067        -------
2068        times : dict
2069            a nested dictionary where first-level keys are clip ids, second-level keys are categories and values are
2070            lists of (start, end, ambiguity status) lists
2071
2072        """
2073        data = pd.read_csv(filename)
2074        columns = [x for x in data.columns if x.split("_")[-1] == "x"]
2075        animals = sorted(set([x.split("_")[-2] for x in columns]))
2076        if len(animals) > 2:
2077            raise ValueError(
2078                "SIMBAStore is only implemented for files with 1 or 2 animals!"
2079            )
2080        if len(animals) == 1:
2081            ind = animals[0]
2082        else:
2083            ind = "+".join(animals)
2084        behaviors = [
2085            "_".join(x.split("_")[:-1])
2086            for x in data.columns
2087            if x.split("_")[-1] == "prediction"
2088        ]
2089        result = {}
2090        for behavior in behaviors:
2091            ann = data[f"{behavior}_prediction"].values
2092            diff = np.diff(
2093                np.concatenate([np.array([0]), (ann == 1).astype(int), np.array([0])])
2094            )
2095            starts = np.where(diff == 1)[0]
2096            ends = np.where(diff == -1)[0]
2097            result[behavior] = [[start, end, 0] for start, end in zip(starts, ends)]
2098            diff = np.diff(
2099                np.concatenate(
2100                    [np.array([0]), (np.isnan(ann)).astype(int), np.array([0])]
2101                )
2102            )
2103            starts = np.where(diff == 1)[0]
2104            ends = np.where(diff == -1)[0]
2105            result[f"unknown {behavior}"] = [
2106                [start, end, 0] for start, end in zip(starts, ends)
2107            ]
2108        if self.behaviors is not None:
2109            for behavior in self.behaviors:
2110                if behavior not in behaviors:
2111                    result[f"unknown {behavior}"] = [[0, len(data), 0]]
2112        return {ind: result}
class EmptyBehaviorStore(dlc2action.data.base_store.BehaviorStore):
 27class EmptyBehaviorStore(BehaviorStore):
 28    """An empty annotation store that does not contain any data samples."""
 29
 30    def __init__(
 31        self, video_order: List = None, key_objects: Tuple = None, *args, **kwargs
 32    ):
 33        """Initialize the store.
 34
 35        Parameters
 36        ----------
 37        video_order : list, optional
 38            a list of video ids that should be processed in the same order (not passed if creating from key objects)
 39        key_objects : tuple, optional
 40            a tuple of key objects
 41
 42        """
 43        pass
 44
 45    def __len__(self) -> int:
 46        """Get the number of available samples.
 47
 48        Returns
 49        -------
 50        length : int
 51            the number of available samples
 52
 53        """
 54        return 0
 55
 56    def remove(self, indices: List) -> None:
 57        """Remove the samples corresponding to indices.
 58
 59        Parameters
 60        ----------
 61        indices : int
 62            a list of integer indices to remove
 63
 64        """
 65        pass
 66
 67    def key_objects(self) -> Tuple:
 68        """Return a tuple of the key objects necessary to re-create the Store.
 69
 70        Returns
 71        -------
 72        key_objects : tuple
 73            a tuple of key objects
 74
 75        """
 76        return ()
 77
 78    def load_from_key_objects(self, key_objects: Tuple) -> None:
 79        """Load the information from a tuple of key objects.
 80
 81        Parameters
 82        ----------
 83        key_objects : tuple
 84            a tuple of key objects
 85
 86        """
 87        pass
 88
 89    def to_ram(self) -> None:
 90        """Transfer the data samples to RAM if they were previously stored as file paths."""
 91        pass
 92
 93    def get_original_coordinates(self) -> np.ndarray:
 94        """Return the original coordinates array.
 95
 96        Returns
 97        -------
 98        np.ndarray
 99            an array that contains the coordinates of the data samples in original input data (video id, clip id,
100            start frame)
101
102        """
103        return None
104
105    def create_subsample(self, indices: List, ssl_indices: List = None):
106        """Create a new store that contains a subsample of the data.
107
108        Parameters
109        ----------
110        indices : list
111            the indices to be included in the subsample
112        ssl_indices : list, optional
113            the indices to be included in the subsample without the annotation data
114
115        """
116        return self.new()
117
118    @classmethod
119    def get_file_ids(cls, *args, **kwargs) -> List:
120        """Get file ids.
121
122        Process data parameters and return a list of ids  of the videos that should
123        be processed by the `__init__` function.
124
125        Returns
126        -------
127        video_ids : list
128            a list of video file ids
129
130        """
131        return None
132
133    def __getitem__(self, ind: int) -> torch.Tensor:
134        """Return the annotation of the sample corresponding to an index.
135
136        Parameters
137        ----------
138        ind : int
139            index of the sample
140
141        Returns
142        -------
143        sample : torch.Tensor
144            the corresponding annotation tensor
145
146        """
147        return torch.tensor(float("nan"))
148
149    def get_len(self, return_unlabeled: bool) -> int:
150        """Get the length of the subsample of labeled/unlabeled data.
151
152        If `return_unlabele`d is `True`, the index is in the subsample of unlabeled data, if `False` in labeled
153        and if `return_unlabeled` if `None` the index is already correct.
154
155        Parameters
156        ----------
157        return_unlabeled : bool
158            the identifier for the subsample
159
160        Returns
161        -------
162        length : int
163            the length of the subsample
164
165        """
166        return None
167
168    def get_idx(self, index: int, return_unlabeled: bool) -> int:
169        """Convert from an index in the subsample of labeled/unlabeled data to an index in the full array.
170
171        If `return_unlabeled` is `True`, the index is in the subsample of unlabeled data, if `False` in labeled
172        and if `return_unlabeled` is `None` the index is already correct.
173
174        Parameters
175        ----------
176        index : int
177            the index in the subsample
178        return_unlabeled : bool
179            the identifier for the subsample
180
181        Returns
182        -------
183        corrected_index : int
184            the index in the full dataset
185
186        """
187        return index
188
189    def count_classes(
190        self, perc: bool = False, zeros: bool = False, bouts: bool = False
191    ) -> Dict:
192        """Get a dictionary with class-wise frame counts.
193
194        Parameters
195        ----------
196        perc : bool, default False
197            if `True`, a fraction of the total frame count is returned
198        zeros : bool, default False
199            if `True` and annotation is not exclusive, zero counts are returned
200        bouts : bool, default False
201            if `True`, instead of frame counts segment counts are returned
202
203        Returns
204        -------
205        count_dictionary : dict
206            a dictionary with class indices as keys and frame counts as values
207
208        """
209        return {}
210
211    def behaviors_dict(self) -> Dict:
212        """Get a dictionary of class names.
213
214        Returns
215        -------
216        behavior_dictionary: dict
217            a dictionary with class indices as keys and class names as values
218
219        """
220        return {}
221
222    def annotation_class(self) -> str:
223        """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon).
224
225        Returns
226        -------
227        annotation_class : str
228            the type of annotation
229
230        """
231        return "none"
232
233    def size(self) -> int:
234        """Get the total number of frames in the data.
235
236        Returns
237        -------
238        size : int
239            the total number of frames
240
241        """
242        return None
243
244    def filtered_indices(self) -> List:
245        """Return the indices of the samples that should be removed.
246
247        Choosing the indices can be based on any kind of filering defined in the `__init__` function by the data
248        parameters.
249
250        Returns
251        -------
252        indices_to_remove : list
253            a list of integer indices that should be removed
254
255        """
256        return []
257
258    def set_pseudo_labels(self, labels: torch.Tensor) -> None:
259        """Set pseudo labels to the unlabeled data.
260
261        Parameters
262        ----------
263        labels : torch.Tensor
264            a tensor of pseudo-labels for the unlabeled data
265
266        """
267        pass

An empty annotation store that does not contain any data samples.

EmptyBehaviorStore(video_order: List = None, key_objects: Tuple = None, *args, **kwargs)
30    def __init__(
31        self, video_order: List = None, key_objects: Tuple = None, *args, **kwargs
32    ):
33        """Initialize the store.
34
35        Parameters
36        ----------
37        video_order : list, optional
38            a list of video ids that should be processed in the same order (not passed if creating from key objects)
39        key_objects : tuple, optional
40            a tuple of key objects
41
42        """
43        pass

Initialize the store.

Parameters

video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects) key_objects : tuple, optional a tuple of key objects

def remove(self, indices: List) -> None:
56    def remove(self, indices: List) -> None:
57        """Remove the samples corresponding to indices.
58
59        Parameters
60        ----------
61        indices : int
62            a list of integer indices to remove
63
64        """
65        pass

Remove the samples corresponding to indices.

Parameters

indices : int a list of integer indices to remove

def key_objects(self) -> Tuple:
67    def key_objects(self) -> Tuple:
68        """Return a tuple of the key objects necessary to re-create the Store.
69
70        Returns
71        -------
72        key_objects : tuple
73            a tuple of key objects
74
75        """
76        return ()

Return a tuple of the key objects necessary to re-create the Store.

Returns

key_objects : tuple a tuple of key objects

def load_from_key_objects(self, key_objects: Tuple) -> None:
78    def load_from_key_objects(self, key_objects: Tuple) -> None:
79        """Load the information from a tuple of key objects.
80
81        Parameters
82        ----------
83        key_objects : tuple
84            a tuple of key objects
85
86        """
87        pass

Load the information from a tuple of key objects.

Parameters

key_objects : tuple a tuple of key objects

def to_ram(self) -> None:
89    def to_ram(self) -> None:
90        """Transfer the data samples to RAM if they were previously stored as file paths."""
91        pass

Transfer the data samples to RAM if they were previously stored as file paths.

def get_original_coordinates(self) -> numpy.ndarray:
 93    def get_original_coordinates(self) -> np.ndarray:
 94        """Return the original coordinates array.
 95
 96        Returns
 97        -------
 98        np.ndarray
 99            an array that contains the coordinates of the data samples in original input data (video id, clip id,
100            start frame)
101
102        """
103        return None

Return the original coordinates array.

Returns

np.ndarray an array that contains the coordinates of the data samples in original input data (video id, clip id, start frame)

def create_subsample(self, indices: List, ssl_indices: List = None):
105    def create_subsample(self, indices: List, ssl_indices: List = None):
106        """Create a new store that contains a subsample of the data.
107
108        Parameters
109        ----------
110        indices : list
111            the indices to be included in the subsample
112        ssl_indices : list, optional
113            the indices to be included in the subsample without the annotation data
114
115        """
116        return self.new()

Create a new store that contains a subsample of the data.

Parameters

indices : list the indices to be included in the subsample ssl_indices : list, optional the indices to be included in the subsample without the annotation data

@classmethod
def get_file_ids(cls, *args, **kwargs) -> List:
118    @classmethod
119    def get_file_ids(cls, *args, **kwargs) -> List:
120        """Get file ids.
121
122        Process data parameters and return a list of ids  of the videos that should
123        be processed by the `__init__` function.
124
125        Returns
126        -------
127        video_ids : list
128            a list of video file ids
129
130        """
131        return None

Get file ids.

Process data parameters and return a list of ids of the videos that should be processed by the __init__ function.

Returns

video_ids : list a list of video file ids

def get_len(self, return_unlabeled: bool) -> int:
149    def get_len(self, return_unlabeled: bool) -> int:
150        """Get the length of the subsample of labeled/unlabeled data.
151
152        If `return_unlabele`d is `True`, the index is in the subsample of unlabeled data, if `False` in labeled
153        and if `return_unlabeled` if `None` the index is already correct.
154
155        Parameters
156        ----------
157        return_unlabeled : bool
158            the identifier for the subsample
159
160        Returns
161        -------
162        length : int
163            the length of the subsample
164
165        """
166        return None

Get the length of the subsample of labeled/unlabeled data.

If return_unlabeled is True, the index is in the subsample of unlabeled data, if False in labeled and if return_unlabeled if None the index is already correct.

Parameters

return_unlabeled : bool the identifier for the subsample

Returns

length : int the length of the subsample

def get_idx(self, index: int, return_unlabeled: bool) -> int:
168    def get_idx(self, index: int, return_unlabeled: bool) -> int:
169        """Convert from an index in the subsample of labeled/unlabeled data to an index in the full array.
170
171        If `return_unlabeled` is `True`, the index is in the subsample of unlabeled data, if `False` in labeled
172        and if `return_unlabeled` is `None` the index is already correct.
173
174        Parameters
175        ----------
176        index : int
177            the index in the subsample
178        return_unlabeled : bool
179            the identifier for the subsample
180
181        Returns
182        -------
183        corrected_index : int
184            the index in the full dataset
185
186        """
187        return index

Convert from an index in the subsample of labeled/unlabeled data to an index in the full array.

If return_unlabeled is True, the index is in the subsample of unlabeled data, if False in labeled and if return_unlabeled is None the index is already correct.

Parameters

index : int the index in the subsample return_unlabeled : bool the identifier for the subsample

Returns

corrected_index : int the index in the full dataset

def count_classes( self, perc: bool = False, zeros: bool = False, bouts: bool = False) -> Dict:
189    def count_classes(
190        self, perc: bool = False, zeros: bool = False, bouts: bool = False
191    ) -> Dict:
192        """Get a dictionary with class-wise frame counts.
193
194        Parameters
195        ----------
196        perc : bool, default False
197            if `True`, a fraction of the total frame count is returned
198        zeros : bool, default False
199            if `True` and annotation is not exclusive, zero counts are returned
200        bouts : bool, default False
201            if `True`, instead of frame counts segment counts are returned
202
203        Returns
204        -------
205        count_dictionary : dict
206            a dictionary with class indices as keys and frame counts as values
207
208        """
209        return {}

Get a dictionary with class-wise frame counts.

Parameters

perc : bool, default False if True, a fraction of the total frame count is returned zeros : bool, default False if True and annotation is not exclusive, zero counts are returned bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

count_dictionary : dict a dictionary with class indices as keys and frame counts as values

def behaviors_dict(self) -> Dict:
211    def behaviors_dict(self) -> Dict:
212        """Get a dictionary of class names.
213
214        Returns
215        -------
216        behavior_dictionary: dict
217            a dictionary with class indices as keys and class names as values
218
219        """
220        return {}

Get a dictionary of class names.

Returns

behavior_dictionary: dict a dictionary with class indices as keys and class names as values

def annotation_class(self) -> str:
222    def annotation_class(self) -> str:
223        """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon).
224
225        Returns
226        -------
227        annotation_class : str
228            the type of annotation
229
230        """
231        return "none"

Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon).

Returns

annotation_class : str the type of annotation

def size(self) -> int:
233    def size(self) -> int:
234        """Get the total number of frames in the data.
235
236        Returns
237        -------
238        size : int
239            the total number of frames
240
241        """
242        return None

Get the total number of frames in the data.

Returns

size : int the total number of frames

def filtered_indices(self) -> List:
244    def filtered_indices(self) -> List:
245        """Return the indices of the samples that should be removed.
246
247        Choosing the indices can be based on any kind of filering defined in the `__init__` function by the data
248        parameters.
249
250        Returns
251        -------
252        indices_to_remove : list
253            a list of integer indices that should be removed
254
255        """
256        return []

Return the indices of the samples that should be removed.

Choosing the indices can be based on any kind of filering defined in the __init__ function by the data parameters.

Returns

indices_to_remove : list a list of integer indices that should be removed

def set_pseudo_labels(self, labels: torch.Tensor) -> None:
258    def set_pseudo_labels(self, labels: torch.Tensor) -> None:
259        """Set pseudo labels to the unlabeled data.
260
261        Parameters
262        ----------
263        labels : torch.Tensor
264            a tensor of pseudo-labels for the unlabeled data
265
266        """
267        pass

Set pseudo labels to the unlabeled data.

Parameters

labels : torch.Tensor a tensor of pseudo-labels for the unlabeled data

class ActionSegmentationStore(dlc2action.data.base_store.BehaviorStore):
 270class ActionSegmentationStore(BehaviorStore):  # +
 271    """A general realization of an annotation store for action segmentation tasks.
 272
 273    Assumes the following file structure:
 274    ```
 275    annotation_path
 276    ├── video1_annotation.pickle
 277    └── video2_labels.pickle
 278    ```
 279    Here `annotation_suffix` is `{'_annotation.pickle', '_labels.pickle'}`.
 280    """
 281
 282    def __init__(
 283        self,
 284        video_order: List = None,
 285        min_frames: Dict = None,
 286        max_frames: Dict = None,
 287        visibility: Dict = None,
 288        exclusive: bool = True,
 289        len_segment: int = 128,
 290        overlap: int = 0,
 291        behaviors: Set = None,
 292        ignored_classes: Set = None,
 293        ignored_clips: Set = None,
 294        annotation_suffix: Union[Set, str] = None,
 295        annotation_path: Union[Set, str] = None,
 296        behavior_file: str = None,
 297        correction: Dict = None,
 298        frame_limit: int = 0,
 299        filter_annotated: bool = False,
 300        filter_background: bool = False,
 301        error_class: str = None,
 302        min_frames_action: int = None,
 303        key_objects: Tuple = None,
 304        visibility_min_score: float = 0.2,
 305        visibility_min_frac: float = 0.7,
 306        mask: Dict = None,
 307        use_hard_negatives: bool = False,
 308        interactive: bool = False,
 309        *args,
 310        **kwargs,
 311    ) -> None:
 312        """Initialize the store.
 313
 314        Parameters
 315        ----------
 316        video_order : list, optional
 317            a list of video ids that should be processed in the same order (not passed if creating from key objects)
 318        min_frames : dict, optional
 319            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
 320            clip start frames (not passed if creating from key objects)
 321        max_frames : dict, optional
 322            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
 323            clip end frames (not passed if creating from key objects)
 324        visibility : dict, optional
 325            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
 326            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
 327        exclusive : bool, default True
 328            if True, the annotation is single-label; if False, multi-label
 329        len_segment : int, default 128
 330            the length of the segments in which the data should be cut (in frames)
 331        overlap : int, default 0
 332            the length of the overlap between neighboring segments (in frames)
 333        behaviors : set, optional
 334            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
 335            loaded from a file)
 336        ignored_classes : set, optional
 337            the list of behaviors from the behaviors list or file to not annotate
 338        ignored_clips : set, optional
 339            clip ids to ignore
 340        annotation_suffix : str | set, optional
 341            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
 342            (not passed if creating from key objects or if irrelevant for the dataset)
 343        annotation_path : str | set, optional
 344            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
 345            from key objects)
 346        behavior_file : str, optional
 347            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
 348        correction : dict, optional
 349            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
 350            can be used to correct for variations in naming or to merge several labels in one
 351        frame_limit : int, default 0
 352            the smallest possible length of a clip (shorter clips are discarded)
 353        filter_annotated : bool, default False
 354            if True, the samples that do not have any labels will be filtered
 355        filter_background : bool, default False
 356            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
 357        error_class : str, optional
 358            the name of the error class (the annotations that intersect with this label will be discarded)
 359        min_frames_action : int, default 0
 360            the minimum length of an action (shorter actions are not annotated)
 361        key_objects : tuple, optional
 362            the key objects to load the BehaviorStore from
 363        visibility_min_score : float, default 5
 364            the minimum visibility score for visibility filtering
 365        visibility_min_frac : float, default 0.7
 366            the minimum fraction of visible frames for visibility filtering
 367        mask : dict, optional
 368            a masked value dictionary (for active learning simulation experiments)
 369        use_hard_negatives : bool, default False
 370            mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing
 371        interactive : bool, default False
 372            if `True`, annotation is assigned to pairs of individuals
 373
 374        """
 375        super().__init__()
 376
 377        if ignored_clips is None:
 378            ignored_clips = []
 379        self.len_segment = int(len_segment)
 380        self.exclusive = exclusive
 381        if isinstance(overlap, str):
 382            overlap = float(overlap)
 383        if overlap < 1:
 384            overlap = overlap * self.len_segment
 385        self.overlap = int(overlap)
 386        self.video_order = video_order
 387        self.min_frames = min_frames
 388        self.max_frames = max_frames
 389        self.visibility = visibility
 390        self.vis_min_score = visibility_min_score
 391        self.vis_min_frac = visibility_min_frac
 392        self.mask = mask
 393        self.use_negatives = use_hard_negatives
 394        self.interactive = interactive
 395        self.ignored_clips = ignored_clips
 396        self.file_paths = self._get_file_paths(annotation_path)
 397        self.ignored_classes = ignored_classes
 398        self.update_behaviors = False
 399
 400        self.ram = True
 401        self.original_coordinates = []
 402        self.filtered = []
 403
 404        self.step = self.len_segment - self.overlap
 405
 406        self.ann_suffix = annotation_suffix
 407        self.annotation_folder = annotation_path
 408        self.filter_annotated = filter_annotated
 409        self.filter_background = filter_background
 410        self.frame_limit = frame_limit
 411        self.min_frames_action = min_frames_action
 412        self.error_class = error_class
 413
 414        if correction is None:
 415            correction = {}
 416        self.correction = correction
 417
 418        if self.max_frames is None:
 419            self.max_frames = defaultdict(lambda: {})
 420        if self.min_frames is None:
 421            self.min_frames = defaultdict(lambda: {})
 422
 423        lists = [self.annotation_folder, self.ann_suffix]
 424        for i in range(len(lists)):
 425            iterable = isinstance(lists[i], Iterable) * (not isinstance(lists[i], str))
 426            if lists[i] is not None:
 427                if not iterable:
 428                    lists[i] = [lists[i]]
 429                lists[i] = [x for x in lists[i]]
 430        self.annotation_folder, self.ann_suffix = lists
 431
 432        if ignored_classes is None:
 433            ignored_classes = []
 434        self.ignored_classes = ignored_classes
 435        self._set_behaviors(behaviors, ignored_classes, behavior_file)
 436
 437        if key_objects is None and self.video_order is not None:
 438            self.data = self._load_data()
 439        elif key_objects is not None:
 440            self.load_from_key_objects(key_objects)
 441        else:
 442            self.data = None
 443        self.labeled_indices, self.unlabeled_indices = self._compute_labeled()
 444
 445    def __getitem__(self, ind):
 446        if self.data is None:
 447            raise RuntimeError("The annotation store data has not been initialized!")
 448        return self.data[ind]
 449
 450    def __len__(self) -> int:
 451        if self.data is None:
 452            raise RuntimeError("The annotation store data has not been initialized!")
 453        return len(self.data)
 454
 455    def remove(self, indices: List) -> None:
 456        """Remove the samples corresponding to indices.
 457
 458        Parameters
 459        ----------
 460        indices : list
 461            a list of integer indices to remove
 462
 463        """
 464        if len(indices) > 0:
 465            mask = np.ones(len(self.data))
 466            mask[indices] = 0
 467            mask = mask.astype(bool)
 468            self.data = self.data[mask]
 469            self.original_coordinates = self.original_coordinates[mask]
 470
 471    def key_objects(self) -> Tuple:
 472        """Return a tuple of the key objects necessary to re-create the Store.
 473
 474        Returns
 475        -------
 476        key_objects : tuple
 477            a tuple of key objects
 478
 479        """
 480        return (
 481            self.original_coordinates,
 482            self.data,
 483            self.behaviors,
 484            self.exclusive,
 485            self.len_segment,
 486            self.step,
 487            self.overlap,
 488        )
 489
 490    def load_from_key_objects(self, key_objects: Tuple) -> None:
 491        """Load the information from a tuple of key objects.
 492
 493        Parameters
 494        ----------
 495        key_objects : tuple
 496            a tuple of key objects
 497
 498        """
 499        (
 500            self.original_coordinates,
 501            self.data,
 502            self.behaviors,
 503            self.exclusive,
 504            self.len_segment,
 505            self.step,
 506            self.overlap,
 507        ) = key_objects
 508        self.labeled_indices, self.unlabeled_indices = self._compute_labeled()
 509
 510    def to_ram(self) -> None:
 511        """Transfer the data samples to RAM if they were previously stored as file paths."""
 512        pass
 513
 514    def get_original_coordinates(self) -> np.ndarray:
 515        """Return the `video_indices` array.
 516
 517        Returns
 518        -------
 519        original_coordinates : numpy.ndarray
 520            an array that contains the coordinates of the data samples in original input data
 521
 522        """
 523        return self.original_coordinates
 524
 525    def create_subsample(self, indices: List, ssl_indices: List = None):
 526        """Create a new store that contains a subsample of the data.
 527
 528        Parameters
 529        ----------
 530        indices : list
 531            the indices to be included in the subsample
 532        ssl_indices : list, optional
 533            the indices to be included in the subsample without the annotation data
 534
 535        """
 536        if ssl_indices is None:
 537            ssl_indices = []
 538        data = copy(self.data)
 539        data[ssl_indices, ...] = -100
 540        new = self.new()
 541        new.original_coordinates = self.original_coordinates[indices + ssl_indices]
 542        new.data = self.data[indices + ssl_indices]
 543        new.labeled_indices, new.unlabeled_indices = new._compute_labeled()
 544        new.behaviors = self.behaviors
 545        new.exclusive = self.exclusive
 546        new.len_segment = self.len_segment
 547        new.step = self.step
 548        new.overlap = self.overlap
 549        new.max_frames = self.max_frames
 550        new.min_frames = self.min_frames
 551        return new
 552
 553    def get_len(self, return_unlabeled: bool) -> int:
 554        """Get the length of the subsample of labeled/unlabeled data.
 555
 556        If return_unlabeled is True, the index is in the subsample of unlabeled data, if False in labeled
 557        and if return_unlabeled is None the index is already correct.
 558
 559        Parameters
 560        ----------
 561        return_unlabeled : bool
 562            the identifier for the subsample
 563
 564        Returns
 565        -------
 566        length : int
 567            the length of the subsample
 568
 569        """
 570        if self.data is None:
 571            raise RuntimeError("The annotation store data has not been initialized!")
 572        elif return_unlabeled is None:
 573            return len(self.data)
 574        elif return_unlabeled:
 575            return len(self.unlabeled_indices)
 576        else:
 577            return len(self.labeled_indices)
 578
 579    def get_indices(self, return_unlabeled: bool) -> List:
 580        """Get a list of indices of samples in the labeled/unlabeled subset.
 581
 582        Parameters
 583        ----------
 584        return_unlabeled : bool
 585            the identifier for the subsample (`True` for unlabeled, `False` for labeled, `None` for the
 586            whole dataset)
 587
 588        Returns
 589        -------
 590        indices : list
 591            a list of indices that meet the criteria
 592
 593        """
 594        return list(range(len(self.data)))
 595
 596    def count_classes(
 597        self, perc: bool = False, zeros: bool = False, bouts: bool = False
 598    ) -> Dict:
 599        """Get a dictionary with class-wise frame counts.
 600
 601        Parameters
 602        ----------
 603        perc : bool, default False
 604            if `True`, a fraction of the total frame count is returned
 605        zeros : bool, default False
 606            if `True` and annotation is not exclusive, zero counts are returned
 607        bouts : bool, default False
 608            if `True`, instead of frame counts segment counts are returned
 609
 610        Returns
 611        -------
 612        count_dictionary : dict
 613            a dictionary with class indices as keys and frame counts as values
 614
 615        """
 616        if bouts:
 617            if self.overlap != 0:
 618                data = {}
 619                for video, value in self.max_frames.items():
 620                    for clip, end in value.items():
 621                        length = end - self._get_min_frame(video, clip)
 622                        if self.exclusive:
 623                            data[f"{video}---{clip}"] = -100 * torch.ones(length)
 624                        else:
 625                            data[f"{video}---{clip}"] = -100 * torch.ones(
 626                                (len(self.behaviors_dict()), length)
 627                            )
 628                for x, coords in zip(self.data, self.original_coordinates):
 629                    split = coords[0].split("---")
 630                    l = self._get_max_frame(split[0], split[1]) - self._get_min_frame(
 631                        split[0], split[1]
 632                    )
 633                    i = coords[1]
 634                    start = int(i) * self.step
 635                    end = min(start + self.len_segment, l)
 636                    data[coords[0]][..., start:end] = x[..., : end - start]
 637                values = []
 638                for key, value in data.items():
 639                    values.append(value)
 640                    values.append(-100 * torch.ones((*value.shape[:-1], 1)))
 641                data = torch.cat(values, -1).T
 642            else:
 643                data = copy(self.data)
 644                if self.exclusive:
 645                    data = data.flatten()
 646                else:
 647                    data = data.transpose(1, 2).reshape(-1, len(self.behaviors))
 648            count_dictionary = {}
 649            for c in self.behaviors_dict():
 650                if self.exclusive:
 651                    arr = data == c
 652                else:
 653                    if zeros:
 654                        arr = data[:, c] == 0
 655                    else:
 656                        arr = data[:, c] == 1
 657                output, indices = torch.unique_consecutive(arr, return_inverse=True)
 658                true_indices = torch.where(output)[0]
 659                count_dictionary[c] = len(true_indices)
 660        else:
 661            ind = 1
 662            if zeros:
 663                ind = 0
 664            if self.exclusive:
 665                count_dictionary = dict(Counter(self.data.flatten().cpu().numpy()))
 666            else:
 667                d = {}
 668                for i in range(self.data.shape[1]):
 669                    cnt = Counter(self.data[:, i, :].flatten().cpu().numpy())
 670                    d[i] = cnt[ind]
 671                count_dictionary = d
 672            if perc:
 673                total = sum([v for k, v in count_dictionary.items()])
 674                count_dictionary = {k: v / total for k, v in count_dictionary.items()}
 675        for i in self.behaviors_dict():
 676            if i not in count_dictionary:
 677                count_dictionary[i] = 0
 678        return {int(k): v for k, v in count_dictionary.items()}
 679
 680    def behaviors_dict(self) -> Dict:
 681        """Get a dictionary of class names.
 682
 683        Returns
 684        -------
 685        behavior_dictionary: dict
 686            a dictionary with class indices as keys and class names as values
 687
 688        """
 689        if self.exclusive and "other" not in self.behaviors:
 690            d = {i + 1: b for i, b in enumerate(self.behaviors)}
 691            d[0] = "other"
 692        else:
 693            d = {i: b for i, b in enumerate(self.behaviors)}
 694        return d
 695
 696    def annotation_class(self) -> str:
 697        """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification').
 698
 699        Returns
 700        -------
 701        annotation_class : str
 702            the type of annotation
 703
 704        """
 705        if self.exclusive:
 706            return "exclusive_classification"
 707        else:
 708            return "nonexclusive_classification"
 709
 710    def size(self) -> int:
 711        """Get the total number of frames in the data.
 712
 713        Returns
 714        -------
 715        size : int
 716            the total number of frames
 717
 718        """
 719        return self.data.shape[0] * self.data.shape[-1]
 720
 721    def filtered_indices(self) -> List:
 722        """Return the indices of the samples that should be removed.
 723
 724        Choosing the indices can be based on any kind of filering defined in the `__init__` function by the data
 725        parameters.
 726
 727        Returns
 728        -------
 729        indices_to_remove : list
 730            a list of integer indices that should be removed
 731
 732        """
 733        return self.filtered
 734
 735    def set_pseudo_labels(self, labels: torch.Tensor) -> None:
 736        """Set pseudo labels to the unlabeled data.
 737
 738        Parameters
 739        ----------
 740        labels : torch.Tensor
 741            a tensor of pseudo-labels for the unlabeled data
 742
 743        """
 744        self.data[self.unlabeled_indices] = labels
 745
 746    @classmethod
 747    def get_file_ids(
 748        cls,
 749        annotation_path: Union[str, Set],
 750        annotation_suffix: Union[str, Set],
 751        *args,
 752        **kwargs,
 753    ) -> List:
 754        """Get file ids.
 755
 756        Process data parameters and return a list of ids  of the videos that should
 757        be processed by the `__init__` function.
 758
 759        Parameters
 760        ----------
 761        annotation_path : str | set
 762            the path or the set of paths to the folder where the annotation files are stored
 763        annotation_suffix : str | set, optional
 764            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
 765
 766        Returns
 767        -------
 768        video_ids : list
 769            a list of video file ids
 770
 771        """
 772        lists = [annotation_path, annotation_suffix]
 773        for i in range(len(lists)):
 774            iterable = isinstance(lists[i], Iterable) * (not isinstance(lists[i], str))
 775            if lists[i] is not None:
 776                if not iterable:
 777                    lists[i] = [lists[i]]
 778                lists[i] = [x for x in lists[i]]
 779        annotation_path, annotation_suffix = lists
 780        files = []
 781        for folder in annotation_path:
 782            files += [
 783                strip_suffix(os.path.basename(file), annotation_suffix)
 784                for file in os.listdir(folder)
 785                if file.endswith(tuple([x for x in annotation_suffix]))
 786            ]
 787        files = sorted(files, key=lambda x: os.path.basename(x))
 788        return files
 789
 790    def _set_behaviors(
 791        self, behaviors: List, ignored_classes: List, behavior_file: str
 792    ):
 793        """Get a list of behaviors that should be annotated from behavior parameters."""
 794        if behaviors is not None:
 795            for b in ignored_classes:
 796                if b in behaviors:
 797                    behaviors.remove(b)
 798        self.behaviors = behaviors
 799
 800    def _compute_labeled(self) -> Tuple[torch.Tensor, torch.Tensor]:
 801        """Get the indices of labeled (annotated) and unlabeled samples."""
 802        if self.data is not None and len(self.data) > 0:
 803            unlabeled = torch.sum(self.data != -100, dim=1) == 0
 804            labeled_indices = torch.where(~unlabeled)[0]
 805            unlabeled_indices = torch.where(unlabeled)[0]
 806        else:
 807            labeled_indices, unlabeled_indices = torch.tensor([]), torch.tensor([])
 808        return labeled_indices, unlabeled_indices
 809
 810    def _generate_annotation(self, times: Dict, name: str) -> Dict:
 811        """Process a loaded annotation file to generate a training labels dictionary."""
 812        annotation = {}
 813        if self.behaviors is None and times is not None:
 814            self.update_behaviors = True
 815            behaviors = set()
 816            for d in times.values():
 817                behaviors.update([k for k, v in d.items()])
 818            self.behaviors = [
 819                x
 820                for x in sorted(behaviors)
 821                if x not in self.ignored_classes
 822                and not x.startswith("negative")
 823                and not x.startswith("unknown")
 824            ]
 825        beh_inv = {v: k for k, v in self.behaviors_dict().items()}
 826        # if there is no annotation file, generate empty annotation
 827        if self.interactive:
 828            clips = [
 829                "+".join(sorted(x))
 830                for x in combinations(self.max_frames[name].keys(), 2)
 831            ]
 832        else:
 833            clips = list(self.max_frames[name].keys())
 834        if times is None:
 835            clips = [x for x in clips if x not in self.ignored_clips]
 836        # otherwise, apply filters and generate label arrays
 837        else:
 838            clips = [
 839                x
 840                for x in clips
 841                if x not in self.ignored_clips and x not in times.keys()
 842            ]
 843            for ind in times.keys():
 844                try:
 845                    min_frame = self._get_min_frame(name, ind)
 846                    max_frame = self._get_max_frame(name, ind)
 847                except KeyError:
 848                    continue
 849                go_on = max_frame - min_frame + 1 >= self.frame_limit
 850                if go_on:
 851                    v_len = max_frame - min_frame + 1
 852                    if self.exclusive:
 853                        if not self.filter_background:
 854                            value = beh_inv.get("other", 0)
 855                            labels = np.ones(v_len, dtype=int) * value
 856                        else:
 857                            labels = -100 * np.ones(v_len, dtype=int)
 858                    else:
 859                        labels = np.zeros(
 860                            (len(self.behaviors), v_len), dtype=np.float32
 861                        )
 862                    cat_new = []
 863                    for cat in times[ind].keys():
 864                        if cat.startswith("unknown"):
 865                            cat_new.append(cat)
 866                    for cat in times[ind].keys():
 867                        if cat.startswith("negative"):
 868                            cat_new.append(cat)
 869                    for cat in times[ind].keys():
 870                        if not cat.startswith("negative") and not cat.startswith(
 871                            "unknown"
 872                        ):
 873                            cat_new.append(cat)
 874                    for cat in cat_new:
 875                        neg = False
 876                        unknown = False
 877                        cat_times = times[ind][cat]
 878                        if self.use_negatives and cat.startswith("negative"):
 879                            cat = " ".join(cat.split()[1:])
 880                            neg = True
 881                        elif cat.startswith("unknown"):
 882                            cat = " ".join(cat.split()[1:])
 883                            unknown = True
 884                        if cat in self.correction:
 885                            cat = self.correction[cat]
 886                        for entry in cat_times:
 887                            if len(entry) == 3:
 888                                start, end, amb = entry
 889                            else:
 890                                start, end = entry
 891                                amb = 0
 892                            if end > self._get_max_frame(name, ind) + 1:
 893                                end = self._get_max_frame(name, ind) + 1
 894                            if amb != 0:
 895                                continue
 896                            start -= min_frame
 897                            end -= min_frame
 898                            if (
 899                                self.min_frames_action is not None
 900                                and end - start < self.min_frames_action
 901                            ):
 902                                continue
 903                            if (
 904                                self.vis_min_frac > 0
 905                                and self.vis_min_score > 0
 906                                and self.visibility is not None
 907                            ):
 908                                s = 0
 909                                for ind_k in ind.split("+"):
 910                                    s += np.sum(
 911                                        self.visibility[name][ind_k][start:end]
 912                                        > self.vis_min_score
 913                                    )
 914                                if s < self.vis_min_frac * (end - start) * len(
 915                                    ind.split("+")
 916                                ):
 917                                    continue
 918                            if cat in beh_inv:
 919                                cat_i_global = beh_inv[cat]
 920                                if self.exclusive:
 921                                    labels[start:end] = cat_i_global
 922                                else:
 923                                    if unknown:
 924                                        labels[cat_i_global, start:end] = -100
 925                                    elif neg:
 926                                        labels[cat_i_global, start:end] = 2
 927                                    else:
 928                                        labels[cat_i_global, start:end] = 1
 929                            else:
 930                                self.not_found.add(cat)
 931                                if self.filter_background:
 932                                    if not self.exclusive:
 933                                        labels[:, start:end][
 934                                            labels[:, start:end] == 0
 935                                        ] = 3
 936                                    else:
 937                                        labels[start:end][labels[start:end] == -100] = 0
 938
 939                    if self.error_class is not None and self.error_class in times[ind]:
 940                        for start, end, amb in times[ind][self.error_class]:
 941                            if self.exclusive:
 942                                labels[start:end] = -100
 943                            else:
 944                                labels[:, start:end] = -100
 945                    annotation[os.path.basename(name) + "---" + str(ind)] = labels
 946        for ind in clips:
 947            try:
 948                min_frame = self._get_min_frame(name, ind)
 949                max_frame = self._get_max_frame(name, ind)
 950            except KeyError:
 951                continue
 952            go_on = max_frame - min_frame + 1 >= self.frame_limit
 953            if go_on:
 954                v_len = max_frame - min_frame + 1
 955                if self.exclusive:
 956                    annotation[os.path.basename(name) + "---" + str(ind)] = (
 957                        -100 * np.ones(v_len, dtype=int)
 958                    )
 959                else:
 960                    annotation[os.path.basename(name) + "---" + str(ind)] = (
 961                        -100 * np.ones((len(self.behaviors), v_len), dtype=np.float32)
 962                    )
 963        return annotation
 964
 965    def _make_trimmed_annotations(self, annotations_dict: Dict) -> torch.Tensor:
 966        """Cut a label dictionary into overlapping pieces of equal length."""
 967        labels = []
 968        self.original_coordinates = []
 969        masked_all = []
 970        for v_id in sorted(annotations_dict.keys()):
 971            if v_id in annotations_dict:
 972                annotations = annotations_dict[v_id]
 973            else:
 974                raise ValueError(
 975                    f'The id list in {v_id.split("---")[0]} is not consistent across files'
 976                )
 977            split = v_id.split("---")
 978            if len(split) > 1:
 979                video_id, ind = split
 980            else:
 981                video_id = split[0]
 982                ind = ""
 983            min_frame = self._get_min_frame(video_id, ind)
 984            max_frame = self._get_max_frame(video_id, ind)
 985            v_len = max_frame - min_frame + 1
 986            sp = np.arange(0, v_len, self.step)
 987            pad = sp[-1] + self.len_segment - v_len
 988            if self.exclusive:
 989                annotations = np.pad(annotations, ((0, pad)), constant_values=-100)
 990            else:
 991                annotations = np.pad(
 992                    annotations, ((0, 0), (0, pad)), constant_values=-100
 993                )
 994            masked = np.zeros(annotations.shape)
 995            if (
 996                self.mask is not None
 997                and video_id in self.mask["masked"]
 998                and ind in self.mask["masked"][video_id]
 999            ):
1000                for start, end in self.mask["masked"][video_id][ind]:
1001                    masked[..., int(start) : int(end)] = 1
1002            for i, start in enumerate(sp):
1003                self.original_coordinates.append((v_id, i))
1004                if self.exclusive:
1005                    ann = annotations[start : start + self.len_segment]
1006                    m = masked[start : start + self.len_segment]
1007                else:
1008                    ann = annotations[:, start : start + self.len_segment]
1009                    m = masked[:, start : start + self.len_segment]
1010                labels.append(ann)
1011                masked_all.append(m)
1012        self.original_coordinates = np.array(self.original_coordinates)
1013        labels = torch.tensor(np.array(labels))
1014        masked_all = torch.tensor(np.array(masked_all)).int().bool()
1015        if self.filter_background and not self.exclusive:
1016            for i, label in enumerate(labels):
1017                label[:, torch.sum((label == 1) | (label == 3), 0) == 0] = -100
1018                label[label == 3] = 0
1019        labels[(labels != -100) & masked_all] = -200
1020        return labels
1021
1022    @classmethod
1023    def _get_file_paths(cls, annotation_path: Union[str, Set]) -> List:
1024        """Get a list of relevant files."""
1025        file_paths = []
1026        if annotation_path is not None:
1027            if isinstance(annotation_path, str):
1028                annotation_path = [annotation_path]
1029            for folder in annotation_path:
1030                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
1031        return file_paths
1032
1033    def _get_max_frame(self, video_id: str, clip_id: str):
1034        """Get the end frame of a clip in a video."""
1035        if clip_id in self.max_frames[video_id]:
1036            return self.max_frames[video_id][clip_id]
1037        else:
1038            return min(
1039                [self.max_frames[video_id][ind_k] for ind_k in clip_id.split("+")]
1040            )
1041
1042    def _get_min_frame(self, video_id, clip_id):
1043        """Get the start frame of a clip in a video."""
1044        if clip_id in self.min_frames[video_id]:
1045            if clip_id not in self.min_frames[video_id]:
1046                raise KeyError(
1047                    f"Check your individual names, video_id : {video_id}, clip_id : {clip_id}"
1048                )
1049            return self.min_frames[video_id][clip_id]
1050        else:
1051            return max(
1052                [self.min_frames[video_id][ind_k] for ind_k in clip_id.split("+")]
1053            )
1054
1055    @abstractmethod
1056    def _load_data(self) -> torch.Tensor:
1057        """Load behavior annotation and generate annotation prompts."""

A general realization of an annotation store for action segmentation tasks.

Assumes the following file structure:

annotation_path
├── video1_annotation.pickle
└── video2_labels.pickle

Here annotation_suffix is {'_annotation.pickle', '_labels.pickle'}.

ActionSegmentationStore( video_order: List = None, min_frames: Dict = None, max_frames: Dict = None, visibility: Dict = None, exclusive: bool = True, len_segment: int = 128, overlap: int = 0, behaviors: Set = None, ignored_classes: Set = None, ignored_clips: Set = None, annotation_suffix: Union[Set, str] = None, annotation_path: Union[Set, str] = None, behavior_file: str = None, correction: Dict = None, frame_limit: int = 0, filter_annotated: bool = False, filter_background: bool = False, error_class: str = None, min_frames_action: int = None, key_objects: Tuple = None, visibility_min_score: float = 0.2, visibility_min_frac: float = 0.7, mask: Dict = None, use_hard_negatives: bool = False, interactive: bool = False, *args, **kwargs)
282    def __init__(
283        self,
284        video_order: List = None,
285        min_frames: Dict = None,
286        max_frames: Dict = None,
287        visibility: Dict = None,
288        exclusive: bool = True,
289        len_segment: int = 128,
290        overlap: int = 0,
291        behaviors: Set = None,
292        ignored_classes: Set = None,
293        ignored_clips: Set = None,
294        annotation_suffix: Union[Set, str] = None,
295        annotation_path: Union[Set, str] = None,
296        behavior_file: str = None,
297        correction: Dict = None,
298        frame_limit: int = 0,
299        filter_annotated: bool = False,
300        filter_background: bool = False,
301        error_class: str = None,
302        min_frames_action: int = None,
303        key_objects: Tuple = None,
304        visibility_min_score: float = 0.2,
305        visibility_min_frac: float = 0.7,
306        mask: Dict = None,
307        use_hard_negatives: bool = False,
308        interactive: bool = False,
309        *args,
310        **kwargs,
311    ) -> None:
312        """Initialize the store.
313
314        Parameters
315        ----------
316        video_order : list, optional
317            a list of video ids that should be processed in the same order (not passed if creating from key objects)
318        min_frames : dict, optional
319            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
320            clip start frames (not passed if creating from key objects)
321        max_frames : dict, optional
322            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
323            clip end frames (not passed if creating from key objects)
324        visibility : dict, optional
325            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
326            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
327        exclusive : bool, default True
328            if True, the annotation is single-label; if False, multi-label
329        len_segment : int, default 128
330            the length of the segments in which the data should be cut (in frames)
331        overlap : int, default 0
332            the length of the overlap between neighboring segments (in frames)
333        behaviors : set, optional
334            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
335            loaded from a file)
336        ignored_classes : set, optional
337            the list of behaviors from the behaviors list or file to not annotate
338        ignored_clips : set, optional
339            clip ids to ignore
340        annotation_suffix : str | set, optional
341            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
342            (not passed if creating from key objects or if irrelevant for the dataset)
343        annotation_path : str | set, optional
344            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
345            from key objects)
346        behavior_file : str, optional
347            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
348        correction : dict, optional
349            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
350            can be used to correct for variations in naming or to merge several labels in one
351        frame_limit : int, default 0
352            the smallest possible length of a clip (shorter clips are discarded)
353        filter_annotated : bool, default False
354            if True, the samples that do not have any labels will be filtered
355        filter_background : bool, default False
356            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
357        error_class : str, optional
358            the name of the error class (the annotations that intersect with this label will be discarded)
359        min_frames_action : int, default 0
360            the minimum length of an action (shorter actions are not annotated)
361        key_objects : tuple, optional
362            the key objects to load the BehaviorStore from
363        visibility_min_score : float, default 5
364            the minimum visibility score for visibility filtering
365        visibility_min_frac : float, default 0.7
366            the minimum fraction of visible frames for visibility filtering
367        mask : dict, optional
368            a masked value dictionary (for active learning simulation experiments)
369        use_hard_negatives : bool, default False
370            mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing
371        interactive : bool, default False
372            if `True`, annotation is assigned to pairs of individuals
373
374        """
375        super().__init__()
376
377        if ignored_clips is None:
378            ignored_clips = []
379        self.len_segment = int(len_segment)
380        self.exclusive = exclusive
381        if isinstance(overlap, str):
382            overlap = float(overlap)
383        if overlap < 1:
384            overlap = overlap * self.len_segment
385        self.overlap = int(overlap)
386        self.video_order = video_order
387        self.min_frames = min_frames
388        self.max_frames = max_frames
389        self.visibility = visibility
390        self.vis_min_score = visibility_min_score
391        self.vis_min_frac = visibility_min_frac
392        self.mask = mask
393        self.use_negatives = use_hard_negatives
394        self.interactive = interactive
395        self.ignored_clips = ignored_clips
396        self.file_paths = self._get_file_paths(annotation_path)
397        self.ignored_classes = ignored_classes
398        self.update_behaviors = False
399
400        self.ram = True
401        self.original_coordinates = []
402        self.filtered = []
403
404        self.step = self.len_segment - self.overlap
405
406        self.ann_suffix = annotation_suffix
407        self.annotation_folder = annotation_path
408        self.filter_annotated = filter_annotated
409        self.filter_background = filter_background
410        self.frame_limit = frame_limit
411        self.min_frames_action = min_frames_action
412        self.error_class = error_class
413
414        if correction is None:
415            correction = {}
416        self.correction = correction
417
418        if self.max_frames is None:
419            self.max_frames = defaultdict(lambda: {})
420        if self.min_frames is None:
421            self.min_frames = defaultdict(lambda: {})
422
423        lists = [self.annotation_folder, self.ann_suffix]
424        for i in range(len(lists)):
425            iterable = isinstance(lists[i], Iterable) * (not isinstance(lists[i], str))
426            if lists[i] is not None:
427                if not iterable:
428                    lists[i] = [lists[i]]
429                lists[i] = [x for x in lists[i]]
430        self.annotation_folder, self.ann_suffix = lists
431
432        if ignored_classes is None:
433            ignored_classes = []
434        self.ignored_classes = ignored_classes
435        self._set_behaviors(behaviors, ignored_classes, behavior_file)
436
437        if key_objects is None and self.video_order is not None:
438            self.data = self._load_data()
439        elif key_objects is not None:
440            self.load_from_key_objects(key_objects)
441        else:
442            self.data = None
443        self.labeled_indices, self.unlabeled_indices = self._compute_labeled()

Initialize the store.

Parameters

video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects) min_frames : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are clip start frames (not passed if creating from key objects) max_frames : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are clip end frames (not passed if creating from key objects) visibility : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset) exclusive : bool, default True if True, the annotation is single-label; if False, multi-label len_segment : int, default 128 the length of the segments in which the data should be cut (in frames) overlap : int, default 0 the length of the overlap between neighboring segments (in frames) behaviors : set, optional the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are loaded from a file) ignored_classes : set, optional the list of behaviors from the behaviors list or file to not annotate ignored_clips : set, optional clip ids to ignore annotation_suffix : str | set, optional the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix} (not passed if creating from key objects or if irrelevant for the dataset) annotation_path : str | set, optional the path or the set of paths to the folder where the annotation files are stored (not passed if creating from key objects) behavior_file : str, optional the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset) correction : dict, optional a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'}, can be used to correct for variations in naming or to merge several labels in one frame_limit : int, default 0 the smallest possible length of a clip (shorter clips are discarded) filter_annotated : bool, default False if True, the samples that do not have any labels will be filtered filter_background : bool, default False if True, only the unlabeled frames that are close to annotated frames will be labeled as background error_class : str, optional the name of the error class (the annotations that intersect with this label will be discarded) min_frames_action : int, default 0 the minimum length of an action (shorter actions are not annotated) key_objects : tuple, optional the key objects to load the BehaviorStore from visibility_min_score : float, default 5 the minimum visibility score for visibility filtering visibility_min_frac : float, default 0.7 the minimum fraction of visible frames for visibility filtering mask : dict, optional a masked value dictionary (for active learning simulation experiments) use_hard_negatives : bool, default False mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing interactive : bool, default False if True, annotation is assigned to pairs of individuals

len_segment
exclusive
overlap
video_order
min_frames
max_frames
visibility
vis_min_score
vis_min_frac
mask
use_negatives
interactive
ignored_clips
file_paths
ignored_classes
update_behaviors
ram
original_coordinates
filtered
step
ann_suffix
annotation_folder
filter_annotated
filter_background
frame_limit
min_frames_action
error_class
correction
def remove(self, indices: List) -> None:
455    def remove(self, indices: List) -> None:
456        """Remove the samples corresponding to indices.
457
458        Parameters
459        ----------
460        indices : list
461            a list of integer indices to remove
462
463        """
464        if len(indices) > 0:
465            mask = np.ones(len(self.data))
466            mask[indices] = 0
467            mask = mask.astype(bool)
468            self.data = self.data[mask]
469            self.original_coordinates = self.original_coordinates[mask]

Remove the samples corresponding to indices.

Parameters

indices : list a list of integer indices to remove

def key_objects(self) -> Tuple:
471    def key_objects(self) -> Tuple:
472        """Return a tuple of the key objects necessary to re-create the Store.
473
474        Returns
475        -------
476        key_objects : tuple
477            a tuple of key objects
478
479        """
480        return (
481            self.original_coordinates,
482            self.data,
483            self.behaviors,
484            self.exclusive,
485            self.len_segment,
486            self.step,
487            self.overlap,
488        )

Return a tuple of the key objects necessary to re-create the Store.

Returns

key_objects : tuple a tuple of key objects

def load_from_key_objects(self, key_objects: Tuple) -> None:
490    def load_from_key_objects(self, key_objects: Tuple) -> None:
491        """Load the information from a tuple of key objects.
492
493        Parameters
494        ----------
495        key_objects : tuple
496            a tuple of key objects
497
498        """
499        (
500            self.original_coordinates,
501            self.data,
502            self.behaviors,
503            self.exclusive,
504            self.len_segment,
505            self.step,
506            self.overlap,
507        ) = key_objects
508        self.labeled_indices, self.unlabeled_indices = self._compute_labeled()

Load the information from a tuple of key objects.

Parameters

key_objects : tuple a tuple of key objects

def to_ram(self) -> None:
510    def to_ram(self) -> None:
511        """Transfer the data samples to RAM if they were previously stored as file paths."""
512        pass

Transfer the data samples to RAM if they were previously stored as file paths.

def get_original_coordinates(self) -> numpy.ndarray:
514    def get_original_coordinates(self) -> np.ndarray:
515        """Return the `video_indices` array.
516
517        Returns
518        -------
519        original_coordinates : numpy.ndarray
520            an array that contains the coordinates of the data samples in original input data
521
522        """
523        return self.original_coordinates

Return the video_indices array.

Returns

original_coordinates : numpy.ndarray an array that contains the coordinates of the data samples in original input data

def create_subsample(self, indices: List, ssl_indices: List = None):
525    def create_subsample(self, indices: List, ssl_indices: List = None):
526        """Create a new store that contains a subsample of the data.
527
528        Parameters
529        ----------
530        indices : list
531            the indices to be included in the subsample
532        ssl_indices : list, optional
533            the indices to be included in the subsample without the annotation data
534
535        """
536        if ssl_indices is None:
537            ssl_indices = []
538        data = copy(self.data)
539        data[ssl_indices, ...] = -100
540        new = self.new()
541        new.original_coordinates = self.original_coordinates[indices + ssl_indices]
542        new.data = self.data[indices + ssl_indices]
543        new.labeled_indices, new.unlabeled_indices = new._compute_labeled()
544        new.behaviors = self.behaviors
545        new.exclusive = self.exclusive
546        new.len_segment = self.len_segment
547        new.step = self.step
548        new.overlap = self.overlap
549        new.max_frames = self.max_frames
550        new.min_frames = self.min_frames
551        return new

Create a new store that contains a subsample of the data.

Parameters

indices : list the indices to be included in the subsample ssl_indices : list, optional the indices to be included in the subsample without the annotation data

def get_len(self, return_unlabeled: bool) -> int:
553    def get_len(self, return_unlabeled: bool) -> int:
554        """Get the length of the subsample of labeled/unlabeled data.
555
556        If return_unlabeled is True, the index is in the subsample of unlabeled data, if False in labeled
557        and if return_unlabeled is None the index is already correct.
558
559        Parameters
560        ----------
561        return_unlabeled : bool
562            the identifier for the subsample
563
564        Returns
565        -------
566        length : int
567            the length of the subsample
568
569        """
570        if self.data is None:
571            raise RuntimeError("The annotation store data has not been initialized!")
572        elif return_unlabeled is None:
573            return len(self.data)
574        elif return_unlabeled:
575            return len(self.unlabeled_indices)
576        else:
577            return len(self.labeled_indices)

Get the length of the subsample of labeled/unlabeled data.

If return_unlabeled is True, the index is in the subsample of unlabeled data, if False in labeled and if return_unlabeled is None the index is already correct.

Parameters

return_unlabeled : bool the identifier for the subsample

Returns

length : int the length of the subsample

def get_indices(self, return_unlabeled: bool) -> List:
579    def get_indices(self, return_unlabeled: bool) -> List:
580        """Get a list of indices of samples in the labeled/unlabeled subset.
581
582        Parameters
583        ----------
584        return_unlabeled : bool
585            the identifier for the subsample (`True` for unlabeled, `False` for labeled, `None` for the
586            whole dataset)
587
588        Returns
589        -------
590        indices : list
591            a list of indices that meet the criteria
592
593        """
594        return list(range(len(self.data)))

Get a list of indices of samples in the labeled/unlabeled subset.

Parameters

return_unlabeled : bool the identifier for the subsample (True for unlabeled, False for labeled, None for the whole dataset)

Returns

indices : list a list of indices that meet the criteria

def count_classes( self, perc: bool = False, zeros: bool = False, bouts: bool = False) -> Dict:
596    def count_classes(
597        self, perc: bool = False, zeros: bool = False, bouts: bool = False
598    ) -> Dict:
599        """Get a dictionary with class-wise frame counts.
600
601        Parameters
602        ----------
603        perc : bool, default False
604            if `True`, a fraction of the total frame count is returned
605        zeros : bool, default False
606            if `True` and annotation is not exclusive, zero counts are returned
607        bouts : bool, default False
608            if `True`, instead of frame counts segment counts are returned
609
610        Returns
611        -------
612        count_dictionary : dict
613            a dictionary with class indices as keys and frame counts as values
614
615        """
616        if bouts:
617            if self.overlap != 0:
618                data = {}
619                for video, value in self.max_frames.items():
620                    for clip, end in value.items():
621                        length = end - self._get_min_frame(video, clip)
622                        if self.exclusive:
623                            data[f"{video}---{clip}"] = -100 * torch.ones(length)
624                        else:
625                            data[f"{video}---{clip}"] = -100 * torch.ones(
626                                (len(self.behaviors_dict()), length)
627                            )
628                for x, coords in zip(self.data, self.original_coordinates):
629                    split = coords[0].split("---")
630                    l = self._get_max_frame(split[0], split[1]) - self._get_min_frame(
631                        split[0], split[1]
632                    )
633                    i = coords[1]
634                    start = int(i) * self.step
635                    end = min(start + self.len_segment, l)
636                    data[coords[0]][..., start:end] = x[..., : end - start]
637                values = []
638                for key, value in data.items():
639                    values.append(value)
640                    values.append(-100 * torch.ones((*value.shape[:-1], 1)))
641                data = torch.cat(values, -1).T
642            else:
643                data = copy(self.data)
644                if self.exclusive:
645                    data = data.flatten()
646                else:
647                    data = data.transpose(1, 2).reshape(-1, len(self.behaviors))
648            count_dictionary = {}
649            for c in self.behaviors_dict():
650                if self.exclusive:
651                    arr = data == c
652                else:
653                    if zeros:
654                        arr = data[:, c] == 0
655                    else:
656                        arr = data[:, c] == 1
657                output, indices = torch.unique_consecutive(arr, return_inverse=True)
658                true_indices = torch.where(output)[0]
659                count_dictionary[c] = len(true_indices)
660        else:
661            ind = 1
662            if zeros:
663                ind = 0
664            if self.exclusive:
665                count_dictionary = dict(Counter(self.data.flatten().cpu().numpy()))
666            else:
667                d = {}
668                for i in range(self.data.shape[1]):
669                    cnt = Counter(self.data[:, i, :].flatten().cpu().numpy())
670                    d[i] = cnt[ind]
671                count_dictionary = d
672            if perc:
673                total = sum([v for k, v in count_dictionary.items()])
674                count_dictionary = {k: v / total for k, v in count_dictionary.items()}
675        for i in self.behaviors_dict():
676            if i not in count_dictionary:
677                count_dictionary[i] = 0
678        return {int(k): v for k, v in count_dictionary.items()}

Get a dictionary with class-wise frame counts.

Parameters

perc : bool, default False if True, a fraction of the total frame count is returned zeros : bool, default False if True and annotation is not exclusive, zero counts are returned bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

count_dictionary : dict a dictionary with class indices as keys and frame counts as values

def behaviors_dict(self) -> Dict:
680    def behaviors_dict(self) -> Dict:
681        """Get a dictionary of class names.
682
683        Returns
684        -------
685        behavior_dictionary: dict
686            a dictionary with class indices as keys and class names as values
687
688        """
689        if self.exclusive and "other" not in self.behaviors:
690            d = {i + 1: b for i, b in enumerate(self.behaviors)}
691            d[0] = "other"
692        else:
693            d = {i: b for i, b in enumerate(self.behaviors)}
694        return d

Get a dictionary of class names.

Returns

behavior_dictionary: dict a dictionary with class indices as keys and class names as values

def annotation_class(self) -> str:
696    def annotation_class(self) -> str:
697        """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification').
698
699        Returns
700        -------
701        annotation_class : str
702            the type of annotation
703
704        """
705        if self.exclusive:
706            return "exclusive_classification"
707        else:
708            return "nonexclusive_classification"

Get the type of annotation ('exclusive_classification', 'nonexclusive_classification').

Returns

annotation_class : str the type of annotation

def size(self) -> int:
710    def size(self) -> int:
711        """Get the total number of frames in the data.
712
713        Returns
714        -------
715        size : int
716            the total number of frames
717
718        """
719        return self.data.shape[0] * self.data.shape[-1]

Get the total number of frames in the data.

Returns

size : int the total number of frames

def filtered_indices(self) -> List:
721    def filtered_indices(self) -> List:
722        """Return the indices of the samples that should be removed.
723
724        Choosing the indices can be based on any kind of filering defined in the `__init__` function by the data
725        parameters.
726
727        Returns
728        -------
729        indices_to_remove : list
730            a list of integer indices that should be removed
731
732        """
733        return self.filtered

Return the indices of the samples that should be removed.

Choosing the indices can be based on any kind of filering defined in the __init__ function by the data parameters.

Returns

indices_to_remove : list a list of integer indices that should be removed

def set_pseudo_labels(self, labels: torch.Tensor) -> None:
735    def set_pseudo_labels(self, labels: torch.Tensor) -> None:
736        """Set pseudo labels to the unlabeled data.
737
738        Parameters
739        ----------
740        labels : torch.Tensor
741            a tensor of pseudo-labels for the unlabeled data
742
743        """
744        self.data[self.unlabeled_indices] = labels

Set pseudo labels to the unlabeled data.

Parameters

labels : torch.Tensor a tensor of pseudo-labels for the unlabeled data

@classmethod
def get_file_ids( cls, annotation_path: Union[str, Set], annotation_suffix: Union[str, Set], *args, **kwargs) -> List:
746    @classmethod
747    def get_file_ids(
748        cls,
749        annotation_path: Union[str, Set],
750        annotation_suffix: Union[str, Set],
751        *args,
752        **kwargs,
753    ) -> List:
754        """Get file ids.
755
756        Process data parameters and return a list of ids  of the videos that should
757        be processed by the `__init__` function.
758
759        Parameters
760        ----------
761        annotation_path : str | set
762            the path or the set of paths to the folder where the annotation files are stored
763        annotation_suffix : str | set, optional
764            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
765
766        Returns
767        -------
768        video_ids : list
769            a list of video file ids
770
771        """
772        lists = [annotation_path, annotation_suffix]
773        for i in range(len(lists)):
774            iterable = isinstance(lists[i], Iterable) * (not isinstance(lists[i], str))
775            if lists[i] is not None:
776                if not iterable:
777                    lists[i] = [lists[i]]
778                lists[i] = [x for x in lists[i]]
779        annotation_path, annotation_suffix = lists
780        files = []
781        for folder in annotation_path:
782            files += [
783                strip_suffix(os.path.basename(file), annotation_suffix)
784                for file in os.listdir(folder)
785                if file.endswith(tuple([x for x in annotation_suffix]))
786            ]
787        files = sorted(files, key=lambda x: os.path.basename(x))
788        return files

Get file ids.

Process data parameters and return a list of ids of the videos that should be processed by the __init__ function.

Parameters

annotation_path : str | set the path or the set of paths to the folder where the annotation files are stored annotation_suffix : str | set, optional the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}

Returns

video_ids : list a list of video file ids

class FileActionSegStore(ActionSegmentationStore):
1060class FileActionSegStore(ActionSegmentationStore):  # +
1061    """A generalized implementation of `ActionSegmentationStore` for datasets where one file corresponds to one video."""
1062
1063    def _generate_max_min_frames(self, times: Dict, video_id: str) -> None:
1064        """Generate `max_frames` and `min_frames` objects in case they were not passed from an `InputStore`."""
1065        if video_id in self.max_frames:
1066            return
1067        for ind, cat_dict in times.items():
1068            maxes = []
1069            mins = []
1070            for cat, cat_list in cat_dict.items():
1071                if len(cat_list) > 0:
1072                    maxes.append(max([x[1] for x in cat_list]))
1073                    mins.append(min([x[0] for x in cat_list]))
1074            self.max_frames[video_id][ind] = max(maxes)
1075            self.min_frames[video_id][ind] = min(mins)
1076
1077    def _load_data(self) -> torch.Tensor:
1078        """Load behavior annotation and generate annotation prompts."""
1079        if self.video_order is None:
1080            return None
1081
1082        files = []
1083        for x in self.video_order:
1084            ok = False
1085            for folder in self.annotation_folder:
1086                for s in self.ann_suffix:
1087                    file = os.path.join(folder, x + s)
1088                    if os.path.exists(file):
1089                        files.append(file)
1090                        ok = True
1091                        break
1092            if not ok:
1093                files.append(None)
1094        self.not_found = set()
1095        annotations_dict = {}
1096        print("Computing annotation arrays...")
1097        for name, filename in tqdm(list(zip(self.video_order, files))):
1098            if filename is not None:
1099                times = self._open_annotations(filename)
1100            else:
1101                times = None
1102            if times is not None:
1103                self._generate_max_min_frames(times, name)
1104            annotations_dict.update(self._generate_annotation(times, name))
1105            del times
1106        annotation = self._make_trimmed_annotations(annotations_dict)
1107        del annotations_dict
1108        if self.filter_annotated:
1109            if self.exclusive:
1110                s = torch.sum((annotation != -100), dim=1)
1111            else:
1112                s = torch.sum(
1113                    torch.sum((annotation != -100), dim=1) == annotation.shape[1], dim=1
1114                )
1115            self.filtered += torch.where(s == 0)[0].tolist()
1116        annotation[annotation == -200] = -100
1117        return annotation
1118
1119    @abstractmethod
1120    def _open_annotations(self, filename: str) -> Dict:
1121        """Load the annotation from filename.
1122
1123        Parameters
1124        ----------
1125        filename : str
1126            path to an annotation file
1127
1128        Returns
1129        -------
1130        times : dict
1131            a nested dictionary where first-level keys are clip ids, second-level keys are categories and values are
1132            lists of (start, end, ambiguity status) lists
1133
1134        """

A generalized implementation of ActionSegmentationStore for datasets where one file corresponds to one video.

class SequenceActionSegStore(ActionSegmentationStore):
1137class SequenceActionSegStore(ActionSegmentationStore):  # +
1138    """A generalized implementation of `ActionSegmentationStore` for datasets where one file corresponds to multiple videos."""
1139
1140    def _generate_max_min_frames(self, times: Dict) -> None:
1141        """Generate `max_frames` and `min_frames` objects in case they were not passed from an `InputStore`."""
1142        for video_id in times:
1143            if video_id in self.max_frames:
1144                continue
1145            self.max_frames[video_id] = {}
1146            for ind, cat_dict in times[video_id].items():
1147                maxes = []
1148                mins = []
1149                for cat, cat_list in cat_dict.items():
1150                    maxes.append(max([x[1] for x in cat_list]))
1151                    mins.append(min([x[0] for x in cat_list]))
1152                self.max_frames[video_id][ind] = max(maxes)
1153                self.min_frames[video_id][ind] = min(mins)
1154
1155    @classmethod
1156    def get_file_ids(
1157        cls,
1158        filenames: List = None,
1159        annotation_path: str = None,
1160        *args,
1161        **kwargs,
1162    ) -> List:
1163        """Get file ids.
1164
1165        Process data parameters and return a list of ids  of the videos that should
1166        be processed by the `__init__` function.
1167
1168        Parameters
1169        ----------
1170        filenames : list, optional
1171            a list of annotation file paths
1172        annotation_path : str, optional
1173            path to the annotation folder
1174
1175        Returns
1176        -------
1177        video_ids : list
1178            a list of video file ids
1179
1180        """
1181        file_paths = []
1182        if annotation_path is not None:
1183            if isinstance(annotation_path, str):
1184                annotation_path = [annotation_path]
1185            file_paths = []
1186            for folder in annotation_path:
1187                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
1188        ids = set()
1189        for f in file_paths:
1190            if os.path.basename(f) in filenames:
1191                ids.add(os.path.basename(f))
1192        ids = sorted(ids)
1193        return ids
1194
1195    def _load_data(self) -> torch.Tensor:
1196        """Load behavior annotation and generate annotation prompts."""
1197        if self.video_order is None:
1198            return None
1199
1200        files = []
1201        for f in self.file_paths:
1202            if os.path.basename(f) in self.video_order:
1203                files.append(f)
1204        self.not_found = set()
1205        annotations_dict = {}
1206        for name, filename in tqdm(zip(self.video_order, files)):
1207            if filename is not None:
1208                times = self._open_sequences(filename)
1209            else:
1210                times = None
1211            if times is not None:
1212                self._generate_max_min_frames(times)
1213                none_ids = []
1214                for video_id, sequence_dict in times.items():
1215                    if sequence_dict is None:
1216                        none_ids.append(sequence_dict)
1217                        continue
1218                    annotations_dict.update(
1219                        self._generate_annotation(sequence_dict, video_id)
1220                    )
1221                for video_id in none_ids:
1222                    annotations_dict.update(self._generate_annotation(None, video_id))
1223                del times
1224        annotation = self._make_trimmed_annotations(annotations_dict)
1225        del annotations_dict
1226        if self.filter_annotated:
1227            if self.exclusive:
1228                s = torch.sum((annotation != -100), dim=1)
1229            else:
1230                s = torch.sum(
1231                    torch.sum((annotation != -100), dim=1) == annotation.shape[1], dim=1
1232                )
1233            self.filtered += torch.where(s == 0)[0].tolist()
1234        annotation[annotation == -200] = -100
1235        return annotation
1236
1237    @abstractmethod
1238    def _open_sequences(self, filename: str) -> Dict:
1239        """Load the annotation from filename.
1240
1241        Parameters
1242        ----------
1243        filename : str
1244            path to an annotation file
1245
1246        Returns
1247        -------
1248        times : dict
1249            a nested dictionary where first-level keys are video ids, second-level keys are clip ids,
1250            third-level keys are categories and values are
1251            lists of (start, end, ambiguity status) lists
1252
1253        """

A generalized implementation of ActionSegmentationStore for datasets where one file corresponds to multiple videos.

@classmethod
def get_file_ids( cls, filenames: List = None, annotation_path: str = None, *args, **kwargs) -> List:
1155    @classmethod
1156    def get_file_ids(
1157        cls,
1158        filenames: List = None,
1159        annotation_path: str = None,
1160        *args,
1161        **kwargs,
1162    ) -> List:
1163        """Get file ids.
1164
1165        Process data parameters and return a list of ids  of the videos that should
1166        be processed by the `__init__` function.
1167
1168        Parameters
1169        ----------
1170        filenames : list, optional
1171            a list of annotation file paths
1172        annotation_path : str, optional
1173            path to the annotation folder
1174
1175        Returns
1176        -------
1177        video_ids : list
1178            a list of video file ids
1179
1180        """
1181        file_paths = []
1182        if annotation_path is not None:
1183            if isinstance(annotation_path, str):
1184                annotation_path = [annotation_path]
1185            file_paths = []
1186            for folder in annotation_path:
1187                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
1188        ids = set()
1189        for f in file_paths:
1190            if os.path.basename(f) in filenames:
1191                ids.add(os.path.basename(f))
1192        ids = sorted(ids)
1193        return ids

Get file ids.

Process data parameters and return a list of ids of the videos that should be processed by the __init__ function.

Parameters

filenames : list, optional a list of annotation file paths annotation_path : str, optional path to the annotation folder

Returns

video_ids : list a list of video file ids

class DLC2ActionStore(FileActionSegStore):
1256class DLC2ActionStore(FileActionSegStore):  # +
1257    """DLC type annotation data.
1258
1259    The files are either the DLC2Action GUI output or a pickled dictionary of the following structure:
1260        - nested dictionary,
1261        - first-level keys are individual IDs,
1262        - second-level keys are labels,
1263        - values are lists of intervals,
1264        - the lists of intervals is formatted as `[start_frame, end_frame, ambiguity]`,
1265        - ambiguity is 1 if the action is ambiguous (!!at the moment DLC2Action will IGNORE those intervals!!) or 0 if it isn't.
1266    A minimum working example of such a dictionary is:
1267    ```
1268    {
1269        "ind0": {},
1270        "ind1": {
1271            "running": [60, 70, 0]],
1272            "eating": []
1273        }
1274    }
1275    ```
1276    Here there are two animals: `"ind0"` and `"ind1"`, and two actions: running and eating.
1277    The only annotated action is eating for `"ind1"` between frames 60 and 70.
1278    If you generate those files manually, run this code for a sanity check:
1279    ```
1280    import pickle
1281    with open("/path/to/annotation.pickle", "rb") as f:
1282    data = pickle.load(f)
1283    for ind, ind_dict in data.items():
1284        print(f'individual {ind}:')
1285        for label, intervals in ind_dict.items():
1286            for start, end, ambiguity in intervals:
1287                if ambiguity == 0:
1288                    print(f'  from {start} to {end} frame: {label}')
1289    ```
1290    Assumes the following file structure:
1291    ```
1292    annotation_path
1293    ├── video1_annotation.pickle
1294    └── video2_labels.pickle
1295    ```
1296    Here `annotation_suffix` is `{'_annotation.pickle', '_labels.pickle'}`.
1297    """
1298
1299    def _open_annotations(self, filename: str) -> Dict:
1300        """Load the annotation from `filename`."""
1301        try:
1302            with open(filename, "rb") as f:
1303                data = pickle.load(f)
1304            if isinstance(data, dict):
1305                annotation = data
1306                for ind in annotation:
1307                    for cat, cat_list in annotation[ind].items():
1308                        annotation[ind][cat] = [
1309                            [start, end, 0] for start, end in cat_list
1310                        ]
1311            else:
1312                _, loaded_labels, animals, loaded_times = data
1313                annotation = {}
1314                for ind, ind_list in zip(animals, loaded_times):
1315                    annotation[ind] = {}
1316                    for cat, cat_list in zip(loaded_labels, ind_list):
1317                        annotation[ind][cat] = cat_list
1318            return annotation
1319        except:
1320            print(f"{filename} is invalid or does not exist")
1321            return None

DLC type annotation data.

The files are either the DLC2Action GUI output or a pickled dictionary of the following structure: - nested dictionary, - first-level keys are individual IDs, - second-level keys are labels, - values are lists of intervals, - the lists of intervals is formatted as [start_frame, end_frame, ambiguity], - ambiguity is 1 if the action is ambiguous (!!at the moment DLC2Action will IGNORE those intervals!!) or 0 if it isn't. A minimum working example of such a dictionary is:

{
    "ind0": {},
    "ind1": {
        "running": [60, 70, 0]],
        "eating": []
    }
}

Here there are two animals: "ind0" and "ind1", and two actions: running and eating. The only annotated action is eating for "ind1" between frames 60 and 70. If you generate those files manually, run this code for a sanity check:

import pickle
with open("/path/to/annotation.pickle", "rb") as f:
data = pickle.load(f)
for ind, ind_dict in data.items():
    print(f'individual {ind}:')
    for label, intervals in ind_dict.items():
        for start, end, ambiguity in intervals:
            if ambiguity == 0:
                print(f'  from {start} to {end} frame: {label}')

Assumes the following file structure:

annotation_path
├── video1_annotation.pickle
└── video2_labels.pickle

Here annotation_suffix is {'_annotation.pickle', '_labels.pickle'}.

class BorisStore(FileActionSegStore):
1324class BorisStore(FileActionSegStore):  # +
1325    """BORIS type annotation data.
1326
1327    Assumes the following file structure:
1328    ```
1329    annotation_path
1330    ├── video1_annotation.pickle
1331    └── video2_labels.pickle
1332    ```
1333    Here `annotation_suffix` is `{'_annotation.pickle', '_labels.pickle'}`.
1334    """
1335
1336    def __init__(
1337        self,
1338        video_order: List = None,
1339        min_frames: Dict = None,
1340        max_frames: Dict = None,
1341        visibility: Dict = None,
1342        exclusive: bool = True,
1343        len_segment: int = 128,
1344        overlap: int = 0,
1345        behaviors: Set = None,
1346        ignored_classes: Set = None,
1347        annotation_suffix: Union[Set, str] = None,
1348        annotation_path: Union[Set, str] = None,
1349        behavior_file: str = None,
1350        correction: Dict = None,
1351        frame_limit: int = 0,
1352        filter_annotated: bool = False,
1353        filter_background: bool = False,
1354        error_class: str = None,
1355        min_frames_action: int = None,
1356        key_objects: Tuple = None,
1357        visibility_min_score: float = 0.2,
1358        visibility_min_frac: float = 0.7,
1359        mask: Dict = None,
1360        use_hard_negatives: bool = False,
1361        default_agent_name: str = "ind0",
1362        interactive: bool = False,
1363        ignored_clips: Set = None,
1364        *args,
1365        **kwargs,
1366    ) -> None:
1367        """Initialize a store.
1368
1369        Parameters
1370        ----------
1371        video_order : list, optional
1372            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1373        min_frames : dict, optional
1374            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1375            clip start frames (not passed if creating from key objects)
1376        max_frames : dict, optional
1377            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1378            clip end frames (not passed if creating from key objects)
1379        visibility : dict, optional
1380            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1381            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
1382        exclusive : bool, default True
1383            if True, the annotation is single-label; if False, multi-label
1384        len_segment : int, default 128
1385            the length of the segments in which the data should be cut (in frames)
1386        overlap : int, default 0
1387            the length of the overlap between neighboring segments (in frames)
1388        behaviors : set, optional
1389            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
1390            loaded from a file)
1391        ignored_classes : set, optional
1392            the list of behaviors from the behaviors list or file to not annotate
1393        annotation_suffix : str | set, optional
1394            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
1395            (not passed if creating from key objects or if irrelevant for the dataset)
1396        annotation_path : str | set, optional
1397            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
1398            from key objects)
1399        behavior_file : str, optional
1400            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
1401        correction : dict, optional
1402            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
1403            can be used to correct for variations in naming or to merge several labels in one
1404        frame_limit : int, default 0
1405            the smallest possible length of a clip (shorter clips are discarded)
1406        filter_annotated : bool, default False
1407            if True, the samples that do not have any labels will be filtered
1408        filter_background : bool, default False
1409            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
1410        error_class : str, optional
1411            the name of the error class (the annotations that intersect with this label will be discarded)
1412        min_frames_action : int, default 0
1413            the minimum length of an action (shorter actions are not annotated)
1414        key_objects : tuple, optional
1415            the key objects to load the BehaviorStore from
1416        visibility_min_score : float, default 5
1417            the minimum visibility score for visibility filtering
1418        visibility_min_frac : float, default 0.7
1419            the minimum fraction of visible frames for visibility filtering
1420        mask : dict, optional
1421            a masked value dictionary (for active learning simulation experiments)
1422        use_hard_negatives : bool, default False
1423            mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing
1424        default_agent_name : str, default 'ind0'
1425            the name of the default agent
1426        interactive : bool, default False
1427            if `True`, annotation is assigned to pairs of individuals
1428        ignored_clips : set, optional
1429            a set of clip ids to ignore
1430
1431        """
1432        self.default_agent_name = default_agent_name
1433        super().__init__(
1434            video_order=video_order,
1435            min_frames=min_frames,
1436            max_frames=max_frames,
1437            visibility=visibility,
1438            exclusive=exclusive,
1439            len_segment=len_segment,
1440            overlap=overlap,
1441            behaviors=behaviors,
1442            ignored_classes=ignored_classes,
1443            annotation_suffix=annotation_suffix,
1444            annotation_path=annotation_path,
1445            behavior_file=behavior_file,
1446            correction=correction,
1447            frame_limit=frame_limit,
1448            filter_annotated=filter_annotated,
1449            filter_background=filter_background,
1450            error_class=error_class,
1451            min_frames_action=min_frames_action,
1452            key_objects=key_objects,
1453            visibility_min_score=visibility_min_score,
1454            visibility_min_frac=visibility_min_frac,
1455            mask=mask,
1456            use_hard_negatives=use_hard_negatives,
1457            interactive=interactive,
1458            ignored_clips=ignored_clips,
1459        )
1460
1461    def _open_annotations(self, filename: str) -> Dict:
1462        """Load the annotation from filename."""
1463        try:
1464            df = pd.read_csv(filename, header=15)
1465            fps = df.iloc[0]["FPS"]
1466            df["Subject"] = df["Subject"].fillna(self.default_agent_name)
1467            loaded_labels = list(df["Behavior"].unique())
1468            animals = list(df["Subject"].unique())
1469            loaded_times = {}
1470            for ind in animals:
1471                loaded_times[ind] = {}
1472                agent_df = df[df["Subject"] == ind]
1473                for cat in loaded_labels:
1474                    filtered_df = agent_df[agent_df["Behavior"] == cat]
1475                    starts = (
1476                        filtered_df["Time"][filtered_df["Status"] == "START"] * fps
1477                    ).astype(int)
1478                    ends = (
1479                        filtered_df["Time"][filtered_df["Status"] == "STOP"] * fps
1480                    ).astype(int)
1481                    loaded_times[ind][cat] = [
1482                        [start, end, 0] for start, end in zip(starts, ends)
1483                    ]
1484            return loaded_times
1485        except:
1486            print(f"{filename} is invalid or does not exist")
1487            return None

BORIS type annotation data.

Assumes the following file structure:

annotation_path
├── video1_annotation.pickle
└── video2_labels.pickle

Here annotation_suffix is {'_annotation.pickle', '_labels.pickle'}.

BorisStore( video_order: List = None, min_frames: Dict = None, max_frames: Dict = None, visibility: Dict = None, exclusive: bool = True, len_segment: int = 128, overlap: int = 0, behaviors: Set = None, ignored_classes: Set = None, annotation_suffix: Union[Set, str] = None, annotation_path: Union[Set, str] = None, behavior_file: str = None, correction: Dict = None, frame_limit: int = 0, filter_annotated: bool = False, filter_background: bool = False, error_class: str = None, min_frames_action: int = None, key_objects: Tuple = None, visibility_min_score: float = 0.2, visibility_min_frac: float = 0.7, mask: Dict = None, use_hard_negatives: bool = False, default_agent_name: str = 'ind0', interactive: bool = False, ignored_clips: Set = None, *args, **kwargs)
1336    def __init__(
1337        self,
1338        video_order: List = None,
1339        min_frames: Dict = None,
1340        max_frames: Dict = None,
1341        visibility: Dict = None,
1342        exclusive: bool = True,
1343        len_segment: int = 128,
1344        overlap: int = 0,
1345        behaviors: Set = None,
1346        ignored_classes: Set = None,
1347        annotation_suffix: Union[Set, str] = None,
1348        annotation_path: Union[Set, str] = None,
1349        behavior_file: str = None,
1350        correction: Dict = None,
1351        frame_limit: int = 0,
1352        filter_annotated: bool = False,
1353        filter_background: bool = False,
1354        error_class: str = None,
1355        min_frames_action: int = None,
1356        key_objects: Tuple = None,
1357        visibility_min_score: float = 0.2,
1358        visibility_min_frac: float = 0.7,
1359        mask: Dict = None,
1360        use_hard_negatives: bool = False,
1361        default_agent_name: str = "ind0",
1362        interactive: bool = False,
1363        ignored_clips: Set = None,
1364        *args,
1365        **kwargs,
1366    ) -> None:
1367        """Initialize a store.
1368
1369        Parameters
1370        ----------
1371        video_order : list, optional
1372            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1373        min_frames : dict, optional
1374            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1375            clip start frames (not passed if creating from key objects)
1376        max_frames : dict, optional
1377            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1378            clip end frames (not passed if creating from key objects)
1379        visibility : dict, optional
1380            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1381            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
1382        exclusive : bool, default True
1383            if True, the annotation is single-label; if False, multi-label
1384        len_segment : int, default 128
1385            the length of the segments in which the data should be cut (in frames)
1386        overlap : int, default 0
1387            the length of the overlap between neighboring segments (in frames)
1388        behaviors : set, optional
1389            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
1390            loaded from a file)
1391        ignored_classes : set, optional
1392            the list of behaviors from the behaviors list or file to not annotate
1393        annotation_suffix : str | set, optional
1394            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
1395            (not passed if creating from key objects or if irrelevant for the dataset)
1396        annotation_path : str | set, optional
1397            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
1398            from key objects)
1399        behavior_file : str, optional
1400            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
1401        correction : dict, optional
1402            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
1403            can be used to correct for variations in naming or to merge several labels in one
1404        frame_limit : int, default 0
1405            the smallest possible length of a clip (shorter clips are discarded)
1406        filter_annotated : bool, default False
1407            if True, the samples that do not have any labels will be filtered
1408        filter_background : bool, default False
1409            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
1410        error_class : str, optional
1411            the name of the error class (the annotations that intersect with this label will be discarded)
1412        min_frames_action : int, default 0
1413            the minimum length of an action (shorter actions are not annotated)
1414        key_objects : tuple, optional
1415            the key objects to load the BehaviorStore from
1416        visibility_min_score : float, default 5
1417            the minimum visibility score for visibility filtering
1418        visibility_min_frac : float, default 0.7
1419            the minimum fraction of visible frames for visibility filtering
1420        mask : dict, optional
1421            a masked value dictionary (for active learning simulation experiments)
1422        use_hard_negatives : bool, default False
1423            mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing
1424        default_agent_name : str, default 'ind0'
1425            the name of the default agent
1426        interactive : bool, default False
1427            if `True`, annotation is assigned to pairs of individuals
1428        ignored_clips : set, optional
1429            a set of clip ids to ignore
1430
1431        """
1432        self.default_agent_name = default_agent_name
1433        super().__init__(
1434            video_order=video_order,
1435            min_frames=min_frames,
1436            max_frames=max_frames,
1437            visibility=visibility,
1438            exclusive=exclusive,
1439            len_segment=len_segment,
1440            overlap=overlap,
1441            behaviors=behaviors,
1442            ignored_classes=ignored_classes,
1443            annotation_suffix=annotation_suffix,
1444            annotation_path=annotation_path,
1445            behavior_file=behavior_file,
1446            correction=correction,
1447            frame_limit=frame_limit,
1448            filter_annotated=filter_annotated,
1449            filter_background=filter_background,
1450            error_class=error_class,
1451            min_frames_action=min_frames_action,
1452            key_objects=key_objects,
1453            visibility_min_score=visibility_min_score,
1454            visibility_min_frac=visibility_min_frac,
1455            mask=mask,
1456            use_hard_negatives=use_hard_negatives,
1457            interactive=interactive,
1458            ignored_clips=ignored_clips,
1459        )

Initialize a store.

Parameters

video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects) min_frames : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are clip start frames (not passed if creating from key objects) max_frames : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are clip end frames (not passed if creating from key objects) visibility : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset) exclusive : bool, default True if True, the annotation is single-label; if False, multi-label len_segment : int, default 128 the length of the segments in which the data should be cut (in frames) overlap : int, default 0 the length of the overlap between neighboring segments (in frames) behaviors : set, optional the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are loaded from a file) ignored_classes : set, optional the list of behaviors from the behaviors list or file to not annotate annotation_suffix : str | set, optional the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix} (not passed if creating from key objects or if irrelevant for the dataset) annotation_path : str | set, optional the path or the set of paths to the folder where the annotation files are stored (not passed if creating from key objects) behavior_file : str, optional the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset) correction : dict, optional a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'}, can be used to correct for variations in naming or to merge several labels in one frame_limit : int, default 0 the smallest possible length of a clip (shorter clips are discarded) filter_annotated : bool, default False if True, the samples that do not have any labels will be filtered filter_background : bool, default False if True, only the unlabeled frames that are close to annotated frames will be labeled as background error_class : str, optional the name of the error class (the annotations that intersect with this label will be discarded) min_frames_action : int, default 0 the minimum length of an action (shorter actions are not annotated) key_objects : tuple, optional the key objects to load the BehaviorStore from visibility_min_score : float, default 5 the minimum visibility score for visibility filtering visibility_min_frac : float, default 0.7 the minimum fraction of visible frames for visibility filtering mask : dict, optional a masked value dictionary (for active learning simulation experiments) use_hard_negatives : bool, default False mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing default_agent_name : str, default 'ind0' the name of the default agent interactive : bool, default False if True, annotation is assigned to pairs of individuals ignored_clips : set, optional a set of clip ids to ignore

default_agent_name
class CalMS21Store(SequenceActionSegStore):
1490class CalMS21Store(SequenceActionSegStore):  # +
1491    """CalMS21 annotation data.
1492
1493    Use the `'random:test_from_name:{name}'` and `'val-from-name:{val_name}:test-from-name:{test_name}'`
1494    partitioning methods with `'train'`, `'test'` and `'unlabeled'` names to separate into train, test and validation
1495    subsets according to the original files. For example, with `'val-from-name:test:test-from-name:unlabeled'`
1496    the data from the test file will go into validation and the unlabeled files will be the test.
1497
1498    Assumes the following file structure:
1499    ```
1500    annotation_path
1501    ├── calms21_task_train.npy
1502    ├── calms21_task_test.npy
1503    ├── calms21_unlabeled_videos_part1.npy
1504    ├── calms21_unlabeled_videos_part2.npy
1505    └── calms21_unlabeled_videos_part3.npy
1506    ```
1507    """
1508
1509    def __init__(
1510        self,
1511        task_n: int = 1,
1512        include_task1: bool = True,
1513        video_order: List = None,
1514        min_frames: Dict = None,
1515        max_frames: Dict = None,
1516        len_segment: int = 128,
1517        overlap: int = 0,
1518        ignored_classes: Set = None,
1519        annotation_path: Union[Set, str] = None,
1520        key_objects: Tuple = None,
1521        treba_files: bool = False,
1522        *args,
1523        **kwargs,
1524    ) -> None:
1525        """Initialize the store.
1526
1527        Parameters
1528        ----------
1529        task_n : [1, 2]
1530            the number of the task
1531        include_task1 : bool, default True
1532            include task 1 data to training set
1533        video_order : list, optional
1534            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1535        min_frames : dict, optional
1536            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1537            clip start frames (not passed if creating from key objects)
1538        max_frames : dict, optional
1539            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1540            clip end frames (not passed if creating from key objects)
1541        len_segment : int, default 128
1542            the length of the segments in which the data should be cut (in frames)
1543        overlap : int, default 0
1544            the length of the overlap between neighboring segments (in frames)
1545        ignored_classes : set, optional
1546            the list of behaviors from the behaviors list or file to not annotate
1547        annotation_path : str | set, optional
1548            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
1549            from key objects)
1550        key_objects : tuple, optional
1551            the key objects to load the BehaviorStore from
1552        treba_files : bool, default False
1553            if `True`, TREBA feature files will be loaded
1554
1555        """
1556        self.task_n = int(task_n)
1557        self.include_task1 = include_task1
1558        if self.task_n == 1:
1559            self.include_task1 = False
1560        self.treba_files = treba_files
1561        if "exclusive" in kwargs:
1562            exclusive = kwargs["exclusive"]
1563        else:
1564            exclusive = True
1565        if "behaviors" in kwargs and kwargs["behaviors"] is not None:
1566            behaviors = kwargs["behaviors"]
1567        else:
1568            behaviors = ["attack", "investigation", "mount", "other"]
1569            if task_n == 3:
1570                exclusive = False
1571                behaviors += [
1572                    "approach",
1573                    "disengaged",
1574                    "groom",
1575                    "intromission",
1576                    "mount_attempt",
1577                    "sniff_face",
1578                    "whiterearing",
1579                ]
1580        super().__init__(
1581            video_order=video_order,
1582            min_frames=min_frames,
1583            max_frames=max_frames,
1584            exclusive=exclusive,
1585            len_segment=len_segment,
1586            overlap=overlap,
1587            behaviors=behaviors,
1588            ignored_classes=ignored_classes,
1589            annotation_path=annotation_path,
1590            key_objects=key_objects,
1591            filter_annotated=False,
1592            interactive=True,
1593        )
1594
1595    @classmethod
1596    def get_file_ids(
1597        cls,
1598        task_n: int = 1,
1599        include_task1: bool = False,
1600        treba_files: bool = False,
1601        annotation_path: Union[str, Set] = None,
1602        file_paths=None,
1603        *args,
1604        **kwargs,
1605    ) -> Iterable:
1606        """Get file ids.
1607
1608        Process data parameters and return a list of ids  of the videos that should
1609        be processed by the `__init__` function.
1610
1611        Parameters
1612        ----------
1613        task_n : {1, 2, 3}
1614            the index of the CalMS21 challenge task
1615        include_task1 : bool, default False
1616            if `True`, the training file of the task 1 will be loaded
1617        treba_files : bool, default False
1618            if `True`, the TREBA feature files will be loaded
1619        filenames : set, optional
1620            a set of string filenames to search for (only basenames, not the whole paths)
1621        annotation_path : str | set, optional
1622            the path to the folder where the pose and feature files are stored or a set of such paths
1623            (not passed if creating from key objects or from `file_paths`)
1624        file_paths : set, optional
1625            a set of string paths to the pose and feature files
1626            (not passed if creating from key objects or from `data_path`)
1627
1628        Returns
1629        -------
1630        video_ids : list
1631            a list of video file ids
1632
1633        """
1634        task_n = int(task_n)
1635        if task_n == 1:
1636            include_task1 = False
1637        files = []
1638        if treba_files:
1639            postfix = "_features"
1640        else:
1641            postfix = ""
1642        files.append(f"calms21_task{task_n}_train{postfix}.npy")
1643        files.append(f"calms21_task{task_n}_test{postfix}.npy")
1644        if include_task1:
1645            files.append(f"calms21_task1_train{postfix}.npy")
1646        filenames = set(files)
1647        return SequenceActionSegStore.get_file_ids(
1648            filenames, annotation_path=annotation_path
1649        )
1650
1651    def _open_sequences(self, filename: str) -> Dict:
1652        """Load the annotation from filename.
1653
1654        Parameters
1655        ----------
1656        filename : str
1657            path to an annotation file
1658
1659        Returns
1660        -------
1661        times : dict
1662            a nested dictionary where first-level keys are video ids, second-level keys are clip ids,
1663            third-level keys are categories and values are
1664            lists of (start, end, ambiguity status) lists
1665
1666        """
1667        data_dict = np.load(filename, allow_pickle=True).item()
1668        data = {}
1669        result = {}
1670        keys = list(data_dict.keys())
1671        if "test" in os.path.basename(filename):
1672            mode = "test"
1673        elif "unlabeled" in os.path.basename(filename):
1674            mode = "unlabeled"
1675        else:
1676            mode = "train"
1677        if "approach" in keys:
1678            for behavior in keys:
1679                for key in data_dict[behavior].keys():
1680                    ann = data_dict[behavior][key]["annotations"]
1681                    result[f'{mode}--{key.split("/")[-1]}'] = {
1682                        "mouse1+mouse2": defaultdict(lambda: [])
1683                    }
1684                    starts = np.where(
1685                        np.diff(np.concatenate([np.array([0]), ann, np.array([0])]))
1686                        == 1
1687                    )[0]
1688                    ends = np.where(
1689                        np.diff(np.concatenate([np.array([0]), ann, np.array([0])]))
1690                        == -1
1691                    )[0]
1692                    for start, end in zip(starts, ends):
1693                        result[f'{mode}--{key.split("/")[-1]}']["mouse1+mouse2"][
1694                            behavior
1695                        ].append([start, end, 0])
1696                    for b in self.behaviors:
1697                        result[f'{mode}--{key.split("/")[-1]}---mouse1+mouse2'][
1698                            "mouse1+mouse2"
1699                        ][f"unknown {b}"].append([0, len(ann), 0])
1700        for key in keys:
1701            data.update(data_dict[key])
1702            data_dict.pop(key)
1703        if "approach" not in keys and self.task_n == 3:
1704            for key in data.keys():
1705                result[f'{mode}--{key.split("/")[-1]}'] = {"mouse1+mouse2": {}}
1706                ann = data[key]["annotations"]
1707                for i in range(4):
1708                    starts = np.where(
1709                        np.diff(
1710                            np.concatenate(
1711                                [np.array([0]), (ann == i).astype(int), np.array([0])]
1712                            )
1713                        )
1714                        == 1
1715                    )[0]
1716                    ends = np.where(
1717                        np.diff(
1718                            np.concatenate(
1719                                [np.array([0]), (ann == i).astype(int), np.array([0])]
1720                            )
1721                        )
1722                        == -1
1723                    )[0]
1724                    result[f'{mode}--{key.split("/")[-1]}']["mouse1+mouse2"][
1725                        self.behaviors_dict()[i]
1726                    ] = [[start, end, 0] for start, end in zip(starts, ends)]
1727        if self.task_n != 3:
1728            for seq_name, seq_dict in data.items():
1729                if "annotations" not in seq_dict:
1730                    return None
1731                behaviors = np.unique(seq_dict["annotations"])
1732                ann = seq_dict["annotations"]
1733                key = f'{mode}--{seq_name.split("/")[-1]}'
1734                result[key] = {"mouse1+mouse2": {}}
1735                for i in behaviors:
1736                    starts = np.where(
1737                        np.diff(
1738                            np.concatenate(
1739                                [np.array([0]), (ann == i).astype(int), np.array([0])]
1740                            )
1741                        )
1742                        == 1
1743                    )[0]
1744                    ends = np.where(
1745                        np.diff(
1746                            np.concatenate(
1747                                [np.array([0]), (ann == i).astype(int), np.array([0])]
1748                            )
1749                        )
1750                        == -1
1751                    )[0]
1752                    result[key]["mouse1+mouse2"][self.behaviors_dict()[i]] = [
1753                        [start, end, 0] for start, end in zip(starts, ends)
1754                    ]
1755        return result

CalMS21 annotation data.

Use the 'random:test_from_name:{name}' and 'val-from-name:{val_name}:test-from-name:{test_name}' partitioning methods with 'train', 'test' and 'unlabeled' names to separate into train, test and validation subsets according to the original files. For example, with 'val-from-name:test:test-from-name:unlabeled' the data from the test file will go into validation and the unlabeled files will be the test.

Assumes the following file structure:

annotation_path
├── calms21_task_train.npy
├── calms21_task_test.npy
├── calms21_unlabeled_videos_part1.npy
├── calms21_unlabeled_videos_part2.npy
└── calms21_unlabeled_videos_part3.npy
CalMS21Store( task_n: int = 1, include_task1: bool = True, video_order: List = None, min_frames: Dict = None, max_frames: Dict = None, len_segment: int = 128, overlap: int = 0, ignored_classes: Set = None, annotation_path: Union[Set, str] = None, key_objects: Tuple = None, treba_files: bool = False, *args, **kwargs)
1509    def __init__(
1510        self,
1511        task_n: int = 1,
1512        include_task1: bool = True,
1513        video_order: List = None,
1514        min_frames: Dict = None,
1515        max_frames: Dict = None,
1516        len_segment: int = 128,
1517        overlap: int = 0,
1518        ignored_classes: Set = None,
1519        annotation_path: Union[Set, str] = None,
1520        key_objects: Tuple = None,
1521        treba_files: bool = False,
1522        *args,
1523        **kwargs,
1524    ) -> None:
1525        """Initialize the store.
1526
1527        Parameters
1528        ----------
1529        task_n : [1, 2]
1530            the number of the task
1531        include_task1 : bool, default True
1532            include task 1 data to training set
1533        video_order : list, optional
1534            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1535        min_frames : dict, optional
1536            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1537            clip start frames (not passed if creating from key objects)
1538        max_frames : dict, optional
1539            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1540            clip end frames (not passed if creating from key objects)
1541        len_segment : int, default 128
1542            the length of the segments in which the data should be cut (in frames)
1543        overlap : int, default 0
1544            the length of the overlap between neighboring segments (in frames)
1545        ignored_classes : set, optional
1546            the list of behaviors from the behaviors list or file to not annotate
1547        annotation_path : str | set, optional
1548            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
1549            from key objects)
1550        key_objects : tuple, optional
1551            the key objects to load the BehaviorStore from
1552        treba_files : bool, default False
1553            if `True`, TREBA feature files will be loaded
1554
1555        """
1556        self.task_n = int(task_n)
1557        self.include_task1 = include_task1
1558        if self.task_n == 1:
1559            self.include_task1 = False
1560        self.treba_files = treba_files
1561        if "exclusive" in kwargs:
1562            exclusive = kwargs["exclusive"]
1563        else:
1564            exclusive = True
1565        if "behaviors" in kwargs and kwargs["behaviors"] is not None:
1566            behaviors = kwargs["behaviors"]
1567        else:
1568            behaviors = ["attack", "investigation", "mount", "other"]
1569            if task_n == 3:
1570                exclusive = False
1571                behaviors += [
1572                    "approach",
1573                    "disengaged",
1574                    "groom",
1575                    "intromission",
1576                    "mount_attempt",
1577                    "sniff_face",
1578                    "whiterearing",
1579                ]
1580        super().__init__(
1581            video_order=video_order,
1582            min_frames=min_frames,
1583            max_frames=max_frames,
1584            exclusive=exclusive,
1585            len_segment=len_segment,
1586            overlap=overlap,
1587            behaviors=behaviors,
1588            ignored_classes=ignored_classes,
1589            annotation_path=annotation_path,
1590            key_objects=key_objects,
1591            filter_annotated=False,
1592            interactive=True,
1593        )

Initialize the store.

Parameters

task_n : [1, 2] the number of the task include_task1 : bool, default True include task 1 data to training set video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects) min_frames : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are clip start frames (not passed if creating from key objects) max_frames : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are clip end frames (not passed if creating from key objects) len_segment : int, default 128 the length of the segments in which the data should be cut (in frames) overlap : int, default 0 the length of the overlap between neighboring segments (in frames) ignored_classes : set, optional the list of behaviors from the behaviors list or file to not annotate annotation_path : str | set, optional the path or the set of paths to the folder where the annotation files are stored (not passed if creating from key objects) key_objects : tuple, optional the key objects to load the BehaviorStore from treba_files : bool, default False if True, TREBA feature files will be loaded

task_n
include_task1
treba_files
@classmethod
def get_file_ids( cls, task_n: int = 1, include_task1: bool = False, treba_files: bool = False, annotation_path: Union[str, Set] = None, file_paths=None, *args, **kwargs) -> Iterable:
1595    @classmethod
1596    def get_file_ids(
1597        cls,
1598        task_n: int = 1,
1599        include_task1: bool = False,
1600        treba_files: bool = False,
1601        annotation_path: Union[str, Set] = None,
1602        file_paths=None,
1603        *args,
1604        **kwargs,
1605    ) -> Iterable:
1606        """Get file ids.
1607
1608        Process data parameters and return a list of ids  of the videos that should
1609        be processed by the `__init__` function.
1610
1611        Parameters
1612        ----------
1613        task_n : {1, 2, 3}
1614            the index of the CalMS21 challenge task
1615        include_task1 : bool, default False
1616            if `True`, the training file of the task 1 will be loaded
1617        treba_files : bool, default False
1618            if `True`, the TREBA feature files will be loaded
1619        filenames : set, optional
1620            a set of string filenames to search for (only basenames, not the whole paths)
1621        annotation_path : str | set, optional
1622            the path to the folder where the pose and feature files are stored or a set of such paths
1623            (not passed if creating from key objects or from `file_paths`)
1624        file_paths : set, optional
1625            a set of string paths to the pose and feature files
1626            (not passed if creating from key objects or from `data_path`)
1627
1628        Returns
1629        -------
1630        video_ids : list
1631            a list of video file ids
1632
1633        """
1634        task_n = int(task_n)
1635        if task_n == 1:
1636            include_task1 = False
1637        files = []
1638        if treba_files:
1639            postfix = "_features"
1640        else:
1641            postfix = ""
1642        files.append(f"calms21_task{task_n}_train{postfix}.npy")
1643        files.append(f"calms21_task{task_n}_test{postfix}.npy")
1644        if include_task1:
1645            files.append(f"calms21_task1_train{postfix}.npy")
1646        filenames = set(files)
1647        return SequenceActionSegStore.get_file_ids(
1648            filenames, annotation_path=annotation_path
1649        )

Get file ids.

Process data parameters and return a list of ids of the videos that should be processed by the __init__ function.

Parameters

task_n : {1, 2, 3} the index of the CalMS21 challenge task include_task1 : bool, default False if True, the training file of the task 1 will be loaded treba_files : bool, default False if True, the TREBA feature files will be loaded filenames : set, optional a set of string filenames to search for (only basenames, not the whole paths) annotation_path : str | set, optional the path to the folder where the pose and feature files are stored or a set of such paths (not passed if creating from key objects or from file_paths) file_paths : set, optional a set of string paths to the pose and feature files (not passed if creating from key objects or from data_path)

Returns

video_ids : list a list of video file ids

class CSVActionSegStore(FileActionSegStore):
1758class CSVActionSegStore(FileActionSegStore):  # +
1759    """CSV type annotation data.
1760
1761    Assumes that files are saved as .csv tables with at least the following columns:
1762    - from / start : start of action,
1763    - to / end : end of action,
1764    - class / behavior / behaviour / label / type : action label.
1765
1766    If the times are set in seconds instead of frames, don't forget to set the `fps` parameter to your frame rate.
1767
1768    Assumes the following file structure:
1769    ```
1770    annotation_path
1771    ├── video1_annotation.csv
1772    └── video2_labels.csv
1773    ```
1774    Here `annotation_suffix` is `{'_annotation.csv', '_labels.csv'}`.
1775    """
1776
1777    def __init__(
1778        self,
1779        video_order: List = None,
1780        min_frames: Dict = None,
1781        max_frames: Dict = None,
1782        visibility: Dict = None,
1783        exclusive: bool = True,
1784        len_segment: int = 128,
1785        overlap: int = 0,
1786        behaviors: Set = None,
1787        ignored_classes: Set = None,
1788        annotation_suffix: Union[Set, str] = None,
1789        annotation_path: Union[Set, str] = None,
1790        behavior_file: str = None,
1791        correction: Dict = None,
1792        frame_limit: int = 0,
1793        filter_annotated: bool = False,
1794        filter_background: bool = False,
1795        error_class: str = None,
1796        min_frames_action: int = None,
1797        key_objects: Tuple = None,
1798        visibility_min_score: float = 0.2,
1799        visibility_min_frac: float = 0.7,
1800        mask: Dict = None,
1801        default_agent_name: str = "ind0",
1802        separator: str = ",",
1803        fps: int = 30,
1804        *args,
1805        **kwargs,
1806    ) -> None:
1807        """Initialize the store.
1808
1809        Parameters
1810        ----------
1811        video_order : list, optional
1812            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1813        min_frames : dict, optional
1814            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1815            clip start frames (not passed if creating from key objects)
1816        max_frames : dict, optional
1817            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1818            clip end frames (not passed if creating from key objects)
1819        visibility : dict, optional
1820            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1821            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
1822        exclusive : bool, default True
1823            if True, the annotation is single-label; if False, multi-label
1824        len_segment : int, default 128
1825            the length of the segments in which the data should be cut (in frames)
1826        overlap : int, default 0
1827            the length of the overlap between neighboring segments (in frames)
1828        behaviors : set, optional
1829            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
1830            loaded from a file)
1831        ignored_classes : set, optional
1832            the list of behaviors from the behaviors list or file to not annotate
1833        annotation_suffix : str | set, optional
1834            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
1835            (not passed if creating from key objects or if irrelevant for the dataset)
1836        annotation_path : str | set, optional
1837            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
1838            from key objects)
1839        behavior_file : str, optional
1840            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
1841        correction : dict, optional
1842            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
1843            can be used to correct for variations in naming or to merge several labels in one
1844        frame_limit : int, default 0
1845            the smallest possible length of a clip (shorter clips are discarded)
1846        filter_annotated : bool, default False
1847            if True, the samples that do not have any labels will be filtered
1848        filter_background : bool, default False
1849            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
1850        error_class : str, optional
1851            the name of the error class (the annotations that intersect with this label will be discarded)
1852        min_frames_action : int, default 0
1853            the minimum length of an action (shorter actions are not annotated)
1854        key_objects : tuple, optional
1855            the key objects to load the BehaviorStore from
1856        visibility_min_score : float, default 5
1857            the minimum visibility score for visibility filtering
1858        visibility_min_frac : float, default 0.7
1859            the minimum fraction of visible frames for visibility filtering
1860        mask : dict, optional
1861            a masked value dictionary (for active learning simulation experiments)
1862        default_agent_name : str, default "ind0"
1863            the clip id to use when there is no given
1864        separator : str, default ","
1865            the separator in the csv files
1866        fps : int, default 30
1867            frames per second in the videos
1868
1869        """
1870        self.default_agent_name = default_agent_name
1871        self.separator = separator
1872        self.fps = fps
1873        super().__init__(
1874            video_order=video_order,
1875            min_frames=min_frames,
1876            max_frames=max_frames,
1877            visibility=visibility,
1878            exclusive=exclusive,
1879            len_segment=len_segment,
1880            overlap=overlap,
1881            behaviors=behaviors,
1882            ignored_classes=ignored_classes,
1883            ignored_clips=None,
1884            annotation_suffix=annotation_suffix,
1885            annotation_path=annotation_path,
1886            behavior_file=behavior_file,
1887            correction=correction,
1888            frame_limit=frame_limit,
1889            filter_annotated=filter_annotated,
1890            filter_background=filter_background,
1891            error_class=error_class,
1892            min_frames_action=min_frames_action,
1893            key_objects=key_objects,
1894            visibility_min_score=visibility_min_score,
1895            visibility_min_frac=visibility_min_frac,
1896            mask=mask,
1897        )
1898
1899    def _open_annotations(self, filename: str) -> Dict:
1900        """Load the annotation from `filename`."""
1901        data = pd.read_csv(filename, sep=self.separator)
1902        data.columns = list(map(lambda x: x.lower(), data.columns))
1903        starts, ends, actions = None, None, None
1904        start_names = ["from", "start"]
1905        for x in start_names:
1906            if x in data.columns:
1907                starts = data[x]
1908        end_names = ["to", "end"]
1909        for x in end_names:
1910            if x in data.columns:
1911                ends = data[x]
1912        class_names = ["class", "behavior", "behaviour", "type", "label"]
1913        for x in class_names:
1914            if x in data.columns:
1915                actions = data[x]
1916        if starts is None:
1917            raise ValueError("The file must have a column titled 'from' or 'start'!")
1918        if ends is None:
1919            raise ValueError("The file must have a column titled 'to' or 'end'!")
1920        if actions is None:
1921            raise ValueError(
1922                "The file must have a column titled 'class', 'behavior', 'behaviour', 'type' or 'label'!"
1923            )
1924        times = defaultdict(lambda: defaultdict(lambda: []))
1925        for start, end, action in zip(starts, ends, actions):
1926            if any([np.isnan(x) for x in [start, end]]):
1927                continue
1928            times[self.default_agent_name][action].append(
1929                [int(start * self.fps), int(end * self.fps), 0]
1930            )
1931        return times

CSV type annotation data.

Assumes that files are saved as .csv tables with at least the following columns:

  • from / start : start of action,
  • to / end : end of action,
  • class / behavior / behaviour / label / type : action label.

If the times are set in seconds instead of frames, don't forget to set the fps parameter to your frame rate.

Assumes the following file structure:

annotation_path
├── video1_annotation.csv
└── video2_labels.csv

Here annotation_suffix is {'_annotation.csv', '_labels.csv'}.

CSVActionSegStore( video_order: List = None, min_frames: Dict = None, max_frames: Dict = None, visibility: Dict = None, exclusive: bool = True, len_segment: int = 128, overlap: int = 0, behaviors: Set = None, ignored_classes: Set = None, annotation_suffix: Union[Set, str] = None, annotation_path: Union[Set, str] = None, behavior_file: str = None, correction: Dict = None, frame_limit: int = 0, filter_annotated: bool = False, filter_background: bool = False, error_class: str = None, min_frames_action: int = None, key_objects: Tuple = None, visibility_min_score: float = 0.2, visibility_min_frac: float = 0.7, mask: Dict = None, default_agent_name: str = 'ind0', separator: str = ',', fps: int = 30, *args, **kwargs)
1777    def __init__(
1778        self,
1779        video_order: List = None,
1780        min_frames: Dict = None,
1781        max_frames: Dict = None,
1782        visibility: Dict = None,
1783        exclusive: bool = True,
1784        len_segment: int = 128,
1785        overlap: int = 0,
1786        behaviors: Set = None,
1787        ignored_classes: Set = None,
1788        annotation_suffix: Union[Set, str] = None,
1789        annotation_path: Union[Set, str] = None,
1790        behavior_file: str = None,
1791        correction: Dict = None,
1792        frame_limit: int = 0,
1793        filter_annotated: bool = False,
1794        filter_background: bool = False,
1795        error_class: str = None,
1796        min_frames_action: int = None,
1797        key_objects: Tuple = None,
1798        visibility_min_score: float = 0.2,
1799        visibility_min_frac: float = 0.7,
1800        mask: Dict = None,
1801        default_agent_name: str = "ind0",
1802        separator: str = ",",
1803        fps: int = 30,
1804        *args,
1805        **kwargs,
1806    ) -> None:
1807        """Initialize the store.
1808
1809        Parameters
1810        ----------
1811        video_order : list, optional
1812            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1813        min_frames : dict, optional
1814            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1815            clip start frames (not passed if creating from key objects)
1816        max_frames : dict, optional
1817            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1818            clip end frames (not passed if creating from key objects)
1819        visibility : dict, optional
1820            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1821            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
1822        exclusive : bool, default True
1823            if True, the annotation is single-label; if False, multi-label
1824        len_segment : int, default 128
1825            the length of the segments in which the data should be cut (in frames)
1826        overlap : int, default 0
1827            the length of the overlap between neighboring segments (in frames)
1828        behaviors : set, optional
1829            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
1830            loaded from a file)
1831        ignored_classes : set, optional
1832            the list of behaviors from the behaviors list or file to not annotate
1833        annotation_suffix : str | set, optional
1834            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
1835            (not passed if creating from key objects or if irrelevant for the dataset)
1836        annotation_path : str | set, optional
1837            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
1838            from key objects)
1839        behavior_file : str, optional
1840            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
1841        correction : dict, optional
1842            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
1843            can be used to correct for variations in naming or to merge several labels in one
1844        frame_limit : int, default 0
1845            the smallest possible length of a clip (shorter clips are discarded)
1846        filter_annotated : bool, default False
1847            if True, the samples that do not have any labels will be filtered
1848        filter_background : bool, default False
1849            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
1850        error_class : str, optional
1851            the name of the error class (the annotations that intersect with this label will be discarded)
1852        min_frames_action : int, default 0
1853            the minimum length of an action (shorter actions are not annotated)
1854        key_objects : tuple, optional
1855            the key objects to load the BehaviorStore from
1856        visibility_min_score : float, default 5
1857            the minimum visibility score for visibility filtering
1858        visibility_min_frac : float, default 0.7
1859            the minimum fraction of visible frames for visibility filtering
1860        mask : dict, optional
1861            a masked value dictionary (for active learning simulation experiments)
1862        default_agent_name : str, default "ind0"
1863            the clip id to use when there is no given
1864        separator : str, default ","
1865            the separator in the csv files
1866        fps : int, default 30
1867            frames per second in the videos
1868
1869        """
1870        self.default_agent_name = default_agent_name
1871        self.separator = separator
1872        self.fps = fps
1873        super().__init__(
1874            video_order=video_order,
1875            min_frames=min_frames,
1876            max_frames=max_frames,
1877            visibility=visibility,
1878            exclusive=exclusive,
1879            len_segment=len_segment,
1880            overlap=overlap,
1881            behaviors=behaviors,
1882            ignored_classes=ignored_classes,
1883            ignored_clips=None,
1884            annotation_suffix=annotation_suffix,
1885            annotation_path=annotation_path,
1886            behavior_file=behavior_file,
1887            correction=correction,
1888            frame_limit=frame_limit,
1889            filter_annotated=filter_annotated,
1890            filter_background=filter_background,
1891            error_class=error_class,
1892            min_frames_action=min_frames_action,
1893            key_objects=key_objects,
1894            visibility_min_score=visibility_min_score,
1895            visibility_min_frac=visibility_min_frac,
1896            mask=mask,
1897        )

Initialize the store.

Parameters

video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects) min_frames : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are clip start frames (not passed if creating from key objects) max_frames : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are clip end frames (not passed if creating from key objects) visibility : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset) exclusive : bool, default True if True, the annotation is single-label; if False, multi-label len_segment : int, default 128 the length of the segments in which the data should be cut (in frames) overlap : int, default 0 the length of the overlap between neighboring segments (in frames) behaviors : set, optional the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are loaded from a file) ignored_classes : set, optional the list of behaviors from the behaviors list or file to not annotate annotation_suffix : str | set, optional the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix} (not passed if creating from key objects or if irrelevant for the dataset) annotation_path : str | set, optional the path or the set of paths to the folder where the annotation files are stored (not passed if creating from key objects) behavior_file : str, optional the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset) correction : dict, optional a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'}, can be used to correct for variations in naming or to merge several labels in one frame_limit : int, default 0 the smallest possible length of a clip (shorter clips are discarded) filter_annotated : bool, default False if True, the samples that do not have any labels will be filtered filter_background : bool, default False if True, only the unlabeled frames that are close to annotated frames will be labeled as background error_class : str, optional the name of the error class (the annotations that intersect with this label will be discarded) min_frames_action : int, default 0 the minimum length of an action (shorter actions are not annotated) key_objects : tuple, optional the key objects to load the BehaviorStore from visibility_min_score : float, default 5 the minimum visibility score for visibility filtering visibility_min_frac : float, default 0.7 the minimum fraction of visible frames for visibility filtering mask : dict, optional a masked value dictionary (for active learning simulation experiments) default_agent_name : str, default "ind0" the clip id to use when there is no given separator : str, default "," the separator in the csv files fps : int, default 30 frames per second in the videos

default_agent_name
separator
fps
class SIMBAStore(FileActionSegStore):
1934class SIMBAStore(FileActionSegStore):  # +
1935    """SIMBA paper format data.
1936
1937    Assumes the following file structure:
1938    ```
1939    annotation_path
1940    ├── Video1.csv
1941    ...
1942    └── Video9.csv
1943    """
1944
1945    def __init__(
1946        self,
1947        video_order: List = None,
1948        min_frames: Dict = None,
1949        max_frames: Dict = None,
1950        visibility: Dict = None,
1951        exclusive: bool = True,
1952        len_segment: int = 128,
1953        overlap: int = 0,
1954        behaviors: Set = None,
1955        ignored_classes: Set = None,
1956        ignored_clips: Set = None,
1957        annotation_path: Union[Set, str] = None,
1958        correction: Dict = None,
1959        filter_annotated: bool = False,
1960        filter_background: bool = False,
1961        error_class: str = None,
1962        min_frames_action: int = None,
1963        key_objects: Tuple = None,
1964        visibility_min_score: float = 0.2,
1965        visibility_min_frac: float = 0.7,
1966        mask: Dict = None,
1967        use_hard_negatives: bool = False,
1968        annotation_suffix: str = None,
1969        *args,
1970        **kwargs,
1971    ) -> None:
1972        """Initialize the annotation store.
1973
1974        Parameters
1975        ----------
1976        video_order : list, optional
1977            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1978        min_frames : dict, optional
1979            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1980            clip start frames (not passed if creating from key objects)
1981        max_frames : dict, optional
1982            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1983            clip end frames (not passed if creating from key objects)
1984        visibility : dict, optional
1985            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1986            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
1987        exclusive : bool, default True
1988            if True, the annotation is single-label; if False, multi-label
1989        len_segment : int, default 128
1990            the length of the segments in which the data should be cut (in frames)
1991        overlap : int, default 0
1992            the length of the overlap between neighboring segments (in frames)
1993        behaviors : set, optional
1994            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
1995            loaded from a file)
1996        ignored_classes : set, optional
1997            the list of behaviors from the behaviors list or file to not annotate
1998        ignored_clips : set, optional
1999            clip ids to ignore
2000        annotation_path : str | set, optional
2001            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
2002            from key objects)
2003        behavior_file : str, optional
2004            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
2005        correction : dict, optional
2006            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
2007            can be used to correct for variations in naming or to merge several labels in one
2008        filter_annotated : bool, default False
2009            if True, the samples that do not have any labels will be filtered
2010        filter_background : bool, default False
2011            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
2012        error_class : str, optional
2013            the name of the error class (the annotations that intersect with this label will be discarded)
2014        min_frames_action : int, default 0
2015            the minimum length of an action (shorter actions are not annotated)
2016        key_objects : tuple, optional
2017            the key objects to load the BehaviorStore from
2018        visibility_min_score : float, default 5
2019            the minimum visibility score for visibility filtering
2020        visibility_min_frac : float, default 0.7
2021            the minimum fraction of visible frames for visibility filtering
2022        mask : dict, optional
2023            a masked value dictionary (for active learning simulation experiments)
2024        use_hard_negatives : bool, default False
2025            mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing
2026        annotation_suffix : str | set, optional
2027            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
2028            (not passed if creating from key objects or if irrelevant for the dataset)
2029
2030        """
2031        super().__init__(
2032            video_order=video_order,
2033            min_frames=min_frames,
2034            max_frames=max_frames,
2035            visibility=visibility,
2036            exclusive=exclusive,
2037            len_segment=len_segment,
2038            overlap=overlap,
2039            behaviors=behaviors,
2040            ignored_classes=ignored_classes,
2041            ignored_clips=ignored_clips,
2042            annotation_suffix=annotation_suffix,
2043            annotation_path=annotation_path,
2044            behavior_file=None,
2045            correction=correction,
2046            frame_limit=0,
2047            filter_annotated=filter_annotated,
2048            filter_background=filter_background,
2049            error_class=error_class,
2050            min_frames_action=min_frames_action,
2051            key_objects=key_objects,
2052            visibility_min_score=visibility_min_score,
2053            visibility_min_frac=visibility_min_frac,
2054            mask=mask,
2055            use_hard_negatives=use_hard_negatives,
2056            interactive=True,
2057        )
2058
2059    def _open_annotations(self, filename: str) -> Dict:
2060        """Load the annotation from filename.
2061
2062        Parameters
2063        ----------
2064        filename : str
2065            path to an annotation file
2066
2067        Returns
2068        -------
2069        times : dict
2070            a nested dictionary where first-level keys are clip ids, second-level keys are categories and values are
2071            lists of (start, end, ambiguity status) lists
2072
2073        """
2074        data = pd.read_csv(filename)
2075        columns = [x for x in data.columns if x.split("_")[-1] == "x"]
2076        animals = sorted(set([x.split("_")[-2] for x in columns]))
2077        if len(animals) > 2:
2078            raise ValueError(
2079                "SIMBAStore is only implemented for files with 1 or 2 animals!"
2080            )
2081        if len(animals) == 1:
2082            ind = animals[0]
2083        else:
2084            ind = "+".join(animals)
2085        behaviors = [
2086            "_".join(x.split("_")[:-1])
2087            for x in data.columns
2088            if x.split("_")[-1] == "prediction"
2089        ]
2090        result = {}
2091        for behavior in behaviors:
2092            ann = data[f"{behavior}_prediction"].values
2093            diff = np.diff(
2094                np.concatenate([np.array([0]), (ann == 1).astype(int), np.array([0])])
2095            )
2096            starts = np.where(diff == 1)[0]
2097            ends = np.where(diff == -1)[0]
2098            result[behavior] = [[start, end, 0] for start, end in zip(starts, ends)]
2099            diff = np.diff(
2100                np.concatenate(
2101                    [np.array([0]), (np.isnan(ann)).astype(int), np.array([0])]
2102                )
2103            )
2104            starts = np.where(diff == 1)[0]
2105            ends = np.where(diff == -1)[0]
2106            result[f"unknown {behavior}"] = [
2107                [start, end, 0] for start, end in zip(starts, ends)
2108            ]
2109        if self.behaviors is not None:
2110            for behavior in self.behaviors:
2111                if behavior not in behaviors:
2112                    result[f"unknown {behavior}"] = [[0, len(data), 0]]
2113        return {ind: result}

SIMBA paper format data.

Assumes the following file structure: ``` annotation_path ├── Video1.csv ... └── Video9.csv

SIMBAStore( video_order: List = None, min_frames: Dict = None, max_frames: Dict = None, visibility: Dict = None, exclusive: bool = True, len_segment: int = 128, overlap: int = 0, behaviors: Set = None, ignored_classes: Set = None, ignored_clips: Set = None, annotation_path: Union[Set, str] = None, correction: Dict = None, filter_annotated: bool = False, filter_background: bool = False, error_class: str = None, min_frames_action: int = None, key_objects: Tuple = None, visibility_min_score: float = 0.2, visibility_min_frac: float = 0.7, mask: Dict = None, use_hard_negatives: bool = False, annotation_suffix: str = None, *args, **kwargs)
1945    def __init__(
1946        self,
1947        video_order: List = None,
1948        min_frames: Dict = None,
1949        max_frames: Dict = None,
1950        visibility: Dict = None,
1951        exclusive: bool = True,
1952        len_segment: int = 128,
1953        overlap: int = 0,
1954        behaviors: Set = None,
1955        ignored_classes: Set = None,
1956        ignored_clips: Set = None,
1957        annotation_path: Union[Set, str] = None,
1958        correction: Dict = None,
1959        filter_annotated: bool = False,
1960        filter_background: bool = False,
1961        error_class: str = None,
1962        min_frames_action: int = None,
1963        key_objects: Tuple = None,
1964        visibility_min_score: float = 0.2,
1965        visibility_min_frac: float = 0.7,
1966        mask: Dict = None,
1967        use_hard_negatives: bool = False,
1968        annotation_suffix: str = None,
1969        *args,
1970        **kwargs,
1971    ) -> None:
1972        """Initialize the annotation store.
1973
1974        Parameters
1975        ----------
1976        video_order : list, optional
1977            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1978        min_frames : dict, optional
1979            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1980            clip start frames (not passed if creating from key objects)
1981        max_frames : dict, optional
1982            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1983            clip end frames (not passed if creating from key objects)
1984        visibility : dict, optional
1985            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
1986            visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset)
1987        exclusive : bool, default True
1988            if True, the annotation is single-label; if False, multi-label
1989        len_segment : int, default 128
1990            the length of the segments in which the data should be cut (in frames)
1991        overlap : int, default 0
1992            the length of the overlap between neighboring segments (in frames)
1993        behaviors : set, optional
1994            the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are
1995            loaded from a file)
1996        ignored_classes : set, optional
1997            the list of behaviors from the behaviors list or file to not annotate
1998        ignored_clips : set, optional
1999            clip ids to ignore
2000        annotation_path : str | set, optional
2001            the path or the set of paths to the folder where the annotation files are stored (not passed if creating
2002            from key objects)
2003        behavior_file : str, optional
2004            the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset)
2005        correction : dict, optional
2006            a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'},
2007            can be used to correct for variations in naming or to merge several labels in one
2008        filter_annotated : bool, default False
2009            if True, the samples that do not have any labels will be filtered
2010        filter_background : bool, default False
2011            if True, only the unlabeled frames that are close to annotated frames will be labeled as background
2012        error_class : str, optional
2013            the name of the error class (the annotations that intersect with this label will be discarded)
2014        min_frames_action : int, default 0
2015            the minimum length of an action (shorter actions are not annotated)
2016        key_objects : tuple, optional
2017            the key objects to load the BehaviorStore from
2018        visibility_min_score : float, default 5
2019            the minimum visibility score for visibility filtering
2020        visibility_min_frac : float, default 0.7
2021            the minimum fraction of visible frames for visibility filtering
2022        mask : dict, optional
2023            a masked value dictionary (for active learning simulation experiments)
2024        use_hard_negatives : bool, default False
2025            mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing
2026        annotation_suffix : str | set, optional
2027            the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix}
2028            (not passed if creating from key objects or if irrelevant for the dataset)
2029
2030        """
2031        super().__init__(
2032            video_order=video_order,
2033            min_frames=min_frames,
2034            max_frames=max_frames,
2035            visibility=visibility,
2036            exclusive=exclusive,
2037            len_segment=len_segment,
2038            overlap=overlap,
2039            behaviors=behaviors,
2040            ignored_classes=ignored_classes,
2041            ignored_clips=ignored_clips,
2042            annotation_suffix=annotation_suffix,
2043            annotation_path=annotation_path,
2044            behavior_file=None,
2045            correction=correction,
2046            frame_limit=0,
2047            filter_annotated=filter_annotated,
2048            filter_background=filter_background,
2049            error_class=error_class,
2050            min_frames_action=min_frames_action,
2051            key_objects=key_objects,
2052            visibility_min_score=visibility_min_score,
2053            visibility_min_frac=visibility_min_frac,
2054            mask=mask,
2055            use_hard_negatives=use_hard_negatives,
2056            interactive=True,
2057        )

Initialize the annotation store.

Parameters

video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects) min_frames : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are clip start frames (not passed if creating from key objects) max_frames : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are clip end frames (not passed if creating from key objects) visibility : dict, optional a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are visibility score arrays (not passed if creating from key objects or if irrelevant for the dataset) exclusive : bool, default True if True, the annotation is single-label; if False, multi-label len_segment : int, default 128 the length of the segments in which the data should be cut (in frames) overlap : int, default 0 the length of the overlap between neighboring segments (in frames) behaviors : set, optional the list of behaviors to put in the annotation (not passed if creating a blank instance or if behaviors are loaded from a file) ignored_classes : set, optional the list of behaviors from the behaviors list or file to not annotate ignored_clips : set, optional clip ids to ignore annotation_path : str | set, optional the path or the set of paths to the folder where the annotation files are stored (not passed if creating from key objects) behavior_file : str, optional the path to an .xlsx behavior file (not passed if creating from key objects or if irrelevant for the dataset) correction : dict, optional a dictionary of corrections for the labels (e.g. {'sleping': 'sleeping', 'calm locomotion': 'locomotion'}, can be used to correct for variations in naming or to merge several labels in one filter_annotated : bool, default False if True, the samples that do not have any labels will be filtered filter_background : bool, default False if True, only the unlabeled frames that are close to annotated frames will be labeled as background error_class : str, optional the name of the error class (the annotations that intersect with this label will be discarded) min_frames_action : int, default 0 the minimum length of an action (shorter actions are not annotated) key_objects : tuple, optional the key objects to load the BehaviorStore from visibility_min_score : float, default 5 the minimum visibility score for visibility filtering visibility_min_frac : float, default 0.7 the minimum fraction of visible frames for visibility filtering mask : dict, optional a masked value dictionary (for active learning simulation experiments) use_hard_negatives : bool, default False mark hard negatives as 2 instead of 0 or 1, for loss functions that have options for hard negative processing annotation_suffix : str | set, optional the suffix or the set of suffices such that the annotation files are named {video_id}{annotation_suffix} (not passed if creating from key objects or if irrelevant for the dataset)