pytreeclass

Visualize, create, and operate on pytrees in the most intuitive way possible.

APACHE-2.0 License

Downloads
2.6K
Stars
38
Committers
2

Bot releases are hidden (Show)

pytreeclass - v0.9.2 Latest Release

Published by ASEM000 about 1 year ago

v0.9.2

Changes:

  • change threads_count in apply parallel kwargs to max_workers

Full Changelog: https://github.com/ASEM000/pytreeclass/compare/v0.9.1...v0.9.2

pytreeclass -

Published by ASEM000 about 1 year ago

v0.9.1

Additions:

  • Add parallel mapping option in AtIndexer. This enables myriad of tasks, like reading a pytree of image file names.
# benchmarking serial vs sequential image read
# on mac m1 cpu with image of size 512x512x3
import pytreeclass as tc
from matplotlib.pyplot import imread
paths = ["lenna.png"] * 10
indexer = tc.AtIndexer(paths)
%timeit indexer[...].apply(imread,parallel=True)  # parallel
# 24.9 ms ± 938 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit indexer[...].apply(imread)  # not parallel
# # 84.8 ms ± 453 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
pytreeclass - v0.9

Published by ASEM000 about 1 year ago

V0.9

Breaking changes:

  • To simplify the API the following will be removed:
    1. tree_repr_with_trace
    2. tree_map_with_trace
    3. tree_flatten_with_trace
    4. tree_leaves_with_trace

What's Changed

Full Changelog: https://github.com/ASEM000/pytreeclass/compare/v0.8.0...v0.9

pytreeclass - v0.8.0

Published by ASEM000 about 1 year ago

V0.8

Additions:

  • Add on_getattr in field to apply function on __getattr__

Breaking changes:

  • Rename callbacks in field to on_setattr to match attrs and better reflect its functionality.

These changes enable:

  1. stricter data validation on instance values, as in the following example:

    on_setattr ensure the value is of certain type (e.g.integer) during initialization, and on_getattr, ensure the value is of certain type (e.g. integer) whenever its accessed.

    
    import pytreeclass as pytc
    import jax
    
    def assert_int(x):
        assert isinstance(x, int), "must be an int"
        return x
    
    @pytc.autoinit
    class Tree(pytc.TreeClass):
        a: int = pytc.field(on_getattr=[assert_int], on_setattr=[assert_int])
    
        def __call__(self, x):
            # enusre `a` is an int before using it in computation by calling `assert_int`
            a: int = self.a
            return a + x
    
    tree = Tree(a=1)
    print(tree(1.0))  # 2.0
    tree = jax.tree_map(lambda x: x + 0.0, tree)  # make `a` a float
    tree(1.0)  # AssertionError: must be an int
    
  2. Frozen field without using tree_mask/tree_unmask

    The following shows a pattern where the value is frozen on __setattr__ and unfrozen whenever accessed, this ensures that jax transformation does not see the value. the following example showcase this functionality

    import pytreeclass as pytc
    import jax
    
    @pytc.autoinit
    class Tree(pytc.TreeClass):
        frozen_a : int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze])
    
        def __call__(self, x):
            return self.frozen_a + x
    
    tree = Tree(frozen_a=1)  # 1 is non-jaxtype
    # can be used in jax transformations
    
    @jax.jit
    def f(tree, x):
        return tree(x)
    
    f(tree, 1.0)  # 2.0
    
    grads = jax.grad(f)(tree, 1.0)  # Tree(frozen_a=#1)
    

    Compared with other libraies that implements static_field, this pattern has lower overhead and does not alter tree_flatten/tree_unflatten methods of the tree.

  3. Easier way to create a buffer (non-trainable array)

    Just use jax.lax.stop_gradient in on_getattr

    import pytreeclass as pytc
    import jax
    import jax.numpy as jnp
    
    def assert_array(x):
        assert isinstance(x, jax.Array)
        return x
    
    @pytc.autoinit
    class Tree(pytc.TreeClass):
        buffer: jax.Array = pytc.field(on_getattr=[jax.lax.stop_gradient],on_setattr=[assert_array])
        def __call__(self, x):
            return self.buffer**x
        
    tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
    tree(2.0)  # Array([1., 4., 9.], dtype=float32)
    @jax.jit
    def f(tree, x):
        return jnp.sum(tree(x))
    
    f(tree, 1.0)  # Array([1., 2., 3.], dtype=float32)
    print(jax.grad(f)(tree, 1.0))  # Tree(buffer=[0. 0. 0.])
    
