dlc2action.data.dataset

Behavior dataset (class that manages high-level data interactions).

   1#
   2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. 
   5# A copy is included in dlc2action/LICENSE.AGPL.
   6#
   7"""Behavior dataset (class that manages high-level data interactions)."""
   8
   9import inspect
  10import os
  11import pickle
  12import warnings
  13from abc import ABC
  14from collections import Counter, defaultdict
  15from copy import copy, deepcopy
  16from typing import Dict, List, Optional, Tuple, Union
  17
  18import numpy as np
  19import torch
  20from dlc2action import options
  21from dlc2action.data.base_store import BehaviorStore, InputStore
  22from dlc2action.utils import (
  23    apply_threshold,
  24    apply_threshold_hysteresis,
  25    apply_threshold_max,
  26)
  27from numpy import ndarray
  28from torch.utils.data import Dataset
  29from tqdm import tqdm
  30
  31
  32class BehaviorDataset(Dataset, ABC):
  33    """A generalized dataset class.
  34
  35    Data and annotation are stored in separate InputStore and BehaviorStore objects; the dataset class
  36    manages their interactions.
  37    """
  38
  39    def __init__(
  40        self,
  41        data_type: str,
  42        annotation_type: str = "none",
  43        ssl_transformations: List = None,
  44        saved_data_path: str = None,
  45        input_store: InputStore = None,
  46        annotation_store: BehaviorStore = None,
  47        only_load_annotated: bool = False,
  48        recompute_annotation: bool = False,
  49        # mask: str = None,
  50        ids: List = None,
  51        **data_parameters,
  52    ) -> None:
  53        """Initialize a dataset.
  54
  55        Parameters
  56        ----------
  57        data_type : str
  58            the data type (see available types by running BehaviorDataset.data_types())
  59        annotation_type : str
  60            the annotation type (see available types by running BehaviorDataset.annotation_types())
  61        ssl_transformations : list
  62            a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple
  63        saved_data_path : str
  64            the path to a pre-computed pickled dataset
  65        input_store : InputStore
  66            a pre-computed input store
  67        annotation_store : BehaviorStore
  68            a precomputed annotation store
  69        only_load_annotated : bool
  70            if `True`, the input files that don't have a matching annotation file will be disregarded
  71        recompute_annotation : bool
  72            if `True`, the annotation will be recomputed even if a precomputed annotation store is provided
  73        ids : list
  74            a list of ids to load from the input store
  75        *data_parameters : dict
  76            parameters to initialize the input and annotation stores
  77
  78        """
  79        mask = None
  80        if len(data_parameters) == 0:
  81            recompute_annotation = False
  82        feature_extraction = data_parameters.get("feature_extraction")
  83        if feature_extraction is not None and not issubclass(
  84            options.input_stores[data_type],
  85            options.feature_extractors[feature_extraction].input_store_class,
  86        ):
  87            raise ValueError(
  88                f"The {feature_extraction} feature extractor does not work with "
  89                f"the {data_type} data type, please choose a suclass of "
  90                f"{options.feature_extractors[feature_extraction].input_store_class}"
  91            )
  92        if ssl_transformations is None:
  93            ssl_transformations = []
  94        self.ssl_transformations = ssl_transformations
  95        self.input_type = data_type
  96        self.annotation_type = annotation_type
  97        self.stats = None
  98        if mask is not None:
  99            with open(mask, "rb") as f:
 100                self.mask = pickle.load(f)
 101        else:
 102            self.mask = None
 103        self.ids = ids
 104        self.tag = None
 105        self.return_unlabeled = None
 106        # load saved key objects for annotation and input if they exist
 107        input_key_objects, annotation_key_objects = None, None
 108        if saved_data_path is not None:
 109            if os.path.exists(saved_data_path):
 110                with open(saved_data_path, "rb") as f:
 111                    input_key_objects, annotation_key_objects = pickle.load(f)
 112        # if the input or the annotation store need to be created, generate the common video order
 113        if len(data_parameters) > 0:
 114            input_files = options.input_stores[data_type].get_file_ids(
 115                **data_parameters
 116            )
 117            annotation_files = options.annotation_stores[annotation_type].get_file_ids(
 118                **data_parameters
 119            )
 120            if only_load_annotated:
 121                data_parameters["video_order"] = [
 122                    x for x in input_files if x in annotation_files
 123                ]
 124            else:
 125                data_parameters["video_order"] = input_files
 126            if len(data_parameters["video_order"]) == 0:
 127                raise RuntimeError(
 128                    "The length of file list is 0! Please check your data parameters!"
 129                )
 130        data_parameters["mask"] = self.mask
 131        # load or create the input store
 132        ok = False
 133        if input_store is not None:
 134            self.input_store = input_store
 135            ok = True
 136        elif input_key_objects is not None:
 137            try:
 138                self.input_store = self._load_input_store(data_type, input_key_objects)
 139                ok = True
 140            except:
 141                warnings.warn("Loading input store from key objects failed")
 142        if not ok:
 143            self.input_store = self._get_input_store(
 144                data_type, deepcopy(data_parameters)
 145            )
 146        # get the objects needed to create the annotation store (like a clip length dictionary)
 147        annotation_objects = self.input_store.get_annotation_objects()
 148        data_parameters.update(annotation_objects)
 149        # load or create the annotation store
 150        ok = False
 151        if annotation_store is not None:
 152            self.annotation_store = annotation_store
 153            ok = True
 154        elif (
 155            (annotation_key_objects is not None)
 156            and mask is None
 157            and not recompute_annotation
 158        ):
 159            if len(annotation_key_objects) > 0:
 160                try:
 161                    self.annotation_store = self._load_annotation_store(
 162                        annotation_type, annotation_key_objects
 163                    )
 164                    ok = True
 165                except:
 166                    warnings.warn("Loading annotation store from key objects failed")
 167        if not ok:
 168            self.annotation_store = self._get_annotation_store(
 169                annotation_type, deepcopy(data_parameters)
 170            )
 171        to_remove = self.annotation_store.filtered_indices()
 172        if len(to_remove) > 0:
 173            print(
 174                f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples"
 175            )
 176        if len(to_remove) == len(self.annotation_store) and len(to_remove) > 0:
 177            raise ValueError("All samples were filtered out!")
 178
 179        if len(self.input_store) == len(self.annotation_store):
 180            self.input_store.remove(to_remove)
 181        self.annotation_store.remove(to_remove)
 182        self.input_indices = list(range(len(self.input_store)))
 183        self.annotation_indices = list(range(len(self.input_store)))
 184        self.indices = list(range(len(self.input_store)))
 185
 186    def __getitem__(self, item: int) -> Dict:
 187        idx = self._get_idx(item)
 188        input = deepcopy(self.input_store[idx])
 189        target = self.annotation_store[idx]
 190        tag = self.input_store.get_tag(idx)
 191        ssl_inputs, ssl_targets = self._get_SSL_targets(input)
 192        batch = {"input": input}
 193        for name, x in zip(
 194            ["target", "ssl_inputs", "ssl_targets", "tag"],
 195            [target, ssl_inputs, ssl_targets, tag],
 196        ):
 197            if x is not None:
 198                batch[name] = x
 199        batch["index"] = idx
 200        if self.stats is not None:
 201            for key in batch["input"].keys():
 202                key_name = key.split("---")[0]
 203                if key_name in self.stats:
 204                    batch["input"][key][:, batch["input"][key].sum(0) != 0] = (
 205                        (batch["input"][key] - self.stats[key_name]["mean"])
 206                        / (self.stats[key_name]["std"] + 1e-7)
 207                    )[:, batch["input"][key].sum(0) != 0]
 208        return batch
 209
 210    def __len__(self) -> int:
 211        return len(self.indices)
 212        # if self.annotation_type != "none":
 213        #     return self.annotation_store.get_len(return_unlabeled=self.return_unlabeled)
 214        # else:
 215        #     return len(self.input_store)
 216
 217    def get_tags(self) -> List:
 218        """Get a list of all meta tags.
 219
 220        Returns
 221        -------
 222        tags: List
 223            a list of unique meta tag values
 224
 225        """
 226        return self.input_store.get_tags()
 227
 228    def save(self, save_path: str) -> None:
 229        """Save the dictionary.
 230
 231        Parameters
 232        ----------
 233        save_path : str
 234            the path where the pickled file will be stored
 235
 236        """
 237        input_obj = self.input_store.key_objects()
 238        annotation_obj = self.annotation_store.key_objects()
 239        with open(save_path, "wb") as f:
 240            pickle.dump((input_obj, annotation_obj), f)
 241
 242    def to_ram(self) -> None:
 243        """Transfer the dataset to RAM."""
 244        self.input_store.to_ram()
 245        self.annotation_store.to_ram()
 246
 247    def generate_full_length_gt(self) -> Dict:
 248        """Generate full-length ground truth from the annotations.
 249
 250        Returns
 251        -------
 252        full_length_gt : dict
 253            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
 254            values are the ground truth labels
 255
 256        """
 257        if self.annotation_class() == "exclusive_classification":
 258            gt = torch.zeros((len(self), self.len_segment()))
 259        else:
 260            gt = torch.zeros(
 261                (len(self), len(self.behaviors_dict()), self.len_segment())
 262            )
 263        for i in range(len(self)):
 264            gt[i] = self.annotation_store[i]
 265        return self.generate_full_length_prediction(gt)
 266
 267    def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict:
 268        """Map predictions for the equal-length pieces to predictions for the original data.
 269
 270        Probabilities are averaged over predictions on overlapping intervals.
 271
 272        Parameters
 273        ----------
 274        predicted: torch.Tensor
 275            a tensor of predicted probabilities of shape `(N, #classes, #frames)`
 276
 277        Returns
 278        -------
 279        full_length_prediction : dict
 280            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
 281            averaged probability tensors
 282
 283        """
 284        result = defaultdict(lambda: {})
 285        counter = defaultdict(lambda: {})
 286        coordinates = self.input_store.get_original_coordinates()
 287        for coords, prediction in zip(coordinates, predicted):
 288            l = self.input_store.get_clip_length_from_coords(coords)
 289            video_name = self.input_store.get_video_id(coords)
 290            clip_id = self.input_store.get_clip_id(coords)
 291            start, end = self.input_store.get_clip_start_end(coords)
 292            if clip_id not in result[video_name].keys():
 293                result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
 294                counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
 295            result[video_name][clip_id][..., start:end] += (
 296                prediction.squeeze()[..., : end - start].detach().cpu()
 297            )
 298            counter[video_name][clip_id][..., start:end] += 1
 299        for video_name in result:
 300            for clip_id in result[video_name]:
 301                result[video_name][clip_id] /= counter[video_name][clip_id]
 302                result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100
 303        result = dict(result)
 304        return result
 305
 306    def find_valleys(
 307        self,
 308        predicted: Union[torch.Tensor, Dict],
 309        threshold: float = 0.5,
 310        min_frames: int = 0,
 311        visibility_min_score: float = 0,
 312        visibility_min_frac: float = 0,
 313        main_class: int = 1,
 314        low: bool = True,
 315        predicted_error: torch.Tensor = None,
 316        error_threshold: float = 0.5,
 317        hysteresis: bool = False,
 318        threshold_diff: float = None,
 319        min_frames_error: int = None,
 320        smooth_interval: int = 1,
 321        cut_annotated: bool = False,
 322    ) -> Dict:
 323        """Find the intervals where the probability of a certain class is below or above a certain hard_threshold.
 324
 325        Parameters
 326        ----------
 327        predicted : torch.Tensor | dict
 328            either a tensor of predictions for the data prompts or the output of
 329            `BehaviorDataset.generate_full_length_prediction`
 330        threshold : float, default 0.5
 331            the main hard_threshold
 332        min_frames : int, default 0
 333            the minimum length of the intervals
 334        visibility_min_score : float, default 0
 335            the minimum visibility score in the intervals
 336        visibility_min_frac : float, default 0
 337            fraction of the interval that has to have the visibility score larger than visibility_score_thr
 338        main_class : int, default 1
 339            the index of the class the function is inspecting
 340        low : bool, default True
 341            if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above
 342        predicted_error : torch.Tensor, optional
 343            a tensor of error predictions for the data prompts
 344        error_threshold : float, default 0.5
 345            maximum possible probability of error at the intervals
 346        hysteresis: bool, default False
 347            if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff
 348        threshold_diff: float, optional
 349            the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes
 350        min_frames_error: int, optional
 351            if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames
 352        smooth_interval: int, default 1
 353            the number of frames to smooth the predictions over
 354        cut_annotated: bool, default False
 355            if `True`, annotated intervals will be cut out of the predicted intervals
 356
 357        Returns
 358        -------
 359        valleys : dict
 360            a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals
 361
 362        """
 363        result = defaultdict(lambda: [])
 364        if type(predicted) is not dict:
 365            predicted = self.generate_full_length_prediction(predicted)
 366        if predicted_error is not None:
 367            predicted_error = self.generate_full_length_prediction(predicted_error)
 368        elif min_frames_error is not None and min_frames_error != 0:
 369            # warnings.warn(
 370            #     f"The min_frames_error parameter is set to {min_frames_error} but no error prediction "
 371            #     f"is given! Setting min_frames_error to 0."
 372            # )
 373            min_frames_error = 0
 374        if low and hysteresis and threshold_diff is None:
 375            raise ValueError(
 376                "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff."
 377            )
 378        if cut_annotated:
 379            masked_intervals_dict = self.get_annotated_intervals()
 380        else:
 381            masked_intervals_dict = None
 382        print("Valleys found:")
 383        for v_id in predicted:
 384            for clip_id in predicted[v_id].keys():
 385                if predicted_error is not None:
 386                    error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold
 387                    if min_frames_error is not None:
 388                        output, indices, counts = torch.unique_consecutive(
 389                            error_mask, return_inverse=True, return_counts=True
 390                        )
 391                        wrong_indices = torch.where(
 392                            output * (counts < min_frames_error)
 393                        )[0]
 394                        if len(wrong_indices) > 0:
 395                            for i in wrong_indices:
 396                                error_mask[indices == i] = False
 397                else:
 398                    error_mask = None
 399                if masked_intervals_dict is not None:
 400                    masked_intervals = masked_intervals_dict[v_id][clip_id]
 401                else:
 402                    masked_intervals = None
 403                if not hysteresis:
 404                    res_indices_start, res_indices_end = apply_threshold(
 405                        predicted[v_id][clip_id][main_class, :],
 406                        threshold,
 407                        low,
 408                        error_mask,
 409                        min_frames,
 410                        smooth_interval,
 411                        masked_intervals,
 412                    )
 413                elif threshold_diff is not None:
 414                    if low:
 415                        soft_threshold = threshold + threshold_diff
 416                    else:
 417                        soft_threshold = threshold - threshold_diff
 418                    res_indices_start, res_indices_end = apply_threshold_hysteresis(
 419                        predicted[v_id][clip_id][main_class, :],
 420                        soft_threshold,
 421                        threshold,
 422                        low,
 423                        error_mask,
 424                        min_frames,
 425                        smooth_interval,
 426                        masked_intervals,
 427                    )
 428                else:
 429                    res_indices_start, res_indices_end = apply_threshold_max(
 430                        predicted[v_id][clip_id],
 431                        threshold,
 432                        main_class,
 433                        error_mask,
 434                        min_frames,
 435                        smooth_interval,
 436                        masked_intervals,
 437                    )
 438                start = self.input_store.get_clip_start(v_id, clip_id)
 439                result[v_id] += [
 440                    [i + start, j + start, clip_id]
 441                    for i, j in zip(res_indices_start, res_indices_end)
 442                    if self.input_store.get_visibility(
 443                        v_id, clip_id, i, j, visibility_min_score
 444                    )
 445                    > visibility_min_frac
 446                ]
 447            result[v_id] = sorted(result[v_id])
 448            print(f"    {v_id}: {len(result[v_id])}")
 449        return dict(result)
 450
 451    def valleys_union(self, valleys_list) -> Dict:
 452        """Find the intersection of two valleys dictionaries.
 453
 454        Parameters
 455        ----------
 456        valleys_list : list
 457            a list of valleys dictionaries
 458
 459        Returns
 460        -------
 461        intersection : dict
 462            a new valleys dictionary with the intersection of the input intervals
 463
 464        """
 465        valleys_list = [x for x in valleys_list if x is not None]
 466        if len(valleys_list) == 1:
 467            return valleys_list[0]
 468        elif len(valleys_list) == 0:
 469            return {}
 470        union = {}
 471        keys_list = [set(valleys.keys()) for valleys in valleys_list]
 472        keys = set.union(*keys_list)
 473        for v_id in keys:
 474            res = []
 475            clips_list = [
 476                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
 477            ]
 478            clips = set.union(*clips_list)
 479            for clip_id in clips:
 480                clip_intervals = [
 481                    x
 482                    for valleys in valleys_list
 483                    for x in valleys[v_id]
 484                    if x[-1] == clip_id
 485                ]
 486                v_len = self.input_store.get_clip_length(v_id, clip_id)
 487                arr = torch.zeros(v_len)
 488                for start, end, _ in clip_intervals:
 489                    arr[start:end] += 1
 490                output, indices, counts = torch.unique_consecutive(
 491                    arr > 0, return_inverse=True, return_counts=True
 492                )
 493                long_indices = torch.where(output)[0]
 494                res += [
 495                    (
 496                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
 497                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
 498                        clip_id,
 499                    )
 500                    for i in long_indices
 501                ]
 502            union[v_id] = res
 503        return union
 504
 505    def valleys_intersection(self, valleys_list) -> Dict:
 506        """Find the intersection of two valleys dictionaries.
 507
 508        Parameters
 509        ----------
 510        valleys_list : list
 511            a list of valleys dictionaries
 512
 513        Returns
 514        -------
 515        intersection : dict
 516            a new valleys dictionary with the intersection of the input intervals
 517
 518        """
 519        valleys_list = [x for x in valleys_list if x is not None]
 520        if len(valleys_list) == 1:
 521            return valleys_list[0]
 522        elif len(valleys_list) == 0:
 523            return {}
 524        intersection = {}
 525        keys_list = [set(valleys.keys()) for valleys in valleys_list]
 526        keys = set.intersection(*keys_list)
 527        for v_id in keys:
 528            res = []
 529            clips_list = [
 530                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
 531            ]
 532            clips = set.intersection(*clips_list)
 533            for clip_id in clips:
 534                clip_intervals = [
 535                    x
 536                    for valleys in valleys_list
 537                    for x in valleys[v_id]
 538                    if x[-1] == clip_id
 539                ]
 540                v_len = self.input_store.get_clip_length(v_id, clip_id)
 541                arr = torch.zeros(v_len)
 542                for start, end, _ in clip_intervals:
 543                    arr[start:end] += 1
 544                output, indices, counts = torch.unique_consecutive(
 545                    arr, return_inverse=True, return_counts=True
 546                )
 547                long_indices = torch.where(output == 2)[0]
 548                res += [
 549                    (
 550                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
 551                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
 552                        clip_id,
 553                    )
 554                    for i in long_indices
 555                ]
 556            intersection[v_id] = res
 557        return intersection
 558
 559    def partition_train_test_val(
 560        self,
 561        use_test: float = 0,
 562        split_path: str = None,
 563        method: str = "random",
 564        val_frac: float = 0,
 565        test_frac: float = 0,
 566        save_split: bool = False,
 567        normalize: bool = False,
 568        skip_normalization_keys: List = None,
 569        stats: Dict = None,
 570    ) -> Tuple:
 571        """Partition the dataset into three new datasets.
 572
 573        Parameters
 574        ----------
 575        use_test : float, default 0
 576            The fraction of the test dataset to be used in training without labels
 577        split_path : str, optional
 578            The path to load the split information from (if `'file'` method is used) and to save it to
 579            (if `'save_split'` is `True`)
 580        method : {'random', 'random:test-from-name', 'random:test-from-name:{name}',
 581            'val-from-name:{val_name}:test-from-name:{test_name}',
 582            'random:equalize:segments', 'random:equalize:videos',
 583            'folders', 'time', 'time:strict', 'file'}
 584            The partitioning method:
 585            - `'random'`: sort videos into subsets randomly,
 586            - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation
 587                subsets randomly and create
 588                the test subset from the video ids that start with a speific substring (`'test'` by default, or `name`
 589                if provided),
 590            - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but
 591                making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain
 592                occurrences of the class get into the validation subset and `0.8 * test_frac` get into the test subset;
 593                this in ensured for all classes in order of increasing number of occurrences until the validation and
 594                test subsets are full
 595            - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test
 596                subsets from the video ids that start with specific substrings (`val_name` for validation
 597                and `test_name` for test) and sort all other videos into the training subset
 598            - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets,
 599            - `'time'`: split each video into training, validation and test subsequences,
 600            - `'time:strict'`: split each video into validation, test and training subsequences
 601                and throw out the last segments in validation and test (to get rid of overlaps),
 602            - `'file'`: split according to a split file.
 603        val_frac : float, default 0
 604            The fraction of the dataset to be used in validation
 605        test_frac : float, default 0
 606            The fraction of the dataset to be used in test
 607        save_split : bool, default False
 608            Save a split file if True
 609        normalize : bool, default False
 610            Normalize the dataset if `True`
 611        skip_normalization_keys : list, optional
 612            A list of keys to skip normalization for
 613        stats : dict, optional
 614            A dictionary of (pre-computed) statistics to use for normalization
 615
 616        Returns
 617        -------
 618        train_dataset : BehaviorDataset
 619            train dataset
 620        val_dataset : BehaviorDataset
 621            validation dataset
 622        test_dataset : BehaviorDataset
 623            test dataset
 624
 625        """
 626        train_indices, test_indices, val_indices = self._partition_indices(
 627            split_path=split_path,
 628            method=method,
 629            val_frac=val_frac,
 630            test_frac=test_frac,
 631            save_split=save_split,
 632        )
 633        ssl_indices = None
 634        partition_method = method.split(":")
 635        if (
 636            partition_method[0] in ("leave-one-in", "leave-n-in")
 637            and len(partition_method) > 1
 638            and partition_method[2] == "val-for-ssl"
 639        ):
 640            print("Using validation samples for SSL!")
 641            ssl_indices = val_indices
 642
 643        val_dataset = self._create_new_dataset(val_indices)
 644        test_dataset = self._create_new_dataset(test_indices)
 645        train_dataset = self._create_new_dataset(train_indices, ssl_indices=ssl_indices)
 646
 647        train_classes = train_dataset.count_classes()
 648        val_classes = val_dataset.count_classes()
 649        test_classes = test_dataset.count_classes()
 650        print("Number of samples:")
 651        print(f"    validation:")
 652        print(f"      {[f'{k}: {val_classes[k]}' for k in sorted(val_classes.keys())]}")
 653        print(f"    training:")
 654        print(f"      {[f'{k}: {train_classes[k]}' for k in sorted(train_classes.keys())]}")
 655        print(f"    test:")
 656        print(f"      {[f'{k}: {test_classes[k]}' for k in sorted(test_classes.keys())]}")
 657        if normalize:
 658            if stats is None:
 659                print("Computing normalization statistics...")
 660                stats = train_dataset.get_normalization_stats(skip_normalization_keys)
 661            else:
 662                print("Setting loaded normalization statistics...")
 663            train_dataset.set_normalization_stats(stats)
 664            val_dataset.set_normalization_stats(stats)
 665            test_dataset.set_normalization_stats(stats)
 666        return train_dataset, test_dataset, val_dataset
 667
 668    def class_weights(self, proportional=False) -> List:
 669        """Calculate class weights in inverse proportion to number of samples.
 670
 671        Parameters
 672        ----------
 673        proportional : bool, default False
 674            If `True`, the weights are proportional to the number of samples in the most common class
 675
 676        Returns
 677        -------
 678        weights: list
 679            a list of class weights
 680
 681        """
 682        items = sorted(
 683            [
 684                (k, v)
 685                for k, v in self.annotation_store.count_classes().items()
 686                if k != -100
 687            ]
 688        )
 689        if self.annotation_store.annotation_class() == "exclusive_classification":
 690            if not proportional:
 691                numerator = len(self.annotation_store)
 692            else:
 693                numerator = max([x[1] for x in items])
 694            weights = [numerator / (v + 1e-7) for _, v in items]
 695        else:
 696            items_zero = sorted(
 697                [
 698                    (k, v)
 699                    for k, v in self.annotation_store.count_classes(zeros=True).items()
 700                    if k != -100
 701                ]
 702            )
 703            if not proportional:
 704                numerators = defaultdict(lambda: len(self.annotation_store))
 705            else:
 706                numerators = {
 707                    item_one[0]: max(item_one[1], item_zero[1])
 708                    for item_one, item_zero in zip(items, items_zero)
 709                }
 710            weights = {}
 711            weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero]
 712            weights[1] = [numerators[k] / (v + 1e-7) for k, v in items]
 713        return weights
 714
 715    def _boundary_class_weight(self):
 716        """Calculate the weight of the boundary class.
 717
 718        Returns
 719        -------
 720        weight: float
 721            the weight of the boundary class
 722
 723        """
 724        if self.annotation_type != "none":
 725            f = self.annotation_store.data.flatten()
 726            _, inv = torch.unique_consecutive(f, return_inverse=True)
 727            boundary = torch.cat([torch.tensor([0]), torch.diff(inv)]).reshape(
 728                self.annotation_store.data.shape
 729            )
 730            boundary[..., 0] = 0
 731            cnt = Counter(boundary.flatten().numpy())
 732            return cnt[1] / cnt[0]
 733        else:
 734            return 0
 735
 736    def count_classes(self, bouts: bool = False) -> Dict:
 737        """Get a class counter dictionary.
 738
 739        Parameters
 740        ----------
 741        bouts : bool, default False
 742            if `True`, instead of frame counts segment counts are returned
 743
 744        Returns
 745        -------
 746        count_dictionary : dict
 747            a dictionary with class indices as keys and frame or bout counts as values
 748
 749        """
 750        return self.annotation_store.count_classes(bouts=bouts)
 751
 752    def behaviors_dict(self) -> Dict:
 753        """Get a behavior dictionary.
 754
 755        Returns
 756        -------
 757        dict
 758            behavior dictionary
 759
 760        """
 761        return self.annotation_store.behaviors_dict()
 762
 763    def bodyparts_order(self) -> List:
 764        """Get the order of bodyparts.
 765
 766        Returns
 767        -------
 768        bodyparts : List
 769            a list of bodyparts
 770
 771        """
 772        try:
 773            return self.input_store.get_bodyparts()
 774        except:
 775            raise RuntimeError(
 776                f"The {self.input_type} input store does not have bodyparts implemented!"
 777            )
 778
 779    def features_shape(self) -> Dict:
 780        """Get the shapes of the input features.
 781
 782        Returns
 783        -------
 784        shapes : Dict
 785            a dictionary with the shapes of the features
 786
 787        """
 788        sample = self.input_store[0]
 789        shapes = {k: v.shape for k, v in sample.items()}
 790        # for key, value in shapes.items():
 791        #     print(f'{key}: {value}')
 792        return shapes
 793
 794    def num_classes(self) -> int:
 795        """Get the number of classes in the data.
 796
 797        Returns
 798        -------
 799        num_classes : int
 800            the number of classes
 801
 802        """
 803        return len(self.annotation_store.behaviors_dict())
 804
 805    def len_segment(self) -> int:
 806        """Get the segment length in the data.
 807
 808        Returns
 809        -------
 810        len_segment : int
 811            the segment length
 812
 813        """
 814        sample = self.input_store[0]
 815        key = list(sample.keys())[0]
 816        return sample[key].shape[-1]
 817
 818    def set_ssl_transformations(self, ssl_transformations: List) -> None:
 819        """Set new SSL transformations.
 820
 821        Parameters
 822        ----------
 823        ssl_transformations : list
 824            a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets
 825            lists
 826
 827        """
 828        self.ssl_transformations = ssl_transformations
 829
 830    @classmethod
 831    def new(cls, *args, **kwargs):
 832        """Create a new object of the same class.
 833
 834        Parameters
 835        ----------
 836        args : list
 837            arguments for the constructor
 838        kwargs : dict
 839            keyword arguments for the constructor
 840
 841        Returns
 842        -------
 843        new_instance: BehaviorDataset
 844            a new instance of the same class
 845
 846        """
 847        return cls(*args, **kwargs)
 848
 849    @classmethod
 850    def get_parameters(cls, data_type: str, annotation_type: str) -> List:
 851        """Get parameters necessary for initialization.
 852
 853        Parameters
 854        ----------
 855        data_type : str
 856            the data type
 857        annotation_type : str
 858            the annotation type
 859
 860        Returns
 861        -------
 862        parameters : list
 863            a list of parameters
 864
 865        """
 866        input_features = options.input_stores[data_type].get_parameters()
 867        annotation_features = options.annotation_stores[
 868            annotation_type
 869        ].get_parameters()
 870        self_features = inspect.getfullargspec(cls.__init__).args
 871        return self_features + input_features + annotation_features
 872
 873    @staticmethod
 874    def data_types() -> List:
 875        """List available data types.
 876
 877        Returns
 878        -------
 879        data_types : list
 880            available data types
 881
 882        """
 883        return list(options.input_stores.keys())
 884
 885    @staticmethod
 886    def annotation_types() -> List:
 887        """List available annotation types.
 888
 889        Returns
 890        -------
 891        annotation_types : list
 892            available annotation types
 893
 894        """
 895        return list(options.annotation_stores.keys())
 896
 897    def _get_SSL_targets(self, input: Dict) -> Tuple[List, List]:
 898        """Get the SSL inputs and targets from a sample dictionary."""
 899        ssl_inputs = []
 900        ssl_targets = []
 901        for transform in self.ssl_transformations:
 902            ssl_input, ssl_target = transform(copy(input))
 903            ssl_inputs.append(ssl_input)
 904            ssl_targets.append(ssl_target)
 905        return ssl_inputs, ssl_targets
 906
 907    def _create_new_dataset(self, indices: List, ssl_indices: List = None):
 908        """Create a subsample of the dataset, with samples at ssl_indices losing the annotation."""
 909        if ssl_indices is None:
 910            ssl_indices = []
 911        input_store = self.input_store.create_subsample(indices, ssl_indices)
 912        annotation_store = self.annotation_store.create_subsample(indices, ssl_indices)
 913        new = self.new(
 914            data_type=self.input_type,
 915            annotation_type=self.annotation_type,
 916            ssl_transformations=self.ssl_transformations,
 917            annotation_store=annotation_store,
 918            input_store=input_store,
 919            ids=list(indices) + list(ssl_indices),
 920            recompute_annotation=False,
 921        )
 922        return new
 923
 924    def _load_input_store(self, data_type: str, key_objects: Tuple) -> InputStore:
 925        """Load input store from key objects."""
 926        input_store = options.input_stores[data_type](key_objects=key_objects)
 927        return input_store
 928
 929    def _load_annotation_store(
 930        self, annotation_type: str, key_objects: Tuple
 931    ) -> BehaviorStore:
 932        """Load annotation store from key objects."""
 933        annotation_store = options.annotation_stores[annotation_type](
 934            key_objects=key_objects
 935        )
 936        return annotation_store
 937
 938    def _get_input_store(self, data_type: str, data_parameters: Dict) -> InputStore:
 939        """Create input store from parameters."""
 940        data_parameters["key_objects"] = None
 941        input_store = options.input_stores[data_type](**data_parameters)
 942        return input_store
 943
 944    def _get_annotation_store(
 945        self, annotation_type: str, data_parameters: Dict
 946    ) -> BehaviorStore:
 947        """Create annotation store from parameters."""
 948        annotation_store = options.annotation_stores[annotation_type](**data_parameters)
 949        return annotation_store
 950
 951    def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None:
 952        """Set the parameters that change the subset that is returned at `__getitem__`.
 953
 954        Parameters
 955        ----------
 956        unlabeled : bool
 957            a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and
 958            all if `None`
 959        tag : int
 960            if not `None`, only samples with this meta tag will be returned
 961
 962        """
 963        if unlabeled != self.return_unlabeled:
 964            self.annotation_indices = self.annotation_store.get_indices(unlabeled)
 965            self.return_unlabeled = unlabeled
 966        if tag != self.tag:
 967            self.input_indices = self.input_store.get_indices(tag)
 968            self.tag = tag
 969        self.indices = [x for x in self.annotation_indices if x in self.input_indices]
 970
 971    def _get_idx(self, index: int) -> int:
 972        """Get index in full dataset."""
 973        return self.indices[index]
 974
 975        # return self.annotation_store.get_idx(
 976        #     index, return_unlabeled=self.return_unlabeled
 977        # )
 978
 979    def _partition_indices(
 980        self,
 981        split_path: str = None,
 982        method: str = "random",
 983        val_frac: float = 0,
 984        test_frac: float = 0,
 985        save_split: bool = False,
 986    ) -> Tuple[List, List, List]:
 987        """Partition indices into train, validation, test subsets."""
 988        if self.mask is not None:
 989            val_indices = self.mask["val_ids"]
 990            train_indices = [x for x in range(len(self)) if x not in val_indices]
 991            test_indices = []
 992        elif method == "random":
 993            videos = np.array(self.input_store.get_video_id_order())
 994            all_videos = list(set(videos))
 995            if len(all_videos) == 1:
 996                warnings.warn(
 997                    "There is only one video in the dataset, so train/val/test split is done on segments; "
 998                    'that might lead to overlaps, please consider using "time" or "time:strict" as the '
 999                    "partitioning method instead"
1000                )
1001                # Quick fix for single video: the problem with this is that the segments can overlap
1002                # length = int(self.input_store.get_original_coordinates()[-1][1])    # number of segments
1003                length = len(self.input_store.get_original_coordinates())
1004                val_len = int(val_frac * length)
1005                test_len = int(test_frac * length)
1006                all_indices = np.random.choice(np.arange(length), length, replace=False)
1007                val_indices = all_indices[:val_len]
1008                test_indices = all_indices[val_len : val_len + test_len]
1009                train_indices = all_indices[val_len + test_len :]
1010                coords = self.input_store.get_original_coordinates()
1011                if save_split:
1012                    self._save_partition(
1013                        coords[train_indices],
1014                        coords[val_indices],
1015                        coords[test_indices],
1016                        split_path,
1017                        coords=True,
1018                    )
1019            else:
1020                length = len(all_videos)
1021                val_len = int(val_frac * length)
1022                test_len = int(test_frac * length)
1023                validation = all_videos[:val_len]
1024                test = all_videos[val_len : val_len + test_len]
1025                training = all_videos[val_len + test_len :]
1026                train_indices = np.where(np.isin(videos, training))[0]
1027                val_indices = np.where(np.isin(videos, validation))[0]
1028                test_indices = np.where(np.isin(videos, test))[0]
1029                if save_split:
1030                    self._save_partition(training, validation, test, split_path)
1031        elif method.startswith("random:equalize"):
1032            counter = self.count_classes()
1033            counter = sorted(list([(v, k) for k, v in counter.items()]))
1034            classes = [x[1] for x in counter]
1035            indicator = {c: [] for c in classes}
1036            if method.endswith("videos"):
1037                videos = np.array(self.input_store.get_video_id_order())
1038                all_videos = list(set(videos))
1039                total_len = len(all_videos)
1040                for video_id in all_videos:
1041                    video_coords = np.where(videos == video_id)[0]
1042                    ann = torch.cat(
1043                        [self.annotation_store[i] for i in video_coords], dim=-1
1044                    )
1045                    for c in classes:
1046                        if self.annotation_class() == "nonexclusive_classification":
1047                            indicator[c].append(torch.sum(ann[c] == 1) > 0)
1048                        elif self.annotation_class() == "exclusive_classification":
1049                            indicator[c].append(torch.sum(ann == c) > 0)
1050                        else:
1051                            raise ValueError(
1052                                f"The random:equalize partition method is not implemented"
1053                                f"for the {self.annotation_class()} annotation class"
1054                            )
1055            elif method.endswith("segments"):
1056                total_len = len(self)
1057                for ann in self.annotation_store:
1058                    for c in classes:
1059                        if self.annotation_class() == "nonexclusive_classification":
1060                            indicator[c].append(torch.sum(ann[c] == 1) > 0)
1061                        elif self.annotation_class() == "exclusive_classification":
1062                            indicator[c].append(torch.sum(ann == c) > 0)
1063                        else:
1064                            raise ValueError(
1065                                f"The random:equalize partition method is not implemented"
1066                                f"for the {self.annotation_class()} annotation class"
1067                            )
1068            else:
1069                values = []
1070                for v in options.partition_methods.values():
1071                    values += v
1072                raise ValueError(
1073                    f"The {method} partition method is not recognized; please choose from {values}"
1074                )
1075            val_indices = []
1076            test_indices = []
1077            for c in classes:
1078                indicator[c] = np.array(indicator[c])
1079                ind = np.where(indicator[c])[0]
1080                np.random.shuffle(ind)
1081                c_sum = len(ind)
1082                in_val = np.sum(indicator[c][val_indices])
1083                in_test = np.sum(indicator[c][test_indices])
1084                while (
1085                    len(val_indices) < val_frac * total_len
1086                    and in_val < val_frac * c_sum * 0.8
1087                ):
1088                    first, ind = ind[0], ind[1:]
1089                    val_indices = list(set(val_indices).union({first}))
1090                    in_val = np.sum(indicator[c][val_indices])
1091                while (
1092                    len(test_indices) < test_frac * total_len
1093                    and in_test < test_frac * c_sum * 0.8
1094                ):
1095                    first, ind = ind[0], ind[1:]
1096                    test_indices = list(set(test_indices).union({first}))
1097                    in_test = np.sum(indicator[c][test_indices])
1098            if len(val_indices) < int(val_frac * total_len):
1099                left_val = int(val_frac * total_len) - len(val_indices)
1100            else:
1101                left_val = 0
1102            if len(test_indices) < int(test_frac * total_len):
1103                left_test = int(test_frac * total_len) - len(test_indices)
1104            else:
1105                left_test = 0
1106            indicator = np.ones(total_len)
1107            indicator[val_indices] = 0
1108            indicator[test_indices] = 0
1109            ind = np.where(indicator)[0]
1110            np.random.shuffle(ind)
1111            val_indices += list(ind[:left_val])
1112            test_indices += list(ind[left_val : left_val + left_test])
1113            train_indices = list(ind[left_val + left_test :])
1114            if save_split:
1115                if method.endswith("segments"):
1116                    coords = self.input_store.get_original_coordinates()
1117                    self._save_partition(
1118                        coords[train_indices],
1119                        coords[val_indices],
1120                        coords[test_indices],
1121                        coords[split_path],
1122                        coords=True,
1123                    )
1124                else:
1125                    all_videos = np.array(all_videos)
1126                    validation = all_videos[val_indices]
1127                    test = all_videos[test_indices]
1128                    training = all_videos[train_indices]
1129                    self._save_partition(training, validation, test, split_path)
1130        elif method.startswith("random:test-from-name"):
1131            split = method.split(":")
1132            if len(split) > 2:
1133                test_name = split[-1]
1134            else:
1135                test_name = "test"
1136            videos = np.array(self.input_store.get_video_id_order())
1137            all_videos = list(set(videos))
1138            test = []
1139            train_videos = []
1140            for x in all_videos:
1141                if x.startswith(test_name):
1142                    test.append(x)
1143                else:
1144                    train_videos.append(x)
1145            length = len(train_videos)
1146            val_len = int(val_frac * length)
1147            validation = train_videos[:val_len]
1148            training = train_videos[val_len:]
1149            train_indices = np.where(np.isin(videos, training))[0]
1150            val_indices = np.where(np.isin(videos, validation))[0]
1151            test_indices = np.where(np.isin(videos, test))[0]
1152            if save_split:
1153                self._save_partition(training, validation, test, split_path)
1154        elif method.startswith("val-from-name"):
1155            split = method.split(":")
1156            if split[2] != "test-from-name":
1157                raise ValueError(
1158                    f"The {method} partition method is not recognized, please choose from {options.partition_methods}"
1159                )
1160            val_name = split[1]
1161            test_name = split[-1]
1162            videos = np.array(self.input_store.get_video_id_order())
1163            all_videos = list(set(videos))
1164            test = []
1165            validation = []
1166            training = []
1167            for x in all_videos:
1168                if x.startswith(test_name):
1169                    test.append(x)
1170                elif x.startswith(val_name):
1171                    validation.append(x)
1172                else:
1173                    training.append(x)
1174            train_indices = np.where(np.isin(videos, training))[0]
1175            val_indices = np.where(np.isin(videos, validation))[0]
1176            test_indices = np.where(np.isin(videos, test))[0]
1177        elif method == "folders":
1178            folders = np.array(self.input_store.get_folder_order())
1179            videos = np.array(self.input_store.get_video_id_order())
1180            train_indices = np.where(np.isin(folders, ["training", "train"]))[0]
1181            if np.sum(np.isin(folders, ["validation", "val"])) > 0:
1182                val_indices = np.where(np.isin(folders, ["validation", "val"]))[0]
1183            else:
1184                train_videos = list(set(videos[train_indices]))
1185                val_len = int(val_frac * len(train_videos))
1186                validation = train_videos[:val_len]
1187                training = train_videos[val_len:]
1188                train_indices = np.where(np.isin(videos, training))[0]
1189                val_indices = np.where(np.isin(videos, validation))[0]
1190            test_indices = np.where(folders == "test")[0]
1191            if save_split:
1192                self._save_partition(
1193                    list(set(videos[train_indices])),
1194                    list(set(videos[val_indices])),
1195                    list(set(videos[test_indices])),
1196                    split_path,
1197                )
1198        elif method.startswith("leave-one-out"):
1199            n = int(method.split(":")[-1])
1200            videos = np.array(self.input_store.get_video_id_order())
1201            all_videos = sorted(list(set(videos)))
1202            print(len(all_videos))
1203            validation = [all_videos.pop(n)]
1204            training = all_videos
1205            train_indices = np.where(np.isin(videos, training))[0]
1206            val_indices = np.where(np.isin(videos, validation))[0]
1207            test_indices = np.array([])
1208        elif method.startswith("leave-one-in"):
1209            n = int(method.split(":")[1])
1210            videos = np.array(self.input_store.get_video_id_order())
1211            all_videos = sorted(list(set(videos)))
1212            training = [all_videos.pop(n)]
1213            validation = all_videos
1214            train_indices = np.where(np.isin(videos, training))[0]
1215            val_indices = np.where(np.isin(videos, validation))[0]
1216            test_indices = np.array([])
1217        elif method.startswith("leave-n-in"):
1218            train_idx = [int(i) for i in method.split(":")[1].split(",")]
1219            videos = np.array(self.input_store.get_video_id_order())
1220            all_videos = sorted(list(set(videos)))
1221            training = [v for i, v in enumerate(all_videos) if i in train_idx]
1222            validation = [v for i, v in enumerate(all_videos) if i not in train_idx]
1223            train_indices = np.where(np.isin(videos, training))[0]
1224            val_indices = np.where(np.isin(videos, validation))[0]
1225            test_indices = np.array([])
1226        elif method.startswith("time"):
1227            if method.endswith("strict"):
1228                len_segment = self.len_segment()
1229                step = self.input_store.step
1230                num_removed = len_segment // step
1231            else:
1232                num_removed = 0
1233            videos = np.array(self.input_store.get_video_id_order())
1234            all_videos = set(videos)
1235            train_indices = []
1236            val_indices = []
1237            test_indices = []
1238            start = 0
1239            if len(method.split(":")) > 1 and method.split(":")[1] == "start-from":
1240                start = float(method.split(":")[2])
1241            for video_id in all_videos:
1242                video_indices = np.where(videos == video_id)[0]
1243                val_len = int(val_frac * len(video_indices))
1244                test_len = int(test_frac * len(video_indices))
1245                start_pos = int(start * len(video_indices))
1246                all_ind = np.ones(len(video_indices))
1247                val_indices += list(video_indices[start_pos : start_pos + val_len])
1248                all_ind[start_pos : start_pos + val_len] = 0
1249                if start_pos + val_len > len(video_indices):
1250                    p = start_pos + val_len - len(video_indices)
1251                    val_indices += list(video_indices[:p])
1252                    all_ind[:p] = 0
1253                else:
1254                    p = start_pos + val_len
1255                test_indices += list(video_indices[p : p + test_len])
1256                all_ind[p : p + test_len] = 0
1257                if p + test_len > len(video_indices):
1258                    p = test_len + p - len(video_indices)
1259                    test_indices += list(video_indices[:p])
1260                    all_ind[:p] = 0
1261                train_indices += list(video_indices[all_ind > 0])
1262                for _ in range(num_removed):
1263                    if len(val_indices) > 0:
1264                        val_indices.pop(-1)
1265                    if len(test_indices) > 0:
1266                        test_indices.pop(-1)
1267                    if start > 0 and len(train_indices) > 0:
1268                        train_indices.pop(-1)
1269        elif method == "file":
1270            if split_path is None:
1271                raise ValueError(
1272                    'You need to either set split_path or change partition method ("file" requires a file)'
1273                )
1274            active_list = None
1275            training, validation, test = [], [], []
1276            with open(split_path) as f:
1277                for line in f.readlines():
1278                    if line.startswith("Train"):
1279                        active_list = training
1280                    elif line.startswith("Valid"):
1281                        active_list = validation
1282                    elif line.startswith("Test"):
1283                        active_list = test
1284                    else:
1285                        stripped_line = line.rstrip(",\n ")
1286                        if stripped_line == "":
1287                            continue
1288                        if ", " in stripped_line:
1289                            active_list += stripped_line.split(", ")
1290                        else:
1291                            active_list.append(stripped_line)
1292            all_lines = training + validation + test
1293            if len(all_lines[0].split("---")) == 3:
1294                entry_type = "coords"
1295            else:
1296                entry_type = "videos"
1297
1298            if entry_type == "videos":
1299                videos = np.array(self.input_store.get_video_id_order())
1300                val_indices = np.where(np.isin(videos, validation))[0]
1301                test_indices = np.where(np.isin(videos, test))[0]
1302                train_indices = np.where(np.isin(videos, training))[0]
1303            elif entry_type == "coords":
1304                coords = self.input_store.get_original_coordinates()
1305                video_ids = self.input_store.get_video_id_order()
1306                clip_ids = [self.input_store.get_clip_id(coord) for coord in coords]
1307                starts, ends = zip(
1308                    *[self.input_store.get_clip_start_end(coord) for coord in coords]
1309                )
1310                coords = np.array(
1311                    [
1312                        f"{video_id}---{clip_id}---{start}-{end}"
1313                        for video_id, clip_id, start, end in zip(
1314                            video_ids, clip_ids, starts, ends
1315                        )
1316                    ]
1317                )
1318                val_indices = np.where(np.isin(coords, validation))[0]
1319                test_indices = np.where(np.isin(coords, test))[0]
1320                train_indices = np.where(np.isin(coords, training))[0]
1321            else:
1322                raise ValueError("The split path has unrecognized format!")
1323            all_indices = np.ones(len(self))
1324            if len(train_indices) == 0:
1325                all_indices[val_indices] = 0
1326                all_indices[test_indices] = 0
1327                train_indices = np.where(all_indices)[0]
1328            elif len(val_indices) == 0:
1329                all_indices[train_indices] = 0
1330                all_indices[test_indices] = 0
1331                val_indices = np.where(all_indices)[0]
1332            elif len(test_indices) == 0:
1333                all_indices[train_indices] = 0
1334                all_indices[val_indices] = 0
1335                test_indices = np.where(all_indices)[0]
1336        else:
1337            raise ValueError(
1338                f"The {method} partition is not recognized, please choose from {options.partition_methods}"
1339            )
1340        return sorted(train_indices), sorted(test_indices), sorted(val_indices)
1341
1342    def _save_partition(
1343        self,
1344        training: List,
1345        validation: List,
1346        test: List,
1347        split_path: str,
1348        coords: bool = False,
1349    ) -> None:
1350        """Save a split file."""
1351        if coords:
1352            name = "coords"
1353            training_coords = []
1354            val_coords = []
1355            test_coords = []
1356            for coord in training:
1357                video_id = self.input_store.get_video_id(coord)
1358                clip_id = self.input_store.get_clip_id(coord)
1359                start, end = self.input_store.get_clip_start_end(coord)
1360                training_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1361            for coord in validation:
1362                video_id = self.input_store.get_video_id(coord)
1363                clip_id = self.input_store.get_clip_id(coord)
1364                start, end = self.input_store.get_clip_start_end(coord)
1365                val_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1366            for coord in test:
1367                video_id = self.input_store.get_video_id(coord)
1368                clip_id = self.input_store.get_clip_id(coord)
1369                start, end = self.input_store.get_clip_start_end(coord)
1370                test_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1371            training, validation, test = training_coords, val_coords, test_coords
1372        else:
1373            name = "videos"
1374        if split_path is not None:
1375            with open(split_path, "w") as f:
1376                f.write(f"Training {name}:\n")
1377                for x in training:
1378                    f.write(x + "\n")
1379                f.write(f"Validation {name}:\n")
1380                for x in validation:
1381                    f.write(x + "\n")
1382                f.write(f"Test {name}:\n")
1383                for x in test:
1384                    f.write(x + "\n")
1385
1386    def _get_intervals_from_ind(self, frame_indices: np.ndarray):
1387        """Get a list of intervals from a list of frame indices.
1388
1389        Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`.
1390
1391        Parameters
1392        ----------
1393        frame_indices : np.ndarray
1394            a list of frame indices
1395
1396        Returns
1397        -------
1398        intervals : list
1399            a list of interval boundaries
1400
1401        """
1402        masked_intervals = []
1403        breaks = np.where(np.diff(frame_indices) != 1)[0]
1404        if len(frame_indices) > 0:
1405            start = frame_indices[0]
1406            for k in breaks:
1407                masked_intervals.append([start, frame_indices[k] + 1])
1408                start = frame_indices[k + 1]
1409            masked_intervals.append([start, frame_indices[-1] + 1])
1410        return masked_intervals
1411
1412    def get_intervals(self) -> Tuple[dict, Optional[list]]:
1413        """Get a list of intervals covered by the dataset in the original coordinates.
1414
1415        Returns
1416        -------
1417        intervals : dict
1418            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1419            values are lists of the intervals in `[start, end]` format
1420
1421        """
1422        counter = defaultdict(lambda: {})
1423        coordinates = self.input_store.get_original_coordinates()
1424        for coords in coordinates:
1425            l = self.input_store.get_clip_length_from_coords(coords)
1426            video_name = self.input_store.get_video_id(coords)
1427            clip_id = self.input_store.get_clip_id(coords)
1428            start, end = self.input_store.get_clip_start_end(coords)
1429            if clip_id not in counter[video_name]:
1430                counter[video_name][clip_id] = np.zeros(l)
1431            counter[video_name][clip_id][start:end] = 1
1432        result = {video_name: {} for video_name in counter}
1433        for video_name in counter:
1434            for clip_id in counter[video_name]:
1435                result[video_name][clip_id] = self._get_intervals_from_ind(
1436                    np.where(counter[video_name][clip_id])[0]
1437                )
1438        return result, self.ids
1439
1440    def get_unannotated_intervals(self, first_intervals=None) -> Dict:
1441        """Get a list of intervals in the original coordinates where there is no annotation.
1442
1443        Parameters
1444        ----------
1445        first_intervals : dict
1446            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1447            values are lists of the intervals in `[start, end]` format. If provided, only the intersection with
1448            those intervals will be returned
1449
1450        Returns
1451        -------
1452        intervals : dict
1453            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1454            values are lists of the intervals in `[start, end]` format
1455
1456        """
1457        counter_value = 2
1458        if first_intervals is None:
1459            first_intervals = defaultdict(lambda: defaultdict(lambda: []))
1460            counter_value = 1
1461        counter = defaultdict(lambda: {})
1462        coordinates = self.input_store.get_original_coordinates()
1463        for i, coords in enumerate(coordinates):
1464            l = self.input_store.get_clip_length_from_coords(coords)
1465            ann = self.annotation_store[i]
1466            if (
1467                self.annotation_store.annotation_class()
1468                == "nonexclusive_classification"
1469            ):
1470                ann = ann[0, :]
1471            video_name = self.input_store.get_video_id(coords)
1472            clip_id = self.input_store.get_clip_id(coords)
1473            start, end = self.input_store.get_clip_start_end(coords)
1474            if clip_id not in counter[video_name]:
1475                counter[video_name][clip_id] = np.ones(l)
1476            counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int()
1477        result = {video_name: {} for video_name in counter}
1478        for video_name in counter:
1479            for clip_id in counter[video_name]:
1480                for start, end in first_intervals[video_name][clip_id]:
1481                    counter[video_name][clip_id][start:end] += 1
1482                result[video_name][clip_id] = self._get_intervals_from_ind(
1483                    np.where(counter[video_name][clip_id] == counter_value)[0]
1484                )
1485        return result
1486
1487    def get_annotated_intervals(self) -> Dict:
1488        """Get a list of intervals in the original coordinates where there is no annotation.
1489
1490        Returns
1491        -------
1492        intervals : dict
1493            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1494            values are lists of the intervals in `[start, end]` format
1495
1496        """
1497        if self.annotation_type == "none":
1498            return []
1499        counter_value = 1
1500        counter = defaultdict(lambda: {})
1501        coordinates = self.input_store.get_original_coordinates()
1502        for i, coords in enumerate(coordinates):
1503            l = self.input_store.get_clip_length_from_coords(coords)
1504            ann = self.annotation_store[i]
1505            video_name = self.input_store.get_video_id(coords)
1506            clip_id = self.input_store.get_clip_id(coords)
1507            start, end = self.input_store.get_clip_start_end(coords)
1508            if clip_id not in counter[video_name]:
1509                counter[video_name][clip_id] = np.zeros(l)
1510            if (
1511                self.annotation_store.annotation_class()
1512                == "nonexclusive_classification"
1513            ):
1514                counter[video_name][clip_id][start:end] = (
1515                    torch.sum(ann[:, : end - start] != -100, dim=0) > 0
1516                ).int()
1517            else:
1518                counter[video_name][clip_id][start:end] = (
1519                    ann[: end - start] != -100
1520                ).int()
1521        result = {video_name: {} for video_name in counter}
1522        for video_name in counter:
1523            for clip_id in counter[video_name]:
1524                result[video_name][clip_id] = self._get_intervals_from_ind(
1525                    np.where(counter[video_name][clip_id] == counter_value)[0]
1526                )
1527        return result
1528
1529    def get_ids(self) -> Dict:
1530        """Get a dictionary of all clip ids in the dataset.
1531
1532        Returns
1533        -------
1534        ids : dict
1535            a dictionary where keys are video ids and values are lists of clip ids
1536
1537        """
1538        coordinates = self.input_store.get_original_coordinates()
1539        video_ids = np.array(self.input_store.get_video_id_order())
1540        id_set = set(video_ids)
1541        result = {}
1542        for video_id in id_set:
1543            coords = coordinates[video_ids == video_id]
1544            clip_ids = list({self.input_store.get_clip_id(c) for c in coords})
1545            result[video_id] = clip_ids
1546        return result
1547
1548    def get_len(self, video_id: str, clip_id: str) -> int:
1549        """Get the length of a specific clip.
1550
1551        Parameters
1552        ----------
1553        video_id : str
1554            the video id
1555        clip_id : str
1556            the clip id
1557
1558        Returns
1559        -------
1560        length : int
1561            the length
1562
1563        """
1564        return self.input_store.get_clip_length(video_id, clip_id)
1565
1566    def get_confusion_matrix(
1567        self, prediction: torch.Tensor, confusion_type: str = "recall"
1568    ) -> Tuple[ndarray, list]:
1569        """Get a confusion matrix.
1570
1571        Parameters
1572        ----------
1573        prediction : torch.Tensor
1574            a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)`
1575        confusion_type : {"recall", "precision"}
1576            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
1577            into account, and if `type` is `"precision"`, only false negatives
1578
1579        Returns
1580        -------
1581        confusion_matrix : np.ndarray
1582            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
1583            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
1584            `N_i` is the number of frames that have the i-th label in the ground truth
1585        classes : list
1586            a list of classes
1587
1588        """
1589        behaviors_dict = self.annotation_store.behaviors_dict()
1590        num_behaviors = len(behaviors_dict)
1591        confusion_matrix = np.zeros((num_behaviors, num_behaviors))
1592        if self.annotation_store.annotation_class() == "exclusive_classification":
1593            exclusive = True
1594            confusion_type = None
1595        elif self.annotation_store.annotation_class() == "nonexclusive_classification":
1596            exclusive = False
1597        else:
1598            raise RuntimeError(
1599                f"The {self.annotation_store.annotation_class()} annotation class is not recognized!"
1600            )
1601        for ann, p in zip(self.annotation_store, prediction):
1602            if exclusive:
1603                class_prediction = torch.max(p, dim=0)[1]
1604                for i in behaviors_dict.keys():
1605                    for j in behaviors_dict.keys():
1606                        confusion_matrix[i, j] += int(
1607                            torch.sum(class_prediction[ann == i] == j)
1608                        )
1609            else:
1610                class_prediction = (p > 0.5).int()
1611                for i in behaviors_dict.keys():
1612                    for j in behaviors_dict.keys():
1613                        if confusion_type == "recall":
1614                            pred = deepcopy(class_prediction[j])
1615                            if i != j:
1616                                pred[ann[j] == 1] = 0
1617                            confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1]))
1618                        elif confusion_type == "precision":
1619                            annotation = deepcopy(ann[j])
1620                            if i != j:
1621                                annotation[class_prediction[j] == 1] = 0
1622                            confusion_matrix[i, j] += int(
1623                                torch.sum(annotation[class_prediction[i] == 1])
1624                            )
1625                        else:
1626                            raise ValueError(
1627                                f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']"
1628                            )
1629        counter = self.annotation_store.count_classes()
1630        for i in behaviors_dict.keys():
1631            if counter[i] != 0:
1632                if confusion_type == "recall" or confusion_type is None:
1633                    confusion_matrix[i, :] /= counter[i]
1634                else:
1635                    confusion_matrix[:, i] /= counter[i]
1636        return confusion_matrix, list(behaviors_dict.values()), confusion_type
1637
1638    def annotation_class(self) -> str:
1639        """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon).
1640
1641        Returns
1642        -------
1643        annotation_class : str
1644            the type of annotation
1645
1646        """
1647        return self.annotation_store.annotation_class()
1648
1649    def set_normalization_stats(self, stats: Dict) -> None:
1650        """Set the stats to normalize data at runtime.
1651
1652        Parameters
1653        ----------
1654        stats : dict
1655            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1656            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1657
1658        """
1659        self.stats = stats
1660
1661    def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]:
1662        """Get the minimum and maximum frame numbers for each clip in a video.
1663
1664        Parameters
1665        ----------
1666        video_id : str
1667            the video id
1668
1669        Returns
1670        -------
1671        min_frames : dict
1672            a dictionary where keys are clip ids and values are the minimum frame numbers
1673        max_frames : dict
1674            a dictionary where keys are clip ids and values are the maximum frame numbers
1675
1676        """
1677        coords = self.input_store.get_original_coordinates()
1678        clips = set(
1679            [
1680                self.input_store.get_clip_id(c)
1681                for c in coords
1682                if self.input_store.get_video_id(c) == video_id
1683            ]
1684        )
1685        min_frames = {}
1686        max_frames = {}
1687        for clip in clips:
1688            start = self.input_store.get_clip_start(video_id, clip)
1689            end = start + self.input_store.get_clip_length(video_id, clip)
1690            min_frames[clip] = start
1691            max_frames[clip] = end - 1
1692        return min_frames, max_frames
1693
1694    def get_normalization_stats(self, skip_keys=None) -> Dict:
1695        """Get mean and standard deviation for each key.
1696
1697        Parameters
1698        ----------
1699        skip_keys : list, optional
1700            a list of keys to skip
1701
1702        Returns
1703        -------
1704        stats : dict
1705            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1706            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1707
1708        """
1709        stats = defaultdict(lambda: {})
1710        sums = defaultdict(lambda: 0)
1711        if skip_keys is None:
1712            skip_keys = []
1713        counter = defaultdict(lambda: 0)
1714        for sample in tqdm(self):
1715            for key, value in sample["input"].items():
1716                key_name = key.split("---")[0]
1717                if key_name not in skip_keys:
1718                    sums[key_name] += value[:, value.sum(0) != 0].sum(-1)
1719                counter[key_name] += torch.sum(value.sum(0) != 0)
1720        for key, value in sums.items():
1721            stats[key]["mean"] = (value / counter[key]).unsqueeze(-1)
1722        sums = defaultdict(lambda: 0)
1723        for sample in tqdm(self):
1724            for key, value in sample["input"].items():
1725                key_name = key.split("---")[0]
1726                if key_name not in skip_keys:
1727                    sums[key_name] += (
1728                        (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2
1729                    ).sum(-1)
1730        for key, value in sums.items():
1731            stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key])
1732        return stats
class BehaviorDataset(typing.Generic[+_T_co]):
  33class BehaviorDataset(Dataset, ABC):
  34    """A generalized dataset class.
  35
  36    Data and annotation are stored in separate InputStore and BehaviorStore objects; the dataset class
  37    manages their interactions.
  38    """
  39
  40    def __init__(
  41        self,
  42        data_type: str,
  43        annotation_type: str = "none",
  44        ssl_transformations: List = None,
  45        saved_data_path: str = None,
  46        input_store: InputStore = None,
  47        annotation_store: BehaviorStore = None,
  48        only_load_annotated: bool = False,
  49        recompute_annotation: bool = False,
  50        # mask: str = None,
  51        ids: List = None,
  52        **data_parameters,
  53    ) -> None:
  54        """Initialize a dataset.
  55
  56        Parameters
  57        ----------
  58        data_type : str
  59            the data type (see available types by running BehaviorDataset.data_types())
  60        annotation_type : str
  61            the annotation type (see available types by running BehaviorDataset.annotation_types())
  62        ssl_transformations : list
  63            a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple
  64        saved_data_path : str
  65            the path to a pre-computed pickled dataset
  66        input_store : InputStore
  67            a pre-computed input store
  68        annotation_store : BehaviorStore
  69            a precomputed annotation store
  70        only_load_annotated : bool
  71            if `True`, the input files that don't have a matching annotation file will be disregarded
  72        recompute_annotation : bool
  73            if `True`, the annotation will be recomputed even if a precomputed annotation store is provided
  74        ids : list
  75            a list of ids to load from the input store
  76        *data_parameters : dict
  77            parameters to initialize the input and annotation stores
  78
  79        """
  80        mask = None
  81        if len(data_parameters) == 0:
  82            recompute_annotation = False
  83        feature_extraction = data_parameters.get("feature_extraction")
  84        if feature_extraction is not None and not issubclass(
  85            options.input_stores[data_type],
  86            options.feature_extractors[feature_extraction].input_store_class,
  87        ):
  88            raise ValueError(
  89                f"The {feature_extraction} feature extractor does not work with "
  90                f"the {data_type} data type, please choose a suclass of "
  91                f"{options.feature_extractors[feature_extraction].input_store_class}"
  92            )
  93        if ssl_transformations is None:
  94            ssl_transformations = []
  95        self.ssl_transformations = ssl_transformations
  96        self.input_type = data_type
  97        self.annotation_type = annotation_type
  98        self.stats = None
  99        if mask is not None:
 100            with open(mask, "rb") as f:
 101                self.mask = pickle.load(f)
 102        else:
 103            self.mask = None
 104        self.ids = ids
 105        self.tag = None
 106        self.return_unlabeled = None
 107        # load saved key objects for annotation and input if they exist
 108        input_key_objects, annotation_key_objects = None, None
 109        if saved_data_path is not None:
 110            if os.path.exists(saved_data_path):
 111                with open(saved_data_path, "rb") as f:
 112                    input_key_objects, annotation_key_objects = pickle.load(f)
 113        # if the input or the annotation store need to be created, generate the common video order
 114        if len(data_parameters) > 0:
 115            input_files = options.input_stores[data_type].get_file_ids(
 116                **data_parameters
 117            )
 118            annotation_files = options.annotation_stores[annotation_type].get_file_ids(
 119                **data_parameters
 120            )
 121            if only_load_annotated:
 122                data_parameters["video_order"] = [
 123                    x for x in input_files if x in annotation_files
 124                ]
 125            else:
 126                data_parameters["video_order"] = input_files
 127            if len(data_parameters["video_order"]) == 0:
 128                raise RuntimeError(
 129                    "The length of file list is 0! Please check your data parameters!"
 130                )
 131        data_parameters["mask"] = self.mask
 132        # load or create the input store
 133        ok = False
 134        if input_store is not None:
 135            self.input_store = input_store
 136            ok = True
 137        elif input_key_objects is not None:
 138            try:
 139                self.input_store = self._load_input_store(data_type, input_key_objects)
 140                ok = True
 141            except:
 142                warnings.warn("Loading input store from key objects failed")
 143        if not ok:
 144            self.input_store = self._get_input_store(
 145                data_type, deepcopy(data_parameters)
 146            )
 147        # get the objects needed to create the annotation store (like a clip length dictionary)
 148        annotation_objects = self.input_store.get_annotation_objects()
 149        data_parameters.update(annotation_objects)
 150        # load or create the annotation store
 151        ok = False
 152        if annotation_store is not None:
 153            self.annotation_store = annotation_store
 154            ok = True
 155        elif (
 156            (annotation_key_objects is not None)
 157            and mask is None
 158            and not recompute_annotation
 159        ):
 160            if len(annotation_key_objects) > 0:
 161                try:
 162                    self.annotation_store = self._load_annotation_store(
 163                        annotation_type, annotation_key_objects
 164                    )
 165                    ok = True
 166                except:
 167                    warnings.warn("Loading annotation store from key objects failed")
 168        if not ok:
 169            self.annotation_store = self._get_annotation_store(
 170                annotation_type, deepcopy(data_parameters)
 171            )
 172        to_remove = self.annotation_store.filtered_indices()
 173        if len(to_remove) > 0:
 174            print(
 175                f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples"
 176            )
 177        if len(to_remove) == len(self.annotation_store) and len(to_remove) > 0:
 178            raise ValueError("All samples were filtered out!")
 179
 180        if len(self.input_store) == len(self.annotation_store):
 181            self.input_store.remove(to_remove)
 182        self.annotation_store.remove(to_remove)
 183        self.input_indices = list(range(len(self.input_store)))
 184        self.annotation_indices = list(range(len(self.input_store)))
 185        self.indices = list(range(len(self.input_store)))
 186
 187    def __getitem__(self, item: int) -> Dict:
 188        idx = self._get_idx(item)
 189        input = deepcopy(self.input_store[idx])
 190        target = self.annotation_store[idx]
 191        tag = self.input_store.get_tag(idx)
 192        ssl_inputs, ssl_targets = self._get_SSL_targets(input)
 193        batch = {"input": input}
 194        for name, x in zip(
 195            ["target", "ssl_inputs", "ssl_targets", "tag"],
 196            [target, ssl_inputs, ssl_targets, tag],
 197        ):
 198            if x is not None:
 199                batch[name] = x
 200        batch["index"] = idx
 201        if self.stats is not None:
 202            for key in batch["input"].keys():
 203                key_name = key.split("---")[0]
 204                if key_name in self.stats:
 205                    batch["input"][key][:, batch["input"][key].sum(0) != 0] = (
 206                        (batch["input"][key] - self.stats[key_name]["mean"])
 207                        / (self.stats[key_name]["std"] + 1e-7)
 208                    )[:, batch["input"][key].sum(0) != 0]
 209        return batch
 210
 211    def __len__(self) -> int:
 212        return len(self.indices)
 213        # if self.annotation_type != "none":
 214        #     return self.annotation_store.get_len(return_unlabeled=self.return_unlabeled)
 215        # else:
 216        #     return len(self.input_store)
 217
 218    def get_tags(self) -> List:
 219        """Get a list of all meta tags.
 220
 221        Returns
 222        -------
 223        tags: List
 224            a list of unique meta tag values
 225
 226        """
 227        return self.input_store.get_tags()
 228
 229    def save(self, save_path: str) -> None:
 230        """Save the dictionary.
 231
 232        Parameters
 233        ----------
 234        save_path : str
 235            the path where the pickled file will be stored
 236
 237        """
 238        input_obj = self.input_store.key_objects()
 239        annotation_obj = self.annotation_store.key_objects()
 240        with open(save_path, "wb") as f:
 241            pickle.dump((input_obj, annotation_obj), f)
 242
 243    def to_ram(self) -> None:
 244        """Transfer the dataset to RAM."""
 245        self.input_store.to_ram()
 246        self.annotation_store.to_ram()
 247
 248    def generate_full_length_gt(self) -> Dict:
 249        """Generate full-length ground truth from the annotations.
 250
 251        Returns
 252        -------
 253        full_length_gt : dict
 254            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
 255            values are the ground truth labels
 256
 257        """
 258        if self.annotation_class() == "exclusive_classification":
 259            gt = torch.zeros((len(self), self.len_segment()))
 260        else:
 261            gt = torch.zeros(
 262                (len(self), len(self.behaviors_dict()), self.len_segment())
 263            )
 264        for i in range(len(self)):
 265            gt[i] = self.annotation_store[i]
 266        return self.generate_full_length_prediction(gt)
 267
 268    def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict:
 269        """Map predictions for the equal-length pieces to predictions for the original data.
 270
 271        Probabilities are averaged over predictions on overlapping intervals.
 272
 273        Parameters
 274        ----------
 275        predicted: torch.Tensor
 276            a tensor of predicted probabilities of shape `(N, #classes, #frames)`
 277
 278        Returns
 279        -------
 280        full_length_prediction : dict
 281            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
 282            averaged probability tensors
 283
 284        """
 285        result = defaultdict(lambda: {})
 286        counter = defaultdict(lambda: {})
 287        coordinates = self.input_store.get_original_coordinates()
 288        for coords, prediction in zip(coordinates, predicted):
 289            l = self.input_store.get_clip_length_from_coords(coords)
 290            video_name = self.input_store.get_video_id(coords)
 291            clip_id = self.input_store.get_clip_id(coords)
 292            start, end = self.input_store.get_clip_start_end(coords)
 293            if clip_id not in result[video_name].keys():
 294                result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
 295                counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
 296            result[video_name][clip_id][..., start:end] += (
 297                prediction.squeeze()[..., : end - start].detach().cpu()
 298            )
 299            counter[video_name][clip_id][..., start:end] += 1
 300        for video_name in result:
 301            for clip_id in result[video_name]:
 302                result[video_name][clip_id] /= counter[video_name][clip_id]
 303                result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100
 304        result = dict(result)
 305        return result
 306
 307    def find_valleys(
 308        self,
 309        predicted: Union[torch.Tensor, Dict],
 310        threshold: float = 0.5,
 311        min_frames: int = 0,
 312        visibility_min_score: float = 0,
 313        visibility_min_frac: float = 0,
 314        main_class: int = 1,
 315        low: bool = True,
 316        predicted_error: torch.Tensor = None,
 317        error_threshold: float = 0.5,
 318        hysteresis: bool = False,
 319        threshold_diff: float = None,
 320        min_frames_error: int = None,
 321        smooth_interval: int = 1,
 322        cut_annotated: bool = False,
 323    ) -> Dict:
 324        """Find the intervals where the probability of a certain class is below or above a certain hard_threshold.
 325
 326        Parameters
 327        ----------
 328        predicted : torch.Tensor | dict
 329            either a tensor of predictions for the data prompts or the output of
 330            `BehaviorDataset.generate_full_length_prediction`
 331        threshold : float, default 0.5
 332            the main hard_threshold
 333        min_frames : int, default 0
 334            the minimum length of the intervals
 335        visibility_min_score : float, default 0
 336            the minimum visibility score in the intervals
 337        visibility_min_frac : float, default 0
 338            fraction of the interval that has to have the visibility score larger than visibility_score_thr
 339        main_class : int, default 1
 340            the index of the class the function is inspecting
 341        low : bool, default True
 342            if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above
 343        predicted_error : torch.Tensor, optional
 344            a tensor of error predictions for the data prompts
 345        error_threshold : float, default 0.5
 346            maximum possible probability of error at the intervals
 347        hysteresis: bool, default False
 348            if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff
 349        threshold_diff: float, optional
 350            the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes
 351        min_frames_error: int, optional
 352            if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames
 353        smooth_interval: int, default 1
 354            the number of frames to smooth the predictions over
 355        cut_annotated: bool, default False
 356            if `True`, annotated intervals will be cut out of the predicted intervals
 357
 358        Returns
 359        -------
 360        valleys : dict
 361            a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals
 362
 363        """
 364        result = defaultdict(lambda: [])
 365        if type(predicted) is not dict:
 366            predicted = self.generate_full_length_prediction(predicted)
 367        if predicted_error is not None:
 368            predicted_error = self.generate_full_length_prediction(predicted_error)
 369        elif min_frames_error is not None and min_frames_error != 0:
 370            # warnings.warn(
 371            #     f"The min_frames_error parameter is set to {min_frames_error} but no error prediction "
 372            #     f"is given! Setting min_frames_error to 0."
 373            # )
 374            min_frames_error = 0
 375        if low and hysteresis and threshold_diff is None:
 376            raise ValueError(
 377                "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff."
 378            )
 379        if cut_annotated:
 380            masked_intervals_dict = self.get_annotated_intervals()
 381        else:
 382            masked_intervals_dict = None
 383        print("Valleys found:")
 384        for v_id in predicted:
 385            for clip_id in predicted[v_id].keys():
 386                if predicted_error is not None:
 387                    error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold
 388                    if min_frames_error is not None:
 389                        output, indices, counts = torch.unique_consecutive(
 390                            error_mask, return_inverse=True, return_counts=True
 391                        )
 392                        wrong_indices = torch.where(
 393                            output * (counts < min_frames_error)
 394                        )[0]
 395                        if len(wrong_indices) > 0:
 396                            for i in wrong_indices:
 397                                error_mask[indices == i] = False
 398                else:
 399                    error_mask = None
 400                if masked_intervals_dict is not None:
 401                    masked_intervals = masked_intervals_dict[v_id][clip_id]
 402                else:
 403                    masked_intervals = None
 404                if not hysteresis:
 405                    res_indices_start, res_indices_end = apply_threshold(
 406                        predicted[v_id][clip_id][main_class, :],
 407                        threshold,
 408                        low,
 409                        error_mask,
 410                        min_frames,
 411                        smooth_interval,
 412                        masked_intervals,
 413                    )
 414                elif threshold_diff is not None:
 415                    if low:
 416                        soft_threshold = threshold + threshold_diff
 417                    else:
 418                        soft_threshold = threshold - threshold_diff
 419                    res_indices_start, res_indices_end = apply_threshold_hysteresis(
 420                        predicted[v_id][clip_id][main_class, :],
 421                        soft_threshold,
 422                        threshold,
 423                        low,
 424                        error_mask,
 425                        min_frames,
 426                        smooth_interval,
 427                        masked_intervals,
 428                    )
 429                else:
 430                    res_indices_start, res_indices_end = apply_threshold_max(
 431                        predicted[v_id][clip_id],
 432                        threshold,
 433                        main_class,
 434                        error_mask,
 435                        min_frames,
 436                        smooth_interval,
 437                        masked_intervals,
 438                    )
 439                start = self.input_store.get_clip_start(v_id, clip_id)
 440                result[v_id] += [
 441                    [i + start, j + start, clip_id]
 442                    for i, j in zip(res_indices_start, res_indices_end)
 443                    if self.input_store.get_visibility(
 444                        v_id, clip_id, i, j, visibility_min_score
 445                    )
 446                    > visibility_min_frac
 447                ]
 448            result[v_id] = sorted(result[v_id])
 449            print(f"    {v_id}: {len(result[v_id])}")
 450        return dict(result)
 451
 452    def valleys_union(self, valleys_list) -> Dict:
 453        """Find the intersection of two valleys dictionaries.
 454
 455        Parameters
 456        ----------
 457        valleys_list : list
 458            a list of valleys dictionaries
 459
 460        Returns
 461        -------
 462        intersection : dict
 463            a new valleys dictionary with the intersection of the input intervals
 464
 465        """
 466        valleys_list = [x for x in valleys_list if x is not None]
 467        if len(valleys_list) == 1:
 468            return valleys_list[0]
 469        elif len(valleys_list) == 0:
 470            return {}
 471        union = {}
 472        keys_list = [set(valleys.keys()) for valleys in valleys_list]
 473        keys = set.union(*keys_list)
 474        for v_id in keys:
 475            res = []
 476            clips_list = [
 477                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
 478            ]
 479            clips = set.union(*clips_list)
 480            for clip_id in clips:
 481                clip_intervals = [
 482                    x
 483                    for valleys in valleys_list
 484                    for x in valleys[v_id]
 485                    if x[-1] == clip_id
 486                ]
 487                v_len = self.input_store.get_clip_length(v_id, clip_id)
 488                arr = torch.zeros(v_len)
 489                for start, end, _ in clip_intervals:
 490                    arr[start:end] += 1
 491                output, indices, counts = torch.unique_consecutive(
 492                    arr > 0, return_inverse=True, return_counts=True
 493                )
 494                long_indices = torch.where(output)[0]
 495                res += [
 496                    (
 497                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
 498                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
 499                        clip_id,
 500                    )
 501                    for i in long_indices
 502                ]
 503            union[v_id] = res
 504        return union
 505
 506    def valleys_intersection(self, valleys_list) -> Dict:
 507        """Find the intersection of two valleys dictionaries.
 508
 509        Parameters
 510        ----------
 511        valleys_list : list
 512            a list of valleys dictionaries
 513
 514        Returns
 515        -------
 516        intersection : dict
 517            a new valleys dictionary with the intersection of the input intervals
 518
 519        """
 520        valleys_list = [x for x in valleys_list if x is not None]
 521        if len(valleys_list) == 1:
 522            return valleys_list[0]
 523        elif len(valleys_list) == 0:
 524            return {}
 525        intersection = {}
 526        keys_list = [set(valleys.keys()) for valleys in valleys_list]
 527        keys = set.intersection(*keys_list)
 528        for v_id in keys:
 529            res = []
 530            clips_list = [
 531                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
 532            ]
 533            clips = set.intersection(*clips_list)
 534            for clip_id in clips:
 535                clip_intervals = [
 536                    x
 537                    for valleys in valleys_list
 538                    for x in valleys[v_id]
 539                    if x[-1] == clip_id
 540                ]
 541                v_len = self.input_store.get_clip_length(v_id, clip_id)
 542                arr = torch.zeros(v_len)
 543                for start, end, _ in clip_intervals:
 544                    arr[start:end] += 1
 545                output, indices, counts = torch.unique_consecutive(
 546                    arr, return_inverse=True, return_counts=True
 547                )
 548                long_indices = torch.where(output == 2)[0]
 549                res += [
 550                    (
 551                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
 552                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
 553                        clip_id,
 554                    )
 555                    for i in long_indices
 556                ]
 557            intersection[v_id] = res
 558        return intersection
 559
 560    def partition_train_test_val(
 561        self,
 562        use_test: float = 0,
 563        split_path: str = None,
 564        method: str = "random",
 565        val_frac: float = 0,
 566        test_frac: float = 0,
 567        save_split: bool = False,
 568        normalize: bool = False,
 569        skip_normalization_keys: List = None,
 570        stats: Dict = None,
 571    ) -> Tuple:
 572        """Partition the dataset into three new datasets.
 573
 574        Parameters
 575        ----------
 576        use_test : float, default 0
 577            The fraction of the test dataset to be used in training without labels
 578        split_path : str, optional
 579            The path to load the split information from (if `'file'` method is used) and to save it to
 580            (if `'save_split'` is `True`)
 581        method : {'random', 'random:test-from-name', 'random:test-from-name:{name}',
 582            'val-from-name:{val_name}:test-from-name:{test_name}',
 583            'random:equalize:segments', 'random:equalize:videos',
 584            'folders', 'time', 'time:strict', 'file'}
 585            The partitioning method:
 586            - `'random'`: sort videos into subsets randomly,
 587            - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation
 588                subsets randomly and create
 589                the test subset from the video ids that start with a speific substring (`'test'` by default, or `name`
 590                if provided),
 591            - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but
 592                making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain
 593                occurrences of the class get into the validation subset and `0.8 * test_frac` get into the test subset;
 594                this in ensured for all classes in order of increasing number of occurrences until the validation and
 595                test subsets are full
 596            - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test
 597                subsets from the video ids that start with specific substrings (`val_name` for validation
 598                and `test_name` for test) and sort all other videos into the training subset
 599            - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets,
 600            - `'time'`: split each video into training, validation and test subsequences,
 601            - `'time:strict'`: split each video into validation, test and training subsequences
 602                and throw out the last segments in validation and test (to get rid of overlaps),
 603            - `'file'`: split according to a split file.
 604        val_frac : float, default 0
 605            The fraction of the dataset to be used in validation
 606        test_frac : float, default 0
 607            The fraction of the dataset to be used in test
 608        save_split : bool, default False
 609            Save a split file if True
 610        normalize : bool, default False
 611            Normalize the dataset if `True`
 612        skip_normalization_keys : list, optional
 613            A list of keys to skip normalization for
 614        stats : dict, optional
 615            A dictionary of (pre-computed) statistics to use for normalization
 616
 617        Returns
 618        -------
 619        train_dataset : BehaviorDataset
 620            train dataset
 621        val_dataset : BehaviorDataset
 622            validation dataset
 623        test_dataset : BehaviorDataset
 624            test dataset
 625
 626        """
 627        train_indices, test_indices, val_indices = self._partition_indices(
 628            split_path=split_path,
 629            method=method,
 630            val_frac=val_frac,
 631            test_frac=test_frac,
 632            save_split=save_split,
 633        )
 634        ssl_indices = None
 635        partition_method = method.split(":")
 636        if (
 637            partition_method[0] in ("leave-one-in", "leave-n-in")
 638            and len(partition_method) > 1
 639            and partition_method[2] == "val-for-ssl"
 640        ):
 641            print("Using validation samples for SSL!")
 642            ssl_indices = val_indices
 643
 644        val_dataset = self._create_new_dataset(val_indices)
 645        test_dataset = self._create_new_dataset(test_indices)
 646        train_dataset = self._create_new_dataset(train_indices, ssl_indices=ssl_indices)
 647
 648        train_classes = train_dataset.count_classes()
 649        val_classes = val_dataset.count_classes()
 650        test_classes = test_dataset.count_classes()
 651        print("Number of samples:")
 652        print(f"    validation:")
 653        print(f"      {[f'{k}: {val_classes[k]}' for k in sorted(val_classes.keys())]}")
 654        print(f"    training:")
 655        print(f"      {[f'{k}: {train_classes[k]}' for k in sorted(train_classes.keys())]}")
 656        print(f"    test:")
 657        print(f"      {[f'{k}: {test_classes[k]}' for k in sorted(test_classes.keys())]}")
 658        if normalize:
 659            if stats is None:
 660                print("Computing normalization statistics...")
 661                stats = train_dataset.get_normalization_stats(skip_normalization_keys)
 662            else:
 663                print("Setting loaded normalization statistics...")
 664            train_dataset.set_normalization_stats(stats)
 665            val_dataset.set_normalization_stats(stats)
 666            test_dataset.set_normalization_stats(stats)
 667        return train_dataset, test_dataset, val_dataset
 668
 669    def class_weights(self, proportional=False) -> List:
 670        """Calculate class weights in inverse proportion to number of samples.
 671
 672        Parameters
 673        ----------
 674        proportional : bool, default False
 675            If `True`, the weights are proportional to the number of samples in the most common class
 676
 677        Returns
 678        -------
 679        weights: list
 680            a list of class weights
 681
 682        """
 683        items = sorted(
 684            [
 685                (k, v)
 686                for k, v in self.annotation_store.count_classes().items()
 687                if k != -100
 688            ]
 689        )
 690        if self.annotation_store.annotation_class() == "exclusive_classification":
 691            if not proportional:
 692                numerator = len(self.annotation_store)
 693            else:
 694                numerator = max([x[1] for x in items])
 695            weights = [numerator / (v + 1e-7) for _, v in items]
 696        else:
 697            items_zero = sorted(
 698                [
 699                    (k, v)
 700                    for k, v in self.annotation_store.count_classes(zeros=True).items()
 701                    if k != -100
 702                ]
 703            )
 704            if not proportional:
 705                numerators = defaultdict(lambda: len(self.annotation_store))
 706            else:
 707                numerators = {
 708                    item_one[0]: max(item_one[1], item_zero[1])
 709                    for item_one, item_zero in zip(items, items_zero)
 710                }
 711            weights = {}
 712            weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero]
 713            weights[1] = [numerators[k] / (v + 1e-7) for k, v in items]
 714        return weights
 715
 716    def _boundary_class_weight(self):
 717        """Calculate the weight of the boundary class.
 718
 719        Returns
 720        -------
 721        weight: float
 722            the weight of the boundary class
 723
 724        """
 725        if self.annotation_type != "none":
 726            f = self.annotation_store.data.flatten()
 727            _, inv = torch.unique_consecutive(f, return_inverse=True)
 728            boundary = torch.cat([torch.tensor([0]), torch.diff(inv)]).reshape(
 729                self.annotation_store.data.shape
 730            )
 731            boundary[..., 0] = 0
 732            cnt = Counter(boundary.flatten().numpy())
 733            return cnt[1] / cnt[0]
 734        else:
 735            return 0
 736
 737    def count_classes(self, bouts: bool = False) -> Dict:
 738        """Get a class counter dictionary.
 739
 740        Parameters
 741        ----------
 742        bouts : bool, default False
 743            if `True`, instead of frame counts segment counts are returned
 744
 745        Returns
 746        -------
 747        count_dictionary : dict
 748            a dictionary with class indices as keys and frame or bout counts as values
 749
 750        """
 751        return self.annotation_store.count_classes(bouts=bouts)
 752
 753    def behaviors_dict(self) -> Dict:
 754        """Get a behavior dictionary.
 755
 756        Returns
 757        -------
 758        dict
 759            behavior dictionary
 760
 761        """
 762        return self.annotation_store.behaviors_dict()
 763
 764    def bodyparts_order(self) -> List:
 765        """Get the order of bodyparts.
 766
 767        Returns
 768        -------
 769        bodyparts : List
 770            a list of bodyparts
 771
 772        """
 773        try:
 774            return self.input_store.get_bodyparts()
 775        except:
 776            raise RuntimeError(
 777                f"The {self.input_type} input store does not have bodyparts implemented!"
 778            )
 779
 780    def features_shape(self) -> Dict:
 781        """Get the shapes of the input features.
 782
 783        Returns
 784        -------
 785        shapes : Dict
 786            a dictionary with the shapes of the features
 787
 788        """
 789        sample = self.input_store[0]
 790        shapes = {k: v.shape for k, v in sample.items()}
 791        # for key, value in shapes.items():
 792        #     print(f'{key}: {value}')
 793        return shapes
 794
 795    def num_classes(self) -> int:
 796        """Get the number of classes in the data.
 797
 798        Returns
 799        -------
 800        num_classes : int
 801            the number of classes
 802
 803        """
 804        return len(self.annotation_store.behaviors_dict())
 805
 806    def len_segment(self) -> int:
 807        """Get the segment length in the data.
 808
 809        Returns
 810        -------
 811        len_segment : int
 812            the segment length
 813
 814        """
 815        sample = self.input_store[0]
 816        key = list(sample.keys())[0]
 817        return sample[key].shape[-1]
 818
 819    def set_ssl_transformations(self, ssl_transformations: List) -> None:
 820        """Set new SSL transformations.
 821
 822        Parameters
 823        ----------
 824        ssl_transformations : list
 825            a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets
 826            lists
 827
 828        """
 829        self.ssl_transformations = ssl_transformations
 830
 831    @classmethod
 832    def new(cls, *args, **kwargs):
 833        """Create a new object of the same class.
 834
 835        Parameters
 836        ----------
 837        args : list
 838            arguments for the constructor
 839        kwargs : dict
 840            keyword arguments for the constructor
 841
 842        Returns
 843        -------
 844        new_instance: BehaviorDataset
 845            a new instance of the same class
 846
 847        """
 848        return cls(*args, **kwargs)
 849
 850    @classmethod
 851    def get_parameters(cls, data_type: str, annotation_type: str) -> List:
 852        """Get parameters necessary for initialization.
 853
 854        Parameters
 855        ----------
 856        data_type : str
 857            the data type
 858        annotation_type : str
 859            the annotation type
 860
 861        Returns
 862        -------
 863        parameters : list
 864            a list of parameters
 865
 866        """
 867        input_features = options.input_stores[data_type].get_parameters()
 868        annotation_features = options.annotation_stores[
 869            annotation_type
 870        ].get_parameters()
 871        self_features = inspect.getfullargspec(cls.__init__).args
 872        return self_features + input_features + annotation_features
 873
 874    @staticmethod
 875    def data_types() -> List:
 876        """List available data types.
 877
 878        Returns
 879        -------
 880        data_types : list
 881            available data types
 882
 883        """
 884        return list(options.input_stores.keys())
 885
 886    @staticmethod
 887    def annotation_types() -> List:
 888        """List available annotation types.
 889
 890        Returns
 891        -------
 892        annotation_types : list
 893            available annotation types
 894
 895        """
 896        return list(options.annotation_stores.keys())
 897
 898    def _get_SSL_targets(self, input: Dict) -> Tuple[List, List]:
 899        """Get the SSL inputs and targets from a sample dictionary."""
 900        ssl_inputs = []
 901        ssl_targets = []
 902        for transform in self.ssl_transformations:
 903            ssl_input, ssl_target = transform(copy(input))
 904            ssl_inputs.append(ssl_input)
 905            ssl_targets.append(ssl_target)
 906        return ssl_inputs, ssl_targets
 907
 908    def _create_new_dataset(self, indices: List, ssl_indices: List = None):
 909        """Create a subsample of the dataset, with samples at ssl_indices losing the annotation."""
 910        if ssl_indices is None:
 911            ssl_indices = []
 912        input_store = self.input_store.create_subsample(indices, ssl_indices)
 913        annotation_store = self.annotation_store.create_subsample(indices, ssl_indices)
 914        new = self.new(
 915            data_type=self.input_type,
 916            annotation_type=self.annotation_type,
 917            ssl_transformations=self.ssl_transformations,
 918            annotation_store=annotation_store,
 919            input_store=input_store,
 920            ids=list(indices) + list(ssl_indices),
 921            recompute_annotation=False,
 922        )
 923        return new
 924
 925    def _load_input_store(self, data_type: str, key_objects: Tuple) -> InputStore:
 926        """Load input store from key objects."""
 927        input_store = options.input_stores[data_type](key_objects=key_objects)
 928        return input_store
 929
 930    def _load_annotation_store(
 931        self, annotation_type: str, key_objects: Tuple
 932    ) -> BehaviorStore:
 933        """Load annotation store from key objects."""
 934        annotation_store = options.annotation_stores[annotation_type](
 935            key_objects=key_objects
 936        )
 937        return annotation_store
 938
 939    def _get_input_store(self, data_type: str, data_parameters: Dict) -> InputStore:
 940        """Create input store from parameters."""
 941        data_parameters["key_objects"] = None
 942        input_store = options.input_stores[data_type](**data_parameters)
 943        return input_store
 944
 945    def _get_annotation_store(
 946        self, annotation_type: str, data_parameters: Dict
 947    ) -> BehaviorStore:
 948        """Create annotation store from parameters."""
 949        annotation_store = options.annotation_stores[annotation_type](**data_parameters)
 950        return annotation_store
 951
 952    def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None:
 953        """Set the parameters that change the subset that is returned at `__getitem__`.
 954
 955        Parameters
 956        ----------
 957        unlabeled : bool
 958            a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and
 959            all if `None`
 960        tag : int
 961            if not `None`, only samples with this meta tag will be returned
 962
 963        """
 964        if unlabeled != self.return_unlabeled:
 965            self.annotation_indices = self.annotation_store.get_indices(unlabeled)
 966            self.return_unlabeled = unlabeled
 967        if tag != self.tag:
 968            self.input_indices = self.input_store.get_indices(tag)
 969            self.tag = tag
 970        self.indices = [x for x in self.annotation_indices if x in self.input_indices]
 971
 972    def _get_idx(self, index: int) -> int:
 973        """Get index in full dataset."""
 974        return self.indices[index]
 975
 976        # return self.annotation_store.get_idx(
 977        #     index, return_unlabeled=self.return_unlabeled
 978        # )
 979
 980    def _partition_indices(
 981        self,
 982        split_path: str = None,
 983        method: str = "random",
 984        val_frac: float = 0,
 985        test_frac: float = 0,
 986        save_split: bool = False,
 987    ) -> Tuple[List, List, List]:
 988        """Partition indices into train, validation, test subsets."""
 989        if self.mask is not None:
 990            val_indices = self.mask["val_ids"]
 991            train_indices = [x for x in range(len(self)) if x not in val_indices]
 992            test_indices = []
 993        elif method == "random":
 994            videos = np.array(self.input_store.get_video_id_order())
 995            all_videos = list(set(videos))
 996            if len(all_videos) == 1:
 997                warnings.warn(
 998                    "There is only one video in the dataset, so train/val/test split is done on segments; "
 999                    'that might lead to overlaps, please consider using "time" or "time:strict" as the '
1000                    "partitioning method instead"
1001                )
1002                # Quick fix for single video: the problem with this is that the segments can overlap
1003                # length = int(self.input_store.get_original_coordinates()[-1][1])    # number of segments
1004                length = len(self.input_store.get_original_coordinates())
1005                val_len = int(val_frac * length)
1006                test_len = int(test_frac * length)
1007                all_indices = np.random.choice(np.arange(length), length, replace=False)
1008                val_indices = all_indices[:val_len]
1009                test_indices = all_indices[val_len : val_len + test_len]
1010                train_indices = all_indices[val_len + test_len :]
1011                coords = self.input_store.get_original_coordinates()
1012                if save_split:
1013                    self._save_partition(
1014                        coords[train_indices],
1015                        coords[val_indices],
1016                        coords[test_indices],
1017                        split_path,
1018                        coords=True,
1019                    )
1020            else:
1021                length = len(all_videos)
1022                val_len = int(val_frac * length)
1023                test_len = int(test_frac * length)
1024                validation = all_videos[:val_len]
1025                test = all_videos[val_len : val_len + test_len]
1026                training = all_videos[val_len + test_len :]
1027                train_indices = np.where(np.isin(videos, training))[0]
1028                val_indices = np.where(np.isin(videos, validation))[0]
1029                test_indices = np.where(np.isin(videos, test))[0]
1030                if save_split:
1031                    self._save_partition(training, validation, test, split_path)
1032        elif method.startswith("random:equalize"):
1033            counter = self.count_classes()
1034            counter = sorted(list([(v, k) for k, v in counter.items()]))
1035            classes = [x[1] for x in counter]
1036            indicator = {c: [] for c in classes}
1037            if method.endswith("videos"):
1038                videos = np.array(self.input_store.get_video_id_order())
1039                all_videos = list(set(videos))
1040                total_len = len(all_videos)
1041                for video_id in all_videos:
1042                    video_coords = np.where(videos == video_id)[0]
1043                    ann = torch.cat(
1044                        [self.annotation_store[i] for i in video_coords], dim=-1
1045                    )
1046                    for c in classes:
1047                        if self.annotation_class() == "nonexclusive_classification":
1048                            indicator[c].append(torch.sum(ann[c] == 1) > 0)
1049                        elif self.annotation_class() == "exclusive_classification":
1050                            indicator[c].append(torch.sum(ann == c) > 0)
1051                        else:
1052                            raise ValueError(
1053                                f"The random:equalize partition method is not implemented"
1054                                f"for the {self.annotation_class()} annotation class"
1055                            )
1056            elif method.endswith("segments"):
1057                total_len = len(self)
1058                for ann in self.annotation_store:
1059                    for c in classes:
1060                        if self.annotation_class() == "nonexclusive_classification":
1061                            indicator[c].append(torch.sum(ann[c] == 1) > 0)
1062                        elif self.annotation_class() == "exclusive_classification":
1063                            indicator[c].append(torch.sum(ann == c) > 0)
1064                        else:
1065                            raise ValueError(
1066                                f"The random:equalize partition method is not implemented"
1067                                f"for the {self.annotation_class()} annotation class"
1068                            )
1069            else:
1070                values = []
1071                for v in options.partition_methods.values():
1072                    values += v
1073                raise ValueError(
1074                    f"The {method} partition method is not recognized; please choose from {values}"
1075                )
1076            val_indices = []
1077            test_indices = []
1078            for c in classes:
1079                indicator[c] = np.array(indicator[c])
1080                ind = np.where(indicator[c])[0]
1081                np.random.shuffle(ind)
1082                c_sum = len(ind)
1083                in_val = np.sum(indicator[c][val_indices])
1084                in_test = np.sum(indicator[c][test_indices])
1085                while (
1086                    len(val_indices) < val_frac * total_len
1087                    and in_val < val_frac * c_sum * 0.8
1088                ):
1089                    first, ind = ind[0], ind[1:]
1090                    val_indices = list(set(val_indices).union({first}))
1091                    in_val = np.sum(indicator[c][val_indices])
1092                while (
1093                    len(test_indices) < test_frac * total_len
1094                    and in_test < test_frac * c_sum * 0.8
1095                ):
1096                    first, ind = ind[0], ind[1:]
1097                    test_indices = list(set(test_indices).union({first}))
1098                    in_test = np.sum(indicator[c][test_indices])
1099            if len(val_indices) < int(val_frac * total_len):
1100                left_val = int(val_frac * total_len) - len(val_indices)
1101            else:
1102                left_val = 0
1103            if len(test_indices) < int(test_frac * total_len):
1104                left_test = int(test_frac * total_len) - len(test_indices)
1105            else:
1106                left_test = 0
1107            indicator = np.ones(total_len)
1108            indicator[val_indices] = 0
1109            indicator[test_indices] = 0
1110            ind = np.where(indicator)[0]
1111            np.random.shuffle(ind)
1112            val_indices += list(ind[:left_val])
1113            test_indices += list(ind[left_val : left_val + left_test])
1114            train_indices = list(ind[left_val + left_test :])
1115            if save_split:
1116                if method.endswith("segments"):
1117                    coords = self.input_store.get_original_coordinates()
1118                    self._save_partition(
1119                        coords[train_indices],
1120                        coords[val_indices],
1121                        coords[test_indices],
1122                        coords[split_path],
1123                        coords=True,
1124                    )
1125                else:
1126                    all_videos = np.array(all_videos)
1127                    validation = all_videos[val_indices]
1128                    test = all_videos[test_indices]
1129                    training = all_videos[train_indices]
1130                    self._save_partition(training, validation, test, split_path)
1131        elif method.startswith("random:test-from-name"):
1132            split = method.split(":")
1133            if len(split) > 2:
1134                test_name = split[-1]
1135            else:
1136                test_name = "test"
1137            videos = np.array(self.input_store.get_video_id_order())
1138            all_videos = list(set(videos))
1139            test = []
1140            train_videos = []
1141            for x in all_videos:
1142                if x.startswith(test_name):
1143                    test.append(x)
1144                else:
1145                    train_videos.append(x)
1146            length = len(train_videos)
1147            val_len = int(val_frac * length)
1148            validation = train_videos[:val_len]
1149            training = train_videos[val_len:]
1150            train_indices = np.where(np.isin(videos, training))[0]
1151            val_indices = np.where(np.isin(videos, validation))[0]
1152            test_indices = np.where(np.isin(videos, test))[0]
1153            if save_split:
1154                self._save_partition(training, validation, test, split_path)
1155        elif method.startswith("val-from-name"):
1156            split = method.split(":")
1157            if split[2] != "test-from-name":
1158                raise ValueError(
1159                    f"The {method} partition method is not recognized, please choose from {options.partition_methods}"
1160                )
1161            val_name = split[1]
1162            test_name = split[-1]
1163            videos = np.array(self.input_store.get_video_id_order())
1164            all_videos = list(set(videos))
1165            test = []
1166            validation = []
1167            training = []
1168            for x in all_videos:
1169                if x.startswith(test_name):
1170                    test.append(x)
1171                elif x.startswith(val_name):
1172                    validation.append(x)
1173                else:
1174                    training.append(x)
1175            train_indices = np.where(np.isin(videos, training))[0]
1176            val_indices = np.where(np.isin(videos, validation))[0]
1177            test_indices = np.where(np.isin(videos, test))[0]
1178        elif method == "folders":
1179            folders = np.array(self.input_store.get_folder_order())
1180            videos = np.array(self.input_store.get_video_id_order())
1181            train_indices = np.where(np.isin(folders, ["training", "train"]))[0]
1182            if np.sum(np.isin(folders, ["validation", "val"])) > 0:
1183                val_indices = np.where(np.isin(folders, ["validation", "val"]))[0]
1184            else:
1185                train_videos = list(set(videos[train_indices]))
1186                val_len = int(val_frac * len(train_videos))
1187                validation = train_videos[:val_len]
1188                training = train_videos[val_len:]
1189                train_indices = np.where(np.isin(videos, training))[0]
1190                val_indices = np.where(np.isin(videos, validation))[0]
1191            test_indices = np.where(folders == "test")[0]
1192            if save_split:
1193                self._save_partition(
1194                    list(set(videos[train_indices])),
1195                    list(set(videos[val_indices])),
1196                    list(set(videos[test_indices])),
1197                    split_path,
1198                )
1199        elif method.startswith("leave-one-out"):
1200            n = int(method.split(":")[-1])
1201            videos = np.array(self.input_store.get_video_id_order())
1202            all_videos = sorted(list(set(videos)))
1203            print(len(all_videos))
1204            validation = [all_videos.pop(n)]
1205            training = all_videos
1206            train_indices = np.where(np.isin(videos, training))[0]
1207            val_indices = np.where(np.isin(videos, validation))[0]
1208            test_indices = np.array([])
1209        elif method.startswith("leave-one-in"):
1210            n = int(method.split(":")[1])
1211            videos = np.array(self.input_store.get_video_id_order())
1212            all_videos = sorted(list(set(videos)))
1213            training = [all_videos.pop(n)]
1214            validation = all_videos
1215            train_indices = np.where(np.isin(videos, training))[0]
1216            val_indices = np.where(np.isin(videos, validation))[0]
1217            test_indices = np.array([])
1218        elif method.startswith("leave-n-in"):
1219            train_idx = [int(i) for i in method.split(":")[1].split(",")]
1220            videos = np.array(self.input_store.get_video_id_order())
1221            all_videos = sorted(list(set(videos)))
1222            training = [v for i, v in enumerate(all_videos) if i in train_idx]
1223            validation = [v for i, v in enumerate(all_videos) if i not in train_idx]
1224            train_indices = np.where(np.isin(videos, training))[0]
1225            val_indices = np.where(np.isin(videos, validation))[0]
1226            test_indices = np.array([])
1227        elif method.startswith("time"):
1228            if method.endswith("strict"):
1229                len_segment = self.len_segment()
1230                step = self.input_store.step
1231                num_removed = len_segment // step
1232            else:
1233                num_removed = 0
1234            videos = np.array(self.input_store.get_video_id_order())
1235            all_videos = set(videos)
1236            train_indices = []
1237            val_indices = []
1238            test_indices = []
1239            start = 0
1240            if len(method.split(":")) > 1 and method.split(":")[1] == "start-from":
1241                start = float(method.split(":")[2])
1242            for video_id in all_videos:
1243                video_indices = np.where(videos == video_id)[0]
1244                val_len = int(val_frac * len(video_indices))
1245                test_len = int(test_frac * len(video_indices))
1246                start_pos = int(start * len(video_indices))
1247                all_ind = np.ones(len(video_indices))
1248                val_indices += list(video_indices[start_pos : start_pos + val_len])
1249                all_ind[start_pos : start_pos + val_len] = 0
1250                if start_pos + val_len > len(video_indices):
1251                    p = start_pos + val_len - len(video_indices)
1252                    val_indices += list(video_indices[:p])
1253                    all_ind[:p] = 0
1254                else:
1255                    p = start_pos + val_len
1256                test_indices += list(video_indices[p : p + test_len])
1257                all_ind[p : p + test_len] = 0
1258                if p + test_len > len(video_indices):
1259                    p = test_len + p - len(video_indices)
1260                    test_indices += list(video_indices[:p])
1261                    all_ind[:p] = 0
1262                train_indices += list(video_indices[all_ind > 0])
1263                for _ in range(num_removed):
1264                    if len(val_indices) > 0:
1265                        val_indices.pop(-1)
1266                    if len(test_indices) > 0:
1267                        test_indices.pop(-1)
1268                    if start > 0 and len(train_indices) > 0:
1269                        train_indices.pop(-1)
1270        elif method == "file":
1271            if split_path is None:
1272                raise ValueError(
1273                    'You need to either set split_path or change partition method ("file" requires a file)'
1274                )
1275            active_list = None
1276            training, validation, test = [], [], []
1277            with open(split_path) as f:
1278                for line in f.readlines():
1279                    if line.startswith("Train"):
1280                        active_list = training
1281                    elif line.startswith("Valid"):
1282                        active_list = validation
1283                    elif line.startswith("Test"):
1284                        active_list = test
1285                    else:
1286                        stripped_line = line.rstrip(",\n ")
1287                        if stripped_line == "":
1288                            continue
1289                        if ", " in stripped_line:
1290                            active_list += stripped_line.split(", ")
1291                        else:
1292                            active_list.append(stripped_line)
1293            all_lines = training + validation + test
1294            if len(all_lines[0].split("---")) == 3:
1295                entry_type = "coords"
1296            else:
1297                entry_type = "videos"
1298
1299            if entry_type == "videos":
1300                videos = np.array(self.input_store.get_video_id_order())
1301                val_indices = np.where(np.isin(videos, validation))[0]
1302                test_indices = np.where(np.isin(videos, test))[0]
1303                train_indices = np.where(np.isin(videos, training))[0]
1304            elif entry_type == "coords":
1305                coords = self.input_store.get_original_coordinates()
1306                video_ids = self.input_store.get_video_id_order()
1307                clip_ids = [self.input_store.get_clip_id(coord) for coord in coords]
1308                starts, ends = zip(
1309                    *[self.input_store.get_clip_start_end(coord) for coord in coords]
1310                )
1311                coords = np.array(
1312                    [
1313                        f"{video_id}---{clip_id}---{start}-{end}"
1314                        for video_id, clip_id, start, end in zip(
1315                            video_ids, clip_ids, starts, ends
1316                        )
1317                    ]
1318                )
1319                val_indices = np.where(np.isin(coords, validation))[0]
1320                test_indices = np.where(np.isin(coords, test))[0]
1321                train_indices = np.where(np.isin(coords, training))[0]
1322            else:
1323                raise ValueError("The split path has unrecognized format!")
1324            all_indices = np.ones(len(self))
1325            if len(train_indices) == 0:
1326                all_indices[val_indices] = 0
1327                all_indices[test_indices] = 0
1328                train_indices = np.where(all_indices)[0]
1329            elif len(val_indices) == 0:
1330                all_indices[train_indices] = 0
1331                all_indices[test_indices] = 0
1332                val_indices = np.where(all_indices)[0]
1333            elif len(test_indices) == 0:
1334                all_indices[train_indices] = 0
1335                all_indices[val_indices] = 0
1336                test_indices = np.where(all_indices)[0]
1337        else:
1338            raise ValueError(
1339                f"The {method} partition is not recognized, please choose from {options.partition_methods}"
1340            )
1341        return sorted(train_indices), sorted(test_indices), sorted(val_indices)
1342
1343    def _save_partition(
1344        self,
1345        training: List,
1346        validation: List,
1347        test: List,
1348        split_path: str,
1349        coords: bool = False,
1350    ) -> None:
1351        """Save a split file."""
1352        if coords:
1353            name = "coords"
1354            training_coords = []
1355            val_coords = []
1356            test_coords = []
1357            for coord in training:
1358                video_id = self.input_store.get_video_id(coord)
1359                clip_id = self.input_store.get_clip_id(coord)
1360                start, end = self.input_store.get_clip_start_end(coord)
1361                training_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1362            for coord in validation:
1363                video_id = self.input_store.get_video_id(coord)
1364                clip_id = self.input_store.get_clip_id(coord)
1365                start, end = self.input_store.get_clip_start_end(coord)
1366                val_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1367            for coord in test:
1368                video_id = self.input_store.get_video_id(coord)
1369                clip_id = self.input_store.get_clip_id(coord)
1370                start, end = self.input_store.get_clip_start_end(coord)
1371                test_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1372            training, validation, test = training_coords, val_coords, test_coords
1373        else:
1374            name = "videos"
1375        if split_path is not None:
1376            with open(split_path, "w") as f:
1377                f.write(f"Training {name}:\n")
1378                for x in training:
1379                    f.write(x + "\n")
1380                f.write(f"Validation {name}:\n")
1381                for x in validation:
1382                    f.write(x + "\n")
1383                f.write(f"Test {name}:\n")
1384                for x in test:
1385                    f.write(x + "\n")
1386
1387    def _get_intervals_from_ind(self, frame_indices: np.ndarray):
1388        """Get a list of intervals from a list of frame indices.
1389
1390        Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`.
1391
1392        Parameters
1393        ----------
1394        frame_indices : np.ndarray
1395            a list of frame indices
1396
1397        Returns
1398        -------
1399        intervals : list
1400            a list of interval boundaries
1401
1402        """
1403        masked_intervals = []
1404        breaks = np.where(np.diff(frame_indices) != 1)[0]
1405        if len(frame_indices) > 0:
1406            start = frame_indices[0]
1407            for k in breaks:
1408                masked_intervals.append([start, frame_indices[k] + 1])
1409                start = frame_indices[k + 1]
1410            masked_intervals.append([start, frame_indices[-1] + 1])
1411        return masked_intervals
1412
1413    def get_intervals(self) -> Tuple[dict, Optional[list]]:
1414        """Get a list of intervals covered by the dataset in the original coordinates.
1415
1416        Returns
1417        -------
1418        intervals : dict
1419            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1420            values are lists of the intervals in `[start, end]` format
1421
1422        """
1423        counter = defaultdict(lambda: {})
1424        coordinates = self.input_store.get_original_coordinates()
1425        for coords in coordinates:
1426            l = self.input_store.get_clip_length_from_coords(coords)
1427            video_name = self.input_store.get_video_id(coords)
1428            clip_id = self.input_store.get_clip_id(coords)
1429            start, end = self.input_store.get_clip_start_end(coords)
1430            if clip_id not in counter[video_name]:
1431                counter[video_name][clip_id] = np.zeros(l)
1432            counter[video_name][clip_id][start:end] = 1
1433        result = {video_name: {} for video_name in counter}
1434        for video_name in counter:
1435            for clip_id in counter[video_name]:
1436                result[video_name][clip_id] = self._get_intervals_from_ind(
1437                    np.where(counter[video_name][clip_id])[0]
1438                )
1439        return result, self.ids
1440
1441    def get_unannotated_intervals(self, first_intervals=None) -> Dict:
1442        """Get a list of intervals in the original coordinates where there is no annotation.
1443
1444        Parameters
1445        ----------
1446        first_intervals : dict
1447            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1448            values are lists of the intervals in `[start, end]` format. If provided, only the intersection with
1449            those intervals will be returned
1450
1451        Returns
1452        -------
1453        intervals : dict
1454            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1455            values are lists of the intervals in `[start, end]` format
1456
1457        """
1458        counter_value = 2
1459        if first_intervals is None:
1460            first_intervals = defaultdict(lambda: defaultdict(lambda: []))
1461            counter_value = 1
1462        counter = defaultdict(lambda: {})
1463        coordinates = self.input_store.get_original_coordinates()
1464        for i, coords in enumerate(coordinates):
1465            l = self.input_store.get_clip_length_from_coords(coords)
1466            ann = self.annotation_store[i]
1467            if (
1468                self.annotation_store.annotation_class()
1469                == "nonexclusive_classification"
1470            ):
1471                ann = ann[0, :]
1472            video_name = self.input_store.get_video_id(coords)
1473            clip_id = self.input_store.get_clip_id(coords)
1474            start, end = self.input_store.get_clip_start_end(coords)
1475            if clip_id not in counter[video_name]:
1476                counter[video_name][clip_id] = np.ones(l)
1477            counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int()
1478        result = {video_name: {} for video_name in counter}
1479        for video_name in counter:
1480            for clip_id in counter[video_name]:
1481                for start, end in first_intervals[video_name][clip_id]:
1482                    counter[video_name][clip_id][start:end] += 1
1483                result[video_name][clip_id] = self._get_intervals_from_ind(
1484                    np.where(counter[video_name][clip_id] == counter_value)[0]
1485                )
1486        return result
1487
1488    def get_annotated_intervals(self) -> Dict:
1489        """Get a list of intervals in the original coordinates where there is no annotation.
1490
1491        Returns
1492        -------
1493        intervals : dict
1494            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1495            values are lists of the intervals in `[start, end]` format
1496
1497        """
1498        if self.annotation_type == "none":
1499            return []
1500        counter_value = 1
1501        counter = defaultdict(lambda: {})
1502        coordinates = self.input_store.get_original_coordinates()
1503        for i, coords in enumerate(coordinates):
1504            l = self.input_store.get_clip_length_from_coords(coords)
1505            ann = self.annotation_store[i]
1506            video_name = self.input_store.get_video_id(coords)
1507            clip_id = self.input_store.get_clip_id(coords)
1508            start, end = self.input_store.get_clip_start_end(coords)
1509            if clip_id not in counter[video_name]:
1510                counter[video_name][clip_id] = np.zeros(l)
1511            if (
1512                self.annotation_store.annotation_class()
1513                == "nonexclusive_classification"
1514            ):
1515                counter[video_name][clip_id][start:end] = (
1516                    torch.sum(ann[:, : end - start] != -100, dim=0) > 0
1517                ).int()
1518            else:
1519                counter[video_name][clip_id][start:end] = (
1520                    ann[: end - start] != -100
1521                ).int()
1522        result = {video_name: {} for video_name in counter}
1523        for video_name in counter:
1524            for clip_id in counter[video_name]:
1525                result[video_name][clip_id] = self._get_intervals_from_ind(
1526                    np.where(counter[video_name][clip_id] == counter_value)[0]
1527                )
1528        return result
1529
1530    def get_ids(self) -> Dict:
1531        """Get a dictionary of all clip ids in the dataset.
1532
1533        Returns
1534        -------
1535        ids : dict
1536            a dictionary where keys are video ids and values are lists of clip ids
1537
1538        """
1539        coordinates = self.input_store.get_original_coordinates()
1540        video_ids = np.array(self.input_store.get_video_id_order())
1541        id_set = set(video_ids)
1542        result = {}
1543        for video_id in id_set:
1544            coords = coordinates[video_ids == video_id]
1545            clip_ids = list({self.input_store.get_clip_id(c) for c in coords})
1546            result[video_id] = clip_ids
1547        return result
1548
1549    def get_len(self, video_id: str, clip_id: str) -> int:
1550        """Get the length of a specific clip.
1551
1552        Parameters
1553        ----------
1554        video_id : str
1555            the video id
1556        clip_id : str
1557            the clip id
1558
1559        Returns
1560        -------
1561        length : int
1562            the length
1563
1564        """
1565        return self.input_store.get_clip_length(video_id, clip_id)
1566
1567    def get_confusion_matrix(
1568        self, prediction: torch.Tensor, confusion_type: str = "recall"
1569    ) -> Tuple[ndarray, list]:
1570        """Get a confusion matrix.
1571
1572        Parameters
1573        ----------
1574        prediction : torch.Tensor
1575            a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)`
1576        confusion_type : {"recall", "precision"}
1577            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
1578            into account, and if `type` is `"precision"`, only false negatives
1579
1580        Returns
1581        -------
1582        confusion_matrix : np.ndarray
1583            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
1584            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
1585            `N_i` is the number of frames that have the i-th label in the ground truth
1586        classes : list
1587            a list of classes
1588
1589        """
1590        behaviors_dict = self.annotation_store.behaviors_dict()
1591        num_behaviors = len(behaviors_dict)
1592        confusion_matrix = np.zeros((num_behaviors, num_behaviors))
1593        if self.annotation_store.annotation_class() == "exclusive_classification":
1594            exclusive = True
1595            confusion_type = None
1596        elif self.annotation_store.annotation_class() == "nonexclusive_classification":
1597            exclusive = False
1598        else:
1599            raise RuntimeError(
1600                f"The {self.annotation_store.annotation_class()} annotation class is not recognized!"
1601            )
1602        for ann, p in zip(self.annotation_store, prediction):
1603            if exclusive:
1604                class_prediction = torch.max(p, dim=0)[1]
1605                for i in behaviors_dict.keys():
1606                    for j in behaviors_dict.keys():
1607                        confusion_matrix[i, j] += int(
1608                            torch.sum(class_prediction[ann == i] == j)
1609                        )
1610            else:
1611                class_prediction = (p > 0.5).int()
1612                for i in behaviors_dict.keys():
1613                    for j in behaviors_dict.keys():
1614                        if confusion_type == "recall":
1615                            pred = deepcopy(class_prediction[j])
1616                            if i != j:
1617                                pred[ann[j] == 1] = 0
1618                            confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1]))
1619                        elif confusion_type == "precision":
1620                            annotation = deepcopy(ann[j])
1621                            if i != j:
1622                                annotation[class_prediction[j] == 1] = 0
1623                            confusion_matrix[i, j] += int(
1624                                torch.sum(annotation[class_prediction[i] == 1])
1625                            )
1626                        else:
1627                            raise ValueError(
1628                                f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']"
1629                            )
1630        counter = self.annotation_store.count_classes()
1631        for i in behaviors_dict.keys():
1632            if counter[i] != 0:
1633                if confusion_type == "recall" or confusion_type is None:
1634                    confusion_matrix[i, :] /= counter[i]
1635                else:
1636                    confusion_matrix[:, i] /= counter[i]
1637        return confusion_matrix, list(behaviors_dict.values()), confusion_type
1638
1639    def annotation_class(self) -> str:
1640        """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon).
1641
1642        Returns
1643        -------
1644        annotation_class : str
1645            the type of annotation
1646
1647        """
1648        return self.annotation_store.annotation_class()
1649
1650    def set_normalization_stats(self, stats: Dict) -> None:
1651        """Set the stats to normalize data at runtime.
1652
1653        Parameters
1654        ----------
1655        stats : dict
1656            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1657            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1658
1659        """
1660        self.stats = stats
1661
1662    def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]:
1663        """Get the minimum and maximum frame numbers for each clip in a video.
1664
1665        Parameters
1666        ----------
1667        video_id : str
1668            the video id
1669
1670        Returns
1671        -------
1672        min_frames : dict
1673            a dictionary where keys are clip ids and values are the minimum frame numbers
1674        max_frames : dict
1675            a dictionary where keys are clip ids and values are the maximum frame numbers
1676
1677        """
1678        coords = self.input_store.get_original_coordinates()
1679        clips = set(
1680            [
1681                self.input_store.get_clip_id(c)
1682                for c in coords
1683                if self.input_store.get_video_id(c) == video_id
1684            ]
1685        )
1686        min_frames = {}
1687        max_frames = {}
1688        for clip in clips:
1689            start = self.input_store.get_clip_start(video_id, clip)
1690            end = start + self.input_store.get_clip_length(video_id, clip)
1691            min_frames[clip] = start
1692            max_frames[clip] = end - 1
1693        return min_frames, max_frames
1694
1695    def get_normalization_stats(self, skip_keys=None) -> Dict:
1696        """Get mean and standard deviation for each key.
1697
1698        Parameters
1699        ----------
1700        skip_keys : list, optional
1701            a list of keys to skip
1702
1703        Returns
1704        -------
1705        stats : dict
1706            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1707            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1708
1709        """
1710        stats = defaultdict(lambda: {})
1711        sums = defaultdict(lambda: 0)
1712        if skip_keys is None:
1713            skip_keys = []
1714        counter = defaultdict(lambda: 0)
1715        for sample in tqdm(self):
1716            for key, value in sample["input"].items():
1717                key_name = key.split("---")[0]
1718                if key_name not in skip_keys:
1719                    sums[key_name] += value[:, value.sum(0) != 0].sum(-1)
1720                counter[key_name] += torch.sum(value.sum(0) != 0)
1721        for key, value in sums.items():
1722            stats[key]["mean"] = (value / counter[key]).unsqueeze(-1)
1723        sums = defaultdict(lambda: 0)
1724        for sample in tqdm(self):
1725            for key, value in sample["input"].items():
1726                key_name = key.split("---")[0]
1727                if key_name not in skip_keys:
1728                    sums[key_name] += (
1729                        (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2
1730                    ).sum(-1)
1731        for key, value in sums.items():
1732            stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key])
1733        return stats

