Commit
·
53f8aa7
1
Parent(s):
b97ddc6
hf changes
Browse files
app.py
CHANGED
@@ -5,16 +5,14 @@ import torch
|
|
5 |
from tqdm.auto import tqdm
|
6 |
from PIL import Image
|
7 |
import gradio as gr
|
8 |
-
#from IPython.display import display
|
9 |
|
10 |
-
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14"
|
11 |
-
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14"
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16)
|
16 |
|
17 |
-
beta_start,beta_end = 0.00085,0.012
|
18 |
height = 512
|
19 |
width = 512
|
20 |
num_inference_steps = 70
|
@@ -22,42 +20,40 @@ guidance_scale = 7.5
|
|
22 |
batch_size = 1
|
23 |
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
24 |
|
25 |
-
#prompt = ["a photograph of an astronaut riding a horse"]
|
26 |
-
|
27 |
def text_enc(prompts, maxlen=None):
|
28 |
-
if maxlen is None:
|
|
|
29 |
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
|
30 |
-
input_ids = inp.input_ids
|
|
|
31 |
return text_encoder(input_ids)[0]
|
32 |
|
33 |
def do_both(prompts):
|
34 |
def mk_img(t):
|
35 |
image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
|
36 |
return Image.fromarray((image*255).round().astype("uint8"))
|
37 |
-
|
38 |
def mk_samples(prompts, g=7.5, seed=100, steps=70):
|
39 |
bs = len(prompts)
|
40 |
text = text_enc(prompts)
|
41 |
uncond = text_enc([""] * bs, text.shape[1])
|
42 |
emb = torch.cat([uncond, text])
|
43 |
-
if seed:
|
44 |
-
|
|
|
45 |
latents = torch.randn((bs, unet.config.in_channels, height//8, width//8))
|
46 |
scheduler.set_timesteps(steps)
|
47 |
latents = latents.float() * scheduler.init_noise_sigma
|
48 |
-
|
49 |
for i,ts in enumerate(tqdm(scheduler.timesteps)):
|
50 |
inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
|
51 |
with torch.no_grad(): u,t = unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
|
52 |
pred = u + g*(t-u)
|
53 |
latents = scheduler.step(pred, ts, latents).prev_sample
|
54 |
-
|
55 |
with torch.no_grad(): return vae.decode(1 / 0.18215 * latents).sample
|
|
|
56 |
images = mk_samples([prompts])
|
57 |
for img in images: return(mk_img(img))
|
58 |
-
|
59 |
-
|
60 |
-
# images = mk_samples(prompt)
|
61 |
-
#iface = gr.Interface(fn=do_both, inputs=gr.inputs.Textbox(lines=2, label="Enter text prompt"), outputs=gr.outputs.Image(type="numpy", label="Generated Image")).launch()
|
62 |
-
gr.Interface(do_both, gr.Text(), gr.Image(), title = 'Stable Diffusion model from scratch').launch(share = True, debug = True)
|
63 |
-
# for img in images: display(mk_img(img))
|
|
|
5 |
from tqdm.auto import tqdm
|
6 |
from PIL import Image
|
7 |
import gradio as gr
|
|
|
8 |
|
9 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
10 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
11 |
|
12 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
|
13 |
+
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
|
|
|
14 |
|
15 |
+
beta_start, beta_end = 0.00085, 0.012
|
16 |
height = 512
|
17 |
width = 512
|
18 |
num_inference_steps = 70
|
|
|
20 |
batch_size = 1
|
21 |
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear", num_train_timesteps=1000)
|
22 |
|
|
|
|
|
23 |
def text_enc(prompts, maxlen=None):
|
24 |
+
if maxlen is None:
|
25 |
+
maxlen = tokenizer.model_max_length
|
26 |
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
|
27 |
+
input_ids = inp.input_ids
|
28 |
+
input_ids = input_ids.to(torch.int)
|
29 |
return text_encoder(input_ids)[0]
|
30 |
|
31 |
def do_both(prompts):
|
32 |
def mk_img(t):
|
33 |
image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
|
34 |
return Image.fromarray((image*255).round().astype("uint8"))
|
35 |
+
|
36 |
def mk_samples(prompts, g=7.5, seed=100, steps=70):
|
37 |
bs = len(prompts)
|
38 |
text = text_enc(prompts)
|
39 |
uncond = text_enc([""] * bs, text.shape[1])
|
40 |
emb = torch.cat([uncond, text])
|
41 |
+
if seed:
|
42 |
+
torch.manual_seed(seed)
|
43 |
+
|
44 |
latents = torch.randn((bs, unet.config.in_channels, height//8, width//8))
|
45 |
scheduler.set_timesteps(steps)
|
46 |
latents = latents.float() * scheduler.init_noise_sigma
|
47 |
+
|
48 |
for i,ts in enumerate(tqdm(scheduler.timesteps)):
|
49 |
inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
|
50 |
with torch.no_grad(): u,t = unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
|
51 |
pred = u + g*(t-u)
|
52 |
latents = scheduler.step(pred, ts, latents).prev_sample
|
53 |
+
|
54 |
with torch.no_grad(): return vae.decode(1 / 0.18215 * latents).sample
|
55 |
+
|
56 |
images = mk_samples([prompts])
|
57 |
for img in images: return(mk_img(img))
|
58 |
+
|
59 |
+
gr.Interface(do_both, gr.Text(), gr.Image(), title='Stable Diffusion model from scratch').launch(share=True, debug=True)
|
|
|
|
|
|
|
|