Skip to content

Commit 068eb49

Browse files
committed
Setting up preconditioner within Reactor function
1 parent 3c35141 commit 068eb49

1 file changed

Lines changed: 53 additions & 6 deletions

File tree

src/Reactor.jl

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,52 @@ struct Reactor{D,Q} <: AbstractReactor
1616
forwardsensitivities::Bool
1717
end
1818

19-
function Reactor(domain::T,y0::Array{W,1},tspan::Tuple,interfaces::Z=[];p::X=DiffEqBase.NullParameters(),forwardsensitivities=false,forwarddiff=false,modelingtoolkit=false) where {T<:AbstractDomain,W<:Real,Z<:AbstractArray,X}
19+
function Reactor(domain::T,y0::Array{T1,1},tspan::Tuple,interfaces::Z=[];p::X=DiffEqBase.NullParameters(),forwardsensitivities=false,forwarddiff=false,modelingtoolkit=false,tau=1e-3) where {T<:AbstractDomain,T1<:Real,Z<:AbstractArray,X}
2020
dydt(dy::X,y::T,p::V,t::Q) where {X,T,Q,V} = dydtreactor!(dy,y,t,domain,interfaces,p=p)
2121
jacy!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobiany!(J,y,p,t,domain,interfaces,nothing)
2222
jacyforwarddiff!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobianyforwarddiff!(J,y,p,t,domain,interfaces,nothing)
2323
jacp!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobianp!(J,y,p,t,domain,interfaces,nothing)
2424
jacpforwarddiff!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobianpforwarddiff!(J,y,p,t,domain,interfaces,nothing)
2525

