Spaces:
Sleeping
Sleeping
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import gradio as gr | |
# 加载模型 | |
model_name = "SecurityXuanwuLab/HaS-820m" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) # 默认加载到 CPU,避免 accelerate 依赖 | |
# 定义基本功能函数 | |
def paraphrase(text): | |
prompt = f"将以下句子润色为更自然的表达:{text}" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=50, temperature=0.7) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip() | |
def summarize(text): | |
prompt = f"将以下文本总结为简短的摘要:{text}" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip() | |
def translate(text, target_lang): | |
prompt = f"将以下文本从中文翻译到{target_lang}:{text}" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=100, temperature=0.7) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip() | |
def reading_comprehension(context, question): | |
prompt = f"根据以下上下文回答问题:\n上下文:{context}\n问题:{question}" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=50, temperature=0.7) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip() | |
def classify(text): | |
prompt = f"判断以下文本的情感是积极还是消极:{text}。情感是" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
outputs = model.generate(**inputs, max_new_tokens=10, temperature=0.7) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True).replace(prompt, "").strip() | |
# 定义 HaS 隐私保护功能 | |
def hide(text): | |
hide_template = """<s>Paraphrase the text:%s\n\n""" | |
input_text = hide_template % text | |
inputs = tokenizer(input_text, return_tensors="pt") | |
outputs = model.generate(**inputs, max_length=100) | |
pred = outputs[0][len(inputs['input_ids'][0]):] | |
return tokenizer.decode(pred, skip_special_tokens=True) | |
def seek(hide_input, hide_output, original_input): | |
seek_template = "Convert the text:\n%s\n\n%s\n\nConvert the text:\n%s\n\n" | |
input_text = seek_template % (hide_input, hide_output, original_input) | |
inputs = tokenizer(input_text, return_tensors="pt") | |
outputs = model.generate(**inputs, max_length=512) | |
pred = outputs[0][len(inputs['input_ids'][0]):] | |
return tokenizer.decode(pred, skip_special_tokens=True) | |
# 处理任务 | |
def process_task(task, text, context=None, question=None, target_lang=None, hide_output=None): | |
if task == "润色": | |
return paraphrase(text), None, None | |
elif task == "摘要": | |
return summarize(text), None, None | |
elif task == "翻译": | |
return translate(text, target_lang), None, None | |
elif task == "阅读理解": | |
return reading_comprehension(context, question), None, None | |
elif task == "文本分类": | |
return classify(text), None, None | |
elif task == "隐私保护 (Hide)": | |
hidden_text = hide(text) | |
return hidden_text, hidden_text, None | |
elif task == "隐私保护 (Seek)": | |
if not hide_output: | |
return "请提供 Hide 输出后的翻译结果", None, None | |
original_output = seek(text, hide_output, context) # text 为 hide_input, context 为 original_input | |
return original_output, None, None | |
css = """ | |
footer { | |
visibility: hidden; | |
} | |
""" | |
# Gradio 界面 | |
with gr.Blocks(title="腾讯玄武产品体验 - HaS-820m", css=css) as demo: | |
gr.Markdown("# 腾讯玄武产品体验 - HaS-820m") | |
gr.Markdown("支持润色、摘要、翻译、阅读理解、文本分类及隐私保护(Hide 和 Seek)。") | |
task = gr.Dropdown( | |
choices=["润色", "摘要", "翻译", "阅读理解", "文本分类", "隐私保护 (Hide)", "隐私保护 (Seek)"], | |
label="选择任务" | |
) | |
# 输入组件 | |
text_input = gr.Textbox(label="输入文本", lines=5) | |
context_input = gr.Textbox(label="上下文(阅读理解或 Seek 的原始输入)", lines=5, visible=False) | |
question_input = gr.Textbox(label="问题(阅读理解)", visible=False) | |
lang_input = gr.Dropdown(choices=["英语", "法语", "西班牙语"], label="目标语言(翻译)", visible=False) | |
hide_output_input = gr.Textbox(label="Hide 输出后的翻译结果(Seek)", lines=3, visible=False) | |
# 输出组件 | |
output = gr.Textbox(label="输出结果", lines=5) | |
hidden_text_output = gr.Textbox(label="Hide 输出(隐私保护)", lines=3, visible=False) | |
note = gr.Markdown("**注意**:隐私保护 (Seek) 需要手动输入 Hide 输出后的翻译结果(当前无外部 API 支持)。") | |
# 动态显示输入框 | |
def update_inputs(task): | |
if task == "阅读理解": | |
return (gr.update(visible=False), gr.update(visible=True, label="上下文"), | |
gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), | |
gr.update(visible=False)) | |
elif task == "翻译": | |
return (gr.update(visible=True), gr.update(visible=False), | |
gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), | |
gr.update(visible=False)) | |
elif task == "隐私保护 (Hide)": | |
return (gr.update(visible=True), gr.update(visible=False), | |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), | |
gr.update(visible=True)) | |
elif task == "隐私保护 (Seek)": | |
return (gr.update(visible=True, label="Hide 输出"), gr.update(visible=True, label="原始输入"), | |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), | |
gr.update(visible=False)) | |
else: | |
return (gr.update(visible=True), gr.update(visible=False), | |
gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), | |
gr.update(visible=False)) | |
task.change( | |
fn=update_inputs, | |
inputs=task, | |
outputs=[text_input, context_input, question_input, lang_input, hide_output_input, hidden_text_output] | |
) | |
# 提交按钮 | |
submit_btn = gr.Button("提交") | |
submit_btn.click( | |
fn=process_task, | |
inputs=[task, text_input, context_input, question_input, lang_input, hide_output_input], | |
outputs=[output, hidden_text_output, note] | |
) | |
demo.launch() |