/* * Copyright 2011 Ian Kent * Copyright 2011 Red Hat, Inc. * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program. If not, see . */ #include #include #include #include #include #include #include #include #include #include #include #include #include "automount.h" #include "dclist.h" #define MAX_TTL (60*60) /* 1 hour */ struct rr { unsigned int type; unsigned int class; unsigned long ttl; unsigned int len; }; struct srv_rr { const char *name; unsigned int priority; unsigned int weight; unsigned int port; unsigned long ttl; }; static pthread_mutex_t dclist_mutex = PTHREAD_MUTEX_INITIALIZER; static void dclist_mutex_lock(void) { int status = pthread_mutex_lock(&dclist_mutex); if (status) fatal(status); return; } static void dclist_mutex_unlock(void) { int status = pthread_mutex_unlock(&dclist_mutex); if (status) fatal(status); return; } static int do_srv_query(unsigned int logopt, char *name, u_char **packet) { int len = PACKETSZ; unsigned int last_len = len; char ebuf[MAX_ERR_BUF]; u_char *buf; while (1) { buf = malloc(last_len); if (!buf) { char *estr = strerror_r(errno, ebuf, MAX_ERR_BUF); error(logopt, "malloc: %s", estr); return -1; } len = res_query(name, C_IN, T_SRV, buf, last_len); if (len < 0) { char *estr = strerror_r(errno, ebuf, MAX_ERR_BUF); error(logopt, "Failed to resolve %s (%s)", name, estr); free(buf); return -1; } if (len == last_len) { /* These shouldn't too large, bump by PACKETSZ only */ last_len += PACKETSZ; free(buf); continue; } break; } *packet = buf; return len; } static int get_name_len(u_char *buffer, u_char *start, u_char *end) { char tmp[MAXDNAME]; return dn_expand(buffer, end, start, tmp, MAXDNAME); } static int get_data_offset(u_char *buffer, u_char *start, u_char *end, struct rr *rr) { u_char *cp = start; int name_len; name_len = get_name_len(buffer, start, end); if (name_len < 0) return -1; cp += name_len; GETSHORT(rr->type, cp); GETSHORT(rr->class, cp); GETLONG(rr->ttl, cp); GETSHORT(rr->len, cp); return (cp - start); } static struct srv_rr *parse_srv_rr(unsigned int logopt, u_char *buffer, u_char *start, u_char *end, struct rr *rr, struct srv_rr *srv) { u_char *cp = start; char ebuf[MAX_ERR_BUF]; char tmp[MAXDNAME]; int len; GETSHORT(srv->priority, cp); GETSHORT(srv->weight, cp); GETSHORT(srv->port, cp); srv->ttl = rr->ttl; len = dn_expand(buffer, end, cp, tmp, MAXDNAME); if (len < 0) { error(logopt, "failed to expand name"); return NULL; } srv->name = strdup(tmp); if (!srv->name) { char *estr = strerror_r(errno, ebuf, MAX_ERR_BUF); error(logopt, "strdup: %s", estr); return NULL; } return srv; } static int cmp(struct srv_rr *a, struct srv_rr *b) { if (a->priority < b->priority) return -1; if (a->priority > b->priority) return 1; if (!a->weight || a->weight == b->weight) return 0; if (a->weight > b->weight) return -1; return 1; } static void free_srv_rrs(struct srv_rr *dcs, unsigned int count) { int i; for (i = 0; i < count; i++) { if (dcs[i].name) free((void *) dcs[i].name); } free(dcs); } int get_srv_rrs(unsigned int logopt, char *name, struct srv_rr **dcs, unsigned int *dcs_count) { struct srv_rr *srvs; unsigned int srv_num; HEADER *header; u_char *packet; u_char *start; u_char *end; unsigned int count; int i, len; char ebuf[MAX_ERR_BUF]; len = do_srv_query(logopt, name, &packet); if (len < 0) return 0; header = (HEADER *) packet; start = packet + sizeof(HEADER); end = packet + len; srvs = NULL; srv_num = 0; /* Skip over question */ len = get_name_len(packet, start, end); if (len < 0) { error(logopt, "failed to get name length"); goto error_out; } start += len + QFIXEDSZ; count = ntohs(header->ancount); debug(logopt, "%d records returned in the answer section", count); if (count <= 0) { error(logopt, "no records found in answers section"); goto error_out; } srvs = malloc(sizeof(struct srv_rr) * count); if (!srvs) { char *estr = strerror_r(errno, ebuf, MAX_ERR_BUF); error(logopt, "malloc: %s", estr); goto error_out; } memset(srvs, 0, sizeof(struct srv_rr) * count); srv_num = 0; for (i = 0; i < count && (start < end); i++) { unsigned int data_offset; struct srv_rr srv; struct srv_rr *psrv; struct rr rr; memset(&rr, 0, sizeof(struct rr)); data_offset = get_data_offset(packet, start, end, &rr); if (data_offset <= 0) { error(logopt, "failed to get start of data"); goto error_out; } start += data_offset; if (rr.type != T_SRV) continue; psrv = parse_srv_rr(logopt, packet, start, end, &rr, &srv); if (psrv) { memcpy(&srvs[srv_num], psrv, sizeof(struct srv_rr)); srv_num++; } start += rr.len; } free(packet); if (!srv_num) { error(logopt, "no srv resource records found"); goto error_srvs; } qsort(srvs, srv_num, sizeof(struct srv_rr), (int (*)(const void *, const void *)) cmp); *dcs = srvs; *dcs_count = srv_num; return 1; error_out: free(packet); error_srvs: if (srvs) free_srv_rrs(srvs, srv_num); return 0; } static char *escape_dn_commas(const char *uri) { size_t len = strlen(uri); char *new, *tmp, *ptr; ptr = (char *) uri; while (*ptr) { if (*ptr == '\\') ptr += 2; if (*ptr == ',') len += 2; ptr++; } new = malloc(len + 1); if (!new) return NULL; memset(new, 0, len + 1); ptr = (char *) uri; tmp = new; while (*ptr) { if (*ptr == '\\') { ptr++; *tmp++ = *ptr++; continue; } if (*ptr == ',') { strcpy(tmp, "%2c"); ptr++; tmp += 3; continue; } *tmp++ = *ptr++; } return new; } void free_dclist(struct dclist *dclist) { if (dclist->uri) free((void *) dclist->uri); free(dclist); } static char *getdnsdomainname(unsigned int logopt) { struct addrinfo hints, *ni; char name[MAXDNAME + 1]; char buf[MAX_ERR_BUF]; char *dnsdomain = NULL; char *ptr; int ret; memset(name, 0, sizeof(name)); if (gethostname(name, MAXDNAME) == -1) { char *estr = strerror_r(errno, buf, MAX_ERR_BUF); error(logopt, "gethostname: %s", estr); return NULL; } memset(&hints, 0, sizeof(hints)); hints.ai_flags = AI_CANONNAME; hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_DGRAM; ret = getaddrinfo(name, NULL, &hints, &ni); if (ret) { error(logopt, "hostname lookup for %s failed: %s", name, gai_strerror(ret)); return NULL; } ptr = ni->ai_canonname; while (*ptr && *ptr != '.') ptr++; if (*++ptr) dnsdomain = strdup(ptr); freeaddrinfo(ni); return dnsdomain; } struct dclist *get_dc_list(unsigned int logopt, const char *uri) { LDAPURLDesc *ludlist = NULL; LDAPURLDesc **ludp; unsigned int min_ttl = MAX_TTL; struct dclist *dclist = NULL;; char buf[MAX_ERR_BUF]; char *dn_uri, *esc_uri; char *domain; char *list; int ret; if (strcmp(uri, "ldap:///") && strcmp(uri, "ldaps:///")) { dn_uri = strdup(uri); if (!dn_uri) { char *estr = strerror_r(errno, buf, MAX_ERR_BUF); error(logopt, "strdup: %s", estr); return NULL; } } else { char *dnsdomain; char *hdn; dnsdomain = getdnsdomainname(logopt); if (!dnsdomain) { error(logopt, "failed to get dns domainname"); return NULL; } if (ldap_domain2dn(dnsdomain, &hdn) || hdn == NULL) { error(logopt, "Could not turn domain \"%s\" into a dn\n", dnsdomain); free(dnsdomain); return NULL; } free(dnsdomain); dn_uri = malloc(strlen(uri) + strlen(hdn) + 1); if (!dn_uri) { char *estr = strerror_r(errno, buf, MAX_ERR_BUF); error(logopt, "malloc: %s", estr); ber_memfree(hdn); return NULL; } strcpy(dn_uri, uri); strcat(dn_uri, hdn); ber_memfree(hdn); } esc_uri = escape_dn_commas(dn_uri); if (!esc_uri) { error(logopt, "Could not escape commas in uri %s", dn_uri); free(dn_uri); return NULL; } ret = ldap_url_parse(esc_uri, &ludlist); if (ret != LDAP_URL_SUCCESS) { error(logopt, "Could not parse uri %s (%d)", dn_uri, ret); free(esc_uri); free(dn_uri); return NULL; } free(esc_uri); if (!ludlist) { error(logopt, "No dn found in uri %s", dn_uri); free(dn_uri); return NULL; } free(dn_uri); dclist = malloc(sizeof(struct dclist)); if (!dclist) { char *estr = strerror_r(errno, buf, MAX_ERR_BUF); error(logopt, "malloc: %s", estr); ldap_free_urldesc(ludlist); return NULL; } memset(dclist, 0, sizeof(struct dclist)); list = NULL; for (ludp = &ludlist; *ludp != NULL;) { LDAPURLDesc *lud = *ludp; struct srv_rr *dcs = NULL; unsigned int numdcs = 0; size_t req_len, len; char *request = NULL; char *tmp; int i; if (!lud->lud_dn && !lud->lud_dn[0] && (!lud->lud_host || !lud->lud_host[0])) { *ludp = lud->lud_next; continue; } domain = NULL; if (ldap_dn2domain(lud->lud_dn, &domain) || domain == NULL) { error(logopt, "Could not turn dn \"%s\" into a domain", lud->lud_dn); *ludp = lud->lud_next; continue; } debug(logopt, "doing lookup of SRV RRs for domain %s", domain); req_len = sizeof("_ldap._tcp.") + strlen(domain); request = malloc(req_len); if (!request) { char *estr = strerror_r(errno, buf, MAX_ERR_BUF); error(logopt, "malloc: %s", estr); goto out_error; } ret = snprintf(request, req_len, "_ldap._tcp.%s", domain); if (ret >= req_len) { free(request); goto out_error; } dclist_mutex_lock(); ret = get_srv_rrs(logopt, request, &dcs, &numdcs); if (!ret | !dcs) { error(logopt, "DNS SRV query failed for domain %s", domain); dclist_mutex_unlock(); free(request); goto out_error; } dclist_mutex_unlock(); free(request); len = strlen(lud->lud_scheme); len += sizeof("://"); len *= numdcs; for (i = 0; i < numdcs; i++) { if (dcs[i].ttl > 0 && dcs[i].ttl < min_ttl) min_ttl = dcs[i].ttl; len += strlen(dcs[i].name); if (dcs[i].port > 0) len += sizeof(":65535"); } tmp = realloc(list, len); if (!tmp) { char *estr = strerror_r(errno, buf, MAX_ERR_BUF); error(logopt, "realloc: %s", estr); free_srv_rrs(dcs, numdcs); goto out_error; } if (!list) memset(tmp, 0, len); else strcat(tmp, " "); list = NULL; for (i = 0; i < numdcs; i++) { if (i > 0) strcat(tmp, " "); strcat(tmp, lud->lud_scheme); strcat(tmp, "://"); strcat(tmp, dcs[i].name); if (dcs[i].port > 0) { char port[7]; ret = snprintf(port, 7, ":%d", dcs[i].port); if (ret > 6) { error(logopt, "invalid port: %u", dcs[i].port); free_srv_rrs(dcs, numdcs); free(tmp); goto out_error; } strcat(tmp, port); } } list = tmp; *ludp = lud->lud_next; ber_memfree(domain); free_srv_rrs(dcs, numdcs); } ldap_free_urldesc(ludlist); if (!list) goto out_error; dclist->expire = monotonic_time(NULL) + min_ttl; dclist->uri = list; return dclist; out_error: if (list) free(list); if (domain) ber_memfree(domain); ldap_free_urldesc(ludlist); free_dclist(dclist); return NULL; }