MaskGIT-Masked Generative Image Transformer

MaskGIT: Masked Generative Image Transformer

TL;DR

在此之前,最好的transformer 模型做法是把图像先编码成token 序列,然后line-by-line 的进行解码,作者认为这种方式不是最优也不是最高效的。本文提出MaskGIT, 一种的图像合成领域的新范式,使用一种双向的transformer decoder. 在训练阶段,MaskGIT从所有方向上实行attention机制,从而预测随机masked token。在推理阶段,模型先同步生成图像的所有token,然后不断refine之前生成的结果。该方法相比之前的AR方法,有64x 的加速。

Untitled

Method

Untitled

  • MaskGit 的目标是提升二阶段的效果,一阶段和VQGAN方法保持一致

3.1 Masked Visual Token Modeling(MVTM)

  • $Y = [{y_i}]_{i=1}^N$ 把图像输入VQ Encoder,获得的latent tokens。 N表示把token matix reshape 后的矩阵长度。
  • $M=[{m_i}]_{i=1}^N$ 表示对应的binary mask
  • 在训练过程中,获得的latent token 被【Mask】token随机替换。采样过程是首先从0,1之间的均匀分布采样出一个ratio的值。然后从N个token 中采样出 $[\gamma * N]$ 个token
  • $Y_{\overline M}$ 表示把Mask apply 到token Y上

Untitled

Iterative Decoding

  • 图像压缩得到的所有token可以并行被decoder,这是提出的双向 self-attention所带来的feature。每个算法迭代,可以分为如下几个步骤

    • predict:在第t 次迭代,给定masked token,模型输出所有掩模位置的概率分布 - $p^{(t)} \in R^{N \times K}$
    • sample :对于每个mask 位置,根据codebook中的所有token的概率分布进行
    • mask schedule:根据上述的mask schedule function 来确定采样token 的单位
    • mask

    $$
    m_{i}^{t+1} = \begin{cases}
    1, & \text{if } c_i \geq \operatorname*{sorted}_{j} (c_j)[n] \
    0, & \text{otherwise}
    \end{cases}
    $$

    • 简单描述decoder过程,模型同时预测出所有的token,并且只保留置信度最高的一项。剩余的token 会被mask住并且重新预测。mask ratio逐步减小,直到T次迭代生成所有的token为止。

Mask Design

  • masking design对生成图像的质量影响很大,因此mask schedule function 很关键。在测试阶段,mask ratio 分别为 $0/T,1/T, …… ,(T-1)/T$ ti
  • 提出一种新的mask schedule function,首先从性质上它必须是有界且连续,其次要单调递减。这里分别分成了三个组的function来表示这些函数。• Linear /Concave / Convex 。后面有实验会比较

Experiment

  • Settings:
    • (1)One encoder / one decoder for each dataset, 1024 token codebook. 图像尺寸256x256。这个token 尺寸在512 x 512上的图也是适用的。
    • (2)24 layer, 8 attention head, 768 embed dimensions, 3072 hidden dimensions.
  • 相比BigGAN而言效果提升不少

Untitled

Untitled

Image Editing

  • 一些 inpainting / outpainting的应用

Untitled

Untitled

Thoughts

有限次数去逐渐decoder token, 替代了line-by-line的方式。