Blob Blame History Raw
/**
* Copyright (C) Mellanox Technologies Ltd. 2019.  ALL RIGHTS RESERVED.
*
* See file LICENSE for terms.
*/

#include "tcp_sockcm_ep.h"
#include <ucs/sys/sock.h>
#include <ucs/async/async.h>
#include <ucs/arch/bitops.h>
#include <ucs/sys/string.h>


static UCS_F_ALWAYS_INLINE
uct_tcp_sockcm_t *uct_tcp_sockcm_ep_get_cm(uct_tcp_sockcm_ep_t *cep)
{
    /* return the tcp sockcm connection manager this ep is using */
    return ucs_container_of(cep->super.super.super.iface, uct_tcp_sockcm_t,
                            super.iface);
}

ucs_status_t uct_tcp_sockcm_ep_disconnect(uct_ep_h ep, unsigned flags)
{
    return UCS_ERR_NOT_IMPLEMENTED;
}

static void uct_tcp_sockcm_ep_init_comm_ctx(uct_tcp_sockcm_ep_t *cep)
{
    cep->comm_ctx.offset = 0;
    cep->comm_ctx.length = 0;
}

static void uct_tcp_sockcm_ep_handle_disconnect(uct_tcp_sockcm_ep_t *cep,
                                                ucs_status_t status)
{
    uct_cm_remote_data_t remote_data;

    /* remote peer disconnected */
    ucs_debug("ep %p (fd=%d): remote peer disconnected", cep, cep->fd);
    uct_tcp_sockcm_ep_init_comm_ctx(cep);

    ucs_assert(status != UCS_OK);
    if (cep->state & UCT_TCP_SOCKCM_EP_ON_SERVER) {
        uct_cm_ep_server_connect_cb(&cep->super, status);
    } else {
        ucs_assert(cep->state & UCT_TCP_SOCKCM_EP_ON_CLIENT);
        remote_data.field_mask = 0;
        uct_cm_ep_client_connect_cb(&cep->super, &remote_data, status);
    }

    /* TODO handle disconnect if the ep already invoked the connect_cb */
}

static int uct_tcp_sockcm_ep_is_tx_rx_done(uct_tcp_sockcm_ep_t *cep)
{
    ucs_assert((cep->comm_ctx.length != 0));
    return (cep->comm_ctx.offset == cep->comm_ctx.length);
}

ucs_status_t uct_tcp_sockcm_ep_progress_send(uct_tcp_sockcm_ep_t *cep)
{
    ucs_status_t status;
    size_t sent_length;

    ucs_assert(ucs_test_all_flags(cep->state, UCT_TCP_SOCKCM_EP_ON_CLIENT |
                                              UCT_TCP_SOCKCM_EP_CONNECTED));
    ucs_assert(cep->comm_ctx.offset < cep->comm_ctx.length);

    sent_length = cep->comm_ctx.length - cep->comm_ctx.offset;

    status = ucs_socket_send_nb(cep->fd,
                                UCS_PTR_BYTE_OFFSET(cep->comm_ctx.buf,
                                                    cep->comm_ctx.offset),
                                &sent_length, NULL, NULL);
    if ((status != UCS_OK) && (status != UCS_ERR_NO_PROGRESS)) {
        if (status == UCS_ERR_NOT_CONNECTED) {
            uct_tcp_sockcm_ep_handle_disconnect(cep, status);
        } else {
            ucs_error("ep %p failed to send client's data (len=%zu offset=%zu)",
                      cep, cep->comm_ctx.length, cep->comm_ctx.offset);
        }
        return status;
    }

    cep->comm_ctx.offset += sent_length;
    ucs_assert(cep->comm_ctx.offset <= cep->comm_ctx.length);
    cep->state           |= UCT_TCP_SOCKCM_EP_SENDING;

    if (uct_tcp_sockcm_ep_is_tx_rx_done(cep)) {
        cep->state |= UCT_TCP_SOCKCM_EP_DATA_SENT;
        uct_tcp_sockcm_ep_init_comm_ctx(cep);

        /* wait for a reply from the peer */
        status = ucs_async_modify_handler(cep->fd, UCS_EVENT_SET_EVREAD);
        if (status != UCS_OK) {
            ucs_error("failed to modify %d event handler to "
                      "UCS_EVENT_SET_EVREAD: %s", cep->fd,
                      ucs_status_string(status));
        }
    }

    return UCS_OK;
}

