Skip to content
Snippets Groups Projects

VQ-VAE implementation

Open Richard Danis requested to merge rdanis/aixd:vqvae into master

Merge Request Description

Summary:

Implementation of the vector quantized VAE (VQ-VAE) architecture (https://arxiv.org/abs/1711.00937) as a third model option.

This architecture features a discrete latent space, which may help in interpretability. There is a fixed number of embedding vectors (num_embeddings argument) whose dimensionality can be chosen with the embedding_dim argument. The encoder produces a number (equal to latent_dim) of vectors of dimension embedding_dim, and those are then mapped to the closest embedding vectors. This is essentially a quantization step. There is then a quantization loss that punishes how far the predicted vectors deviate from the embedding vectors, and this loss has an additional hyperparameter, commit_weight. The loss is extensively explained in https://huggingface.co/blog/ariG23498/understand-vq.

I tested performance mainly on the parametric vase dataset from https://gitlab.renkulab.io/luis.salamanca/aixd-tutorial. In reconstruction error, it is mostly better than VAE, although one has to play a bit with the 4 hyperparameters latent_dim, num_embeddings, embedding_dim, and commit_weight.

I tried to reuse as much code as possible. The CondVQVAEModel inherits from the CondAEModel, and it was only necessary to override a few inherited functions.

I expanded the mlmodel tests to cover all three models now.

Lastly, I refactored the _step function of the CondAEModel to remove the "step_ae" argument. This leads to better inheritance, as the CondVAEModel and the CondVQVAEModel no longer have to override the _step function.

Changes:

  • added src/aixd/mlmodel/architecture/cond_vqvae_model.py containing the CondVQVAEModel class
  • added VQEncoder class to src/aixd/mlmodel/architecture/encoders.py
  • added VQLoss class to src/aixd/mlmodel/architecture/losses.py
  • adapted tests/test_ml_model.py to test all three model classes CondAEModel, CondVAEModel, CondVQVAEModel
  • removed the step_ae argument from the _step function in src/aixd/mlmodel/architecture/cond_ae_model.py and refactored CondVAEModel and CondVQVAEModel for improved inheritance

Checklist

  • I ran all tests on my computer and it's all green (i.e. invoke test).
  • I ran lint on my computer and there are no errors (i.e. invoke lint).
  • I have added tests that prove my fix is effective or that my feature works.
  • I have added necessary documentation (if appropriate)
Edited by Richard Danis

Merge request reports

Loading
Loading

Activity

Filter activity
  • Approvals
  • Assignees & reviewers
  • Comments (from bots)
  • Comments (from users)
  • Commits & branches
  • Edits
  • Labels
  • Lock status
  • Mentions
  • Merge request status
  • Tracking
  • Loading
  • Loading
Please register or sign in to reply
Loading