资讯详情

深度学习系列33:有标签的CGAN:Pix2Pix/Pix2PixHD/cycleGAN

1. 从GAN到CGAN

GAN训练数据没有标签。如果我们想进行标签训练,我们需要使用它CGAN。 对于图像,我们不仅要让输出图像真实,还要让输出图像符合标签c。Discriminator输入被改为同时输入c和x,输出需要做两件事,一是判断x是否是真实图片,二是x和c是否匹配。 在以下两种情况下,虽然左侧输出图片清晰,但不符合c;右侧输出图片不真实。所以D的输出在两种情况下都是0。 在这里插入图片描述

让我们来看看简单的示例代码:

import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import numpy as np import matplotlib.pyplot as plt import torchvision from torchvision import transforms from torch.utils import data import os import glob from PIL import Image   # 独热编码 # 输入x代表默认torchvision返回的类比值,class_count类别值为10 def one_hot(x, class_count=10):     return torch.eye(class_count)[x, :]  # 切片选择,第一维选择x,第二维全要     transform =transforms.Compose([transforms.ToTensor(),                                transforms.Normalize(0.5, 0.5)])   dataset = torchvision.datasets.MNIST('data',                                      train=True,                                      transform=transform,                                      target_transform=one_hot,                                      download=False) dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)     # 定义生成器 class Generator(nn.Module):     def __init__(self):         super(Generator, self).__init__()         self.linear1 = nn.Linear(10, 128 * 7 * 7)         self.bn1 = nn.BatchNorm1d(128 * 7 * 7)         self.linear2 = nn.Linear(100, 128 * 7 * 7)         self.bn2 = nn.BatchNorm1d(128 * 7 * 7)         self.deconv1 = nn.ConvTranspose2d(256, 128,                                           kernel_size=(3, 3),                                           padding=1)         self.bn3 = nn.BatchNorm2d(128)         self.deconv2 = nn.ConvTranspose2d(128, 64,                                           kernel_size=(4, 4),                                           stride=2,                                           padding=1)         self.bn4 = nn.BatchNorm2d(64)         self.deconv3 = nn.ConvTranspose2d(64, 1,                                           kernel_size=(4, 4),                                           stride=2,                                           padding=1)       def forward(self, x1, x2):         x1 = F.relu(self.linear1(x1))         x1 = self.bn1(x1)         x1 = x1.view(-1, 128, 7, 7)         x2 = F.relu(self.linear2(x2))         x2 = self.bn2(x2)         x2 = x2.view(-1, 128, 7, 7)         x = torch.cat([x1, x2], axis=1)         x = F.relu(self.deconv1(x))         x = self.bn3(x)         x = F.relu(self.deconv2(x))         x = self.bn4(x)         x = torch.tanh(self.deconv3(x))         return x   # 定义判别器 # input:1,28,28的图片以及长度为10的condition class Discriminator(nn.Module):     def __init__(self):         super(Discriminator, self).__init__()         self.linear = nn.Linear(10, 1*28*28)         self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2)         self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)         self.bn = nn.BatchNorm2d(128)         self.fc = nn.Linear(128*6*6, 1) # 输出概率值       def forward(self, x1, x2):         x1 =F.leaky_relu(self.linear(x1))         x1 = x1.view(-1, 1, 28, 28)         x = torch.cat([x1, x2], axis=1)         x = F.dropout2d(F.leaky_relu(self.conv1(x)))         x = F.dropout2d(F.leaky_relu(self.conv2(x)))         x = self.bn(x)         x = x.view(-1, 128*6*6)         x = torch.sigmoid(self.fc(x))         return x   # 初始化模型 device = 'cuda' if torch.cuda.is_available() else 'cpu' gen = Generator().to(device) dis = Discriminator().to(device)   # 计算损失函数 loss_function = torch.nn.BCELoss()   # 定义优化器 d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5) g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)     # 定义可视化函数 def generate_and_save_images(model, epoch, label_input, noise_input):     predictions = np.squeeze(model(label_input, noise_input).cpu().numpy())     fig = plt.figure(figsize=(4, 4))     for i in range(predictions.shape[0]):         plt.subplot(4, 4, i   1)         plt.imshow((predictions[i]   1) / 2, cmap='gray')         plt.axis("off")     plt.show() noise_seed = torch.randn(16, 100, device=device)   label_seed = torch.randint(0, 10, size=(16,)) label_seed_onehot = one_hot(label_seed).to(device) print(label_seed) # print(label_seed_onehot)   # 开始训练 D_loss = [] G_loss = [] # 训练循环 for epoch in range(150):     d_epoch_loss = 0     g_epoch_loss = 0     count = len(dataloader.dataset)     # 迭代所有数据集     for step, (img, label) in enumerate(dataloader):         img = img.to(device)         label = label.to(device)         size = mg.shape[0]
        random_noise = torch.randn(size, 100, device=device)
 
        d_optim.zero_grad()
 
        real_output = dis(label, img)
        d_real_loss = loss_function(real_output,
                                    torch.ones_like(real_output, device=device)
                                    )
        d_real_loss.backward() #求解梯度
 
        # 得到判别器在生成图像上的损失
        gen_img = gen(label,random_noise)
        fake_output = dis(label, gen_img.detach())  # 判别器输入生成的图片,f_o是对生成图片的预测结果
        d_fake_loss = loss_function(fake_output,
                                    torch.zeros_like(fake_output, device=device))
        d_fake_loss.backward()
 
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()  # 优化
 
        # 得到生成器的损失
        g_optim.zero_grad()
        fake_output = dis(label, gen_img)
        g_loss = loss_function(fake_output,
                               torch.ones_like(fake_output, device=device))
        g_loss.backward()
        g_optim.step()
 
        with torch.no_grad():
            d_epoch_loss += d_loss.item()
            g_epoch_loss += g_loss.item()
    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        if epoch % 10 == 0:
            print('Epoch:', epoch)
            generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)

