Skip to content

Scan

vscan

vscan(f, init, xs, length=None, reverse=False, unroll=1, log_rate=10, log_value=True)

Scan with verbose output.

This is based on code from this excellent blog post.

Example

import jax.numpy as jnp ... def f(carry, x): ... return carry + x, carry + x init = 0 xs = jnp.arange(10) vscan(f, init, xs) (Array(45, dtype=int32), Array([ 0, 1, 3, 6, 10, 15, 21, 28, 36, 45], dtype=int32))

Parameters:

  • f (Callable[[Carry, X], Tuple[Carry, Y]]) –

    A function that takes in a carry and an input and returns a tuple of a new carry and an output.

  • init (Carry) –

    The initial carry.

  • xs (X) –

    The inputs.

  • length (Optional[int], default: None ) –

    The length of the inputs. If None, then the length of the inputs is inferred.

  • reverse (bool, default: False ) –

    Whether to scan in reverse.

  • unroll (int, default: 1 ) –

    The number of iterations to unroll.

  • log_rate (int, default: 10 ) –

    The rate at which to log the progress bar.

  • log_value (bool, default: True ) –

    Whether to log the value of the objective function.

Returns

Tuple[Carry, list[Y]]: A tuple of the final carry and the outputs.