Spaces:
Runtime error
Runtime error
| from threading import Thread | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| AutoConfig, | |
| TextIteratorStreamer | |
| ) | |
| MODEL_ID = "universeTBD/astrollama" | |
| WINDOW_SIZE = 4096 | |
| DEVICE = "cuda" | |
| config = AutoConfig.from_pretrained(pretrained_model_name_or_path=MODEL_ID) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| pretrained_model_name_or_path=MODEL_ID | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| pretrained_model_name_or_path=MODEL_ID, | |
| config=config, | |
| device_map="auto", | |
| use_safetensors=True, | |
| trust_remote_code=True, | |
| load_in_4bit=True, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| def generate_text(prompt: str, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.5, | |
| top_p: float = 0.95, | |
| top_k: int = 50) -> str: | |
| # Encode the prompt | |
| inputs = tokenizer([prompt], | |
| return_tensors='pt', | |
| add_special_tokens=False).to(DEVICE) | |
| # Prepare arguments for generation | |
| input_length = inputs["input_ids"].shape[-1] | |
| max_new_tokens = min(max_new_tokens, WINDOW_SIZE - input_length) | |
| if temperature >= 1.0: | |
| temperature = 0.99 | |
| elif temperature <= 0.0: | |
| temperature = 0.01 | |
| if top_p > 1.0 or top_p <= 0.0: | |
| top_p = 1.0 | |
| if top_k <= 0: | |
| top_k = 100 | |
| streamer = TextIteratorStreamer(tokenizer, | |
| timeout=10., | |
| skip_prompt=True, | |
| skip_special_tokens=True) | |
| generation_kwargs = dict( | |
| **inputs, | |
| streamer=streamer, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| top_p=top_p, | |
| top_k=top_k, | |
| temperature=temperature, | |
| num_beams=1, | |
| ) | |
| # Generate text | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| generated_text = prompt | |
| for new_text in streamer: | |
| generated_text += new_text | |
| return generated_text | |
| demo = gr.Interface( | |
| fn=generate_text, | |
| inputs=[ | |
| # Prompt | |
| gr.Textbox( | |
| label="Prompt", | |
| container=False, | |
| show_label=False, | |
| placeholder="Enter some text...", | |
| lines=10, | |
| scale=10, | |
| ), | |
| gr.Slider( | |
| label="Maximum new tokens", | |
| minimum=1, | |
| maximum=4096, | |
| step=1, | |
| value=1024, | |
| ), | |
| gr.Slider( | |
| label="Temperature", | |
| minimum=0.01, | |
| maximum=0.99, | |
| step=0.01, | |
| value=0.5, | |
| ), | |
| gr.Slider( | |
| label="Top-p (for sampling)", | |
| minimum=0.05, | |
| maximum=1.0, | |
| step=0.05, | |
| value=0.95, | |
| ), | |
| gr.Slider( | |
| label='Top-k (for sampling)', | |
| minimum=1, | |
| maximum=1000, | |
| step=1, | |
| value=100, | |
| ) | |
| ], | |
| outputs=[ | |
| gr.Textbox( | |
| container=False, | |
| show_label=False, | |
| placeholder="Generated output...", | |
| scale=10, | |
| lines=10, | |
| ) | |
| ], | |
| ) | |
| demo.queue(max_size=20).launch() | |