dlc2action.task.universal_task

Training and inference

   1#
   2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved.
   3#
   4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
   5#
   6"""
   7Training and inference
   8"""
   9
  10from typing import Callable, Dict, Union, Tuple, List, Set, Any, Optional
  11
  12import numpy as np
  13from torch.optim import Optimizer, Adam
  14from torch import nn
  15from torch.utils.data import DataLoader
  16import torch
  17from collections.abc import Iterable
  18from tqdm import tqdm
  19import warnings
  20from random import randint
  21from matplotlib import pyplot as plt
  22from matplotlib import cm
  23from collections import defaultdict
  24from dlc2action.transformer.base_transformer import Transformer, EmptyTransformer
  25from dlc2action.data.dataset import BehaviorDataset
  26from dlc2action.model.base_model import Model, LoadedModel
  27from dlc2action.metric.base_metric import Metric
  28import os
  29import random
  30from copy import deepcopy, copy
  31from optuna.trial import Trial
  32from optuna import TrialPruned
  33from math import pi, sqrt, exp, floor, ceil
  34
  35
  36class MyDataParallel(nn.DataParallel):
  37    def __init__(self, *args, **kwargs):
  38        super().__init__(*args, **kwargs)
  39        self.process_labels = self.module.process_labels
  40
  41    def freeze_feature_extractor(self):
  42        self.module.freeze_feature_extractor()
  43
  44    def unfreeze_feature_extractor(self):
  45        self.module.unfreeze_feature_extractor()
  46
  47    def transform_labels(self, device):
  48        return self.module.transform_labels(device)
  49
  50    def logit_scale(self):
  51        return self.module.logit_scale()
  52
  53    def main_task_off(self):
  54        self.module.main_task_off()
  55
  56    def state_dict(self, *args, **kwargs):
  57        return self.module.state_dict(*args, **kwargs)
  58
  59    def ssl_on(self):
  60        self.module.ssl_on()
  61
  62    def ssl_off(self):
  63        self.module.ssl_off()
  64
  65    def extract_features(self, x, start=0):
  66        return self.module.extract_features(x, start)
  67
  68
  69class Task:
  70    """
  71    A universal trainer class that performs training, evaluation and prediction for all types of tasks and data
  72    """
  73
  74    def __init__(
  75        self,
  76        train_dataloader: DataLoader,
  77        model: Union[nn.Module, Model],
  78        loss: Callable[[torch.Tensor, torch.Tensor], float],
  79        num_epochs: int = 0,
  80        transformer: Transformer = None,
  81        ssl_losses: List = None,
  82        ssl_weights: List = None,
  83        lr: float = 1e-3,
  84        metrics: Dict = None,
  85        val_dataloader: DataLoader = None,
  86        test_dataloader: DataLoader = None,
  87        optimizer: Optimizer = None,
  88        device: str = "cuda",
  89        verbose: bool = True,
  90        log_file: Union[str, None] = None,
  91        augment_train: int = 1,
  92        augment_val: int = 0,
  93        validation_interval: int = 1,
  94        predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None,
  95        primary_predict_function: Callable = None,
  96        exclusive: bool = True,
  97        ignore_tags: bool = True,
  98        threshold: float = 0.5,
  99        model_save_path: str = None,
 100        model_save_epochs: int = 5,
 101        pseudolabel: bool = False,
 102        pseudolabel_start: int = 100,
 103        correction_interval: int = 2,
 104        pseudolabel_alpha_f: float = 3,
 105        alpha_growth_stop: int = 600,
 106        parallel: bool = False,
 107        skip_metrics: List = None,
 108    ) -> None:
 109        """
 110        Parameters
 111        ----------
 112        train_dataloader : torch.utils.data.DataLoader
 113            a training dataloader
 114        model : dlc2action.model.base_model.Model
 115            a model
 116        loss : callable
 117            a loss function
 118        num_epochs : int, default 0
 119            the number of epochs
 120        transformer : dlc2action.transformer.base_transformer.Transformer, optional
 121            a transformer
 122        ssl_losses : list, optional
 123            a list of SSL losses
 124        ssl_weights : list, optional
 125            a list of SSL weights (if not provided initializes to 1)
 126        lr : float, default 1e-3
 127            learning rate
 128        metrics : dict, optional
 129            a list of metric functions
 130        val_dataloader : torch.utils.data.DataLoader, optional
 131            a validation dataloader
 132        optimizer : torch.optim.Optimizer, optional
 133            an optimizer (`Adam` by default)
 134        device : str, default 'cuda'
 135            the device to train the model on
 136        verbose : bool, default True
 137            if `True`, the process is described in standard output
 138        log_file : str, optional
 139            the path to a text file where the process will be logged
 140        augment_train : {1, 0}
 141            number of augmentations to apply at training
 142        augment_val : int, default 0
 143            number of augmentations to apply at validation
 144        validation_interval : int, default 1
 145            every time this number of epochs passes, validation metrics are computed
 146        predict_function : callable, optional
 147            a function that maps probabilities to class predictions (if not provided, a default is generated)
 148        primary_predict_function : callable, optional
 149            a function that maps model output to probabilities (if not provided, initialized as identity)
 150        exclusive : bool, default True
 151            set to False for multi-label classification
 152        ignore_tags : bool, default False
 153            if `True`, samples with different meta tags will be mixed in batches
 154        threshold : float, default 0.5
 155            the threshold used for multi-label classification default prediction function
 156        model_save_path : str, optional
 157            the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path
 158            is not provided)
 159        model_save_epochs : int, default 5
 160            the interval for saving the model checkpoints (the last epoch is always saved)
 161        pseudolabel : bool, default False
 162            if True, the pseudolabeling procedure will be applied
 163        pseudolabel_start : int, default 100
 164            pseudolabeling starts after this epoch
 165        correction_interval : int, default 1
 166            after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and
 167            new pseudolabels are generated
 168        pseudolabel_alpha_f : float, default 3
 169            the maximum value of pseudolabeling alpha
 170        alpha_growth_stop : int, default 600
 171            pseudolabeling alpha stops growing after this epoch
 172        """
 173
 174        # pseudolabeling might be buggy right now -- not using it!
 175        if skip_metrics is None:
 176            skip_metrics = []
 177        self.train_dataloader = train_dataloader
 178        self.val_dataloader = val_dataloader
 179        self.test_dataloader = test_dataloader
 180        self.transformer = transformer
 181        self.num_epochs = num_epochs
 182        self.skip_metrics = skip_metrics
 183        self.verbose = verbose
 184        self.augment_train = int(augment_train)
 185        self.augment_val = int(augment_val)
 186        self.ignore_tags = ignore_tags
 187        self.validation_interval = int(validation_interval)
 188        self.log_file = log_file
 189        self.loss = loss
 190        self.model_save_path = model_save_path
 191        self.model_save_epochs = model_save_epochs
 192        self.epoch = 0
 193
 194        if metrics is None:
 195            metrics = {}
 196        self.metrics = metrics
 197
 198        if optimizer is None:
 199            optimizer = Adam
 200
 201        if ssl_weights is None:
 202            ssl_weights = [1 for _ in ssl_losses]
 203        if not isinstance(ssl_weights, Iterable):
 204            ssl_weights = [ssl_weights for _ in ssl_losses]
 205        self.ssl_weights = ssl_weights
 206
 207        self.optimizer_class = optimizer
 208        self.lr = lr
 209        if not isinstance(model, Model):
 210            self.model = LoadedModel(model=model)
 211        else:
 212            self.set_model(model)
 213        self.parallel = parallel
 214
 215        if self.transformer is None:
 216            self.augment_val = 0
 217            self.augment_train = 0
 218            self.transformer = EmptyTransformer()
 219
 220        if self.augment_train > 1:
 221            warnings.warn(
 222                'The "augment_train" parameter is too large -> setting it to 1.'
 223            )
 224            self.augment_train = 1
 225
 226        try:
 227            if device == "auto":
 228                device = "cuda" if torch.cuda.is_available() else "cpu"
 229            self.device = torch.device(device)
 230        except:
 231            raise ("The format of the device is incorrect")
 232
 233        if ssl_losses is None:
 234            self.ssl_losses = [lambda x, y: 0]
 235        else:
 236            self.ssl_losses = ssl_losses
 237
 238        if primary_predict_function is None:
 239            if exclusive:
 240                primary_predict_function = lambda x: nn.Softmax(x, dim=1)
 241            else:
 242                primary_predict_function = lambda x: torch.sigmoid(x)
 243        self.primary_predict_function = primary_predict_function
 244
 245        if predict_function is None:
 246            if exclusive:
 247                self.predict_function = lambda x: torch.max(x.data, 1)[1]
 248            else:
 249                self.predict_function = lambda x: (x > threshold).int()
 250        else:
 251            self.predict_function = predict_function
 252
 253        self.pseudolabel = pseudolabel
 254        self.alpha_f = pseudolabel_alpha_f
 255        self.T2 = alpha_growth_stop
 256        self.T1 = pseudolabel_start
 257        self.t = correction_interval
 258        if self.T2 <= self.T1:
 259            raise ValueError(
 260                f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got "
 261                f"{pseudolabel_start=} and {alpha_growth_stop=}"
 262            )
 263        self.decision_thresholds = [0.5 for x in self.behaviors_dict()]
 264
 265    def save_checkpoint(self, checkpoint_path: str) -> None:
 266        """
 267        Save a general checkpoint
 268
 269        Parameters
 270        ----------
 271        checkpoint_path : str
 272            the path where the checkpoint will be saved
 273        """
 274
 275        torch.save(
 276            {
 277                "epoch": self.epoch,
 278                "model_state_dict": self.model.state_dict(),
 279                "optimizer_state_dict": self.optimizer.state_dict(),
 280            },
 281            checkpoint_path,
 282        )
 283
 284    def load_from_checkpoint(
 285        self, checkpoint_path, only_model: bool = False, load_strict: bool = True
 286    ) -> None:
 287        """
 288        Load from a checkpoint
 289
 290        Parameters
 291        ----------
 292        checkpoint_path : str
 293            the path to the checkpoint
 294        only_model : bool, default False
 295            if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state
 296            dictionary)
 297        load_strict : bool, default True
 298            if `True`, any inconsistencies in state dictionaries are regarded as errors
 299        """
 300
 301        if checkpoint_path is None:
 302            return
 303        checkpoint = torch.load(checkpoint_path, map_location=self.device)
 304        self.model.to(self.device)
 305        self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict)
 306        if not only_model:
 307            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
 308            self.epoch = checkpoint["epoch"]
 309
 310    def save_model(self, save_path: str) -> None:
 311        """
 312        Save the model state dictionary
 313
 314        Parameters
 315        ----------
 316        save_path : str
 317            the path where the state will be saved
 318        """
 319
 320        torch.save(self.model.state_dict(), save_path)
 321        print("saved the model successfully")
 322
 323    def _apply_predict_functions(self, predicted):
 324        """
 325        Map from model output to prediction
 326        """
 327
 328        predicted = self.primary_predict_function(predicted)
 329        predicted = self.predict_function(predicted)
 330        return predicted
 331
 332    def _get_prediction(
 333        self,
 334        main_input: Dict,
 335        tags: torch.Tensor,
 336        ssl_inputs: List = None,
 337        ssl_targets: List = None,
 338        augment_n: int = 0,
 339        embedding: bool = False,
 340        subsample: List = None,
 341    ) -> Tuple:
 342        """
 343        Get the prediction of `self.model` for input averaged over `augment_n` augmentations
 344        """
 345
 346        if augment_n == 0:
 347            augment_n = 1
 348            augment = False
 349        else:
 350            augment = True
 351        model_input, ssl_inputs, ssl_targets = self.transformer.transform(
 352            deepcopy(main_input),
 353            ssl_inputs=ssl_inputs,
 354            ssl_targets=ssl_targets,
 355            augment=augment,
 356            subsample=subsample,
 357        )
 358        if not embedding:
 359            predicted, ssl_predicted = self.model(model_input, ssl_inputs, tag=tags)
 360        else:
 361            predicted = self.model.extract_features(model_input)
 362            ssl_predicted = None
 363        if self.parallel and predicted is not None and len(predicted.shape) == 4:
 364            predicted = predicted.reshape(
 365                (-1, model_input.shape[0], *predicted.shape[2:])
 366            )
 367        if predicted is not None and augment_n > 1:
 368            self.model.ssl_off()
 369            for i in range(augment_n - 1):
 370                model_input, *_ = self.transformer.transform(
 371                    deepcopy(main_input), augment=augment
 372                )
 373                if not embedding:
 374                    pred, _ = self.model(model_input, None, tag=tags)
 375                else:
 376                    pred = self.model.extract_features(model_input)
 377                predicted += pred.detach()
 378            self.model.ssl_on()
 379            predicted /= augment_n
 380        if self.model.process_labels:
 381            class_embedding = self.model.transform_labels(predicted.device)
 382            beh_dict = self.behaviors_dict()
 383            class_embedding = torch.cat(
 384                [class_embedding[beh_dict[k]] for k in sorted(beh_dict.keys())], 0
 385            )
 386            predicted = {
 387                "video": predicted,
 388                "text": class_embedding,
 389                "logit_scale": self.model.logit_scale(),
 390                "device": predicted.device,
 391            }
 392        return predicted, ssl_predicted, ssl_targets
 393
 394    def _ssl_loss(self, ssl_predicted, ssl_targets):
 395        """
 396        Apply SSL losses
 397        """
 398
 399        ssl_loss = []
 400        for loss, predicted, target in zip(self.ssl_losses, ssl_predicted, ssl_targets):
 401            ssl_loss.append(loss(predicted, target))
 402        return ssl_loss
 403
 404    def _loss_function(
 405        self,
 406        batch: Dict,
 407        augment_n: int,
 408        temporal_subsampling_size: int = None,
 409        skip_metrics: List = None,
 410    ) -> Tuple[float, float]:
 411        """
 412        Calculate the loss function and the metric for a dataloader batch
 413
 414        Averaging the predictions over augment_n augmentations.
 415        """
 416
 417        if "target" not in batch or torch.isnan(batch["target"]).all():
 418            raise ValueError("Cannot compute loss function with nan targets!")
 419        main_input = {k: v.to(self.device) for k, v in batch["input"].items()}
 420        main_target, ssl_targets, ssl_inputs = None, None, None
 421        main_target = batch["target"].to(self.device)
 422        if "ssl_targets" in batch:
 423            ssl_targets = [
 424                {k: v.to(self.device) for k, v in x.items()}
 425                if isinstance(x, dict)
 426                else None
 427                for x in batch["ssl_targets"]
 428            ]
 429        if "ssl_inputs" in batch:
 430            ssl_inputs = [
 431                {k: v.to(self.device) for k, v in x.items()}
 432                if isinstance(x, dict)
 433                else None
 434                for x in batch["ssl_inputs"]
 435            ]
 436        if temporal_subsampling_size is not None:
 437            subsample = sorted(
 438                random.sample(
 439                    range(main_target.shape[-1]),
 440                    int(temporal_subsampling_size * main_target.shape[-1]),
 441                )
 442            )
 443            main_target = main_target[..., subsample]
 444        else:
 445            subsample = None
 446
 447        if self.ignore_tags:
 448            tag = None
 449        else:
 450            tag = batch.get("tag")
 451        predicted, ssl_predicted, ssl_targets = self._get_prediction(
 452            main_input, tag, ssl_inputs, ssl_targets, augment_n, subsample=subsample
 453        )
 454        del main_input, ssl_inputs
 455        return self._compute(
 456            ssl_predicted,
 457            ssl_targets,
 458            predicted,
 459            main_target,
 460            tag=batch.get("tag"),
 461            skip_metrics=skip_metrics,
 462        )
 463
 464    def _compute(
 465        self,
 466        ssl_predicted: List,
 467        ssl_targets: List,
 468        predicted: torch.Tensor,
 469        main_target: torch.Tensor,
 470        tag: Any = None,
 471        skip_loss: bool = False,
 472        apply_primary_function: bool = True,
 473        skip_metrics: List = None,
 474    ) -> Tuple[float, float]:
 475        """
 476        Compute the losses and metrics from predictions
 477        """
 478
 479        if skip_metrics is None:
 480            skip_metrics = []
 481        if not skip_loss:
 482            ssl_losses = self._ssl_loss(ssl_predicted, ssl_targets)
 483            if predicted is not None:
 484                loss = self.loss(predicted, main_target)
 485            else:
 486                loss = 0
 487        else:
 488            ssl_losses, loss = [], 0
 489
 490        if predicted is not None:
 491            if isinstance(predicted, dict):
 492                predicted = {
 493                    k: v.detach()
 494                    for k, v in predicted.items()
 495                    if isinstance(v, torch.Tensor)
 496                }
 497            else:
 498                predicted = predicted.detach()
 499            if apply_primary_function:
 500                predicted = self.primary_predict_function(predicted)
 501            predicted_transformed = self.predict_function(predicted)
 502            for name, metric_function in self.metrics.items():
 503                if name not in skip_metrics:
 504                    if metric_function.needs_raw_data:
 505                        metric_function.update(predicted, main_target, tag)
 506                    else:
 507                        metric_function.update(
 508                            predicted_transformed,
 509                            main_target,
 510                            tag,
 511                        )
 512        return loss, ssl_losses
 513
 514    def _calculate_metrics(self) -> Dict:
 515        """
 516        Calculate the final values of epoch metrics
 517        """
 518
 519        epoch_metrics = {}
 520        for metric_name, metric in self.metrics.items():
 521            m = metric.calculate()
 522            if type(m) is dict:
 523                for k, v in m.items():
 524                    if type(v) is torch.Tensor:
 525                        v = v.item()
 526                    epoch_metrics[f"{metric_name}_{k}"] = v
 527            else:
 528                if type(m) is torch.Tensor:
 529                    m = m.item()
 530                epoch_metrics[metric_name] = m
 531            metric.reset()
 532        return epoch_metrics
 533
 534    def _run_epoch(
 535        self,
 536        dataloader: DataLoader,
 537        mode: str,
 538        augment_n: int,
 539        verbose: bool = False,
 540        unlabeled: bool = None,
 541        alpha: float = 1,
 542        temporal_subsampling_size: int = None,
 543    ) -> Tuple:
 544        """
 545        Run one epoch on dataloader
 546
 547        Averaging the predictions over augment_n augmentations.
 548        Use "train" mode for training and "val" mode for evaluation.
 549        """
 550
 551        if mode == "train":
 552            self.model.train()
 553        elif mode == "val":
 554            self.model.eval()
 555            pass
 556        else:
 557            raise ValueError(
 558                f'Mode {mode} is not recognized, please choose either "train" for training or "val" for validation'
 559            )
 560        if self.ignore_tags:
 561            tags = [None]
 562        else:
 563            tags = dataloader.dataset.get_tags()
 564        epoch_loss = 0
 565        epoch_ssl_loss = defaultdict(lambda: 0)
 566        data_len = 0
 567        set_pars = dataloader.dataset.set_indexing_parameters
 568        skip_metrics = self.skip_metrics if mode == "train" else None
 569        for tag in tags:
 570            set_pars(unlabeled=unlabeled, tag=tag)
 571            data_len += len(dataloader)
 572            if verbose:
 573                dataloader = tqdm(dataloader)
 574            for batch in dataloader:
 575                loss, ssl_losses = self._loss_function(
 576                    batch,
 577                    augment_n,
 578                    temporal_subsampling_size=temporal_subsampling_size,
 579                    skip_metrics=skip_metrics,
 580                )
 581                if loss != 0:
 582                    loss = loss * alpha
 583                    epoch_loss += loss.item()
 584                for i, (ssl_loss, weight) in enumerate(
 585                    zip(ssl_losses, self.ssl_weights)
 586                ):
 587                    if ssl_loss != 0:
 588                        epoch_ssl_loss[i] += ssl_loss.item()
 589                        loss = loss + weight * ssl_loss
 590                if mode == "train":
 591                    self.optimizer.zero_grad()
 592                    if loss.requires_grad:
 593                        loss.backward()
 594                    self.optimizer.step()
 595
 596        epoch_loss = epoch_loss / data_len
 597        epoch_ssl_loss = {k: v / data_len for k, v in epoch_ssl_loss.items()}
 598        epoch_metrics = self._calculate_metrics()
 599
 600        return epoch_loss, epoch_ssl_loss, epoch_metrics
 601
 602    def train(
 603        self,
 604        trial: Trial = None,
 605        optimized_metric: str = None,
 606        to_ram: bool = False,
 607        autostop_interval: int = 30,
 608        autostop_threshold: float = 0.001,
 609        autostop_metric: str = None,
 610        main_task_on: bool = True,
 611        ssl_on: bool = True,
 612        temporal_subsampling_size: int = None,
 613        loading_bar: bool = False,
 614    ) -> Tuple:
 615        """
 616        Train the task and return a log of epoch-average loss and metric
 617
 618        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 619        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 620        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 621        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 622
 623        Parameters
 624        ----------
 625        trial : Trial
 626            an `optuna` trial (for hyperparameter searches)
 627        optimized_metric : str
 628            the name of the metric being optimized (for hyperparameter searches)
 629        to_ram : bool, default False
 630            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
 631            if the dataset is too large)
 632        autostop_interval : int, default 50
 633            the number of epochs to average the autostop metric over
 634        autostop_threshold : float, default 0.001
 635            the autostop difference threshold
 636        autostop_metric : str, optional
 637            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 638        main_task_on : bool, default True
 639            if `False`, the main task (action segmentation) will not be used in training
 640        ssl_on : bool, default True
 641            if `False`, the SSL task will not be used in training
 642
 643        Returns
 644        -------
 645        loss_log: list
 646            a list of float loss function values for each epoch
 647        metrics_log: dict
 648            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
 649            names, values are lists of function values)
 650        """
 651
 652        if self.parallel and not isinstance(self.model, nn.DataParallel):
 653            self.model = MyDataParallel(self.model)
 654        self.model.to(self.device)
 655        assert autostop_metric in [None, "loss"] + list(self.metrics)
 656        autostop_interval //= self.validation_interval
 657        if trial is not None and optimized_metric is None:
 658            raise ValueError(
 659                "You need to provide the optimized metric name (optimized_metric parameter) "
 660                "for optuna pruning to work!"
 661            )
 662        if to_ram:
 663            print("transferring datasets to RAM...")
 664            self.train_dataloader.dataset.to_ram()
 665            if self.val_dataloader is not None and len(self.val_dataloader) > 0:
 666                self.val_dataloader.dataset.to_ram()
 667        loss_log = {"train": [], "val": []}
 668        metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])}
 669        if not main_task_on:
 670            self.model.main_task_off()
 671        if not ssl_on:
 672            self.model.ssl_off()
 673        while self.epoch < self.num_epochs:
 674            self.epoch += 1
 675            unlabeled = None
 676            alpha = 1
 677            if self.pseudolabel:
 678                if self.epoch >= self.T1:
 679                    unlabeled = (self.epoch - self.T1) % self.t != 0
 680                    if unlabeled:
 681                        alpha = self._alpha(self.epoch)
 682                else:
 683                    unlabeled = False
 684            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
 685                dataloader=self.train_dataloader,
 686                mode="train",
 687                augment_n=self.augment_train,
 688                unlabeled=unlabeled,
 689                alpha=alpha,
 690                temporal_subsampling_size=temporal_subsampling_size,
 691                verbose=loading_bar,
 692            )
 693            loss_log["train"].append(epoch_loss)
 694            epoch_string = f"[epoch {self.epoch}]"
 695            if self.pseudolabel:
 696                if unlabeled:
 697                    epoch_string += " (unlabeled)"
 698                else:
 699                    epoch_string += " (labeled)"
 700            epoch_string += f": loss {epoch_loss:.4f}"
 701
 702            if len(epoch_ssl_loss) != 0:
 703                for key, value in sorted(epoch_ssl_loss.items()):
 704                    metrics_log["train"][f"ssl_loss_{key}"].append(value)
 705                    epoch_string += f", ssl_loss_{key} {value:.4f}"
 706
 707            for metric_name, metric_value in sorted(epoch_metrics.items()):
 708                if metric_name not in self.skip_metrics:
 709                    metrics_log["train"][metric_name].append(metric_value)
 710                    epoch_string += f", {metric_name} {metric_value:.3f}"
 711
 712            if (
 713                self.val_dataloader is not None
 714                and self.epoch % self.validation_interval == 0
 715            ):
 716                with torch.no_grad():
 717                    epoch_string += "\n"
 718                    (
 719                        val_epoch_loss,
 720                        val_epoch_ssl_loss,
 721                        val_epoch_metrics,
 722                    ) = self._run_epoch(
 723                        dataloader=self.val_dataloader,
 724                        mode="val",
 725                        augment_n=self.augment_val,
 726                    )
 727                    loss_log["val"].append(val_epoch_loss)
 728                    epoch_string += f"validation: loss {val_epoch_loss:.4f}"
 729
 730                    if len(val_epoch_ssl_loss) != 0:
 731                        for key, value in sorted(val_epoch_ssl_loss.items()):
 732                            metrics_log["val"][f"ssl_loss_{key}"].append(value)
 733                            epoch_string += f", ssl_loss_{key} {value:.4f}"
 734
 735                    for metric_name, metric_value in sorted(val_epoch_metrics.items()):
 736                        metrics_log["val"][metric_name].append(metric_value)
 737                        epoch_string += f", {metric_name} {metric_value:.3f}"
 738
 739                if trial is not None:
 740                    if optimized_metric not in metrics_log["val"]:
 741                        raise ValueError(
 742                            f"The {optimized_metric} metric set for optimization is not being logged!"
 743                        )
 744                    trial.report(metrics_log["val"][optimized_metric][-1], self.epoch)
 745                    if trial.should_prune():
 746                        raise TrialPruned()
 747
 748            if self.verbose:
 749                print(epoch_string)
 750
 751            if self.log_file is not None:
 752                with open(self.log_file, "a") as f:
 753                    f.write(epoch_string + "\n")
 754
 755            save_condition = (
 756                (self.model_save_epochs != 0)
 757                and (self.epoch % self.model_save_epochs == 0)
 758            ) or (self.epoch == self.num_epochs)
 759
 760            if self.epoch > 0 and save_condition and self.model_save_path is not None:
 761                epoch_s = str(self.epoch).zfill(len(str(self.num_epochs)))
 762                self.save_checkpoint(
 763                    os.path.join(self.model_save_path, f"epoch{epoch_s}.pt")
 764                )
 765
 766            if self.pseudolabel and self.epoch >= self.T1 and not unlabeled:
 767                self._set_pseudolabels()
 768
 769            if autostop_metric == "loss":
 770                if len(loss_log["val"]) > autostop_interval * 2:
 771                    if (
 772                        np.mean(loss_log["val"][-autostop_interval:])
 773                        < np.mean(
 774                            loss_log["val"][-2 * autostop_interval : -autostop_interval]
 775                        )
 776                        + autostop_threshold
 777                    ):
 778                        break
 779            elif autostop_metric in metrics_log["val"]:
 780                if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2:
 781                    if (
 782                        np.mean(
 783                            metrics_log["val"][autostop_metric][-autostop_interval:]
 784                        )
 785                        < np.mean(
 786                            metrics_log["val"][autostop_metric][
 787                                -2 * autostop_interval : -autostop_interval
 788                            ]
 789                        )
 790                        + autostop_threshold
 791                    ):
 792                        break
 793
 794        metrics_log = {k: dict(v) for k, v in metrics_log.items()}
 795
 796        return loss_log, metrics_log
 797
 798    def evaluate_prediction(
 799        self,
 800        prediction: Union[torch.Tensor, Dict],
 801        data: Union[DataLoader, BehaviorDataset, str] = None,
 802        batch_size: int = 32,
 803    ) -> Tuple:
 804        """
 805        Compute metrics for a prediction
 806
 807        Parameters
 808        ----------
 809        prediction : torch.Tensor
 810            the prediction
 811        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 812            the data the prediction was made for (if not provided, take the validation dataset)
 813        batch_size : int, default 32
 814            the batch size
 815
 816        Returns
 817        -------
 818        loss : float
 819            the average value of the loss function
 820        metric : dict
 821            a dictionary of average values of metric functions
 822        """
 823
 824        if type(data) is not DataLoader:
 825            dataset = self._get_dataset(data)
 826            data = DataLoader(dataset, shuffle=False, batch_size=batch_size)
 827        epoch_loss = 0
 828        if isinstance(prediction, dict):
 829            num_classes = len(self.behaviors_dict())
 830            length = dataset.len_segment()
 831            coords = dataset.annotation_store.get_original_coordinates()
 832            for batch in data:
 833                main_target = batch["target"]
 834                pr_coords = coords[batch["index"]]
 835                predicted = torch.zeros((len(pr_coords), num_classes, length))
 836                for i, c in enumerate(pr_coords):
 837                    video_id = dataset.input_store.get_video_id(c)
 838                    clip_id = dataset.input_store.get_clip_id(c)
 839                    start, end = dataset.input_store.get_clip_start_end(c)
 840                    predicted[i, :, : end - start] = prediction[video_id][clip_id][
 841                        :, start:end
 842                    ]
 843                self._compute(
 844                    [],
 845                    [],
 846                    predicted,
 847                    main_target,
 848                    skip_loss=True,
 849                    tag=batch.get("tag"),
 850                    apply_primary_function=False,
 851                )
 852        else:
 853            for batch in data:
 854                main_target = batch["target"]
 855                predicted = prediction[batch["index"]]
 856                self._compute(
 857                    [],
 858                    [],
 859                    predicted,
 860                    main_target,
 861                    skip_loss=True,
 862                    tag=batch.get("tag"),
 863                    apply_primary_function=False,
 864                )
 865        epoch_metrics = self._calculate_metrics()
 866        strings = [
 867            f"{metric_name} {metric_value:.3f}"
 868            for metric_name, metric_value in epoch_metrics.items()
 869        ]
 870        val_string = ", ".join(sorted(strings))
 871        print(val_string)
 872        return epoch_loss, epoch_metrics
 873
 874    def evaluate(
 875        self,
 876        data: Union[DataLoader, BehaviorDataset, str] = None,
 877        augment_n: int = 0,
 878        batch_size: int = 32,
 879        verbose: bool = True,
 880    ) -> Tuple:
 881        """
 882        Evaluate the Task model
 883
 884        Parameters
 885        ----------
 886        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 887            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 888        augment_n : int, default 0
 889            the number of augmentations to average results over
 890        batch_size : int, default 32
 891            the batch size
 892        verbose : bool, default True
 893            if True, the process is reported to standard output
 894
 895        Returns
 896        -------
 897        loss : float
 898            the average value of the loss function
 899        ssl_loss : float
 900            the average value of the SSL loss function
 901        metric : dict
 902            a dictionary of average values of metric functions
 903        """
 904
 905        if self.parallel and not isinstance(self.model, nn.DataParallel):
 906            self.model = MyDataParallel(self.model)
 907        self.model.to(self.device)
 908        if type(data) is not DataLoader:
 909            data = self._get_dataset(data)
 910            data = DataLoader(data, shuffle=False, batch_size=batch_size)
 911        with torch.no_grad():
 912            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
 913                dataloader=data, mode="val", augment_n=augment_n, verbose=verbose
 914            )
 915        val_string = f"loss {epoch_loss:.4f}"
 916        for metric_name, metric_value in sorted(epoch_metrics.items()):
 917            val_string += f", {metric_name} {metric_value:.3f}"
 918        print(val_string)
 919        return epoch_loss, epoch_ssl_loss, epoch_metrics
 920
 921    def predict(
 922        self,
 923        data: Union[DataLoader, BehaviorDataset, str] = None,
 924        raw_output: bool = False,
 925        apply_primary_function: bool = True,
 926        augment_n: int = 0,
 927        batch_size: int = 32,
 928        train_mode: bool = False,
 929        to_ram: bool = False,
 930        embedding: bool = False,
 931    ) -> torch.Tensor:
 932        """
 933        Make a prediction with the Task model
 934
 935        Parameters
 936        ----------
 937        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional
 938            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 939        raw_output : bool, default False
 940            if `True`, the raw predicted probabilities are returned
 941        apply_primary_function : bool, default True
 942            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
 943            the input)
 944        augment_n : int, default 0
 945            the number of augmentations to average results over
 946        batch_size : int, default 32
 947            the batch size
 948        train_mode : bool, default False
 949            if `True`, the model is used in training mode (affects dropout and batch normalization layers)
 950        to_ram : bool, default False
 951            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
 952            if the dataset is too large)
 953        embedding : bool, default False
 954            if `True`, the output of feature extractor is returned, ignoring the prediction module of the model
 955
 956        Returns
 957        -------
 958        prediction : torch.Tensor
 959            a prediction for the input data
 960        """
 961
 962        if self.parallel and not isinstance(self.model, nn.DataParallel):
 963            self.model = MyDataParallel(self.model)
 964        self.model.to(self.device)
 965        if train_mode:
 966            self.model.train()
 967        else:
 968            self.model.eval()
 969        output = []
 970        if embedding:
 971            raw_output = True
 972            apply_primary_function = True
 973        if type(data) is not DataLoader:
 974            data = self._get_dataset(data)
 975            if to_ram:
 976                print("transferring dataset to RAM...")
 977                data.to_ram()
 978            data = DataLoader(data, shuffle=False, batch_size=batch_size)
 979        self.model.ssl_off()
 980        with torch.no_grad():
 981            for batch in tqdm(data):
 982                input = {k: v.to(self.device) for k, v in batch["input"].items()}
 983                predicted, _, _ = self._get_prediction(
 984                    input,
 985                    batch.get("tag"),
 986                    augment_n=augment_n,
 987                    embedding=embedding,
 988                )
 989                if apply_primary_function:
 990                    predicted = self.primary_predict_function(predicted)
 991                if not raw_output:
 992                    predicted = self.predict_function(predicted)
 993                output.append(predicted.detach().cpu())
 994        self.model.ssl_on()
 995        output = torch.cat(output).detach()
 996        return output
 997
 998    def dataset(self, mode="train") -> BehaviorDataset:
 999        """
1000        Get a dataset
1001
1002        Parameters
1003        ----------
1004        mode : {'train', 'val', 'test}
1005            the dataset to get
1006
1007        Returns
1008        -------
1009        dataset : dlc2action.data.dataset.BehaviorDataset
1010            the dataset
1011        """
1012
1013        dataloader = self.dataloader(mode)
1014        if dataloader is None:
1015            raise ValueError("The length of the dataloader is 0!")
1016        return dataloader.dataset
1017
1018    def dataloader(self, mode: str = "train") -> DataLoader:
1019        """
1020        Get a dataloader
1021
1022        Parameters
1023        ----------
1024        mode : {'train', 'val', 'test}
1025            the dataset to get
1026
1027        Returns
1028        -------
1029        dataloader : torch.utils.data.DataLoader
1030            the dataloader
1031        """
1032
1033        if mode == "train":
1034            return self.train_dataloader
1035        elif mode == "val":
1036            return self.val_dataloader
1037        elif mode == "test":
1038            return self.test_dataloader
1039        else:
1040            raise ValueError(
1041                f'The {mode} mode is not recognized, please choose from "train", "val" or "test"'
1042            )
1043
1044    def _get_dataset(self, dataset):
1045        """
1046        Get a dataset from a dataloader, a string ('train', 'test' or 'val') or `None` (default)
1047        """
1048
1049        if dataset is None:
1050            dataset = self.dataset("val")
1051        elif dataset in ["train", "val", "test"]:
1052            dataset = self.dataset(dataset)
1053        elif type(dataset) is DataLoader:
1054            dataset = dataset.dataset
1055        if type(dataset) is BehaviorDataset:
1056            return dataset
1057        else:
1058            raise TypeError(f"The {type(dataset)} type of dataset is not recognized!")
1059
1060    def _get_dataloader(self, dataset):
1061        """
1062        Get a dataloader from a dataset, a string ('train', 'test' or 'val') or `None` (default)
1063        """
1064
1065        if dataset is None:
1066            dataset = self.dataloader("val")
1067        elif dataset in ["train", "val", "test"]:
1068            dataset = self.dataloader(dataset)
1069            if dataset is None:
1070                raise ValueError(f"The length of the dataloader is 0!")
1071        elif type(dataset) is BehaviorDataset:
1072            dataset = DataLoader(dataset)
1073        if type(dataset) is DataLoader:
1074            return dataset
1075        else:
1076            raise TypeError(f"The {type(dataset)} type of dataset is not recognized!")
1077
1078    def generate_full_length_prediction(
1079        self, dataset=None, batch_size=32, augment_n=10
1080    ):
1081        """
1082        Compile a prediction for the original input sequences
1083
1084        Parameters
1085        ----------
1086        dataset : BehaviorDataset, optional
1087            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1088        batch_size : int, default 32
1089            the batch size
1090        augment_n : int, default 10
1091            the number of augmentations to average results over
1092
1093        Returns
1094        -------
1095        prediction : dict
1096            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1097            are prediction tensors
1098        """
1099
1100        dataset = self._get_dataset(dataset)
1101        if not isinstance(dataset, BehaviorDataset):
1102            raise TypeError(
1103                f"The dataset parameter has to be either None, string, "
1104                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1105            )
1106        predicted = self.predict(
1107            dataset,
1108            raw_output=True,
1109            apply_primary_function=True,
1110            augment_n=augment_n,
1111            batch_size=batch_size,
1112        )
1113        predicted = dataset.generate_full_length_prediction(predicted)
1114        predicted = {
1115            v_id: {
1116                clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze()
1117                for clip_id, v in video_dict.items()
1118            }
1119            for v_id, video_dict in predicted.items()
1120        }
1121        return predicted
1122
1123    def generate_submission(
1124        self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10
1125    ):
1126        """
1127        Generate a MABe-22 style submission dictionary
1128
1129        Parameters
1130        ----------
1131        dataset : BehaviorDataset, optional
1132            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1133        batch_size : int, default 32
1134            the batch size
1135        augment_n : int, default 10
1136            the number of augmentations to average results over
1137
1138        Returns
1139        -------
1140        submission : dict
1141            a dictionary with frame number mapping and embeddings
1142        """
1143
1144        dataset = self._get_dataset(dataset)
1145        if not isinstance(dataset, BehaviorDataset):
1146            raise TypeError(
1147                f"The dataset parameter has to be either None, string, "
1148                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1149            )
1150        predicted = self.predict(
1151            dataset,
1152            raw_output=True,
1153            apply_primary_function=True,
1154            augment_n=augment_n,
1155            batch_size=batch_size,
1156            embedding=True,
1157        )
1158        predicted = dataset.generate_full_length_prediction(predicted)
1159        frame_map = np.load(frame_number_map_file, allow_pickle=True).item()
1160        length = frame_map[list(frame_map.keys())[-1]][1]
1161        embeddings = None
1162        for video_id in list(predicted.keys()):
1163            split = video_id.split("--")
1164            if len(split) != 2 or len(predicted[video_id]) > 1:
1165                raise RuntimeError(
1166                    "Generating submissions is only implemented for the mabe22 dataset!"
1167                )
1168            if split[1] not in frame_map:
1169                raise RuntimeError(f"The {split[1]} video is not in the frame map file")
1170            v_id = split[1]
1171            clip_id = list(predicted[video_id].keys())[0]
1172            if embeddings is None:
1173                embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0]))
1174            start, end = frame_map[v_id]
1175            embeddings[start:end, :] = predicted[video_id][clip_id].T
1176            predicted.pop(video_id)
1177        predicted = {
1178            "frame_number_map": frame_map,
1179            "embeddings": embeddings.astype(np.float32),
1180        }
1181        return predicted
1182
1183    def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor:
1184        """
1185        Get a list of True group beginning and end indices from a boolean tensor
1186        """
1187
1188        output, indices = torch.unique_consecutive(tensor, return_inverse=True)
1189        true_indices = torch.where(output)[0]
1190        starts = torch.tensor(
1191            [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
1192        )
1193        ends = torch.tensor(
1194            [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
1195        )
1196        return torch.stack([starts, ends]).T
1197
1198    def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor:
1199        """
1200        Get rid of jittering in a non-exclusive classification tensor
1201
1202        First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than
1203        `smooth_interval`.
1204        """
1205
1206        if len(tensor.shape) > 1:
1207            for c in tensor.shape[1]:
1208                intervals = self._get_intervals(tensor[:, c] == 0)
1209                interval_lengths = torch.tensor(
1210                    [interval[1] - interval[0] for interval in intervals]
1211                )
1212                short_intervals = intervals[interval_lengths <= smooth_interval]
1213                for start, end in short_intervals:
1214                    tensor[start:end, c] = 1
1215                intervals = self._get_intervals(tensor[:, c] == 1)
1216                interval_lengths = torch.tensor(
1217                    [interval[1] - interval[0] for interval in intervals]
1218                )
1219                short_intervals = intervals[interval_lengths <= smooth_interval]
1220                for start, end in short_intervals:
1221                    tensor[start:end, c] = 0
1222        else:
1223            for c in tensor.unique():
1224                intervals = self._get_intervals(tensor == c)
1225                interval_lengths = torch.tensor(
1226                    [interval[1] - interval[0] for interval in intervals]
1227                )
1228                short_intervals = intervals[interval_lengths <= smooth_interval]
1229                for start, end in short_intervals:
1230                    if start == 0:
1231                        tensor[start:end] = tensor[end + 1]
1232                    else:
1233                        tensor[start:end] = tensor[start - 1]
1234        return tensor
1235
1236    def _visualize_results_label(
1237        self,
1238        label: str,
1239        save_path: str = None,
1240        add_legend: bool = True,
1241        ground_truth: bool = True,
1242        hide_axes: bool = False,
1243        width: int = 10,
1244        whole_video: bool = False,
1245        transparent: bool = False,
1246        dataset: BehaviorDataset = None,
1247        smooth_interval: int = 0,
1248        title: str = None,
1249    ):
1250        """
1251        Visualize random predictions
1252
1253        Parameters
1254        ----------
1255        save_path : str, optional
1256            the path where the plot will be saved
1257        add_legend : bool, default True
1258            if `True`, legend will be added to the plot
1259        ground_truth : bool, default True
1260            if `True`, ground truth will be added to the plot
1261        colormap : str, default 'viridis'
1262            the `matplotlib` colormap to use
1263        hide_axes : bool, default True
1264            if `True`, the axes will be hidden on the plot
1265        min_classes : int, default 1
1266            the minimum number of classes in a displayed interval
1267        width : float, default 10
1268            the width of the plot
1269        whole_video : bool, default False
1270            if `True`, whole videos are plotted instead of segments
1271        transparent : bool, default False
1272            if `True`, the background on the plot is transparent
1273        dataset : BehaviorDataset, optional
1274            the dataset to make the prediction for (if not provided, the validation dataset is used)
1275        drop_classes : set, optional
1276            a set of class names to not be displayed
1277        """
1278
1279        if title is None:
1280            title = ""
1281        dataset = self._get_dataset(dataset)
1282        inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()}
1283        label_ind = inverse_dict[label]
1284        labels = {1: label, -100: "unknown"}
1285        label_keys = [1, -100]
1286        color_list = ["blue", "gray"]
1287        if whole_video:
1288            predicted = self.generate_full_length_prediction(dataset)
1289            keys = list(predicted.keys())
1290        counter = 0
1291        if whole_video:
1292            max_iter = len(keys) * 5
1293        else:
1294            max_iter = len(dataset) * 5
1295        ok = False
1296        while not ok:
1297            counter += 1
1298            if counter > max_iter:
1299                raise RuntimeError(
1300                    "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive"
1301                )
1302            if whole_video:
1303                i = randint(0, len(keys) - 1)
1304                prediction = predicted[keys[i]]
1305                keys_i = list(prediction.keys())
1306                j = randint(0, len(keys_i) - 1)
1307                full_p = prediction[keys_i[j]]
1308                prediction = prediction[keys_i[j]][label_ind]
1309            else:
1310                dataloader = DataLoader(dataset)
1311                i = randint(0, len(dataloader) - 1)
1312                for num, batch in enumerate(dataloader):
1313                    if num == i:
1314                        break
1315                input_data = {k: v.to(self.device) for k, v in batch["input"].items()}
1316                prediction, *_ = self._get_prediction(
1317                    input_data, batch.get("tag"), augment_n=5
1318                )
1319                prediction = self._apply_predict_functions(prediction)
1320                j = randint(0, len(prediction) - 1)
1321                full_p = prediction[j]
1322                prediction = prediction[j][label_ind]
1323            classes = [x for x in torch.unique(prediction) if int(x) in label_keys]
1324            ok = 1 in classes
1325        fig, ax = plt.subplots(figsize=(width, 2))
1326        for c in classes:
1327            c_i = label_keys.index(int(c))
1328            output, indices, counts = torch.unique_consecutive(
1329                prediction == c, return_inverse=True, return_counts=True
1330            )
1331            long_indices = torch.where(output)[0]
1332            res_indices_start = [
1333                (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
1334            ]
1335            res_indices_end = [
1336                (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1337                for i in long_indices
1338            ]
1339            res_indices_len = [
1340                end - start for start, end in zip(res_indices_start, res_indices_end)
1341            ]
1342            ax.broken_barh(
1343                list(zip(res_indices_start, res_indices_len)),
1344                (0, 1),
1345                label=labels[int(c)],
1346                facecolors=color_list[c_i],
1347            )
1348        if ground_truth:
1349            gt = batch["target"][j][label_ind].to(self.device)
1350            classes_gt = [x for x in torch.unique(gt) if int(x) in label_keys]
1351            for c in classes_gt:
1352                c_i = label_keys.index(int(c))
1353                if c in classes:
1354                    label = None
1355                else:
1356                    label = labels[int(c)]
1357                output, indices, counts = torch.unique_consecutive(
1358                    gt == c, return_inverse=True, return_counts=True
1359                )
1360                long_indices = torch.where(output * (counts > 5))[0]
1361                res_indices_start = [
1362                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1363                    for i in long_indices
1364                ]
1365                res_indices_end = [
1366                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1367                    for i in long_indices
1368                ]
1369                res_indices_len = [
1370                    end - start
1371                    for start, end in zip(res_indices_start, res_indices_end)
1372                ]
1373                ax.broken_barh(
1374                    list(zip(res_indices_start, res_indices_len)),
1375                    (1.5, 1),
1376                    facecolors=color_list[c_i],
1377                    label=label,
1378                )
1379        self._compute(
1380            main_target=batch["target"][j].unsqueeze(0).to(self.device),
1381            predicted=full_p.unsqueeze(0).to(self.device),
1382            ssl_targets=[],
1383            ssl_predicted=[],
1384            skip_loss=True,
1385        )
1386        metrics = self._calculate_metrics()
1387        if smooth_interval > 0:
1388            smoothed = self._smooth(full_p, smooth_interval=smooth_interval)[
1389                label_ind, :
1390            ]
1391            for c in classes:
1392                c_i = label_keys.index(int(c))
1393                output, indices, counts = torch.unique_consecutive(
1394                    smoothed == c, return_inverse=True, return_counts=True
1395                )
1396                long_indices = torch.where(output)[0]
1397                res_indices_start = [
1398                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1399                    for i in long_indices
1400                ]
1401                res_indices_end = [
1402                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1403                    for i in long_indices
1404                ]
1405                res_indices_len = [
1406                    end - start
1407                    for start, end in zip(res_indices_start, res_indices_end)
1408                ]
1409                ax.broken_barh(
1410                    list(zip(res_indices_start, res_indices_len)),
1411                    (3, 1),
1412                    label=labels[int(c)],
1413                    facecolors=color_list[c_i],
1414                )
1415        keys = list(metrics.keys())
1416        for key in keys:
1417            if key.split("_")[-1] != (str(label_ind)):
1418                metrics.pop(key)
1419        title = [title]
1420        for key, value in metrics.items():
1421            title.append(f"{'_'.join(key.split('_')[: -1])}: {value:.2f}")
1422        title = ", ".join(title)
1423        if not ground_truth:
1424            ax.axes.yaxis.set_visible(False)
1425        else:
1426            ax.set_yticks([0.5, 2])
1427            ax.set_yticklabels(["prediction", "ground truth"])
1428        if add_legend:
1429            ax.legend()
1430        if hide_axes:
1431            plt.axis("off")
1432        plt.title(title)
1433        plt.xlim((0, len(prediction)))
1434        if save_path is not None:
1435            plt.savefig(save_path, transparent=transparent)
1436        plt.show()
1437
1438    def visualize_results(
1439        self,
1440        save_path: str = None,
1441        add_legend: bool = True,
1442        ground_truth: bool = True,
1443        colormap: str = "viridis",
1444        hide_axes: bool = False,
1445        min_classes: int = 1,
1446        width: int = 10,
1447        whole_video: bool = False,
1448        transparent: bool = False,
1449        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1450        drop_classes: Set = None,
1451        search_classes: Set = None,
1452        num_samples: int = 1,
1453        smooth_interval_prediction: int = None,
1454        behavior_name: str = None,
1455    ):
1456        """
1457        Visualize random predictions
1458
1459        Parameters
1460        ----------
1461        save_path : str, optional
1462            the path where the plot will be saved
1463        add_legend : bool, default True
1464            if `True`, legend will be added to the plot
1465        ground_truth : bool, default True
1466            if `True`, ground truth will be added to the plot
1467        colormap : str, default 'Accent'
1468            the `matplotlib` colormap to use
1469        hide_axes : bool, default True
1470            if `True`, the axes will be hidden on the plot
1471        min_classes : int, default 1
1472            the minimum number of classes in a displayed interval
1473        width : float, default 10
1474            the width of the plot
1475        whole_video : bool, default False
1476            if `True`, whole videos are plotted instead of segments
1477        transparent : bool, default False
1478            if `True`, the background on the plot is transparent
1479        dataset : BehaviorDataset | DataLoader | str | None, optional
1480            the dataset to make the prediction for (if not provided, the validation dataset is used)
1481        drop_classes : set, optional
1482            a set of class names to not be displayed
1483        search_classes : set, optional
1484            if given, only intervals where at least one of the classes is in ground truth will be shown
1485        """
1486
1487        if drop_classes is None:
1488            drop_classes = []
1489        dataset = self._get_dataset(dataset)
1490        if dataset.annotation_class() == "exclusive_classification":
1491            exclusive = True
1492        elif dataset.annotation_class() == "nonexclusive_classification":
1493            exclusive = False
1494        else:
1495            raise NotImplementedError(
1496                f"Results visualisation is not implemented for {dataset.annotation_class() } datasets!"
1497            )
1498        if exclusive:
1499            labels = {
1500                k: v
1501                for k, v in dataset.behaviors_dict().items()
1502                if v not in drop_classes
1503            }
1504            labels.update({-100: "unknown"})
1505        else:
1506            inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()}
1507            if behavior_name is None:
1508                behavior_name = list(inverse_dict.keys())[0]
1509            label_ind = inverse_dict[behavior_name]
1510            labels = {1: label, -100: "unknown"}
1511        label_keys = sorted([int(x) for x in labels.keys()])
1512        if search_classes is None:
1513            ok = True
1514        else:
1515            ok = False
1516        classes = []
1517        if whole_video:
1518            predicted = self.generate_full_length_prediction(dataset)
1519            keys = list(predicted.keys())
1520        counter = 0
1521        if whole_video:
1522            max_iter = len(keys) * 2
1523        else:
1524            max_iter = len(dataset) * 2
1525        while len(classes) < min_classes or not ok:
1526            counter += 1
1527            if counter > max_iter:
1528                raise RuntimeError(
1529                    "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive"
1530                )
1531            if whole_video:
1532                i = randint(0, len(keys) - 1)
1533                prediction = predicted[keys[i]]
1534                keys_i = list(prediction.keys())
1535                j = randint(0, len(keys_i) - 1)
1536                prediction = prediction[keys_i[j]]
1537                key1 = keys[i]
1538                key2 = keys_i[j]
1539            else:
1540                dataloader = DataLoader(dataset)
1541                i = randint(0, len(dataloader) - 1)
1542                for num, batch in enumerate(dataloader):
1543                    if num == i:
1544                        break
1545                input = {k: v.to(self.device) for k, v in batch["input"].items()}
1546                prediction, *_ = self._get_prediction(
1547                    input, batch.get("tag"), augment_n=5
1548                )
1549                prediction = self._apply_predict_functions(prediction)
1550                j = randint(0, len(prediction) - 1)
1551                prediction = prediction[j]
1552            if not exclusive:
1553                prediction = prediction[label_ind]
1554            if smooth_interval_prediction > 0:
1555                unsmoothed_prediction = deepcopy(prediction)
1556                prediction = self._smooth(prediction, smooth_interval_prediction)
1557                height = 3
1558            else:
1559                height = 2
1560            classes = [
1561                labels[int(x)] for x in torch.unique(prediction) if x in label_keys
1562            ]
1563            if search_classes is not None:
1564                ok = any([x in classes for x in search_classes])
1565        fig, ax = plt.subplots(figsize=(width, height))
1566        color_list = cm.get_cmap(colormap, lut=len(labels)).colors
1567
1568        def _plot_prediction(prediction, height, set_labels=True):
1569            for c in label_keys:
1570                c_i = label_keys.index(int(c))
1571                output, indices, counts = torch.unique_consecutive(
1572                    prediction == c, return_inverse=True, return_counts=True
1573                )
1574                long_indices = torch.where(output)[0]
1575                if len(long_indices) == 0:
1576                    continue
1577                res_indices_start = [
1578                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1579                    for i in long_indices
1580                ]
1581                res_indices_end = [
1582                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1583                    for i in long_indices
1584                ]
1585                res_indices_len = [
1586                    end - start
1587                    for start, end in zip(res_indices_start, res_indices_end)
1588                ]
1589                if set_labels:
1590                    label = labels[int(c)]
1591                else:
1592                    label = None
1593                ax.broken_barh(
1594                    list(zip(res_indices_start, res_indices_len)),
1595                    (height, 1),
1596                    label=label,
1597                    facecolors=color_list[c_i],
1598                )
1599
1600        if smooth_interval_prediction > 0:
1601            _plot_prediction(unsmoothed_prediction, 0)
1602            _plot_prediction(prediction, 1.5, set_labels=False)
1603            gt_height = 3
1604        else:
1605            _plot_prediction(prediction, 0)
1606            gt_height = 1.5
1607        if ground_truth:
1608            if not whole_video:
1609                gt = batch["target"][j].to(self.device)
1610            else:
1611                gt = dataset.generate_full_length_gt()[key1][key2]
1612            for c in label_keys:
1613                c_i = label_keys.index(int(c))
1614                if labels[int(c)] in classes:
1615                    label = None
1616                else:
1617                    label = labels[int(c)]
1618                output, indices, counts = torch.unique_consecutive(
1619                    gt == c, return_inverse=True, return_counts=True
1620                )
1621                long_indices = torch.where(output)[0]
1622                if len(long_indices) == 0:
1623                    continue
1624                res_indices_start = [
1625                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1626                    for i in long_indices
1627                ]
1628                res_indices_end = [
1629                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1630                    for i in long_indices
1631                ]
1632                res_indices_len = [
1633                    end - start
1634                    for start, end in zip(res_indices_start, res_indices_end)
1635                ]
1636                ax.broken_barh(
1637                    list(zip(res_indices_start, res_indices_len)),
1638                    (gt_height, 1),
1639                    facecolors=color_list[c_i] if c != "unknown" else "gray",
1640                    label=label,
1641                )
1642        if not ground_truth:
1643            ax.axes.yaxis.set_visible(False)
1644        else:
1645            if smooth_interval_prediction > 0:
1646                ax.set_yticks([0.5, 2, 3.5])
1647                ax.set_yticklabels(["prediction", "smoothed", "ground truth"])
1648            else:
1649                ax.set_yticks([0.5, 2])
1650                ax.set_yticklabels(["prediction", "ground truth"])
1651        if add_legend:
1652            ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1653        if hide_axes:
1654            plt.axis("off")
1655        if save_path is not None:
1656            plt.savefig(save_path, transparent=transparent)
1657        plt.show()
1658
1659    def set_ssl_transformations(self, ssl_transformations):
1660        self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations)
1661        if self.val_dataloader is not None:
1662            self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations)
1663
1664    def set_ssl_losses(self, ssl_losses: list) -> None:
1665        """
1666        Set SSL losses
1667
1668        Parameters
1669        ----------
1670        ssl_losses : list
1671            a list of callable SSL losses
1672        """
1673
1674        self.ssl_losses = ssl_losses
1675
1676    def set_log(self, log: str) -> None:
1677        """
1678        Set the log file
1679
1680        Parameters
1681        ----------
1682        log: str
1683            the mew log file path
1684        """
1685
1686        self.log_file = log
1687
1688    def set_keep_target_none(self, keep_target_none: List) -> None:
1689        """
1690        Set the keep_target_none parameter of the transformer
1691
1692        Parameters
1693        ----------
1694        keep_target_none : list
1695            a list of bool values
1696        """
1697
1698        self.transformer.keep_target_none = keep_target_none
1699
1700    def set_generate_ssl_input(self, generate_ssl_input: list) -> None:
1701        """
1702        Set the generate_ssl_input parameter of the transformer
1703
1704        Parameters
1705        ----------
1706        generate_ssl_input : list
1707            a list of bool values
1708        """
1709
1710        self.transformer.generate_ssl_input = generate_ssl_input
1711
1712    def set_model(self, model: Model) -> None:
1713        """
1714        Set a new model
1715
1716        Parameters
1717        ----------
1718        model: Model
1719            the new model
1720        """
1721
1722        self.epoch = 0
1723        self.model = model
1724        self.optimizer = self.optimizer_class(model.parameters(), lr=self.lr)
1725        if self.model.process_labels:
1726            self.model.set_behaviors(list(self.behaviors_dict().values()))
1727
1728    def set_dataloaders(
1729        self,
1730        train_dataloader: DataLoader,
1731        val_dataloader: DataLoader = None,
1732        test_dataloader: DataLoader = None,
1733    ) -> None:
1734        """
1735        Set new dataloaders
1736
1737        Parameters
1738        ----------
1739        train_dataloader: torch.utils.data.DataLoader
1740            the new train dataloader
1741        val_dataloader : torch.utils.data.DataLoader
1742            the new validation dataloader
1743        test_dataloader : torch.utils.data.DataLoader
1744            the new test dataloader
1745        """
1746
1747        self.train_dataloader = train_dataloader
1748        self.val_dataloader = val_dataloader
1749        self.test_dataloader = test_dataloader
1750
1751    def set_loss(self, loss: Callable) -> None:
1752        """
1753        Set new loss function
1754
1755        Parameters
1756        ----------
1757        loss: callable
1758            the new loss function
1759        """
1760
1761        self.loss = loss
1762
1763    def set_metrics(self, metrics: dict) -> None:
1764        """
1765        Set new metric
1766
1767        Parameters
1768        ----------
1769        metrics : dict
1770            the new metric dictionary
1771        """
1772
1773        self.metrics = metrics
1774
1775    def set_transformer(self, transformer: Transformer) -> None:
1776        """
1777        Set a new transformer
1778
1779        Parameters
1780        ----------
1781        transformer: Transformer
1782            the new transformer
1783        """
1784
1785        self.transformer = transformer
1786
1787    def set_predict_functions(
1788        self, primary_predict_function: Callable, predict_function: Callable
1789    ) -> None:
1790        """
1791        Set new predict functions
1792
1793        Parameters
1794        ----------
1795        primary_predict_function : callable
1796            the new primary predict function
1797        predict_function : callable
1798            the new predict function
1799        """
1800
1801        self.primary_predict_function = primary_predict_function
1802        self.predict_function = predict_function
1803
1804    def _set_pseudolabels(self):
1805        """
1806        Set pseudolabels
1807        """
1808
1809        self.train_dataloader.dataset.set_unlabeled(True)
1810        predicted = self.predict(
1811            data=self.dataset("train"),
1812            raw_output=False,
1813            augment_n=self.augment_val,
1814            ssl_off=True,
1815        )
1816        self.train_dataloader.dataset.set_annotation(predicted.detach())
1817
1818    def _alpha(self, epoch):
1819        """
1820        Get the current pseudolabeling alpha parameter
1821        """
1822
1823        if epoch <= self.T1:
1824            return 0
1825        elif epoch < self.T2:
1826            return self.alpha_f * (epoch - self.T1) / (self.T2 - self.T1)
1827        else:
1828            return self.alpha_f
1829
1830    def count_classes(self, bouts: bool = False) -> Dict:
1831        """
1832        Get a dictionary of class counts in different modes
1833
1834        Parameters
1835        ----------
1836        bouts : bool, default False
1837            if `True`, instead of frame counts segment counts are returned
1838
1839        Returns
1840        -------
1841        class_counts : dict
1842            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1843            class names and values are class counts (in frames)
1844        """
1845
1846        class_counts = {}
1847        for x in ["train", "val", "test"]:
1848            try:
1849                class_counts[x] = self.dataset(x).count_classes(bouts)
1850            except ValueError:
1851                class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()}
1852        return class_counts
1853
1854    def behaviors_dict(self) -> Dict:
1855        """
1856        Get a behavior dictionary
1857
1858        Keys are label indices and values are label names.
1859
1860        Returns
1861        -------
1862        behaviors_dict : dict
1863            behavior dictionary
1864        """
1865
1866        return self.dataset().behaviors_dict()
1867
1868    def update_parameters(self, parameters: Dict) -> None:
1869        """
1870        Update training parameters from a dictionary
1871
1872        Parameters
1873        ----------
1874        parameters : dict
1875            the update dictionary
1876        """
1877
1878        self.lr = parameters.get("lr", self.lr)
1879        self.parallel = parameters.get("parallel", self.parallel)
1880        self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr)
1881        self.verbose = parameters.get("verbose", self.verbose)
1882        self.device = parameters.get("device", self.device)
1883        if self.device == "auto":
1884            self.device = "cuda" if torch.cuda.is_available() else "cpu"
1885        self.augment_train = int(parameters.get("augment_train", self.augment_train))
1886        self.augment_val = int(parameters.get("augment_val", self.augment_val))
1887        ssl_weights = parameters.get("ssl_weights", self.ssl_weights)
1888        if ssl_weights is None:
1889            ssl_weights = [1 for _ in self.ssl_losses]
1890        if not isinstance(ssl_weights, Iterable):
1891            ssl_weights = [ssl_weights for _ in self.ssl_losses]
1892        self.ssl_weights = ssl_weights
1893        self.num_epochs = parameters.get("num_epochs", self.num_epochs)
1894        self.model_save_epochs = parameters.get(
1895            "model_save_epochs", self.model_save_epochs
1896        )
1897        self.model_save_path = parameters.get("model_save_path", self.model_save_path)
1898        self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel)
1899        self.T1 = int(parameters.get("pseudolabel_start", self.T1))
1900        self.T2 = int(parameters.get("alpha_growth_stop", self.T2))
1901        self.t = int(parameters.get("correction_interval", self.t))
1902        self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f)
1903        self.log_file = parameters.get("log_file", self.log_file)
1904
1905    def generate_uncertainty_score(
1906        self,
1907        classes: List,
1908        augment_n: int = 0,
1909        batch_size: int = 32,
1910        method: str = "least_confidence",
1911        predicted: torch.Tensor = None,
1912        behaviors_dict: Dict = None,
1913    ) -> Dict:
1914        """
1915        Generate frame-wise scores for active learning
1916
1917        Parameters
1918        ----------
1919        classes : list
1920            a list of class names or indices; their confidence scores will be computed separately and stacked
1921        augment_n : int, default 0
1922            the number of augmentations to average over
1923        batch_size : int, default 32
1924            the batch size
1925        method : {"least_confidence", "entropy"}
1926            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1927            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1928
1929        Returns
1930        -------
1931        score_dicts : dict
1932            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1933            are score tensors
1934        """
1935
1936        dataset = self.dataset("train")
1937        if behaviors_dict is None:
1938            behaviors_dict = self.behaviors_dict()
1939        if not isinstance(dataset, BehaviorDataset):
1940            raise TypeError(
1941                f"The dataset parameter has to be either None, string, "
1942                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1943            )
1944        if predicted is None:
1945            predicted = self.predict(
1946                dataset,
1947                raw_output=True,
1948                apply_primary_function=True,
1949                augment_n=augment_n,
1950                batch_size=batch_size,
1951            )
1952        predicted = dataset.generate_full_length_prediction(predicted)
1953        if isinstance(classes[0], str):
1954            behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()}
1955            classes = [behaviors_dict_inverse[c] for c in classes]
1956        for v_id, v in predicted.items():
1957            for clip_id, vv in v.items():
1958                if method == "least_confidence":
1959                    predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5]
1960                elif method == "entropy":
1961                    predicted[v_id][clip_id][vv != -100] = (
1962                        -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv)
1963                    )[vv != -100]
1964                elif method == "random":
1965                    predicted[v_id][clip_id] = torch.rand(vv.shape)
1966                else:
1967                    raise ValueError(
1968                        f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']"
1969                    )
1970                predicted[v_id][clip_id][vv == -100] = 0
1971
1972        predicted = {
1973            v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
1974            for v_id, video_dict in predicted.items()
1975        }
1976        return predicted
1977
1978    def generate_bald_score(
1979        self,
1980        classes: List,
1981        augment_n: int = 0,
1982        batch_size: int = 32,
1983        num_models: int = 10,
1984        kernel_size: int = 11,
1985    ) -> Dict:
1986        """
1987        Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning
1988
1989        Parameters
1990        ----------
1991        classes : list
1992            a list of class names or indices; their confidence scores will be computed separately and stacked
1993        augment_n : int, default 0
1994            the number of augmentations to average over
1995        batch_size : int, default 32
1996            the batch size
1997        num_models : int, default 10
1998            the number of dropout masks to apply
1999        kernel_size : int, default 11
2000            the size of the smoothing gaussian kernel
2001
2002        Returns
2003        -------
2004        score_dicts : dict
2005            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
2006            are score tensors
2007        """
2008
2009        dataset = self.dataset("train")
2010        dataset = self._get_dataset(dataset)
2011        if not isinstance(dataset, BehaviorDataset):
2012            raise TypeError(
2013                f"The dataset parameter has to be either None, string, "
2014                f"BehaviorDataset or Dataloader, got {type(dataset)}"
2015            )
2016        predictions = []
2017        for _ in range(num_models):
2018            predicted = self.predict(
2019                dataset,
2020                raw_output=True,
2021                apply_primary_function=True,
2022                augment_n=augment_n,
2023                batch_size=batch_size,
2024                train_mode=True,
2025            )
2026            predicted = dataset.generate_full_length_prediction(predicted)
2027            if isinstance(classes[0], str):
2028                behaviors_dict_inverse = {
2029                    v: k for k, v in self.behaviors_dict().items()
2030                }
2031                classes = [behaviors_dict_inverse[c] for c in classes]
2032            for v_id, v in predicted.items():
2033                for clip_id, vv in v.items():
2034                    vv[vv != -100] = (vv[vv != -100] > 0.5).int().float()
2035                    predicted[v_id][clip_id] = vv
2036            predicted = {
2037                v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
2038                for v_id, video_dict in predicted.items()
2039            }
2040            predictions.append(predicted)
2041        result = {v_id: {} for v_id in predictions[0]}
2042        r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1)
2043        gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r]
2044        gauss = [x / sum(gauss) for x in gauss]
2045        kernel = torch.FloatTensor([[gauss]])
2046        for v_id in predictions[0]:
2047            for clip_id in predictions[0][v_id]:
2048                consensus = (
2049                    (
2050                        torch.mean(
2051                            torch.stack([x[v_id][clip_id] for x in predictions]), dim=0
2052                        )
2053                        > 0.5
2054                    )
2055                    .int()
2056                    .float()
2057                )
2058                consensus[predictions[0][v_id][clip_id] == -100] = -100
2059                result[v_id][clip_id] = torch.zeros(consensus.shape)
2060                for x in predictions:
2061                    result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int()
2062                result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models
2063                res = torch.zeros(result[v_id][clip_id].shape)
2064                for i in range(len(classes)):
2065                    res[
2066                        i, floor(kernel_size // 2) : -floor(kernel_size // 2)
2067                    ] = torch.nn.functional.conv1d(
2068                        result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0), kernel
2069                    )[
2070                        0, ...
2071                    ]
2072                result[v_id][clip_id] = res
2073        return result
2074
2075    def get_normalization_stats(self) -> Optional[Dict]:
2076        return self.train_dataloader.dataset.stats
class MyDataParallel(torch.nn.parallel.data_parallel.DataParallel):
37class MyDataParallel(nn.DataParallel):
38    def __init__(self, *args, **kwargs):
39        super().__init__(*args, **kwargs)
40        self.process_labels = self.module.process_labels
41
42    def freeze_feature_extractor(self):
43        self.module.freeze_feature_extractor()
44
45    def unfreeze_feature_extractor(self):
46        self.module.unfreeze_feature_extractor()
47
48    def transform_labels(self, device):
49        return self.module.transform_labels(device)
50
51    def logit_scale(self):
52        return self.module.logit_scale()
53
54    def main_task_off(self):
55        self.module.main_task_off()
56
57    def state_dict(self, *args, **kwargs):
58        return self.module.state_dict(*args, **kwargs)
59
60    def ssl_on(self):
61        self.module.ssl_on()
62
63    def ssl_off(self):
64        self.module.ssl_off()
65
66    def extract_features(self, x, start=0):
67        return self.module.extract_features(x, start)

