Source code for optimism.MinimizeScalar

from collections import namedtuple
from optimism.JaxConfig import *
#from jax.lax import while_loop


Settings = namedtuple('Settings',
                      ['tol', 'max_iters'])


[docs] def get_settings(tol=1e-8, max_iters=25): return Settings(tol, max_iters)
[docs] def minimize_scalar(objective, x0, diffArgs, nondiffArgs, settings): if not isinstance(diffArgs, tuple): msg = "diffArgs argument to optimism.MinimizeScalar.minimize_scalar must be a tuple, got {}" raise TypeError(msg.format(diffArgs)) if not isinstance(diffArgs, tuple): msg = "nondiffArgs argument to optimism.MinimizeScalar.minimize_scalar must be a tuple, got {}" raise TypeError(msg.format(nondiffArgs)) # define function with args def F(x): return objective(x, *diffArgs, *nondiffArgs) G = jacfwd(F) GH = value_and_grad(G) tol = settings.tol max_iters = settings.max_iters def conditional(carry): x, f, g, h, alpha, i = carry resNorm = np.abs(g) return (resNorm > tol) & (i < max_iters) def body(carry): x, f, g, h, alpha, i = carry print("h=",h) if h > 0: # positive curvature, use newton step p = -g/h alpha0 = 1.0 alpha = line_search_backtrack(x, f, g, p, alpha0, F) else: # negative curvature, use gradient descent p = -g print("alpha initial=", alpha) alpha = line_search_bidirectional(x, f, g, p, alpha, F) x += alpha*p print("i", i, "x", x) g, h = GH(x) f = F(x) return (x, f, g, h, alpha, i + 1) f0 = F(x0) df0, ddf0 = GH(x0) alpha0 = 1.0 xStar, f, df, _, _, iters = while_loop(conditional, body, (x0, f0, df0, ddf0, alpha0, 0)) print("\n----\niters taken=", iters) print("df=", df) print("objective = ", f) return xStar
[docs] def line_search_bidirectional(x, f, g, p, alpha, F): c = 0.01 cgp = c*g*p initialStepLengthSufficient = F(x + alpha*p) < f + cgp*alpha print("fwd track = ", initialStepLengthSufficient) if initialStepLengthSufficient: alpha = line_search_forwardtrack(x, f, g, p, alpha, F) else: alpha = line_search_backtrack(x, f, g, p, alpha, F) return alpha
[docs] def line_search_backtrack(x, f, g, p, alpha, F): cutback = 0.2 c = 0.01 cgp = c*g*p def cond_fun(alphaAndIters): alpha, i = alphaAndIters return (F(x + alpha*p) > f + cgp*alpha) & (i < 20) def body_fun(alphaAndIters): alpha, i = alphaAndIters alpha *= cutback return alpha, i + 1 alpha, lsIters = while_loop(cond_fun, body_fun, (alpha, 0)) print("alpha=", alpha) return alpha
[docs] def line_search_forwardtrack(x, f, g, p, alpha, F): growth = 1.0/0.2 c = 0.01 cgp = c*g*p def cond_fun(alphaAndIters): alpha, i = alphaAndIters return (F(x + growth*alpha*p) < f + cgp*alpha) & (i < 20) def body_fun(alphaAndIters): alpha, i = alphaAndIters alpha *= growth return alpha, i + 1 alpha, lsIters = while_loop(cond_fun, body_fun, (alpha, 0)) print("alpha=", alpha) return alpha
[docs] def while_loop(cond_fun, body_fun, init_val): val = init_val while cond_fun(val): val = body_fun(val) return val