AutoReadmeAgent / modal /engine.py
bogeumkim's picture
Modify Gradio UI
174c3d3
raw
history blame
2.92 kB
import os
import asyncio
from queue import Empty
from typing import List, Union, List
# Import modal's required imports
from modal import Image, Stub, gpu, method, enter, exit
# Import the constants defined in constants.py
from config import (
MODEL_DIR,
BASE_MODEL,
STUB_NAME,
NUM_CONCURRENT_REQUESTS,
TIMEOUT,
GPU_COUNT
)
# Define our GPU Config
if BASE_MODEL == "mistralai/Mistral-7B-Instruct-v0.1":
GPU_CONFIG = gpu.A100(count=GPU_COUNT, memory=80)
else:
GPU_CONFIG = gpu.Any(count=GPU_COUNT)
stub = Stub(name=STUB_NAME)
def download_model_to_folder():
from transformers.utils import move_cache
from huggingface_hub import snapshot_download
os.makedirs(MODEL_DIR, exist_ok=True)
snapshot_download(
BASE_MODEL,
local_dir=MODEL_DIR,
ignore_patterns=["*.pt"], # Using safetensors
)
move_cache()
HF_DOCKER_IMAGE = (
Image.from_registry("nvidia/cuda:12.1.0-base-ubuntu22.04", add_python="3.10").pip_install_from_requirements("./requirements.txt")
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(download_model_to_folder)
)
@stub.cls(
gpu=GPU_CONFIG,
timeout=TIMEOUT,
container_idle_timeout=TIMEOUT,
allow_concurrent_inputs=NUM_CONCURRENT_REQUESTS,
image=HF_DOCKER_IMAGE,
)
class HFEngine:
model_name_or_path: str = MODEL_DIR
device: str = "cuda"
@enter()
def start_engine(self):
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
self.model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, trust_remote_code=True).to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, trust_remote_code=True)
self.streamer = TextIteratorStreamer(self.tokenizer)
return self
@exit()
def terminate_engine(self):
import gc
import torch
del self.model
torch.cuda.synchronize()
gc.collect()
@method()
async def stream(self, chat_input: Union[str, List[dict]], generation_kwargs: dict):
from threading import Thread
if isinstance(chat_input, str):
chat_input = [{"role": "user", "content": chat_input}]
input_ids = self.tokenizer.apply_chat_template(
conversation=chat_input, tokenize=True, return_tensors="pt"
).to(self.device)
gen_kwargs = dict(
input_ids=input_ids,
streamer=self.streamer,
pad_token_id=self.tokenizer.eos_token_id,
**generation_kwargs
)
thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
thread.start()
for next_token in self.streamer:
try:
if next_token is not None:
yield next_token
except Empty:
await asyncio.sleep(0.001)