dlc2action.data.input_store

Specific realisations of dlc2action.data.base_store.InputStore are defined here.

   1#
   2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version.
   5# A copy is included in dlc2action/LICENSE.AGPL.
   6#
   7"""Specific realisations of `dlc2action.data.base_store.InputStore` are defined here."""
   8
   9import mimetypes
  10import os
  11import pickle
  12from abc import abstractmethod
  13from collections import defaultdict
  14from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
  15
  16import numpy as np
  17import pandas as pd
  18import torch
  19from p_tqdm import p_map
  20from tqdm import tqdm
  21
  22from dlc2action import options
  23from dlc2action.data.base_store import PoseInputStore
  24from dlc2action.utils import TensorDict, strip_prefix, strip_suffix
  25
  26
  27class GeneralInputStore(PoseInputStore):
  28    """A generalized realization of a `PoseInputStore`.
  29
  30    Assumes the following file structure:
  31    ```
  32    data_path
  33    ├── video1DLC1000.pickle
  34    ├── video2DLC400.pickle
  35    ├── video1_features.pt
  36    └── video2_features.pt
  37    ```
  38    Here `data_suffix` is `{'DLC1000.pickle', 'DLC400.pickle'}` and `feature_suffix` (optional) is `'_features.pt'`.
  39    """
  40
  41    data_suffix = None
  42
  43    def __init__(
  44        self,
  45        video_order: List = None,
  46        data_path: Union[Set, str] = None,
  47        file_paths: Set = None,
  48        data_suffix: Union[Set, str] = None,
  49        data_prefix: Union[Set, str] = None,
  50        feature_suffix: str = None,
  51        convert_int_indices: bool = True,
  52        feature_save_path: str = None,
  53        canvas_shape: List = None,
  54        len_segment: int = 128,
  55        overlap: int = 0,
  56        feature_extraction: str = "kinematic",
  57        ignored_clips: List = None,
  58        ignored_bodyparts: List = None,
  59        default_agent_name: str = "ind0",
  60        key_objects: Tuple = None,
  61        likelihood_threshold: float = 0,
  62        num_cpus: int = None,
  63        frame_limit: int = 1,
  64        normalize: bool = False,
  65        feature_extraction_pars: Dict = None,
  66        centered: bool = False,
  67        transpose_features: bool = False,
  68        *args,
  69        **kwargs,
  70    ) -> None:
  71        """Initialize a store.
  72
  73        Parameters
  74        ----------
  75        video_order : list, optional
  76            a list of video ids that should be processed in the same order (not passed if creating from key objects
  77        data_path : str | set, optional
  78            the path to the folder where the pose and feature files are stored or a set of such paths
  79            (not passed if creating from key objects or from `file_paths`)
  80        file_paths : set, optional
  81            a set of string paths to the pose and feature files
  82            (not passed if creating from key objects or from `data_path`)
  83        data_suffix : str | set, optional
  84            the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix}
  85            (not passed if creating from key objects or if irrelevant for the dataset)
  86        data_prefix : str | set, optional
  87            the prefix or the set of prefixes such that the pose files for different video views of the same
  88            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
  89            or if irrelevant for the dataset)
  90        feature_suffix : str | set, optional
  91            the suffix or the set of suffices such that the additional feature files are named
  92            {video_id}{feature_suffix} (and placed at the data_path folder)
  93        convert_int_indices : bool, default True
  94            if `True`, convert any integer key `i` in feature files to `'ind{i}'`
  95        feature_save_path : str, optional
  96            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
  97        canvas_shape : List, default [1, 1]
  98            the canvas size where the pose is defined
  99        len_segment : int, default 128
 100            the length of the segments in which the data should be cut (in frames)
 101        overlap : int, default 0
 102            the length of the overlap between neighboring segments (in frames)
 103        feature_extraction : str, default 'kinematic'
 104            the feature extraction method (see options.feature_extractors for available options)
 105        ignored_clips : list, optional
 106            list of strings of clip ids to ignore
 107        ignored_bodyparts : list, optional
 108            list of strings of bodypart names to ignore
 109        default_agent_name : str, default 'ind0'
 110            the agent name used as default in the pose files for a single agent
 111        key_objects : tuple, optional
 112            a tuple of key objects
 113        likelihood_threshold : float, default 0
 114            coordinate values with likelihoods less than this value will be set to 'unknown'
 115        num_cpus : int, optional
 116            the number of cpus to use in data processing
 117        frame_limit : int, default 1
 118            clips shorter than this number of frames will be ignored
 119        normalize : bool, default False
 120            whether to normalize the pose
 121        feature_extraction_pars : dict, optional
 122            parameters of the feature extractor
 123        centered : bool, default False
 124            whether the pose is centered
 125        transpose_features : bool, default False
 126            whether to transpose the features
 127
 128        """
 129        super().__init__()
 130        self.loaded_max = 0
 131        if feature_extraction_pars is None:
 132            feature_extraction_pars = {}
 133        if ignored_clips is None:
 134            ignored_clips = []
 135        self.bodyparts = []
 136        self.visibility = None
 137        self.normalize = normalize
 138
 139        if canvas_shape is None:
 140            canvas_shape = [1, 1]
 141        if isinstance(data_suffix, str):
 142            data_suffix = [data_suffix]
 143        if isinstance(data_prefix, str):
 144            data_prefix = [data_prefix]
 145        if isinstance(data_path, str):
 146            data_path = [data_path]
 147        if isinstance(feature_suffix, str):
 148            feature_suffix = [feature_suffix]
 149
 150        self.video_order = video_order
 151        self.centered = centered
 152        self.feature_extraction = feature_extraction
 153        self.len_segment = int(len_segment)
 154        self.data_suffices = data_suffix
 155        self.data_prefixes = data_prefix
 156        self.feature_suffix = feature_suffix
 157        self.convert_int_indices = convert_int_indices
 158        if isinstance(overlap, str):
 159            overlap = float(overlap)
 160        if overlap < 1:
 161            overlap = overlap * self.len_segment
 162        self.overlap = int(overlap)
 163        self.canvas_shape = canvas_shape
 164        self.default_agent_name = default_agent_name
 165        self.feature_save_path = feature_save_path
 166        self.data_suffices = data_suffix
 167        self.data_prefixes = data_prefix
 168        self.likelihood_threshold = likelihood_threshold
 169        self.num_cpus = num_cpus
 170        self.frame_limit = frame_limit
 171        self.transpose = transpose_features
 172
 173        self.ram = False
 174        self.min_frames = {}
 175        self.original_coordinates = np.array([])
 176
 177        self.file_paths = self._get_file_paths(file_paths, data_path)
 178
 179        self.extractor = options.feature_extractors[self.feature_extraction](
 180            self,
 181            **feature_extraction_pars,
 182        )
 183
 184        self.canvas_center = np.array(canvas_shape) // 2
 185
 186        if ignored_clips is not None:
 187            self.ignored_clips = ignored_clips
 188        else:
 189            self.ignored_clips = []
 190        if ignored_bodyparts is not None:
 191            self.ignored_bodyparts = ignored_bodyparts
 192        else:
 193            self.ignored_bodyparts = []
 194
 195        self.step = self.len_segment - self.overlap
 196        if self.step < 0:
 197            raise ValueError(
 198                f"The overlap value ({self.overlap}) cannot be larger than len_segment ({self.len_segment}"
 199            )
 200
 201        if self.feature_save_path is None and data_path is not None:
 202            self.feature_save_path = os.path.join(data_path[0], "trimmed")
 203
 204        if key_objects is None and self.video_order is not None:
 205            print("Computing input features...")
 206            self.data = self._load_data()
 207        elif key_objects is not None:
 208            self.load_from_key_objects(key_objects)
 209
 210    def __getitem__(self, ind: int) -> Dict:
 211        """Get a single item from the dataset."""
 212        prompt = self.data[ind]
 213        if not self.ram:
 214            with open(prompt, "rb") as f:
 215                prompt = pickle.load(f)
 216        return prompt
 217
 218    def __len__(self) -> int:
 219        """Get the length of the dataset."""
 220        if self.data is None:
 221            raise RuntimeError("The input store data has not been initialized!")
 222        return len(self.data)
 223
 224    @classmethod
 225    def _get_file_paths(cls, file_paths: Set, data_path: Union[str, Set]) -> List:
 226        """Get a set of relevant files.
 227
 228        Parameters
 229        ----------
 230        file_paths : set
 231            a set of filepaths to include
 232        data_path : str | set
 233            the path to a folder that contains relevant files (a single path or a set)
 234
 235        Returns
 236        -------
 237        file_paths : list
 238            a list of relevant file paths (input and feature files that follow the dataset naming pattern)
 239
 240        """
 241        if file_paths is None:
 242            file_paths = []
 243        file_paths = list(file_paths)
 244        if data_path is not None:
 245            if isinstance(data_path, str):
 246                data_path = [data_path]
 247            for folder in data_path:
 248                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
 249        return file_paths
 250
 251    def get_folder(self, video_id: str) -> str:
 252        """Get the input folder that the file with this video id was read from.
 253
 254        Parameters
 255        ----------
 256        video_id : str
 257            the video id
 258
 259        Returns
 260        -------
 261        folder : str
 262            the path to the directory that contains the input file associated with the video id
 263
 264        """
 265        for file in self.file_paths:
 266            if (
 267                strip_prefix(
 268                    strip_suffix(os.path.basename(file), self.data_suffices),
 269                    self.data_prefixes,
 270                )
 271                == video_id
 272            ):
 273                return os.path.dirname(file)
 274
 275    def remove(self, indices: List) -> None:
 276        """Remove the samples corresponding to indices.
 277
 278        Parameters
 279        ----------
 280        indices : int
 281            a list of integer indices to remove
 282
 283        """
 284        if len(indices) > 0:
 285            mask = np.ones(len(self.original_coordinates))
 286            mask[indices] = 0
 287            mask = mask.astype(bool)
 288            for file in self.data[~mask]:
 289                os.remove(file)
 290            self.original_coordinates = self.original_coordinates[mask]
 291            self.data = self.data[mask]
 292            if self.metadata is not None:
 293                self.metadata = self.metadata[mask]
 294
 295    def key_objects(self) -> Tuple:
 296        """Return a tuple of the key objects necessary to re-create the Store.
 297
 298        Returns
 299        -------
 300        key_objects : tuple
 301            a tuple of key objects
 302
 303        """
 304        for k, v in self.min_frames.items():
 305            self.min_frames[k] = dict(v)
 306        for k, v in self.max_frames.items():
 307            self.max_frames[k] = dict(v)
 308        return (
 309            self.original_coordinates,
 310            dict(self.min_frames),
 311            dict(self.max_frames),
 312            self.data,
 313            self.visibility,
 314            self.step,
 315            self.file_paths,
 316            self.len_segment,
 317            self.metadata,
 318        )
 319
 320    def load_from_key_objects(self, key_objects: Tuple) -> None:
 321        """Load the information from a tuple of key objects.
 322
 323        Parameters
 324        ----------
 325        key_objects : tuple
 326            a tuple of key objects
 327
 328        """
 329        (
 330            self.original_coordinates,
 331            self.min_frames,
 332            self.max_frames,
 333            self.data,
 334            self.visibility,
 335            self.step,
 336            self.file_paths,
 337            self.len_segment,
 338            self.metadata,
 339        ) = key_objects
 340
 341    def to_ram(self) -> None:
 342        """Transfer the data samples to RAM if they were previously stored as file paths."""
 343        if self.ram:
 344            return
 345
 346        if os.name != "nt":
 347            data = p_map(
 348                lambda x: self[x], list(range(len(self))), num_cpus=self.num_cpus
 349            )
 350        else:
 351            print(
 352                "Multiprocessing is not supported on Windows, loading files sequentially."
 353            )
 354            data = [load(x) for x in tqdm(self.data)]
 355        self.data = TensorDict(data)
 356        self.ram = True
 357
 358    def get_original_coordinates(self) -> np.ndarray:
 359        """Return the original coordinates array.
 360
 361        Returns
 362        -------
 363        np.ndarray
 364            an array that contains the coordinates of the data samples in original input data (video id, clip id,
 365            start frame)
 366
 367        """
 368        return self.original_coordinates
 369
 370    def create_subsample(self, indices: List, ssl_indices: List = None):
 371        """Create a new store that contains a subsample of the data.
 372
 373        Parameters
 374        ----------
 375        indices : list
 376            the indices to be included in the subsample
 377        ssl_indices : list, optional
 378            the indices to be included in the subsample without the annotation data
 379
 380        """
 381        if ssl_indices is None:
 382            ssl_indices = []
 383        new = self.new()
 384        new.original_coordinates = self.original_coordinates[indices + ssl_indices]
 385        new.min_frames = self.min_frames
 386        new.max_frames = self.max_frames
 387        new.data = self.data[indices + ssl_indices]
 388        new.visibility = self.visibility
 389        new.step = self.step
 390        new.file_paths = self.file_paths
 391        new.len_segment = self.len_segment
 392        if self.metadata is None:
 393            new.metadata = None
 394        else:
 395            new.metadata = self.metadata[indices + ssl_indices]
 396        return new
 397
 398    def get_video_id(self, coords: Tuple) -> str:
 399        """Get the video id from an element of original coordinates.
 400
 401        Parameters
 402        ----------
 403        coords : tuple
 404            an element of the original coordinates array
 405
 406        Returns
 407        -------
 408        video_id: str
 409            the id of the video that the coordinates point to
 410
 411        """
 412        video_name = coords[0].split("---")[0]
 413        return video_name
 414
 415    def get_clip_id(self, coords: Tuple) -> str:
 416        """Get the clip id from an element of original coordinates.
 417
 418        Parameters
 419        ----------
 420        coords : tuple
 421            an element of the original coordinates array
 422
 423        Returns
 424        -------
 425        clip_id : str
 426            the id of the clip that the coordinates point to
 427
 428        """
 429        clip_id = coords[0].split("---")[1]
 430        return clip_id
 431
 432    def get_clip_length(self, video_id: str, clip_id: str) -> int:
 433        """Get the clip length from the id.
 434
 435        Parameters
 436        ----------
 437        video_id : str
 438            the video id
 439        clip_id : str
 440            the clip id
 441
 442        Returns
 443        -------
 444        clip_length : int
 445            the length of the clip
 446
 447        """
 448        inds = clip_id.split("+")
 449        max_frame = min([self.max_frames[video_id][x] for x in inds])
 450        min_frame = max([self.min_frames[video_id][x] for x in inds])
 451        return max_frame - min_frame + 1
 452
 453    def get_clip_start_end(self, coords: Tuple) -> Tuple[int, int]:
 454        """Get the clip start and end frames from an element of original coordinates.
 455
 456        Parameters
 457        ----------
 458        coords : tuple
 459            an element of original coordinates array
 460
 461        Returns
 462        -------
 463        start : int
 464            the start frame of the clip that the coordinates point to
 465        end : int
 466            the end frame of the clip that the coordinates point to
 467
 468        """
 469        l = self.get_clip_length_from_coords(coords)
 470        i = coords[1]
 471        start = int(i) * self.step
 472        end = min(start + self.len_segment, l)
 473        return start, end
 474
 475    def get_clip_start(self, video_name: str, clip_id: str) -> int:
 476        """Get the clip start frame from the video id and the clip id.
 477
 478        Parameters
 479        ----------
 480        video_name : str
 481            the video id
 482        clip_id : str
 483            the clip id
 484
 485        Returns
 486        -------
 487        clip_start : int
 488            the start frame of the clip
 489
 490        """
 491        return max(
 492            [self.min_frames[video_name][clip_id_k] for clip_id_k in clip_id.split("+")]
 493        )
 494
 495    def get_visibility(
 496        self, video_id: str, clip_id: str, start: int, end: int, score: int
 497    ) -> float:
 498        """Get the fraction of the frames in that have a visibility score better than a hard_threshold.
 499
 500        For example, in the case of keypoint data the visibility score can be the number of identified keypoints.
 501
 502        Parameters
 503        ----------
 504        video_id : str
 505            the video id of the frames
 506        clip_id : str
 507            the clip id of the frames
 508        start : int
 509            the start frame
 510        end : int
 511            the end frame
 512        score : float
 513            the visibility score hard_threshold
 514
 515        Returns
 516        -------
 517        frac_visible: float
 518            the fraction of frames with visibility above the hard_threshold
 519
 520        """
 521        s = 0
 522        for ind_k in clip_id.split("+"):
 523            s += np.sum(self.visibility[video_id][ind_k][start:end] > score) / (
 524                end - start
 525            )
 526        return s / len(clip_id.split("+"))
 527
 528    def get_annotation_objects(self) -> Dict:
 529        """Get a dictionary of objects necessary to create an `BehaviorStore`.
 530
 531        Returns
 532        -------
 533        annotation_objects : dict
 534            a dictionary of objects to be passed to the BehaviorStore constructor where the keys are the names of
 535            the objects
 536
 537        """
 538        min_frames = self.min_frames
 539        max_frames = self.max_frames
 540        num_bp = self.visibility
 541        return {
 542            "min_frames": min_frames,
 543            "max_frames": max_frames,
 544            "visibility": num_bp,
 545        }
 546
 547    @classmethod
 548    def get_file_ids(
 549        cls,
 550        data_suffix: Union[Set, str] = None,
 551        data_path: Union[Set, str] = None,
 552        data_prefix: Union[Set, str] = None,
 553        file_paths: Set = None,
 554        feature_suffix: Set = None,
 555        *args,
 556        **kwargs,
 557    ) -> List:
 558        """Get file ids.
 559
 560        Process data parameters and return a list of ids  of the videos that should
 561        be processed by the `__init__` function.
 562
 563        Parameters
 564        ----------
 565        data_suffix : set | str, optional
 566            the suffix (or a set of suffixes) of the input data files
 567        data_path : set | str, optional
 568            the path to the folder where the pose and feature files are stored or a set of such paths
 569            (not passed if creating from key objects or from `file_paths`)
 570        data_prefix : set | str, optional
 571            the prefix or the set of prefixes such that the pose files for different video views of the same
 572            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
 573            or if irrelevant for the dataset)
 574        file_paths : set, optional
 575            a set of string paths to the pose and feature files
 576        feature_suffix : str | set, optional
 577            the suffix or the set of suffices such that the additional feature files are named
 578            {video_id}{feature_suffix} (and placed at the `data_path` folder or at `file_paths`)
 579
 580        Returns
 581        -------
 582        video_ids : list
 583            a list of video file ids
 584
 585        """
 586        if data_suffix is None:
 587            if cls.data_suffix is not None:
 588                data_suffix = cls.data_suffix
 589            else:
 590                raise ValueError("Cannot get video ids without the data suffix!")
 591        if feature_suffix is None:
 592            feature_suffix = []
 593        if data_prefix is None:
 594            data_prefix = ""
 595        if isinstance(data_suffix, str):
 596            data_suffix = [data_suffix]
 597        else:
 598            data_suffix = [x for x in data_suffix]
 599        data_suffix = tuple(data_suffix)
 600        if isinstance(data_prefix, str):
 601            data_prefix = data_prefix
 602        else:
 603            data_prefix = tuple([x for x in data_prefix])
 604        if isinstance(feature_suffix, str):
 605            feature_suffix = [feature_suffix]
 606        if file_paths is None:
 607            file_paths = []
 608        if data_path is not None:
 609            if isinstance(data_path, str):
 610                data_path = [data_path]
 611            file_paths = []
 612            for folder in data_path:
 613                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
 614        basenames = [os.path.basename(f) for f in file_paths]
 615        ids = set()
 616        for f in file_paths:
 617            if f.endswith(data_suffix) and os.path.basename(f).startswith(data_prefix):
 618                bn = os.path.basename(f)
 619                video_id = strip_prefix(strip_suffix(bn, data_suffix), data_prefix)
 620                if all([video_id + s in basenames for s in feature_suffix]):
 621                    ids.add(video_id)
 622        ids = sorted(ids)
 623        return ids
 624
 625    def get_bodyparts(self) -> List:
 626        """Get a list of bodypart names.
 627
 628        Parameters
 629        ----------
 630        data_dict : dict
 631            the data dictionary (passed to feature extractor)
 632        clip_id : str
 633            the clip id
 634
 635        Returns
 636        -------
 637        bodyparts : list
 638            a list of string or integer body part names
 639
 640        """
 641        return [x for x in self.bodyparts if x not in self.ignored_bodyparts]
 642
 643    def get_coords(self, data_dict: Dict, clip_id: str, bodypart: str) -> np.ndarray:
 644        """Get the coordinates array of a specific bodypart in a specific clip.
 645
 646        Parameters
 647        ----------
 648        data_dict : dict
 649            the data dictionary (passed to feature extractor)
 650        clip_id : str
 651            the clip id
 652        bodypart : str
 653            the name of the body part
 654
 655        Returns
 656        -------
 657        coords : np.ndarray
 658            the coordinates array of shape (#timesteps, #coordinates)
 659
 660        """
 661        columns = [x for x in data_dict[clip_id].columns if x != "likelihood"]
 662        xy_coord = (
 663            data_dict[clip_id]
 664            .xs(bodypart, axis=0, level=1, drop_level=False)[columns]
 665            .values
 666        )
 667        return xy_coord
 668
 669    def get_n_frames(self, data_dict: Dict, clip_id: str) -> int:
 670        """Get the length of the clip.
 671
 672        Parameters
 673        ----------
 674        data_dict : dict
 675            the data dictionary (passed to feature extractor)
 676        clip_id : str
 677            the clip id
 678
 679        Returns
 680        -------
 681        n_frames : int
 682            the length of the clip
 683
 684        """
 685        if clip_id in data_dict:
 686            return len(data_dict[clip_id].groupby(level=0))
 687        else:
 688            return min(
 689                [len(data_dict[ind_k].groupby(level=0)) for ind_k in clip_id.split("+")]
 690            )
 691
 692    def _filter(self, data_dict: Dict) -> Tuple[Dict, Dict, Dict]:
 693        """Apply filters to a data dictionary + normalize the values and generate frame index dictionaries.
 694
 695        The filters include filling nan values, applying length and likelihood thresholds and removing
 696        ignored clip ids.
 697
 698        """
 699        new_data_dict = {}
 700        keys = list(data_dict.keys())
 701        for key in keys:
 702            if key == "loaded":
 703                continue
 704            coord = data_dict.pop(key)
 705            if key in self.ignored_clips:
 706                continue
 707            num_frames = len(coord.index.unique(level=0))
 708            if num_frames < self.frame_limit:
 709                continue
 710            if "likelihood" in coord.columns:
 711                columns = list(coord.columns)
 712                columns.remove("likelihood")
 713                coord.loc[coord["likelihood"] < self.likelihood_threshold, columns] = (
 714                    np.nan
 715                )
 716            if not isinstance(self.centered, Iterable):
 717                self.centered = [
 718                    bool(self.centered)
 719                    for dim in ["x", "y", "z"]
 720                    if dim in coord.columns
 721                ]
 722            for i, dim in enumerate(["x", "y", "z"]):
 723                if dim in coord.columns:
 724                    if self.centered[i]:
 725                        coord[dim] = coord[dim] + self.canvas_shape[i] // 2
 726                    # coord.loc[coord[dim] < -self.canvas_shape[i] * 3 // 2, dim] = np.nan
 727                    # coord.loc[coord[dim] > self.canvas_shape[i] * 3 // 2, dim] = np.nan
 728            coord = coord.sort_index(level=0)
 729            for bp in coord.index.unique(level=1):
 730                coord.loc[coord.index.isin([bp], level=1)] = coord[
 731                    coord.index.isin([bp], level=1)
 732                ].interpolate()
 733            dims = [x for x in coord.columns if x != "likelihood"]
 734            mask = ~coord[dims[0]].isna()
 735            for dim in dims[1:]:
 736                mask = mask & (~coord[dim].isna())
 737            mean = coord.loc[mask].groupby(level=0).mean()
 738            for frame in set(coord.index.get_level_values(0)):
 739                if frame not in mean.index:
 740                    mean.loc[frame] = [np.nan for _ in mean.columns]
 741            mean = mean.interpolate()
 742            mean[mean.isna()] = 0
 743            for dim in coord.columns:
 744                if dim == "likelihood":
 745                    continue
 746                coord.loc[coord[dim].isna(), dim] = mean.loc[
 747                    coord.loc[coord[dim].isna()].index.get_level_values(0)
 748                ][dim].to_numpy()
 749            if np.sum(self.canvas_shape) > 0:
 750                for i, dim in enumerate(["x", "y", "z"]):
 751                    if dim in coord.columns:
 752                        coord[dim] = (
 753                            coord[dim] - self.canvas_shape[i] // 2
 754                        ) / self.canvas_shape[0]
 755            new_data_dict[key] = coord
 756        max_frames = {}
 757        min_frames = {}
 758        for key, value in new_data_dict.items():
 759            max_frames[key] = max(value.index.unique(0))
 760            min_frames[key] = min(value.index.unique(0))
 761        if "loaded" in data_dict:
 762            new_data_dict["loaded"] = data_dict["loaded"]
 763        return new_data_dict, min_frames, max_frames
 764
 765    def _get_files_from_ids(self):
 766        files = defaultdict(lambda: [])
 767        used_prefixes = defaultdict(lambda: [])
 768        for f in self.file_paths:
 769            if f.endswith(tuple([x for x in self.data_suffices])):
 770                bn = os.path.basename(f)
 771                video_id = strip_prefix(
 772                    strip_suffix(bn, self.data_suffices), self.data_prefixes
 773                )
 774                ok = True
 775                if self.data_prefixes is not None:
 776                    for p in self.data_prefixes:
 777                        if bn.startswith(p):
 778                            if p not in used_prefixes[video_id]:
 779                                used_prefixes[video_id].append(p)
 780                            else:
 781                                ok = False
 782                            break
 783                if not ok:
 784                    continue
 785                files[video_id].append(f)
 786        files = [files[x] for x in self.video_order]
 787        return files
 788
 789    def _make_trimmed_data(self, keypoint_dict: Dict) -> Tuple[List, Dict, List]:
 790        """Cut a keypoint dictionary into overlapping pieces of equal length."""
 791        X = []
 792        original_coordinates = []
 793        lengths = defaultdict(lambda: {})
 794        os.makedirs(self.feature_save_path, exist_ok=True)
 795        order = sorted(list(keypoint_dict.keys()))
 796        for v_id in order:
 797            keypoints = keypoint_dict[v_id]
 798            v_len = min([len(x) for x in keypoints.values()])
 799            sp = np.arange(0, v_len, self.step)
 800            pad = sp[-1] + self.len_segment - v_len
 801            video_id, clip_id = v_id.split("---")
 802            for key in keypoints:
 803                if len(keypoints[key]) > v_len:
 804                    keypoints[key] = keypoints[key][:v_len]
 805                if len(keypoints[key].shape) == 2:
 806                    keypoints[key] = np.pad(keypoints[key], ((0, pad), (0, 0)))
 807                else:
 808                    keypoints[key] = np.pad(
 809                        keypoints[key], ((0, pad), (0, 0), (0, 0), (0, 0))
 810                    )
 811            for i, start in enumerate(sp):
 812                sample_dict = {}
 813                original_coordinates.append((v_id, i))
 814                for key in keypoints:
 815                    sample_dict[key] = keypoints[key][start : start + self.len_segment]
 816                    arr = np.asarray(sample_dict[key], dtype=np.float32)
 817                    tensor = (
 818                        torch.from_numpy(arr)
 819                        .permute(*range(1, arr.ndim), 0)
 820                        .contiguous()
 821                    )
 822                    sample_dict[key] = tensor
 823
 824                name = os.path.join(self.feature_save_path, f"{v_id}_{start}.pickle")
 825                X.append(name)
 826                lengths[video_id][clip_id] = v_len
 827                with open(name, "wb") as f:
 828                    pickle.dump(sample_dict, f)
 829        return X, dict(lengths), original_coordinates
 830
 831    def _load_saved_features(self, video_id: str):
 832        """Load saved features file `(#frames, #features)`."""
 833        basenames = [os.path.basename(x) for x in self.file_paths]
 834        loaded_features_cat = []
 835        self.feature_suffix = sorted(self.feature_suffix)
 836        for feature_suffix in self.feature_suffix:
 837            i = basenames.index(os.path.basename(video_id) + feature_suffix)
 838            path = self.file_paths[i]
 839            if not os.path.exists(path):
 840                raise RuntimeError(f"Did not find a feature file for {video_id}!")
 841            extension = feature_suffix.split(".")[-1]
 842            if extension in ["pickle", "pkl"]:
 843                with open(path, "rb") as f:
 844                    loaded_features = pickle.load(f)
 845            elif extension in ["pt", "pth"]:
 846                loaded_features = torch.load(path)
 847            elif extension == "npy":
 848                try:
 849                    loaded_features = np.load(path, allow_pickle=True).item()
 850                except:
 851                    loaded_features = np.load(path, allow_pickle=True)
 852                    loaded_features = {
 853                        "features": loaded_features,
 854                        "min_frames": {video_id: 0},
 855                        "max_frames": {video_id: len(loaded_features)},
 856                        "video_tag": video_id,
 857                    }
 858            else:
 859                raise ValueError(
 860                    f"Found feature file in an unrecognized format: .{extension}. \n "
 861                    "Please save with torch (as .pt or .pth), numpy (as .npy) or pickle (as .pickle or .pkl)."
 862                )
 863            loaded_features_cat.append(loaded_features)
 864        keys = list(loaded_features_cat[0].keys())
 865        loaded_features = {}
 866        for k in keys:
 867            if k in ["min_frames", "max_frames", "video_tag"]:
 868                loaded_features[k] = loaded_features_cat[0][k]
 869                continue
 870            features = []
 871            for x in loaded_features_cat:
 872                if not isinstance(x[k], torch.Tensor):
 873                    features.append(torch.from_numpy(x[k]))
 874                else:
 875                    features.append(x[k])
 876            a = torch.cat(features)
 877            if self.transpose:
 878                a = a.T
 879            loaded_features[k] = a
 880        return loaded_features
 881
 882    def get_likelihood(
 883        self, data_dict: Dict, clip_id: str, bodypart: str
 884    ) -> Union[np.ndarray, None]:
 885        """Get the likelihood values.
 886
 887        Parameters
 888        ----------
 889        data_dict : dict
 890            the data dictionary
 891        clip_id : str
 892            the clip id
 893        bodypart : str
 894            the name of the body part
 895
 896        Returns
 897        -------
 898        likelihoods: np.ndarrray | None
 899            `None` if the dataset doesn't have likelihoods or an array of shape (#timestamps)
 900
 901        """
 902        if "likelihood" in data_dict[clip_id].columns:
 903            likelihood = (
 904                data_dict[clip_id]
 905                .xs(bodypart, axis=0, level=1, drop_level=False)
 906                .values[:, -1]
 907            )
 908            return likelihood
 909        else:
 910            return None
 911
 912    def _get_video_metadata(self, metadata_list: Optional[List]):
 913        """Make a single metadata dictionary from a list of dictionaries received from different data prefixes."""
 914        if metadata_list is None:
 915            return None
 916        else:
 917            return metadata_list[0]
 918
 919    def get_indices(self, tag: int) -> List:
 920        """Get a list of indices of samples that have a specific meta tag.
 921
 922        Parameters
 923        ----------
 924        tag : int
 925            the meta tag for the subsample (`None` for the whole dataset)
 926
 927        Returns
 928        -------
 929        indices : list
 930            a list of indices that meet the criteria
 931
 932        """
 933        if tag is None:
 934            return list(range(len(self.data)))
 935        else:
 936            return list(np.where(self.metadata == tag)[0])
 937
 938    def get_tags(self) -> List:
 939        """Get a list of all meta tags.
 940
 941        Returns
 942        -------
 943        tags: List
 944            a list of unique meta tag values
 945
 946        """
 947        if self.metadata is None:
 948            return [None]
 949        else:
 950            return list(np.unique(self.metadata))
 951
 952    def get_tag(self, idx: int) -> Union[int, None]:
 953        """Return a tag object corresponding to an index.
 954
 955        Tags can carry meta information (like annotator id) and are accepted by models that require
 956        that information. When a tag is `None`, it is not passed to the model.
 957
 958        Parameters
 959        ----------
 960        idx : int
 961            the index
 962
 963        Returns
 964        -------
 965        tag : int
 966            the tag object
 967
 968        """
 969        if self.metadata is None or idx is None:
 970            return None
 971        else:
 972            return self.metadata[idx]
 973
 974    @abstractmethod
 975    def _load_data(self) -> None:
 976        """Load input data and generate data prompts."""
 977
 978
 979class FileInputStore(GeneralInputStore):
 980    """An implementation of `dlc2action.data.InputStore` for datasets where each input data file corresponds to one video."""
 981
 982    def _count_bodyparts(
 983        self, data: Dict, stripped_name: str, max_frames: Dict
 984    ) -> Dict:
 985        """Create a visibility score dictionary (with a score from 0 to 1 assigned to each frame of each clip)."""
 986        result = {stripped_name: {}}
 987        prefixes = list(data.keys())
 988        for ind in data[prefixes[0]]:
 989            res = 0
 990            for _, data_dict in data.items():
 991                num_bp = len(data_dict[ind].index.unique(level=1))
 992                coords = (
 993                    data_dict[ind].values.reshape(
 994                        -1, num_bp, len(data_dict[ind].columns)
 995                    )[: max_frames[ind], :, 0]
 996                    != 0
 997                )
 998                res = np.sum(coords, axis=1) + res
 999            result[stripped_name][ind] = (res / len(prefixes)) / coords.shape[1]
