brandonbaek commited on
Commit
c4e9c17
·
verified ·
1 Parent(s): e9213f8

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +226 -0
inference.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import tiktoken
3
+ import json
4
+ from typing import Dict, Optional
5
+
6
+ # Model Architecture Classes
7
+ class Config:
8
+ def __init__(self):
9
+ self.vocab_size = 100283
10
+ self.max_position_embeddings = 1024
11
+ self.hidden_size = 768
12
+ self.num_layers = 6
13
+ self.num_heads = 12
14
+ self.intermediate_size = 3072
15
+ self.dropout = 0.1
16
+
17
+ class AttentionHead(torch.nn.Module):
18
+ def __init__(self, config: Config):
19
+ super().__init__()
20
+ self.head_dim = config.hidden_size // config.num_heads
21
+ self.query = torch.nn.Linear(config.hidden_size, self.head_dim)
22
+ self.key = torch.nn.Linear(config.hidden_size, self.head_dim)
23
+ self.value = torch.nn.Linear(config.hidden_size, self.head_dim)
24
+
25
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
26
+ Q = self.query(x)
27
+ K = self.key(x)
28
+ V = self.value(x)
29
+
30
+ scale = Q.size(-1) ** 0.5
31
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / scale
32
+
33
+ if mask is not None:
34
+ scores = scores.masked_fill(mask == 0, float('-inf'))
35
+
36
+ attention = torch.nn.functional.softmax(scores, dim=-1)
37
+ return torch.matmul(attention, V)
38
+
39
+ class MultiHeadAttention(torch.nn.Module):
40
+ def __init__(self, config: Config):
41
+ super().__init__()
42
+ self.heads = torch.nn.ModuleList([AttentionHead(config) for _ in range(config.num_heads)])
43
+ self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
44
+ self.dropout = torch.nn.Dropout(config.dropout)
45
+
46
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
47
+ heads = [head(x, mask) for head in self.heads]
48
+ multihead = torch.cat(heads, dim=-1)
49
+ return self.dropout(self.linear(multihead))
50
+
51
+ class TransformerBlock(torch.nn.Module):
52
+ def __init__(self, config: Config):
53
+ super().__init__()
54
+ self.attention = MultiHeadAttention(config)
55
+ self.norm1 = torch.nn.LayerNorm(config.hidden_size)
56
+ self.norm2 = torch.nn.LayerNorm(config.hidden_size)
57
+ self.feed_forward = torch.nn.Sequential(
58
+ torch.nn.Linear(config.hidden_size, config.intermediate_size),
59
+ torch.nn.GELU(),
60
+ torch.nn.Linear(config.intermediate_size, config.hidden_size),
61
+ torch.nn.Dropout(config.dropout)
62
+ )
63
+
64
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
65
+ attended = self.attention(x, mask)
66
+ x = self.norm1(x + attended)
67
+ fed_forward = self.feed_forward(x)
68
+ return self.norm2(x + fed_forward)
69
+
70
+ class SmallLanguageModel(torch.nn.Module):
71
+ def __init__(self, config: Config):
72
+ super().__init__()
73
+ self.config = config
74
+ self.token_embedding = torch.nn.Embedding(config.vocab_size, config.hidden_size)
75
+ self.position_embedding = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)
76
+ self.transformer_blocks = torch.nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
77
+ self.dropout = torch.nn.Dropout(config.dropout)
78
+ self.ln_f = torch.nn.LayerNorm(config.hidden_size)
79
+ self.head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
80
+ self.apply(self._init_weights)
81
+
82
+ def _init_weights(self, module):
83
+ if isinstance(module, (torch.nn.Linear, torch.nn.Embedding)):
84
+ module.weight.data.normal_(mean=0.0, std=0.02)
85
+ if isinstance(module, torch.nn.Linear) and module.bias is not None:
86
+ module.bias.data.zero_()
87
+ elif isinstance(module, torch.nn.LayerNorm):
88
+ module.bias.data.zero_()
89
+ module.weight.data.fill_(1.0)
90
+
91
+ def get_causal_mask(self, size: int) -> torch.Tensor:
92
+ mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
93
+ return ~mask
94
+
95
+ def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
96
+ b, t = input_ids.size()
97
+ positions = torch.arange(0, t, dtype=torch.long, device=input_ids.device)
98
+ mask = self.get_causal_mask(t).to(input_ids.device)
99
+ token_embeddings = self.token_embedding(input_ids)
100
+ position_embeddings = self.position_embedding(positions)
101
+ x = self.dropout(token_embeddings + position_embeddings)
102
+ for block in self.transformer_blocks:
103
+ x = block(x, mask)
104
+ x = self.ln_f(x)
105
+ logits = self.head(x)
106
+ return logits
107
+
108
+ # Text Generator Class
109
+ class TextGenerator:
110
+ def __init__(self, model, tokenizer):
111
+ self.model = model
112
+ self.model.eval()
113
+ self.tokenizer = tokenizer
114
+
115
+ @torch.no_grad()
116
+ def generate(
117
+ self,
118
+ prompt: str,
119
+ max_length: int = 100,
120
+ temperature: float = 0.7,
121
+ top_k: int = 50,
122
+ top_p: float = 0.9
123
+ ) -> Dict[str, str]:
124
+ try:
125
+ input_ids = torch.tensor(self.tokenizer.encode(
126
+ prompt,
127
+ allowed_special={'<user>', '</user>', '<assistant>', '</assistant>', '<system>', '</system>'}
128
+ )).unsqueeze(0).to(device)
129
+
130
+ for _ in range(max_length):
131
+ if input_ids.size(1) > config.max_position_embeddings:
132
+ input_ids = input_ids[:, -config.max_position_embeddings:]
133
+
134
+ logits = self.model(input_ids)
135
+ next_token_logits = logits[:, -1, :] / temperature
136
+
137
+ if top_k > 0:
138
+ values, _ = torch.topk(next_token_logits, top_k)
139
+ min_value = values[:, -1].unsqueeze(-1)
140
+ next_token_logits = torch.where(
141
+ next_token_logits < min_value,
142
+ torch.tensor(float('-inf')).to(device),
143
+ next_token_logits
144
+ )
145
+
146
+ if top_p < 1.0:
147
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
148
+ cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
149
+ sorted_indices_to_remove = cumulative_probs > top_p
150
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
151
+ sorted_indices_to_remove[..., 0] = 0
152
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
153
+ next_token_logits = next_token_logits.masked_fill(indices_to_remove, float('-inf'))
154
+
155
+ probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
156
+ next_token = torch.multinomial(probs, num_samples=1)
157
+ input_ids = torch.cat((input_ids, next_token), dim=1)
158
+
159
+ generated_text = self.tokenizer.decode(input_ids[0].tolist())
160
+ return {
161
+ "status": "success",
162
+ "generated_text": generated_text,
163
+ "prompt": prompt,
164
+ "max_length": max_length,
165
+ "temperature": temperature,
166
+ "top_k": top_k,
167
+ "top_p": top_p
168
+ }
169
+
170
+ except Exception as e:
171
+ return {
172
+ "status": "error",
173
+ "error_message": str(e),
174
+ "prompt": prompt
175
+ }
176
+
177
+ # Helper Function to Load Model and Tokenizer
178
+ def load_model_and_tokenizer(checkpoint_path: str) -> Tuple[SmallLanguageModel, tiktoken.Encoding]:
179
+ config = Config()
180
+ cl100k_base = tiktoken.get_encoding("cl100k_base")
181
+ tokenizer = tiktoken.Encoding(
182
+ name="cl100k_xml",
183
+ pat_str=cl100k_base._pat_str,
184
+ mergeable_ranks=cl100k_base._mergeable_ranks,
185
+ special_tokens={
186
+ **cl100k_base._special_tokens,
187
+ "<user>": 100277, "</user>": 100278,
188
+ "<assistant>": 100279, "</assistant>": 100280,
189
+ "<system>": 100281, "</system>": 100282
190
+ }
191
+ )
192
+ config.vocab_size = tokenizer.n_vocab
193
+
194
+ model = SmallLanguageModel(config)
195
+ checkpoint = torch.load(checkpoint_path, map_location=device)
196
+ model.load_state_dict(checkpoint['model_state_dict'])
197
+ model.to(device)
198
+
199
+ return model, tokenizer
200
+
201
+ # Main Function for Inference
202
+ def generate(
203
+ checkpoint_path: str,
204
+ prompt: str,
205
+ max_length: int = 100,
206
+ temperature: float = 0.7,
207
+ top_k: int = 50,
208
+ top_p: float = 0.9
209
+ ) -> Dict[str, str]:
210
+ global device, config
211
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
212
+
213
+ # Load model and tokenizer
214
+ model, tokenizer = load_model_and_tokenizer(checkpoint_path)
215
+
216
+ # Generate text
217
+ generator = TextGenerator(model, tokenizer)
218
+ result = generator.generate(
219
+ prompt=prompt,
220
+ max_length=max_length,
221
+ temperature=temperature,
222
+ top_k=top_k,
223
+ top_p=top_p
224
+ )
225
+
226
+ return result