File size: 4,666 Bytes
bfdb818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8318bc
bfdb818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
###TEST02 JUSTE CHARGER FLUX-SCHNELL
###
###
import os
import gradio as gr
from huggingface_hub import login
from diffusers import FluxPipeline
import torch
from PIL import Image
import fitz  # PyMuPDF pour la gestion des PDF
import sentencepiece

# Force l'utilisation du CPU pour tout PyTorch
torch.set_default_device("cpu")

def load_pdf(pdf_path):
    """Traite le texte d'un fichier PDF"""
    if pdf_path is None:
        return None
    text = ""
    try:
        doc = fitz.open(pdf_path)
        for page in doc:
            text += page.get_text()
        doc.close()
        return text
    except Exception as e:
        print(f"Erreur lors de la lecture du PDF: {str(e)}")
        return None

class FluxGenerator:
    def __init__(self):
        self.token = os.getenv('Authentification_HF')
        if not self.token:
            raise ValueError("Token d'authentification HuggingFace non trouvé")
        login(self.token)
        self.pipeline = None
        self.device = "cpu"  # Force l'utilisation du CPU
        self.load_model()

    def load_model(self):
        """Charge le modèle FLUX avec des paramètres optimisés pour CPU"""
        try:
            print("Chargement du modèle FLUX sur CPU...")
            # Configuration spécifique pour CPU
            torch.set_grad_enabled(False)  # Désactive le calcul des gradients
            
            self.pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-schnell",
                revision="refs/pr/1",
                torch_dtype=torch.float32,  # Utilise float32 au lieu de bfloat16 pour meilleure compatibilité CPU
                device_map={"auto": self.device}  # Force tous les composants sur CPU
            )
            
            # Désactive les optimisations GPU
            self.pipeline.to(self.device)
            print(f"Utilisation forcée du CPU")
            print("Modèle FLUX chargé avec succès!")
            
        except Exception as e:
            print(f"Erreur lors du chargement du modèle: {str(e)}")
            raise

    def generate_image(self, prompt, reference_image=None, pdf_file=None):
        """Génère une image à partir d'un prompt et optionnellement une référence"""
        try:
            # Si un PDF est fourni, ajoute son contenu au prompt
            if pdf_file is not None:
                pdf_text = load_pdf(pdf_file)
                if pdf_text:
                    prompt = f"{prompt}\nContexte du PDF:\n{pdf_text}"

            # Configuration pour génération sur CPU
            with torch.no_grad():  # Désactive le calcul des gradients pendant la génération
                image = self.pipeline(
                    prompt=prompt,
                    num_inference_steps=20,  # Réduit le nombre d'étapes pour accélérer sur CPU
                    guidance_scale=0.0,
                    max_sequence_length=256,
                    generator=torch.Generator(device=self.device).manual_seed(0)
                ).images[0]

            return image

        except Exception as e:
            print(f"Erreur lors de la génération de l'image: {str(e)}")
            return None

# Instance globale du générateur
generator = FluxGenerator()

def generate(prompt, reference_file):
    """Fonction de génération pour l'interface Gradio"""
    try:
        # Gestion du fichier de référence
        if reference_file is not None:
            if isinstance(reference_file, dict):  # Si le fichier est fourni par Gradio
                file_path = reference_file.name
            else:  # Si c'est un chemin direct
                file_path = reference_file
                
            file_type = file_path.split('.')[-1].lower()
            if file_type in ['pdf']:
                return generator.generate_image(prompt, pdf_file=file_path)
            elif file_type in ['png', 'jpg', 'jpeg']:
                return generator.generate_image(prompt, reference_image=file_path)

        # Génération sans référence
        return generator.generate_image(prompt)

    except Exception as e:
        print(f"Erreur détaillée: {str(e)}")
        return None

# Interface Gradio simple
demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Décrivez l'image que vous souhaitez générer..."),
        gr.File(label="Image ou PDF de référence (optionnel)", type="file")
    ],
    outputs=gr.Image(label="Image générée"),
    title="Test du modèle FLUX (CPU)",
    description="Interface simple pour tester la génération d'images avec FLUX (optimisé pour CPU)"
)

if __name__ == "__main__":
    demo.launch()