Implements data parallelism at the module level.

This container parallelizes the application of the given module by splitting the input across the specified devices by chunking in the batch dimension (other objects will be copied once per device). In the forward pass, the module is replicated on each device, and each replica handles a portion of the input. During the backwards pass, gradients from each replica are summed into the original module.

The batch size should be larger than the number of GPUs used.

It is recommended to use ~torch.nn.parallel.DistributedDataParallel, instead of this class, to do multi-GPU training, even if there is only a single node. See: :ref:cuda-nn-ddp-instead and :ref:ddp.

Arbitrary positional and keyword inputs are allowed to be passed into DataParallel but some types are specially handled. tensors will be scattered on dim specified (default 0). tuple, list and dict types will be shallow copied. The other types will be shared among different threads and can be corrupted if written to in the model's forward pass.

The parallelized module must have its parameters and buffers on device_ids[0] before running this ~torch.nn.DataParallel module.

In each forward, module is replicated on each device, so any updates to the running module in forward will be lost. For example, if module has a counter attribute that is incremented in each forward, it will always stay at the initial value because the update is done on the replicas which are destroyed after forward. However, ~torch.nn.DataParallel guarantees that the replica on device[0] will have its parameters and buffers sharing storage with the base parallelized module. So in-place updates to the parameters or buffers on device[0] will be recorded. E.g., ~torch.nn.BatchNorm2d and ~torch.nn.utils.spectral_norm rely on this behavior to update the buffers.