A generalized dataset class.

Data and annotation are stored in separate InputStore and BehaviorStore objects; the dataset class manages their interactions.

BehaviorDataset( data_type: str, annotation_type: str = 'none', ssl_transformations: List = None, saved_data_path: str = None, input_store: dlc2action.data.base_store.InputStore = None, annotation_store: dlc2action.data.base_store.BehaviorStore = None, only_load_annotated: bool = False, recompute_annotation: bool = False, ids: List = None, **data_parameters)
 40    def __init__(
 41        self,
 42        data_type: str,
 43        annotation_type: str = "none",
 44        ssl_transformations: List = None,
 45        saved_data_path: str = None,
 46        input_store: InputStore = None,
 47        annotation_store: BehaviorStore = None,
 48        only_load_annotated: bool = False,
 49        recompute_annotation: bool = False,
 50        # mask: str = None,
 51        ids: List = None,
 52        **data_parameters,
 53    ) -> None:
 54        """Initialize a dataset.
 55
 56        Parameters
 57        ----------
 58        data_type : str
 59            the data type (see available types by running BehaviorDataset.data_types())
 60        annotation_type : str
 61            the annotation type (see available types by running BehaviorDataset.annotation_types())
 62        ssl_transformations : list
 63            a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple
 64        saved_data_path : str
 65            the path to a pre-computed pickled dataset
 66        input_store : InputStore
 67            a pre-computed input store
 68        annotation_store : BehaviorStore
 69            a precomputed annotation store
 70        only_load_annotated : bool
 71            if `True`, the input files that don't have a matching annotation file will be disregarded
 72        recompute_annotation : bool
 73            if `True`, the annotation will be recomputed even if a precomputed annotation store is provided
 74        ids : list
 75            a list of ids to load from the input store
 76        *data_parameters : dict
 77            parameters to initialize the input and annotation stores
 78
 79        """
 80        mask = None
 81        if len(data_parameters) == 0:
 82            recompute_annotation = False
 83        feature_extraction = data_parameters.get("feature_extraction")
 84        if feature_extraction is not None and not issubclass(
 85            options.input_stores[data_type],
 86            options.feature_extractors[feature_extraction].input_store_class,
 87        ):
 88            raise ValueError(
 89                f"The {feature_extraction} feature extractor does not work with "
 90                f"the {data_type} data type, please choose a suclass of "
 91                f"{options.feature_extractors[feature_extraction].input_store_class}"
 92            )
 93        if ssl_transformations is None:
 94            ssl_transformations = []
 95        self.ssl_transformations = ssl_transformations
 96        self.input_type = data_type
 97        self.annotation_type = annotation_type
 98        self.stats = None
 99        if mask is not None:
