MisterAI commited on
Commit
0f407c0
·
verified ·
1 Parent(s): 7b4d814

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +305 -0
app.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###TEST03 JUSTE CHARGER FLUX-SCHNELL
2
+ ###https://huggingface.co/spaces/black-forest-labs/FLUX.1-schnell/blob/main/app.py
3
+ ###
4
+ import os
5
+ import gradio as gr
6
+ from huggingface_hub import login
7
+ from diffusers import FluxPipeline
8
+ import torch
9
+ from PIL import Image
10
+ import fitz # PyMuPDF pour la gestion des PDF
11
+ import sentencepiece
12
+
13
+ import numpy as np
14
+ import random
15
+ import spaces
16
+
17
+
18
+
19
+
20
+
21
+ #
22
+ #import gradio as gr
23
+ #import numpy as np
24
+ #import random
25
+ #import spaces
26
+ #import torch
27
+ #from diffusers import DiffusionPipeline
28
+ #
29
+ #dtype = torch.bfloat16
30
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ #
32
+ #pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
33
+ #
34
+ #MAX_SEED = np.iinfo(np.int32).max
35
+ #MAX_IMAGE_SIZE = 2048
36
+ #
37
+ #@spaces.GPU()
38
+ #def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
39
+ # if randomize_seed:
40
+ # seed = random.randint(0, MAX_SEED)
41
+ # generator = torch.Generator().manual_seed(seed)
42
+ # image = pipe(
43
+ # prompt = prompt,
44
+ # width = width,
45
+ # height = height,
46
+ # num_inference_steps = num_inference_steps,
47
+ # generator = generator,
48
+ # guidance_scale=0.0
49
+ # ).images[0]
50
+ # return image, seed
51
+ #
52
+ #examples = [
53
+ # "a tiny astronaut hatching from an egg on the moon",
54
+ # "a cat holding a sign that says hello world",
55
+ # "an anime illustration of a wiener schnitzel",
56
+ #]
57
+ #
58
+ #css="""
59
+ ##col-container {
60
+ # margin: 0 auto;
61
+ # max-width: 520px;
62
+ #}
63
+ #"""
64
+ #
65
+ #with gr.Blocks(css=css) as demo:
66
+ #
67
+ # with gr.Column(elem_id="col-container"):
68
+ # gr.Markdown(f"""# FLUX.1 [schnell]
69
+ #12B param rectified flow transformer distilled from [FLUX.1 [pro]](https://blackforestlabs.ai/) for 4 step generation
70
+ #[[blog](https://blackforestlabs.ai/announcing-black-forest-labs/)] [[model](https://huggingface.co/black-forest-labs/FLUX.1-schnell)]
71
+ # """)
72
+ #
73
+ # with gr.Row():
74
+ #
75
+ # prompt = gr.Text(
76
+ # label="Prompt",
77
+ # show_label=False,
78
+ # max_lines=1,
79
+ # placeholder="Enter your prompt",
80
+ # container=False,
81
+ # )
82
+ #
83
+ # run_button = gr.Button("Run", scale=0)
84
+ #
85
+ # result = gr.Image(label="Result", show_label=False)
86
+ #
87
+ # with gr.Accordion("Advanced Settings", open=False):
88
+ #
89
+ # seed = gr.Slider(
90
+ # label="Seed",
91
+ # minimum=0,
92
+ # maximum=MAX_SEED,
93
+ # step=1,
94
+ # value=0,
95
+ # )
96
+ #
97
+ # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
98
+ #
99
+ # with gr.Row():
100
+ #
101
+ # width = gr.Slider(
102
+ # label="Width",
103
+ # minimum=256,
104
+ # maximum=MAX_IMAGE_SIZE,
105
+ # step=32,
106
+ # value=1024,
107
+ # )
108
+ #
109
+ # height = gr.Slider(
110
+ # label="Height",
111
+ # minimum=256,
112
+ # maximum=MAX_IMAGE_SIZE,
113
+ # step=32,
114
+ # value=1024,
115
+ # )
116
+ #
117
+ # with gr.Row():
118
+ #
119
+ #
120
+ # num_inference_steps = gr.Slider(
121
+ # label="Number of inference steps",
122
+ # minimum=1,
123
+ # maximum=50,
124
+ # step=1,
125
+ # value=4,
126
+ # )
127
+ #
128
+ # gr.Examples(
129
+ # examples = examples,
130
+ # fn = infer,
131
+ # inputs = [prompt],
132
+ # outputs = [result, seed],
133
+ # cache_examples="lazy"
134
+ # )
135
+ #
136
+ # gr.on(
137
+ # triggers=[run_button.click, prompt.submit],
138
+ # fn = infer,
139
+ # inputs = [prompt, seed, randomize_seed, width, height, num_inference_steps],
140
+ # outputs = [result, seed]
141
+ # )
142
+ #
143
+ #demo.launch()
144
+ #
145
+ #
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+ # Force l'utilisation du CPU pour tout PyTorch
173
+ #torch.set_default_device("cpu")
174
+
175
+
176
+ #dtype = torch.bfloat16
177
+ device = "cuda" if torch.cuda.is_available() else "cpu"
178
+ #
179
+ #pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+ def load_pdf(pdf_path):
193
+ """Traite le texte d'un fichier PDF"""
194
+ if pdf_path is None:
195
+ return None
196
+ text = ""
197
+ try:
198
+ doc = fitz.open(pdf_path)
199
+ for page in doc:
200
+ text += page.get_text()
201
+ doc.close()
202
+ return text
203
+ except Exception as e:
204
+ print(f"Erreur lors de la lecture du PDF: {str(e)}")
205
+ return None
206
+
207
+ class FluxGenerator:
208
+ def __init__(self):
209
+ self.token = os.getenv('Authentification_HF')
210
+ if not self.token:
211
+ raise ValueError("Token d'authentification HuggingFace non trouvé")
212
+ login(self.token)
213
+ self.pipeline = None
214
+ self.device = "cpu" # Force l'utilisation du CPU
215
+ self.load_model()
216
+
217
+ def load_model(self):
218
+ """Charge le modèle FLUX avec des paramètres optimisés pour CPU"""
219
+ try:
220
+ print("Chargement du modèle FLUX sur CPU...")
221
+ # Configuration spécifique pour CPU
222
+ torch.set_grad_enabled(False) # Désactive le calcul des gradients
223
+
224
+ self.pipeline = FluxPipeline.from_pretrained(
225
+ "black-forest-labs/FLUX.1-schnell",
226
+ revision="refs/pr/1",
227
+ torch_dtype=torch.float32 # Utilise float32 au lieu de bfloat16 pour meilleure compatibilité CPU
228
+ )
229
+ # device_map={"cpu": self.device} # Force tous les composants sur CPU
230
+ # )device
231
+
232
+ # Désactive les optimisations GPU
233
+ self.pipeline.to(self.device)
234
+ print(f"Utilisation forcée du CPU")
235
+ print("Modèle FLUX chargé avec succès!")
236
+
237
+ except Exception as e:
238
+ print(f"Erreur lors du chargement du modèle: {str(e)}")
239
+ raise
240
+
241
+ def generate_image(self, prompt, reference_image=None, pdf_file=None):
242
+ """Génère une image à partir d'un prompt et optionnellement une référence"""
243
+ try:
244
+ # Si un PDF est fourni, ajoute son contenu au prompt
245
+ if pdf_file is not None:
246
+ pdf_text = load_pdf(pdf_file)
247
+ if pdf_text:
248
+ prompt = f"{prompt}\nContexte du PDF:\n{pdf_text}"
249
+
250
+ # Configuration pour génération sur CPU
251
+ with torch.no_grad(): # Désactive le calcul des gradients pendant la génération
252
+ image = self.pipeline(
253
+ prompt=prompt,
254
+ num_inference_steps=20, # Réduit le nombre d'étapes pour accélérer sur CPU
255
+ guidance_scale=0.0,
256
+ max_sequence_length=256,
257
+ generator=torch.Generator(device=self.device).manual_seed(0)
258
+ ).images[0]
259
+
260
+ return image
261
+
262
+ except Exception as e:
263
+ print(f"Erreur lors de la génération de l'image: {str(e)}")
264
+ return None
265
+
266
+ # Instance globale du générateur
267
+ generator = FluxGenerator()
268
+
269
+ def generate(prompt, reference_file):
270
+ """Fonction de génération pour l'interface Gradio"""
271
+ try:
272
+ # Gestion du fichier de référence
273
+ if reference_file is not None:
274
+ if isinstance(reference_file, dict): # Si le fichier est fourni par Gradio
275
+ file_path = reference_file.name
276
+ else: # Si c'est un chemin direct
277
+ file_path = reference_file
278
+
279
+ file_type = file_path.split('.')[-1].lower()
280
+ if file_type in ['pdf']:
281
+ return generator.generate_image(prompt, pdf_file=file_path)
282
+ elif file_type in ['png', 'jpg', 'jpeg']:
283
+ return generator.generate_image(prompt, reference_image=file_path)
284
+
285
+ # Génération sans référence
286
+ return generator.generate_image(prompt)
287
+
288
+ except Exception as e:
289
+ print(f"Erreur détaillée: {str(e)}")
290
+ return None
291
+
292
+ # Interface Gradio simple
293
+ demo = gr.Interface(
294
+ fn=generate,
295
+ inputs=[
296
+ gr.Textbox(label="Prompt", placeholder="Décrivez l'image que vous souhaitez générer..."),
297
+ gr.File(label="Image ou PDF de référence (optionnel)", type="file")
298
+ ],
299
+ outputs=gr.Image(label="Image générée"),
300
+ title="Test du modèle FLUX (CPU)",
301
+ description="Interface simple pour tester la génération d'images avec FLUX (optimisé pour CPU)"
302
+ )
303
+
304
+ if __name__ == "__main__":
305
+ demo.launch()