Blob Blame History Raw
/*
 * Copyright 2011 Ian Kent <raven@themaw.net>
 * Copyright 2011 Red Hat, Inc.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
*/

#include <sys/types.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <arpa/nameser.h>
#include <stdlib.h>
#include <string.h>
#include <resolv.h>
#include <netdb.h>
#include <ldap.h>
#include <sys/param.h>
#include <errno.h>
#include <endian.h>

#include "automount.h"
#include "dclist.h"

#define MAX_TTL		(60*60) /* 1 hour */

struct rr {
	unsigned int type;
	unsigned int class;
	unsigned long ttl;
	unsigned int len;
};

struct srv_rr {
	const char *name;
	unsigned int priority;
	unsigned int weight;
	unsigned int port;
	unsigned long ttl;
};

static pthread_mutex_t dclist_mutex = PTHREAD_MUTEX_INITIALIZER;

static void dclist_mutex_lock(void)
{
	int status = pthread_mutex_lock(&dclist_mutex);
	if (status)
		fatal(status);
	return;
}

static void dclist_mutex_unlock(void)
{
	int status = pthread_mutex_unlock(&dclist_mutex);
	if (status)
		fatal(status);
	return;
}

static int do_srv_query(unsigned int logopt, char *name, u_char **packet)
{
	int len = PACKETSZ;
	unsigned int last_len = len;
	char ebuf[MAX_ERR_BUF];
	u_char *buf;

	while (1) {
		buf = malloc(last_len);
		if (!buf) {
			char *estr = strerror_r(errno, ebuf, MAX_ERR_BUF);
			error(logopt, "malloc: %s", estr);
			return -1;
		}

		len = res_query(name, C_IN, T_SRV, buf, last_len);
		if (len < 0) {
			char *estr = strerror_r(errno, ebuf, MAX_ERR_BUF);
			error(logopt, "Failed to resolve %s (%s)", name, estr);
			free(buf);
			return -1;
		}

		if (len == last_len) {
			/* These shouldn't too large, bump by PACKETSZ only */
			last_len += PACKETSZ;
			free(buf);
			continue;
		}

		break;
	}

	*packet = buf;

	return len;
}

static int get_name_len(u_char *buffer, u_char *start, u_char *end)
{
	char tmp[MAXDNAME];
	return dn_expand(buffer, end, start, tmp, MAXDNAME);
}

static int get_data_offset(u_char *buffer,
			   u_char *start, u_char *end,
			   struct rr *rr)
{
	u_char *cp = start;
	int name_len;

	name_len = get_name_len(buffer, start, end);
	if (name_len < 0)
		return -1;
	cp += name_len;

	GETSHORT(rr->type, cp);
	GETSHORT(rr->class, cp);
	GETLONG(rr->ttl, cp);
	GETSHORT(rr->len, cp);

	return (cp - start);
}

static struct srv_rr *parse_srv_rr(unsigned int logopt,
				   u_char *buffer, u_char *start, u_char *end,
				   struct rr *rr, struct srv_rr *srv)
{
	u_char *cp = start;
	char ebuf[MAX_ERR_BUF];
	char tmp[MAXDNAME];
	int len;

	GETSHORT(srv->priority, cp);
	GETSHORT(srv->weight, cp);
	GETSHORT(srv->port, cp);
	srv->ttl = rr->ttl;

	len = dn_expand(buffer, end, cp, tmp, MAXDNAME);
	if (len < 0) {
		error(logopt, "failed to expand name");
		return NULL;
	}
	srv->name = strdup(tmp);
	if (!srv->name) {
		char *estr = strerror_r(errno, ebuf, MAX_ERR_BUF);
		error(logopt, "strdup: %s", estr);
		return NULL;
	}

	return srv;
}

static int cmp(struct srv_rr *a, struct srv_rr *b)
{
	if (a->priority < b->priority)
		return -1;

	if (a->priority > b->priority)
		return 1;

	if (!a->weight || a->weight == b->weight)
		return 0;

	if (a->weight > b->weight)
		return -1;

	return 1;
}

static void free_srv_rrs(struct srv_rr *dcs, unsigned int count)
{
	int i;

	for (i = 0; i < count; i++) {
		if (dcs[i].name)
			free((void *) dcs[i].name);
	}
	free(dcs);
}

