Bayesian Linear Regression

via variational EM (mean field approximation)

%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import scipy.stats
np.random.seed(4)
import daft
def plot_glm():
    pgm = daft.PGM([5.3, 4.05], origin=[-0.3, -0.3], aspect=1.)
    pgm.add_node(daft.Node("alpha", r"$\alpha$", 2.5, 3, fixed=True))
    pgm.add_node(daft.Node("tau", r"$\tau$", 3.5, 2.2, fixed=True))

    pgm.add_node(daft.Node("theta", r"$\theta$", 2.5, 2.2))
    # Data.
    pgm.add_node(daft.Node("xi", r"$\vec x^{(i)}$", 1.5, 1, fixed=True))
    pgm.add_node(daft.Node("yi", r"$y^{(i)}$", 2.5, 1, observed=True))

    pgm.add_node(daft.Node("x", r"$\vec x$", 4.5, 1, fixed=True))
    pgm.add_node(daft.Node("y", r"$y$", 3.5, 1))

    # Add in the edges.
    pgm.add_edge("alpha", "theta")
    pgm.add_edge("theta", "yi")
    pgm.add_edge("xi", "yi")
    pgm.add_edge("tau", "yi")

    pgm.add_edge("x", "y")
    pgm.add_edge("theta", "y")
    pgm.add_edge("tau", "y")
    # And a plate.
    pgm.add_plate(daft.Plate([1., .4, 2, 1.1], label=r"$i = 1, \ldots, n$",
    shift=-0.1))

    pgm.render()
plot_glm()

We will show how to estimate regression parameters using a simple linear model

We can restate the linear model: $ y = \theta_0 + \theta_1 x + \epsilon $

as sampling from a probability distribution

$ y \sim \mathcal N(\theta_0 + \theta_1 x, \tau^{-1}) $

with$ \tau = 1/\sigma^2 $

Likelihood: $ p(Y \mid X; \theta) = \prod_{i=1}^m \sqrt{\frac{\tau}{2\pi }} \exp \left(-\frac{\tau}{2}(\theta_0 + \theta_1 x_i - y_i)^2\right) $

Log-Likelihood: $ \log p(Y \mid X, \theta; \tau) = m \log \sqrt{\frac{\tau}{2\pi }} - \sum_{i=1}^m \left(\frac{\tau}{2}(\theta_0 + \theta_1 x_i - y_i)^2\right) $

For simplicity we will use the improper prior

$ p(\theta) = const. $

Here:

  • we want Bayes on the parameters$ \theta $. So we want an approximation for$ p(\theta \mid \mathcal D) $
  • the noise has a parameter$ \tau $. For$ \tau $ we just want a point estimate.

Variational EM

Mean field approximation$ q(\theta) = q(\theta_0)q(\theta_1) $ with$ q(\theta_k) $ are univariate Gaussian.

E-Step:

Loop until convergence: $ \log q({\theta}_k) = \mathbb E_{q_{-k}} \left[ \log{\hat p( {\theta} \mid {\mathcal D} )} \right] + const. $

$ \begin{align} \hat p( {\theta} \mid X, Y ) &= p( Y \mid X; \theta ) p(\theta)\\ &= p( Y \mid X; \theta ) p(\theta_0,\theta_1)\\ \end{align} $

$ \begin{align} \log \hat p( {\theta} \mid X, Y ) &= \log p( Y \mid X; \theta ) + \log p(\theta_0,\theta_1)\\ &=\log p( Y \mid X; \theta ) + const\\ \end{align} $

Mean field approximation: $ \begin{align} \log q({\theta}_0) &= \mathbb E_{q_{\theta_1}} \left[ \log{\hat p( {\theta} \mid X, Y )} \right] + const.\\ &= \mathbb E_{q_{\theta_1}} \left[ \log p( Y \mid X; \theta ) + \log p(\theta_0,\theta_1) \right] + const.\\ &= \mathbb E_{q_{\theta_1}} \left[- \sum_{i=1}^m \left(\frac{\tau}{2}(\theta_0 + \theta_1 x_i - y_i)^2\right) \right] + const.\\ &= - \sum_{i=1}^m \left(\frac{\tau}{2}(\theta_0 + \mathbb E_{q_{\theta_1}} \left[\theta_1\right] x_i - y_i)^2\right) + const.\\ \end{align} $

with$ \gamma_1 = \mathbb E_{q_{\theta_1}}[\theta_1] $

$ \begin{align} \log q({\theta}_0) &= - \sum_{i=1}^m \left(\frac{\tau}{2}(\theta_0 + \gamma_1 x_i - y_i)^2\right) + const.\\ \end{align} $

