AbstractPhil commited on
Commit
a2f6c58
·
1 Parent(s): dd4aeba
Files changed (1) hide show
  1. app.py +14 -161
app.py CHANGED
@@ -11,7 +11,7 @@ from typing import List, Dict, Optional, Any
11
  from datetime import datetime
12
  import gradio as gr
13
  import spaces # required for ZeroGPU
14
- from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
15
 
16
  # Import Harmony components
17
  try:
@@ -47,7 +47,7 @@ ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
47
  LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
48
 
49
  # Harmony channels for CoT
50
- REQUIRED_CHANNELS = ["analysis", "final"]
51
 
52
  # HF Auth - properly handle multiple token env var names
53
  HF_TOKEN: Optional[str] = (
@@ -138,7 +138,6 @@ def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
138
  # peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
139
  # model = PeftModel.from_pretrained(model, ADAPTER_ID, is_trainable=False, **peft_kwargs)
140
 
141
-
142
  model.eval()
143
  # Ensure a valid pad_token_id is set; some OSS checkpoints reuse eos as pad
144
  if getattr(model.config, "pad_token_id", None) is None:
@@ -190,15 +189,7 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
190
  )
191
 
192
  convo = Conversation.from_messages(harmony_messages)
193
- rendered = harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
194
- # Ensure assistant header includes a final channel + message start to avoid 'assistantassistant...' loops
195
- try:
196
- _tail = tokenizer.decode(list(rendered)[-64:], skip_special_tokens=False)
197
- if '<|channel|>final<|message|>' not in _tail:
198
- rendered = list(rendered) + tokenizer.encode('<|channel|>final<|message|>', add_special_tokens=False)
199
- except Exception:
200
- rendered = list(rendered)
201
- return rendered
202
 
203
  # Fallback: tokenizer chat template -> string prompt
204
  if not messages or messages[0].get("role") != "system":
@@ -282,7 +273,7 @@ def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor
282
  for t in tid:
283
  if isinstance(t, int) and t >= 0:
284
  bias[t] += float(w) / max(1, len(tid))
285
- elif isinstance(tid, int) and t >= 0:
286
  bias[tid] += float(w)
287
  return bias
288
 
@@ -295,12 +286,6 @@ class RoseGuidedLogits(torch.nn.Module):
295
  def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
296
  return scores + self.alpha * self.bias_vec.to(scores.device)
297
 
298
- class StopOnTokens(StoppingCriteria):
299
- def __init__(self, stop_ids: List[int]):
300
- self.stop_ids = set(int(s) for s in (stop_ids or []))
301
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
302
- return int(input_ids[0, -1]) in self.stop_ids
303
-
304
  @spaces.GPU(duration=120)
305
  def zerogpu_generate(full_prompt,
306
  gen_kwargs: Dict[str, Any],
@@ -325,42 +310,21 @@ def zerogpu_generate(full_prompt,
325
 
326
  # Tokenize / prepare inputs
327
  device = next(model.parameters()).device
328
- if HARMONY_AVAILABLE and not isinstance(full_prompt, str):
329
- # Accept list/tuple or any iterable of ints from openai_harmony
330
- try:
331
- token_list = list(full_prompt)
332
- except TypeError:
333
- token_list = list(getattr(full_prompt, "ids", getattr(full_prompt, "token_ids", [])))
334
- if not token_list:
335
- raise ValueError("Harmony prompt produced no tokens")
336
- input_ids = torch.tensor([token_list], dtype=torch.long, device=device)
337
  attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
338
  inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
339
  prompt_len = input_ids.shape[1]
340
  else:
341
  enc = tokenizer(full_prompt, return_tensors="pt")
342
- inputs = {k: v.to(device) for k, v in enc.items()}
343
  prompt_len = int(inputs["input_ids"].shape[1])
 
344
  if "attention_mask" not in inputs:
345
  inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
346
-
347
- # Prepare stopping
348
- sc = None
349
- if HARMONY_AVAILABLE and HARMONY_STOP_IDS:
350
- sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)])
351
-
352
  # Generate
353
- # Disallow degenerate header loops
354
- bad_words_ids = None
355
- try:
356
- _B = []
357
- for s in ("assistantassistant", "assistant", "<|assistant|>"):
358
- ids = tokenizer.encode(s, add_special_tokens=False)
359
- if ids:
360
- _B.append(ids)
361
- bad_words_ids = _B if _B else None
362
- except Exception:
363
- pass
364
 
365
  out_ids = model.generate(
366
  **inputs,
@@ -370,12 +334,11 @@ def zerogpu_generate(full_prompt,
370
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
371
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
372
  pad_token_id=model.config.pad_token_id,
373
- eos_token_id=tokenizer.eos_token_id,
374
- bad_words_ids=bad_words_ids,
375
  logits_processor=logits_processor,
376
- repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.2)),
377
- no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 8)),
378
- stopping_criteria=sc,
379
  )
