File size: 17,745 Bytes
e547b24
 
 
 
 
 
d6d5545
e547b24
 
fb64c67
 
d6d5545
 
e547b24
 
 
fb64c67
d6d5545
e547b24
fb64c67
 
d6d5545
 
fb64c67
 
e547b24
fb64c67
a2f84cc
fb64c67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b6cf25
d6d5545
1b6cf25
d6d5545
1b6cf25
 
 
 
 
d6d5545
1b6cf25
 
 
d6d5545
fb64c67
d6d5545
 
fb64c67
 
1b6cf25
fb64c67
e547b24
1b6cf25
 
fb64c67
 
93a3ca8
fb64c67
 
 
1b6cf25
 
 
93a3ca8
fb64c67
1b6cf25
 
 
 
 
 
 
 
 
 
 
 
 
93a3ca8
 
 
1b6cf25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93a3ca8
 
1b6cf25
 
 
 
 
 
93a3ca8
 
1b6cf25
d6d5545
1b6cf25
 
 
93a3ca8
 
 
 
1b6cf25
 
9dc4469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fb64c67
e547b24
fb64c67
1b6cf25
d6d5545
9dc4469
 
2b5053c
4d6cbec
e547b24
 
fb64c67
e547b24
fb64c67
1b6cf25
 
 
fb64c67
 
 
1b6cf25
fb64c67
e547b24
1b6cf25
 
 
 
93a3ca8
 
1b6cf25
fb64c67
 
 
 
 
1b6cf25
 
 
93a3ca8
 
1b6cf25
fb64c67
 
 
 
 
1b6cf25
fb64c67
 
 
 
1b6cf25
 
 
 
93a3ca8
 
1b6cf25
 
fb64c67
d6d5545
877174d
fb64c67
a2f84cc
fb64c67
 
57e6777
 
fb64c67
 
 
 
 
d6d5545
1b6cf25
 
93a3ca8
 
1b6cf25
fb64c67
 
 
1b6cf25
 
 
93a3ca8
 
1b6cf25
 
 
d6d5545
fb64c67
1b6cf25
 
 
93a3ca8
 
1b6cf25
fb64c67
 
 
d6d5545
fb64c67
1b6cf25
 
 
fb64c67
 
 
1b6cf25
 
fb64c67
 
1b6cf25
 
 
 
fb64c67
1b6cf25
 
 
 
 
93a3ca8
 
1b6cf25
 
d6d5545
1b6cf25
 
 
d6d5545
 
 
 
 
1b6cf25
fb64c67
e547b24
1b6cf25
 
 
93a3ca8
 
1b6cf25
 
d6d5545
e547b24
fb64c67
d6d5545
e547b24
02f8cfa
4d6cbec
02f8cfa
 
73f7edc
b64fadc
 
 
d6d5545
 
 
fb64c67
e547b24
 
d6d5545
02f8cfa
 
d6d5545
02f8cfa
 
 
d6d5545
 
 
 
 
 
02f8cfa
 
d6d5545
 
 
 
 
 
 
4d6cbec
 
 
620782f
9e6fc26
fb64c67
e547b24
d6d5545
02f8cfa
 
fb64c67
 
02f8cfa
0075dfa
fb64c67
 
 
d6d5545
fb64c67
 
 
 
 
e547b24
d6d5545
fb64c67
 
 
1b6cf25
d6d5545
 
fb64c67
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
import gradio as gr
import requests
import io
import random
import os
import time
from PIL import Image, UnidentifiedImageError # Added UnidentifiedImageError
from deep_translator import GoogleTranslator
import json
import uuid
from urllib.parse import quote
import traceback # For detailed error logging
from openai import OpenAI, RateLimitError, APIConnectionError, AuthenticationError # Added OpenAI errors

# Project by Nymbo

