Abstract
Improving the efficiency of algorithms for fundamental computations can have a widespread impact, as it can affect the overall speed of a large amount of computations. Matrix multiplication is one such primitive task, occurring in many systems—from neural networks to scientific computing routines. The automatic discovery of algorithms using machine learning offers the prospect of reaching beyond human intuition and outperforming the current best humandesigned algorithms. However, automating the algorithm discovery procedure is intricate, as the space of possible algorithms is enormous. Here we report a deep reinforcement learning approach based on AlphaZero^{1} for discovering efficient and provably correct algorithms for the multiplication of arbitrary matrices. Our agent, AlphaTensor, is trained to play a singleplayer game where the objective is finding tensor decompositions within a finite factor space. AlphaTensor discovered algorithms that outperform the stateoftheart complexity for many matrix sizes. Particularly relevant is the case of 4 × 4 matrices in a finite field, where AlphaTensor’s algorithm improves on Strassen’s twolevel algorithm for the first time, to our knowledge, since its discovery 50 years ago^{2}. We further showcase the flexibility of AlphaTensor through different usecases: algorithms with stateoftheart complexity for structured matrix multiplication and improved practical efficiency by optimizing matrix multiplication for runtime on specific hardware. Our results highlight AlphaTensor’s ability to accelerate the process of algorithmic discovery on a range of problems, and to optimize for different criteria.
Main
We focus on the fundamental task of matrix multiplication, and use deep reinforcement learning (DRL) to search for provably correct and efficient matrix multiplication algorithms. This algorithm discovery process is particularly amenable to automation because a rich space of matrix multiplication algorithms can be formalized as lowrank decompositions of a specific threedimensional (3D) tensor^{2}, called the matrix multiplication tensor^{3,4,5,6,7}. This space of algorithms contains the standard matrix multiplication algorithm and recursive algorithms such as Strassen’s^{2}, as well as the (unknown) asymptotically optimal algorithm. Although an important body of work aims at characterizing the complexity of the asymptotically optimal algorithm^{8,9,10,11,12}, this does not yield practical algorithms^{5}. We focus here on practical matrix multiplication algorithms, which correspond to explicit lowrank decompositions of the matrix multiplication tensor. In contrast to twodimensional matrices, for which efficient polynomialtime algorithms computing the rank have existed for over two centuries^{13}, finding lowrank decompositions of 3D tensors (and beyond) is NPhard^{14} and is also hard in practice. In fact, the search space is so large that even the optimal algorithm for multiplying two 3 × 3 matrices is still unknown. Nevertheless, in a longstanding research effort, matrix multiplication algorithms have been discovered by attacking this tensor decomposition problem using human search^{2,15,16}, continuous optimization^{17,18,19} and combinatorial search^{20}. These approaches often rely on humandesigned heuristics, which are probably suboptimal. We instead use DRL to learn to recognize and generalize over patterns in tensors, and use the learned agent to predict efficient decompositions.
We formulate the matrix multiplication algorithm discovery procedure (that is, the tensor decomposition problem) as a singleplayer game, called TensorGame. At each step of TensorGame, the player selects how to combine different entries of the matrices to multiply. A score is assigned based on the number of selected operations required to reach the correct multiplication result. This is a challenging game with an enormous action space (more than 10^{12} actions for most interesting cases) that is much larger than that of traditional board games such as chess and Go (hundreds of actions). To solve TensorGame and find efficient matrix multiplication algorithms, we develop a DRL agent, AlphaTensor. AlphaTensor is built on AlphaZero^{1,21}, where a neural network is trained to guide a planning procedure searching for efficient matrix multiplication algorithms. Our framework uses a single agent to decompose matrix multiplication tensors of various sizes, yielding transfer of learned decomposition techniques across various tensors. To address the challenging nature of the game, AlphaTensor uses a specialized neural network architecture, exploits symmetries of the problem and makes use of synthetic training games.
AlphaTensor scales to a substantially larger algorithm space than what is within reach for either human or combinatorial search. In fact, AlphaTensor discovers from scratch many provably correct matrix multiplication algorithms that improve over existing algorithms in terms of number of scalar multiplications. We also adapt the algorithm discovery procedure to finite fields, and improve over Strassen’s twolevel algorithm for multiplying 4 × 4 matrices for the first time, to our knowledge, since its inception in 1969. AlphaTensor also discovers a diverse set of algorithms—up to thousands for each size—showing that the space of matrix multiplication algorithms is richer than previously thought. We also exploit the diversity of discovered factorizations to improve stateoftheart results for large matrix multiplication sizes. Through different usecases, we highlight AlphaTensor’s flexibility and wide applicability: AlphaTensor discovers efficient algorithms for structured matrix multiplication improving over known results, and finds efficient matrix multiplication algorithms tailored to specific hardware, by optimizing for actual runtime. These algorithms multiply large matrices faster than humandesigned algorithms on the same hardware.
Algorithms as tensor decomposition
As matrix multiplication (A, B) ↦ AB is bilinear (that is, linear in both arguments), it can be fully represented by a 3D tensor: see Fig. 1a for how to represent the 2 × 2 matrix multiplication operation as a 3D tensor of size 4 × 4 × 4, and refs. ^{3,5,7} for more details. We write \({{\mathscr{T}}}_{n}\) for the tensor describing n × n matrix multiplication. The tensor \({{\mathscr{T}}}_{n}\) is fixed (that is, it is independent of the matrices to be multiplied), has entries in {0, 1}, and is of size n^{2} × n^{2} × n^{2}. More generally, we use \({{\mathscr{T}}}_{n,m,p}\) to describe the rectangular matrix multiplication operation of size n × m with m × p (note that \({{\mathscr{T}}}_{n}={{\mathscr{T}}}_{n,n,n}\)). By a decomposition of \({{\mathscr{T}}}_{n}\) into R rankone terms, we mean
where ⊗ denotes the outer (tensor) product, and u^{(r)}, v^{(r)} and w^{(r)} are all vectors. If a tensor \({\mathscr{T}}\) can be decomposed into R rankone terms, we say the rank of \({\mathscr{T}}\) is at most R, or \(\,\text{Rank}\,({\mathscr{T}}\,)\le R\). This is a natural extension from the matrix rank, where a matrix is decomposed into \({\sum }_{r=1}^{R}{{\bf{u}}}^{(r)}\otimes {{\bf{v}}}^{(r)}\).
A decomposition of \({{\mathscr{T}}}_{n}\) into R rankone terms provides an algorithm for multiplying arbitrary n × n matrices using R scalar multiplications (see Algorithm 1). We refer to Fig. 1b,c for an example algorithm multiplying 2 × 2 matrices with R = 7 (Strassen’s algorithm).
Crucially, Algorithm 1 can be used to multiply block matrices. By using this algorithm recursively, one can multiply matrices of arbitrary size, with the rank R controlling the asymptotic complexity of the algorithm. In particular, N × N matrices can be multiplied with asymptotic complexity \({\mathcal{O}}({N}^{{\log }_{n}(R)})\); see ref. ^{5} for more details.
DRL for algorithm discovery
We cast the problem of finding efficient matrix multiplication algorithms as a reinforcement learning problem, modelling the environment as a singleplayer game, TensorGame. The game state after step t is described by a tensor \({{\mathscr{S}}}_{t}\), which is initially set to the target tensor we wish to decompose: \({{\mathscr{S}}}_{0}={{\mathscr{T}}}_{n}\). In each step t of the game, the player selects a triplet (u^{(t)}, v^{(t)}, w^{(t)}), and the tensor \({{\mathscr{S}}}_{t}\) is updated by subtracting the resulting rankone tensor: \({{\mathscr{S}}}_{t}\leftarrow {{\mathscr{S}}}_{t1}{{\bf{u}}}^{(t)}\otimes {{\bf{v}}}^{(t)}\otimes {{\bf{w}}}^{(t)}\). The goal of the player is to reach the zero tensor \({{\mathscr{S}}}_{t}={\bf{0}}\) by applying the smallest number of moves. When the player reaches the zero tensor, the sequence of selected factors satisfies \({{\mathscr{T}}}_{n}={\sum }_{t=1}^{R}{{\bf{u}}}^{(t)}\otimes {{\bf{v}}}^{(t)}\otimes {{\bf{w}}}^{(t)}\) (where R denotes the number of moves), which guarantees the correctness of the resulting matrix multiplication algorithm. To avoid playing unnecessarily long games, we limit the number of steps to a maximum value, R_{limit}.
For every step taken, we provide a reward of −1 to encourage finding the shortest path to the zero tensor. If the game terminates with a nonzero tensor (after R_{limit} steps), the agent receives an additional terminal reward equal to \(\gamma ({{\mathscr{S}}}_{{R}_{\text{limit}}})\), where \(\gamma ({{\mathscr{S}}}_{{R}_{\text{limit}}})\) is an upper bound on the rank of the terminal tensor. Although this reward optimizes for rank (and hence for the complexity of the resulting algorithm), other reward schemes can be used to optimize other properties, such as practical runtime (see ‘Algorithm discovery results’). Besides, as our aim is to find exact matrix multiplication algorithms, we constrain {u^{(t)}, v^{(t)}, w^{(t)}} to have entries in a userspecified discrete set of coefficients F (for example, F = {−2, −1, 0, 1, 2}). Such discretization is common practice to avoid issues with the finite precision of floating points^{15,18,20}.
To play TensorGame, we propose AlphaTensor (Fig. 2), an agent based on AlphaZero^{1}, which achieved tabula rasa superhuman performance in the classical board games of Go, chess and shogi, and on its extension to handle large action spaces Sampled AlphaZero^{21}. Similarly to AlphaZero, AlphaTensor uses a deep neural network to guide a Monte Carlo tree search (MCTS) planning procedure. The network takes as input a state (that is, a tensor \({{\mathscr{S}}}_{t}\) to decompose), and outputs a policy and a value. The policy provides a distribution over potential actions. As the set of potential actions (u^{(t)}, v^{(t)}, w^{(t)}) in each step is enormous, we rely on sampling actions rather than enumerating them^{21,22}. The value provides an estimate of the distribution z of returns (cumulative reward) starting from the current state \({{\mathscr{S}}}_{t}\). With the above reward scheme, the distribution z models the agent’s belief about the rank of the tensor \({{\mathscr{S}}}_{t}\). To play a game, AlphaTensor starts from the target tensor (\({{\mathscr{T}}}_{n}\)) and uses the MCTS planner at each step to choose the next action. Finished games are used as feedback to the network to improve the network parameters.
Overcoming the challenges posed by TensorGame—namely, an enormous action space, and game states described by large 3D tensors representing an abstract mathematical operation—requires multiple advances. All these components, described briefly below, substantially improve the overall performance over a plain AlphaZero agent (see Methods and Supplementary Information for details).
Neural network architecture
We propose a transformerbased^{23} architecture that incorporates inductive biases for tensor inputs. We first project the S × S × S input tensor into three S × S grids of feature vectors by using linear layers applied to the three cyclic transpositions of the tensor. The main part of the model comprises a sequence of attention operations, each applied to a set of features belonging to a pair of grids (Extended Data Figs. 3 and 4). This generalizes axial attention^{24} to multiple grids, and is both more efficient and yields better results than naive selfattention. The proposed architecture, which disregards the order of rows and columns in the grids, is inspired by the invariance of the tensor rank to slice reordering. The final feature representation of the three matrices is passed both to the policy head (an autoregressive model) and the value head (a multilayer perceptron).
Synthetic demonstrations
Although tensor decomposition is NPhard, the inverse task of constructing the tensor from its rankone factors is elementary. Hence, we generate a large dataset of tensorfactorization pairs (synthetic demonstrations) by first sampling factors \({\{({{\bf{u}}}^{(r)},{{\bf{v}}}^{(r)},{{\bf{w}}}^{(r)})\}}_{r=1}^{R}\) at random, and then constructing the tensor \({\mathscr{D}}={\sum }_{r=1}^{R}{{\bf{u}}}^{(r)}\otimes {{\bf{v}}}^{(r)}\otimes {{\bf{w}}}^{(r)}\). We train the network on a mixture of supervised loss (that is, to imitate synthetic demonstrations) and standard reinforcement learning loss (that is, learning to decompose a target tensor \({{\mathscr{T}}}_{n}\)) (Fig. 2). This mixed training strategy—training on the target tensor and random tensors— substantially outperforms each training strategy separately. This is despite randomly generated tensors having different properties from the target tensors.
Change of basis
\({{\mathscr{T}}}_{n}\) (Fig. 1a) is the tensor representing the matrix multiplication bilinear operation in the canonical basis. The same bilinear operation can be expressed in other bases, resulting in other tensors. These different tensors are equivalent: they have the same rank, and decompositions obtained in a custom basis can be mapped to the canonical basis, hence obtaining a practical algorithm of the form in Algorithm 1. We leverage this observation by sampling a random change of basis at the beginning of every game, applying it to \({{\mathscr{T}}}_{n}\), and letting AlphaTensor play the game in that basis (Fig. 2). This crucial step injects diversity into the games played by the agent.
Data augmentation
From every played game, we can extract additional tensorfactorization pairs for training the network. Specifically, as factorizations are order invariant (owing to summation), we build an additional tensorfactorization training pair by swapping a random action with the last action from each finished game.
Algorithm discovery results
Discovery of matrix multiplication algorithms
We train a single AlphaTensor agent to find matrix multiplication algorithms for matrix sizes n × m with m × p, where n, m, p ≤ 5. At the beginning of each game, we sample uniformly a triplet (n, m, p) and train AlphaTensor to decompose the tensor \({{\mathscr{T}}}_{n,m,p}\). Although we consider tensors of fixed size (\({{\mathscr{T}}}_{n,m,p}\) has size nm × mp × pn), the discovered algorithms can be applied recursively to multiply matrices of arbitrary size. We use AlphaTensor to find matrix multiplication algorithms over different arithmetics—namely, modular arithmetic (that is, multiplying matrices in the quotient ring \({{\mathbb{Z}}}_{2}\)), and standard arithmetic (that is, multiplying matrices in \({\mathbb{R}}\)).
Figure 3 (left) shows the complexity (that is, rank) of the algorithms discovered by AlphaTensor. AlphaTensor rediscovers the best algorithms known for multiplying matrices (for example, Strassen’s^{2} and Laderman’s^{15} algorithms). More importantly, AlphaTensor improves over the best algorithms known for several matrix sizes. In particular, AlphaTensor finds an algorithm for multiplying 4 × 4 matrices using 47 multiplications in \({{\mathbb{Z}}}_{2}\), thereby outperforming Strassen’s twolevel algorithm^{2}, which involves 7^{2} = 49 multiplications. By applying this algorithm recursively, one obtains a practical matrix multiplication algorithm in \({{\mathbb{Z}}}_{2}\) with complexity \({\mathcal{O}}({N}^{2.778})\). Moreover, AlphaTensor discovers efficient algorithms for multiplying matrices in standard arithmetic; for example, AlphaTensor finds a rank76 decomposition of \({{\mathscr{T}}}_{4,5,5}\), improving over the previous stateoftheart complexity of 80 multiplications. See Extended Data Figs. 1 and 2 for examples.
AlphaTensor generates a large database of matrix multiplication algorithms—up to thousands of algorithms for each size. We exploit this rich space of algorithms by combining them recursively, with the aim of decomposing larger matrix multiplication tensors. We refer to refs. ^{25,26} and Appendix H in Supplementary Information for more details. Using this approach, we improve over the stateoftheart results for more than 70 matrix multiplication tensors (with n, m, p ≤ 12). See Fig. 3 (right) and Extended Data Table 1 for the results.
A crucial aspect of AlphaTensor is its ability to learn to transfer knowledge between targets (despite providing no prior knowledge on their relationship). By training one agent to decompose various tensors, AlphaTensor shares learned strategies among these, thereby improving the overall performance (see Supplementary Information for analysis). Finally, it is noted that AlphaTensor scales beyond current computational approaches for decomposing tensors. For example, to our knowledge, no previous approach was able to handle \({{\mathscr{T}}}_{4}\), which has an action space 10^{10} times larger than \({{\mathscr{T}}}_{3}\). Our agent goes beyond this limit, discovering decompositions matching or surpassing stateoftheart for large tensors such as \({{\mathscr{T}}}_{5}\).
Analysing the symmetries of matrix multiplication algorithms
From a mathematical standpoint, the diverse algorithms discovered by AlphaTensor show that the space is richer than previously known. For example, while the only known rank49 factorization decomposing \({{\mathscr{T}}}_{4}={{\mathscr{T}}}_{2}\otimes {{\mathscr{T}}}_{2}\) before this paper conforms to the product structure (that is, it uses the factorization of \({{\mathscr{T}}}_{2}\) twice, which we refer to as Strassensquare^{2}), AlphaTensor finds more than 14,000 nonequivalent factorizations (with standard arithmetic) that depart from this scheme, and have different properties (such as matrix ranks and sparsity—see Supplementary Information). By nonequivalent, we mean that it is not possible to obtain one from another by applying a symmetry transformation (such as permuting the factors). Such properties of matrix multiplication tensors are of great interest, as these tensors represent fundamental objects in algebraic complexity theory^{3,5,7}. The study of matrix multiplication symmetries can also provide insight into the asymptotic complexity of matrix multiplication^{5}. By exploring this rich space of algorithms, we believe that AlphaTensor will be useful for generating results and guiding mathematical research. See Supplementary Information for proofs and details on the symmetries of factorizations.
Beyond standard matrix multiplication
Tensors can represent any bilinear operation, such as structured matrix multiplication, polynomial multiplication or more custom bilinear operations used in machine learning^{27,28}. We demonstrate here a usecase where AlphaTensor finds a stateoftheart algorithm for multiplying an n x n skewsymmetric matrix with a vector of length n. Figure 4a shows the obtained decompositions for small instance sizes n. We observe a pattern that we generalize to arbitrary n, and prove that this yields a general algorithm for the skewsymmetric matrixvector product (Fig. 4b). This algorithm, which uses \((n1)(n+2)/2 \sim \frac{1}{2}{n}^{2}\) multiplications (where ∼ indicates asymptotic similarity), outperforms the previously known algorithms using asymptotically n^{2} multiplications^{29}, and is asymptotically optimal. See Supplementary Information for a proof, and for another usecase showing AlphaTensor’s ability to rediscover the Fourier basis (see also Extended Data Table 2). This shows that AlphaTensor can be applied to custom bilinear operations, and yield efficient algorithms leveraging the problem structure.
Rapid tailored algorithm discovery
We show a usecase where AlphaTensor finds practically efficient matrix multiplication algorithms, tailored to specific hardware, with zero prior hardware knowledge. To do so, we modify the reward of AlphaTensor: we provide an additional reward at the terminal state (after the agent found a correct algorithm) equal to the negative of the runtime of the algorithm when benchmarked on the target hardware. That is, we set \({r}_{t}^{{\prime} }={r}_{t}+\lambda {b}_{t}\), where r_{t} is the reward scheme described in ‘DRL for algorithm discovery’, b_{t} is the benchmarking reward (nonzero only at the terminal state) and λ is a userspecified coefficient. Aside from the different reward, the exact same formulation of TensorGame is used.
We train AlphaTensor to search for efficient algorithms to multiply 4 × 4 block matrices, and focus on square matrix multiplication of size 8,192 (each block is hence of size 2,048) to define the benchmarking reward. AlphaTensor searches for the optimal way of combining the 16 square blocks of the input matrices on the considered hardware. We do not apply the 4 × 4 algorithm recursively, to leverage the efficient implementation of matrix multiplication on moderatesize matrices (2,048 × 2,048 in this case). We study two hardware devices commonly used in machine learning and scientific computing: an Nvidia V100 graphics processing unit (GPU) and a Google tensor processing unit (TPU) v2. The factorization obtained by AlphaTensor is transformed into JAX^{30} code, which is compiled (just in time) before benchmarking.
Figure 5a,b shows the efficiency of the AlphaTensordiscovered algorithms on the GPU and the TPU, respectively. AlphaTensor discovers algorithms that outperform the Strassensquare algorithm, which is a fast algorithm for large square matrices^{31,32}. Although the discovered algorithm has the same theoretical complexity as Strassensquare, it outperforms it in practice, as it is optimized for the considered hardware. Interestingly, AlphaTensor finds algorithms with a larger number of additions compared with Strassensquare (or equivalently, denser decompositions), but the discovered algorithms generate individual operations that can be efficiently fused by the specific XLA^{33} grouping procedure and thus are more tailored towards the compiler stack we use. The algorithms found by AlphaTensor also provide gains on matrix sizes larger than what they were optimized for. Finally, Fig. 5c shows the importance of tailoring to particular hardware, as algorithms optimized for one hardware do not perform as well on other hardware.
Discussion
Trained from scratch, AlphaTensor discovers matrix multiplication algorithms that are more efficient than existing human and computerdesigned algorithms. Despite improving over known algorithms, we note that a limitation of AlphaTensor is the need to predefine a set of potential factor entries F, which discretizes the search space but can possibly lead to missing out on efficient algorithms. An interesting direction for future research is to adapt AlphaTensor to search for F. One important strength of AlphaTensor is its flexibility to support complex stochastic and nondifferentiable rewards (from the tensor rank to practical efficiency on specific hardware), in addition to finding algorithms for custom operations in a wide variety of spaces (such as finite fields). We believe this will spur applications of AlphaTensor towards designing algorithms that optimize metrics that we did not consider here, such as numerical stability or energy usage.
The discovery of matrix multiplication algorithms has farreaching implications, as matrix multiplication sits at the core of many computational tasks, such as matrix inversion, computing the determinant and solving linear systems, to name a few^{7}. We also note that our methodology can be extended to tackle related primitive mathematical problems, such as computing other notions of rank (for example, border rank—see Supplementary Information), and NPhard matrix factorization problems (for example, nonnegative factorization). By tackling a core NPhard computational problem in mathematics using DRL—the computation of tensor ranks—AlphaTensor demonstrates the viability of DRL in addressing difficult mathematical problems, and potentially assisting mathematicians in discoveries.
Methods
TensorGame
TensorGame is played as follows. The start position \({{\mathscr{S}}}_{0}\) of the game corresponds to the tensor \({\mathscr{T}}\) representing the bilinear operation of interest, expressed in some basis. In each step t of the game, the player writes down three vectors (u^{(t)}, v^{(t)}, w^{(t)}), which specify the rank1 tensor u^{(t)} ⊗ v^{(t)} ⊗ w^{(t)}, and the state of the game is updated by subtracting the newly written down factor:
The game ends when the state reaches the zero tensor, \({{\mathscr{S}}}_{R}={\bf{0}}\). This means that the factors written down throughout the game form a factorization of the start tensor \({{\mathscr{S}}}_{0}\), that is, \({{\mathscr{S}}}_{0}={\sum }_{t=1}^{R}{{\bf{u}}}^{(t)}\otimes {{\bf{v}}}^{(t)}\otimes {{\bf{w}}}^{(t)}\). This factorization is then scored. For example, when optimizing for asymptotic time complexity the score is −R, and when optimizing for practical runtime the algorithm corresponding to the factorization \({\{({{\bf{u}}}^{(t)},{{\bf{v}}}^{(t)},{{\bf{w}}}^{(t)})\}}_{t=1}^{R}\) is constructed (see Algorithm 1) and then benchmarked on the fly (see Supplementary Information).
In practice, we also impose a limit R_{limit} on the maximum number of moves in the game, so that a weak player is not stuck in unnecessarily (or even infinitely) long games. When a game ends because it has run out of moves, a penalty score is given so that it is never advantageous to deliberately exhaust the move limit. For example, when optimizing for asymptotic time complexity, this penalty is derived from an upper bound on the tensor rank of the final residual tensor \({{\mathscr{S}}}_{{R}_{\text{limit}}}\). This upper bound on the tensor rank is obtained by summing the matrix ranks of the slices of the tensor.
TensorGame over rings
We say that the decomposition of \({{\mathscr{T}}}_{n}\) in equation (1) is in a ring \({\mathcal{E}}\) (defining the arithmetic operations) if each of the factors u^{(t)}, v^{(t)} and w^{(t)} has entries belonging to the set \({\mathcal{E}}\), and additions and multiplications are interpreted according to \({\mathcal{E}}\). The tensor rank depends, in general, on the ring. At each step of TensorGame, the additions and multiplications in equation (2) are interpreted in \({\mathcal{E}}\). For example, when working in \({{\mathbb{Z}}}_{2}\), (in this case, the factors u^{(t)}, v^{(t)} and w^{(t)} live in F = {0, 1}), a modulo 2 operation is applied after each state update (equation (2)).
We note that integervalued decompositions u^{(t)}, v^{(t)} and w^{(t)} lead to decompositions in arbitrary rings \({\mathcal{E}}\). Hence, provided F only contains integers, algorithms we find in standard arithmetic apply more generally to any ring.
AlphaTensor
AlphaTensor builds on AlphaZero^{1} and its extension Sampled AlphaZero^{21}, combining a deep neural network with a samplebased MCTS search algorithm.
The deep neural network, f_{θ}(s) = (π, z) parameterized by θ, takes as input the current state s of the game and outputs a probability distribution π(⋅∣s) over actions and z(⋅∣s) over returns (sum of future rewards) G. The parameters θ of the deep neural network are trained by reinforcement learning from selfplay games and synthetic demonstrations. Selfplay games are played by actors, running a samplebased MCTS search at every state s_{t} encountered in the game. The MCTS search returns an improved probability distribution over moves from which an action a_{t} is selected and applied to the environment. The subtree under a_{t} is reused for the subsequent search at s_{t+1}. At the end of the game, a return G is obtained and the trajectory is sent to the learner to update the neural network parameters θ. The distribution over returns z(⋅∣s_{t}) is learned through distributional reinforcement learning using the quantile regression distributional loss^{34}, and the network policy π(⋅∣s_{t}) is updated using a Kullback–Leibler divergence loss, to maximize its similarity to the search policy for selfplay games or to the next action for synthetic demonstrations. We use the Adam optimizer^{35} with decoupled weight decay^{36} to optimize the parameters θ of the neural network.
Samplebased MCTS search
The samplebased MCTS search is very similar to the one described in Sampled AlphaZero. Specifically, the search consists of a series of simulated trajectories of TensorGame that are aggregated in a tree. The search tree therefore consists of nodes representing states and edges representing actions. Each stateaction pair (s, a) stores a set of statistics \(N(s,a),Q(s,a),\hat{\pi }(s,a)\), where N(s, a) is the visit count, Q(s, a) is the action value and \(\hat{\pi }(s,a)\) is the empirical policy probability. Each simulation traverses the tree from the root state s_{0} until a leaf state s_{L} is reached by recursively selecting in each state s an action a that has not been frequently explored, has high empirical policy probability and high value. Concretely, actions within the tree are selected by maximizing over the probabilistic upper confidence tree bound^{21,37}
where c(s) is an exploration factor controlling the influence of the empirical policy \(\hat{\pi }(s,a)\) relative to the values Q(s, a) as nodes are visited more often. In addition, a transposition table is used to recombine different action sequences if they reach the exact same tensor. This can happen particularly often in TensorGame as actions are commutative. Finally, when a leaf state s_{L} is reached, it is evaluated by the neural network, which returns K actions {a_{i}} sampled from π(a∣s_{L}), alongside the empirical distribution \(\hat{\pi }(a {s}_{{\rm{L}}})=\frac{1}{K}{\sum }_{i}{\delta }_{a,{a}_{i}}\) and a value v(s_{L}) constructed from z(⋅∣s_{L}). Differently from AlphaZero and Sampled AlphaZero, we chose v not to be the mean of the distribution of returns z(⋅∣s_{L}) as is usual in most reinforcement learning agents, but instead to be a riskseeking value, leveraging the facts that TensorGame is a deterministic environment and that we are primarily interested in finding the best trajectory possible. The visit counts and values on the simulated trajectory are then updated in a backward pass as in Sampled AlphaZero.
Policy improvement
After simulating N(s) trajectories from state s using MCTS, the normalized visit counts of the actions at the root of the search tree N(s, a)/N(s) form a samplebased improved policy. Differently from AlphaZero and Sampled AlphaZero, we use an adaptive temperature scheme to smooth the normalized visit counts distribution as some states can accumulate an order of magnitude more visits than others because of subtree reuse and transposition table. Concretely, we define the improved policy as \({\mathcal{I}}\hat{\pi }(s,a)={N}^{1/\tau (s)}(s,a)/{\sum }_{b}{N}^{1/\tau (s)}(s,b)\) where \(\tau (s)=\log N(s)/\log \bar{N}\,{\rm{if}}\,N > \bar{N}\) and 1 otherwise, with \(\bar{N}\) being a hyperparameter. For training, we use \({\mathcal{I}}\hat{\pi }\) directly as a target for the network policy π. For acting, we additionally discard all actions that have a value lower than the value of the most visited action, and sample proportionally to \({\mathcal{I}}\hat{\pi }\) among those remaining highvalue actions.
Learning one agent for multiple target tensors
We train a single agent to decompose the different tensors \({{\mathscr{T}}}_{n,m,p}\) in a given arithmetic (standard or modular). As the network works with fixedsize inputs, we pad all tensors (with zeros) to the size of the largest tensor we consider (\({{\mathscr{T}}}_{5}\), of size 25 × 25 × 25). At the beginning of each game, we sample uniformly at random a target \({{\mathscr{T}}}_{n,m,p}\), and play TensorGame. Training a single agent on different targets leads to better results thanks to the transfer between targets. All our results reported in Fig. 3 are obtained using multiple runs of this multitarget setting. We also train a single agent to decompose tensors in both arithmetics. Owing to learned transfer between the two arithmetics, this agent discovers a different distribution of algorithms (of the same ranks) in standard arithmetic than the agent trained on standard arithmetic only, thereby increasing the overall diversity of discovered algorithms.
Synthetic demonstrations
The synthetic demonstrations buffer contains tensorfactorization pairs, where the factorizations \({\{({{\bf{u}}}^{(r)},{{\bf{v}}}^{(r)},{{\bf{w}}}^{(r)})\}}_{r=1}^{R}\) are first generated at random, after which the tensor \({\mathscr{D}}={\sum }_{r=1}^{R}{{\bf{u}}}^{(r)}\otimes {{\bf{v}}}^{(r)}\otimes {{\bf{w}}}^{(r)}\) is formed. We create a dataset containing 5 million such tensorfactorization pairs. Each element in the factors is sampled independently and identically distributed (i.i.d.) from a given categorical distribution over F (all possible values that can be taken). We discarded instances whose decompositions were clearly suboptimal (contained a factor with u = 0, v = 0, or w = 0).
In addition to these synthetic demonstrations, we further add to the demonstration buffer previous games that have achieved large scores to reinforce the good moves made by the agent in these games.
Change of basis
The rank of a bilinear operation does not depend on the basis in which the tensor representing it is expressed, and for any invertible matrices A, B and C we have \({\rm{Rank}}\,({\mathscr{T}})={\rm{Rank}}\,({{\mathscr{T}}}^{({\bf{A}},{\bf{B}},{\bf{C}})})\), where \({{\mathscr{T}}}^{({\bf{A}},{\bf{B}},{\bf{C}})}\) is the tensor after change of basis given by
Hence, exhibiting a rankR decomposition of the matrix multiplication tensor \({{\mathscr{T}}}_{n}\) expressed in any basis proves that the product of two n × n matrices can be computed using R scalar multiplications. Moreover, it is straightforward to convert such a rankR decomposition into a rankR decomposition in the canonical basis, thus yielding a practical algorithm of the form shown in Algorithm 1. We leverage this observation by expressing the matrix multiplication tensor \({{\mathscr{T}}}_{n}\) in a large number of randomly generated bases (typically 100,000) in addition to the canonical basis, and letting AlphaTensor play games in all bases in parallel.
This approach has three appealing properties: (1) it provides a natural exploration mechanism as playing games in different bases automatically injects diversity into the games played by the agent; (2) it exploits properties of the problem as the agent need not succeed in all bases—it is sufficient to find a lowrank decomposition in any of the bases; (3) it enlarges coverage of the algorithm space because a decomposition with entries in a finite set F = {−2, −1, 0, 1, 2} found in a different basis need not have entries in the same set when converted back into the canonical basis.
In full generality, a basis change for a 3D tensor of size S × S × S is specified by three invertible S × S matrices A, B and C. However, in our procedure, we sample bases at random and impose two restrictions: (1) A = B = C, as this performed better in early experiments, and (2) unimodularity (\(\det {\bf{A}}\in \{1,+1\}\)), which ensures that after converting an integral factorization into the canonical basis it still contains integer entries only (this is for representational convenience and numerical stability of the resulting algorithm). See Supplementary Information for the exact algorithm.
Signed permutations
In addition to playing (and training on) games in different bases, we also utilize a data augmentation mechanism whenever the neural network is queried in a new MCTS node. At acting time, when the network is queried, we transform the input tensor by applying a change of basis—where the change of basis matrix is set to a random signed permutation. We then query the network on this transformed input tensor, and finally invert the transformation in the network’s policy predictions. Although this data augmentation procedure can be applied with any generic change of basis matrix (that is, it is not restricted to signed permutation matrices), we use signed permutations mainly for computational efficiency. At training time, whenever the neural network is trained on an (input, policy targets, value target) triplet (Fig. 2), we apply a randomly chosen signed permutation to both the input and the policy targets, and train the network on this transformed triplet. In practice, we sample 100 signed permutations at the beginning of an experiment, and use them thereafter.
Action canonicalization
For any λ_{1}, λ_{2}, λ_{3} ∈ {−1, +1} such that λ_{1}λ_{2}λ_{3} = 1, the actions (λ_{1}u, λ_{2}v, λ_{3}w) and (u, v, w) are equivalent because they lead to the same rankone tensor (λ_{1}u) ⊗ (λ_{2}v) ⊗ (λ_{3}w) = u ⊗ v ⊗ w. To prevent the network from wasting capacity on predicting multiple equivalent actions, during training we always present targets (u, v, w) for the policy head in a canonical form, defined as having the first nonzero element of u and the first nonzero element of v strictly positive. This is well defined because u or v cannot be all zeros (if they are to be part of a minimal rank decomposition), and for any (u, v, w) there are unique λ_{1}, λ_{2}, λ_{3} ∈ {−1, +1} (with λ_{1}λ_{2}λ_{3} = 1) that transform it into canonical form. In case the network predicts multiple equivalent actions anyway, we merge them together (summing their empirical policy probabilities) before inserting them into the MCTS tree.
Training regime
We train AlphaTensor on a TPU v3, with a total batch size of 2,048. We use 64 TPU cores, and train for 600,000 iterations. On the actor side, the games are played on standalone TPU v4, and we use 1,600 actors. In practice, the procedure takes a week to converge.
Neural network
The architecture is composed of a torso, followed by a policy head that predicts a distribution over actions, and a value head that predicts a distribution of the returns from the current state (see Extended Data Fig. 3).
Input
The input to the network contains all the relevant information of the current state and is composed of a list of tensors and a list of scalars. The most important piece of information is the current 3D tensor \({{\mathscr{S}}}_{t}\) of size S × S × S. (For simplicity, in the description here we assume that all the three dimensions of the tensor are equal in size. The generalization to different sizes is straightforward.) In addition, the model is given access to the last h actions (h being a hyperparameter usually set to 7), represented as h rank1 tensors that are concatenated to the input. The list of scalars includes the time index t of the current action (where 0 ≤ t < R_{limit}).
Torso
The torso of the network is in charge of mapping both scalars and tensors from the input to a representation that is useful to both policy and value heads. Its architecture is based on a modification of transformers^{23}, and its main signature is that it operates over three S × S grids projected from the S × S × S input tensors. Each grid represents two out of the three modes of the tensor. Defining the modes of the tensor as \({\mathcal{U}},{\mathcal{V}},{\mathcal{W}}\), the rows and columns of the first grid are associated to \({\mathcal{U}}\) and \({\mathcal{V}}\), respectively, the rows and columns of the second grid are associated to \({\mathcal{W}}\) and \({\mathcal{U}}\), and the rows and columns of the third grid are associated to \({\mathcal{V}}\) and \({\mathcal{W}}\). Each element of each grid is a feature vector, and its initial value is given by the elements of the input tensors along the grid’s missing mode. These feature vectors are enriched by concatenating an S × S × 1 linear projection from the scalars. This is followed by a linear layer projecting these feature vectors into a 512dimensional space.
The rest of the torso is a sequence of attentionbased blocks with the objective of propagating information between the three grids. Each of those blocks has three stages, one for every pair of grids. In each stage, the grids involved are concatenated, and axial attention^{24} is performed over the columns. It is noted that in each stage we perform in parallel S selfattention operations of 2S elements in each. The representation sent to the policy head corresponds to the 3S^{2 }512dimensional feature vectors produced by the last layer of the torso. A detailed description of the structure of the torso is specified in Extended Data Fig. 4 (top) and Appendix A.1.1 in Supplementary Information.
Policy head
The policy head uses the transformer architecture^{23} to model an autoregressive policy. Factors are decomposed into k tokens of dimensionality d such that k × d = 3S. The transformer conditions on the tokens already generated and crossattends to the features produced by the torso. At training time, we use teacherforcing, that is, the ground truth actions are decomposed into tokens and taken as inputs into the causal transformer in such a way that the prediction of a token depends only on the previous tokens. At inference time, K actions are sampled from the head. The feature representation before the last linear layer of the initial step (that is, the only step that is not conditioned on the ground truth) is used as an input to the value head, described below. Details of the architecture are presented in Extended Data Fig. 4 (centre) and Appendix A.1.2 in Supplementary Information.
Value head
The value head is composed of a fourlayer multilayer perceptron whose last layer produces q outputs corresponding to the \(\frac{1}{2q},\frac{3}{2q},\ldots \frac{2q1}{2q}\) quantiles. In this way, the value head predicts the distribution of returns from this state in the form of values predicted for the aforementioned quantiles^{34}. At inference time, we encourage the agent to be riskseeking by using the average of the predicted values for quantiles over 75%. A detailed description of the value head is presented in Extended Data Fig. 4 (bottom) and Appendix A.1.3 in Supplementary Information.
Related work
The quest for efficient matrix multiplication algorithms started with Strassen’s breakthrough in ref. ^{2}, which showed that one can multiply 2 × 2 matrices using 7 scalar multiplications, leading to an algorithm of complexity \({\mathcal{O}}({n}^{2.81})\). This led to the development of a very active field of mathematics attracting worldwide interest, which studies the asymptotic complexity of matrix multiplication (see refs. ^{3,4,5,6}). So far, the best known complexity for matrix multiplication is \({\mathcal{O}}({n}^{2.37286})\) (ref. ^{12}), which improves over ref. ^{11}, and builds on top of fundamental results in the field^{8,9,10}. However, this does not yield practical algorithms, as such approaches become advantageous only for astronomical matrix sizes. Hence, a significant body of work aims at exhibiting explicit factorizations of matrix multiplication tensors, as these factorizations provide practical algorithms. After Strassen’s breakthrough showing that \(\text{rank}\,({{\mathscr{T}}}_{2})\le 7\), efficient algorithms for larger matrix sizes were found^{15,16,18,26,38}. Most notably, Laderman showed in ref. ^{15} that 3 × 3 matrix multiplications can be performed with 23 scalar multiplications. In addition to providing individual lowrank factorizations, an important research direction aims at understanding the space of matrix multiplication algorithms—as opposed to exhibiting individual lowrank factorizations—by studying the symmetry groups and diversity of factorizations (see ref. ^{5} and references therein). For example, the symmetries of 2 × 2 matrix multiplication were studied in refs. ^{39,40,41,42}, where Strassen’s algorithm was shown to be essentially unique. The case of 3 × 3 was studied in ref. ^{43}, whereas a symmetric factorization for all n is provided in ref. ^{44}.
On the computational front, continuous optimization has been the main workhorse for decomposing tensors^{17,45,46}, and in particular matrix multiplication tensors. Such continuous optimization procedures (for example, alternating least squares), however, yield approximate solutions, which correspond to inexact matrix multiplication algorithms with floating point operations. To circumvent this issue, regularization procedures have been proposed, such as ref. ^{18}, to extract exact decompositions. Unfortunately, such approaches often require substantial human intervention and expertise to decompose large tensors. A different line of attack was explored in refs. ^{47,48}, based on learning the continuous weights of a twolayer network that mimics the structure of the matrix multiplication operation. This method, which is trained through supervised learning of matrix multiplication examples, finds approximate solutions to 2 × 2 and 3 × 3 matrix multiplications. In ref. ^{48}, a quantization procedure is further used to obtain an exact decomposition for 2 × 2. Unlike continuous optimizationbased approaches, AlphaTensor directly produces algorithms from the desired set of valid algorithms, and is flexible in that it allows us to optimize a wide range of (even nondifferentiable) objectives. This unlocks tackling broader settings (for example, optimization in finite fields, optimization of runtime), as well as larger problems (for example, \({{\mathscr{T}}}_{4}\) and \({{\mathscr{T}}}_{5}\)) than those previously considered. Different from continuous optimization, a boolean satisfiability (SAT) based formulation of the problem of decomposing 3 × 3 matrix multiplication was recently proposed in ref. ^{20}, which adds thousands of new decompositions of rank 23 to the list of known 3 × 3 factorizations. The approach relies on a stateoftheart SAT solving procedure, where several assumptions and simplifications are made on the factorizations to reduce the search space. As is, this approach is, however, unlikely to scale to larger tensors, as the search space grows very quickly with the size.
On the practical implementation front, ref. ^{31} proposed several ideas to speed up implementation of fast matrix multiplication algorithms on central processing units (CPUs). Different fast algorithms are then compared and benchmarked, and the potential speedup of such algorithms is shown against standard multiplication. Other works focused on getting the maximal performance out of a particular fast matrix multiplication algorithm (Strassen’s algorithm with one or two levels of recursion) on a CPU^{32} or a GPU^{49}. These works show that, despite popular belief, such algorithms are of practical value. We see writing a custom lowlevel implementation of a given algorithm to be distinct from the focus of this paper—developing new efficient algorithms—and we believe that the algorithms we discovered can further benefit from a more efficient implementation by experts.
Beyond matrix multiplication and bilinear operations, a growing amount of research studies the use of optimization and machine learning to improve the efficiency of computational operations. There are three levels of abstractions at which this can be done: (1) in the hardware design, for example, chip floor planning^{50}, (2) at the hardware–software interface, for example, program superoptimization of a reference implementation for specific hardware^{51}, and (3) on the algorithmic level, for example, program induction^{52}, algorithm selection^{53} or metalearning^{54}. Our work focuses on the algorithmic level of abstraction, although AlphaTensor is also flexible to discover efficient algorithms for specific hardware. Different from previous works, we focus on discovering matrix multiplication algorithms that are provably correct, without requiring initial reference implementations. We conclude by relating our work broadly to existing reinforcement learning methods for scientific discovery. Within mathematics, reinforcement learning was applied, for example, to theorem proving^{55,56,57,58}, and to finding counterexamples refuting conjectures in combinatorics and graph theory^{59}. Reinforcement learning was further shown to be useful in many areas in science, such as molecular design^{60,61} and synthesis^{62} and optimizing quantum dynamics^{63}.
Data availability
The data used to train the system were generated synthetically according to the procedures explained in the paper. The algorithms discovered by AlphaTensor are available for download at https://github.com/deepmind/alphatensor.
Code availability
An interactive notebook with code to check the nonequivalence of algorithms is provided. Moreover, the fast algorithms from the ‘Algorithm discovery results’ section on a GPU and a TPU are provided. These are available for download at https://github.com/deepmind/alphatensor. A full description of the AlphaZero algorithm that this work is based on is available in ref. ^{1}, and the specific neural network architecture we use is described using pseudocode in the Supplementary Information.
References
Silver, D. et al. A general reinforcement learning algorithm that masters chess, shogi, and Go through selfplay. Science 362, 1140–1144 (2018).
Strassen, V. Gaussian elimination is not optimal. Numer. Math. 13, 354–356 (1969).
Bürgisser, P., Clausen, M. & Shokrollahi, A. Algebraic Complexity Theory Vol. 315 (Springer Science & Business Media, 2013).
Bläser, M. Fast matrix multiplication. Theory Comput. 5, 1–60 (2013).
Landsberg, J. M. Geometry and Complexity Theory 169 (Cambridge Univ. Press, 2017).
Pan, V. Y. Fast feasible and unfeasible matrix multiplication. Preprint at https://arxiv.org/abs/1804.04102 (2018).
Lim, L.H. Tensors in computations. Acta Numer. 30, 555–764 (2021).
Schönhage, A. Partial and total matrix multiplication. SIAM J. Comput. 10, 434–455 (1981).
Coppersmith, D. & Winograd, S. Matrix multiplication via arithmetic progressions. In ACM Symposium on Theory of Computing 1–6 (ACM, 1987).
Strassen, V. The asymptotic spectrum of tensors and the exponent of matrix multiplication. In 27th Annual Symposium on Foundations of Computer Science 49–54 (IEEE, 1986).
Le Gall, F. Powers of tensors and fast matrix multiplication. In International Symposium on Symbolic and Algebraic Computation 296–303 (ACM, 2014).
Alman, J. & Williams, V. V. A refined laser method and faster matrix multiplication. In ACMSIAM Symposium on Discrete Algorithms 522–539 (SIAM, 2021).
Gauss, C. F. Theoria Motus Corporum Coelestium in Sectionibus Conicis Solum Ambientium (Perthes and Besser, 1809).
Hillar, C. J. & Lim, L.H. Most tensor problems are NPhard. J. ACM 60, 1–39 (2013).
Laderman, J. D. A noncommutative algorithm for multiplying 3 × 3 matrices using 23 multiplications. Bull. Am. Math. Soc. 82, 126–128 (1976).
Hopcroft, J. E. & Kerr, L. R. On minimizing the number of multiplications necessary for matrix multiplication. SIAM J. Appl. Math. 20, 30–36 (1971).
Vervliet, N., Debals, O., Sorber, L., Van Barel, M. & De Lathauwer, L. Tensorlab 3.0 (2016); https://www.tensorlab.net/
Smirnov, A. V. The bilinear complexity and practical algorithms for matrix multiplication. Comput. Math. Math. Phys. 53, 1781–1795 (2013).
Sedoglavic, A. & Smirnov, A. V. The tensor rank of 5x5 matrices multiplication is bounded by 98 and its border rank by 89. In Proc. 2021 on International Symposium on Symbolic and Algebraic Computation 345–351 (ACM, 2021).
Heule, M. J., Kauers, M. & Seidl, M. New ways to multiply 3 × 3matrices. J. Symb. Comput. 104, 899–916 (2021).
Hubert, T. et al. Learning and planning in complex action spaces. In International Conference on Machine Learning 4476–4486 (PMLR, 2021).
Zhang, W. & Dietterich, T. G. A reinforcement learning approach to jobshop scheduling. In International Joint Conferences on Artificial Intelligence Vol. 95, 1114–1120 (Morgan Kaufmann Publishers, 1995).
Vaswani, A. Attention is all you need. In International Conference on Neural Information Processing Systems Vol 30, 5998–6008 (Curran Associates, 2017).
Ho, J., Kalchbrenner, N., Weissenborn, D. & Salimans, T. Axial attention in multidimensional transformers. Preprint at https://arxiv.org/abs/1912.12180 (2019).
Drevet, C.É., Islam, M. N. & Schost, É. Optimization techniques for small matrix multiplication. Theor. Comput. Sci. 412, 2219–2236 (2011).
Sedoglavic, A. A noncommutative algorithm for multiplying (7 × 7) matrices using 250 multiplications. Preprint at https://arxiv.org/abs/1712.07935 (2017).
Battaglia, P. W. et al. Relational inductive biases, deep learning, and graph networks. Preprint at https://arxiv.org/abs/1806.01261 (2018).
Balog, M., van Merriënboer, B., Moitra, S., Li, Y. & Tarlow, D. Fast training of sparse graph neural networks on dense hardware. Preprint at https://arxiv.org/abs/1906.11786 (2019).
Ye, K. & Lim, L.H. Fast structured matrix computations: tensor rank and Cohn–Umans method. Found. Comput. Math. 18, 45–95 (2018).
Bradbury, J. et al. JAX: composable transformations of Python+NumPy programs. GitHub http://github.com/google/jax (2018).
Benson, A. R. & Ballard, G. A framework for practical parallel fast matrix multiplication. ACM SIGPLAN Not. 50, 42–53 (2015).
Huang, J., Smith, T. M., Henry, G. M. & Van De Geijn, R. A. Strassen’s algorithm reloaded. In International Conference for High Performance Computing, Networking, Storage and Analysis 690–701 (IEEE, 2016).
Abadi, M. et al. Tensorflow: a system for largescale machine learning. In USENIX Symposium On Operating Systems Design And Implementation 265–283 (USENIX, 2016).
Dabney, W., Rowland, M., Bellemare, M. & Munos, R. Distributional reinforcement learning with quantile regression. In AAAI Conference on Artificial Intelligence Vol. 32, 2892–2901 (AAAI Press, 2018).
Kingma, D. P., & Ba, J. Adam: a method for stochastic optimization. In International Conference on Learning Representations (ICLR) (2015).
Loshchilov, I. & Hutter, F. Decoupled weight decay regularization. In International Conference on Learning Representations (ICLR) (2019).
Silver, D. et al. Mastering the game of Go with deep neural networks and tree search. Nature 529, 484–489 (2016).
Sedoglavic, A. A noncommutative algorithm for multiplying 5x5 matrices using 99 multiplications. Preprint at https://arxiv.org/abs/1707.06860 (2017).
de Groote, H. F. On varieties of optimal algorithms for the computation of bilinear mappings II. optimal algorithms for 2 × 2matrix multiplication. Theor. Comput. Sci. 7, 127–148 (1978).
Burichenko, V. P. On symmetries of the Strassen algorithm. Preprint at https://arxiv.org/abs/1408.6273 (2014).
Chiantini, L., Ikenmeyer, C., Landsberg, J. M. & Ottaviani, G. The geometry of rank decompositions of matrix multiplication I: 2 × 2 matrices. Exp. Math. 28, 322–327 (2019).
Grochow, J. A. & Moore, C. Designing Strassen’s algorithm. Preprint at https://arxiv.org/abs/1708.09398 (2017).
Ballard, G., Ikenmeyer, C., Landsberg, J. M. & Ryder, N. The geometry of rank decompositions of matrix multiplication II: 3 × 3 matrices. J. Pure Appl. Algebra 223, 3205–3224 (2019).
Grochow, J. A. & Moore, C. Matrix multiplication algorithms from group orbits. Preprint at https://arxiv.org/abs/1612.01527 (2016).
Kolda, T. G. & Bader, B. W. Tensor decompositions and applications. SIAM Rev. 51, 455–500 (2009).
Bernardi, A., Brachat, J., Comon, P. & Mourrain, B. General tensor decomposition, moment matrices and applications. J. Symb. Comput. 52, 51–71 (2013).
Elser, V. A network that learns Strassen multiplication. J. Mach. Learn. Res. 17, 3964–3976 (2016).
Tschannen, M., Khanna, A. & Anandkumar, A, StrassenNets: deep learning with a multiplication budget. In International Conference on Machine Learning 4985–4994 (PMLR, 2018).
Huang, J., Yu, C. D. & Geijn, R. A. V. D. Strassen’s algorithm reloaded on GPUs. ACM Trans. Math. Softw. 46, 1–22 (2020).
Mirhoseini, A. et al. A graph placement methodology for fast chip design. Nature 594, 207–212 (2021).
Bunel, R., Desmaison, A., Kohli, P., Torr, P. H. & Kumar, M. P. Learning to superoptimize programs. In International Conference on Learning Representations (ICLR) (2017).
Li, Y., Gimeno, F., Kohli, P. & Vinyals, O. Strong generalization and efficiency in neural programs. Preprint at https://arxiv.org/abs/2007.03629 (2020).
Lagoudakis, M. G. et al. Algorithm selection using reinforcement learning. In International Conference on Machine Learning 511–518 (Morgan Kaufmann Publishers, 2000).
Schmidhuber, J. Evolutionary Principles in SelfReferential Learning. On Learning now to Learn: The MetaMetaMeta...Hook. Diploma thesis, Technische Univ. Munchen (1987).
Kaliszyk, C., Urban, J., Michalewski, H. & Olšák, M. Reinforcement learning of theorem proving. In International Conference on Neural Information Processing Systems 8836–8847 (Curran Associates, 2018).
Piotrowski, B. & Urban, J. ATPboost: learning premise selection in binary setting with ATP feedback. In International Joint Conference on Automated Reasoning 566–574 (Springer, 2018).
Bansal, K., Loos, S., Rabe, M., Szegedy, C. & Wilcox, S. HOList: an environment for machine learning of higher order logic theorem proving. In International Conference on Machine Learning 454–463 (PMLR, 2019).
Zombori, Z., Urban, J. & Brown, C. E. Prolog technology reinforcement learning prover. In International Joint Conference on Automated Reasoning 489–507 (Springer, 2020).
Wagner, A. Z. Constructions in combinatorics via neural networks. Preprint at https://arxiv.org/abs/2104.14516 (2021).
Popova, M., Isayev, O. & Tropsha, A. Deep reinforcement learning for de novo drug design. Sci. Adv. 4, eaap7885 (2018).
Zhou, Z., Kearnes, S., Li, L., Zare, R. N. & Riley, P. Optimization of molecules via deep reinforcement learning. Sci. Rep. 9, 10752 (2019).
Segler, M. H., Preuss, M. & Waller, M. P. Planning chemical syntheses with deep neural networks and symbolic AI. Nature 555, 604–610 (2018).
Dalgaard, M., Motzoi, F., Sørensen, J. J. & Sherson, J. Global optimization of quantum dynamics with AlphaZero deep exploration. npj Quantum Inf. 6, 6 (2020).
Fast matrix multiplication algorithms catalogue. Université de Lille https://fmm.univlille.fr/ (2021).
Acknowledgements
We thank O. Fawzi, H. Fawzi, C. Ikenmeyer, J. Ellenberg, C. Umans and A. Wigderson for the inspiring discussions on the use of machine learning for maths; A. Davies, A. Gaunt, P. Mudigonda, R. Bunel and O. Ronneberger for their advice on early drafts of the paper; A. Ruderman, M. Bauer, R. Leblond, R. Kabra and B. Winckler for participating in a hackathon at the early stages of the project; D. Visentin, R. Tanburn and S. Noury for sharing their expertise on TPUs; P. Wang and R. Zhao for their help on benchmarking algorithms; G. Holland, A. Pierce, N. Lambert and C. Meyer for assistance coordinating the research; and our colleagues at DeepMind for encouragement and support.
Author information
Authors and Affiliations
Contributions
A.F. conceived the project, with support from B.R.P. and P.K.; T.H., A.H. and J.S. developed the initial AlphaZero codebase, and B.R.P., M. Balog, A.F., A.N., F.J.R.R. and G.S. developed an early supervised network prototype. A.H., T.H., B.R.P., M. Barekatain and J.S. designed the network architecture used in the paper. T.H., J.S., A.H., M. Barekatain, A.F., M. Balog and F.J.R.R. developed the tensor decomposition environment and data generation pipeline, and A.H., T.H., M. Barekatain, M. Balog, B.R.P., F.J.R.R. and A.N. analysed the experimental results and algorithms discovered by AlphaTensor. A.N., A.F. and T.H. developed the benchmarking pipeline and experiments, and B.R.P., F.J.R.R. and A.N. extended the approach to structured tensors. A.F., B.R.P., G.S. and A.N. proved the results in the paper. D.S., D.H. and P.K. contributed technical advice and ideas. A.F., M. Balog, B.R.P., F.J.R.R., A.N. and T.H. wrote the paper. These authors contributed equally, and are listed alphabetically by last name after the corresponding author: A.F., M. Balog, A.H., T.H., B.R.P. These authors contributed equally, and are listed alphabetically by last name: M. Barekatain, A.N., F.J.R.R., J.S. and G.S.
Corresponding author
Ethics declarations
Competing interests
The authors of the paper are planning to file a patent application relating to subject matter contained in this paper in the name of DeepMind Technologies Limited.
Peer review
Peer review information
Nature thanks Grey Ballard, Jordan Ellenberg, LekHeng Lim, Michael Littman and the other, anonymous, reviewer(s) for their contribution to the peer review of this work.
Additional information
Publisher’s note Springer Nature remains neutral with regard to jurisdictional claims in published maps and institutional affiliations.
Extended data figures and tables
Extended Data Fig. 1 Algorithm for multiplying 4 × 4 matrices in modular arithmetic (\({{\mathbb{Z}}}_{2}\)) with 47 multiplications.
This outperforms the twolevel Strassen’s algorithm, which involves 7^{2} = 49 multiplications.
Extended Data Fig. 2 Algorithm for multiplying 4 × 5 by 5 × 5 matrices in standard arithmetic with 76 multiplications.
This outperforms the previously best known algorithm, which involves 80 multiplications.
Extended Data Fig. 3 AlphaTensor’s network architecture.
The network takes as input the list of tensors containing the current state and previous history of actions, and a list of scalars, such as the time index of the current action. It produces two kinds of outputs: one representing the value, and the other inducing a distribution over the action space from which we can sample from. The architecture of the network is accordingly designed to have a common torso, and two heads, the value and the policy heads. c is set to 512 in all experiments.
Extended Data Fig. 4 Detailed view of AlphaTensor’s architecture, included torso, policy and value head.
We refer to Algorithms A.1A.11 in Supplementary Information for the details of each component.
Supplementary information
Rights and permissions
Open Access This article is licensed under a Creative Commons Attribution 4.0 International License, which permits use, sharing, adaptation, distribution and reproduction in any medium or format, as long as you give appropriate credit to the original author(s) and the source, provide a link to the Creative Commons license, and indicate if changes were made. The images or other third party material in this article are included in the article’s Creative Commons license, unless indicated otherwise in a credit line to the material. If material is not included in the article’s Creative Commons license and your intended use is not permitted by statutory regulation or exceeds the permitted use, you will need to obtain permission directly from the copyright holder. To view a copy of this license, visit http://creativecommons.org/licenses/by/4.0/.
About this article
Cite this article
Fawzi, A., Balog, M., Huang, A. et al. Discovering faster matrix multiplication algorithms with reinforcement learning. Nature 610, 47–53 (2022). https://doi.org/10.1038/s41586022051724
Received:
Accepted:
Published:
Issue Date:
DOI: https://doi.org/10.1038/s41586022051724
This article is cited by

Reservoir computing with memristors
Nature Electronics (2022)
Comments
By submitting a comment you agree to abide by our Terms and Community Guidelines. If you find something abusive or that does not comply with our terms or guidelines please flag it as inappropriate.