Blob Blame History Raw
From 42f2abe76fcd9fad7dad3777b3aa9a368cf85d6e Mon Sep 17 00:00:00 2001
From: David Fisher <ddf1991@gmail.com>
Date: Wed, 25 Sep 2013 17:07:30 -0700
Subject: [PATCH 1/1] feat(activation): add socket activation

Checks for sockets on startup and uses them if available. Uses the
proper (socket activated) port it's listening on in its advertised URL.
---
 activation.go  | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 etcd.go        |  9 ++++++
 etcd_server.go |  4 +--
 raft_server.go |  4 +--
 4 files changed, 110 insertions(+), 4 deletions(-)
 create mode 100644 activation.go

diff --git a/activation.go b/activation.go
new file mode 100644
index 0000000..43707f0
--- /dev/null
+++ b/activation.go
@@ -0,0 +1,97 @@
+package main
+
+import (
+	"github.com/coreos/go-systemd/activation"
+	"net"
+	"net/http"
+	"crypto/tls"
+	"net/url"
+)
+
+var activatedSockets []net.Listener
+
+const expectedSockets = 2
+
+const (
+	etcdSock int = iota
+	raftSock
+)
+
+func init() {
+	files := activation.Files()
+	if files == nil || len(files) == 0 {
+		// no socket activation attempted
+		activatedSockets = nil
+	} else if len(files) == expectedSockets {
+		// socket activation
+		activatedSockets = make([]net.Listener, len(files))
+		for i, f := range files {
+			var err error
+			activatedSockets[i], err = net.FileListener(f)
+			if err != nil {
+				fatal("socket activation failure: ", err)
+			}
+		}
+	} else {
+		// socket activation attempted with incorrect number of sockets
+		activatedSockets = nil
+		fatalf("socket activation failure: %d sockets received, %d expected.", len(files), expectedSockets)
+	}
+}
+
+func socketActivated() bool {
+	return activatedSockets != nil
+}
+
+func ActivateListenAndServe(srv *http.Server, sockno int) error {
+	if !socketActivated() {
+		return srv.ListenAndServe()
+	} else {
+		return srv.Serve(activatedSockets[sockno])
+	}
+}
+
+func ActivateListenAndServeTLS(srv *http.Server, sockno int, certFile, keyFile string) error {
+	if !socketActivated() {
+		return srv.ListenAndServeTLS(certFile, keyFile)
+	} else {
+		config := &tls.Config{}
+		if srv.TLSConfig != nil {
+			*config = *srv.TLSConfig
+		}
+		if config.NextProtos == nil {
+			config.NextProtos = []string{"http/1.1"}
+		}
+
+		var err error
+		config.Certificates = make([]tls.Certificate, 1)
+		config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
+		if err != nil {
+			return err
+		}
+
+		tlsListener := tls.NewListener(activatedSockets[sockno], config)
+		return srv.Serve(tlsListener)
+	}
+}
+
+func getActivatedPort(sockno int) string {
+	activatedAddr := activatedSockets[sockno].Addr().String()
+	_, port, err := net.SplitHostPort(activatedAddr)
+	if err != nil {
+		fatal(err)
+	}
+	return port
+}
+
+func useActivatedPort(staticURL string, sockno int) string {
+	port := getActivatedPort(sockno)
+
+	static, err := url.Parse(staticURL)
+	host, _, err := net.SplitHostPort(static.Host)
+	if err != nil {
+		fatal(err)
+	}
+
+	return (&url.URL{Host: net.JoinHostPort(host, port), Scheme:static.Scheme}).String()
+}
diff --git a/etcd.go b/etcd.go
index c1b6b9e..49426b6 100644
--- a/etcd.go
+++ b/etcd.go
@@ -223,6 +223,15 @@ func main() {
 
 	info := getInfo(dirPath)
 
+	// Used socket activated port in advertised URLs, if applicable
+	if socketActivated() {
+		info.RaftURL = useActivatedPort(info.RaftURL, raftSock)
+		info.EtcdURL = useActivatedPort(info.EtcdURL, etcdSock)
+
+		info.RaftListenHost = ":" + getActivatedPort(raftSock)
+		info.EtcdListenHost = ":" + getActivatedPort(etcdSock)
+	}
+
 	// Create etcd key-value store
 	etcdStore = store.CreateStore(maxSize)
 	snapConf = newSnapshotConf()
diff --git a/etcd_server.go b/etcd_server.go
index d72c1b7..81a8d6d 100644
--- a/etcd_server.go
+++ b/etcd_server.go
@@ -50,8 +50,8 @@ func (e *etcdServer) ListenAndServe() {
 	infof("etcd server [name %s, listen on %s, advertised url %s]", e.name, e.Server.Addr, e.url)
 
 	if e.tlsConf.Scheme == "http" {
-		fatal(e.Server.ListenAndServe())
+		fatal(ActivateListenAndServe(&e.Server, etcdSock))
 	} else {
-		fatal(e.Server.ListenAndServeTLS(e.tlsInfo.CertFile, e.tlsInfo.KeyFile))
+		fatal(ActivateListenAndServeTLS(&e.Server,etcdSock,  e.tlsInfo.CertFile, e.tlsInfo.KeyFile))
 	}
 }
diff --git a/raft_server.go b/raft_server.go
index 580a565..7da281b 100644
--- a/raft_server.go
+++ b/raft_server.go
@@ -190,9 +190,9 @@ func (r *raftServer) startTransport(scheme string, tlsConf tls.Config) {
 	raftMux.HandleFunc("/etcdURL", EtcdURLHttpHandler)
 
 	if scheme == "http" {
-		fatal(server.ListenAndServe())
+		fatal(ActivateListenAndServe(server, raftSock))
 	} else {
-		fatal(server.ListenAndServeTLS(r.tlsInfo.CertFile, r.tlsInfo.KeyFile))
+		fatal(ActivateListenAndServeTLS(server, raftSock, r.tlsInfo.CertFile, r.tlsInfo.KeyFile))
 	}
 
 }
-- 
1.8.3.1