johnuwaishe commited on
Commit
bf1baa6
·
verified ·
1 Parent(s): 3fe9f4e

Upload 2 files

Browse files
Files changed (2) hide show
  1. .api/config.json +21 -0
  2. .api/handler.py +119 -0
.api/config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "task": "text-generation",
3
+ "framework": "pytorch",
4
+ "runtime": "transformers",
5
+ "model_id": "johnuwaishe/Nigerian-health-llama-7b",
6
+ "revision": "main",
7
+ "handler_path": "handler.py",
8
+ "requirements": [
9
+ "torch>=2.0.0",
10
+ "transformers>=4.37.0",
11
+ "accelerate>=0.27.0"
12
+ ],
13
+ "parameters": {
14
+ "max_new_tokens": 512,
15
+ "temperature": 0.7,
16
+ "top_p": 0.95,
17
+ "top_k": 50,
18
+ "repetition_penalty": 1.1,
19
+ "do_sample": true
20
+ }
21
+ }
.api/handler.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ from threading import Thread
5
+
6
+ class EndpointHandler:
7
+ def __init__(self, path=""):
8
+ # Load model and tokenizer
9
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
10
+ self.model = AutoModelForCausalLM.from_pretrained(
11
+ path,
12
+ torch_dtype=torch.bfloat16,
13
+ device_map="auto",
14
+ trust_remote_code=True
15
+ )
16
+ self.model.eval()
17
+
18
+ def __call__(self, data: Dict) -> Dict:
19
+ """Handle a request.
20
+ Args:
21
+ data (Dict): Input data for the request.
22
+ Expected format:
23
+ {
24
+ "inputs": str,
25
+ "parameters": {
26
+ "max_new_tokens": int,
27
+ "temperature": float,
28
+ "top_p": float,
29
+ "top_k": int,
30
+ "repetition_penalty": float,
31
+ "do_sample": bool
32
+ }
33
+ }
34
+ Returns:
35
+ Dict: Response data.
36
+ Format:
37
+ {
38
+ "generated_text": str
39
+ }
40
+ """
41
+ # Extract inputs and parameters
42
+ inputs = data.pop("inputs", data)
43
+ parameters = data.pop("parameters", {})
44
+
45
+ # Set default parameters if not provided
46
+ max_new_tokens = parameters.get("max_new_tokens", 100)
47
+ temperature = parameters.get("temperature", 0.7)
48
+ top_p = parameters.get("top_p", 0.95)
49
+ top_k = parameters.get("top_k", 50)
50
+ repetition_penalty = parameters.get("repetition_penalty", 1.1)
51
+ do_sample = parameters.get("do_sample", True)
52
+
53
+ # Tokenize inputs
54
+ input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
55
+ input_ids = input_ids.to(self.model.device)
56
+
57
+ # Generate
58
+ with torch.no_grad():
59
+ outputs = self.model.generate(
60
+ input_ids,
61
+ max_new_tokens=max_new_tokens,
62
+ temperature=temperature,
63
+ top_p=top_p,
64
+ top_k=top_k,
65
+ repetition_penalty=repetition_penalty,
66
+ do_sample=do_sample,
67
+ pad_token_id=self.tokenizer.eos_token_id
68
+ )
69
+
70
+ # Decode and return the generated text
71
+ generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
72
+
73
+ return {"generated_text": generated_text}
74
+
75
+ def stream(self, data: Dict) -> Dict:
76
+ """Handle a streaming request.
77
+ Args:
78
+ data (Dict): Same format as __call__
79
+ Returns:
80
+ Iterator[Dict]: Stream of responses.
81
+ """
82
+ # Extract inputs and parameters
83
+ inputs = data.pop("inputs", data)
84
+ parameters = data.pop("parameters", {})
85
+
86
+ # Set default parameters
87
+ max_new_tokens = parameters.get("max_new_tokens", 100)
88
+ temperature = parameters.get("temperature", 0.7)
89
+ top_p = parameters.get("top_p", 0.95)
90
+ top_k = parameters.get("top_k", 50)
91
+ repetition_penalty = parameters.get("repetition_penalty", 1.1)
92
+ do_sample = parameters.get("do_sample", True)
93
+
94
+ # Tokenize inputs
95
+ input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids
96
+ input_ids = input_ids.to(self.model.device)
97
+
98
+ # Create streamer
99
+ streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=True)
100
+
101
+ # Generate in a separate thread
102
+ generation_kwargs = dict(
103
+ input_ids=input_ids,
104
+ max_new_tokens=max_new_tokens,
105
+ temperature=temperature,
106
+ top_p=top_p,
107
+ top_k=top_k,
108
+ repetition_penalty=repetition_penalty,
109
+ do_sample=do_sample,
110
+ pad_token_id=self.tokenizer.eos_token_id,
111
+ streamer=streamer,
112
+ )
113
+
114
+ thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
115
+ thread.start()
116
+
117
+ # Stream the output
118
+ for text in streamer:
119
+ yield {"token": {"text": text}}