T2M-GPT 코드 분석 (4) VQ-VAE 모델 확인
VQ-VAE Model
코드를 통해 위 그림의 모델을 확인해보았다.
가장 먼저 train_vq.py
를 확인해 VQ-VAE를 어떻게 훈련시키는지 확인해 봤다.
net = vqvae.HumanVQVAE(args, ## use args to define different parameters in different quantizers
args.nb_code,
args.code_dim,
args.output_emb_width,
args.down_t,
args.stride_t,
args.width,
args.depth,
args.dilation_growth_rate,
args.vq_act,
args.vq_norm)
net이라는 네트워크를 HumanVQVAE를 통해 만드는데 args는 nb_code, code_dim, output_emb_width, down_t, stride_t, width, depth, dilation_growth_rate, vq_act, vq_norm을 통해 만들어진다. 대부분의 args는 training code를 통해 확인할 수 있는데
python3 train_vq.py \
--batch-size 256 \
--lr 2e-4 \
--total-iter 300000 \
--lr-scheduler 200000 \
--nb-code 512 \
--down-t 2 \
--depth 3 \
--dilation-growth-rate 3 \
--out-dir output \
--dataname t2m \
--vq-act relu \
--quantizer ema_reset \
--loss-vel 0.5 \
--recons-loss l1_smooth \
--exp-name VQVAE
nb_code는 512, down_t는 2, depth는 3, dilation_growth_rate는 3, vq_act는 relu로 구성되었다. 나머지 부분의 경우 options/option_vq.py
에 적혀있다.
code_dim은 512, output_emb_width은 512, stride_t는 2, width는 512, vq_norm는 None을 보여주고 있다.
가장 기본적으로 VQ-VAE network를 거칠 때 어떤 module들이 거치는지 확인하기 위해 forward를 확인해 봤다.
def forward(self, x):
# x => [batch, motion_frame (64), 263]
x_in = self.preprocess(x)
# x_in => [batch, 263, 64]
# Encode
x_encoder = self.encoder(x_in)
# x_encoder => [256,512,16]
## quantization
x_quantized, loss, perplexity = self.quantizer(x_encoder)
## decoder
x_decoder = self.decoder(x_quantized)
x_out = self.postprocess(x_decoder)
return x_out, loss, perplexity
기존적으로 preprocess를 거치는데
def preprocess(self, x):
# (bs, T, Jx3) -> (bs, Jx3, T)
x = x.permute(0,2,1).float()
return x
간단히 생각해서 dim을 바꿔준다고 생각하면 될 것 같다. 다시 말해서 conv1D를 거칠 때 channel을 joint로 바꾸고 feature를 sequence로 바꾼다고 생각하면 될 것 같다.
그 다음으로 encoder를 거치는데 encoder의 구조는 다음과 같다.
Sequential(
(0): Conv1d(263, 512, kernel_size=(3,), stride=(1,), padding=(1,))
(1): ReLU()
(2): Sequential(
(0): Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
(1): Resnet1D(
(model): Sequential(
(0): ResConv1DBlock(
(norm1): Identity()
(norm2): Identity()
(activation1): ReLU()
(activation2): ReLU()
(conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(9,), dilation=(9,))
(conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,)))
(1): ResConv1DBlock(
(norm1): Identity()
(norm2): Identity()
(activation1): ReLU()
(activation2): ReLU()
(conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
(conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,)))
(2): ResConv1DBlock(
(norm1): Identity()
(norm2): Identity()
(activation1): ReLU()
(activation2): ReLU()
(conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
(conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))))))
(3): Sequential(
(0): Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
(1): Resnet1D(
(model): Sequential(
(0): ResConv1DBlock(
(norm1): Identity()
(norm2): Identity()
(activation1): ReLU()
(activation2): ReLU()
(conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(9,), dilation=(9,))
(conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,)))
(1): ResConv1DBlock(
(norm1): Identity()
(norm2): Identity()
(activation1): ReLU()
(activation2): ReLU()
(conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
(conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,)))
(2): ResConv1DBlock(
(norm1): Identity()
(norm2): Identity()
(activation1): ReLU()
(activation2): ReLU()
(conv1): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
(conv2): Conv1d(512, 512, kernel_size=(1,), stride=(1,))))))
(4): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,)))
kernel size가 3이고 stride 1, padding 1일 때 같은 feature가 나오게 되고 kernel_size 1, stride 1일때도 같은 feature map이 나오게 된다.
하지만 self.encoder(x_in)
를 거치면 [256, 263, 64] => [256, 512, 16]이 나오게 되는데 Conv1d(512, 512, kernel_size=(4,), stride=(2,), padding=(1,))
에서 feature size가 1/2로 줄어들기 때문이다.
이렇게 줄어든 feature map은 self.quantizer
를 들어가게 되는데
if args.quantizer == "ema_reset":
self.quantizer = QuantizeEMAReset(nb_code, code_dim, args)
elif args.quantizer == "orig":
self.quantizer = Quantizer(nb_code, code_dim, 1.0)
elif args.quantizer == "ema":
self.quantizer = QuantizeEMA(nb_code, code_dim, args)
elif args.quantizer == "reset":
self.quantizer = QuantizeReset(nb_code, code_dim, args)
총 4개의 quantizer가 있고 --quantizer ema_reset \
를 통해 quantizer는 ema_reset으로 예측할 수 있다. 따라서 models/quantize_cnn.py
에서 ema_reset을 보면
Leave a comment