from segment_anything import sam_model_registry, SamPredictor from .common import * MODEL_PATH = f'{MODEL_FOLDER}/sam/sam_vit_h_4b8939.pth' DOWNLOAD_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth' # pre-download download_file(DOWNLOAD_URL, MODEL_PATH) def load_model(): print ("Loading model: SAM") download_file(DOWNLOAD_URL, MODEL_PATH) model_type = "vit_h" device = "cuda" sam = sam_model_registry[model_type](checkpoint=MODEL_PATH) sam.to(device=device) sam_predictor = SamPredictor(sam) print ("SAM loaded") return sam_predictor