sams-tom commited on
Commit
a7b33df
·
verified ·
1 Parent(s): f3a7eb0

Add custom model definitions (model_definitions.py)

Browse files
Files changed (1) hide show
  1. model_definitions.py +137 -0
model_definitions.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision.models import resnet50, ResNet50_Weights
4
+ import torch.nn.functional as F
5
+ from huggingface_hub import PyTorchModelHubMixin # Import the mixin
6
+
7
+ # --- Custom Model Definitions ---
8
+
9
+ class Identity(nn.Module):
10
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
11
+ return x
12
+
13
+ class AdditiveAttention(nn.Module):
14
+ def __init__(self, d_model: int, hidden_dim: int = 128):
15
+ super(AdditiveAttention, self).__init__()
16
+ self.query_projection = nn.Linear(d_model, hidden_dim)
17
+ self.key_projection = nn.Linear(d_model, hidden_dim)
18
+ self.value_projection = nn.Linear(d_model, hidden_dim)
19
+ self.attention_mechanism = nn.Linear(hidden_dim, hidden_dim) # Output hidden_dim
20
+
21
+ def forward(self, query: torch.Tensor) -> torch.Tensor:
22
+ keys = self.key_projection(query)
23
+ values = self.value_projection(query)
24
+ queries = self.query_projection(query)
25
+
26
+ attention_scores = torch.tanh(queries + keys)
27
+ attention_weights = F.softmax(self.attention_mechanism(attention_scores), dim=1)
28
+
29
+ attended_values = values * attention_weights # Element-wise product
30
+ return attended_values
31
+
32
+ class ResNet50Custom(nn.Module, PyTorchModelHubMixin): # Inherit from PyTorchModelHubMixin
33
+ def __init__(self, input_channels: int, num_classes: int, **kwargs):
34
+ super(ResNet50Custom, self).__init__()
35
+
36
+ # Store config for PyTorchModelHubMixin to serialize to config.json
37
+ self.config = {
38
+ "input_channels": input_channels,
39
+ "num_classes": num_classes,
40
+ **kwargs
41
+ }
42
+
43
+ self.input_channels = input_channels
44
+
45
+ self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
46
+
47
+ self.model.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
48
+
49
+ # The final FC layer of ResNet50Custom will be used *only* when ResNet50Custom is a standalone classifier.
50
+ # When used as a feature extractor within MultiModalModel, this layer will be temporarily replaced by Identity().
51
+ self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ return self.model(x)
55
+
56
+ def get_feature_size(self) -> int:
57
+ return self.model.fc.in_features
58
+
59
+
60
+ class MultiModalModel(nn.Module, PyTorchModelHubMixin): # Inherit from PyTorchModelHubMixin
61
+ def __init__(self,
62
+ image_input_channels: int,
63
+ bathy_input_channels: int,
64
+ sss_input_channels: int,
65
+ num_classes: int,
66
+ attention_type: str = "scaled_dot_product",
67
+ **kwargs): # Added **kwargs for mixin compatibility
68
+ super(MultiModalModel, self).__init__()
69
+
70
+ # Store config for PyTorchModelHubMixin to serialize to config.json
71
+ self.config = {
72
+ "image_input_channels": image_input_channels,
73
+ "bathy_input_channels": bathy_input_channels,
74
+ "sss_input_channels": sss_input_channels,
75
+ "num_classes": num_classes,
76
+ "attention_type": attention_type,
77
+ **kwargs # Pass along any extra kwargs for mixin
78
+ }
79
+
80
+ # Instantiate feature extraction models *inside* MultiModalModel
81
+ # Their final FC layers will be treated as Identity for feature extraction
82
+ self.image_model_feat = ResNet50Custom(input_channels=image_input_channels, num_classes=num_classes)
83
+ self.bathy_model_feat = ResNet50Custom(input_channels=bathy_input_channels, num_classes=num_classes)
84
+ self.sss_model_feat = ResNet50Custom(input_channels=sss_input_channels, num_classes=num_classes)
85
+
86
+ # The ResNet50's feature output size is 2048 before its final FC layer
87
+ feature_dim = self.image_model_feat.get_feature_size() # Should be 2048
88
+
89
+ # Attention layers (AdditiveAttention uses d_model and outputs hidden_dim)
90
+ attention_hidden_dim = 128 # This matches your fc layer input calculation (3*128=384)
91
+ self.attention_image = AdditiveAttention(feature_dim, hidden_dim=attention_hidden_dim)
92
+ self.attention_bathy = AdditiveAttention(feature_dim, hidden_dim=attention_hidden_dim)
93
+ self.attention_sss = AdditiveAttention(feature_dim, hidden_dim=attention_hidden_dim)
94
+
95
+ # Final classification layers
96
+ self.fc = nn.Linear(3 * attention_hidden_dim, 1284)
97
+ self.fc1 = nn.Linear(1284, 32)
98
+ # Ensure num_classes is int for the linear layer
99
+ num_classes_int = int(num_classes)
100
+ if not isinstance(num_classes_int, int):
101
+ raise TypeError("num_classes must be an integer after casting")
102
+ self.fc2 = nn.Linear(32, num_classes_int)
103
+ self.attention_type = attention_type
104
+
105
+ def forward(self, inputs: torch.Tensor, bathy_tensor: torch.Tensor, sss_image: torch.Tensor) -> torch.Tensor:
106
+ # Temporarily replace the final FC layer of the feature extractors with Identity
107
+ # to get the 2048 features, then restore them.
108
+ original_image_fc = self.image_model_feat.model.fc
109
+ original_bathy_fc = self.bathy_model_feat.model.fc
110
+ original_sss_fc = self.sss_model_feat.model.fc
111
+
112
+ self.image_model_feat.model.fc = Identity()
113
+ self.bathy_model_feat.model.fc = Identity()
114
+ self.sss_model_feat.model.fc = Identity()
115
+
116
+ image_features = self.image_model_feat(inputs)
117
+ bathy_features = self.bathy_model_feat(bathy_tensor)
118
+ sss_features = self.sss_model_feat(sss_image)
119
+
120
+ # Restore original FC layers on the feature extractors
121
+ self.image_model_feat.model.fc = original_image_fc
122
+ self.bathy_model_feat.model.fc = original_bathy_fc
123
+ self.sss_model_feat.model.fc = original_sss_fc
124
+
125
+ # Apply attention
126
+ image_features_attended = self.attention_image(image_features)
127
+ bathy_features_attended = self.attention_bathy(bathy_features)
128
+ sss_features_attended = self.attention_sss(sss_features)
129
+
130
+ # Concatenate attended features
131
+ combined_features = torch.cat([image_features_attended, bathy_features_attended, sss_features_attended], dim=1)
132
+
133
+ # Pass through final classification layers
134
+ outputs_1 = self.fc(combined_features)
135
+ output_2 = self.fc1(outputs_1)
136
+ outputs = self.fc2(output_2)
137
+ return outputs