ucs_status_t uct_tcp_sockcm_ep_send_priv_data(uct_tcp_sockcm_ep_t *cep)
{
    char ifname_str[UCT_DEVICE_NAME_MAX];
    uct_tcp_sockcm_priv_data_hdr_t *hdr;
    ssize_t priv_data_ret;
    ucs_status_t status;

    /* get interface name associated with the connected client fd */
    status = ucs_sockaddr_get_ifname(cep->fd, ifname_str, sizeof(ifname_str));
    if (UCS_OK != status) {
        goto out;
    }

    hdr           = (uct_tcp_sockcm_priv_data_hdr_t*)cep->comm_ctx.buf;
    priv_data_ret = cep->super.priv_pack_cb(cep->super.user_data, ifname_str,
                                            hdr + 1);
    if (priv_data_ret < 0) {
        ucs_assert(priv_data_ret > UCS_ERR_LAST);
        status = (ucs_status_t)priv_data_ret;
        ucs_error("tcp_sockcm private data pack function failed with error: %s",
                  ucs_status_string(status));
        goto out;
    } else if (priv_data_ret > (uct_tcp_sockcm_ep_get_cm(cep)->priv_data_len)) {
        status = UCS_ERR_EXCEEDS_LIMIT;
        ucs_error("tcp_sockcm private data pack function returned %zd "
                  "(max: %zu)", priv_data_ret,
                  uct_tcp_sockcm_ep_get_cm(cep)->priv_data_len);
        goto out;
    }

    hdr->length          = priv_data_ret;
    cep->comm_ctx.length = sizeof(*hdr) + hdr->length;

    status = uct_tcp_sockcm_ep_progress_send(cep);

out:
    return status;
}

static ucs_status_t uct_tcp_sockcm_ep_server_invoke_conn_req_cb(uct_tcp_sockcm_ep_t *cep)
{
    uct_tcp_sockcm_priv_data_hdr_t *hdr = (uct_tcp_sockcm_priv_data_hdr_t *)
                                          cep->comm_ctx.buf;
    struct sockaddr_storage remote_dev_addr = {0};
    socklen_t               remote_dev_addr_len;
    char peer_str[UCS_SOCKADDR_STRING_LEN];
    char                    ifname_str[UCT_DEVICE_NAME_MAX];
    uct_cm_remote_data_t    remote_data;
    ucs_status_t            status;

    /* get the local interface name associated with the connected fd */
    status = ucs_sockaddr_get_ifname(cep->fd, ifname_str, UCT_DEVICE_NAME_MAX);
    if (UCS_OK != status) {
        return status;
    }

    /* get the device address of the remote peer associated with the connected fd */
    status = ucs_socket_getpeername(cep->fd, &remote_dev_addr, &remote_dev_addr_len);
    if (status != UCS_OK) {
        return status;
    }

    remote_data.field_mask            = UCT_CM_REMOTE_DATA_FIELD_DEV_ADDR        |
                                        UCT_CM_REMOTE_DATA_FIELD_DEV_ADDR_LENGTH |
                                        UCT_CM_REMOTE_DATA_FIELD_CONN_PRIV_DATA  |
                                        UCT_CM_REMOTE_DATA_FIELD_CONN_PRIV_DATA_LENGTH;
    remote_data.dev_addr              = (uct_device_addr_t *)&remote_dev_addr;
    remote_data.dev_addr_length       = remote_dev_addr_len;
    remote_data.conn_priv_data        = hdr + 1;
    remote_data.conn_priv_data_length = hdr->length;

    ucs_debug("fd %d: remote_data: (field_mask=%zu) dev_addr: %s (length=%zu), "
              "conn_priv_data_length=%zu", cep->fd, remote_data.field_mask,
              ucs_sockaddr_str((const struct sockaddr*)remote_data.dev_addr,
                               peer_str, UCS_SOCKADDR_STRING_LEN),
              remote_data.dev_addr_length, remote_data.conn_priv_data_length);

    /* the endpoint, passed as the conn_request to the callback, will be passed
     * to uct_ep_create() which will be invoked by the user and therefore moving
     * over to its responsibility. */
    ucs_list_del(&cep->list);
    cep->listener->conn_request_cb(&cep->listener->super, cep->listener->user_data,
                                   ifname_str, cep, &remote_data);

    return UCS_OK;
}

