In [1]:
import logging
import os

import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import pyro

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.9.1')

pyro.enable_validation(True)
pyro.set_rng_seed(1)
logging.basicConfig(format='%(message)s', level=logging.INFO)

# Set matplotlib settings
%matplotlib inline
plt.style.use('default')

In [15]:
#!pip install graphviz

Collecting graphviz
  Using cached graphviz-0.20.3-py3-none-any.whl.metadata (12 kB)
Downloading graphviz-0.20.3-py3-none-any.whl (47 kB)
Installing collected packages: graphviz
Successfully installed graphviz-0.20.3


In [3]:
DATA_URL = "rugged_data.csv"
data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
df = data[["cont_africa", "rugged", "rgdppc_2000"]]

In [4]:
df = df[np.isfinite(df.rgdppc_2000)]
df["rgdppc_2000"] = np.log(df["rgdppc_2000"])

In [5]:
train = torch.tensor(df.values, dtype=torch.float)
is_cont_africa, ruggedness, log_gdp = train[:, 0], train[:, 1], train[:, 2]

In [6]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
african_nations = df[df["cont_africa"] == 1]
non_african_nations = df[df["cont_africa"] == 0]
sns.scatterplot(x=non_african_nations["rugged"],
                y=non_african_nations["rgdppc_2000"],
                ax=ax[0])
ax[0].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="Non African Nations")
sns.scatterplot(x=african_nations["rugged"],
                y=african_nations["rgdppc_2000"],
                ax=ax[1])
ax[1].set(xlabel="Terrain Ruggedness Index",
          ylabel="log GDP (2000)",
          title="African Nations");

In [7]:
fig.savefig("rugged.png")

In [11]:
import pyro.distributions as dist
import pyro.distributions.constraints as constraints

def simple_model(is_cont_africa, ruggedness, log_gdp=None):
    a = pyro.param("a", lambda: torch.randn(()))
    b_a = pyro.param("bA", lambda: torch.randn(()))
    b_r = pyro.param("bR", lambda: torch.randn(()))
    b_ar = pyro.param("bAR", lambda: torch.randn(()))
    sigma = pyro.param("sigma", lambda: torch.ones(()), constraint=constraints.positive)

    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness

    with pyro.plate("data", len(ruggedness)):
        return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)



In [19]:
def simple_Bayesian_model(is_cont_africa, ruggedness, log_gdp=None):
    a = pyro.sample("a", dist.Normal(0., 10.))
    b_a = pyro.sample("bA", dist.Normal(0., 1.))
    b_r = pyro.sample("bR", dist.Normal(0., 1.))
    b_ar = pyro.sample("bAR", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))

    mean = a + b_a * is_cont_africa + b_r * ruggedness + b_ar * is_cont_africa * ruggedness

    with pyro.plate("data", len(ruggedness)):
        return pyro.sample("obs", dist.Normal(mean, sigma), obs=log_gdp)



In [20]:
def custom_guide(is_cont_africa, ruggedness, log_gdp=None):
    a_loc = pyro.param('a_loc', lambda: torch.tensor(0.))
    a_scale = pyro.param('a_scale', lambda: torch.tensor(1.),
                         constraint=constraints.positive)
    sigma_loc = pyro.param('sigma_loc', lambda: torch.tensor(0.))
    weights_loc = pyro.param('weights_loc', lambda: torch.randn(3))
    weights_scale = pyro.param('weights_scale', lambda: torch.ones(3),
                               constraint=constraints.positive)
    a = pyro.sample("a", dist.Normal(a_loc, a_scale))
    b_a = pyro.sample("bA", dist.Normal(weights_loc[0], weights_scale[0]))
    b_r = pyro.sample("bR", dist.Normal(weights_loc[1], weights_scale[1]))
    b_ar = pyro.sample("bAR", dist.Normal(weights_loc[2], weights_scale[2]))
    sigma = pyro.sample("sigma", dist.LogNormal(sigma_loc, torch.tensor(0.05)))  # fixed scale for simplicity
    return {"a": a, "b_a": b_a, "b_r": b_r, "b_ar": b_ar, "sigma": sigma}

