from fastapi import FastAPI from pydantic import BaseModel from huggingface_hub import InferenceClient import uvicorn from transformers import pipeline from transformers import AutoTokenizer, AutoModelForCausalLM #Load pre-trained tokenizer and model (Works) model_name = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, device_map="auto", torch_dtype="auto" ) # Example usage: Generate text prompt = "The quick brown fox" inputs = tokenizer(prompt, return_tensors="pt", padding=True, return_attention_mask=True, ).to(model.device) outputs = model.generate( **inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id # Set this to suppress warning ) resp = tokenizer.decode(outputs[0], skip_special_tokens=True) print(resp) app = FastAPI() class EchoMessage(BaseModel): message: str class Item(BaseModel): prompt: str @app.post("/generate/") async def generate_text(item: Item): # messages = [ # {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, # {"role": "user", "content": "Give me a short introduction to LLMs."}, # ] # outputs = pipeline( # messages, # max_new_tokens=512, # ) # logging.info("request got") # resp = outputs[0]["generated_text"][-1] # logging.info("Response generated") inputs = tokenizer(item.prompt, return_tensors="pt", padding=True, return_attention_mask=True, ).to(model.device) # input_ids = tokenizer.encode(item.prompt, return_tensors="pt") # output = model.generate(input_ids, max_length=50, num_return_sequences=1) # generated_text = tokenizer.decode(output[0], skip_special_tokens=True) # resp = generated_text outputs = model.generate( **inputs, max_new_tokens=100, pad_token_id=tokenizer.eos_token_id # Set this to suppress warning ) resp = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"response": resp} @app.get("/") async def home(): return {"msg":"hey"} @app.post("/echo/") async def echo(echo_msg: EchoMessage): return {"msg":echo_msg.message}