# --- Constants ---
API_URL = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-3.5-large-turbo"
API_TOKEN = os.getenv("HF_READ_TOKEN")
if not API_TOKEN:
    print("WARNING: HF_READ_TOKEN environment variable not set. API calls may fail.")
    # Optionally, raise an error or exit if the token is essential
    # raise ValueError("Missing required environment variable: HF_READ_TOKEN")
headers = {"Authorization": f"Bearer {API_TOKEN}"} if API_TOKEN else {}
timeout = 100 # seconds for API call timeout

IMAGE_DIR = "temp_generated_images" # Directory to store temporary images
ARINTELI_REDIRECT_BASE = "https://arinteli.com/app/" # Your redirector URL

# --- Ensure temporary directory exists ---
try:
    os.makedirs(IMAGE_DIR, exist_ok=True)
    print(f"Confirmed temporary image directory exists: {IMAGE_DIR}")
except OSError as e:
    print(f"ERROR: Could not create directory {IMAGE_DIR}: {e}")
    # This is critical, so raise an error to prevent app start if dir fails
    raise gr.Error(f"Fatal Error: Cannot create temporary image directory: {e}")

# --- Get Absolute Path for allowed_paths ---
# This needs to be done *before* calling launch()
absolute_image_dir = os.path.abspath(IMAGE_DIR)
print(f"Absolute path for allowed_paths: {absolute_image_dir}")

# --- OpenAI Client ---
# Initialize OpenAI client (it will automatically look for the OPENAI_API_KEY env var)
try:
    # Check if the environment variable exists and is not empty
    openai_api_key = os.getenv("OPENAI_API_KEY")
    if not openai_api_key:
        print("WARNING: OPENAI_API_KEY environment variable not set or empty. Moderation will be skipped.")
        openai_client = None
    else:
        openai_client = OpenAI() # Key is implicitly used by the library
        print("OpenAI client initialized successfully.")
except Exception as e:
    print(f"ERROR: Failed to initialize OpenAI client: {e}. Moderation will be skipped.")
    openai_client = None # Set to None so we can check later

