DANN & GRL
Swift Lv6

域自适应是指在目标域与源域的数据分布不同但任务相同下的迁移学习,从而将模型在源域上的良好性能迁移到目标域上,极大地缓解目标域标签缺失严重导致模型性能受损的问题。

介绍一篇经典工作 DANN

模型结构

model

在训练阶段需要预测如下两个任务:

  • 实现源域数据集准确分类,即图像分类误差的最小化,这与正常分类任务保持一致
  • 实现源域和目标域准确分类,即域分类器的误差最小化。而特征提取器的目标是最大化域分类误差,使得域分类器无法分辨数据是来自源域还是目标域,从而让特征提取器学习到域不变特征(domain-invariant)。也就是说特征提取器和域分类器的目标是相反的
    • 本质上就是让特征提取器不要过拟合源域,要学习出源域和目标域的泛化特征
    • 这两个网络对抗训练,DANN通过GRL层使特征提取器更新的梯度与域判别器的梯度相反,构造出了类似于GAN的对抗损失,又通过该层避免了GAN的两阶段训练过程,提升模型训练稳定性

GRL

GRL是作用在特征提取器上的,对其参数梯度取反。

具体实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
class ReverseLayerF(Function):

@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha

return x.view_as(x)

@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha

return output, None

调用如下:

1
2
3
4
5
6
7
8
9
def forward(self, input_data, alpha):
input_data = input_data.expand(input_data.data.shape[0], 3, 28, 28)
feature = self.feature(input_data)
feature = feature.view(-1, 50 * 4 * 4)
reverse_feature = ReverseLayerF.apply(feature, alpha)
class_output = self.class_classifier(feature)
domain_output = self.domain_classifier(reverse_feature)

return class_output, domain_output


参考

Powered by Hexo & Theme Keep
Unique Visitor Page View