wli1995 commited on
Commit
79d1704
·
verified ·
1 Parent(s): 2d40d52

add run_api.sh

Browse files
Files changed (2) hide show
  1. gradio_demo.py +136 -0
  2. run_qwen2.5_1.5b_ctx_ax650_api.sh +15 -0
gradio_demo.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import gradio as gr
3
+ import requests
4
+ import json
5
+
6
+ # Base URL of your API server; adjust host and port as needed
7
+ API_URL = "http://10.168.232.93:8000"
8
+
9
+
10
+ def reset_chat(system_prompt):
11
+ """
12
+ Calls the /api/reset endpoint (POST) to initialize a new conversation.
13
+ If system_prompt is provided, include it in the request body.
14
+ Returns empty history and clears input. On error, shows error in chat.
15
+ """
16
+ payload = {}
17
+ if system_prompt:
18
+ payload["system_prompt"] = system_prompt
19
+ try:
20
+ response = requests.post(f"{API_URL}/api/reset", json=payload)
21
+ response.raise_for_status()
22
+ except Exception as e:
23
+ # Return error in chat if reset fails
24
+ return [("Error resetting chat:", str(e))], ""
25
+ # On successful reset, clear chat history and input
26
+ return [], ""
27
+
28
+
29
+ def stream_generate(history, message, temperature, repetition_penalty, top_p, top_k):
30
+ """
31
+ Sends the user message and sampling parameters to /api/generate.
32
+ Streams the response chunks and updates the last bot message in history.
33
+ Clears input after sending. On error, shows error in chat.
34
+ """
35
+ history = history + [(message, "")]
36
+ yield history, ""
37
+ payload = {
38
+ "prompt": message,
39
+ "temperature": temperature,
40
+ "repetition_penalty": repetition_penalty,
41
+ "top-p": top_p,
42
+ "top-k": top_k
43
+ }
44
+ try:
45
+ response = requests.post(f"{API_URL}/api/generate", json=payload, timeout=(3.05, None))
46
+ response.raise_for_status()
47
+ except Exception as e:
48
+ history[-1] = (message, f"Error: {str(e)}")
49
+ yield history, ""
50
+ return
51
+ time.sleep(0.1)
52
+
53
+ while True:
54
+ time.sleep(0.01)
55
+ response = requests.get(
56
+ f"{API_URL}/api/generate_provider"
57
+ )
58
+ data = response.json()
59
+ chunk:str = data.get("response", "")
60
+ done = data.get("done", False)
61
+ if done:
62
+ break
63
+ if chunk.strip() == "":
64
+ continue
65
+ history[-1] = (message, history[-1][1] + chunk)
66
+ yield history, ""
67
+
68
+ print("end")
69
+
70
+
71
+ def stop_generate():
72
+ try:
73
+ requests.get(f"{API_URL}/api/stop")
74
+ except Exception as e:
75
+ print(e)
76
+
77
+ # Build the Gradio interface optimized for PC with spacious layout
78
+ # custom_css = """
79
+ # .gradio-container {
80
+ # max-width: 1400px;
81
+ # margin: auto;
82
+ # padding: 20px;
83
+ # }
84
+ # .gradio-container > * {
85
+ # margin-bottom: 20px;
86
+ # }
87
+ # #chatbox .overflow-y-auto {
88
+ # height: 600px !important;
89
+ # }
90
+ # """
91
+
92
+ # Build the Gradio interface优化布局
93
+ with gr.Blocks(theme=gr.themes.Soft(font="Consolas"), fill_width=True) as demo:
94
+ gr.Markdown("<h2 style='text-align:center;'>🚀 Chatbot Demo with Axare API Backend</h2>")
95
+
96
+ # 使用Row包裹左右两个主要区域
97
+ with gr.Row():
98
+ # 左侧聊天主区域(占3/4宽度)
99
+ with gr.Column(scale=3):
100
+ system_prompt = gr.Textbox(label="System Prompt", placeholder="Optional system prompt", lines=2, value="You are Qwen, created by Alibaba Cloud. You are a helpful assistant.")
101
+ reset_button = gr.Button("🔄 Reset Chat")
102
+ chatbot = gr.Chatbot(elem_id="chatbox", label="Axera Chat",height=500)
103
+ user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2)
104
+ with gr.Row():
105
+ send_button = gr.Button("➡️ Send", variant="primary")
106
+ stop_button = gr.Button("🛑 Stop", variant="stop")
107
+
108
+ # 右侧参数设置区域(占1/4宽度)
109
+ with gr.Column(scale=1):
110
+ temperature = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.7, label="Temperature")
111
+ repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, step=0.01, value=1.0, label="Repetition Penalty")
112
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.9, label="Top-p Sampling")
113
+ top_k = gr.Slider(minimum=0, maximum=100, step=1, value=40, label="Top-k Sampling")
114
+
115
+ # Wire up events: reset clears chat and input
116
+ reset_button.click(fn=reset_chat, inputs=system_prompt, outputs=[chatbot, user_input])
117
+ # send streams chat and clears input
118
+ send_button.click(
119
+ fn=stream_generate,
120
+ inputs=[chatbot, user_input, temperature, repetition_penalty, top_p, top_k],
121
+ outputs=[chatbot, user_input]
122
+ )
123
+
124
+ stop_button.click(
125
+ fn=stop_generate
126
+ )
127
+
128
+ # allow Enter key to send
129
+ user_input.submit(
130
+ fn=stream_generate,
131
+ inputs=[chatbot, user_input, temperature, repetition_penalty, top_p, top_k],
132
+ outputs=[chatbot, user_input]
133
+ )
134
+
135
+ if __name__ == "__main__":
136
+ demo.launch(server_name="0.0.0.0", server_port=7860) # adjust as needed
run_qwen2.5_1.5b_ctx_ax650_api.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ./main_api_ax650 \
2
+ --template_filename_axmodel "qwen2.5-1.5b-ctx-ax650/qwen2_p128_l%d_together.axmodel" \
3
+ --axmodel_num 28 \
4
+ --url_tokenizer_model "http://0.0.0.0:12345" \
5
+ --filename_post_axmodel "qwen2.5-1.5b-ctx-ax650/qwen2_post.axmodel" \
6
+ --filename_tokens_embed "qwen2.5-1.5b-ctx-ax650/model.embed_tokens.weight.bfloat16.bin" \
7
+ --tokens_embed_num 151936 \
8
+ --tokens_embed_size 1536 \
9
+ --use_mmap_load_embed 1 \
10
+ --live_print 1
11
+
12
+
13
+ #--system_prompt "你的名字叫小智(allen),你是一个人畜无害的AI助手。深圳市今天(4月1日)阴天,愚人节,气温在14°C至19°C之间,微风。" \
14
+ #--kvcache_path "./kvcache" \
15
+