dlc2action.task.universal_task
Training and inference.
1# 2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. 5# A copy is included in dlc2action/LICENSE.AGPL. 6# 7"""Training and inference.""" 8 9import os 10import random 11import warnings 12from collections import defaultdict 13from collections.abc import Iterable 14from copy import copy, deepcopy 15from math import ceil, exp, floor, pi, sqrt 16from random import randint 17from time import time 18from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union 19 20import numpy as np 21import torch 22from dlc2action.data.dataset import BehaviorDataset 23from dlc2action.metric.base_metric import Metric 24from dlc2action.model.base_model import LoadedModel, Model 25from dlc2action.transformer.base_transformer import EmptyTransformer, Transformer 26from dlc2action.options import dlc2action_colormaps 27from matplotlib import cm 28from matplotlib import pyplot as plt 29from matplotlib import rc 30from optuna import TrialPruned 31from optuna.trial import Trial 32from torch import nn 33from torch.optim import Adam, Optimizer 34from torch.utils.data import DataLoader 35from tqdm import tqdm 36 37 38class MyDataParallel(nn.DataParallel): 39 """A wrapper for nn.DataParallel that allows to call methods of the model.""" 40 41 def __init__(self, *args, **kwargs): 42 """Initialize the class.""" 43 super().__init__(*args, **kwargs) 44 self.process_labels = self.module.process_labels 45 46 def freeze_feature_extractor(self): 47 """Freeze the feature extractor.""" 48 self.module.freeze_feature_extractor() 49 50 def unfreeze_feature_extractor(self): 51 """Unfreeze the feature extractor.""" 52 self.module.unfreeze_feature_extractor() 53 54 def transform_labels(self, device): 55 """Transform labels to the device.""" 56 return self.module.transform_labels(device) 57 58 def logit_scale(self): 59 """Return the logit scale of the model.""" 60 return self.module.logit_scale() 61 62 def main_task_off(self): 63 """Turn off the main task training.""" 64 self.module.main_task_off() 65 66 def state_dict(self, *args, **kwargs): 67 """Return the state of the module.""" 68 return self.module.state_dict(*args, **kwargs) 69 70 def ssl_on(self): 71 """Turn on SSL training.""" 72 self.module.ssl_on() 73 74 def ssl_off(self): 75 """Turn off SSL training.""" 76 self.module.ssl_off() 77 78 def extract_features(self, x, start=0): 79 """Extract features from the model.""" 80 return self.module.extract_features(x, start) 81 82 83class Task: 84 """A universal trainer class that performs training, evaluation and prediction for all types of tasks and data.""" 85 86 def __init__( 87 self, 88 train_dataloader: DataLoader, 89 model: Union[nn.Module, Model], 90 loss: Callable[[torch.Tensor, torch.Tensor], float], 91 num_epochs: int = 0, 92 transformer: Transformer = None, 93 ssl_losses: List = None, 94 ssl_weights: List = None, 95 lr: float = 1e-3, 96 weight_decay: float = 0, 97 metrics: Dict = None, 98 val_dataloader: DataLoader = None, 99 test_dataloader: DataLoader = None, 100 optimizer: Optimizer = None, 101 device: str = "cuda", 102 verbose: bool = True, 103 log_file: Union[str, None] = None, 104 augment_train: int = 1, 105 augment_val: int = 0, 106 validation_interval: int = 1, 107 predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None, 108 primary_predict_function: Callable = None, 109 exclusive: bool = True, 110 ignore_tags: bool = True, 111 threshold: float = 0.5, 112 model_save_path: str = None, 113 model_save_epochs: int = 5, 114 pseudolabel: bool = False, 115 pseudolabel_start: int = 100, 116 correction_interval: int = 2, 117 pseudolabel_alpha_f: float = 3, 118 alpha_growth_stop: int = 600, 119 parallel: bool = False, 120 skip_metrics: List = None, 121 ) -> None: 122 """Initialize the class. 123 124 Parameters 125 ---------- 126 train_dataloader : torch.utils.data.DataLoader 127 a training dataloader 128 model : dlc2action.model.base_model.Model 129 a model 130 loss : callable 131 a loss function 132 num_epochs : int, default 0 133 the number of epochs 134 transformer : dlc2action.transformer.base_transformer.Transformer, optional 135 a transformer 136 ssl_losses : list, optional 137 a list of SSL losses 138 ssl_weights : list, optional 139 a list of SSL weights (if not provided initializes to 1) 140 lr : float, default 1e-3 141 learning rate 142 weight_decay : float, default 0 143 weight decay 144 metrics : dict, optional 145 a list of metric functions 146 val_dataloader : torch.utils.data.DataLoader, optional 147 a validation dataloader 148 test_dataloader : torch.utils.data.DataLoader, optional 149 a test dataloader 150 optimizer : torch.optim.Optimizer, optional 151 an optimizer (`Adam` by default) 152 device : str, default 'cuda' 153 the device to train the model on 154 verbose : bool, default True 155 if `True`, the process is described in standard output 156 log_file : str, optional 157 the path to a text file where the process will be logged 158 augment_train : {1, 0} 159 number of augmentations to apply at training 160 augment_val : int, default 0 161 number of augmentations to apply at validation 162 validation_interval : int, default 1 163 every time this number of epochs passes, validation metrics are computed 164 predict_function : callable, optional 165 a function that maps probabilities to class predictions (if not provided, a default is generated) 166 primary_predict_function : callable, optional 167 a function that maps model output to probabilities (if not provided, initialized as identity) 168 exclusive : bool, default True 169 set to False for multi-label classification 170 ignore_tags : bool, default False 171 if `True`, samples with different meta tags will be mixed in batches 172 threshold : float, default 0.5 173 the threshold used for multi-label classification default prediction function 174 model_save_path : str, optional 175 the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path 176 is not provided) 177 model_save_epochs : int, default 5 178 the interval for saving the model checkpoints (the last epoch is always saved) 179 pseudolabel : bool, default False 180 if True, the pseudolabeling procedure will be applied 181 pseudolabel_start : int, default 100 182 pseudolabeling starts after this epoch 183 correction_interval : int, default 1 184 after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and 185 new pseudolabels are generated 186 pseudolabel_alpha_f : float, default 3 187 the maximum value of pseudolabeling alpha 188 alpha_growth_stop : int, default 600 189 pseudolabeling alpha stops growing after this epoch 190 parallel : bool, default False 191 if True, the model is trained on multiple GPUs 192 skip_metrics : list, optional 193 a list of metrics to skip 194 195 """ 196 # pseudolabeling might be buggy right now -- not using it! 197 if skip_metrics is None: 198 skip_metrics = [] 199 self.train_dataloader = train_dataloader 200 self.val_dataloader = val_dataloader 201 self.test_dataloader = test_dataloader 202 self.transformer = transformer 203 self.num_epochs = num_epochs 204 self.skip_metrics = skip_metrics 205 self.verbose = verbose 206 self.augment_train = int(augment_train) 207 self.augment_val = int(augment_val) 208 self.ignore_tags = ignore_tags 209 self.validation_interval = int(validation_interval) 210 self.log_file = log_file 211 self.loss = loss 212 self.model_save_path = model_save_path 213 self.model_save_epochs = model_save_epochs 214 self.epoch = 0 215 216 if metrics is None: 217 metrics = {} 218 self.metrics = metrics 219 220 if optimizer is None: 221 optimizer = Adam 222 223 if ssl_weights is None: 224 ssl_weights = [1 for _ in ssl_losses] 225 if not isinstance(ssl_weights, Iterable): 226 ssl_weights = [ssl_weights for _ in ssl_losses] 227 self.ssl_weights = ssl_weights 228 229 self.optimizer_class = optimizer 230 self.lr = lr 231 self.weight_decay = weight_decay 232 if not isinstance(model, Model): 233 self.model = LoadedModel(model=model) 234 else: 235 self.set_model(model) 236 self.parallel = parallel 237 238 if self.transformer is None: 239 self.augment_val = 0 240 self.augment_train = 0 241 self.transformer = EmptyTransformer() 242 243 if self.augment_train > 1: 244 warnings.warn( 245 'The "augment_train" parameter is too large -> setting it to 1.' 246 ) 247 self.augment_train = 1 248 249 try: 250 if device == "auto": 251 device = "cuda" if torch.cuda.is_available() else "cpu" 252 self.device = torch.device(device) 253 except: 254 raise ("The format of the device is incorrect") 255 256 if ssl_losses is None: 257 self.ssl_losses = [lambda x, y: 0] 258 else: 259 self.ssl_losses = ssl_losses 260 261 if primary_predict_function is None: 262 if exclusive: 263 primary_predict_function = lambda x: nn.Softmax(x, dim=1) 264 else: 265 primary_predict_function = lambda x: torch.sigmoid(x) 266 self.primary_predict_function = primary_predict_function 267 268 if predict_function is None: 269 if exclusive: 270 self.predict_function = lambda x: torch.max(x.data, 1)[1] 271 else: 272 self.predict_function = lambda x: (x > threshold).int() 273 else: 274 self.predict_function = predict_function 275 276 self.pseudolabel = pseudolabel 277 self.alpha_f = pseudolabel_alpha_f 278 self.T2 = alpha_growth_stop 279 self.T1 = pseudolabel_start 280 self.t = correction_interval 281 if self.T2 <= self.T1: 282 raise ValueError( 283 f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got " 284 f"{pseudolabel_start=} and {alpha_growth_stop=}" 285 ) 286 self.decision_thresholds = [0.5 for x in self.behaviors_dict()] 287 288 def save_checkpoint(self, checkpoint_path: str) -> None: 289 """Save a general checkpoint. 290 291 Parameters 292 ---------- 293 checkpoint_path : str 294 the path where the checkpoint will be saved 295 296 """ 297 torch.save( 298 { 299 "epoch": self.epoch, 300 "model_state_dict": self.model.state_dict(), 301 "optimizer_state_dict": self.optimizer.state_dict(), 302 }, 303 checkpoint_path, 304 ) 305 306 def load_from_checkpoint( 307 self, checkpoint_path, only_model: bool = False, load_strict: bool = True 308 ) -> None: 309 """Load from a checkpoint. 310 311 Parameters 312 ---------- 313 checkpoint_path : str 314 the path to the checkpoint 315 only_model : bool, default False 316 if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state 317 dictionary) 318 load_strict : bool, default True 319 if `True`, any inconsistencies in state dictionaries are regarded as errors 320 321 """ 322 if checkpoint_path is None: 323 return 324 checkpoint = torch.load( 325 checkpoint_path, map_location=self.device, weights_only=False 326 ) 327 self.model.to(self.device) 328 self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict) 329 if not only_model: 330 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 331 self.epoch = checkpoint["epoch"] 332 333 def save_model(self, save_path: str) -> None: 334 """Save the model state dictionary. 335 336 Parameters 337 ---------- 338 save_path : str 339 the path where the state will be saved 340 341 """ 342 torch.save(self.model.state_dict(), save_path) 343 print("saved the model successfully") 344 345 def _apply_predict_functions(self, predicted): 346 """Map from model output to prediction.""" 347 predicted = self.primary_predict_function(predicted) 348 predicted = self.predict_function(predicted) 349 return predicted 350 351 def _get_prediction( 352 self, 353 main_input: Dict, 354 tags: torch.Tensor, 355 ssl_inputs: List = None, 356 ssl_targets: List = None, 357 augment_n: int = 0, 358 embedding: bool = False, 359 subsample: List = None, 360 ) -> Tuple: 361 """Get the prediction of `self.model` for input averaged over `augment_n` augmentations.""" 362 if augment_n == 0: 363 augment_n = 1 364 augment = False 365 else: 366 augment = True 367 model_input, ssl_inputs, ssl_targets = self.transformer.transform( 368 deepcopy(main_input), 369 ssl_inputs=ssl_inputs, 370 ssl_targets=ssl_targets, 371 augment=augment, 372 subsample=subsample, 373 ) 374 if not embedding: 375 # t1 = time() 376 predicted, ssl_predicted = self.model(model_input, ssl_inputs, tag=tags) 377 # t2 = time() 378 # with open("/home/andy/Documents/EPFLsmartkitchenrecording/baselines/DLC2Action_baselines/inference_times/body_hand_eyes_2405.txt", "a") as f: 379 # f.write(f"Time: {t2 - t1}\n") 380 else: 381 predicted = self.model.extract_features(model_input) 382 ssl_predicted = None 383 if self.parallel and predicted is not None and len(predicted.shape) == 4: 384 predicted = predicted.reshape( 385 (-1, model_input.shape[0], *predicted.shape[2:]) 386 ) 387 if predicted is not None and augment_n > 1: 388 self.model.ssl_off() 389 for i in range(augment_n - 1): 390 model_input, *_ = self.transformer.transform( 391 deepcopy(main_input), augment=augment 392 ) 393 if not embedding: 394 pred, _ = self.model(model_input, None, tag=tags) 395 else: 396 pred = self.model.extract_features(model_input) 397 predicted += pred.detach() 398 self.model.ssl_on() 399 predicted /= augment_n 400 if self.model.process_labels: 401 class_embedding = self.model.transform_labels(predicted.device) 402 beh_dict = self.behaviors_dict() 403 class_embedding = torch.cat( 404 [class_embedding[beh_dict[k]] for k in sorted(beh_dict.keys())], 0 405 ) 406 predicted = { 407 "video": predicted, 408 "text": class_embedding, 409 "logit_scale": self.model.logit_scale(), 410 "device": predicted.device, 411 } 412 return predicted, ssl_predicted, ssl_targets 413 414 def _ssl_loss(self, ssl_predicted, ssl_targets): 415 """Apply SSL losses.""" 416 if self.ssl_losses is None or ssl_predicted is None or ssl_targets is None: 417 return [] 418 419 ssl_loss = [] 420 for loss, predicted, target in zip(self.ssl_losses, ssl_predicted, ssl_targets): 421 ssl_loss.append(loss(predicted, target)) 422 return ssl_loss 423 424 def _loss_function( 425 self, 426 batch: Dict, 427 augment_n: int, 428 temporal_subsampling_size: int = None, 429 skip_metrics: List = None, 430 ) -> Tuple[float, float]: 431 """Calculate the loss function and the metric for a dataloader batch. 432 433 Averaging the predictions over augment_n augmentations. 434 435 """ 436 if "target" not in batch or torch.isnan(batch["target"]).all(): 437 raise ValueError("Cannot compute loss function with nan targets!") 438 main_input = {k: v.to(self.device) for k, v in batch["input"].items()} 439 main_target, ssl_targets, ssl_inputs = None, None, None 440 main_target = batch["target"].to(self.device) 441 if "ssl_targets" in batch: 442 ssl_targets = [ 443 ( 444 {k: v.to(self.device) for k, v in x.items()} 445 if isinstance(x, dict) 446 else None 447 ) 448 for x in batch["ssl_targets"] 449 ] 450 if "ssl_inputs" in batch: 451 ssl_inputs = [ 452 ( 453 {k: v.to(self.device) for k, v in x.items()} 454 if isinstance(x, dict) 455 else None 456 ) 457 for x in batch["ssl_inputs"] 458 ] 459 if temporal_subsampling_size is not None: 460 subsample = sorted( 461 random.sample( 462 range(main_target.shape[-1]), 463 int(temporal_subsampling_size * main_target.shape[-1]), 464 ) 465 ) 466 main_target = main_target[..., subsample] 467 else: 468 subsample = None 469 470 if self.ignore_tags: 471 tag = None 472 else: 473 tag = batch.get("tag") 474 predicted, ssl_predicted, ssl_targets = self._get_prediction( 475 main_input, tag, ssl_inputs, ssl_targets, augment_n, subsample=subsample 476 ) 477 del main_input, ssl_inputs 478 return self._compute( 479 ssl_predicted, 480 ssl_targets, 481 predicted, 482 main_target, 483 tag=batch.get("tag"), 484 skip_metrics=skip_metrics, 485 ) 486 487 def _compute( 488 self, 489 ssl_predicted: List, 490 ssl_targets: List, 491 predicted: torch.Tensor, 492 main_target: torch.Tensor, 493 tag: Any = None, 494 skip_loss: bool = False, 495 apply_primary_function: bool = True, 496 skip_metrics: List = None, 497 ) -> Tuple[float, float]: 498 """Compute the losses and metrics from predictions.""" 499 if skip_metrics is None: 500 skip_metrics = [] 501 if not skip_loss: 502 ssl_losses = self._ssl_loss(ssl_predicted, ssl_targets) 503 if predicted is not None: 504 loss = self.loss(predicted, main_target) 505 else: 506 loss = 0 507 else: 508 ssl_losses, loss = [], 0 509 510 if predicted is not None: 511 if isinstance(predicted, dict): 512 predicted = { 513 k: v.detach() 514 for k, v in predicted.items() 515 if isinstance(v, torch.Tensor) 516 } 517 else: 518 predicted = predicted.detach() 519 if apply_primary_function: 520 predicted = self.primary_predict_function(predicted) 521 predicted_transformed = self.predict_function(predicted) 522 523 for name, metric_function in self.metrics.items(): 524 if name not in skip_metrics: 525 if metric_function.needs_raw_data: 526 metric_function.update(predicted, main_target, tag) 527 else: 528 metric_function.update( 529 predicted_transformed, 530 main_target, 531 tag, 532 ) 533 return loss, ssl_losses 534 535 def _calculate_metrics(self) -> Dict: 536 """Calculate the final values of epoch metrics.""" 537 epoch_metrics = {} 538 for metric_name, metric in self.metrics.items(): 539 m = metric.calculate() 540 if type(m) is dict: 541 for k, v in m.items(): 542 if type(v) is torch.Tensor: 543 v = v.item() 544 epoch_metrics[f"{metric_name}_{k}"] = v 545 else: 546 if type(m) is torch.Tensor: 547 m = m.item() 548 epoch_metrics[metric_name] = m 549 metric.reset() 550 # print("calculate metric", epoch_metrics) 551 return epoch_metrics 552 553 def _run_epoch( 554 self, 555 dataloader: DataLoader, 556 mode: str, 557 augment_n: int, 558 verbose: bool = False, 559 unlabeled: bool = None, 560 alpha: float = 1, 561 temporal_subsampling_size: int = None, 562 ) -> Tuple: 563 """Run one epoch on dataloader. 564 565 Averaging the predictions over augment_n augmentations. 566 Use "train" mode for training and "val" mode for evaluation. 567 568 """ 569 if mode == "train": 570 self.model.train() 571 elif mode == "val": 572 self.model.eval() 573 pass 574 else: 575 raise ValueError( 576 f'Mode {mode} is not recognized, please choose either "train" for training or "val" for validation' 577 ) 578 if self.ignore_tags: 579 tags = [None] 580 else: 581 tags = dataloader.dataset.get_tags() 582 epoch_loss = 0 583 epoch_ssl_loss = defaultdict(lambda: 0) 584 data_len = 0 585 set_pars = dataloader.dataset.set_indexing_parameters 586 skip_metrics = self.skip_metrics if mode == "train" else None 587 for tag in tags: 588 set_pars(unlabeled=unlabeled, tag=tag) 589 data_len += len(dataloader) 590 if verbose: 591 dataloader = tqdm(dataloader) 592 for batch in dataloader: 593 loss, ssl_losses = self._loss_function( 594 batch, 595 augment_n, 596 temporal_subsampling_size=temporal_subsampling_size, 597 skip_metrics=skip_metrics, 598 ) 599 if loss != 0: 600 loss = loss * alpha 601 epoch_loss += loss.item() 602 for i, (ssl_loss, weight) in enumerate( 603 zip(ssl_losses, self.ssl_weights) 604 ): 605 if ssl_loss != 0: 606 epoch_ssl_loss[i] += ssl_loss.item() 607 loss = loss + weight * ssl_loss 608 if mode == "train": 609 self.optimizer.zero_grad() 610 if loss.requires_grad: 611 loss.backward() 612 self.optimizer.step() 613 614 epoch_loss = epoch_loss / data_len 615 epoch_ssl_loss = {k: v / data_len for k, v in epoch_ssl_loss.items()} 616 epoch_metrics = self._calculate_metrics() 617 618 return epoch_loss, epoch_ssl_loss, epoch_metrics 619 620 def train( 621 self, 622 trial: Trial = None, 623 optimized_metric: str = None, 624 to_ram: bool = False, 625 autostop_interval: int = 30, 626 autostop_threshold: float = 0.001, 627 autostop_metric: str = None, 628 main_task_on: bool = True, 629 ssl_on: bool = True, 630 temporal_subsampling_size: int = None, 631 loading_bar: bool = False, 632 ) -> Tuple: 633 """Train the task and return a log of epoch-average loss and metric. 634 635 You can use the autostop parameters to finish training when the parameters are not improving. It will be 636 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 637 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 638 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 639 640 Parameters 641 ---------- 642 trial : Trial 643 an `optuna` trial (for hyperparameter searches) 644 optimized_metric : str 645 the name of the metric being optimized (for hyperparameter searches) 646 to_ram : bool, default False 647 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 648 if the dataset is too large) 649 autostop_interval : int, default 50 650 the number of epochs to average the autostop metric over 651 autostop_threshold : float, default 0.001 652 the autostop difference threshold 653 autostop_metric : str, optional 654 the autostop metric (can be any one of the tracked metrics of `'loss'`) 655 main_task_on : bool, default True 656 if `False`, the main task (action segmentation) will not be used in training 657 ssl_on : bool, default True 658 if `False`, the SSL task will not be used in training 659 temporal_subsampling_size : int, optional 660 if not `None`, the temporal subsampling will be used in training with the given size 661 loading_bar : bool, default False 662 if `True`, a loading bar will be displayed 663 664 Returns 665 ------- 666 loss_log: list 667 a list of float loss function values for each epoch 668 metrics_log: dict 669 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 670 names, values are lists of function values) 671 672 """ 673 if self.parallel and not isinstance(self.model, nn.DataParallel): 674 self.model = MyDataParallel(self.model) 675 self.model.to(self.device) 676 assert autostop_metric in [None, "loss"] + list(self.metrics) 677 autostop_interval //= self.validation_interval 678 if trial is not None and optimized_metric is None: 679 raise ValueError( 680 "You need to provide the optimized metric name (optimized_metric parameter) " 681 "for optuna pruning to work!" 682 ) 683 if to_ram: 684 print("transferring datasets to RAM...") 685 self.train_dataloader.dataset.to_ram() 686 if self.val_dataloader is not None and len(self.val_dataloader) > 0: 687 self.val_dataloader.dataset.to_ram() 688 loss_log = {"train": [], "val": []} 689 metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])} 690 if not main_task_on: 691 self.model.main_task_off() 692 if not ssl_on: 693 self.model.ssl_off() 694 while self.epoch < self.num_epochs: 695 self.epoch += 1 696 unlabeled = None 697 alpha = 1 698 if self.pseudolabel: 699 if self.epoch >= self.T1: 700 unlabeled = (self.epoch - self.T1) % self.t != 0 701 if unlabeled: 702 alpha = self._alpha(self.epoch) 703 else: 704 unlabeled = False 705 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 706 dataloader=self.train_dataloader, 707 mode="train", 708 augment_n=self.augment_train, 709 unlabeled=unlabeled, 710 alpha=alpha, 711 temporal_subsampling_size=temporal_subsampling_size, 712 verbose=loading_bar, 713 ) 714 loss_log["train"].append(epoch_loss) 715 epoch_string = f"[epoch {self.epoch}]" 716 if self.pseudolabel: 717 if unlabeled: 718 epoch_string += " (unlabeled)" 719 else: 720 epoch_string += " (labeled)" 721 epoch_string += f": loss {epoch_loss:.4f}" 722 723 if len(epoch_ssl_loss) != 0: 724 for key, value in sorted(epoch_ssl_loss.items()): 725 metrics_log["train"][f"ssl_loss_{key}"].append(value) 726 epoch_string += f", ssl_loss_{key} {value:.4f}" 727 728 for metric_name, metric_value in sorted(epoch_metrics.items()): 729 if metric_name not in self.skip_metrics: 730 if isinstance(metric_value, list): 731 metric_value = torch.mean(torch.Tensor(metric_value)) 732 epoch_string += f", {metric_name} {metric_value:.3f}" 733 metrics_log["train"][metric_name].append(metric_value) 734 735 if ( 736 self.val_dataloader is not None 737 and self.epoch % self.validation_interval == 0 738 ): 739 with torch.no_grad(): 740 epoch_string += "\n" 741 ( 742 val_epoch_loss, 743 val_epoch_ssl_loss, 744 val_epoch_metrics, 745 ) = self._run_epoch( 746 dataloader=self.val_dataloader, 747 mode="val", 748 augment_n=self.augment_val, 749 ) 750 loss_log["val"].append(val_epoch_loss) 751 epoch_string += f"validation: loss {val_epoch_loss:.4f}" 752 753 if len(val_epoch_ssl_loss) != 0: 754 for key, value in sorted(val_epoch_ssl_loss.items()): 755 metrics_log["val"][f"ssl_loss_{key}"].append(value) 756 epoch_string += f", ssl_loss_{key} {value:.4f}" 757 758 for metric_name, metric_value in sorted(val_epoch_metrics.items()): 759 if isinstance(metric_value, list): 760 metric_value = torch.mean(torch.Tensor(metric_value)) 761 metrics_log["val"][metric_name].append(metric_value) 762 epoch_string += f", {metric_name} {metric_value:.3f}" 763 764 if trial is not None: 765 if optimized_metric not in metrics_log["val"]: 766 raise ValueError( 767 f"The {optimized_metric} metric set for optimization is not being logged!" 768 ) 769 trial.report(metrics_log["val"][optimized_metric][-1], self.epoch) 770 if trial.should_prune(): 771 raise TrialPruned() 772 773 if self.verbose: 774 print(epoch_string) 775 776 if self.log_file is not None: 777 with open(self.log_file, "a") as f: 778 f.write(epoch_string + "\n") 779 780 save_condition = ( 781 (self.model_save_epochs != 0) 782 and (self.epoch % self.model_save_epochs == 0) 783 ) or (self.epoch == self.num_epochs) 784 785 if self.epoch > 0 and save_condition and self.model_save_path is not None: 786 epoch_s = str(self.epoch).zfill(len(str(self.num_epochs))) 787 self.save_checkpoint( 788 os.path.join(self.model_save_path, f"epoch{epoch_s}.pt") 789 ) 790 791 if self.pseudolabel and self.epoch >= self.T1 and not unlabeled: 792 self._set_pseudolabels() 793 794 if autostop_metric == "loss": 795 if len(loss_log["val"]) > autostop_interval * 2: 796 if ( 797 np.mean(loss_log["val"][-autostop_interval:]) 798 < np.mean( 799 loss_log["val"][-2 * autostop_interval : -autostop_interval] 800 ) 801 + autostop_threshold 802 ): 803 break 804 elif autostop_metric in metrics_log["val"]: 805 if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2: 806 if ( 807 np.mean( 808 metrics_log["val"][autostop_metric][-autostop_interval:] 809 ) 810 < np.mean( 811 metrics_log["val"][autostop_metric][ 812 -2 * autostop_interval : -autostop_interval 813 ] 814 ) 815 + autostop_threshold 816 ): 817 break 818 819 metrics_log = {k: dict(v) for k, v in metrics_log.items()} 820 821 return loss_log, metrics_log 822 823 def evaluate_prediction( 824 self, 825 prediction: Union[torch.Tensor, Dict], 826 data: Union[DataLoader, BehaviorDataset, str] = None, 827 batch_size: int = 32, 828 indices: list = None, 829 ) -> Tuple: 830 """Compute metrics for a prediction. 831 832 Parameters 833 ---------- 834 prediction : torch.Tensor 835 the prediction 836 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 837 the data the prediction was made for (if not provided, take the validation dataset) 838 batch_size : int, default 32 839 the batch size 840 841 Returns 842 ------- 843 loss : float 844 the average value of the loss function 845 metric : dict 846 a dictionary of average values of metric functions 847 """ 848 849 if type(data) is not DataLoader: 850 dataset = self._get_dataset(data) 851 data = DataLoader(dataset, shuffle=False, batch_size=batch_size) 852 epoch_loss = 0 853 if isinstance(prediction, dict): 854 num_classes = len(self.behaviors_dict()) 855 length = dataset.len_segment() 856 coords = dataset.annotation_store.get_original_coordinates() 857 for batch in data: 858 main_target = batch["target"] 859 pr_coords = coords[batch["index"]] 860 predicted = torch.zeros((len(pr_coords), num_classes, length)) 861 for i, c in enumerate(pr_coords): 862 video_id = dataset.input_store.get_video_id(c) 863 clip_id = dataset.input_store.get_clip_id(c) 864 start, end = dataset.input_store.get_clip_start_end(c) 865 beh_ind = list(prediction[video_id]["classes"].keys()) 866 pred_tmp = prediction[video_id][clip_id][beh_ind, :] 867 predicted[i, :, : end - start] = pred_tmp[:, start:end] 868 self._compute( 869 [], 870 [], 871 predicted, 872 main_target, 873 skip_loss=True, 874 tag=batch.get("tag"), 875 apply_primary_function=False, 876 ) 877 else: 878 for batch in data: 879 main_target = batch["target"] 880 predicted = prediction[batch["index"]] 881 if not indices is None: 882 indices_new = [indices.index(i) for i in range(len(indices))] 883 predicted = predicted[:, indices_new] 884 self._compute( 885 [], 886 [], 887 predicted, 888 main_target, 889 skip_loss=True, 890 tag=batch.get("tag"), 891 apply_primary_function=False, 892 ) 893 epoch_metrics = self._calculate_metrics() 894 # strings = [ 895 # f"{metric_name} {metric_value:.3f}" 896 # for metric_name, metric_value in epoch_metrics.items() 897 # ] 898 # val_string = ", ".join(sorted(strings)) 899 # print(val_string) 900 return epoch_loss, epoch_metrics 901 902 def evaluate( 903 self, 904 data: Union[DataLoader, BehaviorDataset, str] = None, 905 augment_n: int = 0, 906 batch_size: int = 32, 907 verbose: bool = True, 908 ) -> Tuple: 909 """Evaluate the Task model. 910 911 Parameters 912 ---------- 913 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 914 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 915 augment_n : int, default 0 916 the number of augmentations to average results over 917 batch_size : int, default 32 918 the batch size 919 verbose : bool, default True 920 if True, the process is reported to standard output 921 922 Returns 923 ------- 924 loss : float 925 the average value of the loss function 926 ssl_loss : float 927 the average value of the SSL loss function 928 metric : dict 929 a dictionary of average values of metric functions 930 931 """ 932 if self.parallel and not isinstance(self.model, nn.DataParallel): 933 self.model = MyDataParallel(self.model) 934 self.model.to(self.device) 935 if type(data) is not DataLoader: 936 data = self._get_dataset(data) 937 data = DataLoader(data, shuffle=False, batch_size=batch_size) 938 with torch.no_grad(): 939 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 940 dataloader=data, mode="val", augment_n=augment_n, verbose=verbose 941 ) 942 val_string = f"loss {epoch_loss:.4f}" 943 for metric_name, metric_value in sorted(epoch_metrics.items()): 944 val_string += f", {metric_name} {metric_value:.3f}" 945 print(val_string) 946 return epoch_loss, epoch_ssl_loss, epoch_metrics 947 948 def predict( 949 self, 950 data: Union[DataLoader, BehaviorDataset, str] = None, 951 raw_output: bool = False, 952 apply_primary_function: bool = True, 953 augment_n: int = 0, 954 batch_size: int = 32, 955 train_mode: bool = False, 956 to_ram: bool = False, 957 embedding: bool = False, 958 ) -> torch.Tensor: 959 """Make a prediction with the Task model. 960 961 Parameters 962 ---------- 963 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional 964 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 965 raw_output : bool, default False 966 if `True`, the raw predicted probabilities are returned 967 apply_primary_function : bool, default True 968 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 969 the input) 970 augment_n : int, default 0 971 the number of augmentations to average results over 972 batch_size : int, default 32 973 the batch size 974 train_mode : bool, default False 975 if `True`, the model is used in training mode (affects dropout and batch normalization layers) 976 to_ram : bool, default False 977 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 978 if the dataset is too large) 979 embedding : bool, default False 980 if `True`, the output of feature extractor is returned, ignoring the prediction module of the model 981 982 Returns 983 ------- 984 prediction : torch.Tensor 985 a prediction for the input data 986 987 """ 988 if self.parallel and not isinstance(self.model, nn.DataParallel): 989 self.model = MyDataParallel(self.model) 990 self.model.to(self.device) 991 if train_mode: 992 self.model.train() 993 else: 994 self.model.eval() 995 output = [] 996 if embedding: 997 raw_output = True 998 apply_primary_function = True 999 if type(data) is not DataLoader: 1000 data = self._get_dataset(data) 1001 if to_ram: 1002 print("transferring dataset to RAM...") 1003 data.to_ram() 1004 data = DataLoader(data, shuffle=False, batch_size=batch_size) 1005 self.model.ssl_off() 1006 with torch.no_grad(): 1007 for batch in tqdm(data): 1008 input = {k: v.to(self.device) for k, v in batch["input"].items()} 1009 predicted, _, _ = self._get_prediction( 1010 input, 1011 batch.get("tag"), 1012 augment_n=augment_n, 1013 embedding=embedding, 1014 ) 1015 if apply_primary_function: 1016 predicted = self.primary_predict_function(predicted) 1017 if not raw_output: 1018 predicted = self.predict_function(predicted) 1019 output.append(predicted.detach().cpu()) 1020 self.model.ssl_on() 1021 output = torch.cat(output).detach() 1022 return output 1023 1024 def dataset(self, mode="train") -> BehaviorDataset: 1025 """Get a dataset. 1026 1027 Parameters 1028 ---------- 1029 mode : {'train', 'val', 'test} 1030 the dataset to get 1031 1032 Returns 1033 ------- 1034 dataset : dlc2action.data.dataset.BehaviorDataset 1035 the dataset 1036 1037 """ 1038 dataloader = self.dataloader(mode) 1039 if dataloader is None: 1040 raise ValueError("The length of the dataloader is 0!") 1041 return dataloader.dataset 1042 1043 def dataloader(self, mode: str = "train") -> DataLoader: 1044 """Get a dataloader. 1045 1046 Parameters 1047 ---------- 1048 mode : {'train', 'val', 'test} 1049 the dataset to get 1050 1051 Returns 1052 ------- 1053 dataloader : torch.utils.data.DataLoader 1054 the dataloader 1055 1056 """ 1057 if mode == "train": 1058 return self.train_dataloader 1059 elif mode == "val": 1060 return self.val_dataloader 1061 elif mode == "test": 1062 return self.test_dataloader 1063 else: 1064 raise ValueError( 1065 f'The {mode} mode is not recognized, please choose from "train", "val" or "test"' 1066 ) 1067 1068 def _get_dataset(self, dataset): 1069 """Get a dataset from a dataloader, a string ('train', 'test' or 'val') or `None` (default).""" 1070 if dataset is None: 1071 dataset = self.dataset("val") 1072 elif dataset in ["train", "val", "test"]: 1073 dataset = self.dataset(dataset) 1074 elif type(dataset) is DataLoader: 1075 dataset = dataset.dataset 1076 if type(dataset) is BehaviorDataset: 1077 return dataset 1078 else: 1079 raise TypeError(f"The {type(dataset)} type of dataset is not recognized!") 1080 1081 def _get_dataloader(self, dataset): 1082 """Get a dataloader from a dataset, a string ('train', 'test' or 'val') or `None` (default).""" 1083 if dataset is None: 1084 dataset = self.dataloader("val") 1085 elif dataset in ["train", "val", "test"]: 1086 dataset = self.dataloader(dataset) 1087 if dataset is None: 1088 raise ValueError(f"The length of the dataloader is 0!") 1089 elif type(dataset) is BehaviorDataset: 1090 dataset = DataLoader(dataset) 1091 if type(dataset) is DataLoader: 1092 return dataset 1093 else: 1094 raise TypeError(f"The {type(dataset)} type of dataset is not recognized!") 1095 1096 def generate_full_length_prediction( 1097 self, dataset=None, batch_size=32, augment_n=10 1098 ): 1099 """Compile a prediction for the original input sequences. 1100 1101 Parameters 1102 ---------- 1103 dataset : BehaviorDataset, optional 1104 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1105 batch_size : int, default 32 1106 the batch size 1107 augment_n : int, default 10 1108 the number of augmentations to average results over 1109 1110 Returns 1111 ------- 1112 prediction : dict 1113 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1114 are prediction tensors 1115 1116 """ 1117 dataset = self._get_dataset(dataset) 1118 if not isinstance(dataset, BehaviorDataset): 1119 raise TypeError( 1120 f"The dataset parameter has to be either None, string, " 1121 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1122 ) 1123 predicted = self.predict( 1124 dataset, 1125 raw_output=True, 1126 apply_primary_function=True, 1127 augment_n=augment_n, 1128 batch_size=batch_size, 1129 ) 1130 predicted = dataset.generate_full_length_prediction(predicted) 1131 predicted = { 1132 v_id: { 1133 clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze() 1134 for clip_id, v in video_dict.items() 1135 } 1136 for v_id, video_dict in predicted.items() 1137 } 1138 return predicted 1139 1140 def generate_submission( 1141 self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10 1142 ): 1143 """Generate a MABe-22 style submission dictionary. 1144 1145 Parameters 1146 ---------- 1147 frame_number_map_file : str 1148 the path to the frame number map file 1149 dataset : BehaviorDataset, optional 1150 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1151 batch_size : int, default 32 1152 the batch size 1153 augment_n : int, default 10 1154 the number of augmentations to average results over 1155 1156 Returns 1157 ------- 1158 submission : dict 1159 a dictionary with frame number mapping and embeddings 1160 1161 """ 1162 dataset = self._get_dataset(dataset) 1163 if not isinstance(dataset, BehaviorDataset): 1164 raise TypeError( 1165 f"The dataset parameter has to be either None, string, " 1166 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1167 ) 1168 predicted = self.predict( 1169 dataset, 1170 raw_output=True, 1171 apply_primary_function=True, 1172 augment_n=augment_n, 1173 batch_size=batch_size, 1174 embedding=True, 1175 ) 1176 predicted = dataset.generate_full_length_prediction(predicted) 1177 frame_map = np.load(frame_number_map_file, allow_pickle=True).item() 1178 length = frame_map[list(frame_map.keys())[-1]][1] 1179 embeddings = None 1180 for video_id in list(predicted.keys()): 1181 split = video_id.split("--") 1182 if len(split) != 2 or len(predicted[video_id]) > 1: 1183 raise RuntimeError( 1184 "Generating submissions is only implemented for the mabe22 dataset!" 1185 ) 1186 if split[1] not in frame_map: 1187 raise RuntimeError(f"The {split[1]} video is not in the frame map file") 1188 v_id = split[1] 1189 clip_id = list(predicted[video_id].keys())[0] 1190 if embeddings is None: 1191 embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0])) 1192 start, end = frame_map[v_id] 1193 embeddings[start:end, :] = predicted[video_id][clip_id].T 1194 predicted.pop(video_id) 1195 predicted = { 1196 "frame_number_map": frame_map, 1197 "embeddings": embeddings.astype(np.float32), 1198 } 1199 return predicted 1200 1201 def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor: 1202 """Get a list of True group beginning and end indices from a boolean tensor.""" 1203 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 1204 true_indices = torch.where(output)[0] 1205 starts = torch.tensor( 1206 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 1207 ) 1208 ends = torch.tensor( 1209 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 1210 ) 1211 return torch.stack([starts, ends]).T 1212 1213 def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor: 1214 """Get rid of jittering in a non-exclusive classification tensor. 1215 1216 First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than 1217 `smooth_interval`. 1218 1219 """ 1220 if len(tensor.shape) > 1: 1221 for c in tensor.shape[1]: 1222 intervals = self._get_intervals(tensor[:, c] == 0) 1223 interval_lengths = torch.tensor( 1224 [interval[1] - interval[0] for interval in intervals] 1225 ) 1226 short_intervals = intervals[interval_lengths <= smooth_interval] 1227 for start, end in short_intervals: 1228 tensor[start:end, c] = 1 1229 intervals = self._get_intervals(tensor[:, c] == 1) 1230 interval_lengths = torch.tensor( 1231 [interval[1] - interval[0] for interval in intervals] 1232 ) 1233 short_intervals = intervals[interval_lengths <= smooth_interval] 1234 for start, end in short_intervals: 1235 tensor[start:end, c] = 0 1236 else: 1237 for c in tensor.unique(): 1238 intervals = self._get_intervals(tensor == c) 1239 interval_lengths = torch.tensor( 1240 [interval[1] - interval[0] for interval in intervals] 1241 ) 1242 short_intervals = intervals[interval_lengths <= smooth_interval] 1243 for start, end in short_intervals: 1244 if start == 0: 1245 tensor[start:end] = tensor[end + 1] 1246 else: 1247 tensor[start:end] = tensor[start - 1] 1248 return tensor 1249 1250 def _visualize_results_single( 1251 self, 1252 behavior: str, 1253 save_path: str = None, 1254 add_legend: bool = True, 1255 ground_truth: bool = True, 1256 hide_axes: bool = False, 1257 width: int = 10, 1258 whole_video: bool = False, 1259 transparent: bool = False, 1260 dataset: BehaviorDataset = None, 1261 smooth_interval: int = 0, 1262 title: str = None, 1263 ): 1264 """Visualize random predictions. 1265 1266 Parameters 1267 ---------- 1268 behavior : str 1269 the behavior to visualize 1270 save_path : str, optional 1271 the path where the plot will be saved 1272 add_legend : bool, default True 1273 if `True`, legend will be added to the plot 1274 ground_truth : bool, default True 1275 if `True`, ground truth will be added to the plot 1276 colormap : str, default 'viridis' 1277 the `matplotlib` colormap to use 1278 hide_axes : bool, default True 1279 if `True`, the axes will be hidden on the plot 1280 min_classes : int, default 1 1281 the minimum number of classes in a displayed interval 1282 width : float, default 10 1283 the width of the plot 1284 whole_video : bool, default False 1285 if `True`, whole videos are plotted instead of segments 1286 transparent : bool, default False 1287 if `True`, the background on the plot is transparent 1288 dataset : BehaviorDataset, optional 1289 the dataset to make the prediction for (if not provided, the validation dataset is used) 1290 drop_classes : set, optional 1291 a set of class names to not be displayed 1292 smooth_interval : int, default 0 1293 the interval to smooth the predictions over 1294 title : str, optional 1295 the title of the plot 1296 1297 """ 1298 if title is None: 1299 title = "" 1300 dataset = self._get_dataset(dataset) 1301 inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()} 1302 label_ind = inverse_dict[behavior] 1303 labels = {1: behavior, -100: "unknown"} 1304 label_keys = [1, -100] 1305 color_list = ["blue", "gray"] 1306 if whole_video: 1307 predicted = self.generate_full_length_prediction(dataset) 1308 keys = list(predicted.keys()) 1309 counter = 0 1310 if whole_video: 1311 max_iter = len(keys) * 5 1312 else: 1313 max_iter = len(dataset) * 5 1314 ok = False 1315 while not ok: 1316 counter += 1 1317 if counter > max_iter: 1318 raise RuntimeError( 1319 "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive" 1320 ) 1321 if whole_video: 1322 i = randint(0, len(keys) - 1) 1323 prediction = predicted[keys[i]] 1324 keys_i = list(prediction.keys()) 1325 j = randint(0, len(keys_i) - 1) 1326 full_p = prediction[keys_i[j]] 1327 prediction = prediction[keys_i[j]][label_ind] 1328 else: 1329 dataloader = DataLoader(dataset) 1330 i = randint(0, len(dataloader) - 1) 1331 for num, batch in enumerate(dataloader): 1332 if num == i: 1333 break 1334 input_data = {k: v.to(self.device) for k, v in batch["input"].items()} 1335 prediction, *_ = self._get_prediction( 1336 input_data, batch.get("tag"), augment_n=5 1337 ) 1338 prediction = self._apply_predict_functions(prediction) 1339 j = randint(0, len(prediction) - 1) 1340 full_p = prediction[j] 1341 prediction = prediction[j][label_ind] 1342 classes = [x for x in torch.unique(prediction) if int(x) in label_keys] 1343 ok = 1 in classes 1344 fig, ax = plt.subplots(figsize=(width, 2)) 1345 for c in classes: 1346 c_i = label_keys.index(int(c)) 1347 output, indices, counts = torch.unique_consecutive( 1348 prediction == c, return_inverse=True, return_counts=True 1349 ) 1350 long_indices = torch.where(output)[0] 1351 res_indices_start = [ 1352 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 1353 ] 1354 res_indices_end = [ 1355 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1356 for i in long_indices 1357 ] 1358 res_indices_len = [ 1359 end - start for start, end in zip(res_indices_start, res_indices_end) 1360 ] 1361 ax.broken_barh( 1362 list(zip(res_indices_start, res_indices_len)), 1363 (0, 1), 1364 label=labels[int(c)], 1365 facecolors=color_list[c_i], 1366 ) 1367 if ground_truth: 1368 gt = batch["target"][j][label_ind].to(self.device) 1369 classes_gt = [x for x in torch.unique(gt) if int(x) in label_keys] 1370 for c in classes_gt: 1371 c_i = label_keys.index(int(c)) 1372 if c in classes: 1373 behavior = None 1374 else: 1375 behavior = labels[int(c)] 1376 output, indices, counts = torch.unique_consecutive( 1377 gt == c, return_inverse=True, return_counts=True 1378 ) 1379 long_indices = torch.where(output * (counts > 5))[0] 1380 res_indices_start = [ 1381 (indices == i).nonzero(as_tuple=True)[0][0].item() 1382 for i in long_indices 1383 ] 1384 res_indices_end = [ 1385 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1386 for i in long_indices 1387 ] 1388 res_indices_len = [ 1389 end - start 1390 for start, end in zip(res_indices_start, res_indices_end) 1391 ] 1392 ax.broken_barh( 1393 list(zip(res_indices_start, res_indices_len)), 1394 (1.5, 1), 1395 facecolors=color_list[c_i], 1396 label=behavior, 1397 ) 1398 self._compute( 1399 main_target=batch["target"][j].unsqueeze(0).to(self.device), 1400 predicted=full_p.unsqueeze(0).to(self.device), 1401 ssl_targets=[], 1402 ssl_predicted=[], 1403 skip_loss=True, 1404 ) 1405 metrics = self._calculate_metrics() 1406 if smooth_interval > 0: 1407 smoothed = self._smooth(full_p, smooth_interval=smooth_interval)[ 1408 label_ind, : 1409 ] 1410 for c in classes: 1411 c_i = label_keys.index(int(c)) 1412 output, indices, counts = torch.unique_consecutive( 1413 smoothed == c, return_inverse=True, return_counts=True 1414 ) 1415 long_indices = torch.where(output)[0] 1416 res_indices_start = [ 1417 (indices == i).nonzero(as_tuple=True)[0][0].item() 1418 for i in long_indices 1419 ] 1420 res_indices_end = [ 1421 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1422 for i in long_indices 1423 ] 1424 res_indices_len = [ 1425 end - start 1426 for start, end in zip(res_indices_start, res_indices_end) 1427 ] 1428 ax.broken_barh( 1429 list(zip(res_indices_start, res_indices_len)), 1430 (3, 1), 1431 label=labels[int(c)], 1432 facecolors=color_list[c_i], 1433 ) 1434 keys = list(metrics.keys()) 1435 for key in keys: 1436 if key.split("_")[-1] != (str(label_ind)): 1437 metrics.pop(key) 1438 title = [title] 1439 for key, value in metrics.items(): 1440 title.append(f"{'_'.join(key.split('_')[: -1])}: {value:.2f}") 1441 title = ", ".join(title) 1442 if not ground_truth: 1443 ax.axes.yaxis.set_visible(False) 1444 else: 1445 ax.set_yticks([0.5, 2]) 1446 ax.set_yticklabels(["prediction", "ground truth"]) 1447 if add_legend: 1448 ax.legend() 1449 if hide_axes: 1450 plt.axis("off") 1451 plt.title(title) 1452 plt.xlim((0, len(prediction))) 1453 if save_path is not None: 1454 plt.savefig(save_path, transparent=transparent) 1455 plt.show() 1456 1457 def visualize_results( 1458 self, 1459 save_path: str = None, 1460 add_legend: bool = True, 1461 ground_truth: bool = True, 1462 colormap: str = "viridis", 1463 hide_axes: bool = False, 1464 min_classes: int = 1, 1465 width: int = 10, 1466 whole_video: bool = False, 1467 transparent: bool = False, 1468 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1469 drop_classes: Set = None, 1470 search_classes: Set = None, 1471 smooth_interval_prediction: int = None, 1472 font_size: float = None, 1473 num_plots: int = 1, 1474 window_size:int =400 1475 ): 1476 """Visualize random predictions. 1477 1478 Parameters 1479 ---------- 1480 save_path : str, optional 1481 the path where the plot will be saved 1482 add_legend : bool, default True 1483 if `True`, legend will be added to the plot 1484 ground_truth : bool, default True 1485 if `True`, ground truth will be added to the plot 1486 colormap : str, default 'Accent' 1487 the `matplotlib` colormap to use 1488 hide_axes : bool, default True 1489 if `True`, the axes will be hidden on the plot 1490 min_classes : int, default 1 1491 the minimum number of classes in a displayed interval 1492 width : float, default 10 1493 the width of the plot 1494 whole_video : bool, default False 1495 if `True`, whole videos are plotted instead of segments 1496 transparent : bool, default False 1497 if `True`, the background on the plot is transparent 1498 dataset : BehaviorDataset | DataLoader | str | None, optional 1499 the dataset to make the prediction for (if not provided, the validation dataset is used) 1500 drop_classes : set, optional 1501 a set of class names to not be displayed 1502 search_classes : set, optional 1503 if given, only intervals where at least one of the classes is in ground truth will be shown 1504 smooth_interval_prediction : int, optional 1505 if given, the prediction will be smoothed with a moving average of the given size 1506 1507 """ 1508 if drop_classes is None: 1509 drop_classes = [] 1510 dataset = self._get_dataset(dataset) 1511 if dataset.annotation_class() != "exclusive_classification": 1512 raise NotImplementedError( 1513 "Results visualisation is only implemented for exclusive classification datasets!" 1514 ) 1515 labels = { 1516 k: v for k, v in dataset.behaviors_dict().items() if v not in drop_classes 1517 } 1518 labels.update({-100: "unknown"}) 1519 label_keys = sorted([int(x) for x in labels.keys()]) 1520 if search_classes is None: 1521 ok = True 1522 else: 1523 ok = False 1524 classes = [] 1525 predicted = self.generate_full_length_prediction(dataset) 1526 keys = list(predicted.keys()) 1527 counter = 0 1528 max_iter = len(keys) * 2 1529 while len(classes) < min_classes or not ok: 1530 counter += 1 1531 if counter > max_iter: 1532 raise RuntimeError( 1533 "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive" 1534 ) 1535 i = randint(0, len(keys) - 1) 1536 prediction = predicted[keys[i]] 1537 keys_i = list(prediction.keys()) 1538 j = randint(0, len(keys_i) - 1) 1539 prediction = prediction[keys_i[j]] 1540 key1 = keys[i] 1541 key2 = keys_i[j] 1542 1543 if smooth_interval_prediction > 0: 1544 unsmoothed_prediction = deepcopy(prediction) 1545 prediction = self._smooth(prediction, smooth_interval_prediction) 1546 height = 3 1547 else: 1548 height = 2 1549 classes = [ 1550 labels[int(x)] for x in torch.unique(prediction) if x in label_keys 1551 ] 1552 if search_classes is not None: 1553 ok = any([x in classes for x in search_classes]) 1554 fig, ax = plt.subplots(figsize=(width, height)) 1555 cmap = cm.get_cmap(colormap) if colormap != "dlc2action" else None 1556 color_list = ( 1557 [cmap(c) for c in np.linspace(0, 1, len(labels))] 1558 if colormap != "dlc2action" 1559 else dlc2action_colormaps["default"] 1560 ) 1561 1562 def _plot_prediction(prediction, height, set_labels=True): 1563 for c in label_keys: 1564 c_i = label_keys.index(int(c)) 1565 output, indices, counts = torch.unique_consecutive( 1566 prediction == c, return_inverse=True, return_counts=True 1567 ) 1568 long_indices = torch.where(output)[0] 1569 if len(long_indices) == 0: 1570 continue 1571 res_indices_start = [ 1572 (indices == i).nonzero(as_tuple=True)[0][0].item() 1573 for i in long_indices 1574 ] 1575 res_indices_end = [ 1576 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1577 for i in long_indices 1578 ] 1579 res_indices_len = [ 1580 end - start 1581 for start, end in zip(res_indices_start, res_indices_end) 1582 ] 1583 if set_labels: 1584 label = labels[int(c)] 1585 else: 1586 label = None 1587 ax.broken_barh( 1588 list(zip(res_indices_start, res_indices_len)), 1589 (height, 1), 1590 label=label, 1591 facecolors=color_list[c_i], 1592 ) 1593 1594 if smooth_interval_prediction > 0: 1595 _plot_prediction(unsmoothed_prediction, 0) 1596 _plot_prediction(prediction, 1.5, set_labels=False) 1597 gt_height = 3 1598 else: 1599 _plot_prediction(prediction, 0) 1600 gt_height = 1.5 1601 if ground_truth: 1602 gt = dataset.generate_full_length_gt()[key1][key2] 1603 for c in label_keys: 1604 c_i = label_keys.index(int(c)) 1605 if labels[int(c)] in classes: 1606 label = None 1607 else: 1608 label = labels[int(c)] 1609 output, indices, counts = torch.unique_consecutive( 1610 gt == c, return_inverse=True, return_counts=True 1611 ) 1612 long_indices = torch.where(output)[0] 1613 if len(long_indices) == 0: 1614 continue 1615 res_indices_start = [ 1616 (indices == i).nonzero(as_tuple=True)[0][0].item() 1617 for i in long_indices 1618 ] 1619 res_indices_end = [ 1620 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1621 for i in long_indices 1622 ] 1623 res_indices_len = [ 1624 end - start 1625 for start, end in zip(res_indices_start, res_indices_end) 1626 ] 1627 ax.broken_barh( 1628 list(zip(res_indices_start, res_indices_len)), 1629 (gt_height, 1), 1630 facecolors=color_list[c_i] if c != "unknown" else "gray", 1631 label=label, 1632 ) 1633 if not ground_truth: 1634 ax.axes.yaxis.set_visible(False) 1635 else: 1636 if smooth_interval_prediction > 0: 1637 ax.set_yticks([0.5, 2, 3.5]) 1638 ax.set_yticklabels(["prediction", "smoothed", "ground truth"]) 1639 else: 1640 ax.set_yticks([0.5, 2]) 1641 ax.set_yticklabels(["prediction", "ground truth"]) 1642 if add_legend: 1643 ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) 1644 if hide_axes: 1645 # plt.axis("off") 1646 plt.box(False) 1647 if font_size is not None: 1648 font = {"size": font_size} 1649 rc("font", **font) 1650 plt.title(f"{key1} -- {key2}") 1651 plt.tight_layout() 1652 for seed in range(num_plots): 1653 if not whole_video: 1654 ax.set_xlim((seed * window_size, (seed + 1) * window_size)) 1655 if save_path is not None: 1656 plt.savefig( 1657 save_path.replace(".svg", f"_{seed}_{key1} -- {key2}.svg"), transparent=transparent 1658 ) 1659 print(f"Saved in {save_path.replace('.svg', f'_{seed}_{key1} -- {key2}.svg')}") 1660 plt.show() 1661 plt.close() 1662 1663 def set_ssl_transformations(self, ssl_transformations): 1664 """Set SSL transformations. 1665 1666 Parameters 1667 ---------- 1668 ssl_transformations : list 1669 a list of callable SSL transformations 1670 1671 """ 1672 self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations) 1673 if self.val_dataloader is not None: 1674 self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations) 1675 1676 def set_ssl_losses(self, ssl_losses: list) -> None: 1677 """Set SSL losses. 1678 1679 Parameters 1680 ---------- 1681 ssl_losses : list 1682 a list of callable SSL losses 1683 1684 """ 1685 self.ssl_losses = ssl_losses 1686 1687 def set_log(self, log: str) -> None: 1688 """Set the log file. 1689 1690 Parameters 1691 ---------- 1692 log: str 1693 the mew log file path 1694 1695 """ 1696 self.log_file = log 1697 1698 def set_keep_target_none(self, keep_target_none: List) -> None: 1699 """Set the keep_target_none parameter of the transformer. 1700 1701 Parameters 1702 ---------- 1703 keep_target_none : list 1704 a list of bool values 1705 1706 """ 1707 self.transformer.keep_target_none = keep_target_none 1708 1709 def set_generate_ssl_input(self, generate_ssl_input: list) -> None: 1710 """Set the generate_ssl_input parameter of the transformer. 1711 1712 Parameters 1713 ---------- 1714 generate_ssl_input : list 1715 a list of bool values 1716 1717 """ 1718 self.transformer.generate_ssl_input = generate_ssl_input 1719 1720 def set_model(self, model: Model) -> None: 1721 """Set a new model. 1722 1723 Parameters 1724 ---------- 1725 model: Model 1726 the new model 1727 1728 """ 1729 self.epoch = 0 1730 self.model = model 1731 self.optimizer = self.optimizer_class( 1732 model.parameters(), lr=self.lr, weight_decay=self.weight_decay 1733 ) 1734 if self.model.process_labels: 1735 self.model.set_behaviors(list(self.behaviors_dict().values())) 1736 1737 def set_dataloaders( 1738 self, 1739 train_dataloader: DataLoader, 1740 val_dataloader: DataLoader = None, 1741 test_dataloader: DataLoader = None, 1742 ) -> None: 1743 """Set new dataloaders. 1744 1745 Parameters 1746 ---------- 1747 train_dataloader: torch.utils.data.DataLoader 1748 the new train dataloader 1749 val_dataloader : torch.utils.data.DataLoader 1750 the new validation dataloader 1751 test_dataloader : torch.utils.data.DataLoader 1752 the new test dataloader 1753 1754 """ 1755 self.train_dataloader = train_dataloader 1756 self.val_dataloader = val_dataloader 1757 self.test_dataloader = test_dataloader 1758 1759 def set_loss(self, loss: Callable) -> None: 1760 """Set new loss function. 1761 1762 Parameters 1763 ---------- 1764 loss: callable 1765 the new loss function 1766 1767 """ 1768 self.loss = loss 1769 1770 def set_metrics(self, metrics: dict) -> None: 1771 """Set new metric. 1772 1773 Parameters 1774 ---------- 1775 metrics : dict 1776 the new metric dictionary 1777 1778 """ 1779 self.metrics = metrics 1780 1781 def set_transformer(self, transformer: Transformer) -> None: 1782 """Set a new transformer. 1783 1784 Parameters 1785 ---------- 1786 transformer: Transformer 1787 the new transformer 1788 1789 """ 1790 self.transformer = transformer 1791 1792 def set_predict_functions( 1793 self, primary_predict_function: Callable, predict_function: Callable 1794 ) -> None: 1795 """Set new predict functions. 1796 1797 Parameters 1798 ---------- 1799 primary_predict_function : callable 1800 the new primary predict function 1801 predict_function : callable 1802 the new predict function 1803 1804 """ 1805 self.primary_predict_function = primary_predict_function 1806 self.predict_function = predict_function 1807 1808 def _set_pseudolabels(self): 1809 """Set pseudolabels.""" 1810 self.train_dataloader.dataset.set_unlabeled(True) 1811 predicted = self.predict( 1812 data=self.dataset("train"), 1813 raw_output=False, 1814 augment_n=self.augment_val, 1815 ssl_off=True, 1816 ) 1817 self.train_dataloader.dataset.set_annotation(predicted.detach()) 1818 1819 def _alpha(self, epoch): 1820 """Get the current pseudolabeling alpha parameter.""" 1821 if epoch <= self.T1: 1822 return 0 1823 elif epoch < self.T2: 1824 return self.alpha_f * (epoch - self.T1) / (self.T2 - self.T1) 1825 else: 1826 return self.alpha_f 1827 1828 def count_classes(self, bouts: bool = False) -> Dict: 1829 """Get a dictionary of class counts in different modes. 1830 1831 Parameters 1832 ---------- 1833 bouts : bool, default False 1834 if `True`, instead of frame counts segment counts are returned 1835 1836 Returns 1837 ------- 1838 class_counts : dict 1839 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1840 class names and values are class counts (in frames) 1841 1842 """ 1843 class_counts = {} 1844 for x in ["train", "val", "test"]: 1845 try: 1846 class_counts[x] = self.dataset(x).count_classes(bouts) 1847 except ValueError: 1848 class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()} 1849 return class_counts 1850 1851 def behaviors_dict(self) -> Dict: 1852 """Get a behavior dictionary. 1853 1854 Keys are label indices and values are label names. 1855 1856 Returns 1857 ------- 1858 behaviors_dict : dict 1859 behavior dictionary 1860 1861 """ 1862 return self.dataset().behaviors_dict() 1863 1864 def update_parameters(self, parameters: Dict) -> None: 1865 """Update training parameters from a dictionary. 1866 1867 Parameters 1868 ---------- 1869 parameters : dict 1870 the update dictionary 1871 1872 """ 1873 self.lr = parameters.get("lr", self.lr) 1874 self.parallel = parameters.get("parallel", self.parallel) 1875 self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr) 1876 self.verbose = parameters.get("verbose", self.verbose) 1877 self.device = parameters.get("device", self.device) 1878 if self.device == "auto": 1879 self.device = "cuda" if torch.cuda.is_available() else "cpu" 1880 self.augment_train = int(parameters.get("augment_train", self.augment_train)) 1881 self.augment_val = int(parameters.get("augment_val", self.augment_val)) 1882 ssl_weights = parameters.get("ssl_weights", self.ssl_weights) 1883 if ssl_weights is None: 1884 ssl_weights = [1 for _ in self.ssl_losses] 1885 if not isinstance(ssl_weights, Iterable): 1886 ssl_weights = [ssl_weights for _ in self.ssl_losses] 1887 self.ssl_weights = ssl_weights 1888 self.num_epochs = parameters.get("num_epochs", self.num_epochs) 1889 self.model_save_epochs = parameters.get( 1890 "model_save_epochs", self.model_save_epochs 1891 ) 1892 self.model_save_path = parameters.get("model_save_path", self.model_save_path) 1893 self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel) 1894 self.T1 = int(parameters.get("pseudolabel_start", self.T1)) 1895 self.T2 = int(parameters.get("alpha_growth_stop", self.T2)) 1896 self.t = int(parameters.get("correction_interval", self.t)) 1897 self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f) 1898 self.log_file = parameters.get("log_file", self.log_file) 1899 1900 def generate_uncertainty_score( 1901 self, 1902 classes: List, 1903 augment_n: int = 0, 1904 batch_size: int = 32, 1905 method: str = "least_confidence", 1906 predicted: torch.Tensor = None, 1907 behaviors_dict: Dict = None, 1908 ) -> Dict: 1909 """Generate frame-wise scores for active learning. 1910 1911 Parameters 1912 ---------- 1913 classes : list 1914 a list of class names or indices; their confidence scores will be computed separately and stacked 1915 augment_n : int, default 0 1916 the number of augmentations to average over 1917 batch_size : int, default 32 1918 the batch size 1919 method : {"least_confidence", "entropy"} 1920 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1921 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1922 predicted : torch.Tensor, default None 1923 if not `None`, the predicted tensor to use instead of predicting from the model 1924 behaviors_dict : dict, default None 1925 if not `None`, the behaviors dictionary to use instead of the one from the dataset 1926 1927 Returns 1928 ------- 1929 score_dicts : dict 1930 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1931 are score tensors 1932 1933 """ 1934 dataset = self.dataset("train") 1935 if behaviors_dict is None: 1936 behaviors_dict = self.behaviors_dict() 1937 if not isinstance(dataset, BehaviorDataset): 1938 raise TypeError( 1939 f"The dataset parameter has to be either None, string, " 1940 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1941 ) 1942 if predicted is None: 1943 predicted = self.predict( 1944 dataset, 1945 raw_output=True, 1946 apply_primary_function=True, 1947 augment_n=augment_n, 1948 batch_size=batch_size, 1949 ) 1950 predicted = dataset.generate_full_length_prediction(predicted) 1951 if isinstance(classes[0], str): 1952 behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()} 1953 classes = [behaviors_dict_inverse[c] for c in classes] 1954 for v_id, v in predicted.items(): 1955 for clip_id, vv in v.items(): 1956 if method == "least_confidence": 1957 predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5] 1958 elif method == "entropy": 1959 predicted[v_id][clip_id][vv != -100] = ( 1960 -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv) 1961 )[vv != -100] 1962 elif method == "random": 1963 predicted[v_id][clip_id] = torch.rand(vv.shape) 1964 else: 1965 raise ValueError( 1966 f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']" 1967 ) 1968 predicted[v_id][clip_id][vv == -100] = 0 1969 1970 predicted = { 1971 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 1972 for v_id, video_dict in predicted.items() 1973 } 1974 return predicted 1975 1976 def generate_bald_score( 1977 self, 1978 classes: List, 1979 augment_n: int = 0, 1980 batch_size: int = 32, 1981 num_models: int = 10, 1982 kernel_size: int = 11, 1983 ) -> Dict: 1984 """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning. 1985 1986 Parameters 1987 ---------- 1988 classes : list 1989 a list of class names or indices; their confidence scores will be computed separately and stacked 1990 augment_n : int, default 0 1991 the number of augmentations to average over 1992 batch_size : int, default 32 1993 the batch size 1994 num_models : int, default 10 1995 the number of dropout masks to apply 1996 kernel_size : int, default 11 1997 the size of the smoothing gaussian kernel 1998 1999 Returns 2000 ------- 2001 score_dicts : dict 2002 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 2003 are score tensors 2004 2005 """ 2006 dataset = self.dataset("train") 2007 dataset = self._get_dataset(dataset) 2008 if not isinstance(dataset, BehaviorDataset): 2009 raise TypeError( 2010 f"The dataset parameter has to be either None, string, " 2011 f"BehaviorDataset or Dataloader, got {type(dataset)}" 2012 ) 2013 predictions = [] 2014 for _ in range(num_models): 2015 predicted = self.predict( 2016 dataset, 2017 raw_output=True, 2018 apply_primary_function=True, 2019 augment_n=augment_n, 2020 batch_size=batch_size, 2021 train_mode=True, 2022 ) 2023 predicted = dataset.generate_full_length_prediction(predicted) 2024 if isinstance(classes[0], str): 2025 behaviors_dict_inverse = { 2026 v: k for k, v in self.behaviors_dict().items() 2027 } 2028 classes = [behaviors_dict_inverse[c] for c in classes] 2029 for v_id, v in predicted.items(): 2030 for clip_id, vv in v.items(): 2031 vv[vv != -100] = (vv[vv != -100] > 0.5).int().float() 2032 predicted[v_id][clip_id] = vv 2033 predicted = { 2034 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 2035 for v_id, video_dict in predicted.items() 2036 } 2037 predictions.append(predicted) 2038 result = {v_id: {} for v_id in predictions[0]} 2039 r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1) 2040 gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r] 2041 gauss = [x / sum(gauss) for x in gauss] 2042 kernel = torch.FloatTensor([[gauss]]) 2043 for v_id in predictions[0]: 2044 for clip_id in predictions[0][v_id]: 2045 consensus = ( 2046 ( 2047 torch.mean( 2048 torch.stack([x[v_id][clip_id] for x in predictions]), dim=0 2049 ) 2050 > 0.5 2051 ) 2052 .int() 2053 .float() 2054 ) 2055 consensus[predictions[0][v_id][clip_id] == -100] = -100 2056 result[v_id][clip_id] = torch.zeros(consensus.shape) 2057 for x in predictions: 2058 result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int() 2059 result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models 2060 res = torch.zeros(result[v_id][clip_id].shape) 2061 for i in range(len(classes)): 2062 res[i, floor(kernel_size // 2) : -floor(kernel_size // 2)] = ( 2063 torch.nn.functional.conv1d( 2064 result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0), 2065 kernel, 2066 )[0, ...] 2067 ) 2068 result[v_id][clip_id] = res 2069 return result 2070 2071 def get_normalization_stats(self) -> Optional[Dict]: 2072 """Get the normalization statistics of the dataset. 2073 2074 Returns 2075 ------- 2076 stats : dict 2077 a dictionary containing the mean and standard deviation of the dataset 2078 2079 """ 2080 return self.train_dataloader.dataset.stats
39class MyDataParallel(nn.DataParallel): 40 """A wrapper for nn.DataParallel that allows to call methods of the model.""" 41 42 def __init__(self, *args, **kwargs): 43 """Initialize the class.""" 44 super().__init__(*args, **kwargs) 45 self.process_labels = self.module.process_labels 46 47 def freeze_feature_extractor(self): 48 """Freeze the feature extractor.""" 49 self.module.freeze_feature_extractor() 50 51 def unfreeze_feature_extractor(self): 52 """Unfreeze the feature extractor.""" 53 self.module.unfreeze_feature_extractor() 54 55 def transform_labels(self, device): 56 """Transform labels to the device.""" 57 return self.module.transform_labels(device) 58 59 def logit_scale(self): 60 """Return the logit scale of the model.""" 61 return self.module.logit_scale() 62 63 def main_task_off(self): 64 """Turn off the main task training.""" 65 self.module.main_task_off() 66 67 def state_dict(self, *args, **kwargs): 68 """Return the state of the module.""" 69 return self.module.state_dict(*args, **kwargs) 70 71 def ssl_on(self): 72 """Turn on SSL training.""" 73 self.module.ssl_on() 74 75 def ssl_off(self): 76 """Turn off SSL training.""" 77 self.module.ssl_off() 78 79 def extract_features(self, x, start=0): 80 """Extract features from the model.""" 81 return self.module.extract_features(x, start)
A wrapper for nn.DataParallel that allows to call methods of the model.
47 def freeze_feature_extractor(self): 48 """Freeze the feature extractor.""" 49 self.module.freeze_feature_extractor()
Freeze the feature extractor.
51 def unfreeze_feature_extractor(self): 52 """Unfreeze the feature extractor.""" 53 self.module.unfreeze_feature_extractor()
Unfreeze the feature extractor.
55 def transform_labels(self, device): 56 """Transform labels to the device.""" 57 return self.module.transform_labels(device)
Transform labels to the device.
59 def logit_scale(self): 60 """Return the logit scale of the model.""" 61 return self.module.logit_scale()
Return the logit scale of the model.
63 def main_task_off(self): 64 """Turn off the main task training.""" 65 self.module.main_task_off()
Turn off the main task training.
84class Task: 85 """A universal trainer class that performs training, evaluation and prediction for all types of tasks and data.""" 86 87 def __init__( 88 self, 89 train_dataloader: DataLoader, 90 model: Union[nn.Module, Model], 91 loss: Callable[[torch.Tensor, torch.Tensor], float], 92 num_epochs: int = 0, 93 transformer: Transformer = None, 94 ssl_losses: List = None, 95 ssl_weights: List = None, 96 lr: float = 1e-3, 97 weight_decay: float = 0, 98 metrics: Dict = None, 99 val_dataloader: DataLoader = None, 100 test_dataloader: DataLoader = None, 101 optimizer: Optimizer = None, 102 device: str = "cuda", 103 verbose: bool = True, 104 log_file: Union[str, None] = None, 105 augment_train: int = 1, 106 augment_val: int = 0, 107 validation_interval: int = 1, 108 predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None, 109 primary_predict_function: Callable = None, 110 exclusive: bool = True, 111 ignore_tags: bool = True, 112 threshold: float = 0.5, 113 model_save_path: str = None, 114 model_save_epochs: int = 5, 115 pseudolabel: bool = False, 116 pseudolabel_start: int = 100, 117 correction_interval: int = 2, 118 pseudolabel_alpha_f: float = 3, 119 alpha_growth_stop: int = 600, 120 parallel: bool = False, 121 skip_metrics: List = None, 122 ) -> None: 123 """Initialize the class. 124 125 Parameters 126 ---------- 127 train_dataloader : torch.utils.data.DataLoader 128 a training dataloader 129 model : dlc2action.model.base_model.Model 130 a model 131 loss : callable 132 a loss function 133 num_epochs : int, default 0 134 the number of epochs 135 transformer : dlc2action.transformer.base_transformer.Transformer, optional 136 a transformer 137 ssl_losses : list, optional 138 a list of SSL losses 139 ssl_weights : list, optional 140 a list of SSL weights (if not provided initializes to 1) 141 lr : float, default 1e-3 142 learning rate 143 weight_decay : float, default 0 144 weight decay 145 metrics : dict, optional 146 a list of metric functions 147 val_dataloader : torch.utils.data.DataLoader, optional 148 a validation dataloader 149 test_dataloader : torch.utils.data.DataLoader, optional 150 a test dataloader 151 optimizer : torch.optim.Optimizer, optional 152 an optimizer (`Adam` by default) 153 device : str, default 'cuda' 154 the device to train the model on 155 verbose : bool, default True 156 if `True`, the process is described in standard output 157 log_file : str, optional 158 the path to a text file where the process will be logged 159 augment_train : {1, 0} 160 number of augmentations to apply at training 161 augment_val : int, default 0 162 number of augmentations to apply at validation 163 validation_interval : int, default 1 164 every time this number of epochs passes, validation metrics are computed 165 predict_function : callable, optional 166 a function that maps probabilities to class predictions (if not provided, a default is generated) 167 primary_predict_function : callable, optional 168 a function that maps model output to probabilities (if not provided, initialized as identity) 169 exclusive : bool, default True 170 set to False for multi-label classification 171 ignore_tags : bool, default False 172 if `True`, samples with different meta tags will be mixed in batches 173 threshold : float, default 0.5 174 the threshold used for multi-label classification default prediction function 175 model_save_path : str, optional 176 the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path 177 is not provided) 178 model_save_epochs : int, default 5 179 the interval for saving the model checkpoints (the last epoch is always saved) 180 pseudolabel : bool, default False 181 if True, the pseudolabeling procedure will be applied 182 pseudolabel_start : int, default 100 183 pseudolabeling starts after this epoch 184 correction_interval : int, default 1 185 after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and 186 new pseudolabels are generated 187 pseudolabel_alpha_f : float, default 3 188 the maximum value of pseudolabeling alpha 189 alpha_growth_stop : int, default 600 190 pseudolabeling alpha stops growing after this epoch 191 parallel : bool, default False 192 if True, the model is trained on multiple GPUs 193 skip_metrics : list, optional 194 a list of metrics to skip 195 196 """ 197 # pseudolabeling might be buggy right now -- not using it! 198 if skip_metrics is None: 199 skip_metrics = [] 200 self.train_dataloader = train_dataloader 201 self.val_dataloader = val_dataloader 202 self.test_dataloader = test_dataloader 203 self.transformer = transformer 204 self.num_epochs = num_epochs 205 self.skip_metrics = skip_metrics 206 self.verbose = verbose 207 self.augment_train = int(augment_train) 208 self.augment_val = int(augment_val) 209 self.ignore_tags = ignore_tags 210 self.validation_interval = int(validation_interval) 211 self.log_file = log_file 212 self.loss = loss 213 self.model_save_path = model_save_path 214 self.model_save_epochs = model_save_epochs 215 self.epoch = 0 216 217 if metrics is None: 218 metrics = {} 219 self.metrics = metrics 220 221 if optimizer is None: 222 optimizer = Adam 223 224 if ssl_weights is None: 225 ssl_weights = [1 for _ in ssl_losses] 226 if not isinstance(ssl_weights, Iterable): 227 ssl_weights = [ssl_weights for _ in ssl_losses] 228 self.ssl_weights = ssl_weights 229 230 self.optimizer_class = optimizer 231 self.lr = lr 232 self.weight_decay = weight_decay 233 if not isinstance(model, Model): 234 self.model = LoadedModel(model=model) 235 else: 236 self.set_model(model) 237 self.parallel = parallel 238 239 if self.transformer is None: 240 self.augment_val = 0 241 self.augment_train = 0 242 self.transformer = EmptyTransformer() 243 244 if self.augment_train > 1: 245 warnings.warn( 246 'The "augment_train" parameter is too large -> setting it to 1.' 247 ) 248 self.augment_train = 1 249 250 try: 251 if device == "auto": 252 device = "cuda" if torch.cuda.is_available() else "cpu" 253 self.device = torch.device(device) 254 except: 255 raise ("The format of the device is incorrect") 256 257 if ssl_losses is None: 258 self.ssl_losses = [lambda x, y: 0] 259 else: 260 self.ssl_losses = ssl_losses 261 262 if primary_predict_function is None: 263 if exclusive: 264 primary_predict_function = lambda x: nn.Softmax(x, dim=1) 265 else: 266 primary_predict_function = lambda x: torch.sigmoid(x) 267 self.primary_predict_function = primary_predict_function 268 269 if predict_function is None: 270 if exclusive: 271 self.predict_function = lambda x: torch.max(x.data, 1)[1] 272 else: 273 self.predict_function = lambda x: (x > threshold).int() 274 else: 275 self.predict_function = predict_function 276 277 self.pseudolabel = pseudolabel 278 self.alpha_f = pseudolabel_alpha_f 279 self.T2 = alpha_growth_stop 280 self.T1 = pseudolabel_start 281 self.t = correction_interval 282 if self.T2 <= self.T1: 283 raise ValueError( 284 f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got " 285 f"{pseudolabel_start=} and {alpha_growth_stop=}" 286 ) 287 self.decision_thresholds = [0.5 for x in self.behaviors_dict()] 288 289 def save_checkpoint(self, checkpoint_path: str) -> None: 290 """Save a general checkpoint. 291 292 Parameters 293 ---------- 294 checkpoint_path : str 295 the path where the checkpoint will be saved 296 297 """ 298 torch.save( 299 { 300 "epoch": self.epoch, 301 "model_state_dict": self.model.state_dict(), 302 "optimizer_state_dict": self.optimizer.state_dict(), 303 }, 304 checkpoint_path, 305 ) 306 307 def load_from_checkpoint( 308 self, checkpoint_path, only_model: bool = False, load_strict: bool = True 309 ) -> None: 310 """Load from a checkpoint. 311 312 Parameters 313 ---------- 314 checkpoint_path : str 315 the path to the checkpoint 316 only_model : bool, default False 317 if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state 318 dictionary) 319 load_strict : bool, default True 320 if `True`, any inconsistencies in state dictionaries are regarded as errors 321 322 """ 323 if checkpoint_path is None: 324 return 325 checkpoint = torch.load( 326 checkpoint_path, map_location=self.device, weights_only=False 327 ) 328 self.model.to(self.device) 329 self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict) 330 if not only_model: 331 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 332 self.epoch = checkpoint["epoch"] 333 334 def save_model(self, save_path: str) -> None: 335 """Save the model state dictionary. 336 337 Parameters 338 ---------- 339 save_path : str 340 the path where the state will be saved 341 342 """ 343 torch.save(self.model.state_dict(), save_path) 344 print("saved the model successfully") 345 346 def _apply_predict_functions(self, predicted): 347 """Map from model output to prediction.""" 348 predicted = self.primary_predict_function(predicted) 349 predicted = self.predict_function(predicted) 350 return predicted 351 352 def _get_prediction( 353 self, 354 main_input: Dict, 355 tags: torch.Tensor, 356 ssl_inputs: List = None, 357 ssl_targets: List = None, 358 augment_n: int = 0, 359 embedding: bool = False, 360 subsample: List = None, 361 ) -> Tuple: 362 """Get the prediction of `self.model` for input averaged over `augment_n` augmentations.""" 363 if augment_n == 0: 364 augment_n = 1 365 augment = False 366 else: 367 augment = True 368 model_input, ssl_inputs, ssl_targets = self.transformer.transform( 369 deepcopy(main_input), 370 ssl_inputs=ssl_inputs, 371 ssl_targets=ssl_targets, 372 augment=augment, 373 subsample=subsample, 374 ) 375 if not embedding: 376 # t1 = time() 377 predicted, ssl_predicted = self.model(model_input, ssl_inputs, tag=tags) 378 # t2 = time() 379 # with open("/home/andy/Documents/EPFLsmartkitchenrecording/baselines/DLC2Action_baselines/inference_times/body_hand_eyes_2405.txt", "a") as f: 380 # f.write(f"Time: {t2 - t1}\n") 381 else: 382 predicted = self.model.extract_features(model_input) 383 ssl_predicted = None 384 if self.parallel and predicted is not None and len(predicted.shape) == 4: 385 predicted = predicted.reshape( 386 (-1, model_input.shape[0], *predicted.shape[2:]) 387 ) 388 if predicted is not None and augment_n > 1: 389 self.model.ssl_off() 390 for i in range(augment_n - 1): 391 model_input, *_ = self.transformer.transform( 392 deepcopy(main_input), augment=augment 393 ) 394 if not embedding: 395 pred, _ = self.model(model_input, None, tag=tags) 396 else: 397 pred = self.model.extract_features(model_input) 398 predicted += pred.detach() 399 self.model.ssl_on() 400 predicted /= augment_n 401 if self.model.process_labels: 402 class_embedding = self.model.transform_labels(predicted.device) 403 beh_dict = self.behaviors_dict() 404 class_embedding = torch.cat( 405 [class_embedding[beh_dict[k]] for k in sorted(beh_dict.keys())], 0 406 ) 407 predicted = { 408 "video": predicted, 409 "text": class_embedding, 410 "logit_scale": self.model.logit_scale(), 411 "device": predicted.device, 412 } 413 return predicted, ssl_predicted, ssl_targets 414 415 def _ssl_loss(self, ssl_predicted, ssl_targets): 416 """Apply SSL losses.""" 417 if self.ssl_losses is None or ssl_predicted is None or ssl_targets is None: 418 return [] 419 420 ssl_loss = [] 421 for loss, predicted, target in zip(self.ssl_losses, ssl_predicted, ssl_targets): 422 ssl_loss.append(loss(predicted, target)) 423 return ssl_loss 424 425 def _loss_function( 426 self, 427 batch: Dict, 428 augment_n: int, 429 temporal_subsampling_size: int = None, 430 skip_metrics: List = None, 431 ) -> Tuple[float, float]: 432 """Calculate the loss function and the metric for a dataloader batch. 433 434 Averaging the predictions over augment_n augmentations. 435 436 """ 437 if "target" not in batch or torch.isnan(batch["target"]).all(): 438 raise ValueError("Cannot compute loss function with nan targets!") 439 main_input = {k: v.to(self.device) for k, v in batch["input"].items()} 440 main_target, ssl_targets, ssl_inputs = None, None, None 441 main_target = batch["target"].to(self.device) 442 if "ssl_targets" in batch: 443 ssl_targets = [ 444 ( 445 {k: v.to(self.device) for k, v in x.items()} 446 if isinstance(x, dict) 447 else None 448 ) 449 for x in batch["ssl_targets"] 450 ] 451 if "ssl_inputs" in batch: 452 ssl_inputs = [ 453 ( 454 {k: v.to(self.device) for k, v in x.items()} 455 if isinstance(x, dict) 456 else None 457 ) 458 for x in batch["ssl_inputs"] 459 ] 460 if temporal_subsampling_size is not None: 461 subsample = sorted( 462 random.sample( 463 range(main_target.shape[-1]), 464 int(temporal_subsampling_size * main_target.shape[-1]), 465 ) 466 ) 467 main_target = main_target[..., subsample] 468 else: 469 subsample = None 470 471 if self.ignore_tags: 472 tag = None 473 else: 474 tag = batch.get("tag") 475 predicted, ssl_predicted, ssl_targets = self._get_prediction( 476 main_input, tag, ssl_inputs, ssl_targets, augment_n, subsample=subsample 477 ) 478 del main_input, ssl_inputs 479 return self._compute( 480 ssl_predicted, 481 ssl_targets, 482 predicted, 483 main_target, 484 tag=batch.get("tag"), 485 skip_metrics=skip_metrics, 486 ) 487 488 def _compute( 489 self, 490 ssl_predicted: List, 491 ssl_targets: List, 492 predicted: torch.Tensor, 493 main_target: torch.Tensor, 494 tag: Any = None, 495 skip_loss: bool = False, 496 apply_primary_function: bool = True, 497 skip_metrics: List = None, 498 ) -> Tuple[float, float]: 499 """Compute the losses and metrics from predictions.""" 500 if skip_metrics is None: 501 skip_metrics = [] 502 if not skip_loss: 503 ssl_losses = self._ssl_loss(ssl_predicted, ssl_targets) 504 if predicted is not None: 505 loss = self.loss(predicted, main_target) 506 else: 507 loss = 0 508 else: 509 ssl_losses, loss = [], 0 510 511 if predicted is not None: 512 if isinstance(predicted, dict): 513 predicted = { 514 k: v.detach() 515 for k, v in predicted.items() 516 if isinstance(v, torch.Tensor) 517 } 518 else: 519 predicted = predicted.detach() 520 if apply_primary_function: 521 predicted = self.primary_predict_function(predicted) 522 predicted_transformed = self.predict_function(predicted) 523 524 for name, metric_function in self.metrics.items(): 525 if name not in skip_metrics: 526 if metric_function.needs_raw_data: 527 metric_function.update(predicted, main_target, tag) 528 else: 529 metric_function.update( 530 predicted_transformed, 531 main_target, 532 tag, 533 ) 534 return loss, ssl_losses 535 536 def _calculate_metrics(self) -> Dict: 537 """Calculate the final values of epoch metrics.""" 538 epoch_metrics = {} 539 for metric_name, metric in self.metrics.items(): 540 m = metric.calculate() 541 if type(m) is dict: 542 for k, v in m.items(): 543 if type(v) is torch.Tensor: 544 v = v.item() 545 epoch_metrics[f"{metric_name}_{k}"] = v 546 else: 547 if type(m) is torch.Tensor: 548 m = m.item() 549 epoch_metrics[metric_name] = m 550 metric.reset() 551 # print("calculate metric", epoch_metrics) 552 return epoch_metrics 553 554 def _run_epoch( 555 self, 556 dataloader: DataLoader, 557 mode: str, 558 augment_n: int, 559 verbose: bool = False, 560 unlabeled: bool = None, 561 alpha: float = 1, 562 temporal_subsampling_size: int = None, 563 ) -> Tuple: 564 """Run one epoch on dataloader. 565 566 Averaging the predictions over augment_n augmentations. 567 Use "train" mode for training and "val" mode for evaluation. 568 569 """ 570 if mode == "train": 571 self.model.train() 572 elif mode == "val": 573 self.model.eval() 574 pass 575 else: 576 raise ValueError( 577 f'Mode {mode} is not recognized, please choose either "train" for training or "val" for validation' 578 ) 579 if self.ignore_tags: 580 tags = [None] 581 else: 582 tags = dataloader.dataset.get_tags() 583 epoch_loss = 0 584 epoch_ssl_loss = defaultdict(lambda: 0) 585 data_len = 0 586 set_pars = dataloader.dataset.set_indexing_parameters 587 skip_metrics = self.skip_metrics if mode == "train" else None 588 for tag in tags: 589 set_pars(unlabeled=unlabeled, tag=tag) 590 data_len += len(dataloader) 591 if verbose: 592 dataloader = tqdm(dataloader) 593 for batch in dataloader: 594 loss, ssl_losses = self._loss_function( 595 batch, 596 augment_n, 597 temporal_subsampling_size=temporal_subsampling_size, 598 skip_metrics=skip_metrics, 599 ) 600 if loss != 0: 601 loss = loss * alpha 602 epoch_loss += loss.item() 603 for i, (ssl_loss, weight) in enumerate( 604 zip(ssl_losses, self.ssl_weights) 605 ): 606 if ssl_loss != 0: 607 epoch_ssl_loss[i] += ssl_loss.item() 608 loss = loss + weight * ssl_loss 609 if mode == "train": 610 self.optimizer.zero_grad() 611 if loss.requires_grad: 612 loss.backward() 613 self.optimizer.step() 614 615 epoch_loss = epoch_loss / data_len 616 epoch_ssl_loss = {k: v / data_len for k, v in epoch_ssl_loss.items()} 617 epoch_metrics = self._calculate_metrics() 618 619 return epoch_loss, epoch_ssl_loss, epoch_metrics 620 621 def train( 622 self, 623 trial: Trial = None, 624 optimized_metric: str = None, 625 to_ram: bool = False, 626 autostop_interval: int = 30, 627 autostop_threshold: float = 0.001, 628 autostop_metric: str = None, 629 main_task_on: bool = True, 630 ssl_on: bool = True, 631 temporal_subsampling_size: int = None, 632 loading_bar: bool = False, 633 ) -> Tuple: 634 """Train the task and return a log of epoch-average loss and metric. 635 636 You can use the autostop parameters to finish training when the parameters are not improving. It will be 637 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 638 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 639 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 640 641 Parameters 642 ---------- 643 trial : Trial 644 an `optuna` trial (for hyperparameter searches) 645 optimized_metric : str 646 the name of the metric being optimized (for hyperparameter searches) 647 to_ram : bool, default False 648 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 649 if the dataset is too large) 650 autostop_interval : int, default 50 651 the number of epochs to average the autostop metric over 652 autostop_threshold : float, default 0.001 653 the autostop difference threshold 654 autostop_metric : str, optional 655 the autostop metric (can be any one of the tracked metrics of `'loss'`) 656 main_task_on : bool, default True 657 if `False`, the main task (action segmentation) will not be used in training 658 ssl_on : bool, default True 659 if `False`, the SSL task will not be used in training 660 temporal_subsampling_size : int, optional 661 if not `None`, the temporal subsampling will be used in training with the given size 662 loading_bar : bool, default False 663 if `True`, a loading bar will be displayed 664 665 Returns 666 ------- 667 loss_log: list 668 a list of float loss function values for each epoch 669 metrics_log: dict 670 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 671 names, values are lists of function values) 672 673 """ 674 if self.parallel and not isinstance(self.model, nn.DataParallel): 675 self.model = MyDataParallel(self.model) 676 self.model.to(self.device) 677 assert autostop_metric in [None, "loss"] + list(self.metrics) 678 autostop_interval //= self.validation_interval 679 if trial is not None and optimized_metric is None: 680 raise ValueError( 681 "You need to provide the optimized metric name (optimized_metric parameter) " 682 "for optuna pruning to work!" 683 ) 684 if to_ram: 685 print("transferring datasets to RAM...") 686 self.train_dataloader.dataset.to_ram() 687 if self.val_dataloader is not None and len(self.val_dataloader) > 0: 688 self.val_dataloader.dataset.to_ram() 689 loss_log = {"train": [], "val": []} 690 metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])} 691 if not main_task_on: 692 self.model.main_task_off() 693 if not ssl_on: 694 self.model.ssl_off() 695 while self.epoch < self.num_epochs: 696 self.epoch += 1 697 unlabeled = None 698 alpha = 1 699 if self.pseudolabel: 700 if self.epoch >= self.T1: 701 unlabeled = (self.epoch - self.T1) % self.t != 0 702 if unlabeled: 703 alpha = self._alpha(self.epoch) 704 else: 705 unlabeled = False 706 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 707 dataloader=self.train_dataloader, 708 mode="train", 709 augment_n=self.augment_train, 710 unlabeled=unlabeled, 711 alpha=alpha, 712 temporal_subsampling_size=temporal_subsampling_size, 713 verbose=loading_bar, 714 ) 715 loss_log["train"].append(epoch_loss) 716 epoch_string = f"[epoch {self.epoch}]" 717 if self.pseudolabel: 718 if unlabeled: 719 epoch_string += " (unlabeled)" 720 else: 721 epoch_string += " (labeled)" 722 epoch_string += f": loss {epoch_loss:.4f}" 723 724 if len(epoch_ssl_loss) != 0: 725 for key, value in sorted(epoch_ssl_loss.items()): 726 metrics_log["train"][f"ssl_loss_{key}"].append(value) 727 epoch_string += f", ssl_loss_{key} {value:.4f}" 728 729 for metric_name, metric_value in sorted(epoch_metrics.items()): 730 if metric_name not in self.skip_metrics: 731 if isinstance(metric_value, list): 732 metric_value = torch.mean(torch.Tensor(metric_value)) 733 epoch_string += f", {metric_name} {metric_value:.3f}" 734 metrics_log["train"][metric_name].append(metric_value) 735 736 if ( 737 self.val_dataloader is not None 738 and self.epoch % self.validation_interval == 0 739 ): 740 with torch.no_grad(): 741 epoch_string += "\n" 742 ( 743 val_epoch_loss, 744 val_epoch_ssl_loss, 745 val_epoch_metrics, 746 ) = self._run_epoch( 747 dataloader=self.val_dataloader, 748 mode="val", 749 augment_n=self.augment_val, 750 ) 751 loss_log["val"].append(val_epoch_loss) 752 epoch_string += f"validation: loss {val_epoch_loss:.4f}" 753 754 if len(val_epoch_ssl_loss) != 0: 755 for key, value in sorted(val_epoch_ssl_loss.items()): 756 metrics_log["val"][f"ssl_loss_{key}"].append(value) 757 epoch_string += f", ssl_loss_{key} {value:.4f}" 758 759 for metric_name, metric_value in sorted(val_epoch_metrics.items()): 760 if isinstance(metric_value, list): 761 metric_value = torch.mean(torch.Tensor(metric_value)) 762 metrics_log["val"][metric_name].append(metric_value) 763 epoch_string += f", {metric_name} {metric_value:.3f}" 764 765 if trial is not None: 766 if optimized_metric not in metrics_log["val"]: 767 raise ValueError( 768 f"The {optimized_metric} metric set for optimization is not being logged!" 769 ) 770 trial.report(metrics_log["val"][optimized_metric][-1], self.epoch) 771 if trial.should_prune(): 772 raise TrialPruned() 773 774 if self.verbose: 775 print(epoch_string) 776 777 if self.log_file is not None: 778 with open(self.log_file, "a") as f: 779 f.write(epoch_string + "\n") 780 781 save_condition = ( 782 (self.model_save_epochs != 0) 783 and (self.epoch % self.model_save_epochs == 0) 784 ) or (self.epoch == self.num_epochs) 785 786 if self.epoch > 0 and save_condition and self.model_save_path is not None: 787 epoch_s = str(self.epoch).zfill(len(str(self.num_epochs))) 788 self.save_checkpoint( 789 os.path.join(self.model_save_path, f"epoch{epoch_s}.pt") 790 ) 791 792 if self.pseudolabel and self.epoch >= self.T1 and not unlabeled: 793 self._set_pseudolabels() 794 795 if autostop_metric == "loss": 796 if len(loss_log["val"]) > autostop_interval * 2: 797 if ( 798 np.mean(loss_log["val"][-autostop_interval:]) 799 < np.mean( 800 loss_log["val"][-2 * autostop_interval : -autostop_interval] 801 ) 802 + autostop_threshold 803 ): 804 break 805 elif autostop_metric in metrics_log["val"]: 806 if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2: 807 if ( 808 np.mean( 809 metrics_log["val"][autostop_metric][-autostop_interval:] 810 ) 811 < np.mean( 812 metrics_log["val"][autostop_metric][ 813 -2 * autostop_interval : -autostop_interval 814 ] 815 ) 816 + autostop_threshold 817 ): 818 break 819 820 metrics_log = {k: dict(v) for k, v in metrics_log.items()} 821 822 return loss_log, metrics_log 823 824 def evaluate_prediction( 825 self, 826 prediction: Union[torch.Tensor, Dict], 827 data: Union[DataLoader, BehaviorDataset, str] = None, 828 batch_size: int = 32, 829 indices: list = None, 830 ) -> Tuple: 831 """Compute metrics for a prediction. 832 833 Parameters 834 ---------- 835 prediction : torch.Tensor 836 the prediction 837 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 838 the data the prediction was made for (if not provided, take the validation dataset) 839 batch_size : int, default 32 840 the batch size 841 842 Returns 843 ------- 844 loss : float 845 the average value of the loss function 846 metric : dict 847 a dictionary of average values of metric functions 848 """ 849 850 if type(data) is not DataLoader: 851 dataset = self._get_dataset(data) 852 data = DataLoader(dataset, shuffle=False, batch_size=batch_size) 853 epoch_loss = 0 854 if isinstance(prediction, dict): 855 num_classes = len(self.behaviors_dict()) 856 length = dataset.len_segment() 857 coords = dataset.annotation_store.get_original_coordinates() 858 for batch in data: 859 main_target = batch["target"] 860 pr_coords = coords[batch["index"]] 861 predicted = torch.zeros((len(pr_coords), num_classes, length)) 862 for i, c in enumerate(pr_coords): 863 video_id = dataset.input_store.get_video_id(c) 864 clip_id = dataset.input_store.get_clip_id(c) 865 start, end = dataset.input_store.get_clip_start_end(c) 866 beh_ind = list(prediction[video_id]["classes"].keys()) 867 pred_tmp = prediction[video_id][clip_id][beh_ind, :] 868 predicted[i, :, : end - start] = pred_tmp[:, start:end] 869 self._compute( 870 [], 871 [], 872 predicted, 873 main_target, 874 skip_loss=True, 875 tag=batch.get("tag"), 876 apply_primary_function=False, 877 ) 878 else: 879 for batch in data: 880 main_target = batch["target"] 881 predicted = prediction[batch["index"]] 882 if not indices is None: 883 indices_new = [indices.index(i) for i in range(len(indices))] 884 predicted = predicted[:, indices_new] 885 self._compute( 886 [], 887 [], 888 predicted, 889 main_target, 890 skip_loss=True, 891 tag=batch.get("tag"), 892 apply_primary_function=False, 893 ) 894 epoch_metrics = self._calculate_metrics() 895 # strings = [ 896 # f"{metric_name} {metric_value:.3f}" 897 # for metric_name, metric_value in epoch_metrics.items() 898 # ] 899 # val_string = ", ".join(sorted(strings)) 900 # print(val_string) 901 return epoch_loss, epoch_metrics 902 903 def evaluate( 904 self, 905 data: Union[DataLoader, BehaviorDataset, str] = None, 906 augment_n: int = 0, 907 batch_size: int = 32, 908 verbose: bool = True, 909 ) -> Tuple: 910 """Evaluate the Task model. 911 912 Parameters 913 ---------- 914 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 915 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 916 augment_n : int, default 0 917 the number of augmentations to average results over 918 batch_size : int, default 32 919 the batch size 920 verbose : bool, default True 921 if True, the process is reported to standard output 922 923 Returns 924 ------- 925 loss : float 926 the average value of the loss function 927 ssl_loss : float 928 the average value of the SSL loss function 929 metric : dict 930 a dictionary of average values of metric functions 931 932 """ 933 if self.parallel and not isinstance(self.model, nn.DataParallel): 934 self.model = MyDataParallel(self.model) 935 self.model.to(self.device) 936 if type(data) is not DataLoader: 937 data = self._get_dataset(data) 938 data = DataLoader(data, shuffle=False, batch_size=batch_size) 939 with torch.no_grad(): 940 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 941 dataloader=data, mode="val", augment_n=augment_n, verbose=verbose 942 ) 943 val_string = f"loss {epoch_loss:.4f}" 944 for metric_name, metric_value in sorted(epoch_metrics.items()): 945 val_string += f", {metric_name} {metric_value:.3f}" 946 print(val_string) 947 return epoch_loss, epoch_ssl_loss, epoch_metrics 948 949 def predict( 950 self, 951 data: Union[DataLoader, BehaviorDataset, str] = None, 952 raw_output: bool = False, 953 apply_primary_function: bool = True, 954 augment_n: int = 0, 955 batch_size: int = 32, 956 train_mode: bool = False, 957 to_ram: bool = False, 958 embedding: bool = False, 959 ) -> torch.Tensor: 960 """Make a prediction with the Task model. 961 962 Parameters 963 ---------- 964 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional 965 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 966 raw_output : bool, default False 967 if `True`, the raw predicted probabilities are returned 968 apply_primary_function : bool, default True 969 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 970 the input) 971 augment_n : int, default 0 972 the number of augmentations to average results over 973 batch_size : int, default 32 974 the batch size 975 train_mode : bool, default False 976 if `True`, the model is used in training mode (affects dropout and batch normalization layers) 977 to_ram : bool, default False 978 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 979 if the dataset is too large) 980 embedding : bool, default False 981 if `True`, the output of feature extractor is returned, ignoring the prediction module of the model 982 983 Returns 984 ------- 985 prediction : torch.Tensor 986 a prediction for the input data 987 988 """ 989 if self.parallel and not isinstance(self.model, nn.DataParallel): 990 self.model = MyDataParallel(self.model) 991 self.model.to(self.device) 992 if train_mode: 993 self.model.train() 994 else: 995 self.model.eval() 996 output = [] 997 if embedding: 998 raw_output = True 999 apply_primary_function = True 1000 if type(data) is not DataLoader: 1001 data = self._get_dataset(data) 1002 if to_ram: 1003 print("transferring dataset to RAM...") 1004 data.to_ram() 1005 data = DataLoader(data, shuffle=False, batch_size=batch_size) 1006 self.model.ssl_off() 1007 with torch.no_grad(): 1008 for batch in tqdm(data): 1009 input = {k: v.to(self.device) for k, v in batch["input"].items()} 1010 predicted, _, _ = self._get_prediction( 1011 input, 1012 batch.get("tag"), 1013 augment_n=augment_n, 1014 embedding=embedding, 1015 ) 1016 if apply_primary_function: 1017 predicted = self.primary_predict_function(predicted) 1018 if not raw_output: 1019 predicted = self.predict_function(predicted) 1020 output.append(predicted.detach().cpu()) 1021 self.model.ssl_on() 1022 output = torch.cat(output).detach() 1023 return output 1024 1025 def dataset(self, mode="train") -> BehaviorDataset: 1026 """Get a dataset. 1027 1028 Parameters 1029 ---------- 1030 mode : {'train', 'val', 'test} 1031 the dataset to get 1032 1033 Returns 1034 ------- 1035 dataset : dlc2action.data.dataset.BehaviorDataset 1036 the dataset 1037 1038 """ 1039 dataloader = self.dataloader(mode) 1040 if dataloader is None: 1041 raise ValueError("The length of the dataloader is 0!") 1042 return dataloader.dataset 1043 1044 def dataloader(self, mode: str = "train") -> DataLoader: 1045 """Get a dataloader. 1046 1047 Parameters 1048 ---------- 1049 mode : {'train', 'val', 'test} 1050 the dataset to get 1051 1052 Returns 1053 ------- 1054 dataloader : torch.utils.data.DataLoader 1055 the dataloader 1056 1057 """ 1058 if mode == "train": 1059 return self.train_dataloader 1060 elif mode == "val": 1061 return self.val_dataloader 1062 elif mode == "test": 1063 return self.test_dataloader 1064 else: 1065 raise ValueError( 1066 f'The {mode} mode is not recognized, please choose from "train", "val" or "test"' 1067 ) 1068 1069 def _get_dataset(self, dataset): 1070 """Get a dataset from a dataloader, a string ('train', 'test' or 'val') or `None` (default).""" 1071 if dataset is None: 1072 dataset = self.dataset("val") 1073 elif dataset in ["train", "val", "test"]: 1074 dataset = self.dataset(dataset) 1075 elif type(dataset) is DataLoader: 1076 dataset = dataset.dataset 1077 if type(dataset) is BehaviorDataset: 1078 return dataset 1079 else: 1080 raise TypeError(f"The {type(dataset)} type of dataset is not recognized!") 1081 1082 def _get_dataloader(self, dataset): 1083 """Get a dataloader from a dataset, a string ('train', 'test' or 'val') or `None` (default).""" 1084 if dataset is None: 1085 dataset = self.dataloader("val") 1086 elif dataset in ["train", "val", "test"]: 1087 dataset = self.dataloader(dataset) 1088 if dataset is None: 1089 raise ValueError(f"The length of the dataloader is 0!") 1090 elif type(dataset) is BehaviorDataset: 1091 dataset = DataLoader(dataset) 1092 if type(dataset) is DataLoader: 1093 return dataset 1094 else: 1095 raise TypeError(f"The {type(dataset)} type of dataset is not recognized!") 1096 1097 def generate_full_length_prediction( 1098 self, dataset=None, batch_size=32, augment_n=10 1099 ): 1100 """Compile a prediction for the original input sequences. 1101 1102 Parameters 1103 ---------- 1104 dataset : BehaviorDataset, optional 1105 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1106 batch_size : int, default 32 1107 the batch size 1108 augment_n : int, default 10 1109 the number of augmentations to average results over 1110 1111 Returns 1112 ------- 1113 prediction : dict 1114 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1115 are prediction tensors 1116 1117 """ 1118 dataset = self._get_dataset(dataset) 1119 if not isinstance(dataset, BehaviorDataset): 1120 raise TypeError( 1121 f"The dataset parameter has to be either None, string, " 1122 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1123 ) 1124 predicted = self.predict( 1125 dataset, 1126 raw_output=True, 1127 apply_primary_function=True, 1128 augment_n=augment_n, 1129 batch_size=batch_size, 1130 ) 1131 predicted = dataset.generate_full_length_prediction(predicted) 1132 predicted = { 1133 v_id: { 1134 clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze() 1135 for clip_id, v in video_dict.items() 1136 } 1137 for v_id, video_dict in predicted.items() 1138 } 1139 return predicted 1140 1141 def generate_submission( 1142 self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10 1143 ): 1144 """Generate a MABe-22 style submission dictionary. 1145 1146 Parameters 1147 ---------- 1148 frame_number_map_file : str 1149 the path to the frame number map file 1150 dataset : BehaviorDataset, optional 1151 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1152 batch_size : int, default 32 1153 the batch size 1154 augment_n : int, default 10 1155 the number of augmentations to average results over 1156 1157 Returns 1158 ------- 1159 submission : dict 1160 a dictionary with frame number mapping and embeddings 1161 1162 """ 1163 dataset = self._get_dataset(dataset) 1164 if not isinstance(dataset, BehaviorDataset): 1165 raise TypeError( 1166 f"The dataset parameter has to be either None, string, " 1167 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1168 ) 1169 predicted = self.predict( 1170 dataset, 1171 raw_output=True, 1172 apply_primary_function=True, 1173 augment_n=augment_n, 1174 batch_size=batch_size, 1175 embedding=True, 1176 ) 1177 predicted = dataset.generate_full_length_prediction(predicted) 1178 frame_map = np.load(frame_number_map_file, allow_pickle=True).item() 1179 length = frame_map[list(frame_map.keys())[-1]][1] 1180 embeddings = None 1181 for video_id in list(predicted.keys()): 1182 split = video_id.split("--") 1183 if len(split) != 2 or len(predicted[video_id]) > 1: 1184 raise RuntimeError( 1185 "Generating submissions is only implemented for the mabe22 dataset!" 1186 ) 1187 if split[1] not in frame_map: 1188 raise RuntimeError(f"The {split[1]} video is not in the frame map file") 1189 v_id = split[1] 1190 clip_id = list(predicted[video_id].keys())[0] 1191 if embeddings is None: 1192 embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0])) 1193 start, end = frame_map[v_id] 1194 embeddings[start:end, :] = predicted[video_id][clip_id].T 1195 predicted.pop(video_id) 1196 predicted = { 1197 "frame_number_map": frame_map, 1198 "embeddings": embeddings.astype(np.float32), 1199 } 1200 return predicted 1201 1202 def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor: 1203 """Get a list of True group beginning and end indices from a boolean tensor.""" 1204 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 1205 true_indices = torch.where(output)[0] 1206 starts = torch.tensor( 1207 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 1208 ) 1209 ends = torch.tensor( 1210 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 1211 ) 1212 return torch.stack([starts, ends]).T 1213 1214 def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor: 1215 """Get rid of jittering in a non-exclusive classification tensor. 1216 1217 First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than 1218 `smooth_interval`. 1219 1220 """ 1221 if len(tensor.shape) > 1: 1222 for c in tensor.shape[1]: 1223 intervals = self._get_intervals(tensor[:, c] == 0) 1224 interval_lengths = torch.tensor( 1225 [interval[1] - interval[0] for interval in intervals] 1226 ) 1227 short_intervals = intervals[interval_lengths <= smooth_interval] 1228 for start, end in short_intervals: 1229 tensor[start:end, c] = 1 1230 intervals = self._get_intervals(tensor[:, c] == 1) 1231 interval_lengths = torch.tensor( 1232 [interval[1] - interval[0] for interval in intervals] 1233 ) 1234 short_intervals = intervals[interval_lengths <= smooth_interval] 1235 for start, end in short_intervals: 1236 tensor[start:end, c] = 0 1237 else: 1238 for c in tensor.unique(): 1239 intervals = self._get_intervals(tensor == c) 1240 interval_lengths = torch.tensor( 1241 [interval[1] - interval[0] for interval in intervals] 1242 ) 1243 short_intervals = intervals[interval_lengths <= smooth_interval] 1244 for start, end in short_intervals: 1245 if start == 0: 1246 tensor[start:end] = tensor[end + 1] 1247 else: 1248 tensor[start:end] = tensor[start - 1] 1249 return tensor 1250 1251 def _visualize_results_single( 1252 self, 1253 behavior: str, 1254 save_path: str = None, 1255 add_legend: bool = True, 1256 ground_truth: bool = True, 1257 hide_axes: bool = False, 1258 width: int = 10, 1259 whole_video: bool = False, 1260 transparent: bool = False, 1261 dataset: BehaviorDataset = None, 1262 smooth_interval: int = 0, 1263 title: str = None, 1264 ): 1265 """Visualize random predictions. 1266 1267 Parameters 1268 ---------- 1269 behavior : str 1270 the behavior to visualize 1271 save_path : str, optional 1272 the path where the plot will be saved 1273 add_legend : bool, default True 1274 if `True`, legend will be added to the plot 1275 ground_truth : bool, default True 1276 if `True`, ground truth will be added to the plot 1277 colormap : str, default 'viridis' 1278 the `matplotlib` colormap to use 1279 hide_axes : bool, default True 1280 if `True`, the axes will be hidden on the plot 1281 min_classes : int, default 1 1282 the minimum number of classes in a displayed interval 1283 width : float, default 10 1284 the width of the plot 1285 whole_video : bool, default False 1286 if `True`, whole videos are plotted instead of segments 1287 transparent : bool, default False 1288 if `True`, the background on the plot is transparent 1289 dataset : BehaviorDataset, optional 1290 the dataset to make the prediction for (if not provided, the validation dataset is used) 1291 drop_classes : set, optional 1292 a set of class names to not be displayed 1293 smooth_interval : int, default 0 1294 the interval to smooth the predictions over 1295 title : str, optional 1296 the title of the plot 1297 1298 """ 1299 if title is None: 1300 title = "" 1301 dataset = self._get_dataset(dataset) 1302 inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()} 1303 label_ind = inverse_dict[behavior] 1304 labels = {1: behavior, -100: "unknown"} 1305 label_keys = [1, -100] 1306 color_list = ["blue", "gray"] 1307 if whole_video: 1308 predicted = self.generate_full_length_prediction(dataset) 1309 keys = list(predicted.keys()) 1310 counter = 0 1311 if whole_video: 1312 max_iter = len(keys) * 5 1313 else: 1314 max_iter = len(dataset) * 5 1315 ok = False 1316 while not ok: 1317 counter += 1 1318 if counter > max_iter: 1319 raise RuntimeError( 1320 "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive" 1321 ) 1322 if whole_video: 1323 i = randint(0, len(keys) - 1) 1324 prediction = predicted[keys[i]] 1325 keys_i = list(prediction.keys()) 1326 j = randint(0, len(keys_i) - 1) 1327 full_p = prediction[keys_i[j]] 1328 prediction = prediction[keys_i[j]][label_ind] 1329 else: 1330 dataloader = DataLoader(dataset) 1331 i = randint(0, len(dataloader) - 1) 1332 for num, batch in enumerate(dataloader): 1333 if num == i: 1334 break 1335 input_data = {k: v.to(self.device) for k, v in batch["input"].items()} 1336 prediction, *_ = self._get_prediction( 1337 input_data, batch.get("tag"), augment_n=5 1338 ) 1339 prediction = self._apply_predict_functions(prediction) 1340 j = randint(0, len(prediction) - 1) 1341 full_p = prediction[j] 1342 prediction = prediction[j][label_ind] 1343 classes = [x for x in torch.unique(prediction) if int(x) in label_keys] 1344 ok = 1 in classes 1345 fig, ax = plt.subplots(figsize=(width, 2)) 1346 for c in classes: 1347 c_i = label_keys.index(int(c)) 1348 output, indices, counts = torch.unique_consecutive( 1349 prediction == c, return_inverse=True, return_counts=True 1350 ) 1351 long_indices = torch.where(output)[0] 1352 res_indices_start = [ 1353 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 1354 ] 1355 res_indices_end = [ 1356 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1357 for i in long_indices 1358 ] 1359 res_indices_len = [ 1360 end - start for start, end in zip(res_indices_start, res_indices_end) 1361 ] 1362 ax.broken_barh( 1363 list(zip(res_indices_start, res_indices_len)), 1364 (0, 1), 1365 label=labels[int(c)], 1366 facecolors=color_list[c_i], 1367 ) 1368 if ground_truth: 1369 gt = batch["target"][j][label_ind].to(self.device) 1370 classes_gt = [x for x in torch.unique(gt) if int(x) in label_keys] 1371 for c in classes_gt: 1372 c_i = label_keys.index(int(c)) 1373 if c in classes: 1374 behavior = None 1375 else: 1376 behavior = labels[int(c)] 1377 output, indices, counts = torch.unique_consecutive( 1378 gt == c, return_inverse=True, return_counts=True 1379 ) 1380 long_indices = torch.where(output * (counts > 5))[0] 1381 res_indices_start = [ 1382 (indices == i).nonzero(as_tuple=True)[0][0].item() 1383 for i in long_indices 1384 ] 1385 res_indices_end = [ 1386 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1387 for i in long_indices 1388 ] 1389 res_indices_len = [ 1390 end - start 1391 for start, end in zip(res_indices_start, res_indices_end) 1392 ] 1393 ax.broken_barh( 1394 list(zip(res_indices_start, res_indices_len)), 1395 (1.5, 1), 1396 facecolors=color_list[c_i], 1397 label=behavior, 1398 ) 1399 self._compute( 1400 main_target=batch["target"][j].unsqueeze(0).to(self.device), 1401 predicted=full_p.unsqueeze(0).to(self.device), 1402 ssl_targets=[], 1403 ssl_predicted=[], 1404 skip_loss=True, 1405 ) 1406 metrics = self._calculate_metrics() 1407 if smooth_interval > 0: 1408 smoothed = self._smooth(full_p, smooth_interval=smooth_interval)[ 1409 label_ind, : 1410 ] 1411 for c in classes: 1412 c_i = label_keys.index(int(c)) 1413 output, indices, counts = torch.unique_consecutive( 1414 smoothed == c, return_inverse=True, return_counts=True 1415 ) 1416 long_indices = torch.where(output)[0] 1417 res_indices_start = [ 1418 (indices == i).nonzero(as_tuple=True)[0][0].item() 1419 for i in long_indices 1420 ] 1421 res_indices_end = [ 1422 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1423 for i in long_indices 1424 ] 1425 res_indices_len = [ 1426 end - start 1427 for start, end in zip(res_indices_start, res_indices_end) 1428 ] 1429 ax.broken_barh( 1430 list(zip(res_indices_start, res_indices_len)), 1431 (3, 1), 1432 label=labels[int(c)], 1433 facecolors=color_list[c_i], 1434 ) 1435 keys = list(metrics.keys()) 1436 for key in keys: 1437 if key.split("_")[-1] != (str(label_ind)): 1438 metrics.pop(key) 1439 title = [title] 1440 for key, value in metrics.items(): 1441 title.append(f"{'_'.join(key.split('_')[: -1])}: {value:.2f}") 1442 title = ", ".join(title) 1443 if not ground_truth: 1444 ax.axes.yaxis.set_visible(False) 1445 else: 1446 ax.set_yticks([0.5, 2]) 1447 ax.set_yticklabels(["prediction", "ground truth"]) 1448 if add_legend: 1449 ax.legend() 1450 if hide_axes: 1451 plt.axis("off") 1452 plt.title(title) 1453 plt.xlim((0, len(prediction))) 1454 if save_path is not None: 1455 plt.savefig(save_path, transparent=transparent) 1456 plt.show() 1457 1458 def visualize_results( 1459 self, 1460 save_path: str = None, 1461 add_legend: bool = True, 1462 ground_truth: bool = True, 1463 colormap: str = "viridis", 1464 hide_axes: bool = False, 1465 min_classes: int = 1, 1466 width: int = 10, 1467 whole_video: bool = False, 1468 transparent: bool = False, 1469 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1470 drop_classes: Set = None, 1471 search_classes: Set = None, 1472 smooth_interval_prediction: int = None, 1473 font_size: float = None, 1474 num_plots: int = 1, 1475 window_size:int =400 1476 ): 1477 """Visualize random predictions. 1478 1479 Parameters 1480 ---------- 1481 save_path : str, optional 1482 the path where the plot will be saved 1483 add_legend : bool, default True 1484 if `True`, legend will be added to the plot 1485 ground_truth : bool, default True 1486 if `True`, ground truth will be added to the plot 1487 colormap : str, default 'Accent' 1488 the `matplotlib` colormap to use 1489 hide_axes : bool, default True 1490 if `True`, the axes will be hidden on the plot 1491 min_classes : int, default 1 1492 the minimum number of classes in a displayed interval 1493 width : float, default 10 1494 the width of the plot 1495 whole_video : bool, default False 1496 if `True`, whole videos are plotted instead of segments 1497 transparent : bool, default False 1498 if `True`, the background on the plot is transparent 1499 dataset : BehaviorDataset | DataLoader | str | None, optional 1500 the dataset to make the prediction for (if not provided, the validation dataset is used) 1501 drop_classes : set, optional 1502 a set of class names to not be displayed 1503 search_classes : set, optional 1504 if given, only intervals where at least one of the classes is in ground truth will be shown 1505 smooth_interval_prediction : int, optional 1506 if given, the prediction will be smoothed with a moving average of the given size 1507 1508 """ 1509 if drop_classes is None: 1510 drop_classes = [] 1511 dataset = self._get_dataset(dataset) 1512 if dataset.annotation_class() != "exclusive_classification": 1513 raise NotImplementedError( 1514 "Results visualisation is only implemented for exclusive classification datasets!" 1515 ) 1516 labels = { 1517 k: v for k, v in dataset.behaviors_dict().items() if v not in drop_classes 1518 } 1519 labels.update({-100: "unknown"}) 1520 label_keys = sorted([int(x) for x in labels.keys()]) 1521 if search_classes is None: 1522 ok = True 1523 else: 1524 ok = False 1525 classes = [] 1526 predicted = self.generate_full_length_prediction(dataset) 1527 keys = list(predicted.keys()) 1528 counter = 0 1529 max_iter = len(keys) * 2 1530 while len(classes) < min_classes or not ok: 1531 counter += 1 1532 if counter > max_iter: 1533 raise RuntimeError( 1534 "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive" 1535 ) 1536 i = randint(0, len(keys) - 1) 1537 prediction = predicted[keys[i]] 1538 keys_i = list(prediction.keys()) 1539 j = randint(0, len(keys_i) - 1) 1540 prediction = prediction[keys_i[j]] 1541 key1 = keys[i] 1542 key2 = keys_i[j] 1543 1544 if smooth_interval_prediction > 0: 1545 unsmoothed_prediction = deepcopy(prediction) 1546 prediction = self._smooth(prediction, smooth_interval_prediction) 1547 height = 3 1548 else: 1549 height = 2 1550 classes = [ 1551 labels[int(x)] for x in torch.unique(prediction) if x in label_keys 1552 ] 1553 if search_classes is not None: 1554 ok = any([x in classes for x in search_classes]) 1555 fig, ax = plt.subplots(figsize=(width, height)) 1556 cmap = cm.get_cmap(colormap) if colormap != "dlc2action" else None 1557 color_list = ( 1558 [cmap(c) for c in np.linspace(0, 1, len(labels))] 1559 if colormap != "dlc2action" 1560 else dlc2action_colormaps["default"] 1561 ) 1562 1563 def _plot_prediction(prediction, height, set_labels=True): 1564 for c in label_keys: 1565 c_i = label_keys.index(int(c)) 1566 output, indices, counts = torch.unique_consecutive( 1567 prediction == c, return_inverse=True, return_counts=True 1568 ) 1569 long_indices = torch.where(output)[0] 1570 if len(long_indices) == 0: 1571 continue 1572 res_indices_start = [ 1573 (indices == i).nonzero(as_tuple=True)[0][0].item() 1574 for i in long_indices 1575 ] 1576 res_indices_end = [ 1577 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1578 for i in long_indices 1579 ] 1580 res_indices_len = [ 1581 end - start 1582 for start, end in zip(res_indices_start, res_indices_end) 1583 ] 1584 if set_labels: 1585 label = labels[int(c)] 1586 else: 1587 label = None 1588 ax.broken_barh( 1589 list(zip(res_indices_start, res_indices_len)), 1590 (height, 1), 1591 label=label, 1592 facecolors=color_list[c_i], 1593 ) 1594 1595 if smooth_interval_prediction > 0: 1596 _plot_prediction(unsmoothed_prediction, 0) 1597 _plot_prediction(prediction, 1.5, set_labels=False) 1598 gt_height = 3 1599 else: 1600 _plot_prediction(prediction, 0) 1601 gt_height = 1.5 1602 if ground_truth: 1603 gt = dataset.generate_full_length_gt()[key1][key2] 1604 for c in label_keys: 1605 c_i = label_keys.index(int(c)) 1606 if labels[int(c)] in classes: 1607 label = None 1608 else: 1609 label = labels[int(c)] 1610 output, indices, counts = torch.unique_consecutive( 1611 gt == c, return_inverse=True, return_counts=True 1612 ) 1613 long_indices = torch.where(output)[0] 1614 if len(long_indices) == 0: 1615 continue 1616 res_indices_start = [ 1617 (indices == i).nonzero(as_tuple=True)[0][0].item() 1618 for i in long_indices 1619 ] 1620 res_indices_end = [ 1621 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1622 for i in long_indices 1623 ] 1624 res_indices_len = [ 1625 end - start 1626 for start, end in zip(res_indices_start, res_indices_end) 1627 ] 1628 ax.broken_barh( 1629 list(zip(res_indices_start, res_indices_len)), 1630 (gt_height, 1), 1631 facecolors=color_list[c_i] if c != "unknown" else "gray", 1632 label=label, 1633 ) 1634 if not ground_truth: 1635 ax.axes.yaxis.set_visible(False) 1636 else: 1637 if smooth_interval_prediction > 0: 1638 ax.set_yticks([0.5, 2, 3.5]) 1639 ax.set_yticklabels(["prediction", "smoothed", "ground truth"]) 1640 else: 1641 ax.set_yticks([0.5, 2]) 1642 ax.set_yticklabels(["prediction", "ground truth"]) 1643 if add_legend: 1644 ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) 1645 if hide_axes: 1646 # plt.axis("off") 1647 plt.box(False) 1648 if font_size is not None: 1649 font = {"size": font_size} 1650 rc("font", **font) 1651 plt.title(f"{key1} -- {key2}") 1652 plt.tight_layout() 1653 for seed in range(num_plots): 1654 if not whole_video: 1655 ax.set_xlim((seed * window_size, (seed + 1) * window_size)) 1656 if save_path is not None: 1657 plt.savefig( 1658 save_path.replace(".svg", f"_{seed}_{key1} -- {key2}.svg"), transparent=transparent 1659 ) 1660 print(f"Saved in {save_path.replace('.svg', f'_{seed}_{key1} -- {key2}.svg')}") 1661 plt.show() 1662 plt.close() 1663 1664 def set_ssl_transformations(self, ssl_transformations): 1665 """Set SSL transformations. 1666 1667 Parameters 1668 ---------- 1669 ssl_transformations : list 1670 a list of callable SSL transformations 1671 1672 """ 1673 self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations) 1674 if self.val_dataloader is not None: 1675 self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations) 1676 1677 def set_ssl_losses(self, ssl_losses: list) -> None: 1678 """Set SSL losses. 1679 1680 Parameters 1681 ---------- 1682 ssl_losses : list 1683 a list of callable SSL losses 1684 1685 """ 1686 self.ssl_losses = ssl_losses 1687 1688 def set_log(self, log: str) -> None: 1689 """Set the log file. 1690 1691 Parameters 1692 ---------- 1693 log: str 1694 the mew log file path 1695 1696 """ 1697 self.log_file = log 1698 1699 def set_keep_target_none(self, keep_target_none: List) -> None: 1700 """Set the keep_target_none parameter of the transformer. 1701 1702 Parameters 1703 ---------- 1704 keep_target_none : list 1705 a list of bool values 1706 1707 """ 1708 self.transformer.keep_target_none = keep_target_none 1709 1710 def set_generate_ssl_input(self, generate_ssl_input: list) -> None: 1711 """Set the generate_ssl_input parameter of the transformer. 1712 1713 Parameters 1714 ---------- 1715 generate_ssl_input : list 1716 a list of bool values 1717 1718 """ 1719 self.transformer.generate_ssl_input = generate_ssl_input 1720 1721 def set_model(self, model: Model) -> None: 1722 """Set a new model. 1723 1724 Parameters 1725 ---------- 1726 model: Model 1727 the new model 1728 1729 """ 1730 self.epoch = 0 1731 self.model = model 1732 self.optimizer = self.optimizer_class( 1733 model.parameters(), lr=self.lr, weight_decay=self.weight_decay 1734 ) 1735 if self.model.process_labels: 1736 self.model.set_behaviors(list(self.behaviors_dict().values())) 1737 1738 def set_dataloaders( 1739 self, 1740 train_dataloader: DataLoader, 1741 val_dataloader: DataLoader = None, 1742 test_dataloader: DataLoader = None, 1743 ) -> None: 1744 """Set new dataloaders. 1745 1746 Parameters 1747 ---------- 1748 train_dataloader: torch.utils.data.DataLoader 1749 the new train dataloader 1750 val_dataloader : torch.utils.data.DataLoader 1751 the new validation dataloader 1752 test_dataloader : torch.utils.data.DataLoader 1753 the new test dataloader 1754 1755 """ 1756 self.train_dataloader = train_dataloader 1757 self.val_dataloader = val_dataloader 1758 self.test_dataloader = test_dataloader 1759 1760 def set_loss(self, loss: Callable) -> None: 1761 """Set new loss function. 1762 1763 Parameters 1764 ---------- 1765 loss: callable 1766 the new loss function 1767 1768 """ 1769 self.loss = loss 1770 1771 def set_metrics(self, metrics: dict) -> None: 1772 """Set new metric. 1773 1774 Parameters 1775 ---------- 1776 metrics : dict 1777 the new metric dictionary 1778 1779 """ 1780 self.metrics = metrics 1781 1782 def set_transformer(self, transformer: Transformer) -> None: 1783 """Set a new transformer. 1784 1785 Parameters 1786 ---------- 1787 transformer: Transformer 1788 the new transformer 1789 1790 """ 1791 self.transformer = transformer 1792 1793 def set_predict_functions( 1794 self, primary_predict_function: Callable, predict_function: Callable 1795 ) -> None: 1796 """Set new predict functions. 1797 1798 Parameters 1799 ---------- 1800 primary_predict_function : callable 1801 the new primary predict function 1802 predict_function : callable 1803 the new predict function 1804 1805 """ 1806 self.primary_predict_function = primary_predict_function 1807 self.predict_function = predict_function 1808 1809 def _set_pseudolabels(self): 1810 """Set pseudolabels.""" 1811 self.train_dataloader.dataset.set_unlabeled(True) 1812 predicted = self.predict( 1813 data=self.dataset("train"), 1814 raw_output=False, 1815 augment_n=self.augment_val, 1816 ssl_off=True, 1817 ) 1818 self.train_dataloader.dataset.set_annotation(predicted.detach()) 1819 1820 def _alpha(self, epoch): 1821 """Get the current pseudolabeling alpha parameter.""" 1822 if epoch <= self.T1: 1823 return 0 1824 elif epoch < self.T2: 1825 return self.alpha_f * (epoch - self.T1) / (self.T2 - self.T1) 1826 else: 1827 return self.alpha_f 1828 1829 def count_classes(self, bouts: bool = False) -> Dict: 1830 """Get a dictionary of class counts in different modes. 1831 1832 Parameters 1833 ---------- 1834 bouts : bool, default False 1835 if `True`, instead of frame counts segment counts are returned 1836 1837 Returns 1838 ------- 1839 class_counts : dict 1840 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1841 class names and values are class counts (in frames) 1842 1843 """ 1844 class_counts = {} 1845 for x in ["train", "val", "test"]: 1846 try: 1847 class_counts[x] = self.dataset(x).count_classes(bouts) 1848 except ValueError: 1849 class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()} 1850 return class_counts 1851 1852 def behaviors_dict(self) -> Dict: 1853 """Get a behavior dictionary. 1854 1855 Keys are label indices and values are label names. 1856 1857 Returns 1858 ------- 1859 behaviors_dict : dict 1860 behavior dictionary 1861 1862 """ 1863 return self.dataset().behaviors_dict() 1864 1865 def update_parameters(self, parameters: Dict) -> None: 1866 """Update training parameters from a dictionary. 1867 1868 Parameters 1869 ---------- 1870 parameters : dict 1871 the update dictionary 1872 1873 """ 1874 self.lr = parameters.get("lr", self.lr) 1875 self.parallel = parameters.get("parallel", self.parallel) 1876 self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr) 1877 self.verbose = parameters.get("verbose", self.verbose) 1878 self.device = parameters.get("device", self.device) 1879 if self.device == "auto": 1880 self.device = "cuda" if torch.cuda.is_available() else "cpu" 1881 self.augment_train = int(parameters.get("augment_train", self.augment_train)) 1882 self.augment_val = int(parameters.get("augment_val", self.augment_val)) 1883 ssl_weights = parameters.get("ssl_weights", self.ssl_weights) 1884 if ssl_weights is None: 1885 ssl_weights = [1 for _ in self.ssl_losses] 1886 if not isinstance(ssl_weights, Iterable): 1887 ssl_weights = [ssl_weights for _ in self.ssl_losses] 1888 self.ssl_weights = ssl_weights 1889 self.num_epochs = parameters.get("num_epochs", self.num_epochs) 1890 self.model_save_epochs = parameters.get( 1891 "model_save_epochs", self.model_save_epochs 1892 ) 1893 self.model_save_path = parameters.get("model_save_path", self.model_save_path) 1894 self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel) 1895 self.T1 = int(parameters.get("pseudolabel_start", self.T1)) 1896 self.T2 = int(parameters.get("alpha_growth_stop", self.T2)) 1897 self.t = int(parameters.get("correction_interval", self.t)) 1898 self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f) 1899 self.log_file = parameters.get("log_file", self.log_file) 1900 1901 def generate_uncertainty_score( 1902 self, 1903 classes: List, 1904 augment_n: int = 0, 1905 batch_size: int = 32, 1906 method: str = "least_confidence", 1907 predicted: torch.Tensor = None, 1908 behaviors_dict: Dict = None, 1909 ) -> Dict: 1910 """Generate frame-wise scores for active learning. 1911 1912 Parameters 1913 ---------- 1914 classes : list 1915 a list of class names or indices; their confidence scores will be computed separately and stacked 1916 augment_n : int, default 0 1917 the number of augmentations to average over 1918 batch_size : int, default 32 1919 the batch size 1920 method : {"least_confidence", "entropy"} 1921 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1922 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1923 predicted : torch.Tensor, default None 1924 if not `None`, the predicted tensor to use instead of predicting from the model 1925 behaviors_dict : dict, default None 1926 if not `None`, the behaviors dictionary to use instead of the one from the dataset 1927 1928 Returns 1929 ------- 1930 score_dicts : dict 1931 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1932 are score tensors 1933 1934 """ 1935 dataset = self.dataset("train") 1936 if behaviors_dict is None: 1937 behaviors_dict = self.behaviors_dict() 1938 if not isinstance(dataset, BehaviorDataset): 1939 raise TypeError( 1940 f"The dataset parameter has to be either None, string, " 1941 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1942 ) 1943 if predicted is None: 1944 predicted = self.predict( 1945 dataset, 1946 raw_output=True, 1947 apply_primary_function=True, 1948 augment_n=augment_n, 1949 batch_size=batch_size, 1950 ) 1951 predicted = dataset.generate_full_length_prediction(predicted) 1952 if isinstance(classes[0], str): 1953 behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()} 1954 classes = [behaviors_dict_inverse[c] for c in classes] 1955 for v_id, v in predicted.items(): 1956 for clip_id, vv in v.items(): 1957 if method == "least_confidence": 1958 predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5] 1959 elif method == "entropy": 1960 predicted[v_id][clip_id][vv != -100] = ( 1961 -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv) 1962 )[vv != -100] 1963 elif method == "random": 1964 predicted[v_id][clip_id] = torch.rand(vv.shape) 1965 else: 1966 raise ValueError( 1967 f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']" 1968 ) 1969 predicted[v_id][clip_id][vv == -100] = 0 1970 1971 predicted = { 1972 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 1973 for v_id, video_dict in predicted.items() 1974 } 1975 return predicted 1976 1977 def generate_bald_score( 1978 self, 1979 classes: List, 1980 augment_n: int = 0, 1981 batch_size: int = 32, 1982 num_models: int = 10, 1983 kernel_size: int = 11, 1984 ) -> Dict: 1985 """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning. 1986 1987 Parameters 1988 ---------- 1989 classes : list 1990 a list of class names or indices; their confidence scores will be computed separately and stacked 1991 augment_n : int, default 0 1992 the number of augmentations to average over 1993 batch_size : int, default 32 1994 the batch size 1995 num_models : int, default 10 1996 the number of dropout masks to apply 1997 kernel_size : int, default 11 1998 the size of the smoothing gaussian kernel 1999 2000 Returns 2001 ------- 2002 score_dicts : dict 2003 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 2004 are score tensors 2005 2006 """ 2007 dataset = self.dataset("train") 2008 dataset = self._get_dataset(dataset) 2009 if not isinstance(dataset, BehaviorDataset): 2010 raise TypeError( 2011 f"The dataset parameter has to be either None, string, " 2012 f"BehaviorDataset or Dataloader, got {type(dataset)}" 2013 ) 2014 predictions = [] 2015 for _ in range(num_models): 2016 predicted = self.predict( 2017 dataset, 2018 raw_output=True, 2019 apply_primary_function=True, 2020 augment_n=augment_n, 2021 batch_size=batch_size, 2022 train_mode=True, 2023 ) 2024 predicted = dataset.generate_full_length_prediction(predicted) 2025 if isinstance(classes[0], str): 2026 behaviors_dict_inverse = { 2027 v: k for k, v in self.behaviors_dict().items() 2028 } 2029 classes = [behaviors_dict_inverse[c] for c in classes] 2030 for v_id, v in predicted.items(): 2031 for clip_id, vv in v.items(): 2032 vv[vv != -100] = (vv[vv != -100] > 0.5).int().float() 2033 predicted[v_id][clip_id] = vv 2034 predicted = { 2035 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 2036 for v_id, video_dict in predicted.items() 2037 } 2038 predictions.append(predicted) 2039 result = {v_id: {} for v_id in predictions[0]} 2040 r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1) 2041 gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r] 2042 gauss = [x / sum(gauss) for x in gauss] 2043 kernel = torch.FloatTensor([[gauss]]) 2044 for v_id in predictions[0]: 2045 for clip_id in predictions[0][v_id]: 2046 consensus = ( 2047 ( 2048 torch.mean( 2049 torch.stack([x[v_id][clip_id] for x in predictions]), dim=0 2050 ) 2051 > 0.5 2052 ) 2053 .int() 2054 .float() 2055 ) 2056 consensus[predictions[0][v_id][clip_id] == -100] = -100 2057 result[v_id][clip_id] = torch.zeros(consensus.shape) 2058 for x in predictions: 2059 result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int() 2060 result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models 2061 res = torch.zeros(result[v_id][clip_id].shape) 2062 for i in range(len(classes)): 2063 res[i, floor(kernel_size // 2) : -floor(kernel_size // 2)] = ( 2064 torch.nn.functional.conv1d( 2065 result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0), 2066 kernel, 2067 )[0, ...] 2068 ) 2069 result[v_id][clip_id] = res 2070 return result 2071 2072 def get_normalization_stats(self) -> Optional[Dict]: 2073 """Get the normalization statistics of the dataset. 2074 2075 Returns 2076 ------- 2077 stats : dict 2078 a dictionary containing the mean and standard deviation of the dataset 2079 2080 """ 2081 return self.train_dataloader.dataset.stats
A universal trainer class that performs training, evaluation and prediction for all types of tasks and data.
87 def __init__( 88 self, 89 train_dataloader: DataLoader, 90 model: Union[nn.Module, Model], 91 loss: Callable[[torch.Tensor, torch.Tensor], float], 92 num_epochs: int = 0, 93 transformer: Transformer = None, 94 ssl_losses: List = None, 95 ssl_weights: List = None, 96 lr: float = 1e-3, 97 weight_decay: float = 0, 98 metrics: Dict = None, 99 val_dataloader: DataLoader = None, 100 test_dataloader: DataLoader = None, 101 optimizer: Optimizer = None, 102 device: str = "cuda", 103 verbose: bool = True, 104 log_file: Union[str, None] = None, 105 augment_train: int = 1, 106 augment_val: int = 0, 107 validation_interval: int = 1, 108 predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None, 109 primary_predict_function: Callable = None, 110 exclusive: bool = True, 111 ignore_tags: bool = True, 112 threshold: float = 0.5, 113 model_save_path: str = None, 114 model_save_epochs: int = 5, 115 pseudolabel: bool = False, 116 pseudolabel_start: int = 100, 117 correction_interval: int = 2, 118 pseudolabel_alpha_f: float = 3, 119 alpha_growth_stop: int = 600, 120 parallel: bool = False, 121 skip_metrics: List = None, 122 ) -> None: 123 """Initialize the class. 124 125 Parameters 126 ---------- 127 train_dataloader : torch.utils.data.DataLoader 128 a training dataloader 129 model : dlc2action.model.base_model.Model 130 a model 131 loss : callable 132 a loss function 133 num_epochs : int, default 0 134 the number of epochs 135 transformer : dlc2action.transformer.base_transformer.Transformer, optional 136 a transformer 137 ssl_losses : list, optional 138 a list of SSL losses 139 ssl_weights : list, optional 140 a list of SSL weights (if not provided initializes to 1) 141 lr : float, default 1e-3 142 learning rate 143 weight_decay : float, default 0 144 weight decay 145 metrics : dict, optional 146 a list of metric functions 147 val_dataloader : torch.utils.data.DataLoader, optional 148 a validation dataloader 149 test_dataloader : torch.utils.data.DataLoader, optional 150 a test dataloader 151 optimizer : torch.optim.Optimizer, optional 152 an optimizer (`Adam` by default) 153 device : str, default 'cuda' 154 the device to train the model on 155 verbose : bool, default True 156 if `True`, the process is described in standard output 157 log_file : str, optional 158 the path to a text file where the process will be logged 159 augment_train : {1, 0} 160 number of augmentations to apply at training 161 augment_val : int, default 0 162 number of augmentations to apply at validation 163 validation_interval : int, default 1 164 every time this number of epochs passes, validation metrics are computed 165 predict_function : callable, optional 166 a function that maps probabilities to class predictions (if not provided, a default is generated) 167 primary_predict_function : callable, optional 168 a function that maps model output to probabilities (if not provided, initialized as identity) 169 exclusive : bool, default True 170 set to False for multi-label classification 171 ignore_tags : bool, default False 172 if `True`, samples with different meta tags will be mixed in batches 173 threshold : float, default 0.5 174 the threshold used for multi-label classification default prediction function 175 model_save_path : str, optional 176 the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path 177 is not provided) 178 model_save_epochs : int, default 5 179 the interval for saving the model checkpoints (the last epoch is always saved) 180 pseudolabel : bool, default False 181 if True, the pseudolabeling procedure will be applied 182 pseudolabel_start : int, default 100 183 pseudolabeling starts after this epoch 184 correction_interval : int, default 1 185 after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and 186 new pseudolabels are generated 187 pseudolabel_alpha_f : float, default 3 188 the maximum value of pseudolabeling alpha 189 alpha_growth_stop : int, default 600 190 pseudolabeling alpha stops growing after this epoch 191 parallel : bool, default False 192 if True, the model is trained on multiple GPUs 193 skip_metrics : list, optional 194 a list of metrics to skip 195 196 """ 197 # pseudolabeling might be buggy right now -- not using it! 198 if skip_metrics is None: 199 skip_metrics = [] 200 self.train_dataloader = train_dataloader 201 self.val_dataloader = val_dataloader 202 self.test_dataloader = test_dataloader 203 self.transformer = transformer 204 self.num_epochs = num_epochs 205 self.skip_metrics = skip_metrics 206 self.verbose = verbose 207 self.augment_train = int(augment_train) 208 self.augment_val = int(augment_val) 209 self.ignore_tags = ignore_tags 210 self.validation_interval = int(validation_interval) 211 self.log_file = log_file 212 self.loss = loss 213 self.model_save_path = model_save_path 214 self.model_save_epochs = model_save_epochs 215 self.epoch = 0 216 217 if metrics is None: 218 metrics = {} 219 self.metrics = metrics 220 221 if optimizer is None: 222 optimizer = Adam 223 224 if ssl_weights is None: 225 ssl_weights = [1 for _ in ssl_losses] 226 if not isinstance(ssl_weights, Iterable): 227 ssl_weights = [ssl_weights for _ in ssl_losses] 228 self.ssl_weights = ssl_weights 229 230 self.optimizer_class = optimizer 231 self.lr = lr 232 self.weight_decay = weight_decay 233 if not isinstance(model, Model): 234 self.model = LoadedModel(model=model) 235 else: 236 self.set_model(model) 237 self.parallel = parallel 238 239 if self.transformer is None: 240 self.augment_val = 0 241 self.augment_train = 0 242 self.transformer = EmptyTransformer() 243 244 if self.augment_train > 1: 245 warnings.warn( 246 'The "augment_train" parameter is too large -> setting it to 1.' 247 ) 248 self.augment_train = 1 249 250 try: 251 if device == "auto": 252 device = "cuda" if torch.cuda.is_available() else "cpu" 253 self.device = torch.device(device) 254 except: 255 raise ("The format of the device is incorrect") 256 257 if ssl_losses is None: 258 self.ssl_losses = [lambda x, y: 0] 259 else: 260 self.ssl_losses = ssl_losses 261 262 if primary_predict_function is None: 263 if exclusive: 264 primary_predict_function = lambda x: nn.Softmax(x, dim=1) 265 else: 266 primary_predict_function = lambda x: torch.sigmoid(x) 267 self.primary_predict_function = primary_predict_function 268 269 if predict_function is None: 270 if exclusive: 271 self.predict_function = lambda x: torch.max(x.data, 1)[1] 272 else: 273 self.predict_function = lambda x: (x > threshold).int() 274 else: 275 self.predict_function = predict_function 276 277 self.pseudolabel = pseudolabel 278 self.alpha_f = pseudolabel_alpha_f 279 self.T2 = alpha_growth_stop 280 self.T1 = pseudolabel_start 281 self.t = correction_interval 282 if self.T2 <= self.T1: 283 raise ValueError( 284 f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got " 285 f"{pseudolabel_start=} and {alpha_growth_stop=}" 286 ) 287 self.decision_thresholds = [0.5 for x in self.behaviors_dict()]
Initialize the class.
Parameters
train_dataloader : torch.utils.data.DataLoader
a training dataloader
model : dlc2action.model.base_model.Model
a model
loss : callable
a loss function
num_epochs : int, default 0
the number of epochs
transformer : dlc2action.transformer.base_transformer.Transformer, optional
a transformer
ssl_losses : list, optional
a list of SSL losses
ssl_weights : list, optional
a list of SSL weights (if not provided initializes to 1)
lr : float, default 1e-3
learning rate
weight_decay : float, default 0
weight decay
metrics : dict, optional
a list of metric functions
val_dataloader : torch.utils.data.DataLoader, optional
a validation dataloader
test_dataloader : torch.utils.data.DataLoader, optional
a test dataloader
optimizer : torch.optim.Optimizer, optional
an optimizer (Adam by default)
device : str, default 'cuda'
the device to train the model on
verbose : bool, default True
if True, the process is described in standard output
log_file : str, optional
the path to a text file where the process will be logged
augment_train : {1, 0}
number of augmentations to apply at training
augment_val : int, default 0
number of augmentations to apply at validation
validation_interval : int, default 1
every time this number of epochs passes, validation metrics are computed
predict_function : callable, optional
a function that maps probabilities to class predictions (if not provided, a default is generated)
primary_predict_function : callable, optional
a function that maps model output to probabilities (if not provided, initialized as identity)
exclusive : bool, default True
set to False for multi-label classification
ignore_tags : bool, default False
if True, samples with different meta tags will be mixed in batches
threshold : float, default 0.5
the threshold used for multi-label classification default prediction function
model_save_path : str, optional
the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path
is not provided)
model_save_epochs : int, default 5
the interval for saving the model checkpoints (the last epoch is always saved)
pseudolabel : bool, default False
if True, the pseudolabeling procedure will be applied
pseudolabel_start : int, default 100
pseudolabeling starts after this epoch
correction_interval : int, default 1
after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and
new pseudolabels are generated
pseudolabel_alpha_f : float, default 3
the maximum value of pseudolabeling alpha
alpha_growth_stop : int, default 600
pseudolabeling alpha stops growing after this epoch
parallel : bool, default False
if True, the model is trained on multiple GPUs
skip_metrics : list, optional
a list of metrics to skip
289 def save_checkpoint(self, checkpoint_path: str) -> None: 290 """Save a general checkpoint. 291 292 Parameters 293 ---------- 294 checkpoint_path : str 295 the path where the checkpoint will be saved 296 297 """ 298 torch.save( 299 { 300 "epoch": self.epoch, 301 "model_state_dict": self.model.state_dict(), 302 "optimizer_state_dict": self.optimizer.state_dict(), 303 }, 304 checkpoint_path, 305 )
Save a general checkpoint.
Parameters
checkpoint_path : str the path where the checkpoint will be saved
307 def load_from_checkpoint( 308 self, checkpoint_path, only_model: bool = False, load_strict: bool = True 309 ) -> None: 310 """Load from a checkpoint. 311 312 Parameters 313 ---------- 314 checkpoint_path : str 315 the path to the checkpoint 316 only_model : bool, default False 317 if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state 318 dictionary) 319 load_strict : bool, default True 320 if `True`, any inconsistencies in state dictionaries are regarded as errors 321 322 """ 323 if checkpoint_path is None: 324 return 325 checkpoint = torch.load( 326 checkpoint_path, map_location=self.device, weights_only=False 327 ) 328 self.model.to(self.device) 329 self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict) 330 if not only_model: 331 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 332 self.epoch = checkpoint["epoch"]
Load from a checkpoint.
Parameters
checkpoint_path : str
the path to the checkpoint
only_model : bool, default False
if True, only the model state dictionary will be loaded (and not the epoch and the optimizer state
dictionary)
load_strict : bool, default True
if True, any inconsistencies in state dictionaries are regarded as errors
334 def save_model(self, save_path: str) -> None: 335 """Save the model state dictionary. 336 337 Parameters 338 ---------- 339 save_path : str 340 the path where the state will be saved 341 342 """ 343 torch.save(self.model.state_dict(), save_path) 344 print("saved the model successfully")
Save the model state dictionary.
Parameters
save_path : str the path where the state will be saved
621 def train( 622 self, 623 trial: Trial = None, 624 optimized_metric: str = None, 625 to_ram: bool = False, 626 autostop_interval: int = 30, 627 autostop_threshold: float = 0.001, 628 autostop_metric: str = None, 629 main_task_on: bool = True, 630 ssl_on: bool = True, 631 temporal_subsampling_size: int = None, 632 loading_bar: bool = False, 633 ) -> Tuple: 634 """Train the task and return a log of epoch-average loss and metric. 635 636 You can use the autostop parameters to finish training when the parameters are not improving. It will be 637 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 638 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 639 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 640 641 Parameters 642 ---------- 643 trial : Trial 644 an `optuna` trial (for hyperparameter searches) 645 optimized_metric : str 646 the name of the metric being optimized (for hyperparameter searches) 647 to_ram : bool, default False 648 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 649 if the dataset is too large) 650 autostop_interval : int, default 50 651 the number of epochs to average the autostop metric over 652 autostop_threshold : float, default 0.001 653 the autostop difference threshold 654 autostop_metric : str, optional 655 the autostop metric (can be any one of the tracked metrics of `'loss'`) 656 main_task_on : bool, default True 657 if `False`, the main task (action segmentation) will not be used in training 658 ssl_on : bool, default True 659 if `False`, the SSL task will not be used in training 660 temporal_subsampling_size : int, optional 661 if not `None`, the temporal subsampling will be used in training with the given size 662 loading_bar : bool, default False 663 if `True`, a loading bar will be displayed 664 665 Returns 666 ------- 667 loss_log: list 668 a list of float loss function values for each epoch 669 metrics_log: dict 670 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 671 names, values are lists of function values) 672 673 """ 674 if self.parallel and not isinstance(self.model, nn.DataParallel): 675 self.model = MyDataParallel(self.model) 676 self.model.to(self.device) 677 assert autostop_metric in [None, "loss"] + list(self.metrics) 678 autostop_interval //= self.validation_interval 679 if trial is not None and optimized_metric is None: 680 raise ValueError( 681 "You need to provide the optimized metric name (optimized_metric parameter) " 682 "for optuna pruning to work!" 683 ) 684 if to_ram: 685 print("transferring datasets to RAM...") 686 self.train_dataloader.dataset.to_ram() 687 if self.val_dataloader is not None and len(self.val_dataloader) > 0: 688 self.val_dataloader.dataset.to_ram() 689 loss_log = {"train": [], "val": []} 690 metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])} 691 if not main_task_on: 692 self.model.main_task_off() 693 if not ssl_on: 694 self.model.ssl_off() 695 while self.epoch < self.num_epochs: 696 self.epoch += 1 697 unlabeled = None 698 alpha = 1 699 if self.pseudolabel: 700 if self.epoch >= self.T1: 701 unlabeled = (self.epoch - self.T1) % self.t != 0 702 if unlabeled: 703 alpha = self._alpha(self.epoch) 704 else: 705 unlabeled = False 706 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 707 dataloader=self.train_dataloader, 708 mode="train", 709 augment_n=self.augment_train, 710 unlabeled=unlabeled, 711 alpha=alpha, 712 temporal_subsampling_size=temporal_subsampling_size, 713 verbose=loading_bar, 714 ) 715 loss_log["train"].append(epoch_loss) 716 epoch_string = f"[epoch {self.epoch}]" 717 if self.pseudolabel: 718 if unlabeled: 719 epoch_string += " (unlabeled)" 720 else: 721 epoch_string += " (labeled)" 722 epoch_string += f": loss {epoch_loss:.4f}" 723 724 if len(epoch_ssl_loss) != 0: 725 for key, value in sorted(epoch_ssl_loss.items()): 726 metrics_log["train"][f"ssl_loss_{key}"].append(value) 727 epoch_string += f", ssl_loss_{key} {value:.4f}" 728 729 for metric_name, metric_value in sorted(epoch_metrics.items()): 730 if metric_name not in self.skip_metrics: 731 if isinstance(metric_value, list): 732 metric_value = torch.mean(torch.Tensor(metric_value)) 733 epoch_string += f", {metric_name} {metric_value:.3f}" 734 metrics_log["train"][metric_name].append(metric_value) 735 736 if ( 737 self.val_dataloader is not None 738 and self.epoch % self.validation_interval == 0 739 ): 740 with torch.no_grad(): 741 epoch_string += "\n" 742 ( 743 val_epoch_loss, 744 val_epoch_ssl_loss, 745 val_epoch_metrics, 746 ) = self._run_epoch( 747 dataloader=self.val_dataloader, 748 mode="val", 749 augment_n=self.augment_val, 750 ) 751 loss_log["val"].append(val_epoch_loss) 752 epoch_string += f"validation: loss {val_epoch_loss:.4f}" 753 754 if len(val_epoch_ssl_loss) != 0: 755 for key, value in sorted(val_epoch_ssl_loss.items()): 756 metrics_log["val"][f"ssl_loss_{key}"].append(value) 757 epoch_string += f", ssl_loss_{key} {value:.4f}" 758 759 for metric_name, metric_value in sorted(val_epoch_metrics.items()): 760 if isinstance(metric_value, list): 761 metric_value = torch.mean(torch.Tensor(metric_value)) 762 metrics_log["val"][metric_name].append(metric_value) 763 epoch_string += f", {metric_name} {metric_value:.3f}" 764 765 if trial is not None: 766 if optimized_metric not in metrics_log["val"]: 767 raise ValueError( 768 f"The {optimized_metric} metric set for optimization is not being logged!" 769 ) 770 trial.report(metrics_log["val"][optimized_metric][-1], self.epoch) 771 if trial.should_prune(): 772 raise TrialPruned() 773 774 if self.verbose: 775 print(epoch_string) 776 777 if self.log_file is not None: 778 with open(self.log_file, "a") as f: 779 f.write(epoch_string + "\n") 780 781 save_condition = ( 782 (self.model_save_epochs != 0) 783 and (self.epoch % self.model_save_epochs == 0) 784 ) or (self.epoch == self.num_epochs) 785 786 if self.epoch > 0 and save_condition and self.model_save_path is not None: 787 epoch_s = str(self.epoch).zfill(len(str(self.num_epochs))) 788 self.save_checkpoint( 789 os.path.join(self.model_save_path, f"epoch{epoch_s}.pt") 790 ) 791 792 if self.pseudolabel and self.epoch >= self.T1 and not unlabeled: 793 self._set_pseudolabels() 794 795 if autostop_metric == "loss": 796 if len(loss_log["val"]) > autostop_interval * 2: 797 if ( 798 np.mean(loss_log["val"][-autostop_interval:]) 799 < np.mean( 800 loss_log["val"][-2 * autostop_interval : -autostop_interval] 801 ) 802 + autostop_threshold 803 ): 804 break 805 elif autostop_metric in metrics_log["val"]: 806 if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2: 807 if ( 808 np.mean( 809 metrics_log["val"][autostop_metric][-autostop_interval:] 810 ) 811 < np.mean( 812 metrics_log["val"][autostop_metric][ 813 -2 * autostop_interval : -autostop_interval 814 ] 815 ) 816 + autostop_threshold 817 ): 818 break 819 820 metrics_log = {k: dict(v) for k, v in metrics_log.items()} 821 822 return loss_log, metrics_log
Train the task and return a log of epoch-average loss and metric.
You can use the autostop parameters to finish training when the parameters are not improving. It will be
stopped if the average value of autostop_metric over the last autostop_interval epochs is smaller than
the average over the previous autostop_interval epochs + autostop_threshold. For example, if the
current epoch is 120 and autostop_interval is 50, the averages over epochs 70-120 and 20-70 will be compared.
Parameters
trial : Trial
an optuna trial (for hyperparameter searches)
optimized_metric : str
the name of the metric being optimized (for hyperparameter searches)
to_ram : bool, default False
if True, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
if the dataset is too large)
autostop_interval : int, default 50
the number of epochs to average the autostop metric over
autostop_threshold : float, default 0.001
the autostop difference threshold
autostop_metric : str, optional
the autostop metric (can be any one of the tracked metrics of 'loss')
main_task_on : bool, default True
if False, the main task (action segmentation) will not be used in training
ssl_on : bool, default True
if False, the SSL task will not be used in training
temporal_subsampling_size : int, optional
if not None, the temporal subsampling will be used in training with the given size
loading_bar : bool, default False
if True, a loading bar will be displayed
Returns
loss_log: list a list of float loss function values for each epoch metrics_log: dict a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric names, values are lists of function values)
824 def evaluate_prediction( 825 self, 826 prediction: Union[torch.Tensor, Dict], 827 data: Union[DataLoader, BehaviorDataset, str] = None, 828 batch_size: int = 32, 829 indices: list = None, 830 ) -> Tuple: 831 """Compute metrics for a prediction. 832 833 Parameters 834 ---------- 835 prediction : torch.Tensor 836 the prediction 837 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 838 the data the prediction was made for (if not provided, take the validation dataset) 839 batch_size : int, default 32 840 the batch size 841 842 Returns 843 ------- 844 loss : float 845 the average value of the loss function 846 metric : dict 847 a dictionary of average values of metric functions 848 """ 849 850 if type(data) is not DataLoader: 851 dataset = self._get_dataset(data) 852 data = DataLoader(dataset, shuffle=False, batch_size=batch_size) 853 epoch_loss = 0 854 if isinstance(prediction, dict): 855 num_classes = len(self.behaviors_dict()) 856 length = dataset.len_segment() 857 coords = dataset.annotation_store.get_original_coordinates() 858 for batch in data: 859 main_target = batch["target"] 860 pr_coords = coords[batch["index"]] 861 predicted = torch.zeros((len(pr_coords), num_classes, length)) 862 for i, c in enumerate(pr_coords): 863 video_id = dataset.input_store.get_video_id(c) 864 clip_id = dataset.input_store.get_clip_id(c) 865 start, end = dataset.input_store.get_clip_start_end(c) 866 beh_ind = list(prediction[video_id]["classes"].keys()) 867 pred_tmp = prediction[video_id][clip_id][beh_ind, :] 868 predicted[i, :, : end - start] = pred_tmp[:, start:end] 869 self._compute( 870 [], 871 [], 872 predicted, 873 main_target, 874 skip_loss=True, 875 tag=batch.get("tag"), 876 apply_primary_function=False, 877 ) 878 else: 879 for batch in data: 880 main_target = batch["target"] 881 predicted = prediction[batch["index"]] 882 if not indices is None: 883 indices_new = [indices.index(i) for i in range(len(indices))] 884 predicted = predicted[:, indices_new] 885 self._compute( 886 [], 887 [], 888 predicted, 889 main_target, 890 skip_loss=True, 891 tag=batch.get("tag"), 892 apply_primary_function=False, 893 ) 894 epoch_metrics = self._calculate_metrics() 895 # strings = [ 896 # f"{metric_name} {metric_value:.3f}" 897 # for metric_name, metric_value in epoch_metrics.items() 898 # ] 899 # val_string = ", ".join(sorted(strings)) 900 # print(val_string) 901 return epoch_loss, epoch_metrics
Compute metrics for a prediction.
Parameters
prediction : torch.Tensor the prediction data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data the prediction was made for (if not provided, take the validation dataset) batch_size : int, default 32 the batch size
Returns
loss : float the average value of the loss function metric : dict a dictionary of average values of metric functions
903 def evaluate( 904 self, 905 data: Union[DataLoader, BehaviorDataset, str] = None, 906 augment_n: int = 0, 907 batch_size: int = 32, 908 verbose: bool = True, 909 ) -> Tuple: 910 """Evaluate the Task model. 911 912 Parameters 913 ---------- 914 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 915 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 916 augment_n : int, default 0 917 the number of augmentations to average results over 918 batch_size : int, default 32 919 the batch size 920 verbose : bool, default True 921 if True, the process is reported to standard output 922 923 Returns 924 ------- 925 loss : float 926 the average value of the loss function 927 ssl_loss : float 928 the average value of the SSL loss function 929 metric : dict 930 a dictionary of average values of metric functions 931 932 """ 933 if self.parallel and not isinstance(self.model, nn.DataParallel): 934 self.model = MyDataParallel(self.model) 935 self.model.to(self.device) 936 if type(data) is not DataLoader: 937 data = self._get_dataset(data) 938 data = DataLoader(data, shuffle=False, batch_size=batch_size) 939 with torch.no_grad(): 940 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 941 dataloader=data, mode="val", augment_n=augment_n, verbose=verbose 942 ) 943 val_string = f"loss {epoch_loss:.4f}" 944 for metric_name, metric_value in sorted(epoch_metrics.items()): 945 val_string += f", {metric_name} {metric_value:.3f}" 946 print(val_string) 947 return epoch_loss, epoch_ssl_loss, epoch_metrics
Evaluate the Task model.
Parameters
data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) augment_n : int, default 0 the number of augmentations to average results over batch_size : int, default 32 the batch size verbose : bool, default True if True, the process is reported to standard output
Returns
loss : float the average value of the loss function ssl_loss : float the average value of the SSL loss function metric : dict a dictionary of average values of metric functions
949 def predict( 950 self, 951 data: Union[DataLoader, BehaviorDataset, str] = None, 952 raw_output: bool = False, 953 apply_primary_function: bool = True, 954 augment_n: int = 0, 955 batch_size: int = 32, 956 train_mode: bool = False, 957 to_ram: bool = False, 958 embedding: bool = False, 959 ) -> torch.Tensor: 960 """Make a prediction with the Task model. 961 962 Parameters 963 ---------- 964 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional 965 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 966 raw_output : bool, default False 967 if `True`, the raw predicted probabilities are returned 968 apply_primary_function : bool, default True 969 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 970 the input) 971 augment_n : int, default 0 972 the number of augmentations to average results over 973 batch_size : int, default 32 974 the batch size 975 train_mode : bool, default False 976 if `True`, the model is used in training mode (affects dropout and batch normalization layers) 977 to_ram : bool, default False 978 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 979 if the dataset is too large) 980 embedding : bool, default False 981 if `True`, the output of feature extractor is returned, ignoring the prediction module of the model 982 983 Returns 984 ------- 985 prediction : torch.Tensor 986 a prediction for the input data 987 988 """ 989 if self.parallel and not isinstance(self.model, nn.DataParallel): 990 self.model = MyDataParallel(self.model) 991 self.model.to(self.device) 992 if train_mode: 993 self.model.train() 994 else: 995 self.model.eval() 996 output = [] 997 if embedding: 998 raw_output = True 999 apply_primary_function = True 1000 if type(data) is not DataLoader: 1001 data = self._get_dataset(data) 1002 if to_ram: 1003 print("transferring dataset to RAM...") 1004 data.to_ram() 1005 data = DataLoader(data, shuffle=False, batch_size=batch_size) 1006 self.model.ssl_off() 1007 with torch.no_grad(): 1008 for batch in tqdm(data): 1009 input = {k: v.to(self.device) for k, v in batch["input"].items()} 1010 predicted, _, _ = self._get_prediction( 1011 input, 1012 batch.get("tag"), 1013 augment_n=augment_n, 1014 embedding=embedding, 1015 ) 1016 if apply_primary_function: 1017 predicted = self.primary_predict_function(predicted) 1018 if not raw_output: 1019 predicted = self.predict_function(predicted) 1020 output.append(predicted.detach().cpu()) 1021 self.model.ssl_on() 1022 output = torch.cat(output).detach() 1023 return output
Make a prediction with the Task model.
Parameters
data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional
the data to evaluate on (if not provided, evaluate on the Task validation dataset)
raw_output : bool, default False
if True, the raw predicted probabilities are returned
apply_primary_function : bool, default True
if True, the primary predict function is applied (to map the model output into a shape corresponding to
the input)
augment_n : int, default 0
the number of augmentations to average results over
batch_size : int, default 32
the batch size
train_mode : bool, default False
if True, the model is used in training mode (affects dropout and batch normalization layers)
to_ram : bool, default False
if True, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
if the dataset is too large)
embedding : bool, default False
if True, the output of feature extractor is returned, ignoring the prediction module of the model
Returns
prediction : torch.Tensor a prediction for the input data
1025 def dataset(self, mode="train") -> BehaviorDataset: 1026 """Get a dataset. 1027 1028 Parameters 1029 ---------- 1030 mode : {'train', 'val', 'test} 1031 the dataset to get 1032 1033 Returns 1034 ------- 1035 dataset : dlc2action.data.dataset.BehaviorDataset 1036 the dataset 1037 1038 """ 1039 dataloader = self.dataloader(mode) 1040 if dataloader is None: 1041 raise ValueError("The length of the dataloader is 0!") 1042 return dataloader.dataset
Get a dataset.
Parameters
mode : {'train', 'val', 'test} the dataset to get
Returns
dataset : dlc2action.data.dataset.BehaviorDataset the dataset
1044 def dataloader(self, mode: str = "train") -> DataLoader: 1045 """Get a dataloader. 1046 1047 Parameters 1048 ---------- 1049 mode : {'train', 'val', 'test} 1050 the dataset to get 1051 1052 Returns 1053 ------- 1054 dataloader : torch.utils.data.DataLoader 1055 the dataloader 1056 1057 """ 1058 if mode == "train": 1059 return self.train_dataloader 1060 elif mode == "val": 1061 return self.val_dataloader 1062 elif mode == "test": 1063 return self.test_dataloader 1064 else: 1065 raise ValueError( 1066 f'The {mode} mode is not recognized, please choose from "train", "val" or "test"' 1067 )
Get a dataloader.
Parameters
mode : {'train', 'val', 'test} the dataset to get
Returns
dataloader : torch.utils.data.DataLoader the dataloader
1097 def generate_full_length_prediction( 1098 self, dataset=None, batch_size=32, augment_n=10 1099 ): 1100 """Compile a prediction for the original input sequences. 1101 1102 Parameters 1103 ---------- 1104 dataset : BehaviorDataset, optional 1105 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1106 batch_size : int, default 32 1107 the batch size 1108 augment_n : int, default 10 1109 the number of augmentations to average results over 1110 1111 Returns 1112 ------- 1113 prediction : dict 1114 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1115 are prediction tensors 1116 1117 """ 1118 dataset = self._get_dataset(dataset) 1119 if not isinstance(dataset, BehaviorDataset): 1120 raise TypeError( 1121 f"The dataset parameter has to be either None, string, " 1122 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1123 ) 1124 predicted = self.predict( 1125 dataset, 1126 raw_output=True, 1127 apply_primary_function=True, 1128 augment_n=augment_n, 1129 batch_size=batch_size, 1130 ) 1131 predicted = dataset.generate_full_length_prediction(predicted) 1132 predicted = { 1133 v_id: { 1134 clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze() 1135 for clip_id, v in video_dict.items() 1136 } 1137 for v_id, video_dict in predicted.items() 1138 } 1139 return predicted
Compile a prediction for the original input sequences.
Parameters
dataset : BehaviorDataset, optional
the dataset to generate a prediction for (if None, generate for the validation dataset)
batch_size : int, default 32
the batch size
augment_n : int, default 10
the number of augmentations to average results over
Returns
prediction : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are prediction tensors
1141 def generate_submission( 1142 self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10 1143 ): 1144 """Generate a MABe-22 style submission dictionary. 1145 1146 Parameters 1147 ---------- 1148 frame_number_map_file : str 1149 the path to the frame number map file 1150 dataset : BehaviorDataset, optional 1151 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1152 batch_size : int, default 32 1153 the batch size 1154 augment_n : int, default 10 1155 the number of augmentations to average results over 1156 1157 Returns 1158 ------- 1159 submission : dict 1160 a dictionary with frame number mapping and embeddings 1161 1162 """ 1163 dataset = self._get_dataset(dataset) 1164 if not isinstance(dataset, BehaviorDataset): 1165 raise TypeError( 1166 f"The dataset parameter has to be either None, string, " 1167 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1168 ) 1169 predicted = self.predict( 1170 dataset, 1171 raw_output=True, 1172 apply_primary_function=True, 1173 augment_n=augment_n, 1174 batch_size=batch_size, 1175 embedding=True, 1176 ) 1177 predicted = dataset.generate_full_length_prediction(predicted) 1178 frame_map = np.load(frame_number_map_file, allow_pickle=True).item() 1179 length = frame_map[list(frame_map.keys())[-1]][1] 1180 embeddings = None 1181 for video_id in list(predicted.keys()): 1182 split = video_id.split("--") 1183 if len(split) != 2 or len(predicted[video_id]) > 1: 1184 raise RuntimeError( 1185 "Generating submissions is only implemented for the mabe22 dataset!" 1186 ) 1187 if split[1] not in frame_map: 1188 raise RuntimeError(f"The {split[1]} video is not in the frame map file") 1189 v_id = split[1] 1190 clip_id = list(predicted[video_id].keys())[0] 1191 if embeddings is None: 1192 embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0])) 1193 start, end = frame_map[v_id] 1194 embeddings[start:end, :] = predicted[video_id][clip_id].T 1195 predicted.pop(video_id) 1196 predicted = { 1197 "frame_number_map": frame_map, 1198 "embeddings": embeddings.astype(np.float32), 1199 } 1200 return predicted
Generate a MABe-22 style submission dictionary.
Parameters
frame_number_map_file : str
the path to the frame number map file
dataset : BehaviorDataset, optional
the dataset to generate a prediction for (if None, generate for the validation dataset)
batch_size : int, default 32
the batch size
augment_n : int, default 10
the number of augmentations to average results over
Returns
submission : dict a dictionary with frame number mapping and embeddings
1458 def visualize_results( 1459 self, 1460 save_path: str = None, 1461 add_legend: bool = True, 1462 ground_truth: bool = True, 1463 colormap: str = "viridis", 1464 hide_axes: bool = False, 1465 min_classes: int = 1, 1466 width: int = 10, 1467 whole_video: bool = False, 1468 transparent: bool = False, 1469 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1470 drop_classes: Set = None, 1471 search_classes: Set = None, 1472 smooth_interval_prediction: int = None, 1473 font_size: float = None, 1474 num_plots: int = 1, 1475 window_size:int =400 1476 ): 1477 """Visualize random predictions. 1478 1479 Parameters 1480 ---------- 1481 save_path : str, optional 1482 the path where the plot will be saved 1483 add_legend : bool, default True 1484 if `True`, legend will be added to the plot 1485 ground_truth : bool, default True 1486 if `True`, ground truth will be added to the plot 1487 colormap : str, default 'Accent' 1488 the `matplotlib` colormap to use 1489 hide_axes : bool, default True 1490 if `True`, the axes will be hidden on the plot 1491 min_classes : int, default 1 1492 the minimum number of classes in a displayed interval 1493 width : float, default 10 1494 the width of the plot 1495 whole_video : bool, default False 1496 if `True`, whole videos are plotted instead of segments 1497 transparent : bool, default False 1498 if `True`, the background on the plot is transparent 1499 dataset : BehaviorDataset | DataLoader | str | None, optional 1500 the dataset to make the prediction for (if not provided, the validation dataset is used) 1501 drop_classes : set, optional 1502 a set of class names to not be displayed 1503 search_classes : set, optional 1504 if given, only intervals where at least one of the classes is in ground truth will be shown 1505 smooth_interval_prediction : int, optional 1506 if given, the prediction will be smoothed with a moving average of the given size 1507 1508 """ 1509 if drop_classes is None: 1510 drop_classes = [] 1511 dataset = self._get_dataset(dataset) 1512 if dataset.annotation_class() != "exclusive_classification": 1513 raise NotImplementedError( 1514 "Results visualisation is only implemented for exclusive classification datasets!" 1515 ) 1516 labels = { 1517 k: v for k, v in dataset.behaviors_dict().items() if v not in drop_classes 1518 } 1519 labels.update({-100: "unknown"}) 1520 label_keys = sorted([int(x) for x in labels.keys()]) 1521 if search_classes is None: 1522 ok = True 1523 else: 1524 ok = False 1525 classes = [] 1526 predicted = self.generate_full_length_prediction(dataset) 1527 keys = list(predicted.keys()) 1528 counter = 0 1529 max_iter = len(keys) * 2 1530 while len(classes) < min_classes or not ok: 1531 counter += 1 1532 if counter > max_iter: 1533 raise RuntimeError( 1534 "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive" 1535 ) 1536 i = randint(0, len(keys) - 1) 1537 prediction = predicted[keys[i]] 1538 keys_i = list(prediction.keys()) 1539 j = randint(0, len(keys_i) - 1) 1540 prediction = prediction[keys_i[j]] 1541 key1 = keys[i] 1542 key2 = keys_i[j] 1543 1544 if smooth_interval_prediction > 0: 1545 unsmoothed_prediction = deepcopy(prediction) 1546 prediction = self._smooth(prediction, smooth_interval_prediction) 1547 height = 3 1548 else: 1549 height = 2 1550 classes = [ 1551 labels[int(x)] for x in torch.unique(prediction) if x in label_keys 1552 ] 1553 if search_classes is not None: 1554 ok = any([x in classes for x in search_classes]) 1555 fig, ax = plt.subplots(figsize=(width, height)) 1556 cmap = cm.get_cmap(colormap) if colormap != "dlc2action" else None 1557 color_list = ( 1558 [cmap(c) for c in np.linspace(0, 1, len(labels))] 1559 if colormap != "dlc2action" 1560 else dlc2action_colormaps["default"] 1561 ) 1562 1563 def _plot_prediction(prediction, height, set_labels=True): 1564 for c in label_keys: 1565 c_i = label_keys.index(int(c)) 1566 output, indices, counts = torch.unique_consecutive( 1567 prediction == c, return_inverse=True, return_counts=True 1568 ) 1569 long_indices = torch.where(output)[0] 1570 if len(long_indices) == 0: 1571 continue 1572 res_indices_start = [ 1573 (indices == i).nonzero(as_tuple=True)[0][0].item() 1574 for i in long_indices 1575 ] 1576 res_indices_end = [ 1577 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1578 for i in long_indices 1579 ] 1580 res_indices_len = [ 1581 end - start 1582 for start, end in zip(res_indices_start, res_indices_end) 1583 ] 1584 if set_labels: 1585 label = labels[int(c)] 1586 else: 1587 label = None 1588 ax.broken_barh( 1589 list(zip(res_indices_start, res_indices_len)), 1590 (height, 1), 1591 label=label, 1592 facecolors=color_list[c_i], 1593 ) 1594 1595 if smooth_interval_prediction > 0: 1596 _plot_prediction(unsmoothed_prediction, 0) 1597 _plot_prediction(prediction, 1.5, set_labels=False) 1598 gt_height = 3 1599 else: 1600 _plot_prediction(prediction, 0) 1601 gt_height = 1.5 1602 if ground_truth: 1603 gt = dataset.generate_full_length_gt()[key1][key2] 1604 for c in label_keys: 1605 c_i = label_keys.index(int(c)) 1606 if labels[int(c)] in classes: 1607 label = None 1608 else: 1609 label = labels[int(c)] 1610 output, indices, counts = torch.unique_consecutive( 1611 gt == c, return_inverse=True, return_counts=True 1612 ) 1613 long_indices = torch.where(output)[0] 1614 if len(long_indices) == 0: 1615 continue 1616 res_indices_start = [ 1617 (indices == i).nonzero(as_tuple=True)[0][0].item() 1618 for i in long_indices 1619 ] 1620 res_indices_end = [ 1621 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1622 for i in long_indices 1623 ] 1624 res_indices_len = [ 1625 end - start 1626 for start, end in zip(res_indices_start, res_indices_end) 1627 ] 1628 ax.broken_barh( 1629 list(zip(res_indices_start, res_indices_len)), 1630 (gt_height, 1), 1631 facecolors=color_list[c_i] if c != "unknown" else "gray", 1632 label=label, 1633 ) 1634 if not ground_truth: 1635 ax.axes.yaxis.set_visible(False) 1636 else: 1637 if smooth_interval_prediction > 0: 1638 ax.set_yticks([0.5, 2, 3.5]) 1639 ax.set_yticklabels(["prediction", "smoothed", "ground truth"]) 1640 else: 1641 ax.set_yticks([0.5, 2]) 1642 ax.set_yticklabels(["prediction", "ground truth"]) 1643 if add_legend: 1644 ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) 1645 if hide_axes: 1646 # plt.axis("off") 1647 plt.box(False) 1648 if font_size is not None: 1649 font = {"size": font_size} 1650 rc("font", **font) 1651 plt.title(f"{key1} -- {key2}") 1652 plt.tight_layout() 1653 for seed in range(num_plots): 1654 if not whole_video: 1655 ax.set_xlim((seed * window_size, (seed + 1) * window_size)) 1656 if save_path is not None: 1657 plt.savefig( 1658 save_path.replace(".svg", f"_{seed}_{key1} -- {key2}.svg"), transparent=transparent 1659 ) 1660 print(f"Saved in {save_path.replace('.svg', f'_{seed}_{key1} -- {key2}.svg')}") 1661 plt.show() 1662 plt.close()
Visualize random predictions.
Parameters
save_path : str, optional
the path where the plot will be saved
add_legend : bool, default True
if True, legend will be added to the plot
ground_truth : bool, default True
if True, ground truth will be added to the plot
colormap : str, default 'Accent'
the matplotlib colormap to use
hide_axes : bool, default True
if True, the axes will be hidden on the plot
min_classes : int, default 1
the minimum number of classes in a displayed interval
width : float, default 10
the width of the plot
whole_video : bool, default False
if True, whole videos are plotted instead of segments
transparent : bool, default False
if True, the background on the plot is transparent
dataset : BehaviorDataset | DataLoader | str | None, optional
the dataset to make the prediction for (if not provided, the validation dataset is used)
drop_classes : set, optional
a set of class names to not be displayed
search_classes : set, optional
if given, only intervals where at least one of the classes is in ground truth will be shown
smooth_interval_prediction : int, optional
if given, the prediction will be smoothed with a moving average of the given size
1664 def set_ssl_transformations(self, ssl_transformations): 1665 """Set SSL transformations. 1666 1667 Parameters 1668 ---------- 1669 ssl_transformations : list 1670 a list of callable SSL transformations 1671 1672 """ 1673 self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations) 1674 if self.val_dataloader is not None: 1675 self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations)
Set SSL transformations.
Parameters
ssl_transformations : list a list of callable SSL transformations
1677 def set_ssl_losses(self, ssl_losses: list) -> None: 1678 """Set SSL losses. 1679 1680 Parameters 1681 ---------- 1682 ssl_losses : list 1683 a list of callable SSL losses 1684 1685 """ 1686 self.ssl_losses = ssl_losses
Set SSL losses.
Parameters
ssl_losses : list a list of callable SSL losses
1688 def set_log(self, log: str) -> None: 1689 """Set the log file. 1690 1691 Parameters 1692 ---------- 1693 log: str 1694 the mew log file path 1695 1696 """ 1697 self.log_file = log
Set the log file.
Parameters
log: str the mew log file path
1699 def set_keep_target_none(self, keep_target_none: List) -> None: 1700 """Set the keep_target_none parameter of the transformer. 1701 1702 Parameters 1703 ---------- 1704 keep_target_none : list 1705 a list of bool values 1706 1707 """ 1708 self.transformer.keep_target_none = keep_target_none
Set the keep_target_none parameter of the transformer.
Parameters
keep_target_none : list a list of bool values
1710 def set_generate_ssl_input(self, generate_ssl_input: list) -> None: 1711 """Set the generate_ssl_input parameter of the transformer. 1712 1713 Parameters 1714 ---------- 1715 generate_ssl_input : list 1716 a list of bool values 1717 1718 """ 1719 self.transformer.generate_ssl_input = generate_ssl_input
Set the generate_ssl_input parameter of the transformer.
Parameters
generate_ssl_input : list a list of bool values
1721 def set_model(self, model: Model) -> None: 1722 """Set a new model. 1723 1724 Parameters 1725 ---------- 1726 model: Model 1727 the new model 1728 1729 """ 1730 self.epoch = 0 1731 self.model = model 1732 self.optimizer = self.optimizer_class( 1733 model.parameters(), lr=self.lr, weight_decay=self.weight_decay 1734 ) 1735 if self.model.process_labels: 1736 self.model.set_behaviors(list(self.behaviors_dict().values()))
Set a new model.
Parameters
model: Model the new model
1738 def set_dataloaders( 1739 self, 1740 train_dataloader: DataLoader, 1741 val_dataloader: DataLoader = None, 1742 test_dataloader: DataLoader = None, 1743 ) -> None: 1744 """Set new dataloaders. 1745 1746 Parameters 1747 ---------- 1748 train_dataloader: torch.utils.data.DataLoader 1749 the new train dataloader 1750 val_dataloader : torch.utils.data.DataLoader 1751 the new validation dataloader 1752 test_dataloader : torch.utils.data.DataLoader 1753 the new test dataloader 1754 1755 """ 1756 self.train_dataloader = train_dataloader 1757 self.val_dataloader = val_dataloader 1758 self.test_dataloader = test_dataloader
Set new dataloaders.
Parameters
train_dataloader: torch.utils.data.DataLoader the new train dataloader val_dataloader : torch.utils.data.DataLoader the new validation dataloader test_dataloader : torch.utils.data.DataLoader the new test dataloader
1760 def set_loss(self, loss: Callable) -> None: 1761 """Set new loss function. 1762 1763 Parameters 1764 ---------- 1765 loss: callable 1766 the new loss function 1767 1768 """ 1769 self.loss = loss
Set new loss function.
Parameters
loss: callable the new loss function
1771 def set_metrics(self, metrics: dict) -> None: 1772 """Set new metric. 1773 1774 Parameters 1775 ---------- 1776 metrics : dict 1777 the new metric dictionary 1778 1779 """ 1780 self.metrics = metrics
Set new metric.
Parameters
metrics : dict the new metric dictionary
1782 def set_transformer(self, transformer: Transformer) -> None: 1783 """Set a new transformer. 1784 1785 Parameters 1786 ---------- 1787 transformer: Transformer 1788 the new transformer 1789 1790 """ 1791 self.transformer = transformer
Set a new transformer.
Parameters
transformer: Transformer the new transformer
1793 def set_predict_functions( 1794 self, primary_predict_function: Callable, predict_function: Callable 1795 ) -> None: 1796 """Set new predict functions. 1797 1798 Parameters 1799 ---------- 1800 primary_predict_function : callable 1801 the new primary predict function 1802 predict_function : callable 1803 the new predict function 1804 1805 """ 1806 self.primary_predict_function = primary_predict_function 1807 self.predict_function = predict_function
Set new predict functions.
Parameters
primary_predict_function : callable the new primary predict function predict_function : callable the new predict function
1829 def count_classes(self, bouts: bool = False) -> Dict: 1830 """Get a dictionary of class counts in different modes. 1831 1832 Parameters 1833 ---------- 1834 bouts : bool, default False 1835 if `True`, instead of frame counts segment counts are returned 1836 1837 Returns 1838 ------- 1839 class_counts : dict 1840 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1841 class names and values are class counts (in frames) 1842 1843 """ 1844 class_counts = {} 1845 for x in ["train", "val", "test"]: 1846 try: 1847 class_counts[x] = self.dataset(x).count_classes(bouts) 1848 except ValueError: 1849 class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()} 1850 return class_counts
Get a dictionary of class counts in different modes.
Parameters
bouts : bool, default False
if True, instead of frame counts segment counts are returned
Returns
class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)
1852 def behaviors_dict(self) -> Dict: 1853 """Get a behavior dictionary. 1854 1855 Keys are label indices and values are label names. 1856 1857 Returns 1858 ------- 1859 behaviors_dict : dict 1860 behavior dictionary 1861 1862 """ 1863 return self.dataset().behaviors_dict()
Get a behavior dictionary.
Keys are label indices and values are label names.
Returns
behaviors_dict : dict behavior dictionary
1865 def update_parameters(self, parameters: Dict) -> None: 1866 """Update training parameters from a dictionary. 1867 1868 Parameters 1869 ---------- 1870 parameters : dict 1871 the update dictionary 1872 1873 """ 1874 self.lr = parameters.get("lr", self.lr) 1875 self.parallel = parameters.get("parallel", self.parallel) 1876 self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr) 1877 self.verbose = parameters.get("verbose", self.verbose) 1878 self.device = parameters.get("device", self.device) 1879 if self.device == "auto": 1880 self.device = "cuda" if torch.cuda.is_available() else "cpu" 1881 self.augment_train = int(parameters.get("augment_train", self.augment_train)) 1882 self.augment_val = int(parameters.get("augment_val", self.augment_val)) 1883 ssl_weights = parameters.get("ssl_weights", self.ssl_weights) 1884 if ssl_weights is None: 1885 ssl_weights = [1 for _ in self.ssl_losses] 1886 if not isinstance(ssl_weights, Iterable): 1887 ssl_weights = [ssl_weights for _ in self.ssl_losses] 1888 self.ssl_weights = ssl_weights 1889 self.num_epochs = parameters.get("num_epochs", self.num_epochs) 1890 self.model_save_epochs = parameters.get( 1891 "model_save_epochs", self.model_save_epochs 1892 ) 1893 self.model_save_path = parameters.get("model_save_path", self.model_save_path) 1894 self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel) 1895 self.T1 = int(parameters.get("pseudolabel_start", self.T1)) 1896 self.T2 = int(parameters.get("alpha_growth_stop", self.T2)) 1897 self.t = int(parameters.get("correction_interval", self.t)) 1898 self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f) 1899 self.log_file = parameters.get("log_file", self.log_file)
Update training parameters from a dictionary.
Parameters
parameters : dict the update dictionary
1901 def generate_uncertainty_score( 1902 self, 1903 classes: List, 1904 augment_n: int = 0, 1905 batch_size: int = 32, 1906 method: str = "least_confidence", 1907 predicted: torch.Tensor = None, 1908 behaviors_dict: Dict = None, 1909 ) -> Dict: 1910 """Generate frame-wise scores for active learning. 1911 1912 Parameters 1913 ---------- 1914 classes : list 1915 a list of class names or indices; their confidence scores will be computed separately and stacked 1916 augment_n : int, default 0 1917 the number of augmentations to average over 1918 batch_size : int, default 32 1919 the batch size 1920 method : {"least_confidence", "entropy"} 1921 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1922 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1923 predicted : torch.Tensor, default None 1924 if not `None`, the predicted tensor to use instead of predicting from the model 1925 behaviors_dict : dict, default None 1926 if not `None`, the behaviors dictionary to use instead of the one from the dataset 1927 1928 Returns 1929 ------- 1930 score_dicts : dict 1931 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1932 are score tensors 1933 1934 """ 1935 dataset = self.dataset("train") 1936 if behaviors_dict is None: 1937 behaviors_dict = self.behaviors_dict() 1938 if not isinstance(dataset, BehaviorDataset): 1939 raise TypeError( 1940 f"The dataset parameter has to be either None, string, " 1941 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1942 ) 1943 if predicted is None: 1944 predicted = self.predict( 1945 dataset, 1946 raw_output=True, 1947 apply_primary_function=True, 1948 augment_n=augment_n, 1949 batch_size=batch_size, 1950 ) 1951 predicted = dataset.generate_full_length_prediction(predicted) 1952 if isinstance(classes[0], str): 1953 behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()} 1954 classes = [behaviors_dict_inverse[c] for c in classes] 1955 for v_id, v in predicted.items(): 1956 for clip_id, vv in v.items(): 1957 if method == "least_confidence": 1958 predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5] 1959 elif method == "entropy": 1960 predicted[v_id][clip_id][vv != -100] = ( 1961 -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv) 1962 )[vv != -100] 1963 elif method == "random": 1964 predicted[v_id][clip_id] = torch.rand(vv.shape) 1965 else: 1966 raise ValueError( 1967 f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']" 1968 ) 1969 predicted[v_id][clip_id][vv == -100] = 0 1970 1971 predicted = { 1972 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 1973 for v_id, video_dict in predicted.items() 1974 } 1975 return predicted
Generate frame-wise scores for active learning.
Parameters
classes : list
a list of class names or indices; their confidence scores will be computed separately and stacked
augment_n : int, default 0
the number of augmentations to average over
batch_size : int, default 32
the batch size
method : {"least_confidence", "entropy"}
the method used to calculate the scores from the probability predictions ("least_confidence": 1 - p_i if
p_i > 0.5 or p_i if p_i < 0.5; "entropy": - p_i * log(p_i) - (1 - p_i) * log(1 - p_i))
predicted : torch.Tensor, default None
if not None, the predicted tensor to use instead of predicting from the model
behaviors_dict : dict, default None
if not None, the behaviors dictionary to use instead of the one from the dataset
Returns
score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors
1977 def generate_bald_score( 1978 self, 1979 classes: List, 1980 augment_n: int = 0, 1981 batch_size: int = 32, 1982 num_models: int = 10, 1983 kernel_size: int = 11, 1984 ) -> Dict: 1985 """Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning. 1986 1987 Parameters 1988 ---------- 1989 classes : list 1990 a list of class names or indices; their confidence scores will be computed separately and stacked 1991 augment_n : int, default 0 1992 the number of augmentations to average over 1993 batch_size : int, default 32 1994 the batch size 1995 num_models : int, default 10 1996 the number of dropout masks to apply 1997 kernel_size : int, default 11 1998 the size of the smoothing gaussian kernel 1999 2000 Returns 2001 ------- 2002 score_dicts : dict 2003 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 2004 are score tensors 2005 2006 """ 2007 dataset = self.dataset("train") 2008 dataset = self._get_dataset(dataset) 2009 if not isinstance(dataset, BehaviorDataset): 2010 raise TypeError( 2011 f"The dataset parameter has to be either None, string, " 2012 f"BehaviorDataset or Dataloader, got {type(dataset)}" 2013 ) 2014 predictions = [] 2015 for _ in range(num_models): 2016 predicted = self.predict( 2017 dataset, 2018 raw_output=True, 2019 apply_primary_function=True, 2020 augment_n=augment_n, 2021 batch_size=batch_size, 2022 train_mode=True, 2023 ) 2024 predicted = dataset.generate_full_length_prediction(predicted) 2025 if isinstance(classes[0], str): 2026 behaviors_dict_inverse = { 2027 v: k for k, v in self.behaviors_dict().items() 2028 } 2029 classes = [behaviors_dict_inverse[c] for c in classes] 2030 for v_id, v in predicted.items(): 2031 for clip_id, vv in v.items(): 2032 vv[vv != -100] = (vv[vv != -100] > 0.5).int().float() 2033 predicted[v_id][clip_id] = vv 2034 predicted = { 2035 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 2036 for v_id, video_dict in predicted.items() 2037 } 2038 predictions.append(predicted) 2039 result = {v_id: {} for v_id in predictions[0]} 2040 r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1) 2041 gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r] 2042 gauss = [x / sum(gauss) for x in gauss] 2043 kernel = torch.FloatTensor([[gauss]]) 2044 for v_id in predictions[0]: 2045 for clip_id in predictions[0][v_id]: 2046 consensus = ( 2047 ( 2048 torch.mean( 2049 torch.stack([x[v_id][clip_id] for x in predictions]), dim=0 2050 ) 2051 > 0.5 2052 ) 2053 .int() 2054 .float() 2055 ) 2056 consensus[predictions[0][v_id][clip_id] == -100] = -100 2057 result[v_id][clip_id] = torch.zeros(consensus.shape) 2058 for x in predictions: 2059 result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int() 2060 result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models 2061 res = torch.zeros(result[v_id][clip_id].shape) 2062 for i in range(len(classes)): 2063 res[i, floor(kernel_size // 2) : -floor(kernel_size // 2)] = ( 2064 torch.nn.functional.conv1d( 2065 result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0), 2066 kernel, 2067 )[0, ...] 2068 ) 2069 result[v_id][clip_id] = res 2070 return result
Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning.
Parameters
classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over batch_size : int, default 32 the batch size num_models : int, default 10 the number of dropout masks to apply kernel_size : int, default 11 the size of the smoothing gaussian kernel
Returns
score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors
2072 def get_normalization_stats(self) -> Optional[Dict]: 2073 """Get the normalization statistics of the dataset. 2074 2075 Returns 2076 ------- 2077 stats : dict 2078 a dictionary containing the mean and standard deviation of the dataset 2079 2080 """ 2081 return self.train_dataloader.dataset.stats
Get the normalization statistics of the dataset.
Returns
stats : dict a dictionary containing the mean and standard deviation of the dataset