2. Pix2pix:像素级别转换

这里是尝试地址:https://affinelayer.com/pixsrv/ 使用Pix2Pix神经网络模型实现论文中预定义的任务:黑白简笔画到彩图、平面房屋到立体房屋和航拍图到地图等功能: Pix2pixgan本质上是一个cgan,图片x作为此cGAN的条件, 需要输入到G和D中。 G的输入是x(x 是需要转换的图片),输出是生成的图片G(x)。 D则需要分辨出{x,G(x)}和{x, y}。 这里的生成器模型我们采用U-Net: 在pix2pix中,作者就是把L1 loss 和GAN loss相结合使用,因为作者认为L1 loss 可以恢复图像的低频部分,而GAN loss可以恢复图像的高频部分。判别器使用patchGAN。

我们看一些代码说明:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.down4 = Downsample(256, 512)
        self.down5 = Downsample(512, 512)
        self.down6 = Downsample(512, 512)
 
        self.up1 = Upsample(512, 512)
        self.up2 = Upsample(1024, 512)
        self.up3 = Upsample(1024, 256)
        self.up4 = Upsample(512, 128)
        self.up5 = Upsample(256, 64)
 
        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)
 
    def forward(self, x):
        x1 = self.down1(x, is_bn=False)  # torch.Size([8, 64, 128, 128])
        x2 = self.down2(x1)  # torch.Size([8, 128, 64, 64])
        x3 = self.down3(x2)  # torch.Size([8, 256, 32, 32])
        x4 = self.down4(x3)  # torch.Size([8, 512, 16, 16])
        x5 = self.down5(x4)  # torch.Size([8, 512, 8, 8])
        x6 = self.down6(x5)  # torch.Size([8, 512, 4, 4])
 
        x6 = self.up1(x6, is_drop=True)  # torch.Size([8, 512, 8, 8])
        x6 = torch.cat([x5, x6], dim=1)  # torch.Size([8, 1024, 8, 8])
 
        x6 = self.up2(x6, is_drop=True)  # torch.Size([8, 512, 16, 16])
        x6 = torch.cat([x4, x6], dim=1)  # torch.Size([8, 1024, 16, 16])
 
        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x3, x6], dim=1)
 
        x6 = self.up4(x6)
        x6 = torch.cat([x2, x6], dim=1)
 
        x6 = self.up5(x6)
        x6 = torch.cat([x1, x6], dim=1)
 
        x6 = torch.tanh(self.last(x6))
        return x6

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.down1 = Downsample(6, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.conv = nn.Conv2d(256, 512, 3, 1, 1)
        self.bn = nn.BatchNorm2d(512)
        self.last = nn.Conv2d(512, 1, 3, 1)
 
    def forward(self, anno, img):
        x = torch.cat([anno, img], dim=1)  # batch*6*H*W
        x = self.down1(x, is_bn=False)
        x = self.down2(x)
        x = F.dropout2d(self.down3(x))
        x = F.dropout2d(F.leaky_relu(self.conv(x)))
        x = F.dropout2d(self.bn(x))
        x = torch.sigmoid(self.last(x))
        return x

3. Pix2PixHD

在pix2pix的基础上,增加了一个“从糙到精生成器(coarse-to-fine generator)”、一个多尺度鉴别器架构和一个健壮的对抗学习目标函数。 1)生成器部分提高分辨率:将生成器U-net拆分成两个子网络G1和G2进行训练:前者输入和输出的分辨率保持一致(如 1024 x 512),后者输出尺寸(2048x1024)是输入尺寸(1024x512)的4倍(长宽各两倍)。如果想要得到更高分辨率的图像,只需要增加更多的局部增强网络即可(如 G={G1,G2,G3}) 2)判别器部分将深度改为宽度:使用三个相同结构的判别器,分别处理不同尺寸的输入。 3)损失函数更稳健:除了PatchGAN的损失,还加上了样本与GT使用判别器网络和VGG16网络提取特征后进行的Element-wise loss 4)输入加入高频特征向量,例如图像的边缘信息,与输入的语义标签连接到一起作为输入。 5)额外学习一个Feature encoder网络,可以将原图转化为features,用来控制图像的颜色、纹理信息。

