from fastapi import FastAPI, Request, Response, Header from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware from fastapi.templating import Jinja2Templates from fastapi.responses import JSONResponse, FileResponse from logger import custom_log as clog #for logging error from pydantic import BaseModel import torch # from transformers import AutoTokenizer, AutoModelForCausalLM from model import Mallam import config as conf import os, time device = "cuda:0" if torch.cuda.is_available() else "cpu" # # print(device) class MallamRequest(BaseModel): text: str def create_app(): app = FastAPI( title='Mallam FastAPI', description=conf.DESC, ) app.mount("/static", StaticFiles(directory=os.path.join(conf.DIR['PATH'][0], "static/")), name="static") app.add_middleware( CORSMiddleware, allow_origins =conf.CORS['ORIGINS'], allow_methods =conf.CORS['METHODS'], allow_headers =conf.CORS['HEADERS'], ) templates = Jinja2Templates(directory=os.path.join(conf.DIR['PATH'][0], "templates/")) # middleware for API timing (before/after request) @app.middleware("http") async def add_process_time_header(request: Request, call_next): start_time = time.time() try: await Mallam.health_check_llm() except Exception as e: # clog('error', str(e)) print('MiddlewareError : '+str(e)) return Response('Error : '+str(e)) response = await call_next(request) process_time = time.time() - start_time print('Time taken to execute code: '+str(process_time)+' sec' ) response.headers["Execution-Time"] = str(process_time) return response # close all es/async_es connection async def on_shutdown(): pass app.add_event_handler("shutdown", on_shutdown ) summ = 'Test API' @app.get('/', summary=summ) async def root(request : Request): ''' Return html page if API is running ''' try: await Mallam.health_check_llm() except Exception as e: return templates.TemplateResponse("error.html", {"request": request, "err": 'Error : '+str(e)}) return templates.TemplateResponse("nyan.html", {"request": request}) # MALLAM_API - GPT summ = 'Generative Pre-trained Transformer (GPT). Chat function' @app.post('/chat/', tags=["GPT"], summary=summ) async def gpt(request : Request, text = Header(None, convert_underscores=False)): ''' Chat API ``` text = "tanaman pendamping ape yg boleh di tanam dengan tomato ceri?" ``` ''' # Measure the total elapsed time total_start_time = time.time() try: if not (text): return JSONResponse("Empty string!") # load model & tokenizer tokenizer, model = await Mallam.load_model_token() messages = [ {'role': 'user', 'content': text} ] # parse msg to prompt prompt = await Mallam.parse_chat(messages) # tokenize prompt inputs = await Mallam.tokenize(tokenizer, prompt) # GPT r = await Mallam.gpt(inputs, model) # decoding res = await Mallam.decode(tokenizer, r) # Calculate total elapsed time total_end_time = time.time() print(f"Total time elapsed for {device}: {total_end_time - total_start_time:.2f} seconds.") print(res) return JSONResponse(res) except Exception as e: print(e) return JSONResponse("gpt() : "+str(e)) @app.get('/favicon.ico') async def favicon(): ''' For favicon path ''' file_path = os.path.join(conf.DIR['PATH'][0], "static/favicon.ico") return FileResponse(path=file_path, headers={"Content-Disposition": "attachment; filename=favicon.ico"}) summ = 'Return url path if path not exist in API' @app.get('/{full_path:path}', summary=summ) async def path(request : Request): ''' Return url path if path not exist in API ''' request.url.path # get url string # return JSONResponse( { # "string" : "Path : '" +request.url.path+ "' does not exist!", # }) return templates.TemplateResponse("error.html", {"request": request, "err": 'ErrorPath : Path : "' +request.url.path+ '" does not exist!'}) return app