Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, Request, Response, Header | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.templating import Jinja2Templates | |
from fastapi.responses import JSONResponse, FileResponse | |
from logger import custom_log as clog #for logging error | |
from pydantic import BaseModel | |
import torch | |
# from transformers import AutoTokenizer, AutoModelForCausalLM | |
from model import Mallam | |
import config as conf | |
import os, time | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# # print(device) | |
class MallamRequest(BaseModel): | |
text: str | |
def create_app(): | |
app = FastAPI( | |
title='Mallam FastAPI', | |
description=conf.DESC, | |
) | |
app.mount("/static", StaticFiles(directory=os.path.join(conf.DIR['PATH'][0], "static/")), name="static") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins =conf.CORS['ORIGINS'], | |
allow_methods =conf.CORS['METHODS'], | |
allow_headers =conf.CORS['HEADERS'], | |
) | |
templates = Jinja2Templates(directory=os.path.join(conf.DIR['PATH'][0], "templates/")) | |
# middleware for API timing (before/after request) | |
async def add_process_time_header(request: Request, call_next): | |
start_time = time.time() | |
try: | |
await Mallam.health_check_llm() | |
except Exception as e: | |
# clog('error', str(e)) | |
print('MiddlewareError : '+str(e)) | |
return Response('Error : '+str(e)) | |
response = await call_next(request) | |
process_time = time.time() - start_time | |
print('Time taken to execute code: '+str(process_time)+' sec' ) | |
response.headers["Execution-Time"] = str(process_time) | |
return response | |
# close all es/async_es connection | |
async def on_shutdown(): | |
pass | |
app.add_event_handler("shutdown", on_shutdown ) | |
summ = 'Test API' | |
async def root(request : Request): | |
''' | |
Return html page if API is running | |
''' | |
try: | |
await Mallam.health_check_llm() | |
except Exception as e: | |
return templates.TemplateResponse("error.html", {"request": request, "err": 'Error : '+str(e)}) | |
return templates.TemplateResponse("nyan.html", {"request": request}) | |
# MALLAM_API - GPT | |
summ = 'Generative Pre-trained Transformer (GPT). Chat function' | |
async def gpt(request : Request, text = Header(None, convert_underscores=False)): | |
''' | |
Chat API | |
``` | |
text = "tanaman pendamping ape yg boleh di tanam dengan tomato ceri?" | |
``` | |
''' | |
# Measure the total elapsed time | |
total_start_time = time.time() | |
try: | |
if not (text): | |
return JSONResponse("Empty string!") | |
# load model & tokenizer | |
tokenizer, model = await Mallam.load_model_token() | |
messages = [ | |
{'role': 'user', 'content': text} | |
] | |
# parse msg to prompt | |
prompt = await Mallam.parse_chat(messages) | |
# tokenize prompt | |
inputs = await Mallam.tokenize(tokenizer, prompt) | |
# GPT | |
r = await Mallam.gpt(inputs, model) | |
# decoding | |
res = await Mallam.decode(tokenizer, r) | |
# Calculate total elapsed time | |
total_end_time = time.time() | |
print(f"Total time elapsed for {device}: {total_end_time - total_start_time:.2f} seconds.") | |
print(res) | |
return JSONResponse(res) | |
except Exception as e: | |
print(e) | |
return JSONResponse("gpt() : "+str(e)) | |
async def favicon(): | |
''' | |
For favicon path | |
''' | |
file_path = os.path.join(conf.DIR['PATH'][0], "static/favicon.ico") | |
return FileResponse(path=file_path, headers={"Content-Disposition": "attachment; filename=favicon.ico"}) | |
summ = 'Return url path if path not exist in API' | |
async def path(request : Request): | |
''' | |
Return url path if path not exist in API | |
''' | |
request.url.path # get url string | |
# return JSONResponse( { | |
# "string" : "Path : '" +request.url.path+ "' does not exist!", | |
# }) | |
return templates.TemplateResponse("error.html", {"request": request, "err": 'ErrorPath : Path : "' +request.url.path+ '" does not exist!'}) | |
return app |