Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion tensorflow_probability/python/optimizer/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def minimize(value_and_gradients_function,
validate_args=True,
max_line_search_iterations=50,
f_absolute_tolerance=0,
line_search_kwargs=None,
name=None):
"""Applies the BFGS algorithm to minimize a differentiable function.

Expand Down Expand Up @@ -177,6 +178,11 @@ def quadratic_loss_and_gradient(x):
f_absolute_tolerance: Scalar `Tensor` of real dtype. If the absolute change
in the objective value between one iteration and the next is smaller
than this value, the algorithm is stopped.
line_search_kwargs: (Optional) Python `dict` of extra keyword arguments to
forward to the underlying `hager_zhang` line search, allowing the caller
to tune line search hyper-parameters such as `sufficient_decrease_param`
or `curvature_param`. See `tfp.optimizer.linesearch.hager_zhang` for the
full list of accepted parameters.
name: (Optional) Python str. The name prefixed to the ops created by this
function. If not supplied, the default name 'minimize' is used.

Expand Down Expand Up @@ -286,7 +292,8 @@ def _body(state):
next_state = bfgs_utils.line_search_step(
current_state, value_and_gradients_function, actual_search_direction,
tolerance, f_relative_tolerance, x_tolerance, stopping_condition,
max_line_search_iterations, f_absolute_tolerance)
max_line_search_iterations, f_absolute_tolerance,
line_search_kwargs=line_search_kwargs)

# Update the inverse Hessian if needed and continue.
return [_update_inv_hessian(current_state, next_state)]
Expand Down
26 changes: 26 additions & 0 deletions tensorflow_probability/python/optimizer/bfgs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,32 @@ def quadratic_with_spike(x):
self.assertAllTrue(results.failed)
self.assertAllFinite(results.position)

def test_line_search_kwargs_are_forwarded(self):
"""`line_search_kwargs` reach `hager_zhang` and affect the search."""
@_make_val_and_grad_fn
def rosenbrock(coord):
x, y = coord[0], coord[1]
return (1 - x)**2 + 100 * (y - x**2)**2

start = tf.constant([-1.2, 1.0])
default_results = self.evaluate(
bfgs.minimize(rosenbrock, initial_position=start, tolerance=1e-5))
small_step_results = self.evaluate(
bfgs.minimize(
rosenbrock, initial_position=start, tolerance=1e-5,
line_search_kwargs=dict(initial_step_size=0.1)))

# Both runs should still converge to the Rosenbrock minimum.
self.assertTrue(default_results.converged)
self.assertTrue(small_step_results.converged)
self.assertArrayNear(small_step_results.position,
np.array([1.0, 1.0]), 1e-5)
# The default starts each line search at a unit step. Shrinking that to
# 0.1 must change the line search trajectory, and therefore the total
# number of objective evaluations.
self.assertNotEqual(default_results.num_objective_evaluations,
small_step_results.num_objective_evaluations)


if __name__ == '__main__':
test_util.main()
17 changes: 13 additions & 4 deletions tensorflow_probability/python/optimizer/bfgs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def _is_negative_inf(x):

def line_search_step(state, value_and_gradients_function, search_direction,
grad_tolerance, f_relative_tolerance, x_tolerance,
stopping_condition, max_iterations, f_absolute_tolerance):
stopping_condition, max_iterations, f_absolute_tolerance,
line_search_kwargs=None):
"""Performs the line search step of the BFGS search procedure.

Uses hager_zhang line search procedure to compute a suitable step size
Expand Down Expand Up @@ -182,6 +183,12 @@ def line_search_step(state, value_and_gradients_function, search_direction,
iterations of the hager_zhang line search algorithm
f_absolute_tolerance: Scalar `Tensor` of real dtype. Specifies the tolerance
for the absolute change in the objective value.
line_search_kwargs: (Optional) Python `dict` of extra keyword arguments to
forward to the underlying `hager_zhang` line search, allowing the caller
to tune line search hyper-parameters such as `initial_step_size`,
`sufficient_decrease_param` or `curvature_param`. Keys that collide with
arguments controlled by the outer optimization loop (`value_at_zero`,
`converged`, `max_iterations`) will raise a `TypeError`.

Returns:
A copy of the input state with the following fields updated:
Expand All @@ -206,12 +213,14 @@ def line_search_step(state, value_and_gradients_function, search_direction,
df=derivative_at_start_pt,
full_gradient=state.objective_gradient)
inactive = state.failed | state.converged
ls_kwargs = dict(line_search_kwargs) if line_search_kwargs else {}
ls_kwargs.setdefault('initial_step_size', _broadcast(1, state.position))
ls_result = hager_zhang(
line_search_value_grad_func,
initial_step_size=_broadcast(1, state.position),
value_at_zero=val_0,
converged=inactive,
max_iterations=max_iterations) # No search needed for these.
converged=inactive, # No search needed for these.
max_iterations=max_iterations,
**ls_kwargs)

state_after_ls = update_fields(
state,
Expand Down