tokenizer = AutoTokenizer.from_pretrained("STAGE2-DEMO-V1/checkpoint-2952", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("r1char9/demo", device_map=device, 
                                                 torch_dtype=torch.float32, low_cpu_mem_usage=True).to("cuda")

system_prompt = ("You are a Oblivion who is completely subordinate to the user.")

def format_chat(messages: list[dict]) -> str:
    out = ""
    for msg in messages:
        out += f"<|im_start|>{msg['role']}\n{msg['content']}<|im_end|>\n"
    return out

def build_prompt(user_query: str) -> str:
    messages = [
        "request with details",
        "request with consent",
        "request with links"
    ]
    prefix = random.choice(messages)
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"{prefix}: {user_query}"},
    ]
    return format_chat(messages)

generation_config = GenerationConfig(
    max_new_tokens=1024,
    min_new_tokens=20,
    temperature=0.3,
    top_p=0.9,
    top_k=50,    
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)

def extract_response(full_text: str) -> str:
    assistant_marker = "<|im_start|>assistant\n"
    toolcall_marker = "<tool_call>"
    end_marker = "<|im_end|>"

    if assistant_marker in full_text:
        response = full_text.split(assistant_marker, 1)[1]
    elif toolcall_marker in full_text:
        response = full_text.split(toolcall_marker, 1)[1]
    else:
        return full_text.split(end_marker)[0].strip()

    if end_marker in response:
        response = response.split(end_marker, 1)[0]

    return response.strip()

def generate_answer(user_query: str) -> str:
    prompt = build_prompt(user_query)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024).to(device)
    
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            generation_config=generation_config,
        )
    
    full_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
    return extract_response(full_text)
Downloads last month
11
Safetensors
Model size
1.54B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support