Conventional variational autoencoders fail in modeling correlations between data points due to their use of factorized priors. Amortized Gaussian process inference through GP-VAEs has led to significant improvements in this regard, but is still inhibited by the intrinsic complexity of exact GP inference. We improve the scalability of these methods through principled sparse inference approaches. We propose a new scalable GP-VAE model that outperforms existing approaches in terms of runtime and memory footprint, is easy to implement, and allows for joint end-to-end optimization of all components.
IntroductionVariational autoencoders (VAEs) are among the most widely used models in representation learning and generative modeling Welling, 2013, 2019;Rezende et al., 2014). As VAEs typically make use of factorized priors, they fall short when modeling correlations between different data points. However, more expressive priors that capture correlations enable useful applications. Casale et al. (2018), for instance, showed that by modeling prior correlations between the data, one could generate a digit's rotated image based on rotations of the same digit at different angles.Gaussian process VAEs (GP-VAEs) have been designed to overcome this shortcoming (Casale et al., 2018). These models introduce a Gaussian process (GP) prior over the latent variables that correlates pairs of latent variables through a kernel function. While GP-VAEs have outperformed standard VAEs on many tasks (Casale et al., 2018;Pearce, 2020), combining the GPs and VAEs brings along fundamental computational challenges. On the one hand, neural networks reveal their full power in conjunction with large datasets, making mini-batching a practical necessity. GPs, on the other hand, are traditionally restricted to medium-scale datasets due to their unfavorable scaling. In GP-VAEs, these contradictory demands must be reconciled, preferably by reducing the O(N 3 ) complexity of GP inference, where N is the number of data points.Despite recent attempts to improve the scalability of GP-VAE models by using specifically designed kernels and inference methods (Casale et al., 2018;, a generic way to scale these models, regardless of data type or kernel choice, has remained elusive. This limits current GP-VAE implementations to small-scale datasets. In this work, we introduce the first generically scalable method for training GP-VAEs based on inducing points. We thereby improve the computational complexity from O(N 3 ) to O(bm 2 + m 3 ), where m is the number of inducing points and b is the batch size.