pytreeclass - v0.7.0

Published by ASEM000 about 1 year ago

Changelog

v0.7

  • Remove .at as an alias for __getitem__ when specifying a path entry for where in AtIndexer. This leads to less verbose style.

Example:


>>> tree = {"level1_0": {"level2_0": 100, "level2_1": 200}, "level1_1": 300}
>>> tree = pytc.AtIndexer(tree)

>>> # Before:
>>> # style 1 (with at):
>>> tree.at["level1_0"].at["level2_0", "level2_1"].get()
{'level1_0': {'level2_0': 100, 'level2_1': 200}, 'level1_1': None}
>>> # style 2 (no at):
>>> tree["level1_0"]["level2_0", "level2_1"].get()

>>> # After
>>> # only style 2 is valid
>>> tree["level1_0"]["level2_0", "level2_1"].get()

For TreeClass

at is specified once for each change

@pytc.autoinit
class Tree(pytc.TreeClass):
    a: float = 1.0
    b: tuple[float, float] = (2.0, 3.0)
    c: jax.Array = jnp.array([4.0, 5.0, 6.0])

    def __call__(self, x):
        return self.a + self.b[0] + self.c + x


tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
       .at["a"].set(100.0)\
-       .at["b"].at[0].set(10.0)\
+      .at["b"][0].set(10.0)\
       .at[mask].set(100.0)
pytreeclass - v0.6.0post0

Published by ASEM000 about 1 year ago

Changelog

v0.6.0post0

  • using tree_{repr,str} with an object containing cyclic references will raise RecursionError instead of displaying cyclicref.
pytreeclass - v0.6.0

Published by ASEM000 about 1 year ago

v0.6.0

  • Allow nested mutations using .at[method](*args, **kwargs).
    After the change, inner methods can mutate copied new instances at any level not just the top level.
    a motivation for this is to experiment with lazy initialization scheme, where inner layers need to mutate their inner state. see the example below for flax-like lazy initialization as descriped here

    
    import pytreeclass as pytc
    import jax.random as jr
    from typing import Any
    import jax
    import jax.numpy as jnp
    from typing import Callable, TypeVar
    
    T = TypeVar("T")
    
    @pytc.autoinit
    class LazyLinear(pytc.TreeClass):
        outdim: int
        weight_init: Callable[..., T] = jax.nn.initializers.glorot_normal()
        bias_init: Callable[..., T] = jax.nn.initializers.zeros
    
        def param(self, name: str, init_func: Callable[..., T], *args) -> T:
            if name not in vars(self):
                setattr(self, name, init_func(*args))
            return vars(self)[name]
    
        def __call__(self, x: jax.Array, *, key: jr.KeyArray = jr.PRNGKey(0)):
            w = self.param("weight", self.weight_init, key, (x.shape[-1], self.outdim))
            y = x @ w
            if self.bias_init is not None:
                b = self.param("bias", self.bias_init, key, (self.outdim,))
                return y + b
            return y
    
    
    @pytc.autoinit
    class StackedLinear(pytc.TreeClass):
        l1: LazyLinear = LazyLinear(outdim=10)
        l2: LazyLinear = LazyLinear(outdim=1)
    
        def call(self, x: jax.Array):
            return self.l2(jax.nn.relu(self.l1(x)))
    
    lazy_layer = StackedLinear()
    print(repr(lazy_layer))
    # StackedLinear(
    #   l1=LazyLinear(
    #     outdim=10, 
    #     weight_init=init(key, shape, dtype), 
    #     bias_init=zeros(key, shape, dtype)
    #   ), 
    #   l2=LazyLinear(
    #     outdim=1, 
    #     weight_init=init(key, shape, dtype), 
    #     bias_init=zeros(key, shape, dtype)
    #   )
    # )
    
    _, materialized_layer = lazy_layer.at["call"](jnp.ones((1, 5)))
    materialized_layer
    # StackedLinear(
    #   l1=LazyLinear(
    #     outdim=10, 
    #     weight_init=init(key, shape, dtype), 
    #     bias_init=zeros(key, shape, dtype), 
    #     weight=f32[5,10](μ=-0.04, σ=0.32, ∈[-0.74,0.63]), 
    #     bias=f32[10](μ=0.00, σ=0.00, ∈[0.00,0.00])
    #   ), 
    #   l2=LazyLinear(
    #     outdim=1, 
    #     weight_init=init(key, shape, dtype), 
    #     bias_init=zeros(key, shape, dtype), 
    #     weight=f32[10,1](μ=-0.07, σ=0.23, ∈[-0.34,0.34]), 
    #     bias=f32[1](μ=0.00, σ=0.00, ∈[0.00,0.00])
    #   )
    # )
    
    materialized_layer(jnp.ones((1, 5)))
    # Array([[0.16712935]], dtype=float32)
    

Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.5...v0.6.0

