dlc2action.task.universal_task

Training and inference.

   1#
   2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. 
   5# A copy is included in dlc2action/LICENSE.AGPL.
   6#
   7"""Training and inference."""
   8
   9import os
  10import random
  11import warnings
  12from collections import defaultdict
  13from collections.abc import Iterable
  14from copy import copy, deepcopy
  15from math import ceil, exp, floor, pi, sqrt
  16from random import randint
  17from time import time
  18from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
  19
  20import numpy as np
  21import torch
  22from dlc2action.data.dataset import BehaviorDataset
  23from dlc2action.metric.base_metric import Metric
  24from dlc2action.model.base_model import LoadedModel, Model
  25from dlc2action.transformer.base_transformer import EmptyTransformer, Transformer
  26from dlc2action.options import dlc2action_colormaps
  27from matplotlib import cm
  28from matplotlib import pyplot as plt
  29from matplotlib import rc
  30from optuna import TrialPruned
  31from optuna.trial import Trial
  32from torch import nn
  33from torch.optim import Adam, Optimizer
  34from torch.utils.data import DataLoader
  35from tqdm import tqdm
  36
  37
  38class MyDataParallel(nn.DataParallel):
  39    """A wrapper for nn.DataParallel that allows to call methods of the model."""
  40
  41    def __init__(self, *args, **kwargs):
  42        """Initialize the class."""
  43        super().__init__(*args, **kwargs)
  44        self.process_labels = self.module.process_labels
  45
  46    def freeze_feature_extractor(self):
  47        """Freeze the feature extractor."""
  48        self.module.freeze_feature_extractor()
  49
  50    def unfreeze_feature_extractor(self):
  51        """Unfreeze the feature extractor."""
  52        self.module.unfreeze_feature_extractor()
  53
  54    def transform_labels(self, device):
  55        """Transform labels to the device."""
  56        return self.module.transform_labels(device)
  57
  58    def logit_scale(self):
  59        """Return the logit scale of the model."""
  60        return self.module.logit_scale()
  61
  62    def main_task_off(self):
  63        """Turn off the main task training."""
  64        self.module.main_task_off()
  65
  66    def state_dict(self, *args, **kwargs):
  67        """Return the state of the module."""
  68        return self.module.state_dict(*args, **kwargs)
  69
  70    def ssl_on(self):
  71        """Turn on SSL training."""
  72        self.module.ssl_on()
  73
  74    def ssl_off(self):
  75        """Turn off SSL training."""
  76        self.module.ssl_off()
  77
  78    def extract_features(self, x, start=0):
  79        """Extract features from the model."""
  80        return self.module.extract_features(x, start)
  81
  82
  83class Task:
  84    """A universal trainer class that performs training, evaluation and prediction for all types of tasks and data."""
  85
  86    def __init__(
  87        self,
  88        train_dataloader: DataLoader,
  89        model: Union[nn.Module, Model],
  90        loss: Callable[[torch.Tensor, torch.Tensor], float],
  91        num_epochs: int = 0,
  92        transformer: Transformer = None,
  93        ssl_losses: List = None,
  94        ssl_weights: List = None,
  95        lr: float = 1e-3,
  96        weight_decay: float = 0,
  97        metrics: Dict = None,
  98        val_dataloader: DataLoader = None,
  99        test_dataloader: DataLoader = None,
 100        optimizer: Optimizer = None,
 101        device: str = "cuda",
 102        verbose: bool = True,
 103        log_file: Union[str, None] = None,
 104        augment_train: int = 1,
 105        augment_val: int = 0,
 106        validation_interval: int = 1,
 107        predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None,
 108        primary_predict_function: Callable = None,
 109        exclusive: bool = True,
 110        ignore_tags: bool = True,
 111        threshold: float = 0.5,
 112        model_save_path: str = None,
 113        model_save_epochs: int = 5,
 114        pseudolabel: bool = False,
 115        pseudolabel_start: int = 100,
 116        correction_interval: int = 2,
 117        pseudolabel_alpha_f: float = 3,
 118        alpha_growth_stop: int = 600,
 119        parallel: bool = False,
 120        skip_metrics: List = None,
 121    ) -> None:
 122        """Initialize the class.
 123
 124        Parameters
 125        ----------
 126        train_dataloader : torch.utils.data.DataLoader
 127            a training dataloader
 128        model : dlc2action.model.base_model.Model
 129            a model
 130        loss : callable
 131            a loss function
 132        num_epochs : int, default 0
 133            the number of epochs
 134        transformer : dlc2action.transformer.base_transformer.Transformer, optional
 135            a transformer
 136        ssl_losses : list, optional
 137            a list of SSL losses
 138        ssl_weights : list, optional
 139            a list of SSL weights (if not provided initializes to 1)
 140        lr : float, default 1e-3
 141            learning rate
 142        weight_decay : float, default 0
 143            weight decay
 144        metrics : dict, optional
 145            a list of metric functions
 146        val_dataloader : torch.utils.data.DataLoader, optional
 147            a validation dataloader
 148        test_dataloader : torch.utils.data.DataLoader, optional
 149            a test dataloader
 150        optimizer : torch.optim.Optimizer, optional
 151            an optimizer (`Adam` by default)
 152        device : str, default 'cuda'
 153            the device to train the model on
 154        verbose : bool, default True
 155            if `True`, the process is described in standard output
 156        log_file : str, optional
 157            the path to a text file where the process will be logged
 158        augment_train : {1, 0}
 159            number of augmentations to apply at training
 160        augment_val : int, default 0
 161            number of augmentations to apply at validation
 162        validation_interval : int, default 1
 163            every time this number of epochs passes, validation metrics are computed
 164        predict_function : callable, optional
 165            a function that maps probabilities to class predictions (if not provided, a default is generated)
 166        primary_predict_function : callable, optional
 167            a function that maps model output to probabilities (if not provided, initialized as identity)
 168        exclusive : bool, default True
 169            set to False for multi-label classification
 170        ignore_tags : bool, default False
 171            if `True`, samples with different meta tags will be mixed in batches
 172        threshold : float, default 0.5
 173            the threshold used for multi-label classification default prediction function
 174        model_save_path : str, optional
 175            the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path
 176            is not provided)
 177        model_save_epochs : int, default 5
 178            the interval for saving the model checkpoints (the last epoch is always saved)
 179        pseudolabel : bool, default False
 180            if True, the pseudolabeling procedure will be applied
 181        pseudolabel_start : int, default 100
 182            pseudolabeling starts after this epoch
 183        correction_interval : int, default 1
 184            after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and
 185            new pseudolabels are generated
 186        pseudolabel_alpha_f : float, default 3
 187            the maximum value of pseudolabeling alpha
 188        alpha_growth_stop : int, default 600
 189            pseudolabeling alpha stops growing after this epoch
 190        parallel : bool, default False
 191            if True, the model is trained on multiple GPUs
 192        skip_metrics : list, optional
 193            a list of metrics to skip
 194
 195        """
 196        # pseudolabeling might be buggy right now -- not using it!
 197        if skip_metrics is None:
 198            skip_metrics = []
 199        self.train_dataloader = train_dataloader
 200        self.val_dataloader = val_dataloader
 201        self.test_dataloader = test_dataloader
 202        self.transformer = transformer
 203        self.num_epochs = num_epochs
 204        self.skip_metrics = skip_metrics
 205        self.verbose = verbose
 206        self.augment_train = int(augment_train)
 207        self.augment_val = int(augment_val)
 208        self.ignore_tags = ignore_tags
 209        self.validation_interval = int(validation_interval)
 210        self.log_file = log_file
 211        self.loss = loss
 212        self.model_save_path = model_save_path
 213        self.model_save_epochs = model_save_epochs
 214        self.epoch = 0
 215
 216        if metrics is None:
 217            metrics = {}
 218        self.metrics = metrics
 219
 220        if optimizer is None:
 221            optimizer = Adam
 222
 223        if ssl_weights is None:
 224            ssl_weights = [1 for _ in ssl_losses]
 225        if not isinstance(ssl_weights, Iterable):
 226            ssl_weights = [ssl_weights for _ in ssl_losses]
 227        self.ssl_weights = ssl_weights
 228
 229        self.optimizer_class = optimizer
 230        self.lr = lr
 231        self.weight_decay = weight_decay
 232        if not isinstance(model, Model):
 233            self.model = LoadedModel(model=model)
 234        else:
 235            self.set_model(model)
 236        self.parallel = parallel
 237
 238        if self.transformer is None:
 239            self.augment_val = 0
 240            self.augment_train = 0
 241            self.transformer = EmptyTransformer()
 242
 243        if self.augment_train > 1:
 244            warnings.warn(
 245                'The "augment_train" parameter is too large -> setting it to 1.'
 246            )
 247            self.augment_train = 1
 248
 249        try:
 250            if device == "auto":
 251                device = "cuda" if torch.cuda.is_available() else "cpu"
 252            self.device = torch.device(device)
 253        except:
 254            raise ("The format of the device is incorrect")
 255
 256        if ssl_losses is None:
 257            self.ssl_losses = [lambda x, y: 0]
 258        else:
 259            self.ssl_losses = ssl_losses
 260
 261        if primary_predict_function is None:
 262            if exclusive:
 263                primary_predict_function = lambda x: nn.Softmax(x, dim=1)
 264            else:
 265                primary_predict_function = lambda x: torch.sigmoid(x)
 266        self.primary_predict_function = primary_predict_function
 267
 268        if predict_function is None:
 269            if exclusive:
 270                self.predict_function = lambda x: torch.max(x.data, 1)[1]
 271            else:
 272                self.predict_function = lambda x: (x > threshold).int()
 273        else:
 274            self.predict_function = predict_function
 275
 276        self.pseudolabel = pseudolabel
 277        self.alpha_f = pseudolabel_alpha_f
 278        self.T2 = alpha_growth_stop
 279        self.T1 = pseudolabel_start
 280        self.t = correction_interval
 281        if self.T2 <= self.T1:
 282            raise ValueError(
 283                f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got "
 284                f"{pseudolabel_start=} and {alpha_growth_stop=}"
 285            )
 286        self.decision_thresholds = [0.5 for x in self.behaviors_dict()]
 287
 288    def save_checkpoint(self, checkpoint_path: str) -> None:
 289        """Save a general checkpoint.
 290
 291        Parameters
 292        ----------
 293        checkpoint_path : str
 294            the path where the checkpoint will be saved
 295
 296        """
 297        torch.save(
 298            {
 299                "epoch": self.epoch,
 300                "model_state_dict": self.model.state_dict(),
 301                "optimizer_state_dict": self.optimizer.state_dict(),
 302            },
 303            checkpoint_path,
 304        )
 305
 306    def load_from_checkpoint(
 307        self, checkpoint_path, only_model: bool = False, load_strict: bool = True
 308    ) -> None:
 309        """Load from a checkpoint.
 310
 311        Parameters
 312        ----------
 313        checkpoint_path : str
 314            the path to the checkpoint
 315        only_model : bool, default False
 316            if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state
 317            dictionary)
 318        load_strict : bool, default True
 319            if `True`, any inconsistencies in state dictionaries are regarded as errors
 320
 321        """
 322        if checkpoint_path is None:
 323            return
 324        checkpoint = torch.load(
 325            checkpoint_path, map_location=self.device, weights_only=False
 326        )
 327        self.model.to(self.device)
 328        self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict)
 329        if not only_model:
 330            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
 331            self.epoch = checkpoint["epoch"]
 332
 333    def save_model(self, save_path: str) -> None:
 334        """Save the model state dictionary.
 335
 336        Parameters
 337        ----------
 338        save_path : str
 339            the path where the state will be saved
 340
 341        """
 342        torch.save(self.model.state_dict(), save_path)
 343        print("saved the model successfully")
 344
 345    def _apply_predict_functions(self, predicted):
 346        """Map from model output to prediction."""
 347        predicted = self.primary_predict_function(predicted)
 348        predicted = self.predict_function(predicted)
 349        return predicted
 350
 351    def _get_prediction(
 352        self,
 353        main_input: Dict,
 354        tags: torch.Tensor,
 355        ssl_inputs: List = None,
 356        ssl_targets: List = None,
 357        augment_n: int = 0,
 358        embedding: bool = False,
 359        subsample: List = None,
 360    ) -> Tuple:
 361        """Get the prediction of `self.model` for input averaged over `augment_n` augmentations."""
 362        if augment_n == 0:
 363            augment_n = 1
 364            augment = False
 365        else:
 366            augment = True
 367        model_input, ssl_inputs, ssl_targets = self.transformer.transform(
 368            deepcopy(main_input),
 369            ssl_inputs=ssl_inputs,
 370            ssl_targets=ssl_targets,
 371            augment=augment,
 372            subsample=subsample,
 373        )
 374        if not embedding:
 375            # t1 = time()
 376            predicted, ssl_predicted = self.model(model_input, ssl_inputs, tag=tags)
 377            # t2 = time()
 378            # with open("/home/andy/Documents/EPFLsmartkitchenrecording/baselines/DLC2Action_baselines/inference_times/body_hand_eyes_2405.txt", "a") as f:
 379            # f.write(f"Time: {t2 - t1}\n")
 380        else:
 381            predicted = self.model.extract_features(model_input)
 382            ssl_predicted = None
 383        if self.parallel and predicted is not None and len(predicted.shape) == 4:
 384            predicted = predicted.reshape(
 385                (-1, model_input.shape[0], *predicted.shape[2:])
 386            )
 387        if predicted is not None and augment_n > 1:
 388            self.model.ssl_off()
 389            for i in range(augment_n - 1):
 390                model_input, *_ = self.transformer.transform(
 391                    deepcopy(main_input), augment=augment
 392                )
 393                if not embedding:
 394                    pred, _ = self.model(model_input, None, tag=tags)
 395                else:
 396                    pred = self.model.extract_features(model_input)
 397                predicted += pred.detach()
 398            self.model.ssl_on()
 399            predicted /= augment_n
 400        if self.model.process_labels:
 401            class_embedding = self.model.transform_labels(predicted.device)
 402            beh_dict = self.behaviors_dict()
 403            class_embedding = torch.cat(
 404                [class_embedding[beh_dict[k]] for k in sorted(beh_dict.keys())], 0
 405            )
 406            predicted = {
 407                "video": predicted,
 408                "text": class_embedding,
 409                "logit_scale": self.model.logit_scale(),
 410                "device": predicted.device,
 411            }
 412        return predicted, ssl_predicted, ssl_targets
 413
 414    def _ssl_loss(self, ssl_predicted, ssl_targets):
 415        """Apply SSL losses."""
 416        if self.ssl_losses is None or ssl_predicted is None or ssl_targets is None:
 417            return []
 418
 419        ssl_loss = []
 420        for loss, predicted, target in zip(self.ssl_losses, ssl_predicted, ssl_targets):
 421            ssl_loss.append(loss(predicted, target))
 422        return ssl_loss
 423
 424    def _loss_function(
 425        self,
 426        batch: Dict,
 427        augment_n: int,
 428        temporal_subsampling_size: int = None,
 429        skip_metrics: List = None,
 430    ) -> Tuple[float, float]:
 431        """Calculate the loss function and the metric for a dataloader batch.
 432
 433        Averaging the predictions over augment_n augmentations.
 434
 435        """
 436        if "target" not in batch or torch.isnan(batch["target"]).all():
 437            raise ValueError("Cannot compute loss function with nan targets!")
 438        main_input = {k: v.to(self.device) for k, v in batch["input"].items()}
 439        main_target, ssl_targets, ssl_inputs = None, None, None
 440        main_target = batch["target"].to(self.device)
 441        if "ssl_targets" in batch:
 442            ssl_targets = [
 443                (
 444                    {k: v.to(self.device) for k, v in x.items()}
 445                    if isinstance(x, dict)
 446                    else None
 447                )
 448                for x in batch["ssl_targets"]
 449            ]
 450        if "ssl_inputs" in batch:
 451            ssl_inputs = [
 452                (
 453                    {k: v.to(self.device) for k, v in x.items()}
 454                    if isinstance(x, dict)
 455                    else None
 456                )
 457                for x in batch["ssl_inputs"]
 458            ]
 459        if temporal_subsampling_size is not None:
 460            subsample = sorted(
 461                random.sample(
 462                    range(main_target.shape[-1]),
 463                    int(temporal_subsampling_size * main_target.shape[-1]),
 464                )
 465            )
 466            main_target = main_target[..., subsample]
 467        else:
 468            subsample = None
 469
 470        if self.ignore_tags:
 471            tag = None
 472        else:
 473            tag = batch.get("tag")
 474        predicted, ssl_predicted, ssl_targets = self._get_prediction(
 475            main_input, tag, ssl_inputs, ssl_targets, augment_n, subsample=subsample
 476        )
 477        del main_input, ssl_inputs
 478        return self._compute(
 479            ssl_predicted,
 480            ssl_targets,
 481            predicted,
 482            main_target,
 483            tag=batch.get("tag"),
 484            skip_metrics=skip_metrics,
 485        )
 486
 487    def _compute(
 488        self,
 489        ssl_predicted: List,
 490        ssl_targets: List,
 491        predicted: torch.Tensor,
 492        main_target: torch.Tensor,
 493        tag: Any = None,
 494        skip_loss: bool = False,
 495        apply_primary_function: bool = True,
 496        skip_metrics: List = None,
 497    ) -> Tuple[float, float]:
 498        """Compute the losses and metrics from predictions."""
 499        if skip_metrics is None:
 500            skip_metrics = []
 501        if not skip_loss:
 502            ssl_losses = self._ssl_loss(ssl_predicted, ssl_targets)
 503            if predicted is not None:
 504                loss = self.loss(predicted, main_target)
 505            else:
 506                loss = 0
 507        else:
 508            ssl_losses, loss = [], 0
 509
 510        if predicted is not None:
 511            if isinstance(predicted, dict):
 512                predicted = {
 513                    k: v.detach()
 514                    for k, v in predicted.items()
 515                    if isinstance(v, torch.Tensor)
 516                }
 517            else:
 518                predicted = predicted.detach()
 519            if apply_primary_function:
 520                predicted = self.primary_predict_function(predicted)
 521            predicted_transformed = self.predict_function(predicted)
 522
 523            for name, metric_function in self.metrics.items():
 524                if name not in skip_metrics:
 525                    if metric_function.needs_raw_data:
 526                        metric_function.update(predicted, main_target, tag)
 527                    else:
 528                        metric_function.update(
 529                            predicted_transformed,
 530                            main_target,
 531                            tag,
 532                        )
 533        return loss, ssl_losses
 534
 535    def _calculate_metrics(self) -> Dict:
 536        """Calculate the final values of epoch metrics."""
 537        epoch_metrics = {}
 538        for metric_name, metric in self.metrics.items():
 539            m = metric.calculate()
 540            if type(m) is dict:
 541                for k, v in m.items():
 542                    if type(v) is torch.Tensor:
 543                        v = v.item()
 544                    epoch_metrics[f"{metric_name}_{k}"] = v
 545            else:
 546                if type(m) is torch.Tensor:
 547                    m = m.item()
 548                epoch_metrics[metric_name] = m
 549            metric.reset()
 550        # print("calculate metric", epoch_metrics)
 551        return epoch_metrics
 552
 553    def _run_epoch(
 554        self,
 555        dataloader: DataLoader,
 556        mode: str,
 557        augment_n: int,
 558        verbose: bool = False,
 559        unlabeled: bool = None,
 560        alpha: float = 1,
 561        temporal_subsampling_size: int = None,
 562    ) -> Tuple:
 563        """Run one epoch on dataloader.
 564
 565        Averaging the predictions over augment_n augmentations.
 566        Use "train" mode for training and "val" mode for evaluation.
 567
 568        """
 569        if mode == "train":
 570            self.model.train()
 571        elif mode == "val":
 572            self.model.eval()
 573            pass
 574        else:
 575            raise ValueError(
 576                f'Mode {mode} is not recognized, please choose either "train" for training or "val" for validation'
 577            )
 578        if self.ignore_tags:
 579            tags = [None]
 580        else:
 581            tags = dataloader.dataset.get_tags()
 582        epoch_loss = 0
 583        epoch_ssl_loss = defaultdict(lambda: 0)
 584        data_len = 0
 585        set_pars = dataloader.dataset.set_indexing_parameters
 586        skip_metrics = self.skip_metrics if mode == "train" else None
 587        for tag in tags:
 588            set_pars(unlabeled=unlabeled, tag=tag)
 589            data_len += len(dataloader)
 590            if verbose:
 591                dataloader = tqdm(dataloader)
 592            for batch in dataloader:
 593                loss, ssl_losses = self._loss_function(
 594                    batch,
 595                    augment_n,
 596                    temporal_subsampling_size=temporal_subsampling_size,
 597                    skip_metrics=skip_metrics,
 598                )
 599                if loss != 0:
 600                    loss = loss * alpha
 601                    epoch_loss += loss.item()
 602                for i, (ssl_loss, weight) in enumerate(
 603                    zip(ssl_losses, self.ssl_weights)
 604                ):
 605                    if ssl_loss != 0:
 606                        epoch_ssl_loss[i] += ssl_loss.item()
 607                        loss = loss + weight * ssl_loss
 608                if mode == "train":
 609                    self.optimizer.zero_grad()
 610                    if loss.requires_grad:
 611                        loss.backward()
 612                    self.optimizer.step()
 613
 614        epoch_loss = epoch_loss / data_len
 615        epoch_ssl_loss = {k: v / data_len for k, v in epoch_ssl_loss.items()}
 616        epoch_metrics = self._calculate_metrics()
 617
 618        return epoch_loss, epoch_ssl_loss, epoch_metrics
 619
 620    def train(
 621        self,
 622        trial: Trial = None,
 623        optimized_metric: str = None,
 624        to_ram: bool = False,
 625        autostop_interval: int = 30,
 626        autostop_threshold: float = 0.001,
 627        autostop_metric: str = None,
 628        main_task_on: bool = True,
 629        ssl_on: bool = True,
 630        temporal_subsampling_size: int = None,
 631        loading_bar: bool = False,
 632    ) -> Tuple:
 633        """Train the task and return a log of epoch-average loss and metric.
 634
 635        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 636        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 637        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 638        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 639
 640        Parameters
 641        ----------
 642        trial : Trial
 643            an `optuna` trial (for hyperparameter searches)
 644        optimized_metric : str
 645            the name of the metric being optimized (for hyperparameter searches)
 646        to_ram : bool, default False
 647            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
 648            if the dataset is too large)
 649        autostop_interval : int, default 50
 650            the number of epochs to average the autostop metric over
 651        autostop_threshold : float, default 0.001
 652            the autostop difference threshold
 653        autostop_metric : str, optional
 654            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 655        main_task_on : bool, default True
 656            if `False`, the main task (action segmentation) will not be used in training
 657        ssl_on : bool, default True
 658            if `False`, the SSL task will not be used in training
 659        temporal_subsampling_size : int, optional
 660            if not `None`, the temporal subsampling will be used in training with the given size
 661        loading_bar : bool, default False
 662            if `True`, a loading bar will be displayed
 663
 664        Returns
 665        -------
 666        loss_log: list
 667            a list of float loss function values for each epoch
 668        metrics_log: dict
 669            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
 670            names, values are lists of function values)
 671
 672        """
 673        if self.parallel and not isinstance(self.model, nn.DataParallel):
 674            self.model = MyDataParallel(self.model)
 675        self.model.to(self.device)
 676        assert autostop_metric in [None, "loss"] + list(self.metrics)
 677        autostop_interval //= self.validation_interval
 678        if trial is not None and optimized_metric is None:
 679            raise ValueError(
 680                "You need to provide the optimized metric name (optimized_metric parameter) "
 681                "for optuna pruning to work!"
 682            )
 683        if to_ram:
 684            print("transferring datasets to RAM...")
 685            self.train_dataloader.dataset.to_ram()
 686            if self.val_dataloader is not None and len(self.val_dataloader) > 0:
 687                self.val_dataloader.dataset.to_ram()
 688        loss_log = {"train": [], "val": []}
 689        metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])}
 690        if not main_task_on:
 691            self.model.main_task_off()
 692        if not ssl_on:
 693            self.model.ssl_off()
 694        while self.epoch < self.num_epochs:
 695            self.epoch += 1
 696            unlabeled = None
 697            alpha = 1
 698            if self.pseudolabel:
 699                if self.epoch >= self.T1:
 700                    unlabeled = (self.epoch - self.T1) % self.t != 0
 701                    if unlabeled:
 702                        alpha = self._alpha(self.epoch)
 703                else:
 704                    unlabeled = False
 705            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
 706                dataloader=self.train_dataloader,
 707                mode="train",
 708                augment_n=self.augment_train,
 709                unlabeled=unlabeled,
 710                alpha=alpha,
 711                temporal_subsampling_size=temporal_subsampling_size,
 712                verbose=loading_bar,
 713            )
 714            loss_log["train"].append(epoch_loss)
 715            epoch_string = f"[epoch {self.epoch}]"
 716            if self.pseudolabel:
 717                if unlabeled:
 718                    epoch_string += " (unlabeled)"
 719                else:
 720                    epoch_string += " (labeled)"
 721            epoch_string += f": loss {epoch_loss:.4f}"
 722
 723            if len(epoch_ssl_loss) != 0:
 724                for key, value in sorted(epoch_ssl_loss.items()):
 725                    metrics_log["train"][f"ssl_loss_{key}"].append(value)
 726                    epoch_string += f", ssl_loss_{key} {value:.4f}"
 727
 728            for metric_name, metric_value in sorted(epoch_metrics.items()):
 729                if metric_name not in self.skip_metrics:
 730                    if isinstance(metric_value, list):
 731                        metric_value = torch.mean(torch.Tensor(metric_value))
 732                    epoch_string += f", {metric_name} {metric_value:.3f}"
 733                    metrics_log["train"][metric_name].append(metric_value)
 734
 735            if (
 736                self.val_dataloader is not None
 737                and self.epoch % self.validation_interval == 0
 738            ):
 739                with torch.no_grad():
 740                    epoch_string += "\n"
 741                    (
 742                        val_epoch_loss,
 743                        val_epoch_ssl_loss,
 744                        val_epoch_metrics,
 745                    ) = self._run_epoch(
 746                        dataloader=self.val_dataloader,
 747                        mode="val",
 748                        augment_n=self.augment_val,
 749                    )
 750                    loss_log["val"].append(val_epoch_loss)
 751                    epoch_string += f"validation: loss {val_epoch_loss:.4f}"
 752
 753                    if len(val_epoch_ssl_loss) != 0:
 754                        for key, value in sorted(val_epoch_ssl_loss.items()):
 755                            metrics_log["val"][f"ssl_loss_{key}"].append(value)
 756                            epoch_string += f", ssl_loss_{key} {value:.4f}"
 757
 758                    for metric_name, metric_value in sorted(val_epoch_metrics.items()):
 759                        if isinstance(metric_value, list):
 760                            metric_value = torch.mean(torch.Tensor(metric_value))
 761                        metrics_log["val"][metric_name].append(metric_value)
 762                        epoch_string += f", {metric_name} {metric_value:.3f}"
 763
 764                if trial is not None:
 765                    if optimized_metric not in metrics_log["val"]:
 766                        raise ValueError(
 767                            f"The {optimized_metric} metric set for optimization is not being logged!"
 768                        )
 769                    trial.report(metrics_log["val"][optimized_metric][-1], self.epoch)
 770                    if trial.should_prune():
 771                        raise TrialPruned()
 772
 773            if self.verbose:
 774                print(epoch_string)
 775
 776            if self.log_file is not None:
 777                with open(self.log_file, "a") as f:
 778                    f.write(epoch_string + "\n")
 779
 780            save_condition = (
 781                (self.model_save_epochs != 0)
 782                and (self.epoch % self.model_save_epochs == 0)
 783            ) or (self.epoch == self.num_epochs)
 784
 785            if self.epoch > 0 and save_condition and self.model_save_path is not None:
 786                epoch_s = str(self.epoch).zfill(len(str(self.num_epochs)))
 787                self.save_checkpoint(
 788                    os.path.join(self.model_save_path, f"epoch{epoch_s}.pt")
 789                )
 790
 791            if self.pseudolabel and self.epoch >= self.T1 and not unlabeled:
 792                self._set_pseudolabels()
 793
 794            if autostop_metric == "loss":
 795                if len(loss_log["val"]) > autostop_interval * 2:
 796                    if (
 797                        np.mean(loss_log["val"][-autostop_interval:])
 798                        < np.mean(
 799                            loss_log["val"][-2 * autostop_interval : -autostop_interval]
 800                        )
 801                        + autostop_threshold
 802                    ):
 803                        break
 804            elif autostop_metric in metrics_log["val"]:
 805                if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2:
 806                    if (
 807                        np.mean(
 808                            metrics_log["val"][autostop_metric][-autostop_interval:]
 809                        )
 810                        < np.mean(
 811                            metrics_log["val"][autostop_metric][
 812                                -2 * autostop_interval : -autostop_interval
 813                            ]
 814                        )
 815                        + autostop_threshold
 816                    ):
 817                        break
 818
 819        metrics_log = {k: dict(v) for k, v in metrics_log.items()}
 820
 821        return loss_log, metrics_log
 822
 823    def evaluate_prediction(
 824        self,
 825        prediction: Union[torch.Tensor, Dict],
 826        data: Union[DataLoader, BehaviorDataset, str] = None,
 827        batch_size: int = 32,
 828        indices: list = None,
 829    ) -> Tuple:
 830        """Compute metrics for a prediction.
 831
 832        Parameters
 833        ----------
 834        prediction : torch.Tensor
 835            the prediction
 836        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 837            the data the prediction was made for (if not provided, take the validation dataset)
 838        batch_size : int, default 32
 839            the batch size
 840
 841        Returns
 842        -------
 843        loss : float
 844            the average value of the loss function
 845        metric : dict
 846            a dictionary of average values of metric functions
 847        """
 848
 849        if type(data) is not DataLoader:
 850            dataset = self._get_dataset(data)
 851            data = DataLoader(dataset, shuffle=False, batch_size=batch_size)
 852        epoch_loss = 0
 853        if isinstance(prediction, dict):
 854            num_classes = len(self.behaviors_dict())
 855            length = dataset.len_segment()
 856            coords = dataset.annotation_store.get_original_coordinates()
 857            for batch in data:
 858                main_target = batch["target"]
 859                pr_coords = coords[batch["index"]]
 860                predicted = torch.zeros((len(pr_coords), num_classes, length))
 861                for i, c in enumerate(pr_coords):
 862                    video_id = dataset.input_store.get_video_id(c)
 863                    clip_id = dataset.input_store.get_clip_id(c)
 864                    start, end = dataset.input_store.get_clip_start_end(c)
 865                    beh_ind = list(prediction[video_id]["classes"].keys())
 866                    pred_tmp = prediction[video_id][clip_id][beh_ind, :]
 867                    predicted[i, :, : end - start] = pred_tmp[:, start:end]
 868                self._compute(
 869                    [],
 870                    [],
 871                    predicted,
 872                    main_target,
 873                    skip_loss=True,
 874                    tag=batch.get("tag"),
 875                    apply_primary_function=False,
 876                )
 877        else:
 878            for batch in data:
 879                main_target = batch["target"]
 880                predicted = prediction[batch["index"]]
 881                if not indices is None:
 882                    indices_new = [indices.index(i) for i in range(len(indices))]
 883                    predicted = predicted[:, indices_new]
 884                self._compute(
 885                    [],
 886                    [],
 887                    predicted,
 888                    main_target,
 889                    skip_loss=True,
 890                    tag=batch.get("tag"),
 891                    apply_primary_function=False,
 892                )
 893        epoch_metrics = self._calculate_metrics()
 894        # strings = [
 895        #     f"{metric_name} {metric_value:.3f}"
 896        #     for metric_name, metric_value in epoch_metrics.items()
 897        # ]
 898        # val_string = ", ".join(sorted(strings))
 899        # print(val_string)
 900        return epoch_loss, epoch_metrics
 901
 902    def evaluate(
 903        self,
 904        data: Union[DataLoader, BehaviorDataset, str] = None,
 905        augment_n: int = 0,
 906        batch_size: int = 32,
 907        verbose: bool = True,
 908    ) -> Tuple:
 909        """Evaluate the Task model.
 910
 911        Parameters
 912        ----------
 913        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 914            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 915        augment_n : int, default 0
 916            the number of augmentations to average results over
 917        batch_size : int, default 32
 918            the batch size
 919        verbose : bool, default True
 920            if True, the process is reported to standard output
 921
 922        Returns
 923        -------
 924        loss : float
 925            the average value of the loss function
 926        ssl_loss : float
 927            the average value of the SSL loss function
 928        metric : dict
 929            a dictionary of average values of metric functions
 930
 931        """
 932        if self.parallel and not isinstance(self.model, nn.DataParallel):
 933            self.model = MyDataParallel(self.model)
 934        self.model.to(self.device)
 935        if type(data) is not DataLoader:
 936            data = self._get_dataset(data)
 937            data = DataLoader(data, shuffle=False, batch_size=batch_size)
 938        with torch.no_grad():
 939            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
 940                dataloader=data, mode="val", augment_n=augment_n, verbose=verbose
 941            )
 942        val_string = f"loss {epoch_loss:.4f}"
 943        for metric_name, metric_value in sorted(epoch_metrics.items()):
 944            val_string += f", {metric_name} {metric_value:.3f}"
 945        print(val_string)
 946        return epoch_loss, epoch_ssl_loss, epoch_metrics
 947
 948    def predict(
 949        self,
 950        data: Union[DataLoader, BehaviorDataset, str] = None,
 951        raw_output: bool = False,
 952        apply_primary_function: bool = True,
 953        augment_n: int = 0,
 954        batch_size: int = 32,
 955        train_mode: bool = False,
 956        to_ram: bool = False,
 957        embedding: bool = False,
 958    ) -> torch.Tensor:
 959        """Make a prediction with the Task model.
 960
 961        Parameters
 962        ----------
 963        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional
 964            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 965        raw_output : bool, default False
 966            if `True`, the raw predicted probabilities are returned
 967        apply_primary_function : bool, default True
 968            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
 969            the input)
 970        augment_n : int, default 0
 971            the number of augmentations to average results over
 972        batch_size : int, default 32
 973            the batch size
 974        train_mode : bool, default False
 975            if `True`, the model is used in training mode (affects dropout and batch normalization layers)
 976        to_ram : bool, default False
 977            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
 978            if the dataset is too large)
 979        embedding : bool, default False
 980            if `True`, the output of feature extractor is returned, ignoring the prediction module of the model
 981
 982        Returns
 983        -------
 984        prediction : torch.Tensor
 985            a prediction for the input data
 986
 987        """
 988        if self.parallel and not isinstance(self.model, nn.DataParallel):
 989            self.model = MyDataParallel(self.model)
 990        self.model.to(self.device)
 991        if train_mode:
 992            self.model.train()
 993        else:
 994            self.model.eval()
 995        output = []
 996        if embedding:
 997            raw_output = True
 998            apply_primary_function = True
 999        if type(data) is not DataLoader:
