Skip to content

Commit 30d3756

Browse files
authored
Merge pull request #185 from ReactionMechanismGenerator/precondition
Use preconditioner if sparsity is larger than an empirical threshold
2 parents 2053510 + b1eeaaa commit 30d3756

2 files changed

Lines changed: 131 additions & 9 deletions

File tree

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ DiffEqSensitivity = "41bf760c-e81c-5289-8e54-58b1f1f8abe2"
1313
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
1414
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1515
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
16+
IncompleteLU = "40713840-3770-5561-ab4c-a76e7d0d7895"
1617
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
1718
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1819
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
@@ -46,6 +47,7 @@ DiffEqBase = "6"
4647
DiffEqSensitivity = "6"
4748
ForwardDiff = "0.10"
4849
Images = "0.23"
50+
IncompleteLU = "0.2.0"
4951
IterTools = "1.3.0"
5052
LsqFit = "0.12"
5153
ModelingToolkit = "3"

src/Reactor.jl

Lines changed: 129 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,38 +3,68 @@ using DiffEqBase
33
using ForwardDiff
44
using Sundials
55
using ModelingToolkit
6+
using IncompleteLU
7+
using LinearAlgebra
8+
using SparseArrays
69
abstract type AbstractReactor end
710
export AbstractReactor
811

9-
struct Reactor{D,Q} <: AbstractReactor
12+
struct Reactor{D,Q,F1,F2,F3} <: AbstractReactor
1013
domain::D
1114
ode::ODEProblem
1215
recommendedsolver::Q
1316
forwardsensitivities::Bool
17+
precsundials::F1 #function to calculate preconditioner for Sundials solvers
18+
psetupsundials::F2 #function to compute preconditioner \ residue for Sundials solvers
19+
precsjulia::F3 #function to calculate preconditioner for Julia solvers
1420
end
1521

16-
function Reactor(domain::T,y0::Array{W,1},tspan::Tuple,interfaces::Z=[];p::X=DiffEqBase.NullParameters(),forwardsensitivities=false,forwarddiff=false,modelingtoolkit=false) where {T<:AbstractDomain,W<:Real,Z<:AbstractArray,X}
22+
function Reactor(domain::T,y0::Array{T1,1},tspan::Tuple,interfaces::Z=[];p::X=DiffEqBase.NullParameters(),forwardsensitivities=false,forwarddiff=false,modelingtoolkit=false,tau=1e-3) where {T<:AbstractDomain,T1<:Real,Z<:AbstractArray,X}
1723
dydt(dy::X,y::T,p::V,t::Q) where {X,T,Q,V} = dydtreactor!(dy,y,t,domain,interfaces,p=p)
1824
jacy!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobiany!(J,y,p,t,domain,interfaces,nothing)
1925
jacyforwarddiff!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobianyforwarddiff!(J,y,p,t,domain,interfaces,nothing)
2026
jacp!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobianp!(J,y,p,t,domain,interfaces,nothing)
2127
jacpforwarddiff!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobianpforwarddiff!(J,y,p,t,domain,interfaces,nothing)
2228

