Blob Blame History Raw
#!/usr/bin/env python
#
# (C) 2014 by Ana Rey Botello <anarey@gmail.com>
#
# Based on iptables-test.py:
# (C) 2012 by Pablo Neira Ayuso <pablo@netfilter.org>"
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Thanks to the Outreach Program for Women (OPW) for sponsoring this test
# infrastructure.

from __future__ import print_function
import sys
import os
import argparse
import signal
import json
import traceback
import tempfile

TESTS_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(TESTS_PATH, '../../py/'))
os.environ['TZ'] = 'UTC-2'

from nftables import Nftables

TESTS_DIRECTORY = ["any", "arp", "bridge", "inet", "ip", "ip6"]
LOGFILE = "/tmp/nftables-test.log"
log_file = None
table_list = []
chain_list = []
all_set = dict()
obj_list = []
signal_received = 0


class Colors:
    if sys.stdout.isatty():
        HEADER = '\033[95m'
        GREEN = '\033[92m'
        YELLOW = '\033[93m'
        RED = '\033[91m'
        ENDC = '\033[0m'
    else:
        HEADER = ''
        GREEN = ''
        YELLOW = ''
        RED = ''
        ENDC = ''


class Chain:
    """Class that represents a chain"""

    def __init__(self, name, config, lineno):
        self.name = name
        self.config = config
        self.lineno = lineno

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __str__(self):
        return "%s" % self.name


class Table:
    """Class that represents a table"""

    def __init__(self, family, name, chains):
        self.family = family
        self.name = name
        self.chains = chains

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __str__(self):
        return "%s %s" % (self.family, self.name)


class Set:
    """Class that represents a set"""

    def __init__(self, family, table, name, type, timeout, flags):
        self.family = family
        self.table = table
        self.name = name
        self.type = type
        self.timeout = timeout
        self.flags = flags

    def __eq__(self, other):
        return self.__dict__ == other.__dict__


class Obj:
    """Class that represents an object"""

    def __init__(self, table, family, name, type, spcf):
        self.table = table
        self.family = family
        self.name = name
        self.type = type
        self.spcf = spcf

    def __eq__(self, other):
        return self.__dict__ == other.__dict__


def print_msg(reason, errstr, filename=None, lineno=None, color=None):
    '''
    Prints a message with nice colors, indicating file and line number.
    '''
    color_errstr = "%s%s%s" % (color, errstr, Colors.ENDC)
    if filename and lineno:
        sys.stderr.write("%s: %s line %d: %s\n" %
                         (filename, color_errstr, lineno + 1, reason))
    else:
        sys.stderr.write("%s %s\n" % (color_errstr, reason))
    sys.stderr.flush() # So that the message stay in the right place.


def print_error(reason, filename=None, lineno=None):
    print_msg(reason, "ERROR:", filename, lineno, Colors.RED)


def print_warning(reason, filename=None, lineno=None):
    print_msg(reason, "WARNING:", filename, lineno, Colors.YELLOW)

def print_info(reason, filename=None, lineno=None):
    print_msg(reason, "INFO:", filename, lineno, Colors.GREEN)

def color_differences(rule, other, color):
    rlen = len(rule)
    olen = len(other)
    out = ""
    i = 0

    # find equal part at start
    for i in range(rlen):
        if i >= olen or rule[i] != other[i]:
            break
    if i > 0:
        out += rule[:i]
        rule = rule[i:]
        other = other[i:]
        rlen = len(rule)
        olen = len(other)

    # find equal part at end
    for i in range(1, rlen + 1):
        if i > olen or rule[rlen - i] != other[olen - i]:
            i -= 1
            break
    if rlen > i:
        out += color + rule[:rlen - i] + Colors.ENDC
        rule = rule[rlen - i:]

    out += rule
    return out

def print_differences_warning(filename, lineno, rule1, rule2, cmd):
    colored_rule1 = color_differences(rule1, rule2, Colors.YELLOW)
    colored_rule2 = color_differences(rule2, rule1, Colors.YELLOW)
    reason = "'%s': '%s' mismatches '%s'" % (cmd, colored_rule1, colored_rule2)
    print_warning(reason, filename, lineno)


def print_differences_error(filename, lineno, cmd):
    reason = "'%s': Listing is broken." % cmd
    print_error(reason, filename, lineno)


def table_exist(table, filename, lineno):
    '''
    Exists a table.
    '''
    cmd = "list table %s" % table
    ret = execute_cmd(cmd, filename, lineno)

    return True if (ret == 0) else False


def table_flush(table, filename, lineno):
    '''
    Flush a table.
    '''
    cmd = "flush table %s" % table
    execute_cmd(cmd, filename, lineno)

    return cmd