4. CycleGAN:风格转换

pix2pixGAN有一个明显的缺点就是,在进行训练的时候必须提供成对的数据集。比如当我们想生成梵高风格的画时,梵高本人画的作品肯定是相对较少的,这个时候就可以考虑使用cycleGAN。cycleGAN适用于非配对的图像到图像转换: 其原理可以概括为将一类图片转成成另一类图片,比如,现有两个样本空间X、Y,我们希望把X空间中的样本转换成Y空间中的样本。这种转换只是风格上的转换,实际X Y 的内容是不一样的。实际的目标就是学习从X到Y的映射,假设该映射为F,它就对应着GAN中的生成器,F就可以将X中的图片A转换为Y中的图片F(A)。 为了实现这个过程,我们需要两个生成器 G_AB 和 G_BA: 首先是单向loss的组成: 判别 loss: 判别器 D_B 是用来判断输入的图片是否是真实的 B 图片,这个流程和GAN是一致的。 生成 loss:生成器用来重建图片 a,目的是希望生成的图片 G_BA(G_AB(a)) 和原图 a 尽可能的相似,那么可以很简单的采取 L1 loss 或者 L2 loss。除了GAN loss,还包含如下loss: ① cycle-loss:也就是循环一致损失。因为网络需要保证生成的图像必须保留有原 始图像的特性,所以如果我们使用生成器GA-B生成一张假图像,那么要能够使用另外一个生成器 GB-A来努力恢复成原始图像。此过程必须满足循环一致性 ② 等价loss:我们要求 G A B ( b ) = b G_{AB}(b)=b GAB​(b)=b,以及 G B A ( a ) = a G_{BA}(a)=a GBA​(a)=a。

下面来看下示例代码: 获取苹果橙子数据:

