titicacine/chinese-text-rating-model
这是一个基于BERT的文本评分预测模型,可以根据输入的文本内容预测1-10分的评分。
项目结构
merged_models/
├── config.json # BERT模型配置文件
├── merge_models.py # 模型合并脚本
├── predict.py # 预测功能实现
├── pytorch_model.bin # 模型权重文件
├── tokenizer_config.json # 分词器配置
└── vocab.txt # 词表文件
环境要求
- Python 3.6+
- PyTorch
- Transformers
- NumPy
- scikit-learn
安装依赖
pip install torch transformers numpy scikit-learn
使用方法
首先确保您已经下载了完整的模型文件夹,包括所有必要的配置文件和模型权重。
在您的Python代码中导入必要的库并加载模型:
import torch
from transformers import BertTokenizer
from merge_models import SimpleRatingPredictor
# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载模型和分词器
model_path = './merged_models' # 指向模型文件夹的路径
tokenizer = BertTokenizer.from_pretrained(model_path)
model = SimpleRatingPredictor(model_path)
# 加载模型权重
state_dict = torch.load(f'{model_path}/pytorch_model.bin', map_location=device)
model.load_state_dict(state_dict)
model.to(device)
- 使用模型进行预测:
from predict import predict_rating
# 示例文本
text = "这是一个需要评分的文本"
# 进行预测
rating = predict_rating(text, model, tokenizer, device)
print(f"预测评分: {rating:.1f}")
评估指标
模型使用以下指标进行评估:
- MSE (均方误差)
- RMSE (均方根误差)
- MAE (平均绝对误差)
批量预测
如果需要对多个文本进行批量预测,可以使用以下代码:
from predict import load_test_data, evaluate_model
# 加载测试数据
test_file = 'test.jsonl'
texts, true_points = load_test_data(test_file)
# 进行预测
predictions = []
for i, text in enumerate(texts, 1):
rating = predict_rating(text, model, tokenizer, device)
predictions.append(rating)
if i % 10 == 0:
print(f"已完成 {i}/{len(texts)} 条预测")
# 评估模型性能
metrics = evaluate_model(predictions, true_points)
print("\n模型评估结果:")
for metric_name, value in metrics.items():
print(f"{metric_name}: {value:.4f}")
注意事项
- 预测结果范围为1-10分
- 输入文本会被自动截断到最大长度512个token
- 建议使用GPU进行预测以获得更好的性能
模型说明
该模型基于BERT-base-Chinese预训练模型,添加了一个回归头用于评分预测。模型结构包括:
- BERT编码层
- Dropout层(dropout_rate=0.3)
- 线性层 (768->128)
- LayerNorm层
- GELU激活函数
- 最终输出层 (128->1)
- Downloads last month
- 14
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Evaluation results
- MSEself-reported0.850
- RMSEself-reported0.920
- MAEself-reported0.760