def table_create(table, filename, lineno):
    '''
    Adds a table.
    '''
    # We check if table exists.
    if table_exist(table, filename, lineno):
        reason = "Table %s already exists" % table
        print_error(reason, filename, lineno)
        return -1

    table_list.append(table)

    # We add a new table
    cmd = "add table %s" % table
    ret = execute_cmd(cmd, filename, lineno)

    if ret != 0:
        reason = "Cannot " + cmd
        print_error(reason, filename, lineno)
        table_list.remove(table)
        return -1

    # We check if table was added correctly.
    if not table_exist(table, filename, lineno):
        table_list.remove(table)
        reason = "I have just added the table %s " \
                 "but it does not exist. Giving up!" % table
        print_error(reason, filename, lineno)
        return -1

    for table_chain in table.chains:
        chain = chain_get_by_name(table_chain)
        if chain is None:
            reason = "The chain %s requested by table %s " \
                     "does not exist." % (table_chain, table)
            print_error(reason, filename, lineno)
        else:
            chain_create(chain, table, filename)

    return 0


def table_delete(table, filename=None, lineno=None):
    '''
    Deletes a table.
    '''
    if not table_exist(table, filename, lineno):
        reason = "Table %s does not exist but I added it before." % table
        print_error(reason, filename, lineno)
        return -1

    cmd = "delete table %s" % table
    ret = execute_cmd(cmd, filename, lineno)
    if ret != 0:
        reason = "%s: I cannot delete table %s. Giving up!" % (cmd, table)
        print_error(reason, filename, lineno)
        return -1

    if table_exist(table, filename, lineno):
        reason = "I have just deleted the table %s " \
                 "but it still exists." % table
        print_error(reason, filename, lineno)
        return -1

    return 0


def chain_exist(chain, table, filename):
    '''
    Checks a chain
    '''
    cmd = "list chain %s %s" % (table, chain)
    ret = execute_cmd(cmd, filename, chain.lineno)

    return True if (ret == 0) else False


def chain_create(chain, table, filename):
    '''
    Adds a chain
    '''
    if chain_exist(chain, table, filename):
        reason = "This chain '%s' exists in %s. I cannot create " \
                 "two chains with same name." % (chain, table)
        print_error(reason, filename, chain.lineno)
        return -1

    cmd = "add chain %s %s" % (table, chain)
    if chain.config:
        cmd += " { %s; }" % chain.config

    ret = execute_cmd(cmd, filename, chain.lineno)
    if ret != 0:
        reason = "I cannot create the chain '%s'" % chain
        print_error(reason, filename, chain.lineno)
        return -1

    if not chain_exist(chain, table, filename):
        reason = "I have added the chain '%s' " \
                 "but it does not exist in %s" % (chain, table)
        print_error(reason, filename, chain.lineno)
        return -1

    return 0


def chain_delete(chain, table, filename=None, lineno=None):
    '''
    Flushes and deletes a chain.
    '''
    if not chain_exist(chain, table, filename):
        reason = "The chain %s does not exist in %s. " \
                 "I cannot delete it." % (chain, table)
        print_error(reason, filename, lineno)
        return -1

    cmd = "flush chain %s %s" % (table, chain)
    ret = execute_cmd(cmd, filename, lineno)
    if ret != 0:
        reason = "I cannot " + cmd
        print_error(reason, filename, lineno)
        return -1

    cmd = "delete chain %s %s" % (table, chain)
    ret = execute_cmd(cmd, filename, lineno)
    if ret != 0:
        reason = "I cannot " + cmd
        print_error(reason, filename, lineno)
        return -1

    if chain_exist(chain, table, filename):
        reason = "The chain %s exists in %s. " \
                 "I cannot delete this chain" % (chain, table)
        print_error(reason, filename, lineno)
        return -1

    return 0


def chain_get_by_name(name):
    for chain in chain_list:
        if chain.name == name:
            break
    else:
        chain = None

    return chain


def set_add(s, test_result, filename, lineno):
    '''
    Adds a set.
    '''
    if not table_list:
        reason = "Missing table to add rule"
        print_error(reason, filename, lineno)
        return -1

    for table in table_list:
        s.table = table.name
        s.family = table.family
        if _set_exist(s, filename, lineno):
            reason = "Set %s already exists in %s" % (s.name, table)
            print_error(reason, filename, lineno)
            return -1

        flags = s.flags
        if flags != "":
            flags = "flags %s; " % flags

        cmd = "add set %s %s { type %s;%s %s}" % (table, s.name, s.type, s.timeout, flags)
        ret = execute_cmd(cmd, filename, lineno)

        if (ret == 0 and test_result == "fail") or \
                (ret != 0 and test_result == "ok"):
            reason = "%s: I cannot add the set %s" % (cmd, s.name)
            print_error(reason, filename, lineno)
            return -1

        if not _set_exist(s, filename, lineno):
            reason = "I have just added the set %s to " \
                     "the table %s but it does not exist" % (s.name, table)
            print_error(reason, filename, lineno)
            return -1

    return 0