# 加载训练数据
apples_path = glob.glob('data/trainA/*.jpg')
oranges_path = glob.glob('data/trainB/*.jpg')
 
 
transform = transforms.Compose([transforms.ToTensor(),  # 0-1归一化
                                transforms.Normalize(0.5, 0.5),  # -1,1])
 
class AppleOrangeDataset(data.Dataset):
    def __init__(self, img_path):
        self.img_path = img_path
 
    def __getitem__(self, index):
        img_path = self.img_path[index]
        pil_img = Image.open(img_path)
        pil_img = transform(pil_img)
        return pil_img
    def __len__(self):
        return len(self.img_path)
 
apple_dataset = AppleOrangeDataset(apples_path)
oranges_dataset = AppleOrangeDataset(oranges_path)

基于Unet结构定义上 / 下采样模块,接着定义生成器:

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.down1 = Downsample(3, 64)
        self.down2 = Downsample(64, 128)
        self.down3 = Downsample(128, 256)
        self.down4 = Downsample(256, 512)
        self.down5 = Downsample(512, 512)
        self.down6 = Downsample(512, 512)
 
        self.up1 = Upsample(512, 512)
        self.up2 = Upsample(1024, 512)
        self.up3 = Upsample(1024, 256)
        self.up4 = Upsample(512, 128)
        self.up5 = Upsample(256, 64)
 
        self.last = nn.ConvTranspose2d(128, 3,
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1)
 
    def forward(self, x):
        x1 = self.down1(x, is_bn=False)  # torch.Size([8, 64, 128, 128])
        x2 = self.down2(x1)  # torch.Size([8, 128, 64, 64])
        x3 = self.down3(x2)  # torch.Size([8, 256, 32, 32])
        x4 = self.down4(x3)  # torch.Size([8, 512, 16, 16])
        x5 = self.down5(x4)  # torch.Size([8, 512, 8, 8])
        x6 = self.down6(x5)  # torch.Size([8, 512, 4, 4])
 
        x6 = self.up1(x6, is_drop=True)  # torch.Size([8, 512, 8, 8])
        x6 = torch.cat([x5, x6], dim=1)  # torch.Size([8, 1024, 8, 8])
 
        x6 = self.up2(x6, is_drop=True)  # torch.Size([8, 512, 16, 16])
        x6 = torch.cat([x4, x6], dim=1)  # torch.Size([8, 1024, 16, 16])
 
        x6 = self.up3(x6, is_drop=True)
        x6 = torch.cat([x3, x6], dim=1)
 
        x6 = self.up4(x6)
        x6 = torch.cat([x2, x6], dim=1)
 
        x6 = self.up5(x6)
        x6 = torch.cat([x1, x6], dim=1)
 
        x6 = torch.tanh(self.last(x6))
        return x6

接下来是鉴别器:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.down1 = Downsample(3, 64)             # 128
        self.down2 = Downsample(64, 128)           # 64
        self.last = nn.Conv2d(128, 1, 3)
 
    def forward(self, img):
        x = self.down1(img)
        x = self.down2(x)
        x = torch.sigmoid(self.last(x))
        return x

我们需要定义两个生成器和两个鉴别器:


gen_AB = Generator().to(device)
gen_BA = Generator().to(device)
dis_A = Discriminator().to(device)
dis_B = Discriminator().to(device)

# 同时对两个生成器进行优化
gen_optimizer = torch.optim.Adam(itertools.chain(gen_AB.parameters(), gen_BA.parameters()),
                                 lr=2e-4, betas=(0.5, 0.999))
dis_A_optimizer = torch.optim.Adam(dis_A.parameters(), lr=2e-4, betas=(0.5, 0.999))
dis_B_optimizer = torch.optim.Adam(dis_B.parameters(), lr=2e-4, betas=(0.5, 0.999))

训练过程如下:

D_loss = []  # 记录训练过程中判别器loss变化
G_loss = []  # 记录训练过程中生成器loss变化
 
# 开始训练
for epoch in range(50):
    D_epoch_loss = 0
    G_epoch_loss = 0
    for step, (real_A, real_B) in enumerate(zip(apples_dl, oranges_dl)):
        # GAN 训练
        gen_optimizer.zero_grad()
 
        # identity loss
        same_B = gen_AB(real_B)
        identity_B_loss = l1loss_fn(same_B, real_B)
        same_A = gen_BA(real_A)
        identity_A_loss = l1loss_fn(same_A, real_A)
 
        # GAN loss
        fake_B = gen_AB(real_A)
        D_pred_fake_B = dis_B(fake_B)
        gan_loss_AB = bceloss_fn(D_pred_fake_B,
                                 torch.ones_like(D_pred_fake_B, device=device))
 
        fake_A = gen_BA(real_B)
        D_pred_fake_A = dis_A(fake_A)
        gan_loss_BA = bceloss_fn(D_pred_fake_A,
                                 torch.ones_like(D_pred_fake_A, device=device))
 
        # cycle consistanse loss
        recovered_A = gen_BA(fake_B)
        cycle_loss_ABA = l1loss_fn(recovered_A, real_A)
 
        recovered_B = gen_AB(fake_A)
        cycle_loss_BAB = l1loss_fn(recovered_B, real_B)
 
        # total_loss
        g_loss = (identity_B_loss + identity_A_loss + gan_loss_AB + gan_loss_BA
                  + cycle_loss_ABA + cycle_loss_BAB)
 
        g_loss.backward()
        gen_optimizer.step()
 
        # dis_A 训练
        dis_A_optimizer.zero_grad()
        dis_A_real_output = dis_A(real_A)  # 判别器输入真实图片
        dis_A_real_loss = bceloss_fn(dis_A_real_output,
                                     torch.ones_like(dis_A_real_output, device=device))
 
        dis_A_fake_output = dis_A(fake_A.detach())  # 判别器输入生成图片
        dis_A_fake_loss = bceloss_fn(dis_A_fake_output,
                                     torch.zeros_like(dis_A_fake_output, device=device))
 
        dis_A_loss = (dis_A_real_loss + dis_A_fake_loss) * 0.5
 
        dis_A_loss.backward()
        dis_A_optimizer.step()
 
        # dis_B 训练
        dis_B_optimizer.zero_grad()
        dis_B_real_output = dis_B(real_B)  # 判别器输入真实图片
        dis_B_real_loss = bceloss_fn(dis_B_real_output,
                                     torch.ones_like(dis_B_real_output, device=device))
 
        dis_B_fake_output = dis_B(fake_B.detach())  # 判别器输入生成图片
        dis_B_fake_loss = bceloss_fn(dis_B_fake_output,
                                     torch.zeros_like(dis_B_fake_output, device=device))
 
        dis_B_loss = (dis_B_real_loss + dis_B_fake_loss) * 0.5
 
        dis_B_loss.backward()
        dis_B_optimizer.step()

标签: g3高频电连接器

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

锐单商城 - 一站式电子元器件采购平台