Forward and backward hooks defined on module and its submodules will be invoked len(device_ids) times, each with inputs located on a particular device. Particularly, the hooks are only guaranteed to be executed in correct order with respect to operations on corresponding devices. For example, it is not guaranteed that hooks set via ~torch.nn.Module.register_forward_pre_hook be executed before all len(device_ids) ~torch.nn.Module.forward calls, but that each such hook be executed before the corresponding ~torch.nn.Module.forward call of that device.

When module returns a scalar (i.e., 0-dimensional tensor) in forward, this wrapper will return a vector of length equal to number of devices used in data parallelism, containing the result from each device.

There is a subtlety in using the pack sequence -> recurrent network -> unpack sequence pattern in a ~torch.nn.Module wrapped in ~torch.nn.DataParallel. See :ref:pack-rnn-unpack-with-data-parallelism section in FAQ for details.

Args: module (Module): module to be parallelized device_ids (list of int or torch.device): CUDA devices (default: all devices) output_device (int or torch.device): device location of output (default: device_ids[0])

Attributes: module (Module): the module to be parallelized

Example::

>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var)  # input_var can be on any device, including CPU
MyDataParallel(*args, **kwargs)
38    def __init__(self, *args, **kwargs):
39        super().__init__(*args, **kwargs)
40        self.process_labels = self.module.process_labels

Initializes internal Module state, shared by both nn.Module and ScriptModule.

def freeze_feature_extractor(self)
42    def freeze_feature_extractor(self):
43        self.module.freeze_feature_extractor()
def unfreeze_feature_extractor(self)
45    def unfreeze_feature_extractor(self):
46        self.module.unfreeze_feature_extractor()
def transform_labels(self, device)
48    def transform_labels(self, device):
49        return self.module.transform_labels(device)
def logit_scale(self)
51    def logit_scale(self):
52        return self.module.logit_scale()
def main_task_off(self)
54    def main_task_off(self):
55        self.module.main_task_off()
def state_dict(self, *args, **kwargs)
57    def state_dict(self, *args, **kwargs):
58        return self.module.state_dict(*args, **kwargs)

Returns a dictionary containing a whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included.

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Please avoid the use of argument destination as it is not designed for end-users.

Args: destination (dict, optional): If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None. prefix (str, optional): a prefix added to parameter and buffer names to compose the keys in state_dict. Default: ''. keep_vars (bool, optional): by default the ~torch.Tensor s returned in the state dict are detached from autograd. If it's set to True, detaching will not be performed. Default: False.

Returns: dict: a dictionary containing a whole state of the module

Example::

