Implementation of an Attention layer where each head can attend to more than just one token, using coordinate descent to pick topk. Perhaps the number of tokens to attend to can even be learned.
In the case that experiments above fail, will use the repo for a few other ideas, among them getting coordinate descent routing working for autoregressive transformers.
Ongoing experiments
Update: I don't think the improvements are worth it. The memory usage becomes impractical as the number of iterations goes up as well. I'll keep playing around with topk attention though, because it bothers me that softmax becomes a bottleneck for the tokens far in the future, especially as sequence lengths go above 8k
Update: Using a kernel written in Triton, it is a bit more viable, but still too much if number of iterations is high
Update: by doing recomputes in segments of iterations, now feasible, if it were to actually yields any improvements
$ pip install coordinate-descent-attention
import torch
from coordinate_descent_attention import Transformer
model = Transformer(
num_tokens = 256,
dim = 512,
depth = 2,
seq_len = 2048,
dim_head = 64,
heads = 8,
attn_use_coor_descent = True # set to True to switch from softmax to coordinate descent on qk similarity matrix
).cuda()
x = torch.randint(0, 256, (1, 2048)).cuda()
logits = model(x)
@article{Wright2015CoordinateDA,
title = {Coordinate descent algorithms},
author = {Stephen J. Wright},
journal = {Mathematical Programming},
year = {2015},
volume = {151},
pages = {3-34}
}
@inproceedings{Gupta2021MemoryefficientTV,
title = {Memory-efficient Transformers via Top-k Attention},
author = {Ankit Gupta and Guy Dar and Shaya Goodman and David Ciprut and Jonathan Berant},
booktitle = {SUSTAINLP},
year = {2021}
}
@article{Zhao2019ExplicitST,
title = {Explicit Sparse Transformer: Concentrated Attention Through Explicit Selection},
author = {Guangxiang Zhao and Junyang Lin and Zhiyuan Zhang and Xuancheng Ren and Qi Su and Xu Sun},
journal = {ArXiv},
year = {2019},
volume = {abs/1912.11637}
}
@article{Schmitzer2016StabilizedSS,
title = {Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems},
author = {Bernhard Schmitzer},
journal = {ArXiv},
year = {2016},
volume = {abs/1610.06519}
}