1000        return result
1001
1002    def _generate_features(self, data: Dict, video_id: str) -> Dict:
1003        """Generate features from the raw coordinates."""
1004        features = defaultdict(lambda: {})
1005        loaded_common = []
1006
1007        for prefix, data_dict in data.items():
1008            if prefix == "":
1009                prefix = None
1010            if "loaded" in data_dict:
1011                # loaded_common.append(torch.tensor(data_dict.pop("loaded")))
1012                loaded_common.append(torch.from_numpy(data_dict.pop("loaded")))
1013            key_features = self.extractor.extract_features(
1014                data_dict, video_id, prefix=prefix
1015            )
1016            for f_key in key_features:
1017                features[f_key].update(key_features[f_key])
1018        if len(loaded_common) > 0:
1019            if len(loaded_common) == 1:
1020                loaded_common = loaded_common[0]
1021            else:
1022                loaded_common = torch.cat(loaded_common, dim=1)
1023        else:
1024            loaded_common = None
1025        if self.feature_suffix is not None:
1026            loaded_features = self._load_saved_features(video_id)
1027            for clip_id, feature_tensor in loaded_features.items():
1028                if not isinstance(feature_tensor, torch.Tensor):
1029                    feature_tensor = torch.tensor(feature_tensor)
1030                if self.convert_int_indices and (
1031                    isinstance(clip_id, int) or isinstance(clip_id, np.integer)
1032                ):
1033                    clip_id = f"ind{clip_id}"
1034                key1 = f"{os.path.basename(video_id)}---{clip_id}"
1035                if key1 in features:
1036                    try:
1037                        key2 = list(features[key1].keys())[0]
1038                        n_frames = features[key1][key2].shape[0]
1039                        if feature_tensor.shape[0] != n_frames:
1040                            n = feature_tensor.shape[0] - n_frames
1041                            if (
1042                                abs(n) > 2
1043                                and abs(feature_tensor.shape[1] - n_frames) <= 2
1044                            ):
1045                                feature_tensor = feature_tensor.T
1046                            # If off by <=2 frames, just clip the end
1047                            elif n > 0 and n <= 2:
1048                                feature_tensor = feature_tensor[:n_frames, :]
1049                            elif n < 0 and n >= -2:
1050                                filler = feature_tensor[-2:-1, :]
1051                                for i in range(n_frames - feature_tensor.shape[0]):
1052                                    feature_tensor = torch.cat(
1053                                        [feature_tensor, filler], 0
1054                                    )
1055                            else:
1056                                raise RuntimeError(
1057                                    f"Number of frames in precomputed features with shape"
1058                                    f" {feature_tensor.shape} is inconsistent with generated features!"
1059                                )
1060                        if loaded_common is not None:
1061                            if feature_tensor.shape[0] == loaded_common.shape[0]:
1062                                feature_tensor = torch.cat(
1063                                    [feature_tensor, loaded_common], dim=1
1064                                )
1065                            elif feature_tensor.shape[0] == loaded_common.shape[1]:
1066                                feature_tensor = torch.cat(
1067                                    [feature_tensor.T, loaded_common], dim=1
1068                                )
1069                            else:
1070                                raise ValueError(
1071                                    "The features from the data file and from the feature file have a different number of frames!"
1072                                )
1073                        features[key1]["loaded"] = feature_tensor
1074                    except ValueError:
1075                        raise RuntimeError(
1076                            "Individuals in precomputed features are inconsistent "
1077                            "with generated features"
1078                        )
1079        elif loaded_common is not None:
1080            for key in features:
1081                features[key]["loaded"] = loaded_common
1082        return features
1083
1084    def _load_data(self) -> np.array:
1085        """Load input data and generate data prompts."""
1086        if self.video_order is None:
1087            return None
1088
1089        files = defaultdict(lambda: [])
1090        for f in self.file_paths:
1091            if f.endswith(tuple([x for x in self.data_suffices])):
1092                bn = os.path.basename(f)
1093                video_id = strip_prefix(
1094                    strip_suffix(bn, self.data_suffices), self.data_prefixes
1095                )
1096                files[video_id].append(f)
1097        files = [files[x] for x in self.video_order]
1098
1099        def make_data_dictionary(filenames):
1100            data = {}
1101            stored_maxes = defaultdict(lambda: [])
1102            min_frames, max_frames = {}, {}
1103            name = strip_suffix(filenames[0], self.data_suffices)
1104            name = os.path.basename(name)
1105            stripped_name = strip_prefix(name, self.data_prefixes)
1106            metadata_list = []
1107            for filename in filenames:
1108                name = strip_suffix(filename, self.data_suffices)
1109                name = os.path.basename(name)
1110                prefix = strip_suffix(name, [stripped_name])
1111                data_new, tag = self._open_data(filename, self.default_agent_name)
1112                data_new, min_frames, max_frames = self._filter(data_new)
1113                data[prefix] = data_new
1114                for key, val in max_frames.items():
1115                    stored_maxes[key].append(val)
1116                metadata_list.append(tag)
1117            video_tag = self._get_video_metadata(metadata_list)
1118            sample_df = list(list(data.values())[0].values())[0]
1119            self.bodyparts = sorted(list(sample_df.index.unique(1)))
1120            smallest_maxes = dict.fromkeys(stored_maxes)
1121            for key, val in stored_maxes.items():
1122                smallest_maxes[key] = np.amin(val)
1123            data_dict = self._generate_features(data, stripped_name)
1124            bp_dict = self._count_bodyparts(
1125                data=data, stripped_name=stripped_name, max_frames=smallest_maxes
1126            )
1127            min_frames = {stripped_name: min_frames}  # name is e.g. 20190707T1126-1226
1128            max_frames = {stripped_name: max_frames}
1129            names, lengths, coords = self._make_trimmed_data(data_dict)
1130            return names, lengths, coords, bp_dict, min_frames, max_frames, video_tag
1131
1132        if os.name != "nt":
1133            dict_list = p_map(make_data_dictionary, files, num_cpus=self.num_cpus)
1134        else:
1135            print(
1136                "Multiprocessing is not supported on Windows, loading files sequentially."
1137            )
1138            dict_list = tqdm([make_data_dictionary(f) for f in files])
1139
1140        self.visibility = {}
1141        self.min_frames = {}
1142        self.max_frames = {}
1143        self.original_coordinates = []
1144        self.metadata = []
1145        X = []
1146        for (
1147            names,
1148            lengths,
1149            coords,
1150            bp_dictionary,
1151            min_frames,
1152            max_frames,
1153            metadata,
1154        ) in dict_list:
1155            X += names
1156            self.original_coordinates += coords
1157            self.visibility.update(bp_dictionary)
1158            self.min_frames.update(min_frames)
1159            self.max_frames.update(max_frames)
1160            if metadata is not None:
1161                self.metadata += metadata
1162        del dict_list
1163        if len(self.metadata) != len(self.original_coordinates):
1164            self.metadata = None
1165        else:
1166            self.metadata = np.array(self.metadata)
1167
1168        self.min_frames = dict(self.min_frames)
1169        self.max_frames = dict(self.max_frames)
1170        self.original_coordinates = np.array(self.original_coordinates)
1171        return np.array(X)
1172
1173    @abstractmethod
1174    def _open_data(
1175        self, filename: str, default_clip_name: str
1176    ) -> Tuple[Dict, Optional[Dict]]:
1177        """Load the keypoints from filename and organize them in a dictionary.
1178
1179        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
1180        The first level is the frame numbers and the second is the body part names. The dataframes should have from
1181        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
1182        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
1183        be done automatically.
1184
1185        Parameters
1186        ----------
1187        filename : str
1188            path to the pose file
1189        default_clip_name : str
1190            the name to assign to a clip if it does not have a name in the raw data
1191
1192        Returns
1193        -------
1194        data dictionary : dict
1195            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
1196        metadata_dictionary : dict
1197            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
1198            like the annotator tag; for no metadata pass `None`)
1199
1200        """
1201
1202
1203class SequenceInputStore(GeneralInputStore):
1204    """An implementation of `dlc2action.data.InputStore` for datasets where input data files correspond to multiple videos."""
1205
1206    def _count_bodyparts(
1207        self, data: Dict, stripped_name: str, max_frames: Dict
1208    ) -> Dict:
1209        """Create a visibility score dictionary (with a score from 0 to 1 assigned to each frame of each clip)."""
1210        result = {stripped_name: {}}
1211        for ind in data.keys():
1212            num_bp = len(data[ind].index.unique(level=1))
1213            coords = (
1214                data[ind].values.reshape(-1, num_bp, len(data[ind].columns))[
1215                    : max_frames[ind], :, 0
1216                ]
1217                != 0
1218            )
1219            res = np.sum(coords, axis=1)
1220            result[stripped_name][ind] = res / coords.shape[1]
1221        return result
1222
1223    def _generate_features(self, data: Dict, name: str) -> Dict:
1224        """Generate features for an individual."""
1225        features = self.extractor.extract_features(data, name, prefix=None)
1226        if self.feature_suffix is not None:
1227            loaded_features = self._load_saved_features(name)
1228            for clip_id, feature_tensor in loaded_features.items():
1229                if not isinstance(feature_tensor, torch.Tensor):
1230                    feature_tensor = torch.tensor(feature_tensor)
1231                if self.convert_int_indices and (
1232                    isinstance(clip_id, int) or isinstance(clip_id, np.integer)
1233                ):
1234                    clip_id = f"ind{clip_id}"
1235                key1 = f"{os.path.basename(name)}---{clip_id}"
1236                if key1 in features:
1237                    try:
1238                        key2 = list(features[key1].keys())[0]
1239                        n_frames = features[key1][key2].shape[0]
1240                        if feature_tensor.shape[0] != n_frames:
1241                            n = feature_tensor.shape[0] - n_frames
1242                            if (
1243                                abs(n) > 2
1244                                and abs(feature_tensor.shape[1] - n_frames) <= 2
1245                            ):
1246                                feature_tensor = feature_tensor.T
1247                            # If off by <=2 frames, just clip the end
1248                            elif n > 0 and n <= 2:
1249                                feature_tensor = feature_tensor[:n_frames, :]
1250                            elif n < 0 and n >= -2:
1251                                filler = feature_tensor[-2:-1, :]
1252                                for i in range(n_frames - feature_tensor.shape[0]):
1253                                    feature_tensor = torch.cat(
1254                                        [feature_tensor, filler], 0
1255                                    )
1256                            else:
1257                                raise RuntimeError(
1258                                    print(
1259                                        f"Number of frames in precomputed features with shape"
1260                                        f" {feature_tensor.shape} is inconsistent with generated features!"
1261                                    )
1262                                )
1263                        features[key1]["loaded"] = feature_tensor
1264                    except ValueError:
1265                        raise RuntimeError(
1266                            print(
1267                                "Individuals in precomputed features are inconsistent "
1268                                "with generated features"
1269                            )
1270                        )
1271        return features
1272
1273    def _load_data(self) -> np.array:
1274        """Load input data and generate data prompts."""
1275        if self.video_order is None:
1276            return None
1277
1278        files = []
1279        for f in self.file_paths:
1280            if os.path.basename(f) in self.video_order:
1281                files.append(f)
1282
1283        def make_data_dictionary(seq_tuple):
1284            loaded_features = None
1285            seq_id, sequence = seq_tuple
1286            data, tag = self._get_data(seq_id, sequence, self.default_agent_name)
1287            if "loaded" in data.keys():
1288                loaded_features = data.pop("loaded")
1289            data, min_frames, max_frames = self._filter(data)
1290            sample_df = list(data.values())[0]
1291            self.bodyparts = sorted(list(sample_df.index.unique(1)))
1292            data_dict = self._generate_features(data, seq_id)
1293            if loaded_features is not None:
1294                for key in data_dict.keys():
1295                    data_dict[key]["loaded"] = loaded_features
1296            bp_dict = self._count_bodyparts(
1297                data=data, stripped_name=seq_id, max_frames=max_frames
1298            )
1299            min_frames = {seq_id: min_frames}  # name is e.g. 20190707T1126-1226
1300            max_frames = {seq_id: max_frames}
1301            names, lengths, coords = self._make_trimmed_data(data_dict)
1302            return names, lengths, coords, bp_dict, min_frames, max_frames, tag
1303
1304        seq_tuples = []
1305        for file in files:
1306            opened = self._open_file(file)
1307            seq_tuples += opened
1308        if os.name != "nt":
1309            dict_list = p_map(
1310                make_data_dictionary, sorted(seq_tuples), num_cpus=self.num_cpus
1311            )
1312        else:
1313            print(
1314                "Multiprocessing is not supported on Windows, loading files sequentially."
1315            )
1316            dict_list = tqdm([make_data_dictionary(f) for f in files])
1317
1318        self.visibility = {}
1319        self.min_frames = {}
1320        self.max_frames = {}
1321        self.original_coordinates = []
1322        self.metadata = []
1323        X = []
1324        for (
1325            names,
1326            lengths,
1327            coords,
1328            bp_dictionary,
1329            min_frames,
1330            max_frames,
1331            tag,
1332        ) in dict_list:
1333            X += names
1334            self.original_coordinates += coords
1335            self.visibility.update(bp_dictionary)
1336            self.min_frames.update(min_frames)
1337            self.max_frames.update(max_frames)
1338            if tag is not None:
1339                self.metadata += [tag for _ in names]
1340        del dict_list
1341
1342        if len(self.metadata) != len(self.original_coordinates):
1343            self.metadata = None
1344        else:
1345            self.metadata = np.array(self.metadata)
1346        self.min_frames = dict(self.min_frames)
1347        self.max_frames = dict(self.max_frames)
1348        self.original_coordinates = np.array(self.original_coordinates)
1349        return np.array(X)
1350
1351    @classmethod
1352    def get_file_ids(
1353        cls,
1354        filenames: Set = None,
1355        data_path: Union[str, Set] = None,
1356        file_paths: Set = None,
1357        *args,
1358        **kwargs,
1359    ) -> List:
1360        """Get file ids.
1361
1362        Process data parameters and return a list of ids  of the videos that should
1363        be processed by the `__init__` function.
1364
1365        Parameters
1366        ----------
1367        filenames : set, optional
1368            a set of string filenames to search for (only basenames, not the whole paths)
1369        data_path : str | set, optional
1370            the path to the folder where the pose and feature files are stored or a set of such paths
1371            (not passed if creating from key objects or from `file_paths`)
1372        file_paths : set, optional
1373            a set of string paths to the pose and feature files
1374            (not passed if creating from key objects or from `data_path`)
1375
1376        Returns
1377        -------
1378        video_ids : list
1379            a list of video file ids
1380
1381        """
1382        if file_paths is None:
1383            file_paths = []
1384        if data_path is not None:
1385            if isinstance(data_path, str):
1386                data_path = [data_path]
1387            file_paths = []
1388            for folder in data_path:
1389                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
1390        ids = set()
1391        for f in file_paths:
1392            if os.path.basename(f) in filenames:
1393                ids.add(os.path.basename(f))
1394        ids = sorted(ids)
1395        return ids
1396
1397    @abstractmethod
1398    def _open_file(self, filename: str) -> List:
1399        """Open a file and make a list of sequences.
1400
1401        The sequence objects should contain information about all clips in one video. The sequences and
1402        video ids will be processed in the `_get_data` function.
1403
1404        Parameters
1405        ----------
1406        filename : str
1407            the name of the file
1408
1409        Returns
1410        -------
1411        video_tuples : list
1412            a list of video tuples: `(video_id, sequence)`
1413
1414        """
1415
1416    @abstractmethod
1417    def _get_data(
1418        self, video_id: str, sequence, default_agent_name: str
1419    ) -> Tuple[Dict, Optional[Dict]]:
1420        """Get the keypoint dataframes from a sequence.
1421
1422        The sequences and video ids are generated in the `_open_file` function.
1423        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
1424        The first level is the frame numbers and the second is the body part names. The dataframes should have from
1425        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
1426        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
1427        be done automatically.
1428
1429        Parameters
1430        ----------
1431        video_id : str
1432            the video id
1433        sequence
1434            an object containing information about all clips in one video
1435        default_agent_name : str
1436            the default agent name
1437
1438        Returns
1439        -------
1440        data dictionary : dict
1441            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
1442        metadata_dictionary : dict
1443            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
1444            like the annotator tag; for no metadata pass `None`)
1445
1446        """
1447
1448
1449class DLCTrackStore(FileInputStore):
1450    """DLC track data.
1451
1452    Assumes the following file structure:
1453    ```
1454    data_path
1455    ├── video1DLC1000.pickle
1456    ├── video2DLC400.pickle
1457    ├── video1_features.pt
1458    └── video2_features.pt
1459    ```
1460    Here `data_suffix` is `{'DLC1000.pickle', 'DLC400.pickle'}` and `feature_suffix` (optional) is `'_features.pt'`.
1461
1462    The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
1463    feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
1464    set `transpose_features` to `True`.
1465
1466    The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
1467    with `pickle.dump()` (with `.pickle` or `.pkl` extension).
1468    """
1469
1470    def _open_data(
1471        self, filename: str, default_agent_name: str
1472    ) -> Tuple[Dict, Optional[Dict]]:
1473        """Load the keypoints from filename and organize them in a dictionary.
1474
1475        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
1476        The first level is the frame numbers and the second is the body part names. The dataframes should have from
1477        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
1478        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
1479        be done automatically.
1480
1481        Parameters
1482        ----------
1483        filename : str
1484            path to the pose file
1485        default_agent_name : str
1486            the default agent name
1487
1488        Returns
1489        -------
1490        data dictionary : dict
1491            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
1492        metadata_dictionary : dict
1493            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
1494            like the annotator tag; for no metadata pass `None`)
1495
1496        """
1497        if filename.endswith("h5"):
1498            temp = pd.read_hdf(filename)
1499            temp = temp.droplevel("scorer", axis=1)
1500        else:
1501            temp = pd.read_csv(filename, header=[1, 2])
1502            temp.columns.names = ["bodyparts", "coords"]
1503        if "individuals" not in temp.columns.names:
1504            old_idx = temp.columns.to_frame()
1505            old_idx.insert(0, "individuals", self.default_agent_name)
1506            temp.columns = pd.MultiIndex.from_frame(old_idx)
1507        df = temp.stack(["individuals", "bodyparts"], future_stack=True)
1508        idx = pd.MultiIndex.from_product(
1509            [df.index.levels[0], df.index.levels[1], df.index.levels[2]],
1510            names=df.index.names,
1511        )
1512        df = df.reindex(idx).fillna(value=0)
1513        animals = sorted(list(df.index.levels[1]))
1514        dic = {}
1515        for ind in animals:
1516            coord = df.iloc[df.index.get_level_values(1) == ind].droplevel(1)
1517            coord = coord[["x", "y", "likelihood"]]
1518            dic[ind] = coord
1519
1520        return dic, None
1521
1522
1523class DLCTrackletStore(FileInputStore):
1524    """DLC tracklet data.
1525
1526    Assumes the following file structure:
1527    ```
1528    data_path
1529    ├── video1DLC1000.pickle
1530    ├── video2DLC400.pickle
1531    ├── video1_features.pt
1532    └── video2_features.pt
1533    ```
1534    Here `data_suffix` is `{'DLC1000.pickle', 'DLC400.pickle'}` and `feature_suffix` (optional) is `'_features.pt'`.
1535
1536    The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
1537    feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
1538    set `transpose_features` to `True`.
1539
1540    The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
1541    with `pickle.dump()` (with `.pickle` or `.pkl` extension).
1542    """
1543
1544    def _open_data(
1545        self, filename: str, default_agent_name: str
1546    ) -> Tuple[Dict, Optional[Dict]]:
1547        """Load the keypoints from filename and organize them in a dictionary.
1548
1549        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
1550        The first level is the frame numbers and the second is the body part names. The dataframes should have from
1551        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
1552        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
1553        be done automatically.
1554
1555        Parameters
1556        ----------
1557        filename : str
1558            path to the pose file
1559        default_agent_name : str
1560            the default agent name
1561
1562        Returns
1563        -------
1564        data dictionary : dict
1565            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
1566        metadata_dictionary : dict
1567            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
1568            like the annotator tag; for no metadata pass `None`)
1569
1570        """
1571        output = {}
1572        with open(filename, "rb") as f:
1573            data_p = pickle.load(f)
1574        header = data_p["header"]
1575        bodyparts = header.unique("bodyparts")
1576
1577        keys = sorted([key for key in data_p.keys() if key != "header"])
1578        min_frames = defaultdict(lambda: 10**5)
1579        max_frames = defaultdict(lambda: 0)
1580        for tr_id in keys:
1581            coords = {}
1582            fr_i = int(list(data_p[tr_id].keys())[0][5:]) - 1
1583            for frame in data_p[tr_id]:
1584                count = 0
1585                while int(frame[5:]) > fr_i + 1:
1586                    count += 1
1587                    fr_i = fr_i + 1
1588                    if count <= 3:
1589                        for bp, name in enumerate(bodyparts):
1590                            coords[(fr_i, name)] = coords[(fr_i - 1, name)]
1591                    else:
1592                        for bp, name in enumerate(bodyparts):
1593                            coords[(fr_i, name)] = np.zeros(
1594                                coords[(fr_i - 1, name)].shape
1595                            )
1596                fr_i = int(frame[5:])
1597                if fr_i > max_frames[f"ind{tr_id}"]:
1598                    max_frames[f"ind{tr_id}"] = fr_i
1599                if fr_i < min_frames[f"ind{tr_id}"]:
1600                    min_frames[f"ind{tr_id}"] = fr_i
1601                for bp, name in enumerate(bodyparts):
1602                    coords[(fr_i, name)] = data_p[tr_id][frame][bp][:3]
1603
1604            output[f"ind{tr_id}"] = pd.DataFrame(
1605                data=coords, index=["x", "y", "likelihood"]
1606            ).T
1607        return output, None
1608
1609
1610class CalMS21InputStore(SequenceInputStore):
1611    """CalMS21 data.
1612
1613    Use the `'random:test_from_name:{name}'` and `'val-from-name:{val_name}:test-from-name:{test_name}'`
1614    partitioning methods with `'train'`, `'test'` and `'unlabeled'` names to separate into train, test and validation
1615    subsets according to the original files. For example, with `'val-from-name:test:test-from-name:unlabeled'`
1616    the data from the test file will go into validation and the unlabeled files will be the test.
1617
1618    Assumes the following file structure:
1619    ```
1620    data_path
1621    ├── calms21_task1_train.npy
1622    ├── calms21_task1_test.npy
1623    ├── calms21_task1_test_features.npy
1624    ├── calms21_task1_test_features.npy
1625    ├── calms21_unlabeled_videos_part1.npy
1626    ├── calms21_unlabeled_videos_part1.npy
1627    ├── calms21_unlabeled_videos_part2.npy
1628    └── calms21_unlabeled_videos_part3.npy
1629    ```
1630    """
1631
1632    def __init__(
1633        self,
1634        video_order: List = None,
1635        data_path: Union[Set, str] = None,
1636        file_paths: Set = None,
1637        task_n: int = 1,
1638        include_task1: bool = True,
1639        feature_save_path: str = None,
1640        len_segment: int = 128,
1641        overlap: int = 0,
1642        feature_extraction: str = "kinematic",
1643        key_objects: Dict = None,
1644        treba_files: bool = False,
1645        num_cpus: int = None,
1646        feature_extraction_pars: Dict = None,
1647        *args,
1648        **kwargs,
1649    ) -> None:
1650        """Initialize a store.
1651
1652        Parameters
1653        ----------
1654        video_order : list, optional
1655            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1656        data_path : str | set, optional
1657            the path to the folder where the pose and feature files are stored or a set of such paths
1658            (not passed if creating from key objects or from `file_paths`)
1659        file_paths : set, optional
1660            a set of string paths to the pose and feature files
1661            (not passed if creating from key objects or from `data_path`)
1662        task_n : [1, 2]
1663            the number of the task
1664        include_task1 : bool, default True
1665            include task 1 data to training set
1666        feature_save_path : str, optional
1667            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
1668        len_segment : int, default 128
1669            the length of the segments in which the data should be cut (in frames)
1670        overlap : int, default 0
1671            the length of the overlap between neighboring segments (in frames)
1672        feature_extraction : str, default 'kinematic'
1673            the feature extraction method (see options.feature_extractors for available options)
1674        ignored_bodyparts : list, optional
1675            list of strings of bodypart names to ignore
1676        key_objects : tuple, optional
1677            a tuple of key objects
1678        treba_files : bool, default False
1679            if `True`, TREBA feature files will be loaded
1680        num_cpus : int, optional
1681            the number of cpus to use in data processing
1682        feature_extraction_pars : dict, optional
1683            parameters of the feature extractor
1684
1685        """
1686        self.task_n = int(task_n)
1687        self.include_task1 = include_task1
1688        self.treba_files = treba_files
1689        if feature_extraction_pars is not None:
1690            feature_extraction_pars["interactive"] = True
1691
1692        super().__init__(
1693            video_order,
1694            data_path,
1695            file_paths,
1696            data_prefix=None,
1697            feature_suffix=None,
1698            convert_int_indices=False,
1699            feature_save_path=feature_save_path,
1700            canvas_shape=[1024, 570],
1701            len_segment=len_segment,
1702            overlap=overlap,
1703            feature_extraction=feature_extraction,
1704            ignored_clips=None,
1705            ignored_bodyparts=None,
1706            default_agent_name="ind0",
1707            key_objects=key_objects,
1708            likelihood_threshold=0,
1709            num_cpus=num_cpus,
1710            frame_limit=1,
1711            feature_extraction_pars=feature_extraction_pars,
1712        )
1713
1714    @classmethod
1715    def get_file_ids(
1716        cls,
1717        task_n: int = 1,
1718        include_task1: bool = False,
1719        treba_files: bool = False,
1720        data_path: Union[str, Set] = None,
1721        file_paths=None,
1722        *args,
1723        **kwargs,
1724    ) -> Iterable:
1725        """Get file ids.
1726
1727        Process data parameters and return a list of ids  of the videos that should
1728        be processed by the `__init__` function.
1729
1730        Parameters
1731        ----------
1732        task_n : {1, 2, 3}
1733            the index of the CalMS21 challenge task
1734        include_task1 : bool, default False
1735            if `True`, the training file of the task 1 will be loaded
1736        treba_files : bool, default False
1737            if `True`, the TREBA feature files will be loaded
1738        filenames : set, optional
1739            a set of string filenames to search for (only basenames, not the whole paths)
1740        data_path : str | set, optional
1741            the path to the folder where the pose and feature files are stored or a set of such paths
1742            (not passed if creating from key objects or from `file_paths`)
1743        file_paths : set, optional
1744            a set of string paths to the pose and feature files
1745            (not passed if creating from key objects or from `data_path`)
1746
1747        Returns
1748        -------
1749        video_ids : list
1750            a list of video file ids
1751
1752        """
1753        task_n = int(task_n)
1754        if task_n == 1:
1755            include_task1 = False
1756        files = []
1757        if treba_files:
1758            postfix = "_features"
1759        else:
1760            postfix = ""
1761        files.append(f"calms21_task{task_n}_train{postfix}.npy")
1762        files.append(f"calms21_task{task_n}_test{postfix}.npy")
1763        if include_task1:
1764            files.append(f"calms21_task1_train{postfix}.npy")
1765        for i in range(1, 5):
1766            files.append(f"calms21_unlabeled_videos_part{i}{postfix}.npy")
1767        filenames = set(files)
1768        return SequenceInputStore.get_file_ids(filenames, data_path, file_paths)
1769
1770    def _open_file(self, filename: str) -> List:
1771        """Open a file and make a list of sequences.
1772
1773        The sequence objects should contain information about all clips in one video. The sequences and
1774        video ids will be processed in the `_get_data` function.
1775
1776        Parameters
1777        ----------
1778        filename : str
1779            the name of the file
1780
1781        Returns
1782        -------
1783        video_tuples : list
1784            a list of video tuples: `(video_id, sequence)`
1785
1786        """
1787        if os.path.basename(filename).startswith("calms21_unlabeled_videos"):
1788            mode = "unlabeled"
1789        elif os.path.basename(filename).startswith(f"calms21_task{self.task_n}_test"):
1790            mode = "test"
1791        else:
1792            mode = "train"
1793        data_dict = np.load(filename, allow_pickle=True).item()
1794        data = {}
1795        keys = list(data_dict.keys())
1796        for key in keys:
1797            data.update(data_dict[key])
1798            data_dict.pop(key)
1799        dict_list = [(f'{mode}--{k.split("/")[-1]}', v) for k, v in data.items()]
1800        return dict_list
1801
1802    def _get_data(
1803        self, video_id: str, sequence, default_agent_name: str
1804    ) -> Tuple[Dict, Optional[Dict]]:
1805        """Get the keypoint dataframes from a sequence.
1806
1807        The sequences and video ids are generated in the `_open_file` function.
1808        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
1809        The first level is the frame numbers and the second is the body part names. The dataframes should have from
1810        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
1811        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
1812        be done automatically.
1813
1814        Parameters
1815        ----------
1816        video_id : str
1817            the video id
1818        sequence
1819            an object containing information about all clips in one video
1820        default_agent_name
1821            the name of the default agent
1822
1823        Returns
1824        -------
1825        data dictionary : dict
1826            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
1827        metadata_dictionary : dict
1828            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
1829            like the annotator tag; for no metadata pass `None`)
1830
1831        """
1832        if "metadata" in sequence:
1833            annotator = sequence["metadata"]["annotator-id"]
1834        else:
1835            annotator = 0
1836        bodyparts = [
1837            "nose",
1838            "left ear",
1839            "right ear",
1840            "neck",
1841            "left hip",
1842            "right hip",
1843            "tail",
1844        ]
1845        columns = ["x", "y"]
1846        if "keypoints" in sequence:
1847            sequence = sequence["keypoints"]
1848            index = pd.MultiIndex.from_product([range(sequence.shape[0]), bodyparts])
1849            data = {
1850                "mouse1": pd.DataFrame(
1851                    data=(sequence[:, 0, :, :]).transpose((0, 2, 1)).reshape(-1, 2),
1852                    columns=columns,
1853                    index=index,
1854                ),
1855                "mouse2": pd.DataFrame(
1856                    data=(sequence[:, 1, :, :]).transpose((0, 2, 1)).reshape(-1, 2),
1857                    columns=columns,
1858                    index=index,
1859                ),
1860            }
1861        else:
1862            sequence = sequence["features"]
1863            mice = sequence[:, :-32].reshape((-1, 2, 2, 7))
1864            index = pd.MultiIndex.from_product([range(mice.shape[0]), bodyparts])
1865            data = {
1866                "mouse1": pd.DataFrame(
1867                    data=(mice[:, 0, :, :]).transpose((0, 2, 1)).reshape(-1, 2),
1868                    columns=columns,
1869                    index=index,
1870                ),
1871                "mouse2": pd.DataFrame(
1872                    data=(mice[:, 1, :, :]).transpose((0, 2, 1)).reshape(-1, 2),
1873                    columns=columns,
1874                    index=index,
1875                ),
1876                "loaded": sequence[:, -32:],
1877            }
1878        # metadata = {k: annotator for k in data.keys()}
1879        metadata = annotator
1880        return data, metadata
1881
1882
1883class Numpy3DInputStore(FileInputStore):
1884    """3D data.
1885
1886    Assumes the data files to be `numpy` arrays saved in `.npy` format with shape `(#frames, #keypoints, 3)`.
1887
1888    Assumes the following file structure:
1889    ```
1890    data_path
1891    ├── video1_suffix1.npy
1892    ├── video2_suffix2.npy
1893    ├── video1_features.pt
1894    └── video2_features.pt
1895    ```
1896    Here `data_suffix` is `{'_suffix1.npy', '_suffix1.npy'}` and `feature_suffix` (optional) is `'_features.pt'`.
1897
1898    The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
1899    feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
1900    set `transpose_features` to `True`.
1901
1902    The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
1903    with `pickle.dump()` (with `.pickle` or `.pkl` extension).
1904    """
1905
1906    def __init__(
1907        self,
1908        video_order: List = None,
1909        data_path: Union[Set, str] = None,
1910        file_paths: Set = None,
1911        data_suffix: Union[Set, str] = None,
1912        data_prefix: Union[Set, str] = None,
1913        feature_suffix: Union[Set, str] = None,
1914        convert_int_indices: bool = True,
1915        feature_save_path: str = None,
1916        canvas_shape: List = None,
1917        len_segment: int = 128,
1918        overlap: int = 0,
1919        feature_extraction: str = "kinematic",
1920        ignored_clips: List = None,
1921        ignored_bodyparts: List = None,
1922        default_agent_name: str = "ind0",
1923        key_objects: Dict = None,
1924        likelihood_threshold: float = 0,
1925        num_cpus: int = None,
1926        frame_limit: int = 1,
1927        feature_extraction_pars: Dict = None,
1928        centered: bool = False,
1929        **kwargs,
1930    ) -> None:
1931        """Initialize a store.
1932
1933        Parameters
1934        ----------
1935        video_order : list, optional
1936            a list of video ids that should be processed in the same order (not passed if creating from key objects
1937        data_path : str | set, optional
1938            the path to the folder where the pose and feature files are stored or a set of such paths
1939            (not passed if creating from key objects or from `file_paths`)
1940        file_paths : set, optional
1941            a set of string paths to the pose and feature files
1942            (not passed if creating from key objects or from `data_path`)
1943        data_suffix : str | set, optional
1944            the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix}
1945            (not passed if creating from key objects or if irrelevant for the dataset)
1946        data_prefix : str | set, optional
1947            the prefix or the set of prefixes such that the pose files for different video views of the same
1948            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
1949            or if irrelevant for the dataset)
1950        feature_suffix : str | set, optional
1951            the suffix or the set of suffices such that the additional feature files are named
1952            {video_id}{feature_suffix} (and placed at the data_path folder)
1953        convert_int_indices : bool, default True
1954            if `True`, convert any integer key `i` in feature files to `'ind{i}'`
1955        feature_save_path : str, optional
1956            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
1957        canvas_shape : List, default [1, 1]
1958            the canvas size where the pose is defined
1959        len_segment : int, default 128
1960            the length of the segments in which the data should be cut (in frames)
1961        overlap : int, default 0
1962            the length of the overlap between neighboring segments (in frames)
1963        feature_extraction : str, default 'kinematic'
1964            the feature extraction method (see options.feature_extractors for available options)
1965        ignored_clips : list, optional
1966            list of strings of clip ids to ignore
1967        ignored_bodyparts : list, optional
1968            list of strings of bodypart names to ignore
1969        default_agent_name : str, default 'ind0'
1970            the agent name used as default in the pose files for a single agent
1971        key_objects : tuple, optional
1972            a tuple of key objects
1973        likelihood_threshold : float, default 0
1974            coordinate values with likelihoods less than this value will be set to 'unknown'
1975        num_cpus : int, optional
1976            the number of cpus to use in data processing
1977        frame_limit : int, default 1
1978            clips shorter than this number of frames will be ignored
1979        feature_extraction_pars : dict, optional
1980            parameters of the feature extractor
1981        centered : bool, default False
1982            if `True`, the pose is centered at the center of mass of the body
1983
1984        """
1985        super().__init__(
1986            video_order,
1987            data_path,
1988            file_paths,
1989            data_suffix=data_suffix,
1990            data_prefix=data_prefix,
1991            feature_suffix=feature_suffix,
1992            convert_int_indices=convert_int_indices,
1993            feature_save_path=feature_save_path,
1994            canvas_shape=canvas_shape,
1995            len_segment=len_segment,
1996            overlap=overlap,
1997            feature_extraction=feature_extraction,
1998            ignored_clips=ignored_clips,
1999            ignored_bodyparts=ignored_bodyparts,
2000            default_agent_name=default_agent_name,
2001            key_objects=key_objects,
2002            likelihood_threshold=likelihood_threshold,
2003            num_cpus=num_cpus,
2004            frame_limit=frame_limit,
2005            feature_extraction_pars=feature_extraction_pars,
2006            centered=centered,
2007        )
2008
2009    def _open_data(
2010        self, filename: str, default_clip_name: str
2011    ) -> Tuple[Dict, Optional[Dict]]:
2012        """Load the keypoints from filename and organize them in a dictionary.
2013
2014        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
2015        The first level is the frame numbers and the second is the body part names. The dataframes should have from
2016        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
2017        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
2018        be done automatically.
2019
2020        Parameters
2021        ----------
2022        filename : str
2023            path to the pose file
2024        default_clip_name : str
2025            the name to assign to a clip if it does not have a name in the raw data
2026
2027        Returns
2028        -------
2029        data dictionary : dict
2030            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
2031        metadata_dictionary : dict
2032            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
2033            like the annotator tag; for no metadata pass `None`)
2034        """
2035        data = np.load(filename, allow_pickle=True)
2036        bodyparts = [str(i) for i in range(data.shape[1])]
2037        clip_id = self.default_agent_name
2038        columns = ["x", "y", "z"]
2039        index = pd.MultiIndex.from_product([range(data.shape[0]), bodyparts])
2040        data_dict = {
2041            clip_id: pd.DataFrame(
2042                data=data.reshape(-1, 3), columns=columns, index=index
2043            )
2044        }
2045        return data_dict, None
2046
2047
2048class LoadedFeaturesInputStore(GeneralInputStore):
2049    """Non-pose feature files.
2050
2051    The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
2052    feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
2053    set `transpose_features` to `True`.
2054
2055    The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
2056    with `pickle.dump()` (with `.pickle` or `.pkl` extension).
2057
2058    Assumes the following file structure:
2059    ```
2060    data_path
2061    ├── video1_features.pt
2062    └── video2_features.pt
2063    ```
2064    Here `feature_suffix` (optional) is `'_features.pt'`.
2065    """
2066
2067    def __init__(
2068        self,
2069        video_order: List = None,
2070        data_path: Union[Set, str] = None,
2071        file_paths: Set = None,
2072        feature_suffix: Union[Set, str] = None,
2073        convert_int_indices: bool = True,
2074        feature_save_path: str = None,
2075        len_segment: int = 128,
2076        overlap: int = 0,
2077        ignored_clips: List = None,
2078        key_objects: Dict = None,
2079        num_cpus: int = None,
2080        frame_limit: int = 1,
2081        transpose_features: bool = False,
2082        **kwargs,
2083    ) -> None:
2084        """Initialize a store.
2085
2086        Parameters
2087        ----------
2088        video_order : list, optional
2089            a list of video ids that should be processed in the same order (not passed if creating from key objects
2090        data_path : str | set, optional
2091            the path to the folder where the pose and feature files are stored or a set of such paths
2092            (not passed if creating from key objects or from `file_paths`)
2093        file_paths : set, optional
2094            a set of string paths to the pose and feature files
2095            (not passed if creating from key objects or from `data_path`)
2096        feature_suffix : str | set, optional
2097            the suffix or the set of suffices such that the additional feature files are named
2098            {video_id}{feature_suffix} (and placed at the data_path folder)
2099        convert_int_indices : bool, default True
2100            if `True`, convert any integer key `i` in feature files to `'ind{i}'`
2101        feature_save_path : str, optional
2102            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
2103        len_segment : int, default 128
2104            the length of the segments in which the data should be cut (in frames)
2105        overlap : int, default 0
2106            the length of the overlap between neighboring segments (in frames)
2107        ignored_clips : list, optional
2108            list of strings of clip ids to ignore
2109        default_agent_name : str, default 'ind0'
2110            the agent name used as default in the pose files for a single agent
2111        key_objects : tuple, optional
2112            a tuple of key objects
2113        num_cpus : int, optional
2114            the number of cpus to use in data processing
2115        frame_limit : int, default 1
2116            clips shorter than this number of frames will be ignored
2117        transpose_features : bool, default False
2118            if `True`,
2119
2120        """
2121        super().__init__(
2122            video_order,
2123            data_path,
2124            file_paths,
2125            feature_suffix=feature_suffix,
2126            convert_int_indices=convert_int_indices,
2127            feature_save_path=feature_save_path,
2128            len_segment=len_segment,
2129            overlap=overlap,
2130            ignored_clips=ignored_clips,
2131            key_objects=key_objects,
2132            num_cpus=num_cpus,
2133            frame_limit=frame_limit,
2134            transpose_features=transpose_features,
2135        )
2136
2137    def get_visibility(
2138        self, video_id: str, clip_id: str, start: int, end: int, score: int
2139    ) -> float:
2140        """Get the fraction of the frames in that have a visibility score better than a hard_threshold.
2141
2142        For example, in the case of keypoint data the visibility score can be the number of identified keypoints.
2143
2144        Parameters
2145        ----------
2146        video_id : str
2147            the video id of the frames
2148        clip_id : str
2149            the clip id of the frames
2150        start : int
2151            the start frame
2152        end : int
2153            the end frame
2154        score : float
2155            the visibility score hard_threshold
2156
2157        Returns
2158        -------
2159        frac_visible: float
2160            the fraction of frames with visibility above the hard_threshold
2161
2162        """
2163        return 1
2164
2165    def _generate_features(
2166        self, video_id: str
2167    ) -> Tuple[Dict, Dict, Dict, Union[str, int]]:
2168        """Generate features from the raw coordinates."""
2169        features = defaultdict(lambda: {})
2170        loaded_features = self._load_saved_features(video_id)
2171        min_frames = None
2172        max_frames = None
2173        video_tag = None
2174        for clip_id, feature_tensor in loaded_features.items():
2175            if clip_id == "max_frames":
2176                max_frames = feature_tensor
2177            elif clip_id == "min_frames":
2178                min_frames = feature_tensor
2179            elif clip_id == "video_tag":
2180                video_tag = feature_tensor
2181            else:
2182                if not isinstance(feature_tensor, torch.Tensor):
2183                    feature_tensor = torch.tensor(feature_tensor)
2184                if self.convert_int_indices and (
2185                    isinstance(clip_id, int) or isinstance(clip_id, np.integer)
2186                ):
2187                    clip_id = f"ind{clip_id}"
2188                key = f"{os.path.basename(video_id)}---{clip_id}"
2189                features[key]["loaded"] = feature_tensor
2190        if min_frames is None:
2191            min_frames = {}
2192            for key, value in features.items():
2193                video_id, clip_id = key.split("---")
2194                min_frames[clip_id] = 0
2195        if max_frames is None:
2196            max_frames = {}
2197            for key, value in features.items():
2198                video_id, clip_id = key.split("---")
2199                max_frames[clip_id] = value["loaded"].shape[0] - 1 + min_frames[clip_id]
2200        return features, min_frames, max_frames, video_tag
2201
2202    def _load_data(self) -> np.array:
2203        """Load input data and generate data prompts."""
2204        if self.video_order is None:
2205            return None
2206
2207        files = []
2208        for video_id in self.video_order:
2209            for f in self.file_paths:
2210                if f.endswith(tuple(self.feature_suffix)):
2211                    bn = os.path.basename(f)
2212                    if video_id == strip_suffix(bn, self.feature_suffix):
2213                        files.append(f)
2214
2215        def make_data_dictionary(filename):
2216            name = strip_suffix(filename, self.feature_suffix)
2217            name = os.path.basename(name)
2218            data_dict, min_frames, max_frames, video_tag = self._generate_features(name)
2219            bp_dict = defaultdict(lambda: {})
2220            for key, value in data_dict.items():
2221                video_id, clip_id = key.split("---")
2222                bp_dict[video_id][clip_id] = 1
2223            min_frames = {name: min_frames}  # name is e.g. 20190707T1126-1226
2224            max_frames = {name: max_frames}
2225            names, lengths, coords = self._make_trimmed_data(data_dict)
2226            return names, lengths, coords, bp_dict, min_frames, max_frames, video_tag
2227
2228        if os.name != "nt":
2229            dict_list = p_map(make_data_dictionary, files, num_cpus=self.num_cpus)
2230        else:
2231            print(
2232                "Multiprocessing is not supported on Windows, loading files sequentially."
2233            )
2234            dict_list = tqdm([make_data_dictionary(f) for f in files])
2235
2236        self.visibility = {}
2237        self.min_frames = {}
2238        self.max_frames = {}
2239        self.original_coordinates = []
2240        self.metadata = []
2241        X = []
2242        for (
2243            names,
2244            lengths,
2245            coords,
2246            bp_dictionary,
2247            min_frames,
2248            max_frames,
2249            metadata,
2250        ) in dict_list:
2251            X += names
2252            self.original_coordinates += coords
2253            self.visibility.update(bp_dictionary)
2254            self.min_frames.update(min_frames)
2255            self.max_frames.update(max_frames)
2256            if metadata is not None:
2257                self.metadata += metadata
2258        del dict_list
2259        if len(self.metadata) != len(self.original_coordinates):
2260            self.metadata = None
2261        else:
2262            self.metadata = np.array(self.metadata)
2263
2264        self.min_frames = dict(self.min_frames)
2265        self.max_frames = dict(self.max_frames)
2266        self.original_coordinates = np.array(self.original_coordinates)
2267        return np.array(X)
2268
2269    @classmethod
2270    def get_file_ids(
2271        cls,
2272        data_path: Union[Set, str] = None,
2273        file_paths: Set = None,
2274        feature_suffix: Set = None,
2275        *args,
2276        **kwargs,
2277    ) -> List:
2278        """Get file ids.
2279
2280        Process data parameters and return a list of ids  of the videos that should
2281        be processed by the __init__ function.
2282
2283        Parameters
2284        ----------
2285        data_suffix : set | str, optional
2286            the suffix (or a set of suffixes) of the input data files
2287        data_path : set | str, optional
2288            the path to the folder where the pose and feature files are stored or a set of such paths
2289            (not passed if creating from key objects or from `file_paths`)
2290        data_prefix : set | str, optional
2291            the prefix or the set of prefixes such that the pose files for different video views of the same
2292            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
2293            or if irrelevant for the dataset)
2294        file_paths : set, optional
2295            a set of string paths to the pose and feature files
2296        feature_suffix : str | set, optional
2297            the suffix or the set of suffices such that the additional feature files are named
2298            {video_id}{feature_suffix} (and placed at the `data_path` folder or at `file_paths`)
2299
2300        Returns
2301        -------
2302        video_ids : list
2303            a list of video file ids
2304
2305        """
2306        if feature_suffix is None:
2307            feature_suffix = []
2308        if isinstance(feature_suffix, str):
2309            feature_suffix = [feature_suffix]
2310        feature_suffix = tuple(feature_suffix)
2311        if file_paths is None:
2312            file_paths = []
2313        if data_path is not None:
2314            if isinstance(data_path, str):
2315                data_path = [data_path]
2316            file_paths = []
2317            for folder in data_path:
2318                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
2319        ids = set()
2320        for f in file_paths:
2321            if f.endswith(feature_suffix):
2322                bn = os.path.basename(f)
2323                video_id = strip_suffix(bn, feature_suffix)
2324                ids.add(video_id)
2325        ids = sorted(ids)
2326        return ids
2327
2328
2329class SIMBAInputStore(FileInputStore):
2330    """SIMBA paper format data.
2331
2332    Assumes the following file structure:
2333
2334    ```
2335    data_path
2336    ├── Video1.csv
2337    ...
2338    └── Video9.csv
2339    ```
2340    Here `data_suffix` is `.csv`.
2341    """
2342
2343    def __init__(
2344        self,
2345        video_order: List = None,
2346        data_path: Union[Set, str] = None,
2347        file_paths: Set = None,
2348        data_prefix: Union[Set, str] = None,
2349        feature_suffix: str = None,
2350        feature_save_path: str = None,
2351        canvas_shape: List = None,
2352        len_segment: int = 128,
2353        overlap: int = 0,
2354        feature_extraction: str = "kinematic",
2355        ignored_clips: List = None,
2356        ignored_bodyparts: List = None,
2357        key_objects: Tuple = None,
2358        likelihood_threshold: float = 0,
2359        num_cpus: int = None,
2360        normalize: bool = False,
2361        feature_extraction_pars: Dict = None,
2362        centered: bool = False,
2363        data_suffix: str = None,
2364        use_features: bool = False,
2365        *args,
2366        **kwargs,
2367    ) -> None:
2368        """Initialize a store.
2369
2370        Parameters
2371        ----------
2372        video_order : list, optional
2373            a list of video ids that should be processed in the same order (not passed if creating from key objects
2374        data_path : str | set, optional
2375            the path to the folder where the pose and feature files are stored or a set of such paths
2376            (not passed if creating from key objects or from `file_paths`)
2377        file_paths : set, optional
2378            a set of string paths to the pose and feature files
2379            (not passed if creating from key objects or from `data_path`)
2380        data_suffix : str | set, optional
2381            the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix}
2382            (not passed if creating from key objects or if irrelevant for the dataset)
2383        data_prefix : str | set, optional
2384            the prefix or the set of prefixes such that the pose files for different video views of the same
2385            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
2386            or if irrelevant for the dataset)
2387        feature_suffix : str | set, optional
2388            the suffix or the set of suffices such that the additional feature files are named
2389            {video_id}{feature_suffix} (and placed at the data_path folder)
2390        feature_save_path : str, optional
2391            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
2392        canvas_shape : List, default [1, 1]
2393            the canvas size where the pose is defined
2394        len_segment : int, default 128
2395            the length of the segments in which the data should be cut (in frames)
2396        overlap : int, default 0
2397            the length of the overlap between neighboring segments (in frames)
2398        feature_extraction : str, default 'kinematic'
2399            the feature extraction method (see options.feature_extractors for available options)
2400        ignored_clips : list, optional
2401            list of strings of clip ids to ignore
2402        ignored_bodyparts : list, optional
2403            list of strings of bodypart names to ignore
2404        key_objects : tuple, optional
2405            a tuple of key objects
2406        likelihood_threshold : float, default 0
2407            coordinate values with likelihoods less than this value will be set to 'unknown'
2408        num_cpus : int, optional
2409            the number of cpus to use in data processing
2410        normalize : bool, default False
2411            whether to normalize the pose
2412        feature_extraction_pars : dict, optional
2413            parameters of the feature extractor
2414        centered : bool, default False
2415            whether the pose is centered at the object of interest
2416        use_features : bool, default False
2417            whether to use features
2418
2419        """
2420        self.use_features = use_features
2421        if feature_extraction_pars is not None:
2422            feature_extraction_pars["interactive"] = True
2423        super().__init__(
2424            video_order=video_order,
2425            data_path=data_path,
2426            file_paths=file_paths,
2427            data_suffix=data_suffix,
2428            data_prefix=data_prefix,
2429            feature_suffix=feature_suffix,
2430            convert_int_indices=False,
2431            feature_save_path=feature_save_path,
2432            canvas_shape=canvas_shape,
2433            len_segment=len_segment,
2434            overlap=overlap,
2435            feature_extraction=feature_extraction,
2436            ignored_clips=ignored_clips,
2437            ignored_bodyparts=ignored_bodyparts,
2438            default_agent_name="",
2439            key_objects=key_objects,
2440            likelihood_threshold=likelihood_threshold,
2441            num_cpus=num_cpus,
2442            min_frames=0,
2443            normalize=normalize,
2444            feature_extraction_pars=feature_extraction_pars,
2445            centered=centered,
2446        )
2447
2448    def _open_data(
2449        self, filename: str, default_clip_name: str
2450    ) -> Tuple[Dict, Optional[Dict]]:
2451
2452        torch.cuda.empty_cache() if torch.cuda.is_available() else None
2453        data = pd.read_csv(filename)
2454        output = {}
2455        column_dict = {"x": "x", "y": "y", "z": "z", "p": "likelihood"}
2456        columns = [x for x in data.columns if x.split("_")[-1] in column_dict]
2457        animals = sorted(set([x.split("_")[-2] for x in columns]))
2458        coords = sorted(set([x.split("_")[-1] for x in columns]))
2459        names = sorted(set(["_".join(x.split("_")[:-2]) for x in columns]))
2460        for animal in animals:
2461            data_dict = {}
2462            for i, row in data.iterrows():
2463                for col_name in names:
2464                    data_dict[(i, col_name)] = [
2465                        row[f"{col_name}_{animal}_{coord}"] for coord in coords
2466                    ]
2467            output[animal] = pd.DataFrame(data_dict).T
2468            output[animal].columns = [column_dict[x] for x in coords]
2469        if self.use_features:
2470            columns_to_avoid = [
2471                x
2472                for x in data.columns
2473                if x.split("_")[-1] in column_dict
2474                or x.split("_")[-1].startswith("prediction")
2475            ]
2476            columns_to_avoid += ["scorer", "frames", "video_no"]
2477            output["loaded"] = (
2478                data[[x for x in data.columns if x not in columns_to_avoid]]
2479                .interpolate()
2480                .values
2481            )
2482        return output, None
2483
2484
2485class ESKTrackStore(FileInputStore):
2486    """DLC track data from EPFL Smart Kitchen, allows to choose specific set of keypoints.
2487
2488    Assumes the following file structure:
2489    ```
2490    data_path
2491    ├── video1DLC1000.pickle
2492    ├── video2DLC400.pickle
2493    ├── video1_features.npy
2494    └── video2_features.npy
2495    ```
2496    Here `data_suffix` is `{'DLC1000.pickle', 'DLC400.pickle'}` and `feature_suffix` (optional) is `'_features.npy'`.
2497
2498    The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
2499    feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
2500    set `transpose_features` to `True`.
2501
2502    The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
2503    with `pickle.dump()` (with `.pickle` or `.pkl` extension).
2504    """
2505
2506    def __init__(
2507        self,
2508        keypoint_type: str = "hand",
2509        *args,
2510        **kwargs,
2511    ):
2512        """Initialize a store."""
2513        self.keypoint_type = keypoint_type
2514        self.num_body_kpts = 17
2515        self.num_hand_kpts = 42
2516        self.num_eye_kpts = 10
2517
2518        super().__init__(
2519            *args,
2520            **kwargs,
2521        )
2522
2523    def get_kpt_names(self):
2524
2525        kpt_names = [
2526            "nose",
2527            "left_eye",
2528            "right_eye",
2529            "left_ear",
2530            "right_ear",
2531            "left_shoulder",
2532            "right_shoulder",
2533            "left_elbow",
2534            "right_elbow",
2535            "left_wrist",
2536            "right_wrist",
2537            "left_hip",
2538            "right_hip",
2539            "left_knee",
2540            "right_knee",
2541            "left_ankle",
2542            "right_ankle",
2543        ]
2544
2545        kpt_names = (
2546            kpt_names
2547            + [
2548                f"hand_{i}"
2549                for i in range(
2550                    self.num_body_kpts, self.num_hand_kpts + self.num_body_kpts
2551                )
2552            ]
2553            + [f"eye_gaze_{i}" for i in range(self.num_eye_kpts)]
2554        )
2555        return np.array(kpt_names)
2556
2557    def get_kpt_ind(self, default_num):
2558        """Get the indices of the keypoints to be used."""
2559        body_ind = list(range(17))
2560        count = len(body_ind)
2561        hands_ind = list(range(count, count + self.num_hand_kpts))
2562        count += len(hands_ind)
2563        eye_ind = list(range(count, count + self.num_eye_kpts))
2564        body_wo_arm_ind = [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16]
2565        default_ind = list(range(default_num))
2566        switcher = {
2567            "body": body_ind,
2568            "hands": hands_ind,
2569            "eyes": eye_ind,
2570            "body_hands": body_ind + hands_ind,
2571            "body_eyes": body_ind + eye_ind,
2572            "hands_eyes": hands_ind + eye_ind,
2573            "body_wo_arm": body_wo_arm_ind,
2574        }
2575        return (
2576            switcher.get(self.keypoint_type, default_ind),
2577            not self.keypoint_type in switcher.keys(),
2578        )
2579
2580    def _open_data(
2581        self, filename: str, default_agent_name: str
2582    ) -> Tuple[Dict, Optional[Dict]]:
2583        """Load the keypoints from filename and organize them in a dictionary.
2584
2585        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
2586        The first level is the frame numbers and the second is the body part names. The dataframes should have from
2587        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
2588        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
2589        be done automatically.
2590
2591        Parameters
2592        ----------
2593        filename : str
2594            path to the pose file
2595        default_agent_name : str
2596            the default agent name
2597
2598        Returns
2599        -------
2600        data dictionary : dict
2601            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
2602        metadata_dictionary : dict
2603            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
2604            like the annotator tag; for no metadata pass `None`)
2605
2606        """
2607        if filename.endswith("h5"):
2608            temp = pd.read_hdf(filename)
2609            temp = temp.droplevel("scorer", axis=1)
2610        elif filename.endswith(".csv"):
2611            temp = pd.read_csv(filename, header=[1, 2])
2612            temp.columns.names = ["bodyparts", "coords"]
2613        else:
2614            raise TypeError("Invalid file type, please use .csv or .h5")
2615
2616        if "individuals" not in temp.columns.names:
2617            old_idx = temp.columns.to_frame()
2618            old_idx.insert(0, "individuals", self.default_agent_name)
2619            temp.columns = pd.MultiIndex.from_frame(old_idx)
2620
2621        df = temp.stack(["individuals", "bodyparts"], future_stack=True)
2622        idx = pd.MultiIndex.from_product(
2623            [df.index.levels[0], df.index.levels[1], df.index.levels[2]],
2624            names=df.index.names,
2625        )
2626        df = df.reindex(idx).fillna(value=0)
2627        animals = sorted(list(df.index.levels[1]))
2628        dic = {}
2629        default_num = len(df.index.levels[2])
2630        kpt_ind, is_special = self.get_kpt_ind(default_num)
2631        kpt_names = self.get_kpt_names()
2632        for ind in animals:
2633            coord = df.iloc[df.index.get_level_values(1) == ind].droplevel(1)
2634            coord = coord[["x", "y", "z", "likelihood"]]
2635            if not is_special:
2636                coord = coord.loc[(slice(None), kpt_names[kpt_ind]), :]
2637            dic[ind] = coord
2638
2639        return dic, None
class GeneralInputStore(dlc2action.data.base_store.PoseInputStore):
 28class GeneralInputStore(PoseInputStore):
 29    """A generalized realization of a `PoseInputStore`.
 30
 31    Assumes the following file structure:
 32    ```
 33    data_path
 34    ├── video1DLC1000.pickle
 35    ├── video2DLC400.pickle
 36    ├── video1_features.pt
 37    └── video2_features.pt
 38    ```
 39    Here `data_suffix` is `{'DLC1000.pickle', 'DLC400.pickle'}` and `feature_suffix` (optional) is `'_features.pt'`.
 40    """
 41
 42    data_suffix = None
 43
 44    def __init__(
 45        self,
 46        video_order: List = None,
 47        data_path: Union[Set, str] = None,
 48        file_paths: Set = None,
 49        data_suffix: Union[Set, str] = None,
 50        data_prefix: Union[Set, str] = None,
 51        feature_suffix: str = None,
 52        convert_int_indices: bool = True,
 53        feature_save_path: str = None,
 54        canvas_shape: List = None,
 55        len_segment: int = 128,
 56        overlap: int = 0,
 57        feature_extraction: str = "kinematic",
 58        ignored_clips: List = None,
 59        ignored_bodyparts: List = None,
 60        default_agent_name: str = "ind0",
 61        key_objects: Tuple = None,
 62        likelihood_threshold: float = 0,
 63        num_cpus: int = None,
 64        frame_limit: int = 1,
 65        normalize: bool = False,
 66        feature_extraction_pars: Dict = None,
 67        centered: bool = False,
 68        transpose_features: bool = False,
 69        *args,
 70        **kwargs,
 71    ) -> None:
 72        """Initialize a store.
 73
 74        Parameters
 75        ----------
 76        video_order : list, optional
 77            a list of video ids that should be processed in the same order (not passed if creating from key objects
 78        data_path : str | set, optional
 79            the path to the folder where the pose and feature files are stored or a set of such paths
 80            (not passed if creating from key objects or from `file_paths`)
 81        file_paths : set, optional
 82            a set of string paths to the pose and feature files
 83            (not passed if creating from key objects or from `data_path`)
 84        data_suffix : str | set, optional
 85            the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix}
 86            (not passed if creating from key objects or if irrelevant for the dataset)
 87        data_prefix : str | set, optional
 88            the prefix or the set of prefixes such that the pose files for different video views of the same
 89            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
 90            or if irrelevant for the dataset)
 91        feature_suffix : str | set, optional
 92            the suffix or the set of suffices such that the additional feature files are named
 93            {video_id}{feature_suffix} (and placed at the data_path folder)
 94        convert_int_indices : bool, default True
 95            if `True`, convert any integer key `i` in feature files to `'ind{i}'`
 96        feature_save_path : str, optional
 97            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
 98        canvas_shape : List, default [1, 1]
 99            the canvas size where the pose is defined
100        len_segment : int, default 128
101            the length of the segments in which the data should be cut (in frames)
102        overlap : int, default 0
103            the length of the overlap between neighboring segments (in frames)
104        feature_extraction : str, default 'kinematic'
105            the feature extraction method (see options.feature_extractors for available options)
106        ignored_clips : list, optional
107            list of strings of clip ids to ignore
108        ignored_bodyparts : list, optional
109            list of strings of bodypart names to ignore
110        default_agent_name : str, default 'ind0'
111            the agent name used as default in the pose files for a single agent
112        key_objects : tuple, optional
113            a tuple of key objects
114        likelihood_threshold : float, default 0
115            coordinate values with likelihoods less than this value will be set to 'unknown'
116        num_cpus : int, optional
117            the number of cpus to use in data processing
118        frame_limit : int, default 1
119            clips shorter than this number of frames will be ignored
120        normalize : bool, default False
121            whether to normalize the pose
122        feature_extraction_pars : dict, optional
123            parameters of the feature extractor
124        centered : bool, default False
125            whether the pose is centered
126        transpose_features : bool, default False
127            whether to transpose the features
128
129        """
130        super().__init__()
131        self.loaded_max = 0
132        if feature_extraction_pars is None:
133            feature_extraction_pars = {}
134        if ignored_clips is None:
135            ignored_clips = []
136        self.bodyparts = []
137        self.visibility = None
138        self.normalize = normalize
139
140        if canvas_shape is None:
141            canvas_shape = [1, 1]
142        if isinstance(data_suffix, str):
143            data_suffix = [data_suffix]
144        if isinstance(data_prefix, str):
145            data_prefix = [data_prefix]
146        if isinstance(data_path, str):
147            data_path = [data_path]
148        if isinstance(feature_suffix, str):
149            feature_suffix = [feature_suffix]
150
151        self.video_order = video_order
152        self.centered = centered
153        self.feature_extraction = feature_extraction
154        self.len_segment = int(len_segment)
155        self.data_suffices = data_suffix
156        self.data_prefixes = data_prefix
157        self.feature_suffix = feature_suffix
158        self.convert_int_indices = convert_int_indices
159        if isinstance(overlap, str):
160            overlap = float(overlap)
161        if overlap < 1:
162            overlap = overlap * self.len_segment
163        self.overlap = int(overlap)
164        self.canvas_shape = canvas_shape
165        self.default_agent_name = default_agent_name
166        self.feature_save_path = feature_save_path
167        self.data_suffices = data_suffix
168        self.data_prefixes = data_prefix
169        self.likelihood_threshold = likelihood_threshold
170        self.num_cpus = num_cpus
171        self.frame_limit = frame_limit
172        self.transpose = transpose_features
173
174        self.ram = False
175        self.min_frames = {}
176        self.original_coordinates = np.array([])
177
178        self.file_paths = self._get_file_paths(file_paths, data_path)
179
180        self.extractor = options.feature_extractors[self.feature_extraction](
181            self,
182            **feature_extraction_pars,
183        )
184
185        self.canvas_center = np.array(canvas_shape) // 2
186
187        if ignored_clips is not None:
188            self.ignored_clips = ignored_clips
189        else:
190            self.ignored_clips = []
191        if ignored_bodyparts is not None:
192            self.ignored_bodyparts = ignored_bodyparts
193        else:
194            self.ignored_bodyparts = []
195
196        self.step = self.len_segment - self.overlap
197        if self.step < 0:
198            raise ValueError(
199                f"The overlap value ({self.overlap}) cannot be larger than len_segment ({self.len_segment}"
200            )
201
202        if self.feature_save_path is None and data_path is not None:
203            self.feature_save_path = os.path.join(data_path[0], "trimmed")
204
205        if key_objects is None and self.video_order is not None:
206            print("Computing input features...")
207            self.data = self._load_data()
208        elif key_objects is not None:
209            self.load_from_key_objects(key_objects)
210
211    def __getitem__(self, ind: int) -> Dict:
212        """Get a single item from the dataset."""
213        prompt = self.data[ind]
214        if not self.ram:
215            with open(prompt, "rb") as f:
216                prompt = pickle.load(f)
217        return prompt
218
219    def __len__(self) -> int:
220        """Get the length of the dataset."""
221        if self.data is None:
222            raise RuntimeError("The input store data has not been initialized!")
223        return len(self.data)
224
225    @classmethod
226    def _get_file_paths(cls, file_paths: Set, data_path: Union[str, Set]) -> List:
227        """Get a set of relevant files.
228
229        Parameters
230        ----------
231        file_paths : set
232            a set of filepaths to include
233        data_path : str | set
234            the path to a folder that contains relevant files (a single path or a set)
235
236        Returns
237        -------
238        file_paths : list
239            a list of relevant file paths (input and feature files that follow the dataset naming pattern)
240
241        """
242        if file_paths is None:
243            file_paths = []
244        file_paths = list(file_paths)
245        if data_path is not None:
246            if isinstance(data_path, str):
247                data_path = [data_path]
248            for folder in data_path:
249                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
250        return file_paths
251
252    def get_folder(self, video_id: str) -> str:
253        """Get the input folder that the file with this video id was read from.
254
255        Parameters
256        ----------
257        video_id : str
258            the video id
259
260        Returns
261        -------
262        folder : str
263            the path to the directory that contains the input file associated with the video id
264
265        """
266        for file in self.file_paths:
267            if (
268                strip_prefix(
269                    strip_suffix(os.path.basename(file), self.data_suffices),
270                    self.data_prefixes,
271                )
272                == video_id
273            ):
274                return os.path.dirname(file)
275
276    def remove(self, indices: List) -> None:
277        """Remove the samples corresponding to indices.
278
279        Parameters
280        ----------
281        indices : int
282            a list of integer indices to remove
283
284        """
285        if len(indices) > 0:
286            mask = np.ones(len(self.original_coordinates))
287            mask[indices] = 0
288            mask = mask.astype(bool)
289            for file in self.data[~mask]:
290                os.remove(file)
291            self.original_coordinates = self.original_coordinates[mask]
292            self.data = self.data[mask]
293            if self.metadata is not None:
294                self.metadata = self.metadata[mask]
295
296    def key_objects(self) -> Tuple:
297        """Return a tuple of the key objects necessary to re-create the Store.
298
299        Returns
300        -------
301        key_objects : tuple
302            a tuple of key objects
303
304        """
305        for k, v in self.min_frames.items():
306            self.min_frames[k] = dict(v)
307        for k, v in self.max_frames.items():
308            self.max_frames[k] = dict(v)
309        return (
310            self.original_coordinates,
311            dict(self.min_frames),
312            dict(self.max_frames),
313            self.data,
314            self.visibility,
315            self.step,
316            self.file_paths,
317            self.len_segment,
318            self.metadata,
319        )
320
321    def load_from_key_objects(self, key_objects: Tuple) -> None:
322        """Load the information from a tuple of key objects.
323
324        Parameters
325        ----------
326        key_objects : tuple
327            a tuple of key objects
328
329        """
330        (
331            self.original_coordinates,
332            self.min_frames,
333            self.max_frames,
334            self.data,
335            self.visibility,
336            self.step,
337            self.file_paths,
338            self.len_segment,
339            self.metadata,
340        ) = key_objects
341
342    def to_ram(self) -> None:
343        """Transfer the data samples to RAM if they were previously stored as file paths."""
344        if self.ram:
345            return
346
347        if os.name != "nt":
348            data = p_map(
349                lambda x: self[x], list(range(len(self))), num_cpus=self.num_cpus
350            )
351        else:
352            print(
353                "Multiprocessing is not supported on Windows, loading files sequentially."
354            )
355            data = [load(x) for x in tqdm(self.data)]
356        self.data = TensorDict(data)
357        self.ram = True
358
359    def get_original_coordinates(self) -> np.ndarray:
360        """Return the original coordinates array.
361
362        Returns
363        -------
364        np.ndarray
365            an array that contains the coordinates of the data samples in original input data (video id, clip id,
366            start frame)
367
368        """
369        return self.original_coordinates
370
371    def create_subsample(self, indices: List, ssl_indices: List = None):
372        """Create a new store that contains a subsample of the data.
373
374        Parameters
375        ----------
376        indices : list
377            the indices to be included in the subsample
378        ssl_indices : list, optional
379            the indices to be included in the subsample without the annotation data
380
381        """
382        if ssl_indices is None:
383            ssl_indices = []
384        new = self.new()
385        new.original_coordinates = self.original_coordinates[indices + ssl_indices]
386        new.min_frames = self.min_frames
387        new.max_frames = self.max_frames
388        new.data = self.data[indices + ssl_indices]
389        new.visibility = self.visibility
390        new.step = self.step
391        new.file_paths = self.file_paths
392        new.len_segment = self.len_segment
393        if self.metadata is None:
394            new.metadata = None
395        else:
396            new.metadata = self.metadata[indices + ssl_indices]
397        return new
398
399    def get_video_id(self, coords: Tuple) -> str:
400        """Get the video id from an element of original coordinates.
401
402        Parameters
403        ----------
404        coords : tuple
405            an element of the original coordinates array
406
407        Returns
408        -------
409        video_id: str
410            the id of the video that the coordinates point to
411
412        """
413        video_name = coords[0].split("---")[0]
414        return video_name
415
416    def get_clip_id(self, coords: Tuple) -> str:
417        """Get the clip id from an element of original coordinates.
418
419        Parameters
420        ----------
421        coords : tuple
422            an element of the original coordinates array
423
424        Returns
425        -------
426        clip_id : str
427            the id of the clip that the coordinates point to
428
429        """
430        clip_id = coords[0].split("---")[1]
431        return clip_id
432
433    def get_clip_length(self, video_id: str, clip_id: str) -> int:
434        """Get the clip length from the id.
435
436        Parameters
437        ----------
438        video_id : str
439            the video id
440        clip_id : str
441            the clip id
442
443        Returns
444        -------
445        clip_length : int
446            the length of the clip
447
448        """
449        inds = clip_id.split("+")
450        max_frame = min([self.max_frames[video_id][x] for x in inds])
451        min_frame = max([self.min_frames[video_id][x] for x in inds])
452        return max_frame - min_frame + 1
453
454    def get_clip_start_end(self, coords: Tuple) -> Tuple[int, int]:
455        """Get the clip start and end frames from an element of original coordinates.
456
457        Parameters
458        ----------
459        coords : tuple
460            an element of original coordinates array
461
462        Returns
463        -------
464        start : int
465            the start frame of the clip that the coordinates point to
466        end : int
467            the end frame of the clip that the coordinates point to
468
469        """
470        l = self.get_clip_length_from_coords(coords)
471        i = coords[1]
472        start = int(i) * self.step
473        end = min(start + self.len_segment, l)
474        return start, end
475
476    def get_clip_start(self, video_name: str, clip_id: str) -> int:
477        """Get the clip start frame from the video id and the clip id.
478
479        Parameters
480        ----------
481        video_name : str
482            the video id
483        clip_id : str
484            the clip id
485
486        Returns
487        -------
488        clip_start : int
489            the start frame of the clip
490
491        """
492        return max(
493            [self.min_frames[video_name][clip_id_k] for clip_id_k in clip_id.split("+")]
494        )
495
496    def get_visibility(
497        self, video_id: str, clip_id: str, start: int, end: int, score: int
498    ) -> float:
499        """Get the fraction of the frames in that have a visibility score better than a hard_threshold.
500
501        For example, in the case of keypoint data the visibility score can be the number of identified keypoints.
502
503        Parameters
504        ----------
505        video_id : str
506            the video id of the frames
507        clip_id : str
508            the clip id of the frames
509        start : int
510            the start frame
511        end : int
512            the end frame
513        score : float
514            the visibility score hard_threshold
515
516        Returns
517        -------
518        frac_visible: float
519            the fraction of frames with visibility above the hard_threshold
520
521        """
522        s = 0
523        for ind_k in clip_id.split("+"):
524            s += np.sum(self.visibility[video_id][ind_k][start:end] > score) / (
525                end - start
526            )
527        return s / len(clip_id.split("+"))
528
529    def get_annotation_objects(self) -> Dict:
530        """Get a dictionary of objects necessary to create an `BehaviorStore`.
531
532        Returns
533        -------
534        annotation_objects : dict
535            a dictionary of objects to be passed to the BehaviorStore constructor where the keys are the names of
536            the objects
537
538        """
539        min_frames = self.min_frames
540        max_frames = self.max_frames
541        num_bp = self.visibility
542        return {
543            "min_frames": min_frames,
544            "max_frames": max_frames,
545            "visibility": num_bp,
546        }
547
548    @classmethod
549    def get_file_ids(
550        cls,
551        data_suffix: Union[Set, str] = None,
552        data_path: Union[Set, str] = None,
553        data_prefix: Union[Set, str] = None,
554        file_paths: Set = None,
555        feature_suffix: Set = None,
556        *args,
557        **kwargs,
558    ) -> List:
559        """Get file ids.
560
561        Process data parameters and return a list of ids  of the videos that should
562        be processed by the `__init__` function.
563
564        Parameters
565        ----------
566        data_suffix : set | str, optional
567            the suffix (or a set of suffixes) of the input data files
568        data_path : set | str, optional
569            the path to the folder where the pose and feature files are stored or a set of such paths
570            (not passed if creating from key objects or from `file_paths`)
571        data_prefix : set | str, optional
572            the prefix or the set of prefixes such that the pose files for different video views of the same
573            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
574            or if irrelevant for the dataset)
575        file_paths : set, optional
576            a set of string paths to the pose and feature files
577        feature_suffix : str | set, optional
578            the suffix or the set of suffices such that the additional feature files are named
579            {video_id}{feature_suffix} (and placed at the `data_path` folder or at `file_paths`)
580
581        Returns
582        -------
583        video_ids : list
584            a list of video file ids
585
586        """
587        if data_suffix is None:
588            if cls.data_suffix is not None:
589                data_suffix = cls.data_suffix
590            else:
591                raise ValueError("Cannot get video ids without the data suffix!")
592        if feature_suffix is None:
593            feature_suffix = []
594        if data_prefix is None:
595            data_prefix = ""
596        if isinstance(data_suffix, str):
597            data_suffix = [data_suffix]
598        else:
599            data_suffix = [x for x in data_suffix]
600        data_suffix = tuple(data_suffix)
601        if isinstance(data_prefix, str):
602            data_prefix = data_prefix
603        else:
604            data_prefix = tuple([x for x in data_prefix])
605        if isinstance(feature_suffix, str):
606            feature_suffix = [feature_suffix]
607        if file_paths is None:
608            file_paths = []
609        if data_path is not None:
610            if isinstance(data_path, str):
611                data_path = [data_path]
612            file_paths = []
613            for folder in data_path:
614                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
615        basenames = [os.path.basename(f) for f in file_paths]
616        ids = set()
617        for f in file_paths:
618            if f.endswith(data_suffix) and os.path.basename(f).startswith(data_prefix):
619                bn = os.path.basename(f)
620                video_id = strip_prefix(strip_suffix(bn, data_suffix), data_prefix)
621                if all([video_id + s in basenames for s in feature_suffix]):
622                    ids.add(video_id)
623        ids = sorted(ids)
624        return ids
625
626    def get_bodyparts(self) -> List:
627        """Get a list of bodypart names.
628
629        Parameters
630        ----------
631        data_dict : dict
632            the data dictionary (passed to feature extractor)
633        clip_id : str
634            the clip id
635
636        Returns
637        -------
638        bodyparts : list
639            a list of string or integer body part names
640
641        """
642        return [x for x in self.bodyparts if x not in self.ignored_bodyparts]
643
644    def get_coords(self, data_dict: Dict, clip_id: str, bodypart: str) -> np.ndarray:
645        """Get the coordinates array of a specific bodypart in a specific clip.
646
647        Parameters
648        ----------
649        data_dict : dict
650            the data dictionary (passed to feature extractor)
651        clip_id : str
652            the clip id
653        bodypart : str
654            the name of the body part
655
656        Returns
657        -------
658        coords : np.ndarray
659            the coordinates array of shape (#timesteps, #coordinates)
660
661        """
662        columns = [x for x in data_dict[clip_id].columns if x != "likelihood"]
663        xy_coord = (
664            data_dict[clip_id]
665            .xs(bodypart, axis=0, level=1, drop_level=False)[columns]
666            .values
667        )
668        return xy_coord
669
670    def get_n_frames(self, data_dict: Dict, clip_id: str) -> int:
671        """Get the length of the clip.
672
673        Parameters
674        ----------
675        data_dict : dict
676            the data dictionary (passed to feature extractor)
677        clip_id : str
678            the clip id
679
680        Returns
681        -------
682        n_frames : int
683            the length of the clip
684
685        """
686        if clip_id in data_dict:
687            return len(data_dict[clip_id].groupby(level=0))
688        else:
689            return min(
690                [len(data_dict[ind_k].groupby(level=0)) for ind_k in clip_id.split("+")]
691            )
692
693    def _filter(self, data_dict: Dict) -> Tuple[Dict, Dict, Dict]:
694        """Apply filters to a data dictionary + normalize the values and generate frame index dictionaries.
695
696        The filters include filling nan values, applying length and likelihood thresholds and removing
697        ignored clip ids.
698
699        """
700        new_data_dict = {}
701        keys = list(data_dict.keys())
702        for key in keys:
703            if key == "loaded":
704                continue
705            coord = data_dict.pop(key)
706            if key in self.ignored_clips:
707                continue
708            num_frames = len(coord.index.unique(level=0))
709            if num_frames < self.frame_limit:
710                continue
711            if "likelihood" in coord.columns:
712                columns = list(coord.columns)
713                columns.remove("likelihood")
714                coord.loc[coord["likelihood"] < self.likelihood_threshold, columns] = (
715                    np.nan
716                )
717            if not isinstance(self.centered, Iterable):
718                self.centered = [
719                    bool(self.centered)
720                    for dim in ["x", "y", "z"]
721                    if dim in coord.columns
722                ]
723            for i, dim in enumerate(["x", "y", "z"]):
724                if dim in coord.columns:
725                    if self.centered[i]:
726                        coord[dim] = coord[dim] + self.canvas_shape[i] // 2
727                    # coord.loc[coord[dim] < -self.canvas_shape[i] * 3 // 2, dim] = np.nan
728                    # coord.loc[coord[dim] > self.canvas_shape[i] * 3 // 2, dim] = np.nan
729            coord = coord.sort_index(level=0)
730            for bp in coord.index.unique(level=1):
731                coord.loc[coord.index.isin([bp], level=1)] = coord[
732                    coord.index.isin([bp], level=1)
733                ].interpolate()
734            dims = [x for x in coord.columns if x != "likelihood"]
735            mask = ~coord[dims[0]].isna()
736            for dim in dims[1:]:
737                mask = mask & (~coord[dim].isna())
738            mean = coord.loc[mask].groupby(level=0).mean()
739            for frame in set(coord.index.get_level_values(0)):
740                if frame not in mean.index:
741                    mean.loc[frame] = [np.nan for _ in mean.columns]
742            mean = mean.interpolate()
743            mean[mean.isna()] = 0
744            for dim in coord.columns:
745                if dim == "likelihood":
746                    continue
747                coord.loc[coord[dim].isna(), dim] = mean.loc[
748                    coord.loc[coord[dim].isna()].index.get_level_values(0)
749                ][dim].to_numpy()
750            if np.sum(self.canvas_shape) > 0:
751                for i, dim in enumerate(["x", "y", "z"]):
752                    if dim in coord.columns:
753                        coord[dim] = (
754                            coord[dim] - self.canvas_shape[i] // 2
755                        ) / self.canvas_shape[0]
756            new_data_dict[key] = coord
757        max_frames = {}
758        min_frames = {}
759        for key, value in new_data_dict.items():
760            max_frames[key] = max(value.index.unique(0))
761            min_frames[key] = min(value.index.unique(0))
762        if "loaded" in data_dict:
763            new_data_dict["loaded"] = data_dict["loaded"]
764        return new_data_dict, min_frames, max_frames
765
766    def _get_files_from_ids(self):
767        files = defaultdict(lambda: [])
768        used_prefixes = defaultdict(lambda: [])
769        for f in self.file_paths:
770            if f.endswith(tuple([x for x in self.data_suffices])):
771                bn = os.path.basename(f)
772                video_id = strip_prefix(
773                    strip_suffix(bn, self.data_suffices), self.data_prefixes
774                )
775                ok = True
776                if self.data_prefixes is not None:
777                    for p in self.data_prefixes:
778                        if bn.startswith(p):
779                            if p not in used_prefixes[video_id]:
780                                used_prefixes[video_id].append(p)
781                            else:
782                                ok = False
783                            break
784                if not ok:
785                    continue
786                files[video_id].append(f)
787        files = [files[x] for x in self.video_order]
788        return files
789
790    def _make_trimmed_data(self, keypoint_dict: Dict) -> Tuple[List, Dict, List]:
791        """Cut a keypoint dictionary into overlapping pieces of equal length."""
792        X = []
793        original_coordinates = []
794        lengths = defaultdict(lambda: {})
795        os.makedirs(self.feature_save_path, exist_ok=True)
796        order = sorted(list(keypoint_dict.keys()))
797        for v_id in order:
798            keypoints = keypoint_dict[v_id]
799            v_len = min([len(x) for x in keypoints.values()])
800            sp = np.arange(0, v_len, self.step)
801            pad = sp[-1] + self.len_segment - v_len
802            video_id, clip_id = v_id.split("---")
803            for key in keypoints:
804                if len(keypoints[key]) > v_len:
805                    keypoints[key] = keypoints[key][:v_len]
806                if len(keypoints[key].shape) == 2:
807                    keypoints[key] = np.pad(keypoints[key], ((0, pad), (0, 0)))
808                else:
809                    keypoints[key] = np.pad(
810                        keypoints[key], ((0, pad), (0, 0), (0, 0), (0, 0))
811                    )
812            for i, start in enumerate(sp):
813                sample_dict = {}
814                original_coordinates.append((v_id, i))
815                for key in keypoints:
816                    sample_dict[key] = keypoints[key][start : start + self.len_segment]
817                    arr = np.asarray(sample_dict[key], dtype=np.float32)
818                    tensor = (
819                        torch.from_numpy(arr)
820                        .permute(*range(1, arr.ndim), 0)
821                        .contiguous()
822                    )
823                    sample_dict[key] = tensor
824
825                name = os.path.join(self.feature_save_path, f"{v_id}_{start}.pickle")
826                X.append(name)
827                lengths[video_id][clip_id] = v_len
828                with open(name, "wb") as f:
829                    pickle.dump(sample_dict, f)
830        return X, dict(lengths), original_coordinates
831
832    def _load_saved_features(self, video_id: str):
833        """Load saved features file `(#frames, #features)`."""
834        basenames = [os.path.basename(x) for x in self.file_paths]
835        loaded_features_cat = []
836        self.feature_suffix = sorted(self.feature_suffix)
837        for feature_suffix in self.feature_suffix:
838            i = basenames.index(os.path.basename(video_id) + feature_suffix)
839            path = self.file_paths[i]
840            if not os.path.exists(path):
841                raise RuntimeError(f"Did not find a feature file for {video_id}!")
842            extension = feature_suffix.split(".")[-1]
843            if extension in ["pickle", "pkl"]:
844                with open(path, "rb") as f:
845                    loaded_features = pickle.load(f)
846            elif extension in ["pt", "pth"]:
847                loaded_features = torch.load(path)
848            elif extension == "npy":
849                try:
850                    loaded_features = np.load(path, allow_pickle=True).item()
851                except:
852                    loaded_features = np.load(path, allow_pickle=True)
853                    loaded_features = {
854                        "features": loaded_features,
855                        "min_frames": {video_id: 0},
856                        "max_frames": {video_id: len(loaded_features)},
857                        "video_tag": video_id,
858                    }
859            else:
860                raise ValueError(
861                    f"Found feature file in an unrecognized format: .{extension}. \n "
862                    "Please save with torch (as .pt or .pth), numpy (as .npy) or pickle (as .pickle or .pkl)."
863                )
864            loaded_features_cat.append(loaded_features)
865        keys = list(loaded_features_cat[0].keys())
866        loaded_features = {}
867        for k in keys:
868            if k in ["min_frames", "max_frames", "video_tag"]:
869                loaded_features[k] = loaded_features_cat[0][k]
870                continue
871            features = []
872            for x in loaded_features_cat:
873                if not isinstance(x[k], torch.Tensor):
874                    features.append(torch.from_numpy(x[k]))
875                else:
876                    features.append(x[k])
877            a = torch.cat(features)
878            if self.transpose:
879                a = a.T
880            loaded_features[k] = a
881        return loaded_features
882
883    def get_likelihood(
884        self, data_dict: Dict, clip_id: str, bodypart: str
885    ) -> Union[np.ndarray, None]:
886        """Get the likelihood values.
887
888        Parameters
889        ----------
890        data_dict : dict
891            the data dictionary
892        clip_id : str
893            the clip id
894        bodypart : str
895            the name of the body part
896
897        Returns
898        -------
899        likelihoods: np.ndarrray | None
900            `None` if the dataset doesn't have likelihoods or an array of shape (#timestamps)
901
902        """
903        if "likelihood" in data_dict[clip_id].columns:
904            likelihood = (
905                data_dict[clip_id]
906                .xs(bodypart, axis=0, level=1, drop_level=False)
907                .values[:, -1]
908            )
909            return likelihood
910        else:
911            return None
912
913    def _get_video_metadata(self, metadata_list: Optional[List]):
914        """Make a single metadata dictionary from a list of dictionaries received from different data prefixes."""
915        if metadata_list is None:
916            return None
917        else:
918            return metadata_list[0]
919
920    def get_indices(self, tag: int) -> List:
921        """Get a list of indices of samples that have a specific meta tag.
922
923        Parameters
924        ----------
925        tag : int
926            the meta tag for the subsample (`None` for the whole dataset)
927
928        Returns
929        -------
930        indices : list
931            a list of indices that meet the criteria
932
933        """
934        if tag is None:
935            return list(range(len(self.data)))
936        else:
937            return list(np.where(self.metadata == tag)[0])
938
939    def get_tags(self) -> List:
940        """Get a list of all meta tags.
941
942        Returns
943        -------
944        tags: List
945            a list of unique meta tag values
946
947        """
948        if self.metadata is None:
949            return [None]
950        else:
951            return list(np.unique(self.metadata))
952
953    def get_tag(self, idx: int) -> Union[int, None]:
954        """Return a tag object corresponding to an index.
955
956        Tags can carry meta information (like annotator id) and are accepted by models that require
957        that information. When a tag is `None`, it is not passed to the model.
958
959        Parameters
960        ----------
961        idx : int
962            the index
963
964        Returns
965        -------
966        tag : int
967            the tag object
968
969        """
970        if self.metadata is None or idx is None:
971            return None
972        else:
973            return self.metadata[idx]
974
975    @abstractmethod
976    def _load_data(self) -> None:
977        """Load input data and generate data prompts."""

