Spaces:
Sleeping
Sleeping
File size: 4,625 Bytes
486275c 9bd9d98 486275c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from fastapi import FastAPI, Request, Response, Header
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from fastapi.responses import JSONResponse, FileResponse
from logger import custom_log as clog #for logging error
from pydantic import BaseModel
import torch
# from transformers import AutoTokenizer, AutoModelForCausalLM
from model import Mallam
import config as conf
import os, time
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# # print(device)
class MallamRequest(BaseModel):
text: str
def create_app():
app = FastAPI(
title='Mallam FastAPI',
description=conf.DESC,
)
app.mount("/static", StaticFiles(directory=os.path.join(conf.DIR['PATH'][0], "static/")), name="static")
app.add_middleware(
CORSMiddleware,
allow_origins =conf.CORS['ORIGINS'],
allow_methods =conf.CORS['METHODS'],
allow_headers =conf.CORS['HEADERS'],
)
templates = Jinja2Templates(directory=os.path.join(conf.DIR['PATH'][0], "templates/"))
# middleware for API timing (before/after request)
@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
start_time = time.time()
try:
await Mallam.health_check_llm()
except Exception as e:
# clog('error', str(e))
print('MiddlewareError : '+str(e))
return Response('Error : '+str(e))
response = await call_next(request)
process_time = time.time() - start_time
print('Time taken to execute code: '+str(process_time)+' sec' )
response.headers["Execution-Time"] = str(process_time)
return response
# close all es/async_es connection
async def on_shutdown():
pass
app.add_event_handler("shutdown", on_shutdown )
summ = 'Test API'
@app.get('/', summary=summ)
async def root(request : Request):
'''
Return html page if API is running
'''
try:
await Mallam.health_check_llm()
except Exception as e:
return templates.TemplateResponse("error.html", {"request": request, "err": 'Error : '+str(e)})
return templates.TemplateResponse("nyan.html", {"request": request})
# MALLAM_API - GPT
summ = 'Generative Pre-trained Transformer (GPT). Chat function'
@app.post('/chat/', tags=["GPT"], summary=summ)
async def gpt(request : Request, text = Header(None, convert_underscores=False)):
'''
Chat API
```
text = "tanaman pendamping ape yg boleh di tanam dengan tomato ceri?"
```
'''
# Measure the total elapsed time
total_start_time = time.time()
try:
if not (text):
return JSONResponse("Empty string!")
# load model & tokenizer
tokenizer, model = await Mallam.load_model_token()
messages = [
{'role': 'user', 'content': text}
]
# parse msg to prompt
prompt = await Mallam.parse_chat(messages)
# tokenize prompt
inputs = await Mallam.tokenize(tokenizer, prompt)
# GPT
r = await Mallam.gpt(inputs, model)
# decoding
res = await Mallam.decode(tokenizer, r)
# Calculate total elapsed time
total_end_time = time.time()
print(f"Total time elapsed for {device}: {total_end_time - total_start_time:.2f} seconds.")
print(res)
return JSONResponse(res)
except Exception as e:
print(e)
return JSONResponse("gpt() : "+str(e))
@app.get('/favicon.ico')
async def favicon():
'''
For favicon path
'''
file_path = os.path.join(conf.DIR['PATH'][0], "static/favicon.ico")
return FileResponse(path=file_path, headers={"Content-Disposition": "attachment; filename=favicon.ico"})
summ = 'Return url path if path not exist in API'
@app.get('/{full_path:path}', summary=summ)
async def path(request : Request):
'''
Return url path if path not exist in API
'''
request.url.path # get url string
# return JSONResponse( {
# "string" : "Path : '" +request.url.path+ "' does not exist!",
# })
return templates.TemplateResponse("error.html", {"request": request, "err": 'ErrorPath : Path : "' +request.url.path+ '" does not exist!'})
return app |