100            with open(mask, "rb") as f:
101                self.mask = pickle.load(f)
102        else:
103            self.mask = None
104        self.ids = ids
105        self.tag = None
106        self.return_unlabeled = None
107        # load saved key objects for annotation and input if they exist
108        input_key_objects, annotation_key_objects = None, None
109        if saved_data_path is not None:
110            if os.path.exists(saved_data_path):
111                with open(saved_data_path, "rb") as f:
112                    input_key_objects, annotation_key_objects = pickle.load(f)
113        # if the input or the annotation store need to be created, generate the common video order
114        if len(data_parameters) > 0:
115            input_files = options.input_stores[data_type].get_file_ids(
116                **data_parameters
117            )
118            annotation_files = options.annotation_stores[annotation_type].get_file_ids(
119                **data_parameters
120            )
121            if only_load_annotated:
122                data_parameters["video_order"] = [
123                    x for x in input_files if x in annotation_files
124                ]
125            else:
126                data_parameters["video_order"] = input_files
127            if len(data_parameters["video_order"]) == 0:
128                raise RuntimeError(
129                    "The length of file list is 0! Please check your data parameters!"
130                )
131        data_parameters["mask"] = self.mask
132        # load or create the input store
133        ok = False
134        if input_store is not None:
135            self.input_store = input_store
136            ok = True
137        elif input_key_objects is not None:
138            try:
139                self.input_store = self._load_input_store(data_type, input_key_objects)
140                ok = True
141            except:
142                warnings.warn("Loading input store from key objects failed")
143        if not ok:
144            self.input_store = self._get_input_store(
145                data_type, deepcopy(data_parameters)
146            )
147        # get the objects needed to create the annotation store (like a clip length dictionary)
148        annotation_objects = self.input_store.get_annotation_objects()
149        data_parameters.update(annotation_objects)
150        # load or create the annotation store
151        ok = False
152        if annotation_store is not None:
153            self.annotation_store = annotation_store
154            ok = True
155        elif (
156            (annotation_key_objects is not None)
157            and mask is None
158            and not recompute_annotation
159        ):
160            if len(annotation_key_objects) > 0:
161                try:
162                    self.annotation_store = self._load_annotation_store(
163                        annotation_type, annotation_key_objects
164                    )
165                    ok = True
166                except:
167                    warnings.warn("Loading annotation store from key objects failed")
168        if not ok:
169            self.annotation_store = self._get_annotation_store(
170                annotation_type, deepcopy(data_parameters)
171            )
172        to_remove = self.annotation_store.filtered_indices()
173        if len(to_remove) > 0:
174            print(
175                f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples"
176            )
177        if len(to_remove) == len(self.annotation_store) and len(to_remove) > 0:
178            raise ValueError("All samples were filtered out!")
179
180        if len(self.input_store) == len(self.annotation_store):
181            self.input_store.remove(to_remove)
182        self.annotation_store.remove(to_remove)
183        self.input_indices = list(range(len(self.input_store)))
184        self.annotation_indices = list(range(len(self.input_store)))
185        self.indices = list(range(len(self.input_store)))

