Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from huggingface_hub import InferenceClient | |
import uvicorn | |
from transformers import pipeline | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
#Load pre-trained tokenizer and model (Works) | |
model_name = "gpt2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
torch_dtype="auto" | |
) | |
# Example usage: Generate text | |
prompt = "The quick brown fox" | |
inputs = tokenizer(prompt, return_tensors="pt", padding=True, return_attention_mask=True, ).to(model.device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=100, | |
pad_token_id=tokenizer.eos_token_id # Set this to suppress warning | |
) | |
resp = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print(resp) | |
app = FastAPI() | |
class EchoMessage(BaseModel): | |
message: str | |
class Item(BaseModel): | |
prompt: str | |
async def generate_text(item: Item): | |
# messages = [ | |
# {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, | |
# {"role": "user", "content": "Give me a short introduction to LLMs."}, | |
# ] | |
# outputs = pipeline( | |
# messages, | |
# max_new_tokens=512, | |
# ) | |
# logging.info("request got") | |
# resp = outputs[0]["generated_text"][-1] | |
# logging.info("Response generated") | |
inputs = tokenizer(item.prompt, return_tensors="pt", padding=True, return_attention_mask=True, ).to(model.device) | |
# input_ids = tokenizer.encode(item.prompt, return_tensors="pt") | |
# output = model.generate(input_ids, max_length=50, num_return_sequences=1) | |
# generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
# resp = generated_text | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=100, | |
pad_token_id=tokenizer.eos_token_id # Set this to suppress warning | |
) | |
resp = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return {"response": resp} | |
async def home(): | |
return {"msg":"hey"} | |
async def echo(echo_msg: EchoMessage): | |
return {"msg":echo_msg.message} | |