>>> module.state_dict().keys()
['bias', 'weight']
def ssl_on(self)
60    def ssl_on(self):
61        self.module.ssl_on()
def ssl_off(self)
63    def ssl_off(self):
64        self.module.ssl_off()
def extract_features(self, x, start=0)
66    def extract_features(self, x, start=0):
67        return self.module.extract_features(x, start)
Inherited Members
torch.nn.parallel.data_parallel.DataParallel
forward
replicate
scatter
parallel_apply
gather
torch.nn.modules.module.Module
dump_patches
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
T_destination
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
class Task:
  70class Task:
  71    """
  72    A universal trainer class that performs training, evaluation and prediction for all types of tasks and data
  73    """
  74
  75    def __init__(
  76        self,
  77        train_dataloader: DataLoader,
  78        model: Union[nn.Module, Model],
  79        loss: Callable[[torch.Tensor, torch.Tensor], float],
  80        num_epochs: int = 0,
  81        transformer: Transformer = None,
  82        ssl_losses: List = None,
  83        ssl_weights: List = None,
  84        lr: float = 1e-3,
  85        metrics: Dict = None,
  86        val_dataloader: DataLoader = None,
  87        test_dataloader: DataLoader = None,
  88        optimizer: Optimizer = None,
  89        device: str = "cuda",
  90        verbose: bool = True,
  91        log_file: Union[str, None] = None,
  92        augment_train: int = 1,
  93        augment_val: int = 0,
  94        validation_interval: int = 1,
  95        predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None,
  96        primary_predict_function: Callable = None,
  97        exclusive: bool = True,
  98        ignore_tags: bool = True,
  99        threshold: float = 0.5,
 100        model_save_path: str = None,
 101        model_save_epochs: int = 5,
 102        pseudolabel: bool = False,
 103        pseudolabel_start: int = 100,
 104        correction_interval: int = 2,
 105        pseudolabel_alpha_f: float = 3,
 106        alpha_growth_stop: int = 600,
 107        parallel: bool = False,
 108        skip_metrics: List = None,
 109    ) -> None:
 110        """
 111        Parameters
 112        ----------
 113        train_dataloader : torch.utils.data.DataLoader
 114            a training dataloader
 115        model : dlc2action.model.base_model.Model
 116            a model
 117        loss : callable
 118            a loss function
 119        num_epochs : int, default 0
 120            the number of epochs
 121        transformer : dlc2action.transformer.base_transformer.Transformer, optional
 122            a transformer
 123        ssl_losses : list, optional
 124            a list of SSL losses
 125        ssl_weights : list, optional
 126            a list of SSL weights (if not provided initializes to 1)
 127        lr : float, default 1e-3
 128            learning rate
 129        metrics : dict, optional
 130            a list of metric functions
 131        val_dataloader : torch.utils.data.DataLoader, optional
 132            a validation dataloader
 133        optimizer : torch.optim.Optimizer, optional
 134            an optimizer (`Adam` by default)
 135        device : str, default 'cuda'
 136            the device to train the model on
 137        verbose : bool, default True
 138            if `True`, the process is described in standard output
 139        log_file : str, optional
 140            the path to a text file where the process will be logged
 141        augment_train : {1, 0}
 142            number of augmentations to apply at training
 143        augment_val : int, default 0
 144            number of augmentations to apply at validation
 145        validation_interval : int, default 1
 146            every time this number of epochs passes, validation metrics are computed
 147        predict_function : callable, optional
 148            a function that maps probabilities to class predictions (if not provided, a default is generated)
 149        primary_predict_function : callable, optional
 150            a function that maps model output to probabilities (if not provided, initialized as identity)
 151        exclusive : bool, default True
 152            set to False for multi-label classification
 153        ignore_tags : bool, default False
 154            if `True`, samples with different meta tags will be mixed in batches
 155        threshold : float, default 0.5
 156            the threshold used for multi-label classification default prediction function
 157        model_save_path : str, optional
 158            the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path
 159            is not provided)
 160        model_save_epochs : int, default 5
 161            the interval for saving the model checkpoints (the last epoch is always saved)
 162        pseudolabel : bool, default False
 163            if True, the pseudolabeling procedure will be applied
 164        pseudolabel_start : int, default 100
 165            pseudolabeling starts after this epoch
 166        correction_interval : int, default 1
 167            after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and
 168            new pseudolabels are generated
 169        pseudolabel_alpha_f : float, default 3
 170            the maximum value of pseudolabeling alpha
 171        alpha_growth_stop : int, default 600
 172            pseudolabeling alpha stops growing after this epoch
 173        """
 174
 175        # pseudolabeling might be buggy right now -- not using it!
 176        if skip_metrics is None:
 177            skip_metrics = []
 178        self.train_dataloader = train_dataloader
 179        self.val_dataloader = val_dataloader
 180        self.test_dataloader = test_dataloader
 181        self.transformer = transformer
 182        self.num_epochs = num_epochs
 183        self.skip_metrics = skip_metrics
 184        self.verbose = verbose
 185        self.augment_train = int(augment_train)
 186        self.augment_val = int(augment_val)
 187        self.ignore_tags = ignore_tags
 188        self.validation_interval = int(validation_interval)
 189        self.log_file = log_file
 190        self.loss = loss
 191        self.model_save_path = model_save_path
 192        self.model_save_epochs = model_save_epochs
 193        self.epoch = 0
 194
 195        if metrics is None:
 196            metrics = {}
 197        self.metrics = metrics
 198
 199        if optimizer is None:
 200            optimizer = Adam
 201
 202        if ssl_weights is None:
 203            ssl_weights = [1 for _ in ssl_losses]
 204        if not isinstance(ssl_weights, Iterable):
 205            ssl_weights = [ssl_weights for _ in ssl_losses]
 206        self.ssl_weights = ssl_weights
 207
 208        self.optimizer_class = optimizer
 209        self.lr = lr
 210        if not isinstance(model, Model):
 211            self.model = LoadedModel(model=model)
 212        else:
 213            self.set_model(model)
 214        self.parallel = parallel
 215
 216        if self.transformer is None:
 217            self.augment_val = 0
 218            self.augment_train = 0
 219            self.transformer = EmptyTransformer()
 220
 221        if self.augment_train > 1:
 222            warnings.warn(
 223                'The "augment_train" parameter is too large -> setting it to 1.'
 224            )
 225            self.augment_train = 1
 226
 227        try:
 228            if device == "auto":
 229                device = "cuda" if torch.cuda.is_available() else "cpu"
 230            self.device = torch.device(device)
 231        except:
 232            raise ("The format of the device is incorrect")
 233
 234        if ssl_losses is None:
 235            self.ssl_losses = [lambda x, y: 0]
 236        else:
 237            self.ssl_losses = ssl_losses
 238
 239        if primary_predict_function is None:
 240            if exclusive:
 241                primary_predict_function = lambda x: nn.Softmax(x, dim=1)
 242            else:
 243                primary_predict_function = lambda x: torch.sigmoid(x)
 244        self.primary_predict_function = primary_predict_function
 245
 246        if predict_function is None:
 247            if exclusive:
 248                self.predict_function = lambda x: torch.max(x.data, 1)[1]
 249            else:
 250                self.predict_function = lambda x: (x > threshold).int()
 251        else:
 252            self.predict_function = predict_function
 253
 254        self.pseudolabel = pseudolabel
 255        self.alpha_f = pseudolabel_alpha_f
 256        self.T2 = alpha_growth_stop
 257        self.T1 = pseudolabel_start
 258        self.t = correction_interval
 259        if self.T2 <= self.T1:
 260            raise ValueError(
 261                f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got "
 262                f"{pseudolabel_start=} and {alpha_growth_stop=}"
 263            )
 264        self.decision_thresholds = [0.5 for x in self.behaviors_dict()]
 265
 266    def save_checkpoint(self, checkpoint_path: str) -> None:
 267        """
 268        Save a general checkpoint
 269
 270        Parameters
 271        ----------
 272        checkpoint_path : str
 273            the path where the checkpoint will be saved
 274        """
 275
 276        torch.save(
 277            {
 278                "epoch": self.epoch,
 279                "model_state_dict": self.model.state_dict(),
 280                "optimizer_state_dict": self.optimizer.state_dict(),
 281            },
 282            checkpoint_path,
 283        )
 284
 285    def load_from_checkpoint(
 286        self, checkpoint_path, only_model: bool = False, load_strict: bool = True
 287    ) -> None:
 288        """
 289        Load from a checkpoint
 290
 291        Parameters
 292        ----------
 293        checkpoint_path : str
 294            the path to the checkpoint
 295        only_model : bool, default False
 296            if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state
 297            dictionary)
 298        load_strict : bool, default True
 299            if `True`, any inconsistencies in state dictionaries are regarded as errors
 300        """
 301
 302        if checkpoint_path is None:
 303            return
 304        checkpoint = torch.load(checkpoint_path, map_location=self.device)
 305        self.model.to(self.device)
 306        self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict)
 307        if not only_model:
 308            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
 309            self.epoch = checkpoint["epoch"]
 310
 311    def save_model(self, save_path: str) -> None:
 312        """
 313        Save the model state dictionary
 314
 315        Parameters
 316        ----------
 317        save_path : str
 318            the path where the state will be saved
 319        """
 320
 321        torch.save(self.model.state_dict(), save_path)
 322        print("saved the model successfully")
 323
 324    def _apply_predict_functions(self, predicted):
 325        """
 326        Map from model output to prediction
 327        """
 328
 329        predicted = self.primary_predict_function(predicted)
 330        predicted = self.predict_function(predicted)
 331        return predicted
 332
 333    def _get_prediction(
 334        self,
 335        main_input: Dict,
 336        tags: torch.Tensor,
 337        ssl_inputs: List = None,
 338        ssl_targets: List = None,
 339        augment_n: int = 0,
 340        embedding: bool = False,
 341        subsample: List = None,
 342    ) -> Tuple:
 343        """
 344        Get the prediction of `self.model` for input averaged over `augment_n` augmentations
 345        """
 346
 347        if augment_n == 0:
 348            augment_n = 1
 349            augment = False
 350        else:
 351            augment = True
 352        model_input, ssl_inputs, ssl_targets = self.transformer.transform(
 353            deepcopy(main_input),
 354            ssl_inputs=ssl_inputs,
 355            ssl_targets=ssl_targets,
 356            augment=augment,
 357            subsample=subsample,
 358        )
 359        if not embedding:
 360            predicted, ssl_predicted = self.model(model_input, ssl_inputs, tag=tags)
 361        else:
 362            predicted = self.model.extract_features(model_input)
 363            ssl_predicted = None
 364        if self.parallel and predicted is not None and len(predicted.shape) == 4:
 365            predicted = predicted.reshape(
 366                (-1, model_input.shape[0], *predicted.shape[2:])
 367            )
 368        if predicted is not None and augment_n > 1:
 369            self.model.ssl_off()
 370            for i in range(augment_n - 1):
 371                model_input, *_ = self.transformer.transform(
 372                    deepcopy(main_input), augment=augment
 373                )
 374                if not embedding:
 375                    pred, _ = self.model(model_input, None, tag=tags)
 376                else:
 377                    pred = self.model.extract_features(model_input)
 378                predicted += pred.detach()
 379            self.model.ssl_on()
 380            predicted /= augment_n
 381        if self.model.process_labels:
 382            class_embedding = self.model.transform_labels(predicted.device)
 383            beh_dict = self.behaviors_dict()
 384            class_embedding = torch.cat(
 385                [class_embedding[beh_dict[k]] for k in sorted(beh_dict.keys())], 0
 386            )
 387            predicted = {
 388                "video": predicted,
 389                "text": class_embedding,
 390                "logit_scale": self.model.logit_scale(),
 391                "device": predicted.device,
 392            }
 393        return predicted, ssl_predicted, ssl_targets
 394
 395    def _ssl_loss(self, ssl_predicted, ssl_targets):
 396        """
 397        Apply SSL losses
 398        """
 399
 400        ssl_loss = []
 401        for loss, predicted, target in zip(self.ssl_losses, ssl_predicted, ssl_targets):
 402            ssl_loss.append(loss(predicted, target))
 403        return ssl_loss
 404
 405    def _loss_function(
 406        self,
 407        batch: Dict,
 408        augment_n: int,
 409        temporal_subsampling_size: int = None,
 410        skip_metrics: List = None,
 411    ) -> Tuple[float, float]:
 412        """
 413        Calculate the loss function and the metric for a dataloader batch
 414
 415        Averaging the predictions over augment_n augmentations.
 416        """
 417
 418        if "target" not in batch or torch.isnan(batch["target"]).all():
 419            raise ValueError("Cannot compute loss function with nan targets!")
 420        main_input = {k: v.to(self.device) for k, v in batch["input"].items()}
 421        main_target, ssl_targets, ssl_inputs = None, None, None
 422        main_target = batch["target"].to(self.device)
 423        if "ssl_targets" in batch:
 424            ssl_targets = [
 425                {k: v.to(self.device) for k, v in x.items()}
 426                if isinstance(x, dict)
 427                else None
 428                for x in batch["ssl_targets"]
 429            ]
 430        if "ssl_inputs" in batch:
 431            ssl_inputs = [
 432                {k: v.to(self.device) for k, v in x.items()}
 433                if isinstance(x, dict)
 434                else None
 435                for x in batch["ssl_inputs"]
 436            ]
 437        if temporal_subsampling_size is not None:
 438            subsample = sorted(
 439                random.sample(
 440                    range(main_target.shape[-1]),
 441                    int(temporal_subsampling_size * main_target.shape[-1]),
 442                )
 443            )
 444            main_target = main_target[..., subsample]
 445        else:
 446            subsample = None
 447
 448        if self.ignore_tags:
 449            tag = None
 450        else:
 451            tag = batch.get("tag")
 452        predicted, ssl_predicted, ssl_targets = self._get_prediction(
 453            main_input, tag, ssl_inputs, ssl_targets, augment_n, subsample=subsample
 454        )
 455        del main_input, ssl_inputs
 456        return self._compute(
 457            ssl_predicted,
 458            ssl_targets,
 459            predicted,
 460            main_target,
 461            tag=batch.get("tag"),
 462            skip_metrics=skip_metrics,
 463        )
 464
 465    def _compute(
 466        self,
 467        ssl_predicted: List,
 468        ssl_targets: List,
 469        predicted: torch.Tensor,
 470        main_target: torch.Tensor,
 471        tag: Any = None,
 472        skip_loss: bool = False,
 473        apply_primary_function: bool = True,
 474        skip_metrics: List = None,
 475    ) -> Tuple[float, float]:
 476        """
 477        Compute the losses and metrics from predictions
 478        """
 479
 480        if skip_metrics is None:
 481            skip_metrics = []
 482        if not skip_loss:
 483            ssl_losses = self._ssl_loss(ssl_predicted, ssl_targets)
 484            if predicted is not None:
 485                loss = self.loss(predicted, main_target)
 486            else:
 487                loss = 0
 488        else:
 489            ssl_losses, loss = [], 0
 490
 491        if predicted is not None:
 492            if isinstance(predicted, dict):
 493                predicted = {
 494                    k: v.detach()
 495                    for k, v in predicted.items()
 496                    if isinstance(v, torch.Tensor)
 497                }
 498            else:
 499                predicted = predicted.detach()
 500            if apply_primary_function:
 501                predicted = self.primary_predict_function(predicted)
 502            predicted_transformed = self.predict_function(predicted)
 503            for name, metric_function in self.metrics.items():
 504                if name not in skip_metrics:
 505                    if metric_function.needs_raw_data:
 506                        metric_function.update(predicted, main_target, tag)
 507                    else:
 508                        metric_function.update(
 509                            predicted_transformed,
 510                            main_target,
 511                            tag,
 512                        )
 513        return loss, ssl_losses
 514
 515    def _calculate_metrics(self) -> Dict:
 516        """
 517        Calculate the final values of epoch metrics
 518        """
 519
 520        epoch_metrics = {}
 521        for metric_name, metric in self.metrics.items():
 522            m = metric.calculate()
 523            if type(m) is dict:
 524                for k, v in m.items():
 525                    if type(v) is torch.Tensor:
 526                        v = v.item()
 527                    epoch_metrics[f"{metric_name}_{k}"] = v
 528            else:
 529                if type(m) is torch.Tensor:
 530                    m = m.item()
 531                epoch_metrics[metric_name] = m
 532            metric.reset()
 533        return epoch_metrics
 534
 535    def _run_epoch(
 536        self,
 537        dataloader: DataLoader,
 538        mode: str,
 539        augment_n: int,
 540        verbose: bool = False,
 541        unlabeled: bool = None,
 542        alpha: float = 1,
 543        temporal_subsampling_size: int = None,
 544    ) -> Tuple:
 545        """
 546        Run one epoch on dataloader
 547
 548        Averaging the predictions over augment_n augmentations.
 549        Use "train" mode for training and "val" mode for evaluation.
 550        """
 551
 552        if mode == "train":
 553            self.model.train()
 554        elif mode == "val":
 555            self.model.eval()
 556            pass
 557        else:
 558            raise ValueError(
 559                f'Mode {mode} is not recognized, please choose either "train" for training or "val" for validation'
 560            )
 561        if self.ignore_tags:
 562            tags = [None]
 563        else:
 564            tags = dataloader.dataset.get_tags()
 565        epoch_loss = 0
 566        epoch_ssl_loss = defaultdict(lambda: 0)
 567        data_len = 0
 568        set_pars = dataloader.dataset.set_indexing_parameters
 569        skip_metrics = self.skip_metrics if mode == "train" else None
 570        for tag in tags:
 571            set_pars(unlabeled=unlabeled, tag=tag)
 572            data_len += len(dataloader)
 573            if verbose:
 574                dataloader = tqdm(dataloader)
 575            for batch in dataloader:
 576                loss, ssl_losses = self._loss_function(
 577                    batch,
 578                    augment_n,
 579                    temporal_subsampling_size=temporal_subsampling_size,
 580                    skip_metrics=skip_metrics,
 581                )
 582                if loss != 0:
 583                    loss = loss * alpha
 584                    epoch_loss += loss.item()
 585                for i, (ssl_loss, weight) in enumerate(
 586                    zip(ssl_losses, self.ssl_weights)
 587                ):
 588                    if ssl_loss != 0:
 589                        epoch_ssl_loss[i] += ssl_loss.item()
 590                        loss = loss + weight * ssl_loss
 591                if mode == "train":
 592                    self.optimizer.zero_grad()
 593                    if loss.requires_grad:
 594                        loss.backward()
 595                    self.optimizer.step()
 596
 597        epoch_loss = epoch_loss / data_len
 598        epoch_ssl_loss = {k: v / data_len for k, v in epoch_ssl_loss.items()}
 599        epoch_metrics = self._calculate_metrics()
 600
 601        return epoch_loss, epoch_ssl_loss, epoch_metrics
 602
 603    def train(
 604        self,
 605        trial: Trial = None,
 606        optimized_metric: str = None,
 607        to_ram: bool = False,
 608        autostop_interval: int = 30,
 609        autostop_threshold: float = 0.001,
 610        autostop_metric: str = None,
 611        main_task_on: bool = True,
 612        ssl_on: bool = True,
 613        temporal_subsampling_size: int = None,
 614        loading_bar: bool = False,
 615    ) -> Tuple:
 616        """
 617        Train the task and return a log of epoch-average loss and metric
 618
 619        You can use the autostop parameters to finish training when the parameters are not improving. It will be
 620        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
 621        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
 622        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
 623
 624        Parameters
 625        ----------
 626        trial : Trial
 627            an `optuna` trial (for hyperparameter searches)
 628        optimized_metric : str
 629            the name of the metric being optimized (for hyperparameter searches)
 630        to_ram : bool, default False
 631            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
 632            if the dataset is too large)
 633        autostop_interval : int, default 50
 634            the number of epochs to average the autostop metric over
 635        autostop_threshold : float, default 0.001
 636            the autostop difference threshold
 637        autostop_metric : str, optional
 638            the autostop metric (can be any one of the tracked metrics of `'loss'`)
 639        main_task_on : bool, default True
 640            if `False`, the main task (action segmentation) will not be used in training
 641        ssl_on : bool, default True
 642            if `False`, the SSL task will not be used in training
 643
 644        Returns
 645        -------
 646        loss_log: list
 647            a list of float loss function values for each epoch
 648        metrics_log: dict
 649            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
 650            names, values are lists of function values)
 651        """
 652
 653        if self.parallel and not isinstance(self.model, nn.DataParallel):
 654            self.model = MyDataParallel(self.model)
 655        self.model.to(self.device)
 656        assert autostop_metric in [None, "loss"] + list(self.metrics)
 657        autostop_interval //= self.validation_interval
 658        if trial is not None and optimized_metric is None:
 659            raise ValueError(
 660                "You need to provide the optimized metric name (optimized_metric parameter) "
 661                "for optuna pruning to work!"
 662            )
 663        if to_ram:
 664            print("transferring datasets to RAM...")
 665            self.train_dataloader.dataset.to_ram()
 666            if self.val_dataloader is not None and len(self.val_dataloader) > 0:
 667                self.val_dataloader.dataset.to_ram()
 668        loss_log = {"train": [], "val": []}
 669        metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])}
 670        if not main_task_on:
 671            self.model.main_task_off()
 672        if not ssl_on:
 673            self.model.ssl_off()
 674        while self.epoch < self.num_epochs:
 675            self.epoch += 1
 676            unlabeled = None
 677            alpha = 1
 678            if self.pseudolabel:
 679                if self.epoch >= self.T1:
 680                    unlabeled = (self.epoch - self.T1) % self.t != 0
 681                    if unlabeled:
 682                        alpha = self._alpha(self.epoch)
 683                else:
 684                    unlabeled = False
 685            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
 686                dataloader=self.train_dataloader,
 687                mode="train",
 688                augment_n=self.augment_train,
 689                unlabeled=unlabeled,
 690                alpha=alpha,
 691                temporal_subsampling_size=temporal_subsampling_size,
 692                verbose=loading_bar,
 693            )
 694            loss_log["train"].append(epoch_loss)
 695            epoch_string = f"[epoch {self.epoch}]"
 696            if self.pseudolabel:
 697                if unlabeled:
 698                    epoch_string += " (unlabeled)"
 699                else:
 700                    epoch_string += " (labeled)"
 701            epoch_string += f": loss {epoch_loss:.4f}"
 702
 703            if len(epoch_ssl_loss) != 0:
 704                for key, value in sorted(epoch_ssl_loss.items()):
 705                    metrics_log["train"][f"ssl_loss_{key}"].append(value)
 706                    epoch_string += f", ssl_loss_{key} {value:.4f}"
 707
 708            for metric_name, metric_value in sorted(epoch_metrics.items()):
 709                if metric_name not in self.skip_metrics:
 710                    metrics_log["train"][metric_name].append(metric_value)
 711                    epoch_string += f", {metric_name} {metric_value:.3f}"
 712
 713            if (
 714                self.val_dataloader is not None
 715                and self.epoch % self.validation_interval == 0
 716            ):
 717                with torch.no_grad():
 718                    epoch_string += "\n"
 719                    (
 720                        val_epoch_loss,
 721                        val_epoch_ssl_loss,
 722                        val_epoch_metrics,
 723                    ) = self._run_epoch(
 724                        dataloader=self.val_dataloader,
 725                        mode="val",
 726                        augment_n=self.augment_val,
 727                    )
 728                    loss_log["val"].append(val_epoch_loss)
 729                    epoch_string += f"validation: loss {val_epoch_loss:.4f}"
 730
 731                    if len(val_epoch_ssl_loss) != 0:
 732                        for key, value in sorted(val_epoch_ssl_loss.items()):
 733                            metrics_log["val"][f"ssl_loss_{key}"].append(value)
 734                            epoch_string += f", ssl_loss_{key} {value:.4f}"
 735
 736                    for metric_name, metric_value in sorted(val_epoch_metrics.items()):
 737                        metrics_log["val"][metric_name].append(metric_value)
 738                        epoch_string += f", {metric_name} {metric_value:.3f}"
 739
 740                if trial is not None:
 741                    if optimized_metric not in metrics_log["val"]:
 742                        raise ValueError(
 743                            f"The {optimized_metric} metric set for optimization is not being logged!"
 744                        )
 745                    trial.report(metrics_log["val"][optimized_metric][-1], self.epoch)
 746                    if trial.should_prune():
 747                        raise TrialPruned()
 748
 749            if self.verbose:
 750                print(epoch_string)
 751
 752            if self.log_file is not None:
 753                with open(self.log_file, "a") as f:
 754                    f.write(epoch_string + "\n")
 755
 756            save_condition = (
 757                (self.model_save_epochs != 0)
 758                and (self.epoch % self.model_save_epochs == 0)
 759            ) or (self.epoch == self.num_epochs)
 760
 761            if self.epoch > 0 and save_condition and self.model_save_path is not None:
 762                epoch_s = str(self.epoch).zfill(len(str(self.num_epochs)))
 763                self.save_checkpoint(
 764                    os.path.join(self.model_save_path, f"epoch{epoch_s}.pt")
 765                )
 766
 767            if self.pseudolabel and self.epoch >= self.T1 and not unlabeled:
 768                self._set_pseudolabels()
 769
 770            if autostop_metric == "loss":
 771                if len(loss_log["val"]) > autostop_interval * 2:
 772                    if (
 773                        np.mean(loss_log["val"][-autostop_interval:])
 774                        < np.mean(
 775                            loss_log["val"][-2 * autostop_interval : -autostop_interval]
 776                        )
 777                        + autostop_threshold
 778                    ):
 779                        break
 780            elif autostop_metric in metrics_log["val"]:
 781                if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2:
 782                    if (
 783                        np.mean(
 784                            metrics_log["val"][autostop_metric][-autostop_interval:]
 785                        )
 786                        < np.mean(
 787                            metrics_log["val"][autostop_metric][
 788                                -2 * autostop_interval : -autostop_interval
 789                            ]
 790                        )
 791                        + autostop_threshold
 792                    ):
 793                        break
 794
 795        metrics_log = {k: dict(v) for k, v in metrics_log.items()}
 796
 797        return loss_log, metrics_log
 798
 799    def evaluate_prediction(
 800        self,
 801        prediction: Union[torch.Tensor, Dict],
 802        data: Union[DataLoader, BehaviorDataset, str] = None,
 803        batch_size: int = 32,
 804    ) -> Tuple:
 805        """
 806        Compute metrics for a prediction
 807
 808        Parameters
 809        ----------
 810        prediction : torch.Tensor
 811            the prediction
 812        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 813            the data the prediction was made for (if not provided, take the validation dataset)
 814        batch_size : int, default 32
 815            the batch size
 816
 817        Returns
 818        -------
 819        loss : float
 820            the average value of the loss function
 821        metric : dict
 822            a dictionary of average values of metric functions
 823        """
 824
 825        if type(data) is not DataLoader:
 826            dataset = self._get_dataset(data)
 827            data = DataLoader(dataset, shuffle=False, batch_size=batch_size)
 828        epoch_loss = 0
 829        if isinstance(prediction, dict):
 830            num_classes = len(self.behaviors_dict())
 831            length = dataset.len_segment()
 832            coords = dataset.annotation_store.get_original_coordinates()
 833            for batch in data:
 834                main_target = batch["target"]
 835                pr_coords = coords[batch["index"]]
 836                predicted = torch.zeros((len(pr_coords), num_classes, length))
 837                for i, c in enumerate(pr_coords):
 838                    video_id = dataset.input_store.get_video_id(c)
 839                    clip_id = dataset.input_store.get_clip_id(c)
 840                    start, end = dataset.input_store.get_clip_start_end(c)
 841                    predicted[i, :, : end - start] = prediction[video_id][clip_id][
 842                        :, start:end
 843                    ]
 844                self._compute(
 845                    [],
 846                    [],
 847                    predicted,
 848                    main_target,
 849                    skip_loss=True,
 850                    tag=batch.get("tag"),
 851                    apply_primary_function=False,
 852                )
 853        else:
 854            for batch in data:
 855                main_target = batch["target"]
 856                predicted = prediction[batch["index"]]
 857                self._compute(
 858                    [],
 859                    [],
 860                    predicted,
 861                    main_target,
 862                    skip_loss=True,
 863                    tag=batch.get("tag"),
 864                    apply_primary_function=False,
 865                )
 866        epoch_metrics = self._calculate_metrics()
 867        strings = [
 868            f"{metric_name} {metric_value:.3f}"
 869            for metric_name, metric_value in epoch_metrics.items()
 870        ]
 871        val_string = ", ".join(sorted(strings))
 872        print(val_string)
 873        return epoch_loss, epoch_metrics
 874
 875    def evaluate(
 876        self,
 877        data: Union[DataLoader, BehaviorDataset, str] = None,
 878        augment_n: int = 0,
 879        batch_size: int = 32,
 880        verbose: bool = True,
 881    ) -> Tuple:
 882        """
 883        Evaluate the Task model
 884
 885        Parameters
 886        ----------
 887        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
 888            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 889        augment_n : int, default 0
 890            the number of augmentations to average results over
 891        batch_size : int, default 32
 892            the batch size
 893        verbose : bool, default True
 894            if True, the process is reported to standard output
 895
 896        Returns
 897        -------
 898        loss : float
 899            the average value of the loss function
 900        ssl_loss : float
 901            the average value of the SSL loss function
 902        metric : dict
 903            a dictionary of average values of metric functions
 904        """
 905
 906        if self.parallel and not isinstance(self.model, nn.DataParallel):
 907            self.model = MyDataParallel(self.model)
 908        self.model.to(self.device)
 909        if type(data) is not DataLoader:
 910            data = self._get_dataset(data)
 911            data = DataLoader(data, shuffle=False, batch_size=batch_size)
 912        with torch.no_grad():
 913            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
 914                dataloader=data, mode="val", augment_n=augment_n, verbose=verbose
 915            )
 916        val_string = f"loss {epoch_loss:.4f}"
 917        for metric_name, metric_value in sorted(epoch_metrics.items()):
 918            val_string += f", {metric_name} {metric_value:.3f}"
 919        print(val_string)
 920        return epoch_loss, epoch_ssl_loss, epoch_metrics
 921
 922    def predict(
 923        self,
 924        data: Union[DataLoader, BehaviorDataset, str] = None,
 925        raw_output: bool = False,
 926        apply_primary_function: bool = True,
 927        augment_n: int = 0,
 928        batch_size: int = 32,
 929        train_mode: bool = False,
 930        to_ram: bool = False,
 931        embedding: bool = False,
 932    ) -> torch.Tensor:
 933        """
 934        Make a prediction with the Task model
 935
 936        Parameters
 937        ----------
 938        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional
 939            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
 940        raw_output : bool, default False
 941            if `True`, the raw predicted probabilities are returned
 942        apply_primary_function : bool, default True
 943            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
 944            the input)
 945        augment_n : int, default 0
 946            the number of augmentations to average results over
 947        batch_size : int, default 32
 948            the batch size
 949        train_mode : bool, default False
 950            if `True`, the model is used in training mode (affects dropout and batch normalization layers)
 951        to_ram : bool, default False
 952            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
 953            if the dataset is too large)
 954        embedding : bool, default False
 955            if `True`, the output of feature extractor is returned, ignoring the prediction module of the model
 956
 957        Returns
 958        -------
 959        prediction : torch.Tensor
 960            a prediction for the input data
 961        """
 962
 963        if self.parallel and not isinstance(self.model, nn.DataParallel):
 964            self.model = MyDataParallel(self.model)
 965        self.model.to(self.device)
 966        if train_mode:
 967            self.model.train()
 968        else:
 969            self.model.eval()
 970        output = []
 971        if embedding:
 972            raw_output = True
 973            apply_primary_function = True
 974        if type(data) is not DataLoader:
 975            data = self._get_dataset(data)
 976            if to_ram:
 977                print("transferring dataset to RAM...")
 978                data.to_ram()
 979            data = DataLoader(data, shuffle=False, batch_size=batch_size)
 980        self.model.ssl_off()
 981        with torch.no_grad():
 982            for batch in tqdm(data):
 983                input = {k: v.to(self.device) for k, v in batch["input"].items()}
 984                predicted, _, _ = self._get_prediction(
 985                    input,
 986                    batch.get("tag"),
 987                    augment_n=augment_n,
 988                    embedding=embedding,
 989                )
 990                if apply_primary_function:
 991                    predicted = self.primary_predict_function(predicted)
 992                if not raw_output:
 993                    predicted = self.predict_function(predicted)
 994                output.append(predicted.detach().cpu())
 995        self.model.ssl_on()
 996        output = torch.cat(output).detach()
 997        return output
 998
 999    def dataset(self, mode="train") -> BehaviorDataset:
