lxcxjxhx's picture
Upload app.py
31f9cfc verified
raw
history blame
6.86 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
# 加载模型
model_name = "SecurityXuanwuLab/HaS-820m"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name) # 默认加载到 CPU,避免 accelerate 依赖
# 定义基本功能函数
def paraphrase(text):
prompt = f"将以下句子润色为更自然的表达:{text}"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50, temperature=0.7)
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
def summarize(text):
prompt = f"将以下文本总结为简短的摘要:{text}"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7)
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
def translate(text, target_lang):
prompt = f"将以下文本从中文翻译到{target_lang}{text}"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7)
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
def reading_comprehension(context, question):
prompt = f"根据以下上下文回答问题:\n上下文:{context}\n问题:{question}"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=50, temperature=0.7)
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
def classify(text):
prompt = f"判断以下文本的情感是积极还是消极:{text}。情感是"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, max_new_tokens=10, temperature=0.7)
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip()
# 定义 HaS 隐私保护功能
def hide(text):
hide_template = """<s>Paraphrase the text:%s\n\n"""
input_text = hide_template % text
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=100)
pred = outputs[0][len(inputs['input_ids'][0]):]
return tokenizer.decode(pred, skip_special_tokens=True)
def seek(hide_input, hide_output, original_input):
seek_template = "Convert the text:\n%s\n\n%s\n\nConvert the text:\n%s\n\n"
input_text = seek_template % (hide_input, hide_output, original_input)
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=512)
pred = outputs[0][len(inputs['input_ids'][0]):]
return tokenizer.decode(pred, skip_special_tokens=True)
# 处理任务
def process_task(task, text, context=None, question=None, target_lang=None, hide_output=None):
if task == "润色":
return paraphrase(text), None, None
elif task == "摘要":
return summarize(text), None, None
elif task == "翻译":
return translate(text, target_lang), None, None
elif task == "阅读理解":
return reading_comprehension(context, question), None, None
elif task == "文本分类":
return classify(text), None, None
elif task == "隐私保护 (Hide)":
hidden_text = hide(text)
return hidden_text, hidden_text, None
elif task == "隐私保护 (Seek)":
if not hide_output:
return "请提供 Hide 输出后的翻译结果", None, None
original_output = seek(text, hide_output, context) # text 为 hide_input, context 为 original_input
return original_output, None, None
css = """
footer {
visibility: hidden;
}
"""
# Gradio 界面
with gr.Blocks(title="腾讯玄武产品体验 - HaS-820m", css=css) as demo:
gr.Markdown("# 腾讯玄武产品体验 - HaS-820m")
gr.Markdown("支持润色、摘要、翻译、阅读理解、文本分类及隐私保护(Hide 和 Seek)。")
task = gr.Dropdown(
choices=["润色", "摘要", "翻译", "阅读理解", "文本分类", "隐私保护 (Hide)", "隐私保护 (Seek)"],
label="选择任务"
)
# 输入组件
text_input = gr.Textbox(label="输入文本", lines=5)
context_input = gr.Textbox(label="上下文(阅读理解或 Seek 的原始输入)", lines=5, visible=False)
question_input = gr.Textbox(label="问题(阅读理解)", visible=False)
lang_input = gr.Dropdown(choices=["英语", "法语", "西班牙语"], label="目标语言(翻译)", visible=False)
hide_output_input = gr.Textbox(label="Hide 输出后的翻译结果(Seek)", lines=3, visible=False)
# 输出组件
output = gr.Textbox(label="输出结果", lines=5)
hidden_text_output = gr.Textbox(label="Hide 输出(隐私保护)", lines=3, visible=False)
note = gr.Markdown("**注意**:隐私保护 (Seek) 需要手动输入 Hide 输出后的翻译结果(当前无外部 API 支持)。")
# 动态显示输入框
def update_inputs(task):
if task == "阅读理解":
return (gr.update(visible=False), gr.update(visible=True, label="上下文"),
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False),
gr.update(visible=False))
elif task == "翻译":
return (gr.update(visible=True), gr.update(visible=False),
gr.update(visible=False), gr.update(visible=True), gr.update(visible=False),
gr.update(visible=False))
elif task == "隐私保护 (Hide)":
return (gr.update(visible=True), gr.update(visible=False),
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
gr.update(visible=True))
elif task == "隐私保护 (Seek)":
return (gr.update(visible=True, label="Hide 输出"), gr.update(visible=True, label="原始输入"),
gr.update(visible=False), gr.update(visible=False), gr.update(visible=True),
gr.update(visible=False))
else:
return (gr.update(visible=True), gr.update(visible=False),
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False),
gr.update(visible=False))
task.change(
fn=update_inputs,
inputs=task,
outputs=[text_input, context_input, question_input, lang_input, hide_output_input, hidden_text_output]
)
# 提交按钮
submit_btn = gr.Button("提交")
submit_btn.click(
fn=process_task,
inputs=[task, text_input, context_input, question_input, lang_input, hide_output_input],
outputs=[output, hidden_text_output, note]
)
demo.launch()