Spaces:
Sleeping
Sleeping
Update tools/image_generator.py
Browse files- tools/image_generator.py +12 -3
tools/image_generator.py
CHANGED
|
@@ -5,26 +5,35 @@ import requests
|
|
| 5 |
import base64
|
| 6 |
from PIL import Image
|
| 7 |
from io import BytesIO
|
|
|
|
| 8 |
|
| 9 |
class ImageGeneratorTool(Tool):
|
| 10 |
name = "image_generator"
|
| 11 |
description = "Generate an image based on a textual description using HuggingFace's API"
|
| 12 |
inputs = {'description': {'type': 'string', 'description': 'Text description of the image to generate'}}
|
| 13 |
-
output_type = "string"
|
| 14 |
|
| 15 |
def __init__(self, *args, **kwargs):
|
| 16 |
super().__init__(*args, **kwargs)
|
| 17 |
self.api_token = os.environ.get('HF_TOKEN')
|
| 18 |
if not self.api_token:
|
| 19 |
raise ValueError("HF_TOKEN environment variable not found in Space secrets")
|
| 20 |
-
self.api_url = "https://api-inference.huggingface.co/models/
|
| 21 |
self.headers = {"Authorization": f"Bearer {self.api_token}"}
|
| 22 |
|
| 23 |
def forward(self, description: str) -> str:
|
| 24 |
try:
|
| 25 |
payload = {"inputs": description}
|
|
|
|
|
|
|
| 26 |
response = requests.post(self.api_url, headers=self.headers, json=payload)
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
# Convert the image bytes to base64 string
|
| 30 |
image_base64 = base64.b64encode(response.content).decode('utf-8')
|
|
|
|
| 5 |
import base64
|
| 6 |
from PIL import Image
|
| 7 |
from io import BytesIO
|
| 8 |
+
import time
|
| 9 |
|
| 10 |
class ImageGeneratorTool(Tool):
|
| 11 |
name = "image_generator"
|
| 12 |
description = "Generate an image based on a textual description using HuggingFace's API"
|
| 13 |
inputs = {'description': {'type': 'string', 'description': 'Text description of the image to generate'}}
|
| 14 |
+
output_type = "string"
|
| 15 |
|
| 16 |
def __init__(self, *args, **kwargs):
|
| 17 |
super().__init__(*args, **kwargs)
|
| 18 |
self.api_token = os.environ.get('HF_TOKEN')
|
| 19 |
if not self.api_token:
|
| 20 |
raise ValueError("HF_TOKEN environment variable not found in Space secrets")
|
| 21 |
+
self.api_url = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5"
|
| 22 |
self.headers = {"Authorization": f"Bearer {self.api_token}"}
|
| 23 |
|
| 24 |
def forward(self, description: str) -> str:
|
| 25 |
try:
|
| 26 |
payload = {"inputs": description}
|
| 27 |
+
|
| 28 |
+
# First request might return a loading message
|
| 29 |
response = requests.post(self.api_url, headers=self.headers, json=payload)
|
| 30 |
+
|
| 31 |
+
# If model is loading, wait and retry
|
| 32 |
+
if response.status_code == 503:
|
| 33 |
+
# Wait for 20 seconds
|
| 34 |
+
time.sleep(20)
|
| 35 |
+
response = requests.post(self.api_url, headers=self.headers, json=payload)
|
| 36 |
+
response.raise_for_status()
|
| 37 |
|
| 38 |
# Convert the image bytes to base64 string
|
| 39 |
image_base64 = base64.b64encode(response.content).decode('utf-8')
|