
Implementation of the Triangle Multiplicative module, used in Alphafold2 as an efficient way to mix rows or columns of a 2d feature map, as a standalone package for Pytorch

MIT License


Triangle Multiplicative Module - Pytorch

$ pip install triangle-multiplicative-module


import torch
from triangle_multiplicative_module import TriangleMultiplicativeModule

model = TriangleMultiplicativeModule(
    dim = 64,            # feature map dimension
    hidden_dim = 128,    # intermediate dimension size
    mix = 'outgoing'     # either 'ingoing' or 'outgoing'

fmap = torch.randn(1, 256, 256, 64)
mask = torch.ones(1, 256, 256).bool()

model(fmap, mask = mask) # (1, 256, 256, 64)


