tahsinhasem commited on
Commit
9886add
·
verified ·
1 Parent(s): 7066dd4

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +21 -36
main.py CHANGED
@@ -9,7 +9,12 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
9
  #Load pre-trained tokenizer and model (Works)
10
  model_name = "gpt2"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- model = AutoModelForCausalLM.from_pretrained(model_name)
 
 
 
 
 
13
 
14
  # Example usage: Generate text
15
  prompt = "The quick brown fox"
@@ -20,37 +25,6 @@ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
20
  print(generated_text)
21
 
22
 
23
- # import transformers
24
- # import torch
25
- # import logging
26
-
27
- # model_id = "deepcogito/cogito-v1-preview-llama-3B"
28
-
29
- # pipeline = transformers.pipeline(
30
- # "text-generation",
31
- # model=model_id,
32
- # model_kwargs={"torch_dtype": torch.bfloat16},
33
- # device_map="auto",
34
- # )
35
-
36
-
37
- # print("Pipeline loaded")
38
- # logging.info("Pipeline loaded")
39
-
40
- # messages = [
41
- # {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
42
- # {"role": "user", "content": "Give me a short introduction to LLMs."},
43
- # ]
44
-
45
- # outputs = pipeline(
46
- # messages,
47
- # max_new_tokens=512,
48
- # )
49
-
50
- # logging.info("Generated text")
51
- # print(outputs[0]["generated_text"][-1])
52
-
53
-
54
  app = FastAPI()
55
 
56
  class EchoMessage(BaseModel):
@@ -78,11 +52,22 @@ async def generate_text(item: Item):
78
 
79
  # logging.info("Response generated")
80
 
81
- input_ids = tokenizer.encode(item.prompt, return_tensors="pt")
82
- output = model.generate(input_ids, max_length=50, num_return_sequences=1)
83
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- resp = generated_text
86
 
87
  return {"response": resp}
88
 
 
9
  #Load pre-trained tokenizer and model (Works)
10
  model_name = "gpt2"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ model_name,
14
+ device_map="auto",
15
+ torch_dtype="auto"
16
+ )
17
+
18
 
19
  # Example usage: Generate text
20
  prompt = "The quick brown fox"
 
25
  print(generated_text)
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  app = FastAPI()
29
 
30
  class EchoMessage(BaseModel):
 
52
 
53
  # logging.info("Response generated")
54
 
55
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, return_attention_mask=True).to(model.device)
56
+
57
+
58
+ # input_ids = tokenizer.encode(item.prompt, return_tensors="pt")
59
+ # output = model.generate(input_ids, max_length=50, num_return_sequences=1)
60
+ # generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
61
+ # resp = generated_text
62
+
63
+ outputs = model.generate(
64
+ **inputs,
65
+ max_new_tokens=100,
66
+ pad_token_id=tokenizer.eos_token_id # Set this to suppress warning
67
+ )
68
+
69
+ resp = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
 
 
71
 
72
  return {"response": resp}
73