Spaces:
Running
on
Zero
Running
on
Zero
import os, sys, shutil | |
import cv2 | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry | |
def show_anns(anns): | |
if len(anns) == 0: | |
return | |
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) | |
ax = plt.gca() | |
ax.set_autoscale_on(True) | |
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 3)) | |
# img[:,:,3] = 0 | |
for ann in sorted_anns: | |
m = ann['segmentation'] | |
color_mask = np.concatenate([np.random.random(3)]) | |
img[m] = color_mask | |
return img*255 | |
def show_mask(mask, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
return mask_image * 255 | |
def show_points(coords, labels, ax, marker_size=375): | |
pos_points = coords[labels==1] | |
neg_points = coords[labels==0] | |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1) | |
if __name__ == "__main__": | |
input_parent_folder = "validation_tmp" | |
# Init SAM for segmentation task | |
model_type = "vit_h" | |
weight_path = "pretrained/sam_vit_h_4b8939.pth" | |
sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda") | |
sam_predictor = SamPredictor(sam) | |
mask_generator = SamAutomaticMaskGenerator(sam) | |
# Iterate the folder | |
for sub_dir_name in sorted(os.listdir(input_parent_folder)): | |
print("We are processing ", sub_dir_name) | |
ref_img_path = os.path.join(input_parent_folder, sub_dir_name, 'im_0.jpg') | |
data_txt_path = os.path.join(input_parent_folder, sub_dir_name, 'data.txt') | |
# Read the image and process | |
image = cv2.imread(ref_img_path) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
# Read the positive point | |
data_file = open(data_txt_path, 'r') | |
lines = data_file.readlines() | |
for idx in range(len(lines)): | |
frame_idx, horizontal, vertical = lines[idx].split(' ') | |
vertical, horizontal = int(float(vertical)), int(float(horizontal)) | |
positive_point_cords = [[horizontal, vertical]] | |
positive_point_cords = np.array(positive_point_cords) | |
positive_point_labels = np.ones(len(positive_point_cords)) | |
print(positive_point_cords) | |
# Set the SAM predictor | |
sam_predictor.set_image(np.uint8(image)) | |
masks, scores, logits = sam_predictor.predict( | |
point_coords = positive_point_cords, # Only positive points here | |
point_labels = positive_point_labels, | |
multimask_output = False, | |
) | |
# print("Detected mask length is ", len(masks)) | |
# Visualize | |
mask_img = show_mask(masks[0]) | |
cv2.imwrite(os.path.join(input_parent_folder, sub_dir_name, "first_contact0.png"), mask_img) | |
break | |
# SAM all | |
sam_all = mask_generator.generate(image) | |
all_sam_imgs = show_anns(sam_all) | |
cv2.imwrite("sam_all.png", all_sam_imgs) | |