/**
* Copyright (C) Mellanox Technologies Ltd. 2001-2019. ALL RIGHTS RESERVED.
*
* See file LICENSE for terms.
*/
#ifndef GTEST_MEM_BUFFER_H_
#define GTEST_MEM_BUFFER_H_
#include <ucs/memory/memory_type.h>
#include <stdint.h>
#include <string>
#include <vector>
/**
* Wrapper and utility functions for memory type buffers, e.g buffers which are
* not necessarily allocated on host memory, such as cuda, rocm, etc.
*/
class mem_buffer {
public:
static std::vector<ucs_memory_type_t> supported_mem_types();
/* allocate buffer of a given memory type */
static void *allocate(size_t size, ucs_memory_type_t mem_type);
/* release buffer of a given memory type */
static void release(void *ptr, ucs_memory_type_t mem_type);
/* fill pattern in a host-accessible buffer */
static void pattern_fill(void *buffer, size_t length, uint64_t seed);
/* check pattern in a host-accessible buffer */
static void pattern_check(const void *buffer, size_t length, uint64_t seed);
/* check pattern in a host-accessible buffer, take seed from 1st word */
static void pattern_check(const void *buffer, size_t length);
/* fill pattern in a memtype buffer */
static void pattern_fill(void *buffer, size_t length, uint64_t seed,
ucs_memory_type_t mem_type);
/* check pattern in a memtype buffer */
static void pattern_check(const void *buffer, size_t length, uint64_t seed,
ucs_memory_type_t mem_type);
/* copy from host memory to memtype buffer */
static void copy_to(void *dst, const void *src, size_t length,
ucs_memory_type_t dst_mem_type);
/* copy from memtype buffer to host memory */
static void copy_from(void *dst, const void *src, size_t length,
ucs_memory_type_t src_mem_type);
/* compare memtype buffer with host memory, return true if equal */
static bool compare(const void *expected, const void *buffer,
size_t length, ucs_memory_type_t mem_type);
/* return the string name of a memory type */
static std::string mem_type_name(ucs_memory_type_t mem_type);
mem_buffer(size_t size, ucs_memory_type_t mem_type);
virtual ~mem_buffer();
ucs_memory_type_t mem_type() const;
void *ptr() const;
size_t size() const;
private:
static void abort_wrong_mem_type(ucs_memory_type_t mem_type);
static uint64_t pat(uint64_t prev);
const ucs_memory_type_t m_mem_type;
void * const m_ptr;
const size_t m_size;
};
#endif