$ \begin{align} q({\theta}_0) &= C \exp \left( - \frac{\tau}{2} \sum_{i=1}^m (\theta_0 + \gamma_1 x_i - y_i)^2\right) \\ &= C \exp \left( - \frac{\tau}{2} \sum_{i=1}^m \left(\theta_0^2 + 2 \theta_0 (\gamma_1 x_i - y_i) + (\gamma_1 x_i - y_i)^2\right)\right) \\ &= C \exp \left( - \frac{\tau}{2} \left(m \theta_0^2 + 2 \theta_0 \sum_{i=1}^m (\gamma_1 x_i - y_i) + \sum_{i=1}^m (\gamma_1 x_i - y_i)^2\right) \right)\\ &= C \exp \left( - \frac{\tau}{2} \left(m \theta_0^2 + 2 \theta_0 \sum_{i=1}^m (\gamma_1 x_i - y_i) + \sum_{i=1}^m (\gamma_1 x_i - y_i)^2 \right) \right)\\ &= C \exp \left( - \frac{\tau}{2} \left(m \theta_0^2 + 2 \theta_0 ma + \sum_{i=1}^m (\gamma_1 x_i - y_i)^2 + ma^2 -ma^2\right) \right)\\ &= C' \exp \left( - \frac{\tau}{2} \left(m \theta_0^2 + 2 \theta_0 ma + (ma)^2\right) \right)\\ &= C' \exp \left( - \frac{\tau m }{2} \left( \theta_0^2 + 2 \theta_0 a + a^2\right) \right)\\ &= C' \exp \left( - \frac{\tau m }{2} \left( \theta_0 +a\right)^2 \right)\\ &= \mathcal N\left(-a , (\tau m)^{-1}\right) \end{align} $

with$ ma = \sum_{i=1}^m \gamma_1 x_i - y_i $ resp. $ a = (\sum_{i=1}^m \gamma_1 x_i - y_i) / m $

and$ \gamma_0 = \mathbb E_{q_{\theta_0}}[\theta_0] = -a $

def get_q0_param(x, y, gamma1, tau):
    m = x.shape[0]
    a = gamma1 * x.mean() - y.mean()
    return -a, tau * m

analog for$ \theta_1 $

$ \begin{align} q({\theta}_1) &= D \exp \left( - \frac{\tau}{2} \sum_{i=1}^m (\gamma_0 + \theta_1 x_i - y_i)^2\right) \\ &= D \exp \left( - \frac{\tau}{2} \sum_{i=1}^m ( \theta_1 x_i + \gamma_0 - y_i)^2\right) \\ &= D \exp \left( - \frac{\tau}{2} \sum_{i=1}^m \left(\theta_1^2x_i^2 + 2 \theta_1 x_i (\gamma_o - y_i) + (\gamma_0- y_i)^2\right)\right) \\ &= D \exp \left( - \frac{\tau}{2} \left(m\theta_1^2 \bar {x^2}+ 2 \theta_1 \sum_{i=1}^m x_i (\gamma_0 - y_i) + const \right) \right)\\ &= D \exp \left( - \frac{\tau m \bar {x^2}}{2} \left( \theta_1^2 + 2 \theta_1 b + b^2 + const \right) \right)\\ &= D' \exp \left( - \frac{\tau m \bar {x^2}}{2} \left( \theta_1 + b)^2 \right) \right)\\ &= \mathcal N\left(-b , (\tau m \bar {x^2})^{-1}\right) \end{align} $

with $ b = \frac{\sum_{i=1}^m x_i (\gamma_0- y_i)}{m \bar {x^2}} $ $ \gamma_1 = -b $

def get_q1_param(x, y, gamma0, tau):
    m = x.shape[0]
    x_square_sum = (x**2).sum()
    b = (x.dot(gamma0-y) )/ x_square_sum
    return -b, tau * x_square_sum
# observed data
n = 6
a = 20
b = 4
sigma = 4.3
x = np.linspace(0, 1, n)
y = a*x + b + np.random.normal(0, sigma, n)
plt.plot(x, y, 'xb')
def e_step(x, y, tau, mean0=1, mean1=1):
    for i in range(1000):
        mean0, prec0 = get_q0_param(x, y, gamma1=mean1, tau=tau)
        mean1, prec1 = get_q1_param(x, y, gamma0=mean0, tau=tau)
    return mean0, prec0, mean1, prec1
tau =1/sigma**2
print(tau)
mean0, prec0, mean1, prec1 = e_step(x, y, tau, mean0=1, mean1=1)
mean1, 1/np.sqrt(prec1)
mean0, 1/np.sqrt(prec0)

M-Step

$ \log p(Y, \theta \mid X; \tau) = \log p(Y \mid X, \theta; \tau) + \log p(\theta) = \log p(Y \mid X, \theta; \tau) + const. $

