Spaces:
Runtime error
Runtime error
from segment_anything import sam_model_registry, SamPredictor | |
from .common import * | |
MODEL_PATH = f'{MODEL_FOLDER}/sam/sam_vit_h_4b8939.pth' | |
DOWNLOAD_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth' | |
# pre-download | |
download_file(DOWNLOAD_URL, MODEL_PATH) | |
def load_model(): | |
print ("Loading model: SAM") | |
download_file(DOWNLOAD_URL, MODEL_PATH) | |
model_type = "vit_h" | |
device = "cuda" | |
sam = sam_model_registry[model_type](checkpoint=MODEL_PATH) | |
sam.to(device=device) | |
sam_predictor = SamPredictor(sam) | |
print ("SAM loaded") | |
return sam_predictor | |