File size: 2,177 Bytes
3193f76
 
 
 
 
c6b3a35
3193f76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed3abd9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import re
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr

MODEL_ID = "thoughtcast/marketing-spiked-lassie-experiment"

# ─────────────────── Load model (fp16 on GPU) ───────────────────────────
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",          # puts it on the available GPU
    torch_dtype=torch.float16   # fp16 = small & fast enough for T4-small
)

# ─────────────────── Post-processing helper ─────────────────────────────
def _shorten(txt: str) -> str:
    """return only first sentence & dedupe immediate repeat words."""
    txt = txt.split("<assistant>\n")[-1].strip()

    # keep up to first sentence-ending punctuation
    m = re.search(r"[.!?]", txt)
    if m:
        txt = txt[: m.end()]

    words = txt.split()
    deduped = [words[0]] + [
        w for i, w in enumerate(words[1:]) if w.lower() != words[i].lower()
    ]
    return " ".join(deduped)

# ─────────────────── Chat function ──────────────────────────────────────
def chat(message, history=[]):
    prompt = "".join(f"<user>\n{u}\n<assistant>\n{a}\n" for u, a in history)
    prompt += f"<user>\n{message}\n<assistant>\n"

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    output = model.generate(
        **inputs,
        max_new_tokens=80,
        do_sample=True,
        temperature=0.6,
        top_p=0.85,
        repetition_penalty=1.25,
        no_repeat_ngram_size=3,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )

    decoded = tokenizer.decode(output[0], skip_special_tokens=True)
    return _shorten(decoded)

gr.ChatInterface(chat, title="Marketing Lassie 🐾 (Trained on Lassie's Website Marketing Information in FAQ / Conversational Form C=8, S_c=3, S=1)").launch()