Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
·
a2f6c58
1
Parent(s):
dd4aeba
yes
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ from typing import List, Dict, Optional, Any
|
|
11 |
from datetime import datetime
|
12 |
import gradio as gr
|
13 |
import spaces # required for ZeroGPU
|
14 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
15 |
|
16 |
# Import Harmony components
|
17 |
try:
|
@@ -47,7 +47,7 @@ ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
|
|
47 |
LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
|
48 |
|
49 |
# Harmony channels for CoT
|
50 |
-
REQUIRED_CHANNELS = ["analysis", "final"]
|
51 |
|
52 |
# HF Auth - properly handle multiple token env var names
|
53 |
HF_TOKEN: Optional[str] = (
|
@@ -138,7 +138,6 @@ def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
|
|
138 |
# peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
|
139 |
# model = PeftModel.from_pretrained(model, ADAPTER_ID, is_trainable=False, **peft_kwargs)
|
140 |
|
141 |
-
|
142 |
model.eval()
|
143 |
# Ensure a valid pad_token_id is set; some OSS checkpoints reuse eos as pad
|
144 |
if getattr(model.config, "pad_token_id", None) is None:
|
@@ -190,15 +189,7 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
|
|
190 |
)
|
191 |
|
192 |
convo = Conversation.from_messages(harmony_messages)
|
193 |
-
|
194 |
-
# Ensure assistant header includes a final channel + message start to avoid 'assistantassistant...' loops
|
195 |
-
try:
|
196 |
-
_tail = tokenizer.decode(list(rendered)[-64:], skip_special_tokens=False)
|
197 |
-
if '<|channel|>final<|message|>' not in _tail:
|
198 |
-
rendered = list(rendered) + tokenizer.encode('<|channel|>final<|message|>', add_special_tokens=False)
|
199 |
-
except Exception:
|
200 |
-
rendered = list(rendered)
|
201 |
-
return rendered
|
202 |
|
203 |
# Fallback: tokenizer chat template -> string prompt
|
204 |
if not messages or messages[0].get("role") != "system":
|
@@ -282,7 +273,7 @@ def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor
|
|
282 |
for t in tid:
|
283 |
if isinstance(t, int) and t >= 0:
|
284 |
bias[t] += float(w) / max(1, len(tid))
|
285 |
-
elif isinstance(tid, int) and
|
286 |
bias[tid] += float(w)
|
287 |
return bias
|
288 |
|
@@ -295,12 +286,6 @@ class RoseGuidedLogits(torch.nn.Module):
|
|
295 |
def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
296 |
return scores + self.alpha * self.bias_vec.to(scores.device)
|
297 |
|
298 |
-
class StopOnTokens(StoppingCriteria):
|
299 |
-
def __init__(self, stop_ids: List[int]):
|
300 |
-
self.stop_ids = set(int(s) for s in (stop_ids or []))
|
301 |
-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
|
302 |
-
return int(input_ids[0, -1]) in self.stop_ids
|
303 |
-
|
304 |
@spaces.GPU(duration=120)
|
305 |
def zerogpu_generate(full_prompt,
|
306 |
gen_kwargs: Dict[str, Any],
|
@@ -325,42 +310,21 @@ def zerogpu_generate(full_prompt,
|
|
325 |
|
326 |
# Tokenize / prepare inputs
|
327 |
device = next(model.parameters()).device
|
328 |
-
if HARMONY_AVAILABLE and
|
329 |
-
|
330 |
-
try:
|
331 |
-
token_list = list(full_prompt)
|
332 |
-
except TypeError:
|
333 |
-
token_list = list(getattr(full_prompt, "ids", getattr(full_prompt, "token_ids", [])))
|
334 |
-
if not token_list:
|
335 |
-
raise ValueError("Harmony prompt produced no tokens")
|
336 |
-
input_ids = torch.tensor([token_list], dtype=torch.long, device=device)
|
337 |
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
|
338 |
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
339 |
prompt_len = input_ids.shape[1]
|
340 |
else:
|
341 |
enc = tokenizer(full_prompt, return_tensors="pt")
|
342 |
-
inputs =
|
343 |
prompt_len = int(inputs["input_ids"].shape[1])
|
|
|
344 |
if "attention_mask" not in inputs:
|
345 |
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
|
346 |
-
|
347 |
-
# Prepare stopping
|
348 |
-
sc = None
|
349 |
-
if HARMONY_AVAILABLE and HARMONY_STOP_IDS:
|
350 |
-
sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)])
|
351 |
-
|
352 |
# Generate
|
353 |
-
#
|
354 |
-
|
355 |
-
try:
|
356 |
-
_B = []
|
357 |
-
for s in ("assistantassistant", "assistant", "<|assistant|>"):
|
358 |
-
ids = tokenizer.encode(s, add_special_tokens=False)
|
359 |
-
if ids:
|
360 |
-
_B.append(ids)
|
361 |
-
bad_words_ids = _B if _B else None
|
362 |
-
except Exception:
|
363 |
-
pass
|
364 |
|
365 |
out_ids = model.generate(
|
366 |
**inputs,
|
@@ -370,12 +334,11 @@ def zerogpu_generate(full_prompt,
|
|
370 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
371 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
372 |
pad_token_id=model.config.pad_token_id,
|
373 |
-
eos_token_id=
|
374 |
-
bad_words_ids=bad_words_ids,
|
375 |
logits_processor=logits_processor,
|
376 |
-
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.
|
377 |
-
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size",
|
378 |
-
|
379 |
)
|
380 |
|
381 |
# Extract generated tokens only
|
@@ -421,93 +384,6 @@ def zerogpu_generate(full_prompt,
|
|
421 |
if torch.cuda.is_available():
|
422 |
torch.cuda.empty_cache()
|
423 |
|
424 |
-
# -----------------------
|
425 |
-
# GPU Debug: Harmony Inspector
|
426 |
-
# -----------------------
|
427 |
-
@spaces.GPU(duration=120)
|
428 |
-
def zerogpu_generate_debug(full_prompt, gen_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
429 |
-
"""Minimal GPU path to run a single prompt and return Harmony-parsed output
|
430 |
-
along with short token previews for debugging. Does not use Rose for clarity."""
|
431 |
-
model = None
|
432 |
-
try:
|
433 |
-
model = _load_model_on("auto")
|
434 |
-
device = next(model.parameters()).device
|
435 |
-
|
436 |
-
# Prepare inputs (tokens if Harmony renderer used, else string -> encode)
|
437 |
-
if HARMONY_AVAILABLE and not isinstance(full_prompt, str):
|
438 |
-
token_list = list(full_prompt)
|
439 |
-
if not token_list:
|
440 |
-
raise ValueError("Harmony prompt produced no tokens")
|
441 |
-
input_ids = torch.tensor([token_list], dtype=torch.long, device=device)
|
442 |
-
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
|
443 |
-
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
444 |
-
prompt_len = input_ids.shape[1]
|
445 |
-
else:
|
446 |
-
enc = tokenizer(full_prompt, return_tensors="pt")
|
447 |
-
inputs = {k: v.to(device) for k, v in enc.items()}
|
448 |
-
if "attention_mask" not in inputs:
|
449 |
-
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
|
450 |
-
prompt_len = int(inputs["input_ids"].shape[1])
|
451 |
-
|
452 |
-
# Harmony stop via stopping criteria
|
453 |
-
sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)]) if (HARMONY_AVAILABLE and HARMONY_STOP_IDS) else None
|
454 |
-
|
455 |
-
out_ids = model.generate(
|
456 |
-
**inputs,
|
457 |
-
do_sample=bool(gen_kwargs.get("do_sample", True)),
|
458 |
-
temperature=float(gen_kwargs.get("temperature", 0.7)),
|
459 |
-
top_p=float(gen_kwargs.get("top_p", 0.9)),
|
460 |
-
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
461 |
-
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
462 |
-
pad_token_id=model.config.pad_token_id,
|
463 |
-
eos_token_id=tokenizer.eos_token_id,
|
464 |
-
bad_words_ids=bad_words_ids,
|
465 |
-
stopping_criteria=sc,
|
466 |
-
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.15)),
|
467 |
-
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
|
468 |
-
)
|
469 |
-
|
470 |
-
out_list = out_ids[0].tolist()
|
471 |
-
gen_ids = out_list[prompt_len:]
|
472 |
-
# Truncate at first Harmony stop token if present
|
473 |
-
if HARMONY_AVAILABLE and HARMONY_STOP_IDS:
|
474 |
-
for sid in HARMONY_STOP_IDS:
|
475 |
-
if sid in gen_ids:
|
476 |
-
gen_ids = gen_ids[:gen_ids.index(sid)]
|
477 |
-
break
|
478 |
-
|
479 |
-
# Parse channels
|
480 |
-
if HARMONY_AVAILABLE:
|
481 |
-
try:
|
482 |
-
channels = parse_harmony_response(gen_ids)
|
483 |
-
except Exception:
|
484 |
-
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
485 |
-
channels = {"final": extract_final_channel_fallback(decoded), "raw": decoded}
|
486 |
-
else:
|
487 |
-
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
488 |
-
channels = {"final": extract_final_channel_fallback(decoded), "raw": decoded}
|
489 |
-
|
490 |
-
# Small previews (avoid flooding logs/UI)
|
491 |
-
preview = {
|
492 |
-
"prompt_len": int(prompt_len),
|
493 |
-
"stop_ids": list(HARMONY_STOP_IDS) if HARMONY_AVAILABLE else [],
|
494 |
-
"gen_len": int(len(gen_ids)),
|
495 |
-
"gen_ids_head": gen_ids[:48],
|
496 |
-
"decoded_head": tokenizer.decode(gen_ids[:256], skip_special_tokens=False),
|
497 |
-
"channels": channels,
|
498 |
-
}
|
499 |
-
return preview
|
500 |
-
except Exception as e:
|
501 |
-
return {"error": f"{type(e).__name__}: {e}"}
|
502 |
-
finally:
|
503 |
-
try:
|
504 |
-
del model
|
505 |
-
except Exception:
|
506 |
-
pass
|
507 |
-
gc.collect()
|
508 |
-
if torch.cuda.is_available():
|
509 |
-
torch.cuda.empty_cache()
|
510 |
-
|
511 |
# -----------------------
|
512 |
# Gradio handlers
|
513 |
# -----------------------
|
@@ -605,21 +481,6 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
|
|
605 |
except Exception as e:
|
606 |
return f"[Error] {type(e).__name__}: {str(e)}"
|
607 |
|
608 |
-
# -----------------------
|
609 |
-
# Extra handler: Harmony Inspector wrapper
|
610 |
-
# -----------------------
|
611 |
-
|
612 |
-
def harmony_inspect_handler(user_prompt: str, system_prompt: str, reasoning_effort: str):
|
613 |
-
try:
|
614 |
-
msgs = [{"role": "system", "content": system_prompt or SYSTEM_DEF}, {"role": "user", "content": user_prompt or "What is 2+2?"}]
|
615 |
-
prompt = create_harmony_prompt(msgs, reasoning_effort)
|
616 |
-
return zerogpu_generate_debug(
|
617 |
-
prompt,
|
618 |
-
{"do_sample": True, "temperature": 0.7, "top_p": 0.9, "top_k": 0, "max_new_tokens": MAX_DEF}
|
619 |
-
)
|
620 |
-
except Exception as e:
|
621 |
-
return {"error": f"{type(e).__name__}: {e}"}
|
622 |
-
|
623 |
# -----------------------
|
624 |
# UI
|
625 |
# -----------------------
|
@@ -681,13 +542,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
681 |
value=""
|
682 |
)
|
683 |
|
684 |
-
# --- Harmony Inspector UI ---
|
685 |
-
with gr.Accordion("Harmony Inspector", open=False):
|
686 |
-
debug_prompt = gr.Textbox(label="Debug prompt", value="What is 2+2? Reply with just the number.")
|
687 |
-
run_debug = gr.Button("Run Harmony Inspect")
|
688 |
-
debug_out = gr.JSON(label="Parsed Harmony output", value={})
|
689 |
-
run_debug.click(harmony_inspect_handler, inputs=[debug_prompt, system_prompt, reasoning_effort], outputs=[debug_out])
|
690 |
-
|
691 |
# Chat interface - using only valid parameters
|
692 |
chat = gr.ChatInterface(
|
693 |
fn=generate_response,
|
@@ -697,7 +551,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
697 |
do_sample, seed, rose_enable, rose_alpha, rose_score,
|
698 |
rose_tokens, rose_json, show_thinking, reasoning_effort
|
699 |
],
|
700 |
-
|
701 |
title="Chat with Mirel",
|
702 |
description="A chain-of-thought model using Harmony format",
|
703 |
examples=[
|
|
|
11 |
from datetime import datetime
|
12 |
import gradio as gr
|
13 |
import spaces # required for ZeroGPU
|
14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
15 |
|
16 |
# Import Harmony components
|
17 |
try:
|
|
|
47 |
LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
|
48 |
|
49 |
# Harmony channels for CoT
|
50 |
+
REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
|
51 |
|
52 |
# HF Auth - properly handle multiple token env var names
|
53 |
HF_TOKEN: Optional[str] = (
|
|
|
138 |
# peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
|
139 |
# model = PeftModel.from_pretrained(model, ADAPTER_ID, is_trainable=False, **peft_kwargs)
|
140 |
|
|
|
141 |
model.eval()
|
142 |
# Ensure a valid pad_token_id is set; some OSS checkpoints reuse eos as pad
|
143 |
if getattr(model.config, "pad_token_id", None) is None:
|
|
|
189 |
)
|
190 |
|
191 |
convo = Conversation.from_messages(harmony_messages)
|
192 |
+
return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
# Fallback: tokenizer chat template -> string prompt
|
195 |
if not messages or messages[0].get("role") != "system":
|
|
|
273 |
for t in tid:
|
274 |
if isinstance(t, int) and t >= 0:
|
275 |
bias[t] += float(w) / max(1, len(tid))
|
276 |
+
elif isinstance(tid, int) and tid >= 0:
|
277 |
bias[tid] += float(w)
|
278 |
return bias
|
279 |
|
|
|
286 |
def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
287 |
return scores + self.alpha * self.bias_vec.to(scores.device)
|
288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
@spaces.GPU(duration=120)
|
290 |
def zerogpu_generate(full_prompt,
|
291 |
gen_kwargs: Dict[str, Any],
|
|
|
310 |
|
311 |
# Tokenize / prepare inputs
|
312 |
device = next(model.parameters()).device
|
313 |
+
if HARMONY_AVAILABLE and isinstance(full_prompt, list):
|
314 |
+
input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
|
316 |
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
317 |
prompt_len = input_ids.shape[1]
|
318 |
else:
|
319 |
enc = tokenizer(full_prompt, return_tensors="pt")
|
320 |
+
inputs = enc.to(device)
|
321 |
prompt_len = int(inputs["input_ids"].shape[1])
|
322 |
+
# Guarantee attention_mask exists; avoids pad==eos ambiguity warnings
|
323 |
if "attention_mask" not in inputs:
|
324 |
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
# Generate
|
326 |
+
# Build EOS list: use ONLY Harmony assistant-action stops (per OpenAI docs)
|
327 |
+
eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
out_ids = model.generate(
|
330 |
**inputs,
|
|
|
334 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
335 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
336 |
pad_token_id=model.config.pad_token_id,
|
337 |
+
eos_token_id=eos_ids,
|
|
|
338 |
logits_processor=logits_processor,
|
339 |
+
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
|
340 |
+
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
|
341 |
+
min_new_tokens=1,
|
342 |
)
|
343 |
|
344 |
# Extract generated tokens only
|
|
|
384 |
if torch.cuda.is_available():
|
385 |
torch.cuda.empty_cache()
|
386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
# -----------------------
|
388 |
# Gradio handlers
|
389 |
# -----------------------
|
|
|
481 |
except Exception as e:
|
482 |
return f"[Error] {type(e).__name__}: {str(e)}"
|
483 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
484 |
# -----------------------
|
485 |
# UI
|
486 |
# -----------------------
|
|
|
542 |
value=""
|
543 |
)
|
544 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
# Chat interface - using only valid parameters
|
546 |
chat = gr.ChatInterface(
|
547 |
fn=generate_response,
|
|
|
551 |
do_sample, seed, rose_enable, rose_alpha, rose_score,
|
552 |
rose_tokens, rose_json, show_thinking, reasoning_effort
|
553 |
],
|
|
|
554 |
title="Chat with Mirel",
|
555 |
description="A chain-of-thought model using Harmony format",
|
556 |
examples=[
|