In [39]:
adam = pyro.optim.Adam({"lr": 0.005})
elbo = pyro.infer.Trace_ELBO()
svi = pyro.infer.SVI(simple_Bayesian_model, custom_guide, adam, elbo)


In [42]:
%%time
pyro.clear_param_store()

losses = []
for step in range(10000 if not smoke_test else 2):  # Consider running for more steps.
    loss = svi.step(is_cont_africa, ruggedness, log_gdp)
    losses.append(loss)
    if step % 100 == 0:
        logging.info("Elbo loss: {}".format(loss))



Elbo loss: 9056.972932815552
Elbo loss: 6805.127711117268
Elbo loss: 1834.645792722702
Elbo loss: 1636.7026161551476
Elbo loss: 947.4445533156395
Elbo loss: 1086.7198023796082
Elbo loss: 755.236984372139
Elbo loss: 705.0246660709381
Elbo loss: 498.345064163208
Elbo loss: 571.2897637486458
Elbo loss: 569.0600973963737
Elbo loss: 627.2955424785614
Elbo loss: 575.6702669858932
Elbo loss: 600.1962569355965
Elbo loss: 510.9926045835018
Elbo loss: 581.2685261964798
Elbo loss: 601.3131709694862
Elbo loss: 525.7656850814819
Elbo loss: 561.4776883721352
Elbo loss: 518.7561558485031
Elbo loss: 505.4377267360687
Elbo loss: 498.3325790166855
Elbo loss: 516.0533139705658
Elbo loss: 576.8596050739288
Elbo loss: 532.4986196756363
Elbo loss: 540.2751200795174
Elbo loss: 486.9754716157913
Elbo loss: 533.1615126132965
Elbo loss: 521.3427265286446
Elbo loss: 472.76445269584656
Elbo loss: 473.88119626045227
Elbo loss: 492.6560118198395
Elbo loss: 512.5479438900948
Elbo loss: 471.1250134110451
Elbo loss: 4

CPU times: user 36.7 s, sys: 177 ms, total: 36.9 s
Wall time: 37.2 s


In [49]:
plt.figure(figsize=(5, 2))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("ELBO loss");
#plt.show()
plt.savefig("rugged_elbo.png")

In [44]:
for name, value in pyro.get_param_store().items():
    print(name, pyro.param(name).data.cpu().numpy())

a_loc 9.153103
a_scale 0.073490456
sigma_loc -0.055203907
weights_loc [-1.8690879  -0.18687144  0.33032727]
weights_scale [0.13680099 0.04291384 0.07860292]


In [45]:
with pyro.plate("samples", 800, dim=-1):
    samples = custom_guide(is_cont_africa, ruggedness)

samples.keys()

dict_keys(['a', 'b_a', 'b_r', 'b_ar', 'sigma'])

In [46]:
gamma_within_africa = samples["b_r"] + samples["b_ar"]
gamma_outside_africa = samples["b_r"]


In [58]:
fig = plt.figure(figsize=(10, 6))
sns.histplot(gamma_within_africa.detach().cpu().numpy(), kde=True, stat="density", label="African nations")
sns.histplot(gamma_outside_africa.detach().cpu().numpy(), kde=True, stat="density", label="Non-African nations", color="orange")
fig.suptitle("Density of Slope : log(GDP) vs. Terrain Ruggedness");
plt.xlabel("Slope of regression line")
plt.legend()
#plt.show()
plt.savefig("rugged_coefficients.png")

In [71]:
predictive = pyro.infer.Predictive(simple_Bayesian_model, guide=custom_guide, num_samples=800)
svi_samples = predictive(is_cont_africa, ruggedness, log_gdp=None)
svi_gdp = svi_samples["obs"]

