VQ-VAE implementation
Merge Request Description
Summary:
Implementation of the vector quantized VAE (VQ-VAE) architecture (https://arxiv.org/abs/1711.00937) as a third model option.
This architecture features a discrete latent space, which may help in interpretability. There is a fixed number of embedding vectors (num_embeddings argument) whose dimensionality can be chosen with the embedding_dim argument. The encoder produces a number (equal to latent_dim) of vectors of dimension embedding_dim, and those are then mapped to the closest embedding vectors. This is essentially a quantization step. There is then a quantization loss that punishes how far the predicted vectors deviate from the embedding vectors, and this loss has an additional hyperparameter, commit_weight. The loss is extensively explained in https://huggingface.co/blog/ariG23498/understand-vq.
I tested performance mainly on the parametric vase dataset from https://gitlab.renkulab.io/luis.salamanca/aixd-tutorial. In reconstruction error, it is mostly better than VAE, although one has to play a bit with the 4 hyperparameters latent_dim, num_embeddings, embedding_dim, and commit_weight.
I tried to reuse as much code as possible. The CondVQVAEModel inherits from the CondAEModel, and it was only necessary to override a few inherited functions.
I expanded the mlmodel tests to cover all three models now.
Lastly, I refactored the _step function of the CondAEModel to remove the "step_ae" argument. This leads to better inheritance, as the CondVAEModel and the CondVQVAEModel no longer have to override the _step function.
Changes:
- added src/aixd/mlmodel/architecture/cond_vqvae_model.py containing the CondVQVAEModel class
- added VQEncoder class to src/aixd/mlmodel/architecture/encoders.py
- added VQLoss class to src/aixd/mlmodel/architecture/losses.py
- adapted tests/test_ml_model.py to test all three model classes CondAEModel, CondVAEModel, CondVQVAEModel
- removed the step_ae argument from the _step function in src/aixd/mlmodel/architecture/cond_ae_model.py and refactored CondVAEModel and CondVQVAEModel for improved inheritance
Checklist
-
I ran all tests on my computer and it's all green (i.e. invoke test
). -
I ran lint on my computer and there are no errors (i.e. invoke lint
). -
I have added tests that prove my fix is effective or that my feature works. -
I have added necessary documentation (if appropriate)