def set_add_elements(set_element, set_name, state, filename, lineno):
    '''
    Adds elements to the set.
    '''
    if not table_list:
        reason = "Missing table to add rules"
        print_error(reason, filename, lineno)
        return -1

    for table in table_list:
        # Check if set exists.
        if (not set_exist(set_name, table, filename, lineno) or
                    set_name not in all_set) and state == "ok":
            reason = "I cannot add an element to the set %s " \
                     "since it does not exist." % set_name
            print_error(reason, filename, lineno)
            return -1

        element = ", ".join(set_element)
        cmd = "add element %s %s { %s }" % (table, set_name, element)
        ret = execute_cmd(cmd, filename, lineno)

        if (state == "fail" and ret == 0) or (state == "ok" and ret != 0):
            test_state = "This rule should have failed."
            reason = cmd + ": " + test_state
            print_error(reason, filename, lineno)
            return -1

        # Add element into all_set.
        if ret == 0 and state == "ok":
            for e in set_element:
                all_set[set_name].add(e)

    return 0


def set_delete_elements(set_element, set_name, table, filename=None,
                        lineno=None):
    '''
    Deletes elements in a set.
    '''
    for element in set_element:
        cmd = "delete element %s %s { %s }" % (table, set_name, element)
        ret = execute_cmd(cmd, filename, lineno)
        if ret != 0:
            reason = "I cannot delete element %s " \
                     "from the set %s" % (element, set_name)
            print_error(reason, filename, lineno)
            return -1

    return 0


def set_delete(table, filename=None, lineno=None):
    '''
    Deletes set and its content.
    '''
    for set_name in all_set.keys():
        # Check if exists the set
        if not set_exist(set_name, table, filename, lineno):
            reason = "The set %s does not exist, " \
                     "I cannot delete it" % set_name
            print_error(reason, filename, lineno)
            return -1

        # We delete all elements in the set
        set_delete_elements(all_set[set_name], set_name, table, filename,
                            lineno)

        # We delete the set.
        cmd = "delete set %s %s" % (table, set_name)
        ret = execute_cmd(cmd, filename, lineno)

        # Check if the set still exists after I deleted it.
        if ret != 0 or set_exist(set_name, table, filename, lineno):
            reason = "Cannot remove the set " + set_name
            print_error(reason, filename, lineno)
            return -1

    return 0


def set_exist(set_name, table, filename, lineno):
    '''
    Check if the set exists.
    '''
    cmd = "list set %s %s" % (table, set_name)
    ret = execute_cmd(cmd, filename, lineno)

    return True if (ret == 0) else False


def _set_exist(s, filename, lineno):
    '''
    Check if the set exists.
    '''
    cmd = "list set %s %s %s" % (s.family, s.table, s.name)
    ret = execute_cmd(cmd, filename, lineno)

    return True if (ret == 0) else False


def set_check_element(rule1, rule2):
    '''
    Check if element exists in anonymous sets.
    '''
    pos1 = rule1.find("{")
    pos2 = rule2.find("{")

    if (rule1[:pos1] != rule2[:pos2]):
        return False

    end1 = rule1.find("}")
    end2 = rule2.find("}")

    if (pos1 != -1) and (pos2 != -1) and (end1 != -1) and (end2 != -1):
        list1 = (rule1[pos1 + 1:end1].replace(" ", "")).split(",")
        list2 = (rule2[pos2 + 1:end2].replace(" ", "")).split(",")
        list1.sort()
        list2.sort()
        if list1 != list2:
            return False

        return rule1[end1:] == rule2[end2:]

    return False


def obj_add(o, test_result, filename, lineno):
    '''
    Adds an object.
    '''
    if not table_list:
        reason = "Missing table to add rule"
        print_error(reason, filename, lineno)
        return -1

    for table in table_list:
        o.table = table.name
        o.family = table.family
        obj_handle = o.type + " " + o.name
        if _obj_exist(o, filename, lineno):
            reason = "The %s already exists in %s" % (obj_handle, table)
            print_error(reason, filename, lineno)
            return -1

        cmd = "add %s %s %s %s" % (o.type, table, o.name, o.spcf)
        ret = execute_cmd(cmd, filename, lineno)

        if (ret == 0 and test_result == "fail") or \
                (ret != 0 and test_result == "ok"):
            reason = "%s: I cannot add the %s" % (cmd, obj_handle)
            print_error(reason, filename, lineno)
            return -1

        exist = _obj_exist(o, filename, lineno)

        if exist:
            if test_result == "ok":
                 return 0
            reason = "I added the %s to the table %s " \
                     "but it should have failed" % (obj_handle, table)
            print_error(reason, filename, lineno)
            return -1

        if test_result == "fail":
            return 0

        reason = "I have just added the %s to " \
                 "the table %s but it does not exist" % (obj_handle, table)
        print_error(reason, filename, lineno)
        return -1

def obj_delete(table, filename=None, lineno=None):
    '''
    Deletes object.
    '''
    for o in obj_list:
        obj_handle = o.type + " " + o.name
        # Check if exists the obj
        if not obj_exist(o, table, filename, lineno):
            reason = "The %s does not exist, I cannot delete it" % obj_handle
            print_error(reason, filename, lineno)
            return -1

        # We delete the object.
        cmd = "delete %s %s %s" % (o.type, table, o.name)
        ret = execute_cmd(cmd, filename, lineno)

        # Check if the object still exists after I deleted it.
        if ret != 0 or obj_exist(o, table, filename, lineno):
            reason = "Cannot remove the " + obj_handle
            print_error(reason, filename, lineno)
            return -1

    return 0


