MisterAI commited on
Commit
bfdb818
·
verified ·
1 Parent(s): 789b115

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###TEST02 JUSTE CHARGER FLUX-SCHNELL
2
+ ###
3
+ ###
4
+ import os
5
+ import gradio as gr
6
+ from huggingface_hub import login
7
+ from diffusers import FluxPipeline
8
+ import torch
9
+ from PIL import Image
10
+ import fitz # PyMuPDF pour la gestion des PDF
11
+ import sentencepiece
12
+
13
+ # Force l'utilisation du CPU pour tout PyTorch
14
+ torch.set_default_device("cpu")
15
+
16
+ def load_pdf(pdf_path):
17
+ """Traite le texte d'un fichier PDF"""
18
+ if pdf_path is None:
19
+ return None
20
+ text = ""
21
+ try:
22
+ doc = fitz.open(pdf_path)
23
+ for page in doc:
24
+ text += page.get_text()
25
+ doc.close()
26
+ return text
27
+ except Exception as e:
28
+ print(f"Erreur lors de la lecture du PDF: {str(e)}")
29
+ return None
30
+
31
+ class FluxGenerator:
32
+ def __init__(self):
33
+ self.token = os.getenv('Authentification_HF')
34
+ if not self.token:
35
+ raise ValueError("Token d'authentification HuggingFace non trouvé")
36
+ login(self.token)
37
+ self.pipeline = None
38
+ self.device = "cpu" # Force l'utilisation du CPU
39
+ self.load_model()
40
+
41
+ def load_model(self):
42
+ """Charge le modèle FLUX avec des paramètres optimisés pour CPU"""
43
+ try:
44
+ print("Chargement du modèle FLUX sur CPU...")
45
+ # Configuration spécifique pour CPU
46
+ torch.set_grad_enabled(False) # Désactive le calcul des gradients
47
+
48
+ self.pipeline = FluxPipeline.from_pretrained(
49
+ "black-forest-labs/FLUX.1-schnell",
50
+ revision="refs/pr/1",
51
+ torch_dtype=torch.float32, # Utilise float32 au lieu de bfloat16 pour meilleure compatibilité CPU
52
+ device_map={"": self.device} # Force tous les composants sur CPU
53
+ )
54
+
55
+ # Désactive les optimisations GPU
56
+ self.pipeline.to(self.device)
57
+ print(f"Utilisation forcée du CPU")
58
+ print("Modèle FLUX chargé avec succès!")
59
+
60
+ except Exception as e:
61
+ print(f"Erreur lors du chargement du modèle: {str(e)}")
62
+ raise
63
+
64
+ def generate_image(self, prompt, reference_image=None, pdf_file=None):
65
+ """Génère une image à partir d'un prompt et optionnellement une référence"""
66
+ try:
67
+ # Si un PDF est fourni, ajoute son contenu au prompt
68
+ if pdf_file is not None:
69
+ pdf_text = load_pdf(pdf_file)
70
+ if pdf_text:
71
+ prompt = f"{prompt}\nContexte du PDF:\n{pdf_text}"
72
+
73
+ # Configuration pour génération sur CPU
74
+ with torch.no_grad(): # Désactive le calcul des gradients pendant la génération
75
+ image = self.pipeline(
76
+ prompt=prompt,
77
+ num_inference_steps=20, # Réduit le nombre d'étapes pour accélérer sur CPU
78
+ guidance_scale=0.0,
79
+ max_sequence_length=256,
80
+ generator=torch.Generator(device=self.device).manual_seed(0)
81
+ ).images[0]
82
+
83
+ return image
84
+
85
+ except Exception as e:
86
+ print(f"Erreur lors de la génération de l'image: {str(e)}")
87
+ return None
88
+
89
+ # Instance globale du générateur
90
+ generator = FluxGenerator()
91
+
92
+ def generate(prompt, reference_file):
93
+ """Fonction de génération pour l'interface Gradio"""
94
+ try:
95
+ # Gestion du fichier de référence
96
+ if reference_file is not None:
97
+ if isinstance(reference_file, dict): # Si le fichier est fourni par Gradio
98
+ file_path = reference_file.name
99
+ else: # Si c'est un chemin direct
100
+ file_path = reference_file
101
+
102
+ file_type = file_path.split('.')[-1].lower()
103
+ if file_type in ['pdf']:
104
+ return generator.generate_image(prompt, pdf_file=file_path)
105
+ elif file_type in ['png', 'jpg', 'jpeg']:
106
+ return generator.generate_image(prompt, reference_image=file_path)
107
+
108
+ # Génération sans référence
109
+ return generator.generate_image(prompt)
110
+
111
+ except Exception as e:
112
+ print(f"Erreur détaillée: {str(e)}")
113
+ return None
114
+
115
+ # Interface Gradio simple
116
+ demo = gr.Interface(
117
+ fn=generate,
118
+ inputs=[
119
+ gr.Textbox(label="Prompt", placeholder="Décrivez l'image que vous souhaitez générer..."),
120
+ gr.File(label="Image ou PDF de référence (optionnel)", type="file")
121
+ ],
122
+ outputs=gr.Image(label="Image générée"),
123
+ title="Test du modèle FLUX (CPU)",
124
+ description="Interface simple pour tester la génération d'images avec FLUX (optimisé pour CPU)"
125
+ )
126
+
127
+ if __name__ == "__main__":
128
+ demo.launch()