Spaces:
Runtime error
Runtime error
Commit
·
b19bf93
1
Parent(s):
c0741a6
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import random
|
3 |
+
|
4 |
+
import requests
|
5 |
+
|
6 |
+
# Template
|
7 |
+
title = "A conversation with Gandalf (GPTJ-6B) 🧙"
|
8 |
+
description = ""
|
9 |
+
article = """<img src='http://www.simoninithomas.com/test/gandalf.jpg', alt="Gandalf"/>"""
|
10 |
+
theme="huggingface"
|
11 |
+
examples = [[0.9, 1.1, 50, "Hey Gandalf! How are you?"], [0.9, 1.1, 50, "Hey Gandalf, why you didn't use the great eagles to fly Frodo to Mordor?"]]
|
12 |
+
|
13 |
+
# GPT-J-6B API
|
14 |
+
API_URL = "https://api-inference.huggingface.co/models/EleutherAI/gpt-j-6B"
|
15 |
+
def query(payload):
|
16 |
+
response = requests.post(API_URL, json=payload)
|
17 |
+
return response.json()
|
18 |
+
context_setup = "The following is a conversation with Gandalf, the mage of 'the Lord of the Rings'"
|
19 |
+
context=context_setup
|
20 |
+
interlocutor_names = ["Human", "Gandalf"]
|
21 |
+
|
22 |
+
# Builds the prompt from what previously happened
|
23 |
+
def build_prompt(conversation, context):
|
24 |
+
prompt = context + "\n"
|
25 |
+
for user_msg, resp_msg in conversation:
|
26 |
+
line = "\n- " + interlocutor_names[0] + ":" + user_msg
|
27 |
+
prompt += line
|
28 |
+
line = "\n- " + interlocutor_names[1] + ":" + resp_msg
|
29 |
+
prompt += line
|
30 |
+
prompt += ""
|
31 |
+
return prompt
|
32 |
+
|
33 |
+
# Attempt to recognize what the model said, if it used the correct format
|
34 |
+
def clean_chat_output(txt, prompt):
|
35 |
+
delimiter = "\n- "+interlocutor_names[0]
|
36 |
+
output = txt.replace(prompt, '')
|
37 |
+
output = output[:output.find(delimiter)]
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
def chat(top_p, temperature, max_new_tokens, message):
|
42 |
+
history = gr.get_state() or []
|
43 |
+
history.append((message, ""))
|
44 |
+
gr.set_state(history)
|
45 |
+
conversation = history
|
46 |
+
prompt = build_prompt(conversation, context)
|
47 |
+
|
48 |
+
# Build JSON
|
49 |
+
json_ = {"inputs": prompt,
|
50 |
+
"parameters":
|
51 |
+
{
|
52 |
+
"top_p": top_p,
|
53 |
+
"temperature": temperature,
|
54 |
+
"max_new_tokens": max_new_tokens,
|
55 |
+
"return_full_text": False
|
56 |
+
}}
|
57 |
+
|
58 |
+
output = query(json_)
|
59 |
+
output = output[0]['generated_text']
|
60 |
+
answer = clean_chat_output(output, prompt)
|
61 |
+
response = answer
|
62 |
+
history[-1] = (message, response)
|
63 |
+
gr.set_state(history)
|
64 |
+
html = "<div class='chatbot'>"
|
65 |
+
for user_msg, resp_msg in history:
|
66 |
+
html += f"<div class='user_msg'>{user_msg}</div>"
|
67 |
+
html += f"<div class='resp_msg'>{resp_msg}</div>"
|
68 |
+
html += "</div>"
|
69 |
+
return html
|
70 |
+
|
71 |
+
iface = gr.Interface(
|
72 |
+
chat,
|
73 |
+
[
|
74 |
+
gr.inputs.Slider(minimum=0.5, maximum=1, step=0.05, default=0.9, label="top_p"),
|
75 |
+
gr.inputs.Slider(minimum=0.5, maximum=1.5, step=0.1, default=1.1, label="temperature"),
|
76 |
+
gr.inputs.Slider(minimum=20, maximum=250, step=10, default=50, label="max_new_tokens"),
|
77 |
+
"text",
|
78 |
+
],
|
79 |
+
"html", css="""
|
80 |
+
.chatbox {display:flex;flex-direction:column}
|
81 |
+
.user_msg, .resp_msg {padding:4px;margin-bottom:4px;border-radius:4px;width:80%}
|
82 |
+
.user_msg {background-color:cornflowerblue;color:white;align-self:start}
|
83 |
+
.resp_msg {background-color:lightgray;align-self:self-end}
|
84 |
+
""", allow_screenshot=True,
|
85 |
+
allow_flagging=True,
|
86 |
+
title=title,
|
87 |
+
article=article,
|
88 |
+
theme=theme,
|
89 |
+
examples=examples)
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
+
iface.launch()
|