1000        """
1001        Get a dataset
1002
1003        Parameters
1004        ----------
1005        mode : {'train', 'val', 'test}
1006            the dataset to get
1007
1008        Returns
1009        -------
1010        dataset : dlc2action.data.dataset.BehaviorDataset
1011            the dataset
1012        """
1013
1014        dataloader = self.dataloader(mode)
1015        if dataloader is None:
1016            raise ValueError("The length of the dataloader is 0!")
1017        return dataloader.dataset
1018
1019    def dataloader(self, mode: str = "train") -> DataLoader:
1020        """
1021        Get a dataloader
1022
1023        Parameters
1024        ----------
1025        mode : {'train', 'val', 'test}
1026            the dataset to get
1027
1028        Returns
1029        -------
1030        dataloader : torch.utils.data.DataLoader
1031            the dataloader
1032        """
1033
1034        if mode == "train":
1035            return self.train_dataloader
1036        elif mode == "val":
1037            return self.val_dataloader
1038        elif mode == "test":
1039            return self.test_dataloader
1040        else:
1041            raise ValueError(
1042                f'The {mode} mode is not recognized, please choose from "train", "val" or "test"'
1043            )
1044
1045    def _get_dataset(self, dataset):
1046        """
1047        Get a dataset from a dataloader, a string ('train', 'test' or 'val') or `None` (default)
1048        """
1049
1050        if dataset is None:
1051            dataset = self.dataset("val")
1052        elif dataset in ["train", "val", "test"]:
1053            dataset = self.dataset(dataset)
1054        elif type(dataset) is DataLoader:
1055            dataset = dataset.dataset
1056        if type(dataset) is BehaviorDataset:
1057            return dataset
1058        else:
1059            raise TypeError(f"The {type(dataset)} type of dataset is not recognized!")
1060
1061    def _get_dataloader(self, dataset):
1062        """
1063        Get a dataloader from a dataset, a string ('train', 'test' or 'val') or `None` (default)
1064        """
1065
1066        if dataset is None:
1067            dataset = self.dataloader("val")
1068        elif dataset in ["train", "val", "test"]:
1069            dataset = self.dataloader(dataset)
1070            if dataset is None:
1071                raise ValueError(f"The length of the dataloader is 0!")
1072        elif type(dataset) is BehaviorDataset:
1073            dataset = DataLoader(dataset)
1074        if type(dataset) is DataLoader:
1075            return dataset
1076        else:
1077            raise TypeError(f"The {type(dataset)} type of dataset is not recognized!")
1078
1079    def generate_full_length_prediction(
1080        self, dataset=None, batch_size=32, augment_n=10
1081    ):
1082        """
1083        Compile a prediction for the original input sequences
1084
1085        Parameters
1086        ----------
1087        dataset : BehaviorDataset, optional
1088            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1089        batch_size : int, default 32
1090            the batch size
1091        augment_n : int, default 10
1092            the number of augmentations to average results over
1093
1094        Returns
1095        -------
1096        prediction : dict
1097            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1098            are prediction tensors
1099        """
1100
1101        dataset = self._get_dataset(dataset)
1102        if not isinstance(dataset, BehaviorDataset):
1103            raise TypeError(
1104                f"The dataset parameter has to be either None, string, "
1105                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1106            )
1107        predicted = self.predict(
1108            dataset,
1109            raw_output=True,
1110            apply_primary_function=True,
1111            augment_n=augment_n,
1112            batch_size=batch_size,
1113        )
1114        predicted = dataset.generate_full_length_prediction(predicted)
1115        predicted = {
1116            v_id: {
1117                clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze()
1118                for clip_id, v in video_dict.items()
1119            }
1120            for v_id, video_dict in predicted.items()
1121        }
1122        return predicted
1123
1124    def generate_submission(
1125        self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10
1126    ):
1127        """
1128        Generate a MABe-22 style submission dictionary
1129
1130        Parameters
1131        ----------
1132        dataset : BehaviorDataset, optional
1133            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1134        batch_size : int, default 32
1135            the batch size
1136        augment_n : int, default 10
1137            the number of augmentations to average results over
1138
1139        Returns
1140        -------
1141        submission : dict
1142            a dictionary with frame number mapping and embeddings
1143        """
1144
1145        dataset = self._get_dataset(dataset)
1146        if not isinstance(dataset, BehaviorDataset):
1147            raise TypeError(
1148                f"The dataset parameter has to be either None, string, "
1149                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1150            )
1151        predicted = self.predict(
1152            dataset,
1153            raw_output=True,
1154            apply_primary_function=True,
1155            augment_n=augment_n,
1156            batch_size=batch_size,
1157            embedding=True,
1158        )
1159        predicted = dataset.generate_full_length_prediction(predicted)
1160        frame_map = np.load(frame_number_map_file, allow_pickle=True).item()
1161        length = frame_map[list(frame_map.keys())[-1]][1]
1162        embeddings = None
1163        for video_id in list(predicted.keys()):
1164            split = video_id.split("--")
1165            if len(split) != 2 or len(predicted[video_id]) > 1:
1166                raise RuntimeError(
1167                    "Generating submissions is only implemented for the mabe22 dataset!"
1168                )
1169            if split[1] not in frame_map:
1170                raise RuntimeError(f"The {split[1]} video is not in the frame map file")
1171            v_id = split[1]
1172            clip_id = list(predicted[video_id].keys())[0]
1173            if embeddings is None:
1174                embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0]))
1175            start, end = frame_map[v_id]
1176            embeddings[start:end, :] = predicted[video_id][clip_id].T
1177            predicted.pop(video_id)
1178        predicted = {
1179            "frame_number_map": frame_map,
1180            "embeddings": embeddings.astype(np.float32),
1181        }
1182        return predicted
1183
1184    def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor:
1185        """
1186        Get a list of True group beginning and end indices from a boolean tensor
1187        """
1188
1189        output, indices = torch.unique_consecutive(tensor, return_inverse=True)
1190        true_indices = torch.where(output)[0]
1191        starts = torch.tensor(
1192            [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices]
1193        )
1194        ends = torch.tensor(
1195            [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices]
1196        )
1197        return torch.stack([starts, ends]).T
1198
1199    def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor:
1200        """
1201        Get rid of jittering in a non-exclusive classification tensor
1202
1203        First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than
1204        `smooth_interval`.
1205        """
1206
1207        if len(tensor.shape) > 1:
1208            for c in tensor.shape[1]:
1209                intervals = self._get_intervals(tensor[:, c] == 0)
1210                interval_lengths = torch.tensor(
1211                    [interval[1] - interval[0] for interval in intervals]
1212                )
1213                short_intervals = intervals[interval_lengths <= smooth_interval]
1214                for start, end in short_intervals:
1215                    tensor[start:end, c] = 1
1216                intervals = self._get_intervals(tensor[:, c] == 1)
1217                interval_lengths = torch.tensor(
1218                    [interval[1] - interval[0] for interval in intervals]
1219                )
1220                short_intervals = intervals[interval_lengths <= smooth_interval]
1221                for start, end in short_intervals:
1222                    tensor[start:end, c] = 0
1223        else:
1224            for c in tensor.unique():
1225                intervals = self._get_intervals(tensor == c)
1226                interval_lengths = torch.tensor(
1227                    [interval[1] - interval[0] for interval in intervals]
1228                )
1229                short_intervals = intervals[interval_lengths <= smooth_interval]
1230                for start, end in short_intervals:
1231                    if start == 0:
1232                        tensor[start:end] = tensor[end + 1]
1233                    else:
1234                        tensor[start:end] = tensor[start - 1]
1235        return tensor
1236
1237    def _visualize_results_label(
1238        self,
1239        label: str,
1240        save_path: str = None,
1241        add_legend: bool = True,
1242        ground_truth: bool = True,
1243        hide_axes: bool = False,
1244        width: int = 10,
1245        whole_video: bool = False,
1246        transparent: bool = False,
1247        dataset: BehaviorDataset = None,
1248        smooth_interval: int = 0,
1249        title: str = None,
1250    ):
1251        """
1252        Visualize random predictions
1253
1254        Parameters
1255        ----------
1256        save_path : str, optional
1257            the path where the plot will be saved
1258        add_legend : bool, default True
1259            if `True`, legend will be added to the plot
1260        ground_truth : bool, default True
1261            if `True`, ground truth will be added to the plot
1262        colormap : str, default 'viridis'
1263            the `matplotlib` colormap to use
1264        hide_axes : bool, default True
1265            if `True`, the axes will be hidden on the plot
1266        min_classes : int, default 1
1267            the minimum number of classes in a displayed interval
1268        width : float, default 10
1269            the width of the plot
1270        whole_video : bool, default False
1271            if `True`, whole videos are plotted instead of segments
1272        transparent : bool, default False
1273            if `True`, the background on the plot is transparent
1274        dataset : BehaviorDataset, optional
1275            the dataset to make the prediction for (if not provided, the validation dataset is used)
1276        drop_classes : set, optional
1277            a set of class names to not be displayed
1278        """
1279
1280        if title is None:
1281            title = ""
1282        dataset = self._get_dataset(dataset)
1283        inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()}
1284        label_ind = inverse_dict[label]
1285        labels = {1: label, -100: "unknown"}
1286        label_keys = [1, -100]
1287        color_list = ["blue", "gray"]
1288        if whole_video:
1289            predicted = self.generate_full_length_prediction(dataset)
1290            keys = list(predicted.keys())
1291        counter = 0
1292        if whole_video:
1293            max_iter = len(keys) * 5
1294        else:
1295            max_iter = len(dataset) * 5
1296        ok = False
1297        while not ok:
1298            counter += 1
1299            if counter > max_iter:
1300                raise RuntimeError(
1301                    "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive"
1302                )
1303            if whole_video:
1304                i = randint(0, len(keys) - 1)
1305                prediction = predicted[keys[i]]
1306                keys_i = list(prediction.keys())
1307                j = randint(0, len(keys_i) - 1)
1308                full_p = prediction[keys_i[j]]
1309                prediction = prediction[keys_i[j]][label_ind]
1310            else:
1311                dataloader = DataLoader(dataset)
1312                i = randint(0, len(dataloader) - 1)
1313                for num, batch in enumerate(dataloader):
1314                    if num == i:
1315                        break
1316                input_data = {k: v.to(self.device) for k, v in batch["input"].items()}
1317                prediction, *_ = self._get_prediction(
1318                    input_data, batch.get("tag"), augment_n=5
1319                )
1320                prediction = self._apply_predict_functions(prediction)
1321                j = randint(0, len(prediction) - 1)
1322                full_p = prediction[j]
1323                prediction = prediction[j][label_ind]
1324            classes = [x for x in torch.unique(prediction) if int(x) in label_keys]
1325            ok = 1 in classes
1326        fig, ax = plt.subplots(figsize=(width, 2))
1327        for c in classes:
1328            c_i = label_keys.index(int(c))
1329            output, indices, counts = torch.unique_consecutive(
1330                prediction == c, return_inverse=True, return_counts=True
1331            )
1332            long_indices = torch.where(output)[0]
1333            res_indices_start = [
1334                (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices
1335            ]
1336            res_indices_end = [
1337                (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1338                for i in long_indices
1339            ]
1340            res_indices_len = [
1341                end - start for start, end in zip(res_indices_start, res_indices_end)
1342            ]
1343            ax.broken_barh(
1344                list(zip(res_indices_start, res_indices_len)),
1345                (0, 1),
1346                label=labels[int(c)],
1347                facecolors=color_list[c_i],
1348            )
1349        if ground_truth:
1350            gt = batch["target"][j][label_ind].to(self.device)
1351            classes_gt = [x for x in torch.unique(gt) if int(x) in label_keys]
1352            for c in classes_gt:
1353                c_i = label_keys.index(int(c))
1354                if c in classes:
1355                    label = None
1356                else:
1357                    label = labels[int(c)]
1358                output, indices, counts = torch.unique_consecutive(
1359                    gt == c, return_inverse=True, return_counts=True
1360                )
1361                long_indices = torch.where(output * (counts > 5))[0]
1362                res_indices_start = [
1363                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1364                    for i in long_indices
1365                ]
1366                res_indices_end = [
1367                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1368                    for i in long_indices
1369                ]
1370                res_indices_len = [
1371                    end - start
1372                    for start, end in zip(res_indices_start, res_indices_end)
1373                ]
1374                ax.broken_barh(
1375                    list(zip(res_indices_start, res_indices_len)),
1376                    (1.5, 1),
1377                    facecolors=color_list[c_i],
1378                    label=label,
1379                )
1380        self._compute(
1381            main_target=batch["target"][j].unsqueeze(0).to(self.device),
1382            predicted=full_p.unsqueeze(0).to(self.device),
1383            ssl_targets=[],
1384            ssl_predicted=[],
1385            skip_loss=True,
1386        )
1387        metrics = self._calculate_metrics()
1388        if smooth_interval > 0:
1389            smoothed = self._smooth(full_p, smooth_interval=smooth_interval)[
1390                label_ind, :
1391            ]
1392            for c in classes:
1393                c_i = label_keys.index(int(c))
1394                output, indices, counts = torch.unique_consecutive(
1395                    smoothed == c, return_inverse=True, return_counts=True
1396                )
1397                long_indices = torch.where(output)[0]
1398                res_indices_start = [
1399                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1400                    for i in long_indices
1401                ]
1402                res_indices_end = [
1403                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1404                    for i in long_indices
1405                ]
1406                res_indices_len = [
1407                    end - start
1408                    for start, end in zip(res_indices_start, res_indices_end)
1409                ]
1410                ax.broken_barh(
1411                    list(zip(res_indices_start, res_indices_len)),
1412                    (3, 1),
1413                    label=labels[int(c)],
1414                    facecolors=color_list[c_i],
1415                )
1416        keys = list(metrics.keys())
1417        for key in keys:
1418            if key.split("_")[-1] != (str(label_ind)):
1419                metrics.pop(key)
1420        title = [title]
1421        for key, value in metrics.items():
1422            title.append(f"{'_'.join(key.split('_')[: -1])}: {value:.2f}")
1423        title = ", ".join(title)
1424        if not ground_truth:
1425            ax.axes.yaxis.set_visible(False)
1426        else:
1427            ax.set_yticks([0.5, 2])
1428            ax.set_yticklabels(["prediction", "ground truth"])
1429        if add_legend:
1430            ax.legend()
1431        if hide_axes:
1432            plt.axis("off")
1433        plt.title(title)
1434        plt.xlim((0, len(prediction)))
1435        if save_path is not None:
1436            plt.savefig(save_path, transparent=transparent)
1437        plt.show()
1438
1439    def visualize_results(
1440        self,
1441        save_path: str = None,
1442        add_legend: bool = True,
1443        ground_truth: bool = True,
1444        colormap: str = "viridis",
1445        hide_axes: bool = False,
1446        min_classes: int = 1,
1447        width: int = 10,
1448        whole_video: bool = False,
1449        transparent: bool = False,
1450        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1451        drop_classes: Set = None,
1452        search_classes: Set = None,
1453        num_samples: int = 1,
1454        smooth_interval_prediction: int = None,
1455        behavior_name: str = None,
1456    ):
1457        """
1458        Visualize random predictions
1459
1460        Parameters
1461        ----------
1462        save_path : str, optional
1463            the path where the plot will be saved
1464        add_legend : bool, default True
1465            if `True`, legend will be added to the plot
1466        ground_truth : bool, default True
1467            if `True`, ground truth will be added to the plot
1468        colormap : str, default 'Accent'
1469            the `matplotlib` colormap to use
1470        hide_axes : bool, default True
1471            if `True`, the axes will be hidden on the plot
1472        min_classes : int, default 1
1473            the minimum number of classes in a displayed interval
1474        width : float, default 10
1475            the width of the plot
1476        whole_video : bool, default False
1477            if `True`, whole videos are plotted instead of segments
1478        transparent : bool, default False
1479            if `True`, the background on the plot is transparent
1480        dataset : BehaviorDataset | DataLoader | str | None, optional
1481            the dataset to make the prediction for (if not provided, the validation dataset is used)
1482        drop_classes : set, optional
1483            a set of class names to not be displayed
1484        search_classes : set, optional
1485            if given, only intervals where at least one of the classes is in ground truth will be shown
1486        """
1487
1488        if drop_classes is None:
1489            drop_classes = []
1490        dataset = self._get_dataset(dataset)
1491        if dataset.annotation_class() == "exclusive_classification":
1492            exclusive = True
1493        elif dataset.annotation_class() == "nonexclusive_classification":
1494            exclusive = False
1495        else:
1496            raise NotImplementedError(
1497                f"Results visualisation is not implemented for {dataset.annotation_class() } datasets!"
1498            )
1499        if exclusive:
1500            labels = {
1501                k: v
1502                for k, v in dataset.behaviors_dict().items()
1503                if v not in drop_classes
1504            }
1505            labels.update({-100: "unknown"})
1506        else:
1507            inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()}
1508            if behavior_name is None:
1509                behavior_name = list(inverse_dict.keys())[0]
1510            label_ind = inverse_dict[behavior_name]
1511            labels = {1: label, -100: "unknown"}
1512        label_keys = sorted([int(x) for x in labels.keys()])
1513        if search_classes is None:
1514            ok = True
1515        else:
1516            ok = False
1517        classes = []
1518        if whole_video:
1519            predicted = self.generate_full_length_prediction(dataset)
1520            keys = list(predicted.keys())
1521        counter = 0
1522        if whole_video:
1523            max_iter = len(keys) * 2
1524        else:
1525            max_iter = len(dataset) * 2
1526        while len(classes) < min_classes or not ok:
1527            counter += 1
1528            if counter > max_iter:
1529                raise RuntimeError(
1530                    "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive"
1531                )
1532            if whole_video:
1533                i = randint(0, len(keys) - 1)
1534                prediction = predicted[keys[i]]
1535                keys_i = list(prediction.keys())
1536                j = randint(0, len(keys_i) - 1)
1537                prediction = prediction[keys_i[j]]
1538                key1 = keys[i]
1539                key2 = keys_i[j]
1540            else:
1541                dataloader = DataLoader(dataset)
1542                i = randint(0, len(dataloader) - 1)
1543                for num, batch in enumerate(dataloader):
1544                    if num == i:
1545                        break
1546                input = {k: v.to(self.device) for k, v in batch["input"].items()}
1547                prediction, *_ = self._get_prediction(
1548                    input, batch.get("tag"), augment_n=5
1549                )
1550                prediction = self._apply_predict_functions(prediction)
1551                j = randint(0, len(prediction) - 1)
1552                prediction = prediction[j]
1553            if not exclusive:
1554                prediction = prediction[label_ind]
1555            if smooth_interval_prediction > 0:
1556                unsmoothed_prediction = deepcopy(prediction)
1557                prediction = self._smooth(prediction, smooth_interval_prediction)
1558                height = 3
1559            else:
1560                height = 2
1561            classes = [
1562                labels[int(x)] for x in torch.unique(prediction) if x in label_keys
1563            ]
1564            if search_classes is not None:
1565                ok = any([x in classes for x in search_classes])
1566        fig, ax = plt.subplots(figsize=(width, height))
1567        color_list = cm.get_cmap(colormap, lut=len(labels)).colors
1568
1569        def _plot_prediction(prediction, height, set_labels=True):
1570            for c in label_keys:
1571                c_i = label_keys.index(int(c))
1572                output, indices, counts = torch.unique_consecutive(
1573                    prediction == c, return_inverse=True, return_counts=True
1574                )
1575                long_indices = torch.where(output)[0]
1576                if len(long_indices) == 0:
1577                    continue
1578                res_indices_start = [
1579                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1580                    for i in long_indices
1581                ]
1582                res_indices_end = [
1583                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1584                    for i in long_indices
1585                ]
1586                res_indices_len = [
1587                    end - start
1588                    for start, end in zip(res_indices_start, res_indices_end)
1589                ]
1590                if set_labels:
1591                    label = labels[int(c)]
1592                else:
1593                    label = None
1594                ax.broken_barh(
1595                    list(zip(res_indices_start, res_indices_len)),
1596                    (height, 1),
1597                    label=label,
1598                    facecolors=color_list[c_i],
1599                )
1600
1601        if smooth_interval_prediction > 0:
1602            _plot_prediction(unsmoothed_prediction, 0)
1603            _plot_prediction(prediction, 1.5, set_labels=False)
1604            gt_height = 3
1605        else:
1606            _plot_prediction(prediction, 0)
1607            gt_height = 1.5
1608        if ground_truth:
1609            if not whole_video:
1610                gt = batch["target"][j].to(self.device)
1611            else:
1612                gt = dataset.generate_full_length_gt()[key1][key2]
1613            for c in label_keys:
1614                c_i = label_keys.index(int(c))
1615                if labels[int(c)] in classes:
1616                    label = None
1617                else:
1618                    label = labels[int(c)]
1619                output, indices, counts = torch.unique_consecutive(
1620                    gt == c, return_inverse=True, return_counts=True
1621                )
1622                long_indices = torch.where(output)[0]
1623                if len(long_indices) == 0:
1624                    continue
1625                res_indices_start = [
1626                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1627                    for i in long_indices
1628                ]
1629                res_indices_end = [
1630                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1631                    for i in long_indices
1632                ]
1633                res_indices_len = [
1634                    end - start
1635                    for start, end in zip(res_indices_start, res_indices_end)
1636                ]
1637                ax.broken_barh(
1638                    list(zip(res_indices_start, res_indices_len)),
1639                    (gt_height, 1),
1640                    facecolors=color_list[c_i] if c != "unknown" else "gray",
1641                    label=label,
1642                )
1643        if not ground_truth:
1644            ax.axes.yaxis.set_visible(False)
1645        else:
1646            if smooth_interval_prediction > 0:
1647                ax.set_yticks([0.5, 2, 3.5])
1648                ax.set_yticklabels(["prediction", "smoothed", "ground truth"])
1649            else:
1650                ax.set_yticks([0.5, 2])
1651                ax.set_yticklabels(["prediction", "ground truth"])
1652        if add_legend:
1653            ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1654        if hide_axes:
1655            plt.axis("off")
1656        if save_path is not None:
1657            plt.savefig(save_path, transparent=transparent)
1658        plt.show()
1659
1660    def set_ssl_transformations(self, ssl_transformations):
1661        self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations)
1662        if self.val_dataloader is not None:
1663            self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations)
1664
1665    def set_ssl_losses(self, ssl_losses: list) -> None:
1666        """
1667        Set SSL losses
1668
1669        Parameters
1670        ----------
1671        ssl_losses : list
1672            a list of callable SSL losses
1673        """
1674
1675        self.ssl_losses = ssl_losses
1676
1677    def set_log(self, log: str) -> None:
1678        """
1679        Set the log file
1680
1681        Parameters
1682        ----------
1683        log: str
1684            the mew log file path
1685        """
1686
1687        self.log_file = log
1688
1689    def set_keep_target_none(self, keep_target_none: List) -> None:
1690        """
1691        Set the keep_target_none parameter of the transformer
1692
1693        Parameters
1694        ----------
1695        keep_target_none : list
1696            a list of bool values
1697        """
1698
1699        self.transformer.keep_target_none = keep_target_none
1700
1701    def set_generate_ssl_input(self, generate_ssl_input: list) -> None:
1702        """
1703        Set the generate_ssl_input parameter of the transformer
1704
1705        Parameters
1706        ----------
1707        generate_ssl_input : list
1708            a list of bool values
1709        """
1710
1711        self.transformer.generate_ssl_input = generate_ssl_input
1712
1713    def set_model(self, model: Model) -> None:
1714        """
1715        Set a new model
1716
1717        Parameters
1718        ----------
1719        model: Model
1720            the new model
1721        """
1722
1723        self.epoch = 0
1724        self.model = model
1725        self.optimizer = self.optimizer_class(model.parameters(), lr=self.lr)
1726        if self.model.process_labels:
1727            self.model.set_behaviors(list(self.behaviors_dict().values()))
1728
1729    def set_dataloaders(
1730        self,
1731        train_dataloader: DataLoader,
1732        val_dataloader: DataLoader = None,
1733        test_dataloader: DataLoader = None,
1734    ) -> None:
1735        """
1736        Set new dataloaders
1737
1738        Parameters
1739        ----------
1740        train_dataloader: torch.utils.data.DataLoader
1741            the new train dataloader
1742        val_dataloader : torch.utils.data.DataLoader
1743            the new validation dataloader
1744        test_dataloader : torch.utils.data.DataLoader
1745            the new test dataloader
1746        """
1747
1748        self.train_dataloader = train_dataloader
1749        self.val_dataloader = val_dataloader
1750        self.test_dataloader = test_dataloader
1751
1752    def set_loss(self, loss: Callable) -> None:
1753        """
1754        Set new loss function
1755
1756        Parameters
1757        ----------
1758        loss: callable
1759            the new loss function
1760        """
1761
1762        self.loss = loss
1763
1764    def set_metrics(self, metrics: dict) -> None:
1765        """
1766        Set new metric
1767
1768        Parameters
1769        ----------
1770        metrics : dict
1771            the new metric dictionary
1772        """
1773
1774        self.metrics = metrics
1775
1776    def set_transformer(self, transformer: Transformer) -> None:
1777        """
1778        Set a new transformer
1779
1780        Parameters
1781        ----------
1782        transformer: Transformer
1783            the new transformer
1784        """
1785
1786        self.transformer = transformer
1787
1788    def set_predict_functions(
1789        self, primary_predict_function: Callable, predict_function: Callable
1790    ) -> None:
1791        """
1792        Set new predict functions
1793
1794        Parameters
1795        ----------
1796        primary_predict_function : callable
1797            the new primary predict function
1798        predict_function : callable
1799            the new predict function
1800        """
1801
1802        self.primary_predict_function = primary_predict_function
1803        self.predict_function = predict_function
1804
1805    def _set_pseudolabels(self):
1806        """
1807        Set pseudolabels
1808        """
1809
1810        self.train_dataloader.dataset.set_unlabeled(True)
1811        predicted = self.predict(
1812            data=self.dataset("train"),
1813            raw_output=False,
1814            augment_n=self.augment_val,
1815            ssl_off=True,
1816        )
1817        self.train_dataloader.dataset.set_annotation(predicted.detach())
1818
1819    def _alpha(self, epoch):
1820        """
1821        Get the current pseudolabeling alpha parameter
1822        """
1823
1824        if epoch <= self.T1:
1825            return 0
1826        elif epoch < self.T2:
1827            return self.alpha_f * (epoch - self.T1) / (self.T2 - self.T1)
1828        else:
1829            return self.alpha_f
1830
1831    def count_classes(self, bouts: bool = False) -> Dict:
1832        """
1833        Get a dictionary of class counts in different modes
1834
1835        Parameters
1836        ----------
1837        bouts : bool, default False
1838            if `True`, instead of frame counts segment counts are returned
1839
1840        Returns
1841        -------
1842        class_counts : dict
1843            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1844            class names and values are class counts (in frames)
1845        """
1846
1847        class_counts = {}
1848        for x in ["train", "val", "test"]:
1849            try:
1850                class_counts[x] = self.dataset(x).count_classes(bouts)
1851            except ValueError:
1852                class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()}
1853        return class_counts
1854
1855    def behaviors_dict(self) -> Dict:
1856        """
1857        Get a behavior dictionary
1858
1859        Keys are label indices and values are label names.
1860
1861        Returns
1862        -------
1863        behaviors_dict : dict
1864            behavior dictionary
1865        """
1866
1867        return self.dataset().behaviors_dict()
1868
1869    def update_parameters(self, parameters: Dict) -> None:
1870        """
1871        Update training parameters from a dictionary
1872
1873        Parameters
1874        ----------
1875        parameters : dict
1876            the update dictionary
1877        """
1878
1879        self.lr = parameters.get("lr", self.lr)
1880        self.parallel = parameters.get("parallel", self.parallel)
1881        self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr)
1882        self.verbose = parameters.get("verbose", self.verbose)
1883        self.device = parameters.get("device", self.device)
1884        if self.device == "auto":
1885            self.device = "cuda" if torch.cuda.is_available() else "cpu"
1886        self.augment_train = int(parameters.get("augment_train", self.augment_train))
1887        self.augment_val = int(parameters.get("augment_val", self.augment_val))
1888        ssl_weights = parameters.get("ssl_weights", self.ssl_weights)
1889        if ssl_weights is None:
1890            ssl_weights = [1 for _ in self.ssl_losses]
1891        if not isinstance(ssl_weights, Iterable):
1892            ssl_weights = [ssl_weights for _ in self.ssl_losses]
1893        self.ssl_weights = ssl_weights
1894        self.num_epochs = parameters.get("num_epochs", self.num_epochs)
1895        self.model_save_epochs = parameters.get(
1896            "model_save_epochs", self.model_save_epochs
1897        )
1898        self.model_save_path = parameters.get("model_save_path", self.model_save_path)
1899        self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel)
1900        self.T1 = int(parameters.get("pseudolabel_start", self.T1))
1901        self.T2 = int(parameters.get("alpha_growth_stop", self.T2))
1902        self.t = int(parameters.get("correction_interval", self.t))
1903        self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f)
1904        self.log_file = parameters.get("log_file", self.log_file)
1905
1906    def generate_uncertainty_score(
1907        self,
1908        classes: List,
1909        augment_n: int = 0,
1910        batch_size: int = 32,
1911        method: str = "least_confidence",
1912        predicted: torch.Tensor = None,
1913        behaviors_dict: Dict = None,
1914    ) -> Dict:
1915        """
1916        Generate frame-wise scores for active learning
1917
1918        Parameters
1919        ----------
1920        classes : list
1921            a list of class names or indices; their confidence scores will be computed separately and stacked
1922        augment_n : int, default 0
1923            the number of augmentations to average over
1924        batch_size : int, default 32
1925            the batch size
1926        method : {"least_confidence", "entropy"}
1927            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1928            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1929
1930        Returns
1931        -------
1932        score_dicts : dict
1933            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1934            are score tensors
1935        """
1936
1937        dataset = self.dataset("train")
1938        if behaviors_dict is None:
1939            behaviors_dict = self.behaviors_dict()
1940        if not isinstance(dataset, BehaviorDataset):
1941            raise TypeError(
1942                f"The dataset parameter has to be either None, string, "
1943                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1944            )
1945        if predicted is None:
1946            predicted = self.predict(
1947                dataset,
1948                raw_output=True,
1949                apply_primary_function=True,
1950                augment_n=augment_n,
1951                batch_size=batch_size,
1952            )
1953        predicted = dataset.generate_full_length_prediction(predicted)
1954        if isinstance(classes[0], str):
1955            behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()}
1956            classes = [behaviors_dict_inverse[c] for c in classes]
1957        for v_id, v in predicted.items():
1958            for clip_id, vv in v.items():
1959                if method == "least_confidence":
1960                    predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5]
1961                elif method == "entropy":
1962                    predicted[v_id][clip_id][vv != -100] = (
1963                        -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv)
1964                    )[vv != -100]
1965                elif method == "random":
1966                    predicted[v_id][clip_id] = torch.rand(vv.shape)
1967                else:
1968                    raise ValueError(
1969                        f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']"
1970                    )
1971                predicted[v_id][clip_id][vv == -100] = 0
1972
1973        predicted = {
1974            v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
1975            for v_id, video_dict in predicted.items()
1976        }
1977        return predicted
1978
1979    def generate_bald_score(
1980        self,
1981        classes: List,
1982        augment_n: int = 0,
1983        batch_size: int = 32,
1984        num_models: int = 10,
1985        kernel_size: int = 11,
1986    ) -> Dict:
1987        """
1988        Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning
1989
1990        Parameters
1991        ----------
1992        classes : list
1993            a list of class names or indices; their confidence scores will be computed separately and stacked
1994        augment_n : int, default 0
1995            the number of augmentations to average over
1996        batch_size : int, default 32
1997            the batch size
1998        num_models : int, default 10
1999            the number of dropout masks to apply
2000        kernel_size : int, default 11
2001            the size of the smoothing gaussian kernel
2002
2003        Returns
2004        -------
2005        score_dicts : dict
2006            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
2007            are score tensors
2008        """
2009
2010        dataset = self.dataset("train")
2011        dataset = self._get_dataset(dataset)
2012        if not isinstance(dataset, BehaviorDataset):
2013            raise TypeError(
2014                f"The dataset parameter has to be either None, string, "
2015                f"BehaviorDataset or Dataloader, got {type(dataset)}"
2016            )
2017        predictions = []
2018        for _ in range(num_models):
2019            predicted = self.predict(
2020                dataset,
2021                raw_output=True,
2022                apply_primary_function=True,
2023                augment_n=augment_n,
2024                batch_size=batch_size,
2025                train_mode=True,
2026            )
2027            predicted = dataset.generate_full_length_prediction(predicted)
2028            if isinstance(classes[0], str):
2029                behaviors_dict_inverse = {
2030                    v: k for k, v in self.behaviors_dict().items()
2031                }
2032                classes = [behaviors_dict_inverse[c] for c in classes]
2033            for v_id, v in predicted.items():
2034                for clip_id, vv in v.items():
2035                    vv[vv != -100] = (vv[vv != -100] > 0.5).int().float()
2036                    predicted[v_id][clip_id] = vv
2037            predicted = {
2038                v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
2039                for v_id, video_dict in predicted.items()
2040            }
2041            predictions.append(predicted)
2042        result = {v_id: {} for v_id in predictions[0]}
2043        r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1)
2044        gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r]
2045        gauss = [x / sum(gauss) for x in gauss]
2046        kernel = torch.FloatTensor([[gauss]])
2047        for v_id in predictions[0]:
2048            for clip_id in predictions[0][v_id]:
2049                consensus = (
2050                    (
2051                        torch.mean(
2052                            torch.stack([x[v_id][clip_id] for x in predictions]), dim=0
2053                        )
2054                        > 0.5
2055                    )
2056                    .int()
2057                    .float()
2058                )
2059                consensus[predictions[0][v_id][clip_id] == -100] = -100
2060                result[v_id][clip_id] = torch.zeros(consensus.shape)
2061                for x in predictions:
2062                    result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int()
2063                result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models
2064                res = torch.zeros(result[v_id][clip_id].shape)
2065                for i in range(len(classes)):
2066                    res[
2067                        i, floor(kernel_size // 2) : -floor(kernel_size // 2)
2068                    ] = torch.nn.functional.conv1d(
2069                        result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0), kernel
2070                    )[
2071                        0, ...
2072                    ]
2073                result[v_id][clip_id] = res
2074        return result
2075
2076    def get_normalization_stats(self) -> Optional[Dict]:
2077        return self.train_dataloader.dataset.stats

