Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import story_generator # β Import Story Generator | |
from diffusers import DiffusionPipeline | |
import io | |
import base64 | |
from PIL import Image | |
from huggingface_hub import login | |
app = FastAPI() | |
# β Set Hugging Face cache directory to /tmp | |
os.environ["HF_HOME"] = "/tmp/huggingface" | |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface" | |
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface" | |
# β Hugging Face Authentication (Only needed if model is private) | |
HF_TOKEN = "your_huggingface_token_here" # Replace this with your actual token | |
login(token=HF_TOKEN) | |
# β Load Image Generation Model (Use a fast, public model) | |
IMAGE_MODEL = "stabilityai/sdxl-turbo" # Replace with "stabilityai/sdxl-lightning" if needed | |
pipeline = DiffusionPipeline.from_pretrained( | |
IMAGE_MODEL, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
use_auth_token=HF_TOKEN # Required for private models | |
).to("cuda" if torch.cuda.is_available() else "cpu") | |
# β Define the input request format | |
class StoryRequest(BaseModel): | |
theme: str | |
reading_level: str | |
def generate_story_questions_images(request: StoryRequest): | |
"""Generates a story, questions, and corresponding images.""" | |
try: | |
print(f"π Generating story for theme: {request.theme} and level: {request.reading_level}") | |
# β Generate the story and questions | |
if not hasattr(story_generator, "generate_story_and_questions"): | |
raise HTTPException(status_code=500, detail="Story generation function not found in story_generator.py") | |
story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level) | |
story_text = story_result["story"] | |
questions = story_result["questions"] | |
# β Split the story into sentences for image generation | |
story_sentences = story_text.strip().split(". ") | |
# β Generate an image for each sentence | |
images = [] | |
for sentence in story_sentences: | |
if len(sentence) > 5: # Avoid empty sentences | |
print(f"πΌοΈ Generating image for: {sentence}") | |
image = pipeline(prompt=sentence, num_inference_steps=5).images[0] | |
# Convert Image to Base64 | |
img_byte_arr = io.BytesIO() | |
image.save(img_byte_arr, format="PNG") | |
img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8") | |
images.append({"sentence": sentence, "image": img_base64}) | |
# β Return the full response | |
return { | |
"theme": request.theme, | |
"reading_level": request.reading_level, | |
"story": story_text, | |
"questions": questions, | |
"images": images, | |
} | |
except Exception as e: | |
print(f"β Error generating story, questions, or images: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# β Welcome message at root | |
def home(): | |
return {"message": "π Welcome to the Story, Question & Image Generation API! Use /generate_story_questions_images"} | |