Skip to content

Commit 8f4b128

Browse files
authored
Merge pull request #97 from ReactionMechanismGenerator/adjoint_optimize
Speed up Adjoint Sensitivities
2 parents 6eeadc2 + 019e33d commit 8f4b128

4 files changed

Lines changed: 211 additions & 36 deletions

File tree

src/Domain.jl

Lines changed: 102 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using Parameters
22
using LinearAlgebra
33
using StaticArrays
44
using Calculus
5-
using SmoothingSplines
65
using DiffEqBase
76
using ForwardDiff
87
using Tracker
@@ -37,6 +36,7 @@ export AbstractVariableKDomain
3736
jacuptodate::MArray{Tuple{1},Bool,1,1}=MVector(false)
3837
t::MArray{Tuple{1},W2,1,1}=MVector(0.0)
3938
p::Array{W,1}
39+
thermovariabledict::Dict{String,Int64}
4040
end
4141
function ConstantTPDomain(;phase::E2,initialconds::Dict{X,X2},constantspecies::Array{X3,1}=Array{String,1}(),
4242
sparse::Bool=false,sensitivity::Bool=false) where {E<:Real,E2<:AbstractPhase,Q<:AbstractInterface,W<:Real,X,X2,X3}
@@ -97,7 +97,7 @@ function ConstantTPDomain(;phase::E2,initialconds::Dict{X,X2},constantspecies::A
9797
end
9898
rxnarray = getreactionindices(phase)
9999
return ConstantTPDomain(phase,[phase.species[1].index,phase.species[end].index,phase.species[end].index+1],[1,length(phase.species)+length(phase.reactions)],constspcinds,
100-
T,P,kfs,krevs,efficiencyinds,Gs,rxnarray,mu,diffs,jacobian,sensitivity,false,MVector(false),MVector(0.0),p), y0, p
100+
T,P,kfs,krevs,efficiencyinds,Gs,rxnarray,mu,diffs,jacobian,sensitivity,false,MVector(false),MVector(0.0),p, Dict("V"=>phase.species[end].index+1)), y0, p
101101
end
102102
export ConstantTPDomain
103103

@@ -114,6 +114,7 @@ export ConstantTPDomain
114114
jacuptodate::MArray{Tuple{1},Bool,1,1}=MVector(false)
115115
t::MArray{Tuple{1},W2,1,1}=MVector(0.0)
116116
p::Array{W,1}
117+
thermovariabledict::Dict{String,Int64}
117118
end
118119
function ConstantVDomain(;phase::Z,initialconds::Dict{X,E},constantspecies::Array{X2,1}=Array{String,1}(),
119120
sparse::Bool=false,sensitivity::Bool=false) where {E,X,X2,Z<:IdealGas,Q<:AbstractInterface}
@@ -164,7 +165,7 @@ function ConstantVDomain(;phase::Z,initialconds::Dict{X,E},constantspecies::Arra
164165
end
165166
rxnarray = getreactionindices(phase)
166167
return ConstantVDomain(phase,[phase.species[1].index,phase.species[end].index,phase.species[end].index+1,phase.species[end].index+2],[1,length(phase.species)+length(phase.reactions)],constspcinds,
167-
V,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p), y0, p
168+
V,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p,Dict("T"=>phase.species[end].index+1,"P"=>phase.species[end].index+2)), y0, p
168169
end
169170
export ConstantVDomain
170171

@@ -181,6 +182,7 @@ export ConstantVDomain
181182
jacuptodate::MArray{Tuple{1},Bool,1,1}=MVector(false)
182183
t::MArray{Tuple{1},W2,1,1}=MVector(0.0)
183184
p::Array{W,1}
185+
thermovariabledict::Dict{String,Int64}
184186
end
185187
function ConstantPDomain(;phase::Z,initialconds::Dict{X,E},constantspecies::Array{X2,1}=Array{String,1}(),
186188
sparse::Bool=false,sensitivity::Bool=false) where {E,X,X2,Z<:IdealGas,Q<:AbstractInterface}
@@ -231,7 +233,7 @@ function ConstantPDomain(;phase::Z,initialconds::Dict{X,E},constantspecies::Arra
231233
end
232234
rxnarray = getreactionindices(phase)
233235
return ConstantPDomain(phase,[phase.species[1].index,phase.species[end].index,phase.species[end].index+1,phase.species[end].index+2],[1,length(phase.species)+length(phase.reactions)],constspcinds,
234-
P,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p), y0, p
236+
P,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p,Dict("T"=>phase.species[end].index+1,"V"=>phase.species[end].index+2)), y0, p
235237
end
236238
export ConstantPDomain
237239