Initialize a dataset.

Parameters

data_type : str the data type (see available types by running BehaviorDataset.data_types()) annotation_type : str the annotation type (see available types by running BehaviorDataset.annotation_types()) ssl_transformations : list a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple saved_data_path : str the path to a pre-computed pickled dataset input_store : InputStore a pre-computed input store annotation_store : BehaviorStore a precomputed annotation store only_load_annotated : bool if True, the input files that don't have a matching annotation file will be disregarded recompute_annotation : bool if True, the annotation will be recomputed even if a precomputed annotation store is provided ids : list a list of ids to load from the input store *data_parameters : dict parameters to initialize the input and annotation stores

ssl_transformations
input_type
annotation_type
stats
ids
tag
return_unlabeled
input_indices
annotation_indices
indices
def get_tags(self) -> List:
218    def get_tags(self) -> List:
219        """Get a list of all meta tags.
220
221        Returns
222        -------
223        tags: List
224            a list of unique meta tag values
225
226        """
227        return self.input_store.get_tags()

Get a list of all meta tags.

Returns

tags: List a list of unique meta tag values

def save(self, save_path: str) -> None:
229    def save(self, save_path: str) -> None:
230        """Save the dictionary.
231
232        Parameters
233        ----------
234        save_path : str
235            the path where the pickled file will be stored
236
237        """
238        input_obj = self.input_store.key_objects()
239        annotation_obj = self.annotation_store.key_objects()
240        with open(save_path, "wb") as f:
241            pickle.dump((input_obj, annotation_obj), f)

