Deconstructing AlphaGo: Exploring Monte Carlo Tree Search

What do you do when there are too many possibilities?

In what's since become a watershed moment in AI, DeepMind's AlphaGo defeated 18-time Go champion Lee Sedol in 2016, about a decade earlier than expected. DeepMind would then go on to produce a flurry of advancements: first AlphaGo Zero, which taught itself instead of learning from humans, and then AlphaZero, which extended the approach to other board games like chess and shogi. Throughout these advancements was an approach known as Monte Carlo tree search, or MCTS for short, but when I first got into learning about it, I couldn't find any good resources. To fix this, I've written this blogpost, that hopefully demystifies some of the confusion for you too. I'll start by giving the basic idea, and then formalizing it as needed.

When the Action Space is Too Big

Let's say you've just wrapped up a chess game, and you're using the state-of-the-art computer program Stockfish to get feedback. What Stockfish does under the hood is to list all possible responses to a move (in what's known as the action space), and then their responses, and so on until it's searched deep enough to provide a good answer. Since most continuations of a position are bad (for instance, if your queen is attacked, you probably won't move anything other than the queen), Stockfish uses a technique called alpha-beta pruning to stop searching lines that it knows are suboptimal. This, combined with an extensive array of heuristics to eliminate moves that are likely to be poor, is what makes Stockfish the undisputed champion at computer chess.

Alpha-beta pruning chess search tree An abstract four-level minimax tree of circular nodes representing a chess search from the position after 1.e4 e5 2.Nf3 Nc6. Each edge is annotated with a plausible move. A purple principal variation runs from root to a leaf, several branches are pruned (shown as faded purple dashed edges with cut marks and pruned labels), and a legend at the bottom keys the line styles. Bb5 Bc4 a6 Nf6 Bc5 Nf6 (pruned) Ba4 Bxc6 O-O d3 (pruned) c3 O-O d3 Ng5 principal variation pruned searched
Searching after 1. e4 e5 2. Nf3 Nc6 with alpha-beta

Enter Go: unlike chess, which has about 30 or so moves each position, Go has about 250, so even a list of all the possible continuations becomes computationally infeasible to store. 11 Formally, we'd say that the branching factor is too high If you can't list the moves, you can't try them; and if you can't try them, alpha-beta can't be sure that they're suboptimal to prune. We need to do something different.

Building the Tree Step-by-Step

You may have noticed a central tension in alpha-beta: we enumerate moves so that we can try them, and then immediately discard them when we find out that they're bad. What if we could somehow avoid considering those bad moves to begin with? The problem is that alpha-beta is "blind": it doesn't know that a move is bad until it's tried it, unlike a great human player, who can instinctively tell what moves are promising. What we need is some search strategy (in jargon, a policy) that assigns scores to moves a priori, and then tries out moves accordingly (neglecting moves that are too poor to even consider). Instead of pruning the giant tree of all possible continuations, we can grow the tree by iteratively picking new moves, evaluating the resulting positions, and using those evaluations to update what the best moves are. This process is all there is to MCTS.

Stage 1: Selection

Before we explore new moves, we need to pick some position along the game tree as our starting point. Aptly, selecting the right position (or node) in the tree is called selection, and is done using three quantities:

  • the node value, based on previous results from having walked along this position,
  • the visit count, or how many times we've walked this node before, and
  • the policy value, or the initial assessment of how good this node is

Starting from the root node (that is, the node corresponding to no moves searched at the top of the tree), we combine these three quantities using a formula called PUCT (more on that below) to progressively walk the tree until we reach a leaf node, that is, a node that has no edges leading out of it.

MCTS stage 1 — Selection Minimax-style Ruy Lopez tree (pruned branches removed). A purple PUCT path descends from the root through Bb5 and a6 to the leaf Bxc6, marked as the node to expand. Arrowheads show the downward direction of selection. Bb5 Bc4 a6 Nf6 Bc5 Ba4 Bxc6 O-O c3 O-O leaf selected path leaf to expand
The selection phase after 1. e4 e5 2. Nf3 Nc6

