OGS
MPI.h
Go to the documentation of this file.
1
10#pragma once
11
12#include <algorithm>
13
14#include "Algorithm.h"
15#include "Error.h"
16
17#ifdef USE_PETSC
18#include <mpi.h>
19#endif
20
21namespace BaseLib::MPI
22{
23
24struct Setup
25{
26 Setup(int argc, char* argv[])
27 {
28#ifdef USE_PETSC
29 MPI_Init(&argc, &argv);
30#else
31 (void)argc;
32 (void)argv;
33#endif // USE_PETSC
34 }
35
37 {
38#ifdef USE_PETSC
39 MPI_Finalize();
40#endif // USE_PETSC
41 }
42};
43
44#ifdef USE_PETSC
45struct Mpi
46{
47 Mpi(MPI_Comm const communicator = MPI_COMM_WORLD)
49 {
50 int mpi_init;
51 MPI_Initialized(&mpi_init);
52 if (mpi_init != 1)
53 {
54 OGS_FATAL("MPI is not initialized.");
55 }
56 MPI_Comm_size(communicator, &size);
57 MPI_Comm_rank(communicator, &rank);
58 }
59
60 MPI_Comm communicator;
61 int size;
62 int rank;
63};
64
65template <typename T>
66constexpr MPI_Datatype mpiType()
67{
68 using U = std::remove_const_t<T>;
69 if constexpr (std::is_same_v<U, bool>)
70 {
71 return MPI_C_BOOL;
72 }
73 if constexpr (std::is_same_v<U, char>)
74 {
75 return MPI_CHAR;
76 }
77 if constexpr (std::is_same_v<U, double>)
78 {
79 return MPI_DOUBLE;
80 }
81 if constexpr (std::is_same_v<U, float>)
82 {
83 return MPI_FLOAT;
84 }
85 if constexpr (std::is_same_v<U, int>)
86 {
87 return MPI_INT;
88 }
89 if constexpr (std::is_same_v<U, std::size_t>)
90 {
91 return MPI_UNSIGNED_LONG;
92 }
93 if constexpr (std::is_same_v<U, unsigned int>)
94 {
95 return MPI_UNSIGNED;
96 }
97}
98
99template <typename T>
100static std::vector<T> allgather(T const& value, Mpi const& mpi)
101{
102 std::vector<T> result(mpi.size);
103
104 result[mpi.rank] = value;
105
106 MPI_Allgather(&result[mpi.rank], 1, mpiType<T>(), result.data(), 1,
107 mpiType<T>(), mpi.communicator);
108
109 return result;
110}
111
112template <typename T>
113static std::vector<T> allgather(std::vector<T> const& vector, Mpi const& mpi)
114{
115 std::size_t const size = vector.size();
116 // Flat in memory over all ranks;
117 std::vector<T> result(mpi.size * size);
118
119 std::copy_n(vector.begin(), size, &result[mpi.rank * size]);
120
121 MPI_Allgather(&result[mpi.rank * size], size, mpiType<T>(), result.data(),
122 size, mpiType<T>(), mpi.communicator);
123
124 return result;
125}
126
127template <typename T>
128static T allreduce(T const& value, MPI_Op const& mpi_op, Mpi const& mpi)
129{
130 T result{};
131
132 MPI_Allreduce(&value, &result, 1, mpiType<T>(), mpi_op, mpi.communicator);
133 return result;
134}
135
136template <typename T>
137static std::vector<T> allreduce(std::vector<T> const& vector,
138 MPI_Op const& mpi_op, Mpi const& mpi)
139{
140 std::size_t const size = vector.size();
141 std::vector<T> result(vector.size());
142
143 MPI_Allreduce(vector.data(), result.data(), size, mpiType<T>(), mpi_op,
144 mpi.communicator);
145 return result;
146}
147
148template <typename T>
149static void allreduceInplace(std::vector<T>& vector,
150 MPI_Op const& mpi_op,
151 Mpi const& mpi)
152{
153 MPI_Allreduce(MPI_IN_PLACE,
154 vector.data(),
155 vector.size(),
156 mpiType<T>(),
157 mpi_op,
158 mpi.communicator);
159}
160
163template <typename T>
164static std::vector<int> allgatherv(
165 std::span<T> const send_buffer,
166 std::vector<std::remove_const_t<T>>& receive_buffer,
167 Mpi const& mpi)
168{
169 // Determine the number of elements to send
170 int const size = static_cast<int>(send_buffer.size());
171
172 // Gather sizes from all ranks
173 std::vector<int> const sizes = allgather(size, mpi);
174
175 // Compute offsets based on counts
176 std::vector<int> const offsets = BaseLib::sizesToOffsets(sizes);
177
178 // Resize receive buffer to hold all gathered data
179 receive_buffer.resize(offsets.back());
180
181 MPI_Allgatherv(send_buffer.data(), size, mpiType<T>(),
182 receive_buffer.data(), sizes.data(), offsets.data(),
183 mpiType<T>(), mpi.communicator);
184
185 return offsets;
186}
187#endif
188
191static inline int reduceMin(int const val)
192{
193#ifdef USE_PETSC
194 return allreduce(val, MPI_MIN, Mpi{MPI_COMM_WORLD});
195#else
196 return val;
197#endif
198}
199
200} // namespace BaseLib::MPI
#define OGS_FATAL(...)
Definition Error.h:26
static int reduceMin(int const val)
Definition MPI.h:191
static void allreduceInplace(std::vector< T > &vector, MPI_Op const &mpi_op, Mpi const &mpi)
Definition MPI.h:149
static std::vector< int > allgatherv(std::span< T > const send_buffer, std::vector< std::remove_const_t< T > > &receive_buffer, Mpi const &mpi)
Definition MPI.h:164
static T allreduce(T const &value, MPI_Op const &mpi_op, Mpi const &mpi)
Definition MPI.h:128
constexpr MPI_Datatype mpiType()
Definition MPI.h:66
static std::vector< T > allgather(T const &value, Mpi const &mpi)
Definition MPI.h:100
std::vector< ranges::range_value_t< R > > sizesToOffsets(R const &sizes)
Definition Algorithm.h:283
Mpi(MPI_Comm const communicator=MPI_COMM_WORLD)
Definition MPI.h:47
MPI_Comm communicator
Definition MPI.h:60
Setup(int argc, char *argv[])
Definition MPI.h:26