def obj_exist(o, table, filename, lineno):
    '''
    Check if the object exists.
    '''
    cmd = "list %s %s %s" % (o.type, table, o.name)
    ret = execute_cmd(cmd, filename, lineno)

    return True if (ret == 0) else False


def _obj_exist(o, filename, lineno):
    '''
    Check if the object exists.
    '''
    cmd = "list %s %s %s %s" % (o.type, o.family, o.table, o.name)
    ret = execute_cmd(cmd, filename, lineno)

    return True if (ret == 0) else False


def output_clean(pre_output, chain):
    pos_chain = pre_output.find(chain.name)
    if pos_chain == -1:
        return ""
    output_intermediate = pre_output[pos_chain:]
    brace_start = output_intermediate.find("{")
    brace_end = output_intermediate.find("}")
    pre_rule = output_intermediate[brace_start:brace_end]
    if pre_rule[1:].find("{") > -1:  # this rule has a set.
        set = pre_rule[1:].replace("\t", "").replace("\n", "").strip()
        set = set.split(";")[2].strip() + "}"
        remainder = output_clean(chain.name + " {;;" + output_intermediate[brace_end+1:], chain)
        if len(remainder) <= 0:
            return set
        return set + " " + remainder
    else:
        rule = pre_rule.split(";")[2].replace("\t", "").replace("\n", "").\
            strip()
    if len(rule) < 0:
        return ""
    return rule


def payload_check_elems_to_set(elems):
    newset = set()

    for n, line in enumerate(elems.split('[end]')):
        e = line.strip()
        if e in newset:
            print_error("duplicate", e, n)
            return newset

        newset.add(e)

    return newset


def payload_check_set_elems(want, got):
    if want.find('element') < 0 or want.find('[end]') < 0:
        return 0

    if got.find('element') < 0 or got.find('[end]') < 0:
        return 0

    set_want = payload_check_elems_to_set(want)
    set_got = payload_check_elems_to_set(got)

    return set_want == set_got


def payload_check(payload_buffer, file, cmd):
    file.seek(0, 0)
    i = 0

    if not payload_buffer:
        return False

    for lineno, want_line in enumerate(payload_buffer):
        line = file.readline()

        if want_line == line:
            i += 1
            continue

        if want_line.find('[') < 0 and line.find('[') < 0:
            continue
        if want_line.find(']') < 0 and line.find(']') < 0:
            continue

        if payload_check_set_elems(want_line, line):
            continue

        print_differences_warning(file.name, lineno, want_line.strip(),
                                  line.strip(), cmd)
        return 0

    return i > 0


def json_dump_normalize(json_string, human_readable = False):
    json_obj = json.loads(json_string)

    if human_readable:
        return json.dumps(json_obj, sort_keys = True,
                          indent = 4, separators = (',', ': '))
    else:
        return json.dumps(json_obj, sort_keys = True)

def json_validate(json_string):
    json_obj = json.loads(json_string)
    try:
        nftables.json_validate(json_obj)
    except Exception:
        print_error("schema validation failed for input '%s'" % json_string)
        print_error(traceback.format_exc())

