dlc2action.data.dataset

Behavior dataset (class that manages high-level data interactions)

   1#
   2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
   5#
   6"""
   7Behavior dataset (class that manages high-level data interactions)
   8"""
   9
  10import warnings
  11from typing import Dict, Union, Tuple, List, Optional, Any
  12
  13from numpy import ndarray
  14from torch.utils.data import Dataset
  15import torch
  16from abc import ABC
  17import numpy as np
  18from copy import copy
  19from tqdm import tqdm
  20from collections import Counter
  21import inspect
  22from collections import defaultdict
  23from dlc2action.utils import (
  24    apply_threshold_hysteresis,
  25    apply_threshold,
  26    apply_threshold_max,
  27)
  28from dlc2action.data.base_store import InputStore, AnnotationStore
  29from copy import deepcopy, copy
  30import os
  31import pickle
  32from dlc2action import options
  33
  34
  35class BehaviorDataset(Dataset, ABC):
  36    """
  37    A generalized dataset class
  38
  39    Data and annotation are stored in separate InputStore and AnnotationStore objects; the dataset class
  40    manages their interactions.
  41    """
  42
  43    def __init__(
  44        self,
  45        data_type: str,
  46        annotation_type: str = "none",
  47        ssl_transformations: List = None,
  48        saved_data_path: str = None,
  49        input_store: InputStore = None,
  50        annotation_store: AnnotationStore = None,
  51        only_load_annotated: bool = False,
  52        recompute_annotation: bool = False,
  53        # mask: str = None,
  54        ids: List = None,
  55        **data_parameters,
  56    ) -> None:
  57        """
  58        Parameters
  59        ----------
  60        data_type : str
  61            the data type (see available types by running BehaviorDataset.data_types())
  62        annotation_type : str
  63            the annotation type (see available types by running BehaviorDataset.annotation_types())
  64        ssl_transformations : list
  65            a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple
  66        saved_data_path : str
  67            the path to a pre-computed pickled dataset
  68        input_store : InputStore
  69            a pre-computed input store
  70        annotation_store : AnnotationStore
  71            a precomputed annotation store
  72        only_load_annotated : bool
  73            if True, the input files that don't have a matching annotation file will be disregarded
  74        *data_parameters : dict
  75            parameters to initialize the input and annotation stores
  76        """
  77
  78        mask = None
  79        if len(data_parameters) == 0:
  80            recompute_annotation = False
  81        feature_extraction = data_parameters.get("feature_extraction")
  82        if feature_extraction is not None and not issubclass(
  83            options.input_stores[data_type],
  84            options.feature_extractors[feature_extraction].input_store_class,
  85        ):
  86            raise ValueError(
  87                f"The {feature_extraction} feature extractor does not work with "
  88                f"the {data_type} data type, please choose a suclass of "
  89                f"{options.feature_extractors[feature_extraction].input_store_class}"
  90            )
  91        if ssl_transformations is None:
  92            ssl_transformations = []
  93        self.ssl_transformations = ssl_transformations
  94        self.input_type = data_type
  95        self.annotation_type = annotation_type
  96        self.stats = None
  97        if mask is not None:
  98            with open(mask, "rb") as f:
  99                self.mask = pickle.load(f)
 100        else:
 101            self.mask = None
 102        self.ids = ids
 103        self.tag = None
 104        self.return_unlabeled = None
 105        # load saved key objects for annotation and input if they exist
 106        input_key_objects, annotation_key_objects = None, None
 107        if saved_data_path is not None:
 108            if os.path.exists(saved_data_path):
 109                with open(saved_data_path, "rb") as f:
 110                    input_key_objects, annotation_key_objects = pickle.load(f)
 111        # if the input or the annotation store need to be created, generate the common video order
 112        if len(data_parameters) > 0:
 113            input_files = options.input_stores[data_type].get_file_ids(
 114                **data_parameters
 115            )
 116            annotation_files = options.annotation_stores[annotation_type].get_file_ids(
 117                **data_parameters
 118            )
 119            if only_load_annotated:
 120                data_parameters["video_order"] = [
 121                    x for x in input_files if x in annotation_files
 122                ]
 123            else:
 124                data_parameters["video_order"] = input_files
 125            if len(data_parameters["video_order"]) == 0:
 126                raise RuntimeError(
 127                    "The length of file list is 0! Please check your data parameters!"
 128                )
 129        data_parameters["mask"] = self.mask
 130        # load or create the input store
 131        ok = False
 132        if input_store is not None:
 133            self.input_store = input_store
 134            ok = True
 135        elif input_key_objects is not None:
 136            try:
 137                self.input_store = self._load_input_store(data_type, input_key_objects)
 138                ok = True
 139            except:
 140                warnings.warn("Loading input store from key objects failed")
 141        if not ok:
 142            self.input_store = self._get_input_store(
 143                data_type, deepcopy(data_parameters)
 144            )
 145        # get the objects needed to create the annotation store (like a clip length dictionary)
 146        annotation_objects = self.input_store.get_annotation_objects()
 147        data_parameters.update(annotation_objects)
 148        # load or create the annotation store
 149        ok = False
 150        if annotation_store is not None:
 151            self.annotation_store = annotation_store
 152            ok = True
 153        elif (
 154            annotation_key_objects is not None
 155            and mask is None
 156            and not recompute_annotation
 157        ):
 158            try:
 159                self.annotation_store = self._load_annotation_store(
 160                    annotation_type, annotation_key_objects
 161                )
 162                ok = True
 163            except:
 164                warnings.warn("Loading annotation store from key objects failed")
 165        if not ok:
 166            self.annotation_store = self._get_annotation_store(
 167                annotation_type, deepcopy(data_parameters)
 168            )
 169        if (
 170            mask is None
 171            and annotation_type != "none"
 172            and not recompute_annotation
 173            and (
 174                self.annotation_store.get_original_coordinates()
 175                != self.input_store.get_original_coordinates()
 176            ).any()
 177        ):
 178            raise RuntimeError(
 179                "The clip orders in the annotation store and input store are different!"
 180            )
 181        # filter the data based on data parameters
 182        # print(f"1 {self.annotation_store.get_original_coordinates().shape=}")
 183        # print(f"1 {self.input_store.get_original_coordinates().shape=}")
 184        to_remove = self.annotation_store.filtered_indices()
 185        if len(to_remove) > 0:
 186            print(
 187                f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples"
 188            )
 189        if len(self.input_store) == len(self.annotation_store):
 190            self.input_store.remove(to_remove)
 191        self.annotation_store.remove(to_remove)
 192        self.input_indices = list(range(len(self.input_store)))
 193        self.annotation_indices = list(range(len(self.input_store)))
 194        self.indices = list(range(len(self.input_store)))
 195        # print(f'{data_parameters["video_order"]=}')
 196        # print(f"{self.annotation_store.get_original_coordinates().shape=}")
 197        # print(f"{self.input_store.get_original_coordinates().shape=}")
 198        # count = 0
 199        # for i, (x, y) in enumerate(zip(
 200        #     self.annotation_store.get_original_coordinates(),
 201        #     self.input_store.get_original_coordinates(),
 202        # )):
 203        #     if (x != y).any():
 204        #         count += 1
 205        #         print({i})
 206        #         print(f"ann: {x}")
 207        #         print(f"inp: {y}")
 208        #         print("\n")
 209        #     if count > 50:
 210        #         break
 211        if annotation_type != "none" and (
 212            self.annotation_store.get_original_coordinates().shape
 213            != self.input_store.get_original_coordinates().shape
 214            or (
 215                self.annotation_store.get_original_coordinates()
 216                != self.input_store.get_original_coordinates()
 217            ).any()
 218        ):
 219            raise RuntimeError(
 220                "The clip orders in the annotation store and input store are different!"
 221            )
 222
 223    def __getitem__(self, item: int) -> Dict:
 224        idx = self._get_idx(item)
 225        input = deepcopy(self.input_store[idx])
 226        target = self.annotation_store[idx]
 227        tag = self.input_store.get_tag(idx)
 228        ssl_inputs, ssl_targets = self._get_SSL_targets(input)
 229        batch = {"input": input}
 230        for name, x in zip(
 231            ["target", "ssl_inputs", "ssl_targets", "tag"],
 232            [target, ssl_inputs, ssl_targets, tag],
 233        ):
 234            if x is not None:
 235                batch[name] = x
 236        batch["index"] = idx
 237        if self.stats is not None:
 238            for key in batch["input"].keys():
 239                key_name = key.split("---")[0]
 240                if key_name in self.stats:
 241                    batch["input"][key][:, batch["input"][key].sum(0) != 0] = (
 242                        (batch["input"][key] - self.stats[key_name]["mean"])
 243                        / (self.stats[key_name]["std"] + 1e-7)
 244                    )[:, batch["input"][key].sum(0) != 0]
 245        return batch
 246
 247    def __len__(self) -> int:
 248        return len(self.indices)
 249        # if self.annotation_type != "none":
 250        #     return self.annotation_store.get_len(return_unlabeled=self.return_unlabeled)
 251        # else:
 252        #     return len(self.input_store)
 253
 254    def get_tags(self) -> List:
 255        """
 256        Get a list of all meta tags
 257
 258        Returns
 259        -------
 260        tags: List
 261            a list of unique meta tag values
 262        """
 263
 264        return self.input_store.get_tags()
 265
 266    def save(self, save_path: str) -> None:
 267        """
 268        Save the dictionary
 269
 270        Parameters
 271        ----------
 272        save_path : str
 273            the path where the pickled file will be stored
 274        """
 275
 276        input_obj = self.input_store.key_objects()
 277        annotation_obj = self.annotation_store.key_objects()
 278        with open(save_path, "wb") as f:
 279            pickle.dump((input_obj, annotation_obj), f)
 280
 281    def to_ram(self) -> None:
 282        """
 283        Transfer the dataset to RAM
 284        """
 285
 286        self.input_store.to_ram()
 287        self.annotation_store.to_ram()
 288
 289    def generate_full_length_gt(self) -> Dict:
 290        if self.annotation_class() == "exclusive_classification":
 291            gt = torch.zeros((len(self), self.len_segment()))
 292        else:
 293            gt = torch.zeros(
 294                (len(self), len(self.behaviors_dict()), self.len_segment())
 295            )
 296        for i in range(len(self)):
 297            gt[i] = self.annotation_store[i]
 298        return self.generate_full_length_prediction(gt)
 299
 300    def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict:
 301        """
 302        Map predictions for the equal-length pieces to predictions for the original data
 303
 304        Probabilities are averaged over predictions on overlapping intervals.
 305
 306        Parameters
 307        ----------
 308        predicted: torch.Tensor
 309            a tensor of predicted probabilities of shape `(N, #classes, #frames)`
 310
 311        Returns
 312        -------
 313        full_length_prediction : dict
 314            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
 315            averaged probability tensors
 316        """
 317
 318        result = defaultdict(lambda: {})
 319        counter = defaultdict(lambda: {})
 320        coordinates = self.input_store.get_original_coordinates()
 321        for coords, prediction in zip(coordinates, predicted):
 322            l = self.input_store.get_clip_length_from_coords(coords)
 323            video_name = self.input_store.get_video_id(coords)
 324            clip_id = self.input_store.get_clip_id(coords)
 325            start, end = self.input_store.get_clip_start_end(coords)
 326            if clip_id not in result[video_name].keys():
 327                result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
 328                counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
 329            result[video_name][clip_id][..., start:end] += (
 330                prediction.squeeze()[..., : end - start].detach().cpu()
 331            )
 332            counter[video_name][clip_id][..., start:end] += 1
 333        for video_name in result:
 334            for clip_id in result[video_name]:
 335                result[video_name][clip_id] /= counter[video_name][clip_id]
 336                result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100
 337        result = dict(result)
 338        return result
 339
 340    def find_valleys(
 341        self,
 342        predicted: Union[torch.Tensor, Dict],
 343        threshold: float = 0.5,
 344        min_frames: int = 0,
 345        visibility_min_score: float = 0,
 346        visibility_min_frac: float = 0,
 347        main_class: int = 1,
 348        low: bool = True,
 349        predicted_error: torch.Tensor = None,
 350        error_threshold: float = 0.5,
 351        hysteresis: bool = False,
 352        threshold_diff: float = None,
 353        min_frames_error: int = None,
 354        smooth_interval: int = 1,
 355        cut_annotated: bool = False,
 356    ) -> Dict:
 357        """
 358        Find the intervals where the probability of a certain class is below or above a certain hard_threshold
 359
 360        Parameters
 361        ----------
 362        predicted : torch.Tensor | dict
 363            either a tensor of predictions for the data prompts or the output of
 364            `BehaviorDataset.generate_full_length_prediction`
 365        threshold : float, default 0.5
 366            the main hard_threshold
 367        min_frames : int, default 0
 368            the minimum length of the intervals
 369        visibility_min_score : float, default 0
 370            the minimum visibility score in the intervals
 371        visibility_min_frac : float, default 0
 372            fraction of the interval that has to have the visibility score larger than visibility_score_thr
 373        main_class : int, default 1
 374            the index of the class the function is inspecting
 375        low : bool, default True
 376            if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above
 377        predicted_error : torch.Tensor, optional
 378            a tensor of error predictions for the data prompts
 379        error_threshold : float, default 0.5
 380            maximum possible probability of error at the intervals
 381        hysteresis: bool, default False
 382            if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff
 383        threshold_diff: float, optional
 384            the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes
 385        min_frames_error: int, optional
 386            if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames
 387
 388        Returns
 389        -------
 390        valleys : dict
 391            a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals
 392        """
 393
 394        result = defaultdict(lambda: [])
 395        if type(predicted) is not dict:
 396            predicted = self.generate_full_length_prediction(predicted)
 397        if predicted_error is not None:
 398            predicted_error = self.generate_full_length_prediction(predicted_error)
 399        elif min_frames_error is not None and min_frames_error != 0:
 400            # warnings.warn(
 401            #     f"The min_frames_error parameter is set to {min_frames_error} but no error prediction "
 402            #     f"is given! Setting min_frames_error to 0."
 403            # )
 404            min_frames_error = 0
 405        if low and hysteresis and threshold_diff is None:
 406            raise ValueError(
 407                "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff."
 408            )
 409        if cut_annotated:
 410            masked_intervals_dict = self.get_annotated_intervals()
 411        else:
 412            masked_intervals_dict = None
 413        print("Valleys found:")
 414        for v_id in predicted:
 415            for clip_id in predicted[v_id].keys():
 416                if predicted_error is not None:
 417                    error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold
 418                    if min_frames_error is not None:
 419                        output, indices, counts = torch.unique_consecutive(
 420                            error_mask, return_inverse=True, return_counts=True
 421                        )
 422                        wrong_indices = torch.where(
 423                            output * (counts < min_frames_error)
 424                        )[0]
 425                        if len(wrong_indices) > 0:
 426                            for i in wrong_indices:
 427                                error_mask[indices == i] = False
 428                else:
 429                    error_mask = None
 430                if masked_intervals_dict is not None:
 431                    masked_intervals = masked_intervals_dict[v_id][clip_id]
 432                else:
 433                    masked_intervals = None
 434                if not hysteresis:
 435                    res_indices_start, res_indices_end = apply_threshold(
 436                        predicted[v_id][clip_id][main_class, :],
 437                        threshold,
 438                        low,
 439                        error_mask,
 440                        min_frames,
 441                        smooth_interval,
 442                        masked_intervals,
 443                    )
 444                elif threshold_diff is not None:
 445                    if low:
 446                        soft_threshold = threshold + threshold_diff
 447                    else:
 448                        soft_threshold = threshold - threshold_diff
 449                    res_indices_start, res_indices_end = apply_threshold_hysteresis(
 450                        predicted[v_id][clip_id][main_class, :],
 451                        soft_threshold,
 452                        threshold,
 453                        low,
 454                        error_mask,
 455                        min_frames,
 456                        smooth_interval,
 457                        masked_intervals,
 458                    )
 459                else:
 460                    res_indices_start, res_indices_end = apply_threshold_max(
 461                        predicted[v_id][clip_id],
 462                        threshold,
 463                        main_class,
 464                        error_mask,
 465                        min_frames,
 466                        smooth_interval,
 467                        masked_intervals,
 468                    )
 469                start = self.input_store.get_clip_start(v_id, clip_id)
 470                result[v_id] += [
 471                    [i + start, j + start, clip_id]
 472                    for i, j in zip(res_indices_start, res_indices_end)
 473                    if self.input_store.get_visibility(
 474                        v_id, clip_id, i, j, visibility_min_score
 475                    )
 476                    > visibility_min_frac
 477                ]
 478            result[v_id] = sorted(result[v_id])
 479            print(f"    {v_id}: {len(result[v_id])}")
 480        return dict(result)
 481
 482    def valleys_union(self, valleys_list) -> Dict:
 483        """
 484        Find the intersection of two valleys dictionaries
 485
 486        Parameters
 487        ----------
 488        valleys_list : list
 489            a list of valleys dictionaries
 490
 491        Returns
 492        -------
 493        intersection : dict
 494            a new valleys dictionary with the intersection of the input intervals
 495        """
 496
 497        valleys_list = [x for x in valleys_list if x is not None]
 498        if len(valleys_list) == 1:
 499            return valleys_list[0]
 500        elif len(valleys_list) == 0:
 501            return {}
 502        union = {}
 503        keys_list = [set(valleys.keys()) for valleys in valleys_list]
 504        keys = set.union(*keys_list)
 505        for v_id in keys:
 506            res = []
 507            clips_list = [
 508                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
 509            ]
 510            clips = set.union(*clips_list)
 511            for clip_id in clips:
 512                clip_intervals = [
 513                    x
 514                    for valleys in valleys_list
 515                    for x in valleys[v_id]
 516                    if x[-1] == clip_id
 517                ]
 518                v_len = self.input_store.get_clip_length(v_id, clip_id)
 519                arr = torch.zeros(v_len)
 520                for start, end, _ in clip_intervals:
 521                    arr[start:end] += 1
 522                output, indices, counts = torch.unique_consecutive(
 523                    arr > 0, return_inverse=True, return_counts=True
 524                )
 525                long_indices = torch.where(output)[0]
 526                res += [
 527                    (
 528                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
 529                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
 530                        clip_id,
 531                    )
 532                    for i in long_indices
 533                ]
 534            union[v_id] = res
 535        return union
 536
 537    def valleys_intersection(self, valleys_list) -> Dict:
 538        """
 539        Find the intersection of two valleys dictionaries
 540
 541        Parameters
 542        ----------
 543        valleys_list : list
 544            a list of valleys dictionaries
 545
 546        Returns
 547        -------
 548        intersection : dict
 549            a new valleys dictionary with the intersection of the input intervals
 550        """
 551
 552        valleys_list = [x for x in valleys_list if x is not None]
 553        if len(valleys_list) == 1:
 554            return valleys_list[0]
 555        elif len(valleys_list) == 0:
 556            return {}
 557        intersection = {}
 558        keys_list = [set(valleys.keys()) for valleys in valleys_list]
 559        keys = set.intersection(*keys_list)
 560        for v_id in keys:
 561            res = []
 562            clips_list = [
 563                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
 564            ]
 565            clips = set.intersection(*clips_list)
 566            for clip_id in clips:
 567                clip_intervals = [
 568                    x
 569                    for valleys in valleys_list
 570                    for x in valleys[v_id]
 571                    if x[-1] == clip_id
 572                ]
 573                v_len = self.input_store.get_clip_length(v_id, clip_id)
 574                arr = torch.zeros(v_len)
 575                for start, end, _ in clip_intervals:
 576                    arr[start:end] += 1
 577                output, indices, counts = torch.unique_consecutive(
 578                    arr, return_inverse=True, return_counts=True
 579                )
 580                long_indices = torch.where(output == 2)[0]
 581                res += [
 582                    (
 583                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
 584                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
 585                        clip_id,
 586                    )
 587                    for i in long_indices
 588                ]
 589            intersection[v_id] = res
 590        return intersection
 591
 592    def partition_train_test_val(
 593        self,
 594        use_test: float = 0,
 595        split_path: str = None,
 596        method: str = "random",
 597        val_frac: float = 0,
 598        test_frac: float = 0,
 599        save_split: bool = False,
 600        normalize: bool = False,
 601        skip_normalization_keys: List = None,
 602        stats: Dict = None,
 603    ) -> Tuple:
 604        """
 605        Partition the dataset into three new datasets
 606
 607        Parameters
 608        ----------
 609        use_test : float, default 0
 610            The fraction of the test dataset to be used in training without labels
 611        split_path : str, optional
 612            The path to load the split information from (if `'file'` method is used) and to save it to
 613            (if `'save_split'` is `True`)
 614        method : {'random', 'random:test-from-name', 'random:test-from-name:{name}',
 615            'val-from-name:{val_name}:test-from-name:{test_name}',
 616            'random:equalize:segments', 'random:equalize:videos',
 617            'folders', 'time', 'time:strict', 'file'}
 618            The partitioning method:
 619            - `'random'`: sort videos into subsets randomly,
 620            - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation
 621                subsets randomly and create
 622                the test subset from the video ids that start with a speific substring (`'test'` by default, or `name`
 623                if provided),
 624            - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but
 625                making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain
 626                occurences of the class get into the validation subset and `0.8 * test_frac` get into the test subset;
 627                this in ensured for all classes in order of increasing number of occurences until the validation and
 628                test subsets are full
 629            - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test
 630                subsets from the video ids that start with specific substrings (`val_name` for validation
 631                and `test_name` for test) and sort all other videos into the training subset
 632            - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets,
 633            - `'time'`: split each video into training, validation and test subsequences,
 634            - `'time:strict'`: split each video into validation, test and training subsequences
 635                and throw out the last segments in validation and test (to get rid of overlaps),
 636            - `'file'`: split according to a split file.
 637        val_frac : float, default 0
 638            The fraction of the dataset to be used in validation
 639        test_frac : float, default 0
 640            The fraction of the dataset to be used in test
 641        save_split : bool, default False
 642            Save a split file if True
 643
 644        Returns
 645        -------
 646        train_dataset : BehaviorDataset
 647            train dataset
 648
 649        val_dataset : BehaviorDataset
 650            validation dataset
 651
 652        test_dataset : BehaviorDataset
 653            test dataset
 654        """
 655
 656        train_indices, test_indices, val_indices = self._partition_indices(
 657            split_path=split_path,
 658            method=method,
 659            val_frac=val_frac,
 660            test_frac=test_frac,
 661            save_split=save_split,
 662        )
 663        val_dataset = self._create_new_dataset(val_indices)
 664        test_dataset = self._create_new_dataset(test_indices)
 665        train_dataset = self._create_new_dataset(
 666            train_indices, ssl_indices=test_indices[: int(len(test_indices) * use_test)]
 667        )
 668        print("Number of samples:")
 669        print(f"    validation:")
 670        print(f"      {val_dataset.count_classes()}")
 671        print(f"    training:")
 672        print(f"      {train_dataset.count_classes()}")
 673        print(f"    test:")
 674        print(f"      {test_dataset.count_classes()}")
 675        if normalize:
 676            if stats is None:
 677                print("Computing normalization statistics...")
 678                stats = train_dataset.get_normalization_stats(skip_normalization_keys)
 679            else:
 680                print("Setting loaded normalization statistics...")
 681            train_dataset.set_normalization_stats(stats)
 682            val_dataset.set_normalization_stats(stats)
 683            test_dataset.set_normalization_stats(stats)
 684        return train_dataset, test_dataset, val_dataset
 685
 686    def class_weights(self, proportional=False) -> List:
 687        """
 688        Calculate class weights in inverse proportion to number of samples
 689        Returns
 690        -------
 691        weights: list
 692            a list of class weights
 693        """
 694
 695        items = sorted(
 696            [
 697                (k, v)
 698                for k, v in self.annotation_store.count_classes().items()
 699                if k != -100
 700            ]
 701        )
 702        if self.annotation_store.annotation_class() == "exclusive_classification":
 703            if not proportional:
 704                numerator = len(self.annotation_store)
 705            else:
 706                numerator = max([x[1] for x in items])
 707            weights = [numerator / (v + 1e-7) for _, v in items]
 708        else:
 709            items_zero = sorted(
 710                [
 711                    (k, v)
 712                    for k, v in self.annotation_store.count_classes(zeros=True).items()
 713                    if k != -100
 714                ]
 715            )
 716            if not proportional:
 717                numerators = defaultdict(lambda: len(self.annotation_store))
 718            else:
 719                numerators = {
 720                    item_one[0]: max(item_one[1], item_zero[1])
 721                    for item_one, item_zero in zip(items, items_zero)
 722                }
 723            weights = {}
 724            weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero]
 725            weights[1] = [numerators[k] / (v + 1e-7) for k, v in items]
 726        return weights
 727
 728    def boundary_class_weight(self):
 729        if self.annotation_type != "none":
 730            f = self.annotation_store.data.flatten()
 731            _, inv = torch.unique_consecutive(f, return_inverse=True)
 732            boundary = torch.cat([torch.tensor([0]), torch.diff(inv)]).reshape(
 733                self.annotation_store.data.shape
 734            )
 735            boundary[..., 0] = 0
 736            cnt = Counter(boundary.flatten().numpy())
 737            return cnt[1] / cnt[0]
 738        else:
 739            return 0
 740
 741    def count_classes(self, bouts: bool = False) -> Dict:
 742        """
 743        Get a class counter dictionary
 744
 745        Parameters
 746        ----------
 747        bouts : bool, default False
 748            if `True`, instead of frame counts segment counts are returned
 749
 750        Returns
 751        -------
 752        count_dictionary : dict
 753            a dictionary with class indices as keys and frame or bout counts as values
 754        """
 755
 756        return self.annotation_store.count_classes(bouts=bouts)
 757
 758    def behaviors_dict(self) -> Dict:
 759        """
 760        Get a behavior dictionary
 761
 762        Returns
 763        -------
 764        dict
 765            behavior dictionary
 766        """
 767
 768        return self.annotation_store.behaviors_dict()
 769
 770    def bodyparts_order(self) -> List:
 771        try:
 772            return self.input_store.get_bodyparts()
 773        except:
 774            raise RuntimeError(
 775                f"The {self.input_type} input store does not have bodyparts implemented!"
 776            )
 777
 778    def features_shape(self) -> Dict:
 779        """
 780        Get the shapes of the input features
 781
 782        Returns
 783        -------
 784        shapes : Dict
 785            a dictionary with the shapes of the features
 786        """
 787
 788        sample = self.input_store[0]
 789        shapes = {k: v.shape for k, v in sample.items()}
 790        # for key, value in shapes.items():
 791        #     print(f'{key}: {value}')
 792        return shapes
 793
 794    def num_classes(self) -> int:
 795        """
 796        Get the number of classes in the data
 797
 798        Returns
 799        -------
 800        num_classes : int
 801            the number of classes
 802        """
 803
 804        return len(self.annotation_store.behaviors_dict())
 805
 806    def len_segment(self) -> int:
 807        """
 808        Get the segment length in the data
 809
 810        Returns
 811        -------
 812        len_segment : int
 813            the segment length
 814        """
 815
 816        sample = self.input_store[0]
 817        key = list(sample.keys())[0]
 818        return sample[key].shape[-1]
 819
 820    def set_ssl_transformations(self, ssl_transformations: List) -> None:
 821        """
 822        Set new SSL transformations
 823
 824        Parameters
 825        ----------
 826        ssl_transformations : list
 827            a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets
 828            lists
 829        """
 830
 831        self.ssl_transformations = ssl_transformations
 832
 833    @classmethod
 834    def new(cls, *args, **kwargs):
 835        """
 836        Create a new object of the same class
 837
 838        Returns
 839        -------
 840        new_instance: BehaviorDataset
 841            a new instance of the same class
 842        """
 843
 844        return cls(*args, **kwargs)
 845
 846    @classmethod
 847    def get_parameters(cls, data_type: str, annotation_type: str) -> List:
 848        """
 849        Get parameters necessary for initialization
 850
 851        Parameters
 852        ----------
 853        data_type : str
 854            the data type
 855        annotation_type : str
 856            the annotation type
 857        """
 858
 859        input_features = options.input_stores[data_type].get_parameters()
 860        annotation_features = options.annotation_stores[
 861            annotation_type
 862        ].get_parameters()
 863        self_features = inspect.getfullargspec(cls.__init__).args
 864        return self_features + input_features + annotation_features
 865
 866    @staticmethod
 867    def data_types() -> List:
 868        """
 869        List available data types
 870
 871        Returns
 872        -------
 873        data_types : list
 874            available data types
 875        """
 876
 877        return list(options.input_stores.keys())
 878
 879    @staticmethod
 880    def annotation_types() -> List:
 881        """
 882        List available annotation types
 883
 884        Returns
 885        -------
 886        annotation_types : list
 887            available annotation types
 888        """
 889
 890        return list(options.annotation_stores.keys())
 891
 892    def _get_SSL_targets(self, input: Dict) -> Tuple[List, List]:
 893        """
 894        Get the SSL inputs and targets from a sample dictionary
 895        """
 896
 897        ssl_inputs = []
 898        ssl_targets = []
 899        for transform in self.ssl_transformations:
 900            ssl_input, ssl_target = transform(copy(input))
 901            ssl_inputs.append(ssl_input)
 902            ssl_targets.append(ssl_target)
 903        return ssl_inputs, ssl_targets
 904
 905    def _create_new_dataset(self, indices: List, ssl_indices: List = None):
 906        """
 907        Create a subsample of the dataset, with samples at ssl_indices losing the annotation
 908        """
 909
 910        if ssl_indices is None:
 911            ssl_indices = []
 912        input_store = self.input_store.create_subsample(indices, ssl_indices)
 913        annotation_store = self.annotation_store.create_subsample(indices, ssl_indices)
 914        new = self.new(
 915            data_type=self.input_type,
 916            annotation_type=self.annotation_type,
 917            ssl_transformations=self.ssl_transformations,
 918            annotation_store=annotation_store,
 919            input_store=input_store,
 920            ids=list(indices) + list(ssl_indices),
 921            recompute_annotation=False,
 922        )
 923        return new
 924
 925    def _load_input_store(self, data_type: str, key_objects: Tuple) -> InputStore:
 926        """
 927        Load input store from key objects
 928        """
 929
 930        input_store = options.input_stores[data_type](key_objects=key_objects)
 931        return input_store
 932
 933    def _load_annotation_store(
 934        self, annotation_type: str, key_objects: Tuple
 935    ) -> AnnotationStore:
 936        """
 937        Load annotation store from key objects
 938        """
 939
 940        annotation_store = options.annotation_stores[annotation_type](
 941            key_objects=key_objects
 942        )
 943        return annotation_store
 944
 945    def _get_input_store(self, data_type: str, data_parameters: Dict) -> InputStore:
 946        """
 947        Create input store from parameters
 948        """
 949
 950        data_parameters["key_objects"] = None
 951        input_store = options.input_stores[data_type](**data_parameters)
 952        return input_store
 953
 954    def _get_annotation_store(
 955        self, annotation_type: str, data_parameters: Dict
 956    ) -> AnnotationStore:
 957        """
 958        Create annotation store from parameters
 959        """
 960
 961        annotation_store = options.annotation_stores[annotation_type](**data_parameters)
 962        return annotation_store
 963
 964    def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None:
 965        """
 966        Set the parameters that change the subset that is returned at `__getitem__`
 967
 968        Parameters
 969        ----------
 970        unlabeled : bool
 971            a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and
 972            all if `None`
 973        tag : int
 974            if not `None`, only samples with this meta tag will be returned
 975        """
 976
 977        if unlabeled != self.return_unlabeled:
 978            self.annotation_indices = self.annotation_store.get_indices(unlabeled)
 979            self.return_unlabeled = unlabeled
 980        if tag != self.tag:
 981            self.input_indices = self.input_store.get_indices(tag)
 982            self.tag = tag
 983        self.indices = [x for x in self.annotation_indices if x in self.input_indices]
 984
 985    def _get_idx(self, index: int) -> int:
 986        """
 987        Get index in full dataset
 988        """
 989
 990        return self.indices[index]
 991
 992        # return self.annotation_store.get_idx(
 993        #     index, return_unlabeled=self.return_unlabeled
 994        # )
 995
 996    def _partition_indices(
 997        self,
 998        split_path: str = None,
 999        method: str = "random",
1000        val_frac: float = 0,
1001        test_frac: float = 0,
1002        save_split: bool = False,
1003    ) -> Tuple[List, List, List]:
1004        """
1005        Partition indices into train, validation, test subsets
1006        """
1007
1008        if self.mask is not None:
1009            val_indices = self.mask["val_ids"]
1010            train_indices = [x for x in range(len(self)) if x not in val_indices]
1011            test_indices = []
1012        elif method == "random":
1013            videos = np.array(self.input_store.get_video_id_order())
1014            all_videos = list(set(videos))
1015            if len(all_videos) == 1:
1016                warnings.warn(
1017                    "There is only one video in the dataset, so train/val/test split is done on segments; "
1018                    'that might lead to overlaps, please consider using "time" or "time:strict" as the '
1019                    "partitioning method instead"
1020                )
1021                # Quick fix for single video: the problem with this is that the segments can overlap
1022                # length = int(self.input_store.get_original_coordinates()[-1][1])    # number of segments
1023                length = len(self.input_store.get_original_coordinates())
1024                val_len = int(val_frac * length)
1025                test_len = int(test_frac * length)
1026                all_indices = np.random.choice(np.arange(length), length, replace=False)
1027                val_indices = all_indices[:val_len]
1028                test_indices = all_indices[val_len : val_len + test_len]
1029                train_indices = all_indices[val_len + test_len :]
1030                coords = self.input_store.get_original_coordinates()
1031                if save_split:
1032                    self._save_partition(
1033                        coords[train_indices],
1034                        coords[val_indices],
1035                        coords[test_indices],
1036                        split_path,
1037                        coords=True,
1038                    )
1039            else:
1040                length = len(all_videos)
1041                val_len = int(val_frac * length)
1042                test_len = int(test_frac * length)
1043                validation = all_videos[:val_len]
1044                test = all_videos[val_len : val_len + test_len]
1045                training = all_videos[val_len + test_len :]
1046                train_indices = np.where(np.isin(videos, training))[0]
1047                val_indices = np.where(np.isin(videos, validation))[0]
1048                test_indices = np.where(np.isin(videos, test))[0]
1049                if save_split:
1050                    self._save_partition(training, validation, test, split_path)
1051        elif method.startswith("random:equalize"):
1052            counter = self.count_classes()
1053            counter = sorted(list([(v, k) for k, v in counter.items()]))
1054            classes = [x[1] for x in counter]
1055            indicator = {c: [] for c in classes}
1056            if method.endswith("videos"):
1057                videos = np.array(self.input_store.get_video_id_order())
1058                all_videos = list(set(videos))
1059                total_len = len(all_videos)
1060                for video_id in all_videos:
1061                    video_coords = np.where(videos == video_id)[0]
1062                    ann = torch.cat(
1063                        [self.annotation_store[i] for i in video_coords], dim=-1
1064                    )
1065                    for c in classes:
1066                        if self.annotation_class() == "nonexclusive_classification":
1067                            indicator[c].append(torch.sum(ann[c] == 1) > 0)
1068                        elif self.annotation_class() == "exclusive_classification":
1069                            indicator[c].append(torch.sum(ann == c) > 0)
1070                        else:
1071                            raise ValueError(
1072                                f"The random:equalize partition method is not implemented"
1073                                f"for the {self.annotation_class()} annotation class"
1074                            )
1075            elif method.endswith("segments"):
1076                total_len = len(self)
1077                for ann in self.annotation_store:
1078                    for c in classes:
1079                        if self.annotation_class() == "nonexclusive_classification":
1080                            indicator[c].append(torch.sum(ann[c] == 1) > 0)
1081                        elif self.annotation_class() == "exclusive_classification":
1082                            indicator[c].append(torch.sum(ann == c) > 0)
1083                        else:
1084                            raise ValueError(
1085                                f"The random:equalize partition method is not implemented"
1086                                f"for the {self.annotation_class()} annotation class"
1087                            )
1088            else:
1089                values = []
1090                for v in options.partition_methods.values():
1091                    values += v
1092                raise ValueError(
1093                    f"The {method} partition method is not recognized; please choose from {values}"
1094                )
1095            val_indices = []
1096            test_indices = []
1097            for c in classes:
1098                indicator[c] = np.array(indicator[c])
1099                ind = np.where(indicator[c])[0]
1100                np.random.shuffle(ind)
1101                c_sum = len(ind)
1102                in_val = np.sum(indicator[c][val_indices])
1103                in_test = np.sum(indicator[c][test_indices])
1104                while (
1105                    len(val_indices) < val_frac * total_len
1106                    and in_val < val_frac * c_sum * 0.8
1107                ):
1108                    first, ind = ind[0], ind[1:]
1109                    val_indices = list(set(val_indices).union({first}))
1110                    in_val = np.sum(indicator[c][val_indices])
1111                while (
1112                    len(test_indices) < test_frac * total_len
1113                    and in_test < test_frac * c_sum * 0.8
1114                ):
1115                    first, ind = ind[0], ind[1:]
1116                    test_indices = list(set(test_indices).union({first}))
1117                    in_test = np.sum(indicator[c][test_indices])
1118            if len(val_indices) < int(val_frac * total_len):
1119                left_val = int(val_frac * total_len) - len(val_indices)
1120            else:
1121                left_val = 0
1122            if len(test_indices) < int(test_frac * total_len):
1123                left_test = int(test_frac * total_len) - len(test_indices)
1124            else:
1125                left_test = 0
1126            indicator = np.ones(total_len)
1127            indicator[val_indices] = 0
1128            indicator[test_indices] = 0
1129            ind = np.where(indicator)[0]
1130            np.random.shuffle(ind)
1131            val_indices += list(ind[:left_val])
1132            test_indices += list(ind[left_val : left_val + left_test])
1133            train_indices = list(ind[left_val + left_test :])
1134            if save_split:
1135                if method.endswith("segments"):
1136                    coords = self.input_store.get_original_coordinates()
1137                    self._save_partition(
1138                        coords[train_indices],
1139                        coords[val_indices],
1140                        coords[test_indices],
1141                        coords[split_path],
1142                        coords=True,
1143                    )
1144                else:
1145                    all_videos = np.array(all_videos)
1146                    validation = all_videos[val_indices]
1147                    test = all_videos[test_indices]
1148                    training = all_videos[train_indices]
1149                    self._save_partition(training, validation, test, split_path)
1150        elif method.startswith("random:test-from-name"):
1151            split = method.split(":")
1152            if len(split) > 2:
1153                test_name = split[-1]
1154            else:
1155                test_name = "test"
1156            videos = np.array(self.input_store.get_video_id_order())
1157            all_videos = list(set(videos))
1158            test = []
1159            train_videos = []
1160            for x in all_videos:
1161                if x.startswith(test_name):
1162                    test.append(x)
1163                else:
1164                    train_videos.append(x)
1165            length = len(train_videos)
1166            val_len = int(val_frac * length)
1167            validation = train_videos[:val_len]
1168            training = train_videos[val_len:]
1169            train_indices = np.where(np.isin(videos, training))[0]
1170            val_indices = np.where(np.isin(videos, validation))[0]
1171            test_indices = np.where(np.isin(videos, test))[0]
1172            if save_split:
1173                self._save_partition(training, validation, test, split_path)
1174        elif method.startswith("val-from-name"):
1175            split = method.split(":")
1176            if split[2] != "test-from-name":
1177                raise ValueError(
1178                    f"The {method} partition method is not recognized, please choose from {options.partition_methods}"
1179                )
1180            val_name = split[1]
1181            test_name = split[-1]
1182            videos = np.array(self.input_store.get_video_id_order())
1183            all_videos = list(set(videos))
1184            test = []
1185            validation = []
1186            training = []
1187            for x in all_videos:
1188                if x.startswith(test_name):
1189                    test.append(x)
1190                elif x.startswith(val_name):
1191                    validation.append(x)
1192                else:
1193                    training.append(x)
1194            train_indices = np.where(np.isin(videos, training))[0]
1195            val_indices = np.where(np.isin(videos, validation))[0]
1196            test_indices = np.where(np.isin(videos, test))[0]
1197        elif method == "folders":
1198            folders = np.array(self.input_store.get_folder_order())
1199            videos = np.array(self.input_store.get_video_id_order())
1200            train_indices = np.where(np.isin(folders, ["training", "train"]))[0]
1201            if np.sum(np.isin(folders, ["validation", "val"])) > 0:
1202                val_indices = np.where(np.isin(folders, ["validation", "val"]))[0]
1203            else:
1204                train_videos = list(set(videos[train_indices]))
1205                val_len = int(val_frac * len(train_videos))
1206                validation = train_videos[:val_len]
1207                training = train_videos[val_len:]
1208                train_indices = np.where(np.isin(videos, training))[0]
1209                val_indices = np.where(np.isin(videos, validation))[0]
1210            test_indices = np.where(folders == "test")[0]
1211            if save_split:
1212                self._save_partition(
1213                    list(set(videos[train_indices])),
1214                    list(set(videos[val_indices])),
1215                    list(set(videos[test_indices])),
1216                    split_path,
1217                )
1218        elif method.startswith("leave-one-out"):
1219            n = int(method.split(":")[-1])
1220            videos = np.array(self.input_store.get_video_id_order())
1221            all_videos = sorted(list(set(videos)))
1222            validation = [all_videos.pop(n)]
1223            training = all_videos
1224            train_indices = np.where(np.isin(videos, training))[0]
1225            val_indices = np.where(np.isin(videos, validation))[0]
1226            test_indices = np.array([])
1227        elif method.startswith("time"):
1228            if method.endswith("strict"):
1229                len_segment = self.len_segment()
1230                # TODO: change this
1231                step = self.input_store.step
1232                num_removed = len_segment // step
1233            else:
1234                num_removed = 0
1235            videos = np.array(self.input_store.get_video_id_order())
1236            all_videos = set(videos)
1237            train_indices = []
1238            val_indices = []
1239            test_indices = []
1240            start = 0
1241            if len(method.split(":")) > 1 and method.split(":")[1] == "start-from":
1242                start = float(method.split(":")[2])
1243            for video_id in all_videos:
1244                video_indices = np.where(videos == video_id)[0]
1245                val_len = int(val_frac * len(video_indices))
1246                test_len = int(test_frac * len(video_indices))
1247                start_pos = int(start * len(video_indices))
1248                all_ind = np.ones(len(video_indices))
1249                val_indices += list(video_indices[start_pos : start_pos + val_len])
1250                all_ind[start_pos : start_pos + val_len] = 0
1251                if start_pos + val_len > len(video_indices):
1252                    p = start_pos + val_len - len(video_indices)
1253                    val_indices += list(video_indices[:p])
1254                    all_ind[:p] = 0
1255                else:
1256                    p = start_pos + val_len
1257                test_indices += list(video_indices[p : p + test_len])
1258                all_ind[p : p + test_len] = 0
1259                if p + test_len > len(video_indices):
1260                    p = test_len + p - len(video_indices)
1261                    test_indices += list(video_indices[:p])
1262                    all_ind[:p] = 0
1263                train_indices += list(video_indices[all_ind > 0])
1264                for _ in range(num_removed):
1265                    if len(val_indices) > 0:
1266                        val_indices.pop(-1)
1267                    if len(test_indices) > 0:
1268                        test_indices.pop(-1)
1269                    if start > 0 and len(train_indices) > 0:
1270                        train_indices.pop(-1)
1271        elif method == "file":
1272            if split_path is None:
1273                raise ValueError(
1274                    'You need to either set split_path or change partition method ("file" requires a file)'
1275                )
1276            with open(split_path) as f:
1277                train_line = f.readline()
1278                line = f.readline()
1279                while not line.startswith("Validation") and not line.startswith(
1280                    "Validataion"
1281                ):
1282                    line = f.readline()
1283                if line.startswith("Validation"):
1284                    validation = []
1285                    test = []
1286                    while True:
1287                        line = f.readline()
1288                        if line.startswith("Test") or len(line) == 0:
1289                            break
1290                        validation.append(line.rstrip())
1291                    while True:
1292                        line = f.readline()
1293                        if len(line) == 0:
1294                            break
1295                        test.append(line.rstrip())
1296                    type = train_line[9:-2]
1297                else:
1298                    line = f.readline()
1299                    validation = line.rstrip(",\n ").split(", ")
1300                    test = []
1301                    type = "videos"
1302            if type == "videos":
1303                videos = np.array(self.input_store.get_video_id_order())
1304                val_indices = np.where(np.isin(videos, validation))[0]
1305                test_indices = np.where(np.isin(videos, test))[0]
1306            elif type == "coords":
1307                coords = self.input_store.get_original_coordinates()
1308                video_ids = self.input_store.get_video_id_order()
1309                clip_ids = [self.input_store.get_clip_id(coord) for coord in coords]
1310                starts, ends = zip(
1311                    *[self.input_store.get_clip_start_end(coord) for coord in coords]
1312                )
1313                coords = np.array(
1314                    [
1315                        f"{video_id}---{clip_id}---{start}-{end}"
1316                        for video_id, clip_id, start, end in zip(
1317                            video_ids, clip_ids, starts, ends
1318                        )
1319                    ]
1320                )
1321                val_indices = np.where(np.isin(coords, validation))[0]
1322                test_indices = np.where(np.isin(coords, test))[0]
1323            else:
1324                raise ValueError("The split path has unrecognized format!")
1325            all = np.ones(len(self))
1326            all[val_indices] = 0
1327            all[test_indices] = 0
1328            train_indices = np.where(all)[0]
1329        else:
1330            raise ValueError(
1331                f"The {method} partition is not recognized, please choose from {options.partition_methods}"
1332            )
1333        return sorted(train_indices), sorted(test_indices), sorted(val_indices)
1334
1335    def _save_partition(
1336        self,
1337        training: List,
1338        validation: List,
1339        test: List,
1340        split_path: str,
1341        coords: bool = False,
1342    ) -> None:
1343        """
1344        Save a split file
1345        """
1346
1347        if coords:
1348            name = "coords"
1349            training_coords = []
1350            val_coords = []
1351            test_coords = []
1352            for coord in training:
1353                video_id = self.input_store.get_video_id(coord)
1354                clip_id = self.input_store.get_clip_id(coord)
1355                start, end = self.input_store.get_clip_start_end(coord)
1356                training_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1357            for coord in validation:
1358                video_id = self.input_store.get_video_id(coord)
1359                clip_id = self.input_store.get_clip_id(coord)
1360                start, end = self.input_store.get_clip_start_end(coord)
1361                val_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1362            for coord in test:
1363                video_id = self.input_store.get_video_id(coord)
1364                clip_id = self.input_store.get_clip_id(coord)
1365                start, end = self.input_store.get_clip_start_end(coord)
1366                test_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1367            training, validation, test = training_coords, val_coords, test_coords
1368        else:
1369            name = "videos"
1370        if split_path is not None:
1371            with open(split_path, "w") as f:
1372                f.write(f"Training {name}:\n")
1373                for x in training:
1374                    f.write(x + "\n")
1375                f.write(f"Validation {name}:\n")
1376                for x in validation:
1377                    f.write(x + "\n")
1378                f.write(f"Test {name}:\n")
1379                for x in test:
1380                    f.write(x + "\n")
1381
1382    def _get_intervals_from_ind(self, frame_indices: np.ndarray):
1383        """
1384        Get a list of intervals from a list of frame indices
1385
1386        Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`.
1387
1388        Parameters
1389        ----------
1390        frame_indices : np.ndarray
1391            a list of frame indices
1392
1393        Returns
1394        -------
1395        intervals : list
1396            a list of interval boundaries
1397        """
1398
1399        masked_intervals = []
1400        breaks = np.where(np.diff(frame_indices) != 1)[0]
1401        if len(frame_indices) > 0:
1402            start = frame_indices[0]
1403            for k in breaks:
1404                masked_intervals.append([start, frame_indices[k] + 1])
1405                start = frame_indices[k + 1]
1406            masked_intervals.append([start, frame_indices[-1] + 1])
1407        return masked_intervals
1408
1409    def get_intervals(self) -> Tuple[dict, Optional[list]]:
1410        """
1411        Get a list of intervals covered by the dataset in the original coordinates
1412
1413        Returns
1414        -------
1415        intervals : dict
1416            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1417            values are lists of the intervals in `[start, end]` format
1418        """
1419
1420        counter = defaultdict(lambda: {})
1421        coordinates = self.input_store.get_original_coordinates()
1422        for coords in coordinates:
1423            l = self.input_store.get_clip_length_from_coords(coords)
1424            video_name = self.input_store.get_video_id(coords)
1425            clip_id = self.input_store.get_clip_id(coords)
1426            start, end = self.input_store.get_clip_start_end(coords)
1427            if clip_id not in counter[video_name]:
1428                counter[video_name][clip_id] = np.zeros(l)
1429            counter[video_name][clip_id][start:end] = 1
1430        result = {video_name: {} for video_name in counter}
1431        for video_name in counter:
1432            for clip_id in counter[video_name]:
1433                result[video_name][clip_id] = self._get_intervals_from_ind(
1434                    np.where(counter[video_name][clip_id])[0]
1435                )
1436        return result, self.ids
1437
1438    def get_unannotated_intervals(self, first_intervals=None) -> Dict:
1439        """
1440        Get a list of intervals in the original coordinates where there is no annotation
1441
1442        Returns
1443        -------
1444        intervals : dict
1445            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1446            values are lists of the intervals in `[start, end]` format
1447        """
1448
1449        counter_value = 2
1450        if first_intervals is None:
1451            first_intervals = defaultdict(lambda: defaultdict(lambda: []))
1452            counter_value = 1
1453        counter = defaultdict(lambda: {})
1454        coordinates = self.input_store.get_original_coordinates()
1455        for i, coords in enumerate(coordinates):
1456            l = self.input_store.get_clip_length_from_coords(coords)
1457            ann = self.annotation_store[i]
1458            if (
1459                self.annotation_store.annotation_class()
1460                == "nonexclusive_classification"
1461            ):
1462                ann = ann[0, :]
1463            video_name = self.input_store.get_video_id(coords)
1464            clip_id = self.input_store.get_clip_id(coords)
1465            start, end = self.input_store.get_clip_start_end(coords)
1466            if clip_id not in counter[video_name]:
1467                counter[video_name][clip_id] = np.ones(l)
1468            counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int()
1469        result = {video_name: {} for video_name in counter}
1470        for video_name in counter:
1471            for clip_id in counter[video_name]:
1472                for start, end in first_intervals[video_name][clip_id]:
1473                    counter[video_name][clip_id][start:end] += 1
1474                result[video_name][clip_id] = self._get_intervals_from_ind(
1475                    np.where(counter[video_name][clip_id] == counter_value)[0]
1476                )
1477        return result
1478
1479    def get_annotated_intervals(self) -> Dict:
1480        """
1481        Get a list of intervals in the original coordinates where there is no annotation
1482
1483        Returns
1484        -------
1485        intervals : dict
1486            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1487            values are lists of the intervals in `[start, end]` format
1488        """
1489
1490        if self.annotation_type == "none":
1491            return []
1492        counter_value = 1
1493        counter = defaultdict(lambda: {})
1494        coordinates = self.input_store.get_original_coordinates()
1495        for i, coords in enumerate(coordinates):
1496            l = self.input_store.get_clip_length_from_coords(coords)
1497            ann = self.annotation_store[i]
1498            video_name = self.input_store.get_video_id(coords)
1499            clip_id = self.input_store.get_clip_id(coords)
1500            start, end = self.input_store.get_clip_start_end(coords)
1501            if clip_id not in counter[video_name]:
1502                counter[video_name][clip_id] = np.zeros(l)
1503            if (
1504                self.annotation_store.annotation_class()
1505                == "nonexclusive_classification"
1506            ):
1507                counter[video_name][clip_id][start:end] = (
1508                    torch.sum(ann[:, : end - start] != -100, dim=0) > 0
1509                ).int()
1510            else:
1511                counter[video_name][clip_id][start:end] = (
1512                    ann[: end - start] != -100
1513                ).int()
1514        result = {video_name: {} for video_name in counter}
1515        for video_name in counter:
1516            for clip_id in counter[video_name]:
1517                result[video_name][clip_id] = self._get_intervals_from_ind(
1518                    np.where(counter[video_name][clip_id] == counter_value)[0]
1519                )
1520        return result
1521
1522    def get_ids(self) -> Dict:
1523        """
1524        Get a dictionary of all clip ids in the dataset
1525
1526        Returns
1527        -------
1528        ids : dict
1529            a dictionary where keys are video ids and values are lists of clip ids
1530        """
1531
1532        coordinates = self.input_store.get_original_coordinates()
1533        video_ids = np.array(self.input_store.get_video_id_order())
1534        id_set = set(video_ids)
1535        result = {}
1536        for video_id in id_set:
1537            coords = coordinates[video_ids == video_id]
1538            clip_ids = list({self.input_store.get_clip_id(c) for c in coords})
1539            result[video_id] = clip_ids
1540        return result
1541
1542    def get_len(self, video_id: str, clip_id: str) -> int:
1543        """
1544        Get the length of a specific clip
1545
1546        Parameters
1547        ----------
1548        video_id : str
1549            the video id
1550        clip_id : str
1551            the clip id
1552
1553        Returns
1554        -------
1555        length : int
1556            the length
1557        """
1558
1559        return self.input_store.get_clip_length(video_id, clip_id)
1560
1561    def get_confusion_matrix(
1562        self, prediction: torch.Tensor, confusion_type: str = "recall"
1563    ) -> Tuple[ndarray, list]:
1564        """
1565        Get a confusion matrix
1566
1567        Parameters
1568        ----------
1569        prediction : torch.Tensor
1570            a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)`
1571        confusion_type : {"recall", "precision"}
1572            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
1573            into account, and if `type` is `"precision"`, only false negatives
1574
1575        Returns
1576        -------
1577        confusion_matrix : np.ndarray
1578            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
1579            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
1580            `N_i` is the number of frames that have the i-th label in the ground truth
1581        classes : list
1582            a list of classes
1583        """
1584
1585        behaviors_dict = self.annotation_store.behaviors_dict()
1586        num_behaviors = len(behaviors_dict)
1587        confusion_matrix = np.zeros((num_behaviors, num_behaviors))
1588        if self.annotation_store.annotation_class() == "exclusive_classification":
1589            exclusive = True
1590            confusion_type = None
1591        elif self.annotation_store.annotation_class() == "nonexclusive_classification":
1592            exclusive = False
1593        else:
1594            raise RuntimeError(
1595                f"The {self.annotation_store.annotation_class()} annotation class is not recognized!"
1596            )
1597        for ann, p in zip(self.annotation_store, prediction):
1598            if exclusive:
1599                class_prediction = torch.max(p, dim=0)[1]
1600                for i in behaviors_dict.keys():
1601                    for j in behaviors_dict.keys():
1602                        confusion_matrix[i, j] += int(
1603                            torch.sum(class_prediction[ann == i] == j)
1604                        )
1605            else:
1606                class_prediction = (p > 0.5).int()
1607                for i in behaviors_dict.keys():
1608                    for j in behaviors_dict.keys():
1609                        if confusion_type == "recall":
1610                            pred = deepcopy(class_prediction[j])
1611                            if i != j:
1612                                pred[ann[j] == 1] = 0
1613                            confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1]))
1614                        elif confusion_type == "precision":
1615                            annotation = deepcopy(ann[j])
1616                            if i != j:
1617                                annotation[class_prediction[j] == 1] = 0
1618                            confusion_matrix[i, j] += int(
1619                                torch.sum(annotation[class_prediction[i] == 1])
1620                            )
1621                        else:
1622                            raise ValueError(
1623                                f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']"
1624                            )
1625        counter = self.annotation_store.count_classes()
1626        for i in behaviors_dict.keys():
1627            if counter[i] != 0:
1628                if confusion_type == "recall" or confusion_type is None:
1629                    confusion_matrix[i, :] /= counter[i]
1630                else:
1631                    confusion_matrix[:, i] /= counter[i]
1632        return confusion_matrix, list(behaviors_dict.values()), confusion_type
1633
1634    def annotation_class(self) -> str:
1635        """
1636        Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon)
1637
1638        Returns
1639        -------
1640        annotation_class : str
1641            the type of annotation
1642        """
1643
1644        return self.annotation_store.annotation_class()
1645
1646    def set_normalization_stats(self, stats: Dict) -> None:
1647        """
1648        Set the stats to normalize data at runtime
1649
1650        Parameters
1651        ----------
1652        stats : dict
1653            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1654            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1655        """
1656
1657        self.stats = stats
1658
1659    def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]:
1660        coords = self.input_store.get_original_coordinates()
1661        clips = set(
1662            [
1663                self.input_store.get_clip_id(c)
1664                for c in coords
1665                if self.input_store.get_video_id(c) == video_id
1666            ]
1667        )
1668        min_frames = {}
1669        max_frames = {}
1670        for clip in clips:
1671            start = self.input_store.get_clip_start(video_id, clip)
1672            end = start + self.input_store.get_clip_length(video_id, clip)
1673            min_frames[clip] = start
1674            max_frames[clip] = end - 1
1675        return min_frames, max_frames
1676
1677    def get_normalization_stats(self, skip_keys=None) -> Dict:
1678        """
1679        Get mean and standard deviation for each key
1680
1681        Returns
1682        -------
1683        stats : dict
1684            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1685            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1686        """
1687
1688        stats = defaultdict(lambda: {})
1689        sums = defaultdict(lambda: 0)
1690        if skip_keys is None:
1691            skip_keys = []
1692        counter = defaultdict(lambda: 0)
1693        for sample in tqdm(self):
1694            for key, value in sample["input"].items():
1695                key_name = key.split("---")[0]
1696                if key_name not in skip_keys:
1697                    sums[key_name] += value[:, value.sum(0) != 0].sum(-1)
1698                counter[key_name] += torch.sum(value.sum(0) != 0)
1699        for key, value in sums.items():
1700            stats[key]["mean"] = (value / counter[key]).unsqueeze(-1)
1701        sums = defaultdict(lambda: 0)
1702        for sample in tqdm(self):
1703            for key, value in sample["input"].items():
1704                key_name = key.split("---")[0]
1705                if key_name not in skip_keys:
1706                    sums[key_name] += (
1707                        (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2
1708                    ).sum(-1)
1709        for key, value in sums.items():
1710            stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key])
1711        return stats
class BehaviorDataset(typing.Generic[+T_co]):
  36class BehaviorDataset(Dataset, ABC):
  37    """
  38    A generalized dataset class
  39
  40    Data and annotation are stored in separate InputStore and AnnotationStore objects; the dataset class
  41    manages their interactions.
  42    """
  43
  44    def __init__(
  45        self,
  46        data_type: str,
  47        annotation_type: str = "none",
  48        ssl_transformations: List = None,
  49        saved_data_path: str = None,
  50        input_store: InputStore = None,
  51        annotation_store: AnnotationStore = None,
  52        only_load_annotated: bool = False,
  53        recompute_annotation: bool = False,
  54        # mask: str = None,
  55        ids: List = None,
  56        **data_parameters,
  57    ) -> None:
  58        """
  59        Parameters
  60        ----------
  61        data_type : str
  62            the data type (see available types by running BehaviorDataset.data_types())
  63        annotation_type : str
  64            the annotation type (see available types by running BehaviorDataset.annotation_types())
  65        ssl_transformations : list
  66            a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple
  67        saved_data_path : str
  68            the path to a pre-computed pickled dataset
  69        input_store : InputStore
  70            a pre-computed input store
  71        annotation_store : AnnotationStore
  72            a precomputed annotation store
  73        only_load_annotated : bool
  74            if True, the input files that don't have a matching annotation file will be disregarded
  75        *data_parameters : dict
  76            parameters to initialize the input and annotation stores
  77        """
  78
  79        mask = None
  80        if len(data_parameters) == 0:
  81            recompute_annotation = False
  82        feature_extraction = data_parameters.get("feature_extraction")
  83        if feature_extraction is not None and not issubclass(
  84            options.input_stores[data_type],
  85            options.feature_extractors[feature_extraction].input_store_class,
  86        ):
  87            raise ValueError(
  88                f"The {feature_extraction} feature extractor does not work with "
  89                f"the {data_type} data type, please choose a suclass of "
  90                f"{options.feature_extractors[feature_extraction].input_store_class}"
  91            )
  92        if ssl_transformations is None:
  93            ssl_transformations = []
  94        self.ssl_transformations = ssl_transformations
  95        self.input_type = data_type
  96        self.annotation_type = annotation_type
  97        self.stats = None
  98        if mask is not None:
  99            with open(mask, "rb") as f:
 100                self.mask = pickle.load(f)
 101        else:
 102            self.mask = None
 103        self.ids = ids
 104        self.tag = None
 105        self.return_unlabeled = None
 106        # load saved key objects for annotation and input if they exist
 107        input_key_objects, annotation_key_objects = None, None
 108        if saved_data_path is not None:
 109            if os.path.exists(saved_data_path):
 110                with open(saved_data_path, "rb") as f:
 111                    input_key_objects, annotation_key_objects = pickle.load(f)
 112        # if the input or the annotation store need to be created, generate the common video order
 113        if len(data_parameters) > 0:
 114            input_files = options.input_stores[data_type].get_file_ids(
 115                **data_parameters
 116            )
 117            annotation_files = options.annotation_stores[annotation_type].get_file_ids(
 118                **data_parameters
 119            )
 120            if only_load_annotated:
 121                data_parameters["video_order"] = [
 122                    x for x in input_files if x in annotation_files
 123                ]
 124            else:
 125                data_parameters["video_order"] = input_files
 126            if len(data_parameters["video_order"]) == 0:
 127                raise RuntimeError(
 128                    "The length of file list is 0! Please check your data parameters!"
 129                )
 130        data_parameters["mask"] = self.mask
 131        # load or create the input store
 132        ok = False
 133        if input_store is not None:
 134            self.input_store = input_store
 135            ok = True
 136        elif input_key_objects is not None:
 137            try:
 138                self.input_store = self._load_input_store(data_type, input_key_objects)
 139                ok = True
 140            except:
 141                warnings.warn("Loading input store from key objects failed")
 142        if not ok:
 143            self.input_store = self._get_input_store(
 144                data_type, deepcopy(data_parameters)
 145            )
 146        # get the objects needed to create the annotation store (like a clip length dictionary)
 147        annotation_objects = self.input_store.get_annotation_objects()
 148        data_parameters.update(annotation_objects)
 149        # load or create the annotation store
 150        ok = False
 151        if annotation_store is not None:
 152            self.annotation_store = annotation_store
 153            ok = True
 154        elif (
 155            annotation_key_objects is not None
 156            and mask is None
 157            and not recompute_annotation
 158        ):
 159            try:
 160                self.annotation_store = self._load_annotation_store(
 161                    annotation_type, annotation_key_objects
 162                )
 163                ok = True
 164            except:
 165                warnings.warn("Loading annotation store from key objects failed")
 166        if not ok:
 167            self.annotation_store = self._get_annotation_store(
 168                annotation_type, deepcopy(data_parameters)
 169            )
 170        if (
 171            mask is None
 172            and annotation_type != "none"
 173            and not recompute_annotation
 174            and (
 175                self.annotation_store.get_original_coordinates()
 176                != self.input_store.get_original_coordinates()
 177            ).any()
 178        ):
 179            raise RuntimeError(
 180                "The clip orders in the annotation store and input store are different!"
 181            )
 182        # filter the data based on data parameters
 183        # print(f"1 {self.annotation_store.get_original_coordinates().shape=}")
 184        # print(f"1 {self.input_store.get_original_coordinates().shape=}")
 185        to_remove = self.annotation_store.filtered_indices()
 186        if len(to_remove) > 0:
 187            print(
 188                f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples"
 189            )
 190        if len(self.input_store) == len(self.annotation_store):
 191            self.input_store.remove(to_remove)
 192        self.annotation_store.remove(to_remove)
 193        self.input_indices = list(range(len(self.input_store)))
 194        self.annotation_indices = list(range(len(self.input_store)))
 195        self.indices = list(range(len(self.input_store)))
 196        # print(f'{data_parameters["video_order"]=}')
 197        # print(f"{self.annotation_store.get_original_coordinates().shape=}")
 198        # print(f"{self.input_store.get_original_coordinates().shape=}")
 199        # count = 0
 200        # for i, (x, y) in enumerate(zip(
 201        #     self.annotation_store.get_original_coordinates(),
 202        #     self.input_store.get_original_coordinates(),
 203        # )):
 204        #     if (x != y).any():
 205        #         count += 1
 206        #         print({i})
 207        #         print(f"ann: {x}")
 208        #         print(f"inp: {y}")
 209        #         print("\n")
 210        #     if count > 50:
 211        #         break
 212        if annotation_type != "none" and (
 213            self.annotation_store.get_original_coordinates().shape
 214            != self.input_store.get_original_coordinates().shape
 215            or (
 216                self.annotation_store.get_original_coordinates()
 217                != self.input_store.get_original_coordinates()
 218            ).any()
 219        ):
 220            raise RuntimeError(
 221                "The clip orders in the annotation store and input store are different!"
 222            )
 223
 224    def __getitem__(self, item: int) -> Dict:
 225        idx = self._get_idx(item)
 226        input = deepcopy(self.input_store[idx])
 227        target = self.annotation_store[idx]
 228        tag = self.input_store.get_tag(idx)
 229        ssl_inputs, ssl_targets = self._get_SSL_targets(input)
 230        batch = {"input": input}
 231        for name, x in zip(
 232            ["target", "ssl_inputs", "ssl_targets", "tag"],
 233            [target, ssl_inputs, ssl_targets, tag],
 234        ):
 235            if x is not None:
 236                batch[name] = x
 237        batch["index"] = idx
 238        if self.stats is not None:
 239            for key in batch["input"].keys():
 240                key_name = key.split("---")[0]
 241                if key_name in self.stats:
 242                    batch["input"][key][:, batch["input"][key].sum(0) != 0] = (
 243                        (batch["input"][key] - self.stats[key_name]["mean"])
 244                        / (self.stats[key_name]["std"] + 1e-7)
 245                    )[:, batch["input"][key].sum(0) != 0]
 246        return batch
 247
 248    def __len__(self) -> int:
 249        return len(self.indices)
 250        # if self.annotation_type != "none":
 251        #     return self.annotation_store.get_len(return_unlabeled=self.return_unlabeled)
 252        # else:
 253        #     return len(self.input_store)
 254
 255    def get_tags(self) -> List:
 256        """
 257        Get a list of all meta tags
 258
 259        Returns
 260        -------
 261        tags: List
 262            a list of unique meta tag values
 263        """
 264
 265        return self.input_store.get_tags()
 266
 267    def save(self, save_path: str) -> None:
 268        """
 269        Save the dictionary
 270
 271        Parameters
 272        ----------
 273        save_path : str
 274            the path where the pickled file will be stored
 275        """
 276
 277        input_obj = self.input_store.key_objects()
 278        annotation_obj = self.annotation_store.key_objects()
 279        with open(save_path, "wb") as f:
 280            pickle.dump((input_obj, annotation_obj), f)
 281
 282    def to_ram(self) -> None:
 283        """
 284        Transfer the dataset to RAM
 285        """
 286
 287        self.input_store.to_ram()
 288        self.annotation_store.to_ram()
 289
 290    def generate_full_length_gt(self) -> Dict:
 291        if self.annotation_class() == "exclusive_classification":
 292            gt = torch.zeros((len(self), self.len_segment()))
 293        else:
 294            gt = torch.zeros(
 295                (len(self), len(self.behaviors_dict()), self.len_segment())
 296            )
 297        for i in range(len(self)):
 298            gt[i] = self.annotation_store[i]
 299        return self.generate_full_length_prediction(gt)
 300
 301    def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict:
 302        """
 303        Map predictions for the equal-length pieces to predictions for the original data
 304
 305        Probabilities are averaged over predictions on overlapping intervals.
 306
 307        Parameters
 308        ----------
 309        predicted: torch.Tensor
 310            a tensor of predicted probabilities of shape `(N, #classes, #frames)`
 311
 312        Returns
 313        -------
 314        full_length_prediction : dict
 315            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
 316            averaged probability tensors
 317        """
 318
 319        result = defaultdict(lambda: {})
 320        counter = defaultdict(lambda: {})
 321        coordinates = self.input_store.get_original_coordinates()
 322        for coords, prediction in zip(coordinates, predicted):
 323            l = self.input_store.get_clip_length_from_coords(coords)
 324            video_name = self.input_store.get_video_id(coords)
 325            clip_id = self.input_store.get_clip_id(coords)
 326            start, end = self.input_store.get_clip_start_end(coords)
 327            if clip_id not in result[video_name].keys():
 328                result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
 329                counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
 330            result[video_name][clip_id][..., start:end] += (
 331                prediction.squeeze()[..., : end - start].detach().cpu()
 332            )
 333            counter[video_name][clip_id][..., start:end] += 1
 334        for video_name in result:
 335            for clip_id in result[video_name]:
 336                result[video_name][clip_id] /= counter[video_name][clip_id]
 337                result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100
 338        result = dict(result)
 339        return result
 340
 341    def find_valleys(
 342        self,
 343        predicted: Union[torch.Tensor, Dict],
 344        threshold: float = 0.5,
 345        min_frames: int = 0,
 346        visibility_min_score: float = 0,
 347        visibility_min_frac: float = 0,
 348        main_class: int = 1,
 349        low: bool = True,
 350        predicted_error: torch.Tensor = None,
 351        error_threshold: float = 0.5,
 352        hysteresis: bool = False,
 353        threshold_diff: float = None,
 354        min_frames_error: int = None,
 355        smooth_interval: int = 1,
 356        cut_annotated: bool = False,
 357    ) -> Dict:
 358        """
 359        Find the intervals where the probability of a certain class is below or above a certain hard_threshold
 360
 361        Parameters
 362        ----------
 363        predicted : torch.Tensor | dict
 364            either a tensor of predictions for the data prompts or the output of
 365            `BehaviorDataset.generate_full_length_prediction`
 366        threshold : float, default 0.5
 367            the main hard_threshold
 368        min_frames : int, default 0
 369            the minimum length of the intervals
 370        visibility_min_score : float, default 0
 371            the minimum visibility score in the intervals
 372        visibility_min_frac : float, default 0
 373            fraction of the interval that has to have the visibility score larger than visibility_score_thr
 374        main_class : int, default 1
 375            the index of the class the function is inspecting
 376        low : bool, default True
 377            if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above
 378        predicted_error : torch.Tensor, optional
 379            a tensor of error predictions for the data prompts
 380        error_threshold : float, default 0.5
 381            maximum possible probability of error at the intervals
 382        hysteresis: bool, default False
 383            if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff
 384        threshold_diff: float, optional
 385            the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes
 386        min_frames_error: int, optional
 387            if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames
 388
 389        Returns
 390        -------
 391        valleys : dict
 392            a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals
 393        """
 394
 395        result = defaultdict(lambda: [])
 396        if type(predicted) is not dict:
 397            predicted = self.generate_full_length_prediction(predicted)
 398        if predicted_error is not None:
 399            predicted_error = self.generate_full_length_prediction(predicted_error)
 400        elif min_frames_error is not None and min_frames_error != 0:
 401            # warnings.warn(
 402            #     f"The min_frames_error parameter is set to {min_frames_error} but no error prediction "
 403            #     f"is given! Setting min_frames_error to 0."
 404            # )
 405            min_frames_error = 0
 406        if low and hysteresis and threshold_diff is None:
 407            raise ValueError(
 408                "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff."
 409            )
 410        if cut_annotated:
 411            masked_intervals_dict = self.get_annotated_intervals()
 412        else:
 413            masked_intervals_dict = None
 414        print("Valleys found:")
 415        for v_id in predicted:
 416            for clip_id in predicted[v_id].keys():
 417                if predicted_error is not None:
 418                    error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold
 419                    if min_frames_error is not None:
 420                        output, indices, counts = torch.unique_consecutive(
 421                            error_mask, return_inverse=True, return_counts=True
 422                        )
 423                        wrong_indices = torch.where(
 424                            output * (counts < min_frames_error)
 425                        )[0]
 426                        if len(wrong_indices) > 0:
 427                            for i in wrong_indices:
 428                                error_mask[indices == i] = False
 429                else:
 430                    error_mask = None
 431                if masked_intervals_dict is not None:
 432                    masked_intervals = masked_intervals_dict[v_id][clip_id]
 433                else:
 434                    masked_intervals = None
 435                if not hysteresis:
 436                    res_indices_start, res_indices_end = apply_threshold(
 437                        predicted[v_id][clip_id][main_class, :],
 438                        threshold,
 439                        low,
 440                        error_mask,
 441                        min_frames,
 442                        smooth_interval,
 443                        masked_intervals,
 444                    )
 445                elif threshold_diff is not None:
 446                    if low:
 447                        soft_threshold = threshold + threshold_diff
 448                    else:
 449                        soft_threshold = threshold - threshold_diff
 450                    res_indices_start, res_indices_end = apply_threshold_hysteresis(
 451                        predicted[v_id][clip_id][main_class, :],
 452                        soft_threshold,
 453                        threshold,
 454                        low,
 455                        error_mask,
 456                        min_frames,
 457                        smooth_interval,
 458                        masked_intervals,
 459                    )
 460                else:
 461                    res_indices_start, res_indices_end = apply_threshold_max(
 462                        predicted[v_id][clip_id],
 463                        threshold,
 464                        main_class,
 465                        error_mask,
 466                        min_frames,
 467                        smooth_interval,
 468                        masked_intervals,
 469                    )
 470                start = self.input_store.get_clip_start(v_id, clip_id)
 471                result[v_id] += [
 472                    [i + start, j + start, clip_id]
 473                    for i, j in zip(res_indices_start, res_indices_end)
 474                    if self.input_store.get_visibility(
 475                        v_id, clip_id, i, j, visibility_min_score
 476                    )
 477                    > visibility_min_frac
 478                ]
 479            result[v_id] = sorted(result[v_id])
 480            print(f"    {v_id}: {len(result[v_id])}")
 481        return dict(result)
 482
 483    def valleys_union(self, valleys_list) -> Dict:
 484        """
 485        Find the intersection of two valleys dictionaries
 486
 487        Parameters
 488        ----------
 489        valleys_list : list
 490            a list of valleys dictionaries
 491
 492        Returns
 493        -------
 494        intersection : dict
 495            a new valleys dictionary with the intersection of the input intervals
 496        """
 497
 498        valleys_list = [x for x in valleys_list if x is not None]
 499        if len(valleys_list) == 1:
 500            return valleys_list[0]
 501        elif len(valleys_list) == 0:
 502            return {}
 503        union = {}
 504        keys_list = [set(valleys.keys()) for valleys in valleys_list]
 505        keys = set.union(*keys_list)
 506        for v_id in keys:
 507            res = []
 508            clips_list = [
 509                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
 510            ]
 511            clips = set.union(*clips_list)
 512            for clip_id in clips:
 513                clip_intervals = [
 514                    x
 515                    for valleys in valleys_list
 516                    for x in valleys[v_id]
 517                    if x[-1] == clip_id
 518                ]
 519                v_len = self.input_store.get_clip_length(v_id, clip_id)
 520                arr = torch.zeros(v_len)
 521                for start, end, _ in clip_intervals:
 522                    arr[start:end] += 1
 523                output, indices, counts = torch.unique_consecutive(
 524                    arr > 0, return_inverse=True, return_counts=True
 525                )
 526                long_indices = torch.where(output)[0]
 527                res += [
 528                    (
 529                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
 530                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
 531                        clip_id,
 532                    )
 533                    for i in long_indices
 534                ]
 535            union[v_id] = res
 536        return union
 537
 538    def valleys_intersection(self, valleys_list) -> Dict:
 539        """
 540        Find the intersection of two valleys dictionaries
 541
 542        Parameters
 543        ----------
 544        valleys_list : list
 545            a list of valleys dictionaries
 546
 547        Returns
 548        -------
 549        intersection : dict
 550            a new valleys dictionary with the intersection of the input intervals
 551        """
 552
 553        valleys_list = [x for x in valleys_list if x is not None]
 554        if len(valleys_list) == 1:
 555            return valleys_list[0]
 556        elif len(valleys_list) == 0:
 557            return {}
 558        intersection = {}
 559        keys_list = [set(valleys.keys()) for valleys in valleys_list]
 560        keys = set.intersection(*keys_list)
 561        for v_id in keys:
 562            res = []
 563            clips_list = [
 564                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
 565            ]
 566            clips = set.intersection(*clips_list)
 567            for clip_id in clips:
 568                clip_intervals = [
 569                    x
 570                    for valleys in valleys_list
 571                    for x in valleys[v_id]
 572                    if x[-1] == clip_id
 573                ]
 574                v_len = self.input_store.get_clip_length(v_id, clip_id)
 575                arr = torch.zeros(v_len)
 576                for start, end, _ in clip_intervals:
 577                    arr[start:end] += 1
 578                output, indices, counts = torch.unique_consecutive(
 579                    arr, return_inverse=True, return_counts=True
 580                )
 581                long_indices = torch.where(output == 2)[0]
 582                res += [
 583                    (
 584                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
 585                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
 586                        clip_id,
 587                    )
 588                    for i in long_indices
 589                ]
 590            intersection[v_id] = res
 591        return intersection
 592
 593    def partition_train_test_val(
 594        self,
 595        use_test: float = 0,
 596        split_path: str = None,
 597        method: str = "random",
 598        val_frac: float = 0,
 599        test_frac: float = 0,
 600        save_split: bool = False,
 601        normalize: bool = False,
 602        skip_normalization_keys: List = None,
 603        stats: Dict = None,
 604    ) -> Tuple:
 605        """
 606        Partition the dataset into three new datasets
 607
 608        Parameters
 609        ----------
 610        use_test : float, default 0
 611            The fraction of the test dataset to be used in training without labels
 612        split_path : str, optional
 613            The path to load the split information from (if `'file'` method is used) and to save it to
 614            (if `'save_split'` is `True`)
 615        method : {'random', 'random:test-from-name', 'random:test-from-name:{name}',
 616            'val-from-name:{val_name}:test-from-name:{test_name}',
 617            'random:equalize:segments', 'random:equalize:videos',
 618            'folders', 'time', 'time:strict', 'file'}
 619            The partitioning method:
 620            - `'random'`: sort videos into subsets randomly,
 621            - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation
 622                subsets randomly and create
 623                the test subset from the video ids that start with a speific substring (`'test'` by default, or `name`
 624                if provided),
 625            - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but
 626                making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain
 627                occurences of the class get into the validation subset and `0.8 * test_frac` get into the test subset;
 628                this in ensured for all classes in order of increasing number of occurences until the validation and
 629                test subsets are full
 630            - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test
 631                subsets from the video ids that start with specific substrings (`val_name` for validation
 632                and `test_name` for test) and sort all other videos into the training subset
 633            - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets,
 634            - `'time'`: split each video into training, validation and test subsequences,
 635            - `'time:strict'`: split each video into validation, test and training subsequences
 636                and throw out the last segments in validation and test (to get rid of overlaps),
 637            - `'file'`: split according to a split file.
 638        val_frac : float, default 0
 639            The fraction of the dataset to be used in validation
 640        test_frac : float, default 0
 641            The fraction of the dataset to be used in test
 642        save_split : bool, default False
 643            Save a split file if True
 644
 645        Returns
 646        -------
 647        train_dataset : BehaviorDataset
 648            train dataset
 649
 650        val_dataset : BehaviorDataset
 651            validation dataset
 652
 653        test_dataset : BehaviorDataset
 654            test dataset
 655        """
 656
 657        train_indices, test_indices, val_indices = self._partition_indices(
 658            split_path=split_path,
 659            method=method,
 660            val_frac=val_frac,
 661            test_frac=test_frac,
 662            save_split=save_split,
 663        )
 664        val_dataset = self._create_new_dataset(val_indices)
 665        test_dataset = self._create_new_dataset(test_indices)
 666        train_dataset = self._create_new_dataset(
 667            train_indices, ssl_indices=test_indices[: int(len(test_indices) * use_test)]
 668        )
 669        print("Number of samples:")
 670        print(f"    validation:")
 671        print(f"      {val_dataset.count_classes()}")
 672        print(f"    training:")
 673        print(f"      {train_dataset.count_classes()}")
 674        print(f"    test:")
 675        print(f"      {test_dataset.count_classes()}")
 676        if normalize:
 677            if stats is None:
 678                print("Computing normalization statistics...")
 679                stats = train_dataset.get_normalization_stats(skip_normalization_keys)
 680            else:
 681                print("Setting loaded normalization statistics...")
 682            train_dataset.set_normalization_stats(stats)
 683            val_dataset.set_normalization_stats(stats)
 684            test_dataset.set_normalization_stats(stats)
 685        return train_dataset, test_dataset, val_dataset
 686
 687    def class_weights(self, proportional=False) -> List:
 688        """
 689        Calculate class weights in inverse proportion to number of samples
 690        Returns
 691        -------
 692        weights: list
 693            a list of class weights
 694        """
 695
 696        items = sorted(
 697            [
 698                (k, v)
 699                for k, v in self.annotation_store.count_classes().items()
 700                if k != -100
 701            ]
 702        )
 703        if self.annotation_store.annotation_class() == "exclusive_classification":
 704            if not proportional:
 705                numerator = len(self.annotation_store)
 706            else:
 707                numerator = max([x[1] for x in items])
 708            weights = [numerator / (v + 1e-7) for _, v in items]
 709        else:
 710            items_zero = sorted(
 711                [
 712                    (k, v)
 713                    for k, v in self.annotation_store.count_classes(zeros=True).items()
 714                    if k != -100
 715                ]
 716            )
 717            if not proportional:
 718                numerators = defaultdict(lambda: len(self.annotation_store))
 719            else:
 720                numerators = {
 721                    item_one[0]: max(item_one[1], item_zero[1])
 722                    for item_one, item_zero in zip(items, items_zero)
 723                }
 724            weights = {}
 725            weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero]
 726            weights[1] = [numerators[k] / (v + 1e-7) for k, v in items]
 727        return weights
 728
 729    def boundary_class_weight(self):
 730        if self.annotation_type != "none":
 731            f = self.annotation_store.data.flatten()
 732            _, inv = torch.unique_consecutive(f, return_inverse=True)
 733            boundary = torch.cat([torch.tensor([0]), torch.diff(inv)]).reshape(
 734                self.annotation_store.data.shape
 735            )
 736            boundary[..., 0] = 0
 737            cnt = Counter(boundary.flatten().numpy())
 738            return cnt[1] / cnt[0]
 739        else:
 740            return 0
 741
 742    def count_classes(self, bouts: bool = False) -> Dict:
 743        """
 744        Get a class counter dictionary
 745
 746        Parameters
 747        ----------
 748        bouts : bool, default False
 749            if `True`, instead of frame counts segment counts are returned
 750
 751        Returns
 752        -------
 753        count_dictionary : dict
 754            a dictionary with class indices as keys and frame or bout counts as values
 755        """
 756
 757        return self.annotation_store.count_classes(bouts=bouts)
 758
 759    def behaviors_dict(self) -> Dict:
 760        """
 761        Get a behavior dictionary
 762
 763        Returns
 764        -------
 765        dict
 766            behavior dictionary
 767        """
 768
 769        return self.annotation_store.behaviors_dict()
 770
 771    def bodyparts_order(self) -> List:
 772        try:
 773            return self.input_store.get_bodyparts()
 774        except:
 775            raise RuntimeError(
 776                f"The {self.input_type} input store does not have bodyparts implemented!"
 777            )
 778
 779    def features_shape(self) -> Dict:
 780        """
 781        Get the shapes of the input features
 782
 783        Returns
 784        -------
 785        shapes : Dict
 786            a dictionary with the shapes of the features
 787        """
 788
 789        sample = self.input_store[0]
 790        shapes = {k: v.shape for k, v in sample.items()}
 791        # for key, value in shapes.items():
 792        #     print(f'{key}: {value}')
 793        return shapes
 794
 795    def num_classes(self) -> int:
 796        """
 797        Get the number of classes in the data
 798
 799        Returns
 800        -------
 801        num_classes : int
 802            the number of classes
 803        """
 804
 805        return len(self.annotation_store.behaviors_dict())
 806
 807    def len_segment(self) -> int:
 808        """
 809        Get the segment length in the data
 810
 811        Returns
 812        -------
 813        len_segment : int
 814            the segment length
 815        """
 816
 817        sample = self.input_store[0]
 818        key = list(sample.keys())[0]
 819        return sample[key].shape[-1]
 820
 821    def set_ssl_transformations(self, ssl_transformations: List) -> None:
 822        """
 823        Set new SSL transformations
 824
 825        Parameters
 826        ----------
 827        ssl_transformations : list
 828            a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets
 829            lists
 830        """
 831
 832        self.ssl_transformations = ssl_transformations
 833
 834    @classmethod
 835    def new(cls, *args, **kwargs):
 836        """
 837        Create a new object of the same class
 838
 839        Returns
 840        -------
 841        new_instance: BehaviorDataset
 842            a new instance of the same class
 843        """
 844
 845        return cls(*args, **kwargs)
 846
 847    @classmethod
 848    def get_parameters(cls, data_type: str, annotation_type: str) -> List:
 849        """
 850        Get parameters necessary for initialization
 851
 852        Parameters
 853        ----------
 854        data_type : str
 855            the data type
 856        annotation_type : str
 857            the annotation type
 858        """
 859
 860        input_features = options.input_stores[data_type].get_parameters()
 861        annotation_features = options.annotation_stores[
 862            annotation_type
 863        ].get_parameters()
 864        self_features = inspect.getfullargspec(cls.__init__).args
 865        return self_features + input_features + annotation_features
 866
 867    @staticmethod
 868    def data_types() -> List:
 869        """
 870        List available data types
 871
 872        Returns
 873        -------
 874        data_types : list
 875            available data types
 876        """
 877
 878        return list(options.input_stores.keys())
 879
 880    @staticmethod
 881    def annotation_types() -> List:
 882        """
 883        List available annotation types
 884
 885        Returns
 886        -------
 887        annotation_types : list
 888            available annotation types
 889        """
 890
 891        return list(options.annotation_stores.keys())
 892
 893    def _get_SSL_targets(self, input: Dict) -> Tuple[List, List]:
 894        """
 895        Get the SSL inputs and targets from a sample dictionary
 896        """
 897
 898        ssl_inputs = []
 899        ssl_targets = []
 900        for transform in self.ssl_transformations:
 901            ssl_input, ssl_target = transform(copy(input))
 902            ssl_inputs.append(ssl_input)
 903            ssl_targets.append(ssl_target)
 904        return ssl_inputs, ssl_targets
 905
 906    def _create_new_dataset(self, indices: List, ssl_indices: List = None):
 907        """
 908        Create a subsample of the dataset, with samples at ssl_indices losing the annotation
 909        """
 910
 911        if ssl_indices is None:
 912            ssl_indices = []
 913        input_store = self.input_store.create_subsample(indices, ssl_indices)
 914        annotation_store = self.annotation_store.create_subsample(indices, ssl_indices)
 915        new = self.new(
 916            data_type=self.input_type,
 917            annotation_type=self.annotation_type,
 918            ssl_transformations=self.ssl_transformations,
 919            annotation_store=annotation_store,
 920            input_store=input_store,
 921            ids=list(indices) + list(ssl_indices),
 922            recompute_annotation=False,
 923        )
 924        return new
 925
 926    def _load_input_store(self, data_type: str, key_objects: Tuple) -> InputStore:
 927        """
 928        Load input store from key objects
 929        """
 930
 931        input_store = options.input_stores[data_type](key_objects=key_objects)
 932        return input_store
 933
 934    def _load_annotation_store(
 935        self, annotation_type: str, key_objects: Tuple
 936    ) -> AnnotationStore:
 937        """
 938        Load annotation store from key objects
 939        """
 940
 941        annotation_store = options.annotation_stores[annotation_type](
 942            key_objects=key_objects
 943        )
 944        return annotation_store
 945
 946    def _get_input_store(self, data_type: str, data_parameters: Dict) -> InputStore:
 947        """
 948        Create input store from parameters
 949        """
 950
 951        data_parameters["key_objects"] = None
 952        input_store = options.input_stores[data_type](**data_parameters)
 953        return input_store
 954
 955    def _get_annotation_store(
 956        self, annotation_type: str, data_parameters: Dict
 957    ) -> AnnotationStore:
 958        """
 959        Create annotation store from parameters
 960        """
 961
 962        annotation_store = options.annotation_stores[annotation_type](**data_parameters)
 963        return annotation_store
 964
 965    def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None:
 966        """
 967        Set the parameters that change the subset that is returned at `__getitem__`
 968
 969        Parameters
 970        ----------
 971        unlabeled : bool
 972            a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and
 973            all if `None`
 974        tag : int
 975            if not `None`, only samples with this meta tag will be returned
 976        """
 977
 978        if unlabeled != self.return_unlabeled:
 979            self.annotation_indices = self.annotation_store.get_indices(unlabeled)
 980            self.return_unlabeled = unlabeled
 981        if tag != self.tag:
 982            self.input_indices = self.input_store.get_indices(tag)
 983            self.tag = tag
 984        self.indices = [x for x in self.annotation_indices if x in self.input_indices]
 985
 986    def _get_idx(self, index: int) -> int:
 987        """
 988        Get index in full dataset
 989        """
 990
 991        return self.indices[index]
 992
 993        # return self.annotation_store.get_idx(
 994        #     index, return_unlabeled=self.return_unlabeled
 995        # )
 996
 997    def _partition_indices(
 998        self,
 999        split_path: str = None,
1000        method: str = "random",
1001        val_frac: float = 0,
1002        test_frac: float = 0,
1003        save_split: bool = False,
1004    ) -> Tuple[List, List, List]:
1005        """
1006        Partition indices into train, validation, test subsets
1007        """
1008
1009        if self.mask is not None:
1010            val_indices = self.mask["val_ids"]
1011            train_indices = [x for x in range(len(self)) if x not in val_indices]
1012            test_indices = []
1013        elif method == "random":
1014            videos = np.array(self.input_store.get_video_id_order())
1015            all_videos = list(set(videos))
1016            if len(all_videos) == 1:
1017                warnings.warn(
1018                    "There is only one video in the dataset, so train/val/test split is done on segments; "
1019                    'that might lead to overlaps, please consider using "time" or "time:strict" as the '
1020                    "partitioning method instead"
1021                )
1022                # Quick fix for single video: the problem with this is that the segments can overlap
1023                # length = int(self.input_store.get_original_coordinates()[-1][1])    # number of segments
1024                length = len(self.input_store.get_original_coordinates())
1025                val_len = int(val_frac * length)
1026                test_len = int(test_frac * length)
1027                all_indices = np.random.choice(np.arange(length), length, replace=False)
1028                val_indices = all_indices[:val_len]
1029                test_indices = all_indices[val_len : val_len + test_len]
1030                train_indices = all_indices[val_len + test_len :]
1031                coords = self.input_store.get_original_coordinates()
1032                if save_split:
1033                    self._save_partition(
1034                        coords[train_indices],
1035                        coords[val_indices],
1036                        coords[test_indices],
1037                        split_path,
1038                        coords=True,
1039                    )
1040            else:
1041                length = len(all_videos)
1042                val_len = int(val_frac * length)
1043                test_len = int(test_frac * length)
1044                validation = all_videos[:val_len]
1045                test = all_videos[val_len : val_len + test_len]
1046                training = all_videos[val_len + test_len :]
1047                train_indices = np.where(np.isin(videos, training))[0]
1048                val_indices = np.where(np.isin(videos, validation))[0]
1049                test_indices = np.where(np.isin(videos, test))[0]
1050                if save_split:
1051                    self._save_partition(training, validation, test, split_path)
1052        elif method.startswith("random:equalize"):
1053            counter = self.count_classes()
1054            counter = sorted(list([(v, k) for k, v in counter.items()]))
1055            classes = [x[1] for x in counter]
1056            indicator = {c: [] for c in classes}
1057            if method.endswith("videos"):
1058                videos = np.array(self.input_store.get_video_id_order())
1059                all_videos = list(set(videos))
1060                total_len = len(all_videos)
1061                for video_id in all_videos:
1062                    video_coords = np.where(videos == video_id)[0]
1063                    ann = torch.cat(
1064                        [self.annotation_store[i] for i in video_coords], dim=-1
1065                    )
1066                    for c in classes:
1067                        if self.annotation_class() == "nonexclusive_classification":
1068                            indicator[c].append(torch.sum(ann[c] == 1) > 0)
1069                        elif self.annotation_class() == "exclusive_classification":
1070                            indicator[c].append(torch.sum(ann == c) > 0)
1071                        else:
1072                            raise ValueError(
1073                                f"The random:equalize partition method is not implemented"
1074                                f"for the {self.annotation_class()} annotation class"
1075                            )
1076            elif method.endswith("segments"):
1077                total_len = len(self)
1078                for ann in self.annotation_store:
1079                    for c in classes:
1080                        if self.annotation_class() == "nonexclusive_classification":
1081                            indicator[c].append(torch.sum(ann[c] == 1) > 0)
1082                        elif self.annotation_class() == "exclusive_classification":
1083                            indicator[c].append(torch.sum(ann == c) > 0)
1084                        else:
1085                            raise ValueError(
1086                                f"The random:equalize partition method is not implemented"
1087                                f"for the {self.annotation_class()} annotation class"
1088                            )
1089            else:
1090                values = []
1091                for v in options.partition_methods.values():
1092                    values += v
1093                raise ValueError(
1094                    f"The {method} partition method is not recognized; please choose from {values}"
1095                )
1096            val_indices = []
1097            test_indices = []
1098            for c in classes:
1099                indicator[c] = np.array(indicator[c])
1100                ind = np.where(indicator[c])[0]
1101                np.random.shuffle(ind)
1102                c_sum = len(ind)
1103                in_val = np.sum(indicator[c][val_indices])
1104                in_test = np.sum(indicator[c][test_indices])
1105                while (
1106                    len(val_indices) < val_frac * total_len
1107                    and in_val < val_frac * c_sum * 0.8
1108                ):
1109                    first, ind = ind[0], ind[1:]
1110                    val_indices = list(set(val_indices).union({first}))
1111                    in_val = np.sum(indicator[c][val_indices])
1112                while (
1113                    len(test_indices) < test_frac * total_len
1114                    and in_test < test_frac * c_sum * 0.8
1115                ):
1116                    first, ind = ind[0], ind[1:]
1117                    test_indices = list(set(test_indices).union({first}))
1118                    in_test = np.sum(indicator[c][test_indices])
1119            if len(val_indices) < int(val_frac * total_len):
1120                left_val = int(val_frac * total_len) - len(val_indices)
1121            else:
1122                left_val = 0
1123            if len(test_indices) < int(test_frac * total_len):
1124                left_test = int(test_frac * total_len) - len(test_indices)
1125            else:
1126                left_test = 0
1127            indicator = np.ones(total_len)
1128            indicator[val_indices] = 0
1129            indicator[test_indices] = 0
1130            ind = np.where(indicator)[0]
1131            np.random.shuffle(ind)
1132            val_indices += list(ind[:left_val])
1133            test_indices += list(ind[left_val : left_val + left_test])
1134            train_indices = list(ind[left_val + left_test :])
1135            if save_split:
1136                if method.endswith("segments"):
1137                    coords = self.input_store.get_original_coordinates()
1138                    self._save_partition(
1139                        coords[train_indices],
1140                        coords[val_indices],
1141                        coords[test_indices],
1142                        coords[split_path],
1143                        coords=True,
1144                    )
1145                else:
1146                    all_videos = np.array(all_videos)
1147                    validation = all_videos[val_indices]
1148                    test = all_videos[test_indices]
1149                    training = all_videos[train_indices]
1150                    self._save_partition(training, validation, test, split_path)
1151        elif method.startswith("random:test-from-name"):
1152            split = method.split(":")
1153            if len(split) > 2:
1154                test_name = split[-1]
1155            else:
1156                test_name = "test"
1157            videos = np.array(self.input_store.get_video_id_order())
1158            all_videos = list(set(videos))
1159            test = []
1160            train_videos = []
1161            for x in all_videos:
1162                if x.startswith(test_name):
1163                    test.append(x)
1164                else:
1165                    train_videos.append(x)
1166            length = len(train_videos)
1167            val_len = int(val_frac * length)
1168            validation = train_videos[:val_len]
1169            training = train_videos[val_len:]
1170            train_indices = np.where(np.isin(videos, training))[0]
1171            val_indices = np.where(np.isin(videos, validation))[0]
1172            test_indices = np.where(np.isin(videos, test))[0]
1173            if save_split:
1174                self._save_partition(training, validation, test, split_path)
1175        elif method.startswith("val-from-name"):
1176            split = method.split(":")
1177            if split[2] != "test-from-name":
1178                raise ValueError(
1179                    f"The {method} partition method is not recognized, please choose from {options.partition_methods}"
1180                )
1181            val_name = split[1]
1182            test_name = split[-1]
1183            videos = np.array(self.input_store.get_video_id_order())
1184            all_videos = list(set(videos))
1185            test = []
1186            validation = []
1187            training = []
1188            for x in all_videos:
1189                if x.startswith(test_name):
1190                    test.append(x)
1191                elif x.startswith(val_name):
1192                    validation.append(x)
1193                else:
1194                    training.append(x)
1195            train_indices = np.where(np.isin(videos, training))[0]
1196            val_indices = np.where(np.isin(videos, validation))[0]
1197            test_indices = np.where(np.isin(videos, test))[0]
1198        elif method == "folders":
1199            folders = np.array(self.input_store.get_folder_order())
1200            videos = np.array(self.input_store.get_video_id_order())
1201            train_indices = np.where(np.isin(folders, ["training", "train"]))[0]
1202            if np.sum(np.isin(folders, ["validation", "val"])) > 0:
1203                val_indices = np.where(np.isin(folders, ["validation", "val"]))[0]
1204            else:
1205                train_videos = list(set(videos[train_indices]))
1206                val_len = int(val_frac * len(train_videos))
1207                validation = train_videos[:val_len]
1208                training = train_videos[val_len:]
1209                train_indices = np.where(np.isin(videos, training))[0]
1210                val_indices = np.where(np.isin(videos, validation))[0]
1211            test_indices = np.where(folders == "test")[0]
1212            if save_split:
1213                self._save_partition(
1214                    list(set(videos[train_indices])),
1215                    list(set(videos[val_indices])),
1216                    list(set(videos[test_indices])),
1217                    split_path,
1218                )
1219        elif method.startswith("leave-one-out"):
1220            n = int(method.split(":")[-1])
1221            videos = np.array(self.input_store.get_video_id_order())
1222            all_videos = sorted(list(set(videos)))
1223            validation = [all_videos.pop(n)]
1224            training = all_videos
1225            train_indices = np.where(np.isin(videos, training))[0]
1226            val_indices = np.where(np.isin(videos, validation))[0]
1227            test_indices = np.array([])
1228        elif method.startswith("time"):
1229            if method.endswith("strict"):
1230                len_segment = self.len_segment()
1231                # TODO: change this
1232                step = self.input_store.step
1233                num_removed = len_segment // step
1234            else:
1235                num_removed = 0
1236            videos = np.array(self.input_store.get_video_id_order())
1237            all_videos = set(videos)
1238            train_indices = []
1239            val_indices = []
1240            test_indices = []
1241            start = 0
1242            if len(method.split(":")) > 1 and method.split(":")[1] == "start-from":
1243                start = float(method.split(":")[2])
1244            for video_id in all_videos:
1245                video_indices = np.where(videos == video_id)[0]
1246                val_len = int(val_frac * len(video_indices))
1247                test_len = int(test_frac * len(video_indices))
1248                start_pos = int(start * len(video_indices))
1249                all_ind = np.ones(len(video_indices))
1250                val_indices += list(video_indices[start_pos : start_pos + val_len])
1251                all_ind[start_pos : start_pos + val_len] = 0
1252                if start_pos + val_len > len(video_indices):
1253                    p = start_pos + val_len - len(video_indices)
1254                    val_indices += list(video_indices[:p])
1255                    all_ind[:p] = 0
1256                else:
1257                    p = start_pos + val_len
1258                test_indices += list(video_indices[p : p + test_len])
1259                all_ind[p : p + test_len] = 0
1260                if p + test_len > len(video_indices):
1261                    p = test_len + p - len(video_indices)
1262                    test_indices += list(video_indices[:p])
1263                    all_ind[:p] = 0
1264                train_indices += list(video_indices[all_ind > 0])
1265                for _ in range(num_removed):
1266                    if len(val_indices) > 0:
1267                        val_indices.pop(-1)
1268                    if len(test_indices) > 0:
1269                        test_indices.pop(-1)
1270                    if start > 0 and len(train_indices) > 0:
1271                        train_indices.pop(-1)
1272        elif method == "file":
1273            if split_path is None:
1274                raise ValueError(
1275                    'You need to either set split_path or change partition method ("file" requires a file)'
1276                )
1277            with open(split_path) as f:
1278                train_line = f.readline()
1279                line = f.readline()
1280                while not line.startswith("Validation") and not line.startswith(
1281                    "Validataion"
1282                ):
1283                    line = f.readline()
1284                if line.startswith("Validation"):
1285                    validation = []
1286                    test = []
1287                    while True:
1288                        line = f.readline()
1289                        if line.startswith("Test") or len(line) == 0:
1290                            break
1291                        validation.append(line.rstrip())
1292                    while True:
1293                        line = f.readline()
1294                        if len(line) == 0:
1295                            break
1296                        test.append(line.rstrip())
1297                    type = train_line[9:-2]
1298                else:
1299                    line = f.readline()
1300                    validation = line.rstrip(",\n ").split(", ")
1301                    test = []
1302                    type = "videos"
1303            if type == "videos":
1304                videos = np.array(self.input_store.get_video_id_order())
1305                val_indices = np.where(np.isin(videos, validation))[0]
1306                test_indices = np.where(np.isin(videos, test))[0]
1307            elif type == "coords":
1308                coords = self.input_store.get_original_coordinates()
1309                video_ids = self.input_store.get_video_id_order()
1310                clip_ids = [self.input_store.get_clip_id(coord) for coord in coords]
1311                starts, ends = zip(
1312                    *[self.input_store.get_clip_start_end(coord) for coord in coords]
1313                )
1314                coords = np.array(
1315                    [
1316                        f"{video_id}---{clip_id}---{start}-{end}"
1317                        for video_id, clip_id, start, end in zip(
1318                            video_ids, clip_ids, starts, ends
1319                        )
1320                    ]
1321                )
1322                val_indices = np.where(np.isin(coords, validation))[0]
1323                test_indices = np.where(np.isin(coords, test))[0]
1324            else:
1325                raise ValueError("The split path has unrecognized format!")
1326            all = np.ones(len(self))
1327            all[val_indices] = 0
1328            all[test_indices] = 0
1329            train_indices = np.where(all)[0]
1330        else:
1331            raise ValueError(
1332                f"The {method} partition is not recognized, please choose from {options.partition_methods}"
1333            )
1334        return sorted(train_indices), sorted(test_indices), sorted(val_indices)
1335
1336    def _save_partition(
1337        self,
1338        training: List,
1339        validation: List,
1340        test: List,
1341        split_path: str,
1342        coords: bool = False,
1343    ) -> None:
1344        """
1345        Save a split file
1346        """
1347
1348        if coords:
1349            name = "coords"
1350            training_coords = []
1351            val_coords = []
1352            test_coords = []
1353            for coord in training:
1354                video_id = self.input_store.get_video_id(coord)
1355                clip_id = self.input_store.get_clip_id(coord)
1356                start, end = self.input_store.get_clip_start_end(coord)
1357                training_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1358            for coord in validation:
1359                video_id = self.input_store.get_video_id(coord)
1360                clip_id = self.input_store.get_clip_id(coord)
1361                start, end = self.input_store.get_clip_start_end(coord)
1362                val_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1363            for coord in test:
1364                video_id = self.input_store.get_video_id(coord)
1365                clip_id = self.input_store.get_clip_id(coord)
1366                start, end = self.input_store.get_clip_start_end(coord)
1367                test_coords.append(f"{video_id}---{clip_id}---{start}-{end}")
1368            training, validation, test = training_coords, val_coords, test_coords
1369        else:
1370            name = "videos"
1371        if split_path is not None:
1372            with open(split_path, "w") as f:
1373                f.write(f"Training {name}:\n")
1374                for x in training:
1375                    f.write(x + "\n")
1376                f.write(f"Validation {name}:\n")
1377                for x in validation:
1378                    f.write(x + "\n")
1379                f.write(f"Test {name}:\n")
1380                for x in test:
1381                    f.write(x + "\n")
1382
1383    def _get_intervals_from_ind(self, frame_indices: np.ndarray):
1384        """
1385        Get a list of intervals from a list of frame indices
1386
1387        Example: `[0, 1, 2, 5, 6, 8] -> [[0, 3], [5, 7], [8, 9]]`.
1388
1389        Parameters
1390        ----------
1391        frame_indices : np.ndarray
1392            a list of frame indices
1393
1394        Returns
1395        -------
1396        intervals : list
1397            a list of interval boundaries
1398        """
1399
1400        masked_intervals = []
1401        breaks = np.where(np.diff(frame_indices) != 1)[0]
1402        if len(frame_indices) > 0:
1403            start = frame_indices[0]
1404            for k in breaks:
1405                masked_intervals.append([start, frame_indices[k] + 1])
1406                start = frame_indices[k + 1]
1407            masked_intervals.append([start, frame_indices[-1] + 1])
1408        return masked_intervals
1409
1410    def get_intervals(self) -> Tuple[dict, Optional[list]]:
1411        """
1412        Get a list of intervals covered by the dataset in the original coordinates
1413
1414        Returns
1415        -------
1416        intervals : dict
1417            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1418            values are lists of the intervals in `[start, end]` format
1419        """
1420
1421        counter = defaultdict(lambda: {})
1422        coordinates = self.input_store.get_original_coordinates()
1423        for coords in coordinates:
1424            l = self.input_store.get_clip_length_from_coords(coords)
1425            video_name = self.input_store.get_video_id(coords)
1426            clip_id = self.input_store.get_clip_id(coords)
1427            start, end = self.input_store.get_clip_start_end(coords)
1428            if clip_id not in counter[video_name]:
1429                counter[video_name][clip_id] = np.zeros(l)
1430            counter[video_name][clip_id][start:end] = 1
1431        result = {video_name: {} for video_name in counter}
1432        for video_name in counter:
1433            for clip_id in counter[video_name]:
1434                result[video_name][clip_id] = self._get_intervals_from_ind(
1435                    np.where(counter[video_name][clip_id])[0]
1436                )
1437        return result, self.ids
1438
1439    def get_unannotated_intervals(self, first_intervals=None) -> Dict:
1440        """
1441        Get a list of intervals in the original coordinates where there is no annotation
1442
1443        Returns
1444        -------
1445        intervals : dict
1446            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1447            values are lists of the intervals in `[start, end]` format
1448        """
1449
1450        counter_value = 2
1451        if first_intervals is None:
1452            first_intervals = defaultdict(lambda: defaultdict(lambda: []))
1453            counter_value = 1
1454        counter = defaultdict(lambda: {})
1455        coordinates = self.input_store.get_original_coordinates()
1456        for i, coords in enumerate(coordinates):
1457            l = self.input_store.get_clip_length_from_coords(coords)
1458            ann = self.annotation_store[i]
1459            if (
1460                self.annotation_store.annotation_class()
1461                == "nonexclusive_classification"
1462            ):
1463                ann = ann[0, :]
1464            video_name = self.input_store.get_video_id(coords)
1465            clip_id = self.input_store.get_clip_id(coords)
1466            start, end = self.input_store.get_clip_start_end(coords)
1467            if clip_id not in counter[video_name]:
1468                counter[video_name][clip_id] = np.ones(l)
1469            counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int()
1470        result = {video_name: {} for video_name in counter}
1471        for video_name in counter:
1472            for clip_id in counter[video_name]:
1473                for start, end in first_intervals[video_name][clip_id]:
1474                    counter[video_name][clip_id][start:end] += 1
1475                result[video_name][clip_id] = self._get_intervals_from_ind(
1476                    np.where(counter[video_name][clip_id] == counter_value)[0]
1477                )
1478        return result
1479
1480    def get_annotated_intervals(self) -> Dict:
1481        """
1482        Get a list of intervals in the original coordinates where there is no annotation
1483
1484        Returns
1485        -------
1486        intervals : dict
1487            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1488            values are lists of the intervals in `[start, end]` format
1489        """
1490
1491        if self.annotation_type == "none":
1492            return []
1493        counter_value = 1
1494        counter = defaultdict(lambda: {})
1495        coordinates = self.input_store.get_original_coordinates()
1496        for i, coords in enumerate(coordinates):
1497            l = self.input_store.get_clip_length_from_coords(coords)
1498            ann = self.annotation_store[i]
1499            video_name = self.input_store.get_video_id(coords)
1500            clip_id = self.input_store.get_clip_id(coords)
1501            start, end = self.input_store.get_clip_start_end(coords)
1502            if clip_id not in counter[video_name]:
1503                counter[video_name][clip_id] = np.zeros(l)
1504            if (
1505                self.annotation_store.annotation_class()
1506                == "nonexclusive_classification"
1507            ):
1508                counter[video_name][clip_id][start:end] = (
1509                    torch.sum(ann[:, : end - start] != -100, dim=0) > 0
1510                ).int()
1511            else:
1512                counter[video_name][clip_id][start:end] = (
1513                    ann[: end - start] != -100
1514                ).int()
1515        result = {video_name: {} for video_name in counter}
1516        for video_name in counter:
1517            for clip_id in counter[video_name]:
1518                result[video_name][clip_id] = self._get_intervals_from_ind(
1519                    np.where(counter[video_name][clip_id] == counter_value)[0]
1520                )
1521        return result
1522
1523    def get_ids(self) -> Dict:
1524        """
1525        Get a dictionary of all clip ids in the dataset
1526
1527        Returns
1528        -------
1529        ids : dict
1530            a dictionary where keys are video ids and values are lists of clip ids
1531        """
1532
1533        coordinates = self.input_store.get_original_coordinates()
1534        video_ids = np.array(self.input_store.get_video_id_order())
1535        id_set = set(video_ids)
1536        result = {}
1537        for video_id in id_set:
1538            coords = coordinates[video_ids == video_id]
1539            clip_ids = list({self.input_store.get_clip_id(c) for c in coords})
1540            result[video_id] = clip_ids
1541        return result
1542
1543    def get_len(self, video_id: str, clip_id: str) -> int:
1544        """
1545        Get the length of a specific clip
1546
1547        Parameters
1548        ----------
1549        video_id : str
1550            the video id
1551        clip_id : str
1552            the clip id
1553
1554        Returns
1555        -------
1556        length : int
1557            the length
1558        """
1559
1560        return self.input_store.get_clip_length(video_id, clip_id)
1561
1562    def get_confusion_matrix(
1563        self, prediction: torch.Tensor, confusion_type: str = "recall"
1564    ) -> Tuple[ndarray, list]:
1565        """
1566        Get a confusion matrix
1567
1568        Parameters
1569        ----------
1570        prediction : torch.Tensor
1571            a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)`
1572        confusion_type : {"recall", "precision"}
1573            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
1574            into account, and if `type` is `"precision"`, only false negatives
1575
1576        Returns
1577        -------
1578        confusion_matrix : np.ndarray
1579            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
1580            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
1581            `N_i` is the number of frames that have the i-th label in the ground truth
1582        classes : list
1583            a list of classes
1584        """
1585
1586        behaviors_dict = self.annotation_store.behaviors_dict()
1587        num_behaviors = len(behaviors_dict)
1588        confusion_matrix = np.zeros((num_behaviors, num_behaviors))
1589        if self.annotation_store.annotation_class() == "exclusive_classification":
1590            exclusive = True
1591            confusion_type = None
1592        elif self.annotation_store.annotation_class() == "nonexclusive_classification":
1593            exclusive = False
1594        else:
1595            raise RuntimeError(
1596                f"The {self.annotation_store.annotation_class()} annotation class is not recognized!"
1597            )
1598        for ann, p in zip(self.annotation_store, prediction):
1599            if exclusive:
1600                class_prediction = torch.max(p, dim=0)[1]
1601                for i in behaviors_dict.keys():
1602                    for j in behaviors_dict.keys():
1603                        confusion_matrix[i, j] += int(
1604                            torch.sum(class_prediction[ann == i] == j)
1605                        )
1606            else:
1607                class_prediction = (p > 0.5).int()
1608                for i in behaviors_dict.keys():
1609                    for j in behaviors_dict.keys():
1610                        if confusion_type == "recall":
1611                            pred = deepcopy(class_prediction[j])
1612                            if i != j:
1613                                pred[ann[j] == 1] = 0
1614                            confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1]))
1615                        elif confusion_type == "precision":
1616                            annotation = deepcopy(ann[j])
1617                            if i != j:
1618                                annotation[class_prediction[j] == 1] = 0
1619                            confusion_matrix[i, j] += int(
1620                                torch.sum(annotation[class_prediction[i] == 1])
1621                            )
1622                        else:
1623                            raise ValueError(
1624                                f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']"
1625                            )
1626        counter = self.annotation_store.count_classes()
1627        for i in behaviors_dict.keys():
1628            if counter[i] != 0:
1629                if confusion_type == "recall" or confusion_type is None:
1630                    confusion_matrix[i, :] /= counter[i]
1631                else:
1632                    confusion_matrix[:, i] /= counter[i]
1633        return confusion_matrix, list(behaviors_dict.values()), confusion_type
1634
1635    def annotation_class(self) -> str:
1636        """
1637        Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon)
1638
1639        Returns
1640        -------
1641        annotation_class : str
1642            the type of annotation
1643        """
1644
1645        return self.annotation_store.annotation_class()
1646
1647    def set_normalization_stats(self, stats: Dict) -> None:
1648        """
1649        Set the stats to normalize data at runtime
1650
1651        Parameters
1652        ----------
1653        stats : dict
1654            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1655            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1656        """
1657
1658        self.stats = stats
1659
1660    def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]:
1661        coords = self.input_store.get_original_coordinates()
1662        clips = set(
1663            [
1664                self.input_store.get_clip_id(c)
1665                for c in coords
1666                if self.input_store.get_video_id(c) == video_id
1667            ]
1668        )
1669        min_frames = {}
1670        max_frames = {}
1671        for clip in clips:
1672            start = self.input_store.get_clip_start(video_id, clip)
1673            end = start + self.input_store.get_clip_length(video_id, clip)
1674            min_frames[clip] = start
1675            max_frames[clip] = end - 1
1676        return min_frames, max_frames
1677
1678    def get_normalization_stats(self, skip_keys=None) -> Dict:
1679        """
1680        Get mean and standard deviation for each key
1681
1682        Returns
1683        -------
1684        stats : dict
1685            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1686            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1687        """
1688
1689        stats = defaultdict(lambda: {})
1690        sums = defaultdict(lambda: 0)
1691        if skip_keys is None:
1692            skip_keys = []
1693        counter = defaultdict(lambda: 0)
1694        for sample in tqdm(self):
1695            for key, value in sample["input"].items():
1696                key_name = key.split("---")[0]
1697                if key_name not in skip_keys:
1698                    sums[key_name] += value[:, value.sum(0) != 0].sum(-1)
1699                counter[key_name] += torch.sum(value.sum(0) != 0)
1700        for key, value in sums.items():
1701            stats[key]["mean"] = (value / counter[key]).unsqueeze(-1)
1702        sums = defaultdict(lambda: 0)
1703        for sample in tqdm(self):
1704            for key, value in sample["input"].items():
1705                key_name = key.split("---")[0]
1706                if key_name not in skip_keys:
1707                    sums[key_name] += (
1708                        (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2
1709                    ).sum(-1)
1710        for key, value in sums.items():
1711            stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key])
1712        return stats

