smol-vision / smolvlm.py
mervenoyan's picture
add files
87b3d4b
import torch
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics3ForConditionalGeneration
from transformers import TrainingArguments, Trainer
from datasets import load_dataset
import os
from PIL import Image
from transformers.image_utils import load_image
USE_LORA = False
USE_QLORA = True
SMOL = True
model_id = "HuggingFaceTB/SmolVLM-Base" if SMOL else "HuggingFaceM4/Idefics3-8B-Llama3"
processor = AutoProcessor.from_pretrained(
model_id
)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1, 4"
if USE_QLORA or USE_LORA:
lora_config = LoraConfig(
r=8,
lora_alpha=8,
lora_dropout=0.1,
target_modules=['down_proj','o_proj','k_proj','q_proj','gate_proj','up_proj','v_proj'],
use_dora=False if USE_QLORA else True,
init_lora_weights="gaussian"
)
lora_config.inference_mode = False
if USE_QLORA:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
quantization_config=bnb_config if USE_QLORA else None,
_attn_implementation="flash_attention_2",
device_map="auto"
)
model.add_adapter(lora_config)
model.enable_adapters()
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
print(model.get_nb_trainable_parameters())
else:
model = Idefics3ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
_attn_implementation="flash_attention_2",
).to(DEVICE)
# if you'd like to only fine-tune LLM
for param in model.model.vision_model.parameters():
param.requires_grad = False
ds = load_dataset('merve/vqav2-small', trust_remote_code=True)
split_ds = ds["validation"].train_test_split(test_size=0.8)
train_ds = split_ds["train"]
image_token_id = processor.tokenizer.additional_special_tokens_ids[
processor.tokenizer.additional_special_tokens.index("<image>")]
def collate_fn(examples):
texts = []
images = []
for example in examples:
image = example["image"]
if image.mode != 'RGB':
image = image.convert('RGB')
question = example["question"]
answer = example["multiple_choice_answer"]
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Answer briefly."},
{"type": "image"},
{"type": "text", "text": question}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer}
]
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=False)
texts.append(text.strip())
images.append([image])
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100
labels[labels == image_token_id] = -100
batch["labels"] = labels
return batch
model_name = model_id.split("/")[-1]
training_args = TrainingArguments(
num_train_epochs=1,
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
warmup_steps=50,
learning_rate=1e-4,
weight_decay=0.01,
logging_steps=25,
save_strategy="steps",
save_steps=250,
save_total_limit=1,
optim="paged_adamw_8bit", # for 8-bit, keep this, else adamw_hf
bf16=True, # underlying precision for 8bit
output_dir=f"./{model_name}-vqav2",
hub_model_id=f"{model_name}-vqav2",
report_to="tensorboard",
remove_unused_columns=False,
gradient_checkpointing=True
)
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
train_dataset=train_ds,
)
trainer.train()
trainer.push_to_hub()