ucs_status_t uct_tcp_sockcm_ep_handle_data_received(uct_tcp_sockcm_ep_t *cep)
{
    ucs_status_t status;

    cep->state |= UCT_TCP_SOCKCM_EP_DATA_RECEIVED;
    uct_tcp_sockcm_ep_init_comm_ctx(cep);

    status = uct_tcp_sockcm_ep_server_invoke_conn_req_cb(cep);
    if (status != UCS_OK) {
        goto out;
    }

    status = ucs_async_modify_handler(cep->fd, UCS_EVENT_SET_EVWRITE);
    if (status != UCS_OK) {
        ucs_error("failed to modify %d event handler to UCS_EVENT_SET_EVWRITE: %s",
                  cep->fd, ucs_status_string(status));
        goto out;
    }

out:
    return status;
}

static ucs_status_t uct_tcp_sockcm_ep_recv_nb(uct_tcp_sockcm_ep_t *cep)
{
    size_t recv_length;
    ucs_status_t status;

    recv_length = uct_tcp_sockcm_ep_get_cm(cep)->priv_data_len +
                  sizeof(uct_tcp_sockcm_priv_data_hdr_t) - cep->comm_ctx.offset;
    status = ucs_socket_recv_nb(cep->fd, UCS_PTR_BYTE_OFFSET(cep->comm_ctx.buf,
                                                             cep->comm_ctx.offset),
                                &recv_length, NULL, NULL);
    if ((status != UCS_OK) && (status != UCS_ERR_NO_PROGRESS)) {
        if (status == UCS_ERR_NOT_CONNECTED) {
            uct_tcp_sockcm_ep_handle_disconnect(cep, status);
        } else {
            ucs_error("ep %p (fd=%d) failed to recv client's data (offset=%zu)",
                      cep, cep->fd, cep->comm_ctx.offset);
        }
        return status;
    }

    cep->comm_ctx.offset += recv_length;
    ucs_assertv((cep->comm_ctx.length ?
                 cep->comm_ctx.offset <= cep->comm_ctx.length : 1), "%zu > %zu",
                cep->comm_ctx.offset, cep->comm_ctx.length);
    return UCS_OK;
}

ucs_status_t uct_tcp_sockcm_ep_progress_recv(uct_tcp_sockcm_ep_t *cep)
{
    ucs_status_t status;

    status = uct_tcp_sockcm_ep_recv_nb(cep);
    if (status != UCS_OK) {
        return status;
    }

    if (uct_tcp_sockcm_ep_is_tx_rx_done(cep)) {
        status = uct_tcp_sockcm_ep_handle_data_received(cep);
    }

    return status;
}

ucs_status_t uct_tcp_sockcm_ep_recv(uct_tcp_sockcm_ep_t *cep)
{
    uct_tcp_sockcm_priv_data_hdr_t *hdr;
    ucs_status_t status;

    status = uct_tcp_sockcm_ep_recv_nb(cep);
    if (status != UCS_OK) {
        goto out;
    }

    if (cep->comm_ctx.offset < sizeof(*hdr)) {
        goto out;
    }

    hdr                  = (uct_tcp_sockcm_priv_data_hdr_t *)cep->comm_ctx.buf;
    cep->comm_ctx.length = sizeof(*hdr) + hdr->length;
    ucs_assertv(cep->comm_ctx.offset <= cep->comm_ctx.length , "%zu > %zu",
                cep->comm_ctx.offset, cep->comm_ctx.length);

    cep->state          |= UCT_TCP_SOCKCM_EP_RECEIVING;

    if (uct_tcp_sockcm_ep_is_tx_rx_done(cep)) {
        status = uct_tcp_sockcm_ep_handle_data_received(cep);
    }

out:
    return status;
}

static ucs_status_t uct_tcp_sockcm_ep_server_init(uct_tcp_sockcm_ep_t *cep,
                                                  const uct_ep_params_t *params)
{
    cep->state                  |= UCT_TCP_SOCKCM_EP_ON_SERVER;
    cep->super.server.connect_cb = params->sockaddr_connect_cb.server;
    return UCS_OK;
}

