Normalizing flows for probability distribution reconstruction

Tutorial

Normalizing flows can approximate complex probability distributions by applying a sequence of invertible transformations to a simpler base distribution, such as a multivariate Gaussian. If the base distribution has a joint probability density function (PDF) p(x)p\left(\textbf{x}\right), the sequence for invertible transformations f1,f2,...fkf_1, f_2, … f_k is applied to the original variable x\textbf{x} such that

z=fk(fk1(...f1(x)))=g(x),\textbf{z} = f_k\left(f_{k-1}\left(…f_1\left(\textbf{x}\right)\right)\right) = g\left(\textbf{x}\right),

where z\textbf{z} is the transformed variable and g(x)g\left(\textbf{x}\right) is the composition of all other transformations fif_i. This flexible framework allows for the learning of complex probability distributions that might not be obtainable by other methods.

Additionally, we can see that one of the benefits of normalizing flows is efficient sampling of complex distributions. Samples can be easily calculated for the base distribution, such as a Gaussian, which then through the normalizing flow transform g(x)g\left(\textbf{x}\right) provides samples from the learned distribution.

The transformed joint PDF p(z)p\left(\textbf{z}\right) can be easily calculated from the trained normalizing flow. The normalization condition for a PDF dictates that, for both p(x)p\left(\textbf{x}\right) and p(z)p\left(\textbf{z}\right),

...dx1...dxN  p(x)=1,(1)\int … \int d x_1 … d x_N \; p\left(\textbf{x}\right) = 1, \tag{1}

...dz1...dzN  p(z)=1.\int … \int d z_1 … d z_N \; p\left(\textbf{z}\right) = 1.

Because z=g(x)\textbf{z} = g\left(\textbf{x}\right), the second integral can be expressed in terms of xx as

1=...dz1...dzN  p(z)=...dx1...dxNdetdzdxp(g(x)),(2)\begin{align} 1 &= \int … \int d z_1 … d z_N \; p\left(\textbf{z}\right) \\ &= \int … \int d x_1 … d x_N \left| \text{det} \frac{d \textbf{z}}{d \textbf{x}} \right| p\left(g\left(x\right)\right), \end{align}\tag{2}

where dzdx\frac{d \textbf{z}}{d \textbf{x}} is the matrix of first derivatives whose determinant is the Jacobian of the transformation function gg. With the correct assumptions about the integration domain, comparison of Equations 1 and 2 yield

p(z)=p(x)detdxdz,p\left(\textbf{z}\right) = p\left(\textbf{x}\right) \left| \text{det} \frac{d \textbf{x}}{d \textbf{z}} \right|,

which can be easily calculated from the trained normalizing flow using automatic differentiation. See these resources [1] [2] [3] for more details about the transformation of random variables.

Tutorial

This tutorial will demonstrate the use of the normflows Python package (see the normalizing-flows Github repository) for a simple example using a normalizing flow to approximate an unknown target distribution. An IPython notebook containing the full code in this tutorial can be found here. For this example, the 2D target distribution p(a,b)p(a,b) will consist of a random variable aa sampled from a uniform distribution, and another random variable bb sampled from a Normal distribution with mean and variance aa.

aUniform[1,2]a \sim \text{Uniform}\left[1,2\right]

bNormal(μ=a,σ2=a).b \sim \text{Normal}\left(\mu=a, \sigma^2=a\right).

Stated differently, we want the normalizing flow approximation NF(a,b)p(a,b)NF\left(a, b\right) \approx p\left(a,b\right). It is assumed that samples from the target distribution p(a,b)p\left(a,b\right) can be easily obtained.

This tutorial will begin by visualizing the target distribution p(a,b)p\left(a,b\right), followed by construction and training of the normalizing flow on samples from p(a,b)p\left(a,b\right).

import normflows as nf
import torch
import numpy as np
import matplotlib.pyplot as plt

Visualizing the target distribution

To help us get an idea of what this target distribution p(a,b)p\left(a,b\right) looks like, we will visualize it in a couple different ways. First, we’ll draw 1000010000 samples of pairs (a,b)\left(a,b\right) from the distributions defined in Equations 4 and 5 and visualize the resulting 2D scatter plot.

N = 10000
sampled_a_vals = np.random.uniform(1, 2, N) # Uniform distribution
sampled_b_vals = np.random.normal(sampled_a_vals, np.sqrt(sampled_a_vals)) # Normal distribution with mean a and variance a

png

We see that as expected, both the mean and the standard deviation of the bb distribution increase with increasing aa. Note that the sampled points appear to have the highest density for low aa. We can see the same thing by calculating the expected probability distribution p(a,b)p\left(a,b\right) analytically, using the law of total probability.

def p_a(a):
    """Probability density function of a ~ Uniform[1,2]. Trivial function in this case, but included for consistency."""
    return np.where((a >= 1) & (a <= 2), 1, 0)