def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
    '''
    Adds a rule
    '''
    # TODO Check if a rule is added correctly.
    ret = warning = error = unit_tests = 0

    if not table_list or not chain_list:
        reason = "Missing table or chain to add rule."
        print_error(reason, filename, lineno)
        return [-1, warning, error, unit_tests]

    if rule[1].strip() == "ok":
        payload_expected = None
        try:
            payload_log = open("%s.payload" % filename_path)
            payload_expected = payload_find_expected(payload_log, rule[0])
        except:
            payload_log = None

        if enable_json_option:
            try:
                json_log = open("%s.json" % filename_path)
                json_input = json_find_expected(json_log, rule[0])
            except:
                json_input = None

            if not json_input:
                print_error("did not find JSON equivalent for rule '%s'"
                            % rule[0])
            else:
                try:
                    json_input = json_dump_normalize(json_input)
                except ValueError:
                    reason = "Invalid JSON syntax in rule: %s" % json_input
                    print_error(reason)
                    return [-1, warning, error, unit_tests]

            try:
                json_log = open("%s.json.output" % filename_path)
                json_expected = json_find_expected(json_log, rule[0])
            except:
                # will use json_input for comparison
                json_expected = None

            if json_expected:
                try:
                    json_expected = json_dump_normalize(json_expected)
                except ValueError:
                    reason = "Invalid JSON syntax in expected output: %s" % json_expected
                    print_error(reason)
                    return [-1, warning, error, unit_tests]

    for table in table_list:
        if rule[1].strip() == "ok":
            table_payload_expected = None
            try:
                payload_log = open("%s.payload.%s" % (filename_path, table.family))
                table_payload_expected = payload_find_expected(payload_log, rule[0])
            except:
                if not payload_log:
                    print_error("did not find any payload information",
                                filename_path)
                elif not payload_expected:
                    print_error("did not find payload information for "
                                "rule '%s'" % rule[0], payload_log.name, 1)
            if not table_payload_expected:
                table_payload_expected = payload_expected

        for table_chain in table.chains:
            chain = chain_get_by_name(table_chain)
            unit_tests += 1
            table_flush(table, filename, lineno)

            payload_log = tempfile.TemporaryFile(mode="w+")

            # Add rule and check return code
            cmd = "add rule %s %s %s" % (table, chain, rule[0])
            ret = execute_cmd(cmd, filename, lineno, payload_log, debug="netlink")

            state = rule[1].rstrip()
            if (ret in [0,134] and state == "fail") or (ret != 0 and state == "ok"):
                if state == "fail":
                    test_state = "This rule should have failed."
                else:
                    test_state = "This rule should not have failed."
                reason = cmd + ": " + test_state
                print_error(reason, filename, lineno)
                ret = -1
                error += 1
                if not force_all_family_option:
                    return [ret, warning, error, unit_tests]

            if state == "fail" and ret != 0:
                ret = 0
                continue

            if ret != 0:
                continue

            # Check for matching payload
            if state == "ok" and not payload_check(table_payload_expected,
                                                   payload_log, cmd):
                error += 1
                gotf = open("%s.payload.got" % filename_path, 'a')
                payload_log.seek(0, 0)
                gotf.write("# %s\n" % rule[0])
                while True:
                    line = payload_log.readline()
                    if line == "":
                        break
                    gotf.write(line)
                gotf.close()
                print_warning("Wrote payload for rule %s" % rule[0],
                              gotf.name, 1)

            # Check for matching ruleset listing
            numeric_proto_old = nftables.set_numeric_proto_output(True)
            stateless_old = nftables.set_stateless_output(True)
            list_cmd = 'list table %s' % table
            rc, pre_output, err = nftables.cmd(list_cmd)
            nftables.set_numeric_proto_output(numeric_proto_old)
            nftables.set_stateless_output(stateless_old)

            output = pre_output.split(";")
            if len(output) < 2:
                reason = cmd + ": Listing is broken."
                print_error(reason, filename, lineno)
                ret = -1
                error += 1
                if not force_all_family_option:
                    return [ret, warning, error, unit_tests]
                continue

            rule_output = output_clean(pre_output, chain)
            retest_output = False
            if len(rule) == 3:
                teoric_exit = rule[2]
                retest_output = True
            else:
                teoric_exit = rule[0]

            if rule_output.rstrip() != teoric_exit.rstrip():
                if rule[0].find("{") != -1:  # anonymous sets
                    if not set_check_element(teoric_exit.rstrip(),
                                         rule_output.rstrip()):
                        warning += 1
                        retest_output = True
                        print_differences_warning(filename, lineno,
                                                  teoric_exit.rstrip(),
                                                  rule_output, cmd)
                        if not force_all_family_option:
                            return [ret, warning, error, unit_tests]
                else:
                    if len(rule_output) <= 0:
                        error += 1
                        print_differences_error(filename, lineno, cmd)
                        if not force_all_family_option:
                            return [ret, warning, error, unit_tests]

                    warning += 1
                    retest_output = True
                    print_differences_warning(filename, lineno,
                                              teoric_exit.rstrip(),
                                              rule_output, cmd)

                    if not force_all_family_option:
                        return [ret, warning, error, unit_tests]

            if retest_output:
                table_flush(table, filename, lineno)

                # Add rule and check return code
                cmd = "add rule %s %s %s" % (table, chain, rule_output.rstrip())
                ret = execute_cmd(cmd, filename, lineno, payload_log, debug="netlink")

                if ret != 0:
                    test_state = "Replaying rule failed."
                    reason = cmd + ": " + test_state
                    print_warning(reason, filename, lineno)
                    ret = -1
                    error += 1
                    if not force_all_family_option:
                        return [ret, warning, error, unit_tests]
                # Check for matching payload
                elif not payload_check(table_payload_expected,
                                       payload_log, cmd):
                    error += 1

            if not enable_json_option:
                continue

            # Generate JSON equivalent for rule if not found
            if not json_input:
                json_old = nftables.set_json_output(True)
                rc, json_output, err = nftables.cmd(list_cmd)
                nftables.set_json_output(json_old)

                json_output = json.loads(json_output)
                for item in json_output["nftables"]:
                    if "rule" in item:
                        del(item["rule"]["handle"])
                        json_output = item["rule"]
                        break
                json_input = json.dumps(json_output["expr"], sort_keys = True)

                gotf = open("%s.json.got" % filename_path, 'a')
                jdump = json_dump_normalize(json_input, True)
                gotf.write("# %s\n%s\n\n" % (rule[0], jdump))
                gotf.close()
                print_warning("Wrote JSON equivalent for rule %s" % rule[0],
                              gotf.name, 1)

            table_flush(table, filename, lineno)
            payload_log = tempfile.TemporaryFile(mode="w+")

            # Add rule in JSON format
            cmd = json.dumps({ "nftables": [{ "add": { "rule": {
                    "family": table.family,
                    "table": table.name,
                    "chain": chain.name,
                    "expr": json.loads(json_input),
            }}}]})

            if enable_json_schema:
                json_validate(cmd)

            json_old = nftables.set_json_output(True)
            ret = execute_cmd(cmd, filename, lineno, payload_log, debug="netlink")
            nftables.set_json_output(json_old)

            if ret != 0:
                reason = "Failed to add JSON equivalent rule"
                print_error(reason, filename, lineno)
                continue

            # Check for matching payload
            if not payload_check(table_payload_expected, payload_log, cmd):
                error += 1
                gotf = open("%s.json.payload.got" % filename_path, 'a')
                payload_log.seek(0, 0)
                gotf.write("# %s\n" % rule[0])
                while True:
                    line = payload_log.readline()
                    if line == "":
                        break
                    gotf.write(line)
                gotf.close()
                print_warning("Wrote JSON payload for rule %s" % rule[0],
                              gotf.name, 1)

            # Check for matching ruleset listing
            numeric_proto_old = nftables.set_numeric_proto_output(True)
            stateless_old = nftables.set_stateless_output(True)
            json_old = nftables.set_json_output(True)
            rc, json_output, err = nftables.cmd(list_cmd)
            nftables.set_json_output(json_old)
            nftables.set_numeric_proto_output(numeric_proto_old)
            nftables.set_stateless_output(stateless_old)

            if enable_json_schema:
                json_validate(json_output)

            json_output = json.loads(json_output)
            for item in json_output["nftables"]:
                if "rule" in item:
                    del(item["rule"]["handle"])
                    json_output = item["rule"]
                    break
            json_output = json.dumps(json_output["expr"], sort_keys = True)

            if not json_expected and json_output != json_input:
                print_differences_warning(filename, lineno,
                                          json_input, json_output, cmd)
                error += 1
                gotf = open("%s.json.output.got" % filename_path, 'a')
                jdump = json_dump_normalize(json_output, True)
                gotf.write("# %s\n%s\n\n" % (rule[0], jdump))
                gotf.close()
                print_warning("Wrote JSON output for rule %s" % rule[0],
                              gotf.name, 1)
                # prevent further warnings and .got file updates
                json_expected = json_output
            elif json_expected and json_output != json_expected:
                print_differences_warning(filename, lineno,
                                          json_expected, json_output, cmd)
                error += 1

    return [ret, warning, error, unit_tests]