A universal trainer class that performs training, evaluation and prediction for all types of tasks and data

Task( train_dataloader: torch.utils.data.dataloader.DataLoader, model: Union[torch.nn.modules.module.Module, dlc2action.model.base_model.Model], loss: Callable[[torch.Tensor, torch.Tensor], float], num_epochs: int = 0, transformer: dlc2action.transformer.base_transformer.Transformer = None, ssl_losses: List = None, ssl_weights: List = None, lr: float = 0.001, metrics: Dict = None, val_dataloader: torch.utils.data.dataloader.DataLoader = None, test_dataloader: torch.utils.data.dataloader.DataLoader = None, optimizer: torch.optim.optimizer.Optimizer = None, device: str = 'cuda', verbose: bool = True, log_file: Optional[str] = None, augment_train: int = 1, augment_val: int = 0, validation_interval: int = 1, predict_function: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, primary_predict_function: Callable = None, exclusive: bool = True, ignore_tags: bool = True, threshold: float = 0.5, model_save_path: str = None, model_save_epochs: int = 5, pseudolabel: bool = False, pseudolabel_start: int = 100, correction_interval: int = 2, pseudolabel_alpha_f: float = 3, alpha_growth_stop: int = 600, parallel: bool = False, skip_metrics: List = None)
 75    def __init__(
 76        self,
 77        train_dataloader: DataLoader,
 78        model: Union[nn.Module, Model],
 79        loss: Callable[[torch.Tensor, torch.Tensor], float],
 80        num_epochs: int = 0,
 81        transformer: Transformer = None,
 82        ssl_losses: List = None,
 83        ssl_weights: List = None,
 84        lr: float = 1e-3,
 85        metrics: Dict = None,
 86        val_dataloader: DataLoader = None,
 87        test_dataloader: DataLoader = None,
 88        optimizer: Optimizer = None,
 89        device: str = "cuda",
 90        verbose: bool = True,
 91        log_file: Union[str, None] = None,
 92        augment_train: int = 1,
 93        augment_val: int = 0,
 94        validation_interval: int = 1,
 95        predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None,
 96        primary_predict_function: Callable = None,
 97        exclusive: bool = True,
 98        ignore_tags: bool = True,
 99        threshold: float = 0.5,
100        model_save_path: str = None,
101        model_save_epochs: int = 5,
102        pseudolabel: bool = False,
103        pseudolabel_start: int = 100,
104        correction_interval: int = 2,
105        pseudolabel_alpha_f: float = 3,
106        alpha_growth_stop: int = 600,
107        parallel: bool = False,
108        skip_metrics: List = None,
109    ) -> None:
110        """
111        Parameters
112        ----------
113        train_dataloader : torch.utils.data.DataLoader
114            a training dataloader
115        model : dlc2action.model.base_model.Model
116            a model
117        loss : callable
118            a loss function
119        num_epochs : int, default 0
120            the number of epochs
121        transformer : dlc2action.transformer.base_transformer.Transformer, optional
122            a transformer
123        ssl_losses : list, optional
124            a list of SSL losses
125        ssl_weights : list, optional
126            a list of SSL weights (if not provided initializes to 1)
127        lr : float, default 1e-3
128            learning rate
129        metrics : dict, optional
130            a list of metric functions
131        val_dataloader : torch.utils.data.DataLoader, optional
132            a validation dataloader
133        optimizer : torch.optim.Optimizer, optional
134            an optimizer (`Adam` by default)
135        device : str, default 'cuda'
136            the device to train the model on
137        verbose : bool, default True
138            if `True`, the process is described in standard output
139        log_file : str, optional
140            the path to a text file where the process will be logged
141        augment_train : {1, 0}
142            number of augmentations to apply at training
143        augment_val : int, default 0
144            number of augmentations to apply at validation
145        validation_interval : int, default 1
146            every time this number of epochs passes, validation metrics are computed
147        predict_function : callable, optional
148            a function that maps probabilities to class predictions (if not provided, a default is generated)
149        primary_predict_function : callable, optional
150            a function that maps model output to probabilities (if not provided, initialized as identity)
151        exclusive : bool, default True
152            set to False for multi-label classification
153        ignore_tags : bool, default False
154            if `True`, samples with different meta tags will be mixed in batches
155        threshold : float, default 0.5
156            the threshold used for multi-label classification default prediction function
157        model_save_path : str, optional
158            the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path
159            is not provided)
160        model_save_epochs : int, default 5
161            the interval for saving the model checkpoints (the last epoch is always saved)
162        pseudolabel : bool, default False
163            if True, the pseudolabeling procedure will be applied
164        pseudolabel_start : int, default 100
165            pseudolabeling starts after this epoch
166        correction_interval : int, default 1
167            after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and
168            new pseudolabels are generated
169        pseudolabel_alpha_f : float, default 3
170            the maximum value of pseudolabeling alpha
171        alpha_growth_stop : int, default 600
172            pseudolabeling alpha stops growing after this epoch
173        """
174
175        # pseudolabeling might be buggy right now -- not using it!
176        if skip_metrics is None:
177            skip_metrics = []
178        self.train_dataloader = train_dataloader
179        self.val_dataloader = val_dataloader
180        self.test_dataloader = test_dataloader
181        self.transformer = transformer
182        self.num_epochs = num_epochs
183        self.skip_metrics = skip_metrics
184        self.verbose = verbose
185        self.augment_train = int(augment_train)
186        self.augment_val = int(augment_val)
187        self.ignore_tags = ignore_tags
188        self.validation_interval = int(validation_interval)
189        self.log_file = log_file
190        self.loss = loss
191        self.model_save_path = model_save_path
192        self.model_save_epochs = model_save_epochs
193        self.epoch = 0
194
195        if metrics is None:
196            metrics = {}
197        self.metrics = metrics
198
199        if optimizer is None:
200            optimizer = Adam
201
202        if ssl_weights is None:
203            ssl_weights = [1 for _ in ssl_losses]
204        if not isinstance(ssl_weights, Iterable):
205            ssl_weights = [ssl_weights for _ in ssl_losses]
206        self.ssl_weights = ssl_weights
207
208        self.optimizer_class = optimizer
209        self.lr = lr
210        if not isinstance(model, Model):
211            self.model = LoadedModel(model=model)
212        else:
213            self.set_model(model)
214        self.parallel = parallel
215
216        if self.transformer is None:
217            self.augment_val = 0
218            self.augment_train = 0
219            self.transformer = EmptyTransformer()
220
221        if self.augment_train > 1:
222            warnings.warn(
223                'The "augment_train" parameter is too large -> setting it to 1.'
224            )
225            self.augment_train = 1
226
227        try:
228            if device == "auto":
229                device = "cuda" if torch.cuda.is_available() else "cpu"
230            self.device = torch.device(device)
231        except:
232            raise ("The format of the device is incorrect")
233
234        if ssl_losses is None:
235            self.ssl_losses = [lambda x, y: 0]
236        else:
237            self.ssl_losses = ssl_losses
238
239        if primary_predict_function is None:
240            if exclusive:
241                primary_predict_function = lambda x: nn.Softmax(x, dim=1)
242            else:
243                primary_predict_function = lambda x: torch.sigmoid(x)
244        self.primary_predict_function = primary_predict_function
245
246        if predict_function is None:
247            if exclusive:
248                self.predict_function = lambda x: torch.max(x.data, 1)[1]
249            else:
250                self.predict_function = lambda x: (x > threshold).int()
251        else:
252            self.predict_function = predict_function
253
254        self.pseudolabel = pseudolabel
255        self.alpha_f = pseudolabel_alpha_f
256        self.T2 = alpha_growth_stop
257        self.T1 = pseudolabel_start
258        self.t = correction_interval
259        if self.T2 <= self.T1:
260            raise ValueError(
261                f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got "
262                f"{pseudolabel_start=} and {alpha_growth_stop=}"
263            )
264        self.decision_thresholds = [0.5 for x in self.behaviors_dict()]

Parameters