In [59]:
predictions = pd.DataFrame({
    "cont_africa": is_cont_africa,
    "rugged": ruggedness,
    "y_mean": svi_gdp.mean(0).detach().cpu().numpy(),
    "y_perc_5": svi_gdp.kthvalue(int(len(svi_gdp) * 0.05), dim=0)[0].detach().cpu().numpy(),
    "y_perc_95": svi_gdp.kthvalue(int(len(svi_gdp) * 0.95), dim=0)[0].detach().cpu().numpy(),
    "true_gdp": log_gdp,
})
african_nations = predictions[predictions["cont_africa"] == 1].sort_values(by=["rugged"])
non_african_nations = predictions[predictions["cont_africa"] == 0].sort_values(by=["rugged"])

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)
fig.suptitle("Posterior predictive distribution with 90% CI", fontsize=16)

ax[0].plot(non_african_nations["rugged"], non_african_nations["y_mean"])
ax[0].fill_between(non_african_nations["rugged"], non_african_nations["y_perc_5"], non_african_nations["y_perc_95"], alpha=0.5)
ax[0].plot(non_african_nations["rugged"], non_african_nations["true_gdp"], "o")
ax[0].set(xlabel="Terrain Ruggedness Index", ylabel="log GDP (2000)", title="Non African Nations")

ax[1].plot(african_nations["rugged"], african_nations["y_mean"])
ax[1].fill_between(african_nations["rugged"], african_nations["y_perc_5"], african_nations["y_perc_95"], alpha=0.5)
ax[1].plot(african_nations["rugged"], african_nations["true_gdp"], "o")
ax[1].set(xlabel="Terrain Ruggedness Index", ylabel="log GDP (2000)", title="African Nations");
fig.savefig("rugged_predictive.png")

In [61]:
auto_guide = pyro.infer.autoguide.AutoNormal(simple_Bayesian_model)

In [None]:
mvn_guide = pyro.infer.autoguide.AutoMultivariateNormal(simple_Bayesian_model)

In [64]:
%%time
pyro.clear_param_store()
svi = pyro.infer.SVI(simple_Bayesian_model,
                     mvn_guide,
                     pyro.optim.Adam({"lr": 0.02}),
                     pyro.infer.Trace_ELBO())

losses = []
for step in range(1000 if not smoke_test else 2):
    loss = svi.step(is_cont_africa, ruggedness, log_gdp)
    losses.append(loss)
    if step % 100 == 0:
        logging.info("Elbo loss: {}".format(loss))



Elbo loss: 248.77821505069733
Elbo loss: 247.36815518140793
Elbo loss: 246.61216068267822
Elbo loss: 247.67040252685547
Elbo loss: 249.9051572084427
Elbo loss: 249.81669574975967
Elbo loss: 248.0096670985222
Elbo loss: 248.1366507411003
Elbo loss: 247.68013435602188
Elbo loss: 248.20837092399597


CPU times: user 22.2 s, sys: 162 ms, total: 22.3 s
Wall time: 3.73 s


In [65]:
plt.figure(figsize=(5, 2))
plt.plot(losses)
plt.xlabel("SVI step")
plt.ylabel("ELBO loss")
plt.savefig("rugged_MVN_ELBO.png")


In [68]:
with pyro.plate("samples", 800, dim=-1):
    mvn_samples = mvn_guide(is_cont_africa, ruggedness)

print(mvn_samples.keys())


dict_keys(['a', 'bA', 'bR', 'bAR', 'sigma'])


In [69]:
mvn_gamma_within_africa = mvn_samples["bR"] + mvn_samples["bAR"]
mvn_gamma_outside_africa = mvn_samples["bR"]

# Interface note: reuse guide samples for prediction by passing them to Predictive
# via the posterior_samples keyword argument instead of passing the guide as above
assert "obs" not in mvn_samples
mvn_predictive = pyro.infer.Predictive(simple_Bayesian_model, posterior_samples=mvn_samples)
mvn_predictive_samples = mvn_predictive(is_cont_africa, ruggedness, log_gdp=None)