def cleanup_on_exit():
    for table in table_list:
        for table_chain in table.chains:
            chain = chain_get_by_name(table_chain)
            chain_delete(chain, table, "", "")
        if all_set:
            set_delete(table)
        if obj_list:
            obj_delete(table)
        table_delete(table)


def signal_handler(signal, frame):
    global signal_received
    signal_received = 1


def execute_cmd(cmd, filename, lineno, stdout_log=False, debug=False):
    '''
    Executes a command, checks for segfaults and returns the command exit
    code.

    :param cmd: string with the command to be executed
    :param filename: name of the file tested (used for print_error purposes)
    :param lineno: line number being tested (used for print_error purposes)
    :param stdout_log: redirect stdout to this file instead of global log_file
    :param debug: temporarily set these debug flags
    '''
    global log_file
    print("command: {}".format(cmd), file=log_file)
    if debug_option:
        print(cmd)

    if debug:
        debug_old = nftables.get_debug()
        nftables.set_debug(debug)

    ret, out, err = nftables.cmd(cmd)

    if not stdout_log:
        stdout_log = log_file

    stdout_log.write(out)
    stdout_log.flush()
    log_file.write(err)
    log_file.flush()

    if debug:
        nftables.set_debug(debug_old)

    return ret


def print_result(filename, tests, warning, error):
    return str(filename) + ": " + str(tests) + " unit tests, " + str(error) + \
           " error, " + str(warning) + " warning"


def print_result_all(filename, tests, warning, error, unit_tests):
    return str(filename) + ": " + str(tests) + " unit tests, " + \
           str(unit_tests) + " total test executed, " + str(error) + \
           " error, " + str(warning) + " warning"


def table_process(table_line, filename, lineno):
    table_info = table_line.split(";")
    table = Table(table_info[0], table_info[1], table_info[2].split(","))

    return table_create(table, filename, lineno)


def chain_process(chain_line, lineno):
    chain_info = chain_line.split(";")
    chain_list.append(Chain(chain_info[0], chain_info[1], lineno))

    return 0


def set_process(set_line, filename, lineno):
    test_result = set_line[1]
    timeout=""

    tokens = set_line[0].split(" ")
    set_name = tokens[0]
    set_type = tokens[2]
    set_flags = ""

    i = 3
    while len(tokens) > i and tokens[i] == ".":
        set_type += " . " + tokens[i+1]
        i += 2

    if len(tokens) == i+2 and tokens[i] == "timeout":
        timeout = "timeout " + tokens[i+1] + ";"
        i += 2

    if len(tokens) == i+2 and tokens[i] == "flags":
        set_flags = tokens[i+1]
    elif len(tokens) != i:
        print_error(set_name + " bad flag: " + tokens[i], filename, lineno)

    s = Set("", "", set_name, set_type, timeout, set_flags)

    ret = set_add(s, test_result, filename, lineno)
    if ret == 0:
        all_set[set_name] = set()

    return ret


