/** * Copyright (C) Mellanox Technologies Ltd. 2017-2019. ALL RIGHTS RESERVED. * Copyright (C) NVIDIA Corporation. 2019. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ #include "sockcm_ep.h" #include #include #include #include #define UCT_SOCKCM_CB_FLAGS_CHECK(_flags) \ do { \ UCT_CB_FLAGS_CHECK(_flags); \ if (!((_flags) & UCT_CB_FLAG_ASYNC)) { \ return UCS_ERR_UNSUPPORTED; \ } \ } while (0) ucs_status_t uct_sockcm_ep_set_sock_id(uct_sockcm_ep_t *ep) { ucs_status_t status; struct sockaddr *dest_addr = NULL; ep->sock_id_ctx = ucs_malloc(sizeof(*ep->sock_id_ctx), "client sock_id_ctx"); if (ep->sock_id_ctx == NULL) { return UCS_ERR_NO_MEMORY; } dest_addr = (struct sockaddr *) &(ep->remote_addr); status = ucs_socket_create(dest_addr->sa_family, SOCK_STREAM, &ep->sock_id_ctx->sock_fd); if (status != UCS_OK) { ucs_debug("unable to create client socket for sockcm"); ucs_free(ep->sock_id_ctx); return status; } return UCS_OK; } void uct_sockcm_ep_put_sock_id(uct_sockcm_ctx_t *sock_id_ctx) { close(sock_id_ctx->sock_fd); ucs_free(sock_id_ctx); } ucs_status_t uct_sockcm_ep_send_client_info(uct_sockcm_ep_t *ep) { uct_sockcm_iface_t *iface = ucs_derived_of(ep->super.super.iface, uct_sockcm_iface_t); ucs_status_t status; uct_sockcm_conn_param_t conn_param; char dev_name[UCT_DEVICE_NAME_MAX]; memset(&conn_param, 0, sizeof(uct_sockcm_conn_param_t)); /* get interface name associated with the connected client fd; use that for pack_cb */ status = ucs_sockaddr_get_ifname(ep->sock_id_ctx->sock_fd, dev_name, UCT_DEVICE_NAME_MAX); if (UCS_OK != status) { goto out; } conn_param.length = ep->pack_cb(ep->pack_cb_arg, dev_name, (void*)conn_param.private_data); if (conn_param.length < 0) { ucs_error("sockcm client (iface=%p, ep = %p) failed to fill " "private data. status: %s", iface, ep, ucs_status_string((ucs_status_t)conn_param.length)); status = UCS_ERR_IO_ERROR; goto out; } ucs_assert(conn_param.length <= UCT_SOCKCM_PRIV_DATA_LEN); status = ucs_socket_send(ep->sock_id_ctx->sock_fd, &conn_param, sizeof(uct_sockcm_conn_param_t), NULL, NULL); out: return status; } static const char* uct_sockcm_ep_conn_state_str(uct_sockcm_ep_conn_state_t state) { switch (state) { case UCT_SOCKCM_EP_CONN_STATE_SOCK_CONNECTING: return "UCT_SOCKCM_EP_CONN_STATE_SOCK_CONNECTING"; case UCT_SOCKCM_EP_CONN_STATE_INFO_SENT: return "UCT_SOCKCM_EP_CONN_STATE_INFO_SENT"; case UCT_SOCKCM_EP_CONN_STATE_CLOSED: return "UCT_SOCKCM_EP_CONN_STATE_CLOSED"; case UCT_SOCKCM_EP_CONN_STATE_CONNECTED: return "UCT_SOCKCM_EP_CONN_STATE_CONNECTED"; default: ucs_fatal("invaild sockcm endpoint state %d", state); } } static void uct_sockcm_change_state(uct_sockcm_ep_t *ep, uct_sockcm_ep_conn_state_t conn_state, ucs_status_t status) { uct_sockcm_iface_t *iface = ucs_derived_of(ep->super.super.iface, uct_sockcm_iface_t); pthread_mutex_lock(&ep->ops_mutex); ucs_debug("changing ep with status %s from state %s to state %s, status %s", ucs_status_string(ep->status), uct_sockcm_ep_conn_state_str(ep->conn_state), uct_sockcm_ep_conn_state_str(conn_state), ucs_status_string(status)); if ((ep->status != UCS_OK) && (ep->conn_state == UCT_SOCKCM_EP_CONN_STATE_CLOSED)) { /* Do not handle failure twice for closed EP */ pthread_mutex_unlock(&ep->ops_mutex); return; } ep->status = status; ep->conn_state = conn_state; if (conn_state == UCT_SOCKCM_EP_CONN_STATE_CLOSED) { uct_sockcm_ep_set_failed(&iface->super.super, &ep->super.super, status); } uct_sockcm_ep_invoke_completions(ep, status); pthread_mutex_unlock(&ep->ops_mutex); } static void uct_sockcm_handle_sock_connect(uct_sockcm_ep_t *ep) { char sockaddr_str[UCS_SOCKADDR_STRING_LEN]; int fd = ep->sock_id_ctx->sock_fd; ucs_status_t status; if (!ucs_socket_is_connected(fd)) { ucs_error("failed to connect to %s", ucs_sockaddr_str((struct sockaddr*)&ep->remote_addr, sockaddr_str, sizeof(sockaddr_str))); uct_sockcm_change_state(ep, UCT_SOCKCM_EP_CONN_STATE_CLOSED, UCS_ERR_UNREACHABLE); goto err; } status = uct_sockcm_ep_send_client_info(ep); if (status != UCS_OK) { ucs_error("failed to send client info: %s", ucs_status_string(status)); uct_sockcm_change_state(ep, UCT_SOCKCM_EP_CONN_STATE_CLOSED, status); goto err; } ep->conn_state = UCT_SOCKCM_EP_CONN_STATE_INFO_SENT; /* Call current handler when server responds to sent message */ if (UCS_OK != ucs_async_modify_handler(fd, UCS_EVENT_SET_EVREAD)) { ucs_error("failed to modify async handler for fd %d", fd); uct_sockcm_change_state(ep, UCT_SOCKCM_EP_CONN_STATE_CLOSED, UCS_ERR_IO_ERROR); goto err; } return; err: status = ucs_async_modify_handler(fd, 0); if (status != UCS_OK) { ucs_debug("unable to modify handler"); } } static void uct_sockcm_handle_info_sent(uct_sockcm_ep_t *ep) { ucs_status_t status; size_t recv_len; char notif_val; recv_len = sizeof(notif_val); status = ucs_socket_recv_nb(ep->sock_id_ctx->sock_fd, ¬if_val, &recv_len, NULL, NULL); if (UCS_ERR_NO_PROGRESS == status) { /* will call recv again when ready */ return; } ucs_async_remove_handler(ep->sock_id_ctx->sock_fd, 0); if (UCS_OK != status) { /* receive notif failed, close the connection */ uct_sockcm_change_state(ep, UCT_SOCKCM_EP_CONN_STATE_CLOSED, status); return; } if (notif_val == UCT_SOCKCM_IFACE_NOTIFY_ACCEPT) { ucs_debug("event_handler OK after accept"); uct_sockcm_change_state(ep, UCT_SOCKCM_EP_CONN_STATE_CONNECTED, UCS_OK); } else { ucs_debug("event_handler REJECTED after reject"); uct_sockcm_change_state(ep, UCT_SOCKCM_EP_CONN_STATE_CLOSED, UCS_ERR_REJECTED); } } static void uct_sockcm_ep_event_handler(int fd, void *arg) { uct_sockcm_ep_t *ep = (uct_sockcm_ep_t *) arg; switch (ep->conn_state) { case UCT_SOCKCM_EP_CONN_STATE_SOCK_CONNECTING: uct_sockcm_handle_sock_connect(ep); break; case UCT_SOCKCM_EP_CONN_STATE_INFO_SENT: uct_sockcm_handle_info_sent(ep); break; case UCT_SOCKCM_EP_CONN_STATE_CONNECTED: if (UCS_OK != ucs_async_modify_handler(fd, 0)) { ucs_warn("unable to turn off event notifications on %d", fd); } uct_sockcm_change_state(ep, UCT_SOCKCM_EP_CONN_STATE_CONNECTED, UCS_OK); break; case UCT_SOCKCM_EP_CONN_STATE_CLOSED: default: ucs_debug("handling closed/default state, ep %p fd %d", ep, fd); uct_sockcm_change_state(ep, UCT_SOCKCM_EP_CONN_STATE_CLOSED, UCS_ERR_IO_ERROR); break; } } static UCS_CLASS_INIT_FUNC(uct_sockcm_ep_t, const uct_ep_params_t *params) { const ucs_sock_addr_t *sockaddr = params->sockaddr; uct_sockcm_iface_t *iface = NULL; struct sockaddr *param_sockaddr = NULL; char ip_port_str[UCS_SOCKADDR_STRING_LEN]; ucs_status_t status; size_t sockaddr_len; iface = ucs_derived_of(params->iface, uct_sockcm_iface_t); UCS_CLASS_CALL_SUPER_INIT(uct_base_ep_t, &iface->super); if (iface->is_server) { return UCS_ERR_UNSUPPORTED; } if (!(params->field_mask & UCT_EP_PARAM_FIELD_SOCKADDR)) { return UCS_ERR_INVALID_PARAM; } UCT_SOCKCM_CB_FLAGS_CHECK((params->field_mask & UCT_EP_PARAM_FIELD_SOCKADDR_CB_FLAGS) ? params->sockaddr_cb_flags : 0); self->pack_cb = (params->field_mask & UCT_EP_PARAM_FIELD_SOCKADDR_PACK_CB) ? params->sockaddr_pack_cb : NULL; self->pack_cb_arg = (params->field_mask & UCT_EP_PARAM_FIELD_USER_DATA) ? params->user_data : NULL; self->pack_cb_flags = (params->field_mask & UCT_EP_PARAM_FIELD_SOCKADDR_CB_FLAGS) ? params->sockaddr_cb_flags : 0; pthread_mutex_init(&self->ops_mutex, NULL); ucs_queue_head_init(&self->ops); param_sockaddr = (struct sockaddr *) sockaddr->addr; if (UCS_OK != ucs_sockaddr_sizeof(param_sockaddr, &sockaddr_len)) { ucs_error("sockcm ep: unknown remote sa_family=%d", sockaddr->addr->sa_family); status = UCS_ERR_IO_ERROR; goto err; } memcpy(&self->remote_addr, param_sockaddr, sockaddr_len); self->slow_prog_id = UCS_CALLBACKQ_ID_NULL; status = uct_sockcm_ep_set_sock_id(self); if (status != UCS_OK) { goto err; } status = ucs_sys_fcntl_modfl(self->sock_id_ctx->sock_fd, O_NONBLOCK, 0); if (status != UCS_OK) { goto sock_err; } status = ucs_socket_connect(self->sock_id_ctx->sock_fd, param_sockaddr); if (UCS_STATUS_IS_ERR(status)) { self->conn_state = UCT_SOCKCM_EP_CONN_STATE_CLOSED; goto sock_err; } self->conn_state = UCT_SOCKCM_EP_CONN_STATE_SOCK_CONNECTING; self->status = UCS_INPROGRESS; /* set ep->status before event handler call to avoid simultaneous writes to state*/ status = ucs_async_set_event_handler(iface->super.worker->async->mode, self->sock_id_ctx->sock_fd, UCS_EVENT_SET_EVWRITE, uct_sockcm_ep_event_handler, self, iface->super.worker->async); if (status != UCS_OK) { goto sock_err; } ucs_debug("created an SOCKCM endpoint on iface %p, " "remote addr: %s", iface, ucs_sockaddr_str(param_sockaddr, ip_port_str, UCS_SOCKADDR_STRING_LEN)); return UCS_OK; sock_err: uct_sockcm_ep_put_sock_id(self->sock_id_ctx); err: ucs_debug("error in sock connect"); pthread_mutex_destroy(&self->ops_mutex); return status; } static UCS_CLASS_CLEANUP_FUNC(uct_sockcm_ep_t) { uct_sockcm_iface_t *iface = ucs_derived_of(self->super.super.iface, uct_sockcm_iface_t); ucs_debug("sockcm_ep %p: destroying", self); UCS_ASYNC_BLOCK(iface->super.worker->async); ucs_async_remove_handler(self->sock_id_ctx->sock_fd, 1); uct_sockcm_ep_put_sock_id(self->sock_id_ctx); uct_worker_progress_unregister_safe(&iface->super.worker->super, &self->slow_prog_id); pthread_mutex_destroy(&self->ops_mutex); if (!ucs_queue_is_empty(&self->ops)) { ucs_warn("destroying endpoint %p with not completed operations", self); } UCS_ASYNC_UNBLOCK(iface->super.worker->async); } UCS_CLASS_DEFINE(uct_sockcm_ep_t, uct_base_ep_t) UCS_CLASS_DEFINE_NEW_FUNC(uct_sockcm_ep_t, uct_ep_t, const uct_ep_params_t *); UCS_CLASS_DEFINE_DELETE_FUNC(uct_sockcm_ep_t, uct_ep_t); static unsigned uct_sockcm_client_err_handle_progress(void *arg) { uct_sockcm_ep_t *sockcm_ep = arg; uct_sockcm_iface_t *iface = ucs_derived_of(sockcm_ep->super.super.iface, uct_sockcm_iface_t); ucs_trace_func("err_handle ep=%p", sockcm_ep); UCS_ASYNC_BLOCK(iface->super.worker->async); sockcm_ep->slow_prog_id = UCS_CALLBACKQ_ID_NULL; uct_set_ep_failed(&UCS_CLASS_NAME(uct_sockcm_ep_t), &sockcm_ep->super.super, sockcm_ep->super.super.iface, sockcm_ep->status); UCS_ASYNC_UNBLOCK(iface->super.worker->async); return 0; } void uct_sockcm_ep_set_failed(uct_iface_t *iface, uct_ep_h ep, ucs_status_t status) { uct_sockcm_iface_t *sockcm_iface = ucs_derived_of(iface, uct_sockcm_iface_t); uct_sockcm_ep_t *sockcm_ep = ucs_derived_of(ep, uct_sockcm_ep_t); if (sockcm_iface->super.err_handler_flags & UCT_CB_FLAG_ASYNC) { uct_set_ep_failed(&UCS_CLASS_NAME(uct_sockcm_ep_t), &sockcm_ep->super.super, &sockcm_iface->super.super, status); } else { sockcm_ep->status = status; uct_worker_progress_register_safe(&sockcm_iface->super.worker->super, uct_sockcm_client_err_handle_progress, sockcm_ep, UCS_CALLBACKQ_FLAG_ONESHOT, &sockcm_ep->slow_prog_id); } } void uct_sockcm_ep_invoke_completions(uct_sockcm_ep_t *ep, ucs_status_t status) { uct_sockcm_ep_op_t *op; ucs_assert(pthread_mutex_trylock(&ep->ops_mutex) == EBUSY); ucs_queue_for_each_extract(op, &ep->ops, queue_elem, 1) { pthread_mutex_unlock(&ep->ops_mutex); uct_invoke_completion(op->user_comp, status); ucs_free(op); pthread_mutex_lock(&ep->ops_mutex); } }