Generation and GAN(初步认知)
Generation and GAN
文章目录
- Generation and GAN
- 参考
- 什么是Generation?
- 什么是GAN?
- GAN 的理论支撑?
- Discriminator
- **Generator**的**Divergence**
- JS divergence is not suitable
- Wasserstein Distance
- GAN的瓶颈
- Diversity - Mode Collapse
- Diversity - Mode Dropping
- Conditional Generation
- Cycle GAN
参考
- 李宏毅机器学习课程笔记:DeepLearning_LHY21_Notes
- (强推)李宏毅2021春机器学习课程40-43
- 开发者自述:我是这样学习 GAN 的
什么是Generation?
Generation就是模型通过学习一些数据,然后生成类似的数据。让机器看一些动物图片,然后自己来产生动物的图片,这就是生成。
这里引用李宏毅老师课上的话:network输出,不再是单一一个固定的东西,而变成了一个复杂的distribution,这种可以输出,一个distribution的network,我们就叫它generator
- 这里的distribution可以是上面提到的图片,也可以是语音、文字等
什么是GAN?
Generative Adversarial Network (GAN),百度百科等解释是这样的:
- 生成式对抗网络(GAN),模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。
- 原始 GAN 理论中,并不要求 G 和 D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G 和 D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。
关于生成对抗的解释,李老师用了一个枯叶蝶(=Generator)+比比鸟(=Discriminative)的例子
- 首先初代的波波会吃枯叶蝶的祖先,
- 然后枯叶蝶祖先进化,颜色变成棕色,这样可以增加存活几率
- 然后波波进化成比比鸟,可以认出棕色的枯叶蝶祖先,这样可以增加捕食效率
- 然后棕色枯叶蝶祖先超进化,有了叶子的纹路,这样就和枯叶更像了,生存几率上升
- 然后比比鸟进化成大比鸟。。。。。。。
GAN 的理论支撑?
- 上面说到,我们用Generator生成对象,用Discriminator来检测对象的逼真程度
也就是下图中所示( P G = 生 成 分 布 P d a t a = 真 实 分 布 P_G=生成分布 \qquad P_{data} = 真实分布 PG=生成分布Pdata=真实分布),我们用Discriminator来检测对象的逼真程度,也就是:$P_G $计算 P d a t a P_{data} Pdata之间的距离,最终目的是让Generator生成的$P_G $ 和真实的 P d a t a P_{data} Pdata之间的距离越短越好,也就是: G ∗ = arg min D i v ( P G , P d a t a ) G^* = \arg \min Div(P_G,P_{data}) G∗=argminDiv(PG,Pdata)
- 那么如何算这个DIV(Divergence)?GAN 告诉我们就是,你不需要知道 PG 跟 Pdata它们实际上的 Formulation 长什麼样子,只要能从 P G P_G PG 和 P d a t a P_{data} Pdata这两个 Distributions Sample 东西出来,就有办法算 Divergence,这个就是要靠 Discriminator 的力量
Discriminator
- Discriminator在我们根本不知道 PG 跟 Pdata,实际上完整的 Formulation 长什麼样子,就能估测出 Divergence
Discriminator是如何训练出的?
- 我们有一大堆的 Real Data,这个 Real Data 就是从 P d a t a P_{data} Pdata Sample 出来的结果
- 我们有一大堆 Generative 的 Data,Generative 的 Data,就可以看作是从 P G P_G PG Sample 出来的结果
根据 Real 的 Data 跟 Generative 的 Data,我会去训练一个 Discriminator,它的训练的目标是
- 看到 Real Data,就给它比较高的分数
- 看到这个 Generative 的 Data,就给它比较低的分数
- 其实就是要分辨好的图跟不好的图,分辨真的图跟生成的图,所以看到真的图给它高分,看到生成的图给它低分
实际以上的过程,可以把它写成式子,把它当做是一个 Optimization 的问题(最优化问题),这个 Optimization 的问题是这样子的(我们要 Maximize 的东西,我们会叫 Objective Function,如果 Minimize 我们就叫它 Loss Function):
- 我们现在要找一个 D,它可以 Maximize 这个 Objective Function
我们希望这个 Objective Function V越大越好
- 意味著我们希望这边的 D (Y) 越大越好,我们希望 Y 如果是从 P d a t a P_{data} Pdata Sample 出来的,它就要越大越好
- 我们希望说如果 Y 是从,这个 P G P_{G} PG Sample 出来的,它就要越小越好
- 事实上这个 Objective Function,它就是 Cross Entropy 乘一个负号,我们在寻找Maximize的时候,也等同雨寻找Corss Entropy的Minimize,也就是说这等同於是在训练一个 Classifier(其实Discriminator确实可以当做是一个 Classifier来train)
Generator的Divergence
有了Discriminator,我们对Generator的Divergence就有了下面的思路:
- 我们现在知道,我们只要训练一个 Discriminator,训练完以后,这个 Objective Function 的最大值,就是这个 Divergence,就跟这个 Divergence 有关
- 然后我们可以用D*替换掉Div
JS divergence is not suitable
JS Divergence 有个特性,是两个没有重叠的分布,JS Divergence 算出来,就永远都是 Log2,不管这两个分布长什麼样,所以两个分布只要没有重叠,算出来就一定是 log2
- 如果不能分辨出来bad的程度,则就不回说去从左侧逼近右侧,因此我们需要找到一个衡量两个 Distribution 的相似程度的方式
Wasserstein Distance
- Wasserstein Distance你在开一台推土机,那你把 P 想成是一堆土,把 Q 想成是你要把土堆放的目的地,那这个推土机把 P 这边的土,挪到 Q 所移动的平均距离,就是 Wasserstein Distance
- 但是对于复杂的分布来说,这种方法可能就有不同值:
- 于是为了让 Wasserstein Distance 只有一个值,所以这边 Wasserstein Distance 的定义是,穷举所有的 Moving Plans,然后看哪一个推土的方法,哪一个 Moving 的计划,可以让平均的距离最小,那个最小的值,才是 Wasserstein Distance
GAN的瓶颈
Diversity - Mode Collapse
- 假如Generator生成了一个可以完美骗过Discriminator的一个sample,然后所有的Generator生成的sample都围绕那个接近完美的sample即可骗过Discriminator,拿到一个很高的分,但是这样会导致生成的图趋向一致
Diversity - Mode Dropping
- 和上面的类似,只不过这里趋向一堆sample(但不是全部),因此这会导致生成的图总是那么几个
Conditional Generation
- 相比于此前只输入分布z的Generation来说,Conditional Generation还要输入条件:x,因此画出来的时候会与输入条件相关,至于下图中的red eyes怎么输入的,这里用RNN或者Transformer 的 Encoder即可
- 相应的,在训练Discriminator的时候也要加入条件,这样Discriminator才能将对应的条件和图片对应起来,然后Generation才能train出一个比较好的结果
Cycle GAN
对于那些未标注的信息,可以用Cycle GAN来train,比如对于一个人头像转换成动漫头像这个任务,可以这么做:
现在这边有三个Network
- 第一个generator,它的工作是把X转成Y
- 第二个generator,它的工作是要把Y还原回原来的X
- 那这个discriminator,它的工作仍然是要看,蓝色的这个generator它的输出,像不像是Y domain的图
- 那加入了这个橙色的从Y到X的generator以后,对於前面这个蓝色的generator来说,它就再也不能够随便乱做了,它就不能够随便產生乱七八糟,跟输入没有关係的人脸了(原来它可以直接对每个人脸输入输出同一个动漫头像,这就不是很合理,虽然Cycle GAN理论上也可能会发生这种情况,但在实际train的过程中,基本可以避免上述情况的发生)
还没有评论,来说两句吧...