mvn_gdp = mvn_predictive_samples["obs"]

In [74]:
svi_samples = {k: v.detach().cpu().numpy() for k, v in samples.items()}
svi_mvn_samples = {k: v.detach().cpu().numpy() for k, v in mvn_samples.items()}


In [82]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)

ax[0].scatter(svi_samples["b_a"], svi_samples["b_r"])
ax[0].set(xlabel="b_a", ylabel="b_r", title="Mean field guide");
ax[1].scatter(svi_mvn_samples["bA"], svi_mvn_samples["bR"])
ax[1].set(xlabel="b_a", ylabel="b_r", title="Full rank guide");
fig.savefig("rugged_autoguide_comp.png")

In [84]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)

ax[0].scatter(svi_samples["b_a"], svi_samples["b_ar"])
ax[0].set(xlabel="b_a", ylabel="b_ar", title="Mean field guide");
ax[1].scatter(svi_mvn_samples["bA"], svi_mvn_samples["bAR"])
ax[1].set(xlabel="b_a", ylabel="b_ar", title="Full rank guide");
fig.savefig("rugged_autoguide_comp2.png")

In [88]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)

ax[0].scatter(svi_samples["b_r"], svi_samples["sigma"])
ax[0].set(xlabel="b_r", ylabel="sigma", title="Mean field guide");
ax[1].scatter(svi_mvn_samples["bR"], svi_mvn_samples["sigma"])
ax[1].set(xlabel="b_r", ylabel="sigma", title="Full rank guide");
fig.savefig("rugged_autoguide_comp3.png")


In [89]:
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 6), sharey=True)

ax[0].scatter(svi_samples["b_a"], svi_samples["sigma"])
ax[0].set(xlabel="b_a", ylabel="sigma", title="Mean field guide");
ax[1].scatter(svi_mvn_samples["bA"], svi_mvn_samples["sigma"])
ax[1].set(xlabel="b_a", ylabel="sigma", title="Full rank guide");
fig.savefig("rugged_autoguide_comp4.png")


In [91]:
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(12, 6))
fig.suptitle("Density of Slope : log(GDP) vs. Terrain Ruggedness");

sns.histplot(gamma_within_africa.detach().cpu().numpy(), ax=axs[0], kde=True, stat="density", label="African nations")
sns.histplot(gamma_outside_africa.detach().cpu().numpy(), ax=axs[0], kde=True, stat="density", color="orange", label="Non-African nations")
axs[0].set(title="Mean field", xlabel="Slope of regression line", xlim=(-0.6, 0.6), ylim=(0, 11))

sns.histplot(mvn_gamma_within_africa.detach().cpu().numpy(), ax=axs[1], kde=True, stat="density", label="African nations")
sns.histplot(mvn_gamma_outside_africa.detach().cpu().numpy(), ax=axs[1], kde=True, stat="density", color="orange", label="Non-African nations")
axs[1].set(title="Full rank", xlabel="Slope of regression line", xlim=(-0.6, 0.6), ylim=(0, 11))

handles, labels = axs[1].get_legend_handles_labels()
fig.legend(handles, labels, loc='upper right');
fig.savefig("rugged_mvn_hist.png")

In [92]:
? pyro.infer.Trace_ELBO

[0;31mInit signature:[0m
 [0mpyro[0m[0;34m.[0m[0minfer[0m[0;34m.[0m[0mTrace_ELBO[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mnum_particles[0m[0;34m=[0m[0;36m1[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_plate_nesting[0m[0;34m=[0m[0minf[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_iarange_nesting[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mvectorize_particles[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mjit_options[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mretain_graph[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtail_adaptive_beta[0m[0;34m=[0m[0;34m-[0m[0;36m1.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
A trace implementation of ELBO-based SVI. The estimator is constructed
along the lines of references [1] and [2]. There are no restrictions on the
dependency structure of t