Spaces:
Runtime error
Runtime error
File size: 9,138 Bytes
7a11626 10528ca 7a11626 10528ca 7a11626 10528ca c255c40 7a11626 10528ca cfc39b8 7a11626 10528ca 7a11626 0016575 10528ca 0b4917d 10528ca c7f38a3 5426b53 c7f38a3 10528ca 4bd20e7 10528ca 4bd20e7 10528ca 7a11626 10528ca 7a11626 10528ca 7a11626 10528ca 7a11626 10528ca 7a11626 10528ca 7a11626 10528ca 7a11626 10528ca 7a11626 10528ca 5ecf930 10528ca 4bd20e7 c255c40 4bd20e7 c255c40 10528ca 4bd20e7 10528ca 4bd20e7 10528ca 4bd20e7 10528ca 4bd20e7 c255c40 4bd20e7 10528ca 4bd20e7 10528ca 4bd20e7 10528ca 4bd20e7 10528ca 4bd20e7 7a11626 4bd20e7 7a11626 4bd20e7 10528ca 4bd20e7 10528ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import numpy as np
import time
from pathlib import Path
import torch
import imageio
from my.utils import tqdm
from my.utils.seed import seed_everything
from run_img_sampling import SD, StableDiffusion
from misc import torch_samps_to_imgs
from pose import PoseConfig
from run_nerf import VoxConfig
from voxnerf.utils import every
from voxnerf.vis import stitch_vis, bad_vis as nerf_vis
from run_sjc import render_one_view, tsr_stats
from highres_final_vis import highres_render_one_view
import gradio as gr
import gc
import os
device_glb = torch.device("cuda")
def vis_routine(y, depth):
pane = nerf_vis(y, depth, final_H=256)
im = torch_samps_to_imgs(y)[0]
depth = depth.cpu().numpy()
return pane, im, depth
css = '''
.instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
.arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
#component-4, #component-3, #component-10{min-height: 0}
.duplicate-button img{margin: 0}
'''
with gr.Blocks(css=css) as demo:
# title
gr.Markdown('# [Score Jacobian Chaining](https://github.com/pals-ttic/sjc): Lifting Pretrained 2D Diffusion Models for 3D Generation')
gr.HTML(f'''
<div class="gr-prose" style="max-width: 80%">
<h2>Attention - This Space takes over 30min to run!</h2>
<p>If the Queue is too long you can run locally or duplicate the Space and run it on your own profile using a (paid) private T4 GPU for training. As each T4 costs US$0.60/h, it should cost < US$1 to train most models using default settings! <a style='display:inline-block' href='https://huggingface.co/spaces/MirageML/sjc?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14' alt='Duplicate Space'></a></p>
</div>
''')
# inputs
prompt = gr.Textbox(label="Prompt", max_lines=1, value="A high quality photo of a delicious burger")
iters = gr.Slider(label="Iters", minimum=100, maximum=20000, value=10000, step=100)
seed = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, randomize=True)
button = gr.Button('Generate')
# outputs
image = gr.Image(label="image", visible=True)
# depth = gr.Image(label="depth", visible=True)
video = gr.Video(label="video", visible=False)
logs = gr.Textbox(label="logging")
def submit(prompt, iters, seed):
start_t = time.time()
seed_everything(seed)
# cfgs = {'gddpm': {'model': 'm_lsun_256', 'lsun_cat': 'bedroom', 'imgnet_cat': -1}, 'sd': {'variant': 'v1', 'v2_highres': False, 'prompt': 'A high quality photo of a delicious burger', 'scale': 100.0, 'precision': 'autocast'}, 'lr': 0.05, 'n_steps': 10000, 'emptiness_scale': 10, 'emptiness_weight': 10000, 'emptiness_step': 0.5, 'emptiness_multiplier': 20.0, 'depth_weight': 0, 'var_red': True}
pose = PoseConfig(rend_hw=64, FoV=60.0, R=1.5)
poser = pose.make()
sd_model = SD(variant='v1', v2_highres=False, prompt=prompt, scale=100.0, precision='autocast')
model = sd_model.make()
vox = VoxConfig(
model_type="V_SD", grid_size=100, density_shift=-1.0, c=4,
blend_bg_texture=True, bg_texture_hw=4,
bbox_len=1.0)
vox = vox.make()
lr = 0.05
n_steps = iters
emptiness_scale = 10
emptiness_weight = 10000
emptiness_step = 0.5
emptiness_multiplier = 20.0
depth_weight = 0
var_red = True
assert model.samps_centered()
_, target_H, target_W = model.data_shape()
bs = 1
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
opt = torch.optim.Adamax(vox.opt_params(), lr=lr)
H, W = poser.H, poser.W
Ks, poses, prompt_prefixes = poser.sample_train(n_steps)
ts = model.us[30:-10]
same_noise = torch.randn(1, 4, H, W, device=model.device).repeat(bs, 1, 1, 1)
with tqdm(total=n_steps) as pbar:
for i in range(n_steps):
p = f"{prompt_prefixes[i]} {model.prompt}"
score_conds = model.prompts_emb([p])
y, depth, ws = render_one_view(vox, aabb, H, W, Ks[i], poses[i], return_w=True)
if isinstance(model, StableDiffusion):
pass
else:
y = torch.nn.functional.interpolate(y, (target_H, target_W), mode='bilinear')
opt.zero_grad()
with torch.no_grad():
chosen_σs = np.random.choice(ts, bs, replace=False)
chosen_σs = chosen_σs.reshape(-1, 1, 1, 1)
chosen_σs = torch.as_tensor(chosen_σs, device=model.device, dtype=torch.float32)
# chosen_σs = us[i]
noise = torch.randn(bs, *y.shape[1:], device=model.device)
zs = y + chosen_σs * noise
Ds = model.denoise(zs, chosen_σs, **score_conds)
if var_red:
grad = (Ds - y) / chosen_σs
else:
grad = (Ds - zs) / chosen_σs
grad = grad.mean(0, keepdim=True)
y.backward(-grad, retain_graph=True)
if depth_weight > 0:
center_depth = depth[7:-7, 7:-7]
border_depth_mean = (depth.sum() - center_depth.sum()) / (64*64-50*50)
center_depth_mean = center_depth.mean()
depth_diff = center_depth_mean - border_depth_mean
depth_loss = - torch.log(depth_diff + 1e-12)
depth_loss = depth_weight * depth_loss
depth_loss.backward(retain_graph=True)
emptiness_loss = torch.log(1 + emptiness_scale * ws).mean()
emptiness_loss = emptiness_weight * emptiness_loss
if emptiness_step * n_steps <= i:
emptiness_loss *= emptiness_multiplier
emptiness_loss.backward()
opt.step()
# metric.put_scalars()
with torch.no_grad():
if isinstance(model, StableDiffusion):
y = model.decode(y)
pane, img, depth = vis_routine(y, depth)
yield {
image: gr.update(value=img, visible=True),
video: gr.update(visible=False),
logs: f"Steps: {i}/{n_steps}: \n" + str(tsr_stats(y)),
}
# TODO: Output pane, img and depth to Gradio
pbar.update()
pbar.set_description(p)
# TODO: Save Checkpoint
with torch.no_grad():
n_frames=200
factor=4
ckpt = vox.state_dict()
H, W = poser.H, poser.W
vox.eval()
K, poses = poser.sample_test(n_frames)
del n_frames
poses = poses[60:] # skip the full overhead view; not interesting
aabb = vox.aabb.T.cpu().numpy()
vox = vox.to(device_glb)
num_imgs = len(poses)
all_images = []
for i in (pbar := tqdm(range(num_imgs))):
pose = poses[i]
y, depth = highres_render_one_view(vox, aabb, H, W, K, pose, f=factor)
if isinstance(model, StableDiffusion):
y = model.decode(y)
pane, img, depth = vis_routine(y, depth)
# Save img to output
all_images.append(img)
yield {
image: gr.update(value=img, visible=True),
video: gr.update(visible=False),
logs: str(tsr_stats(y)),
}
output_video = "/tmp/tmp.mp4"
imageio.mimwrite(output_video, all_images, quality=8, fps=10)
end_t = time.time()
yield {
image: gr.update(value=img, visible=False),
video: gr.update(value=output_video, visible=True),
logs: f"Generation Finished in {(end_t - start_t)/ 60:.4f} minutes!",
}
button.click(
submit,
[prompt, iters, seed],
[image, video, logs]
)
# concurrency_count: only allow ONE running progress, else GPU will OOM.
demo.queue(concurrency_count=1)
demo.launch() |