jkorstad commited on
Commit
c55444e
·
verified ·
1 Parent(s): 3f50134

Update tools/image_generator.py

Browse files
Files changed (1) hide show
  1. tools/image_generator.py +12 -3
tools/image_generator.py CHANGED
@@ -5,26 +5,35 @@ import requests
5
  import base64
6
  from PIL import Image
7
  from io import BytesIO
 
8
 
9
  class ImageGeneratorTool(Tool):
10
  name = "image_generator"
11
  description = "Generate an image based on a textual description using HuggingFace's API"
12
  inputs = {'description': {'type': 'string', 'description': 'Text description of the image to generate'}}
13
- output_type = "string" # Changed from "bytes" to "string"
14
 
15
  def __init__(self, *args, **kwargs):
16
  super().__init__(*args, **kwargs)
17
  self.api_token = os.environ.get('HF_TOKEN')
18
  if not self.api_token:
19
  raise ValueError("HF_TOKEN environment variable not found in Space secrets")
20
- self.api_url = "https://api-inference.huggingface.co/models/prompthero/openjourney"
21
  self.headers = {"Authorization": f"Bearer {self.api_token}"}
22
 
23
  def forward(self, description: str) -> str:
24
  try:
25
  payload = {"inputs": description}
 
 
26
  response = requests.post(self.api_url, headers=self.headers, json=payload)
27
- response.raise_for_status()
 
 
 
 
 
 
28
 
29
  # Convert the image bytes to base64 string
30
  image_base64 = base64.b64encode(response.content).decode('utf-8')
 
5
  import base64
6
  from PIL import Image
7
  from io import BytesIO
8
+ import time
9
 
10
  class ImageGeneratorTool(Tool):
11
  name = "image_generator"
12
  description = "Generate an image based on a textual description using HuggingFace's API"
13
  inputs = {'description': {'type': 'string', 'description': 'Text description of the image to generate'}}
14
+ output_type = "string"
15
 
16
  def __init__(self, *args, **kwargs):
17
  super().__init__(*args, **kwargs)
18
  self.api_token = os.environ.get('HF_TOKEN')
19
  if not self.api_token:
20
  raise ValueError("HF_TOKEN environment variable not found in Space secrets")
21
+ self.api_url = "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5"
22
  self.headers = {"Authorization": f"Bearer {self.api_token}"}
23
 
24
  def forward(self, description: str) -> str:
25
  try:
26
  payload = {"inputs": description}
27
+
28
+ # First request might return a loading message
29
  response = requests.post(self.api_url, headers=self.headers, json=payload)
30
+
31
+ # If model is loading, wait and retry
32
+ if response.status_code == 503:
33
+ # Wait for 20 seconds
34
+ time.sleep(20)
35
+ response = requests.post(self.api_url, headers=self.headers, json=payload)
36
+ response.raise_for_status()
37
 
38
  # Convert the image bytes to base64 string
39
  image_base64 = base64.b64encode(response.content).decode('utf-8')