Blob Blame History Raw
#!/usr/bin/env python
# SPDX-License-Identifier: GPL-2.0-or-later
#
# Copyright (C) 2018 - 2019 Red Hat, Inc.
#

# nm-wg-set: modify an existing WireGuard connection profile.
#
#   $ nm-wg-set [id|uuid|interface] ID [wg-args...]
#
# The arguments to set the parameters are like the set parameters from `man 8 wg`.
# For example:
#
#   $ nm-wg-set wg0 peer wN8G5HpphoXOGkiXTgBPyr9BhrRm2z9JEI6BiH6fB0g= preshared-key <(wg genpsk)
#
# extra, script specific arguments:
#   - private-key-flags
#   - preshared-key-flags
#
# Note that the arguments have some simliarities to `wg set` command. But this
# script only modify the connection profile in NetworkManager. They don't (re)activate
# the profile and thus the changes only result in the configuration of the kernel interface
# after activating the profile. Use `nmcli connection up` for that.
#
# The example script does not support creating or deleting the WireGuard profile itself. It also
# does not support modifying other settings of the connection profile, like the IP address configuration.
# For that also use nmcli. For example:
#
#       PROFILE=wg0
#
#       # create the WireGuard profile with nmcli
#       PRIVKEY_FILE=/tmp/wg.key
#       (umask 077; rm -f "$PRIVKEY_FILE"; wg genkey > "$PRIVKEY_FILE")
#       IFNAME=wg0
#       PUBKEY=$(wg pubkey < "$PRIVKEY_FILE")
#       IP4ADDR=192.168.99.5/24
#       IP4GW=192.168.99.1
#       nmcli connection delete id "$PROFILE"
#       nmcli connection add \
#           type wireguard \
#           con-name "$PROFILE" \
#           ifname "$IFNAME" \
#           connection.stable-id "$PROFILE-$PUBKEY" \
#           ipv4.method manual \
#           ipv4.addresses "$IP4ADDR" \
#           ipv4.gateway "$IP4GW" \
#           ipv4.never-default yes \
#           ipv6.method link-local \
#           wireguard.listen-port 0 \
#           wireguard.fwmark 0 \
#           wireguard.private-key '' \
#           wireguard.private-key-flags 0
#       nmcli connection up \
#           id "$PROFILE" \
#           passwd-file <(echo "wireguard.private-key:$(cat "$PRIVKEY_FILE")")
#
#       # modify the WireGuard profile with the script
#       nm-wg-set id "$PROFILE" $WG_ARGS

import sys
import os

import gi

gi.require_version("NM", "1.0")
from gi.repository import NM


class MyError(Exception):
    pass


def pr(v):
    import pprint

    pprint.pprint(v, indent=4, depth=5, width=60)


###############################################################################


def connection_is_wireguard(conn):
    s_con = conn.get_setting(NM.SettingConnection)
    return (
        s_con
        and s_con.get_connection_type() == NM.SETTING_WIREGUARD_SETTING_NAME
        and conn.get_setting(NM.SettingWireGuard)
    )


def connection_to_str(conn):
    if connection_is_wireguard(conn):
        iface = conn.get_setting(NM.SettingConnection).get_interface_name()
        if iface:
            extra = ', interface: "%s"' % (iface)
        else:
            extra = ""
    else:
        extra = ", type: %s" % (
            conn.get_setting(NM.SettingConnection).get_connection_type()
        )

    return '"%s" (%s%s)' % (conn.get_id(), conn.get_uuid(), extra)


def connections_wg(connections):
    l = list([c for c in connections if connection_is_wireguard(c)])
    l.sort(key=connection_to_str)
    return l


def connections_find(connections, con_spec, con_id):
    connections = list(sorted(connections, key=connection_to_str))
    l = []
    if con_spec in [None, "id"]:
        for c in connections:
            if con_id == c.get_id():
                if c not in l:
                    l.append(c)
    if con_spec in [None, "interface"]:
        for c in connections:
            s_con = c.get_setting(NM.SettingConnection)
            if s_con and con_id == s_con.get_interface_name():
                if c not in l:
                    l.append(c)
    if con_spec in [None, "uuid"]:
        for c in connections:
            if con_id == c.get_uuid():
                if c not in l:
                    l.append(c)
    l.sort(key=connection_to_str)
    return l


