#include "hcoll/api/hcoll_dte.h"
#include "mpiimpl.h"
#include "hcoll_dtypes.h"
extern int hcoll_initialized;
extern int hcoll_enable;
static dte_data_representation_t mpi_predefined_derived_2_hcoll(MPI_Datatype datatype)
{
MPI_Aint size;
switch (datatype) {
case MPI_FLOAT_INT:
return DTE_FLOAT_INT;
case MPI_DOUBLE_INT:
return DTE_DOUBLE_INT;
case MPI_LONG_INT:
return DTE_LONG_INT;
case MPI_SHORT_INT:
return DTE_SHORT_INT;
case MPI_LONG_DOUBLE_INT:
return DTE_LONG_DOUBLE_INT;
case MPI_2INT:
return DTE_2INT;
#ifdef HAVE_FORTRAN_BINDING
#if HCOLL_API >= HCOLL_VERSION(3,7)
case MPI_2INTEGER:
MPIR_Datatype_get_size_macro(datatype, size);
switch (size) {
case 4:
return DTE_2INT;
case 8:
return DTE_2INT64;
default:
return DTE_ZERO;
}
case MPI_2REAL:
MPIR_Datatype_get_size_macro(datatype, size);
switch (size) {
case 4:
return DTE_2FLOAT32;
case 8:
return DTE_2FLOAT64;
default:
return DTE_ZERO;
}
case MPI_2DOUBLE_PRECISION:
MPIR_Datatype_get_size_macro(datatype, size);
switch (size) {
case 4:
return DTE_2FLOAT32;
case 8:
return DTE_2FLOAT64;
default:
return DTE_ZERO;
}
#endif
#endif
default:
break;
}
return DTE_ZERO;
}
dte_data_representation_t mpi_dtype_2_hcoll_dtype(MPI_Datatype datatype, int count, const int mode)
{
dte_data_representation_t dte_data_rep = DTE_ZERO;
if (HANDLE_GET_KIND((datatype)) == HANDLE_KIND_BUILTIN) {
/* Built-in type */
dte_data_rep = mpi_dtype_2_dte_dtype(datatype);
}
#if HCOLL_API >= HCOLL_VERSION(3,6)
else if (TRY_FIND_DERIVED == mode) {
/* Check for predefined derived types */
dte_data_rep = mpi_predefined_derived_2_hcoll(datatype);
if (HCOL_DTE_IS_ZERO(dte_data_rep)) {
MPIR_Datatype *dt_ptr;
/* Must be a non-predefined derived mapping, get it */
MPIR_Datatype_get_ptr(datatype, dt_ptr);
dte_data_rep = (dte_data_representation_t) dt_ptr->dev.hcoll_datatype;
}
}
#endif
/* We always fall back, don't even think about forcing it! */
/* XXX Fix me
* if (HCOL_DTE_IS_ZERO(dte_data_rep) && TRY_FIND_DERIVED == mode
* && !mca_coll_hcoll_component.hcoll_datatype_fallback) {
* dte_data_rep = DTE_ZERO;
* dte_data_rep.rep.in_line_rep.data_handle.in_line.in_line = 0;
* dte_data_rep.rep.in_line_rep.data_handle.pointer_to_handle = (uint64_t) &datatype;
* }
*/
return dte_data_rep;
}
/* This will only get called once */
int hcoll_type_commit_hook(MPIR_Datatype * dtype_p)
{
int mpi_errno, ret;
if (0 == hcoll_initialized) {
mpi_errno = hcoll_initialize();
if (mpi_errno)
return MPI_ERR_OTHER;
}
if (0 == hcoll_enable) {
return MPI_SUCCESS;
}
dtype_p->dev.hcoll_datatype = mpi_predefined_derived_2_hcoll(dtype_p->handle);
if (!HCOL_DTE_IS_ZERO(dtype_p->dev.hcoll_datatype)) {
return MPI_SUCCESS;
}
dtype_p->dev.hcoll_datatype = DTE_ZERO;
ret = hcoll_create_mpi_type((void *) (intptr_t) dtype_p->handle, &dtype_p->dev.hcoll_datatype);
if (HCOLL_SUCCESS != ret) {
return MPI_ERR_OTHER;
}
if (HCOL_DTE_IS_ZERO(dtype_p->dev.hcoll_datatype))
MPIR_Datatype_add_ref_if_not_builtin(dtype_p->handle);
return MPI_SUCCESS;
}
int hcoll_type_free_hook(MPIR_Datatype * dtype_p)
{
if (0 == hcoll_enable) {
return MPI_SUCCESS;
}
if (HCOL_DTE_IS_ZERO(dtype_p->dev.hcoll_datatype))
MPIR_Datatype_release_if_not_builtin(dtype_p->handle);
int rc = hcoll_dt_destroy(dtype_p->dev.hcoll_datatype);
if (HCOLL_SUCCESS != rc) {
return MPI_ERR_OTHER;
}
dtype_p->dev.hcoll_datatype = DTE_ZERO;
return MPI_SUCCESS;
}