A generalized realization of a PoseInputStore.

Assumes the following file structure:

data_path
├── video1DLC1000.pickle
├── video2DLC400.pickle
├── video1_features.pt
└── video2_features.pt

Here data_suffix is {'DLC1000.pickle', 'DLC400.pickle'} and feature_suffix (optional) is '_features.pt'.

GeneralInputStore( video_order: List = None, data_path: Union[Set, str] = None, file_paths: Set = None, data_suffix: Union[Set, str] = None, data_prefix: Union[Set, str] = None, feature_suffix: str = None, convert_int_indices: bool = True, feature_save_path: str = None, canvas_shape: List = None, len_segment: int = 128, overlap: int = 0, feature_extraction: str = 'kinematic', ignored_clips: List = None, ignored_bodyparts: List = None, default_agent_name: str = 'ind0', key_objects: Tuple = None, likelihood_threshold: float = 0, num_cpus: int = None, frame_limit: int = 1, normalize: bool = False, feature_extraction_pars: Dict = None, centered: bool = False, transpose_features: bool = False, *args, **kwargs)
 44    def __init__(
 45        self,
 46        video_order: List = None,
 47        data_path: Union[Set, str] = None,
 48        file_paths: Set = None,
 49        data_suffix: Union[Set, str] = None,
 50        data_prefix: Union[Set, str] = None,
 51        feature_suffix: str = None,
 52        convert_int_indices: bool = True,
 53        feature_save_path: str = None,
 54        canvas_shape: List = None,
 55        len_segment: int = 128,
 56        overlap: int = 0,
 57        feature_extraction: str = "kinematic",
 58        ignored_clips: List = None,
 59        ignored_bodyparts: List = None,
 60        default_agent_name: str = "ind0",
 61        key_objects: Tuple = None,
 62        likelihood_threshold: float = 0,
 63        num_cpus: int = None,
 64        frame_limit: int = 1,
 65        normalize: bool = False,
 66        feature_extraction_pars: Dict = None,
 67        centered: bool = False,
 68        transpose_features: bool = False,
 69        *args,
 70        **kwargs,
 71    ) -> None:
 72        """Initialize a store.
 73
 74        Parameters
 75        ----------
 76        video_order : list, optional
 77            a list of video ids that should be processed in the same order (not passed if creating from key objects
 78        data_path : str | set, optional
 79            the path to the folder where the pose and feature files are stored or a set of such paths
 80            (not passed if creating from key objects or from `file_paths`)
 81        file_paths : set, optional
 82            a set of string paths to the pose and feature files
 83            (not passed if creating from key objects or from `data_path`)
 84        data_suffix : str | set, optional
 85            the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix}
 86            (not passed if creating from key objects or if irrelevant for the dataset)
 87        data_prefix : str | set, optional
 88            the prefix or the set of prefixes such that the pose files for different video views of the same
 89            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
 90            or if irrelevant for the dataset)
 91        feature_suffix : str | set, optional
 92            the suffix or the set of suffices such that the additional feature files are named
 93            {video_id}{feature_suffix} (and placed at the data_path folder)
 94        convert_int_indices : bool, default True
 95            if `True`, convert any integer key `i` in feature files to `'ind{i}'`
 96        feature_save_path : str, optional
 97            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
 98        canvas_shape : List, default [1, 1]
 99            the canvas size where the pose is defined
100        len_segment : int, default 128
101            the length of the segments in which the data should be cut (in frames)
102        overlap : int, default 0
103            the length of the overlap between neighboring segments (in frames)
104        feature_extraction : str, default 'kinematic'
105            the feature extraction method (see options.feature_extractors for available options)
106        ignored_clips : list, optional
107            list of strings of clip ids to ignore
108        ignored_bodyparts : list, optional
109            list of strings of bodypart names to ignore
110        default_agent_name : str, default 'ind0'
111            the agent name used as default in the pose files for a single agent
112        key_objects : tuple, optional
113            a tuple of key objects
114        likelihood_threshold : float, default 0
115            coordinate values with likelihoods less than this value will be set to 'unknown'
116        num_cpus : int, optional
117            the number of cpus to use in data processing
118        frame_limit : int, default 1
119            clips shorter than this number of frames will be ignored
120        normalize : bool, default False
121            whether to normalize the pose
122        feature_extraction_pars : dict, optional
123            parameters of the feature extractor
124        centered : bool, default False
125            whether the pose is centered
126        transpose_features : bool, default False
127            whether to transpose the features
128
129        """
130        super().__init__()
131        self.loaded_max = 0
132        if feature_extraction_pars is None:
133            feature_extraction_pars = {}
134        if ignored_clips is None:
135            ignored_clips = []
136        self.bodyparts = []
137        self.visibility = None
138        self.normalize = normalize
139
140        if canvas_shape is None:
141            canvas_shape = [1, 1]
142        if isinstance(data_suffix, str):
143            data_suffix = [data_suffix]
144        if isinstance(data_prefix, str):
145            data_prefix = [data_prefix]
146        if isinstance(data_path, str):
147            data_path = [data_path]
148        if isinstance(feature_suffix, str):
149            feature_suffix = [feature_suffix]
150
151        self.video_order = video_order
152        self.centered = centered
153        self.feature_extraction = feature_extraction
154        self.len_segment = int(len_segment)
155        self.data_suffices = data_suffix
156        self.data_prefixes = data_prefix
157        self.feature_suffix = feature_suffix
158        self.convert_int_indices = convert_int_indices
159        if isinstance(overlap, str):
160            overlap = float(overlap)
161        if overlap < 1:
162            overlap = overlap * self.len_segment
163        self.overlap = int(overlap)
164        self.canvas_shape = canvas_shape
165        self.default_agent_name = default_agent_name
166        self.feature_save_path = feature_save_path
167        self.data_suffices = data_suffix
168        self.data_prefixes = data_prefix
169        self.likelihood_threshold = likelihood_threshold
170        self.num_cpus = num_cpus
171        self.frame_limit = frame_limit
172        self.transpose = transpose_features
173
174        self.ram = False
175        self.min_frames = {}
176        self.original_coordinates = np.array([])
177
178        self.file_paths = self._get_file_paths(file_paths, data_path)
179
180        self.extractor = options.feature_extractors[self.feature_extraction](
181            self,
182            **feature_extraction_pars,
183        )
184
185        self.canvas_center = np.array(canvas_shape) // 2
186
187        if ignored_clips is not None:
188            self.ignored_clips = ignored_clips
189        else:
190            self.ignored_clips = []
191        if ignored_bodyparts is not None:
192            self.ignored_bodyparts = ignored_bodyparts
193        else:
194            self.ignored_bodyparts = []
195
196        self.step = self.len_segment - self.overlap
197        if self.step < 0:
198            raise ValueError(
199                f"The overlap value ({self.overlap}) cannot be larger than len_segment ({self.len_segment}"
200            )
201
202        if self.feature_save_path is None and data_path is not None:
203            self.feature_save_path = os.path.join(data_path[0], "trimmed")
204
205        if key_objects is None and self.video_order is not None:
206            print("Computing input features...")
207            self.data = self._load_data()
208        elif key_objects is not None:
209            self.load_from_key_objects(key_objects)

