vish26 commited on
Commit
173da7a
·
verified ·
1 Parent(s): 69d845c

Update app/inference.py

Browse files
Files changed (1) hide show
  1. app/inference.py +4 -4
app/inference.py CHANGED
@@ -20,11 +20,11 @@ def generate_image_with_clip_score(prompt, num_inference_steps=50, guidance_scal
20
  model_path = os.path.join(model_path, "epoch_4")
21
 
22
  unet = UNet2DConditionModel.from_pretrained(os.path.join(model_path, "unet"), use_safetensors=True).to(device)
23
- text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
24
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
25
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
26
  scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear")
27
- clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
28
 
29
  checkpoint = torch.load(os.path.join(model_path, "training_state.pth"), map_location=device)
30
  unet.load_state_dict(checkpoint['model_state_dict'])
 
20
  model_path = os.path.join(model_path, "epoch_4")
21
 
22
  unet = UNet2DConditionModel.from_pretrained(os.path.join(model_path, "unet"), use_safetensors=True).to(device)
23
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32",cache_dir="/tmp/huggingface").to(device)
24
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32",cache_dir="/tmp/huggingface")
25
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse",cache_dir="/tmp/huggingface").to(device)
26
  scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="linear")
27
+ clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32",cache_dir="/tmp/huggingface").to(device)
28
 
29
  checkpoint = torch.load(os.path.join(model_path, "training_state.pth"), map_location=device)
30
  unet.load_state_dict(checkpoint['model_state_dict'])