TL;DR
We introduce Mixture of Tokens, a novel, fully-differentiable Transformer architecture that allows for increasing the number of model parameters while keeping the computation cost constant.
- It avoids problems typical for Mixture of Experts architectures
- It is compatible with causal and masked Large Language Models
- Our PoC model achieves the same performance as the baseline Transformer with \(3\times\) wall-clock speedup and \(4\times\) FLOPS reduction
- See the initial arXiv version for reference and citation info
Introduction
Mixture of Experts (MoE) architectures have recently garnered considerable attention for their ability to increase the size of Transformer models while keeping the computational cost of training and inference constant. The most successful MoE approaches achieve this by activating only subsets of a very large feed-forward layer for each processed token (the alternative subsets of parameters are often called experts).
This technique comes at a cost, though: the operation of choosing the most suitable experts for a given token is discrete, and learning discrete choices is difficult; the models are known to suffer from issues including training instability and expert under- and overload. While some of those problems can be alleviated with, e.g., the use of various auxiliary losses or reduced initialization scale, it is certain that existing MoE techniques are more difficult and less intuitive to train than dense counterparts.
Aiming to avoid these problems, we propose Mixture of Tokens: a new, fully-differentiable type of architecture that retains all efficiency benefits of MoE and avoids the the aforementioned problems by, instead of routing tokens to experts, mixing tokens from different examples before feeding them into experts, in effect allowing the model to simultaneously learn from all token-expert combinations. Importantly, mixing can be disabled to avoid mixing of different sequences during inference. Crucially, this technique is fully compatible with both masked and causal LLM training.
GENERAL TODOS: grammarly, go over and rephrase, align and make clearer the Methods section, make mail, tweet, settle on contribution attributions,
Motivation
Scaling Language Models
Large language models based on Transformers currently make up one of the most active fields in Machine Learning, exhibiting human-level performance in a variety of tasks. This is in large part due to their scaling properties - (Kaplan et al., 2020; Hoffmann et al., 2022) showed that an increase in model size results in a predictable increase in performance. This scaling leads to an ever-growing demand for computational resources, with their effective utilization often deemed as one of the critical challenges of the field (Rae et al., 2022; Jaszczur et al., 2021; Nawrot et al., 2022).
Mixture of Experts
How can we increase the model size without additional computational cost? Mixture of Expert does this by replacing the feed-forward layer standard for Transformer architectures with a (potentially very large) set of experts, together with a small network often called a controller1. The (trainable) controller matches tokens and experts in a way that each token is processed only by a small subset of experts.
Similarly to vanilla Transformers, the performance of MoE models also scales with parameter count (Clark et al., 2022). For a more detailed background and explanation of variants of the MoE architecture, see Background.
Limitations of current approaches
While the performance of the huge-parameter-count MoE architectures is impressive, they come with an entirely new set of challenges during both training and inference. The most notable include:
Training instability. Multiple studies (Fedus et al., 2022; Du et al., 2022; Mustafa et al., 2022) report difficulties in training MoE models due to instabilities. This is likely due to the nature of the technique: the operation of choosing top-k most relevant tokens/experts in discrete, and thus small changes of controller weights can have disproportional effects on controller decisions. We hypothesize that existing techniques used for training the controller with gradient descent, while somewhat effective, do not entirely solve this problem. (Jaszczur et al., 2021) reported training stability improvements due to using a weighted average of expert outputs.
Load imbalance. Typically, in MoE we set the maximum capacity for each expert. However, we are not able to efficiently restrict the choice of the routing network to assign tokens in a perfectly balanced way. This leads to token dropping (when some tokens are not processed by an expert) and mode collapse (when the controller sends almost all tokens to a few experts).
Information leak. Some of the most successful MoE methods process tokens from different positions in a sequence together (i.e., by comparing scores of all tokens in a batch). This imposes an information leak and hinders their utility in autoregressive decoding.
Our technique is as stable as a vanilla Transformer because the network is fully differentiable, and no discrete choices are made during training. As every expert receives the same number of tokens, the issue of load imbalance is side-stepped as well. Finally, our technique is fully compatible with autoregressive decoding. See a detailed explanation of the technique in Method.
Background
In the context of language models, Mixture of Experts was originally proposed in (Shazeer et al., 2017). The basic idea is as follows: instead of processing all tokens with the standard feed-forward layer, we route each processed token to a small subset of multiple experts. The technique was further simplified by (Fedus et al., 2022) by proposing the Switch Transformer, which sends each token to only one expert with the highest score produced by the controller. The technique allowed them to train a 1.6T model with a T5 architecture with FLOPS cost of an equivalent 1.4B vanilla Transformer. In both cases, auxiliary losses are needed in order to encourage exploration and mitigate load imbalance across experts.
More recently, (Zhou et al., 2022) proposed Expert Choice, where, in contrast to Switch, each expert chooses which token to process. This results in a tradeoff: on the one hand, each expert receives the same number of tokens, sidestepping the load-balancing issue; on the other hand, different tokens might be attended to by varying numbers of experts, and some tokens might not be chosen by any expert. Both approaches, as well as a standard feed-forward Transformer layer, are illustrated in the diagram below.
There are a number of works that try to improve the stability and quality of the controller, including methods based on reinforcement learning (Bengio et al., 2015), routing by hashing (Roller et al., 2021), optimal transport (Clark et al., 2022), and more (Dai et al., 2022; Chi et al., 2022). (Lewis et al., 2021) address the load balancing problem by linear programming while (Riquelme et al., 2021) tries to achieve this by learning to drop unimportant tokens.
Concurrently to our work, (Puigcerver et al. 2023) proposed a continuous variant of Mixture of Experts for the Vision Transformer, limited to encoder-only models where patches are mixed only within each image. Another approach allowing to avoid discrete operations in MoE by merging experts was presented in (Muqeeth et al., 2023).
Method
Let’s say we have a Transformer model, and we would like to increase the parameter count in the feed-forward layers without increasing model runtime. One of the ways to achieve this is to activate only a subset of parameters for a given token - this gives rise to Mixture of Experts architectures. Another way would be to somehow merge the tokens and process them together. This idea lies at the heart of Mixture of Tokens.
Because the mixing operation is continuous, the architecture does not experience the problems present in MoE. This diagram shows an intuitive comparison of Mixture of Tokens with Mixture of Experts.
In the final design, similar to Multi-Head Attention, we sacrifice a single, big representation for multiple, independent, smaller ones: we divide the large feed-forward layer into experts and send a separate mixture to each expert. The resulting network is illustrated here.
How to Mix Tokens
In order to mix tokens for a given group, we need importance weights for each token. To get those, we send each token through the controller (a standard linear layer) and calculate a softmax over the resulting token scores. Note that the weights are calculated independently for each expert - as visible in the diagram.
Now that we have importance weights, we simply multiply each token by its importance weight and add all of them together.
How to Redistribute the Mixed Tokens
Once every mixed token is processed by their respective expert, we redistribute them according to the importance weights we calculated before. See the equations below for details.
How to Group Tokens in MoT
The only question left is how to group tokens - we will show the grouping scheme for autoregressive decoding. While it would be natural to mix tokens within sequences, it would be very inefficient: in vanilla Transformers, nothing is recomputed for tokens already present in the decoded sequence. This allows for very efficient (FLOP-wise) decoding, and we want to keep it that way. Thus, in order to run computations for any given token only once, we group tokens across sequences, i.e., according to position in a sequence. The diagram illustrates the grouping scheme.
While the maximum size of the group is limited by the batch size (number of sequences), note that those two numbers are not coupled together. We can always, if we want to, make groups smaller than the batch size.
Algorithm Summary
The algorithm for computing the output of a Mixture of Tokens layer is as follows:
- Group tokens by position in the sequence
- Independently for each group:
- For each expert independently, calculate importance weights for the tokens in the group
- For each expert independently, calculate the mixed token
- Process each mixed token with the respective expert
- Redistribute the mixed tokens to the original ones using weights from step 2a.
Experiments
Experimental setup
For the baseline, we train a standard GPT-like model on the language modeling task using cross-entropy loss on the C4 dataset (Raffel et al., 2019). Our model replaces all feed-forward layers with Mixture of Tokens layers.
In our proof-of-concept experiments, we train a decoder-only Transformer model with the following hyperparameters:
For the model implementing Mixture of Tokens we choose the following hyperparameters:
When training both the baseline and the Mixture of Token models, we use the following setup:
The learning rate was tuned separately for both our model and the baseline:
Results
Our technique shows very promising results, reducing the required training steps by a factor of 4. The training time gains are also very significant.
Next steps
Scaling Up
Our preliminary experiments suggest that Mixture of Tokens might work even better for larger model sizes. In the upcoming weeks, we aim to prepare a comprehensive comparison of larger models and compare with Mixture of Experts methods. compare them with? compare MoT with?
From Mixture of Tokens to Mixture of Experts
How do we get from MoT to MoE? Assume that the controller in a Mixture of Tokens layer decided to mix in a very particular way: for a given group, it concentrated the entire weight on just one token. In this extreme case, each expert would receive a single, unmixed token. This would make the Mixture of Tokens forward pass equivalent to the Expert Choice described in Background.
This scenario has its advantages: in the default Mixture of Tokens setup for autoregressive training, tokens are aggregated across the batch dimension. However, during decoding, this setup allows for information to be exchanged between different examples. This could be undesirable in some use cases, e.g., when different examples in the same batch come from different users in the industry setting, possibly with privacy issues.
How could we make the controller focus on a single example? One can achieve this by adding a temperature parameter to the softmax operation used by the controller. Low temperature forces the weight distribution to concentrate - in the limit (as the temperature approaches 0), causing the weights to focus exclusively on the token with the highest controller score.
Interestingly, simply allowing the temperature parameter to be learnable for the controller in a Mixture of Tokens layer encourages this phenomenon.
As expected, this results in the controller focusing more on one token. We measured this by monitoring the entropy of weights produced by the controller (averaged over all token groups and all experts).
SA: change naming “Controller Weights” suggest “controller parameters”. Rather it’d be “token weights” or “token importance scores”
Interestingly, this comes at the cost of model performance.
We expect allowing the temperature to be learned at the end of training to be a very promising direction for “private” autoregressive decoding. That way, we would retain all the benefits of training with a high rate of token mixing and prevent token mixing during inference.
Conclusions
We have shown the preliminary results showing the promise of Mixture of Tokens improving the stability of training in comparison with MoE approaches and decreasing the training time \(3\times\) when compared to the vanilla Transformer. We expect even greater improvements in larger models - more thorough experiments are underway at the moment, and we plan to release the paper with more results in the coming weeks. In the meantime, you can contact us with any feedback you have at llm.random.team@gmail.com.add “or respond to the twitter thread here:”, maybe also “reply on reddit here:” Can somebody check if this link actually works?
Acknowledgements
We would like to express sincere gratitude to Piotr Padlewski and Tomasz Trzciński for valuable feedback and Dagmara Rudzińska for amazing support with graphic design.
Citation Information
Please cite the arXiv version of this work.
References
Footnotes
router is also commonly used in MoE literature.↩︎