Exercises: Variance Reduction by Reparametrization

import numpy as np
from matplotlib import pyplot as plt
from scipy.stats import norm

%matplotlib inline

Introduction

We want to compute the gradient of an expectation:

$ \nabla_\phi \mathbb E_{q_\phi(z)} \left[ f(z) \right] $

Here, we assume that$ f(z) $ has no dependency w.r.t.$ \phi $.

Score Function Estimator

Crude MC estimator

$ \begin{align} \nabla_\phi \mathbb E_{q_\phi(z)} \left[ f(z) \right] &= \mathbb E_{q_\phi(z)} \left[ f(z) \nabla_\phi \log q_\phi(z)\right]\\ & = \int f(z) q_\phi(z) \nabla_\phi \log q_\phi(z) dz \\ \end{align} $

The integral can be estimated by Monte-Carlo with sampling of the hidden variables$ z^{(i)} $ from$ q_\phi(z) $ (notation:$ z^{(i)}\sim q_\phi(z) $):

$ \int f(z) q_\phi(z) \nabla_\phi \log q_\phi(z) dz \approx \frac{1}{n} \sum_{i=1}^n f(z^{(i)})\nabla_\phi \log q_\phi(z^{(i)}) $

The problem of the crude MC estimator is that it has high variance.

Example

Let's keep the things simple. We use a Gaussian for our distribution$ q_{\phi}(z^{(i)}) $ with$ \phi = \{ \mu, \sigma^2\} $:

$ \mathcal N (z \mid \mu, \sigma^2) = \frac{1}{\sqrt{2\pi \sigma^2}} \exp(-\frac{(z-\mu)^2}{2 \sigma^2}) $

As function we use$ f(z)=z^2 $.

loc=mu=10. 
scale=3.
sigma2 = scale**2
size=100

def sample_z_from_q(loc=loc, scale=scale, size=size):
    return np.random.normal(loc=loc, scale=scale, size=size)
def f(z):
    return z**2

Show that the derivatives w.r.t.$ \mu $ and$ \sigma^2 $ are:

$ \frac{\partial}{\partial \mu} \log \mathcal N (z \mid \mu, \sigma^2) = \frac{(z-\mu)}{\sigma^2}\\ $

$ \frac{\partial}{\partial \sigma^2} \log \mathcal N (z \mid \mu, \sigma^2) = - \frac{1}{2 \sigma^2}+ \frac{(z-\mu)^2}{2 (\sigma^2)^2}\\ $

$ \begin{align} \frac{\partial}{\partial \mu} \log \mathcal N (z \mid \mu, \sigma^2) &= \frac{\partial}{\partial \mu} \log \left(\frac{1}{\sqrt{2\pi \sigma^2}} \exp(-\frac{(z-\mu)^2}{2 \sigma^2})\right)\\ &= \frac{\partial}{\partial \mu} \left(\log \left(\frac{1}{\sqrt{2\pi \sigma^2}} \right)+ \left(\log \exp(-\frac{(z-\mu)^2}{2 \sigma^2})\right)\right)\\ &= \frac{\partial}{\partial \mu} \log \left(\frac{1}{\sqrt{2\pi \sigma^2}} \right)+ \frac{\partial}{\partial \mu} \left(\log \exp(-\frac{(z-\mu)^2}{2 \sigma^2})\right)\\ &= 0 - \frac{\partial}{\partial \mu}\frac{(z-\mu)^2}{2 \sigma^2}\\ &= \frac{(z-\mu)}{\sigma^2}\\ \end{align} $

$ \begin{align} \frac{\partial}{\partial \sigma^2} \log \mathcal N (z \mid \mu, \sigma^2) &= \frac{\partial}{\partial \sigma^2} \log \left(\frac{1}{\sqrt{2\pi \sigma^2}} \exp(-\frac{(z-\mu)^2}{2\sigma^2})\right) \\ &= \frac{\partial}{\partial \sigma^2} \left(\log \left(\frac{1}{\sqrt{2\pi \sigma^2}} \right)+ \left(\log \exp(-\frac{(z-\mu)^2}{2\sigma^2})\right)\right)\\ &= \frac{\partial}{\partial \sigma^2} \log \left(\frac{1}{\sqrt{2\pi \sigma^2}} \right)+ \frac{\partial}{\partial \sigma^2} \left(\log \exp(-\frac{(z-\mu)^2}{2\sigma^2})\right)\\ &= - \frac{\partial}{\partial \sigma^2} \log \left(\sqrt{2\pi \sigma^2} \right)+ \frac{\partial}{\partial \sigma^2} \left(\log \exp(-\frac{(z-\mu)^2}{2 \sigma^2})\right)\\ &= - \frac{1}{2 \sigma^2}+ \frac{(z-\mu)^2}{2 (\sigma^2)^2}\\ \end{align} $

def dq_dmu(z, mu=mu, sigma2=sigma2):
    return (z - mu)/sigma2

def dq_dsigma2(z, mu=mu, sigma2=sigma2):
    return - 1/(2*sigma2) + (z-mu)**2/sigma2**2

Exercise: Implement the Crude MC estimator:

$ \begin{align} \frac{\partial}{\partial \mu} \mathbb E_{\mathcal N (z \mid \mu, \sigma^2) } \left[ f(z) \right] &= \mathbb E_{\mathcal N (z \mid \mu, \sigma^2) } \left[ f(z) \frac{\partial}{\partial \mu} \log \mathcal N (z \mid \mu, \sigma^2)\right]\\ & \approx \frac{1}{n} \sum_{i=1}^n f(z^{(i)}) \frac{\partial}{\partial \mu} \log \mathcal N (z \mid \mu, \sigma^2) = \frac{1}{n} \sum_{i=1}^n f(z^{(i)}) \frac{(z^{(i)}-\mu)}{\sigma^2} \end{align} $

