File size: 4,625 Bytes
486275c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9bd9d98
486275c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from fastapi import FastAPI, Request, Response, Header
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.templating import Jinja2Templates
from fastapi.responses import JSONResponse, FileResponse

from logger import custom_log as clog #for logging error
from pydantic import BaseModel

import torch
# from transformers import AutoTokenizer, AutoModelForCausalLM

from model import Mallam
import config as conf
import os, time

device = "cuda:0" if torch.cuda.is_available() else "cpu"
# # print(device)

class MallamRequest(BaseModel):
    text: str

def create_app():
    app = FastAPI(
        title='Mallam FastAPI',
        description=conf.DESC,
    )
    app.mount("/static", StaticFiles(directory=os.path.join(conf.DIR['PATH'][0], "static/")), name="static")
    app.add_middleware(
        CORSMiddleware,
        allow_origins =conf.CORS['ORIGINS'],
        allow_methods =conf.CORS['METHODS'],
        allow_headers =conf.CORS['HEADERS'],
    )

    templates = Jinja2Templates(directory=os.path.join(conf.DIR['PATH'][0], "templates/"))

    # middleware for API timing (before/after request)
    @app.middleware("http")
    async def add_process_time_header(request: Request, call_next):
        start_time = time.time()
        try:
            await Mallam.health_check_llm()
        except Exception as e:
            # clog('error', str(e))
            print('MiddlewareError : '+str(e))
            return Response('Error : '+str(e))

        response = await call_next(request)
        process_time = time.time() - start_time
        print('Time taken to execute code: '+str(process_time)+' sec' )
        response.headers["Execution-Time"] = str(process_time)
        return response
    
    # close all es/async_es connection 
    async def on_shutdown():
        pass
    app.add_event_handler("shutdown", on_shutdown )


    summ = 'Test API'
    @app.get('/', summary=summ)
    async def root(request : Request):
        ''' 
        Return html page if API is running
        '''
        try:
            await Mallam.health_check_llm()
        except Exception as e:
            return templates.TemplateResponse("error.html", {"request": request, "err": 'Error : '+str(e)})

        return templates.TemplateResponse("nyan.html", {"request": request})
    

    # MALLAM_API - GPT
    summ = 'Generative Pre-trained Transformer (GPT). Chat function'
    @app.post('/chat/', tags=["GPT"], summary=summ)
    async def gpt(request : Request, text = Header(None, convert_underscores=False)):
        ''' 
        Chat API

        ```
        text = "tanaman pendamping ape yg boleh di tanam dengan tomato ceri?"
        ```
        '''
        # Measure the total elapsed time
        total_start_time = time.time()

        try:
            if not (text):
                return JSONResponse("Empty string!")
            
            # load model & tokenizer
            tokenizer, model = await Mallam.load_model_token()

            messages = [
                {'role': 'user', 'content': text}
            ]

            # parse msg to prompt
            prompt = await Mallam.parse_chat(messages)

            # tokenize prompt
            inputs = await Mallam.tokenize(tokenizer, prompt)

            # GPT
            r = await Mallam.gpt(inputs, model)

            # decoding
            res = await Mallam.decode(tokenizer, r)

            # Calculate total elapsed time
            total_end_time = time.time()
            print(f"Total time elapsed for {device}: {total_end_time - total_start_time:.2f} seconds.")
            print(res)
            return JSONResponse(res)
        
        except Exception as e:
            print(e)
            return JSONResponse("gpt() : "+str(e))


    @app.get('/favicon.ico')
    async def favicon():
        '''
        For favicon path
        '''
        file_path = os.path.join(conf.DIR['PATH'][0], "static/favicon.ico")
        return FileResponse(path=file_path, headers={"Content-Disposition": "attachment; filename=favicon.ico"})


    summ = 'Return url path if path not exist in API'
    @app.get('/{full_path:path}', summary=summ)
    async def path(request : Request):
        ''' 
        Return url path if path not exist in API
        '''
        request.url.path # get url string
 
        # return JSONResponse( {
        #     "string" : "Path : '" +request.url.path+ "' does not exist!",
        # })
        return templates.TemplateResponse("error.html", {"request": request, "err": 'ErrorPath : Path : "' +request.url.path+ '" does not exist!'})
    
    return app