Skip to content

Commit b10f27e

Browse files
committed
Add iterator interface
1 parent bc76368 commit b10f27e

File tree

3 files changed

+143
-44
lines changed

3 files changed

+143
-44
lines changed

src/multivariate/optimize/interface.jl

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -53,85 +53,91 @@ promote_objtype(method::ZerothOrderOptimizer, x, autodiff::Symbol, inplace::Bool
5353
promote_objtype(method::FirstOrderOptimizer, x, autodiff::Symbol, inplace::Bool, td::TwiceDifferentiable) = td
5454
promote_objtype(method::SecondOrderOptimizer, x, autodiff::Symbol, inplace::Bool, td::TwiceDifferentiable) = td
5555

56+
for optimize in [:optimize, :optimizing]
57+
@eval begin
58+
5659
# if no method or options are present
57-
function optimize(f, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)
60+
function $optimize(f, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)
5861
method = fallback_method(f)
5962
checked_kwargs, method = check_kwargs(kwargs, method)
6063
d = promote_objtype(method, initial_x, autodiff, inplace, f)
6164
add_default_opts!(checked_kwargs, method)
6265

6366
options = Options(; checked_kwargs...)
64-
optimize(d, initial_x, method, options)
67+
$optimize(d, initial_x, method, options)
6568
end
66-
function optimize(f, g, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)
69+
function $optimize(f, g, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)
6770

6871
method = fallback_method(f, g)
6972
checked_kwargs, method = check_kwargs(kwargs, method)
7073
d = promote_objtype(method, initial_x, autodiff, inplace, f, g)
7174
add_default_opts!(checked_kwargs, method)
7275

7376
options = Options(; checked_kwargs...)
74-
optimize(d, initial_x, method, options)
77+
$optimize(d, initial_x, method, options)
7578
end
76-
function optimize(f, g, h, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)
79+
function $optimize(f, g, h, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)
7780

7881
method = fallback_method(f, g, h)
7982
checked_kwargs, method = check_kwargs(kwargs, method)
8083
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)
8184
add_default_opts!(checked_kwargs, method)
8285

8386
options = Options(; checked_kwargs...)
84-
optimize(d, initial_x, method, options)
87+
$optimize(d, initial_x, method, options)
8588
end
8689

8790
# no method supplied with objective
88-
function optimize(d::T, initial_x::AbstractArray, options::Options) where T<:AbstractObjective
89-
optimize(d, initial_x, fallback_method(d), options)
91+
function $optimize(d::T, initial_x::AbstractArray, options::Options) where T<:AbstractObjective
92+
$optimize(d, initial_x, fallback_method(d), options)
9093
end
9194
# no method supplied with inplace and autodiff keywords becauase objective is not supplied
92-
function optimize(f, initial_x::AbstractArray, options::Options; inplace = true, autodiff = :finite)
95+
function $optimize(f, initial_x::AbstractArray, options::Options; inplace = true, autodiff = :finite)
9396
method = fallback_method(f)
9497
d = promote_objtype(method, initial_x, autodiff, inplace, f)
95-
optimize(d, initial_x, method, options)
98+
$optimize(d, initial_x, method, options)
9699
end
97-
function optimize(f, g, initial_x::AbstractArray, options::Options; inplace = true, autodiff = :finite)
100+
function $optimize(f, g, initial_x::AbstractArray, options::Options; inplace = true, autodiff = :finite)
98101

99102
method = fallback_method(f, g)
100103
d = promote_objtype(method, initial_x, autodiff, inplace, f, g)
101-
optimize(d, initial_x, method, options)
104+
$optimize(d, initial_x, method, options)
102105
end
103-
function optimize(f, g, h, initial_x::AbstractArray{T}, options::Options; inplace = true, autodiff = :finite) where {T}
106+
function $optimize(f, g, h, initial_x::AbstractArray{T}, options::Options; inplace = true, autodiff = :finite) where {T}
104107

105108
method = fallback_method(f, g, h)
106109
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)
107110

108-
optimize(d, initial_x, method, options)
111+
$optimize(d, initial_x, method, options)
109112
end
110113

