Spaces:
Running
Running
import json | |
import numpy as np | |
import torch | |
from torch import nn | |
from six.moves import xrange | |
from torch.utils.tensorboard import SummaryWriter | |
import random | |
from model.diffusion import ConditionedUnet | |
from tools import create_key | |
class Discriminator(nn.Module): | |
def __init__(self, label_emb_dim): | |
super(Discriminator, self).__init__() | |
# 特征图卷积层 | |
self.conv_layers = nn.Sequential( | |
nn.Conv2d(4, 64, kernel_size=4, stride=2, padding=1), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(128), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(256), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), | |
nn.BatchNorm2d(512), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.AdaptiveAvgPool2d(1), # 添加适应性池化层 | |
nn.Flatten() | |
) | |
# 文本嵌入处理 | |
self.text_embedding = nn.Sequential( | |
nn.Linear(label_emb_dim, 512), | |
nn.LeakyReLU(0.2, inplace=True) | |
) | |
# 判别器最后的全连接层 | |
self.fc = nn.Linear(512 + 512, 1) # 两个512分别来自特征图和文本嵌入 | |
def forward(self, x, text_emb): | |
x = self.conv_layers(x) | |
text_emb = self.text_embedding(text_emb) | |
combined = torch.cat((x, text_emb), dim=1) | |
output = self.fc(combined) | |
return output | |
def evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping): | |
generator.to(device) | |
discriminator.to(device) | |
generator.eval() | |
discriminator.eval() | |
real_accs = [] | |
fake_accs = [] | |
with torch.no_grad(): | |
for i in range(100): | |
data, attributes = next(iter(iterator)) | |
data = data.to(device) | |
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] | |
selected_conditions = [random.choice(conditions_of_one_sample) for conditions_of_one_sample in conditions] | |
selected_conditions = torch.stack(selected_conditions).float().to(device) | |
# 将数据和标签移至设备 | |
real_images = data.to(device) | |
labels = selected_conditions.to(device) | |
# 生成噪声和假图像 | |
noise = torch.randn_like(real_images).to(device) | |
fake_images = generator(noise) | |
# 评估鉴别器的性能 | |
real_preds = discriminator(real_images, labels).reshape(-1) | |
fake_preds = discriminator(fake_images, labels).reshape(-1) | |
real_acc = (real_preds > 0.5).float().mean().item() # 真实图像的准确率 | |
fake_acc = (fake_preds < 0.5).float().mean().item() # 生成图像的准确率 | |
real_accs.append(real_acc) | |
fake_accs.append(fake_acc) | |
# 计算平均准确率 | |
average_real_acc = sum(real_accs) / len(real_accs) | |
average_fake_acc = sum(fake_accs) / len(fake_accs) | |
return average_real_acc, average_fake_acc | |
def get_Generator(model_Config, load_pretrain=False, model_name=None, device="cpu"): | |
generator = ConditionedUnet(**model_Config) | |
print(f"Model intialized, size: {sum(p.numel() for p in generator.parameters() if p.requires_grad)}") | |
generator.to(device) | |
if load_pretrain: | |
print(f"Loading weights from models/{model_name}_generator.pth") | |
checkpoint = torch.load(f'models/{model_name}_generator.pth', map_location=device) | |
generator.load_state_dict(checkpoint['model_state_dict']) | |
generator.eval() | |
return generator | |
def get_Discriminator(model_Config, load_pretrain=False, model_name=None, device="cpu"): | |
discriminator = Discriminator(**model_Config) | |
print(f"Model intialized, size: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)}") | |
discriminator.to(device) | |
if load_pretrain: | |
print(f"Loading weights from models/{model_name}_discriminator.pth") | |
checkpoint = torch.load(f'models/{model_name}_discriminator.pth', map_location=device) | |
discriminator.load_state_dict(checkpoint['model_state_dict']) | |
discriminator.eval() | |
return discriminator | |
def train_GAN(device, init_model_name, unetConfig, BATCH_SIZE, lr_G, lr_D, max_iter, iterator, load_pretrain, | |
encodes2embeddings_mapping, save_steps, unconditional_condition, uncondition_rate, save_model_name=None): | |
if save_model_name is None: | |
save_model_name = init_model_name | |
def save_model_hyperparameter(model_name, unetConfig, BATCH_SIZE, model_size, current_iter, current_loss): | |
model_hyperparameter = unetConfig | |
model_hyperparameter["BATCH_SIZE"] = BATCH_SIZE | |
model_hyperparameter["lr_G"] = lr_G | |
model_hyperparameter["lr_D"] = lr_D | |
model_hyperparameter["model_size"] = model_size | |
model_hyperparameter["current_iter"] = current_iter | |
model_hyperparameter["current_loss"] = current_loss | |
with open(f"models/hyperparameters/{model_name}_GAN.json", "w") as json_file: | |
json.dump(model_hyperparameter, json_file, ensure_ascii=False, indent=4) | |
generator = ConditionedUnet(**unetConfig) | |
discriminator = Discriminator(unetConfig["label_emb_dim"]) | |
generator_size = sum(p.numel() for p in generator.parameters() if p.requires_grad) | |
discriminator_size = sum(p.numel() for p in discriminator.parameters() if p.requires_grad) | |
print(f"Generator trainable parameters: {generator_size}, discriminator trainable parameters: {discriminator_size}") | |
generator.to(device) | |
discriminator.to(device) | |
optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, generator.parameters()), lr=lr_G, amsgrad=False) | |
optimizer_D = torch.optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=lr_D, amsgrad=False) | |
if load_pretrain: | |
print(f"Loading weights from models/{init_model_name}_generator.pt") | |
checkpoint = torch.load(f'models/{init_model_name}_generator.pth') | |
generator.load_state_dict(checkpoint['model_state_dict']) | |
optimizer_G.load_state_dict(checkpoint['optimizer_state_dict']) | |
print(f"Loading weights from models/{init_model_name}_discriminator.pt") | |
checkpoint = torch.load(f'models/{init_model_name}_discriminator.pth') | |
discriminator.load_state_dict(checkpoint['model_state_dict']) | |
optimizer_D.load_state_dict(checkpoint['optimizer_state_dict']) | |
else: | |
print("Model initialized.") | |
if max_iter == 0: | |
print("Return model directly.") | |
return generator, discriminator, optimizer_G, optimizer_D | |
train_loss_G, train_loss_D = [], [] | |
writer = SummaryWriter(f'runs/{save_model_name}_GAN') | |
# average_real_acc, average_fake_acc = evaluate_GAN(device, generator, discriminator, iterator, encodes2embeddings_mapping) | |
# print(f"average_real_acc, average_fake_acc: {average_real_acc, average_fake_acc}") | |
criterion = nn.BCEWithLogitsLoss() | |
generator.train() | |
for i in xrange(max_iter): | |
data, attributes = next(iter(iterator)) | |
data = data.to(device) | |
conditions = [encodes2embeddings_mapping[create_key(attribute)] for attribute in attributes] | |
unconditional_condition_copy = torch.tensor(unconditional_condition, dtype=torch.float32).to(device).detach() | |
selected_conditions = [unconditional_condition_copy if random.random() < uncondition_rate else random.choice( | |
conditions_of_one_sample) for conditions_of_one_sample in conditions] | |
batch_size = len(selected_conditions) | |
selected_conditions = torch.stack(selected_conditions).float().to(device) | |
# 将数据和标签移至设备 | |
real_images = data.to(device) | |
labels = selected_conditions.to(device) | |
# 真实和假的标签 | |
real_labels = torch.ones(batch_size, 1).to(device) | |
fake_labels = torch.zeros(batch_size, 1).to(device) | |
# ========== 训练鉴别器 ========== | |
optimizer_D.zero_grad() | |
# 计算鉴别器对真实图像的损失 | |
outputs_real = discriminator(real_images, labels) | |
loss_D_real = criterion(outputs_real, real_labels) | |
# 生成假图像 | |
noise = torch.randn_like(real_images).to(device) | |
fake_images = generator(noise, labels) | |
# 计算鉴别器对假图像的损失 | |
outputs_fake = discriminator(fake_images.detach(), labels) | |
loss_D_fake = criterion(outputs_fake, fake_labels) | |
# 反向传播和优化 | |
loss_D = loss_D_real + loss_D_fake | |
loss_D.backward() | |
optimizer_D.step() | |
# ========== 训练生成器 ========== | |
optimizer_G.zero_grad() | |
# 计算生成器的损失 | |
outputs_fake = discriminator(fake_images, labels) | |
loss_G = criterion(outputs_fake, real_labels) | |
# 反向传播和优化 | |
loss_G.backward() | |
optimizer_G.step() | |
train_loss_G.append(loss_G.item()) | |
train_loss_D.append(loss_D.item()) | |
step = int(optimizer_G.state_dict()['state'][list(optimizer_G.state_dict()['state'].keys())[0]]['step'].numpy()) | |
if (i + 1) % 100 == 0: | |
print('%d step' % (step)) | |
if (i + 1) % save_steps == 0: | |
current_loss_D = np.mean(train_loss_D[-save_steps:]) | |
current_loss_G = np.mean(train_loss_G[-save_steps:]) | |
print('current_loss_G: %.5f' % current_loss_G) | |
print('current_loss_D: %.5f' % current_loss_D) | |
writer.add_scalar(f"current_loss_G", current_loss_G, step) | |
writer.add_scalar(f"current_loss_D", current_loss_D, step) | |
torch.save({ | |
'model_state_dict': generator.state_dict(), | |
'optimizer_state_dict': optimizer_G.state_dict(), | |
}, f'models/{save_model_name}_generator.pth') | |
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, generator_size, step, current_loss_G) | |
torch.save({ | |
'model_state_dict': discriminator.state_dict(), | |
'optimizer_state_dict': optimizer_D.state_dict(), | |
}, f'models/{save_model_name}_discriminator.pth') | |
save_model_hyperparameter(save_model_name, unetConfig, BATCH_SIZE, discriminator_size, step, current_loss_D) | |
if step % 10000 == 0: | |
torch.save({ | |
'model_state_dict': generator.state_dict(), | |
'optimizer_state_dict': optimizer_G.state_dict(), | |
}, f'models/history/{save_model_name}_{step}_generator.pth') | |
torch.save({ | |
'model_state_dict': discriminator.state_dict(), | |
'optimizer_state_dict': optimizer_D.state_dict(), | |
}, f'models/history/{save_model_name}_{step}_discriminator.pth') | |
return generator, discriminator, optimizer_G, optimizer_D | |