Will Held
commited on
Commit
·
b7ca827
1
Parent(s):
259eb63
Add Stream
Browse files- modeling_diva.py +65 -0
modeling_diva.py
CHANGED
@@ -243,3 +243,68 @@ class DiVAModel(PreTrainedModel):
|
|
243 |
return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
244 |
"<|eot_id|>", ""
|
245 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
243 |
return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
244 |
"<|eot_id|>", ""
|
245 |
)
|
246 |
+
|
247 |
+
def generate_stream(
|
248 |
+
self, audio, text_prompt, do_sample=False, logits_processor=None, max_new_tokens=128
|
249 |
+
):
|
250 |
+
inputs = self.processor(audio, return_tensors="pt", sampling_rate=16_000)
|
251 |
+
input_features = inputs.input_features.to(self.speech_encoder_device)
|
252 |
+
hidden_states = self.whisper_encoder(input_features=input_features)[
|
253 |
+
"last_hidden_state"
|
254 |
+
]
|
255 |
+
virt_tokens = self.connector(
|
256 |
+
hidden_states,
|
257 |
+
output_device=self.llama_decoder.model.embed_tokens.weight.device,
|
258 |
+
).squeeze()
|
259 |
+
|
260 |
+
if text_prompt != None and text_prompt != "":
|
261 |
+
user_prompt_text = torch.tensor(
|
262 |
+
self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"],
|
263 |
+
device=self.pre_user_suffix.device,
|
264 |
+
)
|
265 |
+
prefix = torch.cat(
|
266 |
+
[self.pre_user_suffix, user_prompt_text, self.prefix], axis=0
|
267 |
+
)
|
268 |
+
else:
|
269 |
+
prefix = self.prefix
|
270 |
+
prefix_embed = self.llama_decoder.model.embed_tokens(prefix)
|
271 |
+
suffix = self.final_header
|
272 |
+
suffix_embed = self.llama_decoder.model.embed_tokens(suffix)
|
273 |
+
inputs_embeds = torch.cat(
|
274 |
+
[prefix_embed, virt_tokens, suffix_embed], axis=0
|
275 |
+
).unsqueeze(0)
|
276 |
+
outs = []
|
277 |
+
outputs = None
|
278 |
+
greedy = 1
|
279 |
+
i = 0
|
280 |
+
while greedy != 128009 and len(outs) < max_new_tokens:
|
281 |
+
past_key_values = outputs.past_key_values if outputs else None
|
282 |
+
outputs = self.llama_decoder(
|
283 |
+
inputs_embeds=inputs_embeds.to(
|
284 |
+
self.llama_decoder.model.embed_tokens.weight.device
|
285 |
+
).half(),
|
286 |
+
return_dict=True,
|
287 |
+
output_hidden_states=True,
|
288 |
+
past_key_values=past_key_values,
|
289 |
+
)
|
290 |
+
next_token_logits = outputs.logits[-1, -1, :]
|
291 |
+
|
292 |
+
if logits_processor:
|
293 |
+
local_outs = torch.tensor(outs) if outs != [] else suffix
|
294 |
+
local_outs = local_outs.reshape(1, -1)
|
295 |
+
next_token_logits = logits_processor(
|
296 |
+
local_outs,
|
297 |
+
next_token_logits.reshape(1, -1),
|
298 |
+
)
|
299 |
+
next_token_logits = next_token_logits.flatten()
|
300 |
+
if do_sample:
|
301 |
+
logits = next_token_logits / temperature
|
302 |
+
probs = F.softmax(logits, dim=-1)
|
303 |
+
greedy = torch.multinomial(probs, num_samples=1)[0]
|
304 |
+
else:
|
305 |
+
greedy = next_token_logits.argmax()
|
306 |
+
outs.append(greedy)
|
307 |
+
next_embed = self.llama_decoder.model.embed_tokens(greedy.reshape(1, 1))
|
308 |
+
inputs_embeds = next_embed
|
309 |
+
yield tokenizer.decode(outs).replace("<|eot_id|>", "")
|
310 |
+
return tokenizer.decode(outs).replace("<|eot_id|>", "")
|