Scan
vscan
Scan with verbose output.
This is based on code from this excellent blog post.
Example
import jax.numpy as jnp ... def f(carry, x): ... return carry + x, carry + x init = 0 xs = jnp.arange(10) vscan(f, init, xs) (Array(45, dtype=int32), Array([ 0, 1, 3, 6, 10, 15, 21, 28, 36, 45], dtype=int32))
Parameters:
-
f(Callable[[Carry, X], Tuple[Carry, Y]]) βA function that takes in a carry and an input and returns a tuple of a new carry and an output.
-
init(Carry) βThe initial carry.
-
xs(X) βThe inputs.
-
length(Optional[int], default:None) βThe length of the inputs. If None, then the length of the inputs is inferred.
-
reverse(bool, default:False) βWhether to scan in reverse.
-
unroll(int, default:1) βThe number of iterations to unroll.
-
log_rate(int, default:10) βThe rate at which to log the progress bar.
-
log_value(bool, default:True) βWhether to log the value of the objective function.
Returns
Tuple[Carry, list[Y]]: A tuple of the final carry and the outputs.