train_dataloader : torch.utils.data.DataLoader a training dataloader model : dlc2action.model.base_model.Model a model loss : callable a loss function num_epochs : int, default 0 the number of epochs transformer : dlc2action.transformer.base_transformer.Transformer, optional a transformer ssl_losses : list, optional a list of SSL losses ssl_weights : list, optional a list of SSL weights (if not provided initializes to 1) lr : float, default 1e-3 learning rate metrics : dict, optional a list of metric functions val_dataloader : torch.utils.data.DataLoader, optional a validation dataloader optimizer : torch.optim.Optimizer, optional an optimizer (Adam by default) device : str, default 'cuda' the device to train the model on verbose : bool, default True if True, the process is described in standard output log_file : str, optional the path to a text file where the process will be logged augment_train : {1, 0} number of augmentations to apply at training augment_val : int, default 0 number of augmentations to apply at validation validation_interval : int, default 1 every time this number of epochs passes, validation metrics are computed predict_function : callable, optional a function that maps probabilities to class predictions (if not provided, a default is generated) primary_predict_function : callable, optional a function that maps model output to probabilities (if not provided, initialized as identity) exclusive : bool, default True set to False for multi-label classification ignore_tags : bool, default False if True, samples with different meta tags will be mixed in batches threshold : float, default 0.5 the threshold used for multi-label classification default prediction function model_save_path : str, optional the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path is not provided) model_save_epochs : int, default 5 the interval for saving the model checkpoints (the last epoch is always saved) pseudolabel : bool, default False if True, the pseudolabeling procedure will be applied pseudolabel_start : int, default 100 pseudolabeling starts after this epoch correction_interval : int, default 1 after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and new pseudolabels are generated pseudolabel_alpha_f : float, default 3 the maximum value of pseudolabeling alpha alpha_growth_stop : int, default 600 pseudolabeling alpha stops growing after this epoch

def save_checkpoint(self, checkpoint_path: str) -> None:
266    def save_checkpoint(self, checkpoint_path: str) -> None:
267        """
268        Save a general checkpoint
269
270        Parameters
271        ----------
272        checkpoint_path : str
273            the path where the checkpoint will be saved
274        """
275
276        torch.save(
277            {
278                "epoch": self.epoch,
279                "model_state_dict": self.model.state_dict(),
280                "optimizer_state_dict": self.optimizer.state_dict(),
281            },
282            checkpoint_path,
283        )

Save a general checkpoint

Parameters

checkpoint_path : str the path where the checkpoint will be saved

def load_from_checkpoint( self, checkpoint_path, only_model: bool = False, load_strict: bool = True) -> None:
285    def load_from_checkpoint(
286        self, checkpoint_path, only_model: bool = False, load_strict: bool = True
287    ) -> None:
288        """
289        Load from a checkpoint
290
291        Parameters
292        ----------
293        checkpoint_path : str
294            the path to the checkpoint
295        only_model : bool, default False
296            if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state
297            dictionary)
298        load_strict : bool, default True
299            if `True`, any inconsistencies in state dictionaries are regarded as errors
300        """
301
302        if checkpoint_path is None:
303            return
304        checkpoint = torch.load(checkpoint_path, map_location=self.device)
305        self.model.to(self.device)
306        self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict)
307        if not only_model:
308            self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
309            self.epoch = checkpoint["epoch"]

Load from a checkpoint

Parameters

checkpoint_path : str the path to the checkpoint only_model : bool, default False if True, only the model state dictionary will be loaded (and not the epoch and the optimizer state dictionary) load_strict : bool, default True if True, any inconsistencies in state dictionaries are regarded as errors

def save_model(self, save_path: str) -> None:
311    def save_model(self, save_path: str) -> None:
312        """
313        Save the model state dictionary
314
315        Parameters
316        ----------
317        save_path : str
318            the path where the state will be saved
319        """
320
321        torch.save(self.model.state_dict(), save_path)
322        print("saved the model successfully")

Save the model state dictionary

Parameters

save_path : str the path where the state will be saved

def train( self, trial: optuna.trial._trial.Trial = None, optimized_metric: str = None, to_ram: bool = False, autostop_interval: int = 30, autostop_threshold: float = 0.001, autostop_metric: str = None, main_task_on: bool = True, ssl_on: bool = True, temporal_subsampling_size: int = None, loading_bar: bool = False) -> Tuple:
603    def train(
604        self,
605        trial: Trial = None,
606        optimized_metric: str = None,
607        to_ram: bool = False,
608        autostop_interval: int = 30,
609        autostop_threshold: float = 0.001,
610        autostop_metric: str = None,
611        main_task_on: bool = True,
612        ssl_on: bool = True,
613        temporal_subsampling_size: int = None,
614        loading_bar: bool = False,
615    ) -> Tuple:
616        """
617        Train the task and return a log of epoch-average loss and metric
618
619        You can use the autostop parameters to finish training when the parameters are not improving. It will be
620        stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than
621        the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the
622        current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared.
623
624        Parameters
625        ----------
626        trial : Trial
627            an `optuna` trial (for hyperparameter searches)
628        optimized_metric : str
629            the name of the metric being optimized (for hyperparameter searches)
630        to_ram : bool, default False
631            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
632            if the dataset is too large)
633        autostop_interval : int, default 50
634            the number of epochs to average the autostop metric over
635        autostop_threshold : float, default 0.001
636            the autostop difference threshold
637        autostop_metric : str, optional
638            the autostop metric (can be any one of the tracked metrics of `'loss'`)
639        main_task_on : bool, default True
640            if `False`, the main task (action segmentation) will not be used in training
641        ssl_on : bool, default True
642            if `False`, the SSL task will not be used in training
643
644        Returns
645        -------
646        loss_log: list
647            a list of float loss function values for each epoch
648        metrics_log: dict
649            a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric
650            names, values are lists of function values)
651        """
652
653        if self.parallel and not isinstance(self.model, nn.DataParallel):
654            self.model = MyDataParallel(self.model)
655        self.model.to(self.device)
656        assert autostop_metric in [None, "loss"] + list(self.metrics)
657        autostop_interval //= self.validation_interval
658        if trial is not None and optimized_metric is None:
659            raise ValueError(
660                "You need to provide the optimized metric name (optimized_metric parameter) "
661                "for optuna pruning to work!"
662            )
663        if to_ram:
664            print("transferring datasets to RAM...")
665            self.train_dataloader.dataset.to_ram()
666            if self.val_dataloader is not None and len(self.val_dataloader) > 0:
667                self.val_dataloader.dataset.to_ram()
668        loss_log = {"train": [], "val": []}
669        metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])}
670        if not main_task_on:
671            self.model.main_task_off()
672        if not ssl_on:
673            self.model.ssl_off()
674        while self.epoch < self.num_epochs:
675            self.epoch += 1
676            unlabeled = None
677            alpha = 1
678            if self.pseudolabel:
679                if self.epoch >= self.T1:
680                    unlabeled = (self.epoch - self.T1) % self.t != 0
681                    if unlabeled:
682                        alpha = self._alpha(self.epoch)
683                else:
684                    unlabeled = False
685            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
686                dataloader=self.train_dataloader,
687                mode="train",
688                augment_n=self.augment_train,
689                unlabeled=unlabeled,
690                alpha=alpha,
691                temporal_subsampling_size=temporal_subsampling_size,
692                verbose=loading_bar,
693            )
694            loss_log["train"].append(epoch_loss)
695            epoch_string = f"[epoch {self.epoch}]"
696            if self.pseudolabel:
697                if unlabeled:
698                    epoch_string += " (unlabeled)"
699                else:
700                    epoch_string += " (labeled)"
701            epoch_string += f": loss {epoch_loss:.4f}"
702
703            if len(epoch_ssl_loss) != 0:
704                for key, value in sorted(epoch_ssl_loss.items()):
705                    metrics_log["train"][f"ssl_loss_{key}"].append(value)
706                    epoch_string += f", ssl_loss_{key} {value:.4f}"
707
708            for metric_name, metric_value in sorted(epoch_metrics.items()):
709                if metric_name not in self.skip_metrics:
710                    metrics_log["train"][metric_name].append(metric_value)
711                    epoch_string += f", {metric_name} {metric_value:.3f}"
712
713            if (
714                self.val_dataloader is not None
715                and self.epoch % self.validation_interval == 0
716            ):
717                with torch.no_grad():
718                    epoch_string += "\n"
719                    (
720                        val_epoch_loss,
721                        val_epoch_ssl_loss,
722                        val_epoch_metrics,
723                    ) = self._run_epoch(
724                        dataloader=self.val_dataloader,
725                        mode="val",
726                        augment_n=self.augment_val,
727                    )
728                    loss_log["val"].append(val_epoch_loss)
729                    epoch_string += f"validation: loss {val_epoch_loss:.4f}"
730
731                    if len(val_epoch_ssl_loss) != 0:
732                        for key, value in sorted(val_epoch_ssl_loss.items()):
733                            metrics_log["val"][f"ssl_loss_{key}"].append(value)
734                            epoch_string += f", ssl_loss_{key} {value:.4f}"
735
736                    for metric_name, metric_value in sorted(val_epoch_metrics.items()):
737                        metrics_log["val"][metric_name].append(metric_value)
738                        epoch_string += f", {metric_name} {metric_value:.3f}"
739
740                if trial is not None:
741                    if optimized_metric not in metrics_log["val"]:
742                        raise ValueError(
743                            f"The {optimized_metric} metric set for optimization is not being logged!"
744                        )
745                    trial.report(metrics_log["val"][optimized_metric][-1], self.epoch)
746                    if trial.should_prune():
747                        raise TrialPruned()
748
749            if self.verbose:
750                print(epoch_string)
751
752            if self.log_file is not None:
753                with open(self.log_file, "a") as f:
754                    f.write(epoch_string + "\n")
755
756            save_condition = (
757                (self.model_save_epochs != 0)
758                and (self.epoch % self.model_save_epochs == 0)
759            ) or (self.epoch == self.num_epochs)
760
761            if self.epoch > 0 and save_condition and self.model_save_path is not None:
762                epoch_s = str(self.epoch).zfill(len(str(self.num_epochs)))
763                self.save_checkpoint(
764                    os.path.join(self.model_save_path, f"epoch{epoch_s}.pt")
765                )
766
767            if self.pseudolabel and self.epoch >= self.T1 and not unlabeled:
768                self._set_pseudolabels()
769
770            if autostop_metric == "loss":
771                if len(loss_log["val"]) > autostop_interval * 2:
772                    if (
773                        np.mean(loss_log["val"][-autostop_interval:])
774                        < np.mean(
775                            loss_log["val"][-2 * autostop_interval : -autostop_interval]
776                        )
777                        + autostop_threshold
778                    ):
779                        break
780            elif autostop_metric in metrics_log["val"]:
781                if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2:
782                    if (
783                        np.mean(
784                            metrics_log["val"][autostop_metric][-autostop_interval:]
785                        )
786                        < np.mean(
787                            metrics_log["val"][autostop_metric][
788                                -2 * autostop_interval : -autostop_interval
789                            ]
790                        )
791                        + autostop_threshold
792                    ):
793                        break
794
795        metrics_log = {k: dict(v) for k, v in metrics_log.items()}
796
797        return loss_log, metrics_log

Train the task and return a log of epoch-average loss and metric

You can use the autostop parameters to finish training when the parameters are not improving. It will be stopped if the average value of autostop_metric over the last autostop_interval epochs is smaller than the average over the previous autostop_interval epochs + autostop_threshold. For example, if the current epoch is 120 and autostop_interval is 50, the averages over epochs 70-120 and 20-70 will be compared.

Parameters

trial : Trial an optuna trial (for hyperparameter searches) optimized_metric : str the name of the metric being optimized (for hyperparameter searches) to_ram : bool, default False if True, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes if the dataset is too large) autostop_interval : int, default 50 the number of epochs to average the autostop metric over autostop_threshold : float, default 0.001 the autostop difference threshold autostop_metric : str, optional the autostop metric (can be any one of the tracked metrics of 'loss') main_task_on : bool, default True if False, the main task (action segmentation) will not be used in training ssl_on : bool, default True if False, the SSL task will not be used in training

Returns

loss_log: list a list of float loss function values for each epoch metrics_log: dict a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric names, values are lists of function values)

def evaluate_prediction( self, prediction: Union[torch.Tensor, Dict], data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str] = None, batch_size: int = 32) -> Tuple:
799    def evaluate_prediction(
800        self,
801        prediction: Union[torch.Tensor, Dict],
802        data: Union[DataLoader, BehaviorDataset, str] = None,
803        batch_size: int = 32,
804    ) -> Tuple:
805        """
806        Compute metrics for a prediction
807
808        Parameters
809        ----------
810        prediction : torch.Tensor
811            the prediction
812        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
813            the data the prediction was made for (if not provided, take the validation dataset)
814        batch_size : int, default 32
815            the batch size
816
817        Returns
818        -------
819        loss : float
820            the average value of the loss function
821        metric : dict
822            a dictionary of average values of metric functions
823        """
824
825        if type(data) is not DataLoader:
826            dataset = self._get_dataset(data)
827            data = DataLoader(dataset, shuffle=False, batch_size=batch_size)
828        epoch_loss = 0
829        if isinstance(prediction, dict):
830            num_classes = len(self.behaviors_dict())
831            length = dataset.len_segment()
832            coords = dataset.annotation_store.get_original_coordinates()
833            for batch in data:
834                main_target = batch["target"]
835                pr_coords = coords[batch["index"]]
836                predicted = torch.zeros((len(pr_coords), num_classes, length))
837                for i, c in enumerate(pr_coords):
838                    video_id = dataset.input_store.get_video_id(c)
839                    clip_id = dataset.input_store.get_clip_id(c)
840                    start, end = dataset.input_store.get_clip_start_end(c)
841                    predicted[i, :, : end - start] = prediction[video_id][clip_id][
842                        :, start:end
843                    ]
844                self._compute(
845                    [],
846                    [],
847                    predicted,
848                    main_target,
849                    skip_loss=True,
850                    tag=batch.get("tag"),
851                    apply_primary_function=False,
852                )
853        else:
854            for batch in data:
855                main_target = batch["target"]
856                predicted = prediction[batch["index"]]
857                self._compute(
858                    [],
859                    [],
860                    predicted,
861                    main_target,
862                    skip_loss=True,
863                    tag=batch.get("tag"),
864                    apply_primary_function=False,
865                )
866        epoch_metrics = self._calculate_metrics()
867        strings = [
868            f"{metric_name} {metric_value:.3f}"
869            for metric_name, metric_value in epoch_metrics.items()
870        ]
871        val_string = ", ".join(sorted(strings))
872        print(val_string)
873        return epoch_loss, epoch_metrics

Compute metrics for a prediction

Parameters

prediction : torch.Tensor the prediction data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data the prediction was made for (if not provided, take the validation dataset) batch_size : int, default 32 the batch size

Returns

loss : float the average value of the loss function metric : dict a dictionary of average values of metric functions

def evaluate( self, data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str] = None, augment_n: int = 0, batch_size: int = 32, verbose: bool = True) -> Tuple:
875    def evaluate(
876        self,
877        data: Union[DataLoader, BehaviorDataset, str] = None,
878        augment_n: int = 0,
879        batch_size: int = 32,
880        verbose: bool = True,
881    ) -> Tuple:
882        """
883        Evaluate the Task model
884
885        Parameters
886        ----------
887        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional
888            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
889        augment_n : int, default 0
890            the number of augmentations to average results over
891        batch_size : int, default 32
892            the batch size
893        verbose : bool, default True
894            if True, the process is reported to standard output
895
896        Returns
897        -------
898        loss : float
899            the average value of the loss function
900        ssl_loss : float
901            the average value of the SSL loss function
902        metric : dict
903            a dictionary of average values of metric functions
904        """
905
906        if self.parallel and not isinstance(self.model, nn.DataParallel):
907            self.model = MyDataParallel(self.model)
908        self.model.to(self.device)
909        if type(data) is not DataLoader:
910            data = self._get_dataset(data)
911            data = DataLoader(data, shuffle=False, batch_size=batch_size)
912        with torch.no_grad():
913            epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch(
914                dataloader=data, mode="val", augment_n=augment_n, verbose=verbose
915            )
916        val_string = f"loss {epoch_loss:.4f}"
917        for metric_name, metric_value in sorted(epoch_metrics.items()):
918            val_string += f", {metric_name} {metric_value:.3f}"
919        print(val_string)
920        return epoch_loss, epoch_ssl_loss, epoch_metrics

Evaluate the Task model

Parameters

data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) augment_n : int, default 0 the number of augmentations to average results over batch_size : int, default 32 the batch size verbose : bool, default True if True, the process is reported to standard output

Returns

loss : float the average value of the loss function ssl_loss : float the average value of the SSL loss function metric : dict a dictionary of average values of metric functions

def predict( self, data: Union[torch.utils.data.dataloader.DataLoader, dlc2action.data.dataset.BehaviorDataset, str] = None, raw_output: bool = False, apply_primary_function: bool = True, augment_n: int = 0, batch_size: int = 32, train_mode: bool = False, to_ram: bool = False, embedding: bool = False) -> torch.Tensor:
922    def predict(
923        self,
924        data: Union[DataLoader, BehaviorDataset, str] = None,
925        raw_output: bool = False,
926        apply_primary_function: bool = True,
927        augment_n: int = 0,
928        batch_size: int = 32,
929        train_mode: bool = False,
930        to_ram: bool = False,
931        embedding: bool = False,
932    ) -> torch.Tensor:
933        """
934        Make a prediction with the Task model
935
936        Parameters
937        ----------
938        data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional
939            the data to evaluate on (if not provided, evaluate on the Task validation dataset)
940        raw_output : bool, default False
941            if `True`, the raw predicted probabilities are returned
942        apply_primary_function : bool, default True
943            if `True`, the primary predict function is applied (to map the model output into a shape corresponding to
944            the input)
945        augment_n : int, default 0
946            the number of augmentations to average results over
947        batch_size : int, default 32
948            the batch size
949        train_mode : bool, default False
950            if `True`, the model is used in training mode (affects dropout and batch normalization layers)
951        to_ram : bool, default False
952            if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
953            if the dataset is too large)
954        embedding : bool, default False
955            if `True`, the output of feature extractor is returned, ignoring the prediction module of the model
956
957        Returns
958        -------
959        prediction : torch.Tensor
960            a prediction for the input data
961        """
962
963        if self.parallel and not isinstance(self.model, nn.DataParallel):
964            self.model = MyDataParallel(self.model)
965        self.model.to(self.device)
966        if train_mode:
967            self.model.train()
968        else:
969            self.model.eval()
970        output = []
971        if embedding:
972            raw_output = True
973            apply_primary_function = True
974        if type(data) is not DataLoader:
975            data = self._get_dataset(data)
976            if to_ram:
977                print("transferring dataset to RAM...")
978                data.to_ram()
979            data = DataLoader(data, shuffle=False, batch_size=batch_size)
980        self.model.ssl_off()
981        with torch.no_grad():
982            for batch in tqdm(data):
983                input = {k: v.to(self.device) for k, v in batch["input"].items()}
984                predicted, _, _ = self._get_prediction(
985                    input,
986                    batch.get("tag"),
987                    augment_n=augment_n,
988                    embedding=embedding,
989                )
990                if apply_primary_function:
991                    predicted = self.primary_predict_function(predicted)
992                if not raw_output:
993                    predicted = self.predict_function(predicted)
994                output.append(predicted.detach().cpu())
995        self.model.ssl_on()
996        output = torch.cat(output).detach()
997        return output

Make a prediction with the Task model

Parameters

data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) raw_output : bool, default False if True, the raw predicted probabilities are returned apply_primary_function : bool, default True if True, the primary predict function is applied (to map the model output into a shape corresponding to the input) augment_n : int, default 0 the number of augmentations to average results over batch_size : int, default 32 the batch size train_mode : bool, default False if True, the model is used in training mode (affects dropout and batch normalization layers) to_ram : bool, default False if True, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes if the dataset is too large) embedding : bool, default False if True, the output of feature extractor is returned, ignoring the prediction module of the model

Returns

prediction : torch.Tensor a prediction for the input data

def dataset(self, mode='train') -> dlc2action.data.dataset.BehaviorDataset:
 999    def dataset(self, mode="train") -> BehaviorDataset:
1000        """
1001        Get a dataset
1002
1003        Parameters
1004        ----------
1005        mode : {'train', 'val', 'test}
1006            the dataset to get
1007
1008        Returns
1009        -------
1010        dataset : dlc2action.data.dataset.BehaviorDataset
1011            the dataset
1012        """
1013
1014        dataloader = self.dataloader(mode)
1015        if dataloader is None:
1016            raise ValueError("The length of the dataloader is 0!")
1017        return dataloader.dataset

Get a dataset

Parameters

mode : {'train', 'val', 'test} the dataset to get

Returns

dataset : dlc2action.data.dataset.BehaviorDataset the dataset

def dataloader(self, mode: str = 'train') -> torch.utils.data.dataloader.DataLoader:
1019    def dataloader(self, mode: str = "train") -> DataLoader:
1020        """
1021        Get a dataloader
1022
1023        Parameters
1024        ----------
1025        mode : {'train', 'val', 'test}
1026            the dataset to get
1027
1028        Returns
1029        -------
1030        dataloader : torch.utils.data.DataLoader
1031            the dataloader
1032        """
1033
1034        if mode == "train":
1035            return self.train_dataloader
1036        elif mode == "val":
1037            return self.val_dataloader
1038        elif mode == "test":
1039            return self.test_dataloader
1040        else:
1041            raise ValueError(
1042                f'The {mode} mode is not recognized, please choose from "train", "val" or "test"'
1043            )

Get a dataloader

Parameters

mode : {'train', 'val', 'test} the dataset to get

Returns

dataloader : torch.utils.data.DataLoader the dataloader

def generate_full_length_prediction(self, dataset=None, batch_size=32, augment_n=10)
1079    def generate_full_length_prediction(
1080        self, dataset=None, batch_size=32, augment_n=10
1081    ):
1082        """
1083        Compile a prediction for the original input sequences
1084
1085        Parameters
1086        ----------
1087        dataset : BehaviorDataset, optional
1088            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1089        batch_size : int, default 32
1090            the batch size
1091        augment_n : int, default 10
1092            the number of augmentations to average results over
1093
1094        Returns
1095        -------
1096        prediction : dict
1097            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1098            are prediction tensors
1099        """
1100
1101        dataset = self._get_dataset(dataset)
1102        if not isinstance(dataset, BehaviorDataset):
1103            raise TypeError(
1104                f"The dataset parameter has to be either None, string, "
1105                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1106            )
1107        predicted = self.predict(
1108            dataset,
1109            raw_output=True,
1110            apply_primary_function=True,
1111            augment_n=augment_n,
1112            batch_size=batch_size,
1113        )
1114        predicted = dataset.generate_full_length_prediction(predicted)
1115        predicted = {
1116            v_id: {
1117                clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze()
1118                for clip_id, v in video_dict.items()
1119            }
1120            for v_id, video_dict in predicted.items()
1121        }
1122        return predicted

Compile a prediction for the original input sequences

Parameters

dataset : BehaviorDataset, optional the dataset to generate a prediction for (if None, generate for the validation dataset) batch_size : int, default 32 the batch size augment_n : int, default 10 the number of augmentations to average results over

Returns

prediction : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are prediction tensors

def generate_submission( self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10)
1124    def generate_submission(
1125        self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10
1126    ):
1127        """
1128        Generate a MABe-22 style submission dictionary
1129
1130        Parameters
1131        ----------
1132        dataset : BehaviorDataset, optional
1133            the dataset to generate a prediction for (if `None`, generate for the validation dataset)
1134        batch_size : int, default 32
1135            the batch size
1136        augment_n : int, default 10
1137            the number of augmentations to average results over
1138
1139        Returns
1140        -------
1141        submission : dict
1142            a dictionary with frame number mapping and embeddings
1143        """
1144
1145        dataset = self._get_dataset(dataset)
1146        if not isinstance(dataset, BehaviorDataset):
1147            raise TypeError(
1148                f"The dataset parameter has to be either None, string, "
1149                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1150            )
1151        predicted = self.predict(
1152            dataset,
1153            raw_output=True,
1154            apply_primary_function=True,
1155            augment_n=augment_n,
1156            batch_size=batch_size,
1157            embedding=True,
1158        )
1159        predicted = dataset.generate_full_length_prediction(predicted)
1160        frame_map = np.load(frame_number_map_file, allow_pickle=True).item()
1161        length = frame_map[list(frame_map.keys())[-1]][1]
1162        embeddings = None
1163        for video_id in list(predicted.keys()):
1164            split = video_id.split("--")
1165            if len(split) != 2 or len(predicted[video_id]) > 1:
1166                raise RuntimeError(
1167                    "Generating submissions is only implemented for the mabe22 dataset!"
1168                )
1169            if split[1] not in frame_map:
1170                raise RuntimeError(f"The {split[1]} video is not in the frame map file")
1171            v_id = split[1]
1172            clip_id = list(predicted[video_id].keys())[0]
1173            if embeddings is None:
1174                embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0]))
1175            start, end = frame_map[v_id]
1176            embeddings[start:end, :] = predicted[video_id][clip_id].T
1177            predicted.pop(video_id)
1178        predicted = {
1179            "frame_number_map": frame_map,
1180            "embeddings": embeddings.astype(np.float32),
1181        }
1182        return predicted