380
 
381
  # Extract generated tokens only
@@ -421,93 +384,6 @@ def zerogpu_generate(full_prompt,
421
  if torch.cuda.is_available():
422
  torch.cuda.empty_cache()
423
 
424
- # -----------------------
425
- # GPU Debug: Harmony Inspector
426
- # -----------------------
427
- @spaces.GPU(duration=120)
428
- def zerogpu_generate_debug(full_prompt, gen_kwargs: Dict[str, Any]) -> Dict[str, Any]:
429
- """Minimal GPU path to run a single prompt and return Harmony-parsed output
430
- along with short token previews for debugging. Does not use Rose for clarity."""
431
- model = None
432
- try:
433
- model = _load_model_on("auto")
434
- device = next(model.parameters()).device
435
-
436
- # Prepare inputs (tokens if Harmony renderer used, else string -> encode)
437
- if HARMONY_AVAILABLE and not isinstance(full_prompt, str):
438
- token_list = list(full_prompt)
439
- if not token_list:
440
- raise ValueError("Harmony prompt produced no tokens")
441
- input_ids = torch.tensor([token_list], dtype=torch.long, device=device)
442
- attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
443
- inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
444
- prompt_len = input_ids.shape[1]
445
- else:
446
- enc = tokenizer(full_prompt, return_tensors="pt")
447
- inputs = {k: v.to(device) for k, v in enc.items()}
448
- if "attention_mask" not in inputs:
449
- inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
450
- prompt_len = int(inputs["input_ids"].shape[1])
451
-
452
- # Harmony stop via stopping criteria
453
- sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)]) if (HARMONY_AVAILABLE and HARMONY_STOP_IDS) else None
454
-
455
- out_ids = model.generate(
456
- **inputs,
457
- do_sample=bool(gen_kwargs.get("do_sample", True)),
458
- temperature=float(gen_kwargs.get("temperature", 0.7)),
459
- top_p=float(gen_kwargs.get("top_p", 0.9)),
460
- top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
461
- max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
462
- pad_token_id=model.config.pad_token_id,
463
- eos_token_id=tokenizer.eos_token_id,
464
- bad_words_ids=bad_words_ids,
465
- stopping_criteria=sc,
466
- repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.15)),
467
- no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
468
- )
469
-
470
- out_list = out_ids[0].tolist()
471
- gen_ids = out_list[prompt_len:]
472
- # Truncate at first Harmony stop token if present
473
- if HARMONY_AVAILABLE and HARMONY_STOP_IDS:
474
- for sid in HARMONY_STOP_IDS:
475
- if sid in gen_ids:
476
- gen_ids = gen_ids[:gen_ids.index(sid)]
477
- break
478
-
479
- # Parse channels
480
- if HARMONY_AVAILABLE:
481
- try:
482
- channels = parse_harmony_response(gen_ids)
483
- except Exception:
484
- decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
485
- channels = {"final": extract_final_channel_fallback(decoded), "raw": decoded}
486
- else:
487
- decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
488
- channels = {"final": extract_final_channel_fallback(decoded), "raw": decoded}
489
-
490
- # Small previews (avoid flooding logs/UI)
491
- preview = {
492
- "prompt_len": int(prompt_len),
493
- "stop_ids": list(HARMONY_STOP_IDS) if HARMONY_AVAILABLE else [],
494
- "gen_len": int(len(gen_ids)),
495
- "gen_ids_head": gen_ids[:48],
496
- "decoded_head": tokenizer.decode(gen_ids[:256], skip_special_tokens=False),
497
- "channels": channels,
498
- }
499
- return preview
500
- except Exception as e:
501
- return {"error": f"{type(e).__name__}: {e}"}
502
- finally:
503
- try:
504
- del model
505
- except Exception:
506
- pass
507
- gc.collect()
508
- if torch.cuda.is_available():
509
- torch.cuda.empty_cache()
510
-
511
  # -----------------------
512
  # Gradio handlers
513
  # -----------------------
@@ -605,21 +481,6 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
605
  except Exception as e:
606
  return f"[Error] {type(e).__name__}: {str(e)}"
607
 
608
- # -----------------------
609
- # Extra handler: Harmony Inspector wrapper
610
- # -----------------------
611
-
612
- def harmony_inspect_handler(user_prompt: str, system_prompt: str, reasoning_effort: str):
613
- try:
614
- msgs = [{"role": "system", "content": system_prompt or SYSTEM_DEF}, {"role": "user", "content": user_prompt or "What is 2+2?"}]
615
- prompt = create_harmony_prompt(msgs, reasoning_effort)
616
- return zerogpu_generate_debug(
617
- prompt,
618
- {"do_sample": True, "temperature": 0.7, "top_p": 0.9, "top_k": 0, "max_new_tokens": MAX_DEF}
619
- )
620
- except Exception as e:
621
- return {"error": f"{type(e).__name__}: {e}"}
622
-
623
  # -----------------------