@@ -249,6 +251,7 @@ export ConstantPDomain
249251
jacuptodate::MArray{Tuple{1},Bool,1,1}=MVector(false)
250252
t::MArray{Tuple{1},W2,1,1}=MVector(0.0)
251253
p::Array{W,1}
254+
thermovariabledict::Dict{String,Int64}
252255
end
253256
function ParametrizedTPDomain(;phase::Z,initialconds::Dict{X,Any},constantspecies::Array{X2,1}=Array{String,1}(),
254257
sparse::Bool=false,sensitivity::Bool=false) where {X,X2,Z<:IdealGas,Q<:AbstractInterface}
@@ -311,7 +314,7 @@ function ParametrizedTPDomain(;phase::Z,initialconds::Dict{X,Any},constantspecie
311314
end
312315
rxnarray = getreactionindices(phase)
313316
return ParametrizedTPDomain(phase,[phase.species[1].index,phase.species[end].index,phase.species[end].index+1],[1,length(phase.species)+length(phase.reactions)],constspcinds,
314-
Tfcn,Pfcn,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p), y0, p
317+
Tfcn,Pfcn,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p,Dict("V"=>phase.species[end].index+1)), y0, p
315318
end
316319
export ParametrizedTPDomain
317320

@@ -328,6 +331,7 @@ export ParametrizedTPDomain
328331
jacuptodate::MArray{Tuple{1},Bool,1,1}=MVector(false)
329332
t::MArray{Tuple{1},W2,1,1}=MVector(0.0)
330333
p::Array{W,1}
334+
thermovariabledict::Dict{String,Int64}
331335
end
332336
function ParametrizedVDomain(;phase::Z,initialconds::Dict{X,Any},constantspecies::Array{X2,1}=Array{String,1}(),
333337
sparse::Bool=false,sensitivity::Bool=false) where {X,X2,E<:Real,Z<:IdealGas,Q<:AbstractInterface}
@@ -387,7 +391,7 @@ function ParametrizedVDomain(;phase::Z,initialconds::Dict{X,Any},constantspecies
387391
end
388392
rxnarray = getreactionindices(phase)
389393
return ParametrizedVDomain(phase,[phase.species[1].index,phase.species[end].index,phase.species[end].index+1,phase.species[end].index+2],[1,length(phase.species)+length(phase.reactions)],constspcinds,
390-
Vfcn,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p), y0, p
394+
Vfcn,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p,Dict("T"=>phase.species[end].index+1,"P"=>phase.species[end].index+2)), y0, p
391395
end
392396
export ParametrizedVDomain
393397

@@ -404,6 +408,7 @@ export ParametrizedVDomain
404408
jacuptodate::MArray{Tuple{1},Bool,1,1}=MVector(false)
405409
t::MArray{Tuple{1},W2,1,1}=MVector(0.0)
406410
p::Array{W,1}
411+
thermovariabledict::Dict{String,Int64}
407412
end
408413
function ParametrizedPDomain(;phase::Z,initialconds::Dict{X,Any},constantspecies::Array{X2,1}=Array{String,1}(),
409414
sparse::Bool=false,sensitivity::Bool=false) where {X,X2,E<:Real,Z<:IdealGas,Q<:AbstractInterface}
@@ -463,7 +468,7 @@ function ParametrizedPDomain(;phase::Z,initialconds::Dict{X,Any},constantspecies
463468
end
464469
rxnarray = getreactionindices(phase)
465470
return ParametrizedPDomain(phase,[phase.species[1].index,phase.species[end].index,phase.species[end].index+1,phase.species[end].index+2],[1,length(phase.species)+length(phase.reactions)],constspcinds,
466-
Pfcn,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p), y0, p
471+
Pfcn,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p,Dict("T"=>phase.species[end].index+1,"V"=>phase.species[end].index+2)), y0, p
467472
end
468473
export ParametrizedPDomain
469474

