OGS
NumericalDifferentiation.h
Go to the documentation of this file.
1// SPDX-FileCopyrightText: Copyright (c) OpenGeoSys Community (opengeosys.org)
2// SPDX-License-Identifier: BSD-3-Clause
3
4#pragma once
5
6#include <Eigen/Core>
7#include <tuple>
8#include <utility>
9
10#include "BaseLib/StrongType.h"
11
12namespace NumLib
13{
17
18namespace detail
19{
20template <typename T>
21struct IsScalar : std::true_type
22{
23};
24
25template <int N>
26struct IsScalar<Eigen::Matrix<double, N, 1, Eigen::ColMajor, N, 1>>
27 : std::false_type
28{
29};
30
31template <std::size_t IndexInTuple, typename Tuple>
32double getScalarOrVectorComponent(Tuple const& tuple, Eigen::Index component)
33{
34 auto const& value = std::get<IndexInTuple>(tuple);
35
36 if constexpr (IsScalar<std::remove_cvref_t<decltype(value)>>::value)
37 {
38 return value;
39 }
40 else
41 {
42 return value[component];
43 }
44}
45
51{
52 template <typename Function, typename TupleOfArgs,
53 typename PerturbationStrategy, std::size_t PerturbedArgIdx,
54 std::size_t... AllArgIdcs>
55 auto operator()(Function const& f, TupleOfArgs const& args,
56 PerturbationStrategy const& pert_strat,
57 std::integral_constant<std::size_t, PerturbedArgIdx>,
58 Eigen::Index const perturbed_arg_component,
59 std::index_sequence<AllArgIdcs...>) const
60 {
61 auto const value_plus = f(pert_strat.perturbIf(
62 std::bool_constant<PerturbedArgIdx == AllArgIdcs>{},
63 std::get<AllArgIdcs>(args), 1.0, perturbed_arg_component)...);
64
65 auto const value_minus = f(pert_strat.perturbIf(
66 std::bool_constant<PerturbedArgIdx == AllArgIdcs>{},
67 std::get<AllArgIdcs>(args), -1.0, perturbed_arg_component)...);
68
69 auto const pert = pert_strat.getPerturbation(
71 args, perturbed_arg_component));
72
73 // decltype enforces evaluation of Eigen expressions
74 decltype(value_plus) deriv = (value_plus - value_minus) / (2 * pert);
75
76 return deriv;
77 }
78};
79
84template <typename Value>
86{
87 explicit ComputeDerivativeWrtOneScalar_FD(Value&& unperturbed_value)
88 : unperturbed_value_{std::move(unperturbed_value)}
89 {
90 }
91
92 template <typename Function, typename TupleOfArgs,
93 typename PerturbationStrategy, std::size_t PerturbedArgIdx,
94 std::size_t... AllArgIdcs>
95 Value operator()(Function const& f, TupleOfArgs const& args,
96 PerturbationStrategy const& pert_strat,
97 std::integral_constant<std::size_t, PerturbedArgIdx>,
98 Eigen::Index const perturbed_arg_component,
99 std::index_sequence<AllArgIdcs...>) const
100 {
101 auto const value_plus = f(pert_strat.perturbIf(
102 std::bool_constant<PerturbedArgIdx == AllArgIdcs>{},
103 std::get<AllArgIdcs>(args), 1.0, perturbed_arg_component)...);
104
105 auto const pert = pert_strat.getPerturbation(
107 args, perturbed_arg_component));
108
109 return (value_plus - unperturbed_value_) / pert;
110 }
111
112private:
114};
115
119{
121 MinimumPerturbation const& min_pert)
122 : rel_eps_{*rel_eps}, min_pert_{*min_pert}
123 {
124 }
125
126 double getPerturbation(double const value) const
127 {
128 auto const pert = std::abs(value) * rel_eps_;
129
130 if (std::abs(pert) >= std::abs(min_pert_))
131 {
132 return pert;
133 }
134
135 return min_pert_;
136 }
137
138 template <typename T>
139 static T const& perturbIf(std::false_type, T const& value,
140 double const /*plus_or_minus*/,
141 Eigen::Index /*comp*/)
142 {
143 return value;
144 }
145
146 double perturbIf(std::true_type, double value, double const plus_or_minus,
147 Eigen::Index /*comp*/) const
148 {
149 return value + plus_or_minus * getPerturbation(value);
150 }
151
152 template <int N>
153 Eigen::Vector<double, N> perturbIf(
154 std::true_type,
155 Eigen::Matrix<double, N, 1, Eigen::ColMajor, N, 1> const& vec,
156 double const plus_or_minus,
157 Eigen::Index comp) const
158 {
159 Eigen::Vector<double, N> vec_pert = vec;
160 vec_pert[comp] += plus_or_minus * getPerturbation(vec[comp]);
161 return vec_pert;
162 }
163
164private:
165 double rel_eps_;
166 double min_pert_;
167};
168} // namespace detail
169
174{
175 template <typename Function, typename... Args>
177 Function const& /*f*/, Args const&... /*args*/)
178 {
179 return {};
180 }
181};
182
187{
188 template <typename Function, typename... Args>
189 static auto createDByDScalar(Function const& f, Args const&... args)
190 {
192 }
193};
194
195// TODO better call it NumericalDifferentiationAlgorithm?
200template <typename DerivativeStrategy>
202{
204 MinimumPerturbation const& min_pert)
205 : pert_strat_{rel_eps, min_pert}
206 {
207 }
208
209 template <typename Function, typename... Args>
210 auto operator()(Function const& f, Args const&... args) const
211 {
212 auto const d_by_dScalar =
213 DerivativeStrategy::createDByDScalar(f, args...);
214
215 // TODO also return value from the function, not only the derivatives?
216 return differentiate(f,
217 std::forward_as_tuple(args...),
218 d_by_dScalar,
219 std::make_index_sequence<sizeof...(Args)>{});
220 }
221
222private:
223 template <typename Function, typename TupleOfArgs, typename DByDScalar,
224 std::size_t... AllArgIdcs>
225 auto differentiate(Function const& f, TupleOfArgs const& args,
226 DByDScalar const& d_by_dScalar,
227 std::index_sequence<AllArgIdcs...> all_arg_idcs) const
228 {
230 detail::IsScalar<std::remove_cvref_t<
231 std::tuple_element_t<AllArgIdcs, TupleOfArgs>>>{},
232 f, args, d_by_dScalar,
233 std::integral_constant<std::size_t, AllArgIdcs>{},
234 all_arg_idcs)... /* "for each function argument" */};
235 }
236
237 // scalar case
238 template <typename Function, typename TupleOfArgs, typename DByDScalar,
239 std::size_t... AllArgIdcs, std::size_t PerturbedArgIdx>
241 std::true_type /* is_scalar */, Function const& f,
242 TupleOfArgs const& args, DByDScalar const& d_by_dScalar,
243 std::integral_constant<std::size_t, PerturbedArgIdx> perturbed_arg_idx,
244 std::index_sequence<AllArgIdcs...> all_arg_idcs) const
245 {
246 constexpr Eigen::Index component_does_not_matter = -1;
247
248 return d_by_dScalar(f, args, pert_strat_, perturbed_arg_idx,
249 component_does_not_matter, all_arg_idcs);
250 }
251
252 // vectorial case
253 template <typename Function, typename TupleOfArgs, typename DByDScalar,
254 std::size_t... AllArgIdcs, std::size_t PerturbedArgIdx>
256 std::false_type /* is_scalar */, Function const& f,
257 TupleOfArgs const& args, DByDScalar const& d_by_dScalar,
258 std::integral_constant<std::size_t, PerturbedArgIdx> perturbed_arg_idx,
259 std::index_sequence<AllArgIdcs...> all_arg_idcs) const
260 {
261 using VectorialArg = std::remove_cvref_t<
262 std::tuple_element_t<PerturbedArgIdx, TupleOfArgs>>;
263 constexpr int N = VectorialArg::RowsAtCompileTime;
264
265 static_assert(N != Eigen::Dynamic);
266 static_assert(VectorialArg::ColsAtCompileTime == 1,
267 "Row vectors are not supported, yet. If you implement "
268 "support for them, make sure to test your implementation "
269 "thoroughly.");
270
272 f, args, d_by_dScalar,
273 std::make_integer_sequence<Eigen::Index, N>{}, perturbed_arg_idx,
274 all_arg_idcs);
275 }
276
277 template <typename Function, typename TupleOfArgs, typename DByDScalar,
278 Eigen::Index... PerturbedArgComponents, std::size_t... AllArgIdcs,
279 std::size_t PerturbedArgIdx>
281 Function const& f, TupleOfArgs const& args,
282 DByDScalar const& d_by_dScalar,
283 std::integer_sequence<Eigen::Index, PerturbedArgComponents...>,
284 std::integral_constant<std::size_t, PerturbedArgIdx> perturbed_arg_idx,
285 std::index_sequence<AllArgIdcs...> all_arg_idcs) const
286 {
287 return std::array{
288 d_by_dScalar(f, args, pert_strat_, perturbed_arg_idx,
289 PerturbedArgComponents, all_arg_idcs)...
290 /* "for each component of the vectorial function argument being
291 perturbed" */
292 };
293 }
294
296};
297
298} // namespace NumLib
double getScalarOrVectorComponent(Tuple const &tuple, Eigen::Index component)
BaseLib::StrongType< double, struct RelativeEpsilonTag > RelativeEpsilon
BaseLib::StrongType< double, struct MinimumPerturbationTag > MinimumPerturbation
static detail::ComputeDerivativeWrtOneScalar_CD createDByDScalar(Function const &, Args const &...)
static auto createDByDScalar(Function const &f, Args const &... args)
auto differentiateWrtScalarOrVectorialArgument(std::false_type, Function const &f, TupleOfArgs const &args, DByDScalar const &d_by_dScalar, std::integral_constant< std::size_t, PerturbedArgIdx > perturbed_arg_idx, std::index_sequence< AllArgIdcs... > all_arg_idcs) const
detail::DefaultPerturbationStrategy pert_strat_
auto operator()(Function const &f, Args const &... args) const
NumericalDerivative(RelativeEpsilon const &rel_eps, MinimumPerturbation const &min_pert)
auto differentiateWrtScalarOrVectorialArgument(std::true_type, Function const &f, TupleOfArgs const &args, DByDScalar const &d_by_dScalar, std::integral_constant< std::size_t, PerturbedArgIdx > perturbed_arg_idx, std::index_sequence< AllArgIdcs... > all_arg_idcs) const
auto differentiateWrtAllVectorComponents(Function const &f, TupleOfArgs const &args, DByDScalar const &d_by_dScalar, std::integer_sequence< Eigen::Index, PerturbedArgComponents... >, std::integral_constant< std::size_t, PerturbedArgIdx > perturbed_arg_idx, std::index_sequence< AllArgIdcs... > all_arg_idcs) const
auto differentiate(Function const &f, TupleOfArgs const &args, DByDScalar const &d_by_dScalar, std::index_sequence< AllArgIdcs... > all_arg_idcs) const
auto operator()(Function const &f, TupleOfArgs const &args, PerturbationStrategy const &pert_strat, std::integral_constant< std::size_t, PerturbedArgIdx >, Eigen::Index const perturbed_arg_component, std::index_sequence< AllArgIdcs... >) const
Value operator()(Function const &f, TupleOfArgs const &args, PerturbationStrategy const &pert_strat, std::integral_constant< std::size_t, PerturbedArgIdx >, Eigen::Index const perturbed_arg_component, std::index_sequence< AllArgIdcs... >) const
static T const & perturbIf(std::false_type, T const &value, double const, Eigen::Index)
double perturbIf(std::true_type, double value, double const plus_or_minus, Eigen::Index) const
DefaultPerturbationStrategy(RelativeEpsilon const &rel_eps, MinimumPerturbation const &min_pert)
Eigen::Vector< double, N > perturbIf(std::true_type, Eigen::Matrix< double, N, 1, Eigen::ColMajor, N, 1 > const &vec, double const plus_or_minus, Eigen::Index comp) const