File size: 592 Bytes
105f49d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import torch.nn as nn
import torch.nn.functional as F

# Simplified U2Net architecture would go here
# For actual implementation, you'd want to import from the original repo

def load_model(model_name):
    model_urls = {
        "u2net": "https://github.com/xuebinqin/U-2-Net/raw/master/models/u2net.pth",
        "u2net_human_seg": "https://github.com/xuebinqin/U-2-Net/raw/master/models/u2net_human_seg.pth"
    }
    model = U2Net()  # This would be the actual U2Net class
    model.load_state_dict(torch.hub.load_state_dict_from_url(model_urls[model_name]))
    return model