OGS
MPI.h
Go to the documentation of this file.
1// SPDX-FileCopyrightText: Copyright (c) OpenGeoSys Community (opengeosys.org)
2// SPDX-License-Identifier: BSD-3-Clause
3
4#pragma once
5
6#include <algorithm>
7
8#include "Algorithm.h"
9#include "Error.h"
10
11#ifdef USE_PETSC
12#include <mpi.h>
13#endif
14
15namespace BaseLib::MPI
16{
17
18#ifdef USE_PETSC
19extern MPI_Comm OGS_COMM_WORLD;
20#endif
21
22struct Setup
23{
24 Setup(int argc, char* argv[])
25 {
26#ifdef USE_PETSC
27 MPI_Init(&argc, &argv);
28#else
29 (void)argc;
30 (void)argv;
31#endif // USE_PETSC
32 }
33
35 {
36#ifdef USE_PETSC
37 MPI_Finalize();
38#endif // USE_PETSC
39 }
40};
41
42#ifdef USE_PETSC
43struct Mpi
44{
47 {
48 int mpi_init;
49 MPI_Initialized(&mpi_init);
50 if (mpi_init != 1)
51 {
52 OGS_FATAL("MPI is not initialized.");
53 }
54 MPI_Comm_size(communicator, &size);
55 MPI_Comm_rank(communicator, &rank);
56 }
57
58 MPI_Comm communicator;
59 int size;
60 int rank;
61};
62
63template <typename T>
64constexpr MPI_Datatype mpiType()
65{
66 using U = std::remove_const_t<T>;
67 if constexpr (std::is_same_v<U, bool>)
68 {
69 return MPI_C_BOOL;
70 }
71 if constexpr (std::is_same_v<U, char>)
72 {
73 return MPI_CHAR;
74 }
75 if constexpr (std::is_same_v<U, double>)
76 {
77 return MPI_DOUBLE;
78 }
79 if constexpr (std::is_same_v<U, float>)
80 {
81 return MPI_FLOAT;
82 }
83 if constexpr (std::is_same_v<U, int>)
84 {
85 return MPI_INT;
86 }
87 if constexpr (std::is_same_v<U, std::size_t>)
88 {
89 return MPI_UNSIGNED_LONG;
90 }
91 if constexpr (std::is_same_v<U, unsigned int>)
92 {
93 return MPI_UNSIGNED;
94 }
95}
96
97template <typename T>
98static std::vector<T> allgather(T const& value, Mpi const& mpi)
99{
100 std::vector<T> result(mpi.size);
101
102 MPI_Allgather(&value, 1, mpiType<T>(), result.data(), 1, mpiType<T>(),
103 mpi.communicator);
104
105 return result;
106}
107
108template <typename T>
109static std::vector<T> allgather(std::vector<T> const& vector, Mpi const& mpi)
110{
111 std::size_t const size = vector.size();
112 // Flat in memory over all ranks;
113 std::vector<T> result(mpi.size * size);
114
115 MPI_Allgather(vector.data(), size, mpiType<T>(), result.data(), size,
116 mpiType<T>(), mpi.communicator);
117
118 return result;
119}
120
121template <typename T>
122static T allreduce(T const& value, MPI_Op const& mpi_op, Mpi const& mpi)
123{
124 T result{};
125
126 MPI_Allreduce(&value, &result, 1, mpiType<T>(), mpi_op, mpi.communicator);
127 return result;
128}
129
130template <typename T>
131static std::vector<T> allreduce(std::vector<T> const& vector,
132 MPI_Op const& mpi_op, Mpi const& mpi)
133{
134 std::size_t const size = vector.size();
135 std::vector<T> result(vector.size());
136
137 MPI_Allreduce(vector.data(), result.data(), size, mpiType<T>(), mpi_op,
138 mpi.communicator);
139 return result;
140}
141
142template <typename T>
143static void allreduceInplace(std::vector<T>& vector,
144 MPI_Op const& mpi_op,
145 Mpi const& mpi)
146{
147 MPI_Allreduce(MPI_IN_PLACE,
148 vector.data(),
149 vector.size(),
150 mpiType<T>(),
151 mpi_op,
152 mpi.communicator);
153}
154
157template <typename T>
158static std::vector<int> allgatherv(
159 std::span<T> const send_buffer,
160 std::vector<std::remove_const_t<T>>& receive_buffer,
161 Mpi const& mpi)
162{
163 // Determine the number of elements to send
164 int const size = static_cast<int>(send_buffer.size());
165
166 // Gather sizes from all ranks
167 std::vector<int> const sizes = allgather(size, mpi);
168
169 // Compute offsets based on counts
170 std::vector<int> const offsets = BaseLib::sizesToOffsets(sizes);
171
172 // Resize receive buffer to hold all gathered data
173 receive_buffer.resize(offsets.back());
174
175 MPI_Allgatherv(send_buffer.data(), size, mpiType<T>(),
176 receive_buffer.data(), sizes.data(), offsets.data(),
177 mpiType<T>(), mpi.communicator);
178
179 return offsets;
180}
181#endif
182
185static inline bool anyOf(bool const val
186#ifdef USE_PETSC
187 ,
188 Mpi const& mpi = Mpi{OGS_COMM_WORLD}
189#endif
190)
191{
192#ifdef USE_PETSC
193 return allreduce(val, MPI_LOR, mpi);
194#else
195 return val;
196#endif
197}
198
199} // namespace BaseLib::MPI
#define OGS_FATAL(...)
Definition Error.h:19
static void allreduceInplace(std::vector< T > &vector, MPI_Op const &mpi_op, Mpi const &mpi)
Definition MPI.h:143
static bool anyOf(bool const val, Mpi const &mpi=Mpi{OGS_COMM_WORLD})
Definition MPI.h:185
MPI_Comm OGS_COMM_WORLD
Definition MPI.cpp:9
static std::vector< int > allgatherv(std::span< T > const send_buffer, std::vector< std::remove_const_t< T > > &receive_buffer, Mpi const &mpi)
Definition MPI.h:158
static T allreduce(T const &value, MPI_Op const &mpi_op, Mpi const &mpi)
Definition MPI.h:122
constexpr MPI_Datatype mpiType()
Definition MPI.h:64
static std::vector< T > allgather(T const &value, Mpi const &mpi)
Definition MPI.h:98
std::vector< ranges::range_value_t< R > > sizesToOffsets(R const &sizes)
Definition Algorithm.h:276
MPI_Comm communicator
Definition MPI.h:58
Mpi(MPI_Comm const communicator=OGS_COMM_WORLD)
Definition MPI.h:45
Setup(int argc, char *argv[])
Definition MPI.h:24