Progress Bar
gpjax.progress_bar
__all__ = ['progress_bar']
module-attribute
progress_bar(num_iters: int, log_rate: int) -> Callable
Progress bar decorator for the body function of a jax.lax.scan
.
Example:
>>> import jax.numpy as jnp
>>> import jax
>>>
>>> carry = jnp.array(0.0)
>>> iteration_numbers = jnp.arange(100)
>>>
>>> @progress_bar(num_iters=iteration_numbers.shape[0], log_rate=10)
>>> def body_func(carry, x):
>>> return carry, x
>>>
>>> carry, _ = jax.lax.scan(body_func, carry, iteration_numbers)
Adapted from this excellent blog post.
Might be nice in future to directly create a general purpose verbose scan
inplace of a for a jax.lax.scan, that takes the same arguments as a jax.lax.scan,
but prints a progress bar.