敵対的生成ネットワーク(GAN)で画像生成を行うプログラムをなるべくわかりやすく解説

敵対的生成ネットワーク(GAN)で画像生成を行うプログラムをなるべくわかりやすく解説

今回の記事では,敵対的生成ネットワーク(GAN)を用いて,手書き数字(MNIST)の画像を自動生成するプログラムを紹介したいと思います.

使用するデータセット

今回は,こちらの記事でも使用しました手書き数字のデータセットであるMNISTを使います.MNISTは「0」から「9」までの手書き数字の画像で構成されており,合計で60,000枚あります.MNISTは人工知能に関する学会でもよく使用されているデータセットです.

MNIST

使用するプログラミング環境

プログラミング言語はPython,深層学習用ライブラリはPyTorchを使用します.まだこれらのインストールや設定が完了していない方はこちらの記事を参考に,設定をしてみて下さい.

スポンサーリンク

GANによるMNISTの自動生成

全体のプログラム

全体のプログラムは以下のようになります.プログラムはこちらでも公開していますので,適宜ダウンロードして使用してみて下さい.

from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils


parser = argparse.ArgumentParser()
parser.add_argument('--ngpu', type=int, default=0, help='number of GPUs to use')
opt = parser.parse_args()

# 学習後のネットワークを保存するフォルダ作成
try:
    os.makedirs('./models')
except OSError:
    pass
# 生成画像を保存するフォルダ作成
try:
    os.makedirs('./generated_images')
except OSError:
    pass

# 乱数のシード設定
manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

# GPUの設定
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
if use_cuda:
    ngpu = int(opt.ngpu)
    cudnn.benchmark = True

# データセットの設定
dataset = dset.MNIST(root='./', download=True,
                    transform=transforms.Compose([
                        transforms.Resize(64),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5,), (0.5,)),]))
# 入力画像のチャネル数
nc=1
# バッチサイズ
batchsize = 64
# Dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers=2)