1000            data = self._get_dataset(data)
1001            if to_ram:
1002                print("transferring dataset to RAM...")
1003                data.to_ram()
1004            data = DataLoader(data, shuffle=False, batch_size=batch_size)
1005        self.model.ssl_off()
1006        with torch.no_grad():
1007            for batch in tqdm(data):
1008                input = {k: v.to(self.device) for k, v in batch["input"].items()}
1009                predicted, _, _ = self._get_prediction(
1010                    input,
1011                    batch.get("tag"),
1012                    augment_n=augment_n,
1013                    embedding=embedding,
1014                )
1015                if apply_primary_function:
1016                    predicted = self.primary_predict_function(predicted)
1017                if not raw_output:
1018                    predicted = self.predict_function(predicted)
1019                output.append(predicted.detach().cpu())
1020        self.model.ssl_on()
1021        output = torch.cat(output).detach()
1022        return output
1023
1024    def dataset(self, mode="train") -> BehaviorDataset:
1025        """Get a dataset.
1026
1027        Parameters
1028        ----------
1029        mode : {'train', 'val', 'test}
1030            the dataset to get
1031
1032        Returns
1033        -------
1034        dataset : dlc2action.data.dataset.BehaviorDataset
1035            the dataset
1036
1037        """
1038        dataloader = self.dataloader(mode)
1039        if dataloader is None:
1040            raise ValueError("The length of the dataloader is 0!")
1041        return dataloader.dataset
1042
1043    def dataloader(self, mode: str = "train") -> DataLoader:
1044        """Get a dataloader.
1045
1046        Parameters
1047        ----------
1048        mode : {'train', 'val', 'test}
1049            the dataset to get
1050
1051        Returns
1052        -------
1053        dataloader : torch.utils.data.DataLoader
1054            the dataloader
1055
1056        """
1057        if mode == "train":
1058            return self.train_dataloader
1059        elif mode == "val":
1060            return self.val_dataloader
1061        elif mode == "test":
1062            return self.test_dataloader
1063        else:
1064            raise ValueError(
1065                f'The {mode} mode is not recognized, please choose from "train", "val" or "test"'
1066            )
1067
1068    def _get_dataset(self, dataset):
1069        """Get a dataset from a dataloader, a string ('train', 'test' or 'val') or `None` (default)."""
1070        if dataset is None:
1071            dataset = self.dataset("val")
1072        elif dataset in ["train", "val", "test"]:
1073            dataset = self.dataset(dataset)
1074        elif type(dataset) is DataLoader:
1075            dataset = dataset.dataset
1076        if type(dataset) is BehaviorDataset:
1077            return dataset
1078        else:
1079            raise TypeError(f"The {type(dataset)} type of dataset is not recognized!")
1080
1081    def _get_dataloader(self, dataset):
1082        """Get a dataloader from a dataset, a string ('train', 'test' or 'val') or `None` (default)."""
1083        if dataset is None:
1084            dataset = self.dataloader("val")
1085        elif dataset in ["train", "val", "test"]:
1086            dataset = self.dataloader(dataset)
1087            if dataset is None:
1088                raise ValueError(f"The length of the dataloader is 0!")
1089        elif type(dataset) is BehaviorDataset:
1090            dataset = DataLoader(dataset)
1091        if type(dataset) is DataLoader:
1092            return dataset
1093        else:
1094            raise TypeError(f"The {type(dataset)} type of dataset is not recognized!")
1095
1096    def generate_full_length_prediction(
1097        self, dataset=None, batch_size=32, augment_n=10
1098    ):
1099        """Compile a prediction for the original input sequences.
1100
1101        Parameters
1102        ----------
1103        dataset : BehaviorDataset, optional
1104            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1105        batch_size : int, default 32
1106            the batch size
1107        augment_n : int, default 10
1108            the number of augmentations to average results over
1109
1110        Returns
1111        -------
1112        prediction : dict
1113            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1114            are prediction tensors
1115
1116        """
1117        dataset = self._get_dataset(dataset)
1118        if not isinstance(dataset, BehaviorDataset):
1119            raise TypeError(
1120                f"The dataset parameter has to be either None, string, "
1121                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1122            )
1123        predicted = self.predict(
1124            dataset,
1125            raw_output=True,
1126            apply_primary_function=True,
1127            augment_n=augment_n,
1128            batch_size=batch_size,
1129        )
1130        predicted = dataset.generate_full_length_prediction(predicted)
1131        predicted = {
1132            v_id: {
1133                clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze()
1134                for clip_id, v in video_dict.items()
1135            }
1136            for v_id, video_dict in predicted.items()
1137        }
1138        return predicted
1139
1140    def generate_submission(
1141        self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10
1142    ):
1143        """Generate a MABe-22 style submission dictionary.
1144
1145        Parameters
1146        ----------
1147        frame_number_map_file : str
1148            the path to the frame number map file
1149        dataset : BehaviorDataset, optional
1150            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1151        batch_size : int, default 32
1152            the batch size
1153        augment_n : int, default 10
1154            the number of augmentations to average results over
1155
1156        Returns
1157        -------
1158        submission : dict
1159            a dictionary with frame number mapping and embeddings
1160
1161        """
1162        dataset = self._get_dataset(dataset)
1163        if not isinstance(dataset, BehaviorDataset):
1164            raise TypeError(
1165                f"The dataset parameter has to be either None, string, "
1166                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1167            )
1168        predicted = self.predict(
1169            dataset,
1170            raw_output=True,
1171            apply_primary_function=True,
1172            augment_n=augment_n,
1173            batch_size=batch_size,
1174            embedding=True,
1175        )
1176        predicted = dataset.generate_full_length_prediction(predicted)
1177        frame_map = np.load(frame_number_map_file, allow_pickle=True).item()
1178        length = frame_map[list(frame_map.keys())[-1]][1]
1179        embeddings = None
1180        for video_id in list(predicted.keys()):
1181            split = video_id.split("--")
1182            if len(split) != 2 or len(predicted[video_id]) > 1:
1183                raise RuntimeError(
1184                    "Generating submissions is only implemented for the mabe22 dataset!"
1185                )
1186            if split[1] not in frame_map:
1187                raise RuntimeError(f"The {split[1]} video is not in the frame map file")
1188            v_id = split[1]
1189            clip_id = list(predicted[video_id].keys())[0]
1190            if embeddings is None:
1191                embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0]))
1192            start, end = frame_map[v_id]
1193            embeddings[start:end, :] = predicted[video_id][clip_id].T
1194            predicted.pop(video_id)
1195        predicted = {
1196            "frame_number_map": frame_map,
1197            "embeddings": embeddings.astype(np.float32),
1198        }
1199        return predicted
1200
1201    def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor:
1202        """Get a list of True group beginning and end indices from a boolean tensor."""
1203        output, indices = torch.unique_consecutive(tensor, return_inverse=True)
1204        true_indices = torch.where(output)[0]
1205        starts = torch.tensor(
1206            [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
1207        )
1208        ends = torch.tensor(
1209            [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
1210        )
1211        return torch.stack([starts, ends]).T
1212
1213    def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor:
1214        """Get rid of jittering in a non-exclusive classification tensor.
1215
1216        First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than
1217        `smooth_interval`.
1218
1219        """
1220        if len(tensor.shape) > 1:
1221            for c in tensor.shape[1]:
1222                intervals = self._get_intervals(tensor[:, c] == 0)
1223                interval_lengths = torch.tensor(
1224                    [interval[1] - interval[0] for interval in intervals]
1225                )
1226                short_intervals = intervals[interval_lengths <= smooth_interval]
1227                for start, end in short_intervals:
1228                    tensor[start:end, c] = 1
1229                intervals = self._get_intervals(tensor[:, c] == 1)
1230                interval_lengths = torch.tensor(
1231                    [interval[1] - interval[0] for interval in intervals]
1232                )
1233                short_intervals = intervals[interval_lengths <= smooth_interval]
1234                for start, end in short_intervals:
1235                    tensor[start:end, c] = 0
1236        else:
1237            for c in tensor.unique():
1238                intervals = self._get_intervals(tensor == c)
1239                interval_lengths = torch.tensor(
1240                    [interval[1] - interval[0] for interval in intervals]
1241                )
1242                short_intervals = intervals[interval_lengths <= smooth_interval]
1243                for start, end in short_intervals:
1244                    if start == 0:
1245                        tensor[start:end] = tensor[end + 1]
1246                    else:
1247                        tensor[start:end] = tensor[start - 1]
1248        return tensor
1249
1250    def _visualize_results_single(
1251        self,
1252        behavior: str,
1253        save_path: str = None,
1254        add_legend: bool = True,
1255        ground_truth: bool = True,
1256        hide_axes: bool = False,
1257        width: int = 10,
1258        whole_video: bool = False,
1259        transparent: bool = False,
1260        dataset: BehaviorDataset = None,
1261        smooth_interval: int = 0,
1262        title: str = None,
1263    ):
1264        """Visualize random predictions.
1265
1266        Parameters
1267        ----------
1268        behavior : str
1269            the behavior to visualize
1270        save_path : str, optional
1271            the path where the plot will be saved
1272        add_legend : bool, default True
1273            if `True`, legend will be added to the plot
1274        ground_truth : bool, default True
1275            if `True`, ground truth will be added to the plot
1276        colormap : str, default 'viridis'
1277            the `matplotlib` colormap to use
1278        hide_axes : bool, default True
1279            if `True`, the axes will be hidden on the plot
1280        min_classes : int, default 1
1281            the minimum number of classes in a displayed interval
1282        width : float, default 10
1283            the width of the plot
1284        whole_video : bool, default False
1285            if `True`, whole videos are plotted instead of segments
1286        transparent : bool, default False
1287            if `True`, the background on the plot is transparent
1288        dataset : BehaviorDataset, optional
1289            the dataset to make the prediction for (if not provided, the validation dataset is used)
1290        drop_classes : set, optional
1291            a set of class names to not be displayed
1292        smooth_interval : int, default 0
1293            the interval to smooth the predictions over
1294        title : str, optional
1295            the title of the plot
1296
1297        """
1298        if title is None:
1299            title = ""
1300        dataset = self._get_dataset(dataset)
1301        inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()}
1302        label_ind = inverse_dict[behavior]
1303        labels = {1: behavior, -100: "unknown"}
1304        label_keys = [1, -100]
1305        color_list = ["blue", "gray"]
1306        if whole_video:
1307            predicted = self.generate_full_length_prediction(dataset)
1308            keys = list(predicted.keys())
1309        counter = 0
1310        if whole_video:
1311            max_iter = len(keys) * 5
1312        else:
1313            max_iter = len(dataset) * 5
1314        ok = False
1315        while not ok:
1316            counter += 1
1317            if counter > max_iter:
1318                raise RuntimeError(
1319                    "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive"
1320                )
1321            if whole_video:
1322                i = randint(0, len(keys) - 1)
1323                prediction = predicted[keys[i]]
1324                keys_i = list(prediction.keys())
1325                j = randint(0, len(keys_i) - 1)
1326                full_p = prediction[keys_i[j]]
1327                prediction = prediction[keys_i[j]][label_ind]
1328            else:
1329                dataloader = DataLoader(dataset)
1330                i = randint(0, len(dataloader) - 1)
1331                for num, batch in enumerate(dataloader):
1332                    if num == i:
1333                        break
1334                input_data = {k: v.to(self.device) for k, v in batch["input"].items()}
1335                prediction, *_ = self._get_prediction(
1336                    input_data, batch.get("tag"), augment_n=5
1337                )
1338                prediction = self._apply_predict_functions(prediction)
1339                j = randint(0, len(prediction) - 1)
1340                full_p = prediction[j]
1341                prediction = prediction[j][label_ind]
1342            classes = [x for x in torch.unique(prediction) if int(x) in label_keys]
1343            ok = 1 in classes
1344        fig, ax = plt.subplots(figsize=(width, 2))
1345        for c in classes:
1346            c_i = label_keys.index(int(c))
1347            output, indices, counts = torch.unique_consecutive(
1348                prediction == c, return_inverse=True, return_counts=True
1349            )
1350            long_indices = torch.where(output)[0]
1351            res_indices_start = [
1352                (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
1353            ]
1354            res_indices_end = [
1355                (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1356                for i in long_indices
1357            ]
1358            res_indices_len = [
1359                end - start for start, end in zip(res_indices_start, res_indices_end)
1360            ]
1361            ax.broken_barh(
1362                list(zip(res_indices_start, res_indices_len)),
1363                (0, 1),
1364                label=labels[int(c)],
1365                facecolors=color_list[c_i],
1366            )
1367        if ground_truth:
1368            gt = batch["target"][j][label_ind].to(self.device)
1369            classes_gt = [x for x in torch.unique(gt) if int(x) in label_keys]
1370            for c in classes_gt:
1371                c_i = label_keys.index(int(c))
1372                if c in classes:
1373                    behavior = None
1374                else:
1375                    behavior = labels[int(c)]
1376                output, indices, counts = torch.unique_consecutive(
1377                    gt == c, return_inverse=True, return_counts=True
1378                )
1379                long_indices = torch.where(output * (counts > 5))[0]
1380                res_indices_start = [
1381                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1382                    for i in long_indices
1383                ]
1384                res_indices_end = [
1385                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1386                    for i in long_indices
1387                ]
1388                res_indices_len = [
1389                    end - start
1390                    for start, end in zip(res_indices_start, res_indices_end)
1391                ]
1392                ax.broken_barh(
1393                    list(zip(res_indices_start, res_indices_len)),
1394                    (1.5, 1),
1395                    facecolors=color_list[c_i],
1396                    label=behavior,
1397                )
1398        self._compute(
1399            main_target=batch["target"][j].unsqueeze(0).to(self.device),
1400            predicted=full_p.unsqueeze(0).to(self.device),
1401            ssl_targets=[],
1402            ssl_predicted=[],
1403            skip_loss=True,
1404        )
1405        metrics = self._calculate_metrics()
1406        if smooth_interval > 0:
1407            smoothed = self._smooth(full_p, smooth_interval=smooth_interval)[
1408                label_ind, :
1409            ]
1410            for c in classes:
1411                c_i = label_keys.index(int(c))
1412                output, indices, counts = torch.unique_consecutive(
1413                    smoothed == c, return_inverse=True, return_counts=True
1414                )
1415                long_indices = torch.where(output)[0]
1416                res_indices_start = [
1417                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1418                    for i in long_indices
1419                ]
1420                res_indices_end = [
1421                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1422                    for i in long_indices
1423                ]
1424                res_indices_len = [
1425                    end - start
1426                    for start, end in zip(res_indices_start, res_indices_end)
1427                ]
1428                ax.broken_barh(
1429                    list(zip(res_indices_start, res_indices_len)),
1430                    (3, 1),
1431                    label=labels[int(c)],
1432                    facecolors=color_list[c_i],
1433                )
1434        keys = list(metrics.keys())
1435        for key in keys:
1436            if key.split("_")[-1] != (str(label_ind)):
1437                metrics.pop(key)
1438        title = [title]
1439        for key, value in metrics.items():
1440            title.append(f"{'_'.join(key.split('_')[: -1])}: {value:.2f}")
1441        title = ", ".join(title)
1442        if not ground_truth:
1443            ax.axes.yaxis.set_visible(False)
1444        else:
1445            ax.set_yticks([0.5, 2])
1446            ax.set_yticklabels(["prediction", "ground truth"])
1447        if add_legend:
1448            ax.legend()
1449        if hide_axes:
1450            plt.axis("off")
1451        plt.title(title)
1452        plt.xlim((0, len(prediction)))
1453        if save_path is not None:
1454            plt.savefig(save_path, transparent=transparent)
1455        plt.show()
1456
1457    def visualize_results(
1458        self,
1459        save_path: str = None,
1460        add_legend: bool = True,
1461        ground_truth: bool = True,
1462        colormap: str = "viridis",
1463        hide_axes: bool = False,
1464        min_classes: int = 1,
1465        width: int = 10,
1466        whole_video: bool = False,
1467        transparent: bool = False,
1468        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1469        drop_classes: Set = None,
1470        search_classes: Set = None,
1471        smooth_interval_prediction: int = None,
1472        font_size: float = None,
1473        num_plots: int = 1,
1474        window_size:int =400
1475    ):
1476        """Visualize random predictions.
1477
1478        Parameters
1479        ----------
1480        save_path : str, optional
1481            the path where the plot will be saved
1482        add_legend : bool, default True
1483            if `True`, legend will be added to the plot
1484        ground_truth : bool, default True
1485            if `True`, ground truth will be added to the plot
1486        colormap : str, default 'Accent'
1487            the `matplotlib` colormap to use
1488        hide_axes : bool, default True
1489            if `True`, the axes will be hidden on the plot
1490        min_classes : int, default 1
1491            the minimum number of classes in a displayed interval
1492        width : float, default 10
1493            the width of the plot
1494        whole_video : bool, default False
1495            if `True`, whole videos are plotted instead of segments
1496        transparent : bool, default False
1497            if `True`, the background on the plot is transparent
1498        dataset : BehaviorDataset | DataLoader | str | None, optional
1499            the dataset to make the prediction for (if not provided, the validation dataset is used)
1500        drop_classes : set, optional
1501            a set of class names to not be displayed
1502        search_classes : set, optional
1503            if given, only intervals where at least one of the classes is in ground truth will be shown
1504        smooth_interval_prediction : int, optional
1505            if given, the prediction will be smoothed with a moving average of the given size
1506
1507        """
1508        if drop_classes is None:
1509            drop_classes = []
1510        dataset = self._get_dataset(dataset)
1511        if dataset.annotation_class() != "exclusive_classification":
1512            raise NotImplementedError(
1513                "Results visualisation is only implemented for exclusive classification datasets!"
1514            )
1515        labels = {
1516            k: v for k, v in dataset.behaviors_dict().items() if v not in drop_classes
1517        }
1518        labels.update({-100: "unknown"})
1519        label_keys = sorted([int(x) for x in labels.keys()])
1520        if search_classes is None:
1521            ok = True
1522        else:
1523            ok = False
1524        classes = []
1525        predicted = self.generate_full_length_prediction(dataset)
1526        keys = list(predicted.keys())
1527        counter = 0
1528        max_iter = len(keys) * 2
1529        while len(classes) < min_classes or not ok:
1530            counter += 1
1531            if counter > max_iter:
1532                raise RuntimeError(
1533                    "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive"
1534                )
1535            i = randint(0, len(keys) - 1)
1536            prediction = predicted[keys[i]]
1537            keys_i = list(prediction.keys())
1538            j = randint(0, len(keys_i) - 1)
1539            prediction = prediction[keys_i[j]]
1540            key1 = keys[i]
1541            key2 = keys_i[j]
1542
1543            if smooth_interval_prediction > 0:
1544                unsmoothed_prediction = deepcopy(prediction)
1545                prediction = self._smooth(prediction, smooth_interval_prediction)
1546                height = 3
1547            else:
1548                height = 2
1549            classes = [
1550                labels[int(x)] for x in torch.unique(prediction) if x in label_keys
1551            ]
1552            if search_classes is not None:
1553                ok = any([x in classes for x in search_classes])
1554        fig, ax = plt.subplots(figsize=(width, height))
1555        cmap = cm.get_cmap(colormap) if colormap != "dlc2action" else None
1556        color_list = (
1557            [cmap(c) for c in np.linspace(0, 1, len(labels))]
1558            if colormap != "dlc2action"
1559            else dlc2action_colormaps["default"]
1560        )
1561
1562        def _plot_prediction(prediction, height, set_labels=True):
1563            for c in label_keys:
1564                c_i = label_keys.index(int(c))
1565                output, indices, counts = torch.unique_consecutive(
1566                    prediction == c, return_inverse=True, return_counts=True
1567                )
1568                long_indices = torch.where(output)[0]
1569                if len(long_indices) == 0:
1570                    continue
1571                res_indices_start = [
1572                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1573                    for i in long_indices
1574                ]
1575                res_indices_end = [
1576                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1577                    for i in long_indices
1578                ]
1579                res_indices_len = [
1580                    end - start
1581                    for start, end in zip(res_indices_start, res_indices_end)
1582                ]
1583                if set_labels:
1584                    label = labels[int(c)]
1585                else:
1586                    label = None
1587                ax.broken_barh(
1588                    list(zip(res_indices_start, res_indices_len)),
1589                    (height, 1),
1590                    label=label,
1591                    facecolors=color_list[c_i],
1592                )
1593
1594        if smooth_interval_prediction > 0:
1595            _plot_prediction(unsmoothed_prediction, 0)
1596            _plot_prediction(prediction, 1.5, set_labels=False)
1597            gt_height = 3
1598        else:
1599            _plot_prediction(prediction, 0)
1600            gt_height = 1.5
1601        if ground_truth:
1602            gt = dataset.generate_full_length_gt()[key1][key2]
1603            for c in label_keys:
1604                c_i = label_keys.index(int(c))
1605                if labels[int(c)] in classes:
1606                    label = None
1607                else:
1608                    label = labels[int(c)]
1609                output, indices, counts = torch.unique_consecutive(
1610                    gt == c, return_inverse=True, return_counts=True
1611                )
1612                long_indices = torch.where(output)[0]
1613                if len(long_indices) == 0:
1614                    continue
1615                res_indices_start = [
1616                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1617                    for i in long_indices
1618                ]
1619                res_indices_end = [
1620                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1621                    for i in long_indices
1622                ]
1623                res_indices_len = [
1624                    end - start
1625                    for start, end in zip(res_indices_start, res_indices_end)
1626                ]
1627                ax.broken_barh(
1628                    list(zip(res_indices_start, res_indices_len)),
1629                    (gt_height, 1),
1630                    facecolors=color_list[c_i] if c != "unknown" else "gray",
1631                    label=label,
1632                )
1633        if not ground_truth:
1634            ax.axes.yaxis.set_visible(False)
1635        else:
1636            if smooth_interval_prediction > 0:
1637                ax.set_yticks([0.5, 2, 3.5])
1638                ax.set_yticklabels(["prediction", "smoothed", "ground truth"])
1639            else:
1640                ax.set_yticks([0.5, 2])
1641                ax.set_yticklabels(["prediction", "ground truth"])
1642        if add_legend:
1643            ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1644        if hide_axes:
1645            # plt.axis("off")
1646            plt.box(False)
1647        if font_size is not None:
1648            font = {"size": font_size}
1649            rc("font", **font)
1650        plt.title(f"{key1} -- {key2}")
1651        plt.tight_layout()
1652        for seed in range(num_plots):
1653            if not whole_video:
1654                ax.set_xlim((seed * window_size, (seed + 1) * window_size))
1655            if save_path is not None:
1656                plt.savefig(
1657                    save_path.replace(".svg", f"_{seed}_{key1} -- {key2}.svg"), transparent=transparent
1658                )
1659                print(f"Saved in {save_path.replace('.svg', f'_{seed}_{key1} -- {key2}.svg')}")
1660        plt.show()
1661        plt.close()
1662
1663    def set_ssl_transformations(self, ssl_transformations):
1664        """Set SSL transformations.
1665
1666        Parameters
1667        ----------
1668        ssl_transformations : list
1669            a list of callable SSL transformations
1670
1671        """
1672        self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations)
1673        if self.val_dataloader is not None:
1674            self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations)
1675
1676    def set_ssl_losses(self, ssl_losses: list) -> None:
1677        """Set SSL losses.
1678
1679        Parameters
1680        ----------
1681        ssl_losses : list
1682            a list of callable SSL losses
1683
1684        """
1685        self.ssl_losses = ssl_losses
1686
1687    def set_log(self, log: str) -> None:
1688        """Set the log file.
1689
1690        Parameters
1691        ----------
1692        log: str
1693            the mew log file path
1694
1695        """
1696        self.log_file = log
1697
1698    def set_keep_target_none(self, keep_target_none: List) -> None:
1699        """Set the keep_target_none parameter of the transformer.
1700
1701        Parameters
1702        ----------
1703        keep_target_none : list
1704            a list of bool values
1705
1706        """
1707        self.transformer.keep_target_none = keep_target_none
1708
1709    def set_generate_ssl_input(self, generate_ssl_input: list) -> None:
1710        """Set the generate_ssl_input parameter of the transformer.
1711
1712        Parameters
1713        ----------
1714        generate_ssl_input : list
1715            a list of bool values
1716
1717        """
1718        self.transformer.generate_ssl_input = generate_ssl_input
1719
1720    def set_model(self, model: Model) -> None:
1721        """Set a new model.
1722
1723        Parameters
1724        ----------
1725        model: Model
1726            the new model
1727
1728        """
1729        self.epoch = 0
1730        self.model = model
1731        self.optimizer = self.optimizer_class(
1732            model.parameters(), lr=self.lr, weight_decay=self.weight_decay
1733        )
1734        if self.model.process_labels:
1735            self.model.set_behaviors(list(self.behaviors_dict().values()))
1736
1737    def set_dataloaders(
1738        self,
1739        train_dataloader: DataLoader,
1740        val_dataloader: DataLoader = None,
1741        test_dataloader: DataLoader = None,
1742    ) -> None:
1743        """Set new dataloaders.
1744
1745        Parameters
1746        ----------
1747        train_dataloader: torch.utils.data.DataLoader
1748            the new train dataloader
1749        val_dataloader : torch.utils.data.DataLoader
1750            the new validation dataloader
1751        test_dataloader : torch.utils.data.DataLoader
1752            the new test dataloader
1753
1754        """
1755        self.train_dataloader = train_dataloader
1756        self.val_dataloader = val_dataloader
1757        self.test_dataloader = test_dataloader
1758
1759    def set_loss(self, loss: Callable) -> None:
1760        """Set new loss function.
1761
1762        Parameters
1763        ----------
1764        loss: callable
1765            the new loss function
1766
1767        """
1768        self.loss = loss
1769
1770    def set_metrics(self, metrics: dict) -> None:
1771        """Set new metric.
1772
1773        Parameters
1774        ----------
1775        metrics : dict
1776            the new metric dictionary
1777
1778        """
1779        self.metrics = metrics
1780
1781    def set_transformer(self, transformer: Transformer) -> None:
1782        """Set a new transformer.
1783
1784        Parameters
1785        ----------
1786        transformer: Transformer
1787            the new transformer
1788
1789        """
1790        self.transformer = transformer
1791
1792    def set_predict_functions(
1793        self, primary_predict_function: Callable, predict_function: Callable
1794    ) -> None:
1795        """Set new predict functions.
1796
1797        Parameters
1798        ----------
1799        primary_predict_function : callable
1800            the new primary predict function
1801        predict_function : callable
1802            the new predict function
1803
1804        """
1805        self.primary_predict_function = primary_predict_function
1806        self.predict_function = predict_function
1807
1808    def _set_pseudolabels(self):
1809        """Set pseudolabels."""
1810        self.train_dataloader.dataset.set_unlabeled(True)
1811        predicted = self.predict(
1812            data=self.dataset("train"),
1813            raw_output=False,
1814            augment_n=self.augment_val,
1815            ssl_off=True,
1816        )
1817        self.train_dataloader.dataset.set_annotation(predicted.detach())
1818
1819    def _alpha(self, epoch):
1820        """Get the current pseudolabeling alpha parameter."""
1821        if epoch <= self.T1:
1822            return 0
1823        elif epoch < self.T2:
1824            return self.alpha_f * (epoch - self.T1) / (self.T2 - self.T1)
1825        else:
1826            return self.alpha_f
1827
1828    def count_classes(self, bouts: bool = False) -> Dict:
1829        """Get a dictionary of class counts in different modes.
1830
1831        Parameters
1832        ----------
1833        bouts : bool, default False
1834            if `True`, instead of frame counts segment counts are returned
1835
1836        Returns
1837        -------
1838        class_counts : dict
1839            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1840            class names and values are class counts (in frames)
1841
1842        """
1843        class_counts = {}
1844        for x in ["train", "val", "test"]:
1845            try:
1846                class_counts[x] = self.dataset(x).count_classes(bouts)
1847            except ValueError:
1848                class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()}
1849        return class_counts
1850
1851    def behaviors_dict(self) -> Dict:
1852        """Get a behavior dictionary.
1853
1854        Keys are label indices and values are label names.
1855
1856        Returns
1857        -------
1858        behaviors_dict : dict
1859            behavior dictionary
1860
1861        """
1862        return self.dataset().behaviors_dict()
1863
1864    def update_parameters(self, parameters: Dict) -> None:
1865        """Update training parameters from a dictionary.
1866
1867        Parameters
1868        ----------
1869        parameters : dict
1870            the update dictionary
1871
1872        """
1873        self.lr = parameters.get("lr", self.lr)
1874        self.parallel = parameters.get("parallel", self.parallel)
1875        self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr)
1876        self.verbose = parameters.get("verbose", self.verbose)
1877        self.device = parameters.get("device", self.device)
1878        if self.device == "auto":
1879            self.device = "cuda" if torch.cuda.is_available() else "cpu"
1880        self.augment_train = int(parameters.get("augment_train", self.augment_train))
1881        self.augment_val = int(parameters.get("augment_val", self.augment_val))
1882        ssl_weights = parameters.get("ssl_weights", self.ssl_weights)
1883        if ssl_weights is None:
1884            ssl_weights = [1 for _ in self.ssl_losses]
1885        if not isinstance(ssl_weights, Iterable):
1886            ssl_weights = [ssl_weights for _ in self.ssl_losses]
1887        self.ssl_weights = ssl_weights
1888        self.num_epochs = parameters.get("num_epochs", self.num_epochs)
1889        self.model_save_epochs = parameters.get(
1890            "model_save_epochs", self.model_save_epochs
1891        )
1892        self.model_save_path = parameters.get("model_save_path", self.model_save_path)
1893        self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel)
1894        self.T1 = int(parameters.get("pseudolabel_start", self.T1))
1895        self.T2 = int(parameters.get("alpha_growth_stop", self.T2))
1896        self.t = int(parameters.get("correction_interval", self.t))
1897        self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f)
1898        self.log_file = parameters.get("log_file", self.log_file)
1899
1900    def generate_uncertainty_score(
1901        self,
1902        classes: List,
1903        augment_n: int = 0,
1904        batch_size: int = 32,
1905        method: str = "least_confidence",
1906        predicted: torch.Tensor = None,
1907        behaviors_dict: Dict = None,
1908    ) -> Dict:
1909        """Generate frame-wise scores for active learning.
1910
1911        Parameters
1912        ----------
1913        classes : list
1914            a list of class names or indices; their confidence scores will be computed separately and stacked
1915        augment_n : int, default 0
1916            the number of augmentations to average over
1917        batch_size : int, default 32
1918            the batch size
1919        method : {"least_confidence", "entropy"}
1920            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1921            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1922        predicted : torch.Tensor, default None
1923            if not `None`, the predicted tensor to use instead of predicting from the model
1924        behaviors_dict : dict, default None
1925            if not `None`, the behaviors dictionary to use instead of the one from the dataset
1926
1927        Returns
1928        -------
1929        score_dicts : dict
1930            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1931            are score tensors
1932
1933        """
1934        dataset = self.dataset("train")
1935        if behaviors_dict is None:
1936            behaviors_dict = self.behaviors_dict()
1937        if not isinstance(dataset, BehaviorDataset):
1938            raise TypeError(
1939                f"The dataset parameter has to be either None, string, "
1940                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1941            )
1942        if predicted is None:
1943            predicted = self.predict(
1944                dataset,
1945                raw_output=True,
1946                apply_primary_function=True,
1947                augment_n=augment_n,
1948                batch_size=batch_size,
1949            )
1950        predicted = dataset.generate_full_length_prediction(predicted)
1951        if isinstance(classes[0], str):
1952            behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()}
1953            classes = [behaviors_dict_inverse[c] for c in classes]
1954        for v_id, v in predicted.items():
1955            for clip_id, vv in v.items():
1956                if method == "least_confidence":
1957                    predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5]
1958                elif method == "entropy":
1959                    predicted[v_id][clip_id][vv != -100] = (
1960                        -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv)
1961                    )[vv != -100]
1962                elif method == "random":
1963                    predicted[v_id][clip_id] = torch.rand(vv.shape)
1964                else:
1965                    raise ValueError(
1966                        f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']"
1967                    )
1968                predicted[v_id][clip_id][vv == -100] = 0
1969
1970        predicted = {
1971            v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
1972            for v_id, video_dict in predicted.items()
1973        }
1974        return predicted
1975
1976    def generate_bald_score(
1977        self,
1978        classes: List,
1979        augment_n: int = 0,
1980        batch_size: int = 32,
1981        num_models: int = 10,
1982        kernel_size: int = 11,
1983    ) -> Dict:
1984        """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning.
1985
1986        Parameters
1987        ----------
1988        classes : list
1989            a list of class names or indices; their confidence scores will be computed separately and stacked
1990        augment_n : int, default 0
1991            the number of augmentations to average over
1992        batch_size : int, default 32
1993            the batch size
1994        num_models : int, default 10
1995            the number of dropout masks to apply
1996        kernel_size : int, default 11
1997            the size of the smoothing gaussian kernel
1998
1999        Returns
2000        -------
2001        score_dicts : dict
2002            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
2003            are score tensors
2004
2005        """
2006        dataset = self.dataset("train")
2007        dataset = self._get_dataset(dataset)
2008        if not isinstance(dataset, BehaviorDataset):
2009            raise TypeError(
2010                f"The dataset parameter has to be either None, string, "
2011                f"BehaviorDataset or Dataloader, got {type(dataset)}"
2012            )
2013        predictions = []
2014        for _ in range(num_models):
2015            predicted = self.predict(
2016                dataset,
2017                raw_output=True,
2018                apply_primary_function=True,
2019                augment_n=augment_n,
2020                batch_size=batch_size,
2021                train_mode=True,
2022            )
2023            predicted = dataset.generate_full_length_prediction(predicted)
2024            if isinstance(classes[0], str):
2025                behaviors_dict_inverse = {
2026                    v: k for k, v in self.behaviors_dict().items()
2027                }
2028                classes = [behaviors_dict_inverse[c] for c in classes]
2029            for v_id, v in predicted.items():
2030                for clip_id, vv in v.items():
2031                    vv[vv != -100] = (vv[vv != -100] > 0.5).int().float()
2032                    predicted[v_id][clip_id] = vv
2033            predicted = {
2034                v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
2035                for v_id, video_dict in predicted.items()
2036            }
2037            predictions.append(predicted)
2038        result = {v_id: {} for v_id in predictions[0]}
2039        r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1)
2040        gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r]
2041        gauss = [x / sum(gauss) for x in gauss]
2042        kernel = torch.FloatTensor([[gauss]])
2043        for v_id in predictions[0]:
2044            for clip_id in predictions[0][v_id]:
2045                consensus = (
2046                    (
2047                        torch.mean(
2048                            torch.stack([x[v_id][clip_id] for x in predictions]), dim=0
2049                        )
2050                        > 0.5
2051                    )
2052                    .int()
2053                    .float()
2054                )
2055                consensus[predictions[0][v_id][clip_id] == -100] = -100
2056                result[v_id][clip_id] = torch.zeros(consensus.shape)
2057                for x in predictions:
2058                    result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int()
2059                result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models
2060                res = torch.zeros(result[v_id][clip_id].shape)
2061                for i in range(len(classes)):
2062                    res[i, floor(kernel_size // 2) : -floor(kernel_size // 2)] = (
2063                        torch.nn.functional.conv1d(
2064                            result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0),
2065                            kernel,
2066                        )[0, ...]
2067                    )
2068                result[v_id][clip_id] = res
2069        return result
2070
2071    def get_normalization_stats(self) -> Optional[Dict]:
2072        """Get the normalization statistics of the dataset.
2073
2074        Returns
2075        -------
2076        stats : dict
2077            a dictionary containing the mean and standard deviation of the dataset
2078
2079        """
2080        return self.train_dataloader.dataset.stats
class MyDataParallel(torch.nn.modules.module.Module, typing.Generic[~T]):
39class MyDataParallel(nn.DataParallel):
40    """A wrapper for nn.DataParallel that allows to call methods of the model."""
41
42    def __init__(self, *args, **kwargs):
43        """Initialize the class."""
44        super().__init__(*args, **kwargs)
45        self.process_labels = self.module.process_labels
46
47    def freeze_feature_extractor(self):
48        """Freeze the feature extractor."""
49        self.module.freeze_feature_extractor()
50
51    def unfreeze_feature_extractor(self):
52        """Unfreeze the feature extractor."""
53        self.module.unfreeze_feature_extractor()
54
55    def transform_labels(self, device):
56        """Transform labels to the device."""
57        return self.module.transform_labels(device)
58
59    def logit_scale(self):
60        """Return the logit scale of the model."""
61        return self.module.logit_scale()
62
63    def main_task_off(self):
64        """Turn off the main task training."""
65        self.module.main_task_off()
66
67    def state_dict(self, *args, **kwargs):
68        """Return the state of the module."""
69        return self.module.state_dict(*args, **kwargs)
70
71    def ssl_on(self):
72        """Turn on SSL training."""
73        self.module.ssl_on()
74
75    def ssl_off(self):
76        """Turn off SSL training."""
77        self.module.ssl_off()
78
79    def extract_features(self, x, start=0):
80        """Extract features from the model."""
81        return self.module.extract_features(x, start)

