markhristov commited on
Commit
53f8aa7
·
1 Parent(s): b97ddc6

hf changes

Browse files
Files changed (1) hide show
  1. app.py +18 -22
app.py CHANGED
@@ -5,16 +5,14 @@ import torch
5
  from tqdm.auto import tqdm
6
  from PIL import Image
7
  import gradio as gr
8
- #from IPython.display import display
9
 
10
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
11
- text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
12
 
13
- # Here we use a different VAE to the original release, which has been fine-tuned for more steps
14
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16)
15
- unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16)
16
 
17
- beta_start,beta_end = 0.00085,0.012
18
  height = 512
19
  width = 512
20
  num_inference_steps = 70
@@ -22,42 +20,40 @@ guidance_scale = 7.5
22
  batch_size = 1
23
  scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear", num_train_timesteps=1000)
24
 
25
- #prompt = ["a photograph of an astronaut riding a horse"]
26
-
27
  def text_enc(prompts, maxlen=None):
28
- if maxlen is None: maxlen = tokenizer.model_max_length
 
29
  inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
30
- input_ids = inp.input_ids.to(torch.long)
 
31
  return text_encoder(input_ids)[0]
32
 
33
  def do_both(prompts):
34
  def mk_img(t):
35
  image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
36
  return Image.fromarray((image*255).round().astype("uint8"))
37
-
38
  def mk_samples(prompts, g=7.5, seed=100, steps=70):
39
  bs = len(prompts)
40
  text = text_enc(prompts)
41
  uncond = text_enc([""] * bs, text.shape[1])
42
  emb = torch.cat([uncond, text])
43
- if seed: torch.manual_seed(seed)
44
-
 
45
  latents = torch.randn((bs, unet.config.in_channels, height//8, width//8))
46
  scheduler.set_timesteps(steps)
47
  latents = latents.float() * scheduler.init_noise_sigma
48
-
49
  for i,ts in enumerate(tqdm(scheduler.timesteps)):
50
  inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
51
  with torch.no_grad(): u,t = unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
52
  pred = u + g*(t-u)
53
  latents = scheduler.step(pred, ts, latents).prev_sample
54
-
55
  with torch.no_grad(): return vae.decode(1 / 0.18215 * latents).sample
 
56
  images = mk_samples([prompts])
57
  for img in images: return(mk_img(img))
58
-
59
- # do_both(prompt)
60
- # images = mk_samples(prompt)
61
- #iface = gr.Interface(fn=do_both, inputs=gr.inputs.Textbox(lines=2, label="Enter text prompt"), outputs=gr.outputs.Image(type="numpy", label="Generated Image")).launch()
62
- gr.Interface(do_both, gr.Text(), gr.Image(), title = 'Stable Diffusion model from scratch').launch(share = True, debug = True)
63
- # for img in images: display(mk_img(img))
 
5
  from tqdm.auto import tqdm
6
  from PIL import Image
7
  import gradio as gr
 
8
 
9
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
10
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
11
 
12
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema")
13
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
 
14
 
15
+ beta_start, beta_end = 0.00085, 0.012
16
  height = 512
17
  width = 512
18
  num_inference_steps = 70
 
20
  batch_size = 1
21
  scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear", num_train_timesteps=1000)
22
 
 
 
23
  def text_enc(prompts, maxlen=None):
24
+ if maxlen is None:
25
+ maxlen = tokenizer.model_max_length
26
  inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
27
+ input_ids = inp.input_ids
28
+ input_ids = input_ids.to(torch.int)
29
  return text_encoder(input_ids)[0]
30
 
31
  def do_both(prompts):
32
  def mk_img(t):
33
  image = (t/2+0.5).clamp(0,1).detach().cpu().permute(1, 2, 0).numpy()
34
  return Image.fromarray((image*255).round().astype("uint8"))
35
+
36
  def mk_samples(prompts, g=7.5, seed=100, steps=70):
37
  bs = len(prompts)
38
  text = text_enc(prompts)
39
  uncond = text_enc([""] * bs, text.shape[1])
40
  emb = torch.cat([uncond, text])
41
+ if seed:
42
+ torch.manual_seed(seed)
43
+
44
  latents = torch.randn((bs, unet.config.in_channels, height//8, width//8))
45
  scheduler.set_timesteps(steps)
46
  latents = latents.float() * scheduler.init_noise_sigma
47
+
48
  for i,ts in enumerate(tqdm(scheduler.timesteps)):
49
  inp = scheduler.scale_model_input(torch.cat([latents] * 2), ts)
50
  with torch.no_grad(): u,t = unet(inp, ts, encoder_hidden_states=emb).sample.chunk(2)
51
  pred = u + g*(t-u)
52
  latents = scheduler.step(pred, ts, latents).prev_sample
53
+
54
  with torch.no_grad(): return vae.decode(1 / 0.18215 * latents).sample
55
+
56
  images = mk_samples([prompts])
57
  for img in images: return(mk_img(img))
58
+
59
+ gr.Interface(do_both, gr.Text(), gr.Image(), title='Stable Diffusion model from scratch').launch(share=True, debug=True)