Janeka commited on
Commit
105f49d
·
verified ·
1 Parent(s): ec18915

Create u2net.py

Browse files
Files changed (1) hide show
  1. u2net.py +15 -0
u2net.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # Simplified U2Net architecture would go here
6
+ # For actual implementation, you'd want to import from the original repo
7
+
8
+ def load_model(model_name):
9
+ model_urls = {
10
+ "u2net": "https://github.com/xuebinqin/U-2-Net/raw/master/models/u2net.pth",
11
+ "u2net_human_seg": "https://github.com/xuebinqin/U-2-Net/raw/master/models/u2net_human_seg.pth"
12
+ }
13
+ model = U2Net() # This would be the actual U2Net class
14
+ model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls[model_name]))
15
+ return model