dlc2action.task.universal_task
Training and inference
1# 2# Copyright 2020-2022 by A. Mathis Group and contributors. All rights reserved. 3# 4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL. 5# 6""" 7Training and inference 8""" 9 10from typing import Callable, Dict, Union, Tuple, List, Set, Any, Optional 11 12import numpy as np 13from torch.optim import Optimizer, Adam 14from torch import nn 15from torch.utils.data import DataLoader 16import torch 17from collections.abc import Iterable 18from tqdm import tqdm 19import warnings 20from random import randint 21from matplotlib import pyplot as plt 22from matplotlib import cm 23from collections import defaultdict 24from dlc2action.transformer.base_transformer import Transformer, EmptyTransformer 25from dlc2action.data.dataset import BehaviorDataset 26from dlc2action.model.base_model import Model, LoadedModel 27from dlc2action.metric.base_metric import Metric 28import os 29import random 30from copy import deepcopy, copy 31from optuna.trial import Trial 32from optuna import TrialPruned 33from math import pi, sqrt, exp, floor, ceil 34 35 36class MyDataParallel(nn.DataParallel): 37 def __init__(self, *args, **kwargs): 38 super().__init__(*args, **kwargs) 39 self.process_labels = self.module.process_labels 40 41 def freeze_feature_extractor(self): 42 self.module.freeze_feature_extractor() 43 44 def unfreeze_feature_extractor(self): 45 self.module.unfreeze_feature_extractor() 46 47 def transform_labels(self, device): 48 return self.module.transform_labels(device) 49 50 def logit_scale(self): 51 return self.module.logit_scale() 52 53 def main_task_off(self): 54 self.module.main_task_off() 55 56 def state_dict(self, *args, **kwargs): 57 return self.module.state_dict(*args, **kwargs) 58 59 def ssl_on(self): 60 self.module.ssl_on() 61 62 def ssl_off(self): 63 self.module.ssl_off() 64 65 def extract_features(self, x, start=0): 66 return self.module.extract_features(x, start) 67 68 69class Task: 70 """ 71 A universal trainer class that performs training, evaluation and prediction for all types of tasks and data 72 """ 73 74 def __init__( 75 self, 76 train_dataloader: DataLoader, 77 model: Union[nn.Module, Model], 78 loss: Callable[[torch.Tensor, torch.Tensor], float], 79 num_epochs: int = 0, 80 transformer: Transformer = None, 81 ssl_losses: List = None, 82 ssl_weights: List = None, 83 lr: float = 1e-3, 84 metrics: Dict = None, 85 val_dataloader: DataLoader = None, 86 test_dataloader: DataLoader = None, 87 optimizer: Optimizer = None, 88 device: str = "cuda", 89 verbose: bool = True, 90 log_file: Union[str, None] = None, 91 augment_train: int = 1, 92 augment_val: int = 0, 93 validation_interval: int = 1, 94 predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None, 95 primary_predict_function: Callable = None, 96 exclusive: bool = True, 97 ignore_tags: bool = True, 98 threshold: float = 0.5, 99 model_save_path: str = None, 100 model_save_epochs: int = 5, 101 pseudolabel: bool = False, 102 pseudolabel_start: int = 100, 103 correction_interval: int = 2, 104 pseudolabel_alpha_f: float = 3, 105 alpha_growth_stop: int = 600, 106 parallel: bool = False, 107 skip_metrics: List = None, 108 ) -> None: 109 """ 110 Parameters 111 ---------- 112 train_dataloader : torch.utils.data.DataLoader 113 a training dataloader 114 model : dlc2action.model.base_model.Model 115 a model 116 loss : callable 117 a loss function 118 num_epochs : int, default 0 119 the number of epochs 120 transformer : dlc2action.transformer.base_transformer.Transformer, optional 121 a transformer 122 ssl_losses : list, optional 123 a list of SSL losses 124 ssl_weights : list, optional 125 a list of SSL weights (if not provided initializes to 1) 126 lr : float, default 1e-3 127 learning rate 128 metrics : dict, optional 129 a list of metric functions 130 val_dataloader : torch.utils.data.DataLoader, optional 131 a validation dataloader 132 optimizer : torch.optim.Optimizer, optional 133 an optimizer (`Adam` by default) 134 device : str, default 'cuda' 135 the device to train the model on 136 verbose : bool, default True 137 if `True`, the process is described in standard output 138 log_file : str, optional 139 the path to a text file where the process will be logged 140 augment_train : {1, 0} 141 number of augmentations to apply at training 142 augment_val : int, default 0 143 number of augmentations to apply at validation 144 validation_interval : int, default 1 145 every time this number of epochs passes, validation metrics are computed 146 predict_function : callable, optional 147 a function that maps probabilities to class predictions (if not provided, a default is generated) 148 primary_predict_function : callable, optional 149 a function that maps model output to probabilities (if not provided, initialized as identity) 150 exclusive : bool, default True 151 set to False for multi-label classification 152 ignore_tags : bool, default False 153 if `True`, samples with different meta tags will be mixed in batches 154 threshold : float, default 0.5 155 the threshold used for multi-label classification default prediction function 156 model_save_path : str, optional 157 the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path 158 is not provided) 159 model_save_epochs : int, default 5 160 the interval for saving the model checkpoints (the last epoch is always saved) 161 pseudolabel : bool, default False 162 if True, the pseudolabeling procedure will be applied 163 pseudolabel_start : int, default 100 164 pseudolabeling starts after this epoch 165 correction_interval : int, default 1 166 after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and 167 new pseudolabels are generated 168 pseudolabel_alpha_f : float, default 3 169 the maximum value of pseudolabeling alpha 170 alpha_growth_stop : int, default 600 171 pseudolabeling alpha stops growing after this epoch 172 """ 173 174 # pseudolabeling might be buggy right now -- not using it! 175 if skip_metrics is None: 176 skip_metrics = [] 177 self.train_dataloader = train_dataloader 178 self.val_dataloader = val_dataloader 179 self.test_dataloader = test_dataloader 180 self.transformer = transformer 181 self.num_epochs = num_epochs 182 self.skip_metrics = skip_metrics 183 self.verbose = verbose 184 self.augment_train = int(augment_train) 185 self.augment_val = int(augment_val) 186 self.ignore_tags = ignore_tags 187 self.validation_interval = int(validation_interval) 188 self.log_file = log_file 189 self.loss = loss 190 self.model_save_path = model_save_path 191 self.model_save_epochs = model_save_epochs 192 self.epoch = 0 193 194 if metrics is None: 195 metrics = {} 196 self.metrics = metrics 197 198 if optimizer is None: 199 optimizer = Adam 200 201 if ssl_weights is None: 202 ssl_weights = [1 for _ in ssl_losses] 203 if not isinstance(ssl_weights, Iterable): 204 ssl_weights = [ssl_weights for _ in ssl_losses] 205 self.ssl_weights = ssl_weights 206 207 self.optimizer_class = optimizer 208 self.lr = lr 209 if not isinstance(model, Model): 210 self.model = LoadedModel(model=model) 211 else: 212 self.set_model(model) 213 self.parallel = parallel 214 215 if self.transformer is None: 216 self.augment_val = 0 217 self.augment_train = 0 218 self.transformer = EmptyTransformer() 219 220 if self.augment_train > 1: 221 warnings.warn( 222 'The "augment_train" parameter is too large -> setting it to 1.' 223 ) 224 self.augment_train = 1 225 226 try: 227 if device == "auto": 228 device = "cuda" if torch.cuda.is_available() else "cpu" 229 self.device = torch.device(device) 230 except: 231 raise ("The format of the device is incorrect") 232 233 if ssl_losses is None: 234 self.ssl_losses = [lambda x, y: 0] 235 else: 236 self.ssl_losses = ssl_losses 237 238 if primary_predict_function is None: 239 if exclusive: 240 primary_predict_function = lambda x: nn.Softmax(x, dim=1) 241 else: 242 primary_predict_function = lambda x: torch.sigmoid(x) 243 self.primary_predict_function = primary_predict_function 244 245 if predict_function is None: 246 if exclusive: 247 self.predict_function = lambda x: torch.max(x.data, 1)[1] 248 else: 249 self.predict_function = lambda x: (x > threshold).int() 250 else: 251 self.predict_function = predict_function 252 253 self.pseudolabel = pseudolabel 254 self.alpha_f = pseudolabel_alpha_f 255 self.T2 = alpha_growth_stop 256 self.T1 = pseudolabel_start 257 self.t = correction_interval 258 if self.T2 <= self.T1: 259 raise ValueError( 260 f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got " 261 f"{pseudolabel_start=} and {alpha_growth_stop=}" 262 ) 263 self.decision_thresholds = [0.5 for x in self.behaviors_dict()] 264 265 def save_checkpoint(self, checkpoint_path: str) -> None: 266 """ 267 Save a general checkpoint 268 269 Parameters 270 ---------- 271 checkpoint_path : str 272 the path where the checkpoint will be saved 273 """ 274 275 torch.save( 276 { 277 "epoch": self.epoch, 278 "model_state_dict": self.model.state_dict(), 279 "optimizer_state_dict": self.optimizer.state_dict(), 280 }, 281 checkpoint_path, 282 ) 283 284 def load_from_checkpoint( 285 self, checkpoint_path, only_model: bool = False, load_strict: bool = True 286 ) -> None: 287 """ 288 Load from a checkpoint 289 290 Parameters 291 ---------- 292 checkpoint_path : str 293 the path to the checkpoint 294 only_model : bool, default False 295 if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state 296 dictionary) 297 load_strict : bool, default True 298 if `True`, any inconsistencies in state dictionaries are regarded as errors 299 """ 300 301 if checkpoint_path is None: 302 return 303 checkpoint = torch.load(checkpoint_path, map_location=self.device) 304 self.model.to(self.device) 305 self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict) 306 if not only_model: 307 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 308 self.epoch = checkpoint["epoch"] 309 310 def save_model(self, save_path: str) -> None: 311 """ 312 Save the model state dictionary 313 314 Parameters 315 ---------- 316 save_path : str 317 the path where the state will be saved 318 """ 319 320 torch.save(self.model.state_dict(), save_path) 321 print("saved the model successfully") 322 323 def _apply_predict_functions(self, predicted): 324 """ 325 Map from model output to prediction 326 """ 327 328 predicted = self.primary_predict_function(predicted) 329 predicted = self.predict_function(predicted) 330 return predicted 331 332 def _get_prediction( 333 self, 334 main_input: Dict, 335 tags: torch.Tensor, 336 ssl_inputs: List = None, 337 ssl_targets: List = None, 338 augment_n: int = 0, 339 embedding: bool = False, 340 subsample: List = None, 341 ) -> Tuple: 342 """ 343 Get the prediction of `self.model` for input averaged over `augment_n` augmentations 344 """ 345 346 if augment_n == 0: 347 augment_n = 1 348 augment = False 349 else: 350 augment = True 351 model_input, ssl_inputs, ssl_targets = self.transformer.transform( 352 deepcopy(main_input), 353 ssl_inputs=ssl_inputs, 354 ssl_targets=ssl_targets, 355 augment=augment, 356 subsample=subsample, 357 ) 358 if not embedding: 359 predicted, ssl_predicted = self.model(model_input, ssl_inputs, tag=tags) 360 else: 361 predicted = self.model.extract_features(model_input) 362 ssl_predicted = None 363 if self.parallel and predicted is not None and len(predicted.shape) == 4: 364 predicted = predicted.reshape( 365 (-1, model_input.shape[0], *predicted.shape[2:]) 366 ) 367 if predicted is not None and augment_n > 1: 368 self.model.ssl_off() 369 for i in range(augment_n - 1): 370 model_input, *_ = self.transformer.transform( 371 deepcopy(main_input), augment=augment 372 ) 373 if not embedding: 374 pred, _ = self.model(model_input, None, tag=tags) 375 else: 376 pred = self.model.extract_features(model_input) 377 predicted += pred.detach() 378 self.model.ssl_on() 379 predicted /= augment_n 380 if self.model.process_labels: 381 class_embedding = self.model.transform_labels(predicted.device) 382 beh_dict = self.behaviors_dict() 383 class_embedding = torch.cat( 384 [class_embedding[beh_dict[k]] for k in sorted(beh_dict.keys())], 0 385 ) 386 predicted = { 387 "video": predicted, 388 "text": class_embedding, 389 "logit_scale": self.model.logit_scale(), 390 "device": predicted.device, 391 } 392 return predicted, ssl_predicted, ssl_targets 393 394 def _ssl_loss(self, ssl_predicted, ssl_targets): 395 """ 396 Apply SSL losses 397 """ 398 399 ssl_loss = [] 400 for loss, predicted, target in zip(self.ssl_losses, ssl_predicted, ssl_targets): 401 ssl_loss.append(loss(predicted, target)) 402 return ssl_loss 403 404 def _loss_function( 405 self, 406 batch: Dict, 407 augment_n: int, 408 temporal_subsampling_size: int = None, 409 skip_metrics: List = None, 410 ) -> Tuple[float, float]: 411 """ 412 Calculate the loss function and the metric for a dataloader batch 413 414 Averaging the predictions over augment_n augmentations. 415 """ 416 417 if "target" not in batch or torch.isnan(batch["target"]).all(): 418 raise ValueError("Cannot compute loss function with nan targets!") 419 main_input = {k: v.to(self.device) for k, v in batch["input"].items()} 420 main_target, ssl_targets, ssl_inputs = None, None, None 421 main_target = batch["target"].to(self.device) 422 if "ssl_targets" in batch: 423 ssl_targets = [ 424 {k: v.to(self.device) for k, v in x.items()} 425 if isinstance(x, dict) 426 else None 427 for x in batch["ssl_targets"] 428 ] 429 if "ssl_inputs" in batch: 430 ssl_inputs = [ 431 {k: v.to(self.device) for k, v in x.items()} 432 if isinstance(x, dict) 433 else None 434 for x in batch["ssl_inputs"] 435 ] 436 if temporal_subsampling_size is not None: 437 subsample = sorted( 438 random.sample( 439 range(main_target.shape[-1]), 440 int(temporal_subsampling_size * main_target.shape[-1]), 441 ) 442 ) 443 main_target = main_target[..., subsample] 444 else: 445 subsample = None 446 447 if self.ignore_tags: 448 tag = None 449 else: 450 tag = batch.get("tag") 451 predicted, ssl_predicted, ssl_targets = self._get_prediction( 452 main_input, tag, ssl_inputs, ssl_targets, augment_n, subsample=subsample 453 ) 454 del main_input, ssl_inputs 455 return self._compute( 456 ssl_predicted, 457 ssl_targets, 458 predicted, 459 main_target, 460 tag=batch.get("tag"), 461 skip_metrics=skip_metrics, 462 ) 463 464 def _compute( 465 self, 466 ssl_predicted: List, 467 ssl_targets: List, 468 predicted: torch.Tensor, 469 main_target: torch.Tensor, 470 tag: Any = None, 471 skip_loss: bool = False, 472 apply_primary_function: bool = True, 473 skip_metrics: List = None, 474 ) -> Tuple[float, float]: 475 """ 476 Compute the losses and metrics from predictions 477 """ 478 479 if skip_metrics is None: 480 skip_metrics = [] 481 if not skip_loss: 482 ssl_losses = self._ssl_loss(ssl_predicted, ssl_targets) 483 if predicted is not None: 484 loss = self.loss(predicted, main_target) 485 else: 486 loss = 0 487 else: 488 ssl_losses, loss = [], 0 489 490 if predicted is not None: 491 if isinstance(predicted, dict): 492 predicted = { 493 k: v.detach() 494 for k, v in predicted.items() 495 if isinstance(v, torch.Tensor) 496 } 497 else: 498 predicted = predicted.detach() 499 if apply_primary_function: 500 predicted = self.primary_predict_function(predicted) 501 predicted_transformed = self.predict_function(predicted) 502 for name, metric_function in self.metrics.items(): 503 if name not in skip_metrics: 504 if metric_function.needs_raw_data: 505 metric_function.update(predicted, main_target, tag) 506 else: 507 metric_function.update( 508 predicted_transformed, 509 main_target, 510 tag, 511 ) 512 return loss, ssl_losses 513 514 def _calculate_metrics(self) -> Dict: 515 """ 516 Calculate the final values of epoch metrics 517 """ 518 519 epoch_metrics = {} 520 for metric_name, metric in self.metrics.items(): 521 m = metric.calculate() 522 if type(m) is dict: 523 for k, v in m.items(): 524 if type(v) is torch.Tensor: 525 v = v.item() 526 epoch_metrics[f"{metric_name}_{k}"] = v 527 else: 528 if type(m) is torch.Tensor: 529 m = m.item() 530 epoch_metrics[metric_name] = m 531 metric.reset() 532 return epoch_metrics 533 534 def _run_epoch( 535 self, 536 dataloader: DataLoader, 537 mode: str, 538 augment_n: int, 539 verbose: bool = False, 540 unlabeled: bool = None, 541 alpha: float = 1, 542 temporal_subsampling_size: int = None, 543 ) -> Tuple: 544 """ 545 Run one epoch on dataloader 546 547 Averaging the predictions over augment_n augmentations. 548 Use "train" mode for training and "val" mode for evaluation. 549 """ 550 551 if mode == "train": 552 self.model.train() 553 elif mode == "val": 554 self.model.eval() 555 pass 556 else: 557 raise ValueError( 558 f'Mode {mode} is not recognized, please choose either "train" for training or "val" for validation' 559 ) 560 if self.ignore_tags: 561 tags = [None] 562 else: 563 tags = dataloader.dataset.get_tags() 564 epoch_loss = 0 565 epoch_ssl_loss = defaultdict(lambda: 0) 566 data_len = 0 567 set_pars = dataloader.dataset.set_indexing_parameters 568 skip_metrics = self.skip_metrics if mode == "train" else None 569 for tag in tags: 570 set_pars(unlabeled=unlabeled, tag=tag) 571 data_len += len(dataloader) 572 if verbose: 573 dataloader = tqdm(dataloader) 574 for batch in dataloader: 575 loss, ssl_losses = self._loss_function( 576 batch, 577 augment_n, 578 temporal_subsampling_size=temporal_subsampling_size, 579 skip_metrics=skip_metrics, 580 ) 581 if loss != 0: 582 loss = loss * alpha 583 epoch_loss += loss.item() 584 for i, (ssl_loss, weight) in enumerate( 585 zip(ssl_losses, self.ssl_weights) 586 ): 587 if ssl_loss != 0: 588 epoch_ssl_loss[i] += ssl_loss.item() 589 loss = loss + weight * ssl_loss 590 if mode == "train": 591 self.optimizer.zero_grad() 592 if loss.requires_grad: 593 loss.backward() 594 self.optimizer.step() 595 596 epoch_loss = epoch_loss / data_len 597 epoch_ssl_loss = {k: v / data_len for k, v in epoch_ssl_loss.items()} 598 epoch_metrics = self._calculate_metrics() 599 600 return epoch_loss, epoch_ssl_loss, epoch_metrics 601 602 def train( 603 self, 604 trial: Trial = None, 605 optimized_metric: str = None, 606 to_ram: bool = False, 607 autostop_interval: int = 30, 608 autostop_threshold: float = 0.001, 609 autostop_metric: str = None, 610 main_task_on: bool = True, 611 ssl_on: bool = True, 612 temporal_subsampling_size: int = None, 613 loading_bar: bool = False, 614 ) -> Tuple: 615 """ 616 Train the task and return a log of epoch-average loss and metric 617 618 You can use the autostop parameters to finish training when the parameters are not improving. It will be 619 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 620 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 621 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 622 623 Parameters 624 ---------- 625 trial : Trial 626 an `optuna` trial (for hyperparameter searches) 627 optimized_metric : str 628 the name of the metric being optimized (for hyperparameter searches) 629 to_ram : bool, default False 630 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 631 if the dataset is too large) 632 autostop_interval : int, default 50 633 the number of epochs to average the autostop metric over 634 autostop_threshold : float, default 0.001 635 the autostop difference threshold 636 autostop_metric : str, optional 637 the autostop metric (can be any one of the tracked metrics of `'loss'`) 638 main_task_on : bool, default True 639 if `False`, the main task (action segmentation) will not be used in training 640 ssl_on : bool, default True 641 if `False`, the SSL task will not be used in training 642 643 Returns 644 ------- 645 loss_log: list 646 a list of float loss function values for each epoch 647 metrics_log: dict 648 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 649 names, values are lists of function values) 650 """ 651 652 if self.parallel and not isinstance(self.model, nn.DataParallel): 653 self.model = MyDataParallel(self.model) 654 self.model.to(self.device) 655 assert autostop_metric in [None, "loss"] + list(self.metrics) 656 autostop_interval //= self.validation_interval 657 if trial is not None and optimized_metric is None: 658 raise ValueError( 659 "You need to provide the optimized metric name (optimized_metric parameter) " 660 "for optuna pruning to work!" 661 ) 662 if to_ram: 663 print("transferring datasets to RAM...") 664 self.train_dataloader.dataset.to_ram() 665 if self.val_dataloader is not None and len(self.val_dataloader) > 0: 666 self.val_dataloader.dataset.to_ram() 667 loss_log = {"train": [], "val": []} 668 metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])} 669 if not main_task_on: 670 self.model.main_task_off() 671 if not ssl_on: 672 self.model.ssl_off() 673 while self.epoch < self.num_epochs: 674 self.epoch += 1 675 unlabeled = None 676 alpha = 1 677 if self.pseudolabel: 678 if self.epoch >= self.T1: 679 unlabeled = (self.epoch - self.T1) % self.t != 0 680 if unlabeled: 681 alpha = self._alpha(self.epoch) 682 else: 683 unlabeled = False 684 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 685 dataloader=self.train_dataloader, 686 mode="train", 687 augment_n=self.augment_train, 688 unlabeled=unlabeled, 689 alpha=alpha, 690 temporal_subsampling_size=temporal_subsampling_size, 691 verbose=loading_bar, 692 ) 693 loss_log["train"].append(epoch_loss) 694 epoch_string = f"[epoch {self.epoch}]" 695 if self.pseudolabel: 696 if unlabeled: 697 epoch_string += " (unlabeled)" 698 else: 699 epoch_string += " (labeled)" 700 epoch_string += f": loss {epoch_loss:.4f}" 701 702 if len(epoch_ssl_loss) != 0: 703 for key, value in sorted(epoch_ssl_loss.items()): 704 metrics_log["train"][f"ssl_loss_{key}"].append(value) 705 epoch_string += f", ssl_loss_{key} {value:.4f}" 706 707 for metric_name, metric_value in sorted(epoch_metrics.items()): 708 if metric_name not in self.skip_metrics: 709 metrics_log["train"][metric_name].append(metric_value) 710 epoch_string += f", {metric_name} {metric_value:.3f}" 711 712 if ( 713 self.val_dataloader is not None 714 and self.epoch % self.validation_interval == 0 715 ): 716 with torch.no_grad(): 717 epoch_string += "\n" 718 ( 719 val_epoch_loss, 720 val_epoch_ssl_loss, 721 val_epoch_metrics, 722 ) = self._run_epoch( 723 dataloader=self.val_dataloader, 724 mode="val", 725 augment_n=self.augment_val, 726 ) 727 loss_log["val"].append(val_epoch_loss) 728 epoch_string += f"validation: loss {val_epoch_loss:.4f}" 729 730 if len(val_epoch_ssl_loss) != 0: 731 for key, value in sorted(val_epoch_ssl_loss.items()): 732 metrics_log["val"][f"ssl_loss_{key}"].append(value) 733 epoch_string += f", ssl_loss_{key} {value:.4f}" 734 735 for metric_name, metric_value in sorted(val_epoch_metrics.items()): 736 metrics_log["val"][metric_name].append(metric_value) 737 epoch_string += f", {metric_name} {metric_value:.3f}" 738 739 if trial is not None: 740 if optimized_metric not in metrics_log["val"]: 741 raise ValueError( 742 f"The {optimized_metric} metric set for optimization is not being logged!" 743 ) 744 trial.report(metrics_log["val"][optimized_metric][-1], self.epoch) 745 if trial.should_prune(): 746 raise TrialPruned() 747 748 if self.verbose: 749 print(epoch_string) 750 751 if self.log_file is not None: 752 with open(self.log_file, "a") as f: 753 f.write(epoch_string + "\n") 754 755 save_condition = ( 756 (self.model_save_epochs != 0) 757 and (self.epoch % self.model_save_epochs == 0) 758 ) or (self.epoch == self.num_epochs) 759 760 if self.epoch > 0 and save_condition and self.model_save_path is not None: 761 epoch_s = str(self.epoch).zfill(len(str(self.num_epochs))) 762 self.save_checkpoint( 763 os.path.join(self.model_save_path, f"epoch{epoch_s}.pt") 764 ) 765 766 if self.pseudolabel and self.epoch >= self.T1 and not unlabeled: 767 self._set_pseudolabels() 768 769 if autostop_metric == "loss": 770 if len(loss_log["val"]) > autostop_interval * 2: 771 if ( 772 np.mean(loss_log["val"][-autostop_interval:]) 773 < np.mean( 774 loss_log["val"][-2 * autostop_interval : -autostop_interval] 775 ) 776 + autostop_threshold 777 ): 778 break 779 elif autostop_metric in metrics_log["val"]: 780 if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2: 781 if ( 782 np.mean( 783 metrics_log["val"][autostop_metric][-autostop_interval:] 784 ) 785 < np.mean( 786 metrics_log["val"][autostop_metric][ 787 -2 * autostop_interval : -autostop_interval 788 ] 789 ) 790 + autostop_threshold 791 ): 792 break 793 794 metrics_log = {k: dict(v) for k, v in metrics_log.items()} 795 796 return loss_log, metrics_log 797 798 def evaluate_prediction( 799 self, 800 prediction: Union[torch.Tensor, Dict], 801 data: Union[DataLoader, BehaviorDataset, str] = None, 802 batch_size: int = 32, 803 ) -> Tuple: 804 """ 805 Compute metrics for a prediction 806 807 Parameters 808 ---------- 809 prediction : torch.Tensor 810 the prediction 811 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 812 the data the prediction was made for (if not provided, take the validation dataset) 813 batch_size : int, default 32 814 the batch size 815 816 Returns 817 ------- 818 loss : float 819 the average value of the loss function 820 metric : dict 821 a dictionary of average values of metric functions 822 """ 823 824 if type(data) is not DataLoader: 825 dataset = self._get_dataset(data) 826 data = DataLoader(dataset, shuffle=False, batch_size=batch_size) 827 epoch_loss = 0 828 if isinstance(prediction, dict): 829 num_classes = len(self.behaviors_dict()) 830 length = dataset.len_segment() 831 coords = dataset.annotation_store.get_original_coordinates() 832 for batch in data: 833 main_target = batch["target"] 834 pr_coords = coords[batch["index"]] 835 predicted = torch.zeros((len(pr_coords), num_classes, length)) 836 for i, c in enumerate(pr_coords): 837 video_id = dataset.input_store.get_video_id(c) 838 clip_id = dataset.input_store.get_clip_id(c) 839 start, end = dataset.input_store.get_clip_start_end(c) 840 predicted[i, :, : end - start] = prediction[video_id][clip_id][ 841 :, start:end 842 ] 843 self._compute( 844 [], 845 [], 846 predicted, 847 main_target, 848 skip_loss=True, 849 tag=batch.get("tag"), 850 apply_primary_function=False, 851 ) 852 else: 853 for batch in data: 854 main_target = batch["target"] 855 predicted = prediction[batch["index"]] 856 self._compute( 857 [], 858 [], 859 predicted, 860 main_target, 861 skip_loss=True, 862 tag=batch.get("tag"), 863 apply_primary_function=False, 864 ) 865 epoch_metrics = self._calculate_metrics() 866 strings = [ 867 f"{metric_name} {metric_value:.3f}" 868 for metric_name, metric_value in epoch_metrics.items() 869 ] 870 val_string = ", ".join(sorted(strings)) 871 print(val_string) 872 return epoch_loss, epoch_metrics 873 874 def evaluate( 875 self, 876 data: Union[DataLoader, BehaviorDataset, str] = None, 877 augment_n: int = 0, 878 batch_size: int = 32, 879 verbose: bool = True, 880 ) -> Tuple: 881 """ 882 Evaluate the Task model 883 884 Parameters 885 ---------- 886 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 887 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 888 augment_n : int, default 0 889 the number of augmentations to average results over 890 batch_size : int, default 32 891 the batch size 892 verbose : bool, default True 893 if True, the process is reported to standard output 894 895 Returns 896 ------- 897 loss : float 898 the average value of the loss function 899 ssl_loss : float 900 the average value of the SSL loss function 901 metric : dict 902 a dictionary of average values of metric functions 903 """ 904 905 if self.parallel and not isinstance(self.model, nn.DataParallel): 906 self.model = MyDataParallel(self.model) 907 self.model.to(self.device) 908 if type(data) is not DataLoader: 909 data = self._get_dataset(data) 910 data = DataLoader(data, shuffle=False, batch_size=batch_size) 911 with torch.no_grad(): 912 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 913 dataloader=data, mode="val", augment_n=augment_n, verbose=verbose 914 ) 915 val_string = f"loss {epoch_loss:.4f}" 916 for metric_name, metric_value in sorted(epoch_metrics.items()): 917 val_string += f", {metric_name} {metric_value:.3f}" 918 print(val_string) 919 return epoch_loss, epoch_ssl_loss, epoch_metrics 920 921 def predict( 922 self, 923 data: Union[DataLoader, BehaviorDataset, str] = None, 924 raw_output: bool = False, 925 apply_primary_function: bool = True, 926 augment_n: int = 0, 927 batch_size: int = 32, 928 train_mode: bool = False, 929 to_ram: bool = False, 930 embedding: bool = False, 931 ) -> torch.Tensor: 932 """ 933 Make a prediction with the Task model 934 935 Parameters 936 ---------- 937 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional 938 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 939 raw_output : bool, default False 940 if `True`, the raw predicted probabilities are returned 941 apply_primary_function : bool, default True 942 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 943 the input) 944 augment_n : int, default 0 945 the number of augmentations to average results over 946 batch_size : int, default 32 947 the batch size 948 train_mode : bool, default False 949 if `True`, the model is used in training mode (affects dropout and batch normalization layers) 950 to_ram : bool, default False 951 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 952 if the dataset is too large) 953 embedding : bool, default False 954 if `True`, the output of feature extractor is returned, ignoring the prediction module of the model 955 956 Returns 957 ------- 958 prediction : torch.Tensor 959 a prediction for the input data 960 """ 961 962 if self.parallel and not isinstance(self.model, nn.DataParallel): 963 self.model = MyDataParallel(self.model) 964 self.model.to(self.device) 965 if train_mode: 966 self.model.train() 967 else: 968 self.model.eval() 969 output = [] 970 if embedding: 971 raw_output = True 972 apply_primary_function = True 973 if type(data) is not DataLoader: 974 data = self._get_dataset(data) 975 if to_ram: 976 print("transferring dataset to RAM...") 977 data.to_ram() 978 data = DataLoader(data, shuffle=False, batch_size=batch_size) 979 self.model.ssl_off() 980 with torch.no_grad(): 981 for batch in tqdm(data): 982 input = {k: v.to(self.device) for k, v in batch["input"].items()} 983 predicted, _, _ = self._get_prediction( 984 input, 985 batch.get("tag"), 986 augment_n=augment_n, 987 embedding=embedding, 988 ) 989 if apply_primary_function: 990 predicted = self.primary_predict_function(predicted) 991 if not raw_output: 992 predicted = self.predict_function(predicted) 993 output.append(predicted.detach().cpu()) 994 self.model.ssl_on() 995 output = torch.cat(output).detach() 996 return output 997 998 def dataset(self, mode="train") -> BehaviorDataset: 999 """ 1000 Get a dataset 1001 1002 Parameters 1003 ---------- 1004 mode : {'train', 'val', 'test} 1005 the dataset to get 1006 1007 Returns 1008 ------- 1009 dataset : dlc2action.data.dataset.BehaviorDataset 1010 the dataset 1011 """ 1012 1013 dataloader = self.dataloader(mode) 1014 if dataloader is None: 1015 raise ValueError("The length of the dataloader is 0!") 1016 return dataloader.dataset 1017 1018 def dataloader(self, mode: str = "train") -> DataLoader: 1019 """ 1020 Get a dataloader 1021 1022 Parameters 1023 ---------- 1024 mode : {'train', 'val', 'test} 1025 the dataset to get 1026 1027 Returns 1028 ------- 1029 dataloader : torch.utils.data.DataLoader 1030 the dataloader 1031 """ 1032 1033 if mode == "train": 1034 return self.train_dataloader 1035 elif mode == "val": 1036 return self.val_dataloader 1037 elif mode == "test": 1038 return self.test_dataloader 1039 else: 1040 raise ValueError( 1041 f'The {mode} mode is not recognized, please choose from "train", "val" or "test"' 1042 ) 1043 1044 def _get_dataset(self, dataset): 1045 """ 1046 Get a dataset from a dataloader, a string ('train', 'test' or 'val') or `None` (default) 1047 """ 1048 1049 if dataset is None: 1050 dataset = self.dataset("val") 1051 elif dataset in ["train", "val", "test"]: 1052 dataset = self.dataset(dataset) 1053 elif type(dataset) is DataLoader: 1054 dataset = dataset.dataset 1055 if type(dataset) is BehaviorDataset: 1056 return dataset 1057 else: 1058 raise TypeError(f"The {type(dataset)} type of dataset is not recognized!") 1059 1060 def _get_dataloader(self, dataset): 1061 """ 1062 Get a dataloader from a dataset, a string ('train', 'test' or 'val') or `None` (default) 1063 """ 1064 1065 if dataset is None: 1066 dataset = self.dataloader("val") 1067 elif dataset in ["train", "val", "test"]: 1068 dataset = self.dataloader(dataset) 1069 if dataset is None: 1070 raise ValueError(f"The length of the dataloader is 0!") 1071 elif type(dataset) is BehaviorDataset: 1072 dataset = DataLoader(dataset) 1073 if type(dataset) is DataLoader: 1074 return dataset 1075 else: 1076 raise TypeError(f"The {type(dataset)} type of dataset is not recognized!") 1077 1078 def generate_full_length_prediction( 1079 self, dataset=None, batch_size=32, augment_n=10 1080 ): 1081 """ 1082 Compile a prediction for the original input sequences 1083 1084 Parameters 1085 ---------- 1086 dataset : BehaviorDataset, optional 1087 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1088 batch_size : int, default 32 1089 the batch size 1090 augment_n : int, default 10 1091 the number of augmentations to average results over 1092 1093 Returns 1094 ------- 1095 prediction : dict 1096 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1097 are prediction tensors 1098 """ 1099 1100 dataset = self._get_dataset(dataset) 1101 if not isinstance(dataset, BehaviorDataset): 1102 raise TypeError( 1103 f"The dataset parameter has to be either None, string, " 1104 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1105 ) 1106 predicted = self.predict( 1107 dataset, 1108 raw_output=True, 1109 apply_primary_function=True, 1110 augment_n=augment_n, 1111 batch_size=batch_size, 1112 ) 1113 predicted = dataset.generate_full_length_prediction(predicted) 1114 predicted = { 1115 v_id: { 1116 clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze() 1117 for clip_id, v in video_dict.items() 1118 } 1119 for v_id, video_dict in predicted.items() 1120 } 1121 return predicted 1122 1123 def generate_submission( 1124 self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10 1125 ): 1126 """ 1127 Generate a MABe-22 style submission dictionary 1128 1129 Parameters 1130 ---------- 1131 dataset : BehaviorDataset, optional 1132 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1133 batch_size : int, default 32 1134 the batch size 1135 augment_n : int, default 10 1136 the number of augmentations to average results over 1137 1138 Returns 1139 ------- 1140 submission : dict 1141 a dictionary with frame number mapping and embeddings 1142 """ 1143 1144 dataset = self._get_dataset(dataset) 1145 if not isinstance(dataset, BehaviorDataset): 1146 raise TypeError( 1147 f"The dataset parameter has to be either None, string, " 1148 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1149 ) 1150 predicted = self.predict( 1151 dataset, 1152 raw_output=True, 1153 apply_primary_function=True, 1154 augment_n=augment_n, 1155 batch_size=batch_size, 1156 embedding=True, 1157 ) 1158 predicted = dataset.generate_full_length_prediction(predicted) 1159 frame_map = np.load(frame_number_map_file, allow_pickle=True).item() 1160 length = frame_map[list(frame_map.keys())[-1]][1] 1161 embeddings = None 1162 for video_id in list(predicted.keys()): 1163 split = video_id.split("--") 1164 if len(split) != 2 or len(predicted[video_id]) > 1: 1165 raise RuntimeError( 1166 "Generating submissions is only implemented for the mabe22 dataset!" 1167 ) 1168 if split[1] not in frame_map: 1169 raise RuntimeError(f"The {split[1]} video is not in the frame map file") 1170 v_id = split[1] 1171 clip_id = list(predicted[video_id].keys())[0] 1172 if embeddings is None: 1173 embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0])) 1174 start, end = frame_map[v_id] 1175 embeddings[start:end, :] = predicted[video_id][clip_id].T 1176 predicted.pop(video_id) 1177 predicted = { 1178 "frame_number_map": frame_map, 1179 "embeddings": embeddings.astype(np.float32), 1180 } 1181 return predicted 1182 1183 def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor: 1184 """ 1185 Get a list of True group beginning and end indices from a boolean tensor 1186 """ 1187 1188 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 1189 true_indices = torch.where(output)[0] 1190 starts = torch.tensor( 1191 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 1192 ) 1193 ends = torch.tensor( 1194 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 1195 ) 1196 return torch.stack([starts, ends]).T 1197 1198 def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor: 1199 """ 1200 Get rid of jittering in a non-exclusive classification tensor 1201 1202 First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than 1203 `smooth_interval`. 1204 """ 1205 1206 if len(tensor.shape) > 1: 1207 for c in tensor.shape[1]: 1208 intervals = self._get_intervals(tensor[:, c] == 0) 1209 interval_lengths = torch.tensor( 1210 [interval[1] - interval[0] for interval in intervals] 1211 ) 1212 short_intervals = intervals[interval_lengths <= smooth_interval] 1213 for start, end in short_intervals: 1214 tensor[start:end, c] = 1 1215 intervals = self._get_intervals(tensor[:, c] == 1) 1216 interval_lengths = torch.tensor( 1217 [interval[1] - interval[0] for interval in intervals] 1218 ) 1219 short_intervals = intervals[interval_lengths <= smooth_interval] 1220 for start, end in short_intervals: 1221 tensor[start:end, c] = 0 1222 else: 1223 for c in tensor.unique(): 1224 intervals = self._get_intervals(tensor == c) 1225 interval_lengths = torch.tensor( 1226 [interval[1] - interval[0] for interval in intervals] 1227 ) 1228 short_intervals = intervals[interval_lengths <= smooth_interval] 1229 for start, end in short_intervals: 1230 if start == 0: 1231 tensor[start:end] = tensor[end + 1] 1232 else: 1233 tensor[start:end] = tensor[start - 1] 1234 return tensor 1235 1236 def _visualize_results_label( 1237 self, 1238 label: str, 1239 save_path: str = None, 1240 add_legend: bool = True, 1241 ground_truth: bool = True, 1242 hide_axes: bool = False, 1243 width: int = 10, 1244 whole_video: bool = False, 1245 transparent: bool = False, 1246 dataset: BehaviorDataset = None, 1247 smooth_interval: int = 0, 1248 title: str = None, 1249 ): 1250 """ 1251 Visualize random predictions 1252 1253 Parameters 1254 ---------- 1255 save_path : str, optional 1256 the path where the plot will be saved 1257 add_legend : bool, default True 1258 if `True`, legend will be added to the plot 1259 ground_truth : bool, default True 1260 if `True`, ground truth will be added to the plot 1261 colormap : str, default 'viridis' 1262 the `matplotlib` colormap to use 1263 hide_axes : bool, default True 1264 if `True`, the axes will be hidden on the plot 1265 min_classes : int, default 1 1266 the minimum number of classes in a displayed interval 1267 width : float, default 10 1268 the width of the plot 1269 whole_video : bool, default False 1270 if `True`, whole videos are plotted instead of segments 1271 transparent : bool, default False 1272 if `True`, the background on the plot is transparent 1273 dataset : BehaviorDataset, optional 1274 the dataset to make the prediction for (if not provided, the validation dataset is used) 1275 drop_classes : set, optional 1276 a set of class names to not be displayed 1277 """ 1278 1279 if title is None: 1280 title = "" 1281 dataset = self._get_dataset(dataset) 1282 inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()} 1283 label_ind = inverse_dict[label] 1284 labels = {1: label, -100: "unknown"} 1285 label_keys = [1, -100] 1286 color_list = ["blue", "gray"] 1287 if whole_video: 1288 predicted = self.generate_full_length_prediction(dataset) 1289 keys = list(predicted.keys()) 1290 counter = 0 1291 if whole_video: 1292 max_iter = len(keys) * 5 1293 else: 1294 max_iter = len(dataset) * 5 1295 ok = False 1296 while not ok: 1297 counter += 1 1298 if counter > max_iter: 1299 raise RuntimeError( 1300 "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive" 1301 ) 1302 if whole_video: 1303 i = randint(0, len(keys) - 1) 1304 prediction = predicted[keys[i]] 1305 keys_i = list(prediction.keys()) 1306 j = randint(0, len(keys_i) - 1) 1307 full_p = prediction[keys_i[j]] 1308 prediction = prediction[keys_i[j]][label_ind] 1309 else: 1310 dataloader = DataLoader(dataset) 1311 i = randint(0, len(dataloader) - 1) 1312 for num, batch in enumerate(dataloader): 1313 if num == i: 1314 break 1315 input_data = {k: v.to(self.device) for k, v in batch["input"].items()} 1316 prediction, *_ = self._get_prediction( 1317 input_data, batch.get("tag"), augment_n=5 1318 ) 1319 prediction = self._apply_predict_functions(prediction) 1320 j = randint(0, len(prediction) - 1) 1321 full_p = prediction[j] 1322 prediction = prediction[j][label_ind] 1323 classes = [x for x in torch.unique(prediction) if int(x) in label_keys] 1324 ok = 1 in classes 1325 fig, ax = plt.subplots(figsize=(width, 2)) 1326 for c in classes: 1327 c_i = label_keys.index(int(c)) 1328 output, indices, counts = torch.unique_consecutive( 1329 prediction == c, return_inverse=True, return_counts=True 1330 ) 1331 long_indices = torch.where(output)[0] 1332 res_indices_start = [ 1333 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 1334 ] 1335 res_indices_end = [ 1336 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1337 for i in long_indices 1338 ] 1339 res_indices_len = [ 1340 end - start for start, end in zip(res_indices_start, res_indices_end) 1341 ] 1342 ax.broken_barh( 1343 list(zip(res_indices_start, res_indices_len)), 1344 (0, 1), 1345 label=labels[int(c)], 1346 facecolors=color_list[c_i], 1347 ) 1348 if ground_truth: 1349 gt = batch["target"][j][label_ind].to(self.device) 1350 classes_gt = [x for x in torch.unique(gt) if int(x) in label_keys] 1351 for c in classes_gt: 1352 c_i = label_keys.index(int(c)) 1353 if c in classes: 1354 label = None 1355 else: 1356 label = labels[int(c)] 1357 output, indices, counts = torch.unique_consecutive( 1358 gt == c, return_inverse=True, return_counts=True 1359 ) 1360 long_indices = torch.where(output * (counts > 5))[0] 1361 res_indices_start = [ 1362 (indices == i).nonzero(as_tuple=True)[0][0].item() 1363 for i in long_indices 1364 ] 1365 res_indices_end = [ 1366 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1367 for i in long_indices 1368 ] 1369 res_indices_len = [ 1370 end - start 1371 for start, end in zip(res_indices_start, res_indices_end) 1372 ] 1373 ax.broken_barh( 1374 list(zip(res_indices_start, res_indices_len)), 1375 (1.5, 1), 1376 facecolors=color_list[c_i], 1377 label=label, 1378 ) 1379 self._compute( 1380 main_target=batch["target"][j].unsqueeze(0).to(self.device), 1381 predicted=full_p.unsqueeze(0).to(self.device), 1382 ssl_targets=[], 1383 ssl_predicted=[], 1384 skip_loss=True, 1385 ) 1386 metrics = self._calculate_metrics() 1387 if smooth_interval > 0: 1388 smoothed = self._smooth(full_p, smooth_interval=smooth_interval)[ 1389 label_ind, : 1390 ] 1391 for c in classes: 1392 c_i = label_keys.index(int(c)) 1393 output, indices, counts = torch.unique_consecutive( 1394 smoothed == c, return_inverse=True, return_counts=True 1395 ) 1396 long_indices = torch.where(output)[0] 1397 res_indices_start = [ 1398 (indices == i).nonzero(as_tuple=True)[0][0].item() 1399 for i in long_indices 1400 ] 1401 res_indices_end = [ 1402 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1403 for i in long_indices 1404 ] 1405 res_indices_len = [ 1406 end - start 1407 for start, end in zip(res_indices_start, res_indices_end) 1408 ] 1409 ax.broken_barh( 1410 list(zip(res_indices_start, res_indices_len)), 1411 (3, 1), 1412 label=labels[int(c)], 1413 facecolors=color_list[c_i], 1414 ) 1415 keys = list(metrics.keys()) 1416 for key in keys: 1417 if key.split("_")[-1] != (str(label_ind)): 1418 metrics.pop(key) 1419 title = [title] 1420 for key, value in metrics.items(): 1421 title.append(f"{'_'.join(key.split('_')[: -1])}: {value:.2f}") 1422 title = ", ".join(title) 1423 if not ground_truth: 1424 ax.axes.yaxis.set_visible(False) 1425 else: 1426 ax.set_yticks([0.5, 2]) 1427 ax.set_yticklabels(["prediction", "ground truth"]) 1428 if add_legend: 1429 ax.legend() 1430 if hide_axes: 1431 plt.axis("off") 1432 plt.title(title) 1433 plt.xlim((0, len(prediction))) 1434 if save_path is not None: 1435 plt.savefig(save_path, transparent=transparent) 1436 plt.show() 1437 1438 def visualize_results( 1439 self, 1440 save_path: str = None, 1441 add_legend: bool = True, 1442 ground_truth: bool = True, 1443 colormap: str = "viridis", 1444 hide_axes: bool = False, 1445 min_classes: int = 1, 1446 width: int = 10, 1447 whole_video: bool = False, 1448 transparent: bool = False, 1449 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1450 drop_classes: Set = None, 1451 search_classes: Set = None, 1452 num_samples: int = 1, 1453 smooth_interval_prediction: int = None, 1454 behavior_name: str = None, 1455 ): 1456 """ 1457 Visualize random predictions 1458 1459 Parameters 1460 ---------- 1461 save_path : str, optional 1462 the path where the plot will be saved 1463 add_legend : bool, default True 1464 if `True`, legend will be added to the plot 1465 ground_truth : bool, default True 1466 if `True`, ground truth will be added to the plot 1467 colormap : str, default 'Accent' 1468 the `matplotlib` colormap to use 1469 hide_axes : bool, default True 1470 if `True`, the axes will be hidden on the plot 1471 min_classes : int, default 1 1472 the minimum number of classes in a displayed interval 1473 width : float, default 10 1474 the width of the plot 1475 whole_video : bool, default False 1476 if `True`, whole videos are plotted instead of segments 1477 transparent : bool, default False 1478 if `True`, the background on the plot is transparent 1479 dataset : BehaviorDataset | DataLoader | str | None, optional 1480 the dataset to make the prediction for (if not provided, the validation dataset is used) 1481 drop_classes : set, optional 1482 a set of class names to not be displayed 1483 search_classes : set, optional 1484 if given, only intervals where at least one of the classes is in ground truth will be shown 1485 """ 1486 1487 if drop_classes is None: 1488 drop_classes = [] 1489 dataset = self._get_dataset(dataset) 1490 if dataset.annotation_class() == "exclusive_classification": 1491 exclusive = True 1492 elif dataset.annotation_class() == "nonexclusive_classification": 1493 exclusive = False 1494 else: 1495 raise NotImplementedError( 1496 f"Results visualisation is not implemented for {dataset.annotation_class() } datasets!" 1497 ) 1498 if exclusive: 1499 labels = { 1500 k: v 1501 for k, v in dataset.behaviors_dict().items() 1502 if v not in drop_classes 1503 } 1504 labels.update({-100: "unknown"}) 1505 else: 1506 inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()} 1507 if behavior_name is None: 1508 behavior_name = list(inverse_dict.keys())[0] 1509 label_ind = inverse_dict[behavior_name] 1510 labels = {1: label, -100: "unknown"} 1511 label_keys = sorted([int(x) for x in labels.keys()]) 1512 if search_classes is None: 1513 ok = True 1514 else: 1515 ok = False 1516 classes = [] 1517 if whole_video: 1518 predicted = self.generate_full_length_prediction(dataset) 1519 keys = list(predicted.keys()) 1520 counter = 0 1521 if whole_video: 1522 max_iter = len(keys) * 2 1523 else: 1524 max_iter = len(dataset) * 2 1525 while len(classes) < min_classes or not ok: 1526 counter += 1 1527 if counter > max_iter: 1528 raise RuntimeError( 1529 "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive" 1530 ) 1531 if whole_video: 1532 i = randint(0, len(keys) - 1) 1533 prediction = predicted[keys[i]] 1534 keys_i = list(prediction.keys()) 1535 j = randint(0, len(keys_i) - 1) 1536 prediction = prediction[keys_i[j]] 1537 key1 = keys[i] 1538 key2 = keys_i[j] 1539 else: 1540 dataloader = DataLoader(dataset) 1541 i = randint(0, len(dataloader) - 1) 1542 for num, batch in enumerate(dataloader): 1543 if num == i: 1544 break 1545 input = {k: v.to(self.device) for k, v in batch["input"].items()} 1546 prediction, *_ = self._get_prediction( 1547 input, batch.get("tag"), augment_n=5 1548 ) 1549 prediction = self._apply_predict_functions(prediction) 1550 j = randint(0, len(prediction) - 1) 1551 prediction = prediction[j] 1552 if not exclusive: 1553 prediction = prediction[label_ind] 1554 if smooth_interval_prediction > 0: 1555 unsmoothed_prediction = deepcopy(prediction) 1556 prediction = self._smooth(prediction, smooth_interval_prediction) 1557 height = 3 1558 else: 1559 height = 2 1560 classes = [ 1561 labels[int(x)] for x in torch.unique(prediction) if x in label_keys 1562 ] 1563 if search_classes is not None: 1564 ok = any([x in classes for x in search_classes]) 1565 fig, ax = plt.subplots(figsize=(width, height)) 1566 color_list = cm.get_cmap(colormap, lut=len(labels)).colors 1567 1568 def _plot_prediction(prediction, height, set_labels=True): 1569 for c in label_keys: 1570 c_i = label_keys.index(int(c)) 1571 output, indices, counts = torch.unique_consecutive( 1572 prediction == c, return_inverse=True, return_counts=True 1573 ) 1574 long_indices = torch.where(output)[0] 1575 if len(long_indices) == 0: 1576 continue 1577 res_indices_start = [ 1578 (indices == i).nonzero(as_tuple=True)[0][0].item() 1579 for i in long_indices 1580 ] 1581 res_indices_end = [ 1582 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1583 for i in long_indices 1584 ] 1585 res_indices_len = [ 1586 end - start 1587 for start, end in zip(res_indices_start, res_indices_end) 1588 ] 1589 if set_labels: 1590 label = labels[int(c)] 1591 else: 1592 label = None 1593 ax.broken_barh( 1594 list(zip(res_indices_start, res_indices_len)), 1595 (height, 1), 1596 label=label, 1597 facecolors=color_list[c_i], 1598 ) 1599 1600 if smooth_interval_prediction > 0: 1601 _plot_prediction(unsmoothed_prediction, 0) 1602 _plot_prediction(prediction, 1.5, set_labels=False) 1603 gt_height = 3 1604 else: 1605 _plot_prediction(prediction, 0) 1606 gt_height = 1.5 1607 if ground_truth: 1608 if not whole_video: 1609 gt = batch["target"][j].to(self.device) 1610 else: 1611 gt = dataset.generate_full_length_gt()[key1][key2] 1612 for c in label_keys: 1613 c_i = label_keys.index(int(c)) 1614 if labels[int(c)] in classes: 1615 label = None 1616 else: 1617 label = labels[int(c)] 1618 output, indices, counts = torch.unique_consecutive( 1619 gt == c, return_inverse=True, return_counts=True 1620 ) 1621 long_indices = torch.where(output)[0] 1622 if len(long_indices) == 0: 1623 continue 1624 res_indices_start = [ 1625 (indices == i).nonzero(as_tuple=True)[0][0].item() 1626 for i in long_indices 1627 ] 1628 res_indices_end = [ 1629 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1630 for i in long_indices 1631 ] 1632 res_indices_len = [ 1633 end - start 1634 for start, end in zip(res_indices_start, res_indices_end) 1635 ] 1636 ax.broken_barh( 1637 list(zip(res_indices_start, res_indices_len)), 1638 (gt_height, 1), 1639 facecolors=color_list[c_i] if c != "unknown" else "gray", 1640 label=label, 1641 ) 1642 if not ground_truth: 1643 ax.axes.yaxis.set_visible(False) 1644 else: 1645 if smooth_interval_prediction > 0: 1646 ax.set_yticks([0.5, 2, 3.5]) 1647 ax.set_yticklabels(["prediction", "smoothed", "ground truth"]) 1648 else: 1649 ax.set_yticks([0.5, 2]) 1650 ax.set_yticklabels(["prediction", "ground truth"]) 1651 if add_legend: 1652 ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) 1653 if hide_axes: 1654 plt.axis("off") 1655 if save_path is not None: 1656 plt.savefig(save_path, transparent=transparent) 1657 plt.show() 1658 1659 def set_ssl_transformations(self, ssl_transformations): 1660 self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations) 1661 if self.val_dataloader is not None: 1662 self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations) 1663 1664 def set_ssl_losses(self, ssl_losses: list) -> None: 1665 """ 1666 Set SSL losses 1667 1668 Parameters 1669 ---------- 1670 ssl_losses : list 1671 a list of callable SSL losses 1672 """ 1673 1674 self.ssl_losses = ssl_losses 1675 1676 def set_log(self, log: str) -> None: 1677 """ 1678 Set the log file 1679 1680 Parameters 1681 ---------- 1682 log: str 1683 the mew log file path 1684 """ 1685 1686 self.log_file = log 1687 1688 def set_keep_target_none(self, keep_target_none: List) -> None: 1689 """ 1690 Set the keep_target_none parameter of the transformer 1691 1692 Parameters 1693 ---------- 1694 keep_target_none : list 1695 a list of bool values 1696 """ 1697 1698 self.transformer.keep_target_none = keep_target_none 1699 1700 def set_generate_ssl_input(self, generate_ssl_input: list) -> None: 1701 """ 1702 Set the generate_ssl_input parameter of the transformer 1703 1704 Parameters 1705 ---------- 1706 generate_ssl_input : list 1707 a list of bool values 1708 """ 1709 1710 self.transformer.generate_ssl_input = generate_ssl_input 1711 1712 def set_model(self, model: Model) -> None: 1713 """ 1714 Set a new model 1715 1716 Parameters 1717 ---------- 1718 model: Model 1719 the new model 1720 """ 1721 1722 self.epoch = 0 1723 self.model = model 1724 self.optimizer = self.optimizer_class(model.parameters(), lr=self.lr) 1725 if self.model.process_labels: 1726 self.model.set_behaviors(list(self.behaviors_dict().values())) 1727 1728 def set_dataloaders( 1729 self, 1730 train_dataloader: DataLoader, 1731 val_dataloader: DataLoader = None, 1732 test_dataloader: DataLoader = None, 1733 ) -> None: 1734 """ 1735 Set new dataloaders 1736 1737 Parameters 1738 ---------- 1739 train_dataloader: torch.utils.data.DataLoader 1740 the new train dataloader 1741 val_dataloader : torch.utils.data.DataLoader 1742 the new validation dataloader 1743 test_dataloader : torch.utils.data.DataLoader 1744 the new test dataloader 1745 """ 1746 1747 self.train_dataloader = train_dataloader 1748 self.val_dataloader = val_dataloader 1749 self.test_dataloader = test_dataloader 1750 1751 def set_loss(self, loss: Callable) -> None: 1752 """ 1753 Set new loss function 1754 1755 Parameters 1756 ---------- 1757 loss: callable 1758 the new loss function 1759 """ 1760 1761 self.loss = loss 1762 1763 def set_metrics(self, metrics: dict) -> None: 1764 """ 1765 Set new metric 1766 1767 Parameters 1768 ---------- 1769 metrics : dict 1770 the new metric dictionary 1771 """ 1772 1773 self.metrics = metrics 1774 1775 def set_transformer(self, transformer: Transformer) -> None: 1776 """ 1777 Set a new transformer 1778 1779 Parameters 1780 ---------- 1781 transformer: Transformer 1782 the new transformer 1783 """ 1784 1785 self.transformer = transformer 1786 1787 def set_predict_functions( 1788 self, primary_predict_function: Callable, predict_function: Callable 1789 ) -> None: 1790 """ 1791 Set new predict functions 1792 1793 Parameters 1794 ---------- 1795 primary_predict_function : callable 1796 the new primary predict function 1797 predict_function : callable 1798 the new predict function 1799 """ 1800 1801 self.primary_predict_function = primary_predict_function 1802 self.predict_function = predict_function 1803 1804 def _set_pseudolabels(self): 1805 """ 1806 Set pseudolabels 1807 """ 1808 1809 self.train_dataloader.dataset.set_unlabeled(True) 1810 predicted = self.predict( 1811 data=self.dataset("train"), 1812 raw_output=False, 1813 augment_n=self.augment_val, 1814 ssl_off=True, 1815 ) 1816 self.train_dataloader.dataset.set_annotation(predicted.detach()) 1817 1818 def _alpha(self, epoch): 1819 """ 1820 Get the current pseudolabeling alpha parameter 1821 """ 1822 1823 if epoch <= self.T1: 1824 return 0 1825 elif epoch < self.T2: 1826 return self.alpha_f * (epoch - self.T1) / (self.T2 - self.T1) 1827 else: 1828 return self.alpha_f 1829 1830 def count_classes(self, bouts: bool = False) -> Dict: 1831 """ 1832 Get a dictionary of class counts in different modes 1833 1834 Parameters 1835 ---------- 1836 bouts : bool, default False 1837 if `True`, instead of frame counts segment counts are returned 1838 1839 Returns 1840 ------- 1841 class_counts : dict 1842 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1843 class names and values are class counts (in frames) 1844 """ 1845 1846 class_counts = {} 1847 for x in ["train", "val", "test"]: 1848 try: 1849 class_counts[x] = self.dataset(x).count_classes(bouts) 1850 except ValueError: 1851 class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()} 1852 return class_counts 1853 1854 def behaviors_dict(self) -> Dict: 1855 """ 1856 Get a behavior dictionary 1857 1858 Keys are label indices and values are label names. 1859 1860 Returns 1861 ------- 1862 behaviors_dict : dict 1863 behavior dictionary 1864 """ 1865 1866 return self.dataset().behaviors_dict() 1867 1868 def update_parameters(self, parameters: Dict) -> None: 1869 """ 1870 Update training parameters from a dictionary 1871 1872 Parameters 1873 ---------- 1874 parameters : dict 1875 the update dictionary 1876 """ 1877 1878 self.lr = parameters.get("lr", self.lr) 1879 self.parallel = parameters.get("parallel", self.parallel) 1880 self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr) 1881 self.verbose = parameters.get("verbose", self.verbose) 1882 self.device = parameters.get("device", self.device) 1883 if self.device == "auto": 1884 self.device = "cuda" if torch.cuda.is_available() else "cpu" 1885 self.augment_train = int(parameters.get("augment_train", self.augment_train)) 1886 self.augment_val = int(parameters.get("augment_val", self.augment_val)) 1887 ssl_weights = parameters.get("ssl_weights", self.ssl_weights) 1888 if ssl_weights is None: 1889 ssl_weights = [1 for _ in self.ssl_losses] 1890 if not isinstance(ssl_weights, Iterable): 1891 ssl_weights = [ssl_weights for _ in self.ssl_losses] 1892 self.ssl_weights = ssl_weights 1893 self.num_epochs = parameters.get("num_epochs", self.num_epochs) 1894 self.model_save_epochs = parameters.get( 1895 "model_save_epochs", self.model_save_epochs 1896 ) 1897 self.model_save_path = parameters.get("model_save_path", self.model_save_path) 1898 self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel) 1899 self.T1 = int(parameters.get("pseudolabel_start", self.T1)) 1900 self.T2 = int(parameters.get("alpha_growth_stop", self.T2)) 1901 self.t = int(parameters.get("correction_interval", self.t)) 1902 self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f) 1903 self.log_file = parameters.get("log_file", self.log_file) 1904 1905 def generate_uncertainty_score( 1906 self, 1907 classes: List, 1908 augment_n: int = 0, 1909 batch_size: int = 32, 1910 method: str = "least_confidence", 1911 predicted: torch.Tensor = None, 1912 behaviors_dict: Dict = None, 1913 ) -> Dict: 1914 """ 1915 Generate frame-wise scores for active learning 1916 1917 Parameters 1918 ---------- 1919 classes : list 1920 a list of class names or indices; their confidence scores will be computed separately and stacked 1921 augment_n : int, default 0 1922 the number of augmentations to average over 1923 batch_size : int, default 32 1924 the batch size 1925 method : {"least_confidence", "entropy"} 1926 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1927 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1928 1929 Returns 1930 ------- 1931 score_dicts : dict 1932 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1933 are score tensors 1934 """ 1935 1936 dataset = self.dataset("train") 1937 if behaviors_dict is None: 1938 behaviors_dict = self.behaviors_dict() 1939 if not isinstance(dataset, BehaviorDataset): 1940 raise TypeError( 1941 f"The dataset parameter has to be either None, string, " 1942 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1943 ) 1944 if predicted is None: 1945 predicted = self.predict( 1946 dataset, 1947 raw_output=True, 1948 apply_primary_function=True, 1949 augment_n=augment_n, 1950 batch_size=batch_size, 1951 ) 1952 predicted = dataset.generate_full_length_prediction(predicted) 1953 if isinstance(classes[0], str): 1954 behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()} 1955 classes = [behaviors_dict_inverse[c] for c in classes] 1956 for v_id, v in predicted.items(): 1957 for clip_id, vv in v.items(): 1958 if method == "least_confidence": 1959 predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5] 1960 elif method == "entropy": 1961 predicted[v_id][clip_id][vv != -100] = ( 1962 -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv) 1963 )[vv != -100] 1964 elif method == "random": 1965 predicted[v_id][clip_id] = torch.rand(vv.shape) 1966 else: 1967 raise ValueError( 1968 f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']" 1969 ) 1970 predicted[v_id][clip_id][vv == -100] = 0 1971 1972 predicted = { 1973 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 1974 for v_id, video_dict in predicted.items() 1975 } 1976 return predicted 1977 1978 def generate_bald_score( 1979 self, 1980 classes: List, 1981 augment_n: int = 0, 1982 batch_size: int = 32, 1983 num_models: int = 10, 1984 kernel_size: int = 11, 1985 ) -> Dict: 1986 """ 1987 Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning 1988 1989 Parameters 1990 ---------- 1991 classes : list 1992 a list of class names or indices; their confidence scores will be computed separately and stacked 1993 augment_n : int, default 0 1994 the number of augmentations to average over 1995 batch_size : int, default 32 1996 the batch size 1997 num_models : int, default 10 1998 the number of dropout masks to apply 1999 kernel_size : int, default 11 2000 the size of the smoothing gaussian kernel 2001 2002 Returns 2003 ------- 2004 score_dicts : dict 2005 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 2006 are score tensors 2007 """ 2008 2009 dataset = self.dataset("train") 2010 dataset = self._get_dataset(dataset) 2011 if not isinstance(dataset, BehaviorDataset): 2012 raise TypeError( 2013 f"The dataset parameter has to be either None, string, " 2014 f"BehaviorDataset or Dataloader, got {type(dataset)}" 2015 ) 2016 predictions = [] 2017 for _ in range(num_models): 2018 predicted = self.predict( 2019 dataset, 2020 raw_output=True, 2021 apply_primary_function=True, 2022 augment_n=augment_n, 2023 batch_size=batch_size, 2024 train_mode=True, 2025 ) 2026 predicted = dataset.generate_full_length_prediction(predicted) 2027 if isinstance(classes[0], str): 2028 behaviors_dict_inverse = { 2029 v: k for k, v in self.behaviors_dict().items() 2030 } 2031 classes = [behaviors_dict_inverse[c] for c in classes] 2032 for v_id, v in predicted.items(): 2033 for clip_id, vv in v.items(): 2034 vv[vv != -100] = (vv[vv != -100] > 0.5).int().float() 2035 predicted[v_id][clip_id] = vv 2036 predicted = { 2037 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 2038 for v_id, video_dict in predicted.items() 2039 } 2040 predictions.append(predicted) 2041 result = {v_id: {} for v_id in predictions[0]} 2042 r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1) 2043 gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r] 2044 gauss = [x / sum(gauss) for x in gauss] 2045 kernel = torch.FloatTensor([[gauss]]) 2046 for v_id in predictions[0]: 2047 for clip_id in predictions[0][v_id]: 2048 consensus = ( 2049 ( 2050 torch.mean( 2051 torch.stack([x[v_id][clip_id] for x in predictions]), dim=0 2052 ) 2053 > 0.5 2054 ) 2055 .int() 2056 .float() 2057 ) 2058 consensus[predictions[0][v_id][clip_id] == -100] = -100 2059 result[v_id][clip_id] = torch.zeros(consensus.shape) 2060 for x in predictions: 2061 result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int() 2062 result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models 2063 res = torch.zeros(result[v_id][clip_id].shape) 2064 for i in range(len(classes)): 2065 res[ 2066 i, floor(kernel_size // 2) : -floor(kernel_size // 2) 2067 ] = torch.nn.functional.conv1d( 2068 result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0), kernel 2069 )[ 2070 0, ... 2071 ] 2072 result[v_id][clip_id] = res 2073 return result 2074 2075 def get_normalization_stats(self) -> Optional[Dict]: 2076 return self.train_dataloader.dataset.stats
37class MyDataParallel(nn.DataParallel): 38 def __init__(self, *args, **kwargs): 39 super().__init__(*args, **kwargs) 40 self.process_labels = self.module.process_labels 41 42 def freeze_feature_extractor(self): 43 self.module.freeze_feature_extractor() 44 45 def unfreeze_feature_extractor(self): 46 self.module.unfreeze_feature_extractor() 47 48 def transform_labels(self, device): 49 return self.module.transform_labels(device) 50 51 def logit_scale(self): 52 return self.module.logit_scale() 53 54 def main_task_off(self): 55 self.module.main_task_off() 56 57 def state_dict(self, *args, **kwargs): 58 return self.module.state_dict(*args, **kwargs) 59 60 def ssl_on(self): 61 self.module.ssl_on() 62 63 def ssl_off(self): 64 self.module.ssl_off() 65 66 def extract_features(self, x, start=0): 67 return self.module.extract_features(x, start)
Implements data parallelism at the module level.
This container parallelizes the application of the given module
by
splitting the input across the specified devices by chunking in the batch
dimension (other objects will be copied once per device). In the forward
pass, the module is replicated on each device, and each replica handles a
portion of the input. During the backwards pass, gradients from each replica
are summed into the original module.
The batch size should be larger than the number of GPUs used.
It is recommended to use ~torch.nn.parallel.DistributedDataParallel
,
instead of this class, to do multi-GPU training, even if there is only a single
node. See: :ref:cuda-nn-ddp-instead
and :ref:ddp
.
Arbitrary positional and keyword inputs are allowed to be passed into DataParallel but some types are specially handled. tensors will be scattered on dim specified (default 0). tuple, list and dict types will be shallow copied. The other types will be shared among different threads and can be corrupted if written to in the model's forward pass.
The parallelized module
must have its parameters and buffers on
device_ids[0]
before running this ~torch.nn.DataParallel
module.
In each forward, module
is replicated on each device, so any
updates to the running module in forward
will be lost. For example,
if module
has a counter attribute that is incremented in each
forward
, it will always stay at the initial value because the update
is done on the replicas which are destroyed after forward
. However,
~torch.nn.DataParallel
guarantees that the replica on
device[0]
will have its parameters and buffers sharing storage with
the base parallelized module
. So in-place updates to the
parameters or buffers on device[0]
will be recorded. E.g.,
~torch.nn.BatchNorm2d
and ~torch.nn.utils.spectral_norm
rely on this behavior to update the buffers.
Forward and backward hooks defined on module
and its submodules
will be invoked len(device_ids)
times, each with inputs located on
a particular device. Particularly, the hooks are only guaranteed to be
executed in correct order with respect to operations on corresponding
devices. For example, it is not guaranteed that hooks set via
~torch.nn.Module.register_forward_pre_hook
be executed before
all
len(device_ids)
~torch.nn.Module.forward
calls, but
that each such hook be executed before the corresponding
~torch.nn.Module.forward
call of that device.
When module
returns a scalar (i.e., 0-dimensional tensor) in
forward
, this wrapper will return a vector of length equal to
number of devices used in data parallelism, containing the result from
each device.
There is a subtlety in using the
pack sequence -> recurrent network -> unpack sequence
pattern in a
~torch.nn.Module
wrapped in ~torch.nn.DataParallel
.
See :ref:pack-rnn-unpack-with-data-parallelism
section in FAQ for
details.
Args: module (Module): module to be parallelized device_ids (list of int or torch.device): CUDA devices (default: all devices) output_device (int or torch.device): device location of output (default: device_ids[0])
Attributes: module (Module): the module to be parallelized
Example::
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
>>> output = net(input_var) # input_var can be on any device, including CPU
38 def __init__(self, *args, **kwargs): 39 super().__init__(*args, **kwargs) 40 self.process_labels = self.module.process_labels
Initializes internal Module state, shared by both nn.Module and ScriptModule.
Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to None
are not included.
Currently state_dict()
also accepts positional arguments for
destination
, prefix
and keep_vars
in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
Please avoid the use of argument destination
as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an OrderedDict
will be created and returned.
Default: None
.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ''
.
keep_vars (bool, optional): by default the ~torch.Tensor
s
returned in the state dict are detached from autograd. If it's
set to True
, detaching will not be performed.
Default: False
.
Returns: dict: a dictionary containing a whole state of the module
Example::
>>> module.state_dict().keys()
['bias', 'weight']
Inherited Members
- torch.nn.parallel.data_parallel.DataParallel
- forward
- replicate
- scatter
- parallel_apply
- gather
- torch.nn.modules.module.Module
- dump_patches
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- T_destination
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
70class Task: 71 """ 72 A universal trainer class that performs training, evaluation and prediction for all types of tasks and data 73 """ 74 75 def __init__( 76 self, 77 train_dataloader: DataLoader, 78 model: Union[nn.Module, Model], 79 loss: Callable[[torch.Tensor, torch.Tensor], float], 80 num_epochs: int = 0, 81 transformer: Transformer = None, 82 ssl_losses: List = None, 83 ssl_weights: List = None, 84 lr: float = 1e-3, 85 metrics: Dict = None, 86 val_dataloader: DataLoader = None, 87 test_dataloader: DataLoader = None, 88 optimizer: Optimizer = None, 89 device: str = "cuda", 90 verbose: bool = True, 91 log_file: Union[str, None] = None, 92 augment_train: int = 1, 93 augment_val: int = 0, 94 validation_interval: int = 1, 95 predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None, 96 primary_predict_function: Callable = None, 97 exclusive: bool = True, 98 ignore_tags: bool = True, 99 threshold: float = 0.5, 100 model_save_path: str = None, 101 model_save_epochs: int = 5, 102 pseudolabel: bool = False, 103 pseudolabel_start: int = 100, 104 correction_interval: int = 2, 105 pseudolabel_alpha_f: float = 3, 106 alpha_growth_stop: int = 600, 107 parallel: bool = False, 108 skip_metrics: List = None, 109 ) -> None: 110 """ 111 Parameters 112 ---------- 113 train_dataloader : torch.utils.data.DataLoader 114 a training dataloader 115 model : dlc2action.model.base_model.Model 116 a model 117 loss : callable 118 a loss function 119 num_epochs : int, default 0 120 the number of epochs 121 transformer : dlc2action.transformer.base_transformer.Transformer, optional 122 a transformer 123 ssl_losses : list, optional 124 a list of SSL losses 125 ssl_weights : list, optional 126 a list of SSL weights (if not provided initializes to 1) 127 lr : float, default 1e-3 128 learning rate 129 metrics : dict, optional 130 a list of metric functions 131 val_dataloader : torch.utils.data.DataLoader, optional 132 a validation dataloader 133 optimizer : torch.optim.Optimizer, optional 134 an optimizer (`Adam` by default) 135 device : str, default 'cuda' 136 the device to train the model on 137 verbose : bool, default True 138 if `True`, the process is described in standard output 139 log_file : str, optional 140 the path to a text file where the process will be logged 141 augment_train : {1, 0} 142 number of augmentations to apply at training 143 augment_val : int, default 0 144 number of augmentations to apply at validation 145 validation_interval : int, default 1 146 every time this number of epochs passes, validation metrics are computed 147 predict_function : callable, optional 148 a function that maps probabilities to class predictions (if not provided, a default is generated) 149 primary_predict_function : callable, optional 150 a function that maps model output to probabilities (if not provided, initialized as identity) 151 exclusive : bool, default True 152 set to False for multi-label classification 153 ignore_tags : bool, default False 154 if `True`, samples with different meta tags will be mixed in batches 155 threshold : float, default 0.5 156 the threshold used for multi-label classification default prediction function 157 model_save_path : str, optional 158 the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path 159 is not provided) 160 model_save_epochs : int, default 5 161 the interval for saving the model checkpoints (the last epoch is always saved) 162 pseudolabel : bool, default False 163 if True, the pseudolabeling procedure will be applied 164 pseudolabel_start : int, default 100 165 pseudolabeling starts after this epoch 166 correction_interval : int, default 1 167 after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and 168 new pseudolabels are generated 169 pseudolabel_alpha_f : float, default 3 170 the maximum value of pseudolabeling alpha 171 alpha_growth_stop : int, default 600 172 pseudolabeling alpha stops growing after this epoch 173 """ 174 175 # pseudolabeling might be buggy right now -- not using it! 176 if skip_metrics is None: 177 skip_metrics = [] 178 self.train_dataloader = train_dataloader 179 self.val_dataloader = val_dataloader 180 self.test_dataloader = test_dataloader 181 self.transformer = transformer 182 self.num_epochs = num_epochs 183 self.skip_metrics = skip_metrics 184 self.verbose = verbose 185 self.augment_train = int(augment_train) 186 self.augment_val = int(augment_val) 187 self.ignore_tags = ignore_tags 188 self.validation_interval = int(validation_interval) 189 self.log_file = log_file 190 self.loss = loss 191 self.model_save_path = model_save_path 192 self.model_save_epochs = model_save_epochs 193 self.epoch = 0 194 195 if metrics is None: 196 metrics = {} 197 self.metrics = metrics 198 199 if optimizer is None: 200 optimizer = Adam 201 202 if ssl_weights is None: 203 ssl_weights = [1 for _ in ssl_losses] 204 if not isinstance(ssl_weights, Iterable): 205 ssl_weights = [ssl_weights for _ in ssl_losses] 206 self.ssl_weights = ssl_weights 207 208 self.optimizer_class = optimizer 209 self.lr = lr 210 if not isinstance(model, Model): 211 self.model = LoadedModel(model=model) 212 else: 213 self.set_model(model) 214 self.parallel = parallel 215 216 if self.transformer is None: 217 self.augment_val = 0 218 self.augment_train = 0 219 self.transformer = EmptyTransformer() 220 221 if self.augment_train > 1: 222 warnings.warn( 223 'The "augment_train" parameter is too large -> setting it to 1.' 224 ) 225 self.augment_train = 1 226 227 try: 228 if device == "auto": 229 device = "cuda" if torch.cuda.is_available() else "cpu" 230 self.device = torch.device(device) 231 except: 232 raise ("The format of the device is incorrect") 233 234 if ssl_losses is None: 235 self.ssl_losses = [lambda x, y: 0] 236 else: 237 self.ssl_losses = ssl_losses 238 239 if primary_predict_function is None: 240 if exclusive: 241 primary_predict_function = lambda x: nn.Softmax(x, dim=1) 242 else: 243 primary_predict_function = lambda x: torch.sigmoid(x) 244 self.primary_predict_function = primary_predict_function 245 246 if predict_function is None: 247 if exclusive: 248 self.predict_function = lambda x: torch.max(x.data, 1)[1] 249 else: 250 self.predict_function = lambda x: (x > threshold).int() 251 else: 252 self.predict_function = predict_function 253 254 self.pseudolabel = pseudolabel 255 self.alpha_f = pseudolabel_alpha_f 256 self.T2 = alpha_growth_stop 257 self.T1 = pseudolabel_start 258 self.t = correction_interval 259 if self.T2 <= self.T1: 260 raise ValueError( 261 f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got " 262 f"{pseudolabel_start=} and {alpha_growth_stop=}" 263 ) 264 self.decision_thresholds = [0.5 for x in self.behaviors_dict()] 265 266 def save_checkpoint(self, checkpoint_path: str) -> None: 267 """ 268 Save a general checkpoint 269 270 Parameters 271 ---------- 272 checkpoint_path : str 273 the path where the checkpoint will be saved 274 """ 275 276 torch.save( 277 { 278 "epoch": self.epoch, 279 "model_state_dict": self.model.state_dict(), 280 "optimizer_state_dict": self.optimizer.state_dict(), 281 }, 282 checkpoint_path, 283 ) 284 285 def load_from_checkpoint( 286 self, checkpoint_path, only_model: bool = False, load_strict: bool = True 287 ) -> None: 288 """ 289 Load from a checkpoint 290 291 Parameters 292 ---------- 293 checkpoint_path : str 294 the path to the checkpoint 295 only_model : bool, default False 296 if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state 297 dictionary) 298 load_strict : bool, default True 299 if `True`, any inconsistencies in state dictionaries are regarded as errors 300 """ 301 302 if checkpoint_path is None: 303 return 304 checkpoint = torch.load(checkpoint_path, map_location=self.device) 305 self.model.to(self.device) 306 self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict) 307 if not only_model: 308 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 309 self.epoch = checkpoint["epoch"] 310 311 def save_model(self, save_path: str) -> None: 312 """ 313 Save the model state dictionary 314 315 Parameters 316 ---------- 317 save_path : str 318 the path where the state will be saved 319 """ 320 321 torch.save(self.model.state_dict(), save_path) 322 print("saved the model successfully") 323 324 def _apply_predict_functions(self, predicted): 325 """ 326 Map from model output to prediction 327 """ 328 329 predicted = self.primary_predict_function(predicted) 330 predicted = self.predict_function(predicted) 331 return predicted 332 333 def _get_prediction( 334 self, 335 main_input: Dict, 336 tags: torch.Tensor, 337 ssl_inputs: List = None, 338 ssl_targets: List = None, 339 augment_n: int = 0, 340 embedding: bool = False, 341 subsample: List = None, 342 ) -> Tuple: 343 """ 344 Get the prediction of `self.model` for input averaged over `augment_n` augmentations 345 """ 346 347 if augment_n == 0: 348 augment_n = 1 349 augment = False 350 else: 351 augment = True 352 model_input, ssl_inputs, ssl_targets = self.transformer.transform( 353 deepcopy(main_input), 354 ssl_inputs=ssl_inputs, 355 ssl_targets=ssl_targets, 356 augment=augment, 357 subsample=subsample, 358 ) 359 if not embedding: 360 predicted, ssl_predicted = self.model(model_input, ssl_inputs, tag=tags) 361 else: 362 predicted = self.model.extract_features(model_input) 363 ssl_predicted = None 364 if self.parallel and predicted is not None and len(predicted.shape) == 4: 365 predicted = predicted.reshape( 366 (-1, model_input.shape[0], *predicted.shape[2:]) 367 ) 368 if predicted is not None and augment_n > 1: 369 self.model.ssl_off() 370 for i in range(augment_n - 1): 371 model_input, *_ = self.transformer.transform( 372 deepcopy(main_input), augment=augment 373 ) 374 if not embedding: 375 pred, _ = self.model(model_input, None, tag=tags) 376 else: 377 pred = self.model.extract_features(model_input) 378 predicted += pred.detach() 379 self.model.ssl_on() 380 predicted /= augment_n 381 if self.model.process_labels: 382 class_embedding = self.model.transform_labels(predicted.device) 383 beh_dict = self.behaviors_dict() 384 class_embedding = torch.cat( 385 [class_embedding[beh_dict[k]] for k in sorted(beh_dict.keys())], 0 386 ) 387 predicted = { 388 "video": predicted, 389 "text": class_embedding, 390 "logit_scale": self.model.logit_scale(), 391 "device": predicted.device, 392 } 393 return predicted, ssl_predicted, ssl_targets 394 395 def _ssl_loss(self, ssl_predicted, ssl_targets): 396 """ 397 Apply SSL losses 398 """ 399 400 ssl_loss = [] 401 for loss, predicted, target in zip(self.ssl_losses, ssl_predicted, ssl_targets): 402 ssl_loss.append(loss(predicted, target)) 403 return ssl_loss 404 405 def _loss_function( 406 self, 407 batch: Dict, 408 augment_n: int, 409 temporal_subsampling_size: int = None, 410 skip_metrics: List = None, 411 ) -> Tuple[float, float]: 412 """ 413 Calculate the loss function and the metric for a dataloader batch 414 415 Averaging the predictions over augment_n augmentations. 416 """ 417 418 if "target" not in batch or torch.isnan(batch["target"]).all(): 419 raise ValueError("Cannot compute loss function with nan targets!") 420 main_input = {k: v.to(self.device) for k, v in batch["input"].items()} 421 main_target, ssl_targets, ssl_inputs = None, None, None 422 main_target = batch["target"].to(self.device) 423 if "ssl_targets" in batch: 424 ssl_targets = [ 425 {k: v.to(self.device) for k, v in x.items()} 426 if isinstance(x, dict) 427 else None 428 for x in batch["ssl_targets"] 429 ] 430 if "ssl_inputs" in batch: 431 ssl_inputs = [ 432 {k: v.to(self.device) for k, v in x.items()} 433 if isinstance(x, dict) 434 else None 435 for x in batch["ssl_inputs"] 436 ] 437 if temporal_subsampling_size is not None: 438 subsample = sorted( 439 random.sample( 440 range(main_target.shape[-1]), 441 int(temporal_subsampling_size * main_target.shape[-1]), 442 ) 443 ) 444 main_target = main_target[..., subsample] 445 else: 446 subsample = None 447 448 if self.ignore_tags: 449 tag = None 450 else: 451 tag = batch.get("tag") 452 predicted, ssl_predicted, ssl_targets = self._get_prediction( 453 main_input, tag, ssl_inputs, ssl_targets, augment_n, subsample=subsample 454 ) 455 del main_input, ssl_inputs 456 return self._compute( 457 ssl_predicted, 458 ssl_targets, 459 predicted, 460 main_target, 461 tag=batch.get("tag"), 462 skip_metrics=skip_metrics, 463 ) 464 465 def _compute( 466 self, 467 ssl_predicted: List, 468 ssl_targets: List, 469 predicted: torch.Tensor, 470 main_target: torch.Tensor, 471 tag: Any = None, 472 skip_loss: bool = False, 473 apply_primary_function: bool = True, 474 skip_metrics: List = None, 475 ) -> Tuple[float, float]: 476 """ 477 Compute the losses and metrics from predictions 478 """ 479 480 if skip_metrics is None: 481 skip_metrics = [] 482 if not skip_loss: 483 ssl_losses = self._ssl_loss(ssl_predicted, ssl_targets) 484 if predicted is not None: 485 loss = self.loss(predicted, main_target) 486 else: 487 loss = 0 488 else: 489 ssl_losses, loss = [], 0 490 491 if predicted is not None: 492 if isinstance(predicted, dict): 493 predicted = { 494 k: v.detach() 495 for k, v in predicted.items() 496 if isinstance(v, torch.Tensor) 497 } 498 else: 499 predicted = predicted.detach() 500 if apply_primary_function: 501 predicted = self.primary_predict_function(predicted) 502 predicted_transformed = self.predict_function(predicted) 503 for name, metric_function in self.metrics.items(): 504 if name not in skip_metrics: 505 if metric_function.needs_raw_data: 506 metric_function.update(predicted, main_target, tag) 507 else: 508 metric_function.update( 509 predicted_transformed, 510 main_target, 511 tag, 512 ) 513 return loss, ssl_losses 514 515 def _calculate_metrics(self) -> Dict: 516 """ 517 Calculate the final values of epoch metrics 518 """ 519 520 epoch_metrics = {} 521 for metric_name, metric in self.metrics.items(): 522 m = metric.calculate() 523 if type(m) is dict: 524 for k, v in m.items(): 525 if type(v) is torch.Tensor: 526 v = v.item() 527 epoch_metrics[f"{metric_name}_{k}"] = v 528 else: 529 if type(m) is torch.Tensor: 530 m = m.item() 531 epoch_metrics[metric_name] = m 532 metric.reset() 533 return epoch_metrics 534 535 def _run_epoch( 536 self, 537 dataloader: DataLoader, 538 mode: str, 539 augment_n: int, 540 verbose: bool = False, 541 unlabeled: bool = None, 542 alpha: float = 1, 543 temporal_subsampling_size: int = None, 544 ) -> Tuple: 545 """ 546 Run one epoch on dataloader 547 548 Averaging the predictions over augment_n augmentations. 549 Use "train" mode for training and "val" mode for evaluation. 550 """ 551 552 if mode == "train": 553 self.model.train() 554 elif mode == "val": 555 self.model.eval() 556 pass 557 else: 558 raise ValueError( 559 f'Mode {mode} is not recognized, please choose either "train" for training or "val" for validation' 560 ) 561 if self.ignore_tags: 562 tags = [None] 563 else: 564 tags = dataloader.dataset.get_tags() 565 epoch_loss = 0 566 epoch_ssl_loss = defaultdict(lambda: 0) 567 data_len = 0 568 set_pars = dataloader.dataset.set_indexing_parameters 569 skip_metrics = self.skip_metrics if mode == "train" else None 570 for tag in tags: 571 set_pars(unlabeled=unlabeled, tag=tag) 572 data_len += len(dataloader) 573 if verbose: 574 dataloader = tqdm(dataloader) 575 for batch in dataloader: 576 loss, ssl_losses = self._loss_function( 577 batch, 578 augment_n, 579 temporal_subsampling_size=temporal_subsampling_size, 580 skip_metrics=skip_metrics, 581 ) 582 if loss != 0: 583 loss = loss * alpha 584 epoch_loss += loss.item() 585 for i, (ssl_loss, weight) in enumerate( 586 zip(ssl_losses, self.ssl_weights) 587 ): 588 if ssl_loss != 0: 589 epoch_ssl_loss[i] += ssl_loss.item() 590 loss = loss + weight * ssl_loss 591 if mode == "train": 592 self.optimizer.zero_grad() 593 if loss.requires_grad: 594 loss.backward() 595 self.optimizer.step() 596 597 epoch_loss = epoch_loss / data_len 598 epoch_ssl_loss = {k: v / data_len for k, v in epoch_ssl_loss.items()} 599 epoch_metrics = self._calculate_metrics() 600 601 return epoch_loss, epoch_ssl_loss, epoch_metrics 602 603 def train( 604 self, 605 trial: Trial = None, 606 optimized_metric: str = None, 607 to_ram: bool = False, 608 autostop_interval: int = 30, 609 autostop_threshold: float = 0.001, 610 autostop_metric: str = None, 611 main_task_on: bool = True, 612 ssl_on: bool = True, 613 temporal_subsampling_size: int = None, 614 loading_bar: bool = False, 615 ) -> Tuple: 616 """ 617 Train the task and return a log of epoch-average loss and metric 618 619 You can use the autostop parameters to finish training when the parameters are not improving. It will be 620 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 621 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 622 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 623 624 Parameters 625 ---------- 626 trial : Trial 627 an `optuna` trial (for hyperparameter searches) 628 optimized_metric : str 629 the name of the metric being optimized (for hyperparameter searches) 630 to_ram : bool, default False 631 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 632 if the dataset is too large) 633 autostop_interval : int, default 50 634 the number of epochs to average the autostop metric over 635 autostop_threshold : float, default 0.001 636 the autostop difference threshold 637 autostop_metric : str, optional 638 the autostop metric (can be any one of the tracked metrics of `'loss'`) 639 main_task_on : bool, default True 640 if `False`, the main task (action segmentation) will not be used in training 641 ssl_on : bool, default True 642 if `False`, the SSL task will not be used in training 643 644 Returns 645 ------- 646 loss_log: list 647 a list of float loss function values for each epoch 648 metrics_log: dict 649 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 650 names, values are lists of function values) 651 """ 652 653 if self.parallel and not isinstance(self.model, nn.DataParallel): 654 self.model = MyDataParallel(self.model) 655 self.model.to(self.device) 656 assert autostop_metric in [None, "loss"] + list(self.metrics) 657 autostop_interval //= self.validation_interval 658 if trial is not None and optimized_metric is None: 659 raise ValueError( 660 "You need to provide the optimized metric name (optimized_metric parameter) " 661 "for optuna pruning to work!" 662 ) 663 if to_ram: 664 print("transferring datasets to RAM...") 665 self.train_dataloader.dataset.to_ram() 666 if self.val_dataloader is not None and len(self.val_dataloader) > 0: 667 self.val_dataloader.dataset.to_ram() 668 loss_log = {"train": [], "val": []} 669 metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])} 670 if not main_task_on: 671 self.model.main_task_off() 672 if not ssl_on: 673 self.model.ssl_off() 674 while self.epoch < self.num_epochs: 675 self.epoch += 1 676 unlabeled = None 677 alpha = 1 678 if self.pseudolabel: 679 if self.epoch >= self.T1: 680 unlabeled = (self.epoch - self.T1) % self.t != 0 681 if unlabeled: 682 alpha = self._alpha(self.epoch) 683 else: 684 unlabeled = False 685 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 686 dataloader=self.train_dataloader, 687 mode="train", 688 augment_n=self.augment_train, 689 unlabeled=unlabeled, 690 alpha=alpha, 691 temporal_subsampling_size=temporal_subsampling_size, 692 verbose=loading_bar, 693 ) 694 loss_log["train"].append(epoch_loss) 695 epoch_string = f"[epoch {self.epoch}]" 696 if self.pseudolabel: 697 if unlabeled: 698 epoch_string += " (unlabeled)" 699 else: 700 epoch_string += " (labeled)" 701 epoch_string += f": loss {epoch_loss:.4f}" 702 703 if len(epoch_ssl_loss) != 0: 704 for key, value in sorted(epoch_ssl_loss.items()): 705 metrics_log["train"][f"ssl_loss_{key}"].append(value) 706 epoch_string += f", ssl_loss_{key} {value:.4f}" 707 708 for metric_name, metric_value in sorted(epoch_metrics.items()): 709 if metric_name not in self.skip_metrics: 710 metrics_log["train"][metric_name].append(metric_value) 711 epoch_string += f", {metric_name} {metric_value:.3f}" 712 713 if ( 714 self.val_dataloader is not None 715 and self.epoch % self.validation_interval == 0 716 ): 717 with torch.no_grad(): 718 epoch_string += "\n" 719 ( 720 val_epoch_loss, 721 val_epoch_ssl_loss, 722 val_epoch_metrics, 723 ) = self._run_epoch( 724 dataloader=self.val_dataloader, 725 mode="val", 726 augment_n=self.augment_val, 727 ) 728 loss_log["val"].append(val_epoch_loss) 729 epoch_string += f"validation: loss {val_epoch_loss:.4f}" 730 731 if len(val_epoch_ssl_loss) != 0: 732 for key, value in sorted(val_epoch_ssl_loss.items()): 733 metrics_log["val"][f"ssl_loss_{key}"].append(value) 734 epoch_string += f", ssl_loss_{key} {value:.4f}" 735 736 for metric_name, metric_value in sorted(val_epoch_metrics.items()): 737 metrics_log["val"][metric_name].append(metric_value) 738 epoch_string += f", {metric_name} {metric_value:.3f}" 739 740 if trial is not None: 741 if optimized_metric not in metrics_log["val"]: 742 raise ValueError( 743 f"The {optimized_metric} metric set for optimization is not being logged!" 744 ) 745 trial.report(metrics_log["val"][optimized_metric][-1], self.epoch) 746 if trial.should_prune(): 747 raise TrialPruned() 748 749 if self.verbose: 750 print(epoch_string) 751 752 if self.log_file is not None: 753 with open(self.log_file, "a") as f: 754 f.write(epoch_string + "\n") 755 756 save_condition = ( 757 (self.model_save_epochs != 0) 758 and (self.epoch % self.model_save_epochs == 0) 759 ) or (self.epoch == self.num_epochs) 760 761 if self.epoch > 0 and save_condition and self.model_save_path is not None: 762 epoch_s = str(self.epoch).zfill(len(str(self.num_epochs))) 763 self.save_checkpoint( 764 os.path.join(self.model_save_path, f"epoch{epoch_s}.pt") 765 ) 766 767 if self.pseudolabel and self.epoch >= self.T1 and not unlabeled: 768 self._set_pseudolabels() 769 770 if autostop_metric == "loss": 771 if len(loss_log["val"]) > autostop_interval * 2: 772 if ( 773 np.mean(loss_log["val"][-autostop_interval:]) 774 < np.mean( 775 loss_log["val"][-2 * autostop_interval : -autostop_interval] 776 ) 777 + autostop_threshold 778 ): 779 break 780 elif autostop_metric in metrics_log["val"]: 781 if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2: 782 if ( 783 np.mean( 784 metrics_log["val"][autostop_metric][-autostop_interval:] 785 ) 786 < np.mean( 787 metrics_log["val"][autostop_metric][ 788 -2 * autostop_interval : -autostop_interval 789 ] 790 ) 791 + autostop_threshold 792 ): 793 break 794 795 metrics_log = {k: dict(v) for k, v in metrics_log.items()} 796 797 return loss_log, metrics_log 798 799 def evaluate_prediction( 800 self, 801 prediction: Union[torch.Tensor, Dict], 802 data: Union[DataLoader, BehaviorDataset, str] = None, 803 batch_size: int = 32, 804 ) -> Tuple: 805 """ 806 Compute metrics for a prediction 807 808 Parameters 809 ---------- 810 prediction : torch.Tensor 811 the prediction 812 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 813 the data the prediction was made for (if not provided, take the validation dataset) 814 batch_size : int, default 32 815 the batch size 816 817 Returns 818 ------- 819 loss : float 820 the average value of the loss function 821 metric : dict 822 a dictionary of average values of metric functions 823 """ 824 825 if type(data) is not DataLoader: 826 dataset = self._get_dataset(data) 827 data = DataLoader(dataset, shuffle=False, batch_size=batch_size) 828 epoch_loss = 0 829 if isinstance(prediction, dict): 830 num_classes = len(self.behaviors_dict()) 831 length = dataset.len_segment() 832 coords = dataset.annotation_store.get_original_coordinates() 833 for batch in data: 834 main_target = batch["target"] 835 pr_coords = coords[batch["index"]] 836 predicted = torch.zeros((len(pr_coords), num_classes, length)) 837 for i, c in enumerate(pr_coords): 838 video_id = dataset.input_store.get_video_id(c) 839 clip_id = dataset.input_store.get_clip_id(c) 840 start, end = dataset.input_store.get_clip_start_end(c) 841 predicted[i, :, : end - start] = prediction[video_id][clip_id][ 842 :, start:end 843 ] 844 self._compute( 845 [], 846 [], 847 predicted, 848 main_target, 849 skip_loss=True, 850 tag=batch.get("tag"), 851 apply_primary_function=False, 852 ) 853 else: 854 for batch in data: 855 main_target = batch["target"] 856 predicted = prediction[batch["index"]] 857 self._compute( 858 [], 859 [], 860 predicted, 861 main_target, 862 skip_loss=True, 863 tag=batch.get("tag"), 864 apply_primary_function=False, 865 ) 866 epoch_metrics = self._calculate_metrics() 867 strings = [ 868 f"{metric_name} {metric_value:.3f}" 869 for metric_name, metric_value in epoch_metrics.items() 870 ] 871 val_string = ", ".join(sorted(strings)) 872 print(val_string) 873 return epoch_loss, epoch_metrics 874 875 def evaluate( 876 self, 877 data: Union[DataLoader, BehaviorDataset, str] = None, 878 augment_n: int = 0, 879 batch_size: int = 32, 880 verbose: bool = True, 881 ) -> Tuple: 882 """ 883 Evaluate the Task model 884 885 Parameters 886 ---------- 887 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 888 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 889 augment_n : int, default 0 890 the number of augmentations to average results over 891 batch_size : int, default 32 892 the batch size 893 verbose : bool, default True 894 if True, the process is reported to standard output 895 896 Returns 897 ------- 898 loss : float 899 the average value of the loss function 900 ssl_loss : float 901 the average value of the SSL loss function 902 metric : dict 903 a dictionary of average values of metric functions 904 """ 905 906 if self.parallel and not isinstance(self.model, nn.DataParallel): 907 self.model = MyDataParallel(self.model) 908 self.model.to(self.device) 909 if type(data) is not DataLoader: 910 data = self._get_dataset(data) 911 data = DataLoader(data, shuffle=False, batch_size=batch_size) 912 with torch.no_grad(): 913 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 914 dataloader=data, mode="val", augment_n=augment_n, verbose=verbose 915 ) 916 val_string = f"loss {epoch_loss:.4f}" 917 for metric_name, metric_value in sorted(epoch_metrics.items()): 918 val_string += f", {metric_name} {metric_value:.3f}" 919 print(val_string) 920 return epoch_loss, epoch_ssl_loss, epoch_metrics 921 922 def predict( 923 self, 924 data: Union[DataLoader, BehaviorDataset, str] = None, 925 raw_output: bool = False, 926 apply_primary_function: bool = True, 927 augment_n: int = 0, 928 batch_size: int = 32, 929 train_mode: bool = False, 930 to_ram: bool = False, 931 embedding: bool = False, 932 ) -> torch.Tensor: 933 """ 934 Make a prediction with the Task model 935 936 Parameters 937 ---------- 938 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional 939 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 940 raw_output : bool, default False 941 if `True`, the raw predicted probabilities are returned 942 apply_primary_function : bool, default True 943 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 944 the input) 945 augment_n : int, default 0 946 the number of augmentations to average results over 947 batch_size : int, default 32 948 the batch size 949 train_mode : bool, default False 950 if `True`, the model is used in training mode (affects dropout and batch normalization layers) 951 to_ram : bool, default False 952 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 953 if the dataset is too large) 954 embedding : bool, default False 955 if `True`, the output of feature extractor is returned, ignoring the prediction module of the model 956 957 Returns 958 ------- 959 prediction : torch.Tensor 960 a prediction for the input data 961 """ 962 963 if self.parallel and not isinstance(self.model, nn.DataParallel): 964 self.model = MyDataParallel(self.model) 965 self.model.to(self.device) 966 if train_mode: 967 self.model.train() 968 else: 969 self.model.eval() 970 output = [] 971 if embedding: 972 raw_output = True 973 apply_primary_function = True 974 if type(data) is not DataLoader: 975 data = self._get_dataset(data) 976 if to_ram: 977 print("transferring dataset to RAM...") 978 data.to_ram() 979 data = DataLoader(data, shuffle=False, batch_size=batch_size) 980 self.model.ssl_off() 981 with torch.no_grad(): 982 for batch in tqdm(data): 983 input = {k: v.to(self.device) for k, v in batch["input"].items()} 984 predicted, _, _ = self._get_prediction( 985 input, 986 batch.get("tag"), 987 augment_n=augment_n, 988 embedding=embedding, 989 ) 990 if apply_primary_function: 991 predicted = self.primary_predict_function(predicted) 992 if not raw_output: 993 predicted = self.predict_function(predicted) 994 output.append(predicted.detach().cpu()) 995 self.model.ssl_on() 996 output = torch.cat(output).detach() 997 return output 998 999 def dataset(self, mode="train") -> BehaviorDataset: 1000 """ 1001 Get a dataset 1002 1003 Parameters 1004 ---------- 1005 mode : {'train', 'val', 'test} 1006 the dataset to get 1007 1008 Returns 1009 ------- 1010 dataset : dlc2action.data.dataset.BehaviorDataset 1011 the dataset 1012 """ 1013 1014 dataloader = self.dataloader(mode) 1015 if dataloader is None: 1016 raise ValueError("The length of the dataloader is 0!") 1017 return dataloader.dataset 1018 1019 def dataloader(self, mode: str = "train") -> DataLoader: 1020 """ 1021 Get a dataloader 1022 1023 Parameters 1024 ---------- 1025 mode : {'train', 'val', 'test} 1026 the dataset to get 1027 1028 Returns 1029 ------- 1030 dataloader : torch.utils.data.DataLoader 1031 the dataloader 1032 """ 1033 1034 if mode == "train": 1035 return self.train_dataloader 1036 elif mode == "val": 1037 return self.val_dataloader 1038 elif mode == "test": 1039 return self.test_dataloader 1040 else: 1041 raise ValueError( 1042 f'The {mode} mode is not recognized, please choose from "train", "val" or "test"' 1043 ) 1044 1045 def _get_dataset(self, dataset): 1046 """ 1047 Get a dataset from a dataloader, a string ('train', 'test' or 'val') or `None` (default) 1048 """ 1049 1050 if dataset is None: 1051 dataset = self.dataset("val") 1052 elif dataset in ["train", "val", "test"]: 1053 dataset = self.dataset(dataset) 1054 elif type(dataset) is DataLoader: 1055 dataset = dataset.dataset 1056 if type(dataset) is BehaviorDataset: 1057 return dataset 1058 else: 1059 raise TypeError(f"The {type(dataset)} type of dataset is not recognized!") 1060 1061 def _get_dataloader(self, dataset): 1062 """ 1063 Get a dataloader from a dataset, a string ('train', 'test' or 'val') or `None` (default) 1064 """ 1065 1066 if dataset is None: 1067 dataset = self.dataloader("val") 1068 elif dataset in ["train", "val", "test"]: 1069 dataset = self.dataloader(dataset) 1070 if dataset is None: 1071 raise ValueError(f"The length of the dataloader is 0!") 1072 elif type(dataset) is BehaviorDataset: 1073 dataset = DataLoader(dataset) 1074 if type(dataset) is DataLoader: 1075 return dataset 1076 else: 1077 raise TypeError(f"The {type(dataset)} type of dataset is not recognized!") 1078 1079 def generate_full_length_prediction( 1080 self, dataset=None, batch_size=32, augment_n=10 1081 ): 1082 """ 1083 Compile a prediction for the original input sequences 1084 1085 Parameters 1086 ---------- 1087 dataset : BehaviorDataset, optional 1088 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1089 batch_size : int, default 32 1090 the batch size 1091 augment_n : int, default 10 1092 the number of augmentations to average results over 1093 1094 Returns 1095 ------- 1096 prediction : dict 1097 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1098 are prediction tensors 1099 """ 1100 1101 dataset = self._get_dataset(dataset) 1102 if not isinstance(dataset, BehaviorDataset): 1103 raise TypeError( 1104 f"The dataset parameter has to be either None, string, " 1105 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1106 ) 1107 predicted = self.predict( 1108 dataset, 1109 raw_output=True, 1110 apply_primary_function=True, 1111 augment_n=augment_n, 1112 batch_size=batch_size, 1113 ) 1114 predicted = dataset.generate_full_length_prediction(predicted) 1115 predicted = { 1116 v_id: { 1117 clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze() 1118 for clip_id, v in video_dict.items() 1119 } 1120 for v_id, video_dict in predicted.items() 1121 } 1122 return predicted 1123 1124 def generate_submission( 1125 self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10 1126 ): 1127 """ 1128 Generate a MABe-22 style submission dictionary 1129 1130 Parameters 1131 ---------- 1132 dataset : BehaviorDataset, optional 1133 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1134 batch_size : int, default 32 1135 the batch size 1136 augment_n : int, default 10 1137 the number of augmentations to average results over 1138 1139 Returns 1140 ------- 1141 submission : dict 1142 a dictionary with frame number mapping and embeddings 1143 """ 1144 1145 dataset = self._get_dataset(dataset) 1146 if not isinstance(dataset, BehaviorDataset): 1147 raise TypeError( 1148 f"The dataset parameter has to be either None, string, " 1149 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1150 ) 1151 predicted = self.predict( 1152 dataset, 1153 raw_output=True, 1154 apply_primary_function=True, 1155 augment_n=augment_n, 1156 batch_size=batch_size, 1157 embedding=True, 1158 ) 1159 predicted = dataset.generate_full_length_prediction(predicted) 1160 frame_map = np.load(frame_number_map_file, allow_pickle=True).item() 1161 length = frame_map[list(frame_map.keys())[-1]][1] 1162 embeddings = None 1163 for video_id in list(predicted.keys()): 1164 split = video_id.split("--") 1165 if len(split) != 2 or len(predicted[video_id]) > 1: 1166 raise RuntimeError( 1167 "Generating submissions is only implemented for the mabe22 dataset!" 1168 ) 1169 if split[1] not in frame_map: 1170 raise RuntimeError(f"The {split[1]} video is not in the frame map file") 1171 v_id = split[1] 1172 clip_id = list(predicted[video_id].keys())[0] 1173 if embeddings is None: 1174 embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0])) 1175 start, end = frame_map[v_id] 1176 embeddings[start:end, :] = predicted[video_id][clip_id].T 1177 predicted.pop(video_id) 1178 predicted = { 1179 "frame_number_map": frame_map, 1180 "embeddings": embeddings.astype(np.float32), 1181 } 1182 return predicted 1183 1184 def _get_intervals(self, tensor: torch.Tensor) -> torch.Tensor: 1185 """ 1186 Get a list of True group beginning and end indices from a boolean tensor 1187 """ 1188 1189 output, indices = torch.unique_consecutive(tensor, return_inverse=True) 1190 true_indices = torch.where(output)[0] 1191 starts = torch.tensor( 1192 [(indices == i).nonzero(as_tuple=True)[0][0] for i in true_indices] 1193 ) 1194 ends = torch.tensor( 1195 [(indices == i).nonzero(as_tuple=True)[0][-1] + 1 for i in true_indices] 1196 ) 1197 return torch.stack([starts, ends]).T 1198 1199 def _smooth(self, tensor: torch.Tensor, smooth_interval: int = 1) -> torch.Tensor: 1200 """ 1201 Get rid of jittering in a non-exclusive classification tensor 1202 1203 First, remove intervals of 0 shorter than `smooth_interval`. Then, remove intervals of 1 shorter than 1204 `smooth_interval`. 1205 """ 1206 1207 if len(tensor.shape) > 1: 1208 for c in tensor.shape[1]: 1209 intervals = self._get_intervals(tensor[:, c] == 0) 1210 interval_lengths = torch.tensor( 1211 [interval[1] - interval[0] for interval in intervals] 1212 ) 1213 short_intervals = intervals[interval_lengths <= smooth_interval] 1214 for start, end in short_intervals: 1215 tensor[start:end, c] = 1 1216 intervals = self._get_intervals(tensor[:, c] == 1) 1217 interval_lengths = torch.tensor( 1218 [interval[1] - interval[0] for interval in intervals] 1219 ) 1220 short_intervals = intervals[interval_lengths <= smooth_interval] 1221 for start, end in short_intervals: 1222 tensor[start:end, c] = 0 1223 else: 1224 for c in tensor.unique(): 1225 intervals = self._get_intervals(tensor == c) 1226 interval_lengths = torch.tensor( 1227 [interval[1] - interval[0] for interval in intervals] 1228 ) 1229 short_intervals = intervals[interval_lengths <= smooth_interval] 1230 for start, end in short_intervals: 1231 if start == 0: 1232 tensor[start:end] = tensor[end + 1] 1233 else: 1234 tensor[start:end] = tensor[start - 1] 1235 return tensor 1236 1237 def _visualize_results_label( 1238 self, 1239 label: str, 1240 save_path: str = None, 1241 add_legend: bool = True, 1242 ground_truth: bool = True, 1243 hide_axes: bool = False, 1244 width: int = 10, 1245 whole_video: bool = False, 1246 transparent: bool = False, 1247 dataset: BehaviorDataset = None, 1248 smooth_interval: int = 0, 1249 title: str = None, 1250 ): 1251 """ 1252 Visualize random predictions 1253 1254 Parameters 1255 ---------- 1256 save_path : str, optional 1257 the path where the plot will be saved 1258 add_legend : bool, default True 1259 if `True`, legend will be added to the plot 1260 ground_truth : bool, default True 1261 if `True`, ground truth will be added to the plot 1262 colormap : str, default 'viridis' 1263 the `matplotlib` colormap to use 1264 hide_axes : bool, default True 1265 if `True`, the axes will be hidden on the plot 1266 min_classes : int, default 1 1267 the minimum number of classes in a displayed interval 1268 width : float, default 10 1269 the width of the plot 1270 whole_video : bool, default False 1271 if `True`, whole videos are plotted instead of segments 1272 transparent : bool, default False 1273 if `True`, the background on the plot is transparent 1274 dataset : BehaviorDataset, optional 1275 the dataset to make the prediction for (if not provided, the validation dataset is used) 1276 drop_classes : set, optional 1277 a set of class names to not be displayed 1278 """ 1279 1280 if title is None: 1281 title = "" 1282 dataset = self._get_dataset(dataset) 1283 inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()} 1284 label_ind = inverse_dict[label] 1285 labels = {1: label, -100: "unknown"} 1286 label_keys = [1, -100] 1287 color_list = ["blue", "gray"] 1288 if whole_video: 1289 predicted = self.generate_full_length_prediction(dataset) 1290 keys = list(predicted.keys()) 1291 counter = 0 1292 if whole_video: 1293 max_iter = len(keys) * 5 1294 else: 1295 max_iter = len(dataset) * 5 1296 ok = False 1297 while not ok: 1298 counter += 1 1299 if counter > max_iter: 1300 raise RuntimeError( 1301 "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive" 1302 ) 1303 if whole_video: 1304 i = randint(0, len(keys) - 1) 1305 prediction = predicted[keys[i]] 1306 keys_i = list(prediction.keys()) 1307 j = randint(0, len(keys_i) - 1) 1308 full_p = prediction[keys_i[j]] 1309 prediction = prediction[keys_i[j]][label_ind] 1310 else: 1311 dataloader = DataLoader(dataset) 1312 i = randint(0, len(dataloader) - 1) 1313 for num, batch in enumerate(dataloader): 1314 if num == i: 1315 break 1316 input_data = {k: v.to(self.device) for k, v in batch["input"].items()} 1317 prediction, *_ = self._get_prediction( 1318 input_data, batch.get("tag"), augment_n=5 1319 ) 1320 prediction = self._apply_predict_functions(prediction) 1321 j = randint(0, len(prediction) - 1) 1322 full_p = prediction[j] 1323 prediction = prediction[j][label_ind] 1324 classes = [x for x in torch.unique(prediction) if int(x) in label_keys] 1325 ok = 1 in classes 1326 fig, ax = plt.subplots(figsize=(width, 2)) 1327 for c in classes: 1328 c_i = label_keys.index(int(c)) 1329 output, indices, counts = torch.unique_consecutive( 1330 prediction == c, return_inverse=True, return_counts=True 1331 ) 1332 long_indices = torch.where(output)[0] 1333 res_indices_start = [ 1334 (indices == i).nonzero(as_tuple=True)[0][0].item() for i in long_indices 1335 ] 1336 res_indices_end = [ 1337 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1338 for i in long_indices 1339 ] 1340 res_indices_len = [ 1341 end - start for start, end in zip(res_indices_start, res_indices_end) 1342 ] 1343 ax.broken_barh( 1344 list(zip(res_indices_start, res_indices_len)), 1345 (0, 1), 1346 label=labels[int(c)], 1347 facecolors=color_list[c_i], 1348 ) 1349 if ground_truth: 1350 gt = batch["target"][j][label_ind].to(self.device) 1351 classes_gt = [x for x in torch.unique(gt) if int(x) in label_keys] 1352 for c in classes_gt: 1353 c_i = label_keys.index(int(c)) 1354 if c in classes: 1355 label = None 1356 else: 1357 label = labels[int(c)] 1358 output, indices, counts = torch.unique_consecutive( 1359 gt == c, return_inverse=True, return_counts=True 1360 ) 1361 long_indices = torch.where(output * (counts > 5))[0] 1362 res_indices_start = [ 1363 (indices == i).nonzero(as_tuple=True)[0][0].item() 1364 for i in long_indices 1365 ] 1366 res_indices_end = [ 1367 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1368 for i in long_indices 1369 ] 1370 res_indices_len = [ 1371 end - start 1372 for start, end in zip(res_indices_start, res_indices_end) 1373 ] 1374 ax.broken_barh( 1375 list(zip(res_indices_start, res_indices_len)), 1376 (1.5, 1), 1377 facecolors=color_list[c_i], 1378 label=label, 1379 ) 1380 self._compute( 1381 main_target=batch["target"][j].unsqueeze(0).to(self.device), 1382 predicted=full_p.unsqueeze(0).to(self.device), 1383 ssl_targets=[], 1384 ssl_predicted=[], 1385 skip_loss=True, 1386 ) 1387 metrics = self._calculate_metrics() 1388 if smooth_interval > 0: 1389 smoothed = self._smooth(full_p, smooth_interval=smooth_interval)[ 1390 label_ind, : 1391 ] 1392 for c in classes: 1393 c_i = label_keys.index(int(c)) 1394 output, indices, counts = torch.unique_consecutive( 1395 smoothed == c, return_inverse=True, return_counts=True 1396 ) 1397 long_indices = torch.where(output)[0] 1398 res_indices_start = [ 1399 (indices == i).nonzero(as_tuple=True)[0][0].item() 1400 for i in long_indices 1401 ] 1402 res_indices_end = [ 1403 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1404 for i in long_indices 1405 ] 1406 res_indices_len = [ 1407 end - start 1408 for start, end in zip(res_indices_start, res_indices_end) 1409 ] 1410 ax.broken_barh( 1411 list(zip(res_indices_start, res_indices_len)), 1412 (3, 1), 1413 label=labels[int(c)], 1414 facecolors=color_list[c_i], 1415 ) 1416 keys = list(metrics.keys()) 1417 for key in keys: 1418 if key.split("_")[-1] != (str(label_ind)): 1419 metrics.pop(key) 1420 title = [title] 1421 for key, value in metrics.items(): 1422 title.append(f"{'_'.join(key.split('_')[: -1])}: {value:.2f}") 1423 title = ", ".join(title) 1424 if not ground_truth: 1425 ax.axes.yaxis.set_visible(False) 1426 else: 1427 ax.set_yticks([0.5, 2]) 1428 ax.set_yticklabels(["prediction", "ground truth"]) 1429 if add_legend: 1430 ax.legend() 1431 if hide_axes: 1432 plt.axis("off") 1433 plt.title(title) 1434 plt.xlim((0, len(prediction))) 1435 if save_path is not None: 1436 plt.savefig(save_path, transparent=transparent) 1437 plt.show() 1438 1439 def visualize_results( 1440 self, 1441 save_path: str = None, 1442 add_legend: bool = True, 1443 ground_truth: bool = True, 1444 colormap: str = "viridis", 1445 hide_axes: bool = False, 1446 min_classes: int = 1, 1447 width: int = 10, 1448 whole_video: bool = False, 1449 transparent: bool = False, 1450 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1451 drop_classes: Set = None, 1452 search_classes: Set = None, 1453 num_samples: int = 1, 1454 smooth_interval_prediction: int = None, 1455 behavior_name: str = None, 1456 ): 1457 """ 1458 Visualize random predictions 1459 1460 Parameters 1461 ---------- 1462 save_path : str, optional 1463 the path where the plot will be saved 1464 add_legend : bool, default True 1465 if `True`, legend will be added to the plot 1466 ground_truth : bool, default True 1467 if `True`, ground truth will be added to the plot 1468 colormap : str, default 'Accent' 1469 the `matplotlib` colormap to use 1470 hide_axes : bool, default True 1471 if `True`, the axes will be hidden on the plot 1472 min_classes : int, default 1 1473 the minimum number of classes in a displayed interval 1474 width : float, default 10 1475 the width of the plot 1476 whole_video : bool, default False 1477 if `True`, whole videos are plotted instead of segments 1478 transparent : bool, default False 1479 if `True`, the background on the plot is transparent 1480 dataset : BehaviorDataset | DataLoader | str | None, optional 1481 the dataset to make the prediction for (if not provided, the validation dataset is used) 1482 drop_classes : set, optional 1483 a set of class names to not be displayed 1484 search_classes : set, optional 1485 if given, only intervals where at least one of the classes is in ground truth will be shown 1486 """ 1487 1488 if drop_classes is None: 1489 drop_classes = [] 1490 dataset = self._get_dataset(dataset) 1491 if dataset.annotation_class() == "exclusive_classification": 1492 exclusive = True 1493 elif dataset.annotation_class() == "nonexclusive_classification": 1494 exclusive = False 1495 else: 1496 raise NotImplementedError( 1497 f"Results visualisation is not implemented for {dataset.annotation_class() } datasets!" 1498 ) 1499 if exclusive: 1500 labels = { 1501 k: v 1502 for k, v in dataset.behaviors_dict().items() 1503 if v not in drop_classes 1504 } 1505 labels.update({-100: "unknown"}) 1506 else: 1507 inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()} 1508 if behavior_name is None: 1509 behavior_name = list(inverse_dict.keys())[0] 1510 label_ind = inverse_dict[behavior_name] 1511 labels = {1: label, -100: "unknown"} 1512 label_keys = sorted([int(x) for x in labels.keys()]) 1513 if search_classes is None: 1514 ok = True 1515 else: 1516 ok = False 1517 classes = [] 1518 if whole_video: 1519 predicted = self.generate_full_length_prediction(dataset) 1520 keys = list(predicted.keys()) 1521 counter = 0 1522 if whole_video: 1523 max_iter = len(keys) * 2 1524 else: 1525 max_iter = len(dataset) * 2 1526 while len(classes) < min_classes or not ok: 1527 counter += 1 1528 if counter > max_iter: 1529 raise RuntimeError( 1530 "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive" 1531 ) 1532 if whole_video: 1533 i = randint(0, len(keys) - 1) 1534 prediction = predicted[keys[i]] 1535 keys_i = list(prediction.keys()) 1536 j = randint(0, len(keys_i) - 1) 1537 prediction = prediction[keys_i[j]] 1538 key1 = keys[i] 1539 key2 = keys_i[j] 1540 else: 1541 dataloader = DataLoader(dataset) 1542 i = randint(0, len(dataloader) - 1) 1543 for num, batch in enumerate(dataloader): 1544 if num == i: 1545 break 1546 input = {k: v.to(self.device) for k, v in batch["input"].items()} 1547 prediction, *_ = self._get_prediction( 1548 input, batch.get("tag"), augment_n=5 1549 ) 1550 prediction = self._apply_predict_functions(prediction) 1551 j = randint(0, len(prediction) - 1) 1552 prediction = prediction[j] 1553 if not exclusive: 1554 prediction = prediction[label_ind] 1555 if smooth_interval_prediction > 0: 1556 unsmoothed_prediction = deepcopy(prediction) 1557 prediction = self._smooth(prediction, smooth_interval_prediction) 1558 height = 3 1559 else: 1560 height = 2 1561 classes = [ 1562 labels[int(x)] for x in torch.unique(prediction) if x in label_keys 1563 ] 1564 if search_classes is not None: 1565 ok = any([x in classes for x in search_classes]) 1566 fig, ax = plt.subplots(figsize=(width, height)) 1567 color_list = cm.get_cmap(colormap, lut=len(labels)).colors 1568 1569 def _plot_prediction(prediction, height, set_labels=True): 1570 for c in label_keys: 1571 c_i = label_keys.index(int(c)) 1572 output, indices, counts = torch.unique_consecutive( 1573 prediction == c, return_inverse=True, return_counts=True 1574 ) 1575 long_indices = torch.where(output)[0] 1576 if len(long_indices) == 0: 1577 continue 1578 res_indices_start = [ 1579 (indices == i).nonzero(as_tuple=True)[0][0].item() 1580 for i in long_indices 1581 ] 1582 res_indices_end = [ 1583 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1584 for i in long_indices 1585 ] 1586 res_indices_len = [ 1587 end - start 1588 for start, end in zip(res_indices_start, res_indices_end) 1589 ] 1590 if set_labels: 1591 label = labels[int(c)] 1592 else: 1593 label = None 1594 ax.broken_barh( 1595 list(zip(res_indices_start, res_indices_len)), 1596 (height, 1), 1597 label=label, 1598 facecolors=color_list[c_i], 1599 ) 1600 1601 if smooth_interval_prediction > 0: 1602 _plot_prediction(unsmoothed_prediction, 0) 1603 _plot_prediction(prediction, 1.5, set_labels=False) 1604 gt_height = 3 1605 else: 1606 _plot_prediction(prediction, 0) 1607 gt_height = 1.5 1608 if ground_truth: 1609 if not whole_video: 1610 gt = batch["target"][j].to(self.device) 1611 else: 1612 gt = dataset.generate_full_length_gt()[key1][key2] 1613 for c in label_keys: 1614 c_i = label_keys.index(int(c)) 1615 if labels[int(c)] in classes: 1616 label = None 1617 else: 1618 label = labels[int(c)] 1619 output, indices, counts = torch.unique_consecutive( 1620 gt == c, return_inverse=True, return_counts=True 1621 ) 1622 long_indices = torch.where(output)[0] 1623 if len(long_indices) == 0: 1624 continue 1625 res_indices_start = [ 1626 (indices == i).nonzero(as_tuple=True)[0][0].item() 1627 for i in long_indices 1628 ] 1629 res_indices_end = [ 1630 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1631 for i in long_indices 1632 ] 1633 res_indices_len = [ 1634 end - start 1635 for start, end in zip(res_indices_start, res_indices_end) 1636 ] 1637 ax.broken_barh( 1638 list(zip(res_indices_start, res_indices_len)), 1639 (gt_height, 1), 1640 facecolors=color_list[c_i] if c != "unknown" else "gray", 1641 label=label, 1642 ) 1643 if not ground_truth: 1644 ax.axes.yaxis.set_visible(False) 1645 else: 1646 if smooth_interval_prediction > 0: 1647 ax.set_yticks([0.5, 2, 3.5]) 1648 ax.set_yticklabels(["prediction", "smoothed", "ground truth"]) 1649 else: 1650 ax.set_yticks([0.5, 2]) 1651 ax.set_yticklabels(["prediction", "ground truth"]) 1652 if add_legend: 1653 ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) 1654 if hide_axes: 1655 plt.axis("off") 1656 if save_path is not None: 1657 plt.savefig(save_path, transparent=transparent) 1658 plt.show() 1659 1660 def set_ssl_transformations(self, ssl_transformations): 1661 self.train_dataloader.dataset.set_ssl_transformations(ssl_transformations) 1662 if self.val_dataloader is not None: 1663 self.val_dataloader.dataset.set_ssl_transformations(ssl_transformations) 1664 1665 def set_ssl_losses(self, ssl_losses: list) -> None: 1666 """ 1667 Set SSL losses 1668 1669 Parameters 1670 ---------- 1671 ssl_losses : list 1672 a list of callable SSL losses 1673 """ 1674 1675 self.ssl_losses = ssl_losses 1676 1677 def set_log(self, log: str) -> None: 1678 """ 1679 Set the log file 1680 1681 Parameters 1682 ---------- 1683 log: str 1684 the mew log file path 1685 """ 1686 1687 self.log_file = log 1688 1689 def set_keep_target_none(self, keep_target_none: List) -> None: 1690 """ 1691 Set the keep_target_none parameter of the transformer 1692 1693 Parameters 1694 ---------- 1695 keep_target_none : list 1696 a list of bool values 1697 """ 1698 1699 self.transformer.keep_target_none = keep_target_none 1700 1701 def set_generate_ssl_input(self, generate_ssl_input: list) -> None: 1702 """ 1703 Set the generate_ssl_input parameter of the transformer 1704 1705 Parameters 1706 ---------- 1707 generate_ssl_input : list 1708 a list of bool values 1709 """ 1710 1711 self.transformer.generate_ssl_input = generate_ssl_input 1712 1713 def set_model(self, model: Model) -> None: 1714 """ 1715 Set a new model 1716 1717 Parameters 1718 ---------- 1719 model: Model 1720 the new model 1721 """ 1722 1723 self.epoch = 0 1724 self.model = model 1725 self.optimizer = self.optimizer_class(model.parameters(), lr=self.lr) 1726 if self.model.process_labels: 1727 self.model.set_behaviors(list(self.behaviors_dict().values())) 1728 1729 def set_dataloaders( 1730 self, 1731 train_dataloader: DataLoader, 1732 val_dataloader: DataLoader = None, 1733 test_dataloader: DataLoader = None, 1734 ) -> None: 1735 """ 1736 Set new dataloaders 1737 1738 Parameters 1739 ---------- 1740 train_dataloader: torch.utils.data.DataLoader 1741 the new train dataloader 1742 val_dataloader : torch.utils.data.DataLoader 1743 the new validation dataloader 1744 test_dataloader : torch.utils.data.DataLoader 1745 the new test dataloader 1746 """ 1747 1748 self.train_dataloader = train_dataloader 1749 self.val_dataloader = val_dataloader 1750 self.test_dataloader = test_dataloader 1751 1752 def set_loss(self, loss: Callable) -> None: 1753 """ 1754 Set new loss function 1755 1756 Parameters 1757 ---------- 1758 loss: callable 1759 the new loss function 1760 """ 1761 1762 self.loss = loss 1763 1764 def set_metrics(self, metrics: dict) -> None: 1765 """ 1766 Set new metric 1767 1768 Parameters 1769 ---------- 1770 metrics : dict 1771 the new metric dictionary 1772 """ 1773 1774 self.metrics = metrics 1775 1776 def set_transformer(self, transformer: Transformer) -> None: 1777 """ 1778 Set a new transformer 1779 1780 Parameters 1781 ---------- 1782 transformer: Transformer 1783 the new transformer 1784 """ 1785 1786 self.transformer = transformer 1787 1788 def set_predict_functions( 1789 self, primary_predict_function: Callable, predict_function: Callable 1790 ) -> None: 1791 """ 1792 Set new predict functions 1793 1794 Parameters 1795 ---------- 1796 primary_predict_function : callable 1797 the new primary predict function 1798 predict_function : callable 1799 the new predict function 1800 """ 1801 1802 self.primary_predict_function = primary_predict_function 1803 self.predict_function = predict_function 1804 1805 def _set_pseudolabels(self): 1806 """ 1807 Set pseudolabels 1808 """ 1809 1810 self.train_dataloader.dataset.set_unlabeled(True) 1811 predicted = self.predict( 1812 data=self.dataset("train"), 1813 raw_output=False, 1814 augment_n=self.augment_val, 1815 ssl_off=True, 1816 ) 1817 self.train_dataloader.dataset.set_annotation(predicted.detach()) 1818 1819 def _alpha(self, epoch): 1820 """ 1821 Get the current pseudolabeling alpha parameter 1822 """ 1823 1824 if epoch <= self.T1: 1825 return 0 1826 elif epoch < self.T2: 1827 return self.alpha_f * (epoch - self.T1) / (self.T2 - self.T1) 1828 else: 1829 return self.alpha_f 1830 1831 def count_classes(self, bouts: bool = False) -> Dict: 1832 """ 1833 Get a dictionary of class counts in different modes 1834 1835 Parameters 1836 ---------- 1837 bouts : bool, default False 1838 if `True`, instead of frame counts segment counts are returned 1839 1840 Returns 1841 ------- 1842 class_counts : dict 1843 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1844 class names and values are class counts (in frames) 1845 """ 1846 1847 class_counts = {} 1848 for x in ["train", "val", "test"]: 1849 try: 1850 class_counts[x] = self.dataset(x).count_classes(bouts) 1851 except ValueError: 1852 class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()} 1853 return class_counts 1854 1855 def behaviors_dict(self) -> Dict: 1856 """ 1857 Get a behavior dictionary 1858 1859 Keys are label indices and values are label names. 1860 1861 Returns 1862 ------- 1863 behaviors_dict : dict 1864 behavior dictionary 1865 """ 1866 1867 return self.dataset().behaviors_dict() 1868 1869 def update_parameters(self, parameters: Dict) -> None: 1870 """ 1871 Update training parameters from a dictionary 1872 1873 Parameters 1874 ---------- 1875 parameters : dict 1876 the update dictionary 1877 """ 1878 1879 self.lr = parameters.get("lr", self.lr) 1880 self.parallel = parameters.get("parallel", self.parallel) 1881 self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr) 1882 self.verbose = parameters.get("verbose", self.verbose) 1883 self.device = parameters.get("device", self.device) 1884 if self.device == "auto": 1885 self.device = "cuda" if torch.cuda.is_available() else "cpu" 1886 self.augment_train = int(parameters.get("augment_train", self.augment_train)) 1887 self.augment_val = int(parameters.get("augment_val", self.augment_val)) 1888 ssl_weights = parameters.get("ssl_weights", self.ssl_weights) 1889 if ssl_weights is None: 1890 ssl_weights = [1 for _ in self.ssl_losses] 1891 if not isinstance(ssl_weights, Iterable): 1892 ssl_weights = [ssl_weights for _ in self.ssl_losses] 1893 self.ssl_weights = ssl_weights 1894 self.num_epochs = parameters.get("num_epochs", self.num_epochs) 1895 self.model_save_epochs = parameters.get( 1896 "model_save_epochs", self.model_save_epochs 1897 ) 1898 self.model_save_path = parameters.get("model_save_path", self.model_save_path) 1899 self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel) 1900 self.T1 = int(parameters.get("pseudolabel_start", self.T1)) 1901 self.T2 = int(parameters.get("alpha_growth_stop", self.T2)) 1902 self.t = int(parameters.get("correction_interval", self.t)) 1903 self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f) 1904 self.log_file = parameters.get("log_file", self.log_file) 1905 1906 def generate_uncertainty_score( 1907 self, 1908 classes: List, 1909 augment_n: int = 0, 1910 batch_size: int = 32, 1911 method: str = "least_confidence", 1912 predicted: torch.Tensor = None, 1913 behaviors_dict: Dict = None, 1914 ) -> Dict: 1915 """ 1916 Generate frame-wise scores for active learning 1917 1918 Parameters 1919 ---------- 1920 classes : list 1921 a list of class names or indices; their confidence scores will be computed separately and stacked 1922 augment_n : int, default 0 1923 the number of augmentations to average over 1924 batch_size : int, default 32 1925 the batch size 1926 method : {"least_confidence", "entropy"} 1927 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1928 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1929 1930 Returns 1931 ------- 1932 score_dicts : dict 1933 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1934 are score tensors 1935 """ 1936 1937 dataset = self.dataset("train") 1938 if behaviors_dict is None: 1939 behaviors_dict = self.behaviors_dict() 1940 if not isinstance(dataset, BehaviorDataset): 1941 raise TypeError( 1942 f"The dataset parameter has to be either None, string, " 1943 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1944 ) 1945 if predicted is None: 1946 predicted = self.predict( 1947 dataset, 1948 raw_output=True, 1949 apply_primary_function=True, 1950 augment_n=augment_n, 1951 batch_size=batch_size, 1952 ) 1953 predicted = dataset.generate_full_length_prediction(predicted) 1954 if isinstance(classes[0], str): 1955 behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()} 1956 classes = [behaviors_dict_inverse[c] for c in classes] 1957 for v_id, v in predicted.items(): 1958 for clip_id, vv in v.items(): 1959 if method == "least_confidence": 1960 predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5] 1961 elif method == "entropy": 1962 predicted[v_id][clip_id][vv != -100] = ( 1963 -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv) 1964 )[vv != -100] 1965 elif method == "random": 1966 predicted[v_id][clip_id] = torch.rand(vv.shape) 1967 else: 1968 raise ValueError( 1969 f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']" 1970 ) 1971 predicted[v_id][clip_id][vv == -100] = 0 1972 1973 predicted = { 1974 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 1975 for v_id, video_dict in predicted.items() 1976 } 1977 return predicted 1978 1979 def generate_bald_score( 1980 self, 1981 classes: List, 1982 augment_n: int = 0, 1983 batch_size: int = 32, 1984 num_models: int = 10, 1985 kernel_size: int = 11, 1986 ) -> Dict: 1987 """ 1988 Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning 1989 1990 Parameters 1991 ---------- 1992 classes : list 1993 a list of class names or indices; their confidence scores will be computed separately and stacked 1994 augment_n : int, default 0 1995 the number of augmentations to average over 1996 batch_size : int, default 32 1997 the batch size 1998 num_models : int, default 10 1999 the number of dropout masks to apply 2000 kernel_size : int, default 11 2001 the size of the smoothing gaussian kernel 2002 2003 Returns 2004 ------- 2005 score_dicts : dict 2006 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 2007 are score tensors 2008 """ 2009 2010 dataset = self.dataset("train") 2011 dataset = self._get_dataset(dataset) 2012 if not isinstance(dataset, BehaviorDataset): 2013 raise TypeError( 2014 f"The dataset parameter has to be either None, string, " 2015 f"BehaviorDataset or Dataloader, got {type(dataset)}" 2016 ) 2017 predictions = [] 2018 for _ in range(num_models): 2019 predicted = self.predict( 2020 dataset, 2021 raw_output=True, 2022 apply_primary_function=True, 2023 augment_n=augment_n, 2024 batch_size=batch_size, 2025 train_mode=True, 2026 ) 2027 predicted = dataset.generate_full_length_prediction(predicted) 2028 if isinstance(classes[0], str): 2029 behaviors_dict_inverse = { 2030 v: k for k, v in self.behaviors_dict().items() 2031 } 2032 classes = [behaviors_dict_inverse[c] for c in classes] 2033 for v_id, v in predicted.items(): 2034 for clip_id, vv in v.items(): 2035 vv[vv != -100] = (vv[vv != -100] > 0.5).int().float() 2036 predicted[v_id][clip_id] = vv 2037 predicted = { 2038 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 2039 for v_id, video_dict in predicted.items() 2040 } 2041 predictions.append(predicted) 2042 result = {v_id: {} for v_id in predictions[0]} 2043 r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1) 2044 gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r] 2045 gauss = [x / sum(gauss) for x in gauss] 2046 kernel = torch.FloatTensor([[gauss]]) 2047 for v_id in predictions[0]: 2048 for clip_id in predictions[0][v_id]: 2049 consensus = ( 2050 ( 2051 torch.mean( 2052 torch.stack([x[v_id][clip_id] for x in predictions]), dim=0 2053 ) 2054 > 0.5 2055 ) 2056 .int() 2057 .float() 2058 ) 2059 consensus[predictions[0][v_id][clip_id] == -100] = -100 2060 result[v_id][clip_id] = torch.zeros(consensus.shape) 2061 for x in predictions: 2062 result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int() 2063 result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models 2064 res = torch.zeros(result[v_id][clip_id].shape) 2065 for i in range(len(classes)): 2066 res[ 2067 i, floor(kernel_size // 2) : -floor(kernel_size // 2) 2068 ] = torch.nn.functional.conv1d( 2069 result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0), kernel 2070 )[ 2071 0, ... 2072 ] 2073 result[v_id][clip_id] = res 2074 return result 2075 2076 def get_normalization_stats(self) -> Optional[Dict]: 2077 return self.train_dataloader.dataset.stats
A universal trainer class that performs training, evaluation and prediction for all types of tasks and data
75 def __init__( 76 self, 77 train_dataloader: DataLoader, 78 model: Union[nn.Module, Model], 79 loss: Callable[[torch.Tensor, torch.Tensor], float], 80 num_epochs: int = 0, 81 transformer: Transformer = None, 82 ssl_losses: List = None, 83 ssl_weights: List = None, 84 lr: float = 1e-3, 85 metrics: Dict = None, 86 val_dataloader: DataLoader = None, 87 test_dataloader: DataLoader = None, 88 optimizer: Optimizer = None, 89 device: str = "cuda", 90 verbose: bool = True, 91 log_file: Union[str, None] = None, 92 augment_train: int = 1, 93 augment_val: int = 0, 94 validation_interval: int = 1, 95 predict_function: Union[Callable[[torch.Tensor], torch.Tensor], None] = None, 96 primary_predict_function: Callable = None, 97 exclusive: bool = True, 98 ignore_tags: bool = True, 99 threshold: float = 0.5, 100 model_save_path: str = None, 101 model_save_epochs: int = 5, 102 pseudolabel: bool = False, 103 pseudolabel_start: int = 100, 104 correction_interval: int = 2, 105 pseudolabel_alpha_f: float = 3, 106 alpha_growth_stop: int = 600, 107 parallel: bool = False, 108 skip_metrics: List = None, 109 ) -> None: 110 """ 111 Parameters 112 ---------- 113 train_dataloader : torch.utils.data.DataLoader 114 a training dataloader 115 model : dlc2action.model.base_model.Model 116 a model 117 loss : callable 118 a loss function 119 num_epochs : int, default 0 120 the number of epochs 121 transformer : dlc2action.transformer.base_transformer.Transformer, optional 122 a transformer 123 ssl_losses : list, optional 124 a list of SSL losses 125 ssl_weights : list, optional 126 a list of SSL weights (if not provided initializes to 1) 127 lr : float, default 1e-3 128 learning rate 129 metrics : dict, optional 130 a list of metric functions 131 val_dataloader : torch.utils.data.DataLoader, optional 132 a validation dataloader 133 optimizer : torch.optim.Optimizer, optional 134 an optimizer (`Adam` by default) 135 device : str, default 'cuda' 136 the device to train the model on 137 verbose : bool, default True 138 if `True`, the process is described in standard output 139 log_file : str, optional 140 the path to a text file where the process will be logged 141 augment_train : {1, 0} 142 number of augmentations to apply at training 143 augment_val : int, default 0 144 number of augmentations to apply at validation 145 validation_interval : int, default 1 146 every time this number of epochs passes, validation metrics are computed 147 predict_function : callable, optional 148 a function that maps probabilities to class predictions (if not provided, a default is generated) 149 primary_predict_function : callable, optional 150 a function that maps model output to probabilities (if not provided, initialized as identity) 151 exclusive : bool, default True 152 set to False for multi-label classification 153 ignore_tags : bool, default False 154 if `True`, samples with different meta tags will be mixed in batches 155 threshold : float, default 0.5 156 the threshold used for multi-label classification default prediction function 157 model_save_path : str, optional 158 the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path 159 is not provided) 160 model_save_epochs : int, default 5 161 the interval for saving the model checkpoints (the last epoch is always saved) 162 pseudolabel : bool, default False 163 if True, the pseudolabeling procedure will be applied 164 pseudolabel_start : int, default 100 165 pseudolabeling starts after this epoch 166 correction_interval : int, default 1 167 after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and 168 new pseudolabels are generated 169 pseudolabel_alpha_f : float, default 3 170 the maximum value of pseudolabeling alpha 171 alpha_growth_stop : int, default 600 172 pseudolabeling alpha stops growing after this epoch 173 """ 174 175 # pseudolabeling might be buggy right now -- not using it! 176 if skip_metrics is None: 177 skip_metrics = [] 178 self.train_dataloader = train_dataloader 179 self.val_dataloader = val_dataloader 180 self.test_dataloader = test_dataloader 181 self.transformer = transformer 182 self.num_epochs = num_epochs 183 self.skip_metrics = skip_metrics 184 self.verbose = verbose 185 self.augment_train = int(augment_train) 186 self.augment_val = int(augment_val) 187 self.ignore_tags = ignore_tags 188 self.validation_interval = int(validation_interval) 189 self.log_file = log_file 190 self.loss = loss 191 self.model_save_path = model_save_path 192 self.model_save_epochs = model_save_epochs 193 self.epoch = 0 194 195 if metrics is None: 196 metrics = {} 197 self.metrics = metrics 198 199 if optimizer is None: 200 optimizer = Adam 201 202 if ssl_weights is None: 203 ssl_weights = [1 for _ in ssl_losses] 204 if not isinstance(ssl_weights, Iterable): 205 ssl_weights = [ssl_weights for _ in ssl_losses] 206 self.ssl_weights = ssl_weights 207 208 self.optimizer_class = optimizer 209 self.lr = lr 210 if not isinstance(model, Model): 211 self.model = LoadedModel(model=model) 212 else: 213 self.set_model(model) 214 self.parallel = parallel 215 216 if self.transformer is None: 217 self.augment_val = 0 218 self.augment_train = 0 219 self.transformer = EmptyTransformer() 220 221 if self.augment_train > 1: 222 warnings.warn( 223 'The "augment_train" parameter is too large -> setting it to 1.' 224 ) 225 self.augment_train = 1 226 227 try: 228 if device == "auto": 229 device = "cuda" if torch.cuda.is_available() else "cpu" 230 self.device = torch.device(device) 231 except: 232 raise ("The format of the device is incorrect") 233 234 if ssl_losses is None: 235 self.ssl_losses = [lambda x, y: 0] 236 else: 237 self.ssl_losses = ssl_losses 238 239 if primary_predict_function is None: 240 if exclusive: 241 primary_predict_function = lambda x: nn.Softmax(x, dim=1) 242 else: 243 primary_predict_function = lambda x: torch.sigmoid(x) 244 self.primary_predict_function = primary_predict_function 245 246 if predict_function is None: 247 if exclusive: 248 self.predict_function = lambda x: torch.max(x.data, 1)[1] 249 else: 250 self.predict_function = lambda x: (x > threshold).int() 251 else: 252 self.predict_function = predict_function 253 254 self.pseudolabel = pseudolabel 255 self.alpha_f = pseudolabel_alpha_f 256 self.T2 = alpha_growth_stop 257 self.T1 = pseudolabel_start 258 self.t = correction_interval 259 if self.T2 <= self.T1: 260 raise ValueError( 261 f"The pseudolabel_start parameter has to be smaller than alpha_growth_stop; got " 262 f"{pseudolabel_start=} and {alpha_growth_stop=}" 263 ) 264 self.decision_thresholds = [0.5 for x in self.behaviors_dict()]
Parameters
train_dataloader : torch.utils.data.DataLoader
a training dataloader
model : dlc2action.model.base_model.Model
a model
loss : callable
a loss function
num_epochs : int, default 0
the number of epochs
transformer : dlc2action.transformer.base_transformer.Transformer, optional
a transformer
ssl_losses : list, optional
a list of SSL losses
ssl_weights : list, optional
a list of SSL weights (if not provided initializes to 1)
lr : float, default 1e-3
learning rate
metrics : dict, optional
a list of metric functions
val_dataloader : torch.utils.data.DataLoader, optional
a validation dataloader
optimizer : torch.optim.Optimizer, optional
an optimizer (Adam
by default)
device : str, default 'cuda'
the device to train the model on
verbose : bool, default True
if True
, the process is described in standard output
log_file : str, optional
the path to a text file where the process will be logged
augment_train : {1, 0}
number of augmentations to apply at training
augment_val : int, default 0
number of augmentations to apply at validation
validation_interval : int, default 1
every time this number of epochs passes, validation metrics are computed
predict_function : callable, optional
a function that maps probabilities to class predictions (if not provided, a default is generated)
primary_predict_function : callable, optional
a function that maps model output to probabilities (if not provided, initialized as identity)
exclusive : bool, default True
set to False for multi-label classification
ignore_tags : bool, default False
if True
, samples with different meta tags will be mixed in batches
threshold : float, default 0.5
the threshold used for multi-label classification default prediction function
model_save_path : str, optional
the path to the folder where model checkpoints will be saved (checkpoints will not be saved if the path
is not provided)
model_save_epochs : int, default 5
the interval for saving the model checkpoints (the last epoch is always saved)
pseudolabel : bool, default False
if True, the pseudolabeling procedure will be applied
pseudolabel_start : int, default 100
pseudolabeling starts after this epoch
correction_interval : int, default 1
after this number of epochs, if the pseudolabeling is on, the model is trained on the labeled data and
new pseudolabels are generated
pseudolabel_alpha_f : float, default 3
the maximum value of pseudolabeling alpha
alpha_growth_stop : int, default 600
pseudolabeling alpha stops growing after this epoch
266 def save_checkpoint(self, checkpoint_path: str) -> None: 267 """ 268 Save a general checkpoint 269 270 Parameters 271 ---------- 272 checkpoint_path : str 273 the path where the checkpoint will be saved 274 """ 275 276 torch.save( 277 { 278 "epoch": self.epoch, 279 "model_state_dict": self.model.state_dict(), 280 "optimizer_state_dict": self.optimizer.state_dict(), 281 }, 282 checkpoint_path, 283 )
Save a general checkpoint
Parameters
checkpoint_path : str the path where the checkpoint will be saved
285 def load_from_checkpoint( 286 self, checkpoint_path, only_model: bool = False, load_strict: bool = True 287 ) -> None: 288 """ 289 Load from a checkpoint 290 291 Parameters 292 ---------- 293 checkpoint_path : str 294 the path to the checkpoint 295 only_model : bool, default False 296 if `True`, only the model state dictionary will be loaded (and not the epoch and the optimizer state 297 dictionary) 298 load_strict : bool, default True 299 if `True`, any inconsistencies in state dictionaries are regarded as errors 300 """ 301 302 if checkpoint_path is None: 303 return 304 checkpoint = torch.load(checkpoint_path, map_location=self.device) 305 self.model.to(self.device) 306 self.model.load_state_dict(checkpoint["model_state_dict"], strict=load_strict) 307 if not only_model: 308 self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 309 self.epoch = checkpoint["epoch"]
Load from a checkpoint
Parameters
checkpoint_path : str
the path to the checkpoint
only_model : bool, default False
if True
, only the model state dictionary will be loaded (and not the epoch and the optimizer state
dictionary)
load_strict : bool, default True
if True
, any inconsistencies in state dictionaries are regarded as errors
311 def save_model(self, save_path: str) -> None: 312 """ 313 Save the model state dictionary 314 315 Parameters 316 ---------- 317 save_path : str 318 the path where the state will be saved 319 """ 320 321 torch.save(self.model.state_dict(), save_path) 322 print("saved the model successfully")
Save the model state dictionary
Parameters
save_path : str the path where the state will be saved
603 def train( 604 self, 605 trial: Trial = None, 606 optimized_metric: str = None, 607 to_ram: bool = False, 608 autostop_interval: int = 30, 609 autostop_threshold: float = 0.001, 610 autostop_metric: str = None, 611 main_task_on: bool = True, 612 ssl_on: bool = True, 613 temporal_subsampling_size: int = None, 614 loading_bar: bool = False, 615 ) -> Tuple: 616 """ 617 Train the task and return a log of epoch-average loss and metric 618 619 You can use the autostop parameters to finish training when the parameters are not improving. It will be 620 stopped if the average value of `autostop_metric` over the last `autostop_interval` epochs is smaller than 621 the average over the previous `autostop_interval` epochs + `autostop_threshold`. For example, if the 622 current epoch is 120 and `autostop_interval` is 50, the averages over epochs 70-120 and 20-70 will be compared. 623 624 Parameters 625 ---------- 626 trial : Trial 627 an `optuna` trial (for hyperparameter searches) 628 optimized_metric : str 629 the name of the metric being optimized (for hyperparameter searches) 630 to_ram : bool, default False 631 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 632 if the dataset is too large) 633 autostop_interval : int, default 50 634 the number of epochs to average the autostop metric over 635 autostop_threshold : float, default 0.001 636 the autostop difference threshold 637 autostop_metric : str, optional 638 the autostop metric (can be any one of the tracked metrics of `'loss'`) 639 main_task_on : bool, default True 640 if `False`, the main task (action segmentation) will not be used in training 641 ssl_on : bool, default True 642 if `False`, the SSL task will not be used in training 643 644 Returns 645 ------- 646 loss_log: list 647 a list of float loss function values for each epoch 648 metrics_log: dict 649 a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric 650 names, values are lists of function values) 651 """ 652 653 if self.parallel and not isinstance(self.model, nn.DataParallel): 654 self.model = MyDataParallel(self.model) 655 self.model.to(self.device) 656 assert autostop_metric in [None, "loss"] + list(self.metrics) 657 autostop_interval //= self.validation_interval 658 if trial is not None and optimized_metric is None: 659 raise ValueError( 660 "You need to provide the optimized metric name (optimized_metric parameter) " 661 "for optuna pruning to work!" 662 ) 663 if to_ram: 664 print("transferring datasets to RAM...") 665 self.train_dataloader.dataset.to_ram() 666 if self.val_dataloader is not None and len(self.val_dataloader) > 0: 667 self.val_dataloader.dataset.to_ram() 668 loss_log = {"train": [], "val": []} 669 metrics_log = {"train": defaultdict(lambda: []), "val": defaultdict(lambda: [])} 670 if not main_task_on: 671 self.model.main_task_off() 672 if not ssl_on: 673 self.model.ssl_off() 674 while self.epoch < self.num_epochs: 675 self.epoch += 1 676 unlabeled = None 677 alpha = 1 678 if self.pseudolabel: 679 if self.epoch >= self.T1: 680 unlabeled = (self.epoch - self.T1) % self.t != 0 681 if unlabeled: 682 alpha = self._alpha(self.epoch) 683 else: 684 unlabeled = False 685 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 686 dataloader=self.train_dataloader, 687 mode="train", 688 augment_n=self.augment_train, 689 unlabeled=unlabeled, 690 alpha=alpha, 691 temporal_subsampling_size=temporal_subsampling_size, 692 verbose=loading_bar, 693 ) 694 loss_log["train"].append(epoch_loss) 695 epoch_string = f"[epoch {self.epoch}]" 696 if self.pseudolabel: 697 if unlabeled: 698 epoch_string += " (unlabeled)" 699 else: 700 epoch_string += " (labeled)" 701 epoch_string += f": loss {epoch_loss:.4f}" 702 703 if len(epoch_ssl_loss) != 0: 704 for key, value in sorted(epoch_ssl_loss.items()): 705 metrics_log["train"][f"ssl_loss_{key}"].append(value) 706 epoch_string += f", ssl_loss_{key} {value:.4f}" 707 708 for metric_name, metric_value in sorted(epoch_metrics.items()): 709 if metric_name not in self.skip_metrics: 710 metrics_log["train"][metric_name].append(metric_value) 711 epoch_string += f", {metric_name} {metric_value:.3f}" 712 713 if ( 714 self.val_dataloader is not None 715 and self.epoch % self.validation_interval == 0 716 ): 717 with torch.no_grad(): 718 epoch_string += "\n" 719 ( 720 val_epoch_loss, 721 val_epoch_ssl_loss, 722 val_epoch_metrics, 723 ) = self._run_epoch( 724 dataloader=self.val_dataloader, 725 mode="val", 726 augment_n=self.augment_val, 727 ) 728 loss_log["val"].append(val_epoch_loss) 729 epoch_string += f"validation: loss {val_epoch_loss:.4f}" 730 731 if len(val_epoch_ssl_loss) != 0: 732 for key, value in sorted(val_epoch_ssl_loss.items()): 733 metrics_log["val"][f"ssl_loss_{key}"].append(value) 734 epoch_string += f", ssl_loss_{key} {value:.4f}" 735 736 for metric_name, metric_value in sorted(val_epoch_metrics.items()): 737 metrics_log["val"][metric_name].append(metric_value) 738 epoch_string += f", {metric_name} {metric_value:.3f}" 739 740 if trial is not None: 741 if optimized_metric not in metrics_log["val"]: 742 raise ValueError( 743 f"The {optimized_metric} metric set for optimization is not being logged!" 744 ) 745 trial.report(metrics_log["val"][optimized_metric][-1], self.epoch) 746 if trial.should_prune(): 747 raise TrialPruned() 748 749 if self.verbose: 750 print(epoch_string) 751 752 if self.log_file is not None: 753 with open(self.log_file, "a") as f: 754 f.write(epoch_string + "\n") 755 756 save_condition = ( 757 (self.model_save_epochs != 0) 758 and (self.epoch % self.model_save_epochs == 0) 759 ) or (self.epoch == self.num_epochs) 760 761 if self.epoch > 0 and save_condition and self.model_save_path is not None: 762 epoch_s = str(self.epoch).zfill(len(str(self.num_epochs))) 763 self.save_checkpoint( 764 os.path.join(self.model_save_path, f"epoch{epoch_s}.pt") 765 ) 766 767 if self.pseudolabel and self.epoch >= self.T1 and not unlabeled: 768 self._set_pseudolabels() 769 770 if autostop_metric == "loss": 771 if len(loss_log["val"]) > autostop_interval * 2: 772 if ( 773 np.mean(loss_log["val"][-autostop_interval:]) 774 < np.mean( 775 loss_log["val"][-2 * autostop_interval : -autostop_interval] 776 ) 777 + autostop_threshold 778 ): 779 break 780 elif autostop_metric in metrics_log["val"]: 781 if len(metrics_log["val"][autostop_metric]) > autostop_interval * 2: 782 if ( 783 np.mean( 784 metrics_log["val"][autostop_metric][-autostop_interval:] 785 ) 786 < np.mean( 787 metrics_log["val"][autostop_metric][ 788 -2 * autostop_interval : -autostop_interval 789 ] 790 ) 791 + autostop_threshold 792 ): 793 break 794 795 metrics_log = {k: dict(v) for k, v in metrics_log.items()} 796 797 return loss_log, metrics_log
Train the task and return a log of epoch-average loss and metric
You can use the autostop parameters to finish training when the parameters are not improving. It will be
stopped if the average value of autostop_metric
over the last autostop_interval
epochs is smaller than
the average over the previous autostop_interval
epochs + autostop_threshold
. For example, if the
current epoch is 120 and autostop_interval
is 50, the averages over epochs 70-120 and 20-70 will be compared.
Parameters
trial : Trial
an optuna
trial (for hyperparameter searches)
optimized_metric : str
the name of the metric being optimized (for hyperparameter searches)
to_ram : bool, default False
if True
, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
if the dataset is too large)
autostop_interval : int, default 50
the number of epochs to average the autostop metric over
autostop_threshold : float, default 0.001
the autostop difference threshold
autostop_metric : str, optional
the autostop metric (can be any one of the tracked metrics of 'loss'
)
main_task_on : bool, default True
if False
, the main task (action segmentation) will not be used in training
ssl_on : bool, default True
if False
, the SSL task will not be used in training
Returns
loss_log: list a list of float loss function values for each epoch metrics_log: dict a dictionary of metric value logs (first-level keys are 'train' and 'val', second-level keys are metric names, values are lists of function values)
799 def evaluate_prediction( 800 self, 801 prediction: Union[torch.Tensor, Dict], 802 data: Union[DataLoader, BehaviorDataset, str] = None, 803 batch_size: int = 32, 804 ) -> Tuple: 805 """ 806 Compute metrics for a prediction 807 808 Parameters 809 ---------- 810 prediction : torch.Tensor 811 the prediction 812 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 813 the data the prediction was made for (if not provided, take the validation dataset) 814 batch_size : int, default 32 815 the batch size 816 817 Returns 818 ------- 819 loss : float 820 the average value of the loss function 821 metric : dict 822 a dictionary of average values of metric functions 823 """ 824 825 if type(data) is not DataLoader: 826 dataset = self._get_dataset(data) 827 data = DataLoader(dataset, shuffle=False, batch_size=batch_size) 828 epoch_loss = 0 829 if isinstance(prediction, dict): 830 num_classes = len(self.behaviors_dict()) 831 length = dataset.len_segment() 832 coords = dataset.annotation_store.get_original_coordinates() 833 for batch in data: 834 main_target = batch["target"] 835 pr_coords = coords[batch["index"]] 836 predicted = torch.zeros((len(pr_coords), num_classes, length)) 837 for i, c in enumerate(pr_coords): 838 video_id = dataset.input_store.get_video_id(c) 839 clip_id = dataset.input_store.get_clip_id(c) 840 start, end = dataset.input_store.get_clip_start_end(c) 841 predicted[i, :, : end - start] = prediction[video_id][clip_id][ 842 :, start:end 843 ] 844 self._compute( 845 [], 846 [], 847 predicted, 848 main_target, 849 skip_loss=True, 850 tag=batch.get("tag"), 851 apply_primary_function=False, 852 ) 853 else: 854 for batch in data: 855 main_target = batch["target"] 856 predicted = prediction[batch["index"]] 857 self._compute( 858 [], 859 [], 860 predicted, 861 main_target, 862 skip_loss=True, 863 tag=batch.get("tag"), 864 apply_primary_function=False, 865 ) 866 epoch_metrics = self._calculate_metrics() 867 strings = [ 868 f"{metric_name} {metric_value:.3f}" 869 for metric_name, metric_value in epoch_metrics.items() 870 ] 871 val_string = ", ".join(sorted(strings)) 872 print(val_string) 873 return epoch_loss, epoch_metrics
Compute metrics for a prediction
Parameters
prediction : torch.Tensor the prediction data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data the prediction was made for (if not provided, take the validation dataset) batch_size : int, default 32 the batch size
Returns
loss : float the average value of the loss function metric : dict a dictionary of average values of metric functions
875 def evaluate( 876 self, 877 data: Union[DataLoader, BehaviorDataset, str] = None, 878 augment_n: int = 0, 879 batch_size: int = 32, 880 verbose: bool = True, 881 ) -> Tuple: 882 """ 883 Evaluate the Task model 884 885 Parameters 886 ---------- 887 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional 888 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 889 augment_n : int, default 0 890 the number of augmentations to average results over 891 batch_size : int, default 32 892 the batch size 893 verbose : bool, default True 894 if True, the process is reported to standard output 895 896 Returns 897 ------- 898 loss : float 899 the average value of the loss function 900 ssl_loss : float 901 the average value of the SSL loss function 902 metric : dict 903 a dictionary of average values of metric functions 904 """ 905 906 if self.parallel and not isinstance(self.model, nn.DataParallel): 907 self.model = MyDataParallel(self.model) 908 self.model.to(self.device) 909 if type(data) is not DataLoader: 910 data = self._get_dataset(data) 911 data = DataLoader(data, shuffle=False, batch_size=batch_size) 912 with torch.no_grad(): 913 epoch_loss, epoch_ssl_loss, epoch_metrics = self._run_epoch( 914 dataloader=data, mode="val", augment_n=augment_n, verbose=verbose 915 ) 916 val_string = f"loss {epoch_loss:.4f}" 917 for metric_name, metric_value in sorted(epoch_metrics.items()): 918 val_string += f", {metric_name} {metric_value:.3f}" 919 print(val_string) 920 return epoch_loss, epoch_ssl_loss, epoch_metrics
Evaluate the Task model
Parameters
data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset, optional the data to evaluate on (if not provided, evaluate on the Task validation dataset) augment_n : int, default 0 the number of augmentations to average results over batch_size : int, default 32 the batch size verbose : bool, default True if True, the process is reported to standard output
Returns
loss : float the average value of the loss function ssl_loss : float the average value of the SSL loss function metric : dict a dictionary of average values of metric functions
922 def predict( 923 self, 924 data: Union[DataLoader, BehaviorDataset, str] = None, 925 raw_output: bool = False, 926 apply_primary_function: bool = True, 927 augment_n: int = 0, 928 batch_size: int = 32, 929 train_mode: bool = False, 930 to_ram: bool = False, 931 embedding: bool = False, 932 ) -> torch.Tensor: 933 """ 934 Make a prediction with the Task model 935 936 Parameters 937 ---------- 938 data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional 939 the data to evaluate on (if not provided, evaluate on the Task validation dataset) 940 raw_output : bool, default False 941 if `True`, the raw predicted probabilities are returned 942 apply_primary_function : bool, default True 943 if `True`, the primary predict function is applied (to map the model output into a shape corresponding to 944 the input) 945 augment_n : int, default 0 946 the number of augmentations to average results over 947 batch_size : int, default 32 948 the batch size 949 train_mode : bool, default False 950 if `True`, the model is used in training mode (affects dropout and batch normalization layers) 951 to_ram : bool, default False 952 if `True`, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes 953 if the dataset is too large) 954 embedding : bool, default False 955 if `True`, the output of feature extractor is returned, ignoring the prediction module of the model 956 957 Returns 958 ------- 959 prediction : torch.Tensor 960 a prediction for the input data 961 """ 962 963 if self.parallel and not isinstance(self.model, nn.DataParallel): 964 self.model = MyDataParallel(self.model) 965 self.model.to(self.device) 966 if train_mode: 967 self.model.train() 968 else: 969 self.model.eval() 970 output = [] 971 if embedding: 972 raw_output = True 973 apply_primary_function = True 974 if type(data) is not DataLoader: 975 data = self._get_dataset(data) 976 if to_ram: 977 print("transferring dataset to RAM...") 978 data.to_ram() 979 data = DataLoader(data, shuffle=False, batch_size=batch_size) 980 self.model.ssl_off() 981 with torch.no_grad(): 982 for batch in tqdm(data): 983 input = {k: v.to(self.device) for k, v in batch["input"].items()} 984 predicted, _, _ = self._get_prediction( 985 input, 986 batch.get("tag"), 987 augment_n=augment_n, 988 embedding=embedding, 989 ) 990 if apply_primary_function: 991 predicted = self.primary_predict_function(predicted) 992 if not raw_output: 993 predicted = self.predict_function(predicted) 994 output.append(predicted.detach().cpu()) 995 self.model.ssl_on() 996 output = torch.cat(output).detach() 997 return output
Make a prediction with the Task model
Parameters
data : torch.utils.data.DataLoader | dlc2action.data.dataset.BehaviorDataset | str, optional
the data to evaluate on (if not provided, evaluate on the Task validation dataset)
raw_output : bool, default False
if True
, the raw predicted probabilities are returned
apply_primary_function : bool, default True
if True
, the primary predict function is applied (to map the model output into a shape corresponding to
the input)
augment_n : int, default 0
the number of augmentations to average results over
batch_size : int, default 32
the batch size
train_mode : bool, default False
if True
, the model is used in training mode (affects dropout and batch normalization layers)
to_ram : bool, default False
if True
, the dataset will be loaded in RAM (this speeds up the calculations but can lead to crashes
if the dataset is too large)
embedding : bool, default False
if True
, the output of feature extractor is returned, ignoring the prediction module of the model
Returns
prediction : torch.Tensor a prediction for the input data
999 def dataset(self, mode="train") -> BehaviorDataset: 1000 """ 1001 Get a dataset 1002 1003 Parameters 1004 ---------- 1005 mode : {'train', 'val', 'test} 1006 the dataset to get 1007 1008 Returns 1009 ------- 1010 dataset : dlc2action.data.dataset.BehaviorDataset 1011 the dataset 1012 """ 1013 1014 dataloader = self.dataloader(mode) 1015 if dataloader is None: 1016 raise ValueError("The length of the dataloader is 0!") 1017 return dataloader.dataset
Get a dataset
Parameters
mode : {'train', 'val', 'test} the dataset to get
Returns
dataset : dlc2action.data.dataset.BehaviorDataset the dataset
1019 def dataloader(self, mode: str = "train") -> DataLoader: 1020 """ 1021 Get a dataloader 1022 1023 Parameters 1024 ---------- 1025 mode : {'train', 'val', 'test} 1026 the dataset to get 1027 1028 Returns 1029 ------- 1030 dataloader : torch.utils.data.DataLoader 1031 the dataloader 1032 """ 1033 1034 if mode == "train": 1035 return self.train_dataloader 1036 elif mode == "val": 1037 return self.val_dataloader 1038 elif mode == "test": 1039 return self.test_dataloader 1040 else: 1041 raise ValueError( 1042 f'The {mode} mode is not recognized, please choose from "train", "val" or "test"' 1043 )
Get a dataloader
Parameters
mode : {'train', 'val', 'test} the dataset to get
Returns
dataloader : torch.utils.data.DataLoader the dataloader
1079 def generate_full_length_prediction( 1080 self, dataset=None, batch_size=32, augment_n=10 1081 ): 1082 """ 1083 Compile a prediction for the original input sequences 1084 1085 Parameters 1086 ---------- 1087 dataset : BehaviorDataset, optional 1088 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1089 batch_size : int, default 32 1090 the batch size 1091 augment_n : int, default 10 1092 the number of augmentations to average results over 1093 1094 Returns 1095 ------- 1096 prediction : dict 1097 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1098 are prediction tensors 1099 """ 1100 1101 dataset = self._get_dataset(dataset) 1102 if not isinstance(dataset, BehaviorDataset): 1103 raise TypeError( 1104 f"The dataset parameter has to be either None, string, " 1105 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1106 ) 1107 predicted = self.predict( 1108 dataset, 1109 raw_output=True, 1110 apply_primary_function=True, 1111 augment_n=augment_n, 1112 batch_size=batch_size, 1113 ) 1114 predicted = dataset.generate_full_length_prediction(predicted) 1115 predicted = { 1116 v_id: { 1117 clip_id: self._apply_predict_functions(v.unsqueeze(0)).squeeze() 1118 for clip_id, v in video_dict.items() 1119 } 1120 for v_id, video_dict in predicted.items() 1121 } 1122 return predicted
Compile a prediction for the original input sequences
Parameters
dataset : BehaviorDataset, optional
the dataset to generate a prediction for (if None
, generate for the validation dataset)
batch_size : int, default 32
the batch size
augment_n : int, default 10
the number of augmentations to average results over
Returns
prediction : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are prediction tensors
1124 def generate_submission( 1125 self, frame_number_map_file, dataset=None, batch_size=32, augment_n=10 1126 ): 1127 """ 1128 Generate a MABe-22 style submission dictionary 1129 1130 Parameters 1131 ---------- 1132 dataset : BehaviorDataset, optional 1133 the dataset to generate a prediction for (if `None`, generate for the validation dataset) 1134 batch_size : int, default 32 1135 the batch size 1136 augment_n : int, default 10 1137 the number of augmentations to average results over 1138 1139 Returns 1140 ------- 1141 submission : dict 1142 a dictionary with frame number mapping and embeddings 1143 """ 1144 1145 dataset = self._get_dataset(dataset) 1146 if not isinstance(dataset, BehaviorDataset): 1147 raise TypeError( 1148 f"The dataset parameter has to be either None, string, " 1149 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1150 ) 1151 predicted = self.predict( 1152 dataset, 1153 raw_output=True, 1154 apply_primary_function=True, 1155 augment_n=augment_n, 1156 batch_size=batch_size, 1157 embedding=True, 1158 ) 1159 predicted = dataset.generate_full_length_prediction(predicted) 1160 frame_map = np.load(frame_number_map_file, allow_pickle=True).item() 1161 length = frame_map[list(frame_map.keys())[-1]][1] 1162 embeddings = None 1163 for video_id in list(predicted.keys()): 1164 split = video_id.split("--") 1165 if len(split) != 2 or len(predicted[video_id]) > 1: 1166 raise RuntimeError( 1167 "Generating submissions is only implemented for the mabe22 dataset!" 1168 ) 1169 if split[1] not in frame_map: 1170 raise RuntimeError(f"The {split[1]} video is not in the frame map file") 1171 v_id = split[1] 1172 clip_id = list(predicted[video_id].keys())[0] 1173 if embeddings is None: 1174 embeddings = np.zeros((length, predicted[video_id][clip_id].shape[0])) 1175 start, end = frame_map[v_id] 1176 embeddings[start:end, :] = predicted[video_id][clip_id].T 1177 predicted.pop(video_id) 1178 predicted = { 1179 "frame_number_map": frame_map, 1180 "embeddings": embeddings.astype(np.float32), 1181 } 1182 return predicted
Generate a MABe-22 style submission dictionary
Parameters
dataset : BehaviorDataset, optional
the dataset to generate a prediction for (if None
, generate for the validation dataset)
batch_size : int, default 32
the batch size
augment_n : int, default 10
the number of augmentations to average results over
Returns
submission : dict a dictionary with frame number mapping and embeddings
1439 def visualize_results( 1440 self, 1441 save_path: str = None, 1442 add_legend: bool = True, 1443 ground_truth: bool = True, 1444 colormap: str = "viridis", 1445 hide_axes: bool = False, 1446 min_classes: int = 1, 1447 width: int = 10, 1448 whole_video: bool = False, 1449 transparent: bool = False, 1450 dataset: Union[BehaviorDataset, DataLoader, str, None] = None, 1451 drop_classes: Set = None, 1452 search_classes: Set = None, 1453 num_samples: int = 1, 1454 smooth_interval_prediction: int = None, 1455 behavior_name: str = None, 1456 ): 1457 """ 1458 Visualize random predictions 1459 1460 Parameters 1461 ---------- 1462 save_path : str, optional 1463 the path where the plot will be saved 1464 add_legend : bool, default True 1465 if `True`, legend will be added to the plot 1466 ground_truth : bool, default True 1467 if `True`, ground truth will be added to the plot 1468 colormap : str, default 'Accent' 1469 the `matplotlib` colormap to use 1470 hide_axes : bool, default True 1471 if `True`, the axes will be hidden on the plot 1472 min_classes : int, default 1 1473 the minimum number of classes in a displayed interval 1474 width : float, default 10 1475 the width of the plot 1476 whole_video : bool, default False 1477 if `True`, whole videos are plotted instead of segments 1478 transparent : bool, default False 1479 if `True`, the background on the plot is transparent 1480 dataset : BehaviorDataset | DataLoader | str | None, optional 1481 the dataset to make the prediction for (if not provided, the validation dataset is used) 1482 drop_classes : set, optional 1483 a set of class names to not be displayed 1484 search_classes : set, optional 1485 if given, only intervals where at least one of the classes is in ground truth will be shown 1486 """ 1487 1488 if drop_classes is None: 1489 drop_classes = [] 1490 dataset = self._get_dataset(dataset) 1491 if dataset.annotation_class() == "exclusive_classification": 1492 exclusive = True 1493 elif dataset.annotation_class() == "nonexclusive_classification": 1494 exclusive = False 1495 else: 1496 raise NotImplementedError( 1497 f"Results visualisation is not implemented for {dataset.annotation_class() } datasets!" 1498 ) 1499 if exclusive: 1500 labels = { 1501 k: v 1502 for k, v in dataset.behaviors_dict().items() 1503 if v not in drop_classes 1504 } 1505 labels.update({-100: "unknown"}) 1506 else: 1507 inverse_dict = {v: k for k, v in dataset.behaviors_dict().items()} 1508 if behavior_name is None: 1509 behavior_name = list(inverse_dict.keys())[0] 1510 label_ind = inverse_dict[behavior_name] 1511 labels = {1: label, -100: "unknown"} 1512 label_keys = sorted([int(x) for x in labels.keys()]) 1513 if search_classes is None: 1514 ok = True 1515 else: 1516 ok = False 1517 classes = [] 1518 if whole_video: 1519 predicted = self.generate_full_length_prediction(dataset) 1520 keys = list(predicted.keys()) 1521 counter = 0 1522 if whole_video: 1523 max_iter = len(keys) * 2 1524 else: 1525 max_iter = len(dataset) * 2 1526 while len(classes) < min_classes or not ok: 1527 counter += 1 1528 if counter > max_iter: 1529 raise RuntimeError( 1530 "Plotting is taking too many iterations; you should probably make some of the parameters less restrictive" 1531 ) 1532 if whole_video: 1533 i = randint(0, len(keys) - 1) 1534 prediction = predicted[keys[i]] 1535 keys_i = list(prediction.keys()) 1536 j = randint(0, len(keys_i) - 1) 1537 prediction = prediction[keys_i[j]] 1538 key1 = keys[i] 1539 key2 = keys_i[j] 1540 else: 1541 dataloader = DataLoader(dataset) 1542 i = randint(0, len(dataloader) - 1) 1543 for num, batch in enumerate(dataloader): 1544 if num == i: 1545 break 1546 input = {k: v.to(self.device) for k, v in batch["input"].items()} 1547 prediction, *_ = self._get_prediction( 1548 input, batch.get("tag"), augment_n=5 1549 ) 1550 prediction = self._apply_predict_functions(prediction) 1551 j = randint(0, len(prediction) - 1) 1552 prediction = prediction[j] 1553 if not exclusive: 1554 prediction = prediction[label_ind] 1555 if smooth_interval_prediction > 0: 1556 unsmoothed_prediction = deepcopy(prediction) 1557 prediction = self._smooth(prediction, smooth_interval_prediction) 1558 height = 3 1559 else: 1560 height = 2 1561 classes = [ 1562 labels[int(x)] for x in torch.unique(prediction) if x in label_keys 1563 ] 1564 if search_classes is not None: 1565 ok = any([x in classes for x in search_classes]) 1566 fig, ax = plt.subplots(figsize=(width, height)) 1567 color_list = cm.get_cmap(colormap, lut=len(labels)).colors 1568 1569 def _plot_prediction(prediction, height, set_labels=True): 1570 for c in label_keys: 1571 c_i = label_keys.index(int(c)) 1572 output, indices, counts = torch.unique_consecutive( 1573 prediction == c, return_inverse=True, return_counts=True 1574 ) 1575 long_indices = torch.where(output)[0] 1576 if len(long_indices) == 0: 1577 continue 1578 res_indices_start = [ 1579 (indices == i).nonzero(as_tuple=True)[0][0].item() 1580 for i in long_indices 1581 ] 1582 res_indices_end = [ 1583 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1584 for i in long_indices 1585 ] 1586 res_indices_len = [ 1587 end - start 1588 for start, end in zip(res_indices_start, res_indices_end) 1589 ] 1590 if set_labels: 1591 label = labels[int(c)] 1592 else: 1593 label = None 1594 ax.broken_barh( 1595 list(zip(res_indices_start, res_indices_len)), 1596 (height, 1), 1597 label=label, 1598 facecolors=color_list[c_i], 1599 ) 1600 1601 if smooth_interval_prediction > 0: 1602 _plot_prediction(unsmoothed_prediction, 0) 1603 _plot_prediction(prediction, 1.5, set_labels=False) 1604 gt_height = 3 1605 else: 1606 _plot_prediction(prediction, 0) 1607 gt_height = 1.5 1608 if ground_truth: 1609 if not whole_video: 1610 gt = batch["target"][j].to(self.device) 1611 else: 1612 gt = dataset.generate_full_length_gt()[key1][key2] 1613 for c in label_keys: 1614 c_i = label_keys.index(int(c)) 1615 if labels[int(c)] in classes: 1616 label = None 1617 else: 1618 label = labels[int(c)] 1619 output, indices, counts = torch.unique_consecutive( 1620 gt == c, return_inverse=True, return_counts=True 1621 ) 1622 long_indices = torch.where(output)[0] 1623 if len(long_indices) == 0: 1624 continue 1625 res_indices_start = [ 1626 (indices == i).nonzero(as_tuple=True)[0][0].item() 1627 for i in long_indices 1628 ] 1629 res_indices_end = [ 1630 (indices == i).nonzero(as_tuple=True)[0][-1].item() + 1 1631 for i in long_indices 1632 ] 1633 res_indices_len = [ 1634 end - start 1635 for start, end in zip(res_indices_start, res_indices_end) 1636 ] 1637 ax.broken_barh( 1638 list(zip(res_indices_start, res_indices_len)), 1639 (gt_height, 1), 1640 facecolors=color_list[c_i] if c != "unknown" else "gray", 1641 label=label, 1642 ) 1643 if not ground_truth: 1644 ax.axes.yaxis.set_visible(False) 1645 else: 1646 if smooth_interval_prediction > 0: 1647 ax.set_yticks([0.5, 2, 3.5]) 1648 ax.set_yticklabels(["prediction", "smoothed", "ground truth"]) 1649 else: 1650 ax.set_yticks([0.5, 2]) 1651 ax.set_yticklabels(["prediction", "ground truth"]) 1652 if add_legend: 1653 ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) 1654 if hide_axes: 1655 plt.axis("off") 1656 if save_path is not None: 1657 plt.savefig(save_path, transparent=transparent) 1658 plt.show()
Visualize random predictions
Parameters
save_path : str, optional
the path where the plot will be saved
add_legend : bool, default True
if True
, legend will be added to the plot
ground_truth : bool, default True
if True
, ground truth will be added to the plot
colormap : str, default 'Accent'
the matplotlib
colormap to use
hide_axes : bool, default True
if True
, the axes will be hidden on the plot
min_classes : int, default 1
the minimum number of classes in a displayed interval
width : float, default 10
the width of the plot
whole_video : bool, default False
if True
, whole videos are plotted instead of segments
transparent : bool, default False
if True
, the background on the plot is transparent
dataset : BehaviorDataset | DataLoader | str | None, optional
the dataset to make the prediction for (if not provided, the validation dataset is used)
drop_classes : set, optional
a set of class names to not be displayed
search_classes : set, optional
if given, only intervals where at least one of the classes is in ground truth will be shown
1665 def set_ssl_losses(self, ssl_losses: list) -> None: 1666 """ 1667 Set SSL losses 1668 1669 Parameters 1670 ---------- 1671 ssl_losses : list 1672 a list of callable SSL losses 1673 """ 1674 1675 self.ssl_losses = ssl_losses
Set SSL losses
Parameters
ssl_losses : list a list of callable SSL losses
1677 def set_log(self, log: str) -> None: 1678 """ 1679 Set the log file 1680 1681 Parameters 1682 ---------- 1683 log: str 1684 the mew log file path 1685 """ 1686 1687 self.log_file = log
Set the log file
Parameters
log: str the mew log file path
1689 def set_keep_target_none(self, keep_target_none: List) -> None: 1690 """ 1691 Set the keep_target_none parameter of the transformer 1692 1693 Parameters 1694 ---------- 1695 keep_target_none : list 1696 a list of bool values 1697 """ 1698 1699 self.transformer.keep_target_none = keep_target_none
Set the keep_target_none parameter of the transformer
Parameters
keep_target_none : list a list of bool values
1701 def set_generate_ssl_input(self, generate_ssl_input: list) -> None: 1702 """ 1703 Set the generate_ssl_input parameter of the transformer 1704 1705 Parameters 1706 ---------- 1707 generate_ssl_input : list 1708 a list of bool values 1709 """ 1710 1711 self.transformer.generate_ssl_input = generate_ssl_input
Set the generate_ssl_input parameter of the transformer
Parameters
generate_ssl_input : list a list of bool values
1713 def set_model(self, model: Model) -> None: 1714 """ 1715 Set a new model 1716 1717 Parameters 1718 ---------- 1719 model: Model 1720 the new model 1721 """ 1722 1723 self.epoch = 0 1724 self.model = model 1725 self.optimizer = self.optimizer_class(model.parameters(), lr=self.lr) 1726 if self.model.process_labels: 1727 self.model.set_behaviors(list(self.behaviors_dict().values()))
Set a new model
Parameters
model: Model the new model
1729 def set_dataloaders( 1730 self, 1731 train_dataloader: DataLoader, 1732 val_dataloader: DataLoader = None, 1733 test_dataloader: DataLoader = None, 1734 ) -> None: 1735 """ 1736 Set new dataloaders 1737 1738 Parameters 1739 ---------- 1740 train_dataloader: torch.utils.data.DataLoader 1741 the new train dataloader 1742 val_dataloader : torch.utils.data.DataLoader 1743 the new validation dataloader 1744 test_dataloader : torch.utils.data.DataLoader 1745 the new test dataloader 1746 """ 1747 1748 self.train_dataloader = train_dataloader 1749 self.val_dataloader = val_dataloader 1750 self.test_dataloader = test_dataloader
Set new dataloaders
Parameters
train_dataloader: torch.utils.data.DataLoader the new train dataloader val_dataloader : torch.utils.data.DataLoader the new validation dataloader test_dataloader : torch.utils.data.DataLoader the new test dataloader
1752 def set_loss(self, loss: Callable) -> None: 1753 """ 1754 Set new loss function 1755 1756 Parameters 1757 ---------- 1758 loss: callable 1759 the new loss function 1760 """ 1761 1762 self.loss = loss
Set new loss function
Parameters
loss: callable the new loss function
1764 def set_metrics(self, metrics: dict) -> None: 1765 """ 1766 Set new metric 1767 1768 Parameters 1769 ---------- 1770 metrics : dict 1771 the new metric dictionary 1772 """ 1773 1774 self.metrics = metrics
Set new metric
Parameters
metrics : dict the new metric dictionary
1776 def set_transformer(self, transformer: Transformer) -> None: 1777 """ 1778 Set a new transformer 1779 1780 Parameters 1781 ---------- 1782 transformer: Transformer 1783 the new transformer 1784 """ 1785 1786 self.transformer = transformer
Set a new transformer
Parameters
transformer: Transformer the new transformer
1788 def set_predict_functions( 1789 self, primary_predict_function: Callable, predict_function: Callable 1790 ) -> None: 1791 """ 1792 Set new predict functions 1793 1794 Parameters 1795 ---------- 1796 primary_predict_function : callable 1797 the new primary predict function 1798 predict_function : callable 1799 the new predict function 1800 """ 1801 1802 self.primary_predict_function = primary_predict_function 1803 self.predict_function = predict_function
Set new predict functions
Parameters
primary_predict_function : callable the new primary predict function predict_function : callable the new predict function
1831 def count_classes(self, bouts: bool = False) -> Dict: 1832 """ 1833 Get a dictionary of class counts in different modes 1834 1835 Parameters 1836 ---------- 1837 bouts : bool, default False 1838 if `True`, instead of frame counts segment counts are returned 1839 1840 Returns 1841 ------- 1842 class_counts : dict 1843 a dictionary where first-level keys are "train", "val" and "test", second-level keys are 1844 class names and values are class counts (in frames) 1845 """ 1846 1847 class_counts = {} 1848 for x in ["train", "val", "test"]: 1849 try: 1850 class_counts[x] = self.dataset(x).count_classes(bouts) 1851 except ValueError: 1852 class_counts[x] = {k: 0 for k in self.behaviors_dict().keys()} 1853 return class_counts
Get a dictionary of class counts in different modes
Parameters
bouts : bool, default False
if True
, instead of frame counts segment counts are returned
Returns
class_counts : dict a dictionary where first-level keys are "train", "val" and "test", second-level keys are class names and values are class counts (in frames)
1855 def behaviors_dict(self) -> Dict: 1856 """ 1857 Get a behavior dictionary 1858 1859 Keys are label indices and values are label names. 1860 1861 Returns 1862 ------- 1863 behaviors_dict : dict 1864 behavior dictionary 1865 """ 1866 1867 return self.dataset().behaviors_dict()
Get a behavior dictionary
Keys are label indices and values are label names.
Returns
behaviors_dict : dict behavior dictionary
1869 def update_parameters(self, parameters: Dict) -> None: 1870 """ 1871 Update training parameters from a dictionary 1872 1873 Parameters 1874 ---------- 1875 parameters : dict 1876 the update dictionary 1877 """ 1878 1879 self.lr = parameters.get("lr", self.lr) 1880 self.parallel = parameters.get("parallel", self.parallel) 1881 self.optimizer = self.optimizer_class(self.model.parameters(), lr=self.lr) 1882 self.verbose = parameters.get("verbose", self.verbose) 1883 self.device = parameters.get("device", self.device) 1884 if self.device == "auto": 1885 self.device = "cuda" if torch.cuda.is_available() else "cpu" 1886 self.augment_train = int(parameters.get("augment_train", self.augment_train)) 1887 self.augment_val = int(parameters.get("augment_val", self.augment_val)) 1888 ssl_weights = parameters.get("ssl_weights", self.ssl_weights) 1889 if ssl_weights is None: 1890 ssl_weights = [1 for _ in self.ssl_losses] 1891 if not isinstance(ssl_weights, Iterable): 1892 ssl_weights = [ssl_weights for _ in self.ssl_losses] 1893 self.ssl_weights = ssl_weights 1894 self.num_epochs = parameters.get("num_epochs", self.num_epochs) 1895 self.model_save_epochs = parameters.get( 1896 "model_save_epochs", self.model_save_epochs 1897 ) 1898 self.model_save_path = parameters.get("model_save_path", self.model_save_path) 1899 self.pseudolabel = parameters.get("pseudolabel", self.pseudolabel) 1900 self.T1 = int(parameters.get("pseudolabel_start", self.T1)) 1901 self.T2 = int(parameters.get("alpha_growth_stop", self.T2)) 1902 self.t = int(parameters.get("correction_interval", self.t)) 1903 self.alpha_f = parameters.get("pseudolabel_alpha_f", self.alpha_f) 1904 self.log_file = parameters.get("log_file", self.log_file)
Update training parameters from a dictionary
Parameters
parameters : dict the update dictionary
1906 def generate_uncertainty_score( 1907 self, 1908 classes: List, 1909 augment_n: int = 0, 1910 batch_size: int = 32, 1911 method: str = "least_confidence", 1912 predicted: torch.Tensor = None, 1913 behaviors_dict: Dict = None, 1914 ) -> Dict: 1915 """ 1916 Generate frame-wise scores for active learning 1917 1918 Parameters 1919 ---------- 1920 classes : list 1921 a list of class names or indices; their confidence scores will be computed separately and stacked 1922 augment_n : int, default 0 1923 the number of augmentations to average over 1924 batch_size : int, default 32 1925 the batch size 1926 method : {"least_confidence", "entropy"} 1927 the method used to calculate the scores from the probability predictions (`"least_confidence"`: `1 - p_i` if 1928 `p_i > 0.5` or `p_i` if `p_i < 0.5`; `"entropy"`: `- p_i * log(p_i) - (1 - p_i) * log(1 - p_i)`) 1929 1930 Returns 1931 ------- 1932 score_dicts : dict 1933 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 1934 are score tensors 1935 """ 1936 1937 dataset = self.dataset("train") 1938 if behaviors_dict is None: 1939 behaviors_dict = self.behaviors_dict() 1940 if not isinstance(dataset, BehaviorDataset): 1941 raise TypeError( 1942 f"The dataset parameter has to be either None, string, " 1943 f"BehaviorDataset or Dataloader, got {type(dataset)}" 1944 ) 1945 if predicted is None: 1946 predicted = self.predict( 1947 dataset, 1948 raw_output=True, 1949 apply_primary_function=True, 1950 augment_n=augment_n, 1951 batch_size=batch_size, 1952 ) 1953 predicted = dataset.generate_full_length_prediction(predicted) 1954 if isinstance(classes[0], str): 1955 behaviors_dict_inverse = {v: k for k, v in behaviors_dict.items()} 1956 classes = [behaviors_dict_inverse[c] for c in classes] 1957 for v_id, v in predicted.items(): 1958 for clip_id, vv in v.items(): 1959 if method == "least_confidence": 1960 predicted[v_id][clip_id][vv > 0.5] = 1 - vv[vv > 0.5] 1961 elif method == "entropy": 1962 predicted[v_id][clip_id][vv != -100] = ( 1963 -vv * torch.log(vv) - (1 - vv) * torch.log(1 - vv) 1964 )[vv != -100] 1965 elif method == "random": 1966 predicted[v_id][clip_id] = torch.rand(vv.shape) 1967 else: 1968 raise ValueError( 1969 f"The {method} method is not recognized; please choose from ['least_confidence', 'entropy']" 1970 ) 1971 predicted[v_id][clip_id][vv == -100] = 0 1972 1973 predicted = { 1974 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 1975 for v_id, video_dict in predicted.items() 1976 } 1977 return predicted
Generate frame-wise scores for active learning
Parameters
classes : list
a list of class names or indices; their confidence scores will be computed separately and stacked
augment_n : int, default 0
the number of augmentations to average over
batch_size : int, default 32
the batch size
method : {"least_confidence", "entropy"}
the method used to calculate the scores from the probability predictions ("least_confidence"
: 1 - p_i
if
p_i > 0.5
or p_i
if p_i < 0.5
; "entropy"
: - p_i * log(p_i) - (1 - p_i) * log(1 - p_i)
)
Returns
score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors
1979 def generate_bald_score( 1980 self, 1981 classes: List, 1982 augment_n: int = 0, 1983 batch_size: int = 32, 1984 num_models: int = 10, 1985 kernel_size: int = 11, 1986 ) -> Dict: 1987 """ 1988 Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning 1989 1990 Parameters 1991 ---------- 1992 classes : list 1993 a list of class names or indices; their confidence scores will be computed separately and stacked 1994 augment_n : int, default 0 1995 the number of augmentations to average over 1996 batch_size : int, default 32 1997 the batch size 1998 num_models : int, default 10 1999 the number of dropout masks to apply 2000 kernel_size : int, default 11 2001 the size of the smoothing gaussian kernel 2002 2003 Returns 2004 ------- 2005 score_dicts : dict 2006 a nested dictionary where first level keys are video ids, second level keys are clip ids and values 2007 are score tensors 2008 """ 2009 2010 dataset = self.dataset("train") 2011 dataset = self._get_dataset(dataset) 2012 if not isinstance(dataset, BehaviorDataset): 2013 raise TypeError( 2014 f"The dataset parameter has to be either None, string, " 2015 f"BehaviorDataset or Dataloader, got {type(dataset)}" 2016 ) 2017 predictions = [] 2018 for _ in range(num_models): 2019 predicted = self.predict( 2020 dataset, 2021 raw_output=True, 2022 apply_primary_function=True, 2023 augment_n=augment_n, 2024 batch_size=batch_size, 2025 train_mode=True, 2026 ) 2027 predicted = dataset.generate_full_length_prediction(predicted) 2028 if isinstance(classes[0], str): 2029 behaviors_dict_inverse = { 2030 v: k for k, v in self.behaviors_dict().items() 2031 } 2032 classes = [behaviors_dict_inverse[c] for c in classes] 2033 for v_id, v in predicted.items(): 2034 for clip_id, vv in v.items(): 2035 vv[vv != -100] = (vv[vv != -100] > 0.5).int().float() 2036 predicted[v_id][clip_id] = vv 2037 predicted = { 2038 v_id: {clip_id: v[classes, :] for clip_id, v in video_dict.items()} 2039 for v_id, video_dict in predicted.items() 2040 } 2041 predictions.append(predicted) 2042 result = {v_id: {} for v_id in predictions[0]} 2043 r = range(-int(kernel_size / 2), int(kernel_size / 2) + 1) 2044 gauss = [1 / (1 * sqrt(2 * pi)) * exp(-float(x) ** 2 / (2 * 1**2)) for x in r] 2045 gauss = [x / sum(gauss) for x in gauss] 2046 kernel = torch.FloatTensor([[gauss]]) 2047 for v_id in predictions[0]: 2048 for clip_id in predictions[0][v_id]: 2049 consensus = ( 2050 ( 2051 torch.mean( 2052 torch.stack([x[v_id][clip_id] for x in predictions]), dim=0 2053 ) 2054 > 0.5 2055 ) 2056 .int() 2057 .float() 2058 ) 2059 consensus[predictions[0][v_id][clip_id] == -100] = -100 2060 result[v_id][clip_id] = torch.zeros(consensus.shape) 2061 for x in predictions: 2062 result[v_id][clip_id] += (x[v_id][clip_id] != consensus).int() 2063 result[v_id][clip_id] = result[v_id][clip_id] * 2 / num_models 2064 res = torch.zeros(result[v_id][clip_id].shape) 2065 for i in range(len(classes)): 2066 res[ 2067 i, floor(kernel_size // 2) : -floor(kernel_size // 2) 2068 ] = torch.nn.functional.conv1d( 2069 result[v_id][clip_id][i, :].unsqueeze(0).unsqueeze(0), kernel 2070 )[ 2071 0, ... 2072 ] 2073 result[v_id][clip_id] = res 2074 return result
Generate frame-wise Bayesian Active Learning by Disagreement scores for active learning
Parameters
classes : list a list of class names or indices; their confidence scores will be computed separately and stacked augment_n : int, default 0 the number of augmentations to average over batch_size : int, default 32 the batch size num_models : int, default 10 the number of dropout masks to apply kernel_size : int, default 11 the size of the smoothing gaussian kernel
Returns
score_dicts : dict a nested dictionary where first level keys are video ids, second level keys are clip ids and values are score tensors