/*
* Copyright (C) Advanced Micro Devices, Inc. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
#include "rocm_ipc_md.h"
#include <uct/rocm/base/rocm_base.h>
static ucs_config_field_t uct_rocm_ipc_md_config_table[] = {
{"", "", NULL,
ucs_offsetof(uct_rocm_ipc_md_config_t, super),
UCS_CONFIG_TYPE_TABLE(uct_md_config_table)},
{NULL}
};
static ucs_status_t uct_rocm_ipc_md_query(uct_md_h md, uct_md_attr_t *md_attr)
{
md_attr->rkey_packed_size = sizeof(uct_rocm_ipc_key_t);
md_attr->cap.flags = UCT_MD_FLAG_REG |
UCT_MD_FLAG_NEED_RKEY;
md_attr->cap.reg_mem_types = UCS_BIT(UCS_MEMORY_TYPE_ROCM);
md_attr->cap.access_mem_type = UCS_MEMORY_TYPE_ROCM;
md_attr->cap.detect_mem_types = 0;
md_attr->cap.max_alloc = 0;
md_attr->cap.max_reg = ULONG_MAX;
/* TODO: get accurate number */
md_attr->reg_cost.overhead = 9e-9;
md_attr->reg_cost.growth = 0;
memset(&md_attr->local_cpus, 0xff, sizeof(md_attr->local_cpus));
return UCS_OK;
}
static ucs_status_t uct_rocm_ipc_mkey_pack(uct_md_h md, uct_mem_h memh,
void *rkey_buffer)
{
uct_rocm_ipc_key_t *packed = (uct_rocm_ipc_key_t *) rkey_buffer;
uct_rocm_ipc_key_t *key = (uct_rocm_ipc_key_t *) memh;
*packed = *key;
return UCS_OK;
}
static hsa_status_t uct_rocm_ipc_pack_key(void *address, size_t length,
uct_rocm_ipc_key_t *key)
{
hsa_status_t status;
hsa_agent_t agent;
void *base_ptr;
size_t size;
status = uct_rocm_base_get_ptr_info(address, length, &base_ptr, &size, &agent);
if (status != HSA_STATUS_SUCCESS) {
ucs_error("pack none ROCM ptr %p/%lx", address, length);
return status;
}
status = hsa_amd_ipc_memory_create(base_ptr, size, &key->ipc);
if (status != HSA_STATUS_SUCCESS) {
ucs_error("Failed to create ipc for %p/%lx", address, length);
return status;
}
key->address = (uintptr_t)base_ptr;
key->length = size;
key->dev_num = uct_rocm_base_get_dev_num(agent);
return HSA_STATUS_SUCCESS;
}
static ucs_status_t uct_rocm_ipc_mem_reg(uct_md_h md, void *address, size_t length,
unsigned flags, uct_mem_h *memh_p)
{
uct_rocm_ipc_key_t *key;
hsa_status_t status;
key = ucs_malloc(sizeof(*key), "uct_rocm_ipc_key_t");
if (NULL == key) {
ucs_error("Failed to allocate memory for uct_rocm_ipc_key_t");
return UCS_ERR_NO_MEMORY;
}
status = uct_rocm_ipc_pack_key(address, length, key);
if (status != HSA_STATUS_SUCCESS) {
ucs_free(key);
return UCS_ERR_INVALID_ADDR;
}
*memh_p = key;
return UCS_OK;
}
static ucs_status_t uct_rocm_ipc_mem_dereg(uct_md_h md, uct_mem_h memh)
{
uct_rocm_ipc_key_t *key = (uct_rocm_ipc_key_t *)memh;
ucs_free(key);
return UCS_OK;
}
static ucs_status_t
uct_rocm_ipc_md_open(uct_component_h component, const char *md_name,
const uct_md_config_t *uct_md_config, uct_md_h *md_p)
{
static uct_md_ops_t md_ops = {
.close = (uct_md_close_func_t)ucs_empty_function,
.query = uct_rocm_ipc_md_query,
.mkey_pack = uct_rocm_ipc_mkey_pack,
.mem_reg = uct_rocm_ipc_mem_reg,
.mem_dereg = uct_rocm_ipc_mem_dereg,
.detect_memory_type = ucs_empty_function_return_unsupported,
};
static uct_md_t md = {
.ops = &md_ops,
.component = &uct_rocm_ipc_component,
};
*md_p = &md;
return UCS_OK;
}
static ucs_status_t uct_rocm_ipc_rkey_unpack(uct_component_t *component,
const void *rkey_buffer,
uct_rkey_t *rkey_p, void **handle_p)
{
uct_rocm_ipc_key_t *packed = (uct_rocm_ipc_key_t *)rkey_buffer;
uct_rocm_ipc_key_t *key;
key = ucs_malloc(sizeof(uct_rocm_ipc_key_t), "uct_rocm_ipc_key_t");
if (NULL == key) {
ucs_error("Failed to allocate memory for uct_rocm_ipc_key_t");
return UCS_ERR_NO_MEMORY;
}
*key = *packed;
*handle_p = NULL;
*rkey_p = (uintptr_t)key;
return UCS_OK;
}
static ucs_status_t uct_rocm_ipc_rkey_release(uct_component_t *component,
uct_rkey_t rkey, void *handle)
{
ucs_assert(NULL == handle);
ucs_free((void *)rkey);
return UCS_OK;
}
uct_component_t uct_rocm_ipc_component = {
.query_md_resources = uct_rocm_base_query_md_resources,
.md_open = uct_rocm_ipc_md_open,
.cm_open = ucs_empty_function_return_unsupported,
.rkey_unpack = uct_rocm_ipc_rkey_unpack,
.rkey_ptr = ucs_empty_function_return_unsupported,
.rkey_release = uct_rocm_ipc_rkey_release,
.name = "rocm_ipc",
.md_config = {
.name = "ROCm-IPC memory domain",
.prefix = "ROCM_IPC_MD_",
.table = uct_rocm_ipc_md_config_table,
.size = sizeof(uct_rocm_ipc_md_config_t),
},
.cm_config = UCS_CONFIG_EMPTY_GLOBAL_LIST_ENTRY,
.tl_list = UCT_COMPONENT_TL_LIST_INITIALIZER(&uct_rocm_ipc_component),
.flags = 0
};
UCT_COMPONENT_REGISTER(&uct_rocm_ipc_component);