2 minute read

VQ-VAE Model


코드를 통해 위 그림의 모델을 확인해보았다.
가장 먼저 train_vq.py를 확인해 VQ-VAE를 어떻게 훈련시키는지 확인해 봤다.

net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
                       args.nb_code,
                       args.code_dim,
                       args.output_emb_width,
                       args.down_t,
                       args.stride_t,
                       args.width,
                       args.depth,
                       args.dilation_growth_rate,
                       args.vq_act,
                       args.vq_norm)

net이라는 네트워크를 HumanVQVAE를 통해 만드는데 args는 nb_code, code_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, vq_act, vq_norm을 통해 만들어진다. 대부분의 args는 training code를 통해 확인할 수 있는데

python3 train_vq.py \
--batch-size 256 \
--lr 2e-4 \
--total-iter 300000 \
--lr-scheduler 200000 \
--nb-code 512 \
--down-t 2 \
--depth 3 \
--dilation-growth-rate 3 \
--out-dir output \
--dataname t2m \
--vq-act relu \
--quantizer ema_reset \
--loss-vel 0.5 \
--recons-loss l1_smooth \
--exp-name VQVAE

nb_code는 512, down_t는 2, depth는 3, dilation_growth_rate는 3, vq_act는 relu로 구성되었다. 나머지 부분의 경우 options/option_vq.py에 적혀있다. code_dim은 512, output_emb_width은 512, stride_t는 2, width는 512, vq_norm는 None을 보여주고 있다.
가장 기본적으로 VQ-VAE network를 거칠 때 어떤 module들이 거치는지 확인하기 위해 forward를 확인해 봤다.

    def forward(self, x):
        # x => [batch, motion_frame (64), 263]
        x_in = self.preprocess(x)
        # x_in => [batch, 263, 64]
        # Encode
        x_encoder = self.encoder(x_in)
        # x_encoder => [256,512,16]
        ## quantization
        x_quantized, loss, perplexity  = self.quantizer(x_encoder)

        ## decoder
        x_decoder = self.decoder(x_quantized)
        x_out = self.postprocess(x_decoder)
        return x_out, loss, perplexity

기존적으로 preprocess를 거치는데

    def preprocess(self, x):
        # (bs, T, Jx3) -> (bs, Jx3, T)
        x = x.permute(0,2,1).float()
        return x

간단히 생각해서 dim을 바꿔준다고 생각하면 될 것 같다. 다시 말해서 conv1D를 거칠 때 channel을 joint로 바꾸고 feature를 sequence로 바꾼다고 생각하면 될 것 같다.
그 다음으로 encoder를 거치는데 encoder의 구조는 다음과 같다.

Sequential(
  (0): Conv1d(263, 512, kernel_size=(3,), stride=(1,), padding=(1,))
  (1): ReLU()
  (2): Sequential(
    (0): Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
    (1): Resnet1D(
      (model): Sequential(
        (0): ResConv1DBlock(
          (norm1): Identity()
          (norm2): Identity()
          (activation1): ReLU()
          (activation2): ReLU()
          (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(9,), dilation=(9,))
          (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,)))
        (1): ResConv1DBlock(
          (norm1): Identity()
          (norm2): Identity()
          (activation1): ReLU()
          (activation2): ReLU()
          (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
          (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,)))
        (2): ResConv1DBlock(
          (norm1): Identity()
          (norm2): Identity()
          (activation1): ReLU()
          (activation2): ReLU()
          (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
          (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))))))
  (3): Sequential(
    (0): Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
    (1): Resnet1D(
      (model): Sequential(
        (0): ResConv1DBlock(
          (norm1): Identity()
          (norm2): Identity()
          (activation1): ReLU()
          (activation2): ReLU()
          (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(9,), dilation=(9,))
          (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,)))
        (1): ResConv1DBlock(
          (norm1): Identity()
          (norm2): Identity()
          (activation1): ReLU()
          (activation2): ReLU()
          (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
          (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,)))
        (2): ResConv1DBlock(
          (norm1): Identity()
          (norm2): Identity()
          (activation1): ReLU()
          (activation2): ReLU()
          (conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
          (conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))))))
  (4): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,)))

kernel size가 3이고 stride 1, padding 1일 때 같은 feature가 나오게 되고 kernel_size 1, stride 1일때도 같은 feature map이 나오게 된다.
하지만 self.encoder(x_in)를 거치면 [256, 263, 64] => [256, 512, 16]이 나오게 되는데 Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))에서 feature size가 1/2로 줄어들기 때문이다.
이렇게 줄어든 feature map은 self.quantizer를 들어가게 되는데

        if args.quantizer == "ema_reset":
            self.quantizer = QuantizeEMAReset(nb_code, code_dim, args)
        elif args.quantizer == "orig":
            self.quantizer = Quantizer(nb_code, code_dim, 1.0)
        elif args.quantizer == "ema":
            self.quantizer = QuantizeEMA(nb_code, code_dim, args)
        elif args.quantizer == "reset":
            self.quantizer = QuantizeReset(nb_code, code_dim, args)

총 4개의 quantizer가 있고 --quantizer ema_reset \를 통해 quantizer는 ema_reset으로 예측할 수 있다. 따라서 models/quantize_cnn.py에서 ema_reset을 보면

Categories:

Updated:

Leave a comment