A generalized dataset class

Data and annotation are stored in separate InputStore and AnnotationStore objects; the dataset class manages their interactions.

BehaviorDataset( data_type: str, annotation_type: str = 'none', ssl_transformations: List = None, saved_data_path: str = None, input_store: dlc2action.data.base_store.InputStore = None, annotation_store: dlc2action.data.base_store.AnnotationStore = None, only_load_annotated: bool = False, recompute_annotation: bool = False, ids: List = None, **data_parameters)
 44    def __init__(
 45        self,
 46        data_type: str,
 47        annotation_type: str = "none",
 48        ssl_transformations: List = None,
 49        saved_data_path: str = None,
 50        input_store: InputStore = None,
 51        annotation_store: AnnotationStore = None,
 52        only_load_annotated: bool = False,
 53        recompute_annotation: bool = False,
 54        # mask: str = None,
 55        ids: List = None,
 56        **data_parameters,
 57    ) -> None:
 58        """
 59        Parameters
 60        ----------
 61        data_type : str
 62            the data type (see available types by running BehaviorDataset.data_types())
 63        annotation_type : str
 64            the annotation type (see available types by running BehaviorDataset.annotation_types())
 65        ssl_transformations : list
 66            a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple
 67        saved_data_path : str
 68            the path to a pre-computed pickled dataset
 69        input_store : InputStore
 70            a pre-computed input store
 71        annotation_store : AnnotationStore
 72            a precomputed annotation store
 73        only_load_annotated : bool
 74            if True, the input files that don't have a matching annotation file will be disregarded
 75        *data_parameters : dict
 76            parameters to initialize the input and annotation stores
 77        """
 78
 79        mask = None
 80        if len(data_parameters) == 0:
 81            recompute_annotation = False
 82        feature_extraction = data_parameters.get("feature_extraction")
 83        if feature_extraction is not None and not issubclass(
 84            options.input_stores[data_type],
 85            options.feature_extractors[feature_extraction].input_store_class,
 86        ):
 87            raise ValueError(
 88                f"The {feature_extraction} feature extractor does not work with "
 89                f"the {data_type} data type, please choose a suclass of "
 90                f"{options.feature_extractors[feature_extraction].input_store_class}"
 91            )
 92        if ssl_transformations is None:
 93            ssl_transformations = []
 94        self.ssl_transformations = ssl_transformations
 95        self.input_type = data_type
 96        self.annotation_type = annotation_type
 97        self.stats = None
 98        if mask is not None:
 99            with open(mask, "rb") as f:
