titicacine/chinese-text-rating-model

这是一个基于BERT的文本评分预测模型,可以根据输入的文本内容预测1-10分的评分。

项目结构

merged_models/
├── config.json           # BERT模型配置文件
├── merge_models.py       # 模型合并脚本
├── predict.py           # 预测功能实现
├── pytorch_model.bin    # 模型权重文件
├── tokenizer_config.json # 分词器配置
└── vocab.txt            # 词表文件

环境要求

  • Python 3.6+
  • PyTorch
  • Transformers
  • NumPy
  • scikit-learn

安装依赖

pip install torch transformers numpy scikit-learn

使用方法

  1. 首先确保您已经下载了完整的模型文件夹,包括所有必要的配置文件和模型权重。

  2. 在您的Python代码中导入必要的库并加载模型:

import torch
from transformers import BertTokenizer
from merge_models import SimpleRatingPredictor

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载模型和分词器
model_path = './merged_models'  # 指向模型文件夹的路径
tokenizer = BertTokenizer.from_pretrained(model_path)
model = SimpleRatingPredictor(model_path)

# 加载模型权重
state_dict = torch.load(f'{model_path}/pytorch_model.bin', map_location=device)
model.load_state_dict(state_dict)
model.to(device)
  1. 使用模型进行预测:
from predict import predict_rating

# 示例文本
text = "这是一个需要评分的文本"

# 进行预测
rating = predict_rating(text, model, tokenizer, device)
print(f"预测评分: {rating:.1f}")

评估指标

模型使用以下指标进行评估:

  • MSE (均方误差)
  • RMSE (均方根误差)
  • MAE (平均绝对误差)

批量预测

如果需要对多个文本进行批量预测,可以使用以下代码:

from predict import load_test_data, evaluate_model

# 加载测试数据
test_file = 'test.jsonl'
texts, true_points = load_test_data(test_file)

# 进行预测
predictions = []
for i, text in enumerate(texts, 1):
    rating = predict_rating(text, model, tokenizer, device)
    predictions.append(rating)
    if i % 10 == 0:
        print(f"已完成 {i}/{len(texts)} 条预测")

# 评估模型性能
metrics = evaluate_model(predictions, true_points)
print("\n模型评估结果:")
for metric_name, value in metrics.items():
    print(f"{metric_name}: {value:.4f}")

注意事项

  1. 预测结果范围为1-10分
  2. 输入文本会被自动截断到最大长度512个token
  3. 建议使用GPU进行预测以获得更好的性能

模型说明

该模型基于BERT-base-Chinese预训练模型,添加了一个回归头用于评分预测。模型结构包括:

  • BERT编码层
  • Dropout层(dropout_rate=0.3)
  • 线性层 (768->128)
  • LayerNorm层
  • GELU激活函数
  • 最终输出层 (128->1)
Downloads last month
14
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support