A wrapper for nn.DataParallel that allows to call methods of the model.

process_labels
def freeze_feature_extractor(self):
47    def freeze_feature_extractor(self):
48        """Freeze the feature extractor."""
49        self.module.freeze_feature_extractor()

Freeze the feature extractor.

def unfreeze_feature_extractor(self):
51    def unfreeze_feature_extractor(self):
52        """Unfreeze the feature extractor."""
53        self.module.unfreeze_feature_extractor()

Unfreeze the feature extractor.

def transform_labels(self, device):
55    def transform_labels(self, device):
56        """Transform labels to the device."""
57        return self.module.transform_labels(device)

Transform labels to the device.

def logit_scale(self):
59    def logit_scale(self):
60        """Return the logit scale of the model."""
61        return self.module.logit_scale()

Return the logit scale of the model.

def main_task_off(self):
63    def main_task_off(self):
64        """Turn off the main task training."""
65        self.module.main_task_off()

Turn off the main task training.

def ssl_on(self):
71    def ssl_on(self):
72        """Turn on SSL training."""
73        self.module.ssl_on()

Turn on SSL training.

def ssl_off(self):
75    def ssl_off(self):
76        """Turn off SSL training."""
77        self.module.ssl_off()

Turn off SSL training.

def extract_features(self, x, start=0):
79    def extract_features(self, x, start=0):
80        """Extract features from the model."""
81        return self.module.extract_features(x, start)

Extract features from the model.