Save the dictionary.

Parameters

save_path : str the path where the pickled file will be stored

def to_ram(self) -> None:
243    def to_ram(self) -> None:
244        """Transfer the dataset to RAM."""
245        self.input_store.to_ram()
246        self.annotation_store.to_ram()

Transfer the dataset to RAM.

def generate_full_length_gt(self) -> Dict:
248    def generate_full_length_gt(self) -> Dict:
249        """Generate full-length ground truth from the annotations.
250
251        Returns
252        -------
253        full_length_gt : dict
254            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
255            values are the ground truth labels
256
257        """
258        if self.annotation_class() == "exclusive_classification":
259            gt = torch.zeros((len(self), self.len_segment()))
260        else:
261            gt = torch.zeros(
262                (len(self), len(self.behaviors_dict()), self.len_segment())
263            )
264        for i in range(len(self)):
265            gt[i] = self.annotation_store[i]
266        return self.generate_full_length_prediction(gt)

Generate full-length ground truth from the annotations.

Returns

full_length_gt : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are the ground truth labels

def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict:
268    def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict:
269        """Map predictions for the equal-length pieces to predictions for the original data.
270
271        Probabilities are averaged over predictions on overlapping intervals.
272
273        Parameters
274        ----------
275        predicted: torch.Tensor
276            a tensor of predicted probabilities of shape `(N, #classes, #frames)`
277
278        Returns
279        -------
280        full_length_prediction : dict
281            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
282            averaged probability tensors
283
284        """
285        result = defaultdict(lambda: {})
286        counter = defaultdict(lambda: {})
287        coordinates = self.input_store.get_original_coordinates()
288        for coords, prediction in zip(coordinates, predicted):
289            l = self.input_store.get_clip_length_from_coords(coords)
290            video_name = self.input_store.get_video_id(coords)
291            clip_id = self.input_store.get_clip_id(coords)
292            start, end = self.input_store.get_clip_start_end(coords)
293            if clip_id not in result[video_name].keys():
294                result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
295                counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
296            result[video_name][clip_id][..., start:end] += (
297                prediction.squeeze()[..., : end - start].detach().cpu()
298            )
299            counter[video_name][clip_id][..., start:end] += 1
300        for video_name in result:
301            for clip_id in result[video_name]:
302                result[video_name][clip_id] /= counter[video_name][clip_id]
303                result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100
304        result = dict(result)
305        return result

