OGS
MPI.h
Go to the documentation of this file.
1// SPDX-FileCopyrightText: Copyright (c) OpenGeoSys Community (opengeosys.org)
2// SPDX-License-Identifier: BSD-3-Clause
3
4#pragma once
5
6#include <algorithm>
7#include <concepts>
8
9#include "Algorithm.h"
10#include "DemangleTypeInfo.h"
11#include "Error.h"
12
13#ifdef USE_PETSC
14#include <mpi.h>
15#endif
16
17namespace BaseLib::MPI
18{
19
20#ifdef USE_PETSC
21extern MPI_Comm OGS_COMM_WORLD;
22#endif
23
24struct Setup
25{
26 Setup(int argc, char* argv[]);
27
28 ~Setup();
29};
30
31#ifdef USE_PETSC
32struct Mpi
33{
36 {
37 int mpi_init;
38 MPI_Initialized(&mpi_init);
39 if (mpi_init != 1)
40 {
41 OGS_FATAL("MPI is not initialized.");
42 }
43 MPI_Comm_size(communicator, &size);
44 MPI_Comm_rank(communicator, &rank);
45 }
46
47 MPI_Comm communicator;
48 int size;
49 int rank;
50};
51
52template <typename T>
53constexpr MPI_Datatype mpiType()
54{
55 using U = std::remove_const_t<T>;
56 if constexpr (std::is_same_v<U, bool>)
57 {
58 return MPI_C_BOOL;
59 }
60 if constexpr (std::is_same_v<U, char>)
61 {
62 return MPI_CHAR;
63 }
64 if constexpr (std::is_same_v<U, double>)
65 {
66 return MPI_DOUBLE;
67 }
68 if constexpr (std::is_same_v<U, float>)
69 {
70 return MPI_FLOAT;
71 }
72 if constexpr (std::is_same_v<U, int>)
73 {
74 return MPI_INT;
75 }
76 if constexpr (std::is_same_v<U, std::size_t>)
77 {
78 return MPI_UNSIGNED_LONG;
79 }
80 if constexpr (std::is_same_v<U, unsigned int>)
81 {
82 return MPI_UNSIGNED;
83 }
84}
85
86template <typename T>
87static std::vector<T> allgather(T const& value, Mpi const& mpi)
88{
89 std::vector<T> result(mpi.size);
90
91 MPI_Allgather(&value, 1, mpiType<T>(), result.data(), 1, mpiType<T>(),
92 mpi.communicator);
93
94 return result;
95}
96
97template <typename T>
98static std::vector<T> allgather(std::vector<T> const& vector, Mpi const& mpi)
99{
100 std::size_t const size = vector.size();
101 // Flat in memory over all ranks;
102 std::vector<T> result(mpi.size * size);
103
104 MPI_Allgather(vector.data(), size, mpiType<T>(), result.data(), size,
105 mpiType<T>(), mpi.communicator);
106
107 return result;
108}
109
110template <typename T>
111static T allreduce(T const& value, MPI_Op const& mpi_op, Mpi const& mpi)
112{
113 T result{};
114
115 MPI_Allreduce(&value, &result, 1, mpiType<T>(), mpi_op, mpi.communicator);
116 return result;
117}
118
119template <typename T>
120static std::vector<T> allreduce(std::vector<T> const& vector,
121 MPI_Op const& mpi_op, Mpi const& mpi)
122{
123 std::size_t const size = vector.size();
124 std::vector<T> result(vector.size());
125
126 MPI_Allreduce(vector.data(), result.data(), size, mpiType<T>(), mpi_op,
127 mpi.communicator);
128 return result;
129}
130
131template <typename T>
132static void allreduceInplace(std::vector<T>& vector,
133 MPI_Op const& mpi_op,
134 Mpi const& mpi)
135{
136 MPI_Allreduce(MPI_IN_PLACE,
137 vector.data(),
138 vector.size(),
139 mpiType<T>(),
140 mpi_op,
141 mpi.communicator);
142}
143
146template <typename T>
147static std::vector<int> allgatherv(
148 std::span<T> const send_buffer,
149 std::vector<std::remove_const_t<T>>& receive_buffer,
150 Mpi const& mpi)
151{
152 // Determine the number of elements to send
153 int const size = static_cast<int>(send_buffer.size());
154
155 // Gather sizes from all ranks
156 std::vector<int> const sizes = allgather(size, mpi);
157
158 // Compute offsets based on counts
159 std::vector<int> const offsets = BaseLib::sizesToOffsets(sizes);
160
161 // Resize receive buffer to hold all gathered data
162 receive_buffer.resize(offsets.back());
163
164 MPI_Allgatherv(send_buffer.data(), size, mpiType<T>(),
165 receive_buffer.data(), sizes.data(), offsets.data(),
166 mpiType<T>(), mpi.communicator);
167
168 return offsets;
169}
170#endif
171
174static inline bool anyOf(bool const val
175#ifdef USE_PETSC
176 ,
177 Mpi const& mpi = Mpi{OGS_COMM_WORLD}
178#endif
179)
180{
181#ifdef USE_PETSC
182 return allreduce(val, MPI_LOR, mpi);
183#else
184 return val;
185#endif
186}
187
190static inline bool allOf(bool const val
191#ifdef USE_PETSC
192 ,
193 Mpi const& mpi = Mpi{OGS_COMM_WORLD}
194#endif
195)
196{
197 return !anyOf(!val
198#ifdef USE_PETSC
199 ,
200 mpi
201#endif
202 );
203}
204
211template <typename BaseException>
212 requires std::derived_from<BaseException, std::exception> &&
213 ( // The used ctor excludes std::exception itself
214 !std::same_as<BaseException, std::exception>)
215class AnotherMPIRankThrew : public BaseException
216{
217public:
218 using BaseException::BaseException;
219
221 : BaseException{"Another MPI rank threw an exception."}
222 {
223 }
224};
225
235template <typename Exception>
236void allRanksThrowOrNone([[maybe_unused]] std::exception_ptr const& exception,
237 auto&& warning_callback)
238{
239 // std::exception would lead to duplicate catch clauses below and would not
240 // work together with the current implementation of AnotherMPIRankThrew.
241 static_assert(!std::is_same_v<Exception, std::exception>);
242
243 bool const exception_was_thrown = anyOf(exception != nullptr);
244
245 [[unlikely]] if (exception_was_thrown)
246 {
247 if (exception)
248 {
249 try
250 {
251 std::rethrow_exception(exception);
252 }
253 catch (Exception const&)
254 {
255 // OK. Argument exception is derived from class Exception.
256 throw;
257 }
258 catch (std::exception const& e)
259 {
260 warning_callback(
261 "An exception was thrown on this MPI rank, but it's not "
262 "derived from {}, but rather of type {}",
265 typeid(e).name() /* demangle the runtime type of e */));
266 throw;
267 }
268 catch (...)
269 {
270 warning_callback(
271 "An exception was thrown on this MPI rank, but it's not "
272 "derived from std::exception.");
273 throw;
274 }
275 }
276
278 }
279}
280
282template <typename Exception>
283void allRanksThrowOrNone([[maybe_unused]] std::exception_ptr const& exception)
284{
285 auto warning_callback =
286 []<typename... Args>(fmt::format_string<Args...> fmt, Args&&... args)
287 { WARN(fmt, std::forward<Args>(args)...); };
288
289 allRanksThrowOrNone<Exception>(exception, warning_callback);
290}
291
292} // namespace BaseLib::MPI
#define OGS_FATAL(...)
Definition Error.h:19
void WARN(fmt::format_string< Args... > fmt, Args &&... args)
Definition Logging.h:34
static void allreduceInplace(std::vector< T > &vector, MPI_Op const &mpi_op, Mpi const &mpi)
Definition MPI.h:132
static bool anyOf(bool const val, Mpi const &mpi=Mpi{OGS_COMM_WORLD})
Definition MPI.h:174
MPI_Comm OGS_COMM_WORLD
Definition MPI.cpp:9
static std::vector< int > allgatherv(std::span< T > const send_buffer, std::vector< std::remove_const_t< T > > &receive_buffer, Mpi const &mpi)
Definition MPI.h:147
static bool allOf(bool const val, Mpi const &mpi=Mpi{OGS_COMM_WORLD})
Definition MPI.h:190
static T allreduce(T const &value, MPI_Op const &mpi_op, Mpi const &mpi)
Definition MPI.h:111
void allRanksThrowOrNone(std::exception_ptr const &exception, auto &&warning_callback)
Definition MPI.h:236
constexpr MPI_Datatype mpiType()
Definition MPI.h:53
static std::vector< T > allgather(T const &value, Mpi const &mpi)
Definition MPI.h:87
std::string demangle(const char *mangled_name)
std::string typeToString()
std::vector< ranges::range_value_t< R > > sizesToOffsets(R const &sizes)
Definition Algorithm.h:276
Definition AABB.h:277
MPI_Comm communicator
Definition MPI.h:47
Mpi(MPI_Comm const communicator=OGS_COMM_WORLD)
Definition MPI.h:34
Setup(int argc, char *argv[])
Definition MPI.cpp:15