Initialize a store.

Parameters

video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects data_path : str | set, optional the path to the folder where the pose and feature files are stored or a set of such paths (not passed if creating from key objects or from file_paths) file_paths : set, optional a set of string paths to the pose and feature files (not passed if creating from key objects or from data_path) data_suffix : str | set, optional the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix} (not passed if creating from key objects or if irrelevant for the dataset) data_prefix : str | set, optional the prefix or the set of prefixes such that the pose files for different video views of the same clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects or if irrelevant for the dataset) feature_suffix : str | set, optional the suffix or the set of suffices such that the additional feature files are named {video_id}{feature_suffix} (and placed at the data_path folder) convert_int_indices : bool, default True if True, convert any integer key i in feature files to 'ind{i}' feature_save_path : str, optional the path to the folder where pre-processed files are stored (not passed if creating from key objects) canvas_shape : List, default [1, 1] the canvas size where the pose is defined len_segment : int, default 128 the length of the segments in which the data should be cut (in frames) overlap : int, default 0 the length of the overlap between neighboring segments (in frames) feature_extraction : str, default 'kinematic' the feature extraction method (see options.feature_extractors for available options) ignored_clips : list, optional list of strings of clip ids to ignore ignored_bodyparts : list, optional list of strings of bodypart names to ignore default_agent_name : str, default 'ind0' the agent name used as default in the pose files for a single agent key_objects : tuple, optional a tuple of key objects likelihood_threshold : float, default 0 coordinate values with likelihoods less than this value will be set to 'unknown' num_cpus : int, optional the number of cpus to use in data processing frame_limit : int, default 1 clips shorter than this number of frames will be ignored normalize : bool, default False whether to normalize the pose feature_extraction_pars : dict, optional parameters of the feature extractor centered : bool, default False whether the pose is centered transpose_features : bool, default False whether to transpose the features

