Spaces:
Sleeping
Sleeping
import os | |
import time | |
import torch, torchaudio, torchvision | |
from torch.utils.data import Dataset, DataLoader | |
# from torch.utils.tensorboard import SummaryWriter | |
import numpy as np | |
# 打印库的版本信息 | |
print(f"\033[92mINFO\033[0m: PyTorch version: {torch.__version__}") | |
print(f"\033[92mINFO\033[0m: Torchaudio version: {torchaudio.__version__}") | |
print(f"\033[92mINFO\033[0m: Torchvision version: {torchvision.__version__}") | |
# 设备选择 | |
device = torch.device( | |
"cuda" | |
if torch.cuda.is_available() | |
else "mps" if torch.backends.mps.is_available() else "cpu" | |
) | |
print(f"\033[92mINFO\033[0m: Using device: {device}") | |
# 超参数设置 | |
batch_size = 1 | |
epochs = 20 | |
# 模型保存目录 | |
os.makedirs("./models/", exist_ok=True) | |
class PreprocessedDataset(Dataset): | |
def __init__(self, data_dir): | |
self.data_dir = data_dir | |
self.samples = [ | |
os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".pt") | |
] | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
sample_path = self.samples[idx] | |
mfcc, image, label = torch.load(sample_path) | |
return mfcc.float(), image.float(), label | |
class WatermelonModel(torch.nn.Module): | |
def __init__(self): | |
super(WatermelonModel, self).__init__() | |
# LSTM for audio features | |
self.lstm = torch.nn.LSTM( | |
input_size=376, hidden_size=64, num_layers=2, batch_first=True | |
) | |
self.lstm_fc = torch.nn.Linear( | |
64, 128 | |
) # Convert LSTM output to 128-dim for merging | |
# ResNet50 for image features | |
self.resnet = torchvision.models.resnet50(pretrained=True) | |
self.resnet.fc = torch.nn.Linear( | |
self.resnet.fc.in_features, 128 | |
) # Convert ResNet output to 128-dim for merging | |
# Fully connected layers for final prediction | |
self.fc1 = torch.nn.Linear(256, 64) | |
self.fc2 = torch.nn.Linear(64, 1) | |
self.relu = torch.nn.ReLU() | |
def forward(self, mfcc, image): | |
# LSTM branch | |
lstm_output, _ = self.lstm(mfcc) | |
lstm_output = lstm_output[:, -1, :] # Use the output of the last time step | |
lstm_output = self.lstm_fc(lstm_output) | |
# ResNet branch | |
resnet_output = self.resnet(image) | |
# Concatenate LSTM and ResNet outputs | |
merged = torch.cat((lstm_output, resnet_output), dim=1) | |
# Fully connected layers | |
output = self.relu(self.fc1(merged)) | |
output = self.fc2(output) | |
return output | |
def evaluate_model(model, test_loader, criterion): | |
model.eval() | |
test_loss = 0.0 | |
mae_sum = 0.0 | |
all_predictions = [] | |
all_labels = [] | |
# For debugging | |
debug_samples = [] | |
with torch.no_grad(): | |
for mfcc, image, label in test_loader: | |
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) | |
output = model(mfcc, image) | |
label = label.view(-1, 1).float() | |
# Store debug samples | |
if len(debug_samples) < 5: | |
debug_samples.append((output.item(), label.item())) | |
# Calculate MSE loss | |
loss = criterion(output, label) | |
test_loss += loss.item() | |
# Calculate MAE | |
mae = torch.abs(output - label).mean() | |
mae_sum += mae.item() | |
# Store predictions and labels for additional analysis | |
all_predictions.extend(output.cpu().numpy()) | |
all_labels.extend(label.cpu().numpy()) | |
avg_loss = test_loss / len(test_loader) | |
avg_mae = mae_sum / len(test_loader) | |
# Convert to numpy arrays for easier analysis | |
all_predictions = np.array(all_predictions).flatten() | |
all_labels = np.array(all_labels).flatten() | |
# Print debug samples | |
print("\nDEBUG SAMPLES (Prediction, Label):") | |
for i, (pred, label) in enumerate(debug_samples): | |
print(f"Sample {i+1}: Prediction = {pred:.4f}, Label = {label:.4f}, Difference = {abs(pred-label):.4f}") | |
return avg_loss, avg_mae, all_predictions, all_labels | |
def train_model(): | |
# 数据集加载 | |
data_dir = "./processed/" | |
dataset = PreprocessedDataset(data_dir) | |
n_samples = len(dataset) | |
# Check label range | |
all_labels = [] | |
for i in range(min(10, len(dataset))): | |
_, _, label = dataset[i] | |
all_labels.append(label) | |
print("\nLABEL RANGE CHECK:") | |
print(f"Sample labels: {all_labels}") | |
print(f"Min label: {min(all_labels)}, Max label: {max(all_labels)}") | |
train_size = int(0.7 * n_samples) | |
val_size = int(0.2 * n_samples) | |
test_size = n_samples - train_size - val_size | |
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split( | |
dataset, [train_size, val_size, test_size] | |
) | |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) | |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) | |
model = WatermelonModel().to(device) | |
# 损失函数和优化器 | |
criterion = torch.nn.MSELoss() | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
# TensorBoard | |
writer = SummaryWriter("runs/") | |
global_step = 0 | |
print(f"\033[92mINFO\033[0m: Training model for {epochs} epochs") | |
print(f"\033[92mINFO\033[0m: Training samples: {len(train_dataset)}") | |
print(f"\033[92mINFO\033[0m: Validation samples: {len(val_dataset)}") | |
print(f"\033[92mINFO\033[0m: Test samples: {len(test_dataset)}") | |
print(f"\033[92mINFO\033[0m: Batch size: {batch_size}") | |
best_val_loss = float('inf') | |
best_model_path = None | |
# 训练循环 | |
for epoch in range(epochs): | |
print(f"\033[92mINFO\033[0m: Training epoch ({epoch+1}/{epochs})") | |
model.train() | |
running_loss = 0.0 | |
try: | |
for mfcc, image, label in train_loader: | |
mfcc, image, label = mfcc.to(device), image.to(device), label.to(device) | |
optimizer.zero_grad() | |
output = model(mfcc, image) | |
label = label.view(-1, 1).float() | |
loss = criterion(output, label) | |
loss.backward() | |
optimizer.step() | |
running_loss += loss.item() | |
writer.add_scalar("Training Loss", loss.item(), global_step) | |
global_step += 1 | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: {e}") | |
# 验证阶段 | |
model.eval() | |
val_loss = 0.0 | |
with torch.no_grad(): | |
try: | |
for mfcc, image, label in val_loader: | |
mfcc, image, label = ( | |
mfcc.to(device), | |
image.to(device), | |
label.to(device), | |
) | |
output = model(mfcc, image) | |
loss = criterion(output, label.view(-1, 1)) | |
val_loss += loss.item() | |
except Exception as e: | |
print(f"\033[91mERR!\033[0m: {e}") | |
avg_val_loss = val_loss / len(val_loader) | |
# 记录验证损失 | |
writer.add_scalar("Validation Loss", avg_val_loss, epoch) | |
print( | |
f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss/len(train_loader):.4f}, " | |
f"Validation Loss: {avg_val_loss:.4f}" | |
) | |
# 保存模型检查点 | |
timestamp = time.strftime("%Y%m%d-%H%M%S") | |
model_path = f"models/model_{epoch+1}_{timestamp}.pt" | |
torch.save(model.state_dict(), model_path) | |
# Save the best model based on validation loss | |
if avg_val_loss < best_val_loss: | |
best_val_loss = avg_val_loss | |
best_model_path = model_path | |
print(f"\033[92mINFO\033[0m: New best model saved with validation loss: {best_val_loss:.4f}") | |
print( | |
f"\033[92mINFO\033[0m: Model checkpoint epoch [{epoch+1}/{epochs}] saved: {model_path}" | |
) | |
print(f"\033[92mINFO\033[0m: Training complete") | |
# Load the best model for testing | |
print(f"\033[92mINFO\033[0m: Loading best model from {best_model_path} for testing") | |
model.load_state_dict(torch.load(best_model_path)) | |
# Evaluate on test set | |
test_loss, test_mae, predictions, labels = evaluate_model(model, test_loader, criterion) | |
# Calculate additional metrics | |
max_error = np.max(np.abs(predictions - labels)) | |
min_error = np.min(np.abs(predictions - labels)) | |
print("\n" + "="*50) | |
print("TEST RESULTS:") | |
print(f"Test Loss (MSE): {test_loss:.4f}") | |
print(f"Mean Absolute Error: {test_mae:.4f}") | |
print(f"Maximum Absolute Error: {max_error:.4f}") | |
print(f"Minimum Absolute Error: {min_error:.4f}") | |
# Add test results to TensorBoard | |
writer.add_scalar("Test/MSE", test_loss, 0) | |
writer.add_scalar("Test/MAE", test_mae, 0) | |
writer.add_scalar("Test/Max_Error", max_error, 0) | |
writer.add_scalar("Test/Min_Error", min_error, 0) | |
# Create a histogram of absolute errors | |
abs_errors = np.abs(predictions - labels) | |
writer.add_histogram("Test/Absolute_Errors", abs_errors, 0) | |
print("="*50) | |
writer.close() | |
if __name__ == "__main__": | |
train_model() | |