In [6]:
import numpy as np
import scipy.linalg as la
import matplotlib.pyplot as plt

## Felsenstein pruning algorithm.

We assume a continuous time Markov chain (CTMC) operating along the branches of the tree. The CTMC is defined by a rate matrix $Q$.

In this example, we will use [Jukes-Cantor model](https://en.wikipedia.org/wiki/Models_of_DNA_evolution#JC69_model_(Jukes_and_Cantor_1969)) on DNA nucleotides $\mathcal{B} = \{A, C, G, T\}$.

This means that the rate matrix is defined by the instantaneous rate of mutation, $Q_{ij} = \mu/4$ for some $\mu > 0$ and $Q_{ii} = -\sum_{j \not= i} q_{ij}$, which in this case yields $Q_{ii} = -3 \mu /4$.

We have a length associated with each branch of the tree. We will focus on computing a likelihood table for node $u$ with children $v, w$ (a cherry). The branch lengths for edge $(u,v)$ and $(u,w)$ are denoted $t_{uv}, t_{uw}$.

To obtain the transition matrix along an edge with branch length $t$, we perform matrix exponentiation:

$$P(t) = \exp(Qt).$$

The stationary distribution or equilibrium of the CTMC is denoted by $\pi$. This is usually given by $\pi Q = 0$ such that $\sum_j \pi_j = 1$ for $j \in \mathcal{B}$. 

$\pi$ is a vector with 4 components corresponding to each DNA nucleotide base: $\pi = (\pi_A, \pi_C, \pi_G, \pi_T)$. In the case of Jukes-Cantor model, $\pi_A = \pi_C = \pi_G = \pi_T = 0.25$.

It is also helpful to think of the stationary distribution as $P(\infty) = \lim_{t \to \infty} \exp(Qt)$. Note that $P(\infty)$ is matrix where each row is given by $\pi$.

In evolutionary biology, it is commonly assumed (and reasonable) that a very long time has passed and that we have reached stationarity.

In [7]:
mu = 1.2
Q = (mu/4) * np.ones((4,4))
np.fill_diagonal(Q, -3*mu/4)
print(Q.sum(1))
print(Q)
pi = np.ones(4)/4
print(pi)

[1.11022302e-16 1.11022302e-16 5.55111512e-17 0.00000000e+00]
[[-0.9  0.3  0.3  0.3]
 [ 0.3 -0.9  0.3  0.3]
 [ 0.3  0.3 -0.9  0.3]
 [ 0.3  0.3  0.3 -0.9]]
[0.25 0.25 0.25 0.25]


In [8]:
t_uv = 1.2
t_uw = 0.8

In [9]:
P_uv = la.expm(Q*t_uv)
P_uw = la.expm(Q*t_uw)
print(P_uv)
print(P_uw)

[[0.42769582 0.19076806 0.19076806 0.19076806]
 [0.19076806 0.42769582 0.19076806 0.19076806]
 [0.19076806 0.19076806 0.42769582 0.19076806]
 [0.19076806 0.19076806 0.19076806 0.42769582]]
[[0.53716966 0.15427678 0.15427678 0.15427678]
 [0.15427678 0.53716966 0.15427678 0.15427678]
 [0.15427678 0.15427678 0.53716966 0.15427678]
 [0.15427678 0.15427678 0.15427678 0.53716966]]


In [10]:
print(P_uv.sum(1))
print(P_uw.sum(1))

[1. 1. 1. 1.]
[1. 1. 1. 1.]


## Base case (initialization)

To initialize Felsenstein pruning algorithm, we fill out the likelihood table for each of the observed nodes.

Let $Y_u, Y_v, Y_w$ denote the sequences at nodes $u,v,w$. 

Here, $Y_u$ is unobserved while $Y_u, Y_w$ are observed. So, the likelihood table for $v,w$ can be filled out by indicating the nucleotide that was observed for each site (loci). 


In [11]:
y_v = "AAACCGTCA"
y_w = "AACCCGTCT"
L = len(y_v)
L

9

In [12]:
D_v = np.zeros((4, L))
D_w = np.zeros((4, L))
D_u = np.zeros((4, L))

In [13]:
def ch_to_idx(ch):
    if ch == "A":
        return 0
    elif ch == "C":
        return 1
    elif ch == "G":
        return 2
    elif ch == "T":
        return 3

In [14]:
y_v[1]

'A'

In [15]:
for l in range(L):
  i = ch_to_idx(y_v[l])
  D_v[i,l] = 1
  i = ch_to_idx(y_w[l])
  D_w[i,l] = 1

In [16]:
print(y_v)
print(D_v)

AAACCGTCA
[[1. 1. 1. 0. 0. 0. 0. 0. 1.]
 [0. 0. 0. 1. 1. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 0.]]


In [17]:
print(y_w)
print(D_w)


AACCCGTCT
[[1. 1. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 1. 1. 1. 0. 0. 1. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 1. 0. 1.]]


## Pruning recursion

The likelihood table for $u$ is to be filled in using the likelihood tables for $v, w$.

Let $Y_{\lfloor u \rfloor}$ denote the set of observed sequences below node $u$. In the case with a cherry, it's just $Y_{\lfloor u \rfloor} = (Y_v, Y_w)$.

Note that the sequence at $Y_u$ is unobserved and we wish to marginalize over its values to compute the marginal likelihood of the observation $p(Y_{\lfloor u \rfloor})$.

The recursion for $u$ at site $l$ is given by,

$$D_u^l[i] = p(Y_{\lfloor u \rfloor}^l | Y_u^l = i),$$

for $i \in \mathcal{B}$. We can derive this in terms of the likelihood tables for $v,w$,

\begin{align}
    p(Y_{\lfloor u \rfloor}^l | Y_u^l = i) &= \left(\sum_j p(Y_v^l = j | Y_u^l = i) p(Y_{\lfloor v \rfloor}^l | Y_v^l = j)\right) \times \\
    & \left(\sum_j p(Y_w^l = j | Y_u^l = i) p(Y_{\lfloor w \rfloor}^l | Y_w^l = j)\right).
\end{align}

In matrix form:

$$D_u = (P(b_{uv}) D_v) \odot (P(b_{uw}) D_w).$$

To obtain the marginal likelihood, we take a dot product with the stationary distribution:

$$p(Y_{\lfloor u \rfloor}^l) = p(Y_u^l = j) D_u[j,l] = \pi^T D_u.$$

Finally, we assume site independence, that each site evolves independently to combine the marginal likelihood over sites:

$$p(Y_{\lfloor u \rfloor}) = \prod_l p(Y_{\lfloor u \rfloor}^l).$$

In [18]:
D_u = (P_uv @ D_v) * (P_uw @ D_w)

In [23]:
log_lik = np.sum(np.log(pi.T @ D_u))
print(log_lik)

-23.458518484870616


## Top-down pass

We will first convert the phylogenetic tree $G=(V, E)$ into a factor graph and then derive the message passing algorithm.

### Factor graph construction

The factor graph is an undirected bipartite graph $G'=(V, F, E')$ that comprises two distinct sets of nodes: $V$ denotes variable nodes and $F$ denotes factor nodes. In this graph, the nodes of the same type are never connected; edge $e \in E'$ only links a variable node to a factor node.