class Task:
  84class Task:
  85    """A universal trainer class that performs training, evaluation and prediction for all types of tasks and data."""
  86
  87    def __init__(
  88        self,
  89        train_dataloader: DataLoader,
  90        model: Union[nn.Module, Model],
  91        loss: Callable[[torch.Tensor, torch.Tensor], float],
  92        num_epochs: int = 0,
  93        transformer: Transformer = None,
  94        ssl_losses: List = None,
  95        ssl_weights: List = None,
  96        lr: float = 1e-3,
  97        weight_decay: float = 0,
  98        metrics: Dict = None,
  99        val_dataloader: DataLoader = None,
 100        test_dataloader: DataLoader = None,
 101        optimizer: Optimizer = None,
 102        device: str = "cuda",
 103        verbose: bool = True,
 104        log_file: Union[str, None] = None,
 105        augment_train: int = 1,
 106        augment_val: int = 0,
 107        validation_interval: int = 1,
 108        predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None,
 109        primary_predict_function: Callable = None,
 110        exclusive: bool = True,
 111        ignore_tags: bool = True,
 112        threshold: float = 0.5,
 113        model_save_path: str = None,
 114        model_save_epochs: int = 5,
 115        pseudolabel: bool = False,
 116        pseudolabel_start: int = 100,
 117        correction_interval: int = 2,
 118        pseudolabel_alpha_f: float = 3,
 119        alpha_growth_stop: int = 600,
 120        parallel: bool = False,
 121        skip_metrics: List = None,
 122    ) -> None:
 123        """Initialize the class.
 124
 125        Parameters
 126        ----------
 127        train_dataloader : torch.utils.data.DataLoader
 128            a training dataloader
 129        model : dlc2action.model.base_model.Model
 130            a model
 131        loss : callable
 132            a loss function
 133        num_epochs : int, default 0
 134            the number of epochs
 135        transformer : dlc2action.transformer.base_transformer.Transformer, optional
 136            a transformer
 137        ssl_losses : list, optional
 138            a list of SSL losses
 139        ssl_weights : list, optional
 140            a list of SSL weights (if not provided initializes to 1)
 141        lr : float, default 1e-3
 142            learning rate
 143        weight_decay : float, default 0
 144            weight decay
 145        metrics : dict, optional
 146            a list of metric functions
 147        val_dataloader : torch.utils.data.DataLoader, optional
 148            a validation dataloader
 149        test_dataloader : torch.utils.data.DataLoader, optional
 150            a test dataloader
 151        optimizer : torch.optim.Optimizer, optional
 152            an optimizer (`Adam` by default)
 153        device : str, default 'cuda'
 154            the device to train the model on
 155        verbose : bool, default True
 156            if `True`, the process is described in standard output
 157        log_file : str, optional
 158            the path to a text file where the process will be logged
 159        augment_train : {1, 0}
 160            number of augmentations to apply at training
 161        augment_val : int, default 0
 162            number of augmentations to apply at validation
 163        validation_interval : int, default 1
 164            every time this number of epochs passes, validation metrics are computed
 165        predict_function : callable, optional
 166            a function that maps probabilities to class predictions (if not provided, a default is generated)
 167        primary_predict_function : callable, optional
 168            a function that maps model output to probabilities (if not provided, initialized as identity)
 169        exclusive : bool, default True
 170            set to False for multi-label classification
 171        ignore_tags : bool, default False
 172            if `True`, samples with different meta tags will be mixed in batches
 173        threshold : float, default 0.5
 174            the threshold used for multi-label classification default prediction function
 175        model_save_path : str, optional
 176            the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path
 177            is not provided)
 178        model_save_epochs : int, default 5
 179            the interval for saving the model checkpoints (the last epoch is always saved)
 180        pseudolabel : bool, default False
 181            if True, the pseudolabeling procedure will be applied
 182        pseudolabel_start : int, default 100
 183            pseudolabeling starts after this epoch
 184        correction_interval : int, default 1
 185            after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and
 186            new pseudolabels are generated
 187        pseudolabel_alpha_f : float, default 3
 188            the maximum value of pseudolabeling alpha
 189        alpha_growth_stop : int, default 600
 190            pseudolabeling alpha stops growing after this epoch
 191        parallel : bool, default False
 192            if True, the model is trained on multiple GPUs
 193        skip_metrics : list, optional
 194            a list of metrics to skip
 195
 196        """
 197        # pseudolabeling might be buggy right now -- not using it!
 198        if skip_metrics is None:
 199            skip_metrics = []
 200        self.train_dataloader = train_dataloader
 201        self.val_dataloader = val_dataloader
 202        self.test_dataloader = test_dataloader
 203        self.transformer = transformer
 204        self.num_epochs = num_epochs
 205        self.skip_metrics = skip_metrics
 206        self.verbose = verbose
 207        self.augment_train = int(augment_train)
 208        self.augment_val = int(augment_val)
 209        self.ignore_tags = ignore_tags
 210        self.validation_interval = int(validation_interval)
 211        self.log_file = log_file
 212        self.loss = loss
 213        self.model_save_path = model_save_path
 214        self.model_save_epochs = model_save_epochs
 215        self.epoch = 0
 216
 217        if metrics is None:
 218            metrics = {}
 219        self.metrics = metrics
 220
 221        if optimizer is None:
 222            optimizer = Adam
 223
 224        if ssl_weights is None:
 225            ssl_weights = [1 for _ in ssl_losses]
 226        if not isinstance(ssl_weights, Iterable):
 227            ssl_weights = [ssl_weights for _ in ssl_losses]
 228        self.ssl_weights = ssl_weights
 229
 230        self.optimizer_class = optimizer
 231        self.lr = lr
 232        self.weight_decay = weight_decay
 233        if not isinstance(model, Model):
 234            self.model = LoadedModel(model=model)
 235        else:
 236            self.set_model(model)
 237        self.parallel = parallel
 238
 239        if self.transformer is None:
 240            self.augment_val = 0
 241            self.augment_train = 0
 242            self.transformer = EmptyTransformer()
 243
 244        if self.augment_train > 1:
 245            warnings.warn(
 246                'The "augment_train" parameter is too large -> setting it to 1.'
 247            )
 248            self.augment_train = 1
 249
 250        try:
 251            if device == "auto":
 252                device = "cuda" if torch.cuda.is_available() else "cpu"
 253            self.device = torch.device(device)
 254        except:
 255            raise ("The format of the device is incorrect")
 256
 257        if ssl_losses is None:
 258            self.ssl_losses = [lambda x, y: 0]
 259        else:
 260            self.ssl_losses = ssl_losses
 261
 262        if primary_predict_function is None:
 263            if exclusive:
 264                primary_predict_function = lambda x: nn.Softmax(x, dim=1)
 265            else:
 266                primary_predict_function = lambda x: torch.sigmoid(x)
 267        self.primary_predict_function = primary_predict_function
 268
 269        if predict_function is None:
 270            if exclusive:
 271                self.predict_function = lambda x: torch.max(x.data, 1)[1]
 272            else:
 273                self.predict_function = lambda x: (x > threshold).int()
 274        else:
 275            self.predict_function = predict_function
 276
 277        self.pseudolabel = pseudolabel
 278        self.alpha_f = pseudolabel_alpha_f
 279        self.T2 = alpha_growth_stop
 280        self.T1 = pseudolabel_start
 281        self.t = correction_interval
 282        if self.T2 <= self.T1:
 283            raise ValueError(
 284                f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got "
 285                f"{pseudolabel_start=} and {alpha_growth_stop=}"
 286            )
 287        self.decision_thresholds = [0.5 for x in self.behaviors_dict()]
 288
 289    def save_checkpoint(self, checkpoint_path: str) -> None:
 290        """Save a general checkpoint.
 291
 292        Parameters
 293        ----------
 294        checkpoint_path : str
 295            the path where the checkpoint will be saved
 296
 297        """
 298        torch.save(
 299            {
 300                "epoch": self.epoch,
 301                "model_state_dict": self.model.state_dict(),
 302                "optimizer_state_dict": self.optimizer.state_dict(),
 303            },
 304            checkpoint_path,
 305        )
 306
 307    def load_from_checkpoint(
 308        self, checkpoint_path, only_model: bool = False, load_strict: bool = True
 309    ) -> None:
 310        """Load from a checkpoint.
 311
 312        Parameters
 313        ----------
 314        checkpoint_path : str
 315            the path to the checkpoint
 316        only_model : bool, default False
 317            if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state
 318            dictionary)
 319        load_strict : bool, default True
 320            if `True`, any inconsistencies in state dictionaries are regarded as errors
 321
 322        """
 323        if checkpoint_path is None:
 324            return
 325        checkpoint = torch.load(
 326            checkpoint_path, map_location=self.device, weights_only=False
 327        )
 328        self.model.to(self.device)
 329        self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict)
 330        if not only_model:
 331            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
 332            self.epoch = checkpoint["epoch"]
 333
 334    def save_model(self, save_path: str) -> None:
 335        """Save the model state dictionary.
 336
 337        Parameters
 338        ----------
 339        save_path : str
 340            the path where the state will be saved
 341
 342        """
 343        torch.save(self.model.state_dict(), save_path)
 344        print("saved the model successfully")
 345
 346    def _apply_predict_functions(self, predicted):
 347        """Map from model output to prediction."""
 348        predicted = self.primary_predict_function(predicted)
 349        predicted = self.predict_function(predicted)
 350        return predicted
 351
 352    def _get_prediction(
 353        self,
 354        main_input: Dict,
 355        tags: torch.Tensor,
 356        ssl_inputs: List = None,
 357        ssl_targets: List = None,
 358        augment_n: int = 0,
 359        embedding: bool = False,
 360        subsample: List = None,
 361    ) -> Tuple:
 362        """Get the prediction of `self.model` for input averaged over `augment_n` augmentations."""
 363        if augment_n == 0:
 364            augment_n = 1
 365            augment = False
 366        else:
 367            augment = True
 368        model_input, ssl_inputs, ssl_targets = self.transformer.transform(
 369            deepcopy(main_input),
 370            ssl_inputs=ssl_inputs,
 371            ssl_targets=ssl_targets,
 372            augment=augment,
 373            subsample=subsample,
 374        )
 375        if not embedding:
 376            # t1 = time()
 377            predicted, ssl_predicted = self.model(model_input, ssl_inputs, tag=tags)
 378            # t2 = time()
 379            # with open("/home/andy/Documents/EPFLsmartkitchenrecording/baselines/DLC2Action_baselines/inference_times/body_hand_eyes_2405.txt", "a") as f:
 380            # f.write(f"Time: {t2 - t1}\n")
 381        else:
 382            predicted = self.model.extract_features(model_input)
 383            ssl_predicted = None
 384        if self.parallel and predicted is not None and len(predicted.shape) == 4:
 385            predicted = predicted.reshape(
 386                (-1, model_input.shape[0], *predicted.shape[2:])
 387            )
 388        if predicted is not None and augment_n > 1:
 389            self.model.ssl_off()
 390            for i in range(augment_n - 1):
 391                model_input, *_ = self.transformer.transform(
 392                    deepcopy(main_input), augment=augment
 393                )
 394                if not embedding:
 395                    pred, _ = self.model(model_input, None, tag=tags)
 396                else:
 397                    pred = self.model.extract_features(model_input)
 398                predicted += pred.detach()
 399            self.model.ssl_on()
 400            predicted /= augment_n
 401        if self.model.process_labels:
 402            class_embedding = self.model.transform_labels(predicted.device)
 403            beh_dict = self.behaviors_dict()
 404            class_embedding = torch.cat(
 405                [class_embedding[beh_dict[k]] for k in sorted(beh_dict.keys())], 0
 406            )
 407            predicted = {
 408                "video": predicted,
 409                "text": class_embedding,
 410                "logit_scale": self.model.logit_scale(),
 411                "device": predicted.device,
 412            }
 413        return predicted, ssl_predicted, ssl_targets
 414
 415    def _ssl_loss(self, ssl_predicted, ssl_targets):
 416        """Apply SSL losses."""
 417        if self.ssl_losses is None or ssl_predicted is None or ssl_targets is None:
 418            return []
 419
 420        ssl_loss = []
 421        for loss, predicted, target in zip(self.ssl_losses, ssl_predicted, ssl_targets):
 422            ssl_loss.append(loss(predicted, target))
 423        return ssl_loss
 424
 425    def _loss_function(
 426        self,
 427        batch: Dict,
 428        augment_n: int,
 429        temporal_subsampling_size: int = None,
 430        skip_metrics: List = None,
 431    ) -> Tuple[float, float]:
 432        """Calculate the loss function and the metric for a dataloader batch.
 433
 434        Averaging the predictions over augment_n augmentations.
 435
 436        """
 437        if "target" not in batch or torch.isnan(batch["target"]).all():
 438            raise ValueError("Cannot compute loss function with nan targets!")
 439        main_input = {k: v.to(self.device) for k, v in batch["input"].items()}
 440        main_target, ssl_targets, ssl_inputs = None, None, None
 441        main_target = batch["target"].to(self.device)
 442        if "ssl_targets" in batch:
 443            ssl_targets = [
 444                (
 445                    {k: v.to(self.device) for k, v in x.items()}
 446                    if isinstance(x, dict)
 447                    else None
 448                )
 449                for x in batch["ssl_targets"]
 450            ]
 451        if "ssl_inputs" in batch:
 452            ssl_inputs = [
 453                (
 454                    {k: v.to(self.device) for k, v in x.items()}
 455                    if isinstance(x, dict)
 456                    else None
 457                )
 458                for x in batch["ssl_inputs"]
 459            ]
 460        if temporal_subsampling_size is not None:
 461            subsample = sorted(
 462                random.sample(
 463                    range(main_target.shape[-1]),
 464                    int(temporal_subsampling_size * main_target.shape[-1]),
 465                )
 466            )
 467            main_target = main_target[..., subsample]
 468        else:
 469            subsample = None
 470
 471        if self.ignore_tags:
 472            tag = None
 473        else:
 474            tag = batch.get("tag")
 475        predicted, ssl_predicted, ssl_targets = self._get_prediction(
 476            main_input, tag, ssl_inputs, ssl_targets, augment_n, subsample=subsample
 477        )
 478        del main_input, ssl_inputs
 479        return self._compute(
 480            ssl_predicted,
 481            ssl_targets,
 482            predicted,
 483            main_target,
 484            tag=batch.get("tag"),
 485            skip_metrics=skip_metrics,
 486        )
 487
 488    def _compute(
 489        self,
 490        ssl_predicted: List,
 491        ssl_targets: List,
 492        predicted: torch.Tensor,
 493        main_target: torch.Tensor,
 494        tag: Any = None,
 495        skip_loss: bool = False,
 496        apply_primary_function: bool = True,
 497        skip_metrics: List = None,
 498    ) -> Tuple[float, float]:
 499        """Compute the losses and metrics from predictions."""
 500        if skip_metrics is None:
 501            skip_metrics = []
 502        if not skip_loss:
 503            ssl_losses = self._ssl_loss(ssl_predicted, ssl_targets)
 504            if predicted is not None:
 505                loss = self.loss(predicted, main_target)
 506            else:
 507                loss = 0
 508        else:
 509            ssl_losses, loss = [], 0
 510
 511        if predicted is not None:
 512            if isinstance(predicted, dict):
 513                predicted = {
 514                    k: v.detach()
 515                    for k, v in predicted.items()
 516                    if isinstance(v, torch.Tensor)
 517                }
 518            else:
 519                predicted = predicted.detach()
 520            if apply_primary_function:
 521                predicted = self.primary_predict_function(predicted)
 522            predicted_transformed = self.predict_function(predicted)
 523
 524            for name, metric_function in self.metrics.items():
 525                if name not in skip_metrics:
 526                    if metric_function.needs_raw_data:
 527                        metric_function.update(predicted, main_target, tag)
 528                    else:
 529                        metric_function.update(
 530                            predicted_transformed,
 531                            main_target,
 532                            tag,
 533                        )
 534        return loss, ssl_losses
 535
 536    def _calculate_metrics(self) -> Dict:
 537        """Calculate the final values of epoch metrics."""
 538        epoch_metrics = {}
 539        for metric_name, metric in self.metrics.items():
 540            m = metric.calculate()
 541            if type(m) is dict:
 542                for k, v in m.items():
 543                    if type(v) is torch.Tensor:
 544                        v = v.item()
 545                    epoch_metrics[f"{metric_name}_{k}"] = v
 546            else:
 547                if type(m) is torch.Tensor:
 548                    m = m.item()
 549                epoch_metrics[metric_name] = m
 550            metric.reset()
 551        # print("calculate metric", epoch_metrics)
 552        return epoch_metrics
 553
 554    def _run_epoch(
 555        self,
 556        dataloader: DataLoader,
 557        mode: str,
 558        augment_n: int,
 559        verbose: bool = False,
 560        unlabeled: bool = None,
 561        alpha: float = 1,
 562        temporal_subsampling_size: int = None,
 563    ) -> Tuple:
 564        """Run one epoch on dataloader.
 565
 566        Averaging the predictions over augment_n augmentations.
 567        Use "train" mode for training and "val" mode for evaluation.
 568
 569        """
 570        if mode == "train":
 571            self.model.train()
 572        elif mode == "val":
 573            self.model.eval()
 574            pass
 575        else:
 576            raise ValueError(
 577                f'Mode {mode} is not recognized, please choose either "train" for training or "val" for validation'
 578            )
 579        if self.ignore_tags:
 580            tags = [None]
 581        else:
 582            tags = dataloader.dataset.get_tags()
 583        epoch_loss = 0
 584        epoch_ssl_loss = defaultdict(lambda: 0)
 585        data_len = 0
 586        set_pars = dataloader.dataset.set_indexing_parameters
 587        skip_metrics = self.skip_metrics if mode == "train" else None
 588        for tag in tags:
 589            set_pars(unlabeled=unlabeled, tag=tag)
 590            data_len += len(dataloader)
 591            if verbose:
 592                dataloader = tqdm(dataloader)
 593            for batch in dataloader:
 594                loss, ssl_losses = self._loss_function(
 595                    batch,
 596                    augment_n,
 597                    temporal_subsampling_size=temporal_subsampling_size,
 598                    skip_metrics=skip_metrics,
 599                )
 600                if loss != 0:
 601                    loss = loss * alpha
 602                    epoch_loss += loss.item()
 603                for i, (ssl_loss, weight) in enumerate(
 604                    zip(ssl_losses, self.ssl_weights)
 605                ):
 606                    if ssl_loss != 0:
 607                        epoch_ssl_loss[i] += ssl_loss.item()
 608                        loss = loss + weight * ssl_loss
 609                if mode == "train":
 610                    self.optimizer.zero_grad()
 611                    if loss.requires_grad:
 612                        loss.backward()
 613                    self.optimizer.step()
 614
 615        epoch_loss = epoch_loss / data_len
 616        epoch_ssl_loss = {k: v / data_len for k, v in epoch_ssl_loss.items()}
 617        epoch_metrics = self._calculate_metrics()
 618
 619        return epoch_loss, epoch_ssl_loss, epoch_metrics
 620
 621    def train(
 622        self,
 623        trial: Trial = None,
 624        optimized_metric: str = None,
 625        to_ram: bool = False,
 626        autostop_interval: int = 30,
 627        autostop_threshold: float = 0.001,
 628        autostop_metric: str = None,
 629        main_task_on: bool = True,
 630        ssl_on: bool = True,
 631        temporal_subsampling_size: int = None,
 632        loading_bar: bool = False,
 633    ) -> Tuple:
 634        """Train the task and return a log of epoch-average loss and metric.
 635
 636        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 637        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 638        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 639        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 640
 641        Parameters
 642        ----------
 643        trial : Trial
 644            an `optuna` trial (for hyperparameter searches)
 645        optimized_metric : str
 646            the name of the metric being optimized (for hyperparameter searches)
 647        to_ram : bool, default False
 648            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
 649            if the dataset is too large)
 650        autostop_interval : int, default 50
 651            the number of epochs to average the autostop metric over
 652        autostop_threshold : float, default 0.001
 653            the autostop difference threshold
 654        autostop_metric : str, optional
 655            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 656        main_task_on : bool, default True
 657            if `False`, the main task (action segmentation) will not be used in training
 658        ssl_on : bool, default True
 659            if `False`, the SSL task will not be used in training
 660        temporal_subsampling_size : int, optional
 661            if not `None`, the temporal subsampling will be used in training with the given size
 662        loading_bar : bool, default False
 663            if `True`, a loading bar will be displayed
 664
 665        Returns
 666        -------
 667        loss_log: list
 668            a list of float loss function values for each epoch
 669        metrics_log: dict
 670            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
 671            names, values are lists of function values)
 672
 673        """
 674        if self.parallel and not isinstance(self.model, nn.DataParallel):
 675            self.model = MyDataParallel(self.model)
 676        self.model.to(self.device)
 677        assert autostop_metric in [None, "loss"] + list(self.metrics)
 678        autostop_interval //= self.validation_interval
 679        if trial is not None and optimized_metric is None:
 680            raise ValueError(
 681                "You need to provide the optimized metric name (optimized_metric parameter) "
 682                "for optuna pruning to work!"
 683            )
 684        if to_ram:
 685            print("transferring datasets to RAM...")
 686            self.train_dataloader.dataset.to_ram()
 687            if self.val_dataloader is not None and len(self.val_dataloader) > 0:
 688                self.val_dataloader.dataset.to_ram()
 689        loss_log = {"train": [], "val": []}
 690        metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])}
 691        if not main_task_on:
 692            self.model.main_task_off()
 693        if not ssl_on:
 694            self.model.ssl_off()
 695        while self.epoch < self.num_epochs:
 696            self.epoch += 1
 697            unlabeled = None
 698            alpha = 1
 699            if self.pseudolabel:
 700                if self.epoch >= self.T1:
 701                    unlabeled = (self.epoch - self.T1) % self.t != 0
 702                    if unlabeled:
 703                        alpha = self._alpha(self.epoch)
 704                else:
 705                    unlabeled = False
 706            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
 707                dataloader=self.train_dataloader,
 708                mode="train",
 709                augment_n=self.augment_train,
 710                unlabeled=unlabeled,
 711                alpha=alpha,
 712                temporal_subsampling_size=temporal_subsampling_size,
 713                verbose=loading_bar,
 714            )
 715            loss_log["train"].append(epoch_loss)
 716            epoch_string = f"[epoch {self.epoch}]"
 717            if self.pseudolabel:
 718                if unlabeled:
 719                    epoch_string += " (unlabeled)"
 720                else:
 721                    epoch_string += " (labeled)"
 722            epoch_string += f": loss {epoch_loss:.4f}"
 723
 724            if len(epoch_ssl_loss) != 0:
 725                for key, value in sorted(epoch_ssl_loss.items()):
 726                    metrics_log["train"][f"ssl_loss_{key}"].append(value)
 727                    epoch_string += f", ssl_loss_{key} {value:.4f}"
 728
 729            for metric_name, metric_value in sorted(epoch_metrics.items()):
 730                if metric_name not in self.skip_metrics:
 731                    if isinstance(metric_value, list):
 732                        metric_value = torch.mean(torch.Tensor(metric_value))
 733                    epoch_string += f", {metric_name} {metric_value:.3f}"
 734                    metrics_log["train"][metric_name].append(metric_value)
 735
 736            if (
 737                self.val_dataloader is not None
 738                and self.epoch % self.validation_interval == 0
 739            ):
 740                with torch.no_grad():
 741                    epoch_string += "\n"
 742                    (
 743                        val_epoch_loss,
 744                        val_epoch_ssl_loss,
 745                        val_epoch_metrics,
 746                    ) = self._run_epoch(
 747                        dataloader=self.val_dataloader,
 748                        mode="val",
 749                        augment_n=self.augment_val,
 750                    )
 751                    loss_log["val"].append(val_epoch_loss)
 752                    epoch_string += f"validation: loss {val_epoch_loss:.4f}"
 753
 754                    if len(val_epoch_ssl_loss) != 0:
 755                        for key, value in sorted(val_epoch_ssl_loss.items()):
 756                            metrics_log["val"][f"ssl_loss_{key}"].append(value)
 757                            epoch_string += f", ssl_loss_{key} {value:.4f}"
 758
 759                    for metric_name, metric_value in sorted(val_epoch_metrics.items()):
 760                        if isinstance(metric_value, list):
 761                            metric_value = torch.mean(torch.Tensor(metric_value))
 762                        metrics_log["val"][metric_name].append(metric_value)
 763                        epoch_string += f", {metric_name} {metric_value:.3f}"
 764
 765                if trial is not None:
 766                    if optimized_metric not in metrics_log["val"]:
 767                        raise ValueError(
 768                            f"The {optimized_metric} metric set for optimization is not being logged!"
 769                        )
 770                    trial.report(metrics_log["val"][optimized_metric][-1], self.epoch)
 771                    if trial.should_prune():
 772                        raise TrialPruned()
 773
 774            if self.verbose:
 775                print(epoch_string)
 776
 777            if self.log_file is not None:
 778                with open(self.log_file, "a") as f:
 779                    f.write(epoch_string + "\n")
 780
 781            save_condition = (
 782                (self.model_save_epochs != 0)
 783                and (self.epoch % self.model_save_epochs == 0)
 784            ) or (self.epoch == self.num_epochs)
 785
 786            if self.epoch > 0 and save_condition and self.model_save_path is not None:
 787                epoch_s = str(self.epoch).zfill(len(str(self.num_epochs)))
 788                self.save_checkpoint(
 789                    os.path.join(self.model_save_path, f"epoch{epoch_s}.pt")
 790                )
 791
 792            if self.pseudolabel and self.epoch >= self.T1 and not unlabeled:
 793                self._set_pseudolabels()
 794
 795            if autostop_metric == "loss":
 796                if len(loss_log["val"]) > autostop_interval * 2:
 797                    if (
 798                        np.mean(loss_log["val"][-autostop_interval:])
 799                        < np.mean(
 800                            loss_log["val"][-2 * autostop_interval : -autostop_interval]
 801                        )
 802                        + autostop_threshold
 803                    ):
 804                        break
 805            elif autostop_metric in metrics_log["val"]:
 806                if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2:
 807                    if (
 808                        np.mean(
 809                            metrics_log["val"][autostop_metric][-autostop_interval:]
 810                        )
 811                        < np.mean(
 812                            metrics_log["val"][autostop_metric][
 813                                -2 * autostop_interval : -autostop_interval
 814                            ]
 815                        )
 816                        + autostop_threshold
 817                    ):
 818                        break
 819
 820        metrics_log = {k: dict(v) for k, v in metrics_log.items()}
 821
 822        return loss_log, metrics_log
 823
 824    def evaluate_prediction(
 825        self,
 826        prediction: Union[torch.Tensor, Dict],
 827        data: Union[DataLoader, BehaviorDataset, str] = None,
 828        batch_size: int = 32,
 829        indices: list = None,
 830    ) -> Tuple:
 831        """Compute metrics for a prediction.
 832
 833        Parameters
 834        ----------
 835        prediction : torch.Tensor
 836            the prediction
 837        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 838            the data the prediction was made for (if not provided, take the validation dataset)
 839        batch_size : int, default 32
 840            the batch size
 841
 842        Returns
 843        -------
 844        loss : float
 845            the average value of the loss function
 846        metric : dict
 847            a dictionary of average values of metric functions
 848        """
 849
 850        if type(data) is not DataLoader:
 851            dataset = self._get_dataset(data)
 852            data = DataLoader(dataset, shuffle=False, batch_size=batch_size)
 853        epoch_loss = 0
 854        if isinstance(prediction, dict):
 855            num_classes = len(self.behaviors_dict())
 856            length = dataset.len_segment()
 857            coords = dataset.annotation_store.get_original_coordinates()
 858            for batch in data:
 859                main_target = batch["target"]
 860                pr_coords = coords[batch["index"]]
 861                predicted = torch.zeros((len(pr_coords), num_classes, length))
 862                for i, c in enumerate(pr_coords):
 863                    video_id = dataset.input_store.get_video_id(c)
 864                    clip_id = dataset.input_store.get_clip_id(c)
 865                    start, end = dataset.input_store.get_clip_start_end(c)
 866                    beh_ind = list(prediction[video_id]["classes"].keys())
 867                    pred_tmp = prediction[video_id][clip_id][beh_ind, :]
 868                    predicted[i, :, : end - start] = pred_tmp[:, start:end]
 869                self._compute(
 870                    [],
 871                    [],
 872                    predicted,
 873                    main_target,
 874                    skip_loss=True,
 875                    tag=batch.get("tag"),
 876                    apply_primary_function=False,
 877                )
 878        else:
 879            for batch in data:
 880                main_target = batch["target"]
 881                predicted = prediction[batch["index"]]
 882                if not indices is None:
 883                    indices_new = [indices.index(i) for i in range(len(indices))]
 884                    predicted = predicted[:, indices_new]
 885                self._compute(
 886                    [],
 887                    [],
 888                    predicted,
 889                    main_target,
 890                    skip_loss=True,
 891                    tag=batch.get("tag"),
 892                    apply_primary_function=False,
 893                )
 894        epoch_metrics = self._calculate_metrics()
 895        # strings = [
 896        #     f"{metric_name} {metric_value:.3f}"
 897        #     for metric_name, metric_value in epoch_metrics.items()
 898        # ]
 899        # val_string = ", ".join(sorted(strings))
 900        # print(val_string)
 901        return epoch_loss, epoch_metrics
 902
 903    def evaluate(
 904        self,
 905        data: Union[DataLoader, BehaviorDataset, str] = None,
 906        augment_n: int = 0,
 907        batch_size: int = 32,
 908        verbose: bool = True,
 909    ) -> Tuple:
 910        """Evaluate the Task model.
 911
 912        Parameters
 913        ----------
 914        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 915            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 916        augment_n : int, default 0
 917            the number of augmentations to average results over
 918        batch_size : int, default 32
 919            the batch size
 920        verbose : bool, default True
 921            if True, the process is reported to standard output
 922
 923        Returns
 924        -------
 925        loss : float
 926            the average value of the loss function
 927        ssl_loss : float
 928            the average value of the SSL loss function
 929        metric : dict
 930            a dictionary of average values of metric functions
 931
 932        """
 933        if self.parallel and not isinstance(self.model, nn.DataParallel):
 934            self.model = MyDataParallel(self.model)
 935        self.model.to(self.device)
 936        if type(data) is not DataLoader:
 937            data = self._get_dataset(data)
 938            data = DataLoader(data, shuffle=False, batch_size=batch_size)
 939        with torch.no_grad():
 940            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
 941                dataloader=data, mode="val", augment_n=augment_n, verbose=verbose
 942            )
 943        val_string = f"loss {epoch_loss:.4f}"
 944        for metric_name, metric_value in sorted(epoch_metrics.items()):
 945            val_string += f", {metric_name} {metric_value:.3f}"
 946        print(val_string)
 947        return epoch_loss, epoch_ssl_loss, epoch_metrics
 948
 949    def predict(
 950        self,
 951        data: Union[DataLoader, BehaviorDataset, str] = None,
 952        raw_output: bool = False,
 953        apply_primary_function: bool = True,
 954        augment_n: int = 0,
 955        batch_size: int = 32,
 956        train_mode: bool = False,
 957        to_ram: bool = False,
 958        embedding: bool = False,
 959    ) -> torch.Tensor:
 960        """Make a prediction with the Task model.
 961
 962        Parameters
 963        ----------
 964        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional
 965            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 966        raw_output : bool, default False
 967            if `True`, the raw predicted probabilities are returned
 968        apply_primary_function : bool, default True
 969            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
 970            the input)
 971        augment_n : int, default 0
 972            the number of augmentations to average results over
 973        batch_size : int, default 32
 974            the batch size
 975        train_mode : bool, default False
 976            if `True`, the model is used in training mode (affects dropout and batch normalization layers)
 977        to_ram : bool, default False
 978            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
 979            if the dataset is too large)
 980        embedding : bool, default False
 981            if `True`, the output of feature extractor is returned, ignoring the prediction module of the model
 982
 983        Returns
 984        -------
 985        prediction : torch.Tensor
 986            a prediction for the input data
 987
 988        """
 989        if self.parallel and not isinstance(self.model, nn.DataParallel):
 990            self.model = MyDataParallel(self.model)
 991        self.model.to(self.device)
 992        if train_mode:
 993            self.model.train()
 994        else:
 995            self.model.eval()
 996        output = []
 997        if embedding:
 998            raw_output = True
 999            apply_primary_function = True
