dlc2action.model.motionbert

  1#
  2# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
  3#
  4# This project and all its files are licensed under GNU AGPLv3 or later version. A copy is included in dlc2action/LICENSE.AGPL.
  5#
  6# Incorporates code adapted from MotionBERT by Walter0807
  7# Original work Copyright (c) 2023 Walter0807
  8# Source: https://github.com/Walter0807/MotionBERT
  9# Originally licensed under Apache License Version 2.0, 2023
 10# Combined work licensed under GNU AGPLv3
 11#
 12import torch
 13import torch.nn as nn
 14import torch.nn.functional as F
 15from dlc2action.model.base_model import Model
 16from einops import rearrange
 17from dlc2action.model.motionbert_modules import DSTformer
 18from functools import partial
 19
 20
 21def load_backbone(args):
 22    model_backbone = DSTformer(dim_in=args['dim'], dim_out=args["dim"], dim_feat=args["dim_feat"], dim_rep=args["dim_rep"],
 23                                   depth=args["depth"], num_heads=args["num_heads"], mlp_ratio=args["mlp_ratio"], norm_layer=partial(nn.LayerNorm, eps=1e-6),
 24                                   maxlen=args["maxlen"], num_joints=args["num_joints"])
 25    return model_backbone
 26
 27
 28class ActionHeadClassification(nn.Module):
 29    def __init__(self, dropout_ratio=0., dim_rep=512, num_classes=60, num_joints=17, hidden_dim=2048):
 30        super(ActionHeadClassification, self).__init__()
 31        self.dropout = nn.Dropout(p=dropout_ratio)
 32        self.bn = nn.BatchNorm1d(hidden_dim, momentum=0.1)
 33        self.relu = nn.ReLU(inplace=True)
 34        self.fc1 = nn.Linear(dim_rep*num_joints, hidden_dim)
 35        self.fc2 = nn.Linear(hidden_dim, num_classes)
 36
 37    def forward(self, feat):
 38        '''
 39            Input: (N, M, T, J, C)
 40        '''
 41        N, M, T, J, C = feat.shape
 42        feat = self.dropout(feat)
 43        feat = feat.permute(0, 1, 3, 4, 2)      # (N, M, T, J, C) -> (N, M, J, C, T)
 44        feat = feat.mean(dim=-1)
 45        feat = feat.reshape(N, M, -1)           # (N, M, J*C)
 46        feat = feat.mean(dim=1)
 47        feat = self.fc1(feat)
 48        feat = self.bn(feat)
 49        feat = self.relu(feat)
 50        feat = self.fc2(feat)
 51        return feat
 52
 53class ActionHeadSegmentation(nn.Module):
 54    def __init__(self, input_dim, dropout_ratio=0., num_classes=60, hidden_dim=128):
 55        super(ActionHeadSegmentation, self).__init__()
 56        self.dropout = nn.Dropout(p=dropout_ratio)
 57        self.bn = nn.BatchNorm1d(hidden_dim, momentum=0.1)
 58        self.relu = nn.ReLU(inplace=True)
 59        self.fc1 = nn.Linear(input_dim, hidden_dim)
 60        self.fc2 = nn.Linear(hidden_dim, num_classes)
 61
 62    def forward(self, feat):
 63        '''
 64            Input: (N, M, T, F)
 65        '''
 66        # print('FEAT before', feat.shape)
 67        feat = self.dropout(feat)
 68        feat = self.fc1(feat)
 69        feat = rearrange(feat, "n t f -> n f t")
 70        feat = self.bn(feat)
 71        feat = rearrange(feat, "n f t -> n t f")
 72        feat = self.relu(feat)
 73        feat = self.fc2(feat)
 74        out = rearrange(feat, "n t f -> n f t")
 75        # print('FEAT after -out', feat.shape)
 76        return out
 77
 78class ActionHeadEmbed(nn.Module):
 79    def __init__(self, dropout_ratio=0., dim_rep=512, num_joints=17, hidden_dim=2048):
 80        super(ActionHeadEmbed, self).__init__()
 81        self.dropout = nn.Dropout(p=dropout_ratio)
 82        self.fc1 = nn.Linear(dim_rep*num_joints, hidden_dim)
 83    def forward(self, feat):
 84        '''
 85            Input: (N, M, T, J, C)
 86        '''
 87        N, M, T, J, C = feat.shape
 88        feat = self.dropout(feat)
 89        feat = feat.permute(0, 1, 3, 4, 2)      # (N, M, T, J, C) -> (N, M, J, C, T)
 90        feat = feat.mean(dim=-1)
 91        feat = feat.reshape(N, M, -1)           # (N, M, J*C)
 92        feat = feat.mean(dim=1)
 93        feat = self.fc1(feat)
 94        feat = F.normalize(feat, dim=-1)
 95        return feat
 96
 97class ActionNet(nn.Module):
 98    def __init__(self, backbone, channels):
 99        super(ActionNet, self).__init__()