100                self.mask = pickle.load(f)
101        else:
102            self.mask = None
103        self.ids = ids
104        self.tag = None
105        self.return_unlabeled = None
106        # load saved key objects for annotation and input if they exist
107        input_key_objects, annotation_key_objects = None, None
108        if saved_data_path is not None:
109            if os.path.exists(saved_data_path):
110                with open(saved_data_path, "rb") as f:
111                    input_key_objects, annotation_key_objects = pickle.load(f)
112        # if the input or the annotation store need to be created, generate the common video order
113        if len(data_parameters) > 0:
114            input_files = options.input_stores[data_type].get_file_ids(
115                **data_parameters
116            )
117            annotation_files = options.annotation_stores[annotation_type].get_file_ids(
118                **data_parameters
119            )
120            if only_load_annotated:
121                data_parameters["video_order"] = [
122                    x for x in input_files if x in annotation_files
123                ]
124            else:
125                data_parameters["video_order"] = input_files
126            if len(data_parameters["video_order"]) == 0:
127                raise RuntimeError(
128                    "The length of file list is 0! Please check your data parameters!"
129                )
130        data_parameters["mask"] = self.mask
131        # load or create the input store
132        ok = False
133        if input_store is not None:
134            self.input_store = input_store
135            ok = True
136        elif input_key_objects is not None:
137            try:
138                self.input_store = self._load_input_store(data_type, input_key_objects)
139                ok = True
140            except:
141                warnings.warn("Loading input store from key objects failed")
142        if not ok:
143            self.input_store = self._get_input_store(
144                data_type, deepcopy(data_parameters)
145            )
146        # get the objects needed to create the annotation store (like a clip length dictionary)
147        annotation_objects = self.input_store.get_annotation_objects()
148        data_parameters.update(annotation_objects)
149        # load or create the annotation store
150        ok = False
151        if annotation_store is not None:
152            self.annotation_store = annotation_store
153            ok = True
154        elif (
155            annotation_key_objects is not None
156            and mask is None
157            and not recompute_annotation
158        ):
159            try:
160                self.annotation_store = self._load_annotation_store(
161                    annotation_type, annotation_key_objects
162                )
163                ok = True
164            except:
165                warnings.warn("Loading annotation store from key objects failed")
166        if not ok:
167            self.annotation_store = self._get_annotation_store(
168                annotation_type, deepcopy(data_parameters)
169            )
170        if (
171            mask is None
172            and annotation_type != "none"
173            and not recompute_annotation
174            and (
175                self.annotation_store.get_original_coordinates()
176                != self.input_store.get_original_coordinates()
177            ).any()
178        ):
179            raise RuntimeError(
180                "The clip orders in the annotation store and input store are different!"
181            )
182        # filter the data based on data parameters
183        # print(f"1 {self.annotation_store.get_original_coordinates().shape=}")
184        # print(f"1 {self.input_store.get_original_coordinates().shape=}")
185        to_remove = self.annotation_store.filtered_indices()
186        if len(to_remove) > 0:
187            print(
188                f"Filtering {100 * len(to_remove) / len(self.annotation_store):.2f}% of samples"
189            )
190        if len(self.input_store) == len(self.annotation_store):
191            self.input_store.remove(to_remove)
192        self.annotation_store.remove(to_remove)
193        self.input_indices = list(range(len(self.input_store)))
194        self.annotation_indices = list(range(len(self.input_store)))
195        self.indices = list(range(len(self.input_store)))
196        # print(f'{data_parameters["video_order"]=}')
197        # print(f"{self.annotation_store.get_original_coordinates().shape=}")
198        # print(f"{self.input_store.get_original_coordinates().shape=}")
199        # count = 0
200        # for i, (x, y) in enumerate(zip(
201        #     self.annotation_store.get_original_coordinates(),
202        #     self.input_store.get_original_coordinates(),
203        # )):
204        #     if (x != y).any():
205        #         count += 1
206        #         print({i})
207        #         print(f"ann: {x}")
208        #         print(f"inp: {y}")
209        #         print("\n")
210        #     if count > 50:
211        #         break
212        if annotation_type != "none" and (
213            self.annotation_store.get_original_coordinates().shape
214            != self.input_store.get_original_coordinates().shape
215            or (
216                self.annotation_store.get_original_coordinates()
217                != self.input_store.get_original_coordinates()
218            ).any()
219        ):
220            raise RuntimeError(
221                "The clip orders in the annotation store and input store are different!"
222            )