def set_element_process(element_line, filename, lineno):
    rule_state = element_line[1]
    element_line = element_line[0]
    space = element_line.find(" ")
    set_name = element_line[:space]
    set_element = element_line[space:].split(",")

    return set_add_elements(set_element, set_name, rule_state, filename, lineno)


def obj_process(obj_line, filename, lineno):
    test_result = obj_line[1]

    tokens = obj_line[0].split(" ")
    obj_name = tokens[0]
    obj_type = tokens[2]
    obj_spcf = ""

    if obj_type == "ct" and tokens[3] == "helper":
       obj_type = "ct helper"
       tokens[3] = ""

    if obj_type == "ct" and tokens[3] == "timeout":
       obj_type = "ct timeout"
       tokens[3] = ""

    if obj_type == "ct" and tokens[3] == "expectation":
       obj_type = "ct expectation"
       tokens[3] = ""

    if len(tokens) > 3:
        obj_spcf = " ".join(tokens[3:])

    o = Obj("", "", obj_name, obj_type, obj_spcf)

    ret = obj_add(o, test_result, filename, lineno)
    if ret == 0:
        obj_list.append(o)

    return ret


def payload_find_expected(payload_log, rule):
    '''
    Find the netlink payload that should be generated by given rule in
    payload_log

    :param payload_log: open file handle of the payload data
    :param rule: nft rule we are going to add
    '''
    found = 0
    payload_buffer = []

    while True:
        line = payload_log.readline()
        if not line:
            break

        if line[0] == "#":  # rule start
            rule_line = line.strip()[2:]

            if rule_line == rule.strip():
                found = 1
                continue

        if found == 1:
            payload_buffer.append(line)
            if line.isspace():
                return payload_buffer

    payload_log.seek(0, 0)
    return payload_buffer


def json_find_expected(json_log, rule):
    '''
    Find the corresponding JSON for given rule

    :param json_log: open file handle of the json data
    :param rule: nft rule we are going to add
    '''
    found = 0
    json_buffer = ""

    while True:
        line = json_log.readline()
        if not line:
            break

        if line[0] == "#":  # rule start
            rule_line = line.strip()[2:]

            if rule_line == rule.strip():
                found = 1
                continue

        if found == 1:
            json_buffer += line.rstrip("\n").strip()
            if line.isspace():
                return json_buffer

    json_log.seek(0, 0)
    return json_buffer


def run_test_file(filename, force_all_family_option, specific_file):
    '''
    Runs a test file

    :param filename: name of the file with the test rules
    '''
    filename_path = os.path.join(TESTS_PATH, filename)
    f = open(filename_path)
    tests = passed = total_unit_run = total_warning = total_error = 0

    for lineno, line in enumerate(f):
        sys.stdout.flush()

        if signal_received == 1:
            print("\nSignal received. Cleaning up and Exitting...")
            cleanup_on_exit()
            sys.exit(0)

        if line.isspace():
            continue

        if line[0] == "#":  # Command-line
            continue

        if line[0] == '*':  # Table
            table_line = line.rstrip()[1:]
            ret = table_process(table_line, filename, lineno)
            if ret != 0:
                break
            continue

        if line[0] == ":":  # Chain
            chain_line = line.rstrip()[1:]
            ret = chain_process(chain_line, lineno)
            if ret != 0:
                break
            continue

        if line[0] == "!":  # Adds this set
            set_line = line.rstrip()[1:].split(";")
            ret = set_process(set_line, filename, lineno)
            tests += 1
            if ret == -1:
                continue
            passed += 1
            continue

        if line[0] == "?":  # Adds elements in a set
            element_line = line.rstrip()[1:].split(";")
            ret = set_element_process(element_line, filename, lineno)
            tests += 1
            if ret == -1:
                continue

            passed += 1
            continue

        if line[0] == "%":  # Adds this object
            brace = line.rfind("}")
            if brace < 0:
                obj_line = line.rstrip()[1:].split(";")
            else:
                obj_line = (line[1:brace+1], line[brace+2:].rstrip())

            ret = obj_process(obj_line, filename, lineno)
            tests += 1
            if ret == -1:
                continue
            passed += 1
            continue

        # Rule
        rule = line.split(';')  # rule[1] Ok or FAIL
        if len(rule) == 1 or len(rule) > 3 or rule[1].rstrip() \
                not in {"ok", "fail"}:
            reason = "Skipping malformed rule test. (%s)" % line.rstrip('\n')
            print_warning(reason, filename, lineno)
            continue

        if line[0] == "-":  # Run omitted lines
            if need_fix_option:
                rule[0] = rule[0].rstrip()[1:].strip()
            else:
                continue
        elif need_fix_option:
            continue

        result = rule_add(rule, filename, lineno, force_all_family_option,
                          filename_path)
        tests += 1
        ret = result[0]
        warning = result[1]
        total_warning += warning
        total_error += result[2]
        total_unit_run += result[3]

        if ret != 0:
            continue

        if warning == 0:  # All ok.
            passed += 1

    # Delete rules, sets, chains and tables
    for table in table_list:
        # We delete chains
        for table_chain in table.chains:
            chain = chain_get_by_name(table_chain)
            chain_delete(chain, table, filename, lineno)

        # We delete sets.
        if all_set:
            ret = set_delete(table, filename, lineno)
            if ret != 0:
                reason = "There is a problem when we delete a set"
                print_error(reason, filename, lineno)

        # We delete tables.
        table_delete(table, filename, lineno)

    if specific_file:
        if force_all_family_option:
            print(print_result_all(filename, tests, total_warning, total_error,
                                   total_unit_run))
        else:
            print(print_result(filename, tests, total_warning, total_error))
    else:
        if tests == passed and tests > 0:
            print(filename + ": " + Colors.GREEN + "OK" + Colors.ENDC)

    f.close()
    del table_list[:]
    del chain_list[:]
    all_set.clear()

    return [tests, passed, total_warning, total_error, total_unit_run]


