import json import numpy as np import torch from torch import nn from six.moves import xrange from torch.utils.tensorboard import SummaryWriter import random from model.diffusion import ConditionedUnet from tools import create_key class Discriminator(nn.Module): def __init__(self, label_emb_dim): super(Discriminator, self).__init__() # 特征图卷积层 self.conv_layers = nn.Sequential( nn.Conv2d(4, 64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplace=True), nn.AdaptiveAvgPool2d(1), # 添加适应性池化层 nn.Flatten() ) # 文本嵌入处理 self.text_embedding = nn.Sequential( nn.Linear(label_emb_dim, 512), nn.LeakyReLU(0.2, inplace=True) ) # 判别器最后的全连接层 self.fc = nn.Linear(512 + 512, 1) # 两个512分别来自特征图和文本嵌入 def forward(self, x, text_emb): x = self.conv_layers(x) text_emb = self.text_embedding(text_emb) combined = torch.cat((x, text_emb), dim=1) output = self.fc(combined) return output def evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping): generator.to(device) discriminator.to(device) generator.eval() discriminator.eval() real_accs = [] fake_accs = [] with torch.no_grad(): for i in range(100): data, attributes = next(iter(iterator)) data = data.to(device) conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] selected_conditions = [random.choice(conditions_of_one_sample) for conditions_of_one_sample in conditions] selected_conditions = torch.stack(selected_conditions).float().to(device) # 将数据和标签移至设备 real_images = data.to(device) labels = selected_conditions.to(device) # 生成噪声和假图像 noise = torch.randn_like(real_images).to(device) fake_images = generator(noise) # 评估鉴别器的性能 real_preds = discriminator(real_images, labels).reshape(-1) fake_preds = discriminator(fake_images, labels).reshape(-1) real_acc = (real_preds > 0.5).float().mean().item() # 真实图像的准确率 fake_acc = (fake_preds < 0.5).float().mean().item() # 生成图像的准确率 real_accs.append(real_acc) fake_accs.append(fake_acc) # 计算平均准确率 average_real_acc = sum(real_accs) / len(real_accs) average_fake_acc = sum(fake_accs) / len(fake_accs) return average_real_acc, average_fake_acc def get_Generator(model_Config, load_pretrain=False, model_name=None, device="cpu"): generator = ConditionedUnet(**model_Config) print(f"Model intialized, size: {sum(p.numel() for p in generator.parameters() if p.requires_grad)}") generator.to(device) if load_pretrain: print(f"Loading weights from models/{model_name}_generator.pth") checkpoint = torch.load(f'models/{model_name}_generator.pth', map_location=device) generator.load_state_dict(checkpoint['model_state_dict']) generator.eval() return generator def get_Discriminator(model_Config, load_pretrain=False, model_name=None, device="cpu"): discriminator = Discriminator(**model_Config) print(f"Model intialized, size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}") discriminator.to(device) if load_pretrain: print(f"Loading weights from models/{model_name}_discriminator.pth") checkpoint = torch.load(f'models/{model_name}_discriminator.pth', map_location=device) discriminator.load_state_dict(checkpoint['model_state_dict']) discriminator.eval() return discriminator def train_GAN(device, init_model_name, unetConfig, BATCH_SIZE, lr_G, lr_D, max_iter, iterator, load_pretrain, encodes2embeddings_mapping, save_steps, unconditional_condition, uncondition_rate, save_model_name=None): if save_model_name is None: save_model_name = init_model_name def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, model_size, current_iter, current_loss): model_hyperparameter = unetConfig model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE model_hyperparameter["lr_G"] = lr_G model_hyperparameter["lr_D"] = lr_D model_hyperparameter["model_size"] = model_size model_hyperparameter["current_iter"] = current_iter model_hyperparameter["current_loss"] = current_loss with open(f"models/hyperparameters/{model_name}_GAN.json", "w") as json_file: json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) generator = ConditionedUnet(**unetConfig) discriminator = Discriminator(unetConfig["label_emb_dim"]) generator_size = sum(p.numel() for p in generator.parameters() if p.requires_grad) discriminator_size = sum(p.numel() for p in discriminator.parameters() if p.requires_grad) print(f"Generator trainable parameters: {generator_size}, discriminator trainable parameters: {discriminator_size}") generator.to(device) discriminator.to(device) optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=lr_G, amsgrad=False) optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=lr_D, amsgrad=False) if load_pretrain: print(f"Loading weights from models/{init_model_name}_generator.pt") checkpoint = torch.load(f'models/{init_model_name}_generator.pth') generator.load_state_dict(checkpoint['model_state_dict']) optimizer_G.load_state_dict(checkpoint['optimizer_state_dict']) print(f"Loading weights from models/{init_model_name}_discriminator.pt") checkpoint = torch.load(f'models/{init_model_name}_discriminator.pth') discriminator.load_state_dict(checkpoint['model_state_dict']) optimizer_D.load_state_dict(checkpoint['optimizer_state_dict']) else: print("Model initialized.") if max_iter == 0: print("Return model directly.") return generator, discriminator, optimizer_G, optimizer_D train_loss_G, train_loss_D = [], [] writer = SummaryWriter(f'runs/{save_model_name}_GAN') # average_real_acc, average_fake_acc = evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping) # print(f"average_real_acc, average_fake_acc: {average_real_acc, average_fake_acc}") criterion = nn.BCEWithLogitsLoss() generator.train() for i in xrange(max_iter): data, attributes = next(iter(iterator)) data = data.to(device) conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach() selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice( conditions_of_one_sample) for conditions_of_one_sample in conditions] batch_size = len(selected_conditions) selected_conditions = torch.stack(selected_conditions).float().to(device) # 将数据和标签移至设备 real_images = data.to(device) labels = selected_conditions.to(device) # 真实和假的标签 real_labels = torch.ones(batch_size, 1).to(device) fake_labels = torch.zeros(batch_size, 1).to(device) # ========== 训练鉴别器 ========== optimizer_D.zero_grad() # 计算鉴别器对真实图像的损失 outputs_real = discriminator(real_images, labels) loss_D_real = criterion(outputs_real, real_labels) # 生成假图像 noise = torch.randn_like(real_images).to(device) fake_images = generator(noise, labels) # 计算鉴别器对假图像的损失 outputs_fake = discriminator(fake_images.detach(), labels) loss_D_fake = criterion(outputs_fake, fake_labels) # 反向传播和优化 loss_D = loss_D_real + loss_D_fake loss_D.backward() optimizer_D.step() # ========== 训练生成器 ========== optimizer_G.zero_grad() # 计算生成器的损失 outputs_fake = discriminator(fake_images, labels) loss_G = criterion(outputs_fake, real_labels) # 反向传播和优化 loss_G.backward() optimizer_G.step() train_loss_G.append(loss_G.item()) train_loss_D.append(loss_D.item()) step = int(optimizer_G.state_dict()['state'][list(optimizer_G.state_dict()['state'].keys())[0]]['step'].numpy()) if (i + 1) % 100 == 0: print('%d step' % (step)) if (i + 1) % save_steps == 0: current_loss_D = np.mean(train_loss_D[-save_steps:]) current_loss_G = np.mean(train_loss_G[-save_steps:]) print('current_loss_G: %.5f' % current_loss_G) print('current_loss_D: %.5f' % current_loss_D) writer.add_scalar(f"current_loss_G", current_loss_G, step) writer.add_scalar(f"current_loss_D", current_loss_D, step) torch.save({ 'model_state_dict': generator.state_dict(), 'optimizer_state_dict': optimizer_G.state_dict(), }, f'models/{save_model_name}_generator.pth') save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, generator_size, step, current_loss_G) torch.save({ 'model_state_dict': discriminator.state_dict(), 'optimizer_state_dict': optimizer_D.state_dict(), }, f'models/{save_model_name}_discriminator.pth') save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, discriminator_size, step, current_loss_D) if step % 10000 == 0: torch.save({ 'model_state_dict': generator.state_dict(), 'optimizer_state_dict': optimizer_G.state_dict(), }, f'models/history/{save_model_name}_{step}_generator.pth') torch.save({ 'model_state_dict': discriminator.state_dict(), 'optimizer_state_dict': optimizer_D.state_dict(), }, f'models/history/{save_model_name}_{step}_discriminator.pth') return generator, discriminator, optimizer_G, optimizer_D