OGS
LookupTable.cpp
Go to the documentation of this file.
1
11#include "LookupTable.h"
12
13#include <unordered_set>
14
15#include "BaseLib/Algorithm.h"
16
17namespace ProcessLib
18{
19namespace ComponentTransport
20{
21static void intersection(std::vector<std::size_t>& vec1,
22 std::vector<std::size_t> const& vec2)
23{
24 std::unordered_set<std::size_t> set(vec1.begin(), vec1.end());
25 vec1.clear();
26 for (auto const a : vec2)
27 {
28 if (set.contains(a))
29 {
30 vec1.push_back(a);
31 set.erase(a);
32 }
33 }
34}
35
36std::pair<double, double> Field::getBoundingSeedPoints(double const value) const
37{
38 if (seed_points.size() < 2)
39 {
40 OGS_FATAL("The lookup table for seed_points has less then two values.");
41 }
42
43 auto lower =
44 std::lower_bound(seed_points.cbegin(), seed_points.cend(), value);
45 if (lower == seed_points.begin())
46 {
47 WARN("The interpolation point is below the lower bound.");
48 return std::make_pair(seed_points[0], seed_points[1]);
49 }
50 if (lower == seed_points.end())
51 {
52 WARN("The interpolation point is above the upper bound.");
53 return std::make_pair(seed_points[seed_points.size() - 2],
54 seed_points[seed_points.size() - 1]);
55 }
56
57 auto const upper = lower--;
58 return std::make_pair(*lower, *upper);
59}
60
61void LookupTable::lookup(std::vector<GlobalVector*> const& x,
62 std::vector<GlobalVector*> const& x_prev,
63 std::size_t const n_nodes) const
64{
65 using EntryInput = std::vector<double>;
66
67 for (std::size_t node_id = 0; node_id < n_nodes; ++node_id)
68 {
69 std::vector<InterpolationPoint> interpolation_points;
70 EntryInput base_entry_input;
71 for (auto const& input_field : input_fields)
72 {
73 // process id and variable id are equivalent in the case the
74 // staggered coupling scheme is adopted.
75 auto const process_id = input_field.variable_id;
76 auto const& variable_name = input_field.name;
77 double input_field_value =
78 variable_name.find("_prev") == std::string::npos
79 ? x[process_id]->get(node_id)
80 : x_prev[process_id]->get(node_id);
81 input_field_value =
82 (std::abs(input_field_value) + input_field_value) / 2;
83 auto bounding_seed_points =
84 input_field.getBoundingSeedPoints(input_field_value);
85
86 InterpolationPoint interpolation_point{bounding_seed_points,
87 input_field_value};
88 interpolation_points.push_back(interpolation_point);
89
90 base_entry_input.push_back(bounding_seed_points.first);
91 }
92
93 auto const base_entry_id = getTableEntryID(base_entry_input);
94
95 // collect bounding entry ids
96 EntryInput bounding_entry_input{base_entry_input};
97 std::vector<std::size_t> bounding_entry_ids;
98 for (std::size_t i = 0; i < interpolation_points.size(); ++i)
99 {
100 bounding_entry_input[i] =
101 interpolation_points[i].bounding_points.second;
102 bounding_entry_ids.push_back(getTableEntryID(bounding_entry_input));
103 bounding_entry_input[i] =
104 interpolation_points[i].bounding_points.first;
105 }
106
107 for (auto const& input_field : input_fields)
108 {
109 if (input_field.name.find("_prev") != std::string::npos)
110 {
111 continue;
112 }
113
114 auto const output_field_name = input_field.name + "_new";
115 auto const base_value =
116 tabular_data.at(output_field_name)[base_entry_id];
117 auto new_value = base_value;
118
119 // linear interpolation
120 for (std::size_t i = 0; i < interpolation_points.size(); ++i)
121 {
122 auto const interpolation_point_value =
123 tabular_data.at(output_field_name)[bounding_entry_ids[i]];
124 auto const slope =
125 (interpolation_point_value - base_value) /
126 (interpolation_points[i].bounding_points.second -
127 interpolation_points[i].bounding_points.first);
128
129 new_value +=
130 slope * (interpolation_points[i].value -
131 interpolation_points[i].bounding_points.first);
132 }
133
134 x[input_field.variable_id]->set(node_id, new_value);
135 }
136 }
137}
138
140 std::vector<double> const& entry_input) const
141{
142 std::vector<std::size_t> intersected_vec =
145 input_fields[0].point_id_groups[BaseLib::findIndex(
146 input_fields[0].seed_points, entry_input[0])];
147
151 for (std::size_t i = 1; i < input_fields.size(); ++i)
152 {
153 std::vector<std::size_t> const vec =
154 input_fields[i].point_id_groups[BaseLib::findIndex(
155 input_fields[i].seed_points, entry_input[i])];
156
157 intersection(intersected_vec, vec);
158 }
159
160 return intersected_vec[0];
161}
162} // namespace ComponentTransport
163} // namespace ProcessLib
#define OGS_FATAL(...)
Definition Error.h:26
void WARN(fmt::format_string< Args... > fmt, Args &&... args)
Definition Logging.h:40
std::size_t findIndex(Container const &container, typename Container::value_type const &element)
Definition Algorithm.h:238
static void intersection(std::vector< std::size_t > &vec1, std::vector< std::size_t > const &vec2)
std::pair< double, double > getBoundingSeedPoints(double const value) const
std::vector< double > const seed_points
Definition LookupTable.h:44
void lookup(std::vector< GlobalVector * > const &x, std::vector< GlobalVector * > const &x_prev, std::size_t const n_nodes) const
std::size_t getTableEntryID(std::vector< double > const &entry_input) const
std::map< std::string, std::vector< double > > const tabular_data
Definition LookupTable.h:66
std::vector< Field > const input_fields
Definition LookupTable.h:65