OGS
LookupTable.cpp
Go to the documentation of this file.
1
11#include "LookupTable.h"
12
13#include <unordered_set>
14
15#include "BaseLib/Algorithm.h"
16
17namespace ProcessLib
18{
19namespace ComponentTransport
20{
21static std::vector<std::size_t> intersection(
22 std::vector<std::size_t> const& vec1, std::vector<std::size_t> const& vec2)
23{
24 std::unordered_set<std::size_t> set(vec1.begin(), vec1.end());
25 std::vector<std::size_t> vec;
26 for (auto const a : vec2)
27 {
28 if (set.contains(a))
29 {
30 vec.push_back(a);
31 set.erase(a);
32 }
33 }
34 return vec;
35}
36
37std::pair<double, double> Field::getBoundingSeedPoints(double const value) const
38{
39 auto it = std::lower_bound(seed_points.cbegin(), seed_points.cend(), value);
40 if (it == seed_points.begin())
41 {
42 WARN("The interpolation point is below the lower bound.");
43 auto const nx = std::next(it);
44 return std::make_pair(*it, *nx);
45 }
46 if (it == seed_points.end())
47 {
48 WARN("The interpolation point is above the upper bound.");
49 std::advance(it, -1);
50 }
51
52 auto const pv = std::prev(it);
53 return std::make_pair(*pv, *it);
54}
55
56void LookupTable::lookup(std::vector<GlobalVector*> const& x,
57 std::vector<GlobalVector*> const& x_prev,
58 std::size_t const n_nodes) const
59{
60 using EntryInput = std::vector<double>;
61
62 for (std::size_t node_id = 0; node_id < n_nodes; ++node_id)
63 {
64 std::vector<InterpolationPoint> interpolation_points;
65 EntryInput base_entry_input;
66 for (auto const& input_field : input_fields)
67 {
68 // process id and variable id are equilvalent in the case the
69 // staggered coupling scheme is adopted.
70 auto const process_id = input_field.variable_id;
71 auto const& variable_name = input_field.name;
72 double input_field_value =
73 variable_name.find("_prev") == std::string::npos
74 ? x[process_id]->get(node_id)
75 : x_prev[process_id]->get(node_id);
76 input_field_value =
77 (std::abs(input_field_value) + input_field_value) / 2;
78 auto bounding_seed_points =
79 input_field.getBoundingSeedPoints(input_field_value);
80
81 InterpolationPoint interpolation_point{bounding_seed_points,
82 input_field_value};
83 interpolation_points.push_back(interpolation_point);
84
85 base_entry_input.push_back(bounding_seed_points.first);
86 }
87
88 auto const base_entry_id = getTableEntryID(base_entry_input);
89
90 // collect bounding entry ids
91 EntryInput bounding_entry_input{base_entry_input};
92 std::vector<std::size_t> bounding_entry_ids;
93 for (std::size_t i = 0; i < interpolation_points.size(); ++i)
94 {
95 bounding_entry_input[i] =
96 interpolation_points[i].bounding_points.second;
97 bounding_entry_ids.push_back(getTableEntryID(bounding_entry_input));
98 bounding_entry_input[i] =
99 interpolation_points[i].bounding_points.first;
100 }
101
102 for (auto const& input_field : input_fields)
103 {
104 if (input_field.name.find("_prev") != std::string::npos)
105 {
106 continue;
107 }
108
109 auto const output_field_name = input_field.name + "_new";
110 auto const base_value =
111 tabular_data.at(output_field_name)[base_entry_id];
112 auto new_value = base_value;
113
114 // linear interpolation
115 for (std::size_t i = 0; i < interpolation_points.size(); ++i)
116 {
117 auto const interpolation_point_value =
118 tabular_data.at(output_field_name)[bounding_entry_ids[i]];
119 auto const slope =
120 (interpolation_point_value - base_value) /
121 (interpolation_points[i].bounding_points.second -
122 interpolation_points[i].bounding_points.first);
123
124 new_value +=
125 slope * (interpolation_points[i].value -
126 interpolation_points[i].bounding_points.first);
127 }
128
129 x[input_field.variable_id]->set(node_id, new_value);
130 }
131 }
132}
133
135 std::vector<double> const& entry_input) const
136{
137 std::vector<std::size_t> temp_vec;
138
139 std::vector<std::size_t> intersected_vec =
142 input_fields[0].point_id_groups[BaseLib::findIndex(
143 input_fields[0].seed_points, entry_input[0])];
144
148 for (std::size_t i = 1; i < input_fields.size(); ++i)
149 {
150 std::vector<std::size_t> const vec =
151 input_fields[i].point_id_groups[BaseLib::findIndex(
152 input_fields[i].seed_points, entry_input[i])];
153
154 temp_vec = intersection(intersected_vec, vec);
155
156 std::swap(intersected_vec, temp_vec);
157 temp_vec.clear();
158 }
159
160 return intersected_vec[0];
161}
162} // namespace ComponentTransport
163} // namespace ProcessLib
void WARN(char const *fmt, Args const &... args)
Definition: Logging.h:39
std::size_t findIndex(Container const &container, typename Container::value_type const &element)
Definition: Algorithm.h:287
void set(PETScVector &x, PetscScalar const a)
Definition: LinAlg.cpp:32
static std::vector< std::size_t > intersection(std::vector< std::size_t > const &vec1, std::vector< std::size_t > const &vec2)
Definition: LookupTable.cpp:21
std::pair< double, double > getBoundingSeedPoints(double const value) const
Definition: LookupTable.cpp:37
std::vector< double > const seed_points
Definition: LookupTable.h:44
void lookup(std::vector< GlobalVector * > const &x, std::vector< GlobalVector * > const &x_prev, std::size_t const n_nodes) const
Definition: LookupTable.cpp:56
std::size_t getTableEntryID(std::vector< double > const &entry_input) const
std::map< std::string, std::vector< double > > const tabular_data
Definition: LookupTable.h:66
std::vector< Field > const input_fields
Definition: LookupTable.h:65