appvoid commited on
Commit
d754671
Β·
verified Β·
1 Parent(s): 4813a03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -121
app.py CHANGED
@@ -5,124 +5,27 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
-
10
- MAX_MAX_NEW_TOKENS = 2048
11
- DEFAULT_MAX_NEW_TOKENS = 1024
12
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048"))
13
-
14
- DESCRIPTION = """\
15
- # palmer-004
16
- """
17
-
18
- if not torch.cuda.is_available():
19
- DESCRIPTION += "\n<p>Running on CPU πŸ₯Ά This demo does not work on CPU.</p>"
20
-
21
- if torch.cuda.is_available():
22
- model_id = "appvoid/palmer-004"
23
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=False)
24
- tokenizer = AutoTokenizer.from_pretrained(model_id)
25
- tokenizer.use_default_system_prompt = True
26
-
27
-
28
- @spaces.GPU
29
- def generate(
30
- message: str,
31
- chat_history: list[tuple[str, str]],
32
- system_prompt: str,
33
- max_new_tokens: int = 1024,
34
- temperature: float = 0.6,
35
- top_p: float = 0.9,
36
- top_k: int = 50,
37
- repetition_penalty: float = 1.2,
38
- ) -> Iterator[str]:
39
- conversation = []
40
- if system_prompt:
41
- conversation.append({"role": "system", "content": system_prompt})
42
- for user, assistant in chat_history:
43
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
44
- conversation.append({"role": "user", "content": message})
45
-
46
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
47
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
48
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
49
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
50
- input_ids = input_ids.to(model.device)
51
-
52
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
53
- generate_kwargs = dict(
54
- {"input_ids": input_ids},
55
- streamer=streamer,
56
- max_new_tokens=max_new_tokens,
57
- do_sample=True,
58
- top_p=top_p,
59
- top_k=top_k,
60
- temperature=temperature,
61
- num_beams=1,
62
- repetition_penalty=repetition_penalty,
63
- )
64
- t = Thread(target=model.generate, kwargs=generate_kwargs)
65
- t.start()
66
-
67
- outputs = []
68
- for text in streamer:
69
- outputs.append(text)
70
- yield "".join(outputs)
71
-
72
-
73
- chat_interface = gr.ChatInterface(
74
- fn=generate,
75
- additional_inputs=[
76
- gr.Textbox(label="System prompt", lines=6),
77
- gr.Slider(
78
- label="Max new tokens",
79
- minimum=1,
80
- maximum=MAX_MAX_NEW_TOKENS,
81
- step=1,
82
- value=DEFAULT_MAX_NEW_TOKENS,
83
- ),
84
- gr.Slider(
85
- label="Temperature",
86
- minimum=0.1,
87
- maximum=4.0,
88
- step=0.1,
89
- value=0.6,
90
- ),
91
- gr.Slider(
92
- label="Top-p (nucleus sampling)",
93
- minimum=0.05,
94
- maximum=1.0,
95
- step=0.05,
96
- value=0.9,
97
- ),
98
- gr.Slider(
99
- label="Top-k",
100
- minimum=1,
101
- maximum=1000,
102
- step=1,
103
- value=50,
104
- ),
105
- gr.Slider(
106
- label="Repetition penalty",
107
- minimum=1.0,
108
- maximum=2.0,
109
- step=0.05,
110
- value=1.2,
111
- ),
112
- ],
113
- stop_btn=None,
114
- examples=[
115
- ["Hello there! How are you doing?"],
116
- ["Can you explain briefly to me what is the Python programming language?"],
117
- ["Explain the plot of Cinderella in a sentence."],
118
- ["How many hours does it take a man to eat a Helicopter?"],
119
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
120
- ],
121
- )
122
-
123
- with gr.Blocks(css="style.css") as demo:
124
- gr.Markdown(DESCRIPTION)
125
- chat_interface.render()
126
-
127
- if __name__ == "__main__":
128
- demo.queue(max_size=20).launch(share=True)
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
11
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
12
+
13
+ @space.GPU
14
+ def text_generation(input_text, seed):
15
+ input_ids = tokenizer(input_text, return_tensors="pt").input_ids
16
+ torch.manual_seed(seed) # Max value: 18446744073709551615
17
+ outputs = model.generate(input_ids, do_sample=True, max_length=100)
18
+ generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
19
+ return generated_text
20
+
21
+ title = "palmer demo"
22
+ description = "Text completion app by appvoid"
23
+
24
+ gr.Interface(
25
+ text_generation,
26
+ [gr.inputs.Textbox(lines=2, label="Enter input text"), gr.inputs.Number(default=10, label="Enter seed number")],
27
+ [gr.outputs.Textbox(type="auto", label="Text Generated")],
28
+ title=title,
29
+ description=description,
30
+ theme="huggingface"
31
+ ).launch()