Blob Blame History Raw
/*
 * Copyright (C) 2015 Red Hat, Inc.
 *
 * This file is part of the device-mapper multipath userspace tools.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public License
 * as published by the Free Software Foundation; either version 2.1
 * of the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <stdlib.h>
#include <unistd.h>
#include <stdio.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <poll.h>
#include <string.h>
#include <errno.h>
#include <fcntl.h>

#include "mpath_cmd.h"

/*
 * keep reading until its all read
 */
static ssize_t read_all(int fd, void *buf, size_t len, unsigned int timeout)
{
	size_t total = 0;
	ssize_t n;
	int ret;
	struct pollfd pfd;

	while (len) {
		pfd.fd = fd;
		pfd.events = POLLIN;
		ret = poll(&pfd, 1, timeout);
		if (!ret) {
			errno = ETIMEDOUT;
			return -1;
		} else if (ret < 0) {
			if (errno == EINTR)
				continue;
			return -1;
		} else if (!(pfd.revents & POLLIN))
			continue;
		n = recv(fd, buf, len, 0);
		if (n < 0) {
			if ((errno == EINTR) || (errno == EAGAIN))
				continue;
			return -1;
		}
		if (!n)
			return total;
		buf = n + (char *)buf;
		len -= n;
		total += n;
	}
	return total;
}

/*
 * keep writing until it's all sent
 */
static size_t write_all(int fd, const void *buf, size_t len)
{
	size_t total = 0;

	while (len) {
		ssize_t n = send(fd, buf, len, MSG_NOSIGNAL);
		if (n < 0) {
			if ((errno == EINTR) || (errno == EAGAIN))
				continue;
			return total;
		}
		if (!n)
			return total;
		buf = n + (const char *)buf;
		len -= n;
		total += n;
	}
	return total;
}

/*
 * connect to a unix domain socket
 */
int __mpath_connect(int nonblocking)
{
	int fd;
	size_t len;
	struct sockaddr_un addr;
	int flags = 0;

	memset(&addr, 0, sizeof(addr));
	addr.sun_family = AF_LOCAL;
	addr.sun_path[0] = '\0';
	strncpy(&addr.sun_path[1], DEFAULT_SOCKET, sizeof(addr.sun_path) - 1);
	len = strlen(DEFAULT_SOCKET) + 1 + sizeof(sa_family_t);
	if (len > sizeof(struct sockaddr_un))
		len = sizeof(struct sockaddr_un);

	fd = socket(AF_LOCAL, SOCK_STREAM, 0);
	if (fd == -1)
		return -1;

	if (nonblocking) {
		flags = fcntl(fd, F_GETFL, 0);
		if (flags != -1)
			(void)fcntl(fd, F_SETFL, flags|O_NONBLOCK);
	}

	if (connect(fd, (struct sockaddr *)&addr, len) == -1) {
		int err = errno;

		close(fd);
		errno = err;
		return -1;
	}

	if (nonblocking && flags != -1)
		(void)fcntl(fd, F_SETFL, flags);

	return fd;
}

/*
 * connect to a unix domain socket
 */
int mpath_connect(void)
{
	return __mpath_connect(0);
}

int mpath_disconnect(int fd)
{
	return close(fd);
}

ssize_t mpath_recv_reply_len(int fd, unsigned int timeout)
{
	size_t len;
	ssize_t ret;

	ret = read_all(fd, &len, sizeof(len), timeout);
	if (ret < 0)
		return ret;
	if (ret != sizeof(len)) {
		errno = EIO;
		return -1;
	}
	if (len <= 0 || len >= MAX_REPLY_LEN) {
		errno = ERANGE;
		return -1;
	}
	return len;
}

int mpath_recv_reply_data(int fd, char *reply, size_t len,
			  unsigned int timeout)
{
	ssize_t ret;

	ret = read_all(fd, reply, len, timeout);
	if (ret < 0)
		return ret;
	if ((size_t)ret != len) {
		errno = EIO;
		return -1;
	}
	reply[len - 1] = '\0';
	return 0;
}

int mpath_recv_reply(int fd, char **reply, unsigned int timeout)
{
	int err;
	ssize_t len;

	*reply = NULL;
	len = mpath_recv_reply_len(fd, timeout);
	if (len <= 0)
		return len;
	*reply = malloc(len);
	if (!*reply)
		return -1;
	err = mpath_recv_reply_data(fd, *reply, len, timeout);
	if (err) {
		free(*reply);
		*reply = NULL;
		return -1;
	}
	return 0;
}

int mpath_send_cmd(int fd, const char *cmd)
{
	size_t len;

	if (cmd != NULL)
		len = strlen(cmd) + 1;
	else
		len = 0;
	if (write_all(fd, &len, sizeof(len)) != sizeof(len))
		return -1;
	if (len && write_all(fd, cmd, len) != len)
		return -1;
	return 0;
}

int mpath_process_cmd(int fd, const char *cmd, char **reply,
		      unsigned int timeout)
{
	if (mpath_send_cmd(fd, cmd) != 0)
		return -1;
	return mpath_recv_reply(fd, reply, timeout);
}