Fenix @develop
 
Loading...
Searching...
No Matches
mpi.hpp
1#ifndef FENIX_TASKS_MPI_HPP
2#define FENIX_TASKS_MPI_HPP
3
4#include <type_traits>
5#include <utility>
6#include <vector>
7#include <mpi.h>
8#include "task.hpp"
9
10namespace fenix::util {
11template <typename T>
12MPI_Datatype datatype();
13
14template <typename T>
15MPI_Datatype datatype(T&& t) {
16 return datatype<T>();
17}
18
19template <typename T>
20constexpr int count(T&& t, int in_count);
21}
22
23namespace fenix::tasks::mpi {
24// C++ type corresponding to MPI_Datatype index pairs
25template <typename T>
26struct Indexed {
27 static_assert(std::is_trivially_copyable_v<T>);
28 T value;
29 int index;
30};
31
32using MPITask = Task<Status>;
33
34template <typename T>
35MPITask recv(T* b, int n, MPI_Datatype d, int r, int t, MPI_Comm c) {
36 MPI_Request request;
37 Status ret = MPI_Irecv(b, n, d, r, t, c, &request);
38 if (ret) ret = co_await request;
39 co_return ret;
40}
41template <typename T>
42auto recv(T* b, int n, int r, int t, MPI_Comm c) {
43 return recv(b, util::count(b, n), util::datatype(b), r, t, c);
44}
45template <typename T>
46auto recv(T& b, int r, int t, MPI_Comm c) {
47 return recv(&b, 1, r, t, c);
48}
49template <typename T, typename A>
50auto recv(std::vector<T, A>& v, int r, int t, MPI_Comm c) {
51 return recv(v.data(), v.size(), r, t, c);
52}
53
54template <typename T>
55MPITask send(const T* b, int n, MPI_Datatype d, int r, int t, MPI_Comm c) {
56 MPI_Request request;
57 Status ret = MPI_Isend(b, n, d, r, t, c, &request);
58 if (ret) ret = co_await request;
59 co_return ret;
60}
61template <typename T>
62auto send(const T* b, int n, int r, int t, MPI_Comm c) {
63 return send(b, util::count(b, n), util::datatype(b), r, t, c);
64}
65template <typename T>
66auto send(const T& b, int r, int t, MPI_Comm c) {
67 return send(&b, 1, r, t, c);
68}
69template <typename T, typename A>
70auto send(const std::vector<T>& v, int r, int t, MPI_Comm c) {
71 return send(v.data(), v.size(), r, t, c);
72}
73
74template <typename ST, typename RT>
75MPITask sendrecv(
76 const ST* sb, int sn, MPI_Datatype sd, int sr, int st,
77 RT* rb, int rn, MPI_Datatype rd, int rr, int rt, MPI_Comm c
78) {
79 auto recv_task = recv(rb, rn, rd, rr, rt, c);
80 // ensure lazily-evaluated recv_task actually begins
81 recv_task.resume();
82 co_await send(sb, sn, sd, sr, st, c);
83 co_return co_await recv_task;
84}
85template <typename ST, typename RT>
86auto sendrecv(
87 const ST* sb, int sn, int sr, int st,
88 RT* rb, int rn, int rr, int rt, MPI_Comm c
89) {
90 return sendrecv(
91 sb, util::count(sb, sn), util::datatype(sb), sr, st,
92 rb, util::count(rb, rn), util::datatype(rb), rr, rt, c
93 );
94}
95template <typename ST, typename RT>
96auto sendrecv(
97 const ST& sb, int sr, int st,
98 RT& rb, int rr, int rt, MPI_Comm c
99) {
100 return sendrecv(&sb, 1, sr, st, &rb, 1, rr, rt, c);
101}
102template <typename ST, typename SA, typename RT, typename RA>
103auto sendrecv(
104 const std::vector<ST, SA>& sv, int sr, int st,
105 std::vector<RT, RA>& rv, int rr, int rt, MPI_Comm c
106) {
107 return sendrecv(&sv[0], sv.size(), sr, st, &rv[0], rv.size(), rr, rt, c);
108}
109
110template <typename T>
111MPITask allreduce(
112 const void* sb, T& rb, int n, MPI_Datatype d, MPI_Op o, MPI_Comm c
113) {
114 MPI_Request request;
115 Status ret = MPI_Iallreduce(sb, &rb, n, d, o, c, &request);
116 if (ret) ret = co_await request;
117 co_return ret;
118}
119template <typename T>
120auto allreduce(const T* sb, T& rb, int n, MPI_Op o, MPI_Comm c) {
121 return allreduce(sb, rb, util::count(sb, n), util::datatype(sb), o, c);
122}
123template <typename T>
124auto allreduce(const T& sb, T& rb, MPI_Op o, MPI_Comm c) {
125 return allreduce(&sb, rb, 1, o, c);
126}
127template <typename T, typename A>
128auto allreduce(const std::vector<T, A>& sv, T& rb, MPI_Op o, MPI_Comm c) {
129 return allreduce(&sv[0], rb, sv.size(), o, c);
130}
131
132template <typename T>
133MPITask reduce(
134 const void* sb, T& rb, int n, MPI_Datatype d, MPI_Op o, int r, MPI_Comm c
135) {
136 MPI_Request request;
137 Status ret = MPI_Ireduce(sb, &rb, n, d, o, r, c, &request);
138 if (ret) ret = co_await request;
139 co_return ret;
140}
141template <typename T>
142auto reduce(const T* sb, T& rb, int n, MPI_Op o, int r, MPI_Comm c) {
143 return reduce(sb, rb, util::count(sb, n), util::datatype(sb), o, r, c);
144}
145template <typename T>
146auto reduce(const T& sb, T& rb, MPI_Op o, int r, MPI_Comm c) {
147 return reduce(&sb, rb, 1, o, r, c);
148}
149template <typename T, typename A>
150auto reduce(const std::vector<T, A>& sv, T& rb, MPI_Op o, int r, MPI_Comm c) {
151 return reduce(&sv[0], rb, sv.size(), o, r, c);
152}
153
154template <typename T>
155MPITask bcast(T* b, int n, MPI_Datatype d, int r, MPI_Comm c) {
156 MPI_Request request;
157 Status ret = MPI_Ibcast(b, n, d, r, c, &request);
158 if (ret) ret = co_await request;
159 co_return ret;
160}
161template <typename T>
162auto bcast(T* b, int n, int r, MPI_Comm c) {
163 return bcast(b, util::count(b, n), util::datatype(b), r, c);
164}
165template <typename T>
166auto bcast(T& b, int r, MPI_Comm c) {
167 return bcast(b, 1, r, c);
168}
169template <typename T, typename A>
170auto bcast(std::vector<T, A>& v, int r, MPI_Comm c) {
171 return bcast(&v[0], v.size(), r, c);
172}
173
174inline MPITask probe(int src, int tag, MPI_Comm comm) {
175 int found;
176 Status ret;
177 do {
178 int found;
179 ret = MPI_Iprobe(src, tag, comm, &found, ret);
180 if (found || !ret) co_return ret;
181 co_await std::suspend_always{};
182 } while (true);
183}
184} // namespace fenix::tasks::mpi
185
186namespace fenix::util {
187
188#define MPI_TASK_TYPE(u, r, ...) \
189 if constexpr (std::is_same_v<u, __VA_ARGS__>) return r;
190
191template <typename T>
192MPI_Datatype datatype() {
193 using namespace fenix::tasks::mpi;
194 using U = std::remove_cv_t<std::remove_pointer_t<std::decay_t<T>>>;
195 static_assert(std::is_trivially_copyable_v<U>);
196 // clang-format off
197 MPI_TASK_TYPE(U, MPI_CHAR, char);
198 MPI_TASK_TYPE(U, MPI_FLOAT, float);
199 MPI_TASK_TYPE(U, MPI_DOUBLE, double);
200 MPI_TASK_TYPE(U, MPI_SHORT, short);
201 MPI_TASK_TYPE(U, MPI_UNSIGNED_SHORT, unsigned short);
202 MPI_TASK_TYPE(U, MPI_INT, int);
203 MPI_TASK_TYPE(U, MPI_UNSIGNED, unsigned int);
204 MPI_TASK_TYPE(U, MPI_LONG, long);
205 MPI_TASK_TYPE(U, MPI_UNSIGNED_LONG, unsigned long);
206 MPI_TASK_TYPE(U, MPI_LOGICAL, bool);
207 MPI_TASK_TYPE(U, MPI_FLOAT_INT, Indexed<float>);
208 MPI_TASK_TYPE(U, MPI_DOUBLE_INT, Indexed<double>);
209 MPI_TASK_TYPE(U, MPI_LONG_INT, Indexed<long>);
210 MPI_TASK_TYPE(U, MPI_2INT, Indexed<int>);
211 MPI_TASK_TYPE(U, MPI_SHORT_INT, Indexed<short>);
212 MPI_TASK_TYPE(U, MPI_LONG_DOUBLE_INT, Indexed<long double>);
213 // clang-format on
214
215 // Technically sketch to just make this MPI_BYTE, but only when heterogenenous
216 // so we'll cross that bridge when we get there. Convenient for trivial custom
217 // types for now
218 return MPI_BYTE;
219}
220
221#undef MPI_TASK_TYPE
222
223template <typename T>
224constexpr int count(T&& t, int in_count) {
225 if (datatype<T>() == MPI_BYTE) {
226 return in_count * sizeof(std::remove_pointer_t<std::decay_t<T>>);
227 }
228 return in_count;
229}
230} // namespace fenix::util
231
232#endif // FENIX_TASKS_MPI_HPP
Definition task.hpp:16
Definition mpi_util.hpp:115
Definition mpi.hpp:26