data_suffix = None
loaded_max
bodyparts
visibility
normalize
video_order
centered
feature_extraction
len_segment
data_suffices
data_prefixes
feature_suffix
convert_int_indices
overlap
canvas_shape
default_agent_name
feature_save_path
likelihood_threshold
num_cpus
frame_limit
transpose
ram
min_frames
original_coordinates
file_paths
extractor
canvas_center
step
def get_folder(self, video_id: str) -> str:
252    def get_folder(self, video_id: str) -> str:
253        """Get the input folder that the file with this video id was read from.
254
255        Parameters
256        ----------
257        video_id : str
258            the video id
259
260        Returns
261        -------
262        folder : str
263            the path to the directory that contains the input file associated with the video id
264
265        """
266        for file in self.file_paths:
267            if (
268                strip_prefix(
269                    strip_suffix(os.path.basename(file), self.data_suffices),
270                    self.data_prefixes,
271                )
272                == video_id
273            ):
274                return os.path.dirname(file)

Get the input folder that the file with this video id was read from.

Parameters

video_id : str the video id

Returns

folder : str the path to the directory that contains the input file associated with the video id

def remove(self, indices: List) -> None:
276    def remove(self, indices: List) -> None:
277        """Remove the samples corresponding to indices.
278
279        Parameters
280        ----------
281        indices : int
282            a list of integer indices to remove
283
284        """
285        if len(indices) > 0:
286            mask = np.ones(len(self.original_coordinates))
287            mask[indices] = 0
288            mask = mask.astype(bool)
289            for file in self.data[~mask]:
290                os.remove(file)
291            self.original_coordinates = self.original_coordinates[mask]
292            self.data = self.data[mask]
293            if self.metadata is not None:
294                self.metadata = self.metadata[mask]

Remove the samples corresponding to indices.

Parameters

indices : int a list of integer indices to remove

def key_objects(self) -> Tuple:
296    def key_objects(self) -> Tuple:
297        """Return a tuple of the key objects necessary to re-create the Store.
298
299        Returns
300        -------
301        key_objects : tuple
302            a tuple of key objects
303
304        """
305        for k, v in self.min_frames.items():
306            self.min_frames[k] = dict(v)
307        for k, v in self.max_frames.items():
308            self.max_frames[k] = dict(v)
309        return (
310            self.original_coordinates,
311            dict(self.min_frames),
312            dict(self.max_frames),
313            self.data,
314            self.visibility,
315            self.step,
316            self.file_paths,
317            self.len_segment,
318            self.metadata,
319        )

Return a tuple of the key objects necessary to re-create the Store.

Returns

key_objects : tuple a tuple of key objects

def load_from_key_objects(self, key_objects: Tuple) -> None:
321    def load_from_key_objects(self, key_objects: Tuple) -> None:
322        """Load the information from a tuple of key objects.
323
324        Parameters
325        ----------
326        key_objects : tuple
327            a tuple of key objects
328
329        """
330        (
331            self.original_coordinates,
332            self.min_frames,
333            self.max_frames,
334            self.data,
335            self.visibility,
336            self.step,
337            self.file_paths,
338            self.len_segment,
339            self.metadata,
340        ) = key_objects

Load the information from a tuple of key objects.

Parameters

key_objects : tuple a tuple of key objects

def to_ram(self) -> None:
342    def to_ram(self) -> None:
343        """Transfer the data samples to RAM if they were previously stored as file paths."""
344        if self.ram:
345            return
346
347        if os.name != "nt":
348            data = p_map(
349                lambda x: self[x], list(range(len(self))), num_cpus=self.num_cpus
350            )
351        else:
352            print(
353                "Multiprocessing is not supported on Windows, loading files sequentially."
354            )
355            data = [load(x) for x in tqdm(self.data)]
356        self.data = TensorDict(data)
357        self.ram = True

Transfer the data samples to RAM if they were previously stored as file paths.

def get_original_coordinates(self) -> numpy.ndarray:
359    def get_original_coordinates(self) -> np.ndarray:
360        """Return the original coordinates array.
361
362        Returns
363        -------
364        np.ndarray
365            an array that contains the coordinates of the data samples in original input data (video id, clip id,
366            start frame)
367
368        """
369        return self.original_coordinates

Return the original coordinates array.

Returns

np.ndarray an array that contains the coordinates of the data samples in original input data (video id, clip id, start frame)

def create_subsample(self, indices: List, ssl_indices: List = None):
371    def create_subsample(self, indices: List, ssl_indices: List = None):
372        """Create a new store that contains a subsample of the data.
373
374        Parameters
375        ----------
376        indices : list
377            the indices to be included in the subsample
378        ssl_indices : list, optional
379            the indices to be included in the subsample without the annotation data
380
381        """
382        if ssl_indices is None:
383            ssl_indices = []
384        new = self.new()
385        new.original_coordinates = self.original_coordinates[indices + ssl_indices]
386        new.min_frames = self.min_frames
387        new.max_frames = self.max_frames
388        new.data = self.data[indices + ssl_indices]
389        new.visibility = self.visibility
390        new.step = self.step
391        new.file_paths = self.file_paths
392        new.len_segment = self.len_segment
393        if self.metadata is None:
394            new.metadata = None
395        else:
396            new.metadata = self.metadata[indices + ssl_indices]
397        return new

Create a new store that contains a subsample of the data.

Parameters

indices : list the indices to be included in the subsample ssl_indices : list, optional the indices to be included in the subsample without the annotation data

def get_video_id(self, coords: Tuple) -> str:
399    def get_video_id(self, coords: Tuple) -> str:
400        """Get the video id from an element of original coordinates.
401
402        Parameters
403        ----------
404        coords : tuple
405            an element of the original coordinates array
406
407        Returns
408        -------
409        video_id: str
410            the id of the video that the coordinates point to
411
412        """
413        video_name = coords[0].split("---")[0]
414        return video_name

Get the video id from an element of original coordinates.

Parameters

coords : tuple an element of the original coordinates array

Returns

video_id: str the id of the video that the coordinates point to

def get_clip_id(self, coords: Tuple) -> str:
416    def get_clip_id(self, coords: Tuple) -> str:
417        """Get the clip id from an element of original coordinates.
418
419        Parameters
420        ----------
421        coords : tuple
422            an element of the original coordinates array
423
424        Returns
425        -------
426        clip_id : str
427            the id of the clip that the coordinates point to
428
429        """
430        clip_id = coords[0].split("---")[1]
431        return clip_id

Get the clip id from an element of original coordinates.

Parameters

coords : tuple an element of the original coordinates array

Returns

clip_id : str the id of the clip that the coordinates point to

def get_clip_length(self, video_id: str, clip_id: str) -> int:
433    def get_clip_length(self, video_id: str, clip_id: str) -> int:
434        """Get the clip length from the id.
435
436        Parameters
437        ----------
438        video_id : str
439            the video id
440        clip_id : str
441            the clip id
442
443        Returns
444        -------
445        clip_length : int
446            the length of the clip
447
448        """
449        inds = clip_id.split("+")
450        max_frame = min([self.max_frames[video_id][x] for x in inds])
451        min_frame = max([self.min_frames[video_id][x] for x in inds])
452        return max_frame - min_frame + 1

Get the clip length from the id.

Parameters

video_id : str the video id clip_id : str the clip id

Returns

clip_length : int the length of the clip

def get_clip_start_end(self, coords: Tuple) -> Tuple[int, int]:
454    def get_clip_start_end(self, coords: Tuple) -> Tuple[int, int]:
455        """Get the clip start and end frames from an element of original coordinates.
456
457        Parameters
458        ----------
459        coords : tuple
460            an element of original coordinates array
461
462        Returns
463        -------
464        start : int
465            the start frame of the clip that the coordinates point to
466        end : int
467            the end frame of the clip that the coordinates point to
468
469        """
470        l = self.get_clip_length_from_coords(coords)
471        i = coords[1]
472        start = int(i) * self.step
473        end = min(start + self.len_segment, l)
474        return start, end

Get the clip start and end frames from an element of original coordinates.

Parameters

coords : tuple an element of original coordinates array

Returns

start : int the start frame of the clip that the coordinates point to end : int the end frame of the clip that the coordinates point to

def get_clip_start(self, video_name: str, clip_id: str) -> int:
476    def get_clip_start(self, video_name: str, clip_id: str) -> int:
477        """Get the clip start frame from the video id and the clip id.
478
479        Parameters
480        ----------
481        video_name : str
482            the video id
483        clip_id : str
484            the clip id
485
486        Returns
487        -------
488        clip_start : int
489            the start frame of the clip
490
491        """
492        return max(
493            [self.min_frames[video_name][clip_id_k] for clip_id_k in clip_id.split("+")]
494        )

Get the clip start frame from the video id and the clip id.

Parameters

video_name : str the video id clip_id : str the clip id

Returns

clip_start : int the start frame of the clip

def get_visibility( self, video_id: str, clip_id: str, start: int, end: int, score: int) -> float:
496    def get_visibility(
497        self, video_id: str, clip_id: str, start: int, end: int, score: int
498    ) -> float:
499        """Get the fraction of the frames in that have a visibility score better than a hard_threshold.
500
501        For example, in the case of keypoint data the visibility score can be the number of identified keypoints.
502
503        Parameters
504        ----------
505        video_id : str
506            the video id of the frames
507        clip_id : str
508            the clip id of the frames
509        start : int
510            the start frame
511        end : int
512            the end frame
513        score : float
514            the visibility score hard_threshold
515
516        Returns
517        -------
518        frac_visible: float
519            the fraction of frames with visibility above the hard_threshold
520
521        """
522        s = 0
523        for ind_k in clip_id.split("+"):
524            s += np.sum(self.visibility[video_id][ind_k][start:end] > score) / (
525                end - start
526            )
527        return s / len(clip_id.split("+"))

Get the fraction of the frames in that have a visibility score better than a hard_threshold.

For example, in the case of keypoint data the visibility score can be the number of identified keypoints.

Parameters

video_id : str the video id of the frames clip_id : str the clip id of the frames start : int the start frame end : int the end frame score : float the visibility score hard_threshold

Returns

frac_visible: float the fraction of frames with visibility above the hard_threshold

def get_annotation_objects(self) -> Dict:
529    def get_annotation_objects(self) -> Dict:
530        """Get a dictionary of objects necessary to create an `BehaviorStore`.
531
532        Returns
533        -------
534        annotation_objects : dict
535            a dictionary of objects to be passed to the BehaviorStore constructor where the keys are the names of
536            the objects
537
538        """
539        min_frames = self.min_frames
540        max_frames = self.max_frames
541        num_bp = self.visibility
542        return {
543            "min_frames": min_frames,
544            "max_frames": max_frames,
545            "visibility": num_bp,
546        }

Get a dictionary of objects necessary to create an BehaviorStore.

Returns

annotation_objects : dict a dictionary of objects to be passed to the BehaviorStore constructor where the keys are the names of the objects

@classmethod
def get_file_ids( cls, data_suffix: Union[Set, str] = None, data_path: Union[Set, str] = None, data_prefix: Union[Set, str] = None, file_paths: Set = None, feature_suffix: Set = None, *args, **kwargs) -> List:
548    @classmethod
549    def get_file_ids(
550        cls,
551        data_suffix: Union[Set, str] = None,
552        data_path: Union[Set, str] = None,
553        data_prefix: Union[Set, str] = None,
554        file_paths: Set = None,
555        feature_suffix: Set = None,
556        *args,
557        **kwargs,
558    ) -> List:
559        """Get file ids.
560
561        Process data parameters and return a list of ids  of the videos that should
562        be processed by the `__init__` function.
563
564        Parameters
565        ----------
566        data_suffix : set | str, optional
567            the suffix (or a set of suffixes) of the input data files
568        data_path : set | str, optional
569            the path to the folder where the pose and feature files are stored or a set of such paths
570            (not passed if creating from key objects or from `file_paths`)
571        data_prefix : set | str, optional
572            the prefix or the set of prefixes such that the pose files for different video views of the same
573            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
574            or if irrelevant for the dataset)
575        file_paths : set, optional
576            a set of string paths to the pose and feature files
577        feature_suffix : str | set, optional
578            the suffix or the set of suffices such that the additional feature files are named
579            {video_id}{feature_suffix} (and placed at the `data_path` folder or at `file_paths`)
580
581        Returns
582        -------
583        video_ids : list
584            a list of video file ids
585
586        """
587        if data_suffix is None:
588            if cls.data_suffix is not None:
589                data_suffix = cls.data_suffix
590            else:
591                raise ValueError("Cannot get video ids without the data suffix!")
592        if feature_suffix is None:
593            feature_suffix = []
594        if data_prefix is None:
595            data_prefix = ""
596        if isinstance(data_suffix, str):
597            data_suffix = [data_suffix]
598        else:
599            data_suffix = [x for x in data_suffix]
600        data_suffix = tuple(data_suffix)
601        if isinstance(data_prefix, str):
602            data_prefix = data_prefix
603        else:
604            data_prefix = tuple([x for x in data_prefix])
605        if isinstance(feature_suffix, str):
606            feature_suffix = [feature_suffix]
607        if file_paths is None:
608            file_paths = []
609        if data_path is not None:
610            if isinstance(data_path, str):
611                data_path = [data_path]
612            file_paths = []
613            for folder in data_path:
614                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
615        basenames = [os.path.basename(f) for f in file_paths]
616        ids = set()
617        for f in file_paths:
618            if f.endswith(data_suffix) and os.path.basename(f).startswith(data_prefix):
619                bn = os.path.basename(f)
620                video_id = strip_prefix(strip_suffix(bn, data_suffix), data_prefix)
621                if all([video_id + s in basenames for s in feature_suffix]):
622                    ids.add(video_id)
623        ids = sorted(ids)
624        return ids

Get file ids.

Process data parameters and return a list of ids of the videos that should be processed by the __init__ function.

Parameters

data_suffix : set | str, optional the suffix (or a set of suffixes) of the input data files data_path : set | str, optional the path to the folder where the pose and feature files are stored or a set of such paths (not passed if creating from key objects or from file_paths) data_prefix : set | str, optional the prefix or the set of prefixes such that the pose files for different video views of the same clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects or if irrelevant for the dataset) file_paths : set, optional a set of string paths to the pose and feature files feature_suffix : str | set, optional the suffix or the set of suffices such that the additional feature files are named {video_id}{feature_suffix} (and placed at the data_path folder or at file_paths)

Returns

video_ids : list a list of video file ids

def get_bodyparts(self) -> List:
626    def get_bodyparts(self) -> List:
627        """Get a list of bodypart names.
628
629        Parameters
630        ----------
631        data_dict : dict
632            the data dictionary (passed to feature extractor)
633        clip_id : str
634            the clip id
635
636        Returns
637        -------
638        bodyparts : list
639            a list of string or integer body part names
640
641        """
642        return [x for x in self.bodyparts if x not in self.ignored_bodyparts]

Get a list of bodypart names.

Parameters

data_dict : dict the data dictionary (passed to feature extractor) clip_id : str the clip id

Returns

bodyparts : list a list of string or integer body part names

def get_coords(self, data_dict: Dict, clip_id: str, bodypart: str) -> numpy.ndarray:
644    def get_coords(self, data_dict: Dict, clip_id: str, bodypart: str) -> np.ndarray:
645        """Get the coordinates array of a specific bodypart in a specific clip.
646
647        Parameters
648        ----------
649        data_dict : dict
650            the data dictionary (passed to feature extractor)
651        clip_id : str
652            the clip id
653        bodypart : str
654            the name of the body part
655
656        Returns
657        -------
658        coords : np.ndarray
659            the coordinates array of shape (#timesteps, #coordinates)
660
661        """
662        columns = [x for x in data_dict[clip_id].columns if x != "likelihood"]
663        xy_coord = (
664            data_dict[clip_id]
665            .xs(bodypart, axis=0, level=1, drop_level=False)[columns]
666            .values
667        )
668        return xy_coord

Get the coordinates array of a specific bodypart in a specific clip.

Parameters

data_dict : dict the data dictionary (passed to feature extractor) clip_id : str the clip id bodypart : str the name of the body part

Returns