int get_srv_rrs(unsigned int logopt,
		char *name, struct srv_rr **dcs, unsigned int *dcs_count)
{
	struct srv_rr *srvs;
	unsigned int srv_num;
	HEADER *header;
	u_char *packet;
	u_char *start;
	u_char *end;
	unsigned int count;
	int i, len;
	char ebuf[MAX_ERR_BUF];

	len = do_srv_query(logopt, name, &packet);
	if (len < 0)
		return 0;

	header = (HEADER *) packet;
	start = packet + sizeof(HEADER);
	end = packet + len;

	srvs = NULL;
	srv_num = 0;

	/* Skip over question */
	len = get_name_len(packet, start, end);
	if (len < 0) {
		error(logopt, "failed to get name length");
		goto error_out;
	}

	start += len + QFIXEDSZ;

	count = ntohs(header->ancount);

	debug(logopt, "%d records returned in the answer section", count);

	if (count <= 0) {
		error(logopt, "no records found in answers section");
		goto error_out;
	}

	srvs = malloc(sizeof(struct srv_rr) * count);
	if (!srvs) {
		char *estr = strerror_r(errno, ebuf, MAX_ERR_BUF);
		error(logopt, "malloc: %s", estr);
		goto error_out;
	}
	memset(srvs, 0, sizeof(struct srv_rr) * count);

	srv_num = 0;
	for (i = 0; i < count && (start < end); i++) {
		unsigned int data_offset;
		struct srv_rr srv;
		struct srv_rr *psrv;
		struct rr rr;

		memset(&rr, 0, sizeof(struct rr));

		data_offset = get_data_offset(packet, start, end, &rr);
		if (data_offset <= 0) {
			error(logopt, "failed to get start of data");
			goto error_out;
		}
		start += data_offset;

		if (rr.type != T_SRV)
			continue;

		psrv = parse_srv_rr(logopt, packet, start, end, &rr, &srv);
		if (psrv) {
			memcpy(&srvs[srv_num], psrv, sizeof(struct srv_rr));
			srv_num++;
		}

		start += rr.len;
	}
	free(packet);

	if (!srv_num) {
		error(logopt, "no srv resource records found");
		goto error_srvs;
	}

	qsort(srvs, srv_num, sizeof(struct srv_rr),
		(int (*)(const void *, const void *)) cmp);

	*dcs = srvs;
	*dcs_count = srv_num;

	return 1;

error_out:
	free(packet);
error_srvs:
	if (srvs)
		free_srv_rrs(srvs, srv_num);
	return 0;
}

static char *escape_dn_commas(const char *uri)
{
	size_t len = strlen(uri);
	char *new, *tmp, *ptr;

	ptr = (char *) uri;
	while (*ptr) {
		if (*ptr == '\\')
			ptr += 2;
		if (*ptr == ',')
			len += 2;
		ptr++;
	}

	new = malloc(len + 1);
	if (!new)
		return NULL;
	memset(new, 0, len + 1);

	ptr = (char *) uri;
	tmp = new;
	while (*ptr) {
		if (*ptr == '\\') {
			ptr++;
			*tmp++ = *ptr++;
			continue;
		}
		if (*ptr == ',') {
			strcpy(tmp, "%2c");
			ptr++;
			tmp += 3;
			continue;
		}
		*tmp++ = *ptr++;
	}

	return new;
}

void free_dclist(struct dclist *dclist)
{
	if (dclist->uri)
		free((void *) dclist->uri);
	free(dclist);
}

static char *getdnsdomainname(unsigned int logopt)
{
	struct addrinfo hints, *ni;
	char name[MAXDNAME + 1];
	char buf[MAX_ERR_BUF];
	char *dnsdomain = NULL;
	char *ptr;
	int ret;

	memset(name, 0, sizeof(name));
	if (gethostname(name, MAXDNAME) == -1) {
		char *estr = strerror_r(errno, buf, MAX_ERR_BUF);
		error(logopt, "gethostname: %s", estr);
		return NULL;
	}

	memset(&hints, 0, sizeof(hints));
	hints.ai_flags = AI_CANONNAME;
	hints.ai_family = AF_UNSPEC;
	hints.ai_socktype = SOCK_DGRAM;

	ret = getaddrinfo(name, NULL, &hints, &ni);
	if (ret) {
		error(logopt,
		      "hostname lookup for %s failed: %s",
		      name, gai_strerror(ret));
		return NULL;
	}

	ptr = ni->ai_canonname;
	while (*ptr && *ptr != '.')
		ptr++;

	if (*++ptr)
		dnsdomain = strdup(ptr);

	freeaddrinfo(ni);

	return dnsdomain;
}

struct dclist *get_dc_list(unsigned int logopt, const char *uri)
{
	LDAPURLDesc *ludlist = NULL;
	LDAPURLDesc **ludp;
	unsigned int min_ttl = MAX_TTL;
	struct dclist *dclist = NULL;;
	char buf[MAX_ERR_BUF];
	char *dn_uri, *esc_uri;
	char *domain;
	char *list;
	int ret;

	if (strcmp(uri, "ldap:///") && strcmp(uri, "ldaps:///")) {
		dn_uri = strdup(uri);
		if (!dn_uri) {
			char *estr = strerror_r(errno, buf, MAX_ERR_BUF);
			error(logopt, "strdup: %s", estr);
			return NULL;
		}
	} else {
		char *dnsdomain;
		char *hdn;

		dnsdomain = getdnsdomainname(logopt);
		if (!dnsdomain) {
			error(logopt, "failed to get dns domainname");
			return NULL;
		}

		if (ldap_domain2dn(dnsdomain, &hdn) || hdn == NULL) {
			error(logopt,
			      "Could not turn domain \"%s\" into a dn\n",
			      dnsdomain);
			free(dnsdomain);
			return NULL;
		}
		free(dnsdomain);

		dn_uri = malloc(strlen(uri) + strlen(hdn) + 1);
		if (!dn_uri) {
			char *estr = strerror_r(errno, buf, MAX_ERR_BUF);
			error(logopt, "malloc: %s", estr);
			ber_memfree(hdn);
			return NULL;
		}

		strcpy(dn_uri, uri);
		strcat(dn_uri, hdn);
		ber_memfree(hdn);
	}

