Blob Blame History Raw
/**
* Copyright (C) Mellanox Technologies Ltd. 2001-2017.  ALL RIGHTS RESERVED.
* Copyright (C) Advanced Micro Devices, Inc. 2019.  ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
#include <ucm/api/ucm.h>
#include <common/test.h>
#include <hip_runtime.h>

static ucm_event_t alloc_event, free_event;

static void rocm_mem_alloc_callback(ucm_event_type_t event_type,
                                    ucm_event_t *event, void *arg)
{
    alloc_event.mem_type.address  = event->mem_type.address;
    alloc_event.mem_type.size     = event->mem_type.size;
    alloc_event.mem_type.mem_type = event->mem_type.mem_type;
}

static void rocm_mem_free_callback(ucm_event_type_t event_type,
                                   ucm_event_t *event, void *arg)
{
    free_event.mem_type.address  = event->mem_type.address;
    free_event.mem_type.size     = event->mem_type.size;
    free_event.mem_type.mem_type = event->mem_type.mem_type;
}

class rocm_hooks : public ucs::test {
protected:

    virtual void init() {
        int dev_count;
        ucs_status_t result;
        hipError_t ret;
        ucs::test::init();

        ret = hipGetDeviceCount(&dev_count);
        if ((ret != hipSuccess) || (dev_count < 1)) {
            UCS_TEST_SKIP_R("no ROCm device detected");
        }

        if (hipSetDevice(0) != hipSuccess) {
            UCS_TEST_SKIP_R("can't set ROCm device");
        }

        /* install memory hooks */
        result = ucm_set_event_handler(UCM_EVENT_MEM_TYPE_ALLOC, 0,
                                       rocm_mem_alloc_callback,
                                       reinterpret_cast<void*>(this));
        ASSERT_UCS_OK(result);

        result = ucm_set_event_handler(UCM_EVENT_MEM_TYPE_FREE, 0,
                                       rocm_mem_free_callback,
                                       reinterpret_cast<void*>(this));
        ASSERT_UCS_OK(result);
    }

    virtual void cleanup() {
        ucm_unset_event_handler(UCM_EVENT_MEM_TYPE_ALLOC,
                                rocm_mem_alloc_callback,
                                reinterpret_cast<void*>(this));
        ucm_unset_event_handler(UCM_EVENT_MEM_TYPE_FREE,
                                rocm_mem_free_callback,
                                reinterpret_cast<void*>(this));
        ucs::test::cleanup();
    }

    void check_mem_alloc_events(void *ptr, size_t size,
                                int expect_mem_type = UCS_MEMORY_TYPE_ROCM)  {
        ASSERT_EQ(ptr, alloc_event.mem_type.address);
        ASSERT_EQ(size, alloc_event.mem_type.size);
        ASSERT_EQ(expect_mem_type, alloc_event.mem_type.mem_type);
    }

    void check_mem_free_events(void *ptr, size_t size,
                               int expect_mem_type = UCS_MEMORY_TYPE_ROCM) {
        ASSERT_EQ(ptr, free_event.mem_type.address);
        ASSERT_EQ(expect_mem_type, free_event.mem_type.mem_type);
    }

};

UCS_TEST_F(rocm_hooks, test_hipMem_Alloc_Free) {
    hipError_t ret;
    void *dptr, *dptr1;

    /* small allocation */
    ret = hipMalloc(&dptr, 64);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_alloc_events((void *)dptr, 64);

    ret = hipFree(dptr);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_free_events((void *)dptr, 64);

    /* large allocation */
    ret = hipMalloc(&dptr, (256 * UCS_MBYTE));
    ASSERT_EQ(ret, hipSuccess);
    check_mem_alloc_events((void *)dptr, (256 * UCS_MBYTE));

    ret = hipFree(dptr);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_free_events((void *)dptr, (256 * UCS_MBYTE));

    /* multiple allocations, hipfree in reverse order */
    ret = hipMalloc(&dptr, (1 * UCS_MBYTE));
    ASSERT_EQ(ret, hipSuccess);
    check_mem_alloc_events((void *)dptr, (1 * UCS_MBYTE));

    ret = hipMalloc(&dptr1, (1 * UCS_MBYTE));
    ASSERT_EQ(ret, hipSuccess);
    check_mem_alloc_events((void *)dptr1, (1 * UCS_MBYTE));

    ret = hipFree(dptr1);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_free_events((void *)dptr1, (1 * UCS_MBYTE));

    ret = hipFree(dptr);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_free_events((void *)dptr, (1 * UCS_MBYTE));
}

UCS_TEST_F(rocm_hooks, test_hipMallocManaged) {
    hipError_t ret;
    void * dptr;

    ret = hipMallocManaged(&dptr, 64);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_alloc_events((void *)dptr, 64, UCS_MEMORY_TYPE_ROCM_MANAGED);

    ret = hipFree(dptr);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_free_events((void *)dptr, 0, UCS_MEMORY_TYPE_ROCM_MANAGED);
}

UCS_TEST_F(rocm_hooks, test_hipMallocPitch) {
    hipError_t ret;
    void * dptr;
    size_t pitch;

    ret = hipMallocPitch(&dptr, &pitch, 4, 8);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_alloc_events((void *)dptr, (128 * 8));

    ret = hipFree(dptr);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_free_events((void *)dptr, 0);
}

UCS_TEST_F(rocm_hooks, test_hip_Malloc_Free) {
    hipError_t ret;
    void *ptr, *ptr1;

    /* small allocation */
    ret = hipMalloc(&ptr, 64);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_alloc_events(ptr, 64);

    ret = hipFree(ptr);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_free_events(ptr, 64);

    /* large allocation */
    ret = hipMalloc(&ptr, (256 * UCS_MBYTE));
    ASSERT_EQ(ret, hipSuccess);
    check_mem_alloc_events(ptr, (256 * UCS_MBYTE));

    ret = hipFree(ptr);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_free_events(ptr, (256 * UCS_MBYTE));

    /* multiple allocations, rocmfree in reverse order */
    ret = hipMalloc(&ptr, (1 * UCS_MBYTE));
    ASSERT_EQ(ret, hipSuccess);
    check_mem_alloc_events(ptr, (1 * UCS_MBYTE));

    ret = hipMalloc(&ptr1, (1 * UCS_MBYTE));
    ASSERT_EQ(ret, hipSuccess);
    check_mem_alloc_events(ptr1, (1 * UCS_MBYTE));

    ret = hipFree(ptr1);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_free_events(ptr1, (1 * UCS_MBYTE));

    ret = hipFree(ptr);
    ASSERT_EQ(ret, hipSuccess);
    check_mem_free_events(ptr, (1 * UCS_MBYTE));

    /* hipFree with NULL */
    ret = hipFree(NULL);
    ASSERT_EQ(ret, hipSuccess);
}