#!/usr/bin/env python
#
# (C) 2014 by Ana Rey Botello <anarey@gmail.com>
#
# Based on iptables-test.py:
# (C) 2012 by Pablo Neira Ayuso <pablo@netfilter.org>"
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# Thanks to the Outreach Program for Women (OPW) for sponsoring this test
# infrastructure.
from __future__ import print_function
import sys
import os
import argparse
import signal
import json
import traceback
import tempfile
TESTS_PATH = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.join(TESTS_PATH, '../../py/'))
os.environ['TZ'] = 'UTC-2'
from nftables import Nftables
TESTS_DIRECTORY = ["any", "arp", "bridge", "inet", "ip", "ip6"]
LOGFILE = "/tmp/nftables-test.log"
log_file = None
table_list = []
chain_list = []
all_set = dict()
obj_list = []
signal_received = 0
class Colors:
if sys.stdout.isatty():
HEADER = '\033[95m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
RED = '\033[91m'
ENDC = '\033[0m'
else:
HEADER = ''
GREEN = ''
YELLOW = ''
RED = ''
ENDC = ''
class Chain:
"""Class that represents a chain"""
def __init__(self, name, config, lineno):
self.name = name
self.config = config
self.lineno = lineno
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __str__(self):
return "%s" % self.name
class Table:
"""Class that represents a table"""
def __init__(self, family, name, chains):
self.family = family
self.name = name
self.chains = chains
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __str__(self):
return "%s %s" % (self.family, self.name)
class Set:
"""Class that represents a set"""
def __init__(self, family, table, name, type, timeout, flags):
self.family = family
self.table = table
self.name = name
self.type = type
self.timeout = timeout
self.flags = flags
def __eq__(self, other):
return self.__dict__ == other.__dict__
class Obj:
"""Class that represents an object"""
def __init__(self, table, family, name, type, spcf):
self.table = table
self.family = family
self.name = name
self.type = type
self.spcf = spcf
def __eq__(self, other):
return self.__dict__ == other.__dict__
def print_msg(reason, errstr, filename=None, lineno=None, color=None):
'''
Prints a message with nice colors, indicating file and line number.
'''
color_errstr = "%s%s%s" % (color, errstr, Colors.ENDC)
if filename and lineno:
sys.stderr.write("%s: %s line %d: %s\n" %
(filename, color_errstr, lineno + 1, reason))
else:
sys.stderr.write("%s %s\n" % (color_errstr, reason))
sys.stderr.flush() # So that the message stay in the right place.
def print_error(reason, filename=None, lineno=None):
print_msg(reason, "ERROR:", filename, lineno, Colors.RED)
def print_warning(reason, filename=None, lineno=None):
print_msg(reason, "WARNING:", filename, lineno, Colors.YELLOW)
def print_info(reason, filename=None, lineno=None):
print_msg(reason, "INFO:", filename, lineno, Colors.GREEN)
def color_differences(rule, other, color):
rlen = len(rule)
olen = len(other)
out = ""
i = 0
# find equal part at start
for i in range(rlen):
if i >= olen or rule[i] != other[i]:
break
if i > 0:
out += rule[:i]
rule = rule[i:]
other = other[i:]
rlen = len(rule)
olen = len(other)
# find equal part at end
for i in range(1, rlen + 1):
if i > olen or rule[rlen - i] != other[olen - i]:
i -= 1
break
if rlen > i:
out += color + rule[:rlen - i] + Colors.ENDC
rule = rule[rlen - i:]
out += rule
return out
def print_differences_warning(filename, lineno, rule1, rule2, cmd):
colored_rule1 = color_differences(rule1, rule2, Colors.YELLOW)
colored_rule2 = color_differences(rule2, rule1, Colors.YELLOW)
reason = "'%s': '%s' mismatches '%s'" % (cmd, colored_rule1, colored_rule2)
print_warning(reason, filename, lineno)
def print_differences_error(filename, lineno, cmd):
reason = "'%s': Listing is broken." % cmd
print_error(reason, filename, lineno)
def table_exist(table, filename, lineno):
'''
Exists a table.
'''
cmd = "list table %s" % table
ret = execute_cmd(cmd, filename, lineno)
return True if (ret == 0) else False
def table_flush(table, filename, lineno):
'''
Flush a table.
'''
cmd = "flush table %s" % table
execute_cmd(cmd, filename, lineno)
return cmd
def table_create(table, filename, lineno):
'''
Adds a table.
'''
# We check if table exists.
if table_exist(table, filename, lineno):
reason = "Table %s already exists" % table
print_error(reason, filename, lineno)
return -1
table_list.append(table)
# We add a new table
cmd = "add table %s" % table
ret = execute_cmd(cmd, filename, lineno)
if ret != 0:
reason = "Cannot " + cmd
print_error(reason, filename, lineno)
table_list.remove(table)
return -1
# We check if table was added correctly.
if not table_exist(table, filename, lineno):
table_list.remove(table)
reason = "I have just added the table %s " \
"but it does not exist. Giving up!" % table
print_error(reason, filename, lineno)
return -1
for table_chain in table.chains:
chain = chain_get_by_name(table_chain)
if chain is None:
reason = "The chain %s requested by table %s " \
"does not exist." % (table_chain, table)
print_error(reason, filename, lineno)
else:
chain_create(chain, table, filename)
return 0
def table_delete(table, filename=None, lineno=None):
'''
Deletes a table.
'''
if not table_exist(table, filename, lineno):
reason = "Table %s does not exist but I added it before." % table
print_error(reason, filename, lineno)
return -1
cmd = "delete table %s" % table
ret = execute_cmd(cmd, filename, lineno)
if ret != 0:
reason = "%s: I cannot delete table %s. Giving up!" % (cmd, table)
print_error(reason, filename, lineno)
return -1
if table_exist(table, filename, lineno):
reason = "I have just deleted the table %s " \
"but it still exists." % table
print_error(reason, filename, lineno)
return -1
return 0
def chain_exist(chain, table, filename):
'''
Checks a chain
'''
cmd = "list chain %s %s" % (table, chain)
ret = execute_cmd(cmd, filename, chain.lineno)
return True if (ret == 0) else False
def chain_create(chain, table, filename):
'''
Adds a chain
'''
if chain_exist(chain, table, filename):
reason = "This chain '%s' exists in %s. I cannot create " \
"two chains with same name." % (chain, table)
print_error(reason, filename, chain.lineno)
return -1
cmd = "add chain %s %s" % (table, chain)
if chain.config:
cmd += " { %s; }" % chain.config
ret = execute_cmd(cmd, filename, chain.lineno)
if ret != 0:
reason = "I cannot create the chain '%s'" % chain
print_error(reason, filename, chain.lineno)
return -1
if not chain_exist(chain, table, filename):
reason = "I have added the chain '%s' " \
"but it does not exist in %s" % (chain, table)
print_error(reason, filename, chain.lineno)
return -1
return 0
def chain_delete(chain, table, filename=None, lineno=None):
'''
Flushes and deletes a chain.
'''
if not chain_exist(chain, table, filename):
reason = "The chain %s does not exist in %s. " \
"I cannot delete it." % (chain, table)
print_error(reason, filename, lineno)
return -1
cmd = "flush chain %s %s" % (table, chain)
ret = execute_cmd(cmd, filename, lineno)
if ret != 0:
reason = "I cannot " + cmd
print_error(reason, filename, lineno)
return -1
cmd = "delete chain %s %s" % (table, chain)
ret = execute_cmd(cmd, filename, lineno)
if ret != 0:
reason = "I cannot " + cmd
print_error(reason, filename, lineno)
return -1
if chain_exist(chain, table, filename):
reason = "The chain %s exists in %s. " \
"I cannot delete this chain" % (chain, table)
print_error(reason, filename, lineno)
return -1
return 0
def chain_get_by_name(name):
for chain in chain_list:
if chain.name == name:
break
else:
chain = None
return chain
def set_add(s, test_result, filename, lineno):
'''
Adds a set.
'''
if not table_list:
reason = "Missing table to add rule"
print_error(reason, filename, lineno)
return -1
for table in table_list:
s.table = table.name
s.family = table.family
if _set_exist(s, filename, lineno):
reason = "Set %s already exists in %s" % (s.name, table)
print_error(reason, filename, lineno)
return -1
flags = s.flags
if flags != "":
flags = "flags %s; " % flags
cmd = "add set %s %s { type %s;%s %s}" % (table, s.name, s.type, s.timeout, flags)
ret = execute_cmd(cmd, filename, lineno)
if (ret == 0 and test_result == "fail") or \
(ret != 0 and test_result == "ok"):
reason = "%s: I cannot add the set %s" % (cmd, s.name)
print_error(reason, filename, lineno)
return -1
if not _set_exist(s, filename, lineno):
reason = "I have just added the set %s to " \
"the table %s but it does not exist" % (s.name, table)
print_error(reason, filename, lineno)
return -1
return 0
def set_add_elements(set_element, set_name, state, filename, lineno):
'''
Adds elements to the set.
'''
if not table_list:
reason = "Missing table to add rules"
print_error(reason, filename, lineno)
return -1
for table in table_list:
# Check if set exists.
if (not set_exist(set_name, table, filename, lineno) or
set_name not in all_set) and state == "ok":
reason = "I cannot add an element to the set %s " \
"since it does not exist." % set_name
print_error(reason, filename, lineno)
return -1
element = ", ".join(set_element)
cmd = "add element %s %s { %s }" % (table, set_name, element)
ret = execute_cmd(cmd, filename, lineno)
if (state == "fail" and ret == 0) or (state == "ok" and ret != 0):
test_state = "This rule should have failed."
reason = cmd + ": " + test_state
print_error(reason, filename, lineno)
return -1
# Add element into all_set.
if ret == 0 and state == "ok":
for e in set_element:
all_set[set_name].add(e)
return 0
def set_delete_elements(set_element, set_name, table, filename=None,
lineno=None):
'''
Deletes elements in a set.
'''
for element in set_element:
cmd = "delete element %s %s { %s }" % (table, set_name, element)
ret = execute_cmd(cmd, filename, lineno)
if ret != 0:
reason = "I cannot delete element %s " \
"from the set %s" % (element, set_name)
print_error(reason, filename, lineno)
return -1
return 0
def set_delete(table, filename=None, lineno=None):
'''
Deletes set and its content.
'''
for set_name in all_set.keys():
# Check if exists the set
if not set_exist(set_name, table, filename, lineno):
reason = "The set %s does not exist, " \
"I cannot delete it" % set_name
print_error(reason, filename, lineno)
return -1
# We delete all elements in the set
set_delete_elements(all_set[set_name], set_name, table, filename,
lineno)
# We delete the set.
cmd = "delete set %s %s" % (table, set_name)
ret = execute_cmd(cmd, filename, lineno)
# Check if the set still exists after I deleted it.
if ret != 0 or set_exist(set_name, table, filename, lineno):
reason = "Cannot remove the set " + set_name
print_error(reason, filename, lineno)
return -1
return 0
def set_exist(set_name, table, filename, lineno):
'''
Check if the set exists.
'''
cmd = "list set %s %s" % (table, set_name)
ret = execute_cmd(cmd, filename, lineno)
return True if (ret == 0) else False
def _set_exist(s, filename, lineno):
'''
Check if the set exists.
'''
cmd = "list set %s %s %s" % (s.family, s.table, s.name)
ret = execute_cmd(cmd, filename, lineno)
return True if (ret == 0) else False
def set_check_element(rule1, rule2):
'''
Check if element exists in anonymous sets.
'''
pos1 = rule1.find("{")
pos2 = rule2.find("{")
if (rule1[:pos1] != rule2[:pos2]):
return False
end1 = rule1.find("}")
end2 = rule2.find("}")
if (pos1 != -1) and (pos2 != -1) and (end1 != -1) and (end2 != -1):
list1 = (rule1[pos1 + 1:end1].replace(" ", "")).split(",")
list2 = (rule2[pos2 + 1:end2].replace(" ", "")).split(",")
list1.sort()
list2.sort()
if list1 != list2:
return False
return rule1[end1:] == rule2[end2:]
return False
def obj_add(o, test_result, filename, lineno):
'''
Adds an object.
'''
if not table_list:
reason = "Missing table to add rule"
print_error(reason, filename, lineno)
return -1
for table in table_list:
o.table = table.name
o.family = table.family
obj_handle = o.type + " " + o.name
if _obj_exist(o, filename, lineno):
reason = "The %s already exists in %s" % (obj_handle, table)
print_error(reason, filename, lineno)
return -1
cmd = "add %s %s %s %s" % (o.type, table, o.name, o.spcf)
ret = execute_cmd(cmd, filename, lineno)
if (ret == 0 and test_result == "fail") or \
(ret != 0 and test_result == "ok"):
reason = "%s: I cannot add the %s" % (cmd, obj_handle)
print_error(reason, filename, lineno)
return -1
exist = _obj_exist(o, filename, lineno)
if exist:
if test_result == "ok":
return 0
reason = "I added the %s to the table %s " \
"but it should have failed" % (obj_handle, table)
print_error(reason, filename, lineno)
return -1
if test_result == "fail":
return 0
reason = "I have just added the %s to " \
"the table %s but it does not exist" % (obj_handle, table)
print_error(reason, filename, lineno)
return -1
def obj_delete(table, filename=None, lineno=None):
'''
Deletes object.
'''
for o in obj_list:
obj_handle = o.type + " " + o.name
# Check if exists the obj
if not obj_exist(o, table, filename, lineno):
reason = "The %s does not exist, I cannot delete it" % obj_handle
print_error(reason, filename, lineno)
return -1
# We delete the object.
cmd = "delete %s %s %s" % (o.type, table, o.name)
ret = execute_cmd(cmd, filename, lineno)
# Check if the object still exists after I deleted it.
if ret != 0 or obj_exist(o, table, filename, lineno):
reason = "Cannot remove the " + obj_handle
print_error(reason, filename, lineno)
return -1
return 0
def obj_exist(o, table, filename, lineno):
'''
Check if the object exists.
'''
cmd = "list %s %s %s" % (o.type, table, o.name)
ret = execute_cmd(cmd, filename, lineno)
return True if (ret == 0) else False
def _obj_exist(o, filename, lineno):
'''
Check if the object exists.
'''
cmd = "list %s %s %s %s" % (o.type, o.family, o.table, o.name)
ret = execute_cmd(cmd, filename, lineno)
return True if (ret == 0) else False
def output_clean(pre_output, chain):
pos_chain = pre_output.find(chain.name)
if pos_chain == -1:
return ""
output_intermediate = pre_output[pos_chain:]
brace_start = output_intermediate.find("{")
brace_end = output_intermediate.find("}")
pre_rule = output_intermediate[brace_start:brace_end]
if pre_rule[1:].find("{") > -1: # this rule has a set.
set = pre_rule[1:].replace("\t", "").replace("\n", "").strip()
set = set.split(";")[2].strip() + "}"
remainder = output_clean(chain.name + " {;;" + output_intermediate[brace_end+1:], chain)
if len(remainder) <= 0:
return set
return set + " " + remainder
else:
rule = pre_rule.split(";")[2].replace("\t", "").replace("\n", "").\
strip()
if len(rule) < 0:
return ""
return rule
def payload_check_elems_to_set(elems):
newset = set()
for n, line in enumerate(elems.split('[end]')):
e = line.strip()
if e in newset:
print_error("duplicate", e, n)
return newset
newset.add(e)
return newset
def payload_check_set_elems(want, got):
if want.find('element') < 0 or want.find('[end]') < 0:
return 0
if got.find('element') < 0 or got.find('[end]') < 0:
return 0
set_want = payload_check_elems_to_set(want)
set_got = payload_check_elems_to_set(got)
return set_want == set_got
def payload_check(payload_buffer, file, cmd):
file.seek(0, 0)
i = 0
if not payload_buffer:
return False
for lineno, want_line in enumerate(payload_buffer):
line = file.readline()
if want_line == line:
i += 1
continue
if want_line.find('[') < 0 and line.find('[') < 0:
continue
if want_line.find(']') < 0 and line.find(']') < 0:
continue
if payload_check_set_elems(want_line, line):
continue
print_differences_warning(file.name, lineno, want_line.strip(),
line.strip(), cmd)
return 0
return i > 0
def json_dump_normalize(json_string, human_readable = False):
json_obj = json.loads(json_string)
if human_readable:
return json.dumps(json_obj, sort_keys = True,
indent = 4, separators = (',', ': '))
else:
return json.dumps(json_obj, sort_keys = True)
def json_validate(json_string):
json_obj = json.loads(json_string)
try:
nftables.json_validate(json_obj)
except Exception:
print_error("schema validation failed for input '%s'" % json_string)
print_error(traceback.format_exc())
def rule_add(rule, filename, lineno, force_all_family_option, filename_path):
'''
Adds a rule
'''
# TODO Check if a rule is added correctly.
ret = warning = error = unit_tests = 0
if not table_list or not chain_list:
reason = "Missing table or chain to add rule."
print_error(reason, filename, lineno)
return [-1, warning, error, unit_tests]
if rule[1].strip() == "ok":
payload_expected = None
try:
payload_log = open("%s.payload" % filename_path)
payload_expected = payload_find_expected(payload_log, rule[0])
except:
payload_log = None
if enable_json_option:
try:
json_log = open("%s.json" % filename_path)
json_input = json_find_expected(json_log, rule[0])
except:
json_input = None
if not json_input:
print_error("did not find JSON equivalent for rule '%s'"
% rule[0])
else:
try:
json_input = json_dump_normalize(json_input)
except ValueError:
reason = "Invalid JSON syntax in rule: %s" % json_input
print_error(reason)
return [-1, warning, error, unit_tests]
try:
json_log = open("%s.json.output" % filename_path)
json_expected = json_find_expected(json_log, rule[0])
except:
# will use json_input for comparison
json_expected = None
if json_expected:
try:
json_expected = json_dump_normalize(json_expected)
except ValueError:
reason = "Invalid JSON syntax in expected output: %s" % json_expected
print_error(reason)
return [-1, warning, error, unit_tests]
for table in table_list:
if rule[1].strip() == "ok":
table_payload_expected = None
try:
payload_log = open("%s.payload.%s" % (filename_path, table.family))
table_payload_expected = payload_find_expected(payload_log, rule[0])
except:
if not payload_log:
print_error("did not find any payload information",
filename_path)
elif not payload_expected:
print_error("did not find payload information for "
"rule '%s'" % rule[0], payload_log.name, 1)
if not table_payload_expected:
table_payload_expected = payload_expected
for table_chain in table.chains:
chain = chain_get_by_name(table_chain)
unit_tests += 1
table_flush(table, filename, lineno)
payload_log = tempfile.TemporaryFile(mode="w+")
# Add rule and check return code
cmd = "add rule %s %s %s" % (table, chain, rule[0])
ret = execute_cmd(cmd, filename, lineno, payload_log, debug="netlink")
state = rule[1].rstrip()
if (ret in [0,134] and state == "fail") or (ret != 0 and state == "ok"):
if state == "fail":
test_state = "This rule should have failed."
else:
test_state = "This rule should not have failed."
reason = cmd + ": " + test_state
print_error(reason, filename, lineno)
ret = -1
error += 1
if not force_all_family_option:
return [ret, warning, error, unit_tests]
if state == "fail" and ret != 0:
ret = 0
continue
if ret != 0:
continue
# Check for matching payload
if state == "ok" and not payload_check(table_payload_expected,
payload_log, cmd):
error += 1
gotf = open("%s.payload.got" % filename_path, 'a')
payload_log.seek(0, 0)
gotf.write("# %s\n" % rule[0])
while True:
line = payload_log.readline()
if line == "":
break
gotf.write(line)
gotf.close()
print_warning("Wrote payload for rule %s" % rule[0],
gotf.name, 1)
# Check for matching ruleset listing
numeric_proto_old = nftables.set_numeric_proto_output(True)
stateless_old = nftables.set_stateless_output(True)
list_cmd = 'list table %s' % table
rc, pre_output, err = nftables.cmd(list_cmd)
nftables.set_numeric_proto_output(numeric_proto_old)
nftables.set_stateless_output(stateless_old)
output = pre_output.split(";")
if len(output) < 2:
reason = cmd + ": Listing is broken."
print_error(reason, filename, lineno)
ret = -1
error += 1
if not force_all_family_option:
return [ret, warning, error, unit_tests]
continue
rule_output = output_clean(pre_output, chain)
retest_output = False
if len(rule) == 3:
teoric_exit = rule[2]
retest_output = True
else:
teoric_exit = rule[0]
if rule_output.rstrip() != teoric_exit.rstrip():
if rule[0].find("{") != -1: # anonymous sets
if not set_check_element(teoric_exit.rstrip(),
rule_output.rstrip()):
warning += 1
retest_output = True
print_differences_warning(filename, lineno,
teoric_exit.rstrip(),
rule_output, cmd)
if not force_all_family_option:
return [ret, warning, error, unit_tests]
else:
if len(rule_output) <= 0:
error += 1
print_differences_error(filename, lineno, cmd)
if not force_all_family_option:
return [ret, warning, error, unit_tests]
warning += 1
retest_output = True
print_differences_warning(filename, lineno,
teoric_exit.rstrip(),
rule_output, cmd)
if not force_all_family_option:
return [ret, warning, error, unit_tests]
if retest_output:
table_flush(table, filename, lineno)
# Add rule and check return code
cmd = "add rule %s %s %s" % (table, chain, rule_output.rstrip())
ret = execute_cmd(cmd, filename, lineno, payload_log, debug="netlink")
if ret != 0:
test_state = "Replaying rule failed."
reason = cmd + ": " + test_state
print_warning(reason, filename, lineno)
ret = -1
error += 1
if not force_all_family_option:
return [ret, warning, error, unit_tests]
# Check for matching payload
elif not payload_check(table_payload_expected,
payload_log, cmd):
error += 1
if not enable_json_option:
continue
# Generate JSON equivalent for rule if not found
if not json_input:
json_old = nftables.set_json_output(True)
rc, json_output, err = nftables.cmd(list_cmd)
nftables.set_json_output(json_old)
json_output = json.loads(json_output)
for item in json_output["nftables"]:
if "rule" in item:
del(item["rule"]["handle"])
json_output = item["rule"]
break
json_input = json.dumps(json_output["expr"], sort_keys = True)
gotf = open("%s.json.got" % filename_path, 'a')
jdump = json_dump_normalize(json_input, True)
gotf.write("# %s\n%s\n\n" % (rule[0], jdump))
gotf.close()
print_warning("Wrote JSON equivalent for rule %s" % rule[0],
gotf.name, 1)
table_flush(table, filename, lineno)
payload_log = tempfile.TemporaryFile(mode="w+")
# Add rule in JSON format
cmd = json.dumps({ "nftables": [{ "add": { "rule": {
"family": table.family,
"table": table.name,
"chain": chain.name,
"expr": json.loads(json_input),
}}}]})
if enable_json_schema:
json_validate(cmd)
json_old = nftables.set_json_output(True)
ret = execute_cmd(cmd, filename, lineno, payload_log, debug="netlink")
nftables.set_json_output(json_old)
if ret != 0:
reason = "Failed to add JSON equivalent rule"
print_error(reason, filename, lineno)
continue
# Check for matching payload
if not payload_check(table_payload_expected, payload_log, cmd):
error += 1
gotf = open("%s.json.payload.got" % filename_path, 'a')
payload_log.seek(0, 0)
gotf.write("# %s\n" % rule[0])
while True:
line = payload_log.readline()
if line == "":
break
gotf.write(line)
gotf.close()
print_warning("Wrote JSON payload for rule %s" % rule[0],
gotf.name, 1)
# Check for matching ruleset listing
numeric_proto_old = nftables.set_numeric_proto_output(True)
stateless_old = nftables.set_stateless_output(True)
json_old = nftables.set_json_output(True)
rc, json_output, err = nftables.cmd(list_cmd)
nftables.set_json_output(json_old)
nftables.set_numeric_proto_output(numeric_proto_old)
nftables.set_stateless_output(stateless_old)
if enable_json_schema:
json_validate(json_output)
json_output = json.loads(json_output)
for item in json_output["nftables"]:
if "rule" in item:
del(item["rule"]["handle"])
json_output = item["rule"]
break
json_output = json.dumps(json_output["expr"], sort_keys = True)
if not json_expected and json_output != json_input:
print_differences_warning(filename, lineno,
json_input, json_output, cmd)
error += 1
gotf = open("%s.json.output.got" % filename_path, 'a')
jdump = json_dump_normalize(json_output, True)
gotf.write("# %s\n%s\n\n" % (rule[0], jdump))
gotf.close()
print_warning("Wrote JSON output for rule %s" % rule[0],
gotf.name, 1)
# prevent further warnings and .got file updates
json_expected = json_output
elif json_expected and json_output != json_expected:
print_differences_warning(filename, lineno,
json_expected, json_output, cmd)
error += 1
return [ret, warning, error, unit_tests]
def cleanup_on_exit():
for table in table_list:
for table_chain in table.chains:
chain = chain_get_by_name(table_chain)
chain_delete(chain, table, "", "")
if all_set:
set_delete(table)
if obj_list:
obj_delete(table)
table_delete(table)
def signal_handler(signal, frame):
global signal_received
signal_received = 1
def execute_cmd(cmd, filename, lineno, stdout_log=False, debug=False):
'''
Executes a command, checks for segfaults and returns the command exit
code.
:param cmd: string with the command to be executed
:param filename: name of the file tested (used for print_error purposes)
:param lineno: line number being tested (used for print_error purposes)
:param stdout_log: redirect stdout to this file instead of global log_file
:param debug: temporarily set these debug flags
'''
global log_file
print("command: {}".format(cmd), file=log_file)
if debug_option:
print(cmd)
if debug:
debug_old = nftables.get_debug()
nftables.set_debug(debug)
ret, out, err = nftables.cmd(cmd)
if not stdout_log:
stdout_log = log_file
stdout_log.write(out)
stdout_log.flush()
log_file.write(err)
log_file.flush()
if debug:
nftables.set_debug(debug_old)
return ret
def print_result(filename, tests, warning, error):
return str(filename) + ": " + str(tests) + " unit tests, " + str(error) + \
" error, " + str(warning) + " warning"
def print_result_all(filename, tests, warning, error, unit_tests):
return str(filename) + ": " + str(tests) + " unit tests, " + \
str(unit_tests) + " total test executed, " + str(error) + \
" error, " + str(warning) + " warning"
def table_process(table_line, filename, lineno):
table_info = table_line.split(";")
table = Table(table_info[0], table_info[1], table_info[2].split(","))
return table_create(table, filename, lineno)
def chain_process(chain_line, lineno):
chain_info = chain_line.split(";")
chain_list.append(Chain(chain_info[0], chain_info[1], lineno))
return 0
def set_process(set_line, filename, lineno):
test_result = set_line[1]
timeout=""
tokens = set_line[0].split(" ")
set_name = tokens[0]
set_type = tokens[2]
set_flags = ""
i = 3
while len(tokens) > i and tokens[i] == ".":
set_type += " . " + tokens[i+1]
i += 2
if len(tokens) == i+2 and tokens[i] == "timeout":
timeout = "timeout " + tokens[i+1] + ";"
i += 2
if len(tokens) == i+2 and tokens[i] == "flags":
set_flags = tokens[i+1]
elif len(tokens) != i:
print_error(set_name + " bad flag: " + tokens[i], filename, lineno)
s = Set("", "", set_name, set_type, timeout, set_flags)
ret = set_add(s, test_result, filename, lineno)
if ret == 0:
all_set[set_name] = set()
return ret
def set_element_process(element_line, filename, lineno):
rule_state = element_line[1]
element_line = element_line[0]
space = element_line.find(" ")
set_name = element_line[:space]
set_element = element_line[space:].split(",")
return set_add_elements(set_element, set_name, rule_state, filename, lineno)
def obj_process(obj_line, filename, lineno):
test_result = obj_line[1]
tokens = obj_line[0].split(" ")
obj_name = tokens[0]
obj_type = tokens[2]
obj_spcf = ""
if obj_type == "ct" and tokens[3] == "helper":
obj_type = "ct helper"
tokens[3] = ""
if obj_type == "ct" and tokens[3] == "timeout":
obj_type = "ct timeout"
tokens[3] = ""
if obj_type == "ct" and tokens[3] == "expectation":
obj_type = "ct expectation"
tokens[3] = ""
if len(tokens) > 3:
obj_spcf = " ".join(tokens[3:])
o = Obj("", "", obj_name, obj_type, obj_spcf)
ret = obj_add(o, test_result, filename, lineno)
if ret == 0:
obj_list.append(o)
return ret
def payload_find_expected(payload_log, rule):
'''
Find the netlink payload that should be generated by given rule in
payload_log
:param payload_log: open file handle of the payload data
:param rule: nft rule we are going to add
'''
found = 0
payload_buffer = []
while True:
line = payload_log.readline()
if not line:
break
if line[0] == "#": # rule start
rule_line = line.strip()[2:]
if rule_line == rule.strip():
found = 1
continue
if found == 1:
payload_buffer.append(line)
if line.isspace():
return payload_buffer
payload_log.seek(0, 0)
return payload_buffer
def json_find_expected(json_log, rule):
'''
Find the corresponding JSON for given rule
:param json_log: open file handle of the json data
:param rule: nft rule we are going to add
'''
found = 0
json_buffer = ""
while True:
line = json_log.readline()
if not line:
break
if line[0] == "#": # rule start
rule_line = line.strip()[2:]
if rule_line == rule.strip():
found = 1
continue
if found == 1:
json_buffer += line.rstrip("\n").strip()
if line.isspace():
return json_buffer
json_log.seek(0, 0)
return json_buffer
def run_test_file(filename, force_all_family_option, specific_file):
'''
Runs a test file
:param filename: name of the file with the test rules
'''
filename_path = os.path.join(TESTS_PATH, filename)
f = open(filename_path)
tests = passed = total_unit_run = total_warning = total_error = 0
for lineno, line in enumerate(f):
sys.stdout.flush()
if signal_received == 1:
print("\nSignal received. Cleaning up and Exitting...")
cleanup_on_exit()
sys.exit(0)
if line.isspace():
continue
if line[0] == "#": # Command-line
continue
if line[0] == '*': # Table
table_line = line.rstrip()[1:]
ret = table_process(table_line, filename, lineno)
if ret != 0:
break
continue
if line[0] == ":": # Chain
chain_line = line.rstrip()[1:]
ret = chain_process(chain_line, lineno)
if ret != 0:
break
continue
if line[0] == "!": # Adds this set
set_line = line.rstrip()[1:].split(";")
ret = set_process(set_line, filename, lineno)
tests += 1
if ret == -1:
continue
passed += 1
continue
if line[0] == "?": # Adds elements in a set
element_line = line.rstrip()[1:].split(";")
ret = set_element_process(element_line, filename, lineno)
tests += 1
if ret == -1:
continue
passed += 1
continue
if line[0] == "%": # Adds this object
brace = line.rfind("}")
if brace < 0:
obj_line = line.rstrip()[1:].split(";")
else:
obj_line = (line[1:brace+1], line[brace+2:].rstrip())
ret = obj_process(obj_line, filename, lineno)
tests += 1
if ret == -1:
continue
passed += 1
continue
# Rule
rule = line.split(';') # rule[1] Ok or FAIL
if len(rule) == 1 or len(rule) > 3 or rule[1].rstrip() \
not in {"ok", "fail"}:
reason = "Skipping malformed rule test. (%s)" % line.rstrip('\n')
print_warning(reason, filename, lineno)
continue
if line[0] == "-": # Run omitted lines
if need_fix_option:
rule[0] = rule[0].rstrip()[1:].strip()
else:
continue
elif need_fix_option:
continue
result = rule_add(rule, filename, lineno, force_all_family_option,
filename_path)
tests += 1
ret = result[0]
warning = result[1]
total_warning += warning
total_error += result[2]
total_unit_run += result[3]
if ret != 0:
continue
if warning == 0: # All ok.
passed += 1
# Delete rules, sets, chains and tables
for table in table_list:
# We delete chains
for table_chain in table.chains:
chain = chain_get_by_name(table_chain)
chain_delete(chain, table, filename, lineno)
# We delete sets.
if all_set:
ret = set_delete(table, filename, lineno)
if ret != 0:
reason = "There is a problem when we delete a set"
print_error(reason, filename, lineno)
# We delete tables.
table_delete(table, filename, lineno)
if specific_file:
if force_all_family_option:
print(print_result_all(filename, tests, total_warning, total_error,
total_unit_run))
else:
print(print_result(filename, tests, total_warning, total_error))
else:
if tests == passed and tests > 0:
print(filename + ": " + Colors.GREEN + "OK" + Colors.ENDC)
f.close()
del table_list[:]
del chain_list[:]
all_set.clear()
return [tests, passed, total_warning, total_error, total_unit_run]
def main():
parser = argparse.ArgumentParser(description='Run nft tests')
parser.add_argument('filenames', nargs='*', metavar='path/to/file.t',
help='Run only these tests')
parser.add_argument('-d', '--debug', action='store_true', dest='debug',
help='enable debugging mode')
parser.add_argument('-e', '--need-fix', action='store_true',
dest='need_fix_line', help='run rules that need a fix')
parser.add_argument('-f', '--force-family', action='store_true',
dest='force_all_family',
help='keep testing all families on error')
parser.add_argument('-H', '--host', action='store_true',
help='run tests against installed libnftables.so.1')
parser.add_argument('-j', '--enable-json', action='store_true',
dest='enable_json',
help='test JSON functionality as well')
parser.add_argument('-l', '--library', default=None,
help='path to libntables.so.1, overrides --host')
parser.add_argument('-s', '--schema', action='store_true',
dest='enable_schema',
help='verify json input/output against schema')
parser.add_argument('-v', '--version', action='version',
version='1.0',
help='Print the version information')
args = parser.parse_args()
global debug_option, need_fix_option, enable_json_option, enable_json_schema
debug_option = args.debug
need_fix_option = args.need_fix_line
force_all_family_option = args.force_all_family
enable_json_option = args.enable_json
enable_json_schema = args.enable_schema
specific_file = False
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
if os.getuid() != 0:
print("You need to be root to run this, sorry")
return
# Change working directory to repository root
os.chdir(TESTS_PATH + "/../..")
check_lib_path = True
if args.library is None:
if args.host:
args.library = 'libnftables.so.1'
check_lib_path = False
else:
args.library = 'src/.libs/libnftables.so.1'
if check_lib_path and not os.path.exists(args.library):
print("The nftables library at '%s' does not exist. "
"You need to build the project." % args.library)
return
if args.enable_schema and not args.enable_json:
print_error("Option --schema requires option --json")
return
global nftables
nftables = Nftables(sofile = args.library)
test_files = files_ok = run_total = 0
tests = passed = warnings = errors = 0
global log_file
try:
log_file = open(LOGFILE, 'w')
print_info("Log will be available at %s" % LOGFILE)
except IOError:
print_error("Cannot open log file %s" % LOGFILE)
return
file_list = []
if args.filenames:
file_list = args.filenames
if len(args.filenames) == 1:
specific_file = True
else:
for directory in TESTS_DIRECTORY:
path = os.path.join(TESTS_PATH, directory)
for root, dirs, files in os.walk(path):
for f in files:
if f.endswith(".t"):
file_list.append(os.path.join(directory, f))
for filename in file_list:
result = run_test_file(filename, force_all_family_option, specific_file)
file_tests = result[0]
file_passed = result[1]
file_warnings = result[2]
file_errors = result[3]
file_unit_run = result[4]
test_files += 1
if file_warnings == 0 and file_tests == file_passed:
files_ok += 1
if file_tests:
tests += file_tests
passed += file_passed
errors += file_errors
warnings += file_warnings
if force_all_family_option:
run_total += file_unit_run
if test_files == 0:
print("No test files to run")
else:
if not specific_file:
if force_all_family_option:
print("%d test files, %d files passed, %d unit tests, " % (test_files, files_ok, tests))
print("%d total executed, %d error, %d warning" % (run_total, errors,warnings))
else:
print("%d test files, %d files passed, %d unit tests, " % (test_files, files_ok, tests))
print("%d error, %d warning" % (errors, warnings))
if __name__ == '__main__':
main()