def get_crude_mc_estimates(size):
    estimates = []
    for i in range(10000):
        estimate = crude_mc_estimate(size=size, f=f, dq_dmu=dq_dmu)
        estimates.append(estimate)
    return np.array(estimates)
estimates = get_crude_mc_estimates(size=100)
plt.hist(estimates);
#plt.xlim(15, 25)
cell-output
estimates = get_crude_mc_estimates(size=10000)
plt.hist(estimates);
plt.xlim(15, 25)
cell-output

Reparametrization for variance reduction

$ z = g_\phi(\epsilon) $

$ \nabla_\phi \mathbb E_{q_\phi(z)} \left[ f(z) \right] = \mathbb E_{p(\epsilon)} [\nabla_\phi f (g_\phi(\epsilon))] $ $ p(\epsilon) $ is a (fixed) distribution with no dependence on$ \phi $.

Here: -$ p(\epsilon) = \mathcal N(0, 1) $ -$ g_{\mu, \sigma^2}(\epsilon) = \mu + \epsilon \sigma $

Exercise

  • Implement the reparametrized MC estimator

$ \frac{\partial}{\partial \mu} \mathbb E_{\mathcal N (z \mid \mu, \sigma^2) } \left[ f(z) \right] = \mathbb E_{p(\epsilon)} \left[\frac{\partial}{\partial \mu} f (g_{\mu, \sigma^2}(\epsilon))\right] $

  • Find a mathematical expression for$ \frac{\partial f}{\partial \mu} $ and use it in the estimator.
  • Sample$ \epsilon \sim \mathcal N(0, 1) $
def sample_epsilon_from_p(size=size): # no dependence on mu and sigma2
    return np.random.normal(loc=0, scale=1, size=size)

$ \nabla_\phi f (g_\phi(\epsilon)) $

$ g = \mu + \epsilon \sigma $

$ z = g_{\mu, \sigma^2}(\epsilon) $

$ f(z) = f(g) = g^2 $

$ \frac{\partial f}{\partial \mu} = \frac{\partial f}{\partial g}\frac{\partial g}{\partial \mu} = 2 g $

def get_reparametrized_mc_estimates(size):
    estimates = []
    for i in range(10000):
        estimate= reparametrized_mc_estimate(size)
        estimates.append(estimate)
    return np.array(estimates)
def g(epsilon, mu=mu, sigma2=sigma2):
    return mu + epsilon * np.sqrt(sigma2)
def reparametrized_mc_estimate(size):
    epsilons = sample_epsilon_from_p(size)
    return (2 * g(epsilons)).mean()
# here we need just 100 samples to get a similar result as for 10000 "crude" mc samples
estimates = get_reparametrized_mc_estimates(100)
plt.hist(estimates);
plt.xlim(15, 25)
cell-output
#### for 10000 samples
estimates = get_reparametrized_mc_estimates(10000)
plt.hist(estimates);
plt.xlim(15, 25)
cell-output

Comparision between crude MC and reparametrized MC

crude_mc_var = []
reparam_mc_var = []
ns = np.array([10,30,90,300,600,900, 1500, 3000])
for i in ns:
    crude_mc_estimates = get_crude_mc_estimates(size=i)
    reparam_mc_estimates = get_reparametrized_mc_estimates(size=i)
    crude_mc_var.append(crude_mc_estimates.var(ddof=1))
    reparam_mc_var.append(reparam_mc_estimates.var(ddof=1))
crude_mc_var = np.array(crude_mc_var)
reparam_mc_var = np.array(reparam_mc_var)
plt.figure(figsize=(10, 8))
plt.loglog(ns, crude_mc_var, "bo", label="crude mc")
plt.loglog(ns, crude_mc_var, "b-")
plt.loglog(ns, reparam_mc_var, "go", label="reparametrized mc")
plt.loglog(ns, reparam_mc_var, "g-")
plt.xlabel("number of samples")
plt.ylabel("(empirical) variance of the estimators")
plt.legend();
cell-output

As you can see from the plot there is a strait line in a log-log plot:

$ \log y = a + b \log x $

therefore (with$ c = \exp(a) $): $ y = c x^b $

We know:

$ var(estimate) = \frac{c}{n} $

so be must be$ -1 $. Let's check this:

b_crude = (np.log(crude_mc_var[-1]) - np.log(crude_mc_var[0]))/(np.log(ns[-1])-np.log(ns[0]))
b_crude
b_reparm = (np.log(reparam_mc_var[-1]) - np.log(reparam_mc_var[0]))/(np.log(ns[-1])-np.log(ns[0]))
b_reparm

$ a = \log y - b \log x $

# with b \approx 1
as_reparam = np.log(reparam_mc_var) + np.log(ns) 
cs_reparam = np.exp(as_reparam)
cs_reparam.mean()
# with b \approx 1
as_crude = np.log(crude_mc_var) + np.log(ns) 
cs_crude= np.exp(as_crude)
cs_crude.mean()

$ var(estimate\_reparam) = \frac{35.8}{n} $

$ var(estimate\_crude) = \frac{2648}{n} $

$ \eta_{VR} =\frac{\text{var}\left(\hat H_n\right)-\text{var}\left(\bar H_n\right)}{\text{var}\left(\hat H_n\right)} $

eta = (2648-35.8)/2648
eta
# Here we need approx. 1.35% of the samples to get the same 
# variance with our variance reduction technique 
35.8/2648

Literature

  • Diederik P. Kingma, Max Welling: Auto-Encoding Variational Bayes,
  • Diederik P. Kingma, Max Welling: Efficient Gradient-based Inference through Transformation between Bayes Nets and Neural Nets,