OGS
MPI.h
Go to the documentation of this file.
1
10#pragma once
11
12#include <algorithm>
13
14#include "Algorithm.h"
15#include "Error.h"
16
17#ifdef USE_PETSC
18#include <mpi.h>
19#endif
20
21namespace BaseLib::MPI
22{
23
24struct Setup
25{
26 Setup(int argc, char* argv[])
27 {
28#ifdef USE_PETSC
29 MPI_Init(&argc, &argv);
30#else
31 (void)argc;
32 (void)argv;
33#endif // USE_PETSC
34 }
35
37 {
38#ifdef USE_PETSC
39 MPI_Finalize();
40#endif // USE_PETSC
41 }
42};
43
44#ifdef USE_PETSC
45struct Mpi
46{
47 Mpi(MPI_Comm const communicator = MPI_COMM_WORLD)
49 {
50 int mpi_init;
51 MPI_Initialized(&mpi_init);
52 if (mpi_init != 1)
53 {
54 OGS_FATAL("MPI is not initialized.");
55 }
56 MPI_Comm_size(communicator, &size);
57 MPI_Comm_rank(communicator, &rank);
58 }
59
60 MPI_Comm communicator;
61 int size;
62 int rank;
63};
64
65template <typename T>
66constexpr MPI_Datatype mpiType()
67{
68 using U = std::remove_const_t<T>;
69 if constexpr (std::is_same_v<U, bool>)
70 {
71 return MPI_C_BOOL;
72 }
73 if constexpr (std::is_same_v<U, char>)
74 {
75 return MPI_CHAR;
76 }
77 if constexpr (std::is_same_v<U, double>)
78 {
79 return MPI_DOUBLE;
80 }
81 if constexpr (std::is_same_v<U, float>)
82 {
83 return MPI_FLOAT;
84 }
85 if constexpr (std::is_same_v<U, int>)
86 {
87 return MPI_INT;
88 }
89 if constexpr (std::is_same_v<U, std::size_t>)
90 {
91 return MPI_UNSIGNED_LONG;
92 }
93 if constexpr (std::is_same_v<U, unsigned int>)
94 {
95 return MPI_UNSIGNED;
96 }
97}
98
99template <typename T>
100static std::vector<T> allgather(T const& value, Mpi const& mpi)
101{
102 std::vector<T> result(mpi.size);
103
104 MPI_Allgather(&value, 1, mpiType<T>(), result.data(), 1, mpiType<T>(),
105 mpi.communicator);
106
107 return result;
108}
109
110template <typename T>
111static std::vector<T> allgather(std::vector<T> const& vector, Mpi const& mpi)
112{
113 std::size_t const size = vector.size();
114 // Flat in memory over all ranks;
115 std::vector<T> result(mpi.size * size);
116
117 MPI_Allgather(vector.data(), size, mpiType<T>(), result.data(), size,
118 mpiType<T>(), mpi.communicator);
119
120 return result;
121}
122
123template <typename T>
124static T allreduce(T const& value, MPI_Op const& mpi_op, Mpi const& mpi)
125{
126 T result{};
127
128 MPI_Allreduce(&value, &result, 1, mpiType<T>(), mpi_op, mpi.communicator);
129 return result;
130}
131
132template <typename T>
133static std::vector<T> allreduce(std::vector<T> const& vector,
134 MPI_Op const& mpi_op, Mpi const& mpi)
135{
136 std::size_t const size = vector.size();
137 std::vector<T> result(vector.size());
138
139 MPI_Allreduce(vector.data(), result.data(), size, mpiType<T>(), mpi_op,
140 mpi.communicator);
141 return result;
142}
143
144template <typename T>
145static void allreduceInplace(std::vector<T>& vector,
146 MPI_Op const& mpi_op,
147 Mpi const& mpi)
148{
149 MPI_Allreduce(MPI_IN_PLACE,
150 vector.data(),
151 vector.size(),
152 mpiType<T>(),
153 mpi_op,
154 mpi.communicator);
155}
156
159template <typename T>
160static std::vector<int> allgatherv(
161 std::span<T> const send_buffer,
162 std::vector<std::remove_const_t<T>>& receive_buffer,
163 Mpi const& mpi)
164{
165 // Determine the number of elements to send
166 int const size = static_cast<int>(send_buffer.size());
167
168 // Gather sizes from all ranks
169 std::vector<int> const sizes = allgather(size, mpi);
170
171 // Compute offsets based on counts
172 std::vector<int> const offsets = BaseLib::sizesToOffsets(sizes);
173
174 // Resize receive buffer to hold all gathered data
175 receive_buffer.resize(offsets.back());
176
177 MPI_Allgatherv(send_buffer.data(), size, mpiType<T>(),
178 receive_buffer.data(), sizes.data(), offsets.data(),
179 mpiType<T>(), mpi.communicator);
180
181 return offsets;
182}
183#endif
184
187static inline int reduceMin(int const val)
188{
189#ifdef USE_PETSC
190 return allreduce(val, MPI_MIN, Mpi{MPI_COMM_WORLD});
191#else
192 return val;
193#endif
194}
195
196} // namespace BaseLib::MPI
#define OGS_FATAL(...)
Definition Error.h:26
static int reduceMin(int const val)
Definition MPI.h:187
static void allreduceInplace(std::vector< T > &vector, MPI_Op const &mpi_op, Mpi const &mpi)
Definition MPI.h:145
static std::vector< int > allgatherv(std::span< T > const send_buffer, std::vector< std::remove_const_t< T > > &receive_buffer, Mpi const &mpi)
Definition MPI.h:160
static T allreduce(T const &value, MPI_Op const &mpi_op, Mpi const &mpi)
Definition MPI.h:124
constexpr MPI_Datatype mpiType()
Definition MPI.h:66
static std::vector< T > allgather(T const &value, Mpi const &mpi)
Definition MPI.h:100
std::vector< ranges::range_value_t< R > > sizesToOffsets(R const &sizes)
Definition Algorithm.h:283
Mpi(MPI_Comm const communicator=MPI_COMM_WORLD)
Definition MPI.h:47
MPI_Comm communicator
Definition MPI.h:60
Setup(int argc, char *argv[])
Definition MPI.h:26