Scan
vscan
Scan with verbose output.
This is based on code from this excellent blog post.
Example
import jax.numpy as jnp ... def f(carry, x): ... return carry + x, carry + x init = 0 xs = jnp.arange(10) vscan(f, init, xs) (Array(45, dtype=int32), Array([ 0, 1, 3, 6, 10, 15, 21, 28, 36, 45], dtype=int32))
Parameters:
-
f
(Callable[[Carry, X], Tuple[Carry, Y]]
) βA function that takes in a carry and an input and returns a tuple of a new carry and an output.
-
init
(Carry
) βThe initial carry.
-
xs
(X
) βThe inputs.
-
length
(Optional[int]
, default:None
) βThe length of the inputs. If None, then the length of the inputs is inferred.
-
reverse
(bool
, default:False
) βWhether to scan in reverse.
-
unroll
(int
, default:1
) βThe number of iterations to unroll.
-
log_rate
(int
, default:10
) βThe rate at which to log the progress bar.
-
log_value
(bool
, default:True
) βWhether to log the value of the objective function.
Returns
Tuple[Carry, list[Y]]: A tuple of the final carry and the outputs.