Blob Blame History Raw
#
# Copyright (c) 2020 Red Hat, Inc.
#
# This file is part of nmstate
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 2.1 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
#

from contextlib import contextmanager
import importlib
import logging
from operator import itemgetter
from operator import attrgetter
import os
import pkgutil

from libnmstate import validator
from libnmstate.error import NmstateError
from libnmstate.error import NmstateValueError
from libnmstate.nm import NetworkManagerPlugin
from libnmstate.schema import DNS
from libnmstate.schema import Interface
from libnmstate.schema import Route
from libnmstate.schema import RouteRule

from .plugin import NmstatePlugin
from .state import merge_dict


@contextmanager
def plugin_context():
    plugins = _load_plugins()
    try:
        # Lowest priority plugin should perform actions first.
        plugins.sort(key=attrgetter("priority"))
        yield plugins
    except (Exception, KeyboardInterrupt):
        for plugin in plugins:
            if plugin.checkpoint:
                try:
                    plugin.rollback_checkpoint()
                # Don't complex thing by raise exception when handling another
                # exception, just log the rollback failure.
                except Exception as e:
                    logging.error(f"Rollback failed with error {e}")
        raise
    finally:
        for plugin in plugins:
            plugin.unload()


def show_with_plugins(plugins, include_status_data=None):
    for plugin in plugins:
        plugin.refresh_content()
    report = {}
    if include_status_data:
        report["capabilities"] = plugins_capabilities(plugins)

    report[Interface.KEY] = _get_interface_info_from_plugins(plugins)

    route_plugin = _find_plugin_for_capability(
        plugins, NmstatePlugin.PLUGIN_CAPABILITY_ROUTE
    )
    if route_plugin:
        report[Route.KEY] = route_plugin.get_routes()

    route_rule_plugin = _find_plugin_for_capability(
        plugins, NmstatePlugin.PLUGIN_CAPABILITY_ROUTE_RULE
    )
    if route_rule_plugin:
        report[RouteRule.KEY] = route_rule_plugin.get_route_rules()

    dns_plugin = _find_plugin_for_capability(
        plugins, NmstatePlugin.PLUGIN_CAPABILITY_DNS
    )
    if dns_plugin:
        report[DNS.KEY] = dns_plugin.get_dns_client_config()

    validator.schema_validate(report)
    return report


def plugins_capabilities(plugins):
    capabilities = set()
    for plugin in plugins:
        capabilities.update(set(plugin.capabilities))
    return list(capabilities)


def _load_plugins():
    plugins = [NetworkManagerPlugin()]
    plugins.extend(_load_external_py_plugins())
    return plugins


def _load_external_py_plugins():
    """
    Load module from folder defined in system evironment NMSTATE_PLUGIN_DIR,
    if empty, use the 'plugins' folder of current python file.
    """
    plugins = []
    plugin_dir = os.environ.get("NMSTATE_PLUGIN_DIR")
    if not plugin_dir:
        plugin_dir = f"{os.path.dirname(os.path.realpath(__file__))}/plugins"

    for _, name, ispkg in pkgutil.iter_modules([plugin_dir]):
        if name.startswith("nmstate_plugin_"):
            try:
                spec = importlib.util.spec_from_file_location(
                    name, f"{plugin_dir}/{name}.py"
                )
                plugin_module = importlib.util.module_from_spec(spec)
                spec.loader.exec_module(plugin_module)
                plugin = plugin_module.NMSTATE_PLUGIN()
                plugins.append(plugin)
            except Exception as error:
                logging.warning(f"Failed to load plugin {name}: {error}")

    return plugins


def _find_plugin_for_capability(plugins, capability):
    """
    Return the plugin with specified capability and highest priority.
    """
    chose_plugin = None
    for plugin in plugins:
        if (
            chose_plugin
            and capability in plugin.plugin_capabilities
            and plugin.priority > chose_plugin.priority
        ) or not chose_plugin:
            chose_plugin = plugin
    return chose_plugin


def _get_interface_info_from_plugins(plugins):
    all_ifaces = {}
    IFACE_PRIORITY_METADATA = "_plugin_priority"
    for plugin in plugins:
        if (
            NmstatePlugin.PLUGIN_CAPABILITY_IFACE
            not in plugin.plugin_capabilities
        ):
            continue
        for iface in plugin.get_interfaces():
            iface[IFACE_PRIORITY_METADATA] = plugin.priority
            iface_name = iface[Interface.NAME]
            if iface_name in all_ifaces:
                existing_iface = all_ifaces[iface_name]
                existing_priority = existing_iface[IFACE_PRIORITY_METADATA]
                current_priority = plugin.priority
                if current_priority > existing_priority:
                    merge_dict(iface, existing_iface)
                    all_ifaces[iface_name] = iface
                else:
                    merge_dict(existing_iface, iface)
            else:
                all_ifaces[iface_name] = iface

    # Remove metadata
    for iface in all_ifaces.values():
        iface.pop(IFACE_PRIORITY_METADATA)

    return sorted(all_ifaces.values(), key=itemgetter(Interface.NAME))


def create_checkpoints(plugins, timeout):
    """
    Return a string containing all the check point created by each plugin in
    the format:
        plugin.name|<checkpoing_path>|plugin.name|<checkpoing_path|...

    """
    checkpoints = []
    for plugin in plugins:
        checkpoint = plugin.create_checkpoint(timeout)
        if checkpoint:
            checkpoints.append(f"{plugin.name}|{checkpoint}")
    return "|".join(checkpoints)


def destroy_checkpoints(plugins, checkpoints):
    _checkpoint_action(plugins, _parse_checkpoints(checkpoints), "destroy")


def rollback_checkpoints(plugins, checkpoints):
    _checkpoint_action(plugins, _parse_checkpoints(checkpoints), "rollback")


def _checkpoint_action(plugins, checkpoint_index, action):
    errors = []
    for plugin in plugins:
        if checkpoint_index and plugin.name not in checkpoint_index:
            continue
        checkpoint = (
            checkpoint_index[plugin.name] if checkpoint_index else None
        )
        try:
            if action == "destroy":
                plugin.destroy_checkpoint(checkpoint)
            else:
                plugin.rollback_checkpoint(checkpoint)
        except (Exception, KeyboardInterrupt) as error:
            errors.append(error)

    if errors:
        if len(errors) == 1:
            raise errors[0]
        else:
            raise NmstateError(
                "Got multiple exception during checkpoint "
                f"{action}: {errors}"
            )


def _parse_checkpoints(checkpoints):
    """
    Return a dict mapping plugin name to checkpoint
    """
    if not checkpoints:
        return None
    parsed = checkpoints.split("|")
    if len(parsed) % 2:
        raise NmstateValueError("Invalid format of checkpoint")
    checkpoint_index = {}
    for plugin_name, checkpoint in zip(parsed[0::2], parsed[1::2]):
        checkpoint_index[plugin_name] = checkpoint