g-mlp-gpt

GPT, but made only out of MLPs

MIT License

Downloads
274
Stars
86
Committers
1

GPT - gMLP

This repository will attempt to crack long context autoregressive language modeling (GPT) using variations of gMLPs. Specifically, it will contain a variant that does gMLP for local sliding windows. The hope is to be able to stretch a single GPU to be able to train context lengths of 4096 and above efficiently and well.

You can also add the "tiny" attention (as described in the paper) with the attn_dim keyword argument, which corresponds to the dimension of the single head (64 is recommended). You can pass in a tuple to customize different dimension per layer.

Install

$ pip install g-mlp-gpt

Usage

import torch
from g_mlp_gpt import gMLPGPT

model = gMLPGPT(
    num_tokens = 20000,
    dim = 512,
    depth = 4,
    seq_len = 1024,
    window = (128, 256, 512, 1024) # window sizes for each depth
)

x = torch.randint(0, 20000, (1, 1000))
logits = model(x) # (1, 1000, 20000)

16k context length

import torch
from g_mlp_gpt import gMLPGPT

model = gMLPGPT(
    num_tokens = 20000,
    dim = 512,
    seq_len = 16384,
    reversible = True,    # reversible networks
    act = nn.Tanh(),      # tanh activation for spatial gating
    depth = 12,
    window = (
        128,
        128,
        256,
        512,
        1024,
        1024,
        (2048, 2),    # window size of 2048, axial of 2
        (2048, 2),
        (4096, 4),
        (4096, 4),
        (8192, 8),    # window size of 8192, axial of 8
        (8192, 8)
    )
).cuda()

x = torch.randint(0, 20000, (1, 16384)).cuda()
logits = model(x) # (1, 16384, 20000)

Citations

@misc{liu2021pay,
    title   = {Pay Attention to MLPs}, 
    author  = {Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
    year    = {2021},
    eprint  = {2105.08050},
    archivePrefix = {arXiv},
    primaryClass = {cs.LG}
}
Package Rankings
Top 14.13% on Pypi.org
Related Projects