import gradio as gr import os import threading import arrow import time import argparse import logging from dataclasses import dataclass import torch import sentencepiece as spm from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import GPTNeoXForCausalLM, GPTNeoXConfig from transformers.generation.streamers import BaseStreamer from huggingface_hub import hf_hub_download, login logger = logging.getLogger() logger.setLevel("INFO") gr_interface = None VERSION = "0.1.0" @dataclass class DefaultArgs: hf_model_name_or_path: str = None hf_tokenizer_name_or_path: str = None spm_model_path: str = None env: str = "dev" port: int = 7860 make_public: bool = False if os.getenv("RUNNING_ON_HF_SPACE"): login(token=os.getenv("HF_TOKEN")) hf_repo = os.getenv("HF_MODEL_REPO") args = DefaultArgs() args.hf_model_name_or_path = hf_repo args.hf_tokenizer_name_or_path = os.path.join(hf_repo, "tokenizer") args.spm_model_path = hf_hub_download(repo_id=hf_repo, filename="sentencepiece.model") else: parser = argparse.ArgumentParser(description="") parser.add_argument("--hf_model_name_or_path", type=str, required=True) parser.add_argument("--hf_tokenizer_name_or_path", type=str, required=False) parser.add_argument("--spm_model_path", type=str, required=True) parser.add_argument("--env", type=str, default="dev") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--make_public", action='store_true') args = parser.parse_args() def load_model( model_dir, ): config = GPTNeoXConfig.from_pretrained(model_dir) config.is_decoder = True model = GPTNeoXForCausalLM.from_pretrained(model_dir, config=config, torch_dtype=torch.bfloat16) if torch.cuda.is_available(): model = model.to("cuda:0") return model logging.info("Loading model") model = load_model(args.hf_model_name_or_path) sp = spm.SentencePieceProcessor(model_file=args.spm_model_path) logging.info("Finished loading model") tokenizer = AutoTokenizer.from_pretrained( args.hf_model_name_or_path, subfolder="tokenizer", use_fast=False ) class TokenizerStreamer(BaseStreamer): def __init__(self, tokenizer): self.tokenizer = tokenizer self.num_invoked = 0 self.prompt = "" self.generated_text = "" self.ended = False def put(self, t: torch.Tensor): d = t.dim() if d == 1: pass elif d == 2: t = t[0] else: raise NotImplementedError t = [int(x) for x in t.numpy()] text = tokenizer.decode(t) if text in [tokenizer.bos_token, tokenizer.eos_token]: text = "" if self.num_invoked == 0: self.prompt = text self.num_invoked += 1 return self.generated_text += text logging.debug(f"[streamer]: {self.generated_text}") def end(self): self.ended = True INPUT_PROMPT = """以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。 ### 指示: {instruction} ### 入力: {input} ### 応答: """ NO_INPUT_PROMPT = """以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。 ### 指示: {instruction} ### 応答: """ def postprocess_output(output): output = output\ .split('### 応答:')[1]\ .split('###')[0]\ .split('##')[0]\ .lstrip(tokenizer.bos_token)\ .rstrip(tokenizer.eos_token)\ .replace("###", "")\ .strip() return output def generate( prompt, max_new_tokens, temperature, repetition_penalty, do_sample, no_repeat_ngram_size, ): log = dict(locals()) logging.debug(log) input_text = NO_INPUT_PROMPT.format(instruction=prompt) input_ids = tokenizer.encode(input_text, add_special_tokens=False, return_tensors="pt") streamer = TokenizerStreamer(tokenizer=tokenizer) max_possilbe_new_tokens = model.config.max_position_embeddings - input_ids.shape[0] max_possilbe_new_tokens = min(max_possilbe_new_tokens, max_new_tokens) thr = threading.Thread(target=model.generate, args=(), kwargs=dict( input_ids=input_ids.to(model.device), do_sample=do_sample, temperature=temperature, repetition_penalty=repetition_penalty, no_repeat_ngram_size=no_repeat_ngram_size, max_new_tokens=max_possilbe_new_tokens, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, bad_words_ids=[[tokenizer.unk_token_id]], streamer=streamer, )) thr.start() while not streamer.ended: time.sleep(0.05) yield streamer.generated_text # TODO: optimize for final few tokens gen = streamer.generated_text log.update(dict( generation=gen, version=VERSION, time=str(arrow.now("+09:00")))) logging.info(log) yield gen def process_feedback( rating, prompt, generation, max_new_tokens, temperature, repetition_penalty, do_sample, no_repeat_ngram_size, ): log = dict(locals()) log.update(dict( time=str(arrow.now("+09:00")), version=VERSION, )) logging.info(log) if gr_interface: gr_interface.close(verbose=False) with gr.Blocks() as gr_interface: with gr.Row(): gr.Markdown(f"# 日本語 StableLM Tuned Pre-Alpha ({VERSION})") # gr.Markdown(f"バージョン:{VERSION}") with gr.Row(): gr.Markdown("この言語モデルは Stability AI Japan が開発した初期バージョンの日本語モデルです。モデルは「プロンプト」に入力した聞きたいことに対して、それらしい応答をすることができます。") with gr.Row(): # left panel with gr.Column(scale=1): # generation params with gr.Box(): gr.Markdown("パラメータ") # hidden default params do_sample = gr.Checkbox(True, label="Do Sample", info="サンプリング生成", visible=True) no_repeat_ngram_size = gr.Slider(0, 10, value=3, step=1, label="No Repeat Ngram Size", visible=False) # visible params max_new_tokens = gr.Slider( 128, min(512, model.config.max_position_embeddings), value=128, step=128, label="max tokens", info="生成するトークンの最大数を指定する", ) temperature = gr.Slider( 0, 1, value=0.1, step=0.05, label="temperature", info="低い値は出力をより集中させて決定論的にする") repetition_penalty = gr.Slider( 1, 1.5, value=1.2, step=0.05, label="frequency penalty", info="高い値はAIが繰り返す可能性を減少させる") # grouping params for easier reference gr_params = [ max_new_tokens, temperature, repetition_penalty, do_sample, no_repeat_ngram_size, ] # right panel with gr.Column(scale=2): # user input block with gr.Box(): textbox_prompt = gr.Textbox( label="プロンプト", placeholder="日本の首都は?", interactive=True, lines=5, value="" ) with gr.Box(): with gr.Row(): btn_stop = gr.Button(value="キャンセル", variant="secondary") btn_submit = gr.Button(value="実行", variant="primary") # model output block with gr.Box(): textbox_generation = gr.Textbox( label="生成結果", lines=5, value="" ) # rating block with gr.Row(): gr.Markdown("より良い言語モデルを皆様に提供できるよう、生成品質についてのご意見をお聞かせください。") with gr.Box(): with gr.Row(): rating_options = [ "最悪", "不合格", "中立", "合格", "最高", ] btn_ratings = [gr.Button(value=v) for v in rating_options] # TODO: we might not need this for sharing with close groups # with gr.Box(): # gr.Markdown("TODO:For more feedback link for google form") # event handling inputs = [textbox_prompt] + gr_params click_event = btn_submit.click(generate, inputs, textbox_generation, queue=True) btn_stop.click(None, None, None, cancels=click_event, queue=False) for btn_rating in btn_ratings: btn_rating.click(process_feedback, [btn_rating, textbox_prompt, textbox_generation] + gr_params, queue=False) gr_interface.queue(max_size=32, concurrency_count=2) gr_interface.launch(server_port=args.port, share=args.make_public)