Blob Blame History Raw
"""JSON Communication

This module implements a client/server communication method based on JSON
serialization. It uses unix-domain-datagram-sockets and provides a simple
unicast message transmission.
"""


import array
import contextlib
import errno
import json
import os
import socket
from typing import Any
from typing import Optional
from .types import PathLike


class FdSet:
    """File-Descriptor Set

    This object wraps an array of file-descriptors. Unlike a normal integer
    array, this object owns the file-descriptors and therefore closes them once
    the object is released.

    File-descriptor sets are initialized once. From then one, the only allowed
    operation is to query it for information, or steal file-descriptors from
    it. If you close a set, all remaining file-descriptors are closed and
    removed from the set. It will then be an empty set.
    """

    _fds = array.array("i")

    def __init__(self, *, rawfds):
        for i in rawfds:
            if not isinstance(i, int) or i < 0:
                raise ValueError()

        self._fds = rawfds

    def __del__(self):
        self.close()

    def close(self):
        """Close All Entries

        This closes all stored file-descriptors and clears the set. Once this
        returns, the set will be empty. It is safe to call this multiple times.
        Note that a set is automatically closed when it is garbage collected.
        """

        for i in self._fds:
            if i >= 0:
                os.close(i)

        self._fds = array.array("i")

    @classmethod
    def from_list(cls, l: list):
        """Create new Set from List

        This creates a new file-descriptor set initialized to the same entries
        as in the given list. This consumes the file-descriptors. The caller
        must not assume ownership anymore.
        """

        fds = array.array("i")
        fds.fromlist(l)
        return cls(rawfds=fds)

    def __len__(self):
        return len(self._fds)

    def __getitem__(self, key: Any):
        if self._fds[key] < 0:
            raise IndexError
        return self._fds[key]

    def steal(self, key: Any):
        """Steal Entry

        Retrieve the entry at the given position, but drop it from the internal
        file-descriptor set. The caller will now own the file-descriptor and it
        can no longer be accessed through the set.

        Note that this does not reshuffle the set. All indices stay constant.
        """

        v = self[key]
        self._fds[key] = -1
        return v


