rest / main.py
tahsinhasem's picture
Respond with Special Tokens
b40fa52 verified
from fastapi import FastAPI
from pydantic import BaseModel
from huggingface_hub import InferenceClient
import uvicorn
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
token = os.getenv("HUGGINGFACE_TOKEN")
assert token is not None, "Hugging Face token is missing. Please set the 'HUGGINGFACE_TOKEN' environment variable."
#Load pre-trained tokenizer and model (Works)
model_name = "microsoft/Phi-4-mini-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype="auto",
token=token
)
# Example usage: Generate text
prompt = "<|system|>You are a helpful assistant<|end|><|user|>What is the capital of france?<|end|><|assistant|>"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, return_attention_mask=True ).to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=100,
pad_token_id=tokenizer.eos_token_id # Set this to suppress warning
)
resp = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(resp)
app = FastAPI()
class EchoMessage(BaseModel):
message: str
class Item(BaseModel):
prompt: str
@app.post("/generate/")
async def generate_text(item: Item):
# messages = [
# {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
# {"role": "user", "content": "Give me a short introduction to LLMs."},
# ]
# outputs = pipeline(
# messages,
# max_new_tokens=512,
# )
# logging.info("request got")
# resp = outputs[0]["generated_text"][-1]
# logging.info("Response generated")
inp =f"<|system|>You are a helpful assistant<|end|><|user|> {item.prompt} <|end|><|assistant|>"
inputs = tokenizer(inp, return_tensors="pt", padding=True, return_attention_mask=True ).to(model.device)
# input_ids = tokenizer.encode(item.prompt, return_tensors="pt")
# output = model.generate(input_ids, max_length=50, num_return_sequences=1)
# generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# resp = generated_text
outputs = model.generate(
**inputs,
max_new_tokens=100,
pad_token_id=tokenizer.eos_token_id # Set this to suppress warning
)
resp = tokenizer.decode(outputs[0], skip_special_tokens=False)
return {"response": resp}
@app.get("/")
async def home():
return {"msg":"hey"}
@app.post("/echo/")
async def echo(echo_msg: EchoMessage):
return {"msg":echo_msg.message}