OzoneAsai commited on
Commit
4b4149f
·
1 Parent(s): ab9c8f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -32
app.py CHANGED
@@ -6,59 +6,56 @@ os.system("pip install transformers torch psutil")
6
 
7
  # コマンドの実行結果を取得する(stdoutとstderrは出力されない)
8
  result = os.system("pip install transformers")
 
 
9
 
 
 
10
 
11
- import os
12
- from transformers import AutoModel, AutoTokenizer, trainer_utils, AutoConfig
 
 
 
 
13
  import gradio as gr
14
  import psutil
15
 
16
  device = "cpu"
17
- model_directory="./"
18
- model_path = "./pytorch_model.bin" # モデルのパス
19
- conf_path="./config.json"
20
- if not os.path.exists(model_path): # モデルが存在しない場合のみダウンロード
21
- model_url = "https://huggingface.co/Tanrei/GPTSAN-japanese/resolve/main/pytorch_model.bin"
22
- os.system(f"wget -O {model_path} {model_url}")
23
- conf_url = "https://huggingface.co/Tanrei/GPTSAN-japanese/resolve/main/config.json"
24
- os.system(f"wget -O {conf_path} {conf_url}")
25
-
26
- config = AutoConfig.from_pretrained(conf_path)
27
- model = AutoModel.from_pretrained(model_directory,config=config).to(device)
28
- tokenizer = AutoTokenizer.from_pretrained(Tanrei/GPTSAN-japanese)
29
  trainer_utils.set_seed(30)
30
 
31
  def get_memory_usage():
32
  process = psutil.Process()
33
  memory_usage = process.memory_info().rss / 1024 / 1024 # メモリ使用量をMB単位で取得
34
  return f"Memory Usage: {memory_usage:.2f} MB"
35
-
36
- def generate_text(input_text, repetition_count):
37
- usag = get_memory_usage()
38
- repetition_count = int(repetition_count) # 数値型を整数型に変換
39
  x_token = tokenizer("", prefix_text=input_text, return_tensors="pt")
40
  input_ids = x_token.input_ids.to(device)
41
  token_type_ids = x_token.token_type_ids.to(device)
42
- generated_tokens = []
43
-
44
- for _ in range(repetition_count):
45
- gen_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=3)
46
- generated_tokens.append(gen_token)
47
- input_ids = gen_token.clone().detach()
48
-
49
- output_text = tokenizer.decode(torch.cat(generated_tokens))
50
- return output_text
51
-
 
 
 
52
  input_text = gr.inputs.Textbox(lines=5, label="Input Text")
53
- repetition_count = gr.inputs.Number(default=1, label="Repetition Count")
54
  output_text = gr.outputs.Textbox(label="Generated Text")
55
 
56
  interface = gr.Interface(
57
  fn=generate_text,
58
- inputs=[input_text, repetition_count],
59
  outputs=output_text,
60
  title=get_memory_usage(),
61
  description="Enter a prompt in Japanese to generate text."
62
  )
63
- interface.launch()
64
-
 
6
 
7
  # コマンドの実行結果を取得する(stdoutとstderrは出力されない)
8
  result = os.system("pip install transformers")
9
+ 以下のコードについて、テキスト生成のトークン生成数を3に設定し、生成されたテキストをinputとしてそれをx回繰り返すように設定。
10
+ import os
11
 
12
+ # コマンドを実行する
13
+ os.system("pip install transformers torch psutil")
14
 
15
+ # コマンドの実行結果を取得する(stdoutとstderrは出力されない)
16
+ result = os.system("pip install transformers")
17
+
18
+
19
+
20
+ from transformers import AutoModel, AutoTokenizer, trainer_utils
21
  import gradio as gr
22
  import psutil
23
 
24
  device = "cpu"
25
+ model = AutoModel.from_pretrained("Tanrei/GPTSAN-japanese").to(device)
26
+ tokenizer = AutoTokenizer.from_pretrained("Tanrei/GPTSAN-japanese")
 
 
 
 
 
 
 
 
 
 
27
  trainer_utils.set_seed(30)
28
 
29
  def get_memory_usage():
30
  process = psutil.Process()
31
  memory_usage = process.memory_info().rss / 1024 / 1024 # メモリ使用量をMB単位で取得
32
  return f"Memory Usage: {memory_usage:.2f} MB"
33
+ def generate_text(input_text, num_repeats):
34
+ usag=get_memory_usage()
 
 
35
  x_token = tokenizer("", prefix_text=input_text, return_tensors="pt")
36
  input_ids = x_token.input_ids.to(device)
37
  token_type_ids = x_token.token_type_ids.to(device)
38
+ gen_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=10)
39
+ output_text = tokenizer.decode(gen_token[0])
40
+ repeated_text = output_text
41
+ for _ in range(num_repeats):
42
+ x_token = tokenizer("", prefix_text=repeated_text, return_tensors="pt")
43
+ input_ids = x_token.input_ids.to(device)
44
+ token_type_ids = x_token.token_type_ids.to(device)
45
+ gen_token = model.generate(input_ids, token_type_ids=token_type_ids, max_new_tokens=10)
46
+ repeated_text += tokenizer.decode(gen_token[0])
47
+ return repeated_text
48
+
49
+
50
+
51
  input_text = gr.inputs.Textbox(lines=5, label="Input Text")
52
+ num_repeats = gr.inputs.Number(default=1, label="Number of Repeats")
53
  output_text = gr.outputs.Textbox(label="Generated Text")
54
 
55
  interface = gr.Interface(
56
  fn=generate_text,
57
+ inputs=[input_text, num_repeats],
58
  outputs=output_text,
59
  title=get_memory_usage(),
60
  description="Enter a prompt in Japanese to generate text."
61
  )