pytreeclass - v0.5.0post0

Published by ASEM000 about 1 year ago

Fix __init_subclass__ not accepting arguments. Bug introduced since v0.5

Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.5...v0.0.5post0

pytreeclass - v0.5.0

Published by ASEM000 about 1 year ago

Changelog

PyTreeClass v0.5

Breaking changes

Auto generation of __init__ method from type hints is decoupled from TreeClass

Alternatives

Use:

  1. Preferably decorate with pytreeclass.autoinit with pytreeclass.field as field specifier. as pytreeclass.field has more features (e.g. callbacks, multiple argument kind selection) and the init generation is cached compared to dataclasses.
  2. decorate with dataclasses.dataclass with dataclasses.field as field specifier. however :
    1. Must set fronzen=False because the __setattr__, __delattr__ is handled by TreeClass
    2. Optionally repr=False to be handled by TreeClass
    3. Optionally eq=hash=False as it is handled by TreeClass

Before

import jax.tree_util as jtu
import pytreeclass as pytc
import dataclasses as dc

class Tree(pytc.TreeClass):
    a: int = 1

jtu.tree_leaves(Tree())
# [1]

After

Equivalent behavior when decorating with either:

  1. @pytreeclass.autoinit
  2. @dataclasses.dataclass
import jax.tree_util as jtu
import pytreeclass as pytc

@pytc.autoinit
class Tree(pytc.TreeClass):
    a: int = 1

jtu.tree_leaves(Tree())
# [1]

This change aims to fix the ambiguity of using the dataclass mental model in the following siutations:

  1. subclassing. previously, using TreeClass as a base class is equivalent to decorating the class with dataclasses.dataclass, however this is a bit challenging to understand as demonstrated in the next example:

    import pytreeclass as pytc
    import dataclasses as dc
    
    class A(pytc.TreeClass):
        def ___init__(self, a:int):
            self.a = a
    
    class B(A):
        ...
    
    

    When instantiating B(a=...), an error will be raised, because using TreeClass is equivalent of decorating all classes with @dataclass, which synthesize the __init__ method based on the fields.
    Since no fields (e.g. type hinted values) then the synthesized __init__ method .

    The previous code is equivalent to this code.

    @dc.dataclass
    class A:
        def __init__(self, a:int):
            self.a = a
    @dc.dataclass
    class B:
        ...
    
  2. dataclass_transform does not play nicely with user created __init__ see 1, 2

leafwise_transform is decoupled from TreeClass.

instead decorate the class with pytreeclass.leafwise.

pytreeclass - v0.4

Published by ASEM000 over 1 year ago

Changelog

PyTreeClass v0.4

