Skip to content

Progress Bar


__all__ = ['progress_bar'] module-attribute
progress_bar(num_iters: int, log_rate: int) -> Callable

Progress bar decorator for the body function of a jax.lax.scan.


    >>> import jax.numpy as jnp
    >>> import jax
    >>> carry = jnp.array(0.0)
    >>> iteration_numbers = jnp.arange(100)
    >>> @progress_bar(num_iters=iteration_numbers.shape[0], log_rate=10)
    >>> def body_func(carry, x):
    >>>    return carry, x
    >>> carry, _ = jax.lax.scan(body_func, carry, iteration_numbers)

Adapted from this excellent blog post.

Might be nice in future to directly create a general purpose verbose scan inplace of a for a jax.lax.scan, that takes the same arguments as a jax.lax.scan, but prints a progress bar.