def print_hint(nm_client):
    print("Maybe you want to create a profile first with")
    print("  nmcli connection add type wireguard ifname wg0 $MORE_ARGS")
    connections = connections_wg(nm_client.get_connections())
    if connections:
        print("Or edit one of the following WireGuard profiles:")
        for c in connections:
            print("  - %s" % (connection_to_str(c)))


###############################################################################


def argv_get_one(argv, idx, type_ctor=None, topic=None):

    if topic is not None:
        try:
            v = argv_get_one(argv, idx, type_ctor, None)
        except MyError as e:
            if isinstance(topic, (int, long)):
                topic = argv[topic]
            raise MyError('error for "%s": %s' % (topic, e.message))
        return v

    v = None
    try:
        v = argv[idx]
    except:
        raise MyError("missing argument")
    if type_ctor is not None:
        try:
            v = type_ctor(v)
        except Exception as e:
            raise MyError('invalid argument "%s" (%s)' % (v, e.message))
    return v


###############################################################################


def arg_parse_secret_flags(arg):
    try:
        f = arg.strip()
        n = {
            "none": NM.SettingSecretFlags.NONE,
            "not-saved": NM.SettingSecretFlags.NOT_SAVED,
            "not-required": NM.SettingSecretFlags.NOT_REQUIRED,
            "agent-owned": NM.SettingSecretFlags.AGENT_OWNED,
        }.get(f)
        if n is not None:
            return n
        return NM.SettingSecretFlags(int(f))
    except Exception as e:
        raise MyError('invalid secret flags "%s"' % (arg))


def _arg_parse_int(arg, vmin, vmax, key, base=0):
    try:
        v = int(arg, base)
        if v >= vmin and vmax <= 0xFFFFFFFF:
            return v
    except:
        raise MyError('invalid %s "%s"' % (key, arg))
    raise MyError("%s out of range" % (key))


def arg_parse_listen_port(arg):
    return _arg_parse_int(arg, 0, 0xFFFF, "listen-port")


def arg_parse_fwmark(arg):
    return _arg_parse_int(arg, 0, 0xFFFFFFFF, "fwmark", base=0)


def arg_parse_persistent_keep_alive(arg):
    return _arg_parse_int(arg, 0, 0xFFFFFFFF, "persistent-keepalive")


def arg_parse_allowed_ips(arg):
    l = [s.strip() for s in arg.strip().split(",")]
    l = [s for s in l if s != ""]
    l = list(l)
    # use a peer to parse and validate the allowed-ips.
    peer = NM.WireGuardPeer()
    for aip in l:
        if not peer.append_allowed_ip(aip, False):
            raise MyError('invalid allowed-ip "%s"' % (aip))
    return l


###############################################################################


def secret_flags_to_string(flags):
    nick = {
        NM.SettingSecretFlags.NONE: "none",
        NM.SettingSecretFlags.NOT_SAVED: "not-saved",
        NM.SettingSecretFlags.NOT_REQUIRED: "not-required",
        NM.SettingSecretFlags.AGENT_OWNED: "agent-owned",
    }.get(flags)
    num = str(int(flags))
    if nick is None:
        return num
    return "%s (%s)" % (num, nick)


def secret_to_string(secret):
    if os.environ.get("WG_HIDE_KEYS", "") != "never":
        return "(hidden)"
    if not secret:
        return ""
    return secret


def val_to_str(val):
    if val == NM.Ternary.DEFAULT:
        return "default"
    if val == NM.Ternary.TRUE:
        return "true"
    if val == NM.Ternary.FALSE:
        return "false"
    return repr(val)


###############################################################################


def wg_read_private_key(privkey_file):
    import base64

    try:
        with open(privkey_file, "r") as f:
            data = f.read()
        bdata = base64.decodestring(data)
        if len(bdata) != 32:
            raise Exception("not 32 bytes base64 encoded")
        return base64.encodestring(bdata).strip()
    except Exception as e:
        raise MyError('failed to read private key "%s": %s' % (privkey_file, e.message))


def wg_peer_is_valid(peer, msg=None):
    try:
        peer.is_valid(True, True)
    except gi.repository.GLib.Error as e:
        if msg is None:
            raise MyError("%s" % (e.message))
        else:
            raise MyError("%s" % (msg))