@@ -488,6 +493,7 @@ export ParametrizedPDomain
488493
jacuptodate::MArray{Tuple{1},Bool,1,1}=MVector(false)
489494
t::MArray{Tuple{1},W2,1,1}=MVector(0.0)
490495
p::Array{W,1}
496+
thermovariabledict::Dict{String,Int64}
491497
end
492498
function ConstantTVDomain(;phase::Z,initialconds::Dict{X,E},constantspecies::Array{X2,1}=Array{String,1}(),
493499
sparse=false,sensitivity=false) where {E,X,X2, Z<:AbstractPhase,Q<:AbstractInterface,W<:Real}
@@ -545,7 +551,7 @@ function ConstantTVDomain(;phase::Z,initialconds::Dict{X,E},constantspecies::Arr
545551
end
546552
rxnarray = getreactionindices(phase)
547553
return ConstantTVDomain(phase,[phase.species[1].index,phase.species[end].index],[1,length(phase.species)+length(phase.reactions)],constspcinds,
548-
T,V,kfs,krevs,kfsnondiff,efficiencyinds,Gs,rxnarray,mu,diffs,jacobian,sensitivity,false,MVector(false),MVector(0.0),p), y0, p
554+
T,V,kfs,krevs,kfsnondiff,efficiencyinds,Gs,rxnarray,mu,diffs,jacobian,sensitivity,false,MVector(false),MVector(0.0),p,Dict{String,Int64}()), y0, p
549555
end
550556
export ConstantTVDomain
551557

@@ -563,6 +569,7 @@ export ConstantTVDomain
563569
jacuptodate::MArray{Tuple{1},Bool,1,1}=MVector(false)
564570
t::MArray{Tuple{1},W2,1,1}=MVector(0.0)
565571
p::Array{W,1}
572+
thermovariabledict::Dict{String,Int64}
566573
end
567574
function ParametrizedTConstantVDomain(;phase::IdealDiluteSolution,initialconds::Dict{X,X3},constantspecies::Array{X2,1}=Array{String,1}(),
568575
sparse::Bool=false,sensitivity::Bool=false) where {X,X2,X3,Q<:AbstractInterface}
@@ -614,7 +621,7 @@ function ParametrizedTConstantVDomain(;phase::IdealDiluteSolution,initialconds::
614621
end
615622
rxnarray = getreactionindices(phase)
616623
return ParametrizedTConstantVDomain(phase,[phase.species[1].index,phase.species[end].index],[1,length(phase.species)+length(phase.reactions)],constspcinds,
617-
Tfcn,V,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p), y0, p
624+
Tfcn,V,efficiencyinds,rxnarray,jacobian,sensitivity,MVector(false),MVector(0.0),p,Dict{String,Int64}()), y0, p
618625
end
619626
export ParametrizedTConstantVDomain
620627

@@ -638,6 +645,7 @@ export ParametrizedTConstantVDomain
638645
jacuptodate::MArray{Tuple{1},Bool,1,1}=MVector(false)
639646
t::MArray{Tuple{1},W2,1,1}=MVector(0.0)
640647
p::Array{W,1}
648+
thermovariabledict::Dict{String,Int64}
641649
end
642650
function ConstantTADomain(;phase::E2,initialconds::Dict{X,X2},constantspecies::Array{X3,1}=Array{String,1}(),
643651
sparse::Bool=false,sensitivity::Bool=false,stationary::Bool=false) where {E<:Real,E2<:AbstractPhase,W<:Real,X,X2,X3}
@@ -686,7 +694,7 @@ function ConstantTADomain(;phase::E2,initialconds::Dict{X,X2},constantspecies::A
686694
end
687695
rxnarray = getreactionindices(phase)
688696
return ConstantTADomain(phase,[phase.species[1].index,phase.species[end].index],[1,length(phase.species)+length(phase.reactions)],constspcinds,
689-
T,A,kfs,krevs,efficiencyinds,Gs,rxnarray,mu,jacobian,sensitivity,false,MVector(false),MVector(0.0),stationary), y0, p
697+
T,A,kfs,krevs,efficiencyinds,Gs,rxnarray,mu,jacobian,sensitivity,false,MVector(false),MVector(0.0),p,Dict{String,Int64}()), y0, p
690698
end
691699
export ConstantTADomain
692700

@@ -742,7 +750,7 @@ end
742750
else
743751
d.Gs = d.p[1:length(d.phase.species)].+p[d.parameterindexes[1]-1+1:d.parameterindexes[1]-1+length(d.phase.species)]
744752
end
745-
krevs = getkfkrevs(d.phase,d.T,d.P,C,N,ns,d.Gs,d.diffusivity,V=V;kfs=d.kfs)[2]
753+
krevs = getkfkrevs(d.phase,d.T,d.P,C,N,ns,d.Gs,d.diffusivity,V;kfs=d.kfs)[2]
746754
for ind in d.efficiencyinds #efficiency related rates may have changed
747755
d.kfs[ind],d.krevs[ind] = getkfkrev(d.phase.reactions[ind],d.phase,d.T,d.P,C,N,ns,d.Gs,d.diffusivity,V;f=kfps[ind])
748756
end
@@ -2298,31 +2306,93 @@ function getreactionindices(ig::Q) where {Q<:AbstractPhase}
22982306
end
22992307
export getreactionindices
23002308

