zee2221 commited on
Commit
70605a3
·
1 Parent(s): 9466714

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -103
app.py CHANGED
@@ -1,107 +1,96 @@
1
- import gradio as gr
2
- import requests
3
- import json
4
- import os
5
-
6
- from spaces_info import description, examples, initial_prompt_value
7
-
8
- API_URL = os.getenv("API_URL")
9
- HF_API_TOKEN = os.getenv("HF_API_TOKEN")
10
-
11
-
12
- def query(payload):
13
- print(payload)
14
- response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"})
15
- print(response)
16
- return json.loads(response.content.decode("utf-8"))
17
-
18
-
19
- def inference(input_sentence, max_length, sample_or_greedy, seed=42):
20
- if sample_or_greedy == "Sample":
21
- parameters = {
22
- "max_new_tokens": max_length,
23
- "top_p": 0.9,
24
- "do_sample": True,
25
- "seed": seed,
26
- "early_stopping": False,
27
- "length_penalty": 0.0,
28
- "eos_token_id": None,
29
- }
30
- else:
31
- parameters = {
32
- "max_new_tokens": max_length,
33
- "do_sample": False,
34
- "seed": seed,
35
- "early_stopping": False,
36
- "length_penalty": 0.0,
37
- "eos_token_id": None,
38
- }
39
 
40
- payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} }
41
 
42
- data = query(payload)
43
 
44
- if "error" in data:
45
- return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>")
46
-
47
- generation = data[0]["generated_text"].split(input_sentence, 1)[1]
48
- return (
49
- before_prompt
50
- + input_sentence
51
- + prompt_to_generation
52
- + generation
53
- + after_generation,
54
- data[0]["generated_text"],
55
- "",
56
- )
57
-
58
-
59
- if __name__ == "__main__":
60
- demo = gr.Blocks()
61
- with demo:
62
- with gr.Row():
63
- gr.Markdown(value=description)
64
- with gr.Row():
65
- with gr.Column():
66
- text = gr.Textbox(
67
- label="Input",
68
- value=" ", # should be set to " " when plugged into a real API
69
- )
70
- tokens = gr.Slider(1, 64, value=32, step=1, label="Tokens to generate")
71
- sampling = gr.Radio(
72
- ["Sample", "Greedy"], label="Sample or greedy", value="Sample"
73
- )
74
- sampling2 = gr.Radio(
75
- ["Sample 1", "Sample 2", "Sample 3", "Sample 4", "Sample 5"],
76
- value="Sample 1",
77
- label="Sample other generations (only work in 'Sample' mode)",
78
- type="index",
79
- )
80
-
81
- with gr.Row():
82
- submit = gr.Button("Submit")
83
- load_image = gr.Button("Generate Image")
84
- with gr.Column():
85
- text_error = gr.Markdown(label="Log information")
86
- text_out = gr.Textbox(label="Output")
87
- display_out = gr.HTML(label="Image")
88
- display_out.set_event_trigger(
89
- "load",
90
- fn=None,
91
- inputs=None,
92
- outputs=None,
93
- no_target=True,
94
- js=js_load_script,
95
- )
96
- with gr.Row():
97
- gr.Examples(examples=examples, inputs=[text, tokens, sampling, sampling2])
98
-
99
- submit.click(
100
- inference,
101
- inputs=[text, tokens, sampling, sampling2],
102
- outputs=[display_out, text_out, text_error],
103
- )
104
-
105
- load_image.click(fn=None, inputs=None, outputs=None, _js=js_save)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- demo.launch()
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ # In[ ]:
5
 
 
6
 
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ import torch
9
+ import gradio as gr
10
+ import re
11
+
12
+ def cleaning_history_tuple(history):
13
+ s=sum(history,())
14
+ s=list(s)
15
+ s2=""
16
+ for i in s:
17
+ i=re.sub("\n", '', i)
18
+ i=re.sub("<p>", '', i)
19
+ i=re.sub("</p>", '', i)
20
+ s2=s2+i+'\n'
21
+ return s2
22
+
23
+ def ai_output(string1,string2):
24
+ a1=len(string1)
25
+ a2=len(string2)
26
+ string3=string2[a1:]
27
+ sub1="A:"
28
+ sub2="User"
29
+ #sub3="\n"
30
+ try:
31
+ try:
32
+ idx1=string3.index(sub1)
33
+ response=string3[:idx1]
34
+ return response
35
+
36
+ except:
37
+ idx1=string3.index(sub2)
38
+ response=string3[:idx1]
39
+ return response
40
+ except:
41
+ return string3
42
+
43
+ model4 = AutoModelForCausalLM.from_pretrained("bigscience/bloom-3b")
44
+ tokenizer4 = AutoTokenizer.from_pretrained("bigscience/bloom-3b")
45
+
46
+ def predict(input,initial_prompt, temperature=0.7,top_p=1,top_k=5,max_tokens=64,no_repeat_ngram_size=1,num_beams=6,do_sample=True, history=[]):
47
+
48
+ s = cleaning_history_tuple(history)
49
+
50
+ s = s+ "\n"+ "User: "+ input + "\n" + "Assistant: "
51
+ s2=initial_prompt+" " + s
52
+
53
+ input_ids = tokenizer4.encode(str(s2), return_tensors="pt")
54
+ response = model4.generate(input_ids, min_length = 10,
55
+ max_new_tokens=int(max_tokens),
56
+ top_k=int(top_k),
57
+ top_p=float(top_p),
58
+ temperature=float(temperature),
59
+ no_repeat_ngram_size=int(no_repeat_ngram_size),
60
+ num_beams = int(num_beams),
61
+ do_sample = bool(do_sample),
62
+ )
63
+
64
+
65
+ response2 = tokenizer4.decode(response[0])
66
+ print("Response after decoding tokenizer: ",response2)
67
+ print("\n\n")
68
+ response3=ai_output(s2,response2)
69
+
70
+ input="User: "+input
71
+ response3="Assistant: "+ response3
72
+ history.append((input, response3))
73
+
74
+ return history, history
75
+
76
+ #gr.Interface(fn=predict,title="BLOOM-3b",
77
+ # inputs=["text","text","text","text","text","text","text","text","text",'state'],
78
+ #
79
+ # outputs=["chatbot",'state']).launch()
80
+
81
+
82
+ gr.Interface(inputs=[gr.Textbox(label="input", lines=1, value=""),
83
+ gr.Textbox(label="initial_prompt", lines=1, value=prompt),
84
+ gr.Textbox(label="temperature", lines=1, value=0.7),
85
+ gr.Textbox(label="top_p", lines=1, value=1),
86
+ gr.Textbox(label="top_k", lines=1, value=5),
87
+ gr.Textbox(label="max_tokens", lines=1, value=64),
88
+ gr.Textbox(label="no_repeat_ngram_size", lines=1, value=1),
89
+ gr.Textbox(label="num_beams", lines=1, value=6),
90
+ gr.Textbox(label="do_sample", lines=1, value="True"), 'state'],
91
+ fn=predict, title="OPT-6.7B", outputs=["chatbot",'state']
92
+
93
+ #inputs=["text","text","text","text","text","text","text","text","text",'state'],
94
+
95
+ ).launch()
96