Blob Blame History Raw
/*
 * Copyright (C) Advanced Micro Devices, Inc. 2019. ALL RIGHTS RESERVED.
 * See file LICENSE for terms.
 */

#include "rocm_ipc_ep.h"
#include "rocm_ipc_iface.h"
#include "rocm_ipc_md.h"

#include <uct/rocm/base/rocm_base.h>

static UCS_CLASS_INIT_FUNC(uct_rocm_ipc_ep_t, const uct_ep_params_t *params)
{
    uct_rocm_ipc_iface_t *iface = ucs_derived_of(params->iface, uct_rocm_ipc_iface_t);
    char target_name[64];
    ucs_status_t status;

    UCS_CLASS_CALL_SUPER_INIT(uct_base_ep_t, &iface->super);

    self->remote_pid = *(const pid_t*)params->iface_addr;

    snprintf(target_name, sizeof(target_name), "dest:%d", *(pid_t*)params->iface_addr);
    status = uct_rocm_ipc_create_cache(&self->remote_memh_cache, target_name);
    if (status != UCS_OK) {
        ucs_error("could not create create rocm ipc cache: %s",
                  ucs_status_string(status));
        return status;
    }

    return UCS_OK;
}

static UCS_CLASS_CLEANUP_FUNC(uct_rocm_ipc_ep_t)
{
    uct_rocm_ipc_destroy_cache(self->remote_memh_cache);
}

UCS_CLASS_DEFINE(uct_rocm_ipc_ep_t, uct_base_ep_t);
UCS_CLASS_DEFINE_NEW_FUNC(uct_rocm_ipc_ep_t, uct_ep_t, const uct_ep_params_t *);
UCS_CLASS_DEFINE_DELETE_FUNC(uct_rocm_ipc_ep_t, uct_ep_t);

#define uct_rocm_ipc_trace_data(_remote_addr, _rkey, _fmt, ...) \
    ucs_trace_data(_fmt " to %"PRIx64"(%+ld)", ## __VA_ARGS__, (_remote_addr), \
                   (_rkey))

ucs_status_t uct_rocm_ipc_ep_zcopy(uct_ep_h tl_ep,
                                   uint64_t remote_addr,
                                   const uct_iov_t *iov,
                                   uct_rocm_ipc_key_t *key,
                                   uct_completion_t *comp,
                                   int is_put)
{
    uct_rocm_ipc_ep_t *ep = ucs_derived_of(tl_ep, uct_rocm_ipc_ep_t);
    hsa_status_t status;
    hsa_agent_t local_agent;
    size_t size = uct_iov_get_length(iov);
    ucs_status_t ret = UCS_OK;
    void *base_addr, *local_addr = iov->buffer;
    uct_rocm_ipc_iface_t *iface = ucs_derived_of(tl_ep->iface, uct_rocm_ipc_iface_t);
    void *remote_base_addr, *remote_copy_addr;
    void *dst_addr, *src_addr;
    uct_rocm_ipc_signal_desc_t *rocm_ipc_signal;

    /* no data to deliver */
    if (!size)
        return UCS_OK;

    if ((remote_addr < key->address) ||
        (remote_addr + size > key->address + key->length)) {
        ucs_error("remote addr %lx/%lx out of range %lx/%lx",
                  remote_addr, size, key->address, key->length);
        return UCS_ERR_INVALID_PARAM;
    }

    status = uct_rocm_base_get_ptr_info(local_addr, size, &base_addr,
                                        NULL, &local_agent);
    if (status != HSA_STATUS_SUCCESS) {
        ucs_error("local addr %p/%lx is not ROCM memory", local_addr, size);
        return UCS_ERR_INVALID_ADDR;
    }

    ret = uct_rocm_ipc_cache_map_memhandle((void *)ep->remote_memh_cache, key,
                                           &remote_base_addr);
    if (ret != UCS_OK) {
        ucs_error("fail to attach ipc mem %p %d\n", (void *)key->address, ret);
        return ret;
    }

    remote_copy_addr = UCS_PTR_BYTE_OFFSET(remote_base_addr,
                                           remote_addr - key->address);
    if (is_put) {
        dst_addr = remote_copy_addr;
        src_addr = local_addr;
    }
    else {
        dst_addr = local_addr;
        src_addr = remote_copy_addr;
    }

    rocm_ipc_signal = ucs_mpool_get(&iface->signal_pool);
    hsa_signal_store_screlease(rocm_ipc_signal->signal, 1);

    status = hsa_amd_memory_async_copy(dst_addr, local_agent,
                                       src_addr, local_agent,
                                       size, 0, NULL,
                                       rocm_ipc_signal->signal);

    if (status != HSA_STATUS_SUCCESS) {
        ucs_error("copy error");
        ucs_mpool_put(rocm_ipc_signal);
        return UCS_ERR_IO_ERROR;
    }

    rocm_ipc_signal->comp = comp;
    rocm_ipc_signal->mapped_addr = remote_base_addr;
    ucs_queue_push(&iface->signal_queue, &rocm_ipc_signal->queue);

    ucs_trace("rocm async copy issued :%p remote:%p, local:%p  len:%ld",
              rocm_ipc_signal, (void *)remote_addr, local_addr, size);

    return UCS_INPROGRESS;
}

ucs_status_t uct_rocm_ipc_ep_put_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, size_t iovcnt,
                                       uint64_t remote_addr, uct_rkey_t rkey,
                                       uct_completion_t *comp)
{
    ucs_status_t ret;
    uct_rocm_ipc_key_t *key = (uct_rocm_ipc_key_t *)rkey;

    ret = uct_rocm_ipc_ep_zcopy(tl_ep, remote_addr, iov, key, comp, 1);

    UCT_TL_EP_STAT_OP(ucs_derived_of(tl_ep, uct_base_ep_t), PUT, ZCOPY,
                      uct_iov_total_length(iov, iovcnt));
    uct_rocm_ipc_trace_data(remote_addr, rkey, "PUT_ZCOPY [length %zu]",
                            uct_iov_total_length(iov, iovcnt));

    return ret;
}

ucs_status_t uct_rocm_ipc_ep_get_zcopy(uct_ep_h tl_ep, const uct_iov_t *iov, size_t iovcnt,
                                       uint64_t remote_addr, uct_rkey_t rkey,
                                       uct_completion_t *comp)
{
    ucs_status_t ret;
    uct_rocm_ipc_key_t *key = (uct_rocm_ipc_key_t *)rkey;

    ret = uct_rocm_ipc_ep_zcopy(tl_ep, remote_addr, iov, key, comp, 0);

    UCT_TL_EP_STAT_OP(ucs_derived_of(tl_ep, uct_base_ep_t), GET, ZCOPY,
                      uct_iov_total_length(iov, iovcnt));
    uct_rocm_ipc_trace_data(remote_addr, rkey, "GET_ZCOPY [length %zu]",
                            uct_iov_total_length(iov, iovcnt));

    return ret;
}