daniel-dona's picture
Update app.py
6270ae1 verified
raw
history blame
3.01 kB
import os
import json
import spaces
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import pyparseit
#model_name = "daniel-dona/sparql-model-era-lora-128-qwen3-4b"
model_name = "daniel-dona/sparql-model-era-lora-128-qwen3-0.6b"
print(os.environ)
print("Cuda?", torch.cuda.is_available())
prompt_valid = open("/home/user/app/templates/prompt_valid.txt").read()
prompt_sparql = open("/home/user/app/templates/prompt_sparql.txt").read()
system = open("/home/user/app/templates/system1.txt").read()
@spaces.GPU
def generate(messages):
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=4096
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
return content
def is_valid(message):
valid = False
messages = [{"role": "system", "content": system}]
print("NLQ:", message)
prompt = prompt_valid.replace("%nlq", message)
print("Prompt:", prompt)
messages.append({"role": "user", "content": prompt})
generation = generate(messages)
print("Generated:", generation)
blocks = pyparseit.parse_markdown_string(generation)
if len(blocks) >= 1:
try:
valid = json.loads(blocks[-1].content)["valid"] # Último bloque
except Exception as e:
print(e)
return valid
def gen_sparql(message):
sparql = "```sparql\n[code]```"
messages = [{"role": "system", "content": system}]
print("NLQ:", message)
prompt = prompt_sparql.replace("%nlq", message)
print("Prompt:", prompt)
messages.append({"role": "user", "content": prompt})
generation = generate(messages)
print("Generated:", generation)
blocks = pyparseit.parse_markdown_string(generation)
if len(blocks) >= 1:
try:
sparql = f"```sparql\n{blocks[-1].content}\n```" # Último bloque
except Exception as e:
print(e)
return sparql
def respond(
message,
history: list[tuple[str, str]],
):
if is_valid(message):
return gen_sparql(message)
else:
return "Unable to generate SPARQL with your request, try to rephrase it."
demo = gr.ChatInterface(
respond,
type="messages",
title="SPARQL generator"
)
if __name__ == "__main__":
demo.queue().launch()