File size: 583 Bytes
da1e12f
bfd34e9
 
 
 
 
 
 
 
 
 
da1e12f
bfd34e9
 
da1e12f
bfd34e9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from segment_anything import sam_model_registry, SamPredictor
from .common import *

MODEL_PATH = f'{MODEL_FOLDER}/sam/sam_vit_h_4b8939.pth'
DOWNLOAD_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'

# pre-download
download_file(DOWNLOAD_URL, MODEL_PATH)


def load_model(device='cuda:0'):
    print ("Loading model: SAM")
    download_file(DOWNLOAD_URL, MODEL_PATH)
    sam = sam_model_registry["vit_h"](checkpoint=MODEL_PATH)
    sam.to(device=device)
    sam_predictor = SamPredictor(sam)
    print ("SAM loaded")
    return sam_predictor