MuZero - part 1

MuZero: A high level view

The previous post on the AlphaZero (AZ) algorithm served mainly as warm-up. It helped to build an intuition for Monte Carlo tree search (MCTS) and how this powerful heuristic jives well with neural networks. Even more interesting, however, I find MuZero (MZ; original paper). The MZ algorithm (or at least its pseudocode) has been out for a while, so I’ll keep the recap brief. I’d like to touch on some key difference with respect to AZ and how I decided to implement. This post concerns a simple version intended to start with, while the following post will focus on scaling. That should provide a good base to explore some potential MZ applications I’ve had in mind.

Overview

Crucially, MZ foregoes any direct access to an (environmental) simulator. Whereas AZ uses this simulator to search future trajectories via MCTS, MZ instead relies on a state embedding to perform MCTS in latent space. This has interesting implications for high-dimension state-spaces (e.g. in vision) by enabling search through a lower-dimensional embedding. The embedding is unconstrained and by training end-to-end it can be used in whatever way the neural net sees fit.

Overview
From Schrittweiser et al., visualizing MZ’s search (a), environment interaction (b), and training (c).

The algorithm relies on three networks:

Interaction with the environment can then proceed as follows:

  1. Observe the current environment
  2. Generate its embedding with representation function $h$.
  3. Perform MCTS by recurrently relying on dynamics function $g$ for the roll-out of states and prediction function $f$ to do value and policy estimation.
  4. Take an action based on the visit counts generated by MCTS.

Episodes of such interactions are saved to allow for offline-training. A trajectory is sampled from the replay buffer and first the observation is fed into $h$. We can then unroll the model at each step $k$ for $K$ steps with $g$ by feeding in the previous hidden state $s^k$ and real action $a_{t+k}$. The set of parameters ${\theta_f, \theta_h, \theta_g}$ are trained via back-prop simultaneously and end-to-end1. To do so, the policy suggestion $p^k$ generated by $f$ is aligned with the policy $\pi_{t+k}$ provided by the MCTS procedure. Meanwhile, $v^k \approx z_{t+k}$ and $r^k \approx u_{t+k}$ with $z_{t+k}$ as the sample return which is simply the final reward for board games (in which case $u_{t+k}=0$ due to absence of intermediate rewards). The latent state can thus come represent whatever is relevant to achieve this. Finally, the complete loss, including L2 regularization, looks as follows:

\[\large l_t(\theta) = \sum^K_{k=0} l^p (\pi_{t+k},p_t^k) + \sum^K_{k=0} l^v(z_{t+k},v_t^k) + \sum_{k=1}^K l^r (u_{t+k},r_t^k) + c \Vert \theta \Vert ^2\]

Moving from A(Z) to (M)Z

Concretely, inside the MCTS some important changes were made. For example, each time we expand a node, we rely on dynamics function $g$ to provide the associated next state:

def expand(self, turn, prior, dynamics):

    self.expanded = True
    # 1. generate child
    # 2. assign priors

    prior = prior[self.legal_moves]

    # add Dirichlet-noise
    noise = np.random.dirichlet([self.dir_alpha]*len(prior))
    prior = prior*(1-self.dir_frac) + noise*self.dir_frac
    prior /= np.sum(prior)

    for i, a in enumerate(prior):

        a_oh = F.one_hot(torch.tensor(self.legal_moves[i]), num_classes=9).float()
        next_state, _ = dynamics.predict(torch.cat([self.embedded_state, a_oh.view(1,-1)], dim=1)) 
        self.child[i] = Node(torch.tensor(next_state), turn*-1, self.legal_moves, self.dir_alpha, self.dir_frac)
        self.child[i].prior = prior[i].item()

Note that the Dirichlet-noise constitutes a heuristic which corrupts the prior. This functions as an exploration incentive and can help distribute the probability mass across the different actions, thereby promoting action selection stochasticity.

Further, due to the lack of access to an explicit simulator, MZ cannot rely on it to know whether the tree search has reaching a terminal state. While it makes the initial observation and can determine which moves are legal at that moment, it is not bound by the game once it starts its search in latent space. This is resolved during training, where terminal states (and any subsequent ones) are treated as absorbing states. As planning beyond terminal states should generally result in poor performance, pure reward-based learning appears to sufficiently address the issue.

idx_absorb = idx.copy() # indices corrected for absorbing states
id_k_list = [idx.copy()] # list of indices for each step k
NAB_list = [] # records whether a state is absorbing (inverse Boolean for easy indexing) 

for k in range(self.K):
    
    # check for absorbing states
    not_absorb = ((state_log[idx+1+k,:]==0).sum(axis=1) != 9) 
    
    if k == 0:
        idx_bool_either = not_absorb.copy()
    else: # after k=0 we need to check if current or any preceding state is absorbing
        idx_bool_either = [ x and y for (x,y) in zip(not_absorb, NAB_list[-1])] 
    
    idx_absorb[idx_bool_either] += 1 # only indices for non-absorbing states are incremented by 1

    NAB_list.append(idx_bool_either.copy())
    id_k_list.append(idx_absorb.copy())

Otherwise, MZ is a true successor to AZ as it relies on much of the same principles. Feel free to check out the preliminary code here. With all the basic functionality in place, I’m looking forward to scaling this up in the upcoming weeks!

  1. PyTorch makes this very easy: you can include the different models (here dynamics, prediction, and representation) in the optimizer as follows:

    dynamics = DynamicsNet().to(self.device)
    prediction = PredictionNet().to(self.device)
    representation = RepresentationNet().to(self.device)
    
    opt = torch.optim.SGD(list(dynamics.parameters()) + list(prediction.parameters()) + list(representation.parameters()), lr=lr, weight_decay=1e-5)