from gen import get_answer,get_state
import torch
def load_state(train_state_path, layer=32, n_embd=2560):
train_state = torch.load(pth_file_path, map_location=torch.device('cpu'))
state = [None] * (layer * 3)
for i in range(layer):
state[i*3+0]=torch.zeros(n_embd,).to(dtype=torch.bfloat16,device='cuda')
state[i*3+1]=train_state[f'blocks.{i}.att.time_state'].to(dtype=torch.float,device='cuda')
state[i*3+2]=torch.zeros(n_embd,).to(dtype=torch.bfloat16,device='cuda')
return state
def get_instruction():
"""返回固定的指令内容"""
return "根据input中的input和entity_types,帮助用户找到文本中每种entity_types的实体,标明实体类型并且简单描述。然后给找到实体之间的关系,并且描述这段关系以及对关系强度打分。 避免使用诸如\"其他\"或\"未知\"的通用实体类型。 非常重要的是:不要生成冗余或重叠的实体类型和关系。用JSON格式输出。"
def get_content(input_text):
"""输入内容文本,返回格式化的content部分"""
return f"'{{'input': '{input_text}'}}"
def get_entity_types(entity_list):
"""
输入实体类型列表,返回格式化的entity_types部分
Args:
entity_list: 可以是字符串列表 ['领域', '专家', '任务']
或者是字符串 '领域, 专家, 任务'
"""
if isinstance(entity_list, str):
# 如果是字符串,按逗号分割
entity_list = [item.strip() for item in entity_list.split(',')]
# 不带引号的格式(和原数据一致)
entity_str = ', '.join(entity_list)
return f"{{'entity_types': [{entity_str}]}}"
def generate_prompt(content, entity_types):
"""
生成完整的prompt
Args:
content: 输入的文本内容
entity_types: 实体类型列表或字符串
Returns:
完整的prompt字符串
"""
instruction = get_instruction()
content_part = get_content(content)
entity_types_part = get_entity_types(entity_types)
input_list_str = f'["content": {content_part}, "entity_types": {entity_types_part}]'
# 按照指定格式拼接
prompt = (
f"{input_list_str}\n\n"
f"User: Act as a specialized AI for Knowledge Graph construction. Your task is to extract entities and their relationships from the provided input, based on the given entity_types provided in above content.\nStructure your output as a single, valid JSON object with two top-level keys: entities and relationships.\nentities: A list of objects. Each object must have:\nentity: The exact name of the entity.\ndescription: A brief, context-based summary of the entity.\nrelationships: A list of objects. Each object must have:\nsource: The name of the source entity.\ntarget: The name of the target entity.\nrelationship: A concise description of their connection.\nCritical Rules:\nStrict Typing: Use only the provided entity types. Do not invent types or use generics like \"Other\".\nNo Redundancy: Do not create duplicate or reciprocal relationships (e.g., if A acquired B exists, do not add B was acquired by A).\nYour response must be only the JSON object.\n\n"
f"Assistant:"
)
return prompt
content1 = "根据我国的监狱法令,为了协助监狱囚犯改过自新和重新融入社会,监禁期至少四个星期的囚犯可在服刑至少14天后转入居家宵禁计划,在家服满剩余的刑期"
entity_types1 = ["法律法规", "人物类别", "时间条件", "政策措施"]
ctx = generate_prompt(content1, entity_types1)
pth_file_path = "/home/rwkv/models/triplets1/rwkv-0.pth"
tt_state = load_state(pth_file_path)
print(ctx)
res1 = get_answer(ctx,state=tt_state)
print('train_state :',res1)
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support