watermelon / train.py
Xalphinions's picture
Upload folder using huggingface_hub
48e5328 verified
import os
import time
import torch, torchaudio, torchvision
from torch.utils.data import Dataset, DataLoader
# from torch.utils.tensorboard import SummaryWriter
import numpy as np
# 打印库的版本信息
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}")
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}")
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}")
# 设备选择
device = torch.device(
"cuda"
if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available() else "cpu"
)
print(f"\033[92mINFO\033[0m: Using device: {device}")
# 超参数设置
batch_size = 1
epochs = 20
# 模型保存目录
os.makedirs("./models/", exist_ok=True)
class PreprocessedDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.samples = [
os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt")
]
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample_path = self.samples[idx]
mfcc, image, label = torch.load(sample_path)
return mfcc.float(), image.float(), label
class WatermelonModel(torch.nn.Module):
def __init__(self):
super(WatermelonModel, self).__init__()
# LSTM for audio features
self.lstm = torch.nn.LSTM(
input_size=376, hidden_size=64, num_layers=2, batch_first=True
)
self.lstm_fc = torch.nn.Linear(
64, 128
) # Convert LSTM output to 128-dim for merging
# ResNet50 for image features
self.resnet = torchvision.models.resnet50(pretrained=True)
self.resnet.fc = torch.nn.Linear(
self.resnet.fc.in_features, 128
) # Convert ResNet output to 128-dim for merging
# Fully connected layers for final prediction
self.fc1 = torch.nn.Linear(256, 64)
self.fc2 = torch.nn.Linear(64, 1)
self.relu = torch.nn.ReLU()
def forward(self, mfcc, image):
# LSTM branch
lstm_output, _ = self.lstm(mfcc)
lstm_output = lstm_output[:, -1, :] # Use the output of the last time step
lstm_output = self.lstm_fc(lstm_output)
# ResNet branch
resnet_output = self.resnet(image)
# Concatenate LSTM and ResNet outputs
merged = torch.cat((lstm_output, resnet_output), dim=1)
# Fully connected layers
output = self.relu(self.fc1(merged))
output = self.fc2(output)
return output
def evaluate_model(model, test_loader, criterion):
model.eval()
test_loss = 0.0
mae_sum = 0.0
all_predictions = []
all_labels = []
# For debugging
debug_samples = []
with torch.no_grad():
for mfcc, image, label in test_loader:
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
output = model(mfcc, image)
label = label.view(-1, 1).float()
# Store debug samples
if len(debug_samples) < 5:
debug_samples.append((output.item(), label.item()))
# Calculate MSE loss
loss = criterion(output, label)
test_loss += loss.item()
# Calculate MAE
mae = torch.abs(output - label).mean()
mae_sum += mae.item()
# Store predictions and labels for additional analysis
all_predictions.extend(output.cpu().numpy())
all_labels.extend(label.cpu().numpy())
avg_loss = test_loss / len(test_loader)
avg_mae = mae_sum / len(test_loader)
# Convert to numpy arrays for easier analysis
all_predictions = np.array(all_predictions).flatten()
all_labels = np.array(all_labels).flatten()
# Print debug samples
print("\nDEBUG SAMPLES (Prediction, Label):")
for i, (pred, label) in enumerate(debug_samples):
print(f"Sample {i+1}: Prediction = {pred:.4f}, Label = {label:.4f}, Difference = {abs(pred-label):.4f}")
return avg_loss, avg_mae, all_predictions, all_labels
def train_model():
# 数据集加载
data_dir = "./processed/"
dataset = PreprocessedDataset(data_dir)
n_samples = len(dataset)
# Check label range
all_labels = []
for i in range(min(10, len(dataset))):
_, _, label = dataset[i]
all_labels.append(label)
print("\nLABEL RANGE CHECK:")
print(f"Sample labels: {all_labels}")
print(f"Min label: {min(all_labels)}, Max label: {max(all_labels)}")
train_size = int(0.7 * n_samples)
val_size = int(0.2 * n_samples)
test_size = n_samples - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
dataset, [train_size, val_size, test_size]
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
model = WatermelonModel().to(device)
# 损失函数和优化器
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# TensorBoard
writer = SummaryWriter("runs/")
global_step = 0
print(f"\033[92mINFO\033[0m: Training model for {epochs} epochs")
print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}")
print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}")
print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}")
print(f"\033[92mINFO\033[0m: Batch size: {batch_size}")
best_val_loss = float('inf')
best_model_path = None
# 训练循环
for epoch in range(epochs):
print(f"\033[92mINFO\033[0m: Training epoch ({epoch+1}/{epochs})")
model.train()
running_loss = 0.0
try:
for mfcc, image, label in train_loader:
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device)
optimizer.zero_grad()
output = model(mfcc, image)
label = label.view(-1, 1).float()
loss = criterion(output, label)
loss.backward()
optimizer.step()
running_loss += loss.item()
writer.add_scalar("Training Loss", loss.item(), global_step)
global_step += 1
except Exception as e:
print(f"\033[91mERR!\033[0m: {e}")
# 验证阶段
model.eval()
val_loss = 0.0
with torch.no_grad():
try:
for mfcc, image, label in val_loader:
mfcc, image, label = (
mfcc.to(device),
image.to(device),
label.to(device),
)
output = model(mfcc, image)
loss = criterion(output, label.view(-1, 1))
val_loss += loss.item()
except Exception as e:
print(f"\033[91mERR!\033[0m: {e}")
avg_val_loss = val_loss / len(val_loader)
# 记录验证损失
writer.add_scalar("Validation Loss", avg_val_loss, epoch)
print(
f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss/len(train_loader):.4f}, "
f"Validation Loss: {avg_val_loss:.4f}"
)
# 保存模型检查点
timestamp = time.strftime("%Y%m%d-%H%M%S")
model_path = f"models/model_{epoch+1}_{timestamp}.pt"
torch.save(model.state_dict(), model_path)
# Save the best model based on validation loss
if avg_val_loss < best_val_loss:
best_val_loss = avg_val_loss
best_model_path = model_path
print(f"\033[92mINFO\033[0m: New best model saved with validation loss: {best_val_loss:.4f}")
print(
f"\033[92mINFO\033[0m: Model checkpoint epoch [{epoch+1}/{epochs}] saved: {model_path}"
)
print(f"\033[92mINFO\033[0m: Training complete")
# Load the best model for testing
print(f"\033[92mINFO\033[0m: Loading best model from {best_model_path} for testing")
model.load_state_dict(torch.load(best_model_path))
# Evaluate on test set
test_loss, test_mae, predictions, labels = evaluate_model(model, test_loader, criterion)
# Calculate additional metrics
max_error = np.max(np.abs(predictions - labels))
min_error = np.min(np.abs(predictions - labels))
print("\n" + "="*50)
print("TEST RESULTS:")
print(f"Test Loss (MSE): {test_loss:.4f}")
print(f"Mean Absolute Error: {test_mae:.4f}")
print(f"Maximum Absolute Error: {max_error:.4f}")
print(f"Minimum Absolute Error: {min_error:.4f}")
# Add test results to TensorBoard
writer.add_scalar("Test/MSE", test_loss, 0)
writer.add_scalar("Test/MAE", test_mae, 0)
writer.add_scalar("Test/Max_Error", max_error, 0)
writer.add_scalar("Test/Min_Error", min_error, 0)
# Create a histogram of absolute errors
abs_errors = np.abs(predictions - labels)
writer.add_histogram("Test/Absolute_Errors", abs_errors, 0)
print("="*50)
writer.close()
if __name__ == "__main__":
train_model()