Skip to content

Scan

gpjax.scan

Carry = TypeVar('Carry') module-attribute
X = TypeVar('X') module-attribute
Y = TypeVar('Y') module-attribute
__all__ = ['vscan'] module-attribute
vscan(f: Callable[[Carry, X], Tuple[Carry, Y]], init: Carry, xs: X, length: Optional[int] = None, reverse: Optional[bool] = False, unroll: Optional[int] = 1, log_rate: Optional[int] = 10, log_value: Optional[bool] = True) -> Tuple[Carry, Shaped[Array, ...]]

Scan with verbose output.

This is based on code from this excellent blog post.

Example:

    >>> import jax.numpy as jnp
    >>>
    >>> def f(carry, x):
            return carry + x, carry + x
    >>> init = 0
    >>> xs = jnp.arange(10)
    >>> vscan(f, init, xs)
    (45, DeviceArray([ 0,  1,  3,  6, 10, 15, 21, 28, 36, 45], dtype=int32))

Parameters:

Name Type Description Default
f Callable[[Carry, X], Tuple[Carry, Y]]

A function that takes in a carry and an input and returns a tuple of a new carry and an output.

required
init Carry

The initial carry.

required
xs X

The inputs.

required
length Optional[int]

The length of the inputs. If None, then the length of the inputs is inferred.

None
reverse bool

Whether to scan in reverse.

False
unroll int

The number of iterations to unroll.

1
log_rate int

The rate at which to log the progress bar.

10
log_value bool

Whether to log the value of the objective function.

True
Returns
Tuple[Carry, list[Y]]: A tuple of the final carry and the outputs.