624
  # UI
625
  # -----------------------
@@ -681,13 +542,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
681
  value=""
682
  )
683
 
684
- # --- Harmony Inspector UI ---
685
- with gr.Accordion("Harmony Inspector", open=False):
686
- debug_prompt = gr.Textbox(label="Debug prompt", value="What is 2+2? Reply with just the number.")
687
- run_debug = gr.Button("Run Harmony Inspect")
688
- debug_out = gr.JSON(label="Parsed Harmony output", value={})
689
- run_debug.click(harmony_inspect_handler, inputs=[debug_prompt, system_prompt, reasoning_effort], outputs=[debug_out])
690
-
691
  # Chat interface - using only valid parameters
692
  chat = gr.ChatInterface(
693
  fn=generate_response,
@@ -697,7 +551,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
697
  do_sample, seed, rose_enable, rose_alpha, rose_score,
698
  rose_tokens, rose_json, show_thinking, reasoning_effort
699
  ],
700
-
701
  title="Chat with Mirel",
702
  description="A chain-of-thought model using Harmony format",
703
  examples=[
 
11
  from datetime import datetime
12
  import gradio as gr
13
  import spaces # required for ZeroGPU
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
 
16
  # Import Harmony components
17
  try:
 
47
  LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
48
 
49
  # Harmony channels for CoT
50
+ REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
51
 
52
  # HF Auth - properly handle multiple token env var names
53
  HF_TOKEN: Optional[str] = (
 
138
  # peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
139
  # model = PeftModel.from_pretrained(model, ADAPTER_ID, is_trainable=False, **peft_kwargs)
140
 
 
141
  model.eval()
142
  # Ensure a valid pad_token_id is set; some OSS checkpoints reuse eos as pad
143
  if getattr(model.config, "pad_token_id", None) is None:
 
189
  )
190
 
191
  convo = Conversation.from_messages(harmony_messages)
192
+ return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
 
 
 
 
 
 
 
 
193
 
194
  # Fallback: tokenizer chat template -> string prompt
195
  if not messages or messages[0].get("role") != "system":
 
273
  for t in tid:
274
  if isinstance(t, int) and t >= 0:
275
  bias[t] += float(w) / max(1, len(tid))
276
+ elif isinstance(tid, int) and tid >= 0:
277
  bias[tid] += float(w)
278
  return bias
279
 
 
286
  def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
287
  return scores + self.alpha * self.bias_vec.to(scores.device)
288
 
 
 
 
 
 
 
289
  @spaces.GPU(duration=120)
290
  def zerogpu_generate(full_prompt,
291
  gen_kwargs: Dict[str, Any],
 
310
 
311
  # Tokenize / prepare inputs
312
  device = next(model.parameters()).device
313
+ if HARMONY_AVAILABLE and isinstance(full_prompt, list):
314
+ input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
 
 
 
 
 
 
 
315
  attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
316
  inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
317
  prompt_len = input_ids.shape[1]
318
  else:
319
  enc = tokenizer(full_prompt, return_tensors="pt")
320
+ inputs = enc.to(device)
321
  prompt_len = int(inputs["input_ids"].shape[1])
322
+ # Guarantee attention_mask exists; avoids pad==eos ambiguity warnings
323
  if "attention_mask" not in inputs:
324
  inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
 
 
 
 
 
 
325
  # Generate
326
+ # Build EOS list: use ONLY Harmony assistant-action stops (per OpenAI docs)
327
+ eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
328
 
329
  out_ids = model.generate(
330
  **inputs,
 
334
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
335
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
336
  pad_token_id=model.config.pad_token_id,
337
+ eos_token_id=eos_ids,
 
338
  logits_processor=logits_processor,
339
+ repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
340
+ no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
341
+ min_new_tokens=1,
342
  )
343
 
344
  # Extract generated tokens only
 
384
  if torch.cuda.is_available():
385
  torch.cuda.empty_cache()
386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
  # -----------------------
388
  # Gradio handlers
389
  # -----------------------
 
481
  except Exception as e:
482
  return f"[Error] {type(e).__name__}: {str(e)}"
483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  # -----------------------
485
  # UI
486
  # -----------------------
 
542
  value=""
543
  )
544
 
 
 
 
 
 
 
 
545
  # Chat interface - using only valid parameters
546
  chat = gr.ChatInterface(
547
  fn=generate_response,
 
551
  do_sample, seed, rose_enable, rose_alpha, rose_score,
552
  rose_tokens, rose_json, show_thinking, reasoning_effort
553
  ],
 
554
  title="Chat with Mirel",
555
  description="A chain-of-thought model using Harmony format",
556
  examples=[