1#ifndef FENIX_TASKS_MPI_HPP
2#define FENIX_TASKS_MPI_HPP
10namespace fenix::util {
12MPI_Datatype datatype();
15MPI_Datatype datatype(T&& t) {
20constexpr int count(T&& t,
int in_count);
23namespace fenix::tasks::mpi {
27 static_assert(std::is_trivially_copyable_v<T>);
35MPITask recv(T* b,
int n, MPI_Datatype d,
int r,
int t, MPI_Comm c) {
37 Status ret = MPI_Irecv(b, n, d, r, t, c, &request);
38 if (ret) ret =
co_await request;
42auto recv(T* b,
int n,
int r,
int t, MPI_Comm c) {
43 return recv(b, util::count(b, n), util::datatype(b), r, t, c);
46auto recv(T& b,
int r,
int t, MPI_Comm c) {
47 return recv(&b, 1, r, t, c);
49template <
typename T,
typename A>
50auto recv(std::vector<T, A>& v,
int r,
int t, MPI_Comm c) {
51 return recv(v.data(), v.size(), r, t, c);
55MPITask send(
const T* b,
int n, MPI_Datatype d,
int r,
int t, MPI_Comm c) {
57 Status ret = MPI_Isend(b, n, d, r, t, c, &request);
58 if (ret) ret =
co_await request;
62auto send(
const T* b,
int n,
int r,
int t, MPI_Comm c) {
63 return send(b, util::count(b, n), util::datatype(b), r, t, c);
66auto send(
const T& b,
int r,
int t, MPI_Comm c) {
67 return send(&b, 1, r, t, c);
69template <
typename T,
typename A>
70auto send(
const std::vector<T>& v,
int r,
int t, MPI_Comm c) {
71 return send(v.data(), v.size(), r, t, c);
74template <
typename ST,
typename RT>
76 const ST* sb,
int sn, MPI_Datatype sd,
int sr,
int st,
77 RT* rb,
int rn, MPI_Datatype rd,
int rr,
int rt, MPI_Comm c
79 auto recv_task = recv(rb, rn, rd, rr, rt, c);
82 co_await send(sb, sn, sd, sr, st, c);
83 co_return co_await recv_task;
85template <
typename ST,
typename RT>
87 const ST* sb,
int sn,
int sr,
int st,
88 RT* rb,
int rn,
int rr,
int rt, MPI_Comm c
91 sb, util::count(sb, sn), util::datatype(sb), sr, st,
92 rb, util::count(rb, rn), util::datatype(rb), rr, rt, c
95template <
typename ST,
typename RT>
97 const ST& sb,
int sr,
int st,
98 RT& rb,
int rr,
int rt, MPI_Comm c
100 return sendrecv(&sb, 1, sr, st, &rb, 1, rr, rt, c);
102template <
typename ST,
typename SA,
typename RT,
typename RA>
104 const std::vector<ST, SA>& sv,
int sr,
int st,
105 std::vector<RT, RA>& rv,
int rr,
int rt, MPI_Comm c
107 return sendrecv(&sv[0], sv.size(), sr, st, &rv[0], rv.size(), rr, rt, c);
112 const void* sb, T& rb,
int n, MPI_Datatype d, MPI_Op o, MPI_Comm c
115 Status ret = MPI_Iallreduce(sb, &rb, n, d, o, c, &request);
116 if (ret) ret =
co_await request;
120auto allreduce(
const T* sb, T& rb,
int n, MPI_Op o, MPI_Comm c) {
121 return allreduce(sb, rb, util::count(sb, n), util::datatype(sb), o, c);
124auto allreduce(
const T& sb, T& rb, MPI_Op o, MPI_Comm c) {
125 return allreduce(&sb, rb, 1, o, c);
127template <
typename T,
typename A>
128auto allreduce(
const std::vector<T, A>& sv, T& rb, MPI_Op o, MPI_Comm c) {
129 return allreduce(&sv[0], rb, sv.size(), o, c);
134 const void* sb, T& rb,
int n, MPI_Datatype d, MPI_Op o,
int r, MPI_Comm c
137 Status ret = MPI_Ireduce(sb, &rb, n, d, o, r, c, &request);
138 if (ret) ret =
co_await request;
142auto reduce(
const T* sb, T& rb,
int n, MPI_Op o,
int r, MPI_Comm c) {
143 return reduce(sb, rb, util::count(sb, n), util::datatype(sb), o, r, c);
146auto reduce(
const T& sb, T& rb, MPI_Op o,
int r, MPI_Comm c) {
147 return reduce(&sb, rb, 1, o, r, c);
149template <
typename T,
typename A>
150auto reduce(
const std::vector<T, A>& sv, T& rb, MPI_Op o,
int r, MPI_Comm c) {
151 return reduce(&sv[0], rb, sv.size(), o, r, c);
155MPITask bcast(T* b,
int n, MPI_Datatype d,
int r, MPI_Comm c) {
157 Status ret = MPI_Ibcast(b, n, d, r, c, &request);
158 if (ret) ret =
co_await request;
162auto bcast(T* b,
int n,
int r, MPI_Comm c) {
163 return bcast(b, util::count(b, n), util::datatype(b), r, c);
166auto bcast(T& b,
int r, MPI_Comm c) {
167 return bcast(b, 1, r, c);
169template <
typename T,
typename A>
170auto bcast(std::vector<T, A>& v,
int r, MPI_Comm c) {
171 return bcast(&v[0], v.size(), r, c);
174inline MPITask probe(
int src,
int tag, MPI_Comm comm) {
179 ret = MPI_Iprobe(src, tag, comm, &found, ret);
180 if (found || !ret)
co_return ret;
181 co_await std::suspend_always{};
186namespace fenix::util {
188#define MPI_TASK_TYPE(u, r, ...) \
189 if constexpr (std::is_same_v<u, __VA_ARGS__>) return r;
192MPI_Datatype datatype() {
193 using namespace fenix::tasks::mpi;
194 using U = std::remove_cv_t<std::remove_pointer_t<std::decay_t<T>>>;
195 static_assert(std::is_trivially_copyable_v<U>);
197 MPI_TASK_TYPE(U, MPI_CHAR,
char);
198 MPI_TASK_TYPE(U, MPI_FLOAT,
float);
199 MPI_TASK_TYPE(U, MPI_DOUBLE,
double);
200 MPI_TASK_TYPE(U, MPI_SHORT,
short);
201 MPI_TASK_TYPE(U, MPI_UNSIGNED_SHORT,
unsigned short);
202 MPI_TASK_TYPE(U, MPI_INT,
int);
203 MPI_TASK_TYPE(U, MPI_UNSIGNED,
unsigned int);
204 MPI_TASK_TYPE(U, MPI_LONG,
long);
205 MPI_TASK_TYPE(U, MPI_UNSIGNED_LONG,
unsigned long);
206 MPI_TASK_TYPE(U, MPI_LOGICAL,
bool);
224constexpr int count(T&& t,
int in_count) {
225 if (datatype<T>() == MPI_BYTE) {
226 return in_count *
sizeof(std::remove_pointer_t<std::decay_t<T>>);
Definition mpi_util.hpp:115