Chao Xu
commited on
Commit
·
6c1250a
1
Parent(s):
0e93edd
empty cache
Browse files- sam_utils.py +1 -0
- zero123_utils.py +8 -9
sam_utils.py
CHANGED
@@ -103,5 +103,6 @@ def sam_out_nosave(predictor, input_image, *bbox_sliders):
|
|
103 |
out_image_bbox = out_image.copy()
|
104 |
out_image[:, :, 3] = mask.astype(np.uint8) * 255
|
105 |
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
|
|
|
106 |
return Image.fromarray(out_image_bbox, mode='RGBA')
|
107 |
cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
|
|
|
103 |
out_image_bbox = out_image.copy()
|
104 |
out_image[:, :, 3] = mask.astype(np.uint8) * 255
|
105 |
out_image_bbox[:, :, 3] = masks_bbox[-1].astype(np.uint8) * 255 # np.argmax(scores_bbox)
|
106 |
+
torch.cuda.empty_cache()
|
107 |
return Image.fromarray(out_image_bbox, mode='RGBA')
|
108 |
cv2.imwrite(save_path, cv2.cvtColor(out_image_bbox, cv2.COLOR_RGBA2BGRA))
|
zero123_utils.py
CHANGED
@@ -61,9 +61,9 @@ def init_model(device, ckpt):
|
|
61 |
return models
|
62 |
|
63 |
@torch.no_grad()
|
64 |
-
def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision='
|
65 |
precision_scope = autocast if precision == 'autocast' else nullcontext
|
66 |
-
with precision_scope(
|
67 |
with model.ema_scope():
|
68 |
c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1)
|
69 |
T = []
|
@@ -98,7 +98,9 @@ def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision=
|
|
98 |
print(samples_ddim.shape)
|
99 |
# samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
|
100 |
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
101 |
-
|
|
|
|
|
102 |
|
103 |
|
104 |
def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], device="cuda"):
|
@@ -118,7 +120,7 @@ def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], d
|
|
118 |
for stage1_idx in range(len(x_samples_ddims_8)):
|
119 |
if adjust_set != [] and stage1_idx not in adjust_set:
|
120 |
continue
|
121 |
-
x_sample = 255.0 * rearrange(x_samples_ddims_8[stage1_idx].
|
122 |
Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(save_path_8, '%d.png'%(stage1_idx)))
|
123 |
del x_samples_ddims_8
|
124 |
del input_im
|
@@ -148,7 +150,7 @@ def predict_stage1_gradio(model, raw_im, save_path = "", adjust_set=[], device="
|
|
148 |
for stage1_idx in range(len(delta_x_1_8)):
|
149 |
if adjust_set != [] and stage1_idx not in adjust_set:
|
150 |
continue
|
151 |
-
x_sample = 255.0 * rearrange(x_samples_ddims_8[sample_idx].
|
152 |
out_image = Image.fromarray(x_sample.astype(np.uint8))
|
153 |
ret_imgs.append(out_image)
|
154 |
if save_path:
|
@@ -177,16 +179,13 @@ def infer_stage_2(model, save_path_stage1, save_path_stage2, delta_x_2, delta_y_
|
|
177 |
input_im_init = input_im_init / 255.0
|
178 |
input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device)
|
179 |
input_im = input_im * 2 - 1
|
180 |
-
print("debug input device", input_im.device)
|
181 |
-
print("debug model device", model.device)
|
182 |
# infer stage 2
|
183 |
sampler = DDIMSampler(model)
|
184 |
-
print("debug sampler device", sampler.device)
|
185 |
# sampler.to(device)
|
186 |
# stage2_in = x_samples_ddims[stage1_idx][None, ...].to(device) * 2 - 1
|
187 |
x_samples_ddims_stage2 = sample_model_batch(model, sampler, input_im, delta_x_2, delta_y_2, n_samples=len(delta_x_2), ddim_steps=ddim_steps, scale=scale)
|
188 |
for stage2_idx in range(len(delta_x_2)):
|
189 |
-
x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].
|
190 |
Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
|
191 |
del input_im
|
192 |
del sampler
|
|
|
61 |
return models
|
62 |
|
63 |
@torch.no_grad()
|
64 |
+
def sample_model_batch(model, sampler, input_im, xs, ys, n_samples=4, precision='autocast', ddim_eta=1.0, ddim_steps=75, scale=3.0, h=256, w=256):
|
65 |
precision_scope = autocast if precision == 'autocast' else nullcontext
|
66 |
+
with precision_scope("cuda"):
|
67 |
with model.ema_scope():
|
68 |
c = model.get_learned_conditioning(input_im).tile(n_samples, 1, 1)
|
69 |
T = []
|
|
|
98 |
print(samples_ddim.shape)
|
99 |
# samples_ddim = torch.nn.functional.interpolate(samples_ddim, 64, mode='nearest', antialias=False)
|
100 |
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
101 |
+
ret_imgs = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0).cpu()
|
102 |
+
del cond, c, x_samples_ddim, samples_ddim, uc
|
103 |
+
return ret_imgs
|
104 |
|
105 |
|
106 |
def predict_stage1(model, sampler, input_img_path, save_path_8, adjust_set=[], device="cuda"):
|
|
|
120 |
for stage1_idx in range(len(x_samples_ddims_8)):
|
121 |
if adjust_set != [] and stage1_idx not in adjust_set:
|
122 |
continue
|
123 |
+
x_sample = 255.0 * rearrange(x_samples_ddims_8[stage1_idx].numpy(), 'c h w -> h w c')
|
124 |
Image.fromarray(x_sample.astype(np.uint8)).save(os.path.join(save_path_8, '%d.png'%(stage1_idx)))
|
125 |
del x_samples_ddims_8
|
126 |
del input_im
|
|
|
150 |
for stage1_idx in range(len(delta_x_1_8)):
|
151 |
if adjust_set != [] and stage1_idx not in adjust_set:
|
152 |
continue
|
153 |
+
x_sample = 255.0 * rearrange(x_samples_ddims_8[sample_idx].numpy(), 'c h w -> h w c')
|
154 |
out_image = Image.fromarray(x_sample.astype(np.uint8))
|
155 |
ret_imgs.append(out_image)
|
156 |
if save_path:
|
|
|
179 |
input_im_init = input_im_init / 255.0
|
180 |
input_im = transforms.ToTensor()(input_im_init).unsqueeze(0).to(device)
|
181 |
input_im = input_im * 2 - 1
|
|
|
|
|
182 |
# infer stage 2
|
183 |
sampler = DDIMSampler(model)
|
|
|
184 |
# sampler.to(device)
|
185 |
# stage2_in = x_samples_ddims[stage1_idx][None, ...].to(device) * 2 - 1
|
186 |
x_samples_ddims_stage2 = sample_model_batch(model, sampler, input_im, delta_x_2, delta_y_2, n_samples=len(delta_x_2), ddim_steps=ddim_steps, scale=scale)
|
187 |
for stage2_idx in range(len(delta_x_2)):
|
188 |
+
x_sample_stage2 = 255.0 * rearrange(x_samples_ddims_stage2[stage2_idx].numpy(), 'c h w -> h w c')
|
189 |
Image.fromarray(x_sample_stage2.astype(np.uint8)).save(os.path.join(save_path_stage2, '%d_%d.png'%(stage1_idx, stage2_idx)))
|
190 |
del input_im
|
191 |
del sampler
|