Changes

  1. User-provided re.Pattern is used to match keys with regex pattern instead of using RegexKey

    Example:

    import pytreeclass as pytc
    import re 
    
    tree = {"l1":1, "l2":2, "b":3}
    tree = pytc.AtIndexer(tree)
    tree.at[re.compile("l.*")].get()
    # {'b': None, 'l1': 1, 'l2': 2}
    

Deprecations

  1. RegexKey is deprecated. use re compiled patterns instead.
  2. tree_indent is deprecated. use tree_diagram(tree).replace(...) to replace the edges characters with spaces.

New features

  1. Add tree_mask, tree_unmask to freeze/unfreeze tree leaves based on a callable/boolean pytree mask. defaults to masking non-inexact types by frozen wrapper.

    Example: Pass non-jax types through jax transformation without error.

    # pass non-differentiable values to `jax.grad`
    import pytreeclass as pytc
    import jax
    @jax.grad
    def square(tree):
        tree = pytc.tree_unmask(tree)
        return tree[0]**2
    tree = (1., 2)  # contains a non-differentiable node
    square(pytc.tree_mask(tree))
    # (Array(2., dtype=float32, weak_type=True), #2)
    
  2. Support extending match keys by adding abstract base class BaseKey. check docstring for example

  3. Support multi-index by any acceptable form. e.g. boolean pytree, key, int, or BaseKey instance

    Example:

    
    import pytreeclass as pytc
    tree = {"l1":1, "l2":2, "b":3}
    tree = pytc.AtIndexer(tree)
    tree.at["l1","l2"].get()
    # {'b': None, 'l1': 1, 'l2': 2}
    
    
  4. add scan to AtIndexer to carry a state while applying a function.

    Example:

    
    import pytreeclass as pytc
    def scan_func(leaf, state):
        # increase the state by 1 for each function call
        return leaf**2, state+1
    
    tree = {"l1": 1, "l2": 2, "b": 3}
    tree = pytc.AtIndexer(tree)
    tree, state = tree.at["l1", "l2"].scan(scan_func, 0)
    state
    # 2
    tree
    # {'b': 3, 'l1': 1, 'l2': 4}
    
    
  5. tree_summary improvements.

    • Add size column to tree_summary.
    • add def_count to dispatch count rule for type.
    • add def_size to dispatch size rule for type.
    • add def_type to dispatch type display.

    Example:

    
    import pytreeclass as pytc
    import jax.numpy as jnp
    
    x = jnp.ones((5, 5))
    
    print(pytc.tree_summary([1, 2, 3, x]))
    # ┌────┬────────┬─────┬───────┐
    # │Name│Type    │Count│Size   │
    # ├────┼────────┼─────┼───────┤
    # │[0] │int     │1    │       │
    # ├────┼────────┼─────┼───────┤
    # │[1] │int     │1    │       │
    # ├────┼────────┼─────┼───────┤
    # │[2] │int     │1    │       │
    # ├────┼────────┼─────┼───────┤
    # │[3] │f32[5,5]│25   │100.00B│
    # ├────┼────────┼─────┼───────┤
    # │Σ   │list    │28   │100.00B│
    # └────┴────────┴─────┴───────┘
    
    # make list display its number of elements
    # in the type row
    @pytc.tree_summary.def_type(list)
    def _(_: list) -> str:
        return f"List[{len(_)}]"
    
    print(pytc.tree_summary([1, 2, 3, x]))
    # ┌────┬────────┬─────┬───────┐
    # │Name│Type    │Count│Size   │
    # ├────┼────────┼─────┼───────┤
    # │[0] │int     │1    │       │
    # ├────┼────────┼─────┼───────┤
    # │[1] │int     │1    │       │
    # ├────┼────────┼─────┼───────┤
    # │[2] │int     │1    │       │
    # ├────┼────────┼─────┼───────┤
    # │[3] │f32[5,5]│25   │100.00B│
    # ├────┼────────┼─────┼───────┤
    # │Σ   │List[4] │28   │100.00B│
    # └────┴────────┴─────┴───────┘
    
    
  6. Export pytrees to dot language using tree_graph

    # define custom style for a node by dispatching on the value
    # the defined function should return a dict of attributes
    # that will be passed to graphviz.
    import pytreeclass as pytc
    tree = [1, 2, dict(a=3)]
    @pytc.tree_graph.def_nodestyle(list)
    def _(_) -> dict[str, str]:
        return dict(shape="circle", style="filled", fillcolor="lightblue")
    dot_graph = graphviz.Source(pytc.tree_graph(tree))
    dot_graph
    

    image

  7. Add variable position arguments and variable keyword arguments to pytc.field kind

    import pytreeclass as pytc
    
    
    class Tree(pytc.TreeClass):
        a: int = pytc.field(kind="VAR_POS")
        b: int = pytc.field(kind="POS_ONLY")
        c: int = pytc.field(kind="VAR_KW")
        d: int
        e: int = pytc.field(kind="KW_ONLY")
    
    
    Tree.__init__
    # <function __main__.Tree.__init__(self, b: int, /, d: int, *a: int, e: int, **c: int) -> None>
    

