caball21 commited on
Commit
93e1ff2
·
verified ·
1 Parent(s): fc03d28

Create sam2_model_stub.py

Browse files
Files changed (1) hide show
  1. sam2_model_stub.py +45 -0
sam2_model_stub.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sam2_model_stub.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ class SAM2Hierarchical(nn.Module):
8
+ def __init__(self, num_classes=6, in_channels=3, backbone="vit_b", freeze_backbone=True, use_cls_head=True):
9
+ super().__init__()
10
+ self.use_cls_head = use_cls_head
11
+
12
+ # Minimal vision backbone stub (fake transformer or CNN)
13
+ self.backbone = nn.Sequential(
14
+ nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
15
+ nn.BatchNorm2d(64),
16
+ nn.ReLU(inplace=True),
17
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
18
+ nn.BatchNorm2d(128),
19
+ nn.ReLU(inplace=True)
20
+ )
21
+
22
+ # Segmentation head stub
23
+ self.segmentation_head = nn.Sequential(
24
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
25
+ nn.ReLU(inplace=True),
26
+ nn.Conv2d(64, num_classes, kernel_size=1)
27
+ )
28
+
29
+ # Optional classification head
30
+ if self.use_cls_head:
31
+ self.cls_head = nn.Linear(128, num_classes)
32
+
33
+ if freeze_backbone:
34
+ for param in self.backbone.parameters():
35
+ param.requires_grad = False
36
+
37
+ def forward(self, x):
38
+ features = self.backbone(x)
39
+ logits = self.segmentation_head(features)
40
+
41
+ if self.use_cls_head:
42
+ # Just return segmentation output; inference only cares about logits
43
+ return logits
44
+
45
+ return logits