All Articles

Article List

Intro

What made GPT3.5 and GPT4 completely destroy all the competition? Since “Open”AI’s closed source models make it hard to dissect, we really have no way of knowing for sure. However, a certain technique known as Mixture of Experts (MoE) seems to be exactly what brought them ahead of the game. Recently, MistralAI released their own model Mixtral8x7b which uses MoE in order to beat Llama 2 70b as well as GPT3.5 on many benchmarks. It proves that MoE might be what was missing from LLMs.

What is it?

What exactly is Mixture of Experts? Simply put, it’s just choosing the right parameters for the job. Imagine you were given either a Math, English, or History problem. To solve it, you require help from a Math, English, or History teacher. This means that if you were to get a History problem, you would obviously bring it to your History teacher. Believe it or not, this is almost what Mixture of Experts is doing. You just acted like what is called the “router” in MoE; you routed the input (the problem) to the correct network (the teacher).

Now, how can we translate this to language modeling? The simplest approach would be to just create entirely different models for each specific task -- I. E. one for reading comprehension, one for math, etc. Research has already shown that smaller language models focused on a single task can outperform more general LLMs. So, it makes sense that multiple smaller models combined could potentially outperform a larger LLM. This would mean that if you were to have an 80 billion parameter LLM, you could theoretically instead make eight 10 billion parameter LLMs on eight different subjects and switch between them using some kind of gating network.

Of course, it is not actually that simple. Rather than manually creating several task-specific models, MoE uses a kind of sparse layer that replaces normal feedforward layers and trains its own experts. Each of these MoE layers have their own gating network that decides which expert to use during inference.

This is a representation of the MoE architecture. Rather than training each expert individually, they are trained as an ensemble. The gating network chooses an expert, and then the backpropagation tweaks both the gating network and expert network it chose based on its performance.

The other interesting thing about these expert layers is that they tend to learn things that are completely opposite to what we might expect. Rather than each expert learning a certain subject, experiments have shown that they instead learn more grammar and syntax than anything else. For example, one expert may specialize in nouns, while another may specialize in verbs.

The Math

The gating network, Wgate , is applied to our input x. This gives some logits that represent how good each expert should be for the given input. We then apply top-k gating. For example, if k=2, then it would take the top 2 highest logits. Then, it will apply a softmax to the highest logits and multiply it by the outputs of their respective experts. Of course, if an expert doesn’t make the cut then we don’t need to calculate its output, thus saving inference time as well as computational resources.

 logits = x * W_gate
      gate = softmax(topk(logits))
      
      y = Σ expert(x) * gate_x
      
      (If gate_x is zero, then we don't need to calculate expert(x)) 

One of the biggest issues that you may encounter with MoE is load balancing -- how do you ensure that each expert is actually used? The gating network might only choose one expert the entire time, leaving the rest unused. It will then enter a feedback loop: only one expert actually got trained, so the gating network learned to use the trained expert over all the useless untrained ones. But, since the gating network never chooses the useless untrained experts, they never actually get trained and become useful.

One method of solving this is noisy top-k gating. Noisy top-k gating is pretty easy to understand -- it just adds some tunable noise on tops of the logits so that it doesn’t always choose the same one expert while training.

 logits = x * W_gate + (Noise() * Softplus(x * W_noise)) 

Combined, these allow Mixture of Experts to be incredibly efficient yet powerful. However, it isn’t completely perfect. Although the inference time per time is greatly reduced since a majority of the dense layers aren’t even being run, all of its memory needs to be loaded into RAM. For example, Mixtral8x7b has around 56 billion parameters. However, it only uses 7 billion parameters during inference. The issue is that you still would need to load all 56 billion parameters into RAM even though you are only using 7 billion for each token. Still, it's a greatly improved architecture.

The other great thing about Mixture of Experts is its affinity for hardware parallelization. The premise is that each expert gets their own worker. These workers can be on multiple different GPUs, for example. Each token is sent to the particular worker that contains its most optimal expert. It also does of course introduce the overhead that comes with splitting up data into multiple workers.

Code

Implementing a simple version of Mixture of Experts is actually quite simple.

  1. Make an Expert layer. These will be the “experts” the gating will choose from. Most of the time it will just be a dense layer

 class ExpertLayer(nn.Module):
          def __init__(self, input_size, output_size, hidden_size=4):
              super(ExpertLayer, self).__init__()
              self.fc1 = nn.Linear(input_size, hidden_size)
              self.fc2 = nn.Linear(hidden_size, output_size)
              self.relu = nn.ReLU()
      
      
          def forward(self, x):
              x = self.fc1(x)
              x = self.relu(x)
              x = self.fc2(x)
              x = self.relu(x)
              return x 

2. Define the MoE layer; we will need to keep track of both our experts as well as the parameters that control gating and noise

 class MoE(nn.Module):
          def __init__(self, input_size, output_size, hidden_size, num_experts, topk=2):
              super(MoE, self).__init__()
              self.input_size = input_size
              self.output_size = output_size
              self.hidden_size = hidden_size
              self.num_experts = num_experts
              self.topk = topk
      
      
              self.experts = nn.ModuleList([ExpertLayer(input_size, output_size,               hidden_size) for i in range(num_experts)])
              self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True)
              self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts),                     requires_grad=True)
              self.softmax = nn.Softmax()
              self.softplus = nn.Softplus() 

