AINovice2005 commited on
Commit
eefcbf8
·
verified ·
1 Parent(s): b5e0ea6

Add a Basic Inference Script for the model

Browse files
Files changed (1) hide show
  1. README.md +34 -0
README.md CHANGED
@@ -48,6 +48,40 @@ python ./inference.py --model_type fast
48
  ```
49
  > **Note:** The inference script will automatically download `meta-llama/Meta-Llama-3.1-8B-Instruct` model files. If you encounter network issues, you can download these files ahead of time and place them in the appropriate cache directory to avoid download failures during inference.
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  ## Gradio Demo
52
 
53
  We also provide a Gradio demo for interactive image generation. You can run the demo with:
 
48
  ```
49
  > **Note:** The inference script will automatically download `meta-llama/Meta-Llama-3.1-8B-Instruct` model files. If you encounter network issues, you can download these files ahead of time and place them in the appropriate cache directory to avoid download failures during inference.
50
 
51
+ ## Basic Inference Script
52
+
53
+ ```python
54
+ import torch
55
+ from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
56
+ from diffusers import HiDreamImagePipeline
57
+ tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
58
+ text_encoder_4 = LlamaForCausalLM.from_pretrained(
59
+ "meta-llama/Meta-Llama-3.1-8B-Instruct",
60
+ output_hidden_states=True,
61
+ output_attentions=True,
62
+ torch_dtype=torch.bfloat16,
63
+ )
64
+
65
+ pipe = HiDreamImagePipeline.from_pretrained(
66
+ "HiDream-ai/HiDream-I1-Dev", # "HiDream-ai/HiDream-I1-Dev" | "HiDream-ai/HiDream-I1-Fast"
67
+ tokenizer_4=tokenizer_4,
68
+ text_encoder_4=text_encoder_4,
69
+ torch_dtype=torch.bfloat16,
70
+ )
71
+
72
+ pipe = pipe.to('cuda')
73
+
74
+ image = pipe(
75
+ 'A cat holding a sign that says "HiDream.ai".',
76
+ height=1024,
77
+ width=1024,
78
+ guidance_scale=5.0, # 0.0 for Dev&Fast
79
+ num_inference_steps=50, # 28 for Dev and 16 for Fast
80
+ generator=torch.Generator("cuda").manual_seed(0),
81
+ ).images[0]
82
+ image.save("output.png")
83
+ ```
84
+
85
  ## Gradio Demo
86
 
87
  We also provide a Gradio demo for interactive image generation. You can run the demo with: