OGS
CVodeSolver.cpp
Go to the documentation of this file.
1
10#include "CVodeSolver.h"
11
12#include <cvode/cvode.h> /* prototypes for CVODE fcts., consts. */
13#include <cvode/cvode_dense.h> /* prototype for CVDense */
14#include <nvector/nvector_serial.h> /* serial N_Vector types, fcts., macros */
15#include <sundials/sundials_dense.h> /* definitions DlsMat DENSE_ELEM */
16#include <sundials/sundials_types.h> /* definition of type realtype */
17
18#include <cassert>
19
20#include "BaseLib/ConfigTree.h"
21#include "BaseLib/Error.h"
22#include "BaseLib/Logging.h"
23
26
32void check_error(std::string const& f_name, int const error_flag)
33{
34 if (error_flag != CV_SUCCESS)
35 {
36 OGS_FATAL("CVodeSolver: {:s} failed with error flag {:d}.", f_name,
37 error_flag);
38 }
39}
40
42void printStats(void* cvode_mem)
43{
44 long int nst = 0, nfe = 0, nsetups = 0, nje = 0, nfeLS = 0, nni = 0,
45 ncfn = 0, netf = 0, nge = 0;
46
47 check_error("CVodeGetNumSteps", CVodeGetNumSteps(cvode_mem, &nst));
48 check_error("CVodeGetNumRhsEvals", CVodeGetNumRhsEvals(cvode_mem, &nfe));
49 check_error("CVodeGetNumLinSolvSetups",
50 CVodeGetNumLinSolvSetups(cvode_mem, &nsetups));
51 check_error("CVodeGetNumErrTestFails",
52 CVodeGetNumErrTestFails(cvode_mem, &netf));
53 check_error("CVodeGetNumNonlinSolvIters",
54 CVodeGetNumNonlinSolvIters(cvode_mem, &nni));
55 check_error("CVodeGetNumNonlinSolvConvFails",
56 CVodeGetNumNonlinSolvConvFails(cvode_mem, &ncfn));
57 check_error("CVDlsGetNumJacEvals", CVDlsGetNumJacEvals(cvode_mem, &nje));
58 check_error("CVDlsGetNumRhsEvals", CVDlsGetNumRhsEvals(cvode_mem, &nfeLS));
59 check_error("CVodeGetNumGEvals", CVodeGetNumGEvals(cvode_mem, &nge));
60
61 DBUG("Sundials CVode solver. Statistics:");
62 DBUG(
63 "nst = {:<0d} nfe = {:<0d} nsetups = {:<0d} nfeLS = {:<0d} nje = {:d}",
64 nst, nfe, nsetups, nfeLS, nje);
65 DBUG("nni = {:<0d} ncfn = {:<0d} netf = {:<0d} nge = {:d}\n", nni,
66 ncfn, netf, nge);
67}
68
70
71namespace MathLib
72{
73namespace ODE
74{
77
84class CVodeSolverImpl final
85{
86 static_assert(std::is_same_v<realtype, double>,
87 "CVode's realtype is not the same as double");
88
89public:
91 unsigned const num_equations);
92
93 void setFunction(std::unique_ptr<detail::FunctionHandles>&& f);
94
95 void preSolve();
96 bool solve(const double t_end);
97
98 double const* getSolution() const { return NV_DATA_S(y_); }
99 double getTime() const { return t_; }
100 void getYDot(const double t, double const* const y, double* const y_dot);
101 void setTolerance(const double* abstol, const double reltol);
102 void setTolerance(const double abstol, const double reltol);
103 void setIC(const double t0, double const* const y0);
104
106
107private:
108 N_Vector y_ = nullptr;
109
110 realtype t_;
111
112 N_Vector abstol_ = nullptr;
113 realtype reltol_;
114
115 unsigned num_equations_;
117
120 std::unique_ptr<detail::FunctionHandles> f_;
121
124
126 int nonlinear_solver_iteration_ = CV_FUNCTIONAL;
127};
128
130
132 const unsigned num_equations)
133{
134 if (auto const param =
136 config.getConfigParameterOptional<std::string>(
137 "linear_multistep_method"))
138 {
139 DBUG("setting linear multistep method (config: {:s})", param->c_str());
140
141 if (*param == "Adams")
142 {
143 linear_multistep_method_ = CV_ADAMS;
144 }
145 else if (*param == "BDF")
146 {
148 }
149 else
150 {
151 OGS_FATAL("unknown linear multistep method: {:s}", param->c_str());
152 }
153 }
154
155 if (auto const param =
157 config.getConfigParameterOptional<std::string>(
158 "nonlinear_solver_iteration"))
159 {
160 DBUG("setting nonlinear solver iteration (config: {:s})",
161 param->c_str());
162
163 if (*param == "Functional")
164 {
165 nonlinear_solver_iteration_ = CV_FUNCTIONAL;
166 }
167 else if (*param == "Newton")
168 {
169 nonlinear_solver_iteration_ = CV_NEWTON;
170 }
171 else
172 {
173 OGS_FATAL("unknown nonlinear solver iteration: {:s}",
174 param->c_str());
175 }
176 }
177
178 y_ = N_VNew_Serial(num_equations);
179 abstol_ = N_VNew_Serial(num_equations);
180 num_equations_ = num_equations;
181
182 cvode_mem_ =
184
185 if (cvode_mem_ == nullptr || y_ == nullptr || abstol_ == nullptr)
186 {
187 OGS_FATAL("couldn't allocate storage for CVode solver.");
188 }
189
190 auto f_wrapped = [](const realtype t, const N_Vector y, N_Vector ydot,
191 void* function_handles) -> int
192 {
193 bool successful =
194 static_cast<detail::FunctionHandles*>(function_handles)
195 ->call(t, NV_DATA_S(y), NV_DATA_S(ydot));
196 return successful ? 0 : 1;
197 };
198
199 check_error("CVodeInit", CVodeInit(cvode_mem_, f_wrapped, 0.0, y_));
200}
201
202void CVodeSolverImpl::setTolerance(const double* abstol, const double reltol)
203{
204 for (unsigned i = 0; i < num_equations_; ++i)
205 {
206 NV_Ith_S(abstol_, i) = abstol[i];
207 }
208
209 reltol_ = reltol;
210}
211
212void CVodeSolverImpl::setTolerance(const double abstol, const double reltol)
213{
214 for (unsigned i = 0; i < num_equations_; ++i)
215 {
216 NV_Ith_S(abstol_, i) = abstol;
217 }
218
219 reltol_ = reltol;
220}
221
222void CVodeSolverImpl::setFunction(std::unique_ptr<detail::FunctionHandles>&& f)
223{
224 f_ = std::move(f);
225 assert(num_equations_ == f_->getNumberOfEquations());
226}
227
228void CVodeSolverImpl::setIC(const double t0, double const* const y0)
229{
230 for (unsigned i = 0; i < num_equations_; ++i)
231 {
232 NV_Ith_S(y_, i) = y0[i];
233 }
234
235 t_ = t0;
236}
237
239{
240 assert(f_ != nullptr && "ode function handle was not provided");
241
242 // sets initial conditions
243 check_error("CVodeReInit", CVodeReInit(cvode_mem_, t_, y_));
244
245 check_error("CVodeSetUserData",
246 CVodeSetUserData(cvode_mem_, static_cast<void*>(f_.get())));
247
248 /* Call CVodeSVtolerances to specify the scalar relative tolerance
249 * and vector absolute tolerances */
250 check_error("CVodeSVtolerances",
251 CVodeSVtolerances(cvode_mem_, reltol_, abstol_));
252
253 /* Call CVDense to specify the CVDENSE dense linear solver */
254 check_error("CVDense", CVDense(cvode_mem_, num_equations_));
255
256 if (f_->hasJacobian())
257 {
258 auto df_wrapped = [](const long N, const realtype t, const N_Vector y,
259 const N_Vector ydot, const DlsMat jac,
260 void* function_handles, N_Vector /*tmp1*/,
261 N_Vector /*tmp2*/, N_Vector /*tmp3*/
262 ) -> int
263 {
264 (void)N; // prevent warnings during non-debug build
265 auto* fh = static_cast<detail::FunctionHandles*>(function_handles);
266 assert(N == fh->getNumberOfEquations());
267
268 // Caution: by calling the DENSE_COL() macro we assume that matrices
269 // are stored contiguously in memory!
270 // See also the header files sundials_direct.h and cvode_direct.h in
271 // the Sundials source code. The comments about the macro DENSE_COL
272 // in those files indicate that matrices are stored column-wise.
273 bool successful = fh->callJacobian(t, NV_DATA_S(y), NV_DATA_S(ydot),
274 DENSE_COL(jac, 0));
275 return successful ? 0 : 1;
276 };
277
278 check_error("CVDlsSetDenseJacFn",
279 CVDlsSetDenseJacFn(cvode_mem_, df_wrapped));
280 }
281}
282
283bool CVodeSolverImpl::solve(const double t_end)
284{
285 realtype t_reached;
286 check_error("CVode solve",
287 CVode(cvode_mem_, t_end, y_, &t_reached, CV_NORMAL));
288 t_ = t_reached;
289
290 // check_error asserts that t_end == t_reached and that solving the ODE
291 // went fine. Otherwise the program will be aborted. Therefore, we don't
292 // have to check manually for errors here and can always safely return true.
293 return true;
294}
295
296void CVodeSolverImpl::getYDot(const double t, double const* const y,
297 double* const y_dot)
298{
299 assert(f_ != nullptr);
300 f_->call(t, y, y_dot);
301}
302
304{
306
307 N_VDestroy_Serial(y_);
308 N_VDestroy_Serial(abstol_);
309 CVodeFree(&cvode_mem_);
310}
311
313 unsigned const num_equations)
314 : impl_{new CVodeSolverImpl{config, num_equations}}
315{
316}
317
318void CVodeSolver::setTolerance(const double* abstol, const double reltol)
319{
320 impl_->setTolerance(abstol, reltol);
321}
322
323void CVodeSolver::setTolerance(const double abstol, const double reltol)
324{
325 impl_->setTolerance(abstol, reltol);
326}
327
328void CVodeSolver::setFunction(std::unique_ptr<detail::FunctionHandles>&& f)
329{
330 impl_->setFunction(std::move(f));
331}
332
333void CVodeSolver::setIC(const double t0, double const* const y0)
334{
335 impl_->setIC(t0, y0);
336}
337
339{
340 impl_->preSolve();
341}
342
343bool CVodeSolver::solve(const double t_end)
344{
345 return impl_->solve(t_end);
346}
347
348double const* CVodeSolver::getSolution() const
349{
350 return impl_->getSolution();
351}
352
353void CVodeSolver::getYDot(const double t, double const* const y,
354 double* const y_dot) const
355{
356 impl_->getYDot(t, y, y_dot);
357}
358
360{
361 return impl_->getTime();
362}
363
364CVodeSolver::~CVodeSolver() = default;
365
366} // namespace ODE
367} // namespace MathLib
#define OGS_FATAL(...)
Definition Error.h:26
void DBUG(fmt::format_string< Args... > fmt, Args &&... args)
Definition Logging.h:30
std::optional< T > getConfigParameterOptional(std::string const &param) const
CVodeSolverImpl(BaseLib::ConfigTree const &config, unsigned const num_equations)
std::unique_ptr< detail::FunctionHandles > f_
N_Vector abstol_
current time
void * cvode_mem_
CVode's internal memory.
void setFunction(std::unique_ptr< detail::FunctionHandles > &&f)
double const * getSolution() const
realtype reltol_
Relative tolerance.
N_Vector y_
The solution vector.
void getYDot(const double t, double const *const y, double *const y_dot)
int linear_multistep_method_
The multistep method used for solving the ODE.
void setTolerance(const double *abstol, const double reltol)
unsigned num_equations_
Number of equations in the ODE system.
void setIC(const double t0, double const *const y0)
bool solve(const double t_end)
int nonlinear_solver_iteration_
Either solve via fixed-point iteration or via Newton-Raphson method.
std::unique_ptr< CVodeSolverImpl > impl_
pimpl idiom.
Definition CVodeSolver.h:73
void setFunction(std::unique_ptr< detail::FunctionHandles > &&f)
bool solve(const double t_end)
void setTolerance(double const *const abstol, const double reltol)
void setIC(const double t0, double const *const y0)
void getYDot(const double t, double const *const y, double *const y_dot) const
double const * getSolution() const
CVodeSolver(BaseLib::ConfigTree const &config, unsigned const num_equations)
void printStats(void *cvode_mem)
Prints some statistics about an ODE solver run.
void check_error(std::string const &f_name, int const error_flag)
static const double t