敵対的生成ネットワーク(GAN)についてなるべくわかりやすく解説

敵対的生成ネットワーク(GAN)についてなるべくわかりやすく解説

今回は,人工知能界隈でとても賑わいをみせている敵対的生成ネットワーク(GAN)について解説をしようと思います.GANも深層学習(ディープラーニング)を利用した方法になります.GANを使用すると以下のようにとてもリアルな画像をコンピュータで自動で生成することができます

GANによって生成された画像例.
引用https://research.nvidia.com/sites/default/files/pubs/2017-10_Progressive-Growing-of/karras2018iclr-paper.pdf
GANによって生成された画像例.
引用https://arxiv.org/pdf/1809.11096.pdf

敵対的生成ネットワーク(Generative Adversarial Networks; GAN)の仕組み

GANの基本概念はずばり,「偽札を作る者とそれを見破る者」です.

例えば,偽札を作る者(Gと呼ぶことにします)とそれを見破る者(Dと呼ぶことにします)がいたとしましょう.最初の頃は,Gが作る偽札は完成度が低く,Dに簡単に偽札であると見破れられてしまうかと思います.

しかし,Gも何とかDを騙そうと頑張って技術を磨いていけば,本物のお札に近い完成度の偽札を徐々に作れるようになってくるはずです.そうなってくると,Dの方も本物のお札についてより詳しくなる必要があり,Gが作成した偽札と本物のお札を見分けるための技術を身につける必要が出てきます.

Dがレベルアップすると,Gも同様に技術をさらに磨くことでよりリアルな偽札を作る必要が出てきます.そして,Gがレベルアップしたら,Dもまたレベルアップ...というようなループが続くと思います.

そしてある日気づくと,Gが生成する偽札はいつの間にかかなり本物のお札に近いクォリティになっているはずです.これと全く同じことを2つのニューラルネットワーク同士(GとD)で実行します.

スポンサーリンク

GANの全体像

GANの全体像は以下の通りです.CNN Gがランダムな情報から画像を生成し,それを別のCNN Dが画像分類問題として解く,という形式になっています.GはDの分類精度を下げるようにパラメータを更新し,DはGが作成した画像を偽物と,本物の画像に対しては本物と正しく分類できるようにパラメータの更新を行います.

画像を生成するニューラルネットワーク G

画像を生成する(上の例でいう偽札製造者)ニューラルネットワークGは,Deconvolution処理を搭載したCNNを用いることが多いです.

Deconvolution処理は,フィルタを用いて入力画像を拡大することが可能な畳み込み処理です.以下の例では,青色で表された\(3\times 3\)画素の入力特徴マップを周囲や間にゼロを足してサイズを拡大した後でフィルタを適用し、緑色で表された\(5\times 5\)画素の特徴マップを出力しています.

Deconvolutionの例.青が入力画像,緑が出力画像を表している.画像引用https://github.com/vdumoulin/conv_arithmetic

このDeconvolution処理を用いることで,徐々に画像を拡大しつつ,Dを騙せるようなリアルな画像生成を行います.例えば,Deconvolutionを3つ用いて,\(32\times 32\)のカラー画像を生成するには,以下のような3層のCNNを用いることで実現することができます.

CNNの入力にはランダムに生成したノイズ画像を入力することが多いです.

生成した画像を後述するCNN Dが本物であると(誤って)分類するように,CNN Gのフィルタの重みを誤差逆伝播法によって調整します.Dが間違えて本物と分類するようにパラメータの更新を行っていくと,次第にリアルな画像が生成されていくようになります.

本物と偽物を見分けるニューラルネットワーク D

Gが生成した偽の画像と本物の画像とを見分けるニューラルネットワークは,通常の画像分類で用いるCNNを使用します.以下に,CNNの例を示します.

具体的には,Gが生成した画像に対してはfake(=0),本物の画像に対してはreal(=1)を最終層で出力するようにCNNのパラメータを調整します(2クラス画像分類を行っているだけ).

GANの学習

GANの学習では,

  • Dのパラメータ更新
  • Gのパラメータ更新

を交互に行います.

Dの学習

Dの学習には,本物の画像とGが生成した偽の画像の2つを用います.本物の画像を入力したときには\(1\)を出力,Gが生成した画像を入力したときには\(0\)を出力するようにします.より具体的には,Dの最終層はシグモイド関数にし,Dを学習するための損失関数はバイナリクロスエントロピー関数を用います.

$$Loss_{D}=\mathbb{E}_{{\bf x}\sim p({\bf x})}[t{\rm log}D({\bf x})]+\mathbb{E}_{{\bf z}\sim p({\bf z})}[(1-t){\rm log}(1-D(G({\bf z})))]$$

Dに関しては上記の損失関数を最大化するように誤差逆伝播法によってパラメータを更新します.\(D(\cdot)\) はネットワークDの出力を示しています.

具体的には,本物の画像(\(t=1\))に対しては,\(\mathbb{E}_{{\bf x}\sim p({\bf x})}[t{\rm log}D({\bf x})]\)を最大化,偽物の画像(\(t=0\))に対しては\(\mathbb{E}_{{\bf z}\sim p({\bf z})}[(1-t){\rm log}(1-D(G({\bf z})))]\)を最大化するようにCNNのパラメータの更新を行います.

Gの学習

Gの目標は,Gが生成した画像をDに本物であると認識させることです.これを達成するための損失関数は以下のようになります.

$$Loss_{G}=\mathbb{E}_{{\bf z}\sim p({\bf z})}[{\rm log}(D(G({\bf z})))]$$

上記の損失関数を最大化するように,Gのパラメータを更新します.これはつまり,Gが生成した偽の画像に対してDの出力が\(1\)(=本物)となるように,Gを訓練することに相当します.

まとめ

今回は,現在人工知能界隈で流行っている敵対的生成ネットワーク(GAN)の解説をしてみましたが,いかがだったでしょうか.2つのニューラルネットワークを互いに競わせる,といった点が大きな特徴になります.この学習方法によって,ネットワークGは次第にリアルな画像を生成することができるようになります.

次回は,GANの実装について説明しようかと思います.

タイトルとURLをコピーしました