26+
psetupsundials(p::T1, t::T2, u::T3, du::T4, jok::Bool, jcurPtr::T5, gamma::T6) where {T1,T2,T3,T4,T5,T6} = _psetup(p, t, u, du, jok, jcurPtr, gamma, jacy!, W::SparseMatrixCSC{Float64, Int64}, preccache::Base.RefValue{IncompleteLU.ILUFactorization{Float64, Int64}}, tau::Float64)
27+
precsundials(z::T1, r::T2, p::T3, t::T4, y::T5, fy::T6, gamma::T7, delta::T8, lr::T9) where {T1,T2,T3,T4,T5,T6,T7,T8,T9} = _prec(z, r, p, t, y, fy, gamma, delta, lr, preccache)
28+
precsjulia(W::T1,du::T2,u::T3,p::T4,t::T5,newW::T6,Plprev::T7,Prprev::T8,solverdata::T9) where {T1,T2,T3,T4,T5,T6,T7,T8,T9} = _precsjulia(W,du,u,p,t,newW,Plprev,Prprev,solverdata,tau)
29+
30+
# determine worst sparsity
31+
y0length = length(y0)
32+
J = spzeros(y0length,y0length)
33+
jacyforwarddiff!(J,NaN*ones(y0length),p,0.0)
34+
@. J.nzval = 1.0
35+
sparsity = 1.0 - length(J.nzval)/(y0length*y0length)
36+
37+
# preconditioner caches for Sundials solver
38+
W = spzeros(y0length,y0length)
39+
jacyforwarddiff!(W,y0,p,0.0)
40+
@. W.nzval = -1.0*W.nzval
41+
idxs = diagind(W)
42+
@inbounds @views @. W[idxs] = W[idxs] + 1
43+
prectmp = ilu(W, τ = tau)
44+
preccache = Ref(prectmp)
45+
2646
if (forwardsensitivities || !forwarddiff) && domain isa Union{ConstantTPDomain,ConstantVDomain,ConstantPDomain,ParametrizedTPDomain,ParametrizedVDomain,ParametrizedPDomain,ConstantTVDomain,ParametrizedTConstantVDomain,ConstantTAPhiDomain}
2747
if !forwardsensitivities
2848
odefcn = ODEFunction(dydt;jac=jacy!,paramjac=jacp!)
2949
else
3050
odefcn = ODEFunction(dydt;paramjac=jacp!)
3151
end
3252
else
33-
odefcn = ODEFunction(dydt;jac=jacyforwarddiff!,paramjac=jacpforwarddiff!)
53+
odefcn = ODEFunction(dydt;jac=jacyforwarddiff!,paramjac=jacpforwarddiff!,jac_prototype=float.(J)) #jac_prototype is not needed/used for Sundials solvers but maybe needed for Julia solvers
3454
end
3555
if forwardsensitivities
3656
ode = ODEForwardSensitivityProblem(odefcn,y0,tspan,p)
3757
recsolver = Sundials.CVODE_BDF(linear_solver=:GMRES)
3858
else
3959
ode = ODEProblem(odefcn,y0,tspan,p)
40-
recsolver = Sundials.CVODE_BDF()
60+
if sparsity > 0.8 #empirical threshold to use preconditioner
61+
recsolver = Sundials.CVODE_BDF(linear_solver=:GMRES,prec=precsundials,psetup=psetupsundials,prec_side=1)
62+
else
63+
recsolver = Sundials.CVODE_BDF()
64+
end
4165
end
4266
if modelingtoolkit
4367
sys = modelingtoolkitize(ode)
@@ -55,7 +79,7 @@ function Reactor(domain::T,y0::Array{W,1},tspan::Tuple,interfaces::Z=[];p::X=Dif
5579
end
5680
return Reactor(domain,ode,recsolver,forwardsensitivities)
5781
end
58-
function Reactor(domains::T,y0s::W,tspan::W2,interfaces::Z=Tuple(),ps::X=DiffEqBase.NullParameters();forwardsensitivities=false,modelingtoolkit=false) where {T<:Tuple,W<:Tuple,Z,X,W2}
82+
function Reactor(domains::T,y0s::W1,tspan::W2,interfaces::Z=Tuple(),ps::X=DiffEqBase.NullParameters();forwardsensitivities=false,modelingtoolkit=false,tau=1e-3) where {T<:Tuple,W1<:Tuple,Z,X,W2}
5983
#adjust indexing
6084
y0 = zeros(sum(length(y) for y in y0s))
6185
Nvars = 0
@@ -140,6 +164,25 @@ function Reactor(domains::T,y0s::W,tspan::W2,interfaces::Z=Tuple(),ps::X=DiffEqB
140164
jacy!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobianyforwarddiff!(J,y,p,t,domains,interfaces,nothing)
141165
jacp!(J::Q2,y::T,p::V,t::Q) where {Q2,T,Q,V} = jacobianpforwarddiff!(J,y,p,t,domains,interfaces,nothing)
142166

167+
psetupsundials(p::T1, t::T2, u::T3, du::T4, jok::Bool, jcurPtr::T5, gamma::T6) where {T1,T2,T3,T4,T5,T6} = _psetup(p, t, u, du, jok, jcurPtr, gamma, jacy!, W::SparseMatrixCSC{Float64, Int64}, preccache::Base.RefValue{IncompleteLU.ILUFactorization{Float64, Int64}}, tau::Float64)
168+
precsundials(z::T1, r::T2, p::T3, t::T4, y::T5, fy::T6, gamma::T7, delta::T8, lr::T9) where {T1,T2,T3,T4,T5,T6,T7,T8,T9} = _prec(z, r, p, t, y, fy, gamma, delta, lr, preccache)
169+
precsjulia(W::T1,du::T2,u::T3,p::T4,t::T5,newW::T6,Plprev::T7,Prprev::T8,solverdata::T9) where {T1,T2,T3,T4,T5,T6,T7,T8,T9} = _precsjulia(W,du,u,p,t,newW,Plprev,Prprev,solverdata,tau)
170+
171+
# determine worst sparsity
172+
y0length = length(y0)
173+
J = spzeros(y0length,y0length)
174+
jacy!(J,NaN*ones(y0length),p,0.0)
175+
@. J.nzval = 1.0
176+
sparsity = 1.0 - length(J.nzval)/(y0length*y0length)
177+
178+
# preconditioner caches for Sundials solver
179+
W = spzeros(y0length,y0length)
180+
jacy!(W,y0,p,0.0)
181+
@. W.nzval = -1.0*W.nzval
182+
idxs = diagind(W)
183+
@inbounds @views @. W[idxs] = W[idxs] + 1
184+
prectmp = ilu(W, τ = tau)
185+
preccache = Ref(prectmp)
143186

144187
if forwardsensitivities
145188
odefcn = ODEFunction(dydt;paramjac=jacp!)
@@ -152,9 +195,13 @@ function Reactor(domains::T,y0s::W,tspan::W2,interfaces::Z=Tuple(),ps::X=DiffEqB
152195
ode = ODEForwardSensitivityProblem(odefcn,y0,tspan,p)
153196
end
154197
else
155-
odefcn = ODEFunction(dydt;jac=jacy!,paramjac=jacp!)
198+
odefcn = ODEFunction(dydt;jac=jacy!,paramjac=jacp!,jac_prototype=float.(J))
156199
ode = ODEProblem(odefcn,y0,tspan,p)
157-
recsolver = Sundials.CVODE_BDF()
200+
if sparsity > 0.8 #empirical threshold to use preconditioner
201+
recsolver = Sundials.CVODE_BDF(linear_solver=:GMRES,prec=precsundials,psetup=psetupsundials,prec_side=1)
202+
else
203+
recsolver = Sundials.CVODE_BDF()
204+
end
158205
if modelingtoolkit
159206
sys = modelingtoolkitize(ode)
160207
jac = eval(ModelingToolkit.generate_jacobian(sys)[2])

0 commit comments

Comments
 (0)