OGS
LookupTable.cpp
Go to the documentation of this file.
1// SPDX-FileCopyrightText: Copyright (c) OpenGeoSys Community (opengeosys.org)
2// SPDX-License-Identifier: BSD-3-Clause
3
4#include "LookupTable.h"
5
6#include <unordered_set>
7
8#include "BaseLib/Algorithm.h"
9
10namespace ProcessLib
11{
12namespace ComponentTransport
13{
14static void intersection(std::vector<std::size_t>& vec1,
15 std::vector<std::size_t> const& vec2)
16{
17 std::unordered_set<std::size_t> set(vec1.begin(), vec1.end());
18 vec1.clear();
19 for (auto const a : vec2)
20 {
21 if (set.contains(a))
22 {
23 vec1.push_back(a);
24 set.erase(a);
25 }
26 }
27}
28
29std::pair<double, double> Field::getBoundingSeedPoints(double const value) const
30{
31 if (seed_points.size() < 2)
32 {
33 OGS_FATAL("The lookup table for seed_points has less then two values.");
34 }
35
36 auto lower =
37 std::lower_bound(seed_points.cbegin(), seed_points.cend(), value);
38 if (lower == seed_points.begin())
39 {
40 WARN("The interpolation point is below the lower bound.");
41 return std::make_pair(seed_points[0], seed_points[1]);
42 }
43 if (lower == seed_points.end())
44 {
45 WARN("The interpolation point is above the upper bound.");
46 return std::make_pair(seed_points[seed_points.size() - 2],
47 seed_points[seed_points.size() - 1]);
48 }
49
50 auto const upper = lower--;
51 return std::make_pair(*lower, *upper);
52}
53
54void LookupTable::lookup(std::vector<GlobalVector*> const& x,
55 std::vector<GlobalVector*> const& x_prev,
56 std::size_t const n_nodes) const
57{
58 using EntryInput = std::vector<double>;
59
60 for (std::size_t node_id = 0; node_id < n_nodes; ++node_id)
61 {
62 std::vector<InterpolationPoint> interpolation_points;
63 EntryInput base_entry_input;
64 for (auto const& input_field : input_fields)
65 {
66 // process id and variable id are equivalent in the case the
67 // staggered coupling scheme is adopted.
68 auto const process_id = input_field.variable_id;
69 auto const& variable_name = input_field.name;
70 double input_field_value =
71 variable_name.find("_prev") == std::string::npos
72 ? x[process_id]->get(node_id)
73 : x_prev[process_id]->get(node_id);
74 input_field_value =
75 (std::abs(input_field_value) + input_field_value) / 2;
76 auto bounding_seed_points =
77 input_field.getBoundingSeedPoints(input_field_value);
78
79 InterpolationPoint interpolation_point{bounding_seed_points,
80 input_field_value};
81 interpolation_points.push_back(interpolation_point);
82
83 base_entry_input.push_back(bounding_seed_points.first);
84 }
85
86 auto const base_entry_id = getTableEntryID(base_entry_input);
87
88 // collect bounding entry ids
89 EntryInput bounding_entry_input{base_entry_input};
90 std::vector<std::size_t> bounding_entry_ids;
91 for (std::size_t i = 0; i < interpolation_points.size(); ++i)
92 {
93 bounding_entry_input[i] =
94 interpolation_points[i].bounding_points.second;
95 bounding_entry_ids.push_back(getTableEntryID(bounding_entry_input));
96 bounding_entry_input[i] =
97 interpolation_points[i].bounding_points.first;
98 }
99
100 for (auto const& input_field : input_fields)
101 {
102 if (input_field.name.find("_prev") != std::string::npos)
103 {
104 continue;
105 }
106
107 auto const output_field_name = input_field.name + "_new";
108 auto const base_value =
109 tabular_data.at(output_field_name)[base_entry_id];
110 auto new_value = base_value;
111
112 // linear interpolation
113 for (std::size_t i = 0; i < interpolation_points.size(); ++i)
114 {
115 auto const interpolation_point_value =
116 tabular_data.at(output_field_name)[bounding_entry_ids[i]];
117 auto const slope =
118 (interpolation_point_value - base_value) /
119 (interpolation_points[i].bounding_points.second -
120 interpolation_points[i].bounding_points.first);
121
122 new_value +=
123 slope * (interpolation_points[i].value -
124 interpolation_points[i].bounding_points.first);
125 }
126
127 x[input_field.variable_id]->set(node_id, new_value);
128 }
129 }
130}
131
133 std::vector<double> const& entry_input) const
134{
135 std::vector<std::size_t> intersected_vec =
138 input_fields[0].point_id_groups[BaseLib::findIndex(
139 input_fields[0].seed_points, entry_input[0])];
140
144 for (std::size_t i = 1; i < input_fields.size(); ++i)
145 {
146 std::vector<std::size_t> const vec =
147 input_fields[i].point_id_groups[BaseLib::findIndex(
148 input_fields[i].seed_points, entry_input[i])];
149
150 intersection(intersected_vec, vec);
151 }
152
153 return intersected_vec[0];
154}
155} // namespace ComponentTransport
156} // namespace ProcessLib
#define OGS_FATAL(...)
Definition Error.h:19
void WARN(fmt::format_string< Args... > fmt, Args &&... args)
Definition Logging.h:34
std::size_t findIndex(Container const &container, typename Container::value_type const &element)
Definition Algorithm.h:236
static void intersection(std::vector< std::size_t > &vec1, std::vector< std::size_t > const &vec2)
std::pair< double, double > getBoundingSeedPoints(double const value) const
std::vector< double > const seed_points
Definition LookupTable.h:37
void lookup(std::vector< GlobalVector * > const &x, std::vector< GlobalVector * > const &x_prev, std::size_t const n_nodes) const
std::size_t getTableEntryID(std::vector< double > const &entry_input) const
std::map< std::string, std::vector< double > > const tabular_data
Definition LookupTable.h:59
std::vector< Field > const input_fields
Definition LookupTable.h:58