Loading [MathJax]/jax/output/HTML-CSS/config.js
OGS
MPI.h
Go to the documentation of this file.
1
10#pragma once
11
12#include <algorithm>
13
14#include "Algorithm.h"
15#include "Error.h"
16
17#ifdef USE_PETSC
18#include <mpi.h>
19#endif
20
21namespace BaseLib::MPI
22{
23
24#ifdef USE_PETSC
25extern MPI_Comm OGS_COMM_WORLD;
26#endif
27
28struct Setup
29{
30 Setup(int argc, char* argv[])
31 {
32#ifdef USE_PETSC
33 MPI_Init(&argc, &argv);
34#else
35 (void)argc;
36 (void)argv;
37#endif // USE_PETSC
38 }
39
41 {
42#ifdef USE_PETSC
43 MPI_Finalize();
44#endif // USE_PETSC
45 }
46};
47
48#ifdef USE_PETSC
49struct Mpi
50{
53 {
54 int mpi_init;
55 MPI_Initialized(&mpi_init);
56 if (mpi_init != 1)
57 {
58 OGS_FATAL("MPI is not initialized.");
59 }
60 MPI_Comm_size(communicator, &size);
61 MPI_Comm_rank(communicator, &rank);
62 }
63
64 MPI_Comm communicator;
65 int size;
66 int rank;
67};
68
69template <typename T>
70constexpr MPI_Datatype mpiType()
71{
72 using U = std::remove_const_t<T>;
73 if constexpr (std::is_same_v<U, bool>)
74 {
75 return MPI_C_BOOL;
76 }
77 if constexpr (std::is_same_v<U, char>)
78 {
79 return MPI_CHAR;
80 }
81 if constexpr (std::is_same_v<U, double>)
82 {
83 return MPI_DOUBLE;
84 }
85 if constexpr (std::is_same_v<U, float>)
86 {
87 return MPI_FLOAT;
88 }
89 if constexpr (std::is_same_v<U, int>)
90 {
91 return MPI_INT;
92 }
93 if constexpr (std::is_same_v<U, std::size_t>)
94 {
95 return MPI_UNSIGNED_LONG;
96 }
97 if constexpr (std::is_same_v<U, unsigned int>)
98 {
99 return MPI_UNSIGNED;
100 }
101}
102
103template <typename T>
104static std::vector<T> allgather(T const& value, Mpi const& mpi)
105{
106 std::vector<T> result(mpi.size);
107
108 MPI_Allgather(&value, 1, mpiType<T>(), result.data(), 1, mpiType<T>(),
109 mpi.communicator);
110
111 return result;
112}
113
114template <typename T>
115static std::vector<T> allgather(std::vector<T> const& vector, Mpi const& mpi)
116{
117 std::size_t const size = vector.size();
118 // Flat in memory over all ranks;
119 std::vector<T> result(mpi.size * size);
120
121 MPI_Allgather(vector.data(), size, mpiType<T>(), result.data(), size,
122 mpiType<T>(), mpi.communicator);
123
124 return result;
125}
126
127template <typename T>
128static T allreduce(T const& value, MPI_Op const& mpi_op, Mpi const& mpi)
129{
130 T result{};
131
132 MPI_Allreduce(&value, &result, 1, mpiType<T>(), mpi_op, mpi.communicator);
133 return result;
134}
135
136template <typename T>
137static std::vector<T> allreduce(std::vector<T> const& vector,
138 MPI_Op const& mpi_op, Mpi const& mpi)
139{
140 std::size_t const size = vector.size();
141 std::vector<T> result(vector.size());
142
143 MPI_Allreduce(vector.data(), result.data(), size, mpiType<T>(), mpi_op,
144 mpi.communicator);
145 return result;
146}
147
148template <typename T>
149static void allreduceInplace(std::vector<T>& vector,
150 MPI_Op const& mpi_op,
151 Mpi const& mpi)
152{
153 MPI_Allreduce(MPI_IN_PLACE,
154 vector.data(),
155 vector.size(),
156 mpiType<T>(),
157 mpi_op,
158 mpi.communicator);
159}
160
163template <typename T>
164static std::vector<int> allgatherv(
165 std::span<T> const send_buffer,
166 std::vector<std::remove_const_t<T>>& receive_buffer,
167 Mpi const& mpi)
168{
169 // Determine the number of elements to send
170 int const size = static_cast<int>(send_buffer.size());
171
172 // Gather sizes from all ranks
173 std::vector<int> const sizes = allgather(size, mpi);
174
175 // Compute offsets based on counts
176 std::vector<int> const offsets = BaseLib::sizesToOffsets(sizes);
177
178 // Resize receive buffer to hold all gathered data
179 receive_buffer.resize(offsets.back());
180
181 MPI_Allgatherv(send_buffer.data(), size, mpiType<T>(),
182 receive_buffer.data(), sizes.data(), offsets.data(),
183 mpiType<T>(), mpi.communicator);
184
185 return offsets;
186}
187#endif
188
191static inline bool anyOf(bool const val
192#ifdef USE_PETSC
193 ,
194 Mpi const& mpi = Mpi{OGS_COMM_WORLD}
195#endif
196)
197{
198#ifdef USE_PETSC
199 return allreduce(val, MPI_LOR, mpi);
200#else
201 return val;
202#endif
203}
204
205} // namespace BaseLib::MPI
#define OGS_FATAL(...)
Definition Error.h:26
static void allreduceInplace(std::vector< T > &vector, MPI_Op const &mpi_op, Mpi const &mpi)
Definition MPI.h:149
static bool anyOf(bool const val, Mpi const &mpi=Mpi{OGS_COMM_WORLD})
Definition MPI.h:191
MPI_Comm OGS_COMM_WORLD
Definition MPI.cpp:15
static std::vector< int > allgatherv(std::span< T > const send_buffer, std::vector< std::remove_const_t< T > > &receive_buffer, Mpi const &mpi)
Definition MPI.h:164
static T allreduce(T const &value, MPI_Op const &mpi_op, Mpi const &mpi)
Definition MPI.h:128
constexpr MPI_Datatype mpiType()
Definition MPI.h:70
static std::vector< T > allgather(T const &value, Mpi const &mpi)
Definition MPI.h:104
std::vector< ranges::range_value_t< R > > sizesToOffsets(R const &sizes)
Definition Algorithm.h:283
MPI_Comm communicator
Definition MPI.h:64
Mpi(MPI_Comm const communicator=OGS_COMM_WORLD)
Definition MPI.h:51
Setup(int argc, char *argv[])
Definition MPI.h:30