29+
psetupsundials(p::T1, t::T2, u::T3, du::T4, jok::Bool, jcurPtr::T5, gamma::T6) where {T1,T2,T3,T4,T5,T6} = _psetup(p, t, u, du, jok, jcurPtr, gamma, jacy!, W::SparseMatrixCSC{Float64, Int64}, preccache::Base.RefValue{IncompleteLU.ILUFactorization{Float64, Int64}}, tau::Float64)
30+
precsundials(z::T1, r::T2, p::T3, t::T4, y::T5, fy::T6, gamma::T7, delta::T8, lr::T9) where {T1,T2,T3,T4,T5,T6,T7,T8,T9} = _prec(z, r, p, t, y, fy, gamma, delta, lr, preccache)
31+
precsjulia(W::T1,du::T2,u::T3,p::T4,t::T5,newW::T6,Plprev::T7,Prprev::T8,solverdata::T9) where {T1,T2,T3,T4,T5,T6,T7,T8,T9} = _precsjulia(W,du,u,p,t,newW,Plprev,Prprev,solverdata,tau)
32+
33+
# determine worst sparsity
34+
y0length = length(y0)
35+
J = spzeros(y0length,y0length)
36+
jacyforwarddiff!(J,NaN*ones(y0length),p,0.0)
37+
@. J.nzval = 1.0
38+
sparsity = 1.0 - length(J.nzval)/(y0length*y0length)
39+
40+
# preconditioner caches for Sundials solver
41+
W = spzeros(y0length,y0length)
42+
jacyforwarddiff!(W,y0,p,0.0)
43+
@. W.nzval = -1.0*W.nzval
44+
idxs = diagind(W)
45+
@inbounds @views @. W[idxs] = W[idxs] + 1
46+
prectmp = ilu(W, τ = tau)
47+
preccache = Ref(prectmp)
48+
2349
if (forwardsensitivities || !forwarddiff) && domain isa Union{ConstantTPDomain,ConstantVDomain,ConstantPDomain,ParametrizedTPDomain,ParametrizedVDomain,ParametrizedPDomain,ConstantTVDomain,ParametrizedTConstantVDomain,ConstantTAPhiDomain}
2450
if !forwardsensitivities
2551
odefcn = ODEFunction(dydt;jac=jacy!,paramjac=jacp!)
2652
else
2753
odefcn = ODEFunction(dydt;paramjac=jacp!)
2854
end
2955
else
30-
odefcn = ODEFunction(dydt;jac=jacyforwarddiff!,paramjac=jacpforwarddiff!)
56+
odefcn = ODEFunction(dydt;jac=jacyforwarddiff!,paramjac=jacpforwarddiff!,jac_prototype=float.(J)) #jac_prototype is not needed/used for Sundials solvers but maybe needed for Julia solvers
3157
end
3258
if forwardsensitivities
3359
ode = ODEForwardSensitivityProblem(odefcn,y0,tspan,p)
3460
recsolver = Sundials.CVODE_BDF(linear_solver=:GMRES)
3561
else
3662
ode = ODEProblem(odefcn,y0,tspan,p)
37-
recsolver = Sundials.CVODE_BDF()
63+
if sparsity > 0.8 #empirical threshold to use preconditioner
64+
recsolver = Sundials.CVODE_BDF(linear_solver=:GMRES,prec=precsundials,psetup=psetupsundials,prec_side=1)
65+
else
66+
recsolver = Sundials.CVODE_BDF()
67+
end
3868
end
3969
if modelingtoolkit
4070
sys = modelingtoolkitize(ode)
@@ -50,9 +80,9 @@ function Reactor(domain::T,y0::Array{W,1},tspan::Tuple,interfaces::Z=[];p::X=Dif
5080
ode = ODEProblem(odefcn,y0,tspan,p)
5181
end
5282
end
53-
return Reactor(domain,ode,recsolver,forwardsensitivities)
83+
return Reactor(domain,ode,recsolver,forwardsensitivities,precsundials,psetupsundials,precsjulia)
5484
end
55-
function Reactor(domains::T,y0s::W,tspan::W2,interfaces::Z=Tuple(),ps::X=DiffEqBase.NullParameters();forwardsensitivities=false,modelingtoolkit=false) where {T<:Tuple,W<:Tuple,Z,X,W2}
85+
function Reactor(domains::T,y0s::W1,tspan::W2,interfaces::Z=Tuple(),ps::X=DiffEqBase.NullParameters();forwardsensitivities=false,modelingtoolkit=false,tau=1e-3) where {T<:Tuple,W1<:Tuple,Z,X,W2}
5686
#adjust indexing
5787
y0 = zeros(sum(length(y) for y in y0s))
5888
Nvars = 0
@@ -137,6 +167,25 @@ function Reactor(domains::T,y0s::W,tspan::W2,interfaces::Z=Tuple(),ps::X=DiffEqB
137167
jacy!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobianyforwarddiff!(J,y,p,t,domains,interfaces,nothing)
138168
jacp!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobianpforwarddiff!(J,y,p,t,domains,interfaces,nothing)
139169

170+
psetupsundials(p::T1, t::T2, u::T3, du::T4, jok::Bool, jcurPtr::T5, gamma::T6) where {T1,T2,T3,T4,T5,T6} = _psetup(p, t, u, du, jok, jcurPtr, gamma, jacy!, W::SparseMatrixCSC{Float64, Int64}, preccache::Base.RefValue{IncompleteLU.ILUFactorization{Float64, Int64}}, tau::Float64)
171+
precsundials(z::T1, r::T2, p::T3, t::T4, y::T5, fy::T6, gamma::T7, delta::T8, lr::T9) where {T1,T2,T3,T4,T5,T6,T7,T8,T9} = _prec(z, r, p, t, y, fy, gamma, delta, lr, preccache)
172+
precsjulia(W::T1,du::T2,u::T3,p::T4,t::T5,newW::T6,Plprev::T7,Prprev::T8,solverdata::T9) where {T1,T2,T3,T4,T5,T6,T7,T8,T9} = _precsjulia(W,du,u,p,t,newW,Plprev,Prprev,solverdata,tau)
173+
174+
# determine worst sparsity
175+
y0length = length(y0)
176+
J = spzeros(y0length,y0length)
177+
jacy!(J,NaN*ones(y0length),p,0.0)
178+
@. J.nzval = 1.0
179+
sparsity = 1.0 - length(J.nzval)/(y0length*y0length)
180+
181+
# preconditioner caches for Sundials solver
182+
W = spzeros(y0length,y0length)
183+
jacy!(W,y0,p,0.0)
184+
@. W.nzval = -1.0*W.nzval
185+
idxs = diagind(W)
186+
@inbounds @views @. W[idxs] = W[idxs] + 1
187+
prectmp = ilu(W, τ = tau)
188+
preccache = Ref(prectmp)
140189

141190
if forwardsensitivities
142191
odefcn = ODEFunction(dydt;paramjac=jacp!)
@@ -149,20 +198,91 @@ function Reactor(domains::T,y0s::W,tspan::W2,interfaces::Z=Tuple(),ps::X=DiffEqB
149198
ode = ODEForwardSensitivityProblem(odefcn,y0,tspan,p)
150199
end
151200
else
152-
odefcn = ODEFunction(dydt;jac=jacy!,paramjac=jacp!)
201+
odefcn = ODEFunction(dydt;jac=jacy!,paramjac=jacp!,jac_prototype=float.(J))
153202
ode = ODEProblem(odefcn,y0,tspan,p)
154-
recsolver = Sundials.CVODE_BDF()
203+
if sparsity > 0.8 #empirical threshold to use preconditioner
204+
recsolver = Sundials.CVODE_BDF(linear_solver=:GMRES,prec=precsundials,psetup=psetupsundials,prec_side=1)
205+
else
206+
recsolver = Sundials.CVODE_BDF()
207+
end
155208
if modelingtoolkit
156209
sys = modelingtoolkitize(ode)
157210
jac = eval(ModelingToolkit.generate_jacobian(sys)[2])
158211
odefcn = ODEFunction(dydt;jac=jac,paramjac=jacp!)
159212
ode = ODEProblem(odefcn,y0,tspan,p)
160213
end
161214
end
162-
return Reactor(domains,ode,recsolver,forwardsensitivities),y0,p
215+
return Reactor(domains,ode,recsolver,forwardsensitivities,precsundials,psetupsundials,precsjulia),y0,p
163216
end
164217
export Reactor
165218

219+
#preconditioner related functions
220+
@inline function _psetupsundials(p::T1, t::T2, u::T3, du::T4, jok::Bool, jcurPtr::T5, gamma::T6, jac!::T7, W::T8, preccache::T9, tau::T10) where {T1,T2,T3,T4,T5,T6,T7,T8,T9,T10}
221+
"""
222+
Update preconditioner when Jacobian needs to be updated for Sundials solvers. Credit to tutorial of DifferentialEquations.jl.
223+
p: the parameters
224+
t: the current independent variable
225+
u: the current state
226+
du: the current f(u,p,t)
227+
jok: a bool indicating whether the Jacobian needs to be updated
228+
jcurPtr: a reference to an Int for whether the Jacobian was updated. jcurPtr[]=true should be set if the Jacobian was updated, and jcurPtr[]=false should be set if the Jacobian was not updated.
229+
gamma: the gamma of W = M - gamma*J
230+
"""
231+
if jok
232+
@. W = 0.0
233+
jac!(W,u,p,t)
234+
jcurPtr[] = true
235+
236+
# W = I - gamma*J
237+
@. W.nzval = -gamma*W.nzval
238+
idxs = diagind(W)
239+
@inbounds @views @. W[idxs] = W[idxs] + 1
240+
241+
# Build preconditioner on W
242+
preccache[] = ilu(W, τ = tau)
243+
end
244+
nothing
245+
end
246+
@inline function _precsundials(z::T1, r::T2, p::T3, t::T4, y::T5, fy::T6, gamma::T7, delta::T8, lr::T9, preccache::T10) where {T1,T2,T3,T4,T5,T6,T7,T8,T9,T10}
247+
"""
248+
Compute preccache \\ r in-place and store the result in z for Sundials solver. Credit to tutorial of DifferentialEquations.jl.
249+
z: the computed output vector
250+
r: the right-hand side vector of the linear system
251+
p: the parameters
252+
t: the current independent variable
253+
du: the current value of f(u,p,t)
254+
gamma: the gamma of W = M - gamma*J
255+
delta: the iterative method tolerance
256+
lr: a flag for whether lr=1 (left) or lr=2 (right) preconditioning
257+
preccache: preconditioner cache
258+
"""
259+
ldiv!(z,preccache[],r)
260+
end
261+
@inline function _precsjulia(W::T1,du::T2,u::T3,p::T4,t::T5,newW::T6,Plprev::T7,Prprev::T8,solverdata::T9,tau::T10) where {T1,T2,T3,T4,T5,T6,T7,T8,T9,T10}
262+
"""
263+
Update preconditioner when Jacobian needs to be updated for Julia solvers. Credit to tutorial of DifferentialEquations.jl.
264+
W: I - gamma*J or I/gamma - J depending on the algorithm.
265+
Commonly be a WOperator type defined by OrdinaryDiffEq.jl. It is a lazy representation of the operator
266+
Users can construct the W-matrix on demand by calling convert(AbstractMatrix,W) to receive an AbstractMatrix matching the jac_prototype.
267+
du: the current ODE derivative
268+
u: the current ODE state
269+
p: the ODE parameters
270+
t: the current ODE time
271+
newW: a Bool which specifies whether the W matrix has been updated since the last call to precs.
272+
It is recommended that this is checked to only update the preconditioner when newW == true.
273+
Plprev: the previous Pl.
274+
Prprev: the previous Pr.
275+
solverdata: Optional extra data the solvers can give to the precs function. Solver-dependent and subject to change.
276+
"""
277+
if newW === nothing || newW
278+
Pl = ilu(convert(AbstractMatrix,W), τ = tau)
279+
else
280+
Pl = Plprev
281+
end
282+
Pl,nothing
283+
end
284+
285+
166286
@inline function getrate(rxn::T,cs::Array{W,1},kfs::Array{Q,1},krevs::Array{Q,1}) where {T<:AbstractReaction,Q,W<:Real}
167287
Nreact = length(rxn.reactantinds)
168288
Nprod = length(rxn.productinds)

0 commit comments

Comments
 (0)