1000        if type(data) is not DataLoader:
1001            data = self._get_dataset(data)
1002            if to_ram:
1003                print("transferring dataset to RAM...")
1004                data.to_ram()
1005            data = DataLoader(data, shuffle=False, batch_size=batch_size)
1006        self.model.ssl_off()
1007        with torch.no_grad():
1008            for batch in tqdm(data):
1009                input = {k: v.to(self.device) for k, v in batch["input"].items()}
1010                predicted, _, _ = self._get_prediction(
1011                    input,
1012                    batch.get("tag"),
1013                    augment_n=augment_n,
1014                    embedding=embedding,
1015                )
1016                if apply_primary_function:
1017                    predicted = self.primary_predict_function(predicted)
1018                if not raw_output:
1019                    predicted = self.predict_function(predicted)
1020                output.append(predicted.detach().cpu())
1021        self.model.ssl_on()
1022        output = torch.cat(output).detach()
1023        return output
1024
1025    def dataset(self, mode="train") -> BehaviorDataset:
1026        """Get a dataset.
1027
1028        Parameters
1029        ----------
1030        mode : {'train', 'val', 'test}
1031            the dataset to get
1032
1033        Returns
1034        -------
1035        dataset : dlc2action.data.dataset.BehaviorDataset
1036            the dataset
1037
1038        """
1039        dataloader = self.dataloader(mode)
1040        if dataloader is None:
1041            raise ValueError("The length of the dataloader is 0!")
1042        return dataloader.dataset
1043
1044    def dataloader(self, mode: str = "train") -> DataLoader:
1045        """Get a dataloader.
1046
1047        Parameters
1048        ----------
1049        mode : {'train', 'val', 'test}
1050            the dataset to get
1051
1052        Returns
1053        -------
1054        dataloader : torch.utils.data.DataLoader
1055            the dataloader
1056
1057        """
1058        if mode == "train":
1059            return self.train_dataloader
1060        elif mode == "val":
1061            return self.val_dataloader
1062        elif mode == "test":
1063            return self.test_dataloader
1064        else:
1065            raise ValueError(
1066                f'The {mode} mode is not recognized, please choose from "train", "val" or "test"'
1067            )
1068
1069    def _get_dataset(self, dataset):
1070        """Get a dataset from a dataloader, a string ('train', 'test' or 'val') or `None` (default)."""
1071        if dataset is None:
1072            dataset = self.dataset("val")
1073        elif dataset in ["train", "val", "test"]:
1074            dataset = self.dataset(dataset)
1075        elif type(dataset) is DataLoader:
1076            dataset = dataset.dataset
1077        if type(dataset) is BehaviorDataset:
1078            return dataset
1079        else:
1080            raise TypeError(f"The {type(dataset)} type of dataset is not recognized!")
1081
1082    def _get_dataloader(self, dataset):
1083        """Get a dataloader from a dataset, a string ('train', 'test' or 'val') or `None` (default)."""
1084        if dataset is None:
1085            dataset = self.dataloader("val")
1086        elif dataset in ["train", "val", "test"]:
1087            dataset = self.dataloader(dataset)
1088            if dataset is None:
1089                raise ValueError(f"The length of the dataloader is 0!")
1090        elif type(dataset) is BehaviorDataset:
1091            dataset = DataLoader(dataset)
1092        if type(dataset) is DataLoader:
1093            return dataset
1094        else:
1095            raise TypeError(f"The {type(dataset)} type of dataset is not recognized!")
1096
1097    def generate_full_length_prediction(
1098        self, dataset=None, batch_size=32, augment_n=10
1099    ):
1100        """Compile a prediction for the original input sequences.
1101
1102        Parameters
1103        ----------
1104        dataset : BehaviorDataset, optional
1105            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1106        batch_size : int, default 32
1107            the batch size
1108        augment_n : int, default 10
1109            the number of augmentations to average results over
1110
1111        Returns
1112        -------
1113        prediction : dict
1114            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1115            are prediction tensors
1116
1117        """
1118        dataset = self._get_dataset(dataset)
1119        if not isinstance(dataset, BehaviorDataset):
1120            raise TypeError(
1121                f"The dataset parameter has to be either None, string, "
1122                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1123            )
1124        predicted = self.predict(
1125            dataset,
1126            raw_output=True,
1127            apply_primary_function=True,
1128            augment_n=augment_n,
1129            batch_size=batch_size,
1130        )
1131        predicted = dataset.generate_full_length_prediction(predicted)
1132        predicted = {
1133            v_id: {
1134                clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze()
1135                for clip_id, v in video_dict.items()
1136            }
1137            for v_id, video_dict in predicted.items()
1138        }
1139        return predicted
1140
1141    def generate_submission(
1142        self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10
1143    ):
1144        """Generate a MABe-22 style submission dictionary.
1145
1146        Parameters
1147        ----------
1148        frame_number_map_file : str
1149            the path to the frame number map file
1150        dataset : BehaviorDataset, optional
1151            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1152        batch_size : int, default 32
1153            the batch size
1154        augment_n : int, default 10
1155            the number of augmentations to average results over
1156
1157        Returns
1158        -------
1159        submission : dict
1160            a dictionary with frame number mapping and embeddings
1161
1162        """
1163        dataset = self._get_dataset(dataset)
1164        if not isinstance(dataset, BehaviorDataset):
1165            raise TypeError(
1166                f"The dataset parameter has to be either None, string, "
1167                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1168            )
1169        predicted = self.predict(
1170            dataset,
1171            raw_output=True,
1172            apply_primary_function=True,
1173            augment_n=augment_n,
1174            batch_size=batch_size,
1175            embedding=True,
1176        )
1177        predicted = dataset.generate_full_length_prediction(predicted)
1178        frame_map = np.load(frame_number_map_file, allow_pickle=True).item()
1179        length = frame_map[list(frame_map.keys())[-1]][1]
1180        embeddings = None
1181        for video_id in list(predicted.keys()):
1182            split = video_id.split("--")
1183            if len(split) != 2 or len(predicted[video_id]) > 1:
1184                raise RuntimeError(
1185                    "Generating submissions is only implemented for the mabe22 dataset!"
1186                )
1187            if split[1] not in frame_map:
1188                raise RuntimeError(f"The {split[1]} video is not in the frame map file")
1189            v_id = split[1]
1190            clip_id = list(predicted[video_id].keys())[0]
1191            if embeddings is None:
1192                embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0]))
1193            start, end = frame_map[v_id]
1194            embeddings[start:end, :] = predicted[video_id][clip_id].T
1195            predicted.pop(video_id)
1196        predicted = {
1197            "frame_number_map": frame_map,
1198            "embeddings": embeddings.astype(np.float32),
1199        }
1200        return predicted
1201
1202    def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor:
1203        """Get a list of True group beginning and end indices from a boolean tensor."""
1204        output, indices = torch.unique_consecutive(tensor, return_inverse=True)
1205        true_indices = torch.where(output)[0]
1206        starts = torch.tensor(
1207            [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
1208        )
1209        ends = torch.tensor(
1210            [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
1211        )
1212        return torch.stack([starts, ends]).T
1213
1214    def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor:
1215        """Get rid of jittering in a non-exclusive classification tensor.
1216
1217        First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than
1218        `smooth_interval`.
1219
1220        """
1221        if len(tensor.shape) > 1:
1222            for c in tensor.shape[1]:
1223                intervals = self._get_intervals(tensor[:, c] == 0)
1224                interval_lengths = torch.tensor(
1225                    [interval[1] - interval[0] for interval in intervals]
1226                )
1227                short_intervals = intervals[interval_lengths <= smooth_interval]
1228                for start, end in short_intervals:
1229                    tensor[start:end, c] = 1
1230                intervals = self._get_intervals(tensor[:, c] == 1)
1231                interval_lengths = torch.tensor(
1232                    [interval[1] - interval[0] for interval in intervals]
1233                )
1234                short_intervals = intervals[interval_lengths <= smooth_interval]
1235                for start, end in short_intervals:
1236                    tensor[start:end, c] = 0
1237        else:
1238            for c in tensor.unique():
1239                intervals = self._get_intervals(tensor == c)
1240                interval_lengths = torch.tensor(
1241                    [interval[1] - interval[0] for interval in intervals]
1242                )
1243                short_intervals = intervals[interval_lengths <= smooth_interval]
1244                for start, end in short_intervals:
1245                    if start == 0:
1246                        tensor[start:end] = tensor[end + 1]
1247                    else:
1248                        tensor[start:end] = tensor[start - 1]
1249        return tensor
1250
1251    def _visualize_results_single(
1252        self,
1253        behavior: str,
1254        save_path: str = None,
1255        add_legend: bool = True,
1256        ground_truth: bool = True,
1257        hide_axes: bool = False,
1258        width: int = 10,
1259        whole_video: bool = False,
1260        transparent: bool = False,
1261        dataset: BehaviorDataset = None,
1262        smooth_interval: int = 0,
1263        title: str = None,
1264    ):
1265        """Visualize random predictions.
1266
1267        Parameters
1268        ----------
1269        behavior : str
1270            the behavior to visualize
1271        save_path : str, optional
1272            the path where the plot will be saved
1273        add_legend : bool, default True
1274            if `True`, legend will be added to the plot
1275        ground_truth : bool, default True
1276            if `True`, ground truth will be added to the plot
1277        colormap : str, default 'viridis'
1278            the `matplotlib` colormap to use
1279        hide_axes : bool, default True
1280            if `True`, the axes will be hidden on the plot
1281        min_classes : int, default 1
1282            the minimum number of classes in a displayed interval
1283        width : float, default 10
1284            the width of the plot
1285        whole_video : bool, default False
1286            if `True`, whole videos are plotted instead of segments
1287        transparent : bool, default False
1288            if `True`, the background on the plot is transparent
1289        dataset : BehaviorDataset, optional
1290            the dataset to make the prediction for (if not provided, the validation dataset is used)
1291        drop_classes : set, optional
1292            a set of class names to not be displayed
1293        smooth_interval : int, default 0
1294            the interval to smooth the predictions over
1295        title : str, optional
1296            the title of the plot
1297
1298        """
1299        if title is None:
1300            title = ""
1301        dataset = self._get_dataset(dataset)
1302        inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()}
1303        label_ind = inverse_dict[behavior]
1304        labels = {1: behavior, -100: "unknown"}
1305        label_keys = [1, -100]
1306        color_list = ["blue", "gray"]
1307        if whole_video:
1308            predicted = self.generate_full_length_prediction(dataset)
1309            keys = list(predicted.keys())
1310        counter = 0
1311        if whole_video:
1312            max_iter = len(keys) * 5
1313        else:
1314            max_iter = len(dataset) * 5
1315        ok = False
1316        while not ok:
1317            counter += 1
1318            if counter > max_iter:
1319                raise RuntimeError(
1320                    "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive"
1321                )
1322            if whole_video:
1323                i = randint(0, len(keys) - 1)
1324                prediction = predicted[keys[i]]
1325                keys_i = list(prediction.keys())
1326                j = randint(0, len(keys_i) - 1)
1327                full_p = prediction[keys_i[j]]
1328                prediction = prediction[keys_i[j]][label_ind]
1329            else:
1330                dataloader = DataLoader(dataset)
1331                i = randint(0, len(dataloader) - 1)
1332                for num, batch in enumerate(dataloader):
1333                    if num == i:
1334                        break
1335                input_data = {k: v.to(self.device) for k, v in batch["input"].items()}
1336                prediction, *_ = self._get_prediction(
1337                    input_data, batch.get("tag"), augment_n=5
1338                )
1339                prediction = self._apply_predict_functions(prediction)
1340                j = randint(0, len(prediction) - 1)
1341                full_p = prediction[j]
1342                prediction = prediction[j][label_ind]
1343            classes = [x for x in torch.unique(prediction) if int(x) in label_keys]
1344            ok = 1 in classes
1345        fig, ax = plt.subplots(figsize=(width, 2))
1346        for c in classes:
1347            c_i = label_keys.index(int(c))
1348            output, indices, counts = torch.unique_consecutive(
1349                prediction == c, return_inverse=True, return_counts=True
1350            )
1351            long_indices = torch.where(output)[0]
1352            res_indices_start = [
1353                (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
1354            ]
1355            res_indices_end = [
1356                (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1357                for i in long_indices
1358            ]
1359            res_indices_len = [
1360                end - start for start, end in zip(res_indices_start, res_indices_end)
1361            ]
1362            ax.broken_barh(
1363                list(zip(res_indices_start, res_indices_len)),
1364                (0, 1),
1365                label=labels[int(c)],
1366                facecolors=color_list[c_i],
1367            )
1368        if ground_truth:
1369            gt = batch["target"][j][label_ind].to(self.device)
1370            classes_gt = [x for x in torch.unique(gt) if int(x) in label_keys]
1371            for c in classes_gt:
1372                c_i = label_keys.index(int(c))
1373                if c in classes:
1374                    behavior = None
1375                else:
1376                    behavior = labels[int(c)]
1377                output, indices, counts = torch.unique_consecutive(
1378                    gt == c, return_inverse=True, return_counts=True
1379                )
1380                long_indices = torch.where(output * (counts > 5))[0]
1381                res_indices_start = [
1382                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1383                    for i in long_indices
1384                ]
1385                res_indices_end = [
1386                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1387                    for i in long_indices
1388                ]
1389                res_indices_len = [
1390                    end - start
1391                    for start, end in zip(res_indices_start, res_indices_end)
1392                ]
1393                ax.broken_barh(
1394                    list(zip(res_indices_start, res_indices_len)),
1395                    (1.5, 1),
1396                    facecolors=color_list[c_i],
1397                    label=behavior,
1398                )
1399        self._compute(
1400            main_target=batch["target"][j].unsqueeze(0).to(self.device),
1401            predicted=full_p.unsqueeze(0).to(self.device),
1402            ssl_targets=[],
1403            ssl_predicted=[],
1404            skip_loss=True,
1405        )
1406        metrics = self._calculate_metrics()
1407        if smooth_interval > 0:
1408            smoothed = self._smooth(full_p, smooth_interval=smooth_interval)[
1409                label_ind, :
1410            ]
1411            for c in classes:
1412                c_i = label_keys.index(int(c))
1413                output, indices, counts = torch.unique_consecutive(
1414                    smoothed == c, return_inverse=True, return_counts=True
1415                )
1416                long_indices = torch.where(output)[0]
1417                res_indices_start = [
1418                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1419                    for i in long_indices
1420                ]
1421                res_indices_end = [
1422                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1423                    for i in long_indices
1424                ]
1425                res_indices_len = [
1426                    end - start
1427                    for start, end in zip(res_indices_start, res_indices_end)
1428                ]
1429                ax.broken_barh(
1430                    list(zip(res_indices_start, res_indices_len)),
1431                    (3, 1),
1432                    label=labels[int(c)],
1433                    facecolors=color_list[c_i],
1434                )
1435        keys = list(metrics.keys())
1436        for key in keys:
1437            if key.split("_")[-1] != (str(label_ind)):
1438                metrics.pop(key)
1439        title = [title]
1440        for key, value in metrics.items():
1441            title.append(f"{'_'.join(key.split('_')[: -1])}: {value:.2f}")
1442        title = ", ".join(title)
1443        if not ground_truth:
1444            ax.axes.yaxis.set_visible(False)
1445        else:
1446            ax.set_yticks([0.5, 2])
1447            ax.set_yticklabels(["prediction", "ground truth"])
1448        if add_legend:
1449            ax.legend()
1450        if hide_axes:
1451            plt.axis("off")
1452        plt.title(title)
1453        plt.xlim((0, len(prediction)))
1454        if save_path is not None:
1455            plt.savefig(save_path, transparent=transparent)
1456        plt.show()
1457
1458    def visualize_results(
1459        self,
1460        save_path: str = None,
1461        add_legend: bool = True,
1462        ground_truth: bool = True,
1463        colormap: str = "viridis",
1464        hide_axes: bool = False,
1465        min_classes: int = 1,
1466        width: int = 10,
1467        whole_video: bool = False,
1468        transparent: bool = False,
1469        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1470        drop_classes: Set = None,
1471        search_classes: Set = None,
1472        smooth_interval_prediction: int = None,
1473        font_size: float = None,
1474        num_plots: int = 1,
1475        window_size:int =400
1476    ):
1477        """Visualize random predictions.
1478
1479        Parameters
1480        ----------
1481        save_path : str, optional
1482            the path where the plot will be saved
1483        add_legend : bool, default True
1484            if `True`, legend will be added to the plot
1485        ground_truth : bool, default True
1486            if `True`, ground truth will be added to the plot
1487        colormap : str, default 'Accent'
1488            the `matplotlib` colormap to use
1489        hide_axes : bool, default True
1490            if `True`, the axes will be hidden on the plot
1491        min_classes : int, default 1
1492            the minimum number of classes in a displayed interval
1493        width : float, default 10
1494            the width of the plot
1495        whole_video : bool, default False
1496            if `True`, whole videos are plotted instead of segments
1497        transparent : bool, default False
1498            if `True`, the background on the plot is transparent
1499        dataset : BehaviorDataset | DataLoader | str | None, optional
1500            the dataset to make the prediction for (if not provided, the validation dataset is used)
1501        drop_classes : set, optional
1502            a set of class names to not be displayed
1503        search_classes : set, optional
1504            if given, only intervals where at least one of the classes is in ground truth will be shown
1505        smooth_interval_prediction : int, optional
1506            if given, the prediction will be smoothed with a moving average of the given size
1507
1508        """
1509        if drop_classes is None:
1510            drop_classes = []
1511        dataset = self._get_dataset(dataset)
1512        if dataset.annotation_class() != "exclusive_classification":
1513            raise NotImplementedError(
1514                "Results visualisation is only implemented for exclusive classification datasets!"
1515            )
1516        labels = {
1517            k: v for k, v in dataset.behaviors_dict().items() if v not in drop_classes
1518        }
1519        labels.update({-100: "unknown"})
1520        label_keys = sorted([int(x) for x in labels.keys()])
1521        if search_classes is None:
1522            ok = True
1523        else:
1524            ok = False
1525        classes = []
1526        predicted = self.generate_full_length_prediction(dataset)
1527        keys = list(predicted.keys())
1528        counter = 0
1529        max_iter = len(keys) * 2
1530        while len(classes) < min_classes or not ok:
1531            counter += 1
1532            if counter > max_iter:
1533                raise RuntimeError(
1534                    "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive"
1535                )
1536            i = randint(0, len(keys) - 1)
1537            prediction = predicted[keys[i]]
1538            keys_i = list(prediction.keys())
1539            j = randint(0, len(keys_i) - 1)
1540            prediction = prediction[keys_i[j]]
1541            key1 = keys[i]
1542            key2 = keys_i[j]
1543
1544            if smooth_interval_prediction > 0:
1545                unsmoothed_prediction = deepcopy(prediction)
1546                prediction = self._smooth(prediction, smooth_interval_prediction)
1547                height = 3
1548            else:
1549                height = 2
1550            classes = [
1551                labels[int(x)] for x in torch.unique(prediction) if x in label_keys
1552            ]
1553            if search_classes is not None:
1554                ok = any([x in classes for x in search_classes])
1555        fig, ax = plt.subplots(figsize=(width, height))
1556        cmap = cm.get_cmap(colormap) if colormap != "dlc2action" else None
1557        color_list = (
1558            [cmap(c) for c in np.linspace(0, 1, len(labels))]
1559            if colormap != "dlc2action"
1560            else dlc2action_colormaps["default"]
1561        )
1562
1563        def _plot_prediction(prediction, height, set_labels=True):
1564            for c in label_keys:
1565                c_i = label_keys.index(int(c))
1566                output, indices, counts = torch.unique_consecutive(
1567                    prediction == c, return_inverse=True, return_counts=True
1568                )
1569                long_indices = torch.where(output)[0]
1570                if len(long_indices) == 0:
1571                    continue
1572                res_indices_start = [
1573                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1574                    for i in long_indices
1575                ]
1576                res_indices_end = [
1577                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1578                    for i in long_indices
1579                ]
1580                res_indices_len = [
1581                    end - start
1582                    for start, end in zip(res_indices_start, res_indices_end)
1583                ]
1584                if set_labels:
1585                    label = labels[int(c)]
1586                else:
1587                    label = None
1588                ax.broken_barh(
1589                    list(zip(res_indices_start, res_indices_len)),
1590                    (height, 1),
1591                    label=label,
1592                    facecolors=color_list[c_i],
1593                )
1594
1595        if smooth_interval_prediction > 0:
1596            _plot_prediction(unsmoothed_prediction, 0)
1597            _plot_prediction(prediction, 1.5, set_labels=False)
1598            gt_height = 3
1599        else:
1600            _plot_prediction(prediction, 0)
1601            gt_height = 1.5
1602        if ground_truth:
1603            gt = dataset.generate_full_length_gt()[key1][key2]
1604            for c in label_keys:
1605                c_i = label_keys.index(int(c))
1606                if labels[int(c)] in classes:
1607                    label = None
1608                else:
1609                    label = labels[int(c)]
1610                output, indices, counts = torch.unique_consecutive(
1611                    gt == c, return_inverse=True, return_counts=True
1612                )
1613                long_indices = torch.where(output)[0]
1614                if len(long_indices) == 0:
1615                    continue
1616                res_indices_start = [
1617                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1618                    for i in long_indices
1619                ]
1620                res_indices_end = [
1621                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1622                    for i in long_indices
1623                ]
1624                res_indices_len = [
1625                    end - start
1626                    for start, end in zip(res_indices_start, res_indices_end)
1627                ]
1628                ax.broken_barh(
1629                    list(zip(res_indices_start, res_indices_len)),
1630                    (gt_height, 1),
1631                    facecolors=color_list[c_i] if c != "unknown" else "gray",
1632                    label=label,
1633                )
1634        if not ground_truth:
1635            ax.axes.yaxis.set_visible(False)
1636        else:
1637            if smooth_interval_prediction > 0:
1638                ax.set_yticks([0.5, 2, 3.5])
1639                ax.set_yticklabels(["prediction", "smoothed", "ground truth"])
1640            else:
1641                ax.set_yticks([0.5, 2])
1642                ax.set_yticklabels(["prediction", "ground truth"])
1643        if add_legend:
1644            ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1645        if hide_axes:
1646            # plt.axis("off")
1647            plt.box(False)
1648        if font_size is not None:
1649            font = {"size": font_size}
1650            rc("font", **font)
1651        plt.title(f"{key1} -- {key2}")
1652        plt.tight_layout()
1653        for seed in range(num_plots):
1654            if not whole_video:
1655                ax.set_xlim((seed * window_size, (seed + 1) * window_size))
1656            if save_path is not None:
1657                plt.savefig(
1658                    save_path.replace(".svg", f"_{seed}_{key1} -- {key2}.svg"), transparent=transparent
1659                )
1660                print(f"Saved in {save_path.replace('.svg', f'_{seed}_{key1} -- {key2}.svg')}")
1661        plt.show()
1662        plt.close()
1663
1664    def set_ssl_transformations(self, ssl_transformations):
1665        """Set SSL transformations.
1666
1667        Parameters
1668        ----------
1669        ssl_transformations : list
1670            a list of callable SSL transformations
1671
1672        """
1673        self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations)
1674        if self.val_dataloader is not None:
1675            self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations)
1676
1677    def set_ssl_losses(self, ssl_losses: list) -> None:
1678        """Set SSL losses.
1679
1680        Parameters
1681        ----------
1682        ssl_losses : list
1683            a list of callable SSL losses
1684
1685        """
1686        self.ssl_losses = ssl_losses
1687
1688    def set_log(self, log: str) -> None:
1689        """Set the log file.
1690
1691        Parameters
1692        ----------
1693        log: str
1694            the mew log file path
1695
1696        """
1697        self.log_file = log
1698
1699    def set_keep_target_none(self, keep_target_none: List) -> None:
1700        """Set the keep_target_none parameter of the transformer.
1701
1702        Parameters
1703        ----------
1704        keep_target_none : list
1705            a list of bool values
1706
1707        """
1708        self.transformer.keep_target_none = keep_target_none
1709
1710    def set_generate_ssl_input(self, generate_ssl_input: list) -> None:
1711        """Set the generate_ssl_input parameter of the transformer.
1712
1713        Parameters
1714        ----------
1715        generate_ssl_input : list
1716            a list of bool values
1717
1718        """
1719        self.transformer.generate_ssl_input = generate_ssl_input
1720
1721    def set_model(self, model: Model) -> None:
1722        """Set a new model.
1723
1724        Parameters
1725        ----------
1726        model: Model
1727            the new model
1728
1729        """
1730        self.epoch = 0
1731        self.model = model
1732        self.optimizer = self.optimizer_class(
1733            model.parameters(), lr=self.lr, weight_decay=self.weight_decay
1734        )
1735        if self.model.process_labels:
1736            self.model.set_behaviors(list(self.behaviors_dict().values()))
1737
1738    def set_dataloaders(
1739        self,
1740        train_dataloader: DataLoader,
1741        val_dataloader: DataLoader = None,
1742        test_dataloader: DataLoader = None,
1743    ) -> None:
1744        """Set new dataloaders.
1745
1746        Parameters
1747        ----------
1748        train_dataloader: torch.utils.data.DataLoader
1749            the new train dataloader
1750        val_dataloader : torch.utils.data.DataLoader
1751            the new validation dataloader
1752        test_dataloader : torch.utils.data.DataLoader
1753            the new test dataloader
1754
1755        """
1756        self.train_dataloader = train_dataloader
1757        self.val_dataloader = val_dataloader
1758        self.test_dataloader = test_dataloader
1759
1760    def set_loss(self, loss: Callable) -> None:
1761        """Set new loss function.
1762
1763        Parameters
1764        ----------
1765        loss: callable
1766            the new loss function
1767
1768        """
1769        self.loss = loss
1770
1771    def set_metrics(self, metrics: dict) -> None:
1772        """Set new metric.
1773
1774        Parameters
1775        ----------
1776        metrics : dict
1777            the new metric dictionary
1778
1779        """
1780        self.metrics = metrics
1781
1782    def set_transformer(self, transformer: Transformer) -> None:
1783        """Set a new transformer.
1784
1785        Parameters
1786        ----------
1787        transformer: Transformer
1788            the new transformer
1789
1790        """
1791        self.transformer = transformer
1792
1793    def set_predict_functions(
1794        self, primary_predict_function: Callable, predict_function: Callable
1795    ) -> None:
1796        """Set new predict functions.
1797
1798        Parameters
1799        ----------
1800        primary_predict_function : callable
1801            the new primary predict function
1802        predict_function : callable
1803            the new predict function
1804
1805        """
1806        self.primary_predict_function = primary_predict_function
1807        self.predict_function = predict_function
1808
1809    def _set_pseudolabels(self):
1810        """Set pseudolabels."""
1811        self.train_dataloader.dataset.set_unlabeled(True)
1812        predicted = self.predict(
1813            data=self.dataset("train"),
1814            raw_output=False,
1815            augment_n=self.augment_val,
1816            ssl_off=True,
1817        )
1818        self.train_dataloader.dataset.set_annotation(predicted.detach())
1819
1820    def _alpha(self, epoch):
1821        """Get the current pseudolabeling alpha parameter."""
1822        if epoch <= self.T1:
1823            return 0
1824        elif epoch < self.T2:
1825            return self.alpha_f * (epoch - self.T1) / (self.T2 - self.T1)
1826        else:
1827            return self.alpha_f
1828
1829    def count_classes(self, bouts: bool = False) -> Dict:
1830        """Get a dictionary of class counts in different modes.
1831
1832        Parameters
1833        ----------
1834        bouts : bool, default False
1835            if `True`, instead of frame counts segment counts are returned
1836
1837        Returns
1838        -------
1839        class_counts : dict
1840            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1841            class names and values are class counts (in frames)
1842
1843        """
1844        class_counts = {}
1845        for x in ["train", "val", "test"]:
1846            try:
1847                class_counts[x] = self.dataset(x).count_classes(bouts)
1848            except ValueError:
1849                class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()}
1850        return class_counts
1851
1852    def behaviors_dict(self) -> Dict:
1853        """Get a behavior dictionary.
1854
1855        Keys are label indices and values are label names.
1856
1857        Returns
1858        -------
1859        behaviors_dict : dict
1860            behavior dictionary
1861
1862        """
1863        return self.dataset().behaviors_dict()
1864
1865    def update_parameters(self, parameters: Dict) -> None:
1866        """Update training parameters from a dictionary.
1867
1868        Parameters
1869        ----------
1870        parameters : dict
1871            the update dictionary
1872
1873        """
1874        self.lr = parameters.get("lr", self.lr)
1875        self.parallel = parameters.get("parallel", self.parallel)
1876        self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr)
1877        self.verbose = parameters.get("verbose", self.verbose)
1878        self.device = parameters.get("device", self.device)
1879        if self.device == "auto":
1880            self.device = "cuda" if torch.cuda.is_available() else "cpu"
1881        self.augment_train = int(parameters.get("augment_train", self.augment_train))
1882        self.augment_val = int(parameters.get("augment_val", self.augment_val))
1883        ssl_weights = parameters.get("ssl_weights", self.ssl_weights)
1884        if ssl_weights is None:
1885            ssl_weights = [1 for _ in self.ssl_losses]
1886        if not isinstance(ssl_weights, Iterable):
1887            ssl_weights = [ssl_weights for _ in self.ssl_losses]
1888        self.ssl_weights = ssl_weights
1889        self.num_epochs = parameters.get("num_epochs", self.num_epochs)
1890        self.model_save_epochs = parameters.get(
1891            "model_save_epochs", self.model_save_epochs
1892        )
1893        self.model_save_path = parameters.get("model_save_path", self.model_save_path)
1894        self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel)
1895        self.T1 = int(parameters.get("pseudolabel_start", self.T1))
1896        self.T2 = int(parameters.get("alpha_growth_stop", self.T2))
1897        self.t = int(parameters.get("correction_interval", self.t))
1898        self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f)
1899        self.log_file = parameters.get("log_file", self.log_file)
1900
1901    def generate_uncertainty_score(
1902        self,
1903        classes: List,
1904        augment_n: int = 0,
1905        batch_size: int = 32,
1906        method: str = "least_confidence",
1907        predicted: torch.Tensor = None,
1908        behaviors_dict: Dict = None,
1909    ) -> Dict:
1910        """Generate frame-wise scores for active learning.
1911
1912        Parameters
1913        ----------
1914        classes : list
1915            a list of class names or indices; their confidence scores will be computed separately and stacked
1916        augment_n : int, default 0
1917            the number of augmentations to average over
1918        batch_size : int, default 32
1919            the batch size
1920        method : {"least_confidence", "entropy"}
1921            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1922            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1923        predicted : torch.Tensor, default None
1924            if not `None`, the predicted tensor to use instead of predicting from the model
1925        behaviors_dict : dict, default None
1926            if not `None`, the behaviors dictionary to use instead of the one from the dataset
1927
1928        Returns
1929        -------
1930        score_dicts : dict
1931            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1932            are score tensors
1933
1934        """
1935        dataset = self.dataset("train")
1936        if behaviors_dict is None:
1937            behaviors_dict = self.behaviors_dict()
1938        if not isinstance(dataset, BehaviorDataset):
1939            raise TypeError(
1940                f"The dataset parameter has to be either None, string, "
1941                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1942            )
1943        if predicted is None:
1944            predicted = self.predict(
1945                dataset,
1946                raw_output=True,
1947                apply_primary_function=True,
1948                augment_n=augment_n,
1949                batch_size=batch_size,
1950            )
1951        predicted = dataset.generate_full_length_prediction(predicted)
1952        if isinstance(classes[0], str):
1953            behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()}
1954            classes = [behaviors_dict_inverse[c] for c in classes]
1955        for v_id, v in predicted.items():
1956            for clip_id, vv in v.items():
1957                if method == "least_confidence":
1958                    predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5]
1959                elif method == "entropy":
1960                    predicted[v_id][clip_id][vv != -100] = (
1961                        -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv)
1962                    )[vv != -100]
1963                elif method == "random":
1964                    predicted[v_id][clip_id] = torch.rand(vv.shape)
1965                else:
1966                    raise ValueError(
1967                        f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']"
1968                    )
1969                predicted[v_id][clip_id][vv == -100] = 0
1970
1971        predicted = {
1972            v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
1973            for v_id, video_dict in predicted.items()
1974        }
1975        return predicted
1976
1977    def generate_bald_score(
1978        self,
1979        classes: List,
1980        augment_n: int = 0,
1981        batch_size: int = 32,
1982        num_models: int = 10,
1983        kernel_size: int = 11,
1984    ) -> Dict:
1985        """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning.
1986
1987        Parameters
1988        ----------
1989        classes : list
1990            a list of class names or indices; their confidence scores will be computed separately and stacked
1991        augment_n : int, default 0
1992            the number of augmentations to average over
1993        batch_size : int, default 32
1994            the batch size
1995        num_models : int, default 10
1996            the number of dropout masks to apply
1997        kernel_size : int, default 11
1998            the size of the smoothing gaussian kernel
1999
2000        Returns
2001        -------
2002        score_dicts : dict
2003            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
2004            are score tensors
2005
2006        """
2007        dataset = self.dataset("train")
2008        dataset = self._get_dataset(dataset)
2009        if not isinstance(dataset, BehaviorDataset):
2010            raise TypeError(
2011                f"The dataset parameter has to be either None, string, "
2012                f"BehaviorDataset or Dataloader, got {type(dataset)}"
2013            )
2014        predictions = []
2015        for _ in range(num_models):
2016            predicted = self.predict(
2017                dataset,
2018                raw_output=True,
2019                apply_primary_function=True,
2020                augment_n=augment_n,
2021                batch_size=batch_size,
2022                train_mode=True,
2023            )
2024            predicted = dataset.generate_full_length_prediction(predicted)
2025            if isinstance(classes[0], str):
2026                behaviors_dict_inverse = {
2027                    v: k for k, v in self.behaviors_dict().items()
2028                }
2029                classes = [behaviors_dict_inverse[c] for c in classes]
2030            for v_id, v in predicted.items():
2031                for clip_id, vv in v.items():
2032                    vv[vv != -100] = (vv[vv != -100] > 0.5).int().float()
2033                    predicted[v_id][clip_id] = vv
2034            predicted = {
2035                v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
2036                for v_id, video_dict in predicted.items()
2037            }
2038            predictions.append(predicted)
2039        result = {v_id: {} for v_id in predictions[0]}
2040        r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1)
2041        gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r]
2042        gauss = [x / sum(gauss) for x in gauss]
2043        kernel = torch.FloatTensor([[gauss]])
2044        for v_id in predictions[0]:
2045            for clip_id in predictions[0][v_id]:
2046                consensus = (
2047                    (
2048                        torch.mean(
2049                            torch.stack([x[v_id][clip_id] for x in predictions]), dim=0
2050                        )
2051                        > 0.5
2052                    )
2053                    .int()
2054                    .float()
2055                )
2056                consensus[predictions[0][v_id][clip_id] == -100] = -100
2057                result[v_id][clip_id] = torch.zeros(consensus.shape)
2058                for x in predictions:
2059                    result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int()
2060                result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models
2061                res = torch.zeros(result[v_id][clip_id].shape)
2062                for i in range(len(classes)):
2063                    res[i, floor(kernel_size // 2) : -floor(kernel_size // 2)] = (
2064                        torch.nn.functional.conv1d(
2065                            result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0),
2066                            kernel,
2067                        )[0, ...]
2068                    )
2069                result[v_id][clip_id] = res
2070        return result
2071
2072    def get_normalization_stats(self) -> Optional[Dict]:
2073        """Get the normalization statistics of the dataset.
2074
2075        Returns
2076        -------
2077        stats : dict
2078            a dictionary containing the mean and standard deviation of the dataset
2079
2080        """
2081        return self.train_dataloader.dataset.stats