Map predictions for the equal-length pieces to predictions for the original data.

Probabilities are averaged over predictions on overlapping intervals.

Parameters

predicted: torch.Tensor a tensor of predicted probabilities of shape (N, #classes, #frames)

Returns

full_length_prediction : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are averaged probability tensors

def find_valleys( self, predicted: Union[torch.Tensor, Dict], threshold: float = 0.5, min_frames: int = 0, visibility_min_score: float = 0, visibility_min_frac: float = 0, main_class: int = 1, low: bool = True, predicted_error: torch.Tensor = None, error_threshold: float = 0.5, hysteresis: bool = False, threshold_diff: float = None, min_frames_error: int = None, smooth_interval: int = 1, cut_annotated: bool = False) -> Dict:
307    def find_valleys(
308        self,
309        predicted: Union[torch.Tensor, Dict],
310        threshold: float = 0.5,
311        min_frames: int = 0,
312        visibility_min_score: float = 0,
313        visibility_min_frac: float = 0,
314        main_class: int = 1,
315        low: bool = True,
316        predicted_error: torch.Tensor = None,
317        error_threshold: float = 0.5,
318        hysteresis: bool = False,
319        threshold_diff: float = None,
320        min_frames_error: int = None,
321        smooth_interval: int = 1,
322        cut_annotated: bool = False,
323    ) -> Dict:
324        """Find the intervals where the probability of a certain class is below or above a certain hard_threshold.
325
326        Parameters
327        ----------
328        predicted : torch.Tensor | dict
329            either a tensor of predictions for the data prompts or the output of
330            `BehaviorDataset.generate_full_length_prediction`
331        threshold : float, default 0.5
332            the main hard_threshold
333        min_frames : int, default 0
334            the minimum length of the intervals
335        visibility_min_score : float, default 0
336            the minimum visibility score in the intervals
337        visibility_min_frac : float, default 0
338            fraction of the interval that has to have the visibility score larger than visibility_score_thr
339        main_class : int, default 1
340            the index of the class the function is inspecting
341        low : bool, default True
342            if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above
343        predicted_error : torch.Tensor, optional
344            a tensor of error predictions for the data prompts
345        error_threshold : float, default 0.5
346            maximum possible probability of error at the intervals
347        hysteresis: bool, default False
348            if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff
349        threshold_diff: float, optional
350            the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes
351        min_frames_error: int, optional
352            if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames
353        smooth_interval: int, default 1
354            the number of frames to smooth the predictions over
355        cut_annotated: bool, default False
356            if `True`, annotated intervals will be cut out of the predicted intervals
357
358        Returns
359        -------
360        valleys : dict
361            a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals
362
363        """
364        result = defaultdict(lambda: [])
365        if type(predicted) is not dict:
366            predicted = self.generate_full_length_prediction(predicted)
367        if predicted_error is not None:
368            predicted_error = self.generate_full_length_prediction(predicted_error)
369        elif min_frames_error is not None and min_frames_error != 0:
370            # warnings.warn(
371            #     f"The min_frames_error parameter is set to {min_frames_error} but no error prediction "
372            #     f"is given! Setting min_frames_error to 0."
373            # )
374            min_frames_error = 0
375        if low and hysteresis and threshold_diff is None:
376            raise ValueError(
377                "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff."
378            )
379        if cut_annotated:
380            masked_intervals_dict = self.get_annotated_intervals()
381        else:
382            masked_intervals_dict = None
383        print("Valleys found:")
384        for v_id in predicted:
385            for clip_id in predicted[v_id].keys():
386                if predicted_error is not None:
387                    error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold
388                    if min_frames_error is not None:
389                        output, indices, counts = torch.unique_consecutive(
390                            error_mask, return_inverse=True, return_counts=True
391                        )
392                        wrong_indices = torch.where(
393                            output * (counts < min_frames_error)
394                        )[0]
395                        if len(wrong_indices) > 0:
396                            for i in wrong_indices:
397                                error_mask[indices == i] = False
398                else:
399                    error_mask = None
400                if masked_intervals_dict is not None:
401                    masked_intervals = masked_intervals_dict[v_id][clip_id]
402                else:
403                    masked_intervals = None
404                if not hysteresis:
405                    res_indices_start, res_indices_end = apply_threshold(
406                        predicted[v_id][clip_id][main_class, :],
407                        threshold,
408                        low,
409                        error_mask,
410                        min_frames,
411                        smooth_interval,
412                        masked_intervals,
413                    )
414                elif threshold_diff is not None:
415                    if low:
416                        soft_threshold = threshold + threshold_diff
417                    else:
418                        soft_threshold = threshold - threshold_diff
419                    res_indices_start, res_indices_end = apply_threshold_hysteresis(
420                        predicted[v_id][clip_id][main_class, :],
421                        soft_threshold,
422                        threshold,
423                        low,
424                        error_mask,
425                        min_frames,
426                        smooth_interval,
427                        masked_intervals,
428                    )
429                else:
430                    res_indices_start, res_indices_end = apply_threshold_max(
431                        predicted[v_id][clip_id],
432                        threshold,
433                        main_class,
434                        error_mask,
435                        min_frames,
436                        smooth_interval,
437                        masked_intervals,
438                    )
439                start = self.input_store.get_clip_start(v_id, clip_id)
440                result[v_id] += [
441                    [i + start, j + start, clip_id]
442                    for i, j in zip(res_indices_start, res_indices_end)
443                    if self.input_store.get_visibility(
444                        v_id, clip_id, i, j, visibility_min_score
445                    )
446                    > visibility_min_frac
447                ]
448            result[v_id] = sorted(result[v_id])
449            print(f"    {v_id}: {len(result[v_id])}")
450        return dict(result)

Find the intervals where the probability of a certain class is below or above a certain hard_threshold.

Parameters

predicted : torch.Tensor | dict either a tensor of predictions for the data prompts or the output of BehaviorDataset.generate_full_length_prediction threshold : float, default 0.5 the main hard_threshold min_frames : int, default 0 the minimum length of the intervals visibility_min_score : float, default 0 the minimum visibility score in the intervals visibility_min_frac : float, default 0 fraction of the interval that has to have the visibility score larger than visibility_score_thr main_class : int, default 1 the index of the class the function is inspecting low : bool, default True if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above predicted_error : torch.Tensor, optional a tensor of error predictions for the data prompts error_threshold : float, default 0.5 maximum possible probability of error at the intervals hysteresis: bool, default False if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff threshold_diff: float, optional the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes min_frames_error: int, optional if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames smooth_interval: int, default 1 the number of frames to smooth the predictions over cut_annotated: bool, default False if True, annotated intervals will be cut out of the predicted intervals

Returns

valleys : dict a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals

def valleys_union(self, valleys_list) -> Dict:
452    def valleys_union(self, valleys_list) -> Dict:
453        """Find the intersection of two valleys dictionaries.
454
455        Parameters
456        ----------
457        valleys_list : list
458            a list of valleys dictionaries
459
460        Returns
461        -------
462        intersection : dict
463            a new valleys dictionary with the intersection of the input intervals
464
465        """
466        valleys_list = [x for x in valleys_list if x is not None]
467        if len(valleys_list) == 1:
468            return valleys_list[0]
469        elif len(valleys_list) == 0:
470            return {}
471        union = {}
472        keys_list = [set(valleys.keys()) for valleys in valleys_list]
473        keys = set.union(*keys_list)
474        for v_id in keys:
475            res = []
476            clips_list = [
477                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
478            ]
479            clips = set.union(*clips_list)
480            for clip_id in clips:
481                clip_intervals = [
482                    x
483                    for valleys in valleys_list
484                    for x in valleys[v_id]
485                    if x[-1] == clip_id
486                ]
487                v_len = self.input_store.get_clip_length(v_id, clip_id)
488                arr = torch.zeros(v_len)
489                for start, end, _ in clip_intervals:
490                    arr[start:end] += 1
491                output, indices, counts = torch.unique_consecutive(
492                    arr > 0, return_inverse=True, return_counts=True
493                )
494                long_indices = torch.where(output)[0]
495                res += [
496                    (
497                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
498                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
499                        clip_id,
500                    )
501                    for i in long_indices
502                ]
503            union[v_id] = res
504        return union

Find the intersection of two valleys dictionaries.

Parameters

valleys_list : list a list of valleys dictionaries

Returns

intersection : dict a new valleys dictionary with the intersection of the input intervals

def valleys_intersection(self, valleys_list) -> Dict:
506    def valleys_intersection(self, valleys_list) -> Dict:
507        """Find the intersection of two valleys dictionaries.
508
509        Parameters
510        ----------
511        valleys_list : list
512            a list of valleys dictionaries
513
514        Returns
515        -------
516        intersection : dict
517            a new valleys dictionary with the intersection of the input intervals
518
519        """
520        valleys_list = [x for x in valleys_list if x is not None]
521        if len(valleys_list) == 1:
522            return valleys_list[0]
523        elif len(valleys_list) == 0:
524            return {}
525        intersection = {}
526        keys_list = [set(valleys.keys()) for valleys in valleys_list]
527        keys = set.intersection(*keys_list)
528        for v_id in keys:
529            res = []
530            clips_list = [
531                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
532            ]
533            clips = set.intersection(*clips_list)
534            for clip_id in clips:
535                clip_intervals = [
536                    x
537                    for valleys in valleys_list
538                    for x in valleys[v_id]
539                    if x[-1] == clip_id
540                ]
541                v_len = self.input_store.get_clip_length(v_id, clip_id)
542                arr = torch.zeros(v_len)
543                for start, end, _ in clip_intervals:
544                    arr[start:end] += 1
545                output, indices, counts = torch.unique_consecutive(
546                    arr, return_inverse=True, return_counts=True
547                )
548                long_indices = torch.where(output == 2)[0]
549                res += [
550                    (
551                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
552                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
553                        clip_id,
554                    )
555                    for i in long_indices
556                ]
557            intersection[v_id] = res
558        return intersection

Find the intersection of two valleys dictionaries.

Parameters

valleys_list : list a list of valleys dictionaries

Returns

intersection : dict a new valleys dictionary with the intersection of the input intervals

def partition_train_test_val( self, use_test: float = 0, split_path: str = None, method: str = 'random', val_frac: float = 0, test_frac: float = 0, save_split: bool = False, normalize: bool = False, skip_normalization_keys: List = None, stats: Dict = None) -> Tuple:
560    def partition_train_test_val(
561        self,
562        use_test: float = 0,
563        split_path: str = None,
564        method: str = "random",
565        val_frac: float = 0,
566        test_frac: float = 0,
567        save_split: bool = False,
568        normalize: bool = False,
569        skip_normalization_keys: List = None,
570        stats: Dict = None,
571    ) -> Tuple:
572        """Partition the dataset into three new datasets.
573
574        Parameters
575        ----------
576        use_test : float, default 0
577            The fraction of the test dataset to be used in training without labels
578        split_path : str, optional
579            The path to load the split information from (if `'file'` method is used) and to save it to
580            (if `'save_split'` is `True`)
581        method : {'random', 'random:test-from-name', 'random:test-from-name:{name}',
582            'val-from-name:{val_name}:test-from-name:{test_name}',
583            'random:equalize:segments', 'random:equalize:videos',
584            'folders', 'time', 'time:strict', 'file'}
585            The partitioning method:
586            - `'random'`: sort videos into subsets randomly,
587            - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation
588                subsets randomly and create
589                the test subset from the video ids that start with a speific substring (`'test'` by default, or `name`
590                if provided),
591            - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but
592                making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain
593                occurrences of the class get into the validation subset and `0.8 * test_frac` get into the test subset;
594                this in ensured for all classes in order of increasing number of occurrences until the validation and
595                test subsets are full
596            - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test
597                subsets from the video ids that start with specific substrings (`val_name` for validation
598                and `test_name` for test) and sort all other videos into the training subset
599            - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets,
600            - `'time'`: split each video into training, validation and test subsequences,
601            - `'time:strict'`: split each video into validation, test and training subsequences
602                and throw out the last segments in validation and test (to get rid of overlaps),
603            - `'file'`: split according to a split file.
604        val_frac : float, default 0
605            The fraction of the dataset to be used in validation
606        test_frac : float, default 0
607            The fraction of the dataset to be used in test
608        save_split : bool, default False
609            Save a split file if True
610        normalize : bool, default False
611            Normalize the dataset if `True`
612        skip_normalization_keys : list, optional
613            A list of keys to skip normalization for
614        stats : dict, optional
615            A dictionary of (pre-computed) statistics to use for normalization
616
617        Returns
618        -------
619        train_dataset : BehaviorDataset
620            train dataset
621        val_dataset : BehaviorDataset
622            validation dataset
623        test_dataset : BehaviorDataset
624            test dataset
625
626        """
627        train_indices, test_indices, val_indices = self._partition_indices(
628            split_path=split_path,
629            method=method,
630            val_frac=val_frac,
631            test_frac=test_frac,
632            save_split=save_split,
633        )
634        ssl_indices = None
635        partition_method = method.split(":")
636        if (
637            partition_method[0] in ("leave-one-in", "leave-n-in")
638            and len(partition_method) > 1
639            and partition_method[2] == "val-for-ssl"
640        ):
641            print("Using validation samples for SSL!")
642            ssl_indices = val_indices
643
644        val_dataset = self._create_new_dataset(val_indices)
645        test_dataset = self._create_new_dataset(test_indices)
646        train_dataset = self._create_new_dataset(train_indices, ssl_indices=ssl_indices)
647
648        train_classes = train_dataset.count_classes()
649        val_classes = val_dataset.count_classes()
650        test_classes = test_dataset.count_classes()
651        print("Number of samples:")
652        print(f"    validation:")
653        print(f"      {[f'{k}: {val_classes[k]}' for k in sorted(val_classes.keys())]}")
654        print(f"    training:")
655        print(f"      {[f'{k}: {train_classes[k]}' for k in sorted(train_classes.keys())]}")
656        print(f"    test:")
657        print(f"      {[f'{k}: {test_classes[k]}' for k in sorted(test_classes.keys())]}")
658        if normalize:
659            if stats is None:
660                print("Computing normalization statistics...")
661                stats = train_dataset.get_normalization_stats(skip_normalization_keys)
662            else:
663                print("Setting loaded normalization statistics...")
664            train_dataset.set_normalization_stats(stats)
665            val_dataset.set_normalization_stats(stats)
666            test_dataset.set_normalization_stats(stats)
667        return train_dataset, test_dataset, val_dataset

Partition the dataset into three new datasets.

Parameters

use_test : float, default 0 The fraction of the test dataset to be used in training without labels split_path : str, optional The path to load the split information from (if 'file' method is used) and to save it to (if 'save_split' is True) method : {'random', 'random:test-from-name', 'random:test-from-name:{name}', 'val-from-name:{val_name}:test-from-name:{test_name}', 'random:equalize:segments', 'random:equalize:videos', 'folders', 'time', 'time:strict', 'file'} The partitioning method: - 'random': sort videos into subsets randomly, - 'random:test-from-name' (or 'random:test-from-name:{name}'): sort videos into training and validation subsets randomly and create the test subset from the video ids that start with a speific substring ('test' by default, or name if provided), - 'random:equalize:segments' and 'random:equalize:videos': sort videos into subsets randomly but making sure that for the rarest classes at least 0.8 * val_frac of the videos/segments that contain occurrences of the class get into the validation subset and 0.8 * test_frac get into the test subset; this in ensured for all classes in order of increasing number of occurrences until the validation and test subsets are full - 'val-from-name:{val_name}:test-from-name:{test_name}': create the validation and test subsets from the video ids that start with specific substrings (val_name for validation and test_name for test) and sort all other videos into the training subset - 'folders': read videos from folders named test, train and val into corresponding subsets, - 'time': split each video into training, validation and test subsequences, - 'time:strict': split each video into validation, test and training subsequences and throw out the last segments in validation and test (to get rid of overlaps), - 'file': split according to a split file. val_frac : float, default 0 The fraction of the dataset to be used in validation test_frac : float, default 0 The fraction of the dataset to be used in test save_split : bool, default False Save a split file if True normalize : bool, default False Normalize the dataset if True skip_normalization_keys : list, optional A list of keys to skip normalization for stats : dict, optional A dictionary of (pre-computed) statistics to use for normalization

Returns

train_dataset : BehaviorDataset train dataset val_dataset : BehaviorDataset validation dataset test_dataset : BehaviorDataset test dataset

def class_weights(self, proportional=False) -> List:
669    def class_weights(self, proportional=False) -> List:
670        """Calculate class weights in inverse proportion to number of samples.
671
672        Parameters
673        ----------
674        proportional : bool, default False
675            If `True`, the weights are proportional to the number of samples in the most common class
676
677        Returns
678        -------
679        weights: list
680            a list of class weights
681
682        """
683        items = sorted(
684            [
685                (k, v)
686                for k, v in self.annotation_store.count_classes().items()
687                if k != -100
688            ]
689        )
690        if self.annotation_store.annotation_class() == "exclusive_classification":
691            if not proportional:
692                numerator = len(self.annotation_store)
693            else:
694                numerator = max([x[1] for x in items])
695            weights = [numerator / (v + 1e-7) for _, v in items]
696        else:
697            items_zero = sorted(
698                [
699                    (k, v)
700                    for k, v in self.annotation_store.count_classes(zeros=True).items()
701                    if k != -100
702                ]
703            )
704            if not proportional:
705                numerators = defaultdict(lambda: len(self.annotation_store))
706            else:
707                numerators = {
708                    item_one[0]: max(item_one[1], item_zero[1])
709                    for item_one, item_zero in zip(items, items_zero)
710                }
711            weights = {}
712            weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero]
713            weights[1] = [numerators[k] / (v + 1e-7) for k, v in items]
714        return weights

