OzoneAsai commited on
Commit
f7c185e
·
1 Parent(s): 2d45144

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -32
app.py CHANGED
@@ -1,41 +1,30 @@
1
  import os
2
  os.system("pip install transformers")
3
- from transformers import AutoModel, AutoTokenizer, trainer_utils
4
  import gradio as gr
 
 
5
 
6
- def generate(prefix_text):
7
- device = "cuda"
8
-
9
- model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
10
-
11
- tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
12
-
13
- x_token = tokenizer("", prefix_text=prefix_text, return_tensors="pt")
14
 
 
 
 
15
  trainer_utils.set_seed(30)
16
-
17
  input_ids = x_token.input_ids.to(device)
18
-
19
- token_type_ids = x_token.token_type_ids.to(device)
20
-
21
- gen_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=50)
22
-
23
  return tokenizer.decode(gen_token[0])
24
 
25
-
26
- def main():
27
- # テキストエリアを追加
28
- text_input = gr.Textbox()
29
-
30
- # 送信ボタンを追加
31
- submit_button = gr.Button("送信")
32
-
33
- # テキストエリアと送信ボタンを連携
34
- submit_button.on_click(lambda: gr.show(generate(text_input.value), title="テキスト生成"))
35
-
36
- # アプリを起動
37
- gr.Interface(main).launch()
38
-
39
-
40
- if __name__ == "__main__":
41
- main()
 
1
  import os
2
  os.system("pip install transformers")
 
3
  import gradio as gr
4
+ from transformers import AutoModel, AutoTokenizer, trainer_utils
5
+ import torch
6
 
7
+ # Load the model and tokenizer
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
10
+ tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
 
 
 
 
11
 
12
+ # Define the text generation function
13
+ def generate_text(input_text):
14
+ x_token = tokenizer(input_text, return_tensors="pt")
15
  trainer_utils.set_seed(30)
 
16
  input_ids = x_token.input_ids.to(device)
17
+ gen_token = model.generate(input_ids, max_new_tokens=50)
 
 
 
 
18
  return tokenizer.decode(gen_token[0])
19
 
20
+ # Create the Gradio interface
21
+ iface = gr.Interface(
22
+ fn=generate_text,
23
+ inputs="text",
24
+ outputs="text",
25
+ title="Japanese Text Generation",
26
+ description="Enter a prompt in Japanese to generate text."
27
+ )
28
+
29
+ # Launch the interface
30
+ iface.launch()