coords : np.ndarray the coordinates array of shape (#timesteps, #coordinates)

def get_n_frames(self, data_dict: Dict, clip_id: str) -> int:
670    def get_n_frames(self, data_dict: Dict, clip_id: str) -> int:
671        """Get the length of the clip.
672
673        Parameters
674        ----------
675        data_dict : dict
676            the data dictionary (passed to feature extractor)
677        clip_id : str
678            the clip id
679
680        Returns
681        -------
682        n_frames : int
683            the length of the clip
684
685        """
686        if clip_id in data_dict:
687            return len(data_dict[clip_id].groupby(level=0))
688        else:
689            return min(
690                [len(data_dict[ind_k].groupby(level=0)) for ind_k in clip_id.split("+")]
691            )

Get the length of the clip.

Parameters

data_dict : dict the data dictionary (passed to feature extractor) clip_id : str the clip id

Returns

n_frames : int the length of the clip

def get_likelihood( self, data_dict: Dict, clip_id: str, bodypart: str) -> Optional[numpy.ndarray]:
883    def get_likelihood(
884        self, data_dict: Dict, clip_id: str, bodypart: str
885    ) -> Union[np.ndarray, None]:
886        """Get the likelihood values.
887
888        Parameters
889        ----------
890        data_dict : dict
891            the data dictionary
892        clip_id : str
893            the clip id
894        bodypart : str
895            the name of the body part
896
897        Returns
898        -------
899        likelihoods: np.ndarrray | None
900            `None` if the dataset doesn't have likelihoods or an array of shape (#timestamps)
901
902        """
903        if "likelihood" in data_dict[clip_id].columns:
904            likelihood = (
905                data_dict[clip_id]
906                .xs(bodypart, axis=0, level=1, drop_level=False)
907                .values[:, -1]
908            )
909            return likelihood
910        else:
911            return None

Get the likelihood values.

Parameters

data_dict : dict the data dictionary clip_id : str the clip id bodypart : str the name of the body part

Returns

likelihoods: np.ndarrray | None None if the dataset doesn't have likelihoods or an array of shape (#timestamps)

def get_indices(self, tag: int) -> List:
920    def get_indices(self, tag: int) -> List:
921        """Get a list of indices of samples that have a specific meta tag.
922
923        Parameters
924        ----------
925        tag : int
926            the meta tag for the subsample (`None` for the whole dataset)
927
928        Returns
929        -------
930        indices : list
931            a list of indices that meet the criteria
932
933        """
934        if tag is None:
935            return list(range(len(self.data)))
936        else:
937            return list(np.where(self.metadata == tag)[0])

Get a list of indices of samples that have a specific meta tag.

Parameters

tag : int the meta tag for the subsample (None for the whole dataset)

Returns

indices : list a list of indices that meet the criteria

def get_tags(self) -> List:
939    def get_tags(self) -> List:
940        """Get a list of all meta tags.
941
942        Returns
943        -------
944        tags: List
945            a list of unique meta tag values
946
947        """
948        if self.metadata is None:
949            return [None]
950        else:
951            return list(np.unique(self.metadata))

Get a list of all meta tags.

Returns

tags: List a list of unique meta tag values

def get_tag(self, idx: int) -> Optional[int]:
953    def get_tag(self, idx: int) -> Union[int, None]:
954        """Return a tag object corresponding to an index.
955
956        Tags can carry meta information (like annotator id) and are accepted by models that require
957        that information. When a tag is `None`, it is not passed to the model.
958
959        Parameters
960        ----------
961        idx : int
962            the index
963
964        Returns
965        -------
966        tag : int
967            the tag object
968
969        """
970        if self.metadata is None or idx is None:
971            return None
972        else:
973            return self.metadata[idx]

Return a tag object corresponding to an index.

Tags can carry meta information (like annotator id) and are accepted by models that require that information. When a tag is None, it is not passed to the model.

Parameters

idx : int the index

Returns

tag : int the tag object

class FileInputStore(GeneralInputStore):
 980class FileInputStore(GeneralInputStore):
 981    """An implementation of `dlc2action.data.InputStore` for datasets where each input data file corresponds to one video."""
 982
 983    def _count_bodyparts(
 984        self, data: Dict, stripped_name: str, max_frames: Dict
 985    ) -> Dict:
 986        """Create a visibility score dictionary (with a score from 0 to 1 assigned to each frame of each clip)."""
 987        result = {stripped_name: {}}
 988        prefixes = list(data.keys())
 989        for ind in data[prefixes[0]]:
 990            res = 0
 991            for _, data_dict in data.items():
 992                num_bp = len(data_dict[ind].index.unique(level=1))
 993                coords = (
 994                    data_dict[ind].values.reshape(
 995                        -1, num_bp, len(data_dict[ind].columns)
 996                    )[: max_frames[ind], :, 0]
 997                    != 0
 998                )
 999                res = np.sum(coords, axis=1) + res
1000            result[stripped_name][ind] = (res / len(prefixes)) / coords.shape[1]
1001        return result
1002
1003    def _generate_features(self, data: Dict, video_id: str) -> Dict:
1004        """Generate features from the raw coordinates."""
1005        features = defaultdict(lambda: {})
1006        loaded_common = []
1007
1008        for prefix, data_dict in data.items():
1009            if prefix == "":
1010                prefix = None
1011            if "loaded" in data_dict:
1012                # loaded_common.append(torch.tensor(data_dict.pop("loaded")))
1013                loaded_common.append(torch.from_numpy(data_dict.pop("loaded")))
1014            key_features = self.extractor.extract_features(
1015                data_dict, video_id, prefix=prefix
1016            )
1017            for f_key in key_features:
1018                features[f_key].update(key_features[f_key])
1019        if len(loaded_common) > 0:
1020            if len(loaded_common) == 1:
1021                loaded_common = loaded_common[0]
1022            else:
1023                loaded_common = torch.cat(loaded_common, dim=1)
1024        else:
1025            loaded_common = None
1026        if self.feature_suffix is not None:
1027            loaded_features = self._load_saved_features(video_id)
1028            for clip_id, feature_tensor in loaded_features.items():
1029                if not isinstance(feature_tensor, torch.Tensor):
1030                    feature_tensor = torch.tensor(feature_tensor)
1031                if self.convert_int_indices and (
1032                    isinstance(clip_id, int) or isinstance(clip_id, np.integer)
1033                ):
1034                    clip_id = f"ind{clip_id}"
1035                key1 = f"{os.path.basename(video_id)}---{clip_id}"
1036                if key1 in features:
1037                    try:
1038                        key2 = list(features[key1].keys())[0]
1039                        n_frames = features[key1][key2].shape[0]
1040                        if feature_tensor.shape[0] != n_frames:
1041                            n = feature_tensor.shape[0] - n_frames
1042                            if (
1043                                abs(n) > 2
1044                                and abs(feature_tensor.shape[1] - n_frames) <= 2
1045                            ):
1046                                feature_tensor = feature_tensor.T
1047                            # If off by <=2 frames, just clip the end
1048                            elif n > 0 and n <= 2:
1049                                feature_tensor = feature_tensor[:n_frames, :]
1050                            elif n < 0 and n >= -2:
1051                                filler = feature_tensor[-2:-1, :]
1052                                for i in range(n_frames - feature_tensor.shape[0]):
1053                                    feature_tensor = torch.cat(
1054                                        [feature_tensor, filler], 0
1055                                    )
1056                            else:
1057                                raise RuntimeError(
1058                                    f"Number of frames in precomputed features with shape"
1059                                    f" {feature_tensor.shape} is inconsistent with generated features!"
1060                                )
1061                        if loaded_common is not None:
1062                            if feature_tensor.shape[0] == loaded_common.shape[0]:
1063                                feature_tensor = torch.cat(
1064                                    [feature_tensor, loaded_common], dim=1
1065                                )
1066                            elif feature_tensor.shape[0] == loaded_common.shape[1]:
1067                                feature_tensor = torch.cat(
1068                                    [feature_tensor.T, loaded_common], dim=1
1069                                )
1070                            else:
1071                                raise ValueError(
1072                                    "The features from the data file and from the feature file have a different number of frames!"
1073                                )
1074                        features[key1]["loaded"] = feature_tensor
1075                    except ValueError:
1076                        raise RuntimeError(
1077                            "Individuals in precomputed features are inconsistent "
1078                            "with generated features"
1079                        )
1080        elif loaded_common is not None:
1081            for key in features:
1082                features[key]["loaded"] = loaded_common
1083        return features
1084
1085    def _load_data(self) -> np.array:
1086        """Load input data and generate data prompts."""
1087        if self.video_order is None:
1088            return None
1089
1090        files = defaultdict(lambda: [])
1091        for f in self.file_paths:
1092            if f.endswith(tuple([x for x in self.data_suffices])):
1093                bn = os.path.basename(f)
1094                video_id = strip_prefix(
1095                    strip_suffix(bn, self.data_suffices), self.data_prefixes
1096                )
1097                files[video_id].append(f)
1098        files = [files[x] for x in self.video_order]
1099
1100        def make_data_dictionary(filenames):
1101            data = {}
1102            stored_maxes = defaultdict(lambda: [])
1103            min_frames, max_frames = {}, {}
1104            name = strip_suffix(filenames[0], self.data_suffices)
1105            name = os.path.basename(name)
1106            stripped_name = strip_prefix(name, self.data_prefixes)
1107            metadata_list = []
1108            for filename in filenames:
1109                name = strip_suffix(filename, self.data_suffices)
1110                name = os.path.basename(name)
1111                prefix = strip_suffix(name, [stripped_name])
1112                data_new, tag = self._open_data(filename, self.default_agent_name)
1113                data_new, min_frames, max_frames = self._filter(data_new)
1114                data[prefix] = data_new
1115                for key, val in max_frames.items():
1116                    stored_maxes[key].append(val)
1117                metadata_list.append(tag)
1118            video_tag = self._get_video_metadata(metadata_list)
1119            sample_df = list(list(data.values())[0].values())[0]
1120            self.bodyparts = sorted(list(sample_df.index.unique(1)))
1121            smallest_maxes = dict.fromkeys(stored_maxes)
1122            for key, val in stored_maxes.items():
1123                smallest_maxes[key] = np.amin(val)
1124            data_dict = self._generate_features(data, stripped_name)
1125            bp_dict = self._count_bodyparts(
1126                data=data, stripped_name=stripped_name, max_frames=smallest_maxes
1127            )
1128            min_frames = {stripped_name: min_frames}  # name is e.g. 20190707T1126-1226
1129            max_frames = {stripped_name: max_frames}
1130            names, lengths, coords = self._make_trimmed_data(data_dict)
1131            return names, lengths, coords, bp_dict, min_frames, max_frames, video_tag
1132
1133        if os.name != "nt":
1134            dict_list = p_map(make_data_dictionary, files, num_cpus=self.num_cpus)
1135        else:
1136            print(
1137                "Multiprocessing is not supported on Windows, loading files sequentially."
1138            )
1139            dict_list = tqdm([make_data_dictionary(f) for f in files])
1140
1141        self.visibility = {}
1142        self.min_frames = {}
1143        self.max_frames = {}
1144        self.original_coordinates = []
1145        self.metadata = []
1146        X = []
1147        for (
1148            names,
1149            lengths,
1150            coords,
1151            bp_dictionary,
1152            min_frames,
1153            max_frames,
1154            metadata,
1155        ) in dict_list:
1156            X += names
1157            self.original_coordinates += coords
1158            self.visibility.update(bp_dictionary)
1159            self.min_frames.update(min_frames)
1160            self.max_frames.update(max_frames)
1161            if metadata is not None:
1162                self.metadata += metadata
1163        del dict_list
1164        if len(self.metadata) != len(self.original_coordinates):
1165            self.metadata = None
1166        else:
1167            self.metadata = np.array(self.metadata)
1168
1169        self.min_frames = dict(self.min_frames)
1170        self.max_frames = dict(self.max_frames)
1171        self.original_coordinates = np.array(self.original_coordinates)
1172        return np.array(X)
1173
1174    @abstractmethod
1175    def _open_data(
1176        self, filename: str, default_clip_name: str
1177    ) -> Tuple[Dict, Optional[Dict]]:
1178        """Load the keypoints from filename and organize them in a dictionary.
1179
1180        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
1181        The first level is the frame numbers and the second is the body part names. The dataframes should have from
1182        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
1183        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
1184        be done automatically.
1185
1186        Parameters
1187        ----------
1188        filename : str
1189            path to the pose file
1190        default_clip_name : str
1191            the name to assign to a clip if it does not have a name in the raw data
1192
1193        Returns
1194        -------
1195        data dictionary : dict
1196            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
1197        metadata_dictionary : dict
1198            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
1199            like the annotator tag; for no metadata pass `None`)
1200
1201        """

An implementation of dlc2action.data.InputStore for datasets where each input data file corresponds to one video.

class SequenceInputStore(GeneralInputStore):
1204class SequenceInputStore(GeneralInputStore):
1205    """An implementation of `dlc2action.data.InputStore` for datasets where input data files correspond to multiple videos."""
1206
1207    def _count_bodyparts(
1208        self, data: Dict, stripped_name: str, max_frames: Dict
1209    ) -> Dict:
1210        """Create a visibility score dictionary (with a score from 0 to 1 assigned to each frame of each clip)."""
1211        result = {stripped_name: {}}
1212        for ind in data.keys():
1213            num_bp = len(data[ind].index.unique(level=1))
1214            coords = (
1215                data[ind].values.reshape(-1, num_bp, len(data[ind].columns))[
1216                    : max_frames[ind], :, 0
1217                ]
1218                != 0
1219            )
1220            res = np.sum(coords, axis=1)
1221            result[stripped_name][ind] = res / coords.shape[1]
1222        return result
1223
1224    def _generate_features(self, data: Dict, name: str) -> Dict:
1225        """Generate features for an individual."""
1226        features = self.extractor.extract_features(data, name, prefix=None)
1227        if self.feature_suffix is not None:
1228            loaded_features = self._load_saved_features(name)
1229            for clip_id, feature_tensor in loaded_features.items():
1230                if not isinstance(feature_tensor, torch.Tensor):
1231                    feature_tensor = torch.tensor(feature_tensor)
1232                if self.convert_int_indices and (
1233                    isinstance(clip_id, int) or isinstance(clip_id, np.integer)
1234                ):
1235                    clip_id = f"ind{clip_id}"
1236                key1 = f"{os.path.basename(name)}---{clip_id}"
1237                if key1 in features:
1238                    try:
1239                        key2 = list(features[key1].keys())[0]
1240                        n_frames = features[key1][key2].shape[0]
1241                        if feature_tensor.shape[0] != n_frames:
1242                            n = feature_tensor.shape[0] - n_frames
1243                            if (
1244                                abs(n) > 2
1245                                and abs(feature_tensor.shape[1] - n_frames) <= 2
1246                            ):
1247                                feature_tensor = feature_tensor.T
1248                            # If off by <=2 frames, just clip the end
1249                            elif n > 0 and n <= 2:
1250                                feature_tensor = feature_tensor[:n_frames, :]
1251                            elif n < 0 and n >= -2:
1252                                filler = feature_tensor[-2:-1, :]
1253                                for i in range(n_frames - feature_tensor.shape[0]):
1254                                    feature_tensor = torch.cat(
1255                                        [feature_tensor, filler], 0
1256                                    )
1257                            else:
1258                                raise RuntimeError(
1259                                    print(
1260                                        f"Number of frames in precomputed features with shape"
1261                                        f" {feature_tensor.shape} is inconsistent with generated features!"
1262                                    )
1263                                )
1264                        features[key1]["loaded"] = feature_tensor
1265                    except ValueError:
1266                        raise RuntimeError(
1267                            print(
1268                                "Individuals in precomputed features are inconsistent "
1269                                "with generated features"
1270                            )
1271                        )
1272        return features
1273
1274    def _load_data(self) -> np.array:
1275        """Load input data and generate data prompts."""
1276        if self.video_order is None:
1277            return None
1278
1279        files = []
1280        for f in self.file_paths:
1281            if os.path.basename(f) in self.video_order:
1282                files.append(f)
1283
1284        def make_data_dictionary(seq_tuple):
1285            loaded_features = None
1286            seq_id, sequence = seq_tuple
1287            data, tag = self._get_data(seq_id, sequence, self.default_agent_name)
1288            if "loaded" in data.keys():
1289                loaded_features = data.pop("loaded")
1290            data, min_frames, max_frames = self._filter(data)
1291            sample_df = list(data.values())[0]
1292            self.bodyparts = sorted(list(sample_df.index.unique(1)))
1293            data_dict = self._generate_features(data, seq_id)
1294            if loaded_features is not None:
1295                for key in data_dict.keys():
1296                    data_dict[key]["loaded"] = loaded_features
1297            bp_dict = self._count_bodyparts(
1298                data=data, stripped_name=seq_id, max_frames=max_frames
1299            )
1300            min_frames = {seq_id: min_frames}  # name is e.g. 20190707T1126-1226
1301            max_frames = {seq_id: max_frames}
1302            names, lengths, coords = self._make_trimmed_data(data_dict)
1303            return names, lengths, coords, bp_dict, min_frames, max_frames, tag
1304
1305        seq_tuples = []
1306        for file in files:
1307            opened = self._open_file(file)
1308            seq_tuples += opened
1309        if os.name != "nt":
1310            dict_list = p_map(
1311                make_data_dictionary, sorted(seq_tuples), num_cpus=self.num_cpus
1312            )
1313        else:
1314            print(
1315                "Multiprocessing is not supported on Windows, loading files sequentially."
1316            )
1317            dict_list = tqdm([make_data_dictionary(f) for f in files])
1318
1319        self.visibility = {}
1320        self.min_frames = {}
1321        self.max_frames = {}
1322        self.original_coordinates = []
1323        self.metadata = []
1324        X = []
1325        for (
1326            names,
1327            lengths,
1328            coords,
1329            bp_dictionary,
1330            min_frames,
1331            max_frames,
1332            tag,
1333        ) in dict_list:
1334            X += names
1335            self.original_coordinates += coords
1336            self.visibility.update(bp_dictionary)
1337            self.min_frames.update(min_frames)
1338            self.max_frames.update(max_frames)
1339            if tag is not None:
1340                self.metadata += [tag for _ in names]
1341        del dict_list
1342
1343        if len(self.metadata) != len(self.original_coordinates):
1344            self.metadata = None
1345        else:
1346            self.metadata = np.array(self.metadata)
1347        self.min_frames = dict(self.min_frames)
1348        self.max_frames = dict(self.max_frames)
1349        self.original_coordinates = np.array(self.original_coordinates)
1350        return np.array(X)
1351
1352    @classmethod
1353    def get_file_ids(
1354        cls,
1355        filenames: Set = None,
1356        data_path: Union[str, Set] = None,
1357        file_paths: Set = None,
1358        *args,
1359        **kwargs,
1360    ) -> List:
1361        """Get file ids.
1362
1363        Process data parameters and return a list of ids  of the videos that should
1364        be processed by the `__init__` function.
1365
1366        Parameters
1367        ----------
1368        filenames : set, optional
1369            a set of string filenames to search for (only basenames, not the whole paths)
1370        data_path : str | set, optional
1371            the path to the folder where the pose and feature files are stored or a set of such paths
1372            (not passed if creating from key objects or from `file_paths`)
1373        file_paths : set, optional
1374            a set of string paths to the pose and feature files
1375            (not passed if creating from key objects or from `data_path`)
1376
1377        Returns
1378        -------
1379        video_ids : list
1380            a list of video file ids
1381
1382        """
1383        if file_paths is None:
1384            file_paths = []
1385        if data_path is not None:
1386            if isinstance(data_path, str):
1387                data_path = [data_path]
1388            file_paths = []
1389            for folder in data_path:
1390                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
1391        ids = set()
1392        for f in file_paths:
1393            if os.path.basename(f) in filenames:
1394                ids.add(os.path.basename(f))
1395        ids = sorted(ids)
1396        return ids
1397
1398    @abstractmethod
1399    def _open_file(self, filename: str) -> List:
1400        """Open a file and make a list of sequences.
1401
1402        The sequence objects should contain information about all clips in one video. The sequences and
1403        video ids will be processed in the `_get_data` function.
1404
1405        Parameters
1406        ----------
1407        filename : str
1408            the name of the file
1409
1410        Returns
1411        -------
1412        video_tuples : list
1413            a list of video tuples: `(video_id, sequence)`
1414
1415        """
1416
1417    @abstractmethod
1418    def _get_data(
1419        self, video_id: str, sequence, default_agent_name: str
1420    ) -> Tuple[Dict, Optional[Dict]]:
1421        """Get the keypoint dataframes from a sequence.
1422
1423        The sequences and video ids are generated in the `_open_file` function.
1424        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
1425        The first level is the frame numbers and the second is the body part names. The dataframes should have from
1426        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
1427        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
1428        be done automatically.
1429
1430        Parameters
1431        ----------
1432        video_id : str
1433            the video id
1434        sequence
1435            an object containing information about all clips in one video
1436        default_agent_name : str
1437            the default agent name
1438
1439        Returns
1440        -------
1441        data dictionary : dict
1442            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
1443        metadata_dictionary : dict
1444            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
1445            like the annotator tag; for no metadata pass `None`)
1446
1447        """

An implementation of dlc2action.data.InputStore for datasets where input data files correspond to multiple videos.

@classmethod
def get_file_ids( cls, filenames: Set = None, data_path: Union[str, Set] = None, file_paths: Set = None, *args, **kwargs) -> List:
1352    @classmethod
1353    def get_file_ids(
1354        cls,
1355        filenames: Set = None,
1356        data_path: Union[str, Set] = None,
1357        file_paths: Set = None,
1358        *args,
1359        **kwargs,
1360    ) -> List:
1361        """Get file ids.
1362
1363        Process data parameters and return a list of ids  of the videos that should
1364        be processed by the `__init__` function.
1365
1366        Parameters
1367        ----------
1368        filenames : set, optional
1369            a set of string filenames to search for (only basenames, not the whole paths)
1370        data_path : str | set, optional
1371            the path to the folder where the pose and feature files are stored or a set of such paths
1372            (not passed if creating from key objects or from `file_paths`)
1373        file_paths : set, optional
1374            a set of string paths to the pose and feature files
1375            (not passed if creating from key objects or from `data_path`)
1376
1377        Returns
1378        -------
1379        video_ids : list
1380            a list of video file ids
1381
1382        """
1383        if file_paths is None:
1384            file_paths = []
1385        if data_path is not None:
1386            if isinstance(data_path, str):
1387                data_path = [data_path]
1388            file_paths = []
1389            for folder in data_path:
1390                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
1391        ids = set()
1392        for f in file_paths:
1393            if os.path.basename(f) in filenames:
1394                ids.add(os.path.basename(f))
1395        ids = sorted(ids)
1396        return ids

Get file ids.

Process data parameters and return a list of ids of the videos that should be processed by the __init__ function.

Parameters

filenames : set, optional a set of string filenames to search for (only basenames, not the whole paths) data_path : str | set, optional the path to the folder where the pose and feature files are stored or a set of such paths (not passed if creating from key objects or from file_paths) file_paths : set, optional a set of string paths to the pose and feature files (not passed if creating from key objects or from data_path)

Returns

video_ids : list a list of video file ids

class DLCTrackStore(FileInputStore):
1450class DLCTrackStore(FileInputStore):
1451    """DLC track data.
1452
1453    Assumes the following file structure:
1454    ```
1455    data_path
1456    ├── video1DLC1000.pickle
1457    ├── video2DLC400.pickle
1458    ├── video1_features.pt
1459    └── video2_features.pt
1460    ```
1461    Here `data_suffix` is `{'DLC1000.pickle', 'DLC400.pickle'}` and `feature_suffix` (optional) is `'_features.pt'`.
1462
1463    The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
1464    feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
1465    set `transpose_features` to `True`.
1466
1467    The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
1468    with `pickle.dump()` (with `.pickle` or `.pkl` extension).
1469    """
1470
1471    def _open_data(
1472        self, filename: str, default_agent_name: str
1473    ) -> Tuple[Dict, Optional[Dict]]:
1474        """Load the keypoints from filename and organize them in a dictionary.
1475
1476        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
1477        The first level is the frame numbers and the second is the body part names. The dataframes should have from
1478        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
1479        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
1480        be done automatically.
1481
1482        Parameters
1483        ----------
1484        filename : str
1485            path to the pose file
1486        default_agent_name : str
1487            the default agent name
1488
1489        Returns
1490        -------
1491        data dictionary : dict
1492            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
1493        metadata_dictionary : dict
1494            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
1495            like the annotator tag; for no metadata pass `None`)
1496
1497        """
1498        if filename.endswith("h5"):
1499            temp = pd.read_hdf(filename)
1500            temp = temp.droplevel("scorer", axis=1)
1501        else:
1502            temp = pd.read_csv(filename, header=[1, 2])
1503            temp.columns.names = ["bodyparts", "coords"]
1504        if "individuals" not in temp.columns.names:
1505            old_idx = temp.columns.to_frame()
1506            old_idx.insert(0, "individuals", self.default_agent_name)
1507            temp.columns = pd.MultiIndex.from_frame(old_idx)
1508        df = temp.stack(["individuals", "bodyparts"], future_stack=True)
1509        idx = pd.MultiIndex.from_product(
1510            [df.index.levels[0], df.index.levels[1], df.index.levels[2]],
1511            names=df.index.names,
1512        )
1513        df = df.reindex(idx).fillna(value=0)
1514        animals = sorted(list(df.index.levels[1]))
1515        dic = {}
1516        for ind in animals:
1517            coord = df.iloc[df.index.get_level_values(1) == ind].droplevel(1)
1518            coord = coord[["x", "y", "likelihood"]]
1519            dic[ind] = coord
1520
1521        return dic, None

DLC track data.

Assumes the following file structure:

data_path
├── video1DLC1000.pickle
├── video2DLC400.pickle
├── video1_features.pt
└── video2_features.pt

Here data_suffix is {'DLC1000.pickle', 'DLC400.pickle'} and feature_suffix (optional) is '_features.pt'.

The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are feature values (arrays of shape (#frames, #features)). If the arrays are shaped as (#features, #frames), set transpose_features to True.

The files can be saved with numpy.save() (with .npy extension), torch.save() (with .pt extension) or with pickle.dump() (with .pickle or .pkl extension).

class DLCTrackletStore(FileInputStore):
1524class DLCTrackletStore(FileInputStore):
1525    """DLC tracklet data.
1526
1527    Assumes the following file structure:
1528    ```
1529    data_path
1530    ├── video1DLC1000.pickle
1531    ├── video2DLC400.pickle
1532    ├── video1_features.pt
1533    └── video2_features.pt
1534    ```
1535    Here `data_suffix` is `{'DLC1000.pickle', 'DLC400.pickle'}` and `feature_suffix` (optional) is `'_features.pt'`.
1536
1537    The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
1538    feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
1539    set `transpose_features` to `True`.
1540
1541    The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
1542    with `pickle.dump()` (with `.pickle` or `.pkl` extension).
1543    """
1544
1545    def _open_data(
1546        self, filename: str, default_agent_name: str
1547    ) -> Tuple[Dict, Optional[Dict]]:
1548        """Load the keypoints from filename and organize them in a dictionary.
1549
1550        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
1551        The first level is the frame numbers and the second is the body part names. The dataframes should have from
1552        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
1553        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
1554        be done automatically.
1555
1556        Parameters
1557        ----------
1558        filename : str
1559            path to the pose file
1560        default_agent_name : str
1561            the default agent name
1562
1563        Returns
1564        -------
1565        data dictionary : dict
1566            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
1567        metadata_dictionary : dict
1568            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
1569            like the annotator tag; for no metadata pass `None`)
1570
1571        """
1572        output = {}
1573        with open(filename, "rb") as f:
1574            data_p = pickle.load(f)
1575        header = data_p["header"]
1576        bodyparts = header.unique("bodyparts")
1577
1578        keys = sorted([key for key in data_p.keys() if key != "header"])
1579        min_frames = defaultdict(lambda: 10**5)
1580        max_frames = defaultdict(lambda: 0)
1581        for tr_id in keys:
1582            coords = {}
1583            fr_i = int(list(data_p[tr_id].keys())[0][5:]) - 1
1584            for frame in data_p[tr_id]:
1585                count = 0
1586                while int(frame[5:]) > fr_i + 1:
1587                    count += 1
1588                    fr_i = fr_i + 1
1589                    if count <= 3:
1590                        for bp, name in enumerate(bodyparts):
1591                            coords[(fr_i, name)] = coords[(fr_i - 1, name)]
1592                    else:
1593                        for bp, name in enumerate(bodyparts):
1594                            coords[(fr_i, name)] = np.zeros(
1595                                coords[(fr_i - 1, name)].shape
1596                            )
1597                fr_i = int(frame[5:])
1598                if fr_i > max_frames[f"ind{tr_id}"]:
1599                    max_frames[f"ind{tr_id}"] = fr_i
1600                if fr_i < min_frames[f"ind{tr_id}"]:
1601                    min_frames[f"ind{tr_id}"] = fr_i
1602                for bp, name in enumerate(bodyparts):
1603                    coords[(fr_i, name)] = data_p[tr_id][frame][bp][:3]
1604
1605            output[f"ind{tr_id}"] = pd.DataFrame(
1606                data=coords, index=["x", "y", "likelihood"]
1607            ).T
1608        return output, None

DLC tracklet data.

Assumes the following file structure:

data_path
├── video1DLC1000.pickle
├── video2DLC400.pickle
├── video1_features.pt
└── video2_features.pt

Here data_suffix is {'DLC1000.pickle', 'DLC400.pickle'} and feature_suffix (optional) is '_features.pt'.

The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are feature values (arrays of shape (#frames, #features)). If the arrays are shaped as (#features, #frames), set transpose_features to True.

The files can be saved with numpy.save() (with .npy extension), torch.save() (with .pt extension) or with pickle.dump() (with .pickle or .pkl extension).

class CalMS21InputStore(SequenceInputStore):
1611class CalMS21InputStore(SequenceInputStore):
1612    """CalMS21 data.
1613
1614    Use the `'random:test_from_name:{name}'` and `'val-from-name:{val_name}:test-from-name:{test_name}'`
1615    partitioning methods with `'train'`, `'test'` and `'unlabeled'` names to separate into train, test and validation
1616    subsets according to the original files. For example, with `'val-from-name:test:test-from-name:unlabeled'`
1617    the data from the test file will go into validation and the unlabeled files will be the test.
1618
1619    Assumes the following file structure:
1620    ```
1621    data_path
1622    ├── calms21_task1_train.npy
1623    ├── calms21_task1_test.npy
1624    ├── calms21_task1_test_features.npy
1625    ├── calms21_task1_test_features.npy
1626    ├── calms21_unlabeled_videos_part1.npy
1627    ├── calms21_unlabeled_videos_part1.npy
1628    ├── calms21_unlabeled_videos_part2.npy
1629    └── calms21_unlabeled_videos_part3.npy
1630    ```
1631    """
1632
1633    def __init__(
1634        self,
1635        video_order: List = None,
1636        data_path: Union[Set, str] = None,
1637        file_paths: Set = None,
1638        task_n: int = 1,
1639        include_task1: bool = True,
1640        feature_save_path: str = None,
1641        len_segment: int = 128,
1642        overlap: int = 0,
1643        feature_extraction: str = "kinematic",
1644        key_objects: Dict = None,
1645        treba_files: bool = False,
1646        num_cpus: int = None,
1647        feature_extraction_pars: Dict = None,
1648        *args,
1649        **kwargs,
1650    ) -> None:
1651        """Initialize a store.
1652
1653        Parameters
1654        ----------
1655        video_order : list, optional
1656            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1657        data_path : str | set, optional
1658            the path to the folder where the pose and feature files are stored or a set of such paths
1659            (not passed if creating from key objects or from `file_paths`)
1660        file_paths : set, optional
1661            a set of string paths to the pose and feature files
1662            (not passed if creating from key objects or from `data_path`)
1663        task_n : [1, 2]
1664            the number of the task
1665        include_task1 : bool, default True
1666            include task 1 data to training set
1667        feature_save_path : str, optional
1668            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
1669        len_segment : int, default 128
1670            the length of the segments in which the data should be cut (in frames)
1671        overlap : int, default 0
1672            the length of the overlap between neighboring segments (in frames)
1673        feature_extraction : str, default 'kinematic'
1674            the feature extraction method (see options.feature_extractors for available options)
1675        ignored_bodyparts : list, optional
1676            list of strings of bodypart names to ignore
1677        key_objects : tuple, optional
1678            a tuple of key objects
1679        treba_files : bool, default False
1680            if `True`, TREBA feature files will be loaded
1681        num_cpus : int, optional
1682            the number of cpus to use in data processing
1683        feature_extraction_pars : dict, optional
1684            parameters of the feature extractor
1685
1686        """
1687        self.task_n = int(task_n)
1688        self.include_task1 = include_task1
1689        self.treba_files = treba_files
1690        if feature_extraction_pars is not None:
1691            feature_extraction_pars["interactive"] = True
1692
1693        super().__init__(
1694            video_order,
1695            data_path,
1696            file_paths,
1697            data_prefix=None,
1698            feature_suffix=None,
1699            convert_int_indices=False,
1700            feature_save_path=feature_save_path,
1701            canvas_shape=[1024, 570],
1702            len_segment=len_segment,
1703            overlap=overlap,
1704            feature_extraction=feature_extraction,
1705            ignored_clips=None,
1706            ignored_bodyparts=None,
1707            default_agent_name="ind0",
1708            key_objects=key_objects,
1709            likelihood_threshold=0,
1710            num_cpus=num_cpus,
1711            frame_limit=1,
1712            feature_extraction_pars=feature_extraction_pars,
1713        )
1714
1715    @classmethod
1716    def get_file_ids(
1717        cls,
1718        task_n: int = 1,
1719        include_task1: bool = False,
1720        treba_files: bool = False,
1721        data_path: Union[str, Set] = None,
1722        file_paths=None,
1723        *args,
1724        **kwargs,
1725    ) -> Iterable:
1726        """Get file ids.
1727
1728        Process data parameters and return a list of ids  of the videos that should
1729        be processed by the `__init__` function.
1730
1731        Parameters
1732        ----------
1733        task_n : {1, 2, 3}
1734            the index of the CalMS21 challenge task
1735        include_task1 : bool, default False
1736            if `True`, the training file of the task 1 will be loaded
1737        treba_files : bool, default False
1738            if `True`, the TREBA feature files will be loaded
1739        filenames : set, optional
1740            a set of string filenames to search for (only basenames, not the whole paths)
1741        data_path : str | set, optional
1742            the path to the folder where the pose and feature files are stored or a set of such paths
1743            (not passed if creating from key objects or from `file_paths`)
1744        file_paths : set, optional
1745            a set of string paths to the pose and feature files
1746            (not passed if creating from key objects or from `data_path`)
1747
1748        Returns
1749        -------
1750        video_ids : list
1751            a list of video file ids
1752
1753        """
1754        task_n = int(task_n)
1755        if task_n == 1:
1756            include_task1 = False
1757        files = []
1758        if treba_files:
1759            postfix = "_features"
1760        else:
1761            postfix = ""
1762        files.append(f"calms21_task{task_n}_train{postfix}.npy")
1763        files.append(f"calms21_task{task_n}_test{postfix}.npy")
1764        if include_task1:
1765            files.append(f"calms21_task1_train{postfix}.npy")
1766        for i in range(1, 5):
1767            files.append(f"calms21_unlabeled_videos_part{i}{postfix}.npy")
1768        filenames = set(files)
1769        return SequenceInputStore.get_file_ids(filenames, data_path, file_paths)
1770
1771    def _open_file(self, filename: str) -> List:
1772        """Open a file and make a list of sequences.
1773
1774        The sequence objects should contain information about all clips in one video. The sequences and
1775        video ids will be processed in the `_get_data` function.
1776
1777        Parameters
1778        ----------
1779        filename : str
1780            the name of the file
1781
1782        Returns
1783        -------
1784        video_tuples : list
1785            a list of video tuples: `(video_id, sequence)`
1786
1787        """
1788        if os.path.basename(filename).startswith("calms21_unlabeled_videos"):
1789            mode = "unlabeled"
1790        elif os.path.basename(filename).startswith(f"calms21_task{self.task_n}_test"):
1791            mode = "test"
1792        else:
1793            mode = "train"
1794        data_dict = np.load(filename, allow_pickle=True).item()
1795        data = {}
1796        keys = list(data_dict.keys())
1797        for key in keys:
1798            data.update(data_dict[key])
1799            data_dict.pop(key)
1800        dict_list = [(f'{mode}--{k.split("/")[-1]}', v) for k, v in data.items()]
1801        return dict_list
1802
1803    def _get_data(
1804        self, video_id: str, sequence, default_agent_name: str
1805    ) -> Tuple[Dict, Optional[Dict]]:
1806        """Get the keypoint dataframes from a sequence.
1807
1808        The sequences and video ids are generated in the `_open_file` function.
1809        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
1810        The first level is the frame numbers and the second is the body part names. The dataframes should have from
1811        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
1812        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
1813        be done automatically.
1814
1815        Parameters
1816        ----------
1817        video_id : str
1818            the video id
1819        sequence
1820            an object containing information about all clips in one video
1821        default_agent_name
1822            the name of the default agent
1823
1824        Returns
1825        -------
1826        data dictionary : dict
1827            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
1828        metadata_dictionary : dict
1829            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
1830            like the annotator tag; for no metadata pass `None`)
1831
1832        """
1833        if "metadata" in sequence:
1834            annotator = sequence["metadata"]["annotator-id"]
1835        else:
1836            annotator = 0
1837        bodyparts = [
1838            "nose",
1839            "left ear",
1840            "right ear",
1841            "neck",
1842            "left hip",
1843            "right hip",
1844            "tail",
1845        ]
1846        columns = ["x", "y"]
1847        if "keypoints" in sequence:
1848            sequence = sequence["keypoints"]
1849            index = pd.MultiIndex.from_product([range(sequence.shape[0]), bodyparts])
1850            data = {
1851                "mouse1": pd.DataFrame(
1852                    data=(sequence[:, 0, :, :]).transpose((0, 2, 1)).reshape(-1, 2),
1853                    columns=columns,
1854                    index=index,
1855                ),
1856                "mouse2": pd.DataFrame(
1857                    data=(sequence[:, 1, :, :]).transpose((0, 2, 1)).reshape(-1, 2),
1858                    columns=columns,
1859                    index=index,
1860                ),
1861            }
1862        else:
1863            sequence = sequence["features"]
1864            mice = sequence[:, :-32].reshape((-1, 2, 2, 7))
1865            index = pd.MultiIndex.from_product([range(mice.shape[0]), bodyparts])
1866            data = {
1867                "mouse1": pd.DataFrame(
1868                    data=(mice[:, 0, :, :]).transpose((0, 2, 1)).reshape(-1, 2),
1869                    columns=columns,
1870                    index=index,
1871                ),
1872                "mouse2": pd.DataFrame(
1873                    data=(mice[:, 1, :, :]).transpose((0, 2, 1)).reshape(-1, 2),
1874                    columns=columns,
1875                    index=index,
1876                ),
1877                "loaded": sequence[:, -32:],
1878            }
1879        # metadata = {k: annotator for k in data.keys()}
1880        metadata = annotator
1881        return data, metadata

CalMS21 data.

Use the 'random:test_from_name:{name}' and 'val-from-name:{val_name}:test-from-name:{test_name}' partitioning methods with 'train', 'test' and 'unlabeled' names to separate into train, test and validation subsets according to the original files. For example, with 'val-from-name:test:test-from-name:unlabeled' the data from the test file will go into validation and the unlabeled files will be the test.

Assumes the following file structure:

data_path
├── calms21_task1_train.npy
├── calms21_task1_test.npy
├── calms21_task1_test_features.npy
├── calms21_task1_test_features.npy
├── calms21_unlabeled_videos_part1.npy
├── calms21_unlabeled_videos_part1.npy
├── calms21_unlabeled_videos_part2.npy
└── calms21_unlabeled_videos_part3.npy
CalMS21InputStore( video_order: List = None, data_path: Union[Set, str] = None, file_paths: Set = None, task_n: int = 1, include_task1: bool = True, feature_save_path: str = None, len_segment: int = 128, overlap: int = 0, feature_extraction: str = 'kinematic', key_objects: Dict = None, treba_files: bool = False, num_cpus: int = None, feature_extraction_pars: Dict = None, *args, **kwargs)
1633    def __init__(
1634        self,
1635        video_order: List = None,
1636        data_path: Union[Set, str] = None,
1637        file_paths: Set = None,
1638        task_n: int = 1,
1639        include_task1: bool = True,
1640        feature_save_path: str = None,
1641        len_segment: int = 128,
1642        overlap: int = 0,
1643        feature_extraction: str = "kinematic",
1644        key_objects: Dict = None,
1645        treba_files: bool = False,
1646        num_cpus: int = None,
1647        feature_extraction_pars: Dict = None,
1648        *args,
1649        **kwargs,
1650    ) -> None:
1651        """Initialize a store.
1652
1653        Parameters
1654        ----------
1655        video_order : list, optional
1656            a list of video ids that should be processed in the same order (not passed if creating from key objects)
1657        data_path : str | set, optional
1658            the path to the folder where the pose and feature files are stored or a set of such paths
1659            (not passed if creating from key objects or from `file_paths`)
1660        file_paths : set, optional
1661            a set of string paths to the pose and feature files
1662            (not passed if creating from key objects or from `data_path`)
1663        task_n : [1, 2]
1664            the number of the task
1665        include_task1 : bool, default True
1666            include task 1 data to training set
1667        feature_save_path : str, optional
1668            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
1669        len_segment : int, default 128
1670            the length of the segments in which the data should be cut (in frames)
1671        overlap : int, default 0
1672            the length of the overlap between neighboring segments (in frames)
1673        feature_extraction : str, default 'kinematic'
1674            the feature extraction method (see options.feature_extractors for available options)
1675        ignored_bodyparts : list, optional
1676            list of strings of bodypart names to ignore
1677        key_objects : tuple, optional
1678            a tuple of key objects
1679        treba_files : bool, default False
1680            if `True`, TREBA feature files will be loaded
1681        num_cpus : int, optional
1682            the number of cpus to use in data processing
1683        feature_extraction_pars : dict, optional
1684            parameters of the feature extractor
1685
1686        """
1687        self.task_n = int(task_n)
1688        self.include_task1 = include_task1
1689        self.treba_files = treba_files
1690        if feature_extraction_pars is not None:
1691            feature_extraction_pars["interactive"] = True
1692
1693        super().__init__(
1694            video_order,
1695            data_path,
1696            file_paths,
1697            data_prefix=None,
1698            feature_suffix=None,
1699            convert_int_indices=False,
1700            feature_save_path=feature_save_path,
1701            canvas_shape=[1024, 570],
1702            len_segment=len_segment,
1703            overlap=overlap,
1704            feature_extraction=feature_extraction,
1705            ignored_clips=None,
1706            ignored_bodyparts=None,
1707            default_agent_name="ind0",
1708            key_objects=key_objects,
1709            likelihood_threshold=0,
1710            num_cpus=num_cpus,
1711            frame_limit=1,
1712            feature_extraction_pars=feature_extraction_pars,
1713        )

Initialize a store.

Parameters

video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects) data_path : str | set, optional the path to the folder where the pose and feature files are stored or a set of such paths (not passed if creating from key objects or from file_paths) file_paths : set, optional a set of string paths to the pose and feature files (not passed if creating from key objects or from data_path) task_n : [1, 2] the number of the task include_task1 : bool, default True include task 1 data to training set feature_save_path : str, optional the path to the folder where pre-processed files are stored (not passed if creating from key objects) len_segment : int, default 128 the length of the segments in which the data should be cut (in frames) overlap : int, default 0 the length of the overlap between neighboring segments (in frames) feature_extraction : str, default 'kinematic' the feature extraction method (see options.feature_extractors for available options) ignored_bodyparts : list, optional list of strings of bodypart names to ignore key_objects : tuple, optional a tuple of key objects treba_files : bool, default False if True, TREBA feature files will be loaded num_cpus : int, optional the number of cpus to use in data processing feature_extraction_pars : dict, optional parameters of the feature extractor

task_n
include_task1
treba_files
@classmethod
def get_file_ids( cls, task_n: int = 1, include_task1: bool = False, treba_files: bool = False, data_path: Union[str, Set] = None, file_paths=None, *args, **kwargs) -> Iterable:
1715    @classmethod
1716    def get_file_ids(
1717        cls,
1718        task_n: int = 1,
1719        include_task1: bool = False,
1720        treba_files: bool = False,
1721        data_path: Union[str, Set] = None,
1722        file_paths=None,
1723        *args,
1724        **kwargs,
1725    ) -> Iterable:
1726        """Get file ids.
1727
1728        Process data parameters and return a list of ids  of the videos that should
1729        be processed by the `__init__` function.
1730
1731        Parameters
1732        ----------
1733        task_n : {1, 2, 3}
1734            the index of the CalMS21 challenge task
1735        include_task1 : bool, default False
1736            if `True`, the training file of the task 1 will be loaded
1737        treba_files : bool, default False
1738            if `True`, the TREBA feature files will be loaded
1739        filenames : set, optional
1740            a set of string filenames to search for (only basenames, not the whole paths)
1741        data_path : str | set, optional
1742            the path to the folder where the pose and feature files are stored or a set of such paths
1743            (not passed if creating from key objects or from `file_paths`)
1744        file_paths : set, optional
1745            a set of string paths to the pose and feature files
1746            (not passed if creating from key objects or from `data_path`)
1747
1748        Returns
1749        -------
1750        video_ids : list
1751            a list of video file ids
1752
1753        """
1754        task_n = int(task_n)
1755        if task_n == 1:
1756            include_task1 = False
1757        files = []
1758        if treba_files:
1759            postfix = "_features"
1760        else:
1761            postfix = ""
1762        files.append(f"calms21_task{task_n}_train{postfix}.npy")
1763        files.append(f"calms21_task{task_n}_test{postfix}.npy")
1764        if include_task1:
1765            files.append(f"calms21_task1_train{postfix}.npy")
1766        for i in range(1, 5):
1767            files.append(f"calms21_unlabeled_videos_part{i}{postfix}.npy")
1768        filenames = set(files)
1769        return SequenceInputStore.get_file_ids(filenames, data_path, file_paths)

Get file ids.

Process data parameters and return a list of ids of the videos that should be processed by the __init__ function.

Parameters

task_n : {1, 2, 3} the index of the CalMS21 challenge task include_task1 : bool, default False if True, the training file of the task 1 will be loaded treba_files : bool, default False if True, the TREBA feature files will be loaded filenames : set, optional a set of string filenames to search for (only basenames, not the whole paths) data_path : str | set, optional the path to the folder where the pose and feature files are stored or a set of such paths (not passed if creating from key objects or from file_paths) file_paths : set, optional a set of string paths to the pose and feature files (not passed if creating from key objects or from data_path)

Returns

video_ids : list a list of video file ids

class Numpy3DInputStore(FileInputStore):
1884class Numpy3DInputStore(FileInputStore):
1885    """3D data.
1886
1887    Assumes the data files to be `numpy` arrays saved in `.npy` format with shape `(#frames, #keypoints, 3)`.
1888
1889    Assumes the following file structure:
1890    ```
1891    data_path
1892    ├── video1_suffix1.npy
1893    ├── video2_suffix2.npy
1894    ├── video1_features.pt
1895    └── video2_features.pt
1896    ```
1897    Here `data_suffix` is `{'_suffix1.npy', '_suffix1.npy'}` and `feature_suffix` (optional) is `'_features.pt'`.
1898
1899    The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
1900    feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
1901    set `transpose_features` to `True`.
1902
1903    The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
1904    with `pickle.dump()` (with `.pickle` or `.pkl` extension).
1905    """
1906
1907    def __init__(
1908        self,
1909        video_order: List = None,
1910        data_path: Union[Set, str] = None,
1911        file_paths: Set = None,
1912        data_suffix: Union[Set, str] = None,
1913        data_prefix: Union[Set, str] = None,
1914        feature_suffix: Union[Set, str] = None,
1915        convert_int_indices: bool = True,
1916        feature_save_path: str = None,
1917        canvas_shape: List = None,
1918        len_segment: int = 128,
1919        overlap: int = 0,
1920        feature_extraction: str = "kinematic",
1921        ignored_clips: List = None,
1922        ignored_bodyparts: List = None,
1923        default_agent_name: str = "ind0",
1924        key_objects: Dict = None,
1925        likelihood_threshold: float = 0,
1926        num_cpus: int = None,
1927        frame_limit: int = 1,
1928        feature_extraction_pars: Dict = None,
1929        centered: bool = False,
1930        **kwargs,
1931    ) -> None:
1932        """Initialize a store.
1933
1934        Parameters
1935        ----------
1936        video_order : list, optional
1937            a list of video ids that should be processed in the same order (not passed if creating from key objects
1938        data_path : str | set, optional
1939            the path to the folder where the pose and feature files are stored or a set of such paths
1940            (not passed if creating from key objects or from `file_paths`)
1941        file_paths : set, optional
1942            a set of string paths to the pose and feature files
1943            (not passed if creating from key objects or from `data_path`)
1944        data_suffix : str | set, optional
1945            the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix}
1946            (not passed if creating from key objects or if irrelevant for the dataset)
1947        data_prefix : str | set, optional
1948            the prefix or the set of prefixes such that the pose files for different video views of the same
1949            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
1950            or if irrelevant for the dataset)
1951        feature_suffix : str | set, optional
1952            the suffix or the set of suffices such that the additional feature files are named
1953            {video_id}{feature_suffix} (and placed at the data_path folder)
1954        convert_int_indices : bool, default True
1955            if `True`, convert any integer key `i` in feature files to `'ind{i}'`
1956        feature_save_path : str, optional
1957            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
1958        canvas_shape : List, default [1, 1]
1959            the canvas size where the pose is defined
1960        len_segment : int, default 128
1961            the length of the segments in which the data should be cut (in frames)
1962        overlap : int, default 0
1963            the length of the overlap between neighboring segments (in frames)
1964        feature_extraction : str, default 'kinematic'
1965            the feature extraction method (see options.feature_extractors for available options)
1966        ignored_clips : list, optional
1967            list of strings of clip ids to ignore
1968        ignored_bodyparts : list, optional
1969            list of strings of bodypart names to ignore
1970        default_agent_name : str, default 'ind0'
1971            the agent name used as default in the pose files for a single agent
1972        key_objects : tuple, optional
1973            a tuple of key objects
1974        likelihood_threshold : float, default 0
1975            coordinate values with likelihoods less than this value will be set to 'unknown'
1976        num_cpus : int, optional
1977            the number of cpus to use in data processing
1978        frame_limit : int, default 1
1979            clips shorter than this number of frames will be ignored
1980        feature_extraction_pars : dict, optional
1981            parameters of the feature extractor
1982        centered : bool, default False
1983            if `True`, the pose is centered at the center of mass of the body
1984
1985        """
1986        super().__init__(
1987            video_order,
1988            data_path,
1989            file_paths,
1990            data_suffix=data_suffix,
1991            data_prefix=data_prefix,
1992            feature_suffix=feature_suffix,
1993            convert_int_indices=convert_int_indices,
1994            feature_save_path=feature_save_path,
1995            canvas_shape=canvas_shape,
1996            len_segment=len_segment,
1997            overlap=overlap,
1998            feature_extraction=feature_extraction,
1999            ignored_clips=ignored_clips,
2000            ignored_bodyparts=ignored_bodyparts,
2001            default_agent_name=default_agent_name,
2002            key_objects=key_objects,
2003            likelihood_threshold=likelihood_threshold,
2004            num_cpus=num_cpus,
2005            frame_limit=frame_limit,
2006            feature_extraction_pars=feature_extraction_pars,
2007            centered=centered,
2008        )
2009
2010    def _open_data(
2011        self, filename: str, default_clip_name: str
2012    ) -> Tuple[Dict, Optional[Dict]]:
2013        """Load the keypoints from filename and organize them in a dictionary.
2014
2015        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
2016        The first level is the frame numbers and the second is the body part names. The dataframes should have from
2017        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
2018        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
2019        be done automatically.
2020
2021        Parameters
2022        ----------
2023        filename : str
2024            path to the pose file
2025        default_clip_name : str
2026            the name to assign to a clip if it does not have a name in the raw data
2027
2028        Returns
2029        -------
2030        data dictionary : dict
2031            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
2032        metadata_dictionary : dict
2033            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
2034            like the annotator tag; for no metadata pass `None`)
2035        """
2036        data = np.load(filename, allow_pickle=True)
2037        bodyparts = [str(i) for i in range(data.shape[1])]
2038        clip_id = self.default_agent_name
2039        columns = ["x", "y", "z"]
2040        index = pd.MultiIndex.from_product([range(data.shape[0]), bodyparts])
2041        data_dict = {
2042            clip_id: pd.DataFrame(
2043                data=data.reshape(-1, 3), columns=columns, index=index
2044            )
2045        }
2046        return data_dict, None

3D data.

Assumes the data files to be numpy arrays saved in .npy format with shape (#frames, #keypoints, 3).

Assumes the following file structure:

data_path
├── video1_suffix1.npy
├── video2_suffix2.npy
├── video1_features.pt
└── video2_features.pt

Here data_suffix is {'_suffix1.npy', '_suffix1.npy'} and feature_suffix (optional) is '_features.pt'.

The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are feature values (arrays of shape (#frames, #features)). If the arrays are shaped as (#features, #frames), set transpose_features to True.

The files can be saved with numpy.save() (with .npy extension), torch.save() (with .pt extension) or with pickle.dump() (with .pickle or .pkl extension).

Numpy3DInputStore( video_order: List = None, data_path: Union[Set, str] = None, file_paths: Set = None, data_suffix: Union[Set, str] = None, data_prefix: Union[Set, str] = None, feature_suffix: Union[Set, str] = None, convert_int_indices: bool = True, feature_save_path: str = None, canvas_shape: List = None, len_segment: int = 128, overlap: int = 0, feature_extraction: str = 'kinematic', ignored_clips: List = None, ignored_bodyparts: List = None, default_agent_name: str = 'ind0', key_objects: Dict = None, likelihood_threshold: float = 0, num_cpus: int = None, frame_limit: int = 1, feature_extraction_pars: Dict = None, centered: bool = False, **kwargs)
1907    def __init__(
1908        self,
1909        video_order: List = None,
1910        data_path: Union[Set, str] = None,
1911        file_paths: Set = None,
1912        data_suffix: Union[Set, str] = None,
1913        data_prefix: Union[Set, str] = None,
1914        feature_suffix: Union[Set, str] = None,
1915        convert_int_indices: bool = True,
1916        feature_save_path: str = None,
1917        canvas_shape: List = None,
1918        len_segment: int = 128,
1919        overlap: int = 0,
1920        feature_extraction: str = "kinematic",
1921        ignored_clips: List = None,
1922        ignored_bodyparts: List = None,
1923        default_agent_name: str = "ind0",
1924        key_objects: Dict = None,
1925        likelihood_threshold: float = 0,
1926        num_cpus: int = None,
1927        frame_limit: int = 1,
1928        feature_extraction_pars: Dict = None,
1929        centered: bool = False,
1930        **kwargs,
1931    ) -> None:
1932        """Initialize a store.
1933
1934        Parameters
1935        ----------
1936        video_order : list, optional
1937            a list of video ids that should be processed in the same order (not passed if creating from key objects
1938        data_path : str | set, optional
1939            the path to the folder where the pose and feature files are stored or a set of such paths
1940            (not passed if creating from key objects or from `file_paths`)
1941        file_paths : set, optional
1942            a set of string paths to the pose and feature files
1943            (not passed if creating from key objects or from `data_path`)
1944        data_suffix : str | set, optional
1945            the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix}
1946            (not passed if creating from key objects or if irrelevant for the dataset)
1947        data_prefix : str | set, optional
1948            the prefix or the set of prefixes such that the pose files for different video views of the same
1949            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
1950            or if irrelevant for the dataset)
1951        feature_suffix : str | set, optional
1952            the suffix or the set of suffices such that the additional feature files are named
1953            {video_id}{feature_suffix} (and placed at the data_path folder)
1954        convert_int_indices : bool, default True
1955            if `True`, convert any integer key `i` in feature files to `'ind{i}'`
1956        feature_save_path : str, optional
1957            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
1958        canvas_shape : List, default [1, 1]
1959            the canvas size where the pose is defined
1960        len_segment : int, default 128
1961            the length of the segments in which the data should be cut (in frames)
1962        overlap : int, default 0
1963            the length of the overlap between neighboring segments (in frames)
1964        feature_extraction : str, default 'kinematic'
1965            the feature extraction method (see options.feature_extractors for available options)
1966        ignored_clips : list, optional
1967            list of strings of clip ids to ignore
1968        ignored_bodyparts : list, optional
1969            list of strings of bodypart names to ignore
1970        default_agent_name : str, default 'ind0'
1971            the agent name used as default in the pose files for a single agent
1972        key_objects : tuple, optional
1973            a tuple of key objects
1974        likelihood_threshold : float, default 0
1975            coordinate values with likelihoods less than this value will be set to 'unknown'
1976        num_cpus : int, optional
1977            the number of cpus to use in data processing
1978        frame_limit : int, default 1
1979            clips shorter than this number of frames will be ignored
1980        feature_extraction_pars : dict, optional
1981            parameters of the feature extractor
1982        centered : bool, default False
1983            if `True`, the pose is centered at the center of mass of the body
1984
1985        """
1986        super().__init__(
1987            video_order,
1988            data_path,
1989            file_paths,
1990            data_suffix=data_suffix,
1991            data_prefix=data_prefix,
1992            feature_suffix=feature_suffix,
1993            convert_int_indices=convert_int_indices,
1994            feature_save_path=feature_save_path,
1995            canvas_shape=canvas_shape,
1996            len_segment=len_segment,
1997            overlap=overlap,
1998            feature_extraction=feature_extraction,
1999            ignored_clips=ignored_clips,
2000            ignored_bodyparts=ignored_bodyparts,
2001            default_agent_name=default_agent_name,
2002            key_objects=key_objects,
2003            likelihood_threshold=likelihood_threshold,
2004            num_cpus=num_cpus,
2005            frame_limit=frame_limit,
2006            feature_extraction_pars=feature_extraction_pars,
2007            centered=centered,
2008        )

Initialize a store.

Parameters

video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects data_path : str | set, optional the path to the folder where the pose and feature files are stored or a set of such paths (not passed if creating from key objects or from file_paths) file_paths : set, optional a set of string paths to the pose and feature files (not passed if creating from key objects or from data_path) data_suffix : str | set, optional the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix} (not passed if creating from key objects or if irrelevant for the dataset) data_prefix : str | set, optional the prefix or the set of prefixes such that the pose files for different video views of the same clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects or if irrelevant for the dataset) feature_suffix : str | set, optional the suffix or the set of suffices such that the additional feature files are named {video_id}{feature_suffix} (and placed at the data_path folder) convert_int_indices : bool, default True if True, convert any integer key i in feature files to 'ind{i}' feature_save_path : str, optional the path to the folder where pre-processed files are stored (not passed if creating from key objects) canvas_shape : List, default [1, 1] the canvas size where the pose is defined len_segment : int, default 128 the length of the segments in which the data should be cut (in frames) overlap : int, default 0 the length of the overlap between neighboring segments (in frames) feature_extraction : str, default 'kinematic' the feature extraction method (see options.feature_extractors for available options) ignored_clips : list, optional list of strings of clip ids to ignore ignored_bodyparts : list, optional list of strings of bodypart names to ignore default_agent_name : str, default 'ind0' the agent name used as default in the pose files for a single agent key_objects : tuple, optional a tuple of key objects likelihood_threshold : float, default 0 coordinate values with likelihoods less than this value will be set to 'unknown' num_cpus : int, optional the number of cpus to use in data processing frame_limit : int, default 1 clips shorter than this number of frames will be ignored feature_extraction_pars : dict, optional parameters of the feature extractor centered : bool, default False if True, the pose is centered at the center of mass of the body

class LoadedFeaturesInputStore(GeneralInputStore):
2049class LoadedFeaturesInputStore(GeneralInputStore):
2050    """Non-pose feature files.
2051
2052    The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
2053    feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
2054    set `transpose_features` to `True`.
2055
2056    The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
2057    with `pickle.dump()` (with `.pickle` or `.pkl` extension).
2058
2059    Assumes the following file structure:
2060    ```
2061    data_path
2062    ├── video1_features.pt
2063    └── video2_features.pt
2064    ```
2065    Here `feature_suffix` (optional) is `'_features.pt'`.
2066    """
2067
2068    def __init__(
2069        self,
2070        video_order: List = None,
2071        data_path: Union[Set, str] = None,
2072        file_paths: Set = None,
2073        feature_suffix: Union[Set, str] = None,
2074        convert_int_indices: bool = True,
2075        feature_save_path: str = None,
2076        len_segment: int = 128,
2077        overlap: int = 0,
2078        ignored_clips: List = None,
2079        key_objects: Dict = None,
2080        num_cpus: int = None,
2081        frame_limit: int = 1,
2082        transpose_features: bool = False,
2083        **kwargs,
2084    ) -> None:
2085        """Initialize a store.
2086
2087        Parameters
2088        ----------
2089        video_order : list, optional
2090            a list of video ids that should be processed in the same order (not passed if creating from key objects
2091        data_path : str | set, optional
2092            the path to the folder where the pose and feature files are stored or a set of such paths
2093            (not passed if creating from key objects or from `file_paths`)
2094        file_paths : set, optional
2095            a set of string paths to the pose and feature files
2096            (not passed if creating from key objects or from `data_path`)
2097        feature_suffix : str | set, optional
2098            the suffix or the set of suffices such that the additional feature files are named
2099            {video_id}{feature_suffix} (and placed at the data_path folder)
2100        convert_int_indices : bool, default True
2101            if `True`, convert any integer key `i` in feature files to `'ind{i}'`
2102        feature_save_path : str, optional
2103            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
2104        len_segment : int, default 128
2105            the length of the segments in which the data should be cut (in frames)
2106        overlap : int, default 0
2107            the length of the overlap between neighboring segments (in frames)
2108        ignored_clips : list, optional
2109            list of strings of clip ids to ignore
2110        default_agent_name : str, default 'ind0'
2111            the agent name used as default in the pose files for a single agent
2112        key_objects : tuple, optional
2113            a tuple of key objects
2114        num_cpus : int, optional
2115            the number of cpus to use in data processing
2116        frame_limit : int, default 1
2117            clips shorter than this number of frames will be ignored
2118        transpose_features : bool, default False
2119            if `True`,
2120
2121        """
2122        super().__init__(
2123            video_order,
2124            data_path,
2125            file_paths,
2126            feature_suffix=feature_suffix,
2127            convert_int_indices=convert_int_indices,
2128            feature_save_path=feature_save_path,
2129            len_segment=len_segment,
2130            overlap=overlap,
2131            ignored_clips=ignored_clips,
2132            key_objects=key_objects,
2133            num_cpus=num_cpus,
2134            frame_limit=frame_limit,
2135            transpose_features=transpose_features,
2136        )
2137
2138    def get_visibility(
2139        self, video_id: str, clip_id: str, start: int, end: int, score: int
2140    ) -> float:
2141        """Get the fraction of the frames in that have a visibility score better than a hard_threshold.
2142
2143        For example, in the case of keypoint data the visibility score can be the number of identified keypoints.
2144
2145        Parameters
2146        ----------
2147        video_id : str
2148            the video id of the frames
2149        clip_id : str
2150            the clip id of the frames
2151        start : int
2152            the start frame
2153        end : int
2154            the end frame
2155        score : float
2156            the visibility score hard_threshold
2157
2158        Returns
2159        -------
2160        frac_visible: float
2161            the fraction of frames with visibility above the hard_threshold
2162
2163        """
2164        return 1
2165
2166    def _generate_features(
2167        self, video_id: str
2168    ) -> Tuple[Dict, Dict, Dict, Union[str, int]]:
2169        """Generate features from the raw coordinates."""
2170        features = defaultdict(lambda: {})
2171        loaded_features = self._load_saved_features(video_id)
2172        min_frames = None
2173        max_frames = None
2174        video_tag = None
2175        for clip_id, feature_tensor in loaded_features.items():
2176            if clip_id == "max_frames":
2177                max_frames = feature_tensor
2178            elif clip_id == "min_frames":
2179                min_frames = feature_tensor
2180            elif clip_id == "video_tag":
2181                video_tag = feature_tensor
2182            else:
2183                if not isinstance(feature_tensor, torch.Tensor):
2184                    feature_tensor = torch.tensor(feature_tensor)
2185                if self.convert_int_indices and (
2186                    isinstance(clip_id, int) or isinstance(clip_id, np.integer)
2187                ):
2188                    clip_id = f"ind{clip_id}"
2189                key = f"{os.path.basename(video_id)}---{clip_id}"
2190                features[key]["loaded"] = feature_tensor
2191        if min_frames is None:
2192            min_frames = {}
2193            for key, value in features.items():
2194                video_id, clip_id = key.split("---")
2195                min_frames[clip_id] = 0
2196        if max_frames is None:
2197            max_frames = {}
2198            for key, value in features.items():
2199                video_id, clip_id = key.split("---")
2200                max_frames[clip_id] = value["loaded"].shape[0] - 1 + min_frames[clip_id]
2201        return features, min_frames, max_frames, video_tag
2202
2203    def _load_data(self) -> np.array:
2204        """Load input data and generate data prompts."""
2205        if self.video_order is None:
2206            return None
2207
2208        files = []
2209        for video_id in self.video_order:
2210            for f in self.file_paths:
2211                if f.endswith(tuple(self.feature_suffix)):
2212                    bn = os.path.basename(f)
2213                    if video_id == strip_suffix(bn, self.feature_suffix):
2214                        files.append(f)
2215
2216        def make_data_dictionary(filename):
2217            name = strip_suffix(filename, self.feature_suffix)
2218            name = os.path.basename(name)
2219            data_dict, min_frames, max_frames, video_tag = self._generate_features(name)
2220            bp_dict = defaultdict(lambda: {})
2221            for key, value in data_dict.items():
2222                video_id, clip_id = key.split("---")
2223                bp_dict[video_id][clip_id] = 1
2224            min_frames = {name: min_frames}  # name is e.g. 20190707T1126-1226
2225            max_frames = {name: max_frames}
2226            names, lengths, coords = self._make_trimmed_data(data_dict)
2227            return names, lengths, coords, bp_dict, min_frames, max_frames, video_tag
2228
2229        if os.name != "nt":
2230            dict_list = p_map(make_data_dictionary, files, num_cpus=self.num_cpus)
2231        else:
2232            print(
2233                "Multiprocessing is not supported on Windows, loading files sequentially."
2234            )
2235            dict_list = tqdm([make_data_dictionary(f) for f in files])
2236
2237        self.visibility = {}
2238        self.min_frames = {}
2239        self.max_frames = {}
2240        self.original_coordinates = []
2241        self.metadata = []
2242        X = []
2243        for (
2244            names,
2245            lengths,
2246            coords,
2247            bp_dictionary,
2248            min_frames,
2249            max_frames,
2250            metadata,
2251        ) in dict_list:
2252            X += names
2253            self.original_coordinates += coords
2254            self.visibility.update(bp_dictionary)
2255            self.min_frames.update(min_frames)
2256            self.max_frames.update(max_frames)
2257            if metadata is not None:
2258                self.metadata += metadata
2259        del dict_list
2260        if len(self.metadata) != len(self.original_coordinates):
2261            self.metadata = None
2262        else:
2263            self.metadata = np.array(self.metadata)
2264
2265        self.min_frames = dict(self.min_frames)
2266        self.max_frames = dict(self.max_frames)
2267        self.original_coordinates = np.array(self.original_coordinates)
2268        return np.array(X)
2269
2270    @classmethod
2271    def get_file_ids(
2272        cls,
2273        data_path: Union[Set, str] = None,
2274        file_paths: Set = None,
2275        feature_suffix: Set = None,
2276        *args,
2277        **kwargs,
2278    ) -> List:
2279        """Get file ids.
2280
2281        Process data parameters and return a list of ids  of the videos that should
2282        be processed by the __init__ function.
2283
2284        Parameters
2285        ----------
2286        data_suffix : set | str, optional
2287            the suffix (or a set of suffixes) of the input data files
2288        data_path : set | str, optional
2289            the path to the folder where the pose and feature files are stored or a set of such paths
2290            (not passed if creating from key objects or from `file_paths`)
2291        data_prefix : set | str, optional
2292            the prefix or the set of prefixes such that the pose files for different video views of the same
2293            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
2294            or if irrelevant for the dataset)
2295        file_paths : set, optional
2296            a set of string paths to the pose and feature files
2297        feature_suffix : str | set, optional
2298            the suffix or the set of suffices such that the additional feature files are named
2299            {video_id}{feature_suffix} (and placed at the `data_path` folder or at `file_paths`)
2300
2301        Returns
2302        -------
2303        video_ids : list
2304            a list of video file ids
2305
2306        """
2307        if feature_suffix is None:
2308            feature_suffix = []
2309        if isinstance(feature_suffix, str):
2310            feature_suffix = [feature_suffix]
2311        feature_suffix = tuple(feature_suffix)
2312        if file_paths is None:
2313            file_paths = []
2314        if data_path is not None:
2315            if isinstance(data_path, str):
2316                data_path = [data_path]
2317            file_paths = []
2318            for folder in data_path:
2319                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
2320        ids = set()
2321        for f in file_paths:
2322            if f.endswith(feature_suffix):
2323                bn = os.path.basename(f)
2324                video_id = strip_suffix(bn, feature_suffix)
2325                ids.add(video_id)
2326        ids = sorted(ids)
2327        return ids

Non-pose feature files.

The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are feature values (arrays of shape (#frames, #features)). If the arrays are shaped as (#features, #frames), set transpose_features to True.

The files can be saved with numpy.save() (with .npy extension), torch.save() (with .pt extension) or with pickle.dump() (with .pickle or .pkl extension).

Assumes the following file structure:

data_path
├── video1_features.pt
└── video2_features.pt

Here feature_suffix (optional) is '_features.pt'.

LoadedFeaturesInputStore( video_order: List = None, data_path: Union[Set, str] = None, file_paths: Set = None, feature_suffix: Union[Set, str] = None, convert_int_indices: bool = True, feature_save_path: str = None, len_segment: int = 128, overlap: int = 0, ignored_clips: List = None, key_objects: Dict = None, num_cpus: int = None, frame_limit: int = 1, transpose_features: bool = False, **kwargs)
2068    def __init__(
2069        self,
2070        video_order: List = None,
2071        data_path: Union[Set, str] = None,
2072        file_paths: Set = None,
2073        feature_suffix: Union[Set, str] = None,
2074        convert_int_indices: bool = True,
2075        feature_save_path: str = None,
2076        len_segment: int = 128,
2077        overlap: int = 0,
2078        ignored_clips: List = None,
2079        key_objects: Dict = None,
2080        num_cpus: int = None,
2081        frame_limit: int = 1,
2082        transpose_features: bool = False,
2083        **kwargs,
2084    ) -> None:
2085        """Initialize a store.
2086
2087        Parameters
2088        ----------
2089        video_order : list, optional
2090            a list of video ids that should be processed in the same order (not passed if creating from key objects
2091        data_path : str | set, optional
2092            the path to the folder where the pose and feature files are stored or a set of such paths
2093            (not passed if creating from key objects or from `file_paths`)
2094        file_paths : set, optional
2095            a set of string paths to the pose and feature files
2096            (not passed if creating from key objects or from `data_path`)
2097        feature_suffix : str | set, optional
2098            the suffix or the set of suffices such that the additional feature files are named
2099            {video_id}{feature_suffix} (and placed at the data_path folder)
2100        convert_int_indices : bool, default True
2101            if `True`, convert any integer key `i` in feature files to `'ind{i}'`
2102        feature_save_path : str, optional
2103            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
2104        len_segment : int, default 128
2105            the length of the segments in which the data should be cut (in frames)
2106        overlap : int, default 0
2107            the length of the overlap between neighboring segments (in frames)
2108        ignored_clips : list, optional
2109            list of strings of clip ids to ignore
2110        default_agent_name : str, default 'ind0'
2111            the agent name used as default in the pose files for a single agent
2112        key_objects : tuple, optional
2113            a tuple of key objects
2114        num_cpus : int, optional
2115            the number of cpus to use in data processing
2116        frame_limit : int, default 1
2117            clips shorter than this number of frames will be ignored
2118        transpose_features : bool, default False
2119            if `True`,
2120
2121        """
2122        super().__init__(
2123            video_order,
2124            data_path,
2125            file_paths,
2126            feature_suffix=feature_suffix,
2127            convert_int_indices=convert_int_indices,
2128            feature_save_path=feature_save_path,
2129            len_segment=len_segment,
2130            overlap=overlap,
2131            ignored_clips=ignored_clips,
2132            key_objects=key_objects,
2133            num_cpus=num_cpus,
2134            frame_limit=frame_limit,
2135            transpose_features=transpose_features,
2136        )

Initialize a store.

Parameters

video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects data_path : str | set, optional the path to the folder where the pose and feature files are stored or a set of such paths (not passed if creating from key objects or from file_paths) file_paths : set, optional a set of string paths to the pose and feature files (not passed if creating from key objects or from data_path) feature_suffix : str | set, optional the suffix or the set of suffices such that the additional feature files are named {video_id}{feature_suffix} (and placed at the data_path folder) convert_int_indices : bool, default True if True, convert any integer key i in feature files to 'ind{i}' feature_save_path : str, optional the path to the folder where pre-processed files are stored (not passed if creating from key objects) len_segment : int, default 128 the length of the segments in which the data should be cut (in frames) overlap : int, default 0 the length of the overlap between neighboring segments (in frames) ignored_clips : list, optional list of strings of clip ids to ignore default_agent_name : str, default 'ind0' the agent name used as default in the pose files for a single agent key_objects : tuple, optional a tuple of key objects num_cpus : int, optional the number of cpus to use in data processing frame_limit : int, default 1 clips shorter than this number of frames will be ignored transpose_features : bool, default False if True,

def get_visibility( self, video_id: str, clip_id: str, start: int, end: int, score: int) -> float:
2138    def get_visibility(
2139        self, video_id: str, clip_id: str, start: int, end: int, score: int
2140    ) -> float:
2141        """Get the fraction of the frames in that have a visibility score better than a hard_threshold.
2142
2143        For example, in the case of keypoint data the visibility score can be the number of identified keypoints.
2144
2145        Parameters
2146        ----------
2147        video_id : str
2148            the video id of the frames
2149        clip_id : str
2150            the clip id of the frames
2151        start : int
2152            the start frame
2153        end : int
2154            the end frame
2155        score : float
2156            the visibility score hard_threshold
2157
2158        Returns
2159        -------
2160        frac_visible: float
2161            the fraction of frames with visibility above the hard_threshold
2162
2163        """
2164        return 1

Get the fraction of the frames in that have a visibility score better than a hard_threshold.

For example, in the case of keypoint data the visibility score can be the number of identified keypoints.

Parameters

video_id : str the video id of the frames clip_id : str the clip id of the frames start : int the start frame end : int the end frame score : float the visibility score hard_threshold

Returns

frac_visible: float the fraction of frames with visibility above the hard_threshold

@classmethod
def get_file_ids( cls, data_path: Union[Set, str] = None, file_paths: Set = None, feature_suffix: Set = None, *args, **kwargs) -> List:
2270    @classmethod
2271    def get_file_ids(
2272        cls,
2273        data_path: Union[Set, str] = None,
2274        file_paths: Set = None,
2275        feature_suffix: Set = None,
2276        *args,
2277        **kwargs,
2278    ) -> List:
2279        """Get file ids.
2280
2281        Process data parameters and return a list of ids  of the videos that should
2282        be processed by the __init__ function.
2283
2284        Parameters
2285        ----------
2286        data_suffix : set | str, optional
2287            the suffix (or a set of suffixes) of the input data files
2288        data_path : set | str, optional
2289            the path to the folder where the pose and feature files are stored or a set of such paths
2290            (not passed if creating from key objects or from `file_paths`)
2291        data_prefix : set | str, optional
2292            the prefix or the set of prefixes such that the pose files for different video views of the same
2293            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
2294            or if irrelevant for the dataset)
2295        file_paths : set, optional
2296            a set of string paths to the pose and feature files
2297        feature_suffix : str | set, optional
2298            the suffix or the set of suffices such that the additional feature files are named
2299            {video_id}{feature_suffix} (and placed at the `data_path` folder or at `file_paths`)
2300
2301        Returns
2302        -------
2303        video_ids : list
2304            a list of video file ids
2305
2306        """
2307        if feature_suffix is None:
2308            feature_suffix = []
2309        if isinstance(feature_suffix, str):
2310            feature_suffix = [feature_suffix]
2311        feature_suffix = tuple(feature_suffix)
2312        if file_paths is None:
2313            file_paths = []
2314        if data_path is not None:
2315            if isinstance(data_path, str):
2316                data_path = [data_path]
2317            file_paths = []
2318            for folder in data_path:
2319                file_paths += [os.path.join(folder, x) for x in os.listdir(folder)]
2320        ids = set()
2321        for f in file_paths:
2322            if f.endswith(feature_suffix):
2323                bn = os.path.basename(f)
2324                video_id = strip_suffix(bn, feature_suffix)
2325                ids.add(video_id)
2326        ids = sorted(ids)
2327        return ids

Get file ids.

Process data parameters and return a list of ids of the videos that should be processed by the __init__ function.

Parameters

data_suffix : set | str, optional the suffix (or a set of suffixes) of the input data files data_path : set | str, optional the path to the folder where the pose and feature files are stored or a set of such paths (not passed if creating from key objects or from file_paths) data_prefix : set | str, optional the prefix or the set of prefixes such that the pose files for different video views of the same clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects or if irrelevant for the dataset) file_paths : set, optional a set of string paths to the pose and feature files feature_suffix : str | set, optional the suffix or the set of suffices such that the additional feature files are named {video_id}{feature_suffix} (and placed at the data_path folder or at file_paths)

Returns

video_ids : list a list of video file ids

class SIMBAInputStore(FileInputStore):
2330class SIMBAInputStore(FileInputStore):
2331    """SIMBA paper format data.
2332
2333    Assumes the following file structure:
2334
2335    ```
2336    data_path
2337    ├── Video1.csv
2338    ...
2339    └── Video9.csv
2340    ```
2341    Here `data_suffix` is `.csv`.
2342    """
2343
2344    def __init__(
2345        self,
2346        video_order: List = None,
2347        data_path: Union[Set, str] = None,
2348        file_paths: Set = None,
2349        data_prefix: Union[Set, str] = None,
2350        feature_suffix: str = None,
2351        feature_save_path: str = None,
2352        canvas_shape: List = None,
2353        len_segment: int = 128,
2354        overlap: int = 0,
2355        feature_extraction: str = "kinematic",
2356        ignored_clips: List = None,
2357        ignored_bodyparts: List = None,
2358        key_objects: Tuple = None,
2359        likelihood_threshold: float = 0,
2360        num_cpus: int = None,
2361        normalize: bool = False,
2362        feature_extraction_pars: Dict = None,
2363        centered: bool = False,
2364        data_suffix: str = None,
2365        use_features: bool = False,
2366        *args,
2367        **kwargs,
2368    ) -> None:
2369        """Initialize a store.
2370
2371        Parameters
2372        ----------
2373        video_order : list, optional
2374            a list of video ids that should be processed in the same order (not passed if creating from key objects
2375        data_path : str | set, optional
2376            the path to the folder where the pose and feature files are stored or a set of such paths
2377            (not passed if creating from key objects or from `file_paths`)
2378        file_paths : set, optional
2379            a set of string paths to the pose and feature files
2380            (not passed if creating from key objects or from `data_path`)
2381        data_suffix : str | set, optional
2382            the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix}
2383            (not passed if creating from key objects or if irrelevant for the dataset)
2384        data_prefix : str | set, optional
2385            the prefix or the set of prefixes such that the pose files for different video views of the same
2386            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
2387            or if irrelevant for the dataset)
2388        feature_suffix : str | set, optional
2389            the suffix or the set of suffices such that the additional feature files are named
2390            {video_id}{feature_suffix} (and placed at the data_path folder)
2391        feature_save_path : str, optional
2392            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
2393        canvas_shape : List, default [1, 1]
2394            the canvas size where the pose is defined
2395        len_segment : int, default 128
2396            the length of the segments in which the data should be cut (in frames)
2397        overlap : int, default 0
2398            the length of the overlap between neighboring segments (in frames)
2399        feature_extraction : str, default 'kinematic'
2400            the feature extraction method (see options.feature_extractors for available options)
2401        ignored_clips : list, optional
2402            list of strings of clip ids to ignore
2403        ignored_bodyparts : list, optional
2404            list of strings of bodypart names to ignore
2405        key_objects : tuple, optional
2406            a tuple of key objects
2407        likelihood_threshold : float, default 0
2408            coordinate values with likelihoods less than this value will be set to 'unknown'
2409        num_cpus : int, optional
2410            the number of cpus to use in data processing
2411        normalize : bool, default False
2412            whether to normalize the pose
2413        feature_extraction_pars : dict, optional
2414            parameters of the feature extractor
2415        centered : bool, default False
2416            whether the pose is centered at the object of interest
2417        use_features : bool, default False
2418            whether to use features
2419
2420        """
2421        self.use_features = use_features
2422        if feature_extraction_pars is not None:
2423            feature_extraction_pars["interactive"] = True
2424        super().__init__(
2425            video_order=video_order,
2426            data_path=data_path,
2427            file_paths=file_paths,
2428            data_suffix=data_suffix,
2429            data_prefix=data_prefix,
2430            feature_suffix=feature_suffix,
2431            convert_int_indices=False,
2432            feature_save_path=feature_save_path,
2433            canvas_shape=canvas_shape,
2434            len_segment=len_segment,
2435            overlap=overlap,
2436            feature_extraction=feature_extraction,
2437            ignored_clips=ignored_clips,
2438            ignored_bodyparts=ignored_bodyparts,
2439            default_agent_name="",
2440            key_objects=key_objects,
2441            likelihood_threshold=likelihood_threshold,
2442            num_cpus=num_cpus,
2443            min_frames=0,
2444            normalize=normalize,
2445            feature_extraction_pars=feature_extraction_pars,
2446            centered=centered,
2447        )
2448
2449    def _open_data(
2450        self, filename: str, default_clip_name: str
2451    ) -> Tuple[Dict, Optional[Dict]]:
2452
2453        torch.cuda.empty_cache() if torch.cuda.is_available() else None
2454        data = pd.read_csv(filename)
2455        output = {}
2456        column_dict = {"x": "x", "y": "y", "z": "z", "p": "likelihood"}
2457        columns = [x for x in data.columns if x.split("_")[-1] in column_dict]
2458        animals = sorted(set([x.split("_")[-2] for x in columns]))
2459        coords = sorted(set([x.split("_")[-1] for x in columns]))
2460        names = sorted(set(["_".join(x.split("_")[:-2]) for x in columns]))
2461        for animal in animals:
2462            data_dict = {}
2463            for i, row in data.iterrows():
2464                for col_name in names:
2465                    data_dict[(i, col_name)] = [
2466                        row[f"{col_name}_{animal}_{coord}"] for coord in coords
2467                    ]
2468            output[animal] = pd.DataFrame(data_dict).T
2469            output[animal].columns = [column_dict[x] for x in coords]
2470        if self.use_features:
2471            columns_to_avoid = [
2472                x
2473                for x in data.columns
2474                if x.split("_")[-1] in column_dict
2475                or x.split("_")[-1].startswith("prediction")
2476            ]
2477            columns_to_avoid += ["scorer", "frames", "video_no"]
2478            output["loaded"] = (
2479                data[[x for x in data.columns if x not in columns_to_avoid]]
2480                .interpolate()
2481                .values
2482            )
2483        return output, None

SIMBA paper format data.

Assumes the following file structure:

data_path
├── Video1.csv
...
└── Video9.csv

Here data_suffix is .csv.

SIMBAInputStore( video_order: List = None, data_path: Union[Set, str] = None, file_paths: Set = None, data_prefix: Union[Set, str] = None, feature_suffix: str = None, feature_save_path: str = None, canvas_shape: List = None, len_segment: int = 128, overlap: int = 0, feature_extraction: str = 'kinematic', ignored_clips: List = None, ignored_bodyparts: List = None, key_objects: Tuple = None, likelihood_threshold: float = 0, num_cpus: int = None, normalize: bool = False, feature_extraction_pars: Dict = None, centered: bool = False, data_suffix: str = None, use_features: bool = False, *args, **kwargs)
2344    def __init__(
2345        self,
2346        video_order: List = None,
2347        data_path: Union[Set, str] = None,
2348        file_paths: Set = None,
2349        data_prefix: Union[Set, str] = None,
2350        feature_suffix: str = None,
2351        feature_save_path: str = None,
2352        canvas_shape: List = None,
2353        len_segment: int = 128,
2354        overlap: int = 0,
2355        feature_extraction: str = "kinematic",
2356        ignored_clips: List = None,
2357        ignored_bodyparts: List = None,
2358        key_objects: Tuple = None,
2359        likelihood_threshold: float = 0,
2360        num_cpus: int = None,
2361        normalize: bool = False,
2362        feature_extraction_pars: Dict = None,
2363        centered: bool = False,
2364        data_suffix: str = None,
2365        use_features: bool = False,
2366        *args,
2367        **kwargs,
2368    ) -> None:
2369        """Initialize a store.
2370
2371        Parameters
2372        ----------
2373        video_order : list, optional
2374            a list of video ids that should be processed in the same order (not passed if creating from key objects
2375        data_path : str | set, optional
2376            the path to the folder where the pose and feature files are stored or a set of such paths
2377            (not passed if creating from key objects or from `file_paths`)
2378        file_paths : set, optional
2379            a set of string paths to the pose and feature files
2380            (not passed if creating from key objects or from `data_path`)
2381        data_suffix : str | set, optional
2382            the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix}
2383            (not passed if creating from key objects or if irrelevant for the dataset)
2384        data_prefix : str | set, optional
2385            the prefix or the set of prefixes such that the pose files for different video views of the same
2386            clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects
2387            or if irrelevant for the dataset)
2388        feature_suffix : str | set, optional
2389            the suffix or the set of suffices such that the additional feature files are named
2390            {video_id}{feature_suffix} (and placed at the data_path folder)
2391        feature_save_path : str, optional
2392            the path to the folder where pre-processed files are stored (not passed if creating from key objects)
2393        canvas_shape : List, default [1, 1]
2394            the canvas size where the pose is defined
2395        len_segment : int, default 128
2396            the length of the segments in which the data should be cut (in frames)
2397        overlap : int, default 0
2398            the length of the overlap between neighboring segments (in frames)
2399        feature_extraction : str, default 'kinematic'
2400            the feature extraction method (see options.feature_extractors for available options)
2401        ignored_clips : list, optional
2402            list of strings of clip ids to ignore
2403        ignored_bodyparts : list, optional
2404            list of strings of bodypart names to ignore
2405        key_objects : tuple, optional
2406            a tuple of key objects
2407        likelihood_threshold : float, default 0
2408            coordinate values with likelihoods less than this value will be set to 'unknown'
2409        num_cpus : int, optional
2410            the number of cpus to use in data processing
2411        normalize : bool, default False
2412            whether to normalize the pose
2413        feature_extraction_pars : dict, optional
2414            parameters of the feature extractor
2415        centered : bool, default False
2416            whether the pose is centered at the object of interest
2417        use_features : bool, default False
2418            whether to use features
2419
2420        """
2421        self.use_features = use_features
2422        if feature_extraction_pars is not None:
2423            feature_extraction_pars["interactive"] = True
2424        super().__init__(
2425            video_order=video_order,
2426            data_path=data_path,
2427            file_paths=file_paths,
2428            data_suffix=data_suffix,
2429            data_prefix=data_prefix,
2430            feature_suffix=feature_suffix,
2431            convert_int_indices=False,
2432            feature_save_path=feature_save_path,
2433            canvas_shape=canvas_shape,
2434            len_segment=len_segment,
2435            overlap=overlap,
2436            feature_extraction=feature_extraction,
2437            ignored_clips=ignored_clips,
2438            ignored_bodyparts=ignored_bodyparts,
2439            default_agent_name="",
2440            key_objects=key_objects,
2441            likelihood_threshold=likelihood_threshold,
2442            num_cpus=num_cpus,
2443            min_frames=0,
2444            normalize=normalize,
2445            feature_extraction_pars=feature_extraction_pars,
2446            centered=centered,
2447        )

Initialize a store.

Parameters

video_order : list, optional a list of video ids that should be processed in the same order (not passed if creating from key objects data_path : str | set, optional the path to the folder where the pose and feature files are stored or a set of such paths (not passed if creating from key objects or from file_paths) file_paths : set, optional a set of string paths to the pose and feature files (not passed if creating from key objects or from data_path) data_suffix : str | set, optional the suffix or the set of suffices such that the pose files are named {video_id}{data_suffix} (not passed if creating from key objects or if irrelevant for the dataset) data_prefix : str | set, optional the prefix or the set of prefixes such that the pose files for different video views of the same clip are named {prefix}{sep}{video_id}{data_suffix} (not passed if creating from key objects or if irrelevant for the dataset) feature_suffix : str | set, optional the suffix or the set of suffices such that the additional feature files are named {video_id}{feature_suffix} (and placed at the data_path folder) feature_save_path : str, optional the path to the folder where pre-processed files are stored (not passed if creating from key objects) canvas_shape : List, default [1, 1] the canvas size where the pose is defined len_segment : int, default 128 the length of the segments in which the data should be cut (in frames) overlap : int, default 0 the length of the overlap between neighboring segments (in frames) feature_extraction : str, default 'kinematic' the feature extraction method (see options.feature_extractors for available options) ignored_clips : list, optional list of strings of clip ids to ignore ignored_bodyparts : list, optional list of strings of bodypart names to ignore key_objects : tuple, optional a tuple of key objects likelihood_threshold : float, default 0 coordinate values with likelihoods less than this value will be set to 'unknown' num_cpus : int, optional the number of cpus to use in data processing normalize : bool, default False whether to normalize the pose feature_extraction_pars : dict, optional parameters of the feature extractor centered : bool, default False whether the pose is centered at the object of interest use_features : bool, default False whether to use features

use_features
class ESKTrackStore(FileInputStore):
2486class ESKTrackStore(FileInputStore):
2487    """DLC track data from EPFL Smart Kitchen, allows to choose specific set of keypoints.
2488
2489    Assumes the following file structure:
2490    ```
2491    data_path
2492    ├── video1DLC1000.pickle
2493    ├── video2DLC400.pickle
2494    ├── video1_features.npy
2495    └── video2_features.npy
2496    ```
2497    Here `data_suffix` is `{'DLC1000.pickle', 'DLC400.pickle'}` and `feature_suffix` (optional) is `'_features.npy'`.
2498
2499    The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
2500    feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
2501    set `transpose_features` to `True`.
2502
2503    The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
2504    with `pickle.dump()` (with `.pickle` or `.pkl` extension).
2505    """
2506
2507    def __init__(
2508        self,
2509        keypoint_type: str = "hand",
2510        *args,
2511        **kwargs,
2512    ):
2513        """Initialize a store."""
2514        self.keypoint_type = keypoint_type
2515        self.num_body_kpts = 17
2516        self.num_hand_kpts = 42
2517        self.num_eye_kpts = 10
2518
2519        super().__init__(
2520            *args,
2521            **kwargs,
2522        )
2523
2524    def get_kpt_names(self):
2525
2526        kpt_names = [
2527            "nose",
2528            "left_eye",
2529            "right_eye",
2530            "left_ear",
2531            "right_ear",
2532            "left_shoulder",
2533            "right_shoulder",
2534            "left_elbow",
2535            "right_elbow",
2536            "left_wrist",
2537            "right_wrist",
2538            "left_hip",
2539            "right_hip",
2540            "left_knee",
2541            "right_knee",
2542            "left_ankle",
2543            "right_ankle",
2544        ]
2545
2546        kpt_names = (
2547            kpt_names
2548            + [
2549                f"hand_{i}"
2550                for i in range(
2551                    self.num_body_kpts, self.num_hand_kpts + self.num_body_kpts
2552                )
2553            ]
2554            + [f"eye_gaze_{i}" for i in range(self.num_eye_kpts)]
2555        )
2556        return np.array(kpt_names)
2557
2558    def get_kpt_ind(self, default_num):
2559        """Get the indices of the keypoints to be used."""
2560        body_ind = list(range(17))
2561        count = len(body_ind)
2562        hands_ind = list(range(count, count + self.num_hand_kpts))
2563        count += len(hands_ind)
2564        eye_ind = list(range(count, count + self.num_eye_kpts))
2565        body_wo_arm_ind = [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16]
2566        default_ind = list(range(default_num))
2567        switcher = {
2568            "body": body_ind,
2569            "hands": hands_ind,
2570            "eyes": eye_ind,
2571            "body_hands": body_ind + hands_ind,
2572            "body_eyes": body_ind + eye_ind,
2573            "hands_eyes": hands_ind + eye_ind,
2574            "body_wo_arm": body_wo_arm_ind,
2575        }
2576        return (
2577            switcher.get(self.keypoint_type, default_ind),
2578            not self.keypoint_type in switcher.keys(),
2579        )
2580
2581    def _open_data(
2582        self, filename: str, default_agent_name: str
2583    ) -> Tuple[Dict, Optional[Dict]]:
2584        """Load the keypoints from filename and organize them in a dictionary.
2585
2586        In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
2587        The first level is the frame numbers and the second is the body part names. The dataframes should have from
2588        two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
2589        information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
2590        be done automatically.
2591
2592        Parameters
2593        ----------
2594        filename : str
2595            path to the pose file
2596        default_agent_name : str
2597            the default agent name
2598
2599        Returns
2600        -------
2601        data dictionary : dict
2602            a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
2603        metadata_dictionary : dict
2604            a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
2605            like the annotator tag; for no metadata pass `None`)
2606
2607        """
2608        if filename.endswith("h5"):
2609            temp = pd.read_hdf(filename)
2610            temp = temp.droplevel("scorer", axis=1)
2611        elif filename.endswith(".csv"):
2612            temp = pd.read_csv(filename, header=[1, 2])
2613            temp.columns.names = ["bodyparts", "coords"]
2614        else:
2615            raise TypeError("Invalid file type, please use .csv or .h5")
2616
2617        if "individuals" not in temp.columns.names:
2618            old_idx = temp.columns.to_frame()
2619            old_idx.insert(0, "individuals", self.default_agent_name)
2620            temp.columns = pd.MultiIndex.from_frame(old_idx)
2621
2622        df = temp.stack(["individuals", "bodyparts"], future_stack=True)
2623        idx = pd.MultiIndex.from_product(
2624            [df.index.levels[0], df.index.levels[1], df.index.levels[2]],
2625            names=df.index.names,
2626        )
2627        df = df.reindex(idx).fillna(value=0)
2628        animals = sorted(list(df.index.levels[1]))
2629        dic = {}
2630        default_num = len(df.index.levels[2])
2631        kpt_ind, is_special = self.get_kpt_ind(default_num)
2632        kpt_names = self.get_kpt_names()
2633        for ind in animals:
2634            coord = df.iloc[df.index.get_level_values(1) == ind].droplevel(1)
2635            coord = coord[["x", "y", "z", "likelihood"]]
2636            if not is_special:
2637                coord = coord.loc[(slice(None), kpt_names[kpt_ind]), :]
2638            dic[ind] = coord
2639
2640        return dic, None

DLC track data from EPFL Smart Kitchen, allows to choose specific set of keypoints.

Assumes the following file structure:

data_path
├── video1DLC1000.pickle
├── video2DLC400.pickle
├── video1_features.npy
└── video2_features.npy

Here data_suffix is {'DLC1000.pickle', 'DLC400.pickle'} and feature_suffix (optional) is '_features.npy'.

The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are feature values (arrays of shape (#frames, #features)). If the arrays are shaped as (#features, #frames), set transpose_features to True.

The files can be saved with numpy.save() (with .npy extension), torch.save() (with .pt extension) or with pickle.dump() (with .pickle or .pkl extension).

ESKTrackStore(keypoint_type: str = 'hand', *args, **kwargs)
2507    def __init__(
2508        self,
2509        keypoint_type: str = "hand",
2510        *args,
2511        **kwargs,
2512    ):
2513        """Initialize a store."""
2514        self.keypoint_type = keypoint_type
2515        self.num_body_kpts = 17
2516        self.num_hand_kpts = 42
2517        self.num_eye_kpts = 10
2518
2519        super().__init__(
2520            *args,
2521            **kwargs,
2522        )

Initialize a store.

keypoint_type
num_body_kpts
num_hand_kpts
num_eye_kpts
def get_kpt_names(self):
2524    def get_kpt_names(self):
2525
2526        kpt_names = [
2527            "nose",
2528            "left_eye",
2529            "right_eye",
2530            "left_ear",
2531            "right_ear",
2532            "left_shoulder",
2533            "right_shoulder",
2534            "left_elbow",
2535            "right_elbow",
2536            "left_wrist",
2537            "right_wrist",
2538            "left_hip",
2539            "right_hip",
2540            "left_knee",
2541            "right_knee",
2542            "left_ankle",
2543            "right_ankle",
2544        ]
2545
2546        kpt_names = (
2547            kpt_names
2548            + [
2549                f"hand_{i}"
2550                for i in range(
2551                    self.num_body_kpts, self.num_hand_kpts + self.num_body_kpts
2552                )
2553            ]
2554            + [f"eye_gaze_{i}" for i in range(self.num_eye_kpts)]
2555        )
2556        return np.array(kpt_names)
def get_kpt_ind(self, default_num):
2558    def get_kpt_ind(self, default_num):
2559        """Get the indices of the keypoints to be used."""
2560        body_ind = list(range(17))
2561        count = len(body_ind)
2562        hands_ind = list(range(count, count + self.num_hand_kpts))
2563        count += len(hands_ind)
2564        eye_ind = list(range(count, count + self.num_eye_kpts))
2565        body_wo_arm_ind = [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16]
2566        default_ind = list(range(default_num))
2567        switcher = {
2568            "body": body_ind,
2569            "hands": hands_ind,
2570            "eyes": eye_ind,
2571            "body_hands": body_ind + hands_ind,
2572            "body_eyes": body_ind + eye_ind,
2573            "hands_eyes": hands_ind + eye_ind,
2574            "body_wo_arm": body_wo_arm_ind,
2575        }
2576        return (
2577            switcher.get(self.keypoint_type, default_ind),
2578            not self.keypoint_type in switcher.keys(),
2579        )

Get the indices of the keypoints to be used.