Spaces:
Build error
Build error
from fastapi import FastAPI, Form, HTTPException | |
from fastapi.responses import FileResponse | |
from transformers import pipeline | |
from reportlab.lib.pagesizes import letter | |
from reportlab.pdfgen import canvas | |
import requests | |
import ffmpeg | |
import whisper | |
import soundfile as sf | |
import librosa | |
import numpy as np | |
import uuid | |
import os | |
import logging | |
# # Set the environment variable for Hugging Face cache | |
# os.environ["HF_HOME"] = "/app/.cache/huggingface" | |
# os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface" | |
# Logging for debugging | |
logging.info("HF_HOME: " + os.environ["HF_HOME"]) | |
logging.info("TRANSFORMERS_CACHE: " + os.environ["TRANSFORMERS_CACHE"]) | |
# Initialize FastAPI app | |
app = FastAPI( | |
title="Dubsway Media Analyzer", | |
description="Analyze video/audio content from URLs and generate detailed PDFs.", | |
version="1.0", | |
) | |
async def analyze_media( | |
media_url: str = Form(...), | |
detailed: bool = Form(default=True) | |
): | |
""" | |
Analyze a video/audio from a given CDN URL and generate a detailed PDF report. | |
Args: | |
media_url: URL of the video/audio file. | |
detailed: Whether to include detailed explanations in the report. | |
""" | |
try: | |
# Generate unique filenames | |
unique_id = str(uuid.uuid4()) | |
video_path = f"temp_{unique_id}.mp4" | |
audio_path = f"temp_audio_{unique_id}.wav" | |
pdf_path = f"analysis_{unique_id}.pdf" | |
# Download the video/audio file | |
response = requests.get(media_url, stream=True) | |
if response.status_code != 200: | |
raise HTTPException(status_code=400, detail="Failed to download media file.") | |
with open(video_path, "wb") as f: | |
for chunk in response.iter_content(chunk_size=8192): | |
f.write(chunk) | |
# Extract audio from the media | |
ffmpeg.input(video_path).output(audio_path, ac=1, ar=16000).run(overwrite_output=True) | |
# Load and transcribe the audio | |
model = whisper.load_model("base") | |
with sf.SoundFile(audio_path) as audio_file: | |
audio_data = audio_file.read(dtype="float32") | |
sample_rate = audio_file.samplerate | |
# Resample audio to 16 kHz if needed | |
if sample_rate != 16000: | |
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000) | |
# Transcribe the audio | |
result = model.transcribe(audio=np.array(audio_data)) | |
transcription = result["text"] | |
# Generate detailed explanation (if required) | |
if detailed: | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
explanation = summarizer(transcription, max_length=1024, min_length=256, do_sample=False)[0]["summary_text"] | |
else: | |
explanation = transcription | |
# Create a PDF | |
generate_pdf(pdf_path, transcription, explanation) | |
# Clean up temporary files | |
os.remove(video_path) | |
os.remove(audio_path) | |
# Return the PDF | |
return FileResponse(pdf_path, media_type="application/pdf", filename=f"analysis_{unique_id}.pdf") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error analyzing media: {e}") | |
def generate_pdf(pdf_path: str, transcription: str, explanation: str): | |
""" | |
Generate a PDF containing the transcription and detailed explanation. | |
Args: | |
pdf_path: Path to save the PDF. | |
transcription: The transcription text. | |
explanation: The detailed explanation text. | |
""" | |
c = canvas.Canvas(pdf_path, pagesize=letter) | |
width, height = letter | |
# Add Title | |
c.setFont("Helvetica-Bold", 16) | |
c.drawString(72, height - 72, "Media Analysis Report") | |
# Add Transcription | |
c.setFont("Helvetica", 12) | |
c.drawString(72, height - 108, "Transcription:") | |
text = c.beginText(72, height - 126) | |
text.setFont("Helvetica", 10) | |
for line in transcription.splitlines(): | |
text.textLine(line) | |
c.drawText(text) | |
# Add Explanation | |
c.setFont("Helvetica", 12) | |
c.drawString(72, height - 240, "Detailed Explanation:") | |
text = c.beginText(72, height - 258) | |
text.setFont("Helvetica", 10) | |
for line in explanation.splitlines(): | |
text.textLine(line) | |
c.drawText(text) | |
c.save() | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |