Scan
vscan
vscan(
f: Callable[[Carry, X], Tuple[Carry, Y]],
init: Carry,
xs: X,
length: Optional[int] = None,
reverse: Optional[bool] = False,
unroll: Optional[int] = 1,
log_rate: Optional[int] = 10,
log_value: Optional[bool] = True,
) -> Tuple[Carry, Shaped[Array, ...]]
Scan with verbose output.
This is based on code from this excellent blog post.
Example
import jax.numpy as jnp ... def f(carry, x): ... return carry + x, carry + x init = 0 xs = jnp.arange(10) vscan(f, init, xs) (Array(45, dtype=int32), Array([ 0, 1, 3, 6, 10, 15, 21, 28, 36, 45], dtype=int32))
Parameters:
-
f(Callable[[Carry, X], Tuple[Carry, Y]]) βA function that takes in a carry and an input and returns a tuple of a new carry and an output.
-
init(Carry) βThe initial carry.
-
xs(X) βThe inputs.
-
length(Optional[int], default:None) βThe length of the inputs. If None, then the length of the inputs is inferred.
-
reverse(bool, default:False) βWhether to scan in reverse.
-
unroll(int, default:1) βThe number of iterations to unroll.
-
log_rate(int, default:10) βThe rate at which to log the progress bar.
-
log_value(bool, default:True) βWhether to log the value of the objective function.
Returns
Tuple[Carry, list[Y]]: A tuple of the final carry and the outputs.