Decision Maker
AbstractDecisionMaker
dataclass
AbstractDecisionMaker(search_space, posterior_handlers, datasets, key, batch_size, post_ask, post_tell)
Bases: ABC
AbstractDecisionMaker abstract base class which handles the core decision making
loop, where we sequentially decide on points to query our function of interest at.
The decision making loop is split into two key steps, ask
and tell
. The ask
step is typically used to decide which point to query next. The tell
step is
typically used to update models and datasets with newly queried points. These steps
can be combined in a 'run' loop which alternates between asking which point to query
next and telling the decision maker about the newly queried point having evaluated
the black-box function of interest at this point.
Attributes:
-
search_space
(AbstractSearchSpace
) βSearch space over which we can evaluate the function(s) of interest.
-
posterior_handlers
(Dict[str, PosteriorHandler]
) βdictionary of posterior handlers, which are used to update posteriors throughout the decision making loop. Note that the word
posteriors
is used for consistency with GPJax, but these objects are typically referred to asmodels
in the model-based decision making literature. Tags are used to distinguish between posteriors. In a typical Bayesian optimisation setup one of the tags will beOBJECTIVE
, defined indecision_making.utils
. -
datasets
(Dict[str, Dataset]
) βdictionary of datasets, which are augmented with observations throughout the decision making loop. In a typical setup they are also used to update the posteriors, using the
posterior_handlers
. Tags are used to distinguish datasets, and correspond to tags inposterior_handlers
. -
key
(KeyArray
) βJAX random key, used to generate random numbers.
-
batch_size
(int
) βNumber of points to query at each step of the decision making loop. Note that
SinglePointUtilityFunction
s are only capable of generating one point to be queried at each iteration of the decision making loop. -
post_ask
(List[Callable]
) βList of functions to be executed after each ask step.
-
post_tell
(List[Callable]
) βList of functions to be executed after each tell step.
ask
abstractmethod
Get the point(s) to be queried next.
Parameters:
-
key
(KeyArray
) βJAX PRNG key for controlling random state.
Returns:
-
Float[Array, 'B D']
βFloat[Array, "1 D"]: Point to be queried next
tell
Add newly observed data to datasets and update the corresponding posteriors.
Parameters:
-
observation_datasets
(Mapping[str, Dataset]
) βdictionary of datasets containing new observations. Tags are used to distinguish datasets, and correspond to tags in
posterior_handlers
andself.datasets
. -
key
(KeyArray
) βJAX PRNG key for controlling random state.
run
Run the decision making loop continuously for for n_steps
. This is broken down
into three main steps:
1. Call the ask
method to get the point to be queried next.
2. Call the black_box_function_evaluator
to evaluate the black box functions
of interest at the point chosen to be queried.
3. Call the tell
method to update the datasets and posteriors with the newly
observed data.
In addition to this, after the ask
step, the functions in the post_ask
list
are executed, taking as arguments the decision maker and the point chosen to be
queried next. Similarly, after the tell
step, the functions in the post_tell
list are executed, taking the decision maker as the sole argument.
Parameters:
-
n_steps
(int
) βNumber of steps to run the decision making loop for.
-
black_box_function_evaluator
(FunctionEvaluator
) βFunction evaluator which evaluates the black box functions of interest at supplied points.
Returns:
UtilityDrivenDecisionMaker
dataclass
UtilityDrivenDecisionMaker(search_space, posterior_handlers, datasets, key, batch_size, post_ask, post_tell, utility_function_builder, utility_maximizer)
Bases: AbstractDecisionMaker
UtilityDrivenDecisionMaker class which handles the core decision making loop in a typical model-based decision making setup. In this setup we use surrogate model(s) for the function(s) of interest, and define a utility function (often called the 'acquisition function' in the context of Bayesian optimisation) which characterises how useful it would be to query a given point within the search space given the data we have observed so far. This can then be used to decide which point(s) to query next.
The decision making loop is split into two key steps, ask
and tell
. The ask
step forms a UtilityFunction
from the current posteriors
and datasets
and
returns the point which maximises it. It also stores the formed utility function
under the attribute self.current_utility_function
so that it can be called,
for instance for plotting, after the ask
function has been called. The tell
step
adds a newly queried point to the datasets
and updates the posteriors
.
This can be run as a typical ask-tell loop, or the run
method can be used to run
the decision making loop for a fixed number of steps. Moreover, the run
method executes
the functions in post_ask
and post_tell
after each ask and tell step
respectively. This enables the user to add custom functionality, such as the ability
to plot values of interest during the optimization process.
Attributes:
-
utility_function_builder
(AbstractUtilityFunctionBuilder
) βObject which builds utility functions from posteriors and datasets, to decide where to query next. In a typical Bayesian optimisation setup the point chosen to be queried next is the point which maximizes the utility function.
-
utility_maximizer
(AbstractUtilityMaximizer
) βObject which maximizes utility functions over the search space.
tell
Add newly observed data to datasets and update the corresponding posteriors.
Parameters:
-
observation_datasets
(Mapping[str, Dataset]
) βdictionary of datasets containing new observations. Tags are used to distinguish datasets, and correspond to tags in
posterior_handlers
andself.datasets
. -
key
(KeyArray
) βJAX PRNG key for controlling random state.
run
Run the decision making loop continuously for for n_steps
. This is broken down
into three main steps:
1. Call the ask
method to get the point to be queried next.
2. Call the black_box_function_evaluator
to evaluate the black box functions
of interest at the point chosen to be queried.
3. Call the tell
method to update the datasets and posteriors with the newly
observed data.
In addition to this, after the ask
step, the functions in the post_ask
list
are executed, taking as arguments the decision maker and the point chosen to be
queried next. Similarly, after the tell
step, the functions in the post_tell
list are executed, taking the decision maker as the sole argument.
Parameters:
-
n_steps
(int
) βNumber of steps to run the decision making loop for.
-
black_box_function_evaluator
(FunctionEvaluator
) βFunction evaluator which evaluates the black box functions of interest at supplied points.
Returns:
ask
Get updated utility function(s) and return the point(s) which maximises it/them. This
method also stores the utility function(s) in
self.current_utility_functions
so that they can be accessed after the ask
function has been called. This is useful for non-deterministic utility
functions, which may differ between calls to ask
due to the splitting of
self.key
.
Note that in general SinglePointUtilityFunction
s are only capable of
generating one point to be queried at each iteration of the decision making loop
(i.e. self.batch_size
must be 1). However, Thompson sampling can be used in a
batched setting by drawing a batch of different samples from the GP posterior.
This is done by calling build_utility_function
with different keys
sequentilly, and optimising each of these individual samples in sequence in
order to obtain self.batch_size
points to query next.
Parameters:
-
key
(KeyArray
) βJAX PRNG key for controlling random state.
Returns:
-
Float[Array, 'B D']
βFloat[Array, "B D"]: Point(s) to be queried next.