Arifzyn19 commited on
Commit
69172f9
·
1 Parent(s): d819961

Update: app.py and requirement

Browse files
Files changed (2) hide show
  1. app.py +90 -7
  2. requirements.txt +6 -4
app.py CHANGED
@@ -2,18 +2,66 @@ import torch
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
5
 
6
  app = FastAPI()
7
 
8
- model_id = "mistralai/Mistral-7B-Instruct-v0.1" # example model
9
- tokenizer = AutoTokenizer.from_pretrained(model_id)
10
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  class ChatRequest(BaseModel):
13
  messages: list
14
 
 
15
  @app.post("/chat")
16
  async def chat(req: ChatRequest):
 
 
 
17
  prompt = ""
18
  for msg in req.messages:
19
  role = msg['role']
@@ -23,13 +71,48 @@ async def chat(req: ChatRequest):
23
 
24
  # Encode the prompt
25
  inputs = tokenizer(prompt, return_tensors="pt")
26
- inputs = {key: value.to(model.device) for key, value in inputs.items()}
27
-
 
 
 
 
 
 
 
 
 
 
 
 
28
  # Generate a response
29
- output = model.generate(inputs['input_ids'], max_new_tokens=100)
 
 
 
 
30
 
31
  # Decode the output
32
  result = tokenizer.decode(output[0], skip_special_tokens=True)
33
 
34
  # Return the response, removing the prompt part
35
- return {"response": result.replace(prompt, "").strip()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import os
6
+ import gc
7
 
8
  app = FastAPI()
9
 
10
+ # Model configuration
11
+ model_id = "mistralai/Mistral-7B-Instruct-v0.1"
12
+ model_dir = "model_cache" # Direktori untuk menyimpan model
13
+
14
+ # Variabel global untuk menyimpan model dan tokenizer
15
+ tokenizer = None
16
+ model = None
17
+
18
+ def load_model():
19
+ """Fungsi untuk memuat atau mengunduh model saat dibutuhkan"""
20
+ global tokenizer, model
21
+
22
+ # Cek apakah model telah dimuat
23
+ if tokenizer is None or model is None:
24
+ print(f"Loading model {model_id}...")
25
+
26
+ # Buat direktori cache jika belum ada
27
+ os.makedirs(model_dir, exist_ok=True)
28
+
29
+ # Bersihkan memori jika ada model sebelumnya
30
+ if model is not None:
31
+ del model
32
+ torch.cuda.empty_cache()
33
+ gc.collect()
34
+
35
+ # Muat tokenizer dengan cache
36
+ tokenizer = AutoTokenizer.from_pretrained(
37
+ model_id,
38
+ cache_dir=model_dir,
39
+ use_fast=True
40
+ )
41
+
42
+ # Muat model dengan cache dan pengaturan hemat memori
43
+ device_map = "auto" if torch.cuda.is_available() else None
44
+
45
+ model = AutoModelForCausalLM.from_pretrained(
46
+ model_id,
47
+ cache_dir=model_dir,
48
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
49
+ low_cpu_mem_usage=True,
50
+ device_map=device_map
51
+ )
52
+
53
+ print("Model loaded successfully!")
54
+
55
 
56
  class ChatRequest(BaseModel):
57
  messages: list
58
 
59
+
60
  @app.post("/chat")
61
  async def chat(req: ChatRequest):
62
+ # Pastikan model dimuat sebelum digunakan
63
+ load_model()
64
+
65
  prompt = ""
66
  for msg in req.messages:
67
  role = msg['role']
 
71
 
72
  # Encode the prompt
73
  inputs = tokenizer(prompt, return_tensors="pt")
74
+
75
+ # Pindahkan input ke device yang sama dengan model
76
+ if hasattr(model, 'device'):
77
+ inputs = {key: value.to(model.device) for key, value in inputs.items()}
78
+
79
+ # Set parameter generasi yang lebih sesuai
80
+ generation_config = {
81
+ 'max_new_tokens': 500,
82
+ 'temperature': 0.7,
83
+ 'top_p': 0.9,
84
+ 'do_sample': True,
85
+ 'pad_token_id': tokenizer.eos_token_id
86
+ }
87
+
88
  # Generate a response
89
+ with torch.no_grad():
90
+ output = model.generate(
91
+ inputs['input_ids'],
92
+ **generation_config
93
+ )
94
 
95
  # Decode the output
96
  result = tokenizer.decode(output[0], skip_special_tokens=True)
97
 
98
  # Return the response, removing the prompt part
99
+ return {"response": result.replace(prompt, "").strip()}
100
+
101
+
102
+ @app.get("/model-status")
103
+ async def model_status():
104
+ if model is None:
105
+ return {"status": "not_loaded", "model_id": model_id}
106
+ return {"status": "loaded", "model_id": model_id}
107
+
108
+
109
+ @app.post("/load-model")
110
+ async def force_load_model():
111
+ load_model()
112
+ return {"status": "success", "message": f"Model {model_id} dimuat berhasil"}
113
+
114
+
115
+ # Untuk menjalankan dengan uvicorn
116
+ if __name__ == "__main__":
117
+ import uvicorn
118
+ uvicorn.run(app, host="0.0.0.0", port=7860) # Port 7860 adalah port default di HF Spaces
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
- fastapi
2
- uvicorn[standard]
3
- transformers
4
- torch
 
 
 
1
+ fastapi==0.110.0
2
+ uvicorn==0.27.1
3
+ transformers==4.38.1
4
+ torch>=2.0.0
5
+ pydantic==2.6.1
6
+ accelerate==0.25.0