Experiment of using Tangent to autodiff triton
MIT License
This library is a proof-of-concept of autodifferentiation for Triton GPU Code using the Tangent source-to-source compiler. Remarkably this project is only roughly 50 LoC.
Here's how it works:
# 1) Define a Triton mathematical function
def fn1(x, y):
a = tl.exp(x)
b = tl.log(tl.expand_dims(y, 1))
c = a + b
return a * b + tl.dot(c, c)
fn1_tt = triton.jit(fn1)
# 2) Give Signature of its backwards (generated function will print out)
def fn1_back(x, y, dz):
pass
# 3) Call Tangent
fn1back_tt = grad(fn1, fn1_back, wrt=(0, 1))
The library then outputs the internal code if you want to take a look and debug.
def dfn1dxy(x, y, b_return=1.0):
a = tl.exp(x)
_b = tl.expand_dims(y, 1)
b = tl.log(_b)
c = a + b
tl_dot_c_c = tl.dot(c, c)
a_times_b = a * b
bx = zeroslike(x)
by = zeroslike(y)
b_b = zeroslike(_b)
bc = zeroslike(c)
bb = zeroslike(b)
ba = zeroslike(a)
btl_dot_c_c = zeroslike(tl_dot_c_c)
ba_times_b = zeroslike(a_times_b)
# Grad of: c = a + b
_ba_times_b = triton_unbroadcast(b_return, a_times_b.shape)
_btl_dot_c_c = triton_unbroadcast(b_return, tl_dot_c_c.shape)
ba_times_b = add_grad(ba_times_b, _ba_times_b)
btl_dot_c_c = add_grad(btl_dot_c_c, _btl_dot_c_c)
_ba2 = triton_unbroadcast(ba_times_b * b, a.shape)
_bb2 = triton_unbroadcast(ba_times_b * a, b.shape)
ba = add_grad(ba, _ba2)
bb = add_grad(bb, _bb2)
_bc = tl.trans(tl.dot(c, tl.trans(btl_dot_c_c)))
_bc2 = tl.dot(tl.trans(c), btl_dot_c_c)
bc = add_grad(bc, _bc)
bc = add_grad(bc, _bc2)
_ba = triton_unbroadcast(bc, a.shape)
_bb = triton_unbroadcast(bc, b.shape)
ba = add_grad(ba, _ba)
bb = add_grad(bb, _bb)
# Grad of: b = tl.log(tl.expand_dims(y, 1))
_b_b = bb / _b
b_b = add_grad(b_b, _b_b)
__b = _b
tl.static_assert(__b.shape[1] == 1)
_by = tl.view(b_b, y.shape)
by = add_grad(by, _by)
# Grad of: a = tl.exp(x)
_a = a
_bx = _a * ba
bx = add_grad(bx, _bx)
return bx, by
You can also use the code directly in a full Triton program.
# Boilerplate load and forward
@triton.jit
def tr_forward(X, Y, Z):
r = tl.arange(0, 16)
r2 = tl.arange(0, 16)
x = tl.load(X + r)
y = tl.load(Y + r2)
z = fn1_tt(x, y)
tl.store(Z + 16 * r2[:, None] + r, z)
# Boilerplate load and backward
@triton.jit
def tr_backward(X, Y, dX, dY, dZ):
r = tl.arange(0, 16)
r2 = tl.arange(0, 16)
x = tl.load(X + r)
y = tl.load(Y + r2)
dz = tl.load(dZ + 16 * r2[:, None] + r)
dx, dy = fn1back_tt(x, y, dz)
tl.store(dX + r, dx)
tl.store(dY + r2, dy)
Should give the same answer as PyTorch
# Torch version for sanity check.
def torch_check(x, y):
a = x.exp()
b = y[:, None].log()
c = a + b
return a * b + c @ c
def test_run():
check(tr_forward, tr_backward, torch_check, x_shape=(16,), y_shape=(16,), z_shape=(16, 16))
print("check succeeded!")
test_run()