1. Variable nodes: the factor graph inherits the variable node from the original graph $G$.

2. Factor nodes: for each edge $e = (u,v) \in E$, we create a factor node $f_{uv}$. 

3. Edges: add an edge $\{u, f_{uv}\}$ and $\{f_{uv}, v\}$ (note the graph is undirected so the edge ordering does not matter as indicated by the set notation). Each of these factors define conditional probability $f_{uv}(y_u, y_v) = p(Y_v = i | Y_u = j)$ where $u$ is the parent of $v$.

Additionally, we create a factor for each of the leaf node $w \in V$ and an edge $(w, f_w)$ to $E'$. These factors are merely an indicator function: $f_w(i) = 1[y_w == i]$, to clamp on the observed value $y_w$ (recall that we observe sequences at the leaf nodes).

### Message passing

The message passing or sum-product algorithm on factor graph consists of sendings messages between along the edges. There are two types of messages: node to factor factor to node.

Let $n(v), n(f)$ denote the neighbors of a variable node $v$ and a factor node $f$ respectively. 

The message from a variable node $v$ to a factor $f$ is given by $m_{v \to f}(y_v) = \prod_{h \in n(v) \setminus \{f\}} m_{h \to v}(y_v)$. Essentially, the nodes collect all of the messages from its neighbors except from $f$ and routes it to $f$.

The message from a factor node $f$ to a variable node $v$ is given by $m_{f \to v}(y_v) = \sum_{y_{n(f) \setminus v}} f(y_{n(f)}) \prod_{u \in n(f) \setminus \{v\}} m_{u \to f}(y_u)$. Essentially, we sum over all of the possible values taken by the variable nodes attached to $f$, except $v$.

Note that we denote $f$ as a function of all of the variables $(y_v)$, in reality, it only depends on $y_v$ and $y_u$ or $y_w$ for a unary factor attached  a leaf node $w$.