Calculate class weights in inverse proportion to number of samples.

Parameters

proportional : bool, default False If True, the weights are proportional to the number of samples in the most common class

Returns

weights: list a list of class weights

def count_classes(self, bouts: bool = False) -> Dict:
737    def count_classes(self, bouts: bool = False) -> Dict:
738        """Get a class counter dictionary.
739
740        Parameters
741        ----------
742        bouts : bool, default False
743            if `True`, instead of frame counts segment counts are returned
744
745        Returns
746        -------
747        count_dictionary : dict
748            a dictionary with class indices as keys and frame or bout counts as values
749
750        """
751        return self.annotation_store.count_classes(bouts=bouts)

Get a class counter dictionary.

Parameters

bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

count_dictionary : dict a dictionary with class indices as keys and frame or bout counts as values

def behaviors_dict(self) -> Dict:
753    def behaviors_dict(self) -> Dict:
754        """Get a behavior dictionary.
755
756        Returns
757        -------
758        dict
759            behavior dictionary
760
761        """
762        return self.annotation_store.behaviors_dict()

Get a behavior dictionary.

Returns

dict behavior dictionary

def bodyparts_order(self) -> List:
764    def bodyparts_order(self) -> List:
765        """Get the order of bodyparts.
766
767        Returns
768        -------
769        bodyparts : List
770            a list of bodyparts
771
772        """
773        try:
774            return self.input_store.get_bodyparts()
775        except:
776            raise RuntimeError(
777                f"The {self.input_type} input store does not have bodyparts implemented!"
778            )