111114
# potentially everything is supplied (besides caches)
112-
function optimize(f, initial_x::AbstractArray, method::AbstractOptimizer,
115+
function $optimize(f, initial_x::AbstractArray, method::AbstractOptimizer,
113116
options::Options = Options(;default_options(method)...); inplace = true, autodiff = :finite)
114117

115118
d = promote_objtype(method, initial_x, autodiff, inplace, f)
116-
optimize(d, initial_x, method, options)
119+
$optimize(d, initial_x, method, options)
117120
end
118-
function optimize(f, g, initial_x::AbstractArray, method::AbstractOptimizer,
121+
function $optimize(f, g, initial_x::AbstractArray, method::AbstractOptimizer,
119122
options::Options = Options(;default_options(method)...); inplace = true, autodiff = :finite)
120123

121124
d = promote_objtype(method, initial_x, autodiff, inplace, f, g)
122125

123-
optimize(d, initial_x, method, options)
126+
$optimize(d, initial_x, method, options)
124127
end
125-
function optimize(f, g, h, initial_x::AbstractArray{T}, method::AbstractOptimizer,
128+
function $optimize(f, g, h, initial_x::AbstractArray{T}, method::AbstractOptimizer,
126129
options::Options = Options(;default_options(method)...); inplace = true, autodiff = :finite) where T
127130

128131
d = promote_objtype(method, initial_x, autodiff, inplace, f, g, h)
129132

130-
optimize(d, initial_x, method, options)
133+
$optimize(d, initial_x, method, options)
131134
end
132135

133-
function optimize(d::D, initial_x::AbstractArray, method::SecondOrderOptimizer,
136+
function $optimize(d::D, initial_x::AbstractArray, method::SecondOrderOptimizer,
134137
options::Options = Options(;default_options(method)...); autodiff = :finite, inplace = true) where {D <: Union{NonDifferentiable, OnceDifferentiable}}
135138
d = promote_objtype(method, initial_x, autodiff, inplace, d)
136-
optimize(d, initial_x, method, options)
139+
$optimize(d, initial_x, method, options)
137140
end
141+
142+
end # eval
143+
end # for

src/multivariate/optimize/optimize.jl

Lines changed: 109 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,36 +27,73 @@ function initial_convergence(d, state, method::AbstractOptimizer, initial_x, opt
2727
end
2828
initial_convergence(d, state, method::ZerothOrderOptimizer, initial_x, options) = false
2929

30-
function optimize(d::D, initial_x::Tx, method::M,
31-
options::Options{T, TCallback} = Options(;default_options(method)...),
32-
state = initial_state(method, options, d, initial_x)) where {D<:AbstractObjective, M<:AbstractOptimizer, Tx <: AbstractArray, T, TCallback}
33-
if length(initial_x) == 1 && typeof(method) <: NelderMead
34-
error("You cannot use NelderMead for univariate problems. Alternatively, use either interval bound univariate optimization, or another method such as BFGS or Newton.")
35-
end
30+
struct OptimIterator{D <: AbstractObjective, M <: AbstractOptimizer, Tx <: AbstractArray, O <: Options, S}
31+
d::D
32+
initial_x::Tx
33+
method::M
34+
options::O
35+
state::S
36+
end
37+
38+
Base.IteratorSize(::Type{<:OptimIterator}) = Base.SizeUnknown()
39+
Base.IteratorEltype(::Type{<:OptimIterator}) = Base.HasEltype()
40+
Base.eltype(::Type{<:OptimIterator}) = IteratorState
41+
42+
@with_kw struct IteratorState{IT <: OptimIterator, TR <: OptimizationTrace}
43+
# Put `OptimIterator` in iterator state so that `OptimizationResults` can
44+
# be constructed from `IteratorState`.
45+
iter::IT
46+
47+
t0::Float64
48+
tr::TR
49+
tracing::Bool
50+
stopped::Bool
51+
stopped_by_callback::Bool
52+
stopped_by_time_limit::Bool
53+
f_limit_reached::Bool
54+
g_limit_reached::Bool
55+
h_limit_reached::Bool
56+
x_converged::Bool
57+
f_converged::Bool
58+
f_increased::Bool
59+
counter_f_tol::Int
60+
g_converged::Bool
61+
converged::Bool
62+
iteration::Int
63+
ls_success::Bool
64+
end
65+
66+
function Base.iterate(iter::OptimIterator, istate = nothing)
67+
@unpack d, initial_x, method, options, state = iter
68+
if istate === nothing
69+
t0 = time() # Initial time stamp used to control early stopping by options.time_limit
70+
71+
tr = OptimizationTrace{typeof(value(d)), typeof(method)}()
72+
tracing = options.store_trace || options.show_trace || options.extended_trace || options.callback != nothing
73+
stopped, stopped_by_callback, stopped_by_time_limit = false, false, false
74+
f_limit_reached, g_limit_reached, h_limit_reached = false, false, false
75+
x_converged, f_converged, f_increased, counter_f_tol = false, false, false, 0
3676

37-
t0 = time() # Initial time stamp used to control early stopping by options.time_limit
77+
g_converged = initial_convergence(d, state, method, initial_x, options)
78+
converged = g_converged
3879

39-
tr = OptimizationTrace{typeof(value(d)), typeof(method)}()
40-
tracing = options.store_trace || options.show_trace || options.extended_trace || options.callback != nothing
41-
stopped, stopped_by_callback, stopped_by_time_limit = false, false, false
42-
f_limit_reached, g_limit_reached, h_limit_reached = false, false, false
43-
x_converged, f_converged, f_increased, counter_f_tol = false, false, false, 0
80+
# prepare iteration counter (used to make "initial state" trace entry)
81+
iteration = 0
4482

45-
g_converged = initial_convergence(d, state, method, initial_x, options)
46-
converged = g_converged
83+
options.show_trace && print_header(method)
84+
trace!(tr, d, state, iteration, method, options, time()-t0)
85+
ls_success::Bool = true
86+
else
87+
@unpack_IteratorState istate
4788

48-
# prepare iteration counter (used to make "initial state" trace entry)
49-
iteration = 0
89+
!converged && !stopped && iteration < options.iterations || return nothing
5090

51-
options.show_trace && print_header(method)
52-
trace!(tr, d, state, iteration, method, options, time()-t0)
53-
ls_success::Bool = true
54-
while !converged && !stopped && iteration < options.iterations
5591
iteration += 1
5692

5793
ls_failed = update_state!(d, state, method)
5894
if !ls_success
59-
break # it returns true if it's forced by something in update! to stop (eg dx_dg == 0.0 in BFGS, or linesearch errors)
95+
# it returns true if it's forced by something in update! to stop (eg dx_dg == 0.0 in BFGS, or linesearch errors)
96+
return nothing
6097
end
6198
update_g!(d, state, method) # TODO: Should this be `update_fg!`?
6299

@@ -85,7 +122,35 @@ function optimize(d::D, initial_x::Tx, method::M,
85122
stopped_by_time_limit || f_limit_reached || g_limit_reached || h_limit_reached
86123
stopped = true
87124
end
88-
end # while
125+
end
126+
127+
new_istate = IteratorState(
128+
iter,
129+
t0,
130+
tr,
131+
tracing,
132+
stopped,
133+
stopped_by_callback,
134+
stopped_by_time_limit,
135+
f_limit_reached,
136+
g_limit_reached,
137+
h_limit_reached,
138+
x_converged,
139+
f_converged,
140+
f_increased,
141+
counter_f_tol,
142+
g_converged,
143+
converged,
144+
iteration,
145+
ls_success,
146+
)
147+
148+
return new_istate, new_istate
149+
end
150+
151+
function OptimizationResults(istate::IteratorState)
152+
@unpack_IteratorState istate
153+
@unpack d, initial_x, method, options, state = iter
89154

90155
after_while!(d, state, method, options)
91156

@@ -94,6 +159,9 @@ function optimize(d::D, initial_x::Tx, method::M,
94159
Tf = typeof(value(d))
95160
f_incr_pick = f_increased && !options.allow_f_increases
96161

162+
T = (_tmp(::Options{T}) where T = T)(options)
163+
Tx = typeof(initial_x)
164+
97165
return MultivariateOptimizationResults{typeof(method),T,Tx,typeof(x_abschange(state)),Tf,typeof(tr), Bool}(method,
98166
initial_x,
99167
pick_best_x(f_incr_pick, state),
@@ -120,3 +188,22 @@ function optimize(d::D, initial_x::Tx, method::M,
120188
h_calls(d),
121189
!ls_success)
122190
end
191+
192+
function optimizing(d::D, initial_x::Tx, method::M,
193+
options::Options = Options(;default_options(method)...),
194+
state = initial_state(method, options, d, initial_x)) where {D<:AbstractObjective, M<:AbstractOptimizer, Tx <: AbstractArray}
195+
if length(initial_x) == 1 && typeof(method) <: NelderMead
196+
error("You cannot use NelderMead for univariate problems. Alternatively, use either interval bound univariate optimization, or another method such as BFGS or Newton.")
197+
end
198+
return OptimIterator(d, initial_x, method, options, state)
199+
end
200+
201+
function optimize(d::D, initial_x::Tx, method::M,
202+
options::Options = Options(;default_options(method)...),
203+
state = initial_state(method, options, d, initial_x)) where {D<:AbstractObjective, M<:AbstractOptimizer, Tx <: AbstractArray}
204+
local istate
205+
for istate′ in optimizing(d, initial_x, method, options, state)
206+
istate = istate′
207+
end
208+
return OptimizationResults(istate)
209+
end

test/general/api.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,12 @@
144144
res_extended_nm = Optim.optimize(f, g!, initial_x, NelderMead(), options_extended_nm)
145145
@test haskey(Optim.trace(res_extended_nm)[1].metadata,"centroid")
146146
@test haskey(Optim.trace(res_extended_nm)[1].metadata,"step_type")
147+
148+
local istate
149+
for istate′ in Optim.optimizing(f, initial_x, BFGS())
150+
istate = istate′
151+
end
152+
@test Optim.OptimizationResults(istate) isa Optim.MultivariateOptimizationResults
147153
end
148154

149155
# Test univariate API

0 commit comments

Comments
 (0)