Visualize, create, and operate on pytrees in the most intuitive way possible.
APACHE-2.0 License
Bot releases are hidden (Show)
threads_count
in apply
parallel kwargs to max_workers
Full Changelog: https://github.com/ASEM000/pytreeclass/compare/v0.9.1...v0.9.2
AtIndexer
. This enables myriad of tasks, like reading a pytree of image file names.# benchmarking serial vs sequential image read
# on mac m1 cpu with image of size 512x512x3
import pytreeclass as tc
from matplotlib.pyplot import imread
paths = ["lenna.png"] * 10
indexer = tc.AtIndexer(paths)
%timeit indexer[...].apply(imread,parallel=True) # parallel
# 24.9 ms ± 938 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit indexer[...].apply(imread) # not parallel
# # 84.8 ms ± 453 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Published by ASEM000 about 1 year ago
tree_repr_with_trace
tree_map_with_trace
tree_flatten_with_trace
tree_leaves_with_trace
Full Changelog: https://github.com/ASEM000/pytreeclass/compare/v0.8.0...v0.9
Published by ASEM000 about 1 year ago
on_getattr
in field
to apply function on __getattr__
callbacks
in field
to on_setattr
to match attrs
and better reflect its functionality.These changes enable:
stricter data validation on instance values, as in the following example:
on_setattr
ensure the value is of certain type (e.g.integer) during initialization, and on_getattr
, ensure the value is of certain type (e.g. integer) whenever its accessed.
import pytreeclass as pytc
import jax
def assert_int(x):
assert isinstance(x, int), "must be an int"
return x
@pytc.autoinit
class Tree(pytc.TreeClass):
a: int = pytc.field(on_getattr=[assert_int], on_setattr=[assert_int])
def __call__(self, x):
# enusre `a` is an int before using it in computation by calling `assert_int`
a: int = self.a
return a + x
tree = Tree(a=1)
print(tree(1.0)) # 2.0
tree = jax.tree_map(lambda x: x + 0.0, tree) # make `a` a float
tree(1.0) # AssertionError: must be an int
Frozen field without using tree_mask
/tree_unmask
The following shows a pattern where the value is frozen on __setattr__
and unfrozen whenever accessed, this ensures that jax
transformation does not see the value. the following example showcase this functionality
import pytreeclass as pytc
import jax
@pytc.autoinit
class Tree(pytc.TreeClass):
frozen_a : int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze])
def __call__(self, x):
return self.frozen_a + x
tree = Tree(frozen_a=1) # 1 is non-jaxtype
# can be used in jax transformations
@jax.jit
def f(tree, x):
return tree(x)
f(tree, 1.0) # 2.0
grads = jax.grad(f)(tree, 1.0) # Tree(frozen_a=#1)
Compared with other libraies that implements static_field
, this pattern has lower overhead and does not alter tree_flatten
/tree_unflatten
methods of the tree.
Easier way to create a buffer (non-trainable array)
Just use jax.lax.stop_gradient
in on_getattr
import pytreeclass as pytc
import jax
import jax.numpy as jnp
def assert_array(x):
assert isinstance(x, jax.Array)
return x
@pytc.autoinit
class Tree(pytc.TreeClass):
buffer: jax.Array = pytc.field(on_getattr=[jax.lax.stop_gradient],on_setattr=[assert_array])
def __call__(self, x):
return self.buffer**x
tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
tree(2.0) # Array([1., 4., 9.], dtype=float32)
@jax.jit
def f(tree, x):
return jnp.sum(tree(x))
f(tree, 1.0) # Array([1., 2., 3.], dtype=float32)
print(jax.grad(f)(tree, 1.0)) # Tree(buffer=[0. 0. 0.])
Published by ASEM000 about 1 year ago
.at
as an alias for __getitem__
when specifying a path entry for where in AtIndexer
. This leads to less verbose style.Example:
>>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300}
>>> tree = pytc.AtIndexer(tree)
>>> # Before:
>>> # style 1 (with at):
>>> tree.at["level1_0"].at["level2_0", "level2_1"].get()
{'level1_0': {'level2_0': 100, 'level2_1': 200}, 'level1_1': None}
>>> # style 2 (no at):
>>> tree["level1_0"]["level2_0", "level2_1"].get()
>>> # After
>>> # only style 2 is valid
>>> tree["level1_0"]["level2_0", "level2_1"].get()
For TreeClass
at
is specified once for each change
@pytc.autoinit
class Tree(pytc.TreeClass):
a: float = 1.0
b: tuple[float, float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])
def __call__(self, x):
return self.a + self.b[0] + self.c + x
tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
.at["a"].set(100.0)\
- .at["b"].at[0].set(10.0)\
+ .at["b"][0].set(10.0)\
.at[mask].set(100.0)
Published by ASEM000 about 1 year ago
tree_{repr,str}
with an object containing cyclic references will raise RecursionError
instead of displaying cyclicref.Published by ASEM000 about 1 year ago
Allow nested mutations using .at[method](*args, **kwargs)
.
After the change, inner methods can mutate copied new instances at any level not just the top level.
a motivation for this is to experiment with lazy initialization scheme, where inner layers need to mutate their inner state. see the example below for flax
-like lazy initialization as descriped here
import pytreeclass as pytc
import jax.random as jr
from typing import Any
import jax
import jax.numpy as jnp
from typing import Callable, TypeVar
T = TypeVar("T")
@pytc.autoinit
class LazyLinear(pytc.TreeClass):
outdim: int
weight_init: Callable[..., T] = jax.nn.initializers.glorot_normal()
bias_init: Callable[..., T] = jax.nn.initializers.zeros
def param(self, name: str, init_func: Callable[..., T], *args) -> T:
if name not in vars(self):
setattr(self, name, init_func(*args))
return vars(self)[name]
def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
w = self.param("weight", self.weight_init, key, (x.shape[-1], self.outdim))
y = x @ w
if self.bias_init is not None:
b = self.param("bias", self.bias_init, key, (self.outdim,))
return y + b
return y
@pytc.autoinit
class StackedLinear(pytc.TreeClass):
l1: LazyLinear = LazyLinear(outdim=10)
l2: LazyLinear = LazyLinear(outdim=1)
def call(self, x: jax.Array):
return self.l2(jax.nn.relu(self.l1(x)))
lazy_layer = StackedLinear()
print(repr(lazy_layer))
# StackedLinear(
# l1=LazyLinear(
# outdim=10,
# weight_init=init(key, shape, dtype),
# bias_init=zeros(key, shape, dtype)
# ),
# l2=LazyLinear(
# outdim=1,
# weight_init=init(key, shape, dtype),
# bias_init=zeros(key, shape, dtype)
# )
# )
_, materialized_layer = lazy_layer.at["call"](jnp.ones((1, 5)))
materialized_layer
# StackedLinear(
# l1=LazyLinear(
# outdim=10,
# weight_init=init(key, shape, dtype),
# bias_init=zeros(key, shape, dtype),
# weight=f32[5,10](μ=-0.04, σ=0.32, ∈[-0.74,0.63]),
# bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00])
# ),
# l2=LazyLinear(
# outdim=1,
# weight_init=init(key, shape, dtype),
# bias_init=zeros(key, shape, dtype),
# weight=f32[10,1](μ=-0.07, σ=0.23, ∈[-0.34,0.34]),
# bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
# )
# )
materialized_layer(jnp.ones((1, 5)))
# Array([[0.16712935]], dtype=float32)
Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.5...v0.6.0
Published by ASEM000 about 1 year ago
Fix __init_subclass__
not accepting arguments. Bug introduced since v0.5
Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.5...v0.0.5post0
Published by ASEM000 about 1 year ago
PyTreeClass
v0.5__init__
method from type hints is decoupled from TreeClass
Alternatives
Use:
pytreeclass.autoinit
with pytreeclass.field
as field specifier. as pytreeclass.field
has more features (e.g. callbacks
, multiple argument kind selection) and the init generation is cached compared to dataclasses
.dataclasses.dataclass
with dataclasses.field
as field specifier. however :
fronzen=False
because the __setattr__
, __delattr__
is handled by TreeClass
repr=False
to be handled by TreeClass
eq=hash=False
as it is handled by TreeClass
import jax.tree_util as jtu
import pytreeclass as pytc
import dataclasses as dc
class Tree(pytc.TreeClass):
a: int = 1
jtu.tree_leaves(Tree())
# [1]
Equivalent behavior when decorating with either:
@pytreeclass.autoinit
@dataclasses.dataclass
import jax.tree_util as jtu
import pytreeclass as pytc
@pytc.autoinit
class Tree(pytc.TreeClass):
a: int = 1
jtu.tree_leaves(Tree())
# [1]
This change aims to fix the ambiguity of using the dataclass
mental model in the following siutations:
subclassing. previously, using TreeClass
as a base class is equivalent to decorating the class with dataclasses.dataclass
, however this is a bit challenging to understand as demonstrated in the next example:
import pytreeclass as pytc
import dataclasses as dc
class A(pytc.TreeClass):
def ___init__(self, a:int):
self.a = a
class B(A):
...
When instantiating B(a=...)
, an error will be raised, because using TreeClass
is equivalent of decorating all classes with @dataclass
, which synthesize the __init__
method based on the fields.
Since no fields (e.g. type hinted values) then the synthesized __init__
method .
The previous code is equivalent to this code.
@dc.dataclass
class A:
def __init__(self, a:int):
self.a = a
@dc.dataclass
class B:
...
dataclass_transform
does not play nicely with user created __init__
see 1, 2
leafwise_transform
is decoupled from TreeClass
.instead decorate the class with pytreeclass.leafwise
.
Published by ASEM000 over 1 year ago
PyTreeClass
v0.4User-provided re.Pattern
is used to match keys with regex pattern instead of using RegexKey
Example:
import pytreeclass as pytc
import re
tree = {"l1":1, "l2":2, "b":3}
tree = pytc.AtIndexer(tree)
tree.at[re.compile("l.*")].get()
# {'b': None, 'l1': 1, 'l2': 2}
RegexKey
is deprecated. use re
compiled patterns instead.tree_indent
is deprecated. use tree_diagram(tree).replace(...)
to replace the edges characters with spaces.Add tree_mask
, tree_unmask
to freeze/unfreeze tree leaves based on a callable/boolean pytree mask. defaults to masking non-inexact types by frozen wrapper.
Example: Pass non-jax
types through jax
transformation without error.
# pass non-differentiable values to `jax.grad`
import pytreeclass as pytc
import jax
@jax.grad
def square(tree):
tree = pytc.tree_unmask(tree)
return tree[0]**2
tree = (1., 2) # contains a non-differentiable node
square(pytc.tree_mask(tree))
# (Array(2., dtype=float32, weak_type=True), #2)
Support extending match keys by adding abstract base class BaseKey
. check docstring for example
Support multi-index by any acceptable form. e.g. boolean pytree, key, int, or BaseKey
instance
Example:
import pytreeclass as pytc
tree = {"l1":1, "l2":2, "b":3}
tree = pytc.AtIndexer(tree)
tree.at["l1","l2"].get()
# {'b': None, 'l1': 1, 'l2': 2}
add scan
to AtIndexer
to carry a state while applying a function.
Example:
import pytreeclass as pytc
def scan_func(leaf, state):
# increase the state by 1 for each function call
return leaf**2, state+1
tree = {"l1": 1, "l2": 2, "b": 3}
tree = pytc.AtIndexer(tree)
tree, state = tree.at["l1", "l2"].scan(scan_func, 0)
state
# 2
tree
# {'b': 3, 'l1': 1, 'l2': 4}
tree_summary
improvements.
tree_summary
.def_count
to dispatch count rule for type.def_size
to dispatch size rule for type.def_type
to dispatch type display.Example:
import pytreeclass as pytc
import jax.numpy as jnp
x = jnp.ones((5, 5))
print(pytc.tree_summary([1, 2, 3, x]))
# ┌────┬────────┬─────┬───────┐
# │Name│Type │Count│Size │
# ├────┼────────┼─────┼───────┤
# │[0] │int │1 │ │
# ├────┼────────┼─────┼───────┤
# │[1] │int │1 │ │
# ├────┼────────┼─────┼───────┤
# │[2] │int │1 │ │
# ├────┼────────┼─────┼───────┤
# │[3] │f32[5,5]│25 │100.00B│
# ├────┼────────┼─────┼───────┤
# │Σ │list │28 │100.00B│
# └────┴────────┴─────┴───────┘
# make list display its number of elements
# in the type row
@pytc.tree_summary.def_type(list)
def _(_: list) -> str:
return f"List[{len(_)}]"
print(pytc.tree_summary([1, 2, 3, x]))
# ┌────┬────────┬─────┬───────┐
# │Name│Type │Count│Size │
# ├────┼────────┼─────┼───────┤
# │[0] │int │1 │ │
# ├────┼────────┼─────┼───────┤
# │[1] │int │1 │ │
# ├────┼────────┼─────┼───────┤
# │[2] │int │1 │ │
# ├────┼────────┼─────┼───────┤
# │[3] │f32[5,5]│25 │100.00B│
# ├────┼────────┼─────┼───────┤
# │Σ │List[4] │28 │100.00B│
# └────┴────────┴─────┴───────┘
Export pytrees to dot language using tree_graph
# define custom style for a node by dispatching on the value
# the defined function should return a dict of attributes
# that will be passed to graphviz.
import pytreeclass as pytc
tree = [1, 2, dict(a=3)]
@pytc.tree_graph.def_nodestyle(list)
def _(_) -> dict[str, str]:
return dict(shape="circle", style="filled", fillcolor="lightblue")
dot_graph = graphviz.Source(pytc.tree_graph(tree))
dot_graph
Add variable position arguments and variable keyword arguments to pytc.field
kind
import pytreeclass as pytc
class Tree(pytc.TreeClass):
a: int = pytc.field(kind="VAR_POS")
b: int = pytc.field(kind="POS_ONLY")
c: int = pytc.field(kind="VAR_KW")
d: int
e: int = pytc.field(kind="KW_ONLY")
Tree.__init__
# <function __main__.Tree.__init__(self, b: int, /, d: int, *a: int, e: int, **c: int) -> None>
This release introduces lots of functools.singledispatch
usage, to enable greater customization.
{freeze,unfreeze,is_nondiff}.def_type
to define how to freeze
a type, how to unfreeze it and whether it is considred nondiff or not. these rules are used by these functions and tree_mask
/tree_unmask
.tree_graph.def_nodestyle
, tree_summary.def_{count,type,size}
for pretty printing customizationBaseKey.def_alias
to define type alias usage inside AtIndexer
/.at
Published by ASEM000 over 1 year ago
Example:
Update all leaves starting with weight_
import pytreeclass as pytc
class Tree(pytc.TreeClass):
weight_1: float = 1.0
weight_2: float = 2.0
weight_3: float = 3.0
bias: float = 0.0
tree = Tree()
tree.at[pytc.RegexKey(r"weight_.*")].set(100.0)
# Tree(weight_1=100.0, weight_2=100.0, weight_3=100.0, bias=0.0)
Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.3.7...v0.3.8
Published by ASEM000 over 1 year ago
Published by ASEM000 over 1 year ago
Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.3.5...v0.3.6
Published by ASEM000 over 1 year ago
Partial
(used for bcmap
) for jaxable partial functions with positional arguments support >>> import pytreeclass as pytc
>>> def f(a, b, c):
... print(f"a: {a}, b: {b}, c: {c}")
... return a + b + c
>>> # positional arguments using `...` placeholder
>>> f_a = pytc.Partial(f, ..., 2, 3)
>>> f_a(1)
a: 1, b: 2, c: 3
6
pytc.AtIndexer
instead of tree_indexer
Published by ASEM000 over 1 year ago
Published by ASEM000 over 1 year ago
fields
by @ASEM000 in https://github.com/ASEM000/PyTreeClass/pull/60
Motivation
Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.3.1...v0.3.2
Published by ASEM000 over 1 year ago
flax
and any type registered with jax key registry out of the boxFull Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.3.0...v0.3.1
Published by ASEM000 over 1 year ago
For better typing, the decorator-based approach is changed to subclassing approach.
Old
@functools.partial(pytc.treeclass, leafwise=True)
class Tree:
pass
New
class Tree(pytc.TreeClass, leafwise=True):
pass
is_treeclass
is removed use isinstance(..., TreeClass)
insteadpytc.fields
is removed.Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.2.8...v0.3.0
Published by ASEM000 over 1 year ago
Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.2.7...v0.2.8
Published by ASEM000 over 1 year ago
Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.2.6...v0.2.7