This-and-That / scripts /generate_sam_this_that.py
HikariDawn777's picture
feat: initial push
59b2a81
import os, sys, shutil
import cv2
import numpy as np
import matplotlib.pyplot as plt
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
def show_anns(anns):
if len(anns) == 0:
return
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(True)
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 3))
# img[:,:,3] = 0
for ann in sorted_anns:
m = ann['segmentation']
color_mask = np.concatenate([np.random.random(3)])
img[m] = color_mask
return img*255
def show_mask(mask, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
return mask_image * 255
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels==1]
neg_points = coords[labels==0]
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1)
if __name__ == "__main__":
input_parent_folder = "validation_tmp"
# Init SAM for segmentation task
model_type = "vit_h"
weight_path = "pretrained/sam_vit_h_4b8939.pth"
sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda")
sam_predictor = SamPredictor(sam)
mask_generator = SamAutomaticMaskGenerator(sam)
# Iterate the folder
for sub_dir_name in sorted(os.listdir(input_parent_folder)):
print("We are processing ", sub_dir_name)
ref_img_path = os.path.join(input_parent_folder, sub_dir_name, 'im_0.jpg')
data_txt_path = os.path.join(input_parent_folder, sub_dir_name, 'data.txt')
# Read the image and process
image = cv2.imread(ref_img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Read the positive point
data_file = open(data_txt_path, 'r')
lines = data_file.readlines()
for idx in range(len(lines)):
frame_idx, horizontal, vertical = lines[idx].split(' ')
vertical, horizontal = int(float(vertical)), int(float(horizontal))
positive_point_cords = [[horizontal, vertical]]
positive_point_cords = np.array(positive_point_cords)
positive_point_labels = np.ones(len(positive_point_cords))
print(positive_point_cords)
# Set the SAM predictor
sam_predictor.set_image(np.uint8(image))
masks, scores, logits = sam_predictor.predict(
point_coords = positive_point_cords, # Only positive points here
point_labels = positive_point_labels,
multimask_output = False,
)
# print("Detected mask length is ", len(masks))
# Visualize
mask_img = show_mask(masks[0])
cv2.imwrite(os.path.join(input_parent_folder, sub_dir_name, "first_contact0.png"), mask_img)
break
# SAM all
sam_all = mask_generator.generate(image)
all_sam_imgs = show_anns(sam_all)
cv2.imwrite("sam_all.png", all_sam_imgs)