def main():
    parser = argparse.ArgumentParser(description='Run nft tests')

    parser.add_argument('filenames', nargs='*', metavar='path/to/file.t',
                        help='Run only these tests')

    parser.add_argument('-d', '--debug', action='store_true', dest='debug',
                        help='enable debugging mode')

    parser.add_argument('-e', '--need-fix', action='store_true',
                        dest='need_fix_line', help='run rules that need a fix')

    parser.add_argument('-f', '--force-family', action='store_true',
                        dest='force_all_family',
                        help='keep testing all families on error')

    parser.add_argument('-H', '--host', action='store_true',
                        help='run tests against installed libnftables.so.1')

    parser.add_argument('-j', '--enable-json', action='store_true',
                        dest='enable_json',
                        help='test JSON functionality as well')

    parser.add_argument('-l', '--library', default=None,
                        help='path to libntables.so.1, overrides --host')

    parser.add_argument('-s', '--schema', action='store_true',
                        dest='enable_schema',
                        help='verify json input/output against schema')

    parser.add_argument('-v', '--version', action='version',
                        version='1.0',
                        help='Print the version information')

    args = parser.parse_args()
    global debug_option, need_fix_option, enable_json_option, enable_json_schema
    debug_option = args.debug
    need_fix_option = args.need_fix_line
    force_all_family_option = args.force_all_family
    enable_json_option = args.enable_json
    enable_json_schema = args.enable_schema
    specific_file = False

    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    if os.getuid() != 0:
        print("You need to be root to run this, sorry")
        return

    # Change working directory to repository root
    os.chdir(TESTS_PATH + "/../..")

    check_lib_path = True
    if args.library is None:
        if args.host:
            args.library = 'libnftables.so.1'
            check_lib_path = False
        else:
            args.library = 'src/.libs/libnftables.so.1'

    if check_lib_path and not os.path.exists(args.library):
        print("The nftables library at '%s' does not exist. "
              "You need to build the project." % args.library)
        return

    if args.enable_schema and not args.enable_json:
        print_error("Option --schema requires option --json")
        return

    global nftables
    nftables = Nftables(sofile = args.library)

    test_files = files_ok = run_total = 0
    tests = passed = warnings = errors = 0
    global log_file
    try:
        log_file = open(LOGFILE, 'w')
        print_info("Log will be available at %s" % LOGFILE)
    except IOError:
        print_error("Cannot open log file %s" % LOGFILE)
        return

    file_list = []
    if args.filenames:
        file_list = args.filenames
        if len(args.filenames) == 1:
            specific_file = True
    else:
        for directory in TESTS_DIRECTORY:
            path = os.path.join(TESTS_PATH, directory)
            for root, dirs, files in os.walk(path):
                for f in files:
                    if f.endswith(".t"):
                        file_list.append(os.path.join(directory, f))

    for filename in file_list:
        result = run_test_file(filename, force_all_family_option, specific_file)
        file_tests = result[0]
        file_passed = result[1]
        file_warnings = result[2]
        file_errors = result[3]
        file_unit_run = result[4]

        test_files += 1

        if file_warnings == 0 and file_tests == file_passed:
            files_ok += 1
        if file_tests:
            tests += file_tests
            passed += file_passed
            errors += file_errors
            warnings += file_warnings
        if force_all_family_option:
            run_total += file_unit_run

    if test_files == 0:
        print("No test files to run")
    else:
        if not specific_file:
            if force_all_family_option:
                print("%d test files, %d files passed, %d unit tests, " % (test_files, files_ok, tests))
                print("%d total executed, %d error, %d warning" % (run_total, errors,warnings))
            else:
                print("%d test files, %d files passed, %d unit tests, " % (test_files, files_ok, tests))
                print("%d error, %d warning" % (errors, warnings))

if __name__ == '__main__':
    main()