2301-
"""
2302-
fit a cubic spline to data and return a function evaluating that spline
2303-
"""
2304-
function getspline(xs,vals;s=1e-10)
2305-
smspl = fit(SmoothingSpline,xs,vals,s)
2306-
F(x::T) where {T} = predict(smspl,x)
2307-
return F
2308-
end
2309+
@inline function getsensspcsrxns(domain::D,ind::Int64) where {D<:AbstractDomain}
2310+
sensspcinds = Array{Int64,1}()
2311+
sensrxninds = Array{Int64,1}()
2312+
for rxnind in 1:size(domain.rxnarray)[2]
2313+
if ind in @inbounds @view domain.rxnarray[:,rxnind]
2314+
for spcind in @inbounds @view domain.rxnarray[:,rxnind]
2315+
if !(spcind in sensspcinds) && (spcind !== 0)
2316+
push!(sensspcinds,spcind)
2317+
end
2318+
end
2319+
push!(sensrxninds,rxnind)
2320+
end
2321+
end
2322+
2323+
sensrxns = Array{ElementaryReaction,1}(undef,length(sensrxninds))
2324+
sensspcs = Array{Species,1}(undef,length(sensspcinds))
2325+
sensspcnames = Array{String,1}(undef,length(sensspcinds))
2326+
senstooriginspcind = Array{Int64,1}(undef,length(sensspcinds))
2327+
senstooriginrxnind = Array{Int64,1}(undef,length(sensrxninds))
2328+
for (i,spcind) in enumerate(sensspcinds)
2329+
spc = domain.phase.species[spcind]
2330+
sensspcnames[i] = spc.name
2331+
@inbounds sensspcs[i] = Species(
2332+
name=spc.name,
2333+
index=i,
2334+
inchi=spc.inchi,
2335+
smiles=spc.smiles,
2336+
thermo=spc.thermo,
2337+
atomnums=spc.atomnums,
2338+
bondnum=spc.bondnum,
2339+
diffusion=spc.diffusion,
2340+
radius=spc.radius,
2341+
radicalelectrons=spc.radicalelectrons,
2342+
)
2343+
@inbounds senstooriginspcind[i] = spcind
2344+
end
2345+
2346+
for (i, rxnind) in enumerate(sensrxninds)
2347+
rxn = domain.phase.reactions[rxnind]
2348+
reactants = Array{Species,1}()
2349+
reactantinds = Array{Int64,1}()
2350+
@simd for reactant in rxn.reactants
2351+
ind = findfirst(isequal(reactant.name),sensspcnames)
2352+
@inbounds push!(reactants,sensspcs[ind])
2353+
push!(reactantinds,ind)
2354+
end
2355+
products = Array{Species,1}()
2356+
productinds = Array{Int64,1}()
2357+
@simd for product in rxn.products
2358+
ind = findfirst(isequal(product.name),sensspcnames)
2359+
@inbounds push!(products,sensspcs[ind])
2360+
push!(productinds,ind)
2361+
end
23092362

2310-
function getthermovariableindex(domain::Union{ConstantTPDomain,ParametrizedTPDomain},target::String)
2311-
return domain.indexes[3]
2363+
@inbounds sensrxns[i] = ElementaryReaction(
2364+
index=i,
2365+
reactants=SVector(reactants...),
2366+
reactantinds=SVector(reactantinds...),
2367+
products=SVector(products...),
2368+
productinds=SVector(productinds...),
2369+
kinetics=rxn.kinetics,
2370+
radicalchange=rxn.radicalchange,
2371+
pairs=rxn.pairs
2372+
)
2373+
@inbounds senstooriginrxnind[i] = rxnind
2374+
end
2375+
2376+
return sensspcs,sensrxns,sensspcnames,senstooriginspcind,senstooriginrxnind
23122377
end
23132378

