Chao Xu
commited on
Commit
·
c0c3e1b
1
Parent(s):
6c1250a
pruning
Browse files- sam_utils.py +3 -57
- zero123_utils.py +4 -4
sam_utils.py
CHANGED
@@ -1,14 +1,10 @@
|
|
1 |
import os
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
-
# import matplotlib.pyplot as plt
|
5 |
-
import cv2
|
6 |
from PIL import Image
|
7 |
-
# from PIL import Image
|
8 |
import time
|
9 |
-
from utils import find_image_file
|
10 |
|
11 |
-
from segment_anything import sam_model_registry, SamPredictor
|
12 |
|
13 |
def sam_init(device_id=0):
|
14 |
import inspect
|
@@ -22,60 +18,11 @@ def sam_init(device_id=0):
|
|
22 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
23 |
sam.to(device=device)
|
24 |
predictor = SamPredictor(sam)
|
25 |
-
# mask_generator = SamAutomaticMaskGenerator(sam)
|
26 |
return predictor
|
27 |
|
28 |
-
def sam_out(predictor, shape_dir):
|
29 |
-
image_path = os.path.join(shape_dir, find_image_file(shape_dir))
|
30 |
-
save_path = os.path.join(shape_dir, "image_sam.png")
|
31 |
-
bbox_path = os.path.join(shape_dir, "bbox.txt")
|
32 |
-
bbox = np.loadtxt(bbox_path, delimiter=',')
|
33 |
-
image = cv2.imread(image_path)
|
34 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
35 |
-
|
36 |
-
start_time = time.time()
|
37 |
-
predictor.set_image(image)
|
38 |
-
|
39 |
-
h, w, _ = image.shape
|
40 |
-
input_point = np.array([[h//2, w//2]])
|
41 |
-
input_label = np.array([1])
|
42 |
-
|
43 |
-
masks, scores, logits = predictor.predict(
|
44 |
-
point_coords=input_point,
|
45 |
-
point_labels=input_label,
|
46 |
-
multimask_output=True,
|
47 |
-
)
|
48 |
-
|
49 |
-
masks_bbox, scores_bbox, logits_bbox = predictor.predict(
|
50 |
-
box=bbox,
|
51 |
-
multimask_output=True
|
52 |
-
)
|
53 |
-
|
54 |
-
print(f"SAM Time: {time.time() - start_time:.3f}s")
|
55 |
-
opt_idx = np.argmax(scores)
|
56 |
-
mask = masks[opt_idx]
|
57 |
-
out_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
|
58 |
-
out_image[:, :, :3] = image
|
59 |
-
out_image_bbox = out_image.copy()
|
60 |
-
out_image[:, :, 3] = mask.astype(np.uint8) * 255
|
61 |
-
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
|
62 |
-
cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
|
63 |
-
|
64 |
-
|
65 |
-
def convert_from_cv2_to_image(img: np.ndarray) -> Image:
|
66 |
-
return Image.fromarray(img)
|
67 |
-
# return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA))
|
68 |
-
|
69 |
-
def convert_from_image_to_cv2(img: Image) -> np.ndarray:
|
70 |
-
return np.asarray(img)
|
71 |
-
# return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
|
72 |
-
|
73 |
def sam_out_nosave(predictor, input_image, *bbox_sliders):
|
74 |
-
# save_path = os.path.join(shape_dir, "image_sam.png")
|
75 |
-
# bbox_path = os.path.join(shape_dir, "bbox.txt")
|
76 |
-
# bbox = np.loadtxt(bbox_path, delimiter=',')
|
77 |
bbox = np.array(bbox_sliders)
|
78 |
-
image =
|
79 |
|
80 |
start_time = time.time()
|
81 |
predictor.set_image(image)
|
@@ -104,5 +51,4 @@ def sam_out_nosave(predictor, input_image, *bbox_sliders):
|
|
104 |
out_image[:, :, 3] = mask.astype(np.uint8) * 255
|
105 |
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
|
106 |
torch.cuda.empty_cache()
|
107 |
-
return Image.fromarray(out_image_bbox, mode='RGBA')
|
108 |
-
cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
|
|
|
1 |
import os
|
2 |
import numpy as np
|
3 |
import torch
|
|
|
|
|
4 |
from PIL import Image
|
|
|
5 |
import time
|
|
|
6 |
|
7 |
+
from segment_anything import sam_model_registry, SamPredictor
|
8 |
|
9 |
def sam_init(device_id=0):
|
10 |
import inspect
|
|
|
18 |
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
19 |
sam.to(device=device)
|
20 |
predictor = SamPredictor(sam)
|
|
|
21 |
return predictor
|
22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
def sam_out_nosave(predictor, input_image, *bbox_sliders):
|
|
|
|
|
|
|
24 |
bbox = np.array(bbox_sliders)
|
25 |
+
image = np.asarray(input_image)
|
26 |
|
27 |
start_time = time.time()
|
28 |
predictor.set_image(image)
|
|
|
51 |
out_image[:, :, 3] = mask.astype(np.uint8) * 255
|
52 |
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
|
53 |
torch.cuda.empty_cache()
|
54 |
+
return Image.fromarray(out_image_bbox, mode='RGBA')
|
|
zero123_utils.py
CHANGED
@@ -76,7 +76,7 @@ def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision=
|
|
76 |
cond = {}
|
77 |
cond['c_crossattn'] = [c]
|
78 |
# c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach()
|
79 |
-
cond['c_concat'] = [model.encode_first_stage(
|
80 |
.repeat(n_samples, 1, 1, 1)]
|
81 |
if scale != 1.0:
|
82 |
uc = {}
|
@@ -99,7 +99,8 @@ def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision=
|
|
99 |
# samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
|
100 |
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
101 |
ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
102 |
-
del cond, c, x_samples_ddim, samples_ddim, uc
|
|
|
103 |
return ret_imgs
|
104 |
|
105 |
|
@@ -126,6 +127,7 @@ def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], d
|
|
126 |
del input_im
|
127 |
torch.cuda.empty_cache()
|
128 |
|
|
|
129 |
def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="cuda", ddim_steps=75, scale=3.0):
|
130 |
# raw_im = raw_im.resize([256, 256], Image.LANCZOS)
|
131 |
# input_im_init = preprocess_image(models, raw_im, preprocess=False)
|
@@ -157,7 +159,6 @@ def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="
|
|
157 |
out_image.save(os.path.join(save_path, '%d.png'%(stage1_idx)))
|
158 |
sample_idx += 1
|
159 |
del x_samples_ddims_8
|
160 |
-
del input_im
|
161 |
del sampler
|
162 |
torch.cuda.empty_cache()
|
163 |
return ret_imgs
|
@@ -188,7 +189,6 @@ def infer_stage_2(model, save_path_stage1, save_path_stage2, delta_x_2, delta_y_
|
|
188 |
x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c')
|
189 |
Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
|
190 |
del input_im
|
191 |
-
del sampler
|
192 |
del x_samples_ddims_stage2
|
193 |
torch.cuda.empty_cache()
|
194 |
|
|
|
76 |
cond = {}
|
77 |
cond['c_crossattn'] = [c]
|
78 |
# c_concat = model.encode_first_stage((input_im.to(c.device))).mode().detach()
|
79 |
+
cond['c_concat'] = [model.encode_first_stage(input_im).mode().detach()
|
80 |
.repeat(n_samples, 1, 1, 1)]
|
81 |
if scale != 1.0:
|
82 |
uc = {}
|
|
|
99 |
# samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
|
100 |
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
101 |
ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
102 |
+
del cond, c, x_samples_ddim, samples_ddim, uc, input_im
|
103 |
+
torch.cuda.empty_cache()
|
104 |
return ret_imgs
|
105 |
|
106 |
|
|
|
127 |
del input_im
|
128 |
torch.cuda.empty_cache()
|
129 |
|
130 |
+
@torch.no_grad()
|
131 |
def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="cuda", ddim_steps=75, scale=3.0):
|
132 |
# raw_im = raw_im.resize([256, 256], Image.LANCZOS)
|
133 |
# input_im_init = preprocess_image(models, raw_im, preprocess=False)
|
|
|
159 |
out_image.save(os.path.join(save_path, '%d.png'%(stage1_idx)))
|
160 |
sample_idx += 1
|
161 |
del x_samples_ddims_8
|
|
|
162 |
del sampler
|
163 |
torch.cuda.empty_cache()
|
164 |
return ret_imgs
|
|
|
189 |
x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c')
|
190 |
Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
|
191 |
del input_im
|
|
|
192 |
del x_samples_ddims_stage2
|
193 |
torch.cuda.empty_cache()
|
194 |
|