Spaces:
Sleeping
Sleeping
File size: 2,941 Bytes
a950115 93355f1 5ec57ad 43cfc3a 09dbf6d 5ec57ad 43cfc3a a950115 43cfc3a 09dbf6d a950115 43cfc3a 09dbf6d 93355f1 5ec57ad 09dbf6d 93355f1 cf636f6 93355f1 09dbf6d 93355f1 09dbf6d 93355f1 09dbf6d 93355f1 8481c28 019b8b7 09dbf6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import story_generator
from diffusers import DiffusionPipeline
import io
import base64
from PIL import Image
app = FastAPI()
# β
Set Hugging Face cache directory to /tmp (Fixes cache write errors)
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface"
# β
Load Public Image Generation Model (No Token Needed)
IMAGE_MODEL = "stabilityai/sdxl-turbo" # Fastest model for public access
pipeline = DiffusionPipeline.from_pretrained(
IMAGE_MODEL,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to("cuda" if torch.cuda.is_available() else "cpu")
# β
Define the input request format
class StoryRequest(BaseModel):
theme: str
reading_level: str
@app.post("/generate_story_questions_images")
def generate_story_questions_images(request: StoryRequest):
"""Generates a story, questions, and corresponding images."""
try:
print(f"π Generating story for theme: {request.theme} and level: {request.reading_level}")
# β
Generate the story and questions
if not hasattr(story_generator, "generate_story_and_questions"):
raise HTTPException(status_code=500, detail="Story generation function not found in story_generator.py")
story_result = story_generator.generate_story_and_questions(request.theme, request.reading_level)
story_text = story_result["story"]
questions = story_result["questions"]
# β
Split the story into sentences for image generation
story_sentences = story_text.strip().split(". ")
# β
Generate an image for each sentence
images = []
for sentence in story_sentences:
if len(sentence) > 5: # Avoid empty sentences
print(f"πΌοΈ Generating image for: {sentence}")
image = pipeline(prompt=sentence, num_inference_steps=5).images[0]
# Convert Image to Base64
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format="PNG")
img_base64 = base64.b64encode(img_byte_arr.getvalue()).decode("utf-8")
images.append({"sentence": sentence, "image": img_base64})
# β
Return the full response
return {
"theme": request.theme,
"reading_level": request.reading_level,
"story": story_text,
"questions": questions,
"images": images,
}
except Exception as e:
print(f"β Error generating story, questions, or images: {e}")
raise HTTPException(status_code=500, detail=str(e))
# β
Welcome message at root
@app.get("/")
def home():
return {"message": "π Welcome to the Story, Question & Image Generation API! Use /generate_story_questions_images"}
|