###############################################################################


def do_get(nm_client, connection):
    s_con = conn.get_setting(NM.SettingConnection)
    s_wg = conn.get_setting(NM.SettingWireGuard)

    # Fetching secrets is not implemented. For now show them all as
    # <hidden>.

    print("interface:                    %s" % (s_con.get_interface_name()))
    print("uuid:                         %s" % (conn.get_uuid()))
    print("id:                           %s" % (conn.get_id()))
    print(
        "private-key:                  %s" % (secret_to_string(s_wg.get_private_key()))
    )
    print(
        "private-key-flags:            %s"
        % (secret_flags_to_string(s_wg.get_private_key_flags()))
    )
    print("listen-port:                  %s" % (s_wg.get_listen_port()))
    print("fwmark:                       0x%x" % (s_wg.get_fwmark()))
    print("peer-routes:                  %s" % (val_to_str(s_wg.get_peer_routes())))
    print(
        "ip4-auto-default-route:       %s"
        % (val_to_str(s_wg.get_ip4_auto_default_route()))
    )
    print(
        "ip6-auto-default-route:       %s"
        % (val_to_str(s_wg.get_ip6_auto_default_route()))
    )
    for i in range(s_wg.get_peers_len()):
        peer = s_wg.get_peer(i)
        print("peer[%d].public-key:           %s" % (i, peer.get_public_key()))
        print(
            "peer[%d].preshared-key:        %s"
            % (i, secret_to_string(peer.get_preshared_key()))
        )
        print(
            "peer[%d].preshared-key-flags:  %s"
            % (i, secret_flags_to_string(peer.get_preshared_key_flags()))
        )
        print(
            "peer[%d].endpoint:             %s"
            % (i, peer.get_endpoint() if peer.get_endpoint() else "")
        )
        print(
            "peer[%d].persistent-keepalive: %s" % (i, peer.get_persistent_keepalive())
        )
        print(
            "peer[%d].allowed-ips:          %s"
            % (
                i,
                ",".join(
                    [peer.get_allowed_ip(j) for j in range(peer.get_allowed_ips_len())]
                ),
            )
        )