A universal trainer class that performs training, evaluation and prediction for all types of tasks and data.

Task( train_dataloader: torch.utils.data.dataloader.DataLoader, model: Union[torch.nn.modules.module.Module, dlc2action.model.base_model.Model], loss: Callable[[torch.Tensor, torch.Tensor], float], num_epochs: int = 0, transformer: dlc2action.transformer.base_transformer.Transformer = None, ssl_losses: List = None, ssl_weights: List = None, lr: float = 0.001, weight_decay: float = 0, metrics: Dict = None, val_dataloader: torch.utils.data.dataloader.DataLoader = None, test_dataloader: torch.utils.data.dataloader.DataLoader = None, optimizer: torch.optim.optimizer.Optimizer = None, device: str = 'cuda', verbose: bool = True, log_file: Optional[str] = None, augment_train: int = 1, augment_val: int = 0, validation_interval: int = 1, predict_function: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, primary_predict_function: Callable = None, exclusive: bool = True, ignore_tags: bool = True, threshold: float = 0.5, model_save_path: str = None, model_save_epochs: int = 5, pseudolabel: bool = False, pseudolabel_start: int = 100, correction_interval: int = 2, pseudolabel_alpha_f: float = 3, alpha_growth_stop: int = 600, parallel: bool = False, skip_metrics: List = None)
 87    def __init__(
 88        self,
 89        train_dataloader: DataLoader,
 90        model: Union[nn.Module, Model],
 91        loss: Callable[[torch.Tensor, torch.Tensor], float],
 92        num_epochs: int = 0,
 93        transformer: Transformer = None,
 94        ssl_losses: List = None,
 95        ssl_weights: List = None,
 96        lr: float = 1e-3,
 97        weight_decay: float = 0,
 98        metrics: Dict = None,
 99        val_dataloader: DataLoader = None,
100        test_dataloader: DataLoader = None,
101        optimizer: Optimizer = None,
102        device: str = "cuda",
103        verbose: bool = True,
104        log_file: Union[str, None] = None,
105        augment_train: int = 1,
106        augment_val: int = 0,
107        validation_interval: int = 1,
108        predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None,
109        primary_predict_function: Callable = None,
110        exclusive: bool = True,
111        ignore_tags: bool = True,
112        threshold: float = 0.5,
113        model_save_path: str = None,
114        model_save_epochs: int = 5,
115        pseudolabel: bool = False,
116        pseudolabel_start: int = 100,
117        correction_interval: int = 2,
118        pseudolabel_alpha_f: float = 3,
119        alpha_growth_stop: int = 600,
120        parallel: bool = False,
121        skip_metrics: List = None,
122    ) -> None:
123        """Initialize the class.
124
125        Parameters
126        ----------
127        train_dataloader : torch.utils.data.DataLoader
128            a training dataloader
129        model : dlc2action.model.base_model.Model
130            a model
131        loss : callable
132            a loss function
133        num_epochs : int, default 0
134            the number of epochs
135        transformer : dlc2action.transformer.base_transformer.Transformer, optional
136            a transformer
137        ssl_losses : list, optional
138            a list of SSL losses
139        ssl_weights : list, optional
140            a list of SSL weights (if not provided initializes to 1)
141        lr : float, default 1e-3
142            learning rate
143        weight_decay : float, default 0
144            weight decay
145        metrics : dict, optional
146            a list of metric functions
147        val_dataloader : torch.utils.data.DataLoader, optional
148            a validation dataloader
149        test_dataloader : torch.utils.data.DataLoader, optional
150            a test dataloader
151        optimizer : torch.optim.Optimizer, optional
152            an optimizer (`Adam` by default)
153        device : str, default 'cuda'
154            the device to train the model on
155        verbose : bool, default True
156            if `True`, the process is described in standard output
157        log_file : str, optional
158            the path to a text file where the process will be logged
159        augment_train : {1, 0}
160            number of augmentations to apply at training
161        augment_val : int, default 0
162            number of augmentations to apply at validation
163        validation_interval : int, default 1
164            every time this number of epochs passes, validation metrics are computed
165        predict_function : callable, optional
166            a function that maps probabilities to class predictions (if not provided, a default is generated)
167        primary_predict_function : callable, optional
168            a function that maps model output to probabilities (if not provided, initialized as identity)
169        exclusive : bool, default True
170            set to False for multi-label classification
171        ignore_tags : bool, default False
172            if `True`, samples with different meta tags will be mixed in batches
173        threshold : float, default 0.5
174            the threshold used for multi-label classification default prediction function
175        model_save_path : str, optional
176            the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path
177            is not provided)
178        model_save_epochs : int, default 5
179            the interval for saving the model checkpoints (the last epoch is always saved)
180        pseudolabel : bool, default False
181            if True, the pseudolabeling procedure will be applied
182        pseudolabel_start : int, default 100
183            pseudolabeling starts after this epoch
184        correction_interval : int, default 1
185            after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and
186            new pseudolabels are generated
187        pseudolabel_alpha_f : float, default 3
188            the maximum value of pseudolabeling alpha
189        alpha_growth_stop : int, default 600
190            pseudolabeling alpha stops growing after this epoch
191        parallel : bool, default False
192            if True, the model is trained on multiple GPUs
193        skip_metrics : list, optional
194            a list of metrics to skip
195
196        """
197        # pseudolabeling might be buggy right now -- not using it!
198        if skip_metrics is None:
199            skip_metrics = []
200        self.train_dataloader = train_dataloader
201        self.val_dataloader = val_dataloader
202        self.test_dataloader = test_dataloader
203        self.transformer = transformer
204        self.num_epochs = num_epochs
205        self.skip_metrics = skip_metrics
206        self.verbose = verbose
207        self.augment_train = int(augment_train)
208        self.augment_val = int(augment_val)
209        self.ignore_tags = ignore_tags
210        self.validation_interval = int(validation_interval)
211        self.log_file = log_file
212        self.loss = loss
213        self.model_save_path = model_save_path
214        self.model_save_epochs = model_save_epochs
215        self.epoch = 0
216
217        if metrics is None:
218            metrics = {}
219        self.metrics = metrics
220
221        if optimizer is None:
222            optimizer = Adam
223
224        if ssl_weights is None:
225            ssl_weights = [1 for _ in ssl_losses]
226        if not isinstance(ssl_weights, Iterable):
227            ssl_weights = [ssl_weights for _ in ssl_losses]
228        self.ssl_weights = ssl_weights
229
230        self.optimizer_class = optimizer
231        self.lr = lr
232        self.weight_decay = weight_decay
233        if not isinstance(model, Model):
234            self.model = LoadedModel(model=model)
235        else:
236            self.set_model(model)
237        self.parallel = parallel
238
239        if self.transformer is None:
240            self.augment_val = 0
241            self.augment_train = 0
242            self.transformer = EmptyTransformer()
243
244        if self.augment_train > 1:
245            warnings.warn(
246                'The "augment_train" parameter is too large -> setting it to 1.'
247            )
248            self.augment_train = 1
249
250        try:
251            if device == "auto":
252                device = "cuda" if torch.cuda.is_available() else "cpu"
253            self.device = torch.device(device)
254        except:
255            raise ("The format of the device is incorrect")
256
257        if ssl_losses is None:
258            self.ssl_losses = [lambda x, y: 0]
259        else:
260            self.ssl_losses = ssl_losses
261
262        if primary_predict_function is None:
263            if exclusive:
264                primary_predict_function = lambda x: nn.Softmax(x, dim=1)
265            else:
266                primary_predict_function = lambda x: torch.sigmoid(x)
267        self.primary_predict_function = primary_predict_function
268
269        if predict_function is None:
270            if exclusive:
271                self.predict_function = lambda x: torch.max(x.data, 1)[1]
272            else:
273                self.predict_function = lambda x: (x > threshold).int()
274        else:
275            self.predict_function = predict_function
276
277        self.pseudolabel = pseudolabel
278        self.alpha_f = pseudolabel_alpha_f
279        self.T2 = alpha_growth_stop
280        self.T1 = pseudolabel_start
281        self.t = correction_interval
282        if self.T2 <= self.T1:
283            raise ValueError(
284                f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got "
285                f"{pseudolabel_start=} and {alpha_growth_stop=}"
286            )
287        self.decision_thresholds = [0.5 for x in self.behaviors_dict()]

Initialize the class.

Parameters

train_dataloader : torch.utils.data.DataLoader a training dataloader model : dlc2action.model.base_model.Model a model loss : callable a loss function num_epochs : int, default 0 the number of epochs transformer : dlc2action.transformer.base_transformer.Transformer, optional a transformer ssl_losses : list, optional a list of SSL losses ssl_weights : list, optional a list of SSL weights (if not provided initializes to 1) lr : float, default 1e-3 learning rate weight_decay : float, default 0 weight decay metrics : dict, optional a list of metric functions val_dataloader : torch.utils.data.DataLoader, optional a validation dataloader test_dataloader : torch.utils.data.DataLoader, optional a test dataloader optimizer : torch.optim.Optimizer, optional an optimizer (Adam by default) device : str, default 'cuda' the device to train the model on verbose : bool, default True if True, the process is described in standard output log_file : str, optional the path to a text file where the process will be logged augment_train : {1, 0} number of augmentations to apply at training augment_val : int, default 0 number of augmentations to apply at validation validation_interval : int, default 1 every time this number of epochs passes, validation metrics are computed predict_function : callable, optional a function that maps probabilities to class predictions (if not provided, a default is generated) primary_predict_function : callable, optional a function that maps model output to probabilities (if not provided, initialized as identity) exclusive : bool, default True set to False for multi-label classification ignore_tags : bool, default False if True, samples with different meta tags will be mixed in batches threshold : float, default 0.5 the threshold used for multi-label classification default prediction function model_save_path : str, optional the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path is not provided) model_save_epochs : int, default 5 the interval for saving the model checkpoints (the last epoch is always saved) pseudolabel : bool, default False if True, the pseudolabeling procedure will be applied pseudolabel_start : int, default 100 pseudolabeling starts after this epoch correction_interval : int, default 1 after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and new pseudolabels are generated pseudolabel_alpha_f : float, default 3 the maximum value of pseudolabeling alpha alpha_growth_stop : int, default 600 pseudolabeling alpha stops growing after this epoch parallel : bool, default False if True, the model is trained on multiple GPUs skip_metrics : list, optional a list of metrics to skip

