tahsinhasem commited on
Commit
8d5fa35
·
verified ·
1 Parent(s): d236bfb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +11 -5
main.py CHANGED
@@ -19,11 +19,17 @@ model = AutoModelForCausalLM.from_pretrained(
19
 
20
  # Example usage: Generate text
21
  prompt = "The quick brown fox"
22
- input_ids = tokenizer.encode(prompt, return_tensors="pt")
23
- output = model.generate(input_ids, max_length=50, num_return_sequences=1)
24
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
25
 
26
- print(generated_text)
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  app = FastAPI()
@@ -53,7 +59,7 @@ async def generate_text(item: Item):
53
 
54
  # logging.info("Response generated")
55
 
56
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, return_attention_mask=True, ).to(model.device)
57
 
58
 
59
  # input_ids = tokenizer.encode(item.prompt, return_tensors="pt")
 
19
 
20
  # Example usage: Generate text
21
  prompt = "The quick brown fox"
 
 
 
22
 
23
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, return_attention_mask=True, ).to(model.device)
24
+ outputs = model.generate(
25
+ **inputs,
26
+ max_new_tokens=100,
27
+ pad_token_id=tokenizer.eos_token_id # Set this to suppress warning
28
+ )
29
+
30
+ resp = tokenizer.decode(outputs[0], skip_special_tokens=True)
31
+
32
+ print(resp)
33
 
34
 
35
  app = FastAPI()
 
59
 
60
  # logging.info("Response generated")
61
 
62
+ inputs = tokenizer(item.prompt, return_tensors="pt", padding=True, return_attention_mask=True, ).to(model.device)
63
 
64
 
65
  # input_ids = tokenizer.encode(item.prompt, return_tensors="pt")