File size: 3,676 Bytes
81e6960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import gradio as gr
from huggingface_hub import login
from diffusers import FluxPipeline
import torch
from PIL import Image
import fitz  # PyMuPDF pour la gestion des PDF

def load_pdf(pdf_path):
    """Extrait le texte d'un fichier PDF"""
    if pdf_path is None:
        return None
    text = ""
    try:
        doc = fitz.open(pdf_path)
        for page in doc:
            text += page.get_text()
        doc.close()
        return text
    except Exception as e:
        print(f"Erreur lors de la lecture du PDF: {str(e)}")
        return None

class FluxGenerator:
    def __init__(self):
        self.token = os.getenv('Authentification_HF')
        if not self.token:
            raise ValueError("Token d'authentification HuggingFace non trouvé")
        login(self.token)
        self.pipeline = None
        self.load_model()

    def load_model(self):
        """Charge le modèle FLUX avec des paramètres optimisés"""
        try:
            print("Chargement du modèle FLUX...")
            self.pipeline = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-schnell",
                revision="refs/pr/1",
                torch_dtype=torch.bfloat16
            )
            self.pipeline.enable_model_cpu_offload()
            self.pipeline.tokenizer.add_prefix_space = False
            print("Modèle FLUX chargé avec succès!")
        except Exception as e:
            print(f"Erreur lors du chargement du modèle: {str(e)}")
            raise

    def generate_image(self, prompt, reference_image=None, pdf_file=None):
        """Génère une image à partir d'un prompt et optionnellement une référence"""
        try:
            # Si un PDF est fourni, ajoute son contenu au prompt
            if pdf_file is not None:
                pdf_text = load_pdf(pdf_file)
                if pdf_text:
                    prompt = f"{prompt}\nContexte du PDF:\n{pdf_text}"

            # Génération de l'image
            image = self.pipeline(
                prompt=prompt,
                num_inference_steps=30,
                guidance_scale=0.0,
                max_sequence_length=256,
                generator=torch.Generator("cpu").manual_seed(0)
            ).images[0]

            return image

        except Exception as e:
            print(f"Erreur lors de la génération de l'image: {str(e)}")
            return None

# Instance globale du générateur
generator = FluxGenerator()

def generate(prompt, reference_file):
    """Fonction de génération pour l'interface Gradio"""
    try:
        # Détermine si le fichier de référence est une image ou un PDF
        if reference_file is not None:
            file_type = reference_file.name.split('.')[-1].lower()
            if file_type in ['pdf']:
                return generator.generate_image(prompt, pdf_file=reference_file.name)
            elif file_type in ['png', 'jpg', 'jpeg']:
                return generator.generate_image(prompt, reference_image=reference_file.name)
        
        # Génération sans référence
        return generator.generate_image(prompt)
    
    except Exception as e:
        print(f"Erreur: {str(e)}")
        return None

# Interface Gradio simple
demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(label="Prompt", placeholder="Décrivez l'image que vous souhaitez générer..."),
        gr.File(label="Image ou PDF de référence (optionnel)", type="file")
    ],
    outputs=gr.Image(label="Image générée"),
    title="Test du modèle FLUX",
    description="Interface simple pour tester la génération d'images avec FLUX"
)

if __name__ == "__main__":
    demo.launch()