Upload inference.py
Browse files- inference.py +226 -0
inference.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import tiktoken
|
3 |
+
import json
|
4 |
+
from typing import Dict, Optional
|
5 |
+
|
6 |
+
# Model Architecture Classes
|
7 |
+
class Config:
|
8 |
+
def __init__(self):
|
9 |
+
self.vocab_size = 100283
|
10 |
+
self.max_position_embeddings = 1024
|
11 |
+
self.hidden_size = 768
|
12 |
+
self.num_layers = 6
|
13 |
+
self.num_heads = 12
|
14 |
+
self.intermediate_size = 3072
|
15 |
+
self.dropout = 0.1
|
16 |
+
|
17 |
+
class AttentionHead(torch.nn.Module):
|
18 |
+
def __init__(self, config: Config):
|
19 |
+
super().__init__()
|
20 |
+
self.head_dim = config.hidden_size // config.num_heads
|
21 |
+
self.query = torch.nn.Linear(config.hidden_size, self.head_dim)
|
22 |
+
self.key = torch.nn.Linear(config.hidden_size, self.head_dim)
|
23 |
+
self.value = torch.nn.Linear(config.hidden_size, self.head_dim)
|
24 |
+
|
25 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
26 |
+
Q = self.query(x)
|
27 |
+
K = self.key(x)
|
28 |
+
V = self.value(x)
|
29 |
+
|
30 |
+
scale = Q.size(-1) ** 0.5
|
31 |
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
|
32 |
+
|
33 |
+
if mask is not None:
|
34 |
+
scores = scores.masked_fill(mask == 0, float('-inf'))
|
35 |
+
|
36 |
+
attention = torch.nn.functional.softmax(scores, dim=-1)
|
37 |
+
return torch.matmul(attention, V)
|
38 |
+
|
39 |
+
class MultiHeadAttention(torch.nn.Module):
|
40 |
+
def __init__(self, config: Config):
|
41 |
+
super().__init__()
|
42 |
+
self.heads = torch.nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)])
|
43 |
+
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
44 |
+
self.dropout = torch.nn.Dropout(config.dropout)
|
45 |
+
|
46 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
47 |
+
heads = [head(x, mask) for head in self.heads]
|
48 |
+
multihead = torch.cat(heads, dim=-1)
|
49 |
+
return self.dropout(self.linear(multihead))
|
50 |
+
|
51 |
+
class TransformerBlock(torch.nn.Module):
|
52 |
+
def __init__(self, config: Config):
|
53 |
+
super().__init__()
|
54 |
+
self.attention = MultiHeadAttention(config)
|
55 |
+
self.norm1 = torch.nn.LayerNorm(config.hidden_size)
|
56 |
+
self.norm2 = torch.nn.LayerNorm(config.hidden_size)
|
57 |
+
self.feed_forward = torch.nn.Sequential(
|
58 |
+
torch.nn.Linear(config.hidden_size, config.intermediate_size),
|
59 |
+
torch.nn.GELU(),
|
60 |
+
torch.nn.Linear(config.intermediate_size, config.hidden_size),
|
61 |
+
torch.nn.Dropout(config.dropout)
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
65 |
+
attended = self.attention(x, mask)
|
66 |
+
x = self.norm1(x + attended)
|
67 |
+
fed_forward = self.feed_forward(x)
|
68 |
+
return self.norm2(x + fed_forward)
|
69 |
+
|
70 |
+
class SmallLanguageModel(torch.nn.Module):
|
71 |
+
def __init__(self, config: Config):
|
72 |
+
super().__init__()
|
73 |
+
self.config = config
|
74 |
+
self.token_embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
|
75 |
+
self.position_embedding = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
76 |
+
self.transformer_blocks = torch.nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
|
77 |
+
self.dropout = torch.nn.Dropout(config.dropout)
|
78 |
+
self.ln_f = torch.nn.LayerNorm(config.hidden_size)
|
79 |
+
self.head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
80 |
+
self.apply(self._init_weights)
|
81 |
+
|
82 |
+
def _init_weights(self, module):
|
83 |
+
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
|
84 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
85 |
+
if isinstance(module, torch.nn.Linear) and module.bias is not None:
|
86 |
+
module.bias.data.zero_()
|
87 |
+
elif isinstance(module, torch.nn.LayerNorm):
|
88 |
+
module.bias.data.zero_()
|
89 |
+
module.weight.data.fill_(1.0)
|
90 |
+
|
91 |
+
def get_causal_mask(self, size: int) -> torch.Tensor:
|
92 |
+
mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
|
93 |
+
return ~mask
|
94 |
+
|
95 |
+
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
|
96 |
+
b, t = input_ids.size()
|
97 |
+
positions = torch.arange(0, t, dtype=torch.long, device=input_ids.device)
|
98 |
+
mask = self.get_causal_mask(t).to(input_ids.device)
|
99 |
+
token_embeddings = self.token_embedding(input_ids)
|
100 |
+
position_embeddings = self.position_embedding(positions)
|
101 |
+
x = self.dropout(token_embeddings + position_embeddings)
|
102 |
+
for block in self.transformer_blocks:
|
103 |
+
x = block(x, mask)
|
104 |
+
x = self.ln_f(x)
|
105 |
+
logits = self.head(x)
|
106 |
+
return logits
|
107 |
+
|
108 |
+
# Text Generator Class
|
109 |
+
class TextGenerator:
|
110 |
+
def __init__(self, model, tokenizer):
|
111 |
+
self.model = model
|
112 |
+
self.model.eval()
|
113 |
+
self.tokenizer = tokenizer
|
114 |
+
|
115 |
+
@torch.no_grad()
|
116 |
+
def generate(
|
117 |
+
self,
|
118 |
+
prompt: str,
|
119 |
+
max_length: int = 100,
|
120 |
+
temperature: float = 0.7,
|
121 |
+
top_k: int = 50,
|
122 |
+
top_p: float = 0.9
|
123 |
+
) -> Dict[str, str]:
|
124 |
+
try:
|
125 |
+
input_ids = torch.tensor(self.tokenizer.encode(
|
126 |
+
prompt,
|
127 |
+
allowed_special={'<user>', '</user>', '<assistant>', '</assistant>', '<system>', '</system>'}
|
128 |
+
)).unsqueeze(0).to(device)
|
129 |
+
|
130 |
+
for _ in range(max_length):
|
131 |
+
if input_ids.size(1) > config.max_position_embeddings:
|
132 |
+
input_ids = input_ids[:, -config.max_position_embeddings:]
|
133 |
+
|
134 |
+
logits = self.model(input_ids)
|
135 |
+
next_token_logits = logits[:, -1, :] / temperature
|
136 |
+
|
137 |
+
if top_k > 0:
|
138 |
+
values, _ = torch.topk(next_token_logits, top_k)
|
139 |
+
min_value = values[:, -1].unsqueeze(-1)
|
140 |
+
next_token_logits = torch.where(
|
141 |
+
next_token_logits < min_value,
|
142 |
+
torch.tensor(float('-inf')).to(device),
|
143 |
+
next_token_logits
|
144 |
+
)
|
145 |
+
|
146 |
+
if top_p < 1.0:
|
147 |
+
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
148 |
+
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
149 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
150 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
151 |
+
sorted_indices_to_remove[..., 0] = 0
|
152 |
+
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
153 |
+
next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))
|
154 |
+
|
155 |
+
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
|
156 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
157 |
+
input_ids = torch.cat((input_ids, next_token), dim=1)
|
158 |
+
|
159 |
+
generated_text = self.tokenizer.decode(input_ids[0].tolist())
|
160 |
+
return {
|
161 |
+
"status": "success",
|
162 |
+
"generated_text": generated_text,
|
163 |
+
"prompt": prompt,
|
164 |
+
"max_length": max_length,
|
165 |
+
"temperature": temperature,
|
166 |
+
"top_k": top_k,
|
167 |
+
"top_p": top_p
|
168 |
+
}
|
169 |
+
|
170 |
+
except Exception as e:
|
171 |
+
return {
|
172 |
+
"status": "error",
|
173 |
+
"error_message": str(e),
|
174 |
+
"prompt": prompt
|
175 |
+
}
|
176 |
+
|
177 |
+
# Helper Function to Load Model and Tokenizer
|
178 |
+
def load_model_and_tokenizer(checkpoint_path: str) -> Tuple[SmallLanguageModel, tiktoken.Encoding]:
|
179 |
+
config = Config()
|
180 |
+
cl100k_base = tiktoken.get_encoding("cl100k_base")
|
181 |
+
tokenizer = tiktoken.Encoding(
|
182 |
+
name="cl100k_xml",
|
183 |
+
pat_str=cl100k_base._pat_str,
|
184 |
+
mergeable_ranks=cl100k_base._mergeable_ranks,
|
185 |
+
special_tokens={
|
186 |
+
**cl100k_base._special_tokens,
|
187 |
+
"<user>": 100277, "</user>": 100278,
|
188 |
+
"<assistant>": 100279, "</assistant>": 100280,
|
189 |
+
"<system>": 100281, "</system>": 100282
|
190 |
+
}
|
191 |
+
)
|
192 |
+
config.vocab_size = tokenizer.n_vocab
|
193 |
+
|
194 |
+
model = SmallLanguageModel(config)
|
195 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
196 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
197 |
+
model.to(device)
|
198 |
+
|
199 |
+
return model, tokenizer
|
200 |
+
|
201 |
+
# Main Function for Inference
|
202 |
+
def generate(
|
203 |
+
checkpoint_path: str,
|
204 |
+
prompt: str,
|
205 |
+
max_length: int = 100,
|
206 |
+
temperature: float = 0.7,
|
207 |
+
top_k: int = 50,
|
208 |
+
top_p: float = 0.9
|
209 |
+
) -> Dict[str, str]:
|
210 |
+
global device, config
|
211 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
212 |
+
|
213 |
+
# Load model and tokenizer
|
214 |
+
model, tokenizer = load_model_and_tokenizer(checkpoint_path)
|
215 |
+
|
216 |
+
# Generate text
|
217 |
+
generator = TextGenerator(model, tokenizer)
|
218 |
+
result = generator.generate(
|
219 |
+
prompt=prompt,
|
220 |
+
max_length=max_length,
|
221 |
+
temperature=temperature,
|
222 |
+
top_k=top_k,
|
223 |
+
top_p=top_p
|
224 |
+
)
|
225 |
+
|
226 |
+
return result
|