Parameters

data_type : str the data type (see available types by running BehaviorDataset.data_types()) annotation_type : str the annotation type (see available types by running BehaviorDataset.annotation_types()) ssl_transformations : list a list of functions that take a sample dictionary as input and return an (ssl input, ssl target) tuple saved_data_path : str the path to a pre-computed pickled dataset input_store : InputStore a pre-computed input store annotation_store : AnnotationStore a precomputed annotation store only_load_annotated : bool if True, the input files that don't have a matching annotation file will be disregarded *data_parameters : dict parameters to initialize the input and annotation stores

def get_tags(self) -> List:
255    def get_tags(self) -> List:
256        """
257        Get a list of all meta tags
258
259        Returns
260        -------
261        tags: List
262            a list of unique meta tag values
263        """
264
265        return self.input_store.get_tags()

Get a list of all meta tags

Returns

tags: List a list of unique meta tag values

def save(self, save_path: str) -> None:
267    def save(self, save_path: str) -> None:
268        """
269        Save the dictionary
270
271        Parameters
272        ----------
273        save_path : str
274            the path where the pickled file will be stored
275        """
276
277        input_obj = self.input_store.key_objects()
278        annotation_obj = self.annotation_store.key_objects()
279        with open(save_path, "wb") as f:
280            pickle.dump((input_obj, annotation_obj), f)

