Standalone Product Key Memory module in Pytorch - for augmenting Transformer models
MIT License
Standalone Product Key Memory module for augmenting Transformer models
$ pip install product-key-memory
Replace the feedforwards in a Transformer with the following
import torch
from product_key_memory import PKM
pkm = PKM(
dim = 512,
heads = 4,
dim_head = 128, # keep at 128 for best results
num_keys = 256, # number of subkeys, # values will be num_keys ^ 2
topk = 32 # the top number of subkeys to select
)
x = torch.randn(1, 1024, 512)
mask = torch.ones((1, 1024)).bool()
values = pkm(x, input_mask = mask) # (1, 1024, 512)
To give different learning rates to the value parameters of the product-key-memory network, use the following helper function.
from torch.optim import Adam
from product_key_memory import fetch_pkm_value_parameters
# this helper function, for your root model, finds all the PKM models and the embedding bag weight parameters
pkm_parameters, other_parameters = fetch_pkm_value_parameters(model)
optim = Adam([
{'params': other_parameters},
{'params': pkm_parameters, 'lr': 1e-2}
], lr=1e-3)
Or, if product-key-memory parameters are the only other parameters you have a different learning rate for
from torch.optim import Adam
from product_key_memory import fetch_optimizer_parameters
parameters = fetch_optimizer_parameters(model) # automatically creates array of parameter settings with learning rate set at 1e-2 for pkm values
optim = Adam(parameters, lr=1e-3)
Special thanks go to Aran for encouraging me to look into this, and to Madison May for his educational blog post, which helped me understand this better.
offer stochasticity with annealed gumbel noise. seen dramatic effects in vector-quantization setting
offer a way for smaller value dimensions + concat and linear combination of heads (like multi-head attention)
get caught up on latest literature on product key memories, if any
instead of additive scores, try multiplicative using coordinate descent routing
@misc{lample2019large,
title = {Large Memory Layers with Product Keys},
author = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
year = {2019},
eprint = {1907.05242},
archivePrefix = {arXiv}
}
@misc{liu2020evolving,
title = {Evolving Normalization-Activation Layers},
author = {Hanxiao Liu and Andrew Brock and Karen Simonyan and Quoc V. Le},
year = {2020},
eprint = {2004.02967},
archivePrefix = {arXiv}
}
@article{Shen2023ASO,
title = {A Study on ReLU and Softmax in Transformer},
author = {Kai Shen and Junliang Guo and Xuejiao Tan and Siliang Tang and Rui Wang and Jiang Bian},
journal = {ArXiv},
year = {2023},
volume = {abs/2302.06461},
url = {https://api.semanticscholar.org/CorpusID:256827573}
}
@article{Csordas2023ApproximatingTF,
title = {Approximating Two-Layer Feedforward Networks for Efficient Transformers},
author = {R'obert Csord'as and Kazuki Irie and J{\"u}rgen Schmidhuber},
journal = {ArXiv},
year = {2023},
volume = {abs/2310.10837},
url = {https://api.semanticscholar.org/CorpusID:264172384}
}