def p_b_given_a(b, a):
    """Probability density function of b ~ Normal(a, a), p(b|a). 
    
    Args:
        b (np.ndarray): Values of b.
        a (np.ndarray): Values of a, mean and variance of b.
    """
    return (1 / (np.sqrt(2 * np.pi * a))) * np.exp(-0.5 * ((b - a)**2 / a))

def p_a_b(ab):
    """Joint probability density p(a, b).
    
    Args:
        ab (np.ndarray): 2D array where each row is a tuple (a, b).
    """
    a = ab[:, 0]
    b = ab[:, 1]
    return p_a(a) * p_b_given_a(b, a)

png

In addition to visualizing the full 2D distribution, we can look at just the aa and bb target distributions by marginalizing out the other variable. Because aa does not depend on bb, the marginalized distribution for aa is just the p(a)Uniform[1,2]p\left(a\right) \sim \text{Uniform}\left[1,2\right]. The marginal bb distribution, however, is more challenging to obtain because it depends on aa in a non-trivial way. If you think about it, as aa increases from 11 to 22, the distribution for bb will drift and spread out, meaning that the total marginal distribution of bb is a combination of all these different Normal distributions for different values of aa. The marginal distribution p(b)p\left(b\right) can be calculated by

p(b)=12da12πaexp((ba)22a).p(b) = \int_1^2 da \frac{1}{\sqrt{2 \pi a}} \text{exp}\left(\frac{\left(b-a\right)^2}{2a}\right).

This integral is challenging to solve analytically, but can be calculated straightforwardly via numerical integration. Below is shown the marginal distributions for aa and bb.

png

Setting up the normalizing flow

Now that we understand our target distributions, we can begin to set up our normalizing flow to approximate p(a,b)p(a,b).

Here, we use the normflows Python package which utilizes Pytorch. Setting up a normalizing flow includes deciding on the number and type of flow layers and the base distribution. The base distribution is the initial choice of probability density, which the normalizing flow will transform over the course of training. These choices are problem specific.

# Define NF architecture
torch.manual_seed(0)
K = 16 # number of repeated blocks
latent_size = 2 # num input channels
flows = []
for i in range(K):
    param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)
    flows += [nf.flows.AffineCouplingBlock(param_map)]
    flows += [nf.flows.LULinearPermute(latent_size)]

base = nf.distributions.DiagGaussian(2, trainable=False) # Base distribution
model = nf.NormalizingFlow(q0=base, flows=flows)

Training the normalizing flow

The next step is to train the normalizing flow on samples of (a,b)(a,b) from the true distributions. Here the loss function is chosen to be the negative log likelihood. Using the Adam optimizer, the normalizing flow is trained for 1000 epochs, with 512 samples of the true distribution per epoch.

# Train NF
epochs = 1000
num_samples = 2 ** 9 # 512 samples per iteration
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

for it in range(epochs):
    optimizer.zero_grad()
    x = sample_ab(num_samples, device) # Get training samples
    loss = -1*model.log_prob(x).mean() # Compute loss
    # Do backprop and optimizer step
    if ~(torch.isnan(loss) | torch.isinf(loss)): # Check for NaNs or infs
        loss.backward()
        optimizer.step()
    save_info(it, model, loss)

Following training, we can see that in a small number of iterations the normalizing flow learned distribution was able to effectively capture the primary features of the target distribution. The error plot on the right shows where the true and learned distributions disagree.

png

To get an idea of how quickly the normalizing flow converged to this approximation, we can look at the training loss over epochs. We see that the loss for the normalizing flow solution has plateaued by the end of training. In fact, even if we train for 10x as long the learned distribution does not improve significantly.

# Plot loss
plt.figure(figsize=(6, 6))
plt.plot(loss_hist, label='loss')
plt.title('Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log')
plt.show()

png

By looking at probability densities at intermediate epochs, we can see that the approximate normalizing flow solution achieves reasonable accuracy early in training.

png

From the loss function, we see that the largest changes happen at the very beginning of training. In the subplots below, by looking at some of the first few epochs we can see the probability density migrating from the original Gaussian centered at (0,0)(0,0) towards the region with the target density.

png

As another perspective of the normalizing flow solution, we can look at the marginalized probability densities for aa and bb and compare the true and approximate solutions, in red and black respectively.

png

Sampling from the trained normalizing flow

Once the normalizing flow has been trained, sampling from it is trivial.

model.eval() #Set model to evaluation mode
samples, log_prob = model.sample(num_samples=1000) #Sample from normalizing flow

png

Acknowledgements

Some of the code in this tutorial was adapted from example scripts in the normalizing-flows repository: https://github.com/VincentStimper/normalizing-flows

Stimper et al., (2023). normflows: A PyTorch Package for Normalizing Flows. Journal of Open Source Software, 8(86), 5361, https://doi.org/10.21105/joss.05361

Author

Jay Spendlove

PhD student, Arizona State University

jcspendl@asu.edu

Author


Posted

in

by