Save the dictionary

Parameters

save_path : str the path where the pickled file will be stored

def to_ram(self) -> None:
282    def to_ram(self) -> None:
283        """
284        Transfer the dataset to RAM
285        """
286
287        self.input_store.to_ram()
288        self.annotation_store.to_ram()

Transfer the dataset to RAM

def generate_full_length_gt(self) -> Dict:
290    def generate_full_length_gt(self) -> Dict:
291        if self.annotation_class() == "exclusive_classification":
292            gt = torch.zeros((len(self), self.len_segment()))
293        else:
294            gt = torch.zeros(
295                (len(self), len(self.behaviors_dict()), self.len_segment())
296            )
297        for i in range(len(self)):
298            gt[i] = self.annotation_store[i]
299        return self.generate_full_length_prediction(gt)
def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict:
301    def generate_full_length_prediction(self, predicted: torch.Tensor) -> Dict:
302        """
303        Map predictions for the equal-length pieces to predictions for the original data
304
305        Probabilities are averaged over predictions on overlapping intervals.
306
307        Parameters
308        ----------
309        predicted: torch.Tensor
310            a tensor of predicted probabilities of shape `(N, #classes, #frames)`
311
312        Returns
313        -------
314        full_length_prediction : dict
315            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are
316            averaged probability tensors
317        """
318
319        result = defaultdict(lambda: {})
320        counter = defaultdict(lambda: {})
321        coordinates = self.input_store.get_original_coordinates()
322        for coords, prediction in zip(coordinates, predicted):
323            l = self.input_store.get_clip_length_from_coords(coords)
324            video_name = self.input_store.get_video_id(coords)
325            clip_id = self.input_store.get_clip_id(coords)
326            start, end = self.input_store.get_clip_start_end(coords)
327            if clip_id not in result[video_name].keys():
328                result[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
329                counter[video_name][clip_id] = torch.zeros(*prediction.shape[:-1], l)
330            result[video_name][clip_id][..., start:end] += (
331                prediction.squeeze()[..., : end - start].detach().cpu()
332            )
333            counter[video_name][clip_id][..., start:end] += 1
334        for video_name in result:
335            for clip_id in result[video_name]:
336                result[video_name][clip_id] /= counter[video_name][clip_id]
337                result[video_name][clip_id][counter[video_name][clip_id] == 0] = -100
338        result = dict(result)
339        return result

Map predictions for the equal-length pieces to predictions for the original data

Probabilities are averaged over predictions on overlapping intervals.

Parameters

predicted: torch.Tensor a tensor of predicted probabilities of shape (N, #classes, #frames)

Returns

full_length_prediction : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are averaged probability tensors

def find_valleys( self, predicted: Union[torch.Tensor, Dict], threshold: float = 0.5, min_frames: int = 0, visibility_min_score: float = 0, visibility_min_frac: float = 0, main_class: int = 1, low: bool = True, predicted_error: torch.Tensor = None, error_threshold: float = 0.5, hysteresis: bool = False, threshold_diff: float = None, min_frames_error: int = None, smooth_interval: int = 1, cut_annotated: bool = False) -> Dict:
341    def find_valleys(
342        self,
343        predicted: Union[torch.Tensor, Dict],
344        threshold: float = 0.5,
345        min_frames: int = 0,
346        visibility_min_score: float = 0,
347        visibility_min_frac: float = 0,
348        main_class: int = 1,
349        low: bool = True,
350        predicted_error: torch.Tensor = None,
351        error_threshold: float = 0.5,
352        hysteresis: bool = False,
353        threshold_diff: float = None,
354        min_frames_error: int = None,
355        smooth_interval: int = 1,
356        cut_annotated: bool = False,
357    ) -> Dict:
358        """
359        Find the intervals where the probability of a certain class is below or above a certain hard_threshold
360
361        Parameters
362        ----------
363        predicted : torch.Tensor | dict
364            either a tensor of predictions for the data prompts or the output of
365            `BehaviorDataset.generate_full_length_prediction`
366        threshold : float, default 0.5
367            the main hard_threshold
368        min_frames : int, default 0
369            the minimum length of the intervals
370        visibility_min_score : float, default 0
371            the minimum visibility score in the intervals
372        visibility_min_frac : float, default 0
373            fraction of the interval that has to have the visibility score larger than visibility_score_thr
374        main_class : int, default 1
375            the index of the class the function is inspecting
376        low : bool, default True
377            if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above
378        predicted_error : torch.Tensor, optional
379            a tensor of error predictions for the data prompts
380        error_threshold : float, default 0.5
381            maximum possible probability of error at the intervals
382        hysteresis: bool, default False
383            if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff
384        threshold_diff: float, optional
385            the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes
386        min_frames_error: int, optional
387            if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames
388
389        Returns
390        -------
391        valleys : dict
392            a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals
393        """
394
395        result = defaultdict(lambda: [])
396        if type(predicted) is not dict:
397            predicted = self.generate_full_length_prediction(predicted)
398        if predicted_error is not None:
399            predicted_error = self.generate_full_length_prediction(predicted_error)
400        elif min_frames_error is not None and min_frames_error != 0:
401            # warnings.warn(
402            #     f"The min_frames_error parameter is set to {min_frames_error} but no error prediction "
403            #     f"is given! Setting min_frames_error to 0."
404            # )
405            min_frames_error = 0
406        if low and hysteresis and threshold_diff is None:
407            raise ValueError(
408                "Cannot set low=True, hysteresis=True and threshold_diff=None! Please set threshold_diff."
409            )
410        if cut_annotated:
411            masked_intervals_dict = self.get_annotated_intervals()
412        else:
413            masked_intervals_dict = None
414        print("Valleys found:")
415        for v_id in predicted:
416            for clip_id in predicted[v_id].keys():
417                if predicted_error is not None:
418                    error_mask = predicted_error[v_id][clip_id][1, :] < error_threshold
419                    if min_frames_error is not None:
420                        output, indices, counts = torch.unique_consecutive(
421                            error_mask, return_inverse=True, return_counts=True
422                        )
423                        wrong_indices = torch.where(
424                            output * (counts < min_frames_error)
425                        )[0]
426                        if len(wrong_indices) > 0:
427                            for i in wrong_indices:
428                                error_mask[indices == i] = False
429                else:
430                    error_mask = None
431                if masked_intervals_dict is not None:
432                    masked_intervals = masked_intervals_dict[v_id][clip_id]
433                else:
434                    masked_intervals = None
435                if not hysteresis:
436                    res_indices_start, res_indices_end = apply_threshold(
437                        predicted[v_id][clip_id][main_class, :],
438                        threshold,
439                        low,
440                        error_mask,
441                        min_frames,
442                        smooth_interval,
443                        masked_intervals,
444                    )
445                elif threshold_diff is not None:
446                    if low:
447                        soft_threshold = threshold + threshold_diff
448                    else:
449                        soft_threshold = threshold - threshold_diff
450                    res_indices_start, res_indices_end = apply_threshold_hysteresis(
451                        predicted[v_id][clip_id][main_class, :],
452                        soft_threshold,
453                        threshold,
454                        low,
455                        error_mask,
456                        min_frames,
457                        smooth_interval,
458                        masked_intervals,
459                    )
460                else:
461                    res_indices_start, res_indices_end = apply_threshold_max(
462                        predicted[v_id][clip_id],
463                        threshold,
464                        main_class,
465                        error_mask,
466                        min_frames,
467                        smooth_interval,
468                        masked_intervals,
469                    )
470                start = self.input_store.get_clip_start(v_id, clip_id)
471                result[v_id] += [
472                    [i + start, j + start, clip_id]
473                    for i, j in zip(res_indices_start, res_indices_end)
474                    if self.input_store.get_visibility(
475                        v_id, clip_id, i, j, visibility_min_score
476                    )
477                    > visibility_min_frac
478                ]
479            result[v_id] = sorted(result[v_id])
480            print(f"    {v_id}: {len(result[v_id])}")
481        return dict(result)

Find the intervals where the probability of a certain class is below or above a certain hard_threshold

Parameters

predicted : torch.Tensor | dict either a tensor of predictions for the data prompts or the output of BehaviorDataset.generate_full_length_prediction threshold : float, default 0.5 the main hard_threshold min_frames : int, default 0 the minimum length of the intervals visibility_min_score : float, default 0 the minimum visibility score in the intervals visibility_min_frac : float, default 0 fraction of the interval that has to have the visibility score larger than visibility_score_thr main_class : int, default 1 the index of the class the function is inspecting low : bool, default True if True, the probability in the intervals has to be below the hard_threshold, and if False, it has to be above predicted_error : torch.Tensor, optional a tensor of error predictions for the data prompts error_threshold : float, default 0.5 maximum possible probability of error at the intervals hysteresis: bool, default False if True, the function will apply a hysteresis hard_threshold with the soft hard_threshold defined by threshold_diff threshold_diff: float, optional the difference between the soft and hard hard_threshold if hysteresis is used; if hysteresis is True, low is False and threshold_diff is None, the soft hard_threshold condition is set to the main_class having a larger probability than other classes min_frames_error: int, optional if not None, the intervals will only be considered where the error probability is below error_threshold at at least min_frames_error consecutive frames

Returns

valleys : dict a dictionary where keys are video ids and values are lists of (start, end, individual name) tuples that denote the chosen intervals

def valleys_union(self, valleys_list) -> Dict:
483    def valleys_union(self, valleys_list) -> Dict:
484        """
485        Find the intersection of two valleys dictionaries
486
487        Parameters
488        ----------
489        valleys_list : list
490            a list of valleys dictionaries
491
492        Returns
493        -------
494        intersection : dict
495            a new valleys dictionary with the intersection of the input intervals
496        """
497
498        valleys_list = [x for x in valleys_list if x is not None]
499        if len(valleys_list) == 1:
500            return valleys_list[0]
501        elif len(valleys_list) == 0:
502            return {}
503        union = {}
504        keys_list = [set(valleys.keys()) for valleys in valleys_list]
505        keys = set.union(*keys_list)
506        for v_id in keys:
507            res = []
508            clips_list = [
509                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
510            ]
511            clips = set.union(*clips_list)
512            for clip_id in clips:
513                clip_intervals = [
514                    x
515                    for valleys in valleys_list
516                    for x in valleys[v_id]
517                    if x[-1] == clip_id
518                ]
519                v_len = self.input_store.get_clip_length(v_id, clip_id)
520                arr = torch.zeros(v_len)
521                for start, end, _ in clip_intervals:
522                    arr[start:end] += 1
523                output, indices, counts = torch.unique_consecutive(
524                    arr > 0, return_inverse=True, return_counts=True
525                )
526                long_indices = torch.where(output)[0]
527                res += [
528                    (
529                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
530                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
531                        clip_id,
532                    )
533                    for i in long_indices
534                ]
535            union[v_id] = res
536        return union

Find the intersection of two valleys dictionaries

Parameters

valleys_list : list a list of valleys dictionaries

Returns

intersection : dict a new valleys dictionary with the intersection of the input intervals

def valleys_intersection(self, valleys_list) -> Dict:
538    def valleys_intersection(self, valleys_list) -> Dict:
539        """
540        Find the intersection of two valleys dictionaries
541
542        Parameters
543        ----------
544        valleys_list : list
545            a list of valleys dictionaries
546
547        Returns
548        -------
549        intersection : dict
550            a new valleys dictionary with the intersection of the input intervals
551        """
552
553        valleys_list = [x for x in valleys_list if x is not None]
554        if len(valleys_list) == 1:
555            return valleys_list[0]
556        elif len(valleys_list) == 0:
557            return {}
558        intersection = {}
559        keys_list = [set(valleys.keys()) for valleys in valleys_list]
560        keys = set.intersection(*keys_list)
561        for v_id in keys:
562            res = []
563            clips_list = [
564                set([x[-1] for x in valleys[v_id]]) for valleys in valleys_list
565            ]
566            clips = set.intersection(*clips_list)
567            for clip_id in clips:
568                clip_intervals = [
569                    x
570                    for valleys in valleys_list
571                    for x in valleys[v_id]
572                    if x[-1] == clip_id
573                ]
574                v_len = self.input_store.get_clip_length(v_id, clip_id)
575                arr = torch.zeros(v_len)
576                for start, end, _ in clip_intervals:
577                    arr[start:end] += 1
578                output, indices, counts = torch.unique_consecutive(
579                    arr, return_inverse=True, return_counts=True
580                )
581                long_indices = torch.where(output == 2)[0]
582                res += [
583                    (
584                        (indices == i).nonzero(as_tuple=True)[0][0].item(),
585                        (indices == i).nonzero(as_tuple=True)[0][-1].item(),
586                        clip_id,
587                    )
588                    for i in long_indices
589                ]
590            intersection[v_id] = res
591        return intersection

Find the intersection of two valleys dictionaries

Parameters

valleys_list : list a list of valleys dictionaries

Returns

intersection : dict a new valleys dictionary with the intersection of the input intervals

def partition_train_test_val( self, use_test: float = 0, split_path: str = None, method: str = 'random', val_frac: float = 0, test_frac: float = 0, save_split: bool = False, normalize: bool = False, skip_normalization_keys: List = None, stats: Dict = None) -> Tuple:
593    def partition_train_test_val(
594        self,
595        use_test: float = 0,
596        split_path: str = None,
597        method: str = "random",
598        val_frac: float = 0,
599        test_frac: float = 0,
600        save_split: bool = False,
601        normalize: bool = False,
602        skip_normalization_keys: List = None,
603        stats: Dict = None,
604    ) -> Tuple:
605        """
606        Partition the dataset into three new datasets
607
608        Parameters
609        ----------
610        use_test : float, default 0
611            The fraction of the test dataset to be used in training without labels
612        split_path : str, optional
613            The path to load the split information from (if `'file'` method is used) and to save it to
614            (if `'save_split'` is `True`)
615        method : {'random', 'random:test-from-name', 'random:test-from-name:{name}',
616            'val-from-name:{val_name}:test-from-name:{test_name}',
617            'random:equalize:segments', 'random:equalize:videos',
618            'folders', 'time', 'time:strict', 'file'}
619            The partitioning method:
620            - `'random'`: sort videos into subsets randomly,
621            - `'random:test-from-name'` (or `'random:test-from-name:{name}'`): sort videos into training and validation
622                subsets randomly and create
623                the test subset from the video ids that start with a speific substring (`'test'` by default, or `name`
624                if provided),
625            - `'random:equalize:segments'` and `'random:equalize:videos'`: sort videos into subsets randomly but
626                making sure that for the rarest classes at least `0.8 * val_frac` of the videos/segments that contain
627                occurences of the class get into the validation subset and `0.8 * test_frac` get into the test subset;
628                this in ensured for all classes in order of increasing number of occurences until the validation and
629                test subsets are full
630            - `'val-from-name:{val_name}:test-from-name:{test_name}'`: create the validation and test
631                subsets from the video ids that start with specific substrings (`val_name` for validation
632                and `test_name` for test) and sort all other videos into the training subset
633            - `'folders'`: read videos from folders named *test*, *train* and *val* into corresponding subsets,
634            - `'time'`: split each video into training, validation and test subsequences,
635            - `'time:strict'`: split each video into validation, test and training subsequences
636                and throw out the last segments in validation and test (to get rid of overlaps),
637            - `'file'`: split according to a split file.
638        val_frac : float, default 0
639            The fraction of the dataset to be used in validation
640        test_frac : float, default 0
641            The fraction of the dataset to be used in test
642        save_split : bool, default False
643            Save a split file if True
644
645        Returns
646        -------
647        train_dataset : BehaviorDataset
648            train dataset
649
650        val_dataset : BehaviorDataset
651            validation dataset
652
653        test_dataset : BehaviorDataset
654            test dataset
655        """
656
657        train_indices, test_indices, val_indices = self._partition_indices(
658            split_path=split_path,
659            method=method,
660            val_frac=val_frac,
661            test_frac=test_frac,
662            save_split=save_split,
663        )
664        val_dataset = self._create_new_dataset(val_indices)
665        test_dataset = self._create_new_dataset(test_indices)
666        train_dataset = self._create_new_dataset(
667            train_indices, ssl_indices=test_indices[: int(len(test_indices) * use_test)]
668        )
669        print("Number of samples:")
670        print(f"    validation:")
671        print(f"      {val_dataset.count_classes()}")
672        print(f"    training:")
673        print(f"      {train_dataset.count_classes()}")
674        print(f"    test:")
675        print(f"      {test_dataset.count_classes()}")
676        if normalize:
677            if stats is None:
678                print("Computing normalization statistics...")
679                stats = train_dataset.get_normalization_stats(skip_normalization_keys)
680            else:
681                print("Setting loaded normalization statistics...")
682            train_dataset.set_normalization_stats(stats)
683            val_dataset.set_normalization_stats(stats)
684            test_dataset.set_normalization_stats(stats)
685        return train_dataset, test_dataset, val_dataset

Partition the dataset into three new datasets

Parameters

use_test : float, default 0 The fraction of the test dataset to be used in training without labels split_path : str, optional The path to load the split information from (if 'file' method is used) and to save it to (if 'save_split' is True) method : {'random', 'random:test-from-name', 'random:test-from-name:{name}', 'val-from-name:{val_name}:test-from-name:{test_name}', 'random:equalize:segments', 'random:equalize:videos', 'folders', 'time', 'time:strict', 'file'} The partitioning method: - 'random': sort videos into subsets randomly, - 'random:test-from-name' (or 'random:test-from-name:{name}'): sort videos into training and validation subsets randomly and create the test subset from the video ids that start with a speific substring ('test' by default, or name if provided), - 'random:equalize:segments' and 'random:equalize:videos': sort videos into subsets randomly but making sure that for the rarest classes at least 0.8 * val_frac of the videos/segments that contain occurences of the class get into the validation subset and 0.8 * test_frac get into the test subset; this in ensured for all classes in order of increasing number of occurences until the validation and test subsets are full - 'val-from-name:{val_name}:test-from-name:{test_name}': create the validation and test subsets from the video ids that start with specific substrings (val_name for validation and test_name for test) and sort all other videos into the training subset - 'folders': read videos from folders named test, train and val into corresponding subsets, - 'time': split each video into training, validation and test subsequences, - 'time:strict': split each video into validation, test and training subsequences and throw out the last segments in validation and test (to get rid of overlaps), - 'file': split according to a split file. val_frac : float, default 0 The fraction of the dataset to be used in validation test_frac : float, default 0 The fraction of the dataset to be used in test save_split : bool, default False Save a split file if True

Returns

train_dataset : BehaviorDataset train dataset

val_dataset : BehaviorDataset validation dataset

test_dataset : BehaviorDataset test dataset

def class_weights(self, proportional=False) -> List:
687    def class_weights(self, proportional=False) -> List:
688        """
689        Calculate class weights in inverse proportion to number of samples
690        Returns
691        -------
692        weights: list
693            a list of class weights
694        """
695
696        items = sorted(
697            [
698                (k, v)
699                for k, v in self.annotation_store.count_classes().items()
700                if k != -100
701            ]
702        )
703        if self.annotation_store.annotation_class() == "exclusive_classification":
704            if not proportional:
705                numerator = len(self.annotation_store)
706            else:
707                numerator = max([x[1] for x in items])
708            weights = [numerator / (v + 1e-7) for _, v in items]
709        else:
710            items_zero = sorted(
711                [
712                    (k, v)
713                    for k, v in self.annotation_store.count_classes(zeros=True).items()
714                    if k != -100
715                ]
716            )
717            if not proportional:
718                numerators = defaultdict(lambda: len(self.annotation_store))
719            else:
720                numerators = {
721                    item_one[0]: max(item_one[1], item_zero[1])
722                    for item_one, item_zero in zip(items, items_zero)
723                }
724            weights = {}
725            weights[0] = [numerators[k] / (v + 1e-7) for k, v in items_zero]
726            weights[1] = [numerators[k] / (v + 1e-7) for k, v in items]
727        return weights

Calculate class weights in inverse proportion to number of samples

Returns

weights: list a list of class weights

def boundary_class_weight(self)
729    def boundary_class_weight(self):
730        if self.annotation_type != "none":
731            f = self.annotation_store.data.flatten()
732            _, inv = torch.unique_consecutive(f, return_inverse=True)
733            boundary = torch.cat([torch.tensor([0]), torch.diff(inv)]).reshape(
734                self.annotation_store.data.shape
735            )
736            boundary[..., 0] = 0
737            cnt = Counter(boundary.flatten().numpy())
738            return cnt[1] / cnt[0]
739        else:
740            return 0
def count_classes(self, bouts: bool = False) -> Dict:
742    def count_classes(self, bouts: bool = False) -> Dict:
743        """
744        Get a class counter dictionary
745
746        Parameters
747        ----------
748        bouts : bool, default False
749            if `True`, instead of frame counts segment counts are returned
750
751        Returns
752        -------
753        count_dictionary : dict
754            a dictionary with class indices as keys and frame or bout counts as values
755        """
756
757        return self.annotation_store.count_classes(bouts=bouts)

Get a class counter dictionary

Parameters

bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

count_dictionary : dict a dictionary with class indices as keys and frame or bout counts as values

def behaviors_dict(self) -> Dict:
759    def behaviors_dict(self) -> Dict:
760        """
761        Get a behavior dictionary
762
763        Returns
764        -------
765        dict
766            behavior dictionary
767        """
768
769        return self.annotation_store.behaviors_dict()

Get a behavior dictionary

Returns

dict behavior dictionary

def bodyparts_order(self) -> List:
771    def bodyparts_order(self) -> List:
772        try:
773            return self.input_store.get_bodyparts()
774        except:
775            raise RuntimeError(
776                f"The {self.input_type} input store does not have bodyparts implemented!"
777            )
def features_shape(self) -> Dict:
779    def features_shape(self) -> Dict:
780        """
781        Get the shapes of the input features
782
783        Returns
784        -------
785        shapes : Dict
786            a dictionary with the shapes of the features
787        """
788
789        sample = self.input_store[0]
790        shapes = {k: v.shape for k, v in sample.items()}
791        # for key, value in shapes.items():
792        #     print(f'{key}: {value}')
793        return shapes

Get the shapes of the input features

Returns

shapes : Dict a dictionary with the shapes of the features

def num_classes(self) -> int:
795    def num_classes(self) -> int:
796        """
797        Get the number of classes in the data
798
799        Returns
800        -------
801        num_classes : int
802            the number of classes
803        """
804
805        return len(self.annotation_store.behaviors_dict())

Get the number of classes in the data

Returns

num_classes : int the number of classes

def len_segment(self) -> int:
807    def len_segment(self) -> int:
808        """
809        Get the segment length in the data
810
811        Returns
812        -------
813        len_segment : int
814            the segment length
815        """
816
817        sample = self.input_store[0]
818        key = list(sample.keys())[0]
819        return sample[key].shape[-1]

Get the segment length in the data

Returns

len_segment : int the segment length

def set_ssl_transformations(self, ssl_transformations: List) -> None:
821    def set_ssl_transformations(self, ssl_transformations: List) -> None:
822        """
823        Set new SSL transformations
824
825        Parameters
826        ----------
827        ssl_transformations : list
828            a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets
829            lists
830        """
831
832        self.ssl_transformations = ssl_transformations

Set new SSL transformations

Parameters

ssl_transformations : list a list of functions that take a sample feature dictionary as input and output ssl_inputs and ssl_targets lists

@classmethod
def new(cls, *args, **kwargs)
834    @classmethod
835    def new(cls, *args, **kwargs):
836        """
837        Create a new object of the same class
838
839        Returns
840        -------
841        new_instance: BehaviorDataset
842            a new instance of the same class
843        """
844
845        return cls(*args, **kwargs)

Create a new object of the same class

Returns

new_instance: BehaviorDataset a new instance of the same class

@classmethod
def get_parameters(cls, data_type: str, annotation_type: str) -> List:
847    @classmethod
848    def get_parameters(cls, data_type: str, annotation_type: str) -> List:
849        """
850        Get parameters necessary for initialization
851
852        Parameters
853        ----------
854        data_type : str
855            the data type
856        annotation_type : str
857            the annotation type
858        """
859
860        input_features = options.input_stores[data_type].get_parameters()
861        annotation_features = options.annotation_stores[
862            annotation_type
863        ].get_parameters()
864        self_features = inspect.getfullargspec(cls.__init__).args
865        return self_features + input_features + annotation_features

Get parameters necessary for initialization

Parameters

data_type : str the data type annotation_type : str the annotation type

@staticmethod
def data_types() -> List:
867    @staticmethod
868    def data_types() -> List:
869        """
870        List available data types
871
872        Returns
873        -------
874        data_types : list
875            available data types
876        """
877
878        return list(options.input_stores.keys())

List available data types

Returns

data_types : list available data types

@staticmethod
def annotation_types() -> List:
880    @staticmethod
881    def annotation_types() -> List:
882        """
883        List available annotation types
884
885        Returns
886        -------
887        annotation_types : list
888            available annotation types
889        """
890
891        return list(options.annotation_stores.keys())

List available annotation types

Returns

annotation_types : list available annotation types

def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None:
965    def set_indexing_parameters(self, unlabeled: bool, tag: int) -> None:
966        """
967        Set the parameters that change the subset that is returned at `__getitem__`
968
969        Parameters
970        ----------
971        unlabeled : bool
972            a pseudolabeling parameter; return only unlabeled samples if `True`, only labeled if `False` and
973            all if `None`
974        tag : int
975            if not `None`, only samples with this meta tag will be returned
976        """
977
978        if unlabeled != self.return_unlabeled:
979            self.annotation_indices = self.annotation_store.get_indices(unlabeled)
980            self.return_unlabeled = unlabeled
981        if tag != self.tag:
982            self.input_indices = self.input_store.get_indices(tag)
983            self.tag = tag
984        self.indices = [x for x in self.annotation_indices if x in self.input_indices]

Set the parameters that change the subset that is returned at __getitem__

Parameters

unlabeled : bool a pseudolabeling parameter; return only unlabeled samples if True, only labeled if False and all if None tag : int if not None, only samples with this meta tag will be returned

def get_intervals(self) -> Tuple[dict, Optional[list]]:
1410    def get_intervals(self) -> Tuple[dict, Optional[list]]:
1411        """
1412        Get a list of intervals covered by the dataset in the original coordinates
1413
1414        Returns
1415        -------
1416        intervals : dict
1417            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1418            values are lists of the intervals in `[start, end]` format
1419        """
1420
1421        counter = defaultdict(lambda: {})
1422        coordinates = self.input_store.get_original_coordinates()
1423        for coords in coordinates:
1424            l = self.input_store.get_clip_length_from_coords(coords)
1425            video_name = self.input_store.get_video_id(coords)
1426            clip_id = self.input_store.get_clip_id(coords)
1427            start, end = self.input_store.get_clip_start_end(coords)
1428            if clip_id not in counter[video_name]:
1429                counter[video_name][clip_id] = np.zeros(l)
1430            counter[video_name][clip_id][start:end] = 1
1431        result = {video_name: {} for video_name in counter}
1432        for video_name in counter:
1433            for clip_id in counter[video_name]:
1434                result[video_name][clip_id] = self._get_intervals_from_ind(
1435                    np.where(counter[video_name][clip_id])[0]
1436                )
1437        return result, self.ids

Get a list of intervals covered by the dataset in the original coordinates

Returns

intervals : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are lists of the intervals in [start, end] format

def get_unannotated_intervals(self, first_intervals=None) -> Dict:
1439    def get_unannotated_intervals(self, first_intervals=None) -> Dict:
1440        """
1441        Get a list of intervals in the original coordinates where there is no annotation
1442
1443        Returns
1444        -------
1445        intervals : dict
1446            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1447            values are lists of the intervals in `[start, end]` format
1448        """
1449
1450        counter_value = 2
1451        if first_intervals is None:
1452            first_intervals = defaultdict(lambda: defaultdict(lambda: []))
1453            counter_value = 1
1454        counter = defaultdict(lambda: {})
1455        coordinates = self.input_store.get_original_coordinates()
1456        for i, coords in enumerate(coordinates):
1457            l = self.input_store.get_clip_length_from_coords(coords)
1458            ann = self.annotation_store[i]
1459            if (
1460                self.annotation_store.annotation_class()
1461                == "nonexclusive_classification"
1462            ):
1463                ann = ann[0, :]
1464            video_name = self.input_store.get_video_id(coords)
1465            clip_id = self.input_store.get_clip_id(coords)
1466            start, end = self.input_store.get_clip_start_end(coords)
1467            if clip_id not in counter[video_name]:
1468                counter[video_name][clip_id] = np.ones(l)
1469            counter[video_name][clip_id][start:end] = (ann[: end - start] == -100).int()
1470        result = {video_name: {} for video_name in counter}
1471        for video_name in counter:
1472            for clip_id in counter[video_name]:
1473                for start, end in first_intervals[video_name][clip_id]:
1474                    counter[video_name][clip_id][start:end] += 1
1475                result[video_name][clip_id] = self._get_intervals_from_ind(
1476                    np.where(counter[video_name][clip_id] == counter_value)[0]
1477                )
1478        return result

Get a list of intervals in the original coordinates where there is no annotation

Returns

intervals : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are lists of the intervals in [start, end] format

def get_annotated_intervals(self) -> Dict:
1480    def get_annotated_intervals(self) -> Dict:
1481        """
1482        Get a list of intervals in the original coordinates where there is no annotation
1483
1484        Returns
1485        -------
1486        intervals : dict
1487            a nested dictionary where first-level keys are video ids, second-level keys are clip ids and
1488            values are lists of the intervals in `[start, end]` format
1489        """
1490
1491        if self.annotation_type == "none":
1492            return []
1493        counter_value = 1
1494        counter = defaultdict(lambda: {})
1495        coordinates = self.input_store.get_original_coordinates()
1496        for i, coords in enumerate(coordinates):
1497            l = self.input_store.get_clip_length_from_coords(coords)
1498            ann = self.annotation_store[i]
1499            video_name = self.input_store.get_video_id(coords)
1500            clip_id = self.input_store.get_clip_id(coords)
1501            start, end = self.input_store.get_clip_start_end(coords)
1502            if clip_id not in counter[video_name]:
1503                counter[video_name][clip_id] = np.zeros(l)
1504            if (
1505                self.annotation_store.annotation_class()
1506                == "nonexclusive_classification"
1507            ):
1508                counter[video_name][clip_id][start:end] = (
1509                    torch.sum(ann[:, : end - start] != -100, dim=0) > 0
1510                ).int()
1511            else:
1512                counter[video_name][clip_id][start:end] = (
1513                    ann[: end - start] != -100
1514                ).int()
1515        result = {video_name: {} for video_name in counter}
1516        for video_name in counter:
1517            for clip_id in counter[video_name]:
1518                result[video_name][clip_id] = self._get_intervals_from_ind(
1519                    np.where(counter[video_name][clip_id] == counter_value)[0]
1520                )
1521        return result

Get a list of intervals in the original coordinates where there is no annotation

Returns

intervals : dict a nested dictionary where first-level keys are video ids, second-level keys are clip ids and values are lists of the intervals in [start, end] format

def get_ids(self) -> Dict:
1523    def get_ids(self) -> Dict:
1524        """
1525        Get a dictionary of all clip ids in the dataset
1526
1527        Returns
1528        -------
1529        ids : dict
1530            a dictionary where keys are video ids and values are lists of clip ids
1531        """
1532
1533        coordinates = self.input_store.get_original_coordinates()
1534        video_ids = np.array(self.input_store.get_video_id_order())
1535        id_set = set(video_ids)
1536        result = {}
1537        for video_id in id_set:
1538            coords = coordinates[video_ids == video_id]
1539            clip_ids = list({self.input_store.get_clip_id(c) for c in coords})
1540            result[video_id] = clip_ids
1541        return result

Get a dictionary of all clip ids in the dataset

Returns

ids : dict a dictionary where keys are video ids and values are lists of clip ids

def get_len(self, video_id: str, clip_id: str) -> int:
1543    def get_len(self, video_id: str, clip_id: str) -> int:
1544        """
1545        Get the length of a specific clip
1546
1547        Parameters
1548        ----------
1549        video_id : str
1550            the video id
1551        clip_id : str
1552            the clip id
1553
1554        Returns
1555        -------
1556        length : int
1557            the length
1558        """
1559
1560        return self.input_store.get_clip_length(video_id, clip_id)

Get the length of a specific clip

Parameters

video_id : str the video id clip_id : str the clip id

Returns

length : int the length

def get_confusion_matrix( self, prediction: torch.Tensor, confusion_type: str = 'recall') -> Tuple[numpy.ndarray, list]:
1562    def get_confusion_matrix(
1563        self, prediction: torch.Tensor, confusion_type: str = "recall"
1564    ) -> Tuple[ndarray, list]:
1565        """
1566        Get a confusion matrix
1567
1568        Parameters
1569        ----------
1570        prediction : torch.Tensor
1571            a tensor of predicted class probabilities of shape `(#samples, #classes, #frames)`
1572        confusion_type : {"recall", "precision"}
1573            for datasets with non-exclusive annotation, if `type` is `"recall"`, only false positives are taken
1574            into account, and if `type` is `"precision"`, only false negatives
1575
1576        Returns
1577        -------
1578        confusion_matrix : np.ndarray
1579            a confusion matrix of shape `(#classes, #classes)` where `A[i, j] = F_ij/N_i`, `F_ij` is the number of
1580            frames that have the i-th label in the ground truth and a false positive j-th label in the prediction,
1581            `N_i` is the number of frames that have the i-th label in the ground truth
1582        classes : list
1583            a list of classes
1584        """
1585
1586        behaviors_dict = self.annotation_store.behaviors_dict()
1587        num_behaviors = len(behaviors_dict)
1588        confusion_matrix = np.zeros((num_behaviors, num_behaviors))
1589        if self.annotation_store.annotation_class() == "exclusive_classification":
1590            exclusive = True
1591            confusion_type = None
1592        elif self.annotation_store.annotation_class() == "nonexclusive_classification":
1593            exclusive = False
1594        else:
1595            raise RuntimeError(
1596                f"The {self.annotation_store.annotation_class()} annotation class is not recognized!"
1597            )
1598        for ann, p in zip(self.annotation_store, prediction):
1599            if exclusive:
1600                class_prediction = torch.max(p, dim=0)[1]
1601                for i in behaviors_dict.keys():
1602                    for j in behaviors_dict.keys():
1603                        confusion_matrix[i, j] += int(
1604                            torch.sum(class_prediction[ann == i] == j)
1605                        )
1606            else:
1607                class_prediction = (p > 0.5).int()
1608                for i in behaviors_dict.keys():
1609                    for j in behaviors_dict.keys():
1610                        if confusion_type == "recall":
1611                            pred = deepcopy(class_prediction[j])
1612                            if i != j:
1613                                pred[ann[j] == 1] = 0
1614                            confusion_matrix[i, j] += int(torch.sum(pred[ann[i] == 1]))
1615                        elif confusion_type == "precision":
1616                            annotation = deepcopy(ann[j])
1617                            if i != j:
1618                                annotation[class_prediction[j] == 1] = 0
1619                            confusion_matrix[i, j] += int(
1620                                torch.sum(annotation[class_prediction[i] == 1])
1621                            )
1622                        else:
1623                            raise ValueError(
1624                                f"The {confusion_type} type is not recognized; please choose from ['recall', 'precision']"
1625                            )
1626        counter = self.annotation_store.count_classes()
1627        for i in behaviors_dict.keys():
1628            if counter[i] != 0:
1629                if confusion_type == "recall" or confusion_type is None:
1630                    confusion_matrix[i, :] /= counter[i]
1631                else:
1632                    confusion_matrix[:, i] /= counter[i]
1633        return confusion_matrix, list(behaviors_dict.values()), confusion_type

Get a confusion matrix

Parameters

prediction : torch.Tensor a tensor of predicted class probabilities of shape (#samples, #classes, #frames) confusion_type : {"recall", "precision"} for datasets with non-exclusive annotation, if type is "recall", only false positives are taken into account, and if type is "precision", only false negatives

Returns

confusion_matrix : np.ndarray a confusion matrix of shape (#classes, #classes) where A[i, j] = F_ij/N_i, F_ij is the number of frames that have the i-th label in the ground truth and a false positive j-th label in the prediction, N_i is the number of frames that have the i-th label in the ground truth classes : list a list of classes

def annotation_class(self) -> str:
1635    def annotation_class(self) -> str:
1636        """
1637        Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon)
1638
1639        Returns
1640        -------
1641        annotation_class : str
1642            the type of annotation
1643        """
1644
1645        return self.annotation_store.annotation_class()

Get the type of annotation ('exclusive_classification', 'nonexclusive_classification', more coming soon)

Returns

annotation_class : str the type of annotation

def set_normalization_stats(self, stats: Dict) -> None:
1647    def set_normalization_stats(self, stats: Dict) -> None:
1648        """
1649        Set the stats to normalize data at runtime
1650
1651        Parameters
1652        ----------
1653        stats : dict
1654            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1655            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1656        """
1657
1658        self.stats = stats

Set the stats to normalize data at runtime

Parameters

stats : dict a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' and values are the statistics in torch tensors of shape (#features, 1)

def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]:
1660    def get_min_max_frames(self, video_id) -> Tuple[Dict, Dict]:
1661        coords = self.input_store.get_original_coordinates()
1662        clips = set(
1663            [
1664                self.input_store.get_clip_id(c)
1665                for c in coords
1666                if self.input_store.get_video_id(c) == video_id
1667            ]
1668        )
1669        min_frames = {}
1670        max_frames = {}
1671        for clip in clips:
1672            start = self.input_store.get_clip_start(video_id, clip)
1673            end = start + self.input_store.get_clip_length(video_id, clip)
1674            min_frames[clip] = start
1675            max_frames[clip] = end - 1
1676        return min_frames, max_frames
def get_normalization_stats(self, skip_keys=None) -> Dict:
1678    def get_normalization_stats(self, skip_keys=None) -> Dict:
1679        """
1680        Get mean and standard deviation for each key
1681
1682        Returns
1683        -------
1684        stats : dict
1685            a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std'
1686            and values are the statistics in `torch` tensors of shape `(#features, 1)`
1687        """
1688
1689        stats = defaultdict(lambda: {})
1690        sums = defaultdict(lambda: 0)
1691        if skip_keys is None:
1692            skip_keys = []
1693        counter = defaultdict(lambda: 0)
1694        for sample in tqdm(self):
1695            for key, value in sample["input"].items():
1696                key_name = key.split("---")[0]
1697                if key_name not in skip_keys:
1698                    sums[key_name] += value[:, value.sum(0) != 0].sum(-1)
1699                counter[key_name] += torch.sum(value.sum(0) != 0)
1700        for key, value in sums.items():
1701            stats[key]["mean"] = (value / counter[key]).unsqueeze(-1)
1702        sums = defaultdict(lambda: 0)
1703        for sample in tqdm(self):
1704            for key, value in sample["input"].items():
1705                key_name = key.split("---")[0]
1706                if key_name not in skip_keys:
1707                    sums[key_name] += (
1708                        (value[:, value.sum(0) != 0] - stats[key_name]["mean"]) ** 2
1709                    ).sum(-1)
1710        for key, value in sums.items():
1711            stats[key]["std"] = np.sqrt(value.unsqueeze(-1) / counter[key])
1712        return stats

Get mean and standard deviation for each key

Returns

stats : dict a nested dictionary where first-level keys are feature key names, second-level keys are 'mean' and 'std' and values are the statistics in torch tensors of shape (#features, 1)