$ \begin{align} \text{arg} \max_{\bf \tau} \mathbb E_{q({\theta_0}),q({\theta_1})}\left[\log p(Y, \theta \mid X; \tau) \right] &= \text{arg} \max_{\bf \tau} \mathbb E_{q({\theta_0}),q({\theta_1})}\left[ \log p(Y \mid X, \theta; \tau) \right]\\ &= \text{arg} \max_{\bf \tau} \mathbb E_{q({\theta_0}),q({\theta_1})}\left[ m \log \sqrt{\frac{\tau}{2\pi }} - \sum_{i=1}^m \left(\frac{\tau}{2}(\theta_0 + \theta_1 x_i - y_i)^2\right) \right]\\ &= \text{arg} \max_{\bf \tau} \left[ m \log \sqrt{\frac{\tau}{2\pi }} - \sum_{i=1}^m \left(\frac{\tau}{2}(\mathbb E_{q({\theta_0})}[\theta_0] + \mathbb E_{q({\theta_1})}[\theta_1] x_i - y_i)^2\right) \right]\\ &= \text{arg} \max_{\bf \tau} \left[ m \log \sqrt{\frac{\tau}{2\pi }} - \sum_{i=1}^m \left(\frac{\tau}{2}( \gamma_0 + \gamma_1 x_i - y_i)^2\right) \right]\\ \end{align} $

set the derivative equal 0: $ \begin{align} 0 &= \frac{\partial}{\partial \tau} \left( m \log \left(\frac{\tau}{2\pi }\right)^{1/2} - \sum_{i=1}^m \left(\frac{\tau}{2}( \gamma_0 + \gamma_1 x_i - y_i)^2\right) \right)\\ 0 &= \frac{m}{\left(\frac{\tau}{2\pi }\right)^{1/2}} \frac{1}{2}\left(\frac{\tau}{2\pi }\right)^{-1/2} \frac{1}{2\pi} - \sum_{i=1}^m \left(\frac{1}{2}( \gamma_0 + \gamma_1 x_i - y_i)^2\right) \\ 0 &= \frac{m}{\left(\frac{\tau}{2\pi }\right) } \frac{1}{2} \frac{1}{2\pi} - \sum_{i=1}^m \left(\frac{1}{2}( \gamma_0 + \gamma_1 x_i - y_i)^2\right) \\ 0 &= \frac{m}{ 2 \tau} - \sum_{i=1}^m \left(\frac{1}{2}( \gamma_0 + \gamma_1 x_i - y_i)^2\right) \\ \frac{m}{ \tau} &= \sum_{i=1}^m ( \gamma_0 + \gamma_1 x_i - y_i)^2 \\ \tau &= \frac{m}{\sum_{i=1}^m ( \gamma_0 + \gamma_1 x_i - y_i)^2} \end{align} $

def m_step_tau(x, y, gamma0, gamma1):
    return 1/((gamma0+gamma1*x-y)**2).mean()
# MC-Approximation of the neg_loglikelihood 
# TODO: Variance Reduction possible??
def approx_neg_log_likelihood(x, y, gamma0, gamma1, tau, size=100):
    # random samples of theta
    theta0s = np.random.normal(loc=gamma0, scale=1/np.sqrt(prec0),size=size)
    theta1s = np.random.normal(loc=gamma1, scale=1/np.sqrt(prec1),size=size)
    y_pre =  theta0s + np.outer(x, theta1s) 
    ll = 0.
    for i in range(size):
        ll += (np.log(scipy.stats.norm.pdf(y, loc=y_pre[:,i], scale=1/np.sqrt(tau)))).mean()
    ll = ll / size
    return -ll

gamma0=1
gamma1=1
tau=.001

for i in range(10):    
    gamma0, prec0, gamma1, prec1 = e_step(x, y, tau, gamma0, gamma1)
    #print (prec0,prec1)
    tau = m_step_tau(x, y, gamma0, gamma1)
    #print (tau)
    print(approx_neg_log_likelihood(x, y, gamma0, gamma1, tau, size=100))
gamma0, prec0, gamma1, prec1
# theta 0: mean and std
mean0, 1/np.sqrt(prec0)
# theta 1: mean and std
mean1, 1/np.sqrt(prec1)
# estimated noise (std) of the data
 # 
print("true noise value: ", sigma)
print("ml estimation for noise value: ", 1/np.sqrt(tau))

Plot of the data

# random samples of theta
size=100
theta0s = np.random.normal(loc=gamma0, scale=1/np.sqrt(prec0),size=size)
theta1s = np.random.normal(loc=gamma1, scale=1/np.sqrt(prec1),size=size)
plt.plot(x, y, "b*");
xp = np.array([x.min(), x.max()])
_=plt.plot(xp,theta1s*xp[:, None] + theta0s, c='red', alpha=0.1)
t0s = np.arange(2,8,0.01)
p0 = scipy.stats.norm.pdf(t0s, loc=gamma0, scale=1/np.sqrt(prec0))
plt.plot(t0s, p0)
plt.title("")
plt.xlabel("theta0")
plt.ylabel("q(theta0)")
print("true intersect: ", b)
t1s = np.arange(10,30,0.01)
p1 = scipy.stats.norm.pdf(t1s, loc=gamma1, scale=1/np.sqrt(prec1))
plt.plot(t1s, p1)
plt.title("q($theta1$)")
plt.xlabel("theta1")
plt.ylabel("q(theta1)")
print("true slope: ", a)