2314-
function getthermovariableindex(domain::Union{ConstantVDomain,ParametrizedVDomain},target::String)
2315-
if target == "T"
2316-
return domain.indexes[3]
2317-
elseif target == "P"
2318-
return domain.indexes[4]
2379+
@inline function getsensdomain(domain::D,ind::Int64) where {D<:AbstractDomain}
2380+
2381+
sensspcs,sensrxns,sensspcnames,senstooriginspcind,senstooriginrxnind = getsensspcsrxns(domain,ind)
2382+
2383+
initialconds = Dict{String,Float64}()
2384+
2385+
for fn in fieldnames(typeof(domain))
2386+
if fn in (:T, :P, :V)
2387+
initialconds["$fn"] = getfield(domain,fn)
2388+
end
23192389
end
2320-
end
2390+
2391+
d = Symbol(split(repr(typeof(domain)),"{")[1])
23212392

2322-
function getthermovariableindex(domain::Union{ConstantPDomain,ParametrizedPDomain},target::String)
2323-
if target == "T"
2324-
return domain.indexes[3]
2325-
elseif target == "V"
2326-
return domain.indexes[4]
2393+
if isa(domain.phase,IdealGas)
2394+
return eval(d)(phase=IdealGas(sensspcs,sensrxns,name="phase"),initialconds=initialconds)[1],sensspcnames,senstooriginspcind,senstooriginrxnind
2395+
else
2396+
return eval(d)(phase=IdealDiluteSolution(sensspcs,sensrxns,domain.phase.solvent,name="phase"),initialconds=initialconds)[1],sensspcnames,senstooriginspcind,senstooriginrxnind
23272397
end
23282398
end

src/PhaseState.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ export getkfkrev
192192
return kfs,krev
193193
end
194194

195-
@inline function getkfkrevs(phase::U,T::W1,P::W2,C::W3,N::W4,ns::Q1,Gs::Q2,diffs::Q3,V::W5;kfs::W6=nothing) where {U<:AbstractPhase,W6,W5<:Real,W1<:Real,W2<:Real,W3<:Real,W4<:Real, Q1<:AbstractArray,Q2<:Union{ReverseDiff.TrackedArray,Tracker.TrackedArray},Q3<:AbstractArray} #autodiff p
195+
@inline function getkfkrevs(phase::U,T::W1,P::W2,C::W3,N::W4,ns::Q1,Gs::Q2,diffs::Q3,V::W5;kfs::W6=nothing) where {U<:AbstractPhase,W6,W5<:Real,W1<:Real,W2<:Real,W3<:Real,W4<:Real, Q1<:AbstractArray,Q2,Q3<:AbstractArray} #autodiff p
196196
if !phase.diffusionlimited && kfs === nothing
197197
kfs = getkfs(phase,T,P,C,ns,V)
198198
krev = @fastmath kfs./getKcs(phase,T,Gs)

