介绍

本模型主要用途为基于科技类文章生成对应标题。

本次将开源从 100-2200 steps 的中间所有 checkpoint 以供大家参考。

使用

from transformers import AutoModelForCausalLM, AutoTokenizer,BitsAndBytesConfig

peft_model_id = "checkpoint-2000"
model = AutoModelForCausalLM.from_pretrained(peft_model_id,device_map="cuda")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")

input_text = """
Generate a title for the article:

{content}

---
Title:
""" # 固定格式
encoding = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**encoding,max_length=8192,temperature=0.2,do_sample=True)
generated_ids = outputs[:, encoding.input_ids.shape[1]:]
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(generated_texts[0])

训练数据

linux-cn 文章 https://huggingface.co/datasets/linux-cn/archive

微调

基于 LLaMA-Factory 在单张A100(80G)上进行微调,微调参数如下:

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --do_train True \
    --model_name_or_path google/gemma-2b \
    --finetuning_type lora \
    --template default \
    --dataset title \
    --use_unsloth \
    --cutoff_len 8192 \
    --learning_rate 5e-05 \
    --num_train_epochs 10.0 \
    --max_samples 10000 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --max_grad_norm 1.0 \
    --logging_steps 10 \
    --save_steps 100 \
    --eval_steps 100 \
    --evaluation_strategy steps \
    --warmup_steps 0 \
    --output_dir saves/Gemma-2B/lora/train_2024-03-01-04-36-32 \
    --bf16 True \
    --lora_rank 8 \
    --lora_dropout 0.1 \
    --lora_target q_proj,v_proj \
    --val_size 0.1 \
    --load_best_model_at_end True \
    --plot_loss True \
    --report_to "tensorboard"

如果你需要查看详细的流程,可以查看如下文章

Google Gemma 2B 微调实战(IT科技新闻标题生成)

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Dataset used to train GanymedeNil/gemma-2b-technology-news-title-generation-lora