Get the order of bodyparts.

Returns

bodyparts : List a list of bodyparts

def features_shape(self) -> Dict:
780    def features_shape(self) -> Dict:
781        """Get the shapes of the input features.
782
783        Returns
784        -------
785        shapes : Dict
786            a dictionary with the shapes of the features
787
788        """
789        sample = self.input_store[0]
790        shapes = {k: v.shape for k, v in sample.items()}
791        # for key, value in shapes.items():
792        #     print(f'{key}: {value}')
793        return shapes

Get the shapes of the input features.

Returns

shapes : Dict a dictionary with the shapes of the features

def num_classes(self) -> int:
795    def num_classes(self) -> int:
796        """Get the number of classes in the data.
797
798        Returns
799        -------
800        num_classes : int
801            the number of classes
802
803        """
804        return len(self.annotation_store.behaviors_dict())

Get the number of classes in the data.

Returns

num_classes : int the number of classes

def len_segment(self) -> int:
806    def len_segment(self) -> int:
807        """Get the segment length in the data.
808
809        Returns
810        -------
811        len_segment : int
812            the segment length
813
814        """
815        sample = self.input_store[0]
816        key = list(sample.keys())[0]
817        return sample[key].shape[-1]

Get the segment length in the data.

Returns

len_segment : int the segment length

def set_ssl_transformations(self, ssl_transformations: List) -> None:
819    def set_ssl_transformations(self, ssl_transformations: List) -> None:
820        """Set new SSL transformations.
821
822        Parameters
823        ----------
824        ssl_transformations : list
825            a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets
826            lists
827
828        """
829        self.ssl_transformations = ssl_transformations

Set new SSL transformations.

Parameters

ssl_transformations : list a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets lists

@classmethod
def new(cls, *args, **kwargs):
831    @classmethod
832    def new(cls, *args, **kwargs):
833        """Create a new object of the same class.
834
835        Parameters
836        ----------
837        args : list
838            arguments for the constructor
839        kwargs : dict
840            keyword arguments for the constructor
841
842        Returns
843        -------
844        new_instance: BehaviorDataset
845            a new instance of the same class
846
847        """
848        return cls(*args, **kwargs)

Create a new object of the same class.

Parameters

args : list arguments for the constructor kwargs : dict keyword arguments for the constructor

Returns

new_instance: BehaviorDataset a new instance of the same class

@classmethod
def get_parameters(cls, data_type: str, annotation_type: str) -> List:
850    @classmethod
851    def get_parameters(cls, data_type: str, annotation_type: str) -> List:
852        """Get parameters necessary for initialization.
853
854        Parameters
855        ----------
856        data_type : str
857            the data type
858        annotation_type : str
859            the annotation type
860
861        Returns
862        -------
863        parameters : list
864            a list of parameters
865
866        """
867        input_features = options.input_stores[data_type].get_parameters()
868        annotation_features = options.annotation_stores[
869            annotation_type
870        ].get_parameters()
871        self_features = inspect.getfullargspec(cls.__init__).args
872        return self_features + input_features + annotation_features

Get parameters necessary for initialization.

Parameters

data_type : str the data type annotation_type : str the annotation type

Returns

parameters : list a list of parameters

@staticmethod
def data_types() -> List:
874    @staticmethod
875    def data_types() -> List:
876        """List available data types.
877
878        Returns
879        -------
880        data_types : list
881            available data types
882
883        """
884        return list(options.input_stores.keys())

List available data types.

Returns

data_types : list available data types

@staticmethod
def annotation_types() -> List:
886    @staticmethod
887    def annotation_types() -> List:
888        """List available annotation types.
889
890        Returns
891        -------
892        annotation_types : list
893            available annotation types
894
895        """
896        return list(options.annotation_stores.keys())

List available annotation types.

Returns

annotation_types : list available annotation types

def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None:
952    def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None:
953        """Set the parameters that change the subset that is returned at `__getitem__`.
954
955        Parameters
956        ----------
957        unlabeled : bool
958            a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and
959            all if `None`
960        tag : int
961            if not `None`, only samples with this meta tag will be returned
962
963        """
964        if unlabeled != self.return_unlabeled:
965            self.annotation_indices = self.annotation_store.get_indices(unlabeled)
966            self.return_unlabeled = unlabeled
967        if tag != self.tag:
968            self.input_indices = self.input_store.get_indices(tag)
969            self.tag = tag
970        self.indices = [x for x in self.annotation_indices if x in self.input_indices]

Set the parameters that change the subset that is returned at __getitem__.

Parameters

unlabeled : bool a pseudolabeling parameter; return only unlabeled samples if True, only labeled if False and all if None tag : int if not None, only samples with this meta tag will be returned

def get_intervals(self) -> Tuple[dict, Optional[list]]:
1413    def get_intervals(self) -> Tuple[dict, Optional[list]]:
1414        """Get a list of intervals covered by the dataset in the original coordinates.
1415
1416        Returns
1417        -------
1418        intervals : dict
1419            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1420            values are lists of the intervals in `[start, end]` format
1421
1422        """
1423        counter = defaultdict(lambda: {})
1424        coordinates = self.input_store.get_original_coordinates()
1425        for coords in coordinates:
1426            l = self.input_store.get_clip_length_from_coords(coords)
1427            video_name = self.input_store.get_video_id(coords)
1428            clip_id = self.input_store.get_clip_id(coords)
1429            start, end = self.input_store.get_clip_start_end(coords)
1430            if clip_id not in counter[video_name]:
1431                counter[video_name][clip_id] = np.zeros(l)
1432            counter[video_name][clip_id][start:end] = 1
1433        result = {video_name: {} for video_name in counter}
1434        for video_name in counter:
1435            for clip_id in counter[video_name]:
1436                result[video_name][clip_id] = self._get_intervals_from_ind(
1437                    np.where(counter[video_name][clip_id])[0]
1438                )
1439        return result, self.ids

Get a list of intervals covered by the dataset in the original coordinates.

Returns

intervals : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are lists of the intervals in [start, end] format

def get_unannotated_intervals(self, first_intervals=None) -> Dict:
1441    def get_unannotated_intervals(self, first_intervals=None) -> Dict:
1442        """Get a list of intervals in the original coordinates where there is no annotation.
1443
1444        Parameters
1445        ----------
1446        first_intervals : dict
1447            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1448            values are lists of the intervals in `[start, end]` format. If provided, only the intersection with
1449            those intervals will be returned
1450
1451        Returns
1452        -------
1453        intervals : dict
1454            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1455            values are lists of the intervals in `[start, end]` format
1456
1457        """
1458        counter_value = 2
1459        if first_intervals is None:
1460            first_intervals = defaultdict(lambda: defaultdict(lambda: []))
1461            counter_value = 1
1462        counter = defaultdict(lambda: {})
1463        coordinates = self.input_store.get_original_coordinates()
1464        for i, coords in enumerate(coordinates):
1465            l = self.input_store.get_clip_length_from_coords(coords)
1466            ann = self.annotation_store[i]
1467            if (
1468                self.annotation_store.annotation_class()
1469                == "nonexclusive_classification"
1470            ):
1471                ann = ann[0, :]
1472            video_name = self.input_store.get_video_id(coords)
1473            clip_id = self.input_store.get_clip_id(coords)
1474            start, end = self.input_store.get_clip_start_end(coords)
1475            if clip_id not in counter[video_name]:
1476                counter[video_name][clip_id] = np.ones(l)
1477            counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int()
1478        result = {video_name: {} for video_name in counter}
1479        for video_name in counter:
1480            for clip_id in counter[video_name]:
1481                for start, end in first_intervals[video_name][clip_id]:
1482                    counter[video_name][clip_id][start:end] += 1
1483                result[video_name][clip_id] = self._get_intervals_from_ind(
1484                    np.where(counter[video_name][clip_id] == counter_value)[0]
1485                )
1486        return result

Get a list of intervals in the original coordinates where there is no annotation.

Parameters

first_intervals : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are lists of the intervals in [start, end] format. If provided, only the intersection with those intervals will be returned

Returns

intervals : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are lists of the intervals in [start, end] format

def get_annotated_intervals(self) -> Dict:
1488    def get_annotated_intervals(self) -> Dict:
1489        """Get a list of intervals in the original coordinates where there is no annotation.
1490
1491        Returns
1492        -------
1493        intervals : dict
1494            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1495            values are lists of the intervals in `[start, end]` format
1496
1497        """
1498        if self.annotation_type == "none":
1499            return []
1500        counter_value = 1
1501        counter = defaultdict(lambda: {})
1502        coordinates = self.input_store.get_original_coordinates()
1503        for i, coords in enumerate(coordinates):
1504            l = self.input_store.get_clip_length_from_coords(coords)
1505            ann = self.annotation_store[i]
1506            video_name = self.input_store.get_video_id(coords)
1507            clip_id = self.input_store.get_clip_id(coords)
1508            start, end = self.input_store.get_clip_start_end(coords)
1509            if clip_id not in counter[video_name]:
1510                counter[video_name][clip_id] = np.zeros(l)
1511            if (
1512                self.annotation_store.annotation_class()
1513                == "nonexclusive_classification"
1514            ):
1515                counter[video_name][clip_id][start:end] = (
1516                    torch.sum(ann[:, : end - start] != -100, dim=0) > 0
1517                ).int()
1518            else:
1519                counter[video_name][clip_id][start:end] = (
1520                    ann[: end - start] != -100
1521                ).int()
1522        result = {video_name: {} for video_name in counter}
1523        for video_name in counter:
1524            for clip_id in counter[video_name]:
1525                result[video_name][clip_id] = self._get_intervals_from_ind(
1526                    np.where(counter[video_name][clip_id] == counter_value)[0]
1527                )
1528        return result

Get a list of intervals in the original coordinates where there is no annotation.

Returns

intervals : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are lists of the intervals in [start, end] format

def get_ids(self) -> Dict:
1530    def get_ids(self) -> Dict:
1531        """Get a dictionary of all clip ids in the dataset.
1532
1533        Returns
1534        -------
1535        ids : dict
1536            a dictionary where keys are video ids and values are lists of clip ids
1537
1538        """
1539        coordinates = self.input_store.get_original_coordinates()
1540        video_ids = np.array(self.input_store.get_video_id_order())
1541        id_set = set(video_ids)
1542        result = {}
1543        for video_id in id_set:
1544            coords = coordinates[video_ids == video_id]
1545            clip_ids = list({self.input_store.get_clip_id(c) for c in coords})
1546            result[video_id] = clip_ids
1547        return result

Get a dictionary of all clip ids in the dataset.

Returns

ids : dict a dictionary where keys are video ids and values are lists of clip ids

def get_len(self, video_id: str, clip_id: str) -> int:
1549    def get_len(self, video_id: str, clip_id: str) -> int:
1550        """Get the length of a specific clip.
1551
1552        Parameters
1553        ----------
1554        video_id : str
1555            the video id
1556        clip_id : str
1557            the clip id
1558
1559        Returns
1560        -------
1561        length : int
1562            the length
1563
1564        """
1565        return self.input_store.get_clip_length(video_id, clip_id)

Get the length of a specific clip.

Parameters

video_id : str the video id clip_id : str the clip id

Returns

length : int the length

def get_confusion_matrix( self, prediction: torch.Tensor, confusion_type: str = 'recall') -> Tuple[numpy.ndarray, list]:
1567    def get_confusion_matrix(
1568        self, prediction: torch.Tensor, confusion_type: str = "recall"
1569    ) -> Tuple[ndarray, list]:
1570        """Get a confusion matrix.
1571
1572        Parameters
1573        ----------
1574        prediction : torch.Tensor
1575            a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)`
1576        confusion_type : {"recall", "precision"}
1577            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
1578            into account, and if `type` is `"precision"`, only false negatives
1579
1580        Returns
1581        -------
1582        confusion_matrix : np.ndarray
1583            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
1584            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
1585            `N_i` is the number of frames that have the i-th label in the ground truth
1586        classes : list
1587            a list of classes
1588
1589        """
1590        behaviors_dict = self.annotation_store.behaviors_dict()
1591        num_behaviors = len(behaviors_dict)
1592        confusion_matrix = np.zeros((num_behaviors, num_behaviors))
1593        if self.annotation_store.annotation_class() == "exclusive_classification":
1594            exclusive = True
1595            confusion_type = None
1596        elif self.annotation_store.annotation_class() == "nonexclusive_classification":
1597            exclusive = False
1598        else:
1599            raise RuntimeError(
1600                f"The {self.annotation_store.annotation_class()} annotation class is not recognized!"
1601            )
1602        for ann, p in zip(self.annotation_store, prediction):
1603            if exclusive:
1604                class_prediction = torch.max(p, dim=0)[1]
1605                for i in behaviors_dict.keys():
1606                    for j in behaviors_dict.keys():
1607                        confusion_matrix[i, j] += int(
1608                            torch.sum(class_prediction[ann == i] == j)
1609                        )
1610            else:
1611                class_prediction = (p > 0.5).int()
1612                for i in behaviors_dict.keys():
1613                    for j in behaviors_dict.keys():
1614                        if confusion_type == "recall":
1615                            pred = deepcopy(class_prediction[j])
1616                            if i != j:
1617                                pred[ann[j] == 1] = 0
1618                            confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1]))
1619                        elif confusion_type == "precision":
1620                            annotation = deepcopy(ann[j])
1621                            if i != j:
1622                                annotation[class_prediction[j] == 1] = 0
1623                            confusion_matrix[i, j] += int(
1624                                torch.sum(annotation[class_prediction[i] == 1])
1625                            )
1626                        else:
1627                            raise ValueError(
1628                                f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']"
1629                            )
1630        counter = self.annotation_store.count_classes()
1631        for i in behaviors_dict.keys():
1632            if counter[i] != 0:
1633                if confusion_type == "recall" or confusion_type is None:
1634                    confusion_matrix[i, :] /= counter[i]
1635                else:
1636                    confusion_matrix[:, i] /= counter[i]
1637        return confusion_matrix, list(behaviors_dict.values()), confusion_type

Get a confusion matrix.

Parameters

prediction : torch.Tensor a tensor of predicted class probabilities of shape (#samples, #classes, #frames) confusion_type : {"recall", "precision"} for datasets with non-exclusive annotation, if type is "recall", only false positives are taken into account, and if type is "precision", only false negatives

Returns

confusion_matrix : np.ndarray a confusion matrix of shape (#classes, #classes) where A[i, j] = F_ij/N_i, F_ij is the number of frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, N_i is the number of frames that have the i-th label in the ground truth classes : list a list of classes

def annotation_class(self) -> str:
1639    def annotation_class(self) -> str:
1640        """Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon).
1641
1642        Returns
1643        -------
1644        annotation_class : str
1645            the type of annotation
1646
1647        """
1648        return self.annotation_store.annotation_class()

Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon).

Returns

annotation_class : str the type of annotation

def set_normalization_stats(self, stats: Dict) -> None:
1650    def set_normalization_stats(self, stats: Dict) -> None:
1651        """Set the stats to normalize data at runtime.
1652
1653        Parameters
1654        ----------
1655        stats : dict
1656            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1657            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1658
1659        """
1660        self.stats = stats

Set the stats to normalize data at runtime.

Parameters

stats : dict a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' and values are the statistics in torch tensors of shape (#features, 1)

def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]:
1662    def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]:
1663        """Get the minimum and maximum frame numbers for each clip in a video.
1664
1665        Parameters
1666        ----------
1667        video_id : str
1668            the video id
1669
1670        Returns
1671        -------
1672        min_frames : dict
1673            a dictionary where keys are clip ids and values are the minimum frame numbers
1674        max_frames : dict
1675            a dictionary where keys are clip ids and values are the maximum frame numbers
1676
1677        """
1678        coords = self.input_store.get_original_coordinates()
1679        clips = set(
1680            [
1681                self.input_store.get_clip_id(c)
1682                for c in coords
1683                if self.input_store.get_video_id(c) == video_id
1684            ]
1685        )
1686        min_frames = {}
1687        max_frames = {}
1688        for clip in clips:
1689            start = self.input_store.get_clip_start(video_id, clip)
1690            end = start + self.input_store.get_clip_length(video_id, clip)
1691            min_frames[clip] = start
1692            max_frames[clip] = end - 1
1693        return min_frames, max_frames

Get the minimum and maximum frame numbers for each clip in a video.

Parameters

video_id : str the video id

Returns

min_frames : dict a dictionary where keys are clip ids and values are the minimum frame numbers max_frames : dict a dictionary where keys are clip ids and values are the maximum frame numbers

def get_normalization_stats(self, skip_keys=None) -> Dict:
1695    def get_normalization_stats(self, skip_keys=None) -> Dict:
1696        """Get mean and standard deviation for each key.
1697
1698        Parameters
1699        ----------
1700        skip_keys : list, optional
1701            a list of keys to skip
1702
1703        Returns
1704        -------
1705        stats : dict
1706            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1707            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1708
1709        """
1710        stats = defaultdict(lambda: {})
1711        sums = defaultdict(lambda: 0)
1712        if skip_keys is None:
1713            skip_keys = []
1714        counter = defaultdict(lambda: 0)
1715        for sample in tqdm(self):
1716            for key, value in sample["input"].items():
1717                key_name = key.split("---")[0]
1718                if key_name not in skip_keys:
1719                    sums[key_name] += value[:, value.sum(0) != 0].sum(-1)
1720                counter[key_name] += torch.sum(value.sum(0) != 0)
1721        for key, value in sums.items():
1722            stats[key]["mean"] = (value / counter[key]).unsqueeze(-1)
1723        sums = defaultdict(lambda: 0)
1724        for sample in tqdm(self):
1725            for key, value in sample["input"].items():
1726                key_name = key.split("---")[0]
1727                if key_name not in skip_keys:
1728                    sums[key_name] += (
1729                        (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2
1730                    ).sum(-1)
1731        for key, value in sums.items():
1732            stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key])
1733        return stats

Get mean and standard deviation for each key.

Parameters

skip_keys : list, optional a list of keys to skip

Returns

stats : dict a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' and values are the statistics in torch tensors of shape (#features, 1)