Notice how the tree is asymmetric: since we choose the search the moves we want, some lines get searched deeper than others. In particular, we haven't searched the lines that were pruned! 22 Technically, this assumes our policy tracks the alpha-beta search exactly, but it's a good intuition in practice

Stage 2: Expansion

Once we've picked a leaf node, the natural next step is to consider the moves we can play from it. This step is known as expansion: for the chosen leaf node, we expand the tree by adding nodes for all the continuations one move deeper. 33 Older versions of MCTS, based on the algorithm UCB1, only expand one node at a time

MCTS stage 2 — Expansion The selected leaf Bxc6 is expanded into all of its children — the two recaptures dxc6 and bxc6 — drawn as fresh dashed nodes. Each new edge carries a symbolic prior p supplied by the policy head. Bb5 Bc4 a6 Nf6 Bc5 Ba4 Bxc6 O-O c3 O-O dxc6 bxc6 new child (all moves expanded)
The expansion phase after 1. e4 e5 2. Nf3 Nc6

Note that in some implementations, we might only expand after having visited the node enough times! In this case, we jump straight to the next step

Stage 3: Simulation

From our newly expanded leaf node (not it's children!), we then simulate the outcome of the game by playing random moves until the end (explaining the name Monte Carlo), or, as is more common these days, substitute the simulation with a neural network that approximates the end result directly. 44 The original version of AlphaGo used a tunable mix of both these approaches, but this was later dropped in favor of deep learning . When we do this, our leaf node gets assigned a value (the node value from earlier) representing the expected utility from playing it, which is then stored for future reuse.

MCTS stage 3 — Evaluation Instead of a rollout, the network fθ scores the just-expanded leaf Bxc6 directly. A small network glyph emits an arrow to the leaf, and a symbolic value meter v sits beside it — no numbers, no simulation playout. Bb5 Bc4 a6 Nf6 Bc5 Ba4 Bxc6 O-O c3 O-O dxc6 bxc6 v value net fθ leaf value v (no simulation)
The simulation phase after 1. e4 e5 2. Nf3 Nc6

Stage 4: Backpropagation

Once we have our node value, we then recursively use it to update the values of all the nodes in the path we took. Specifically, we add our node value to the node values of the previous nodes, and increment their visit counts by one (allowing us to compute the average value when needed).

MCTS stage 4 — Backup The leaf value v is propagated up the selection path. Upward arrows run from Bxc6 through a6 and Bb5 to the root, and each visited node carries a pulse ring indicating its visit count N and mean value Q have been updated. Bb5 Bc4 a6 Nf6 Bc5 Ba4 Bxc6 O-O c3 O-O dxc6 bxc6 v value backed up N, Q updated
The backpropagation phase after 1. e4 e5 2. Nf3 Nc6

Congrats! You've successfully completed one iteration of Monte Carlo tree search!

Formalizing the Algorithm

Setup and Notation

When formalizing any bit of code, it's always good practice to give names to what we're working with. Given any node, call it $s$ (for state), and possible move, call it $a$ (for action), let's say that

  • the initial estimate for how good $a$ is at $s$ is called $P(s,a)$,
  • the number of times we've tried $a$ at $s$ is called $N(s,a)$,
  • the running total of accumulated utilities is called $W(s,a)$, and
  • the average utility, computed from the previous two quantities, is called $Q(s,a)$

Notice how these values map cleanly onto the values I've introduced earlier: the node value is $Q(s,a)$, the visit count is $N(s,a)$, and the policy value is $P(s,a)$ 55 As I've indicated it, the value $Q(s,a)$ really should be called the edge value, but since each node only has one edge going into it, we can safely store edge values in the nodes

Selection with pUCT

All that's left is to describe the formula we use to actually select nodes to traverse. Properly, it's this behemoth

$$ a^{*}=\arg\max_{a}\left[Q(s,a)+c_{\text{pUCT}}\cdot P(s,a)\cdot\frac{\sqrt{\sum_{b}{N(s,b)}}}{1+N(s,a)}\right] $$

so let's break it down step by step. The expression in the brackets consists of two terms: $Q(s,a)$, which corresponds to the expected utility of making this move, and the giant term involving $P(s,a)$, which corresponds to our initial assessment. Let's say (unrealistically) that $P$ is always zero, so we are just selecting the move with the highest $Q$. 66 Technically, this is impossible, as we require $P$ to be a probability distribution . This corresponds to the reinforcement learning idea of exploitation: picking the best move of the ones we've already seen. Now let's suppose $Q$ was always zero, so that we're maximizing the term on the right. This is the prior assessment of a move, weighted by a factor inversely proportional to the fraction of visits spent on it (corresponding to the idea of exploration).

But let's look at this factor in more detail, since it can seem rather arbitrary. Our visit fraction is simply $\frac{N(s,a)}{\sum_{b}N(s,b)}$, so inverting it gives $\frac{\sum_{b}N(s,b)}{N(s,a)}$. But there's a problem: what if we've never taken action $a$? This would lead to a division by zero, so it's common practice in machine learning to add a small constant $\epsilon$ to the denominator for numerical stability, say $1\cdot 10^{-8}$. However, dividing by this would still give an unreasonably large number, and we'd like our quantities to remain sensible! Instead, we can view the expansion into a node (from stage 2) as a visit, even as we've never actually walked down that path! 77 The astute reader might be wondering why we don't do this for the numerator. Truth is, this formula was largely arrived at empirically, and there isn't a simple first-principles explanation for it

That leaves the puzzle of the square root. For that, let's suppose we didn't have a square root. In that case, as we kept searching, our inverse visit fraction would approach some constant, as the computer figures out the optimial visit fractions for each node. But if we search for longer (that is, both $N(s,a)$ and $\sum_{b}N(s,b)$ are higher), we can be more confident in our (empirical) values for $Q$, so it makes more sense to rely on them instead, and just pick the move that maximizes that. 88 This is an instance of the famous exploration-exploitation tradeoff in reinforcement learning . By taking the square root, we cause this fraction to decay to zero, causing the greater emphasis on $Q$ just as we'd like. $c_{\text{pUCT}}$ is just a constant used to control the transition.

Training the Networks

I mentioned that it's common to swap out the random simulation for a neural network, so all that's left is training it. Let's call our network $f_{\theta}$: it takes in a state $s$ and returns two things: a node value estimate $v$ and policy scores $\mathbf{p}$ (represented as a vector over all moves). Tuning the value estimate is easy: we simply regress it against the game result $z$ via mean squared error:

$$ \mathcal{L}_{\text{value}}=(z-v)^{2} $$

where we set $z=1$ for a win, $-1$ for a loss, and $0$ for a draw. But what about the policy? For that, observe that the visit counts $N$ are a product of both the raw policy $P$ and the empirical values $Q$, and so give us a better picture of the moves that are actually worth looking at. After all, the most visited moves are the ones pUCT found most promising, and pUCT is inherently more than just $P$. Therefore, we can train the policy against the (normalized) visit counts $\boldsymbol{\pi}$ via the standard cross-entopy loss 99 Actions are discrete, so MSE is not appropriate here :

$$ \mathcal{L}_{\text{policy}}=-\boldsymbol{\pi}^{\top}\log\mathbf{p} $$

Lastly, we add a small weight decay penalty to prevent the network from overfitting, giving the final loss

$$ \mathcal{L}=(z-v)^{2}-\boldsymbol{\pi}^{\top}\log\mathbf{p}+c\left\|\theta\right\|^{2} $$

Looking Ahead: Formalism and Extensions

As it stands, our investigation is still a bit unsatisfying. I've explained the pUCT formula to you, but I haven't convincingly told you why we pick this formula over anything else. Sure, the discovery may have been empirical, but hasn't there been any theory that followed? In fact, there has! And while it's much more technical, it's deeply satisfying, and will be the subject of my sequel post. Thanks for reading!