File size: 6,776 Bytes
7388dd9
0f407c0
 
 
 
 
 
 
7388dd9
 
0f407c0
7388dd9
 
 
 
0f407c0
7388dd9
 
 
 
0f407c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7388dd9
0f407c0
 
 
7388dd9
0f407c0
7388dd9
 
 
 
 
 
a6bd83c
7388dd9
 
 
65dd249
 
 
 
 
 
 
 
 
 
 
6e63c31
 
 
 
 
 
 
 
 
0f407c0
 
 
9d225f0
e372698
6e63c31
 
65dd249
 
 
 
7388dd9
 
 
 
 
 
 
 
 
 
 
 
0f407c0
 
 
 
 
7388dd9
0f407c0
 
 
 
 
 
7388dd9
 
0f407c0
 
7388dd9
 
 
0f407c0
7388dd9
0f407c0
 
 
7388dd9
 
 
 
 
0f407c0
 
 
 
 
 
7388dd9
0f407c0
 
 
7388dd9
0f407c0
7388dd9
 
 
 
0f407c0
 
7388dd9
0f407c0
7388dd9
0f407c0
 
 
 
 
 
 
 
 
 
 
 
 
 
7388dd9
0f407c0
 
 
 
30f2c19
0f407c0
 
7388dd9
 
0f407c0
 
 
6e63c31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
##TEST125 J'en peu plus de FLUX
import os
import gradio as gr
from huggingface_hub import login
from diffusers import FluxPipeline
import torch
from PIL import Image
import fitz  # PyMuPDF pour la gestion des PDF
import gc  # Pour le garbage collector
import psutil  # Pour monitorer la mémoire

# Configuration globale pour réduire l'utilisation de la mémoire
torch.set_default_device("cpu")
torch.set_num_threads(2)  # Limite le nombre de threads CPU
torch.set_grad_enabled(False)  # Désactive complètement le calcul des gradients

def get_memory_usage():
    """Retourne l'utilisation actuelle de la mémoire en GB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024 / 1024  # Conversion en GB

def load_pdf(pdf_path):
    """Traite le texte d'un fichier PDF"""
    if pdf_path is None:
        return None
    text = ""
    try:
        doc = fitz.open(pdf_path)
        for page in doc:
            text += page.get_text()
        doc.close()
        return text
    except Exception as e:
        print(f"Erreur lors de la lecture du PDF: {str(e)}")
        return None

class FluxGenerator:
    def __init__(self):
        self.token = os.getenv('Authentification_HF')
        if not self.token:
            raise ValueError("Token d'authentification HuggingFace non trouvé")
        login(self.token)
        self.pipeline = None
        self.device = "cpu"
        self.load_model()

    def load_model(self):
        """Charge le modèle FLUX avec des paramètres optimisés pour faible mémoire"""
        try:
            print("Chargement du modèle FLUX avec optimisations mémoire...")
            print(f"Mémoire utilisée avant chargement: {get_memory_usage():.2f} GB")

            # Configuration pour minimiser l'utilisation de la mémoire
            model_kwargs = {
                "low_cpu_mem_usage": True,
                "torch_dtype": torch.float8,  # Utilise float16 pour réduire la mémoire
                "use_safetensors": True,  # Utilise safetensors pour un chargement plus efficace
            }

            
#Erreur   Erreur lors du chargement du modèle: `device_map` must be a string.          
#
#            self.pipeline = FluxPipeline.from_pretrained(
#                "black-forest-labs/FLUX.1-schnell",
#                revision="refs/pr/1",
#                device_map={"": self.device},
#                **model_kwargs
#            )

            
#ERREUR `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
#            self.pipeline = FluxPipeline.from_pretrained(
#                "black-forest-labs/FLUX.1-schnell",
#                revision="refs/pr/1",
#                device_map="balanced",  # Utilise une chaîne de caractères au lieu d'un dictionnaire
#                **model_kwargs
#            )        
            
            
            self.pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-schnell",
                revision="refs/pr/1",
                device_map="balanced",
                torch_dtype=torch.float8,
                use_safetensors=True
)            
            
            
            
            

            # Optimisations supplémentaires
            self.pipeline.enable_sequential_cpu_offload()  # Décharge les composants non utilisés
            self.pipeline.enable_attention_slicing(slice_size=1)  # Réduit l'utilisation de la mémoire pendant l'inférence

            # Force le garbage collector
            gc.collect()
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

            print(f"Mémoire utilisée après chargement: {get_memory_usage():.2f} GB")
            print("Modèle FLUX chargé avec succès en mode basse consommation!")

        except Exception as e:
            print(f"Erreur lors du chargement du modèle: {str(e)}")
            raise

    def generate_image(self, prompt, reference_image=None, pdf_file=None):
        """Génère une image avec paramètres optimisés pour la mémoire"""
        try:
            if pdf_file is not None:
                pdf_text = load_pdf(pdf_file)
                if pdf_text:
                    prompt = f"{prompt}\nContexte du PDF:\n{pdf_text}"

            # Paramètres optimisés pour réduire l'utilisation de la mémoire
            with torch.no_grad():
                image = self.pipeline(
                    prompt=prompt,
                    num_inference_steps=4,  # Minimum d'étapes pour économiser la mémoire
                    height=512,  # Taille réduite
                    width=512,   # Taille réduite
                    guidance_scale=0.0,
                    max_sequence_length=128,  # Réduit la longueur de séquence
                    generator=torch.Generator(device=self.device).manual_seed(0)
                ).images[0]

                # Force le nettoyage de la mémoire après génération
                gc.collect()
                torch.cuda.empty_cache() if torch.cuda.is_available() else None

                return image

        except Exception as e:
            print(f"Erreur lors de la génération de l'image: {str(e)}")
            return None

# Instance globale du générateur
generator = None  # On initialise plus tard pour économiser la mémoire

def generate(prompt, reference_file):
    """Fonction de génération pour l'interface Gradio"""
    global generator
    try:
        # Initialisation tardive du générateur
        if generator is None:
            generator = FluxGenerator()

        # Gestion du fichier de référence
        if reference_file is not None:
            if isinstance(reference_file, dict):
                file_path = reference_file.name
            else:
                file_path = reference_file
                
            file_type = file_path.split('.')[-1].lower()
            if file_type in ['pdf']:
                return generator.generate_image(prompt, pdf_file=file_path)
            elif file_type in ['png', 'jpg', 'jpeg']:
                return generator.generate_image(prompt, reference_image=file_path)

        return generator.generate_image(prompt)

    except Exception as e:
        print(f"Erreur détaillée: {str(e)}")
        return None

# Interface Gradio minimaliste
demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Décrivez l'image que vous souhaitez générer..."),
        gr.File(label="Image ou PDF de référence (optionnel)", type="binary")
    ],
    outputs=gr.Image(label="Image générée"),
    title="FLUX (Mode économique)",
    description="Génération d'images optimisée pour systèmes à ressources limitées"
)

if __name__ == "__main__":
    demo.launch()