static ucs_status_t uct_tcp_sockcm_ep_client_init(uct_tcp_sockcm_ep_t *cep,
                                                  const uct_ep_params_t *params)
{
    uct_tcp_sockcm_t *tcp_sockcm = uct_tcp_sockcm_ep_get_cm(cep);
    char ip_port_str[UCS_SOCKADDR_STRING_LEN];
    const struct sockaddr *server_addr;
    ucs_async_context_t *async_ctx;
    ucs_status_t status;

    cep->state |= UCT_TCP_SOCKCM_EP_ON_CLIENT;
    cep->super.client.connect_cb = params->sockaddr_connect_cb.client;

    server_addr = params->sockaddr->addr;
    status = ucs_socket_create(server_addr->sa_family, SOCK_STREAM, &cep->fd);
    if (status != UCS_OK) {
        goto err;
    }

    /* Set the fd to non-blocking mode. (so that connect() won't be blocking) */
    status = ucs_sys_fcntl_modfl(cep->fd, O_NONBLOCK, 0);
    if (status != UCS_OK) {
        status = UCS_ERR_IO_ERROR;
        goto err_close_socket;
    }

    /* try to connect to the server */
    status = ucs_socket_connect(cep->fd, server_addr);
    if (UCS_STATUS_IS_ERR(status)) {
        goto err_close_socket;
    }
    ucs_assert((status == UCS_OK) || (status == UCS_INPROGRESS));

    async_ctx = tcp_sockcm->super.iface.worker->async;
    status    = ucs_async_set_event_handler(async_ctx->mode, cep->fd,
                                            UCS_EVENT_SET_EVWRITE,
                                            uct_tcp_sa_data_handler, cep,
                                            async_ctx);
    if (status != UCS_OK) {
        goto err_close_socket;
    }

    ucs_debug("created a TCP SOCKCM endpoint (fd=%d) on tcp cm %p, "
              "remote addr: %s", cep->fd, tcp_sockcm,
              ucs_sockaddr_str(server_addr, ip_port_str, UCS_SOCKADDR_STRING_LEN));

    return status;

err_close_socket:
    close(cep->fd);
err:
    return status;
}

UCS_CLASS_INIT_FUNC(uct_tcp_sockcm_ep_t, const uct_ep_params_t *params)
{
    ucs_status_t status;

    UCS_CLASS_CALL_SUPER_INIT(uct_cm_base_ep_t, params);

    uct_tcp_sockcm_ep_init_comm_ctx(self);
    self->state        = 0;
    self->comm_ctx.buf = ucs_malloc(uct_tcp_sockcm_ep_get_cm(self)->priv_data_len +
                                    sizeof(uct_tcp_sockcm_priv_data_hdr_t),
                                    "tcp_sockcm priv data");
    if (self->comm_ctx.buf == NULL) {
        ucs_error("failed to allocate memory for the ep's send/recv buf");
        return UCS_ERR_NO_MEMORY;
    }

    if (params->field_mask & UCT_EP_PARAM_FIELD_SOCKADDR) {
        status = uct_tcp_sockcm_ep_client_init(self, params);
    } else {
        status = uct_tcp_sockcm_ep_server_init(self, params);
    }

    if (status == UCS_OK) {
        ucs_debug("created an endpoint on tcp_sockcm %p id: %d state: %d",
                  uct_tcp_sockcm_ep_get_cm(self), self->fd, self->state);
    }

    return status;
}

ucs_status_t uct_tcp_sockcm_ep_create(const uct_ep_params_t *params, uct_ep_h *ep_p)
{
    uct_tcp_sockcm_ep_t *tcp_ep;

    if (params->field_mask & UCT_EP_PARAM_FIELD_SOCKADDR) {
        /* create a new endpoint for the client side */
        return UCS_CLASS_NEW(uct_tcp_sockcm_ep_t, ep_p, params);
    } else if (params->field_mask & UCT_EP_PARAM_FIELD_CONN_REQUEST) {
        /* the server's endpoint was already created by the listener, return it */
        tcp_ep = (uct_tcp_sockcm_ep_t*)(params->conn_request);
        *ep_p  = &tcp_ep->super.super.super;
        return UCS_OK;
    } else {
        ucs_error("either UCT_EP_PARAM_FIELD_SOCKADDR or UCT_EP_PARAM_FIELD_CONN_REQUEST "
                  "has to be provided");
        return UCS_ERR_INVALID_PARAM;
    }
}

UCS_CLASS_CLEANUP_FUNC(uct_tcp_sockcm_ep_t)
{
    uct_tcp_sockcm_t *tcp_sockcm = uct_tcp_sockcm_ep_get_cm(self);

    UCS_ASYNC_BLOCK(tcp_sockcm->super.iface.worker->async);

    ucs_free(self->comm_ctx.buf);

    ucs_async_remove_handler(self->fd, 1);

    if (self->fd != -1) {
        close(self->fd);
    }
    UCS_ASYNC_UNBLOCK(tcp_sockcm->super.iface.worker->async);
}

UCS_CLASS_DEFINE(uct_tcp_sockcm_ep_t, uct_base_ep_t);
UCS_CLASS_DEFINE_NEW_FUNC(uct_tcp_sockcm_ep_t, uct_ep_t, const uct_ep_params_t *);
UCS_CLASS_DEFINE_DELETE_FUNC(uct_tcp_sockcm_ep_t, uct_ep_t);