src/Simulation.jl

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,48 @@ based alternative algorithm is slower, but avoids this concern.
141141
function getadjointsensitivities(bsol::Q,target::String,solver::W;sensalg::W2=InterpolatingAdjoint(autojacvec=ReverseDiffVJP(false)),abstol::Float64=1e-6,reltol::Float64=1e-3,kwargs...) where {Q,W,W2}
142142
@assert target in bsol.names || target in ["T","V","P"]
143143
if target in ["T","V","P"]
144-
ind = getthermovariableindex(bsol.domain,target)
144+
if haskey(bsol.domain.thermovariabledict, target)
145+
ind = bsol.domain.thermovariabledict[target]
146+
else
147+
throw(error("$(bsol.domain) doesn't have $target in its thermovariables"))
148+
end
145149
else
146150
ind = findfirst(isequal(target),bsol.names)
151+
sensdomain,sensspcnames,senstooriginspcind,senstooriginrxnind = getsensdomain(bsol.domain,ind)
152+
if :thermovariabledict in fieldnames(typeof(bsol.domain))
153+
yinds = vcat(senstooriginspcind,collect(values(bsol.domain.thermovariabledict)))
154+
else
155+
yinds = vcat(senstooriginspcind)
156+
end
157+
pinds = vcat(senstooriginspcind,length(bsol.domain.phase.species).+senstooriginrxnind)
158+
ind = findfirst(isequal(target),sensspcnames)
159+
end
160+
161+
function sensg(y::X,p::Array{Y,1},t::Z) where {Q,V,X,Y<:Float64,Z}
162+
sensy = y[yinds]
163+
sensp = p[pinds]
164+
dy = similar(sensy,length(sensy))
165+
return dydtreactor!(dy,sensy,t,sensdomain,[],p=sensp)[ind]
166+
end
167+
function sensg(y::Array{X,1},p::Y,t::Z) where {Q,V,X<:Float64,Y,Z}
168+
sensy = y[yinds]
169+
sensp = p[pinds]
170+
dy = similar(sensp,length(sensy))
171+
return dydtreactor!(dy,sensy,t,sensdomain,[],p=sensp)[ind]
172+
end
173+
function sensg(y::Array{X,1},p::Array{Y,1},t::Z) where {Q,V,X<:Float64,Y<:Float64,Z}
174+
sensy = y[yinds]
175+
sensp = p[pinds]
176+
dy = similar(sensy,length(sensy))
177+
return dydtreactor!(dy,sensy,t,sensdomain,[],p=sensp)[ind]
178+
end
179+
function sensg(y::Array{X,1},p::Array{Y,1},t::Z) where {Q,V,X<:ForwardDiff.Dual,Y<:ForwardDiff.Dual,Z}
180+
sensy = y[yinds]
181+
sensp = p[pinds]
182+
dy = similar(sensy,length(sensy))
183+
return dydtreactor!(dy,sensy,t,sensdomain,[],p=sensp)[ind]
147184
end
185+
148186
function g(y::X,p::Array{Y,1},t::Z) where {Q,V,X,Y<:Float64,Z}
149187
dy = similar(y,length(y))
150188
return dydtreactor!(dy,y,t,bsol.domain,[],p=p)[ind]
@@ -161,12 +199,33 @@ function getadjointsensitivities(bsol::Q,target::String,solver::W;sensalg::W2=In
161199
dy = similar(y,length(y))
162200
return dydtreactor!(dy,y,t,bsol.domain,[],p=p)[ind]
163201
end
202+
203+
dsensgdu(out, y, p, t) = ForwardDiff.gradient!(out, y -> sensg(y, p, t), y)
204+
dsensgdp(out, y, p, t) = ForwardDiff.gradient!(out, p -> sensg(y, p, t), p)
164205
dgdu(out, y, p, t) = ForwardDiff.gradient!(out, y -> g(y, p, t), y)
165206
dgdp(out, y, p, t) = ForwardDiff.gradient!(out, p -> g(y, p, t), p)
166-
du0,dpadj = adjoint_sensitivities(bsol.sol,solver,g,nothing,(dgdu,dgdp);sensealg=sensalg,abstol=abstol,reltol=reltol,kwargs...)
207+
dsensgdurevdiff(out, y, p, t) = ReverseDiff.gradient!(out, y -> sensg(y, p, t), y)
208+
dsensgdprevdiff(out, y, p, t) = ReverseDiff.gradient!(out, p -> sensg(y, p, t), p)
209+
dgdurevdiff(out, y, p, t) = ReverseDiff.gradient!(out, y -> g(y, p, t), y)
210+
dgdprevdiff(out, y, p, t) = ReverseDiff.gradient!(out, p -> g(y, p, t), p)
211+
212+
pethane = 160
213+
if length(bsol.domain.p)<= pethane
214+
if target in ["T","V","P"]
215+
du0,dpadj = adjoint_sensitivities(bsol.sol,solver,g,nothing,(dgdu,dgdp);sensealg=sensalg,abstol=abstol,reltol=reltol,kwargs...)
216+
else
217+
du0,dpadj = adjoint_sensitivities(bsol.sol,solver,sensg,nothing,(dsensgdu,dsensgdp);sensealg=sensalg,abstol=abstol,reltol=reltol,kwargs...)
218+
end
219+
else
220+
if target in ["T","V","P"]
221+
du0,dpadj = adjoint_sensitivities(bsol.sol,solver,g,nothing,(dgdurevdiff,dgdprevdiff);sensealg=sensalg,abstol=abstol,reltol=reltol,kwargs...)
222+
else
223+
du0,dpadj = adjoint_sensitivities(bsol.sol,solver,sensg,nothing,(dsensgdurevdiff,dsensgdprevdiff);sensealg=sensalg,abstol=abstol,reltol=reltol,kwargs...)
224+
end
225+
end
167226
dpadj[length(bsol.domain.phase.species)+1:end] .*= bsol.domain.p[length(bsol.domain.phase.species)+1:end]
168227
if !(target in ["T","V","P"])
169-
dpadj ./= bsol.sol(bsol.sol.t[end])[ind]
228+
dpadj ./= bsol.sol(bsol.sol.t[end])[senstooriginspcind[ind]]
170229
end
171230
return dpadj
172231
end

0 commit comments

Comments
 (0)