This release introduces lots of functools.singledispatch usage, to enable greater customization.

  • {freeze,unfreeze,is_nondiff}.def_type to define how to freeze a type, how to unfreeze it and whether it is considred nondiff or not. these rules are used by these functions and tree_mask/tree_unmask.
  • tree_graph.def_nodestyle, tree_summary.def_{count,type,size} for pretty printing customization
  • BaseKey.def_alias to define type alias usage inside AtIndexer/.at
  • Internally, most of the pretty printing is using dispatching to define repr/str rules for each instance type.
pytreeclass - v0.3.8

Published by ASEM000 over 1 year ago

What's Changed

Example:

Update all leaves starting with weight_


import pytreeclass as pytc

class Tree(pytc.TreeClass):
    weight_1: float = 1.0
    weight_2: float = 2.0
    weight_3: float = 3.0
    bias: float = 0.0

tree = Tree()

tree.at[pytc.RegexKey(r"weight_.*")].set(100.0)
# Tree(weight_1=100.0, weight_2=100.0, weight_3=100.0, bias=0.0)

Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.3.7...v0.3.8

pytreeclass - v0.3.7

Published by ASEM000 over 1 year ago

pytreeclass - v0.3.6

Published by ASEM000 over 1 year ago

pytreeclass - v0.3.4

Published by ASEM000 over 1 year ago

What's Changed

  • Exposed internal Partial (used for bcmap) for jaxable partial functions with positional arguments support
      >>> import pytreeclass as pytc
      >>> def f(a, b, c):
      ...     print(f"a: {a}, b: {b}, c: {c}")
      ...     return a + b + c

      >>> # positional arguments using `...` placeholder
      >>> f_a = pytc.Partial(f, ..., 2, 3)
      >>> f_a(1)
      a: 1, b: 2, c: 3
      6

pytreeclass - v0.3.3

Published by ASEM000 over 1 year ago

  • Fix empty vars bug
pytreeclass - v0.3.2

Published by ASEM000 over 1 year ago

What's Changed

Motivation

  • Enable using `.at[''] on all instance variables, not only the type hinted / pytreeclass wrapped instance variables

Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.3.1...v0.3.2

pytreeclass - v0.3.1

Published by ASEM000 over 1 year ago

What's Changed

Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.3.0...v0.3.1

pytreeclass - v0.3.0

Published by ASEM000 over 1 year ago

Breaking change v0.3

For better typing, the decorator-based approach is changed to subclassing approach.

Old

@functools.partial(pytc.treeclass, leafwise=True)
class Tree:
  pass

New


class Tree(pytc.TreeClass, leafwise=True):
  pass

  • is_treeclass is removed use isinstance(..., TreeClass) instead
  • pytc.fields is removed.

Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.2.8...v0.3.0

pytreeclass - v0.2.8

Published by ASEM000 over 1 year ago

What's Changed

Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.2.7...v0.2.8

pytreeclass - v0.2.7

Published by ASEM000 over 1 year ago

What's Changed

Full Changelog: https://github.com/ASEM000/PyTreeClass/compare/v0.2.6...v0.2.7