tranception-pytorch

Implementation of Tranception, an attention network, paired with retrieval, that is SOTA for protein fitness prediction

MIT License

Downloads
69
Stars
31
Committers
1

Tranception - Pytorch (wip)

Implementation of Tranception, an attention network, paired with retrieval, that is SOTA for protein fitness prediction. The Transformer architecture is inspired by Primer, and uses ALiBi relative positional encoding

Install

$ pip install tranception-pytorch

Usage

import torch
from tranception_pytorch import Tranception

model = Tranception(
    dim = 512,
    depth = 6,
    heads = 8,
    dim_head = 64
)

amino_acids = torch.randint(0, 21, (1, 512))

logits = model(amino_acids) # (1, 512, 21)

Todo

  • grouped heads with customizable depthwise convs (for variable k-mers), as well as grouped alibi pos bias
  • figure out attention to retrieved (looks like axial attention?)
  • play around with protein gym, and start betting on huggingface's accelerate

Citations

@article{Notin2022TranceptionPF,
  title   = {Tranception: protein fitness prediction with autoregressive transformers and inference-time retrieval},
  author  = {Pascal Notin and Mafalda Dias and Jonathan Frazer and Javier Marchena-Hurtado and Aidan N. Gomez and Debora S. Marks and Yarin Gal},
  journal = {ArXiv},
  year    = {2022},
  volume  = {abs/2205.13760}
}
Package Rankings
Top 19.97% on Pypi.org
Related Projects