- Tutorial
- Visualizing the target distribution
- Setting up the normalizing flow
- Training the normalizing flow
- Sampling from the trained normalizing flow
- Acknowledgements
- Author
Normalizing flows can approximate complex probability distributions by applying a sequence of invertible transformations to a simpler base distribution, such as a multivariate Gaussian. If the base distribution has a joint probability density function (PDF) , the sequence for invertible transformations is applied to the original variable such that
where is the transformed variable and is the composition of all other transformations . This flexible framework allows for the learning of complex probability distributions that might not be obtainable by other methods.
Additionally, we can see that one of the benefits of normalizing flows is efficient sampling of complex distributions. Samples can be easily calculated for the base distribution, such as a Gaussian, which then through the normalizing flow transform provides samples from the learned distribution.
The transformed joint PDF can be easily calculated from the trained normalizing flow. The normalization condition for a PDF dictates that, for both and ,
Because , the second integral can be expressed in terms of as
where is the matrix of first derivatives whose determinant is the Jacobian of the transformation function . With the correct assumptions about the integration domain, comparison of Equations 1 and 2 yield
which can be easily calculated from the trained normalizing flow using automatic differentiation. See these resources [1] [2] [3] for more details about the transformation of random variables.
Tutorial
This tutorial will demonstrate the use of the normflows
Python package (see the normalizing-flows
Github repository) for a simple example using a normalizing flow to approximate an unknown target distribution. An IPython notebook containing the full code in this tutorial can be found here. For this example, the 2D target distribution will consist of a random variable sampled from a uniform distribution, and another random variable sampled from a Normal distribution with mean and variance .
Stated differently, we want the normalizing flow approximation . It is assumed that samples from the target distribution can be easily obtained.
This tutorial will begin by visualizing the target distribution , followed by construction and training of the normalizing flow on samples from .
import normflows as nf
import torch
import numpy as np
import matplotlib.pyplot as plt
Visualizing the target distribution
To help us get an idea of what this target distribution looks like, we will visualize it in a couple different ways. First, we’ll draw samples of pairs from the distributions defined in Equations 4 and 5 and visualize the resulting 2D scatter plot.
N = 10000
sampled_a_vals = np.random.uniform(1, 2, N) # Uniform distribution
sampled_b_vals = np.random.normal(sampled_a_vals, np.sqrt(sampled_a_vals)) # Normal distribution with mean a and variance a
We see that as expected, both the mean and the standard deviation of the distribution increase with increasing . Note that the sampled points appear to have the highest density for low . We can see the same thing by calculating the expected probability distribution analytically, using the law of total probability.
def p_a(a):
"""Probability density function of a ~ Uniform[1,2]. Trivial function in this case, but included for consistency."""
return np.where((a >= 1) & (a <= 2), 1, 0)
def p_b_given_a(b, a):
"""Probability density function of b ~ Normal(a, a), p(b|a).
Args:
b (np.ndarray): Values of b.
a (np.ndarray): Values of a, mean and variance of b.
"""
return (1 / (np.sqrt(2 * np.pi * a))) * np.exp(-0.5 * ((b - a)**2 / a))
def p_a_b(ab):
"""Joint probability density p(a, b).
Args:
ab (np.ndarray): 2D array where each row is a tuple (a, b).
"""
a = ab[:, 0]
b = ab[:, 1]
return p_a(a) * p_b_given_a(b, a)
In addition to visualizing the full 2D distribution, we can look at just the and target distributions by marginalizing out the other variable. Because does not depend on , the marginalized distribution for is just the . The marginal distribution, however, is more challenging to obtain because it depends on in a non-trivial way. If you think about it, as increases from to , the distribution for will drift and spread out, meaning that the total marginal distribution of is a combination of all these different Normal distributions for different values of . The marginal distribution can be calculated by
This integral is challenging to solve analytically, but can be calculated straightforwardly via numerical integration. Below is shown the marginal distributions for and .
Setting up the normalizing flow
Now that we understand our target distributions, we can begin to set up our normalizing flow to approximate .
Here, we use the normflows
Python package which utilizes Pytorch. Setting up a normalizing flow includes deciding on the number and type of flow layers and the base distribution. The base distribution is the initial choice of probability density, which the normalizing flow will transform over the course of training. These choices are problem specific.
# Define NF architecture
torch.manual_seed(0)
K = 16 # number of repeated blocks
latent_size = 2 # num input channels
flows = []
for i in range(K):
param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)
flows += [nf.flows.AffineCouplingBlock(param_map)]
flows += [nf.flows.LULinearPermute(latent_size)]
base = nf.distributions.DiagGaussian(2, trainable=False) # Base distribution
model = nf.NormalizingFlow(q0=base, flows=flows)
Training the normalizing flow
The next step is to train the normalizing flow on samples of from the true distributions. Here the loss function is chosen to be the negative log likelihood. Using the Adam optimizer, the normalizing flow is trained for 1000 epochs, with 512 samples of the true distribution per epoch.
# Train NF
epochs = 1000
num_samples = 2 ** 9 # 512 samples per iteration
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
for it in range(epochs):
optimizer.zero_grad()
x = sample_ab(num_samples, device) # Get training samples
loss = -1*model.log_prob(x).mean() # Compute loss
# Do backprop and optimizer step
if ~(torch.isnan(loss) | torch.isinf(loss)): # Check for NaNs or infs
loss.backward()
optimizer.step()
save_info(it, model, loss)
Following training, we can see that in a small number of iterations the normalizing flow learned distribution was able to effectively capture the primary features of the target distribution. The error plot on the right shows where the true and learned distributions disagree.
To get an idea of how quickly the normalizing flow converged to this approximation, we can look at the training loss over epochs. We see that the loss for the normalizing flow solution has plateaued by the end of training. In fact, even if we train for 10x as long the learned distribution does not improve significantly.
# Plot loss
plt.figure(figsize=(6, 6))
plt.plot(loss_hist, label='loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')
plt.show()
By looking at probability densities at intermediate epochs, we can see that the approximate normalizing flow solution achieves reasonable accuracy early in training.
From the loss function, we see that the largest changes happen at the very beginning of training. In the subplots below, by looking at some of the first few epochs we can see the probability density migrating from the original Gaussian centered at towards the region with the target density.
As another perspective of the normalizing flow solution, we can look at the marginalized probability densities for and and compare the true and approximate solutions, in red and black respectively.
Sampling from the trained normalizing flow
Once the normalizing flow has been trained, sampling from it is trivial.
model.eval() #Set model to evaluation mode
samples, log_prob = model.sample(num_samples=1000) #Sample from normalizing flow
Acknowledgements
Some of the code in this tutorial was adapted from example scripts in the normalizing-flows
repository: https://github.com/VincentStimper/normalizing-flows
Stimper et al., (2023). normflows: A PyTorch Package for Normalizing Flows. Journal of Open Source Software, 8(86), 5361, https://doi.org/10.21105/joss.05361
Author
Jay Spendlove
PhD student, Arizona State University
Leave a Reply