Statics methods for MC ensembles

2017 September 5
by Daniel Lakeland

Recently I've been working on problems for which Stan struggles. The typical situation is that I have a large model with many parameters, usually most of the parameters represent some measurement for an individual item. For example a Public Use Microdata Area in the Census, or the behavior of some individual who has been surveyed, or the expression level of some gene, or what have you. And then there are a small number of parameter that describe the group-level properties of these individual measurement taken as an ensemble (a hierarchical model)

The symptoms of this struggling is that Stan adapts its step size very small, uses long trajectories, maxes out its treedepth, takes many iterations to get into the typical set, and then generally takes a long time to sample, and probably has divergences anyway. Often I kill the sampler before it even really adapts because I see the stepsize has crashed down to 1e-7 or the like and each iteration is taking tens of minutes to an hour.

Now, when we have better diagnostics for divergences, perhaps this kind of problem will be remedied somewhat. But in the mean time, solving dynamics problems is known to be hard, a lot harder typically than solving statics problems. For example, if you want to put electrons on the surface of a sphere in a configuration that minimizes total potential energy, you can do so a lot easier computationally than you can solving the dynamics of those electrons if you give each one of them a velocity constrained to the sphere... Solving the dynamics problem requires you to take small timesteps and move all the electrons along complicated curves. Solving the statics problem just requires you to simultaneously perturb all the positions in an energy minimization way.

So, one thought I had was to solve an optimization problem on ensembles.

Here's one way to think about it: suppose p(x) is a probability density on a high dimensional space x. We want to create a sample of say N points in the X space, let N = 100 or so for example.

One thing we know is that if \{x_i\} is any set of points in the X space then for K a nonstandard number of steps \{x_i'\} = \mathrm{RWMH}(x_i,K) is an ensemble in equilibrium with this p(x) distribution. Where RWMH(x,K) means K random walk metropolis hastings updates from each starting point x_i.

Furthermore, once in equilibrium, it stays in equilibrium even for K=1. Of course for K=1 it is possible that each point is rejected in the MCMC update step, and so the ensemble is also unchanged. But as both the size of the ensemble increases and K increases, this non-movement becomes dramatically improbable. In particular, we can choose something like an isotropic gaussian kernel with sufficiently small standard deviation that we get something like 10% chance of acceptance in each point, and with 100 points, the probability that the ensemble is unmoved after 1 update per point is 0.9^100 ~ 1e-5. Of course, the distance moved may not be very large, but we will have some movement of at least some of the points. If we make K = 10 or so, we'll get some movement of most of the points.

However, we want to guarantee that we do get large movement relative to our usually terrible starting point. One way to get movement is to just continue moving each particle along a line:

For each particle:

x(t) = x(0) + t (RWMH(x(0),K) - x(0))

When mapped across the ensemble, this can be thought of as transport of the ensemble in the direction of an acceptable random perturbation (acceptable because the RWMH accepted that update).

How far should we allow this transport to go? Obviously if t goes to infinity, some of the particles go to infinite distances and since our probability density is normalizable this means some of the particles go well outside the high probability region.

One particular expectation that is of interest is the ensemble entropy:

\frac{1}{N} \sum_{i=1}^N -\log(p(x_i))p(x_i)

This has the advantage that the function f(x) = -\log(p(x)) has domain equal to the domain of the density, and is therefore sensitive to all the regions of space. Furthermore, because p(x) goes to zero outside some ball, the function -\log(p(x)) is positive and so as t increases the entropy of the ensemble eventually increases.

The result is, that if x is out of equilibrium, initially as a function of increasing t the ensemble entropy should first decrease (because the RWMH accepts higher probability transitions and many of the transitions will be towards higher probability) then after a while, it should become somewhat constant, then as time goes on it has to increase again. The region of space near the entropy minimum is what we want. Yet, we don't want to just send all the points to the mode... that would be the lowest entropy ensemble possible.

This whole concept is related to concentration of measure. The entropy of an ensemble of 100 randomly chosen iid points is a function of 100 random variables and is therefore more or less constant under different realizations. The region of constant entropy is called the typical set and contains essentially 100% of the probability mass.

Now, suppose that \{x_i\} starts in equilibrium. Then, as t increases, ensemble entropy may decrease or increase initially, but the change will not be dramatic, and as t increases through time eventually entropy increases again. Since p(x) is typically smooth, the ensemble entropy as a function of t is also smooth. If it comes to a minimum somewhere, then in the vicinity of that minimum it acts like a parabola. How far should we go? Suppose that the ensemble entropy has a central limit type theorem such that when the ensemble size is large enough the ensemble entropy is distributed according to Normal(e,s(N)) for e the actual entropy of the distribution, and s some decreasing function of ensemble size N. This manifests as the trajectory having a parabolic minimum entropy. If we sample t with weights according to \exp(-e(x(t))) where e is the ensemble entropy then we should come into equilibrium (intuitively, I have no idea if this works in reality). Note that initially we are going to always go forward in time, but at equilibrium we need to project t both forward and backward.

Some questions: why doesn't this collapse to finding the mode?

The reason is that at equilibrium each individual random walk metropolis step goes either up or down in the distribution equally, so some points are heading into the tail, and others into the core. If we get close to the mode, the RWMH has more ways to go away from the mode and so we will get directions that mostly point away from it. If we get too far from the mode, RWMH will tend to ratchet upwards and give us a direction back into the typical set.

Is this any better than random walk?

I don't know, intuitively we're using a random walk to get a random perturbation direction that at equilibrium leaves the ensemble in equilibrium. So in the direction of this perturbation, we're moving along the typical set manifold. We do so in a way that more or less equally weights all points that we visited which were in the typical set (ie. if the entropy is constant along our curve, then we go to any of those points equally likely). It seems likely that this devolves to a random walk when the typical set is highly curved, and works more like HMC when the typical set is elongated. Now, Stan detects when there is a lot of curvature and drops the step size, but it then keeps the same step size for all the trajectories. so the computation required is based on the worst case. Perhaps this other technique for following the typical set as an ensemble is more robust to allowing larger step size when it's warranted? Also, it doesn't require calculating gradients. In some sense, the RWMH step finds us a direction which is pointed along the typical set, and so has zero ensemble entropy gradient on average.

The nice thing is, I have been working on some example problems in Julia, so I think I have most of the requirements in place to test this on some example problems. I'll report back.

 

One Response leave one →
  1. ojm permalink
    September 7, 2017

    Do you have a 'minimal' working example of a problem that Stan struggles with?

    I'm interested in collecting a few of these examples and trying out some alternatives.

Leave a Reply

Note: You can use basic XHTML in your comments. Your email address will never be published.

Subscribe to this comment feed via RSS