# 画像を生成するネットワークG
class Generator(nn.Module):
    def __init__(self, ngpu, nz):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.nz = nz
        self.nf = 64
        self.main = nn.Sequential(
            # ランダムノイズをDeconvolution layerに入力
            nn.ConvTranspose2d(self.nz, self.nf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(self.nf * 8),
            nn.ReLU(),
            # (nf*8) x 4 x 4
            nn.ConvTranspose2d(self.nf * 8, self.nf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf * 4),
            nn.ReLU(),
            # (nf*4) x 8 x 8
            nn.ConvTranspose2d(self.nf * 4, self.nf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf * 2),
            nn.ReLU(),
            # (nf*2) x 16 x 16
            nn.ConvTranspose2d(self.nf * 2, self.nf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf),
            nn.ReLU(),
            # (nf) x 32 x 32
            nn.ConvTranspose2d(self.nf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # (nc) x 64 x 64の画像を出力
        )

    def forward(self, input):
        output = self.main(input)
        return output

# 本物と偽物を見分けるネットワークD
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.nf = 64
        self.main = nn.Sequential(
            # input size is (nc) x 64 x 64
            nn.Conv2d(nc, self.nf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # (nf) x 32 x 32
            nn.Conv2d(self.nf, self.nf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # (nf*2) x 16 x 16
            nn.Conv2d(self.nf * 2, self.nf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # (nf*4) x 8 x 8
            nn.Conv2d(self.nf * 4, self.nf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # (nf*8) x 4 x 4
            nn.Conv2d(self.nf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1).squeeze(1)


#ネットワークGに入力するランダムノイズ
nz = 100
fixed_noise = torch.randn(batchsize, nz, 1, 1, device=device)

# ネットワークGの宣言
netG = Generator(opt.ngpu, nz).to(device)

# ネットワークDの宣言
netD = Discriminator(opt.ngpu).to(device)

# 損失関数
criterion = nn.BCELoss()

# ラベルの定義(本物:1, 偽物:0)
real_label = 1
fake_label = 0

# Optimizerの設定
optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))

# 画像生成の開始
epoch_num = 20
for epoch in range(epoch_num):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) ネットワークDの学習(=log(D(x)) + log(1 - D(G(z)))の最大化)
        ###########################
        # 本物の画像の見分ける
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, device=device)
        output = netD(real_cpu)
        # 損失(loss)の計算
        errD_real = criterion(output, label)
        # 誤差逆伝搬
        errD_real.backward()
        D_x = output.mean().item()

        # 偽物の画像を見分ける
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        # 損失(loss)の計算
        errD_fake = criterion(output, label)
        # 誤差逆伝播
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        # パラメータ更新
        optimizerD.step()

        ############################
        # (2) ネットワークGの学習(=log(D(G(z)))の最大化)
        ###########################
        netG.zero_grad()
        label.fill_(real_label)
        output = netD(fake)
        # 損失(loss)の計算
        errG = criterion(output, label)
        # 誤差逆伝播
        errG.backward()
        D_G_z2 = output.mean().item()
        # パラメータ更新
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, epoch_num, i, len(dataloader),
                 errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

    # 生成画像の保存
    fake = netG(fixed_noise)
    vutils.save_image(fake.detach(), './generated_images/fake_samples_epoch_%03d.png' % (epoch), normalize=True)
    # ネットワークの保存
    torch.save(netG.state_dict(), './models/netG_epoch_%d.pth' % (epoch))
    torch.save(netD.state_dict(), './models/netD_epoch_%d.pth' % (epoch))

上記のプログラムを

python dcgan.py

のように,ターミナルで実行することで,MNISTの画像生成が始まります.

生成結果

初期状態

初期状態のネットワークGにランダムノイズを入力し,画像生成を行った結果は以下のようになります.まだネットワークGの学習を行っていないため,ランダムノイズを出力しています.

学習初期の生成結果

1 epoch後

ネットワークGとDを1epoch分学習したあとのGによる生成画像が以下のようになります.わずか1epochの学習でも何やら数字らしきものが写った画像を生成できていることがわかります.

1epoch学習後のGによる生成結果

数epoch後

数epoch学習後のGによる生成画像例を下図の左側に示します.右側の本物のMNIST画像と比べてわかるように,結構MNISTの画像に近い画像が生成できていることがわかるかと思います.今回は学習はこのあたりで止めてしまいましたが,さらに続けることでより本物に近い画像を生成することが可能です.

Gが生成したMNIST画像
本物のMNIST画像

画像生成ネットワークGのプログラム

画像生成を行うネットワークGのプログラム部分は以下のようになります.

今回は5層のCNNを使用しています.各畳み込み層では,こちらの記事で紹介しましたDeconvolution処理を行い,画像を拡大しつつ画像生成を行っています.具体的には,画像の大きさを2倍にしつつ,チャネル数は半分にする処理を5回行っています(最初の畳み込み層を除く).

nn.ConvTranspose2d(in_channel, out_channel, kernel_size, stride, padding, output_padding)がDeconvolution処理を表しているのですが,各パラメータを変更することで出力画像のサイズを変更することができます.具体的には,

$$H_{out}=(H_{in}-1)\times stride – 2 \times padding + kernel\_size + output\_padding$$

の式で,入力画像サイズ\(H_{in}\)から出力画像サイズ\(H_{out}\)を計算することができます.

# 画像を生成するネットワークG
class Generator(nn.Module):
    def __init__(self, ngpu, nz):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.nz = nz
        self.nf = 64
        self.main = nn.Sequential(
            # ランダムノイズをDeconvolution layerに入力
            nn.ConvTranspose2d(self.nz, self.nf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(self.nf * 8),
            nn.ReLU(),
            # (nf*8) x 4 x 4
            nn.ConvTranspose2d(self.nf * 8, self.nf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf * 4),
            nn.ReLU(),
            # (nf*4) x 8 x 8
            nn.ConvTranspose2d(self.nf * 4, self.nf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf * 2),
            nn.ReLU(),
            # (nf*2) x 16 x 16
            nn.ConvTranspose2d(self.nf * 2, self.nf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf),
            nn.ReLU(),
            # (nf) x 32 x 32
            nn.ConvTranspose2d(self.nf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # (nc) x 64 x 64の画像を出力
        )

    def forward(self, input):
        output = self.main(input)
        return output
スポンサーリンク

画像識別ネットワークD

本物の画像とGが生成した画像を見分けるネットワークDのプログラムは以下のようになります.

こちらも5層のCNNを用いています.最終層の活性化関数をシグモイド関数にすることでCNNの出力値の範囲を\(0\)から\(1\)にしています.そして,本物画像のときは\(1\)を,Gの生成画像のときは\(0\)を出力するようにパラメータの更新を行います.

また,今回は中間層の活性化関数にLeakly ReLUと呼ばれるものを使用してみました.

$$f(x) = \begin{cases} x & (x>0) \\ 0.01x & (x\le 0) \end{cases}$$

一概にどの活性化関数が1番良いかどうかは判断することが難しく,色々と変えながら試すしかないのが正直なところです.

# 本物と偽物を見分けるネットワークD
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.nf = 64
        self.main = nn.Sequential(
            # input size is (nc) x 64 x 64
            nn.Conv2d(nc, self.nf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # (nf) x 32 x 32
            nn.Conv2d(self.nf, self.nf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # (nf*2) x 16 x 16
            nn.Conv2d(self.nf * 2, self.nf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # (nf*4) x 8 x 8
            nn.Conv2d(self.nf * 4, self.nf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(self.nf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # (nf*8) x 4 x 4
            nn.Conv2d(self.nf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1, 1).squeeze(1)
スポンサーリンク

ネットワークGの学習

ネットワークGの学習では,ネットワークDを騙すように訓練する必要があります.そのため,Gが生成した画像に対するラベルは本来は\(0\)(=偽物)なのですが,\(1\)(=本物)に変更し,Dが\(1\)を出力するようにGのパラメータを更新します.

        ############################
        # (2) ネットワークGの学習(=log(D(G(z)))の最大化)
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        # 損失(loss)の計算
        errG = criterion(output, label)
        # 誤差逆伝播
        errG.backward()
        D_G_z2 = output.mean().item()
        # パラメータ更新
        optimizerG.step()

ネットワークDの学習

ネットワークDの学習では,通常の画像分類同様に,本物の画像に対してはラベル\(1\)(=本物),Gの生成画像に対しては\(0\)(=偽物)と出力するようにパラメータの更新を行います.

        ############################
        # (1) ネットワークDの学習(=log(D(x)) + log(1 - D(G(z)))の最大化)
        ###########################
        # 本物の画像の見分ける
        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, device=device)
        output = netD(real_cpu)
        # 損失(loss)の計算
        errD_real = criterion(output, label)
        # 誤差逆伝搬
        errD_real.backward()
        D_x = output.mean().item()

        # 偽物の画像を見分ける
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        # 損失(loss)の計算
        errD_fake = criterion(output, label)
        # 誤差逆伝播
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        # パラメータ更新
        optimizerD.step()

まとめ

今回はGANを用いてMNISTの画像を生成する実装例を解説してみました.GANについての解説はこちらの記事で行っていますので,よければ参考にしてみて下さい.

また今回のプログラムもこちらでダウンロードできます.

タイトルとURLをコピーしました