Last time, I explained the core idea behind the machine learning algorithm known as Monte Carlo tree search, and mentioned that the core equation behind it (pUCT) was derived largely by experiment. However, there's since been progress in understanding it theoretically, and this sequel will focus on how one might derive it. I'll warn you though, there's a lot of computation ahead until the satisfying result. 11 I largely follow the approach in Grill et al. (2020), for a deeper view, including generalizations to a broad family of algorithms, please have a read!
Setup and Notation
In a slight change from last time (to something that's more in line with mainstream reinforcement learning literature), let's call each state (such as a board position) $x$, each action (such as a move) $a$, and the policy value (or just policy) of $a$ at $x$ $\pi_{\theta}(a|x)$, where I'm using $\theta$ to indicate that the policy is a function of some neural network's parameters. Now let's say $\hat{\pi}$ is the empirical visit distribution, that is, the normalized visit counts for each action from a given state 22 The use of the policy-style notation $\hat{\pi}$ is intentional, this will turn out to be the target to which we're optimizing our policy
$$ \hat{\pi}(a|x)=\frac{1+n(x,a)}{|\mathcal{A}|+\sum_{b}{n(x,b)}} $$
Unlike last time, we add an extra factor of $|\mathcal{A}|$, or the total number of actions available. Moving forward, I'll adopt the shorthand $n_{a}$ when there's only one choice for $x$, and $N(x)$ for the number of times we took any action from $x$ (or $\sum_{b}n_{b}$). Additionally, let's define the special constant $\lambda_{N}$ by
$$ \lambda_{N}=c\cdot\frac{\sqrt{\sum_{b}n_{b}}}{|\mathcal{A}|+\sum_{b}n_{b}} $$
Now let's recall the pUCT formula from last time, it was choosing the action $a^{\star}$ that maximized
$$ a^{\star}=\arg\max_{a}\left[Q(x,a)+c_{\text{pUCT}}\cdot\pi_{\theta}(a|x)\cdot\frac{\sqrt{\sum_{b}n_{b}}}{1+n_{a}}\right] $$
If I force $c=c_{\text{pUCT}}$, I can use our two newly defined quantities to write 33 Specifically, $\lambda_{N}$ absorbs $c_{\text{pUCT}}$ and the square root and $\hat{\pi}$ absorbs our policy; what's left cancels
$$ a^{\star}=\arg\max_{a}\left[Q(x,a)+\lambda_{N}\cdot\frac{\pi_{\theta}(a|x)}{\hat{\pi}(a|x)}\right] $$
If we use boldface notation to form vectors over all actions, so that say, $\mathbf{q}$ is now the vector of $Q(x,a)$'s for all possible actions $a$, we can compactly write
$$ a^{\star}=\arg\max\left[\mathbf{q}+\lambda_{N}\frac{\boldsymbol{\pi_{\theta}}}{\boldsymbol{\hat{\pi}}}\right] $$
Seems pointless, but this will save me from writing the same thing over and over later!
Lastly, one useful quantity that will keep showing up is the Kullback-Leibler divergence between two distributions, which is defined as
$$ \mathrm{KL}[p,q]=\sum_{x}(p(x)\log p(x)-p(x)\log q(x)) $$
or the average difference between the log-probabilities of $P$ and $Q$ with respect to $P$. With a little algebra, we can simplify this to 44 This second expression is actually how KL-divergence is usually introduced, but it's less intuitive at a glance
$$ \mathrm{KL}[p,q]=\sum_{x}p(x)\log\left(\frac{p(x)}{q(x)}\right) $$
A Slightly Different Optimization Problem
Let's say we're given a policy $\boldsymbol{\pi_{\theta}}$, some vector $\mathbf{q}$ of values, and an empirical visit distribution $\boldsymbol{\hat{\pi}}$. We can use our visit distribution to define a "quasi-policy" that just assesses moves based on how often we've seen them 55 Of course, we can't really know this policy before searching a bit, but once we have it can still function as a means to assess moves , so suppose we'd like to choose the optimal action $\bar{a}$ but that doesn't drift too far away from our initial policy's best. 66 Since $\boldsymbol{\hat{\pi}}$ is ultimately constructed algorithmically from $\boldsymbol{\pi_{\theta}}$, they implicitly can't drift too far away from each other. This just makes that explicit. Formally, this looks like optimizing the regularized policy
$$ \bar{a}=\arg\max_{a}\left[\mathbf{q}^{\top}\boldsymbol{\hat{\pi}}-\lambda\mathrm{KL}[\boldsymbol{\pi_{\theta}},\boldsymbol{\hat{\pi}}]\right] $$
where we're using the KL-divergence from earlier as a measure of difference between the two policies. 77 Policies are, by definition, defined as distributions over actions as to how likely an action is to be best, so we can do this . Unlike our initial policy $\boldsymbol{\pi_{\theta}}$, $\boldsymbol{\hat{\pi}}$ keeps changing as we search more, so let's look at how the optimal action might change we're we to search a little bit more. 88 Technically, $n_{a}$ is discrete, but allowing it to take continuous values is what reveals the insight We'll start by expanding the inner expressions to their full forms:
$$ \begin{aligned} \frac{\partial}{\partial n_{a}}\Big(\mathbf{q}^{\top}&\boldsymbol{\hat{\pi}}-\lambda\mathrm{KL}[\boldsymbol{\pi_{\theta}},\boldsymbol{\hat{\pi}}]\Big)= \\ &\frac{\partial}{\partial n_{a}}\Bigg[\underbrace{\sum_{b}\left(q_{b}\cdot\frac{n_{b}+1}{|\mathcal{A}|+\sum_{c}n_{c}}\right)}_{A} \\ &+\underbrace{\lambda\sum_{b}\pi_{\theta}(b)\log\left(\frac{n_{b}+1}{(|\mathcal{A}|+\sum_{c}n_{c})\cdot\pi_{\theta}(b)}\right)}_{B}\Bigg] \end{aligned} $$
Let's tackle $A$ first. We can push the derivative inside the sum to get
$$ \frac{\partial A}{\partial n_{a}}=\sum_{b}\frac{\partial}{\partial n_{a}}\left(q_{b}\cdot\frac{n_{b}+1}{|\mathcal{A}|+\sum_{c}n_{c}}\right) $$
Since $n_{b}$ only increments when $b=a$, we can define an increment function with $\delta_{ab}$ equalling $1$ if $b=a$ and $0$ if not. 99 Some may recognize this function as the Kroenecker delta Using the quotient rule from calculus gives
$$ \begin{aligned} \frac{\partial A}{\partial n_{a}}&\left(\frac{n_{b}+1}{|\mathcal{A}|+\sum_{c}n_{c}}\right) \\ &=\frac{\delta_{ab}(|\mathcal{A}|+\sum_{c}n_{c})-(n_{b}+1)\cdot 1}{(|\mathcal{A}|+\sum_{c}n_{c})^{2}} \\ &=\frac{\delta_{ab}}{|\mathcal{A}|+\sum_{c}n_{c}}-\frac{n_{b}+1}{(|\mathcal{A}|+\sum_{c}n_{c})^{2}} \end{aligned} $$
and if I sum against $q_{b}$, I get
$$ \begin{aligned} \frac{\partial A}{\partial n_{a}}&\left(q_{b}\cdot\frac{n_{b}+1}{|\mathcal{A}|+\sum_{c}n_{c}}\right) \\ &=\frac{q_{a}}{|\mathcal{A}|+\sum_{c}n_{c}}+\underbrace{\left(-\frac{\sum_{b}q_{b}(n_{b}+1)} {(|\mathcal{A}|+\sum_{c}n_{c})^{2}}\right)}_{c_{A}} \end{aligned} $$
where I've collected all the terms independent of $n_{a}$ into $c_{A}$. Now moving onto $B$, I'll first rewrite the logarithm using basic algebra to make it easier:
$$ B=\lambda\sum_{b}\pi_{\theta}(b)\left[\log(n_{b}+1)-\log\left(|\mathcal{A}|+\sum_{c}n_{c}\right)-\log\pi_{\theta}(b)\right]. $$
Let's take these three terms step by step. Working it through, we get that
-
the first term is just $\frac{\partial}{\partial n_{a}}\log(n_{b}+1)=\frac{\delta_{ab}}{n_{b}+1}$, where I've used the same special function as before since $n_{b}$ only changes when $b=a$,
-
the second term is $\frac{\partial}{\partial n_{a}}\log\left(|\mathcal{A}|+\sum_{c}n_{c}\right)=\frac{1}{|\mathcal{A}|+\sum_{c}n_{c}}$ since $n_{a}$ only appears once in $\sum_{c}n_{c}$, so it's derivative is just $1$, and
-
the third term is $\frac{\partial}{\partial n_{a}}\log\pi_{\theta}(b)=0$, since our policy is fixed as we search.
Putting them all together, we get
$$ \frac{\partial B}{\partial n_{a}}=\lambda\sum_{b}\pi_{\theta}(b)\left[\frac{\delta_{ab}}{n_{b}+1}-\frac{1}{|\mathcal{A}|+\sum_{c}n_{c}}\right] $$
The first term again collapses down to one, and the second term is just $\frac{1}{|\mathcal{A}|+\sum_{c}n_{c}}$ since policies (like $\pi_{\theta}$) are always probability distributions. This leaves us with
$$ \frac{\partial B}{\partial n_{a}}=\frac{\lambda\pi_{\theta}(a)}{n_{a}+1}+\underbrace{\left(-\frac{\lambda}{|\mathcal{A}|+\sum_{c}n_{c}}\right)}_{c_{B}} $$
where I've again collected all the terms independent of $n_{a}$ into $c_{B}$
But now is when the magic happens. First, let's play the same trick I did in the setup to write the first term in terms of the visit distribution $\hat{\pi}$, giving
$$ \frac{\lambda\pi_{\theta}(a)}{n_{a}+1}=\frac{\lambda}{|\mathcal{A}|+\sum_{c}n_{c}}\cdot\frac{\pi_{\theta}(a)}{\hat{\pi}(a)} $$
If we add the two halves $A$ and $B$ together, combinding $c_{A}$ and $c_{B}$ into just $C$ along the way, we get the final closed form for our policy optimization:
$$ \frac{\partial}{\partial n_{a}}\left(\mathbf{q}^{\top}\boldsymbol{\hat{\pi}}-\lambda\mathrm{KL}[\boldsymbol{\pi_{\theta}}.\boldsymbol{\hat{\pi}}]\right)=\frac{1}{|\mathcal{A}|+\sum_{c}n_{c}}\left(\mathbf{q}+\lambda\frac{\boldsymbol{\pi_{\theta}}}{\boldsymbol{\hat{\pi}}}\right)+\underbrace{(c_{A}+c_{B})}_{c} $$
If we updated the visit distribution along this derivative, we would be incrementing the action that causes the objective to increase the most. 1010 This is related to the idea of stochastic gradient ascent used in the family of reinforcement learning methods known as policy gradients. The related method of stochastic gradient descent, is the backbone of all deep learning 1111 Of course, $n_{a}$ is discrete, so the step in this direction is relatively large, and the derivative may change along that step But look: the added constant $C$ is independent of $n_{a}$, so removing it doesn't change the optimal action, and the factor up front is always positive (since it's counting things), so removing it doesn't change things either. But then we get
$$ \arg\max_{a}\frac{\partial}{\partial n_{a}}\left(\mathbf{q}^{\top}\boldsymbol{\hat{\pi}}-\lambda\mathrm{KL}[\boldsymbol{\pi_{\theta}},\boldsymbol{\hat{\pi}}]\right)=\arg\max_{a}\left[\mathbf{q}+\lambda\frac{\boldsymbol{\pi_{\theta}}}{\boldsymbol{\hat{\pi}}}\right] $$
and so, after making the final step of forcing $\lambda=\lambda_{N}$ from the start, we've recovered pUCT. Specifically, taking the action that maximizes pUCT is the same as taking an optimization step of this regularized policy optimization problem, explaining why pUCT performs so well in practice.
What's Next: Extensions to MCTS
Having covered the core of how Monte Carlo tree search works, I'd next like to explore how it's been extended and improved over the years. Specifically, I plan to at least cover different regularizers for the optimization objective (going beyond KL-divergence), and planning in action spaces so large even MCTS ends up impractical. See you then!