from gen import get_answer,get_state
import torch


def load_state(train_state_path, layer=32, n_embd=2560):
    train_state = torch.load(pth_file_path, map_location=torch.device('cpu'))
    state = [None] * (layer * 3)
    for i in range(layer):
        state[i*3+0]=torch.zeros(n_embd,).to(dtype=torch.bfloat16,device='cuda')
        state[i*3+1]=train_state[f'blocks.{i}.att.time_state'].to(dtype=torch.float,device='cuda')
        state[i*3+2]=torch.zeros(n_embd,).to(dtype=torch.bfloat16,device='cuda')
    return state

def get_instruction():
    """返回固定的指令内容"""
    return "根据input中的input和entity_types,帮助用户找到文本中每种entity_types的实体,标明实体类型并且简单描述。然后给找到实体之间的关系,并且描述这段关系以及对关系强度打分。 避免使用诸如\"其他\"或\"未知\"的通用实体类型。 非常重要的是:不要生成冗余或重叠的实体类型和关系。用JSON格式输出。"

def get_content(input_text):
    """输入内容文本,返回格式化的content部分"""
    return f"'{{'input': '{input_text}'}}"

def get_entity_types(entity_list):
    """
    输入实体类型列表,返回格式化的entity_types部分
    
    Args:
        entity_list: 可以是字符串列表 ['领域', '专家', '任务'] 
                    或者是字符串 '领域, 专家, 任务'
    """
    if isinstance(entity_list, str):
        # 如果是字符串,按逗号分割
        entity_list = [item.strip() for item in entity_list.split(',')]
    
    # 不带引号的格式(和原数据一致)
    entity_str = ', '.join(entity_list)
    return f"{{'entity_types': [{entity_str}]}}"

def generate_prompt(content, entity_types):
    """
    生成完整的prompt
    
    Args:
        content: 输入的文本内容
        entity_types: 实体类型列表或字符串
        
    Returns:
        完整的prompt字符串
    """
    instruction = get_instruction()
    content_part = get_content(content)
    entity_types_part = get_entity_types(entity_types)
    input_list_str = f'["content": {content_part}, "entity_types": {entity_types_part}]'
    # 按照指定格式拼接
    prompt = (
        f"{input_list_str}\n\n"
        f"User: Act as a specialized AI for Knowledge Graph construction. Your task is to extract entities and their relationships from the provided input, based on the given entity_types provided in above content.\nStructure your output as a single, valid JSON object with two top-level keys: entities and relationships.\nentities: A list of objects. Each object must have:\nentity: The exact name of the entity.\ndescription: A brief, context-based summary of the entity.\nrelationships: A list of objects. Each object must have:\nsource: The name of the source entity.\ntarget: The name of the target entity.\nrelationship: A concise description of their connection.\nCritical Rules:\nStrict Typing: Use only the provided entity types. Do not invent types or use generics like \"Other\".\nNo Redundancy: Do not create duplicate or reciprocal relationships (e.g., if A acquired B exists, do not add B was acquired by A).\nYour response must be only the JSON object.\n\n"
        f"Assistant:"
    )
    
    return prompt

  content1 = "根据我国的监狱法令,为了协助监狱囚犯改过自新和重新融入社会,监禁期至少四个星期的囚犯可在服刑至少14天后转入居家宵禁计划,在家服满剩余的刑期"
  entity_types1 = ["法律法规", "人物类别", "时间条件", "政策措施"]
  
  ctx = generate_prompt(content1, entity_types1)

  pth_file_path = "/home/rwkv/models/triplets1/rwkv-0.pth"


  tt_state = load_state(pth_file_path)
  
  print(ctx)
  res1 = get_answer(ctx,state=tt_state)
  print('train_state :',res1)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support