Spaces:
Sleeping
Sleeping
| # To run: funix main.py | |
| from transformers import AutoTokenizer | |
| from transformers import AutoModelForCausalLM | |
| import typing | |
| from funix import funix | |
| from funix.hint import HTML | |
| low_memory = True # Set to True to run on mobile devices | |
| import os | |
| hf_token = os.environ.get("HF_TOKEN") | |
| ku_gpt_tokenizer = AutoTokenizer.from_pretrained("ku-nlp/gpt2-medium-japanese-char") | |
| chj_gpt_tokenizer = AutoTokenizer.from_pretrained("TURX/chj-gpt2", token=hf_token) | |
| wakagpt_tokenizer = AutoTokenizer.from_pretrained("TURX/wakagpt", token=hf_token) | |
| ku_gpt_model = AutoModelForCausalLM.from_pretrained("ku-nlp/gpt2-medium-japanese-char") | |
| chj_gpt_model = AutoModelForCausalLM.from_pretrained("TURX/chj-gpt2", token=hf_token) | |
| wakagpt_model = AutoModelForCausalLM.from_pretrained("TURX/wakagpt", token=hf_token) | |
| print("Models loaded successfully.") | |
| model_name_map = { | |
| "Kyoto University GPT-2 (Modern)": "ku-gpt2", | |
| "CHJ GPT-2 (Classical)": "chj-gpt2", | |
| "Waka GPT": "wakagpt", | |
| } | |
| waka_type_map = { | |
| "kana": "[ไปฎๅ]", | |
| "original": "[ๅๆ]", | |
| "aligned": "[ๆดๅฝข]", | |
| } | |
| def home(): | |
| return | |
| def __generate(tokenizer: AutoTokenizer, model: AutoModelForCausalLM, prompt: str, | |
| do_sample: bool, num_beams: int, num_beam_groups: int, max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, num_return_sequences: int | |
| ) -> str: | |
| global low_memory | |
| inputs = tokenizer(prompt, return_tensors="pt").input_ids | |
| outputs = model.generate(inputs, low_memory=low_memory, do_sample=do_sample, num_beams=num_beams, num_beam_groups=num_beam_groups, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=num_return_sequences) | |
| return tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| def prompt(prompt: str = "ใใใซใกใฏใ", model_type: typing.Literal["Kyoto University GPT-2 (Modern)", "CHJ GPT-2 (Classical)", "Waka GPT"] = "Kyoto University GPT-2 (Modern)", | |
| do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, max_new_tokens: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1 | |
| ) -> HTML: | |
| model_name = model_name_map[model_type] | |
| if model_name == "ku-gpt2": | |
| tokenizer = ku_gpt_tokenizer | |
| model = ku_gpt_model | |
| elif model_name == "chj-gpt2": | |
| tokenizer = chj_gpt_tokenizer | |
| model = chj_gpt_model | |
| elif model_name == "wakagpt": | |
| tokenizer = wakagpt_tokenizer | |
| model = wakagpt_model | |
| else: | |
| raise NotImplementedError(f"Unsupported model: {model_name}") | |
| generated = __generate(tokenizer, model, prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences) | |
| return HTML("".join([f"<p>{i}</p>" for i in generated])) | |
| def waka(preface: str = "", author: str = "", first_line: str = "ใใใใฌใจโใใซใฏใใใใซโใฟใใญใจใ", type: typing.Literal["Kana", "Original", "Aligned"] = "Kana", remaining_lines: int = 2, | |
| do_sample: bool = True, num_beams: int = 1, num_beam_groups: int = 1, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95, repetition_penalty: float = 1.0, num_return_sequences: int = 1 | |
| ) -> HTML: | |
| waka_prompt = "" | |
| if preface: | |
| waka_prompt += "[่ฉๆธ] " + preface + "\n" | |
| if author: | |
| waka_prompt += "[ไฝ่ ] " + author + "\n" | |
| token_counts = [5, 7, 5, 7, 7] | |
| max_new_tokens = sum(token_counts[-remaining_lines:]) | |
| first_line = first_line.strip() | |
| # add separators | |
| if type.lower() in ["kana", "aligned"]: | |
| if first_line == "": | |
| max_new_tokens += 4 | |
| else: | |
| first_line += "โ" if first_line[-1] != "โ" else first_line | |
| max_new_tokens += remaining_lines - 1 # remaining separators | |
| waka_prompt += waka_type_map[type.lower()] + " " + first_line | |
| info = f""" | |
| Prompt: {waka_prompt}<br> | |
| Max New Tokens: {max_new_tokens}<br> | |
| """ | |
| generated = __generate(wakagpt_tokenizer, wakagpt_model, waka_prompt, do_sample, num_beams, num_beam_groups, max_new_tokens, temperature, top_k, top_p, repetition_penalty, num_return_sequences) | |
| removed = 0 | |
| checked_generated = [] | |
| if type.lower() == "kana": | |
| def check(seq): | |
| poem = first_line + seq[len(waka_prompt) - 1:] | |
| parts = poem.split("โ") | |
| if len(parts) == 5 and all(len(part) == token_counts[i] for i, part in enumerate(parts)): | |
| checked_generated.append(poem) | |
| else: | |
| nonlocal removed | |
| removed += 1 | |
| for i in generated: | |
| check(i) | |
| else: | |
| checked_generated = [first_line + i[len(waka_prompt) - 1:] for i in generated] | |
| generated = [f"<p>{i}</p>" for i in checked_generated] | |
| return info + f"Removed Malformed: {removed}<br>Results:<br>{''.join(generated)}" | |
| if __name__ == "__main__": | |
| print(prompt("ใใใซใกใฏ", "Kyoto University GPT-2 (Modern)", num_beams=5, num_return_sequences=5)) | |