3. Top K Gating. This will use our gating weights in order to come up with some logits, and then apply a softmax to the top 2 logits.

 def topKGating(self, x):
              # multiply x by gating weights
              logits = (x @ self.w_gate)
      
      
              raw_noise_stddev = x @ self.w_noise
              noise_stddev = ((self.softplus(raw_noise_stddev) + 1e-2))
              noisy_logits = logits + (torch.randn_like(logits) * noise_stddev)
              logits = noisy_logits
      
      
              # find the top-k values
              topk_gate, topk_index = logits.topk(self.topk, dim=1)
              topk_gate = topk_gate[:, :self.topk]
              topk_index = topk_index[:, :self.topk]
              topk_gate = self.softmax(topk_gate)
      
      
              # re-insert the topk values into the logits while setting everything else as zero
              zeros = torch.zeros_like(logits, requires_grad=True)
              gates = zeros.scatter(1, topk_index, topk_gate)
      
      
              return gates 

4. Finally, the forward function. You might notice some extra stuff here; this basically splits up the single batch that is passed into the forward into singular inputs that will be passed to each expert. We need this because for each input, there will be a different expert applied. This means that high batches aren’t very optimal for MoE.

 def forward(self, x):
              gates = self.topKGating(x)
              print(gates.shape)
              # create mini-batches (adapted from https://github.com/davidmrau/mixture-of-experts)
             
              sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0)
              print(sorted_experts.shape, index_sorted_experts.shape)
              # drop indices
              _, expert_index = sorted_experts.split(1, dim=1)
              # get according batch index for each expert
              batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0]
              # calculate num samples that each expert gets
              part_sizes = (gates > 0).sum(0).tolist()
              gates_exp = gates[batch_index.flatten()]
              nonzero_gates = torch.gather(gates_exp, 1, expert_index)
      
      
              exp_inputs = x[batch_index].squeeze(1)
              exp_inputs = torch.split(exp_inputs, part_sizes, dim=0)
              expert_outputs = [self.experts[i](exp_inputs[i]) for i in range(self.num_experts)]
      
      
              stitched = torch.cat(expert_outputs, 0)
              stitched = stitched.mul(nonzero_gates)
      
      
              zeros = torch.zeros(gates.size(0), expert_outputs[-1].size(1), requires_grad=True, device=stitched.device)
              combined = zeros.index_add(0, batch_index, stitched.float())
      
      
              return combined 

Conclusion

The frequency and effectiveness of new contributions to AI make it an incredibly fast growing and improving field. Each one of these build atop each other and create the ability to make some incredible things. However, I think that the real enemy of progress is going to inevitably be the closed-source nature of larger companies. After all, their billions of dollars in AI research and training has to pay itself back somehow. In the end, it's up to the gratuity of smaller organizations and companies to keep open source on par with closed source. Luckily, with newer and more potent architectures like MoE, it doesn’t seem that impossible. Whether or not closed source or open source comes out on top, it remains to be seen; for now we get to sit back and watch the battle.