def do_set(nm_client, conn, argv):
    s_wg = conn.get_setting(NM.SettingWireGuard)
    peer = None
    peer_remove = False
    peer_idx = None
    peer_secret_flags = None

    try:
        idx = 0
        while True:
            if peer and (idx >= len(argv) or argv[idx] == "peer"):
                if peer_remove:
                    pp_peer, pp_idx = s_wg.get_peer_by_public_key(peer.get_public_key())
                    if pp_peer:
                        s_wg.remove_peer(pp_idx)
                else:
                    if peer_secret_flags is not None:
                        peer.set_preshared_key_flags(peer_secret_flags)
                    wg_peer_is_valid(peer)
                    if peer_idx is None:
                        s_wg.append_peer(peer)
                    else:
                        s_wg.set_peer(peer, peer_idx)
                peer = None
                peer_remove = False
                peer_idx = None
                peer_secret_flags = None

            if idx >= len(argv):
                break

            if not peer and argv[idx] == "private-key":
                key = argv_get_one(argv, idx + 1, None, idx)
                if key == "":
                    s_wg.set_property(NM.SETTING_WIREGUARD_PRIVATE_KEY, None)
                else:
                    s_wg.set_property(
                        NM.SETTING_WIREGUARD_PRIVATE_KEY, wg_read_private_key(key)
                    )
                idx += 2
                continue
            if not peer and argv[idx] == "private-key-flags":
                s_wg.set_property(
                    NM.SETTING_WIREGUARD_PRIVATE_KEY_FLAGS,
                    argv_get_one(argv, idx + 1, arg_parse_secret_flags, idx),
                )
                idx += 2
                continue
            if not peer and argv[idx] == "listen-port":
                s_wg.set_property(
                    NM.SETTING_WIREGUARD_LISTEN_PORT,
                    argv_get_one(argv, idx + 1, arg_parse_listen_port, idx),
                )
                idx += 2
                continue
            if not peer and argv[idx] == "fwmark":
                s_wg.set_property(
                    NM.SETTING_WIREGUARD_FWMARK,
                    argv_get_one(argv, idx + 1, arg_parse_fwmark, idx),
                )
                idx += 2
                continue
            if argv[idx] == "peer":
                public_key = argv_get_one(argv, idx + 1, None, idx)
                peer, peer_idx = s_wg.get_peer_by_public_key(public_key)
                if peer:
                    peer = peer.new_clone(True)
                else:
                    peer_idx = None
                    peer = NM.WireGuardPeer()
                    peer.set_public_key(public_key, True)
                    wg_peer_is_valid(peer, 'public key "%s" is invalid' % (public_key))
                peer_remove = False
                idx += 2
                continue
            if peer and argv[idx] == "remove":
                peer_remove = True
                idx += 1
                continue
            if peer and argv[idx] == "preshared-key":
                psk = argv_get_one(argv, idx + 1, None, idx)
                if psk == "":
                    peer.set_preshared_key(None, True)
                    if peer_secret_flags is not None:
                        peer_secret_flags = NM.SettingSecretFlags.NOT_REQUIRED
                else:
                    peer.set_preshared_key(wg_read_private_key(psk), True)
                    if peer_secret_flags is not None:
                        peer_secret_flags = NM.SettingSecretFlags.NONE
                idx += 2
                continue
            if peer and argv[idx] == "preshared-key-flags":
                peer_secret_flags = argv_get_one(
                    argv, idx + 1, arg_parse_secret_flags, idx
                )
                idx += 2
                continue
            if peer and argv[idx] == "endpoint":
                peer.set_endpoint(argv_get_one(argv, idx + 1, None, idx), True)
                idx += 2
                continue
            if peer and argv[idx] == "persistent-keepalive":
                peer.set_persistent_keepalive(
                    argv_get_one(argv, idx + 1, arg_parse_persistent_keep_alive, idx)
                )
                idx += 2
                continue
            if peer and argv[idx] == "allowed-ips":
                allowed_ips = list(
                    argv_get_one(argv, idx + 1, arg_parse_allowed_ips, idx)
                )
                peer.clear_allowed_ips()
                for aip in allowed_ips:
                    peer.append_allowed_ip(aip, False)
                del allowed_ips
                idx += 2
                continue

            raise MyError('invalid argument "%s"' % (argv[idx]))
    except MyError as e:
        print("Error: %s" % (e.message))
        sys.exit(1)

    try:
        conn.commit_changes(True, None)
    except Exception as e:
        print("failure to commit connection: %s" % (e))
        sys.exit(1)

    print("Success")
    sys.exit(0)


###############################################################################

if __name__ == "__main__":

    nm_client = NM.Client.new(None)

    argv = sys.argv
    del argv[0]

    con_spec = None
    if len(argv) >= 1:
        if argv[0] in ["id", "uuid", "interface"]:
            con_spec = argv[0]
            del argv[0]

    if len(argv) < 1:
        print(
            "Requires an existing NetworkManager connection profile as first argument"
        )
        print(
            "Select it based on the connection ID, UUID, or interface-name (optionally qualify the selection with [id|uuid|interface])"
        )
        print_hint(nm_client)
        sys.exit(1)
    con_id = argv[0]
    del argv[0]

    connections = connections_find(nm_client.get_connections(), con_spec, con_id)
    if len(connections) == 0:
        print(
            'No matching connection %s"%s" found.'
            % ((con_spec + " " if con_spec else ""), con_id)
        )
        print_hint(nm_client)
        sys.exit(1)
    if len(connections) > 1:
        print(
            'Connection %s"%s" is not unique (%s)'
            % (
                (con_spec + " " if con_spec else ""),
                con_id,
                ", ".join(["[" + connection_to_str(c) + "]" for c in connections]),
            )
        )
        if not con_spec:
            print("Maybe qualify the name with [id|uuid|interface]?")
        sys.exit(1)

    conn = connections[0]
    if not connection_is_wireguard(conn):
        print("Connection %s is not a WireGuard profile" % (connection_to_str(conn)))
        print("See available profiles with `nmcli connection show`")
        sys.exit(1)

    try:
        secrets = conn.get_secrets(NM.SETTING_WIREGUARD_SETTING_NAME)
        if secrets:
            conn.update_secrets(NM.SETTING_WIREGUARD_SETTING_NAME, secrets)
    except Exception:
        pass

    if not argv:
        do_get(nm_client, conn)
    else:
        do_set(nm_client, conn, argv)