100        self.backbone = backbone
101        self.channels = channels
102        # if version=='class':
103        #     self.head = ActionHeadClassification(dropout_ratio=dropout_ratio, dim_rep=dim_rep, num_classes=num_classes, num_joints=num_joints)
104        # elif version=='embed':
105        #     self.head = ActionHeadEmbed(dropout_ratio=dropout_ratio, dim_rep=dim_rep, hidden_dim=hidden_dim, num_joints=num_joints)
106        # else:
107        #     raise Exception('Version Error.')
108
109    def forward(self, x):
110        '''
111            Input: (N, M x T x J x 3)
112        '''
113        #print('BEFORE', x.shape)
114        x = rearrange(x, 'n (j c) t -> n t j c', c=self.channels)
115        # print('AFTER', x.shape)
116        # N, T, J, C = x.shape
117        feat = self.backbone.get_representation(x)
118        # feat = feat.reshape([N, 1, T, self.feat_J, -1])      # (N, M, T, J, C)
119        # out = self.head(feat)
120        feat = rearrange(feat, "n t j c -> n t (j c)")
121        return feat
122
123
124class MotionBERT(Model):
125    """
126    An implementation of MotionBERT
127    """
128
129    def __init__(
130        self,
131        dim_feat,
132        dim_rep,
133        depth,
134        num_heads,
135        mlp_ratio,
136        len_segment,
137        num_joints,
138        num_classes,
139        input_dims,
140        state_dict_path=None,
141        ssl_constructors=None,
142        ssl_types=None,
143        ssl_modules=None,
144    ):
145        if dim_rep == "dim_feat":
146            dim_rep = dim_feat
147        input_dims = int(sum([s[0] for s in input_dims.values()]))
148        print('input_dims', input_dims)
149        print('num_joints', num_joints)
150        assert input_dims % num_joints == 0
151        args = {
152            "dim_feat": int(dim_feat),
153            "dim_rep": int(dim_rep),
154            "depth": int(depth),
155            "num_heads": int(num_heads),
156            "mlp_ratio": int(mlp_ratio),
157            "maxlen": int(len_segment),
158            "num_joints": int(num_joints),
159            "dim": int(input_dims // num_joints),
160        }
161        self.f_shape = args["dim_rep"] * args["num_joints"]
162        self.params = {
163            "backbone": load_backbone(args),
164            "channels": int(input_dims // num_joints),
165        }
166        self.num_classes = num_classes
167        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
168
169    def _feature_extractor(self):
170        return ActionNet(**self.params)
171
172    def _predictor(self) -> torch.nn.Module:
173        #return ActionHeadSegmentation(dropout_ratio=0.5, input_dim=self.f_shape, num_classes=4)
174        return ActionHeadSegmentation(dropout_ratio=0.1, input_dim=self.f_shape, num_classes=self.num_classes)
175
176    def features_shape(self) -> torch.Size:
177        return self.f_shape
def load_backbone(args):
22def load_backbone(args):
23    model_backbone = DSTformer(dim_in=args['dim'], dim_out=args["dim"], dim_feat=args["dim_feat"], dim_rep=args["dim_rep"],
24                                   depth=args["depth"], num_heads=args["num_heads"], mlp_ratio=args["mlp_ratio"], norm_layer=partial(nn.LayerNorm, eps=1e-6),
25                                   maxlen=args["maxlen"], num_joints=args["num_joints"])
26    return model_backbone
class ActionHeadClassification(torch.nn.modules.module.Module):
29class ActionHeadClassification(nn.Module):
30    def __init__(self, dropout_ratio=0., dim_rep=512, num_classes=60, num_joints=17, hidden_dim=2048):
31        super(ActionHeadClassification, self).__init__()
32        self.dropout = nn.Dropout(p=dropout_ratio)
33        self.bn = nn.BatchNorm1d(hidden_dim, momentum=0.1)
34        self.relu = nn.ReLU(inplace=True)
35        self.fc1 = nn.Linear(dim_rep*num_joints, hidden_dim)
36        self.fc2 = nn.Linear(hidden_dim, num_classes)
37
38    def forward(self, feat):
39        '''
40            Input: (N, M, T, J, C)
41        '''
42        N, M, T, J, C = feat.shape
43        feat = self.dropout(feat)
44        feat = feat.permute(0, 1, 3, 4, 2)      # (N, M, T, J, C) -> (N, M, J, C, T)
45        feat = feat.mean(dim=-1)
46        feat = feat.reshape(N, M, -1)           # (N, M, J*C)
47        feat = feat.mean(dim=1)
48        feat = self.fc1(feat)
49        feat = self.bn(feat)
50        feat = self.relu(feat)
51        feat = self.fc2(feat)
52        return feat

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ActionHeadClassification( dropout_ratio=0.0, dim_rep=512, num_classes=60, num_joints=17, hidden_dim=2048)
30    def __init__(self, dropout_ratio=0., dim_rep=512, num_classes=60, num_joints=17, hidden_dim=2048):
31        super(ActionHeadClassification, self).__init__()
32        self.dropout = nn.Dropout(p=dropout_ratio)
33        self.bn = nn.BatchNorm1d(hidden_dim, momentum=0.1)
34        self.relu = nn.ReLU(inplace=True)
35        self.fc1 = nn.Linear(dim_rep*num_joints, hidden_dim)
36        self.fc2 = nn.Linear(hidden_dim, num_classes)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dropout
bn
relu
fc1
fc2
def forward(self, feat):
38    def forward(self, feat):
39        '''
40            Input: (N, M, T, J, C)
41        '''
42        N, M, T, J, C = feat.shape
43        feat = self.dropout(feat)
44        feat = feat.permute(0, 1, 3, 4, 2)      # (N, M, T, J, C) -> (N, M, J, C, T)
45        feat = feat.mean(dim=-1)
46        feat = feat.reshape(N, M, -1)           # (N, M, J*C)
47        feat = feat.mean(dim=1)
48        feat = self.fc1(feat)
49        feat = self.bn(feat)
50        feat = self.relu(feat)
51        feat = self.fc2(feat)
52        return feat

Input: (N, M, T, J, C)

class ActionHeadSegmentation(torch.nn.modules.module.Module):
54class ActionHeadSegmentation(nn.Module):
55    def __init__(self, input_dim, dropout_ratio=0., num_classes=60, hidden_dim=128):
56        super(ActionHeadSegmentation, self).__init__()
57        self.dropout = nn.Dropout(p=dropout_ratio)
58        self.bn = nn.BatchNorm1d(hidden_dim, momentum=0.1)
59        self.relu = nn.ReLU(inplace=True)
60        self.fc1 = nn.Linear(input_dim, hidden_dim)
61        self.fc2 = nn.Linear(hidden_dim, num_classes)
62
63    def forward(self, feat):
64        '''
65            Input: (N, M, T, F)
66        '''
67        # print('FEAT before', feat.shape)
68        feat = self.dropout(feat)
69        feat = self.fc1(feat)
70        feat = rearrange(feat, "n t f -> n f t")
71        feat = self.bn(feat)
72        feat = rearrange(feat, "n f t -> n t f")
73        feat = self.relu(feat)
74        feat = self.fc2(feat)
75        out = rearrange(feat, "n t f -> n f t")
76        # print('FEAT after -out', feat.shape)
77        return out

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ActionHeadSegmentation(input_dim, dropout_ratio=0.0, num_classes=60, hidden_dim=128)
55    def __init__(self, input_dim, dropout_ratio=0., num_classes=60, hidden_dim=128):
56        super(ActionHeadSegmentation, self).__init__()
57        self.dropout = nn.Dropout(p=dropout_ratio)
58        self.bn = nn.BatchNorm1d(hidden_dim, momentum=0.1)
59        self.relu = nn.ReLU(inplace=True)
60        self.fc1 = nn.Linear(input_dim, hidden_dim)
61        self.fc2 = nn.Linear(hidden_dim, num_classes)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dropout
bn
relu
fc1
fc2
def forward(self, feat):
63    def forward(self, feat):
64        '''
65            Input: (N, M, T, F)
66        '''
67        # print('FEAT before', feat.shape)
68        feat = self.dropout(feat)
69        feat = self.fc1(feat)
70        feat = rearrange(feat, "n t f -> n f t")
71        feat = self.bn(feat)
72        feat = rearrange(feat, "n f t -> n t f")
73        feat = self.relu(feat)
74        feat = self.fc2(feat)
75        out = rearrange(feat, "n t f -> n f t")
76        # print('FEAT after -out', feat.shape)
77        return out

Input: (N, M, T, F)

class ActionHeadEmbed(torch.nn.modules.module.Module):
79class ActionHeadEmbed(nn.Module):
80    def __init__(self, dropout_ratio=0., dim_rep=512, num_joints=17, hidden_dim=2048):
81        super(ActionHeadEmbed, self).__init__()
82        self.dropout = nn.Dropout(p=dropout_ratio)
83        self.fc1 = nn.Linear(dim_rep*num_joints, hidden_dim)
84    def forward(self, feat):
85        '''
86            Input: (N, M, T, J, C)
87        '''
88        N, M, T, J, C = feat.shape
89        feat = self.dropout(feat)
90        feat = feat.permute(0, 1, 3, 4, 2)      # (N, M, T, J, C) -> (N, M, J, C, T)
91        feat = feat.mean(dim=-1)
92        feat = feat.reshape(N, M, -1)           # (N, M, J*C)
93        feat = feat.mean(dim=1)
94        feat = self.fc1(feat)
95        feat = F.normalize(feat, dim=-1)
96        return feat

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ActionHeadEmbed(dropout_ratio=0.0, dim_rep=512, num_joints=17, hidden_dim=2048)
80    def __init__(self, dropout_ratio=0., dim_rep=512, num_joints=17, hidden_dim=2048):
81        super(ActionHeadEmbed, self).__init__()
82        self.dropout = nn.Dropout(p=dropout_ratio)
83        self.fc1 = nn.Linear(dim_rep*num_joints, hidden_dim)

Initialize internal Module state, shared by both nn.Module and ScriptModule.

dropout
fc1
def forward(self, feat):
84    def forward(self, feat):
85        '''
86            Input: (N, M, T, J, C)
87        '''
88        N, M, T, J, C = feat.shape
89        feat = self.dropout(feat)
90        feat = feat.permute(0, 1, 3, 4, 2)      # (N, M, T, J, C) -> (N, M, J, C, T)
91        feat = feat.mean(dim=-1)
92        feat = feat.reshape(N, M, -1)           # (N, M, J*C)
93        feat = feat.mean(dim=1)
94        feat = self.fc1(feat)
95        feat = F.normalize(feat, dim=-1)
96        return feat

Input: (N, M, T, J, C)

class ActionNet(torch.nn.modules.module.Module):
 98class ActionNet(nn.Module):
 99    def __init__(self, backbone, channels):
100        super(ActionNet, self).__init__()
101        self.backbone = backbone
102        self.channels = channels
103        # if version=='class':
104        #     self.head = ActionHeadClassification(dropout_ratio=dropout_ratio, dim_rep=dim_rep, num_classes=num_classes, num_joints=num_joints)
105        # elif version=='embed':
106        #     self.head = ActionHeadEmbed(dropout_ratio=dropout_ratio, dim_rep=dim_rep, hidden_dim=hidden_dim, num_joints=num_joints)
107        # else:
108        #     raise Exception('Version Error.')
109
110    def forward(self, x):
111        '''
112            Input: (N, M x T x J x 3)
113        '''
114        #print('BEFORE', x.shape)
115        x = rearrange(x, 'n (j c) t -> n t j c', c=self.channels)
116        # print('AFTER', x.shape)
117        # N, T, J, C = x.shape
118        feat = self.backbone.get_representation(x)
119        # feat = feat.reshape([N, 1, T, self.feat_J, -1])      # (N, M, T, J, C)
120        # out = self.head(feat)
121        feat = rearrange(feat, "n t j c -> n t (j c)")
122        return feat

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call to(), etc.

As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

ActionNet(backbone, channels)
 99    def __init__(self, backbone, channels):
100        super(ActionNet, self).__init__()
101        self.backbone = backbone
102        self.channels = channels
103        # if version=='class':
104        #     self.head = ActionHeadClassification(dropout_ratio=dropout_ratio, dim_rep=dim_rep, num_classes=num_classes, num_joints=num_joints)
105        # elif version=='embed':
106        #     self.head = ActionHeadEmbed(dropout_ratio=dropout_ratio, dim_rep=dim_rep, hidden_dim=hidden_dim, num_joints=num_joints)
107        # else:
108        #     raise Exception('Version Error.')

Initialize internal Module state, shared by both nn.Module and ScriptModule.

backbone
channels
def forward(self, x):
110    def forward(self, x):
111        '''
112            Input: (N, M x T x J x 3)
113        '''
114        #print('BEFORE', x.shape)
115        x = rearrange(x, 'n (j c) t -> n t j c', c=self.channels)
116        # print('AFTER', x.shape)
117        # N, T, J, C = x.shape
118        feat = self.backbone.get_representation(x)
119        # feat = feat.reshape([N, 1, T, self.feat_J, -1])      # (N, M, T, J, C)
120        # out = self.head(feat)
121        feat = rearrange(feat, "n t j c -> n t (j c)")
122        return feat

Input: (N, M x T x J x 3)

class MotionBERT(dlc2action.model.base_model.Model):
125class MotionBERT(Model):
126    """
127    An implementation of MotionBERT
128    """
129
130    def __init__(
131        self,
132        dim_feat,
133        dim_rep,
134        depth,
135        num_heads,
136        mlp_ratio,
137        len_segment,
138        num_joints,
139        num_classes,
140        input_dims,
141        state_dict_path=None,
142        ssl_constructors=None,
143        ssl_types=None,
144        ssl_modules=None,
145    ):
146        if dim_rep == "dim_feat":
147            dim_rep = dim_feat
148        input_dims = int(sum([s[0] for s in input_dims.values()]))
149        print('input_dims', input_dims)
150        print('num_joints', num_joints)
151        assert input_dims % num_joints == 0
152        args = {
153            "dim_feat": int(dim_feat),
154            "dim_rep": int(dim_rep),
155            "depth": int(depth),
156            "num_heads": int(num_heads),
157            "mlp_ratio": int(mlp_ratio),
158            "maxlen": int(len_segment),
159            "num_joints": int(num_joints),
160            "dim": int(input_dims // num_joints),
161        }
162        self.f_shape = args["dim_rep"] * args["num_joints"]
163        self.params = {
164            "backbone": load_backbone(args),
165            "channels": int(input_dims // num_joints),
166        }
167        self.num_classes = num_classes
168        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)
169
170    def _feature_extractor(self):
171        return ActionNet(**self.params)
172
173    def _predictor(self) -> torch.nn.Module:
174        #return ActionHeadSegmentation(dropout_ratio=0.5, input_dim=self.f_shape, num_classes=4)
175        return ActionHeadSegmentation(dropout_ratio=0.1, input_dim=self.f_shape, num_classes=self.num_classes)
176
177    def features_shape(self) -> torch.Size:
178        return self.f_shape

An implementation of MotionBERT

MotionBERT( dim_feat, dim_rep, depth, num_heads, mlp_ratio, len_segment, num_joints, num_classes, input_dims, state_dict_path=None, ssl_constructors=None, ssl_types=None, ssl_modules=None)
130    def __init__(
131        self,
132        dim_feat,
133        dim_rep,
134        depth,
135        num_heads,
136        mlp_ratio,
137        len_segment,
138        num_joints,
139        num_classes,
140        input_dims,
141        state_dict_path=None,
142        ssl_constructors=None,
143        ssl_types=None,
144        ssl_modules=None,
145    ):
146        if dim_rep == "dim_feat":
147            dim_rep = dim_feat
148        input_dims = int(sum([s[0] for s in input_dims.values()]))
149        print('input_dims', input_dims)
150        print('num_joints', num_joints)
151        assert input_dims % num_joints == 0
152        args = {
153            "dim_feat": int(dim_feat),
154            "dim_rep": int(dim_rep),
155            "depth": int(depth),
156            "num_heads": int(num_heads),
157            "mlp_ratio": int(mlp_ratio),
158            "maxlen": int(len_segment),
159            "num_joints": int(num_joints),
160            "dim": int(input_dims // num_joints),
161        }
162        self.f_shape = args["dim_rep"] * args["num_joints"]
163        self.params = {
164            "backbone": load_backbone(args),
165            "channels": int(input_dims // num_joints),
166        }
167        self.num_classes = num_classes
168        super().__init__(ssl_constructors, ssl_modules, ssl_types, state_dict_path)

Initialize the model.

Parameters

ssl_constructors : list, optional a list of SSL constructors that build the necessary SSL modules ssl_modules : list, optional a list of torch.nn.Module instances that will serve as SSL modules ssl_types : list, optional a list of string SSL types state_dict_path : str, optional path to the model state dictionary to load strict : bool, default False when True, the state dictionary will only be loaded if the current and the loaded architecture are the same; otherwise missing or extra keys, as well as shaoe inconsistencies, are ignored prompt_function : callable, optional a function that takes a list of strings and returns a string prompt

f_shape
params
num_classes
def features_shape(self) -> torch.Size:
177    def features_shape(self) -> torch.Size:
178        return self.f_shape

Get the shape of feature extractor output.

Returns

feature_shape : torch.Size shape of feature extractor output