Decision Maker
gpjax.decision_making.decision_maker
AbstractDecisionMaker
dataclass
Bases: ABC
AbstractDecisionMaker abstract base class which handles the core decision making
loop, where we sequentially decide on points to query our function of interest at.
The decision making loop is split into two key steps, ask
and tell
. The ask
step is typically used to decide which point to query next. The tell
step is
typically used to update models and datasets with newly queried points. These steps
can be combined in a 'run' loop which alternates between asking which point to query
next and telling the decision maker about the newly queried point having evaluated
the black-box function of interest at this point.
Attributes:
Name | Type | Description |
---|---|---|
search_space |
AbstractSearchSpace
|
Search space over which we can evaluate the |
posterior_handlers |
Dict[str, PosteriorHandler]
|
Dictionary of posterior
handlers, which are used to update posteriors throughout the decision making
loop. Note that the word |
datasets |
Dict[str, Dataset]
|
Dictionary of datasets, which are augmented with
observations throughout the decision making loop. In a typical setup they are
also used to update the posteriors, using the |
key |
KeyArray
|
JAX random key, used to generate random numbers. |
batch_size |
int
|
Number of points to query at each step of the decision making
loop. Note that |
post_ask |
List[Callable]
|
List of functions to be executed after each ask step. |
post_tell |
List[Callable]
|
List of functions to be executed after each tell step. |
search_space: AbstractSearchSpace
instance-attribute
posterior_handlers: Dict[str, PosteriorHandler]
instance-attribute
datasets: Dict[str, Dataset]
instance-attribute
key: KeyArray
instance-attribute
batch_size: int
instance-attribute
post_ask: List[Callable]
instance-attribute
post_tell: List[Callable]
instance-attribute
__post_init__()
At initialisation we check that the posterior handlers and datasets are consistent (i.e. have the same tags), and then initialise the posteriors, optimizing them using the corresponding datasets.
ask(key: KeyArray) -> Float[Array, 'B D']
abstractmethod
tell(observation_datasets: Mapping[str, Dataset], key: KeyArray)
run(n_steps: int, black_box_function_evaluator: FunctionEvaluator) -> Mapping[str, Dataset]
Run the decision making loop continuously for for n_steps
. This is broken down
into three main steps:
1. Call the ask
method to get the point to be queried next.
2. Call the black_box_function_evaluator
to evaluate the black box functions
of interest at the point chosen to be queried.
3. Call the tell
method to update the datasets and posteriors with the newly
observed data.
In addition to this, after the ask
step, the functions in the post_ask
list
are executed, taking as arguments the decision maker and the point chosen to be
queried next. Similarly, after the tell
step, the functions in the post_tell
list are executed, taking the decision maker as the sole argument.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_steps |
int
|
Number of steps to run the decision making loop for. |
required |
black_box_function_evaluator |
FunctionEvaluator
|
Function evaluator which evaluates the black box functions of interest at supplied points. |
required |
Returns:
Type | Description |
---|---|
Mapping[str, Dataset]
|
Mapping[str, Dataset]: Dictionary of datasets containing the observations |
Mapping[str, Dataset]
|
made throughout the decision making loop, as well as the initial data |
Mapping[str, Dataset]
|
supplied when initialising the |
UtilityDrivenDecisionMaker
dataclass
Bases: AbstractDecisionMaker
UtilityDrivenDecisionMaker class which handles the core decision making loop in a typical model-based decision making setup. In this setup we use surrogate model(s) for the function(s) of interest, and define a utility function (often called the 'acquisition function' in the context of Bayesian optimisation) which characterises how useful it would be to query a given point within the search space given the data we have observed so far. This can then be used to decide which point(s) to query next.
The decision making loop is split into two key steps, ask
and tell
. The ask
step forms a UtilityFunction
from the current posteriors
and datasets
and
returns the point which maximises it. It also stores the formed utility function
under the attribute self.current_utility_function
so that it can be called,
for instance for plotting, after the ask
function has been called. The tell
step
adds a newly queried point to the datasets
and updates the posteriors
.
This can be run as a typical ask-tell loop, or the run
method can be used to run
the decision making loop for a fixed number of steps. Moreover, the run
method executes
the functions in post_ask
and post_tell
after each ask and tell step
respectively. This enables the user to add custom functionality, such as the ability
to plot values of interest during the optimization process.
Attributes:
Name | Type | Description |
---|---|---|
utility_function_builder |
AbstractUtilityFunctionBuilder
|
Object which builds utility functions from posteriors and datasets, to decide where to query next. In a typical Bayesian optimisation setup the point chosen to be queried next is the point which maximizes the utility function. |
utility_maximizer |
AbstractUtilityMaximizer
|
Object which maximizes utility functions over the search space. |
search_space: AbstractSearchSpace
instance-attribute
posterior_handlers: Dict[str, PosteriorHandler]
instance-attribute
datasets: Dict[str, Dataset]
instance-attribute
key: KeyArray
instance-attribute
batch_size: int
instance-attribute
post_ask: List[Callable]
instance-attribute
post_tell: List[Callable]
instance-attribute
utility_function_builder: AbstractUtilityFunctionBuilder
instance-attribute
utility_maximizer: AbstractUtilityMaximizer
instance-attribute
tell(observation_datasets: Mapping[str, Dataset], key: KeyArray)
run(n_steps: int, black_box_function_evaluator: FunctionEvaluator) -> Mapping[str, Dataset]
Run the decision making loop continuously for for n_steps
. This is broken down
into three main steps:
1. Call the ask
method to get the point to be queried next.
2. Call the black_box_function_evaluator
to evaluate the black box functions
of interest at the point chosen to be queried.
3. Call the tell
method to update the datasets and posteriors with the newly
observed data.
In addition to this, after the ask
step, the functions in the post_ask
list
are executed, taking as arguments the decision maker and the point chosen to be
queried next. Similarly, after the tell
step, the functions in the post_tell
list are executed, taking the decision maker as the sole argument.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_steps |
int
|
Number of steps to run the decision making loop for. |
required |
black_box_function_evaluator |
FunctionEvaluator
|
Function evaluator which evaluates the black box functions of interest at supplied points. |
required |
Returns:
Type | Description |
---|---|
Mapping[str, Dataset]
|
Mapping[str, Dataset]: Dictionary of datasets containing the observations |
Mapping[str, Dataset]
|
made throughout the decision making loop, as well as the initial data |
Mapping[str, Dataset]
|
supplied when initialising the |
__post_init__()
ask(key: KeyArray) -> Float[Array, 'B D']
Get updated utility function(s) and return the point(s) which maximises it/them. This
method also stores the utility function(s) in
self.current_utility_functions
so that they can be accessed after the ask
function has been called. This is useful for non-deterministic utility
functions, which may differ between calls to ask
due to the splitting of
self.key
.
Note that in general SinglePointUtilityFunction
s are only capable of
generating one point to be queried at each iteration of the decision making loop
(i.e. self.batch_size
must be 1). However, Thompson sampling can be used in a
batched setting by drawing a batch of different samples from the GP posterior.
This is done by calling build_utility_function
with different keys
sequentilly, and optimising each of these individual samples in sequence in
order to obtain self.batch_size
points to query next.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
KeyArray
|
JAX PRNG key for controlling random state. |
required |
Returns:
Type | Description |
---|---|
Float[Array, 'B D']
|
Float[Array, "B D"]: Point(s) to be queried next. |