train_dataloader
val_dataloader
test_dataloader
transformer
num_epochs
skip_metrics
verbose
augment_train
augment_val
ignore_tags
validation_interval
log_file
loss
model_save_path
model_save_epochs
epoch
metrics
ssl_weights
optimizer_class
lr
weight_decay
parallel
primary_predict_function
pseudolabel
alpha_f
T2
T1
t
decision_thresholds
def save_checkpoint(self, checkpoint_path: str) -> None:
289    def save_checkpoint(self, checkpoint_path: str) -> None:
290        """Save a general checkpoint.
291
292        Parameters
293        ----------
294        checkpoint_path : str
295            the path where the checkpoint will be saved
296
297        """
298        torch.save(
299            {
300                "epoch": self.epoch,
301                "model_state_dict": self.model.state_dict(),
302                "optimizer_state_dict": self.optimizer.state_dict(),
303            },
304            checkpoint_path,
305        )

Save a general checkpoint.

Parameters

checkpoint_path : str the path where the checkpoint will be saved

def load_from_checkpoint( self, checkpoint_path, only_model: bool = False, load_strict: bool = True) -> None:
307    def load_from_checkpoint(
308        self, checkpoint_path, only_model: bool = False, load_strict: bool = True
309    ) -> None:
310        """Load from a checkpoint.
311
312        Parameters
313        ----------
314        checkpoint_path : str
315            the path to the checkpoint
316        only_model : bool, default False
317            if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state
318            dictionary)
319        load_strict : bool, default True
320            if `True`, any inconsistencies in state dictionaries are regarded as errors
321
322        """
323        if checkpoint_path is None:
324            return
325        checkpoint = torch.load(
326            checkpoint_path, map_location=self.device, weights_only=False
327        )
328        self.model.to(self.device)
329        self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict)
330        if not only_model:
331            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
332            self.epoch = checkpoint["epoch"]

Load from a checkpoint.

Parameters

checkpoint_path : str the path to the checkpoint only_model : bool, default False if True, only the model state dictionary will be loaded (and not the epoch and the optimizer state dictionary) load_strict : bool, default True if True, any inconsistencies in state dictionaries are regarded as errors

def save_model(self, save_path: str) -> None:
334    def save_model(self, save_path: str) -> None:
335        """Save the model state dictionary.
336
337        Parameters
338        ----------
339        save_path : str
340            the path where the state will be saved
341
342        """
343        torch.save(self.model.state_dict(), save_path)
344        print("saved the model successfully")

Save the model state dictionary.

Parameters

save_path : str the path where the state will be saved

def train( self, trial: optuna.trial._trial.Trial = None, optimized_metric: str = None, to_ram: bool = False, autostop_interval: int = 30, autostop_threshold: float = 0.001, autostop_metric: str = None, main_task_on: bool = True, ssl_on: bool = True, temporal_subsampling_size: int = None, loading_bar: bool = False) -> Tuple:
621    def train(
622        self,
623        trial: Trial = None,
624        optimized_metric: str = None,
625        to_ram: bool = False,
626        autostop_interval: int = 30,
627        autostop_threshold: float = 0.001,
628        autostop_metric: str = None,
629        main_task_on: bool = True,
630        ssl_on: bool = True,
631        temporal_subsampling_size: int = None,
632        loading_bar: bool = False,
633    ) -> Tuple:
634        """Train the task and return a log of epoch-average loss and metric.
635
636        You can use the autostop parameters to finish training when the parameters are not improving. It will be
637        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
638        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
639        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
640
641        Parameters
642        ----------
643        trial : Trial
644            an `optuna` trial (for hyperparameter searches)
645        optimized_metric : str
646            the name of the metric being optimized (for hyperparameter searches)
647        to_ram : bool, default False
648            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
649            if the dataset is too large)
650        autostop_interval : int, default 50
651            the number of epochs to average the autostop metric over
652        autostop_threshold : float, default 0.001
653            the autostop difference threshold
654        autostop_metric : str, optional
655            the autostop metric (can be any one of the tracked metrics of `'loss'`)
656        main_task_on : bool, default True
657            if `False`, the main task (action segmentation) will not be used in training
658        ssl_on : bool, default True
659            if `False`, the SSL task will not be used in training
660        temporal_subsampling_size : int, optional
661            if not `None`, the temporal subsampling will be used in training with the given size
662        loading_bar : bool, default False
663            if `True`, a loading bar will be displayed
664
665        Returns
666        -------
667        loss_log: list
668            a list of float loss function values for each epoch
669        metrics_log: dict
670            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
671            names, values are lists of function values)
672
673        """
674        if self.parallel and not isinstance(self.model, nn.DataParallel):
675            self.model = MyDataParallel(self.model)
676        self.model.to(self.device)
677        assert autostop_metric in [None, "loss"] + list(self.metrics)
678        autostop_interval //= self.validation_interval
679        if trial is not None and optimized_metric is None:
680            raise ValueError(
681                "You need to provide the optimized metric name (optimized_metric parameter) "
682                "for optuna pruning to work!"
683            )
684        if to_ram:
685            print("transferring datasets to RAM...")
686            self.train_dataloader.dataset.to_ram()
687            if self.val_dataloader is not None and len(self.val_dataloader) > 0:
688                self.val_dataloader.dataset.to_ram()
689        loss_log = {"train": [], "val": []}
690        metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])}
691        if not main_task_on:
692            self.model.main_task_off()
693        if not ssl_on:
694            self.model.ssl_off()
695        while self.epoch < self.num_epochs:
696            self.epoch += 1
697            unlabeled = None
698            alpha = 1
699            if self.pseudolabel:
700                if self.epoch >= self.T1:
701                    unlabeled = (self.epoch - self.T1) % self.t != 0
702                    if unlabeled:
703                        alpha = self._alpha(self.epoch)
704                else:
705                    unlabeled = False
706            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
707                dataloader=self.train_dataloader,
708                mode="train",
709                augment_n=self.augment_train,
710                unlabeled=unlabeled,
711                alpha=alpha,
712                temporal_subsampling_size=temporal_subsampling_size,
713                verbose=loading_bar,
714            )
715            loss_log["train"].append(epoch_loss)
716            epoch_string = f"[epoch {self.epoch}]"
717            if self.pseudolabel:
718                if unlabeled:
719                    epoch_string += " (unlabeled)"
720                else:
721                    epoch_string += " (labeled)"
722            epoch_string += f": loss {epoch_loss:.4f}"
723
724            if len(epoch_ssl_loss) != 0:
725                for key, value in sorted(epoch_ssl_loss.items()):
726                    metrics_log["train"][f"ssl_loss_{key}"].append(value)
727                    epoch_string += f", ssl_loss_{key} {value:.4f}"
728
729            for metric_name, metric_value in sorted(epoch_metrics.items()):
730                if metric_name not in self.skip_metrics:
731                    if isinstance(metric_value, list):
732                        metric_value = torch.mean(torch.Tensor(metric_value))
733                    epoch_string += f", {metric_name} {metric_value:.3f}"
734                    metrics_log["train"][metric_name].append(metric_value)
735
736            if (
737                self.val_dataloader is not None
738                and self.epoch % self.validation_interval == 0
739            ):
740                with torch.no_grad():
741                    epoch_string += "\n"
742                    (
743                        val_epoch_loss,
744                        val_epoch_ssl_loss,
745                        val_epoch_metrics,
746                    ) = self._run_epoch(
747                        dataloader=self.val_dataloader,
748                        mode="val",
749                        augment_n=self.augment_val,
750                    )
751                    loss_log["val"].append(val_epoch_loss)
752                    epoch_string += f"validation: loss {val_epoch_loss:.4f}"
753
754                    if len(val_epoch_ssl_loss) != 0:
755                        for key, value in sorted(val_epoch_ssl_loss.items()):
756                            metrics_log["val"][f"ssl_loss_{key}"].append(value)
757                            epoch_string += f", ssl_loss_{key} {value:.4f}"
758
759                    for metric_name, metric_value in sorted(val_epoch_metrics.items()):
760                        if isinstance(metric_value, list):
761                            metric_value = torch.mean(torch.Tensor(metric_value))
762                        metrics_log["val"][metric_name].append(metric_value)
763                        epoch_string += f", {metric_name} {metric_value:.3f}"
764
765                if trial is not None:
766                    if optimized_metric not in metrics_log["val"]:
767                        raise ValueError(
768                            f"The {optimized_metric} metric set for optimization is not being logged!"
769                        )
770                    trial.report(metrics_log["val"][optimized_metric][-1], self.epoch)
771                    if trial.should_prune():
772                        raise TrialPruned()
773
774            if self.verbose:
775                print(epoch_string)
776
777            if self.log_file is not None:
778                with open(self.log_file, "a") as f:
779                    f.write(epoch_string + "\n")
780
781            save_condition = (
782                (self.model_save_epochs != 0)
783                and (self.epoch % self.model_save_epochs == 0)
784            ) or (self.epoch == self.num_epochs)
785
786            if self.epoch > 0 and save_condition and self.model_save_path is not None:
787                epoch_s = str(self.epoch).zfill(len(str(self.num_epochs)))
788                self.save_checkpoint(
789                    os.path.join(self.model_save_path, f"epoch{epoch_s}.pt")
790                )
791
792            if self.pseudolabel and self.epoch >= self.T1 and not unlabeled:
793                self._set_pseudolabels()
794
795            if autostop_metric == "loss":
796                if len(loss_log["val"]) > autostop_interval * 2:
797                    if (
798                        np.mean(loss_log["val"][-autostop_interval:])
799                        < np.mean(
800                            loss_log["val"][-2 * autostop_interval : -autostop_interval]
801                        )
802                        + autostop_threshold
803                    ):
804                        break
805            elif autostop_metric in metrics_log["val"]:
806                if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2:
807                    if (
808                        np.mean(
809                            metrics_log["val"][autostop_metric][-autostop_interval:]
810                        )
811                        < np.mean(
812                            metrics_log["val"][autostop_metric][
813                                -2 * autostop_interval : -autostop_interval
814                            ]
815                        )
816                        + autostop_threshold
817                    ):
818                        break
819
820        metrics_log = {k: dict(v) for k, v in metrics_log.items()}
821
822        return loss_log, metrics_log

Train the task and return a log of epoch-average loss and metric.

You can use the autostop parameters to finish training when the parameters are not improving. It will be stopped if the average value of autostop_metric over the last autostop_interval epochs is smaller than the average over the previous autostop_interval epochs + autostop_threshold. For example, if the current epoch is 120 and autostop_interval is 50, the averages over epochs 70-120 and 20-70 will be compared.

Parameters

trial : Trial an optuna trial (for hyperparameter searches) optimized_metric : str the name of the metric being optimized (for hyperparameter searches) to_ram : bool, default False if True, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes if the dataset is too large) autostop_interval : int, default 50 the number of epochs to average the autostop metric over autostop_threshold : float, default 0.001 the autostop difference threshold autostop_metric : str, optional the autostop metric (can be any one of the tracked metrics of 'loss') main_task_on : bool, default True if False, the main task (action segmentation) will not be used in training ssl_on : bool, default True if False, the SSL task will not be used in training temporal_subsampling_size : int, optional if not None, the temporal subsampling will be used in training with the given size loading_bar : bool, default False if True, a loading bar will be displayed

Returns

loss_log: list a list of float loss function values for each epoch metrics_log: dict a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric names, values are lists of function values)

def evaluate_prediction( self, prediction: Union[torch.Tensor, Dict], data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str] = None, batch_size: int = 32, indices: list = None) -> Tuple:
824    def evaluate_prediction(
825        self,
826        prediction: Union[torch.Tensor, Dict],
827        data: Union[DataLoader, BehaviorDataset, str] = None,
828        batch_size: int = 32,
829        indices: list = None,
830    ) -> Tuple:
831        """Compute metrics for a prediction.
832
833        Parameters
834        ----------
835        prediction : torch.Tensor
836            the prediction
837        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
838            the data the prediction was made for (if not provided, take the validation dataset)
839        batch_size : int, default 32
840            the batch size
841
842        Returns
843        -------
844        loss : float
845            the average value of the loss function
846        metric : dict
847            a dictionary of average values of metric functions
848        """
849
850        if type(data) is not DataLoader:
851            dataset = self._get_dataset(data)
852            data = DataLoader(dataset, shuffle=False, batch_size=batch_size)
853        epoch_loss = 0
854        if isinstance(prediction, dict):
855            num_classes = len(self.behaviors_dict())
856            length = dataset.len_segment()
857            coords = dataset.annotation_store.get_original_coordinates()
858            for batch in data:
859                main_target = batch["target"]
860                pr_coords = coords[batch["index"]]
861                predicted = torch.zeros((len(pr_coords), num_classes, length))
862                for i, c in enumerate(pr_coords):
863                    video_id = dataset.input_store.get_video_id(c)
864                    clip_id = dataset.input_store.get_clip_id(c)
865                    start, end = dataset.input_store.get_clip_start_end(c)
866                    beh_ind = list(prediction[video_id]["classes"].keys())
867                    pred_tmp = prediction[video_id][clip_id][beh_ind, :]
868                    predicted[i, :, : end - start] = pred_tmp[:, start:end]
869                self._compute(
870                    [],
871                    [],
872                    predicted,
873                    main_target,
874                    skip_loss=True,
875                    tag=batch.get("tag"),
876                    apply_primary_function=False,
877                )
878        else:
879            for batch in data:
880                main_target = batch["target"]
881                predicted = prediction[batch["index"]]
882                if not indices is None:
883                    indices_new = [indices.index(i) for i in range(len(indices))]
884                    predicted = predicted[:, indices_new]
885                self._compute(
886                    [],
887                    [],
888                    predicted,
889                    main_target,
890                    skip_loss=True,
891                    tag=batch.get("tag"),
892                    apply_primary_function=False,
893                )
894        epoch_metrics = self._calculate_metrics()
895        # strings = [
896        #     f"{metric_name} {metric_value:.3f}"
897        #     for metric_name, metric_value in epoch_metrics.items()
898        # ]
899        # val_string = ", ".join(sorted(strings))
900        # print(val_string)
901        return epoch_loss, epoch_metrics

Compute metrics for a prediction.

Parameters

prediction : torch.Tensor the prediction data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data the prediction was made for (if not provided, take the validation dataset) batch_size : int, default 32 the batch size

Returns

loss : float the average value of the loss function metric : dict a dictionary of average values of metric functions

def evaluate( self, data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str] = None, augment_n: int = 0, batch_size: int = 32, verbose: bool = True) -> Tuple:
903    def evaluate(
904        self,
905        data: Union[DataLoader, BehaviorDataset, str] = None,
906        augment_n: int = 0,
907        batch_size: int = 32,
908        verbose: bool = True,
909    ) -> Tuple:
910        """Evaluate the Task model.
911
912        Parameters
913        ----------
914        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
915            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
916        augment_n : int, default 0
917            the number of augmentations to average results over
918        batch_size : int, default 32
919            the batch size
920        verbose : bool, default True
921            if True, the process is reported to standard output
922
923        Returns
924        -------
925        loss : float
926            the average value of the loss function
927        ssl_loss : float
928            the average value of the SSL loss function
929        metric : dict
930            a dictionary of average values of metric functions
931
932        """
933        if self.parallel and not isinstance(self.model, nn.DataParallel):
934            self.model = MyDataParallel(self.model)
935        self.model.to(self.device)
936        if type(data) is not DataLoader:
937            data = self._get_dataset(data)
938            data = DataLoader(data, shuffle=False, batch_size=batch_size)
939        with torch.no_grad():
940            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
941                dataloader=data, mode="val", augment_n=augment_n, verbose=verbose
942            )
943        val_string = f"loss {epoch_loss:.4f}"
944        for metric_name, metric_value in sorted(epoch_metrics.items()):
945            val_string += f", {metric_name} {metric_value:.3f}"
946        print(val_string)
947        return epoch_loss, epoch_ssl_loss, epoch_metrics

Evaluate the Task model.

Parameters

data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) augment_n : int, default 0 the number of augmentations to average results over batch_size : int, default 32 the batch size verbose : bool, default True if True, the process is reported to standard output

Returns

loss : float the average value of the loss function ssl_loss : float the average value of the SSL loss function metric : dict a dictionary of average values of metric functions

def predict( self, data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str] = None, raw_output: bool = False, apply_primary_function: bool = True, augment_n: int = 0, batch_size: int = 32, train_mode: bool = False, to_ram: bool = False, embedding: bool = False) -> torch.Tensor:
 949    def predict(
 950        self,
 951        data: Union[DataLoader, BehaviorDataset, str] = None,
 952        raw_output: bool = False,
 953        apply_primary_function: bool = True,
 954        augment_n: int = 0,
 955        batch_size: int = 32,
 956        train_mode: bool = False,
 957        to_ram: bool = False,
 958        embedding: bool = False,
 959    ) -> torch.Tensor:
 960        """Make a prediction with the Task model.
 961
 962        Parameters
 963        ----------
 964        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional
 965            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 966        raw_output : bool, default False
 967            if `True`, the raw predicted probabilities are returned
 968        apply_primary_function : bool, default True
 969            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
 970            the input)
 971        augment_n : int, default 0
 972            the number of augmentations to average results over
 973        batch_size : int, default 32
 974            the batch size
 975        train_mode : bool, default False
 976            if `True`, the model is used in training mode (affects dropout and batch normalization layers)
 977        to_ram : bool, default False
 978            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
 979            if the dataset is too large)
 980        embedding : bool, default False
 981            if `True`, the output of feature extractor is returned, ignoring the prediction module of the model
 982
 983        Returns
 984        -------
 985        prediction : torch.Tensor
 986            a prediction for the input data
 987
 988        """
 989        if self.parallel and not isinstance(self.model, nn.DataParallel):
 990            self.model = MyDataParallel(self.model)
 991        self.model.to(self.device)
 992        if train_mode:
 993            self.model.train()
 994        else:
 995            self.model.eval()
 996        output = []
 997        if embedding:
 998            raw_output = True
 999            apply_primary_function = True
1000        if type(data) is not DataLoader:
1001            data = self._get_dataset(data)
1002            if to_ram:
1003                print("transferring dataset to RAM...")
1004                data.to_ram()
1005            data = DataLoader(data, shuffle=False, batch_size=batch_size)
1006        self.model.ssl_off()
1007        with torch.no_grad():
1008            for batch in tqdm(data):
1009                input = {k: v.to(self.device) for k, v in batch["input"].items()}
1010                predicted, _, _ = self._get_prediction(
1011                    input,
1012                    batch.get("tag"),
1013                    augment_n=augment_n,
1014                    embedding=embedding,
1015                )
1016                if apply_primary_function:
1017                    predicted = self.primary_predict_function(predicted)
1018                if not raw_output:
1019                    predicted = self.predict_function(predicted)
1020                output.append(predicted.detach().cpu())
1021        self.model.ssl_on()
1022        output = torch.cat(output).detach()
1023        return output

Make a prediction with the Task model.

Parameters

data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) raw_output : bool, default False if True, the raw predicted probabilities are returned apply_primary_function : bool, default True if True, the primary predict function is applied (to map the model output into a shape corresponding to the input) augment_n : int, default 0 the number of augmentations to average results over batch_size : int, default 32 the batch size train_mode : bool, default False if True, the model is used in training mode (affects dropout and batch normalization layers) to_ram : bool, default False if True, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes if the dataset is too large) embedding : bool, default False if True, the output of feature extractor is returned, ignoring the prediction module of the model

Returns

prediction : torch.Tensor a prediction for the input data

def dataset(self, mode='train') -> dlc2action.data.dataset.BehaviorDataset:
1025    def dataset(self, mode="train") -> BehaviorDataset:
1026        """Get a dataset.
1027
1028        Parameters
1029        ----------
1030        mode : {'train', 'val', 'test}
1031            the dataset to get
1032
1033        Returns
1034        -------
1035        dataset : dlc2action.data.dataset.BehaviorDataset
1036            the dataset
1037
1038        """
1039        dataloader = self.dataloader(mode)
1040        if dataloader is None:
1041            raise ValueError("The length of the dataloader is 0!")
1042        return dataloader.dataset

Get a dataset.

Parameters

mode : {'train', 'val', 'test} the dataset to get

Returns

dataset : dlc2action.data.dataset.BehaviorDataset the dataset

def dataloader(self, mode: str = 'train') -> torch.utils.data.dataloader.DataLoader:
1044    def dataloader(self, mode: str = "train") -> DataLoader:
1045        """Get a dataloader.
1046
1047        Parameters
1048        ----------
1049        mode : {'train', 'val', 'test}
1050            the dataset to get
1051
1052        Returns
1053        -------
1054        dataloader : torch.utils.data.DataLoader
1055            the dataloader
1056
1057        """
1058        if mode == "train":
1059            return self.train_dataloader
1060        elif mode == "val":
1061            return self.val_dataloader
1062        elif mode == "test":
1063            return self.test_dataloader
1064        else:
1065            raise ValueError(
1066                f'The {mode} mode is not recognized, please choose from "train", "val" or "test"'
1067            )

Get a dataloader.

Parameters

mode : {'train', 'val', 'test} the dataset to get

Returns

dataloader : torch.utils.data.DataLoader the dataloader

def generate_full_length_prediction(self, dataset=None, batch_size=32, augment_n=10):
1097    def generate_full_length_prediction(
1098        self, dataset=None, batch_size=32, augment_n=10
1099    ):
1100        """Compile a prediction for the original input sequences.
1101
1102        Parameters
1103        ----------
1104        dataset : BehaviorDataset, optional
1105            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1106        batch_size : int, default 32
1107            the batch size
1108        augment_n : int, default 10
1109            the number of augmentations to average results over
1110
1111        Returns
1112        -------
1113        prediction : dict
1114            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1115            are prediction tensors
1116
1117        """
1118        dataset = self._get_dataset(dataset)
1119        if not isinstance(dataset, BehaviorDataset):
1120            raise TypeError(
1121                f"The dataset parameter has to be either None, string, "
1122                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1123            )
1124        predicted = self.predict(
1125            dataset,
1126            raw_output=True,
1127            apply_primary_function=True,
1128            augment_n=augment_n,
1129            batch_size=batch_size,
1130        )
1131        predicted = dataset.generate_full_length_prediction(predicted)
1132        predicted = {
1133            v_id: {
1134                clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze()
1135                for clip_id, v in video_dict.items()
1136            }
1137            for v_id, video_dict in predicted.items()
1138        }
1139        return predicted

Compile a prediction for the original input sequences.

Parameters

dataset : BehaviorDataset, optional the dataset to generate a prediction for (if None, generate for the validation dataset) batch_size : int, default 32 the batch size augment_n : int, default 10 the number of augmentations to average results over

Returns

prediction : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are prediction tensors

def generate_submission( self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10):
1141    def generate_submission(
1142        self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10
1143    ):
1144        """Generate a MABe-22 style submission dictionary.
1145
1146        Parameters
1147        ----------
1148        frame_number_map_file : str
1149            the path to the frame number map file
1150        dataset : BehaviorDataset, optional
1151            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1152        batch_size : int, default 32
1153            the batch size
1154        augment_n : int, default 10
1155            the number of augmentations to average results over
1156
1157        Returns
1158        -------
1159        submission : dict
1160            a dictionary with frame number mapping and embeddings
1161
1162        """
1163        dataset = self._get_dataset(dataset)
1164        if not isinstance(dataset, BehaviorDataset):
1165            raise TypeError(
1166                f"The dataset parameter has to be either None, string, "
1167                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1168            )
1169        predicted = self.predict(
1170            dataset,
1171            raw_output=True,
1172            apply_primary_function=True,
1173            augment_n=augment_n,
1174            batch_size=batch_size,
1175            embedding=True,
1176        )
1177        predicted = dataset.generate_full_length_prediction(predicted)
1178        frame_map = np.load(frame_number_map_file, allow_pickle=True).item()
1179        length = frame_map[list(frame_map.keys())[-1]][1]
1180        embeddings = None
1181        for video_id in list(predicted.keys()):
1182            split = video_id.split("--")
1183            if len(split) != 2 or len(predicted[video_id]) > 1:
1184                raise RuntimeError(
1185                    "Generating submissions is only implemented for the mabe22 dataset!"
1186                )
1187            if split[1] not in frame_map:
1188                raise RuntimeError(f"The {split[1]} video is not in the frame map file")
1189            v_id = split[1]
1190            clip_id = list(predicted[video_id].keys())[0]
1191            if embeddings is None:
1192                embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0]))
1193            start, end = frame_map[v_id]
1194            embeddings[start:end, :] = predicted[video_id][clip_id].T
1195            predicted.pop(video_id)
1196        predicted = {
1197            "frame_number_map": frame_map,
1198            "embeddings": embeddings.astype(np.float32),
1199        }
1200        return predicted