Generate a MABe-22 style submission dictionary

Parameters

dataset : BehaviorDataset, optional the dataset to generate a prediction for (if None, generate for the validation dataset) batch_size : int, default 32 the batch size augment_n : int, default 10 the number of augmentations to average results over

Returns

submission : dict a dictionary with frame number mapping and embeddings

def visualize_results( self, save_path: str = None, add_legend: bool = True, ground_truth: bool = True, colormap: str = 'viridis', hide_axes: bool = False, min_classes: int = 1, width: int = 10, whole_video: bool = False, transparent: bool = False, dataset: Union[dlc2action.data.dataset.BehaviorDataset, torch.utils.data.dataloader.DataLoader, str, NoneType] = None, drop_classes: Set = None, search_classes: Set = None, num_samples: int = 1, smooth_interval_prediction: int = None, behavior_name: str = None)
1439    def visualize_results(
1440        self,
1441        save_path: str = None,
1442        add_legend: bool = True,
1443        ground_truth: bool = True,
1444        colormap: str = "viridis",
1445        hide_axes: bool = False,
1446        min_classes: int = 1,
1447        width: int = 10,
1448        whole_video: bool = False,
1449        transparent: bool = False,
1450        dataset: Union[BehaviorDataset, DataLoader, str, None] = None,
1451        drop_classes: Set = None,
1452        search_classes: Set = None,
1453        num_samples: int = 1,
1454        smooth_interval_prediction: int = None,
1455        behavior_name: str = None,
1456    ):
1457        """
1458        Visualize random predictions
1459
1460        Parameters
1461        ----------
1462        save_path : str, optional
1463            the path where the plot will be saved
1464        add_legend : bool, default True
1465            if `True`, legend will be added to the plot
1466        ground_truth : bool, default True
1467            if `True`, ground truth will be added to the plot
1468        colormap : str, default 'Accent'
1469            the `matplotlib` colormap to use
1470        hide_axes : bool, default True
1471            if `True`, the axes will be hidden on the plot
1472        min_classes : int, default 1
1473            the minimum number of classes in a displayed interval
1474        width : float, default 10
1475            the width of the plot
1476        whole_video : bool, default False
1477            if `True`, whole videos are plotted instead of segments
1478        transparent : bool, default False
1479            if `True`, the background on the plot is transparent
1480        dataset : BehaviorDataset | DataLoader | str | None, optional
1481            the dataset to make the prediction for (if not provided, the validation dataset is used)
1482        drop_classes : set, optional
1483            a set of class names to not be displayed
1484        search_classes : set, optional
1485            if given, only intervals where at least one of the classes is in ground truth will be shown
1486        """
1487
1488        if drop_classes is None:
1489            drop_classes = []
1490        dataset = self._get_dataset(dataset)
1491        if dataset.annotation_class() == "exclusive_classification":
1492            exclusive = True
1493        elif dataset.annotation_class() == "nonexclusive_classification":
1494            exclusive = False
1495        else:
1496            raise NotImplementedError(
1497                f"Results visualisation is not implemented for {dataset.annotation_class() } datasets!"
1498            )
1499        if exclusive:
1500            labels = {
1501                k: v
1502                for k, v in dataset.behaviors_dict().items()
1503                if v not in drop_classes
1504            }
1505            labels.update({-100: "unknown"})
1506        else:
1507            inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()}
1508            if behavior_name is None:
1509                behavior_name = list(inverse_dict.keys())[0]
1510            label_ind = inverse_dict[behavior_name]
1511            labels = {1: label, -100: "unknown"}
1512        label_keys = sorted([int(x) for x in labels.keys()])
1513        if search_classes is None:
1514            ok = True
1515        else:
1516            ok = False
1517        classes = []
1518        if whole_video:
1519            predicted = self.generate_full_length_prediction(dataset)
1520            keys = list(predicted.keys())
1521        counter = 0
1522        if whole_video:
1523            max_iter = len(keys) * 2
1524        else:
1525            max_iter = len(dataset) * 2
1526        while len(classes) < min_classes or not ok:
1527            counter += 1
1528            if counter > max_iter:
1529                raise RuntimeError(
1530                    "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive"
1531                )
1532            if whole_video:
1533                i = randint(0, len(keys) - 1)
1534                prediction = predicted[keys[i]]
1535                keys_i = list(prediction.keys())
1536                j = randint(0, len(keys_i) - 1)
1537                prediction = prediction[keys_i[j]]
1538                key1 = keys[i]
1539                key2 = keys_i[j]
1540            else:
1541                dataloader = DataLoader(dataset)
1542                i = randint(0, len(dataloader) - 1)
1543                for num, batch in enumerate(dataloader):
1544                    if num == i:
1545                        break
1546                input = {k: v.to(self.device) for k, v in batch["input"].items()}
1547                prediction, *_ = self._get_prediction(
1548                    input, batch.get("tag"), augment_n=5
1549                )
1550                prediction = self._apply_predict_functions(prediction)
1551                j = randint(0, len(prediction) - 1)
1552                prediction = prediction[j]
1553            if not exclusive:
1554                prediction = prediction[label_ind]
1555            if smooth_interval_prediction > 0:
1556                unsmoothed_prediction = deepcopy(prediction)
1557                prediction = self._smooth(prediction, smooth_interval_prediction)
1558                height = 3
1559            else:
1560                height = 2
1561            classes = [
1562                labels[int(x)] for x in torch.unique(prediction) if x in label_keys
1563            ]
1564            if search_classes is not None:
1565                ok = any([x in classes for x in search_classes])
1566        fig, ax = plt.subplots(figsize=(width, height))
1567        color_list = cm.get_cmap(colormap, lut=len(labels)).colors
1568
1569        def _plot_prediction(prediction, height, set_labels=True):
1570            for c in label_keys:
1571                c_i = label_keys.index(int(c))
1572                output, indices, counts = torch.unique_consecutive(
1573                    prediction == c, return_inverse=True, return_counts=True
1574                )
1575                long_indices = torch.where(output)[0]
1576                if len(long_indices) == 0:
1577                    continue
1578                res_indices_start = [
1579                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1580                    for i in long_indices
1581                ]
1582                res_indices_end = [
1583                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1584                    for i in long_indices
1585                ]
1586                res_indices_len = [
1587                    end - start
1588                    for start, end in zip(res_indices_start, res_indices_end)
1589                ]
1590                if set_labels:
1591                    label = labels[int(c)]
1592                else:
1593                    label = None
1594                ax.broken_barh(
1595                    list(zip(res_indices_start, res_indices_len)),
1596                    (height, 1),
1597                    label=label,
1598                    facecolors=color_list[c_i],
1599                )
1600
1601        if smooth_interval_prediction > 0:
1602            _plot_prediction(unsmoothed_prediction, 0)
1603            _plot_prediction(prediction, 1.5, set_labels=False)
1604            gt_height = 3
1605        else:
1606            _plot_prediction(prediction, 0)
1607            gt_height = 1.5
1608        if ground_truth:
1609            if not whole_video:
1610                gt = batch["target"][j].to(self.device)
1611            else:
1612                gt = dataset.generate_full_length_gt()[key1][key2]
1613            for c in label_keys:
1614                c_i = label_keys.index(int(c))
1615                if labels[int(c)] in classes:
1616                    label = None
1617                else:
1618                    label = labels[int(c)]
1619                output, indices, counts = torch.unique_consecutive(
1620                    gt == c, return_inverse=True, return_counts=True
1621                )
1622                long_indices = torch.where(output)[0]
1623                if len(long_indices) == 0:
1624                    continue
1625                res_indices_start = [
1626                    (indices == i).nonzero(as_tuple=True)[0][0].item()
1627                    for i in long_indices
1628                ]
1629                res_indices_end = [
1630                    (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1
1631                    for i in long_indices
1632                ]
1633                res_indices_len = [
1634                    end - start
1635                    for start, end in zip(res_indices_start, res_indices_end)
1636                ]
1637                ax.broken_barh(
1638                    list(zip(res_indices_start, res_indices_len)),
1639                    (gt_height, 1),
1640                    facecolors=color_list[c_i] if c != "unknown" else "gray",
1641                    label=label,
1642                )
1643        if not ground_truth:
1644            ax.axes.yaxis.set_visible(False)
1645        else:
1646            if smooth_interval_prediction > 0:
1647                ax.set_yticks([0.5, 2, 3.5])
1648                ax.set_yticklabels(["prediction", "smoothed", "ground truth"])
1649            else:
1650                ax.set_yticks([0.5, 2])
1651                ax.set_yticklabels(["prediction", "ground truth"])
1652        if add_legend:
1653            ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
1654        if hide_axes:
1655            plt.axis("off")
1656        if save_path is not None:
1657            plt.savefig(save_path, transparent=transparent)
1658        plt.show()

Visualize random predictions

Parameters

save_path : str, optional the path where the plot will be saved add_legend : bool, default True if True, legend will be added to the plot ground_truth : bool, default True if True, ground truth will be added to the plot colormap : str, default 'Accent' the matplotlib colormap to use hide_axes : bool, default True if True, the axes will be hidden on the plot min_classes : int, default 1 the minimum number of classes in a displayed interval width : float, default 10 the width of the plot whole_video : bool, default False if True, whole videos are plotted instead of segments transparent : bool, default False if True, the background on the plot is transparent dataset : BehaviorDataset | DataLoader | str | None, optional the dataset to make the prediction for (if not provided, the validation dataset is used) drop_classes : set, optional a set of class names to not be displayed search_classes : set, optional if given, only intervals where at least one of the classes is in ground truth will be shown

def set_ssl_transformations(self, ssl_transformations)
1660    def set_ssl_transformations(self, ssl_transformations):
1661        self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations)
1662        if self.val_dataloader is not None:
1663            self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations)
def set_ssl_losses(self, ssl_losses: list) -> None:
1665    def set_ssl_losses(self, ssl_losses: list) -> None:
1666        """
1667        Set SSL losses
1668
1669        Parameters
1670        ----------
1671        ssl_losses : list
1672            a list of callable SSL losses
1673        """
1674
1675        self.ssl_losses = ssl_losses

Set SSL losses

Parameters

ssl_losses : list a list of callable SSL losses

def set_log(self, log: str) -> None:
1677    def set_log(self, log: str) -> None:
1678        """
1679        Set the log file
1680
1681        Parameters
1682        ----------
1683        log: str
1684            the mew log file path
1685        """
1686
1687        self.log_file = log

Set the log file

Parameters

log: str the mew log file path

def set_keep_target_none(self, keep_target_none: List) -> None:
1689    def set_keep_target_none(self, keep_target_none: List) -> None:
1690        """
1691        Set the keep_target_none parameter of the transformer
1692
1693        Parameters
1694        ----------
1695        keep_target_none : list
1696            a list of bool values
1697        """
1698
1699        self.transformer.keep_target_none = keep_target_none

Set the keep_target_none parameter of the transformer

Parameters

keep_target_none : list a list of bool values

def set_generate_ssl_input(self, generate_ssl_input: list) -> None:
1701    def set_generate_ssl_input(self, generate_ssl_input: list) -> None:
1702        """
1703        Set the generate_ssl_input parameter of the transformer
1704
1705        Parameters
1706        ----------
1707        generate_ssl_input : list
1708            a list of bool values
1709        """
1710
1711        self.transformer.generate_ssl_input = generate_ssl_input

Set the generate_ssl_input parameter of the transformer

Parameters

generate_ssl_input : list a list of bool values

def set_model(self, model: dlc2action.model.base_model.Model) -> None:
1713    def set_model(self, model: Model) -> None:
1714        """
1715        Set a new model
1716
1717        Parameters
1718        ----------
1719        model: Model
1720            the new model
1721        """
1722
1723        self.epoch = 0
1724        self.model = model
1725        self.optimizer = self.optimizer_class(model.parameters(), lr=self.lr)
1726        if self.model.process_labels:
1727            self.model.set_behaviors(list(self.behaviors_dict().values()))

Set a new model

Parameters

model: Model the new model

def set_dataloaders( self, train_dataloader: torch.utils.data.dataloader.DataLoader, val_dataloader: torch.utils.data.dataloader.DataLoader = None, test_dataloader: torch.utils.data.dataloader.DataLoader = None) -> None:
1729    def set_dataloaders(
1730        self,
1731        train_dataloader: DataLoader,
1732        val_dataloader: DataLoader = None,
1733        test_dataloader: DataLoader = None,
1734    ) -> None:
1735        """
1736        Set new dataloaders
1737
1738        Parameters
1739        ----------
1740        train_dataloader: torch.utils.data.DataLoader
1741            the new train dataloader
1742        val_dataloader : torch.utils.data.DataLoader
1743            the new validation dataloader
1744        test_dataloader : torch.utils.data.DataLoader
1745            the new test dataloader
1746        """
1747
1748        self.train_dataloader = train_dataloader
1749        self.val_dataloader = val_dataloader
1750        self.test_dataloader = test_dataloader

Set new dataloaders

Parameters

train_dataloader: torch.utils.data.DataLoader the new train dataloader val_dataloader : torch.utils.data.DataLoader the new validation dataloader test_dataloader : torch.utils.data.DataLoader the new test dataloader

def set_loss(self, loss: Callable) -> None:
1752    def set_loss(self, loss: Callable) -> None:
1753        """
1754        Set new loss function
1755
1756        Parameters
1757        ----------
1758        loss: callable
1759            the new loss function
1760        """
1761
1762        self.loss = loss

Set new loss function

Parameters

loss: callable the new loss function

def set_metrics(self, metrics: dict) -> None:
1764    def set_metrics(self, metrics: dict) -> None:
1765        """
1766        Set new metric
1767
1768        Parameters
1769        ----------
1770        metrics : dict
1771            the new metric dictionary
1772        """
1773
1774        self.metrics = metrics

Set new metric

Parameters

metrics : dict the new metric dictionary

def set_transformer( self, transformer: dlc2action.transformer.base_transformer.Transformer) -> None:
1776    def set_transformer(self, transformer: Transformer) -> None:
1777        """
1778        Set a new transformer
1779
1780        Parameters
1781        ----------
1782        transformer: Transformer
1783            the new transformer
1784        """
1785
1786        self.transformer = transformer

Set a new transformer

Parameters

transformer: Transformer the new transformer

def set_predict_functions( self, primary_predict_function: Callable, predict_function: Callable) -> None:
1788    def set_predict_functions(
1789        self, primary_predict_function: Callable, predict_function: Callable
1790    ) -> None:
1791        """
1792        Set new predict functions
1793
1794        Parameters
1795        ----------
1796        primary_predict_function : callable
1797            the new primary predict function
1798        predict_function : callable
1799            the new predict function
1800        """
1801
1802        self.primary_predict_function = primary_predict_function
1803        self.predict_function = predict_function

Set new predict functions

Parameters

primary_predict_function : callable the new primary predict function predict_function : callable the new predict function

def count_classes(self, bouts: bool = False) -> Dict:
1831    def count_classes(self, bouts: bool = False) -> Dict:
1832        """
1833        Get a dictionary of class counts in different modes
1834
1835        Parameters
1836        ----------
1837        bouts : bool, default False
1838            if `True`, instead of frame counts segment counts are returned
1839
1840        Returns
1841        -------
1842        class_counts : dict
1843            a dictionary where first-level keys are "train", "val" and "test", second-level keys are
1844            class names and values are class counts (in frames)
1845        """
1846
1847        class_counts = {}
1848        for x in ["train", "val", "test"]:
1849            try:
1850                class_counts[x] = self.dataset(x).count_classes(bouts)
1851            except ValueError:
1852                class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()}
1853        return class_counts

Get a dictionary of class counts in different modes

Parameters

bouts : bool, default False if True, instead of frame counts segment counts are returned

Returns

class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)

def behaviors_dict(self) -> Dict:
1855    def behaviors_dict(self) -> Dict:
1856        """
1857        Get a behavior dictionary
1858
1859        Keys are label indices and values are label names.
1860
1861        Returns
1862        -------
1863        behaviors_dict : dict
1864            behavior dictionary
1865        """
1866
1867        return self.dataset().behaviors_dict()

Get a behavior dictionary

Keys are label indices and values are label names.

Returns

behaviors_dict : dict behavior dictionary

def update_parameters(self, parameters: Dict) -> None:
1869    def update_parameters(self, parameters: Dict) -> None:
1870        """
1871        Update training parameters from a dictionary
1872
1873        Parameters
1874        ----------
1875        parameters : dict
1876            the update dictionary
1877        """
1878
1879        self.lr = parameters.get("lr", self.lr)
1880        self.parallel = parameters.get("parallel", self.parallel)
1881        self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr)
1882        self.verbose = parameters.get("verbose", self.verbose)
1883        self.device = parameters.get("device", self.device)
1884        if self.device == "auto":
1885            self.device = "cuda" if torch.cuda.is_available() else "cpu"
1886        self.augment_train = int(parameters.get("augment_train", self.augment_train))
1887        self.augment_val = int(parameters.get("augment_val", self.augment_val))
1888        ssl_weights = parameters.get("ssl_weights", self.ssl_weights)
1889        if ssl_weights is None:
1890            ssl_weights = [1 for _ in self.ssl_losses]
1891        if not isinstance(ssl_weights, Iterable):
1892            ssl_weights = [ssl_weights for _ in self.ssl_losses]
1893        self.ssl_weights = ssl_weights
1894        self.num_epochs = parameters.get("num_epochs", self.num_epochs)
1895        self.model_save_epochs = parameters.get(
1896            "model_save_epochs", self.model_save_epochs
1897        )
1898        self.model_save_path = parameters.get("model_save_path", self.model_save_path)
1899        self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel)
1900        self.T1 = int(parameters.get("pseudolabel_start", self.T1))
1901        self.T2 = int(parameters.get("alpha_growth_stop", self.T2))
1902        self.t = int(parameters.get("correction_interval", self.t))
1903        self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f)
1904        self.log_file = parameters.get("log_file", self.log_file)

Update training parameters from a dictionary

Parameters

parameters : dict the update dictionary

def generate_uncertainty_score( self, classes: List, augment_n: int = 0, batch_size: int = 32, method: str = 'least_confidence', predicted: torch.Tensor = None, behaviors_dict: Dict = None) -> Dict:
1906    def generate_uncertainty_score(
1907        self,
1908        classes: List,
1909        augment_n: int = 0,
1910        batch_size: int = 32,
1911        method: str = "least_confidence",
1912        predicted: torch.Tensor = None,
1913        behaviors_dict: Dict = None,
1914    ) -> Dict:
1915        """
1916        Generate frame-wise scores for active learning
1917
1918        Parameters
1919        ----------
1920        classes : list
1921            a list of class names or indices; their confidence scores will be computed separately and stacked
1922        augment_n : int, default 0
1923            the number of augmentations to average over
1924        batch_size : int, default 32
1925            the batch size
1926        method : {"least_confidence", "entropy"}
1927            the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if
1928            `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`)
1929
1930        Returns
1931        -------
1932        score_dicts : dict
1933            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
1934            are score tensors
1935        """
1936
1937        dataset = self.dataset("train")
1938        if behaviors_dict is None:
1939            behaviors_dict = self.behaviors_dict()
1940        if not isinstance(dataset, BehaviorDataset):
1941            raise TypeError(
1942                f"The dataset parameter has to be either None, string, "
1943                f"BehaviorDataset or Dataloader, got {type(dataset)}"
1944            )
1945        if predicted is None:
1946            predicted = self.predict(
1947                dataset,
1948                raw_output=True,
1949                apply_primary_function=True,
1950                augment_n=augment_n,
1951                batch_size=batch_size,
1952            )
1953        predicted = dataset.generate_full_length_prediction(predicted)
1954        if isinstance(classes[0], str):
1955            behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()}
1956            classes = [behaviors_dict_inverse[c] for c in classes]
1957        for v_id, v in predicted.items():
1958            for clip_id, vv in v.items():
1959                if method == "least_confidence":
1960                    predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5]
1961                elif method == "entropy":
1962                    predicted[v_id][clip_id][vv != -100] = (
1963                        -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv)
1964                    )[vv != -100]
1965                elif method == "random":
1966                    predicted[v_id][clip_id] = torch.rand(vv.shape)
1967                else:
1968                    raise ValueError(
1969                        f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']"
1970                    )
1971                predicted[v_id][clip_id][vv == -100] = 0
1972
1973        predicted = {
1974            v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
1975            for v_id, video_dict in predicted.items()
1976        }
1977        return predicted

Generate frame-wise scores for active learning

Parameters

classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over batch_size : int, default 32 the batch size method : {"least_confidence", "entropy"} the method used to calculate the scores from the probability predictions ("least_confidence": 1 - p_i if p_i > 0.5 or p_i if p_i < 0.5; "entropy": - p_i * log(p_i) - (1 - p_i) * log(1 - p_i))

Returns

score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors

def generate_bald_score( self, classes: List, augment_n: int = 0, batch_size: int = 32, num_models: int = 10, kernel_size: int = 11) -> Dict:
1979    def generate_bald_score(
1980        self,
1981        classes: List,
1982        augment_n: int = 0,
1983        batch_size: int = 32,
1984        num_models: int = 10,
1985        kernel_size: int = 11,
1986    ) -> Dict:
1987        """
1988        Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning
1989
1990        Parameters
1991        ----------
1992        classes : list
1993            a list of class names or indices; their confidence scores will be computed separately and stacked
1994        augment_n : int, default 0
1995            the number of augmentations to average over
1996        batch_size : int, default 32
1997            the batch size
1998        num_models : int, default 10
1999            the number of dropout masks to apply
2000        kernel_size : int, default 11
2001            the size of the smoothing gaussian kernel
2002
2003        Returns
2004        -------
2005        score_dicts : dict
2006            a nested dictionary where first level keys are video ids, second level keys are clip ids and values
2007            are score tensors
2008        """
2009
2010        dataset = self.dataset("train")
2011        dataset = self._get_dataset(dataset)
2012        if not isinstance(dataset, BehaviorDataset):
2013            raise TypeError(
2014                f"The dataset parameter has to be either None, string, "
2015                f"BehaviorDataset or Dataloader, got {type(dataset)}"
2016            )
2017        predictions = []
2018        for _ in range(num_models):
2019            predicted = self.predict(
2020                dataset,
2021                raw_output=True,
2022                apply_primary_function=True,
2023                augment_n=augment_n,
2024                batch_size=batch_size,
2025                train_mode=True,
2026            )
2027            predicted = dataset.generate_full_length_prediction(predicted)
2028            if isinstance(classes[0], str):
2029                behaviors_dict_inverse = {
2030                    v: k for k, v in self.behaviors_dict().items()
2031                }
2032                classes = [behaviors_dict_inverse[c] for c in classes]
2033            for v_id, v in predicted.items():
2034                for clip_id, vv in v.items():
2035                    vv[vv != -100] = (vv[vv != -100] > 0.5).int().float()
2036                    predicted[v_id][clip_id] = vv
2037            predicted = {
2038                v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()}
2039                for v_id, video_dict in predicted.items()
2040            }
2041            predictions.append(predicted)
2042        result = {v_id: {} for v_id in predictions[0]}
2043        r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1)
2044        gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r]
2045        gauss = [x / sum(gauss) for x in gauss]
2046        kernel = torch.FloatTensor([[gauss]])
2047        for v_id in predictions[0]:
2048            for clip_id in predictions[0][v_id]:
2049                consensus = (
2050                    (
2051                        torch.mean(
2052                            torch.stack([x[v_id][clip_id] for x in predictions]), dim=0
2053                        )
2054                        > 0.5
2055                    )
2056                    .int()
2057                    .float()
2058                )
2059                consensus[predictions[0][v_id][clip_id] == -100] = -100
2060                result[v_id][clip_id] = torch.zeros(consensus.shape)
2061                for x in predictions:
2062                    result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int()
2063                result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models
2064                res = torch.zeros(result[v_id][clip_id].shape)
2065                for i in range(len(classes)):
2066                    res[
2067                        i, floor(kernel_size // 2) : -floor(kernel_size // 2)
2068                    ] = torch.nn.functional.conv1d(
2069                        result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0), kernel
2070                    )[
2071                        0, ...
2072                    ]
2073                result[v_id][clip_id] = res
2074        return result

Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning

Parameters

classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over batch_size : int, default 32 the batch size num_models : int, default 10 the number of dropout masks to apply kernel_size : int, default 11 the size of the smoothing gaussian kernel

Returns

score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors

def get_normalization_stats(self) -> Optional[Dict]:
2076    def get_normalization_stats(self) -> Optional[Dict]:
2077        return self.train_dataloader.dataset.stats