tahsinhasem commited on
Commit
9463fa8
·
verified ·
1 Parent(s): 896e66e

Use distillgpt

Browse files
Files changed (1) hide show
  1. main.py +49 -43
main.py CHANGED
@@ -6,49 +6,49 @@ from transformers import pipeline
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
 
9
- # Load pre-trained tokenizer and model (Works)
10
- # model_name = "distilgpt2"
11
- # tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- # model = AutoModelForCausalLM.from_pretrained(model_name)
13
 
14
- # # Example usage: Generate text
15
- # prompt = "The quick brown fox"
16
- # input_ids = tokenizer.encode(prompt, return_tensors="pt")
17
- # output = model.generate(input_ids, max_length=50, num_return_sequences=1)
18
- # generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
19
 
20
- # print(generated_text)
21
 
22
 
23
- import transformers
24
- import torch
25
- import logging
26
 
27
- model_id = "deepcogito/cogito-v1-preview-llama-3B"
28
 
29
- pipeline = transformers.pipeline(
30
- "text-generation",
31
- model=model_id,
32
- model_kwargs={"torch_dtype": torch.bfloat16},
33
- device_map="auto",
34
- )
35
 
36
 
37
- print("Pipeline loaded")
38
- logging.info("Pipeline loaded")
39
 
40
- messages = [
41
- {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
42
- {"role": "user", "content": "Give me a short introduction to LLMs."},
43
- ]
44
 
45
- outputs = pipeline(
46
- messages,
47
- max_new_tokens=512,
48
- )
49
 
50
- logging.info("Generated text")
51
- print(outputs[0]["generated_text"][-1])
52
 
53
 
54
  app = FastAPI()
@@ -62,21 +62,27 @@ class Item(BaseModel):
62
 
63
  @app.post("/generate/")
64
  async def generate_text(item: Item):
65
- messages = [
66
- {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
67
- {"role": "user", "content": "Give me a short introduction to LLMs."},
68
- ]
69
 
70
- outputs = pipeline(
71
- messages,
72
- max_new_tokens=512,
73
- )
74
 
75
- logging.info("request got")
76
 
77
- resp = outputs[0]["generated_text"][-1]
78
 
79
- logging.info("Response generated")
 
 
 
 
 
 
80
 
81
  return {"response": resp}
82
 
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
7
 
8
 
9
+ Load pre-trained tokenizer and model (Works)
10
+ model_name = "distilgpt2"
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForCausalLM.from_pretrained(model_name)
13
 
14
+ # Example usage: Generate text
15
+ prompt = "The quick brown fox"
16
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
17
+ output = model.generate(input_ids, max_length=50, num_return_sequences=1)
18
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
19
 
20
+ print(generated_text)
21
 
22
 
23
+ # import transformers
24
+ # import torch
25
+ # import logging
26
 
27
+ # model_id = "deepcogito/cogito-v1-preview-llama-3B"
28
 
29
+ # pipeline = transformers.pipeline(
30
+ # "text-generation",
31
+ # model=model_id,
32
+ # model_kwargs={"torch_dtype": torch.bfloat16},
33
+ # device_map="auto",
34
+ # )
35
 
36
 
37
+ # print("Pipeline loaded")
38
+ # logging.info("Pipeline loaded")
39
 
40
+ # messages = [
41
+ # {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
42
+ # {"role": "user", "content": "Give me a short introduction to LLMs."},
43
+ # ]
44
 
45
+ # outputs = pipeline(
46
+ # messages,
47
+ # max_new_tokens=512,
48
+ # )
49
 
50
+ # logging.info("Generated text")
51
+ # print(outputs[0]["generated_text"][-1])
52
 
53
 
54
  app = FastAPI()
 
62
 
63
  @app.post("/generate/")
64
  async def generate_text(item: Item):
65
+ # messages = [
66
+ # {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
67
+ # {"role": "user", "content": "Give me a short introduction to LLMs."},
68
+ # ]
69
 
70
+ # outputs = pipeline(
71
+ # messages,
72
+ # max_new_tokens=512,
73
+ # )
74
 
75
+ # logging.info("request got")
76
 
77
+ # resp = outputs[0]["generated_text"][-1]
78
 
79
+ # logging.info("Response generated")
80
+
81
+ input_ids = tokenizer.encode(item.prompt, return_tensors="pt")
82
+ output = model.generate(input_ids, max_length=50, num_return_sequences=1)
83
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
84
+
85
+ resp = generated_text
86
 
87
  return {"response": resp}
88