	esc_uri = escape_dn_commas(dn_uri);
	if (!esc_uri) {
		error(logopt, "Could not escape commas in uri %s", dn_uri);
		free(dn_uri);
		return NULL;
	}

	ret = ldap_url_parse(esc_uri, &ludlist);
	if (ret != LDAP_URL_SUCCESS) {
		error(logopt, "Could not parse uri %s (%d)", dn_uri, ret);
		free(esc_uri);
		free(dn_uri);
		return NULL;
	}

	free(esc_uri);

	if (!ludlist) {
		error(logopt, "No dn found in uri %s", dn_uri);
		free(dn_uri);
		return NULL;
	}

	free(dn_uri);

	dclist = malloc(sizeof(struct dclist));
	if (!dclist) {
		char *estr = strerror_r(errno, buf, MAX_ERR_BUF);
		error(logopt, "malloc: %s", estr);
		ldap_free_urldesc(ludlist);
		return NULL;
	}
	memset(dclist, 0, sizeof(struct dclist));

	list = NULL;
	for (ludp = &ludlist; *ludp != NULL;) {
		LDAPURLDesc *lud = *ludp;
		struct srv_rr *dcs = NULL;
		unsigned int numdcs = 0;
		size_t req_len, len;
		char *request = NULL;
		char *tmp;
		int i;

		if (!lud->lud_dn && !lud->lud_dn[0] &&
		   (!lud->lud_host || !lud->lud_host[0])) {
			*ludp = lud->lud_next;
			continue;
		}

		domain = NULL;
		if (ldap_dn2domain(lud->lud_dn, &domain) || domain == NULL) {
			error(logopt,
			      "Could not turn dn \"%s\" into a domain",
			      lud->lud_dn);
			*ludp = lud->lud_next;
			continue;
		}

		debug(logopt, "doing lookup of SRV RRs for domain %s", domain);

		req_len = sizeof("_ldap._tcp.") + strlen(domain);
		request = malloc(req_len);
		if (!request) {
			char *estr = strerror_r(errno, buf, MAX_ERR_BUF);
			error(logopt, "malloc: %s", estr);
			goto out_error;
		}

		ret = snprintf(request, req_len, "_ldap._tcp.%s", domain);
		if (ret >= req_len) {
			free(request);
			goto out_error;
		}

		dclist_mutex_lock();
		ret = get_srv_rrs(logopt, request, &dcs, &numdcs);
		if (!ret | !dcs) {
			error(logopt,
			      "DNS SRV query failed for domain %s", domain);
			dclist_mutex_unlock();
			free(request);
			goto out_error;
		}
		dclist_mutex_unlock();
		free(request);

		len = strlen(lud->lud_scheme);
		len += sizeof("://");
		len *= numdcs;

		for (i = 0; i < numdcs; i++) {
			if (dcs[i].ttl > 0 && dcs[i].ttl < min_ttl)
				min_ttl = dcs[i].ttl;
			len += strlen(dcs[i].name);
			if (dcs[i].port > 0)
				len += sizeof(":65535");
		}

		tmp = realloc(list, len);
		if (!tmp) {
			char *estr = strerror_r(errno, buf, MAX_ERR_BUF);
			error(logopt, "realloc: %s", estr);
			free_srv_rrs(dcs, numdcs);
			goto out_error;
		}

		if (!list)
			memset(tmp, 0, len);
		else
			strcat(tmp, " ");

		list = NULL;
		for (i = 0; i < numdcs; i++) {
			if (i > 0)
				strcat(tmp, " ");
			strcat(tmp, lud->lud_scheme);
			strcat(tmp, "://");
			strcat(tmp, dcs[i].name);
			if (dcs[i].port > 0) {
				char port[7];
				ret = snprintf(port, 7, ":%d", dcs[i].port);
				if (ret > 6) {
					error(logopt,
					      "invalid port: %u", dcs[i].port);
					free_srv_rrs(dcs, numdcs);
					free(tmp);
					goto out_error;
				}
				strcat(tmp, port);
			}
		}
		list = tmp;

		*ludp = lud->lud_next;
		ber_memfree(domain);
		free_srv_rrs(dcs, numdcs);
	}

	ldap_free_urldesc(ludlist);

	if (!list)
		goto out_error;

	dclist->expire = monotonic_time(NULL) + min_ttl;
	dclist->uri = list;

	return dclist;

out_error:
	if (list)
		free(list);
	if (domain)
		ber_memfree(domain);
	ldap_free_urldesc(ludlist);
	free_dclist(dclist);
	return NULL;
}