Generate a MABe-22 style submission dictionary.

Parameters

frame_number_map_file : str the path to the frame number map file dataset : BehaviorDataset, optional the dataset to generate a prediction for (if None, generate for the validation dataset) batch_size : int, default 32 the batch size augment_n : int, default 10 the number of augmentations to average results over

Returns

submission : dict a dictionary with frame number mapping and embeddings

def visualize_results( self, save_path: str = None, add_legend: bool = True, ground_truth: bool = True, colormap: str = 'viridis', hide_axes: bool = False, min_classes: int = 1, width: int = 10, whole_video: bool = False, transparent: bool = False, dataset: Union[dlc2action.data.dataset.BehaviorDataset, torch.utils.data.dataloader.DataLoader, str, NoneType] = None, drop_classes: Set = None, search_classes: Set = None, smooth_interval_prediction: int = None, font_size: float = None, num_plots: int = 1, window_size: int = 400):
1458    def visualize_results(
1459        self,
1460        save_path: str = None,
1461        add_legend: bool = True,
1462        ground_truth: bool = True,
1463        colormap: str = "viridis",
1464        hide_axes: bool = False,
1465        min_classes: int = 1,
1466        width: int = 10,
1467        whole_video: bool = False,
1468        transparent: bool = False,
1469        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1470        drop_classes: Set = None,
1471        search_classes: Set = None,
1472        smooth_interval_prediction: int = None,
1473        font_size: float = None,
1474        num_plots: int = 1,
1475        window_size:int =400
1476    ):
1477        """Visualize random predictions.
1478
1479        Parameters
1480        ----------
1481        save_path : str, optional
1482            the path where the plot will be saved
1483        add_legend : bool, default True
1484            if `True`, legend will be added to the plot
1485        ground_truth : bool, default True
1486            if `True`, ground truth will be added to the plot
1487        colormap : str, default 'Accent'
1488            the `matplotlib` colormap to use
1489        hide_axes : bool, default True
1490            if `True`, the axes will be hidden on the plot
1491        min_classes : int, default 1
1492            the minimum number of classes in a displayed interval
1493        width : float, default 10
1494            the width of the plot
1495        whole_video : bool, default False
1496            if `True`, whole videos are plotted instead of segments
1497        transparent : bool, default False
1498            if `True`, the background on the plot is transparent
1499        dataset : BehaviorDataset | DataLoader | str | None, optional
1500            the dataset to make the prediction for (if not provided, the validation dataset is used)
1501        drop_classes : set, optional
1502            a set of class names to not be displayed
1503        search_classes : set, optional
1504            if given, only intervals where at least one of the classes is in ground truth will be shown
1505        smooth_interval_prediction : int, optional
1506            if given, the prediction will be smoothed with a moving average of the given size
1507
1508        """
1509        if drop_classes is None:
1510            drop_classes = []
1511        dataset = self._get_dataset(dataset)
1512        if dataset.annotation_class() != "exclusive_classification":
1513            raise NotImplementedError(
1514                "Results visualisation is only implemented for exclusive classification datasets!"
1515            )
1516        labels = {
1517            k: v for k, v in dataset.behaviors_dict().items() if v not in drop_classes
1518        }
1519        labels.update({-100: "unknown"})
1520        label_keys = sorted([int(x) for x in labels.keys()])
1521        if search_classes is None:
1522            ok = True
1523        else:
1524            ok = False
1525        classes = []
1526        predicted = self.generate_full_length_prediction(dataset)
1527        keys = list(predicted.keys())
1528        counter = 0
1529        max_iter = len(keys) * 2
1530        while len(classes) < min_classes or not ok:
1531            counter += 1
1532            if counter > max_iter:
1533                raise RuntimeError(
1534                    "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive"
1535                )
1536            i = randint(0, len(keys) - 1)
1537            prediction = predicted[keys[i]]
1538            keys_i = list(prediction.keys())
1539            j = randint(0, len(keys_i) - 1)
1540            prediction = prediction[keys_i[j]]
1541            key1 = keys[i]
1542            key2 = keys_i[j]
1543
1544            if smooth_interval_prediction > 0:
1545                unsmoothed_prediction = deepcopy(prediction)
1546                prediction = self._smooth(prediction, smooth_interval_prediction)
1547                height = 3
1548            else:
1549                height = 2
1550            classes = [
1551                labels[int(x)] for x in torch.unique(prediction) if x in label_keys
1552            ]
1553            if search_classes is not None:
1554                ok = any([x in classes for x in search_classes])
1555        fig, ax = plt.subplots(figsize=(width, height))
1556        cmap = cm.get_cmap(colormap) if colormap != "dlc2action" else None
1557        color_list = (
1558            [cmap(c) for c in np.linspace(0, 1, len(labels))]
1559            if colormap != "dlc2action"
1560            else dlc2action_colormaps["default"]
1561        )
1562
1563        def _plot_prediction(prediction, height, set_labels=True):
1564            for c in label_keys:
1565                c_i = label_keys.index(int(c))
1566                output, indices, counts = torch.unique_consecutive(
1567                    prediction == c, return_inverse=True, return_counts=True
1568                )
1569                long_indices = torch.where(output)[0]
1570                if len(long_indices) == 0:
1571                    continue
1572                res_indices_start = [
1573                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1574                    for i in long_indices
1575                ]
1576                res_indices_end = [
1577                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1578                    for i in long_indices
1579                ]
1580                res_indices_len = [
1581                    end - start
1582                    for start, end in zip(res_indices_start, res_indices_end)
1583                ]
1584                if set_labels:
1585                    label = labels[int(c)]
1586                else:
1587                    label = None
1588                ax.broken_barh(
1589                    list(zip(res_indices_start, res_indices_len)),
1590                    (height, 1),
1591                    label=label,
1592                    facecolors=color_list[c_i],
1593                )
1594
1595        if smooth_interval_prediction > 0:
1596            _plot_prediction(unsmoothed_prediction, 0)
1597            _plot_prediction(prediction, 1.5, set_labels=False)
1598            gt_height = 3
1599        else:
1600            _plot_prediction(prediction, 0)
1601            gt_height = 1.5
1602        if ground_truth:
1603            gt = dataset.generate_full_length_gt()[key1][key2]
1604            for c in label_keys:
1605                c_i = label_keys.index(int(c))
1606                if labels[int(c)] in classes:
1607                    label = None
1608                else:
1609                    label = labels[int(c)]
1610                output, indices, counts = torch.unique_consecutive(
1611                    gt == c, return_inverse=True, return_counts=True
1612                )
1613                long_indices = torch.where(output)[0]
1614                if len(long_indices) == 0:
1615                    continue
1616                res_indices_start = [
1617                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1618                    for i in long_indices
1619                ]
1620                res_indices_end = [
1621                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1622                    for i in long_indices
1623                ]
1624                res_indices_len = [
1625                    end - start
1626                    for start, end in zip(res_indices_start, res_indices_end)
1627                ]
1628                ax.broken_barh(
1629                    list(zip(res_indices_start, res_indices_len)),
1630                    (gt_height, 1),
1631                    facecolors=color_list[c_i] if c != "unknown" else "gray",
1632                    label=label,
1633                )
1634        if not ground_truth:
1635            ax.axes.yaxis.set_visible(False)
1636        else:
1637            if smooth_interval_prediction > 0:
1638                ax.set_yticks([0.5, 2, 3.5])
1639                ax.set_yticklabels(["prediction", "smoothed", "ground truth"])
1640            else:
1641                ax.set_yticks([0.5, 2])
1642                ax.set_yticklabels(["prediction", "ground truth"])
1643        if add_legend:
1644            ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1645        if hide_axes:
1646            # plt.axis("off")
1647            plt.box(False)
1648        if font_size is not None:
1649            font = {"size": font_size}
1650            rc("font", **font)
1651        plt.title(f"{key1} -- {key2}")
1652        plt.tight_layout()
1653        for seed in range(num_plots):
1654            if not whole_video:
1655                ax.set_xlim((seed * window_size, (seed + 1) * window_size))
1656            if save_path is not None:
1657                plt.savefig(
1658                    save_path.replace(".svg", f"_{seed}_{key1} -- {key2}.svg"), transparent=transparent
1659                )
1660                print(f"Saved in {save_path.replace('.svg', f'_{seed}_{key1} -- {key2}.svg')}")
1661        plt.show()
1662        plt.close()

Visualize random predictions.

Parameters

save_path : str, optional the path where the plot will be saved add_legend : bool, default True if True, legend will be added to the plot ground_truth : bool, default True if True, ground truth will be added to the plot colormap : str, default 'Accent' the matplotlib colormap to use hide_axes : bool, default True if True, the axes will be hidden on the plot min_classes : int, default 1 the minimum number of classes in a displayed interval width : float, default 10 the width of the plot whole_video : bool, default False if True, whole videos are plotted instead of segments transparent : bool, default False if True, the background on the plot is transparent dataset : BehaviorDataset | DataLoader | str | None, optional the dataset to make the prediction for (if not provided, the validation dataset is used) drop_classes : set, optional a set of class names to not be displayed search_classes : set, optional if given, only intervals where at least one of the classes is in ground truth will be shown smooth_interval_prediction : int, optional if given, the prediction will be smoothed with a moving average of the given size

def set_ssl_transformations(self, ssl_transformations):
1664    def set_ssl_transformations(self, ssl_transformations):
1665        """Set SSL transformations.
1666
1667        Parameters
1668        ----------
1669        ssl_transformations : list
1670            a list of callable SSL transformations
1671
1672        """
1673        self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations)
1674        if self.val_dataloader is not None:
1675            self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations)

Set SSL transformations.

Parameters

ssl_transformations : list a list of callable SSL transformations

def set_ssl_losses(self, ssl_losses: list) -> None:
1677    def set_ssl_losses(self, ssl_losses: list) -> None:
1678        """Set SSL losses.
1679
1680        Parameters
1681        ----------
1682        ssl_losses : list
1683            a list of callable SSL losses
1684
1685        """
1686        self.ssl_losses = ssl_losses

Set SSL losses.

Parameters

ssl_losses : list a list of callable SSL losses

def set_log(self, log: str) -> None:
1688    def set_log(self, log: str) -> None:
1689        """Set the log file.
1690
1691        Parameters
1692        ----------
1693        log: str
1694            the mew log file path
1695
1696        """
1697        self.log_file = log

Set the log file.

Parameters

log: str the mew log file path

def set_keep_target_none(self, keep_target_none: List) -> None:
1699    def set_keep_target_none(self, keep_target_none: List) -> None:
1700        """Set the keep_target_none parameter of the transformer.
1701
1702        Parameters
1703        ----------
1704        keep_target_none : list
1705            a list of bool values
1706
1707        """
1708        self.transformer.keep_target_none = keep_target_none

Set the keep_target_none parameter of the transformer.

Parameters

keep_target_none : list a list of bool values

def set_generate_ssl_input(self, generate_ssl_input: list) -> None:
1710    def set_generate_ssl_input(self, generate_ssl_input: list) -> None:
1711        """Set the generate_ssl_input parameter of the transformer.
1712
1713        Parameters
1714        ----------
1715        generate_ssl_input : list
1716            a list of bool values
1717
1718        """
1719        self.transformer.generate_ssl_input = generate_ssl_input

Set the generate_ssl_input parameter of the transformer.

Parameters

generate_ssl_input : list a list of bool values

def set_model(self, model: dlc2action.model.base_model.Model) -> None:
1721    def set_model(self, model: Model) -> None:
1722        """Set a new model.
1723
1724        Parameters
1725        ----------
1726        model: Model
1727            the new model
1728
1729        """
1730        self.epoch = 0
1731        self.model = model
1732        self.optimizer = self.optimizer_class(
1733            model.parameters(), lr=self.lr, weight_decay=self.weight_decay
1734        )
1735        if self.model.process_labels:
1736            self.model.set_behaviors(list(self.behaviors_dict().values()))

Set a new model.

Parameters

model: Model the new model

def set_dataloaders( self, train_dataloader: torch.utils.data.dataloader.DataLoader, val_dataloader: torch.utils.data.dataloader.DataLoader = None, test_dataloader: torch.utils.data.dataloader.DataLoader = None) -> None:
1738    def set_dataloaders(
1739        self,
1740        train_dataloader: DataLoader,
1741        val_dataloader: DataLoader = None,
1742        test_dataloader: DataLoader = None,
1743    ) -> None:
1744        """Set new dataloaders.
1745
1746        Parameters
1747        ----------
1748        train_dataloader: torch.utils.data.DataLoader
1749            the new train dataloader
1750        val_dataloader : torch.utils.data.DataLoader
1751            the new validation dataloader
1752        test_dataloader : torch.utils.data.DataLoader
1753            the new test dataloader
1754
1755        """
1756        self.train_dataloader = train_dataloader
1757        self.val_dataloader = val_dataloader
1758        self.test_dataloader = test_dataloader

Set new dataloaders.

Parameters

train_dataloader: torch.utils.data.DataLoader the new train dataloader val_dataloader : torch.utils.data.DataLoader the new validation dataloader test_dataloader : torch.utils.data.DataLoader the new test dataloader

def set_loss(self, loss: Callable) -> None:
1760    def set_loss(self, loss: Callable) -> None:
1761        """Set new loss function.
1762
1763        Parameters
1764        ----------
1765        loss: callable
1766            the new loss function
1767
1768        """
1769        self.loss = loss

Set new loss function.

Parameters

loss: callable the new loss function

def set_metrics(self, metrics: dict) -> None:
1771    def set_metrics(self, metrics: dict) -> None:
1772        """Set new metric.
1773
1774        Parameters
1775        ----------
1776        metrics : dict
1777            the new metric dictionary
1778
1779        """
1780        self.metrics = metrics

Set new metric.

Parameters

metrics : dict the new metric dictionary

def set_transformer( self, transformer: dlc2action.transformer.base_transformer.Transformer) -> None:
1782    def set_transformer(self, transformer: Transformer) -> None:
1783        """Set a new transformer.
1784
1785        Parameters
1786        ----------
1787        transformer: Transformer
1788            the new transformer
1789
1790        """
1791        self.transformer = transformer

Set a new transformer.

Parameters

transformer: Transformer the new transformer

def set_predict_functions( self, primary_predict_function: Callable, predict_function: Callable) -> None:
1793    def set_predict_functions(
1794        self, primary_predict_function: Callable, predict_function: Callable
1795    ) -> None:
1796        """Set new predict functions.
1797
1798        Parameters
1799        ----------
1800        primary_predict_function : callable
1801            the new primary predict function
1802        predict_function : callable
1803            the new predict function
1804
1805        """
1806        self.primary_predict_function = primary_predict_function
1807        self.predict_function = predict_function

Set new predict functions.

Parameters

primary_predict_function : callable the new primary predict function predict_function : callable the new predict function

def count_classes(self, bouts: bool = False) -> Dict:
1829    def count_classes(self, bouts: bool = False) -> Dict:
1830        """Get a dictionary of class counts in different modes.
1831
1832        Parameters
1833        ----------
1834        bouts : bool, default False
1835            if `True`, instead of frame counts segment counts are returned
1836
1837        Returns
1838        -------
1839        class_counts : dict
1840            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1841            class names and values are class counts (in frames)
1842
1843        """
1844        class_counts = {}
1845        for x in ["train", "val", "test"]:
1846            try:
1847                class_counts[x] = self.dataset(x).count_classes(bouts)
1848            except ValueError:
1849                class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()}
1850        return class_counts

Get a dictionary of class counts in different modes.

Parameters

bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)

def behaviors_dict(self) -> Dict:
1852    def behaviors_dict(self) -> Dict:
1853        """Get a behavior dictionary.
1854
1855        Keys are label indices and values are label names.
1856
1857        Returns
1858        -------
1859        behaviors_dict : dict
1860            behavior dictionary
1861
1862        """
1863        return self.dataset().behaviors_dict()

Get a behavior dictionary.

Keys are label indices and values are label names.

Returns

behaviors_dict : dict behavior dictionary

def update_parameters(self, parameters: Dict) -> None:
1865    def update_parameters(self, parameters: Dict) -> None:
1866        """Update training parameters from a dictionary.
1867
1868        Parameters
1869        ----------
1870        parameters : dict
1871            the update dictionary
1872
1873        """
1874        self.lr = parameters.get("lr", self.lr)
1875        self.parallel = parameters.get("parallel", self.parallel)
1876        self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr)
1877        self.verbose = parameters.get("verbose", self.verbose)
1878        self.device = parameters.get("device", self.device)
1879        if self.device == "auto":
1880            self.device = "cuda" if torch.cuda.is_available() else "cpu"
1881        self.augment_train = int(parameters.get("augment_train", self.augment_train))
1882        self.augment_val = int(parameters.get("augment_val", self.augment_val))
1883        ssl_weights = parameters.get("ssl_weights", self.ssl_weights)
1884        if ssl_weights is None:
1885            ssl_weights = [1 for _ in self.ssl_losses]
1886        if not isinstance(ssl_weights, Iterable):
1887            ssl_weights = [ssl_weights for _ in self.ssl_losses]
1888        self.ssl_weights = ssl_weights
1889        self.num_epochs = parameters.get("num_epochs", self.num_epochs)
1890        self.model_save_epochs = parameters.get(
1891            "model_save_epochs", self.model_save_epochs
1892        )
1893        self.model_save_path = parameters.get("model_save_path", self.model_save_path)
1894        self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel)
1895        self.T1 = int(parameters.get("pseudolabel_start", self.T1))
1896        self.T2 = int(parameters.get("alpha_growth_stop", self.T2))
1897        self.t = int(parameters.get("correction_interval", self.t))
1898        self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f)
1899        self.log_file = parameters.get("log_file", self.log_file)

Update training parameters from a dictionary.

Parameters

parameters : dict the update dictionary

def generate_uncertainty_score( self, classes: List, augment_n: int = 0, batch_size: int = 32, method: str = 'least_confidence', predicted: torch.Tensor = None, behaviors_dict: Dict = None) -> Dict:
1901    def generate_uncertainty_score(
1902        self,
1903        classes: List,
1904        augment_n: int = 0,
1905        batch_size: int = 32,
1906        method: str = "least_confidence",
1907        predicted: torch.Tensor = None,
1908        behaviors_dict: Dict = None,
1909    ) -> Dict:
1910        """Generate frame-wise scores for active learning.
1911
1912        Parameters
1913        ----------
1914        classes : list
1915            a list of class names or indices; their confidence scores will be computed separately and stacked
1916        augment_n : int, default 0
1917            the number of augmentations to average over
1918        batch_size : int, default 32
1919            the batch size
1920        method : {"least_confidence", "entropy"}
1921            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1922            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1923        predicted : torch.Tensor, default None
1924            if not `None`, the predicted tensor to use instead of predicting from the model
1925        behaviors_dict : dict, default None
1926            if not `None`, the behaviors dictionary to use instead of the one from the dataset
1927
1928        Returns
1929        -------
1930        score_dicts : dict
1931            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1932            are score tensors
1933
1934        """
1935        dataset = self.dataset("train")
1936        if behaviors_dict is None:
1937            behaviors_dict = self.behaviors_dict()
1938        if not isinstance(dataset, BehaviorDataset):
1939            raise TypeError(
1940                f"The dataset parameter has to be either None, string, "
1941                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1942            )
1943        if predicted is None:
1944            predicted = self.predict(
1945                dataset,
1946                raw_output=True,
1947                apply_primary_function=True,
1948                augment_n=augment_n,
1949                batch_size=batch_size,
1950            )
1951        predicted = dataset.generate_full_length_prediction(predicted)
1952        if isinstance(classes[0], str):
1953            behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()}
1954            classes = [behaviors_dict_inverse[c] for c in classes]
1955        for v_id, v in predicted.items():
1956            for clip_id, vv in v.items():
1957                if method == "least_confidence":
1958                    predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5]
1959                elif method == "entropy":
1960                    predicted[v_id][clip_id][vv != -100] = (
1961                        -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv)
1962                    )[vv != -100]
1963                elif method == "random":
1964                    predicted[v_id][clip_id] = torch.rand(vv.shape)
1965                else:
1966                    raise ValueError(
1967                        f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']"
1968                    )
1969                predicted[v_id][clip_id][vv == -100] = 0
1970
1971        predicted = {
1972            v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
1973            for v_id, video_dict in predicted.items()
1974        }
1975        return predicted

Generate frame-wise scores for active learning.

Parameters

classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over batch_size : int, default 32 the batch size method : {"least_confidence", "entropy"} the method used to calculate the scores from the probability predictions ("least_confidence": 1 - p_i if p_i > 0.5 or p_i if p_i < 0.5; "entropy": - p_i * log(p_i) - (1 - p_i) * log(1 - p_i)) predicted : torch.Tensor, default None if not None, the predicted tensor to use instead of predicting from the model behaviors_dict : dict, default None if not None, the behaviors dictionary to use instead of the one from the dataset

Returns

score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors

def generate_bald_score( self, classes: List, augment_n: int = 0, batch_size: int = 32, num_models: int = 10, kernel_size: int = 11) -> Dict:
1977    def generate_bald_score(
1978        self,
1979        classes: List,
1980        augment_n: int = 0,
1981        batch_size: int = 32,
1982        num_models: int = 10,
1983        kernel_size: int = 11,
1984    ) -> Dict:
1985        """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning.
1986
1987        Parameters
1988        ----------
1989        classes : list
1990            a list of class names or indices; their confidence scores will be computed separately and stacked
1991        augment_n : int, default 0
1992            the number of augmentations to average over
1993        batch_size : int, default 32
1994            the batch size
1995        num_models : int, default 10
1996            the number of dropout masks to apply
1997        kernel_size : int, default 11
1998            the size of the smoothing gaussian kernel
1999
2000        Returns
2001        -------
2002        score_dicts : dict
2003            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
2004            are score tensors
2005
2006        """
2007        dataset = self.dataset("train")
2008        dataset = self._get_dataset(dataset)
2009        if not isinstance(dataset, BehaviorDataset):
2010            raise TypeError(
2011                f"The dataset parameter has to be either None, string, "
2012                f"BehaviorDataset or Dataloader, got {type(dataset)}"
2013            )
2014        predictions = []
2015        for _ in range(num_models):
2016            predicted = self.predict(
2017                dataset,
2018                raw_output=True,
2019                apply_primary_function=True,
2020                augment_n=augment_n,
2021                batch_size=batch_size,
2022                train_mode=True,
2023            )
2024            predicted = dataset.generate_full_length_prediction(predicted)
2025            if isinstance(classes[0], str):
2026                behaviors_dict_inverse = {
2027                    v: k for k, v in self.behaviors_dict().items()
2028                }
2029                classes = [behaviors_dict_inverse[c] for c in classes]
2030            for v_id, v in predicted.items():
2031                for clip_id, vv in v.items():
2032                    vv[vv != -100] = (vv[vv != -100] > 0.5).int().float()
2033                    predicted[v_id][clip_id] = vv
2034            predicted = {
2035                v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
2036                for v_id, video_dict in predicted.items()
2037            }
2038            predictions.append(predicted)
2039        result = {v_id: {} for v_id in predictions[0]}
2040        r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1)
2041        gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r]
2042        gauss = [x / sum(gauss) for x in gauss]
2043        kernel = torch.FloatTensor([[gauss]])
2044        for v_id in predictions[0]:
2045            for clip_id in predictions[0][v_id]:
2046                consensus = (
2047                    (
2048                        torch.mean(
2049                            torch.stack([x[v_id][clip_id] for x in predictions]), dim=0
2050                        )
2051                        > 0.5
2052                    )
2053                    .int()
2054                    .float()
2055                )
2056                consensus[predictions[0][v_id][clip_id] == -100] = -100
2057                result[v_id][clip_id] = torch.zeros(consensus.shape)
2058                for x in predictions:
2059                    result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int()
2060                result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models
2061                res = torch.zeros(result[v_id][clip_id].shape)
2062                for i in range(len(classes)):
2063                    res[i, floor(kernel_size // 2) : -floor(kernel_size // 2)] = (
2064                        torch.nn.functional.conv1d(
2065                            result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0),
2066                            kernel,
2067                        )[0, ...]
2068                    )
2069                result[v_id][clip_id] = res
2070        return result

Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning.

Parameters

classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over batch_size : int, default 32 the batch size num_models : int, default 10 the number of dropout masks to apply kernel_size : int, default 11 the size of the smoothing gaussian kernel

Returns

score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors

def get_normalization_stats(self) -> Optional[Dict]:
2072    def get_normalization_stats(self) -> Optional[Dict]:
2073        """Get the normalization statistics of the dataset.
2074
2075        Returns
2076        -------
2077        stats : dict
2078            a dictionary containing the mean and standard deviation of the dataset
2079
2080        """
2081        return self.train_dataloader.dataset.stats

Get the normalization statistics of the dataset.

Returns

stats : dict a dictionary containing the mean and standard deviation of the dataset