
Implementation of Soft MoE, proposed by Brain's Vision team, in Pytorch

Implementation of Soft MoE (Mixture of Experts), proposed by Brain's Vision team, in Pytorch.

This MoE has only been made to work with non-autoregressive encoder. However, some recent text-to-image models have started using MoE with great results, so may be a fit there.

If anyone has any ideas for how to make it work for autoregressive, let me know (through email or discussions). I meditated on it but can't think of a good way. The other issue with the slot scheme is that the routing suffers the quadratic as sequence length increases (much like attention)


$ pip install soft-moe-pytorch


import torch
from soft_moe_pytorch import SoftMoE

moe = SoftMoE(
    dim = 512,         # model dimensions
    seq_len = 1024,    # max sequence length (will automatically calculate number of slots as seq_len // num_experts) - you can also set num_slots directly
    num_experts = 4    # number of experts - (they suggest number of experts should be high enough that each of them get only 1 slot. wonder if that is the weakness of the paper?)

x = torch.randn(1, 1024, 512)

out = moe(x) + x # (1, 1024, 512) - add in a transformer in place of a feedforward at a certain layer (here showing the residual too)

For an improvised variant that does dynamic slots so that number of slots ~= sequence length, just import DynamicSlotsSoftMoe instead

import torch
from soft_moe_pytorch import DynamicSlotsSoftMoE

# sequence length or number of slots need not be specified

moe = DynamicSlotsSoftMoE(
    dim = 512,         # model dimensions
    num_experts = 4,   # number of experts
    geglu = True

x = torch.randn(1, 1023, 512)

out = moe(x) + x # (1, 1023, 512)


  • address the limitation of number of slots being fixed. think about a way to make dynamic number of slots based on sequence length
  • once variable sequence length is handled in distributed, add to dynamic soft moe
  • the dispatch and combine tensors can also be split and moved into the Experts class to better distribute work