# --- Function to query the API and return the generated image and download link ---
def query(prompt, negative_prompt, steps=4, cfg_scale=0, seed=-1, width=1024, height=1024):
    # Basic Input Validation
    if not prompt or not prompt.strip():
        print("WARNING: Empty prompt received.\n") # Add newline for separation
        return None, "<p style='color: orange; text-align: center;'>Please enter a prompt.</p>"

    # Store original prompt for logging
    original_prompt = prompt

    # Translation
    translated_prompt = prompt # Start with original
    try:
        translated_prompt = GoogleTranslator(source='auto', target='en').translate(prompt)
    except Exception as e:
        print(f"WARNING: Translation failed. Using original prompt.")
        print(f"  Error: {e}")
        print(f"  Prompt: '{original_prompt}'\n") # Add newline
        translated_prompt = prompt # Ensure it's the original if translation fails

    # --- OpenAI Moderation Check ---
    if openai_client:
        try:
            mod_response = openai_client.moderations.create(
                model="omni-moderation-latest",
                input=translated_prompt
            )
            result = mod_response.results[0]

            if result.categories.sexual_minors:
                print("BLOCKED:")
                print(f"  Reason: sexual/minors")
                print(f"  Prompt: '{original_prompt}'")
                if translated_prompt != original_prompt:
                    print(f"  Translated: '{translated_prompt}'")
                print("") # Add newline
                return None, "<p style='color: red; text-align: center;'>Prompt violates safety guidelines. Generation blocked.</p>"

        except AuthenticationError:
             print("BLOCKED:")
             print(f"  Reason: OpenAI Auth Error")
             print(f"  Prompt: '{original_prompt}'\n") # Add newline
             return None, "<p style='color: red; text-align: center;'>Safety check failed. Generation blocked.</p>"
        except RateLimitError:
             print("BLOCKED:")
             print(f"  Reason: OpenAI Rate Limit")
             print(f"  Prompt: '{original_prompt}'\n") # Add newline
             return None, "<p style='color: red; text-align: center;'>Safety check failed. Please try again later.</p>"
        except APIConnectionError as e:
             print("BLOCKED:")
             print(f"  Reason: OpenAI Connection Error")
             print(f"  Prompt: '{original_prompt}'")
             if translated_prompt != original_prompt:
                 print(f"  Translated: '{translated_prompt}'")
             print(f"  Error: {e}\n") # Add newline
             return None, "<p style='color: red; text-align: center;'>Safety check failed. Please try again later.</p>"
        except Exception as e:
            print("BLOCKED:")
            print(f"  Reason: OpenAI Unexpected Error")
            print(f"  Prompt: '{original_prompt}'")
            if translated_prompt != original_prompt:
                print(f"  Translated: '{translated_prompt}'")
            print(f"  Error: {e}\n") # Add newline
            traceback.print_exc() # Keep traceback for unexpected errors
            return None, "<p style='color: red; text-align: center;'>An unexpected error occurred during safety check. Generation blocked.</p>"
    else:
        print(f"WARNING: OpenAI client not available. Skipping moderation.")
        print(f"  Prompt: '{original_prompt}'") # Log original prompt even if skipping
        if translated_prompt != original_prompt:
             print(f"  (Would use translated: '{translated_prompt}')")
        print("") # Add newline

    # --- Proceed with Generation ---
    # Define a threshold for modifying prompts based on sexual content score
    SEXUAL_SCORE_THRESHOLD = 0.6
    prompt_suffix = " | ultra detail, ultra elaboration, ultra quality, perfect" # Base quality enhancers
    negative_prompt_addition = "" # Start with no additions

    # Check moderation scores if available and adjust prompts
    if openai_client and 'mod_response' in locals(): # Check if moderation ran and response exists
        result = mod_response.results[0]
        if hasattr(result, 'category_scores') and result.category_scores.sexual > SEXUAL_SCORE_THRESHOLD:
            print(f"INFO: High 'sexual' score detected ({result.category_scores.sexual:.4f}). Modifying prompts.")
            prompt_suffix += "adult face, mature face, grown-up, of age, " 
            negative_prompt_addition = "child, minor, underage, teenager, "

    final_prompt = f"{prompt_suffix}{translated_prompt}"
    final_negative_prompt = f"{negative_prompt_addition}{negative_prompt}".strip(', ') # Combine and clean up potential leading comma

    payload = {
        "inputs": final_prompt,
        "parameters": {
            "width": width, "height": height, "num_inference_steps": steps,
            "negative_prompt": final_negative_prompt, # Use the potentially modified negative prompt
            "guidance_scale": cfg_scale,
            "seed": seed if seed != -1 else random.randint(1, 1000000000),
        }
    }

    # API Call Section
    try:
        if not headers:
             print("FAILED:")
             print(f"  Reason: HF Token Missing")
             print(f"  Prompt: '{original_prompt}'\n") # Add newline
             return None, "<p style='color: red; text-align: center;'>Configuration Error: API Token missing.</p>"

        response = requests.post(API_URL, headers=headers, json=payload, timeout=timeout)
        response.raise_for_status()

        image_bytes = response.content
        if not image_bytes or len(image_bytes) < 100:
             print("FAILED:")
             print(f"  Reason: Invalid Image Data (Empty/Small)")
             print(f"  Prompt: '{original_prompt}'")
             if translated_prompt != original_prompt:
                 print(f"  Translated: '{translated_prompt}'")
             print(f"  Length: {len(image_bytes)}\n") # Add newline
             return None, "<p style='color: red; text-align: center;'>API returned invalid image data.</p>"

        try:
            image = Image.open(io.BytesIO(image_bytes))
        except UnidentifiedImageError as img_err:
             print("FAILED:")
             print(f"  Reason: Image Processing Error")
             print(f"  Prompt: '{original_prompt}'")
             if translated_prompt != original_prompt:
                 print(f"  Translated: '{translated_prompt}'")
             print(f"  Error: {img_err}\n") # Add newline
             return None, "<p style='color: red; text-align: center;'>Failed to process image data from API.</p>"

        # --- Save image and create download link ---
        filename = f"{int(time.time())}_{uuid.uuid4().hex[:8]}.png"
        save_path = os.path.join(IMAGE_DIR, filename)
        absolute_save_path = os.path.abspath(save_path)

        try:
            image.save(save_path, "PNG")

            if not os.path.exists(save_path) or os.path.getsize(save_path) < 100:
                 print("FAILED:")
                 print(f"  Reason: Image Save Verification Error")
                 print(f"  Prompt: '{original_prompt}'")
                 if translated_prompt != original_prompt:
                      print(f"  Translated: '{translated_prompt}'")
                 print(f"  Path: '{save_path}'\n") # Add newline
                 return image, "<p style='color: red; text-align: center;'>Internal Error: Failed to confirm image file save.</p>"

            space_name = "greendra-stable-diffusion-3-5-large-serverless"
            relative_file_url = f"gradio_api/file={save_path}"
            encoded_file_url = quote(relative_file_url)
            arinteli_url = f"{ARINTELI_REDIRECT_BASE}?download_url={encoded_file_url}&space_name={space_name}"

            download_html = (
                f'<div style="text-align: right; margin-right: -8px;">'
                f'<a href="{arinteli_url}" target="_blank" style="background: #3dd49f; color: white; padding: 7px 25px; border-radius: 6px; text-decoration: none;">'
                f'Download Image'
                f'</a>'
                f'</div>'
            )

            # *** SUCCESS LOG ***
            print("SUCCESS:")
            print(f"  Prompt: '{original_prompt}'")
            if translated_prompt != original_prompt:
                 print(f"  Translated: '{translated_prompt}'")
            print(f"  {arinteli_url}\n") # Add newline
            return image, download_html

        except (OSError, IOError) as save_err:
            print("FAILED:")
            print(f"  Reason: Image Save IO Error")
            print(f"  Prompt: '{original_prompt}'")
            if translated_prompt != original_prompt:
                 print(f"  Translated: '{translated_prompt}'")
            print(f"  Path: '{save_path}'")
            print(f"  Error: {save_err}\n") # Add newline
            traceback.print_exc()
            return image, f"<p style='color: red; text-align: center;'>Internal Error: Failed to save image file.</p>"
        except Exception as e:
            print("FAILED:")
            print(f"  Reason: Link Creation/Save Unexpected Error")
            print(f"  Prompt: '{original_prompt}'")
            if translated_prompt != original_prompt:
                 print(f"  Translated: '{translated_prompt}'")
            print(f"  Error: {e}\n") # Add newline
            traceback.print_exc()
            return image, "<p style='color: red; text-align: center;'>Internal Error creating download link.</p>"

    # --- Exception Handling for API Call ---
    except requests.exceptions.Timeout:
        print("FAILED:")
        print(f"  Reason: HF API Timeout")
        print(f"  Prompt: '{original_prompt}'\n") # Add newline
        return None, "<p style='color: red; text-align: center;'>Request timed out. The model is taking too long.</p>"
    except requests.exceptions.HTTPError as e:
        status_code = e.response.status_code
        error_text = e.response.text
        error_data = {}
        try:
            error_data = e.response.json()
            parsed_error = error_data.get('error', error_text)
            if isinstance(parsed_error, dict) and 'message' in parsed_error: error_text = parsed_error['message']
            elif isinstance(parsed_error, list): error_text = "; ".join(map(str, parsed_error))
            else: error_text = str(parsed_error)
        except json.JSONDecodeError:
            pass

        print("FAILED:")
        print(f"  Reason: HF API HTTP Error {status_code}")
        print(f"  Prompt: '{original_prompt}'")
        if translated_prompt != original_prompt:
             print(f"  Translated: '{translated_prompt}'")
        print(f"  Details: '{error_text[:200]}'\n") # Add newline

        # User-facing messages remain the same
        if status_code == 503:
             estimated_time = error_data.get("estimated_time") if isinstance(error_data, dict) else None
             error_message = f"Model is loading (503), please wait." + (f" Est. time: {estimated_time:.1f}s." if estimated_time else "") + " Try again."
        elif status_code == 400: error_message = f"Bad Request (400): Check parameters."
        elif status_code == 422: error_message = f"Validation Error (422): Input invalid."
        elif status_code == 401 or status_code == 403: error_message = f"Authorization Error ({status_code}): Check API Token."
        elif status_code == 429: error_message = f"Rate Limit Error (429): Too many requests. Try again later."
        else: error_message = f"API Error ({status_code}). Please try again."

        return None, f"<p style='color: red; text-align: center;'>{error_message}</p>"
    except Exception as e:
        print("FAILED:")
        print(f"  Reason: Unexpected Error During Generation")
        print(f"  Prompt: '{original_prompt}'")
        if translated_prompt != original_prompt:
             print(f"  Translated: '{translated_prompt}'")
        print(f"  Error: {e}\n") # Add newline
        traceback.print_exc()
        return None, f"<p style='color: red; text-align: center;'>An unexpected error occurred. Please check logs.</p>"


