document-vqa-v2 / main.py
MJobe's picture
Update main.py
246ff82 verified
import fitz
import io
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.responses import JSONResponse
from transformers import pipeline
from PIL import Image
from io import BytesIO
from starlette.middleware import Middleware
from starlette.middleware.cors import CORSMiddleware
from pdf2image import convert_from_bytes
from pydub import AudioSegment
import numpy as np
import json
import torchaudio
import torch
from pydub import AudioSegment
import speech_recognition as sr
import logging
import asyncio
from concurrent.futures import ThreadPoolExecutor
import re
from pydantic import BaseModel
from typing import List, Dict, Any
app = FastAPI()
# Set up CORS middleware
origins = ["*"] # or specify your list of allowed origins
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
code_generation_model = pipeline('text-generation', model='codeparrot/codeparrot-small')
description = """
## Image-based Document QA
This API performs document question answering using a LayoutLMv2-based model.
### Endpoints:
- **POST /uploadfile/:** Upload an image file to extract text and answer provided questions.
- **POST /pdfQA/:** Provide a PDF file to extract text and answer provided questions.
"""
app = FastAPI(docs_url="/", description=description)
@app.post("/generate_code/", description="Generate code based on the provided prompt.")
async def generate_code(prompt: str = Form(...)):
try:
# Use the code generation model to generate code based on the provided prompt
generated_code = code_generation_model(prompt, max_length=200, num_return_sequences=1)
# Extract the generated code from the model's output
generated_code_text = generated_code[0]['generated_text']
# Return the generated code as a response
return {"generated_code": generated_code_text}
except Exception as e:
return JSONResponse(content=f"Error generating code: {str(e)}", status_code=500)
# Set up CORS middleware
origins = ["*"] # or specify your list of allowed origins
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)