class Socket(contextlib.AbstractContextManager):
    """Communication Socket

    This socket object represents a communication channel. It allows sending
    and receiving JSON-encoded messages. It uses unix-domain sequenced-packet
    sockets as underlying transport.
    """

    _socket = None
    _unlink = None

    def __init__(self, sock, unlink):
        self._socket = sock
        self._unlink = unlink

    def __del__(self):
        self.close()

    def __exit__(self, exc_type, exc_value, exc_tb):
        self.close()
        return False

    @property
    def blocking(self):
        """Get the current blocking mode of the socket.

        This is related to the socket's timeout, i.e. if no time out is set
        the socket is in blocking mode; otherwise it is non-blocking.
        """
        timeout = self._socket.gettimeout()
        return timeout is not None

    @blocking.setter
    def blocking(self, value: bool):
        """Set the blocking mode of the socket."""
        self._socket.setblocking(value)

    def accept(self) -> Optional["Socket"]:
        """Accept a new connection on the socket.

        See python's `socket.accept` for more information.
        """
        # Since, in the kernel, for AF_UNIX, new connection requests,
        # i.e. clients connecting, are directly put on the receive
        # queue of the listener socket, accept here *should* always
        # return a socket and not block, even if the client meanwhile
        # disconnected; we don't rely on that kernel behavior though
        try:
            conn, _ = self._socket.accept()
        except (socket.timeout, BlockingIOError):
            return None
        return Socket(conn, None)

    def listen(self, backlog: Optional[int] = 2**16):
        """Enable accepting of incoming connections.

        See python's `socket.listen` for details.
        """

        # `Socket.listen` accepts an `int` or no argument, but not `None`
        args = [backlog] if backlog is not None else []
        self._socket.listen(*args)

    def close(self):
        """Close Socket

        Close the socket and all underlying resources. This can be called
        multiple times.
        """

        # close the socket if it is set
        if self._socket is not None:
            self._socket.close()
            self._socket = None

        # unlink the file-system entry, if pinned
        if self._unlink is not None:
            try:
                os.unlink(self._unlink[1], dir_fd=self._unlink[0])
            except OSError as e:
                if e.errno != errno.ENOENT:
                    raise

            os.close(self._unlink[0])
            self._unlink = None

    @classmethod
    def new_client(cls, connect_to: Optional[PathLike] = None):
        """Create Client

        Create a new client socket.

        Parameters
        ----------
        connect_to
            If not `None`, the client will use the specified address as the
            default destination for all send operations.
        """

        sock = None

        try:
            sock = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)

            # Trigger an auto-bind. If you do not do this, you might end up with
            # an unbound unix socket, which cannot receive messages.
            # Alternatively, you can also set `SO_PASSCRED`, but this has
            # side-effects.
            sock.bind("")

            # Connect the socket. This has no effect other than specifying the
            # default destination for send operations.
            if connect_to is not None:
                sock.connect(os.fspath(connect_to))
        except:
            if sock is not None:
                sock.close()
            raise

        return cls(sock, None)

    @classmethod
    def new_server(cls, bind_to: PathLike):
        """Create Server

        Create a new listener socket. Returned socket is in non-blocking
        mode by default. See `blocking` property.

        Parameters
        ----------
        bind_to
            The socket-address to listen on for incoming client requests.
        """

        sock = None
        unlink = None
        path = os.path.split(bind_to)

        try:
            # We bind the socket and then open a directory-fd on the target
            # socket. This allows us to properly unlink the socket when the
            # server is closed. Note that sockets are never automatically
            # cleaned up on linux, nor can you bind to existing sockets.
            # We use a dirfd to guarantee this works even when you change
            # your mount points in-between.
            # Yeah, this is racy when mount-points change between the socket
            # creation and open. But then your entire socket creation is racy
            # as well. We do not guarantee atomicity, so you better make sure
            # you do not rely on it.
            sock = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
            sock.bind(os.fspath(bind_to))
            unlink = os.open(os.path.join(".", path[0]), os.O_CLOEXEC | os.O_PATH)
            sock.setblocking(False)
        except:
            if unlink is not None:
                os.close(unlink)
            if sock is not None:
                sock.close()
            raise

        return cls(sock, (unlink, path[1]))

    def fileno(self) -> int:
        assert self._socket is not None
        return self._socket.fileno()

    def recv(self):
        """Receive a Message

        This receives the next pending message from the socket. This operation
        is synchronous.

        A tuple consisting of the deserialized message payload, the auxiliary
        file-descriptor set, and the socket-address of the sender is returned.

        In case the peer closed the connection, A tuple of `None` values is
        returned.
        """

        # On `SOCK_SEQPACKET`, packets might be arbitrarily sized. There is no
        # hard-coded upper limit, since it is only restricted by the size of
        # the kernel write buffer on sockets (which itself can be modified via
        # sysctl). The only real maximum is probably something like 2^31-1,
        # since that is the maximum of that sysctl datatype.
        # Anyway, `MSG_TRUNC+MSG_PEEK` usually allows us to easily peek at the
        # incoming buffer. Unfortunately, the python `recvmsg()` wrapper
        # discards the return code and we cannot use that. Instead, we simply
        # loop until we know the size. This is slightly awkward, but seems fine
        # as long as you do not put this into a hot-path.
        size = 4096
        while True:
            peek = self._socket.recvmsg(size, 0, socket.MSG_PEEK)
            if not peek[0]:
                # Connection was closed
                return None, None, None
            if not (peek[2] & socket.MSG_TRUNC):
                break
            size *= 2

        # Fetch a packet from the socket. On linux, the maximum SCM_RIGHTS array
        # size is hard-coded to 253. This allows us to size the ancillary buffer
        # big enough to receive any possible message.
        fds = array.array("i")
        msg = self._socket.recvmsg(size, socket.CMSG_LEN(253 * fds.itemsize))

        # First thing we do is always to fetch the CMSG FDs into an FdSet. This
        # guarantees that we do not leak FDs in case the message handling fails
        # for other reasons.
        for level, ty, data in msg[1]:
            if level == socket.SOL_SOCKET and ty == socket.SCM_RIGHTS:
                assert len(data) % fds.itemsize == 0
                fds.frombytes(data)
        fdset = FdSet(rawfds=fds)

        # Check the returned message flags. If the message was truncated, we
        # have to discard it. This shouldn't happen, but there is no harm in
        # handling it. However, `CTRUNC` can happen, since it is also triggered
        # when LSMs reject FD transmission. Treat it the same as a parser error.
        flags = msg[2]
        if flags & (socket.MSG_TRUNC | socket.MSG_CTRUNC):
            raise BufferError

        try:
            payload = json.loads(msg[0])
        except json.JSONDecodeError:
            raise BufferError

        return (payload, fdset, msg[3])

    def send(self, payload: object, *, fds: Optional[list] = None):
        """Send Message

        Send a new message via this socket. This operation is synchronous. The
        maximum message size depends on the configured send-buffer on the
        socket. An `OSError` with `EMSGSIZE` is raised when it is exceeded.

        Parameters
        ----------
        payload
            A python object to serialize as JSON and send via this socket. See
            `json.dump()` for details about the serialization involved.
        destination
            The destination to send to. If `None`, the default destination is
            used (if none is set, this will raise an `OSError`).
        fds
            A list of file-descriptors to send with the message.

        Raises
        ------
        OSError
            If the socket cannot be written, a matching `OSError` is raised.
        TypeError
            If the payload cannot be serialized, a type error is raised.
        """

        serialized = json.dumps(payload).encode()
        cmsg = []
        if fds:
            cmsg.append((socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", fds)))

        n = self._socket.sendmsg([serialized], cmsg, 0)
        assert n == len(serialized)