# --- CSS Styling ---
css = """
#app-container {
    max-width: 800px;
    margin-left: auto;
    margin-right: auto;
}
textarea:focus {
    background: #0d1117 !important;
}
#download-link-container p { /* Style the link container */
    margin-top: 10px; /* Add some space above the link */
    font-size: 0.9em; /* Slightly smaller text for the link message */
}
"""

# --- Build the Gradio UI with Blocks ---
with gr.Blocks(theme='Nymbo/Nymbo_Theme', css=css) as app:
    with gr.Column(elem_id="app-container"):
        # --- Input Components ---
        with gr.Row():
            with gr.Column(elem_id="prompt-container"):
                with gr.Row():
                    text_prompt = gr.Textbox(
                        label="Prompt",
                        placeholder="Enter a prompt here",
                        lines=2,
                        elem_id="prompt-text-input"
                    )
                with gr.Row():
                    with gr.Accordion("Advanced Settings", open=False):
                        negative_prompt = gr.Textbox(
                            label="Negative Prompt",
                            placeholder="What should not be in the image",
                            value="(deformed, distorted, disfigured), poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, (mutated hands and fingers), disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation, misspellings, typos",
                            lines=3,
                            elem_id="negative-prompt-text-input"
                        )
                        with gr.Row():
                            width = gr.Slider(label="Width", value=1024, minimum=64, maximum=1216, step=32)
                            height = gr.Slider(label="Height", value=1024, minimum=64, maximum=1216, step=32)
                        steps = gr.Slider(label="Sampling steps", value=4, minimum=1, maximum=8, step=1)
                        cfg = gr.Slider(label="CFG Scale (guidance_scale)", value=0, minimum=0, maximum=10, step=1)
                        seed = gr.Slider(label="Seed", value=-1, minimum=-1, maximum=1000000000, step=1, info="Set to -1 for random seed")

        # --- Action Button ---
        with gr.Row():
            text_button = gr.Button("Run", variant='primary', elem_id="gen-button")

        # --- Output Components ---
        with gr.Row():
            image_output = gr.Image(type="pil", label="Image Output", elem_id="gallery", show_label=False, show_download_button=False, show_share_button=False, show_fullscreen_button=False)
        with gr.Row():
             download_link_display = gr.HTML(elem_id="download-link-container")

        # --- Event Listener ---
        text_button.click(
            query,
            inputs=[text_prompt, negative_prompt, steps, cfg, seed, width, height],
            outputs=[image_output, download_link_display]
        )

# --- Launch the Gradio app ---
print("Starting Gradio app...")
app.launch(
    show_api=False,
    share=False,
    allowed_paths=[absolute_image_dir],
    # server_name="0.0.0.0"
)