Scan
gpjax.scan
Carry = TypeVar('Carry')
module-attribute
X = TypeVar('X')
module-attribute
Y = TypeVar('Y')
module-attribute
__all__ = ['vscan']
module-attribute
vscan(f: Callable[[Carry, X], Tuple[Carry, Y]], init: Carry, xs: X, length: Optional[int] = None, reverse: Optional[bool] = False, unroll: Optional[int] = 1, log_rate: Optional[int] = 10, log_value: Optional[bool] = True) -> Tuple[Carry, Shaped[Array, ...]]
Scan with verbose output.
This is based on code from this excellent blog post.
Example:
>>> import jax.numpy as jnp
>>>
>>> def f(carry, x):
return carry + x, carry + x
>>> init = 0
>>> xs = jnp.arange(10)
>>> vscan(f, init, xs)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f |
Callable[[Carry, X], Tuple[Carry, Y]]
|
A function that takes in a carry and an input and returns a tuple of a new carry and an output. |
required |
init |
Carry
|
The initial carry. |
required |
xs |
X
|
The inputs. |
required |
length |
Optional[int]
|
The length of the inputs. If None, then the length of the inputs is inferred. |
None
|
reverse |
bool
|
Whether to scan in reverse. |
False
|
unroll |
int
|
The number of iterations to unroll. |
1
|
log_rate |
int
|
The rate at which to log the progress bar. |
10
|
log_value |
bool
|
Whether to log the value of the objective function. |
True
|
Returns
Tuple[Carry, list[Y]]: A tuple of the final carry and the outputs.