diff --git a/include/expression.h b/include/expression.h index ee726aa..2e41aa0 100644 --- a/include/expression.h +++ b/include/expression.h @@ -460,6 +460,7 @@ extern int set_to_intervals(struct list_head *msgs, struct set *set, struct expr *init, bool add, unsigned int debug_mask, bool merge, struct output_ctx *octx); +extern void concat_range_aggregate(struct expr *set); extern void interval_map_decompose(struct expr *set); extern struct expr *get_set_intervals(const struct set *set, diff --git a/include/rule.h b/include/rule.h index c03b0b8..626973e 100644 --- a/include/rule.h +++ b/include/rule.h @@ -372,6 +372,11 @@ static inline bool set_is_interval(uint32_t set_flags) return set_flags & NFT_SET_INTERVAL; } +static inline bool set_is_non_concat_range(struct set *s) +{ + return (s->flags & NFT_SET_INTERVAL) && s->desc.field_count <= 1; +} + #include struct counter { diff --git a/src/evaluate.c b/src/evaluate.c index 58f458d..0c84816 100644 --- a/src/evaluate.c +++ b/src/evaluate.c @@ -136,6 +136,11 @@ static int byteorder_conversion(struct eval_ctx *ctx, struct expr **expr, if ((*expr)->byteorder == byteorder) return 0; + + /* Conversion for EXPR_CONCAT is handled for single composing ranges */ + if ((*expr)->etype == EXPR_CONCAT) + return 0; + if (expr_basetype(*expr)->type != TYPE_INTEGER) return expr_error(ctx->msgs, *expr, "Byteorder mismatch: expected %s, got %s", diff --git a/src/netlink.c b/src/netlink.c index 83d863c..e0ba903 100644 --- a/src/netlink.c +++ b/src/netlink.c @@ -98,10 +98,11 @@ struct nftnl_expr *alloc_nft_expr(const char *name) static struct nftnl_set_elem *alloc_nftnl_setelem(const struct expr *set, const struct expr *expr) { - const struct expr *elem, *key, *data; + const struct expr *elem, *data; struct nftnl_set_elem *nlse; struct nft_data_linearize nld; struct nftnl_udata_buf *udbuf = NULL; + struct expr *key; nlse = nftnl_set_elem_alloc(); if (nlse == NULL) @@ -119,6 +120,16 @@ static struct nftnl_set_elem *alloc_nftnl_setelem(const struct expr *set, netlink_gen_data(key, &nld); nftnl_set_elem_set(nlse, NFTNL_SET_ELEM_KEY, &nld.value, nld.len); + + if (set->set_flags & NFT_SET_INTERVAL && expr->key->field_count > 1) { + key->flags |= EXPR_F_INTERVAL_END; + netlink_gen_data(key, &nld); + key->flags &= ~EXPR_F_INTERVAL_END; + + nftnl_set_elem_set(nlse, NFTNL_SET_ELEM_KEY_END, &nld.value, + nld.len); + } + if (elem->timeout) nftnl_set_elem_set_u64(nlse, NFTNL_SET_ELEM_TIMEOUT, elem->timeout); @@ -186,28 +197,58 @@ void netlink_gen_raw_data(const mpz_t value, enum byteorder byteorder, data->len = len; } +static int netlink_export_pad(unsigned char *data, const mpz_t v, + const struct expr *i) +{ + mpz_export_data(data, v, i->byteorder, + div_round_up(i->len, BITS_PER_BYTE)); + + return netlink_padded_len(i->len) / BITS_PER_BYTE; +} + +static int netlink_gen_concat_data_expr(int end, const struct expr *i, + unsigned char *data) +{ + switch (i->etype) { + case EXPR_RANGE: + i = end ? i->right : i->left; + break; + case EXPR_PREFIX: + if (end) { + int count; + mpz_t v; + + mpz_init_bitmask(v, i->len - i->prefix_len); + mpz_add(v, i->prefix->value, v); + count = netlink_export_pad(data, v, i); + mpz_clear(v); + return count; + } + return netlink_export_pad(data, i->prefix->value, i); + case EXPR_VALUE: + break; + default: + BUG("invalid expression type '%s' in set", expr_ops(i)->name); + } + + return netlink_export_pad(data, i->value, i); +} + static void netlink_gen_concat_data(const struct expr *expr, struct nft_data_linearize *nld) { + unsigned int len = expr->len / BITS_PER_BYTE, offset = 0; + int end = expr->flags & EXPR_F_INTERVAL_END; + unsigned char data[len]; const struct expr *i; - unsigned int len, offset; - - len = expr->len / BITS_PER_BYTE; - if (1) { - unsigned char data[len]; - - memset(data, 0, sizeof(data)); - offset = 0; - list_for_each_entry(i, &expr->expressions, list) { - assert(i->etype == EXPR_VALUE); - mpz_export_data(data + offset, i->value, i->byteorder, - div_round_up(i->len, BITS_PER_BYTE)); - offset += netlink_padded_len(i->len) / BITS_PER_BYTE; - } - memcpy(nld->value, data, len); - nld->len = len; - } + memset(data, 0, len); + + list_for_each_entry(i, &expr->expressions, list) + offset += netlink_gen_concat_data_expr(end, i, data + offset); + + memcpy(nld->value, data, len); + nld->len = len; } static void netlink_gen_constant_data(const struct expr *expr, @@ -812,6 +853,7 @@ int netlink_delinearize_setelem(struct nftnl_set_elem *nlse, if (nftnl_set_elem_is_set(nlse, NFTNL_SET_ELEM_FLAGS)) flags = nftnl_set_elem_get_u32(nlse, NFTNL_SET_ELEM_FLAGS); +key_end: key = netlink_alloc_value(&netlink_location, &nld); datatype_set(key, set->key->dtype); key->byteorder = set->key->byteorder; @@ -880,6 +922,15 @@ int netlink_delinearize_setelem(struct nftnl_set_elem *nlse, } out: compound_expr_add(set->init, expr); + + if (!(flags & NFT_SET_ELEM_INTERVAL_END) && + nftnl_set_elem_is_set(nlse, NFTNL_SET_ELEM_KEY_END)) { + flags |= NFT_SET_ELEM_INTERVAL_END; + nld.value = nftnl_set_elem_get(nlse, NFTNL_SET_ELEM_KEY_END, + &nld.len); + goto key_end; + } + return 0; } @@ -918,15 +969,16 @@ int netlink_list_setelems(struct netlink_ctx *ctx, const struct handle *h, set->init = set_expr_alloc(&internal_location, set); nftnl_set_elem_foreach(nls, list_setelem_cb, ctx); - if (!(set->flags & NFT_SET_INTERVAL)) + if (set->flags & NFT_SET_INTERVAL && set->desc.field_count > 1) + concat_range_aggregate(set->init); + else if (set->flags & NFT_SET_INTERVAL) + interval_map_decompose(set->init); + else list_expr_sort(&ctx->set->init->expressions); nftnl_set_free(nls); ctx->set = NULL; - if (set->flags & NFT_SET_INTERVAL) - interval_map_decompose(set->init); - return 0; } @@ -935,6 +987,7 @@ int netlink_get_setelem(struct netlink_ctx *ctx, const struct handle *h, struct set *set, struct expr *init) { struct nftnl_set *nls, *nls_out = NULL; + int err = 0; nls = nftnl_set_alloc(); if (nls == NULL) @@ -958,18 +1011,18 @@ int netlink_get_setelem(struct netlink_ctx *ctx, const struct handle *h, set->init = set_expr_alloc(loc, set); nftnl_set_elem_foreach(nls_out, list_setelem_cb, ctx); - if (!(set->flags & NFT_SET_INTERVAL)) + if (set->flags & NFT_SET_INTERVAL && set->desc.field_count > 1) + concat_range_aggregate(set->init); + else if (set->flags & NFT_SET_INTERVAL) + err = get_set_decompose(table, set); + else list_expr_sort(&ctx->set->init->expressions); nftnl_set_free(nls); nftnl_set_free(nls_out); ctx->set = NULL; - if (set->flags & NFT_SET_INTERVAL && - get_set_decompose(table, set) < 0) - return -1; - - return 0; + return err; } void netlink_dump_obj(struct nftnl_obj *nln, struct netlink_ctx *ctx) diff --git a/src/parser_bison.y b/src/parser_bison.y index 0fd9b94..ea83f52 100644 --- a/src/parser_bison.y +++ b/src/parser_bison.y @@ -3551,7 +3551,6 @@ range_rhs_expr : basic_rhs_expr DASH basic_rhs_expr multiton_rhs_expr : prefix_rhs_expr | range_rhs_expr - | wildcard_expr ; map_expr : concat_expr MAP rhs_expr @@ -3645,7 +3644,7 @@ set_elem_option : TIMEOUT time_spec ; set_lhs_expr : concat_rhs_expr - | multiton_rhs_expr + | wildcard_expr ; set_rhs_expr : concat_rhs_expr @@ -3898,7 +3897,7 @@ list_rhs_expr : basic_rhs_expr COMMA basic_rhs_expr ; rhs_expr : concat_rhs_expr { $$ = $1; } - | multiton_rhs_expr { $$ = $1; } + | wildcard_expr { $$ = $1; } | set_expr { $$ = $1; } | set_ref_symbol_expr { $$ = $1; } ; @@ -3939,7 +3938,17 @@ basic_rhs_expr : inclusive_or_rhs_expr ; concat_rhs_expr : basic_rhs_expr - | concat_rhs_expr DOT basic_rhs_expr + | multiton_rhs_expr + | concat_rhs_expr DOT multiton_rhs_expr + { + struct location rhs[] = { + [1] = @2, + [2] = @3, + }; + + $$ = handle_concat_expr(&@$, $$, $1, $3, rhs); + } + | concat_rhs_expr DOT basic_rhs_expr { struct location rhs[] = { [1] = @2, diff --git a/src/rule.c b/src/rule.c index 4669577..e18237b 100644 --- a/src/rule.c +++ b/src/rule.c @@ -1512,7 +1512,8 @@ static int __do_add_setelems(struct netlink_ctx *ctx, struct set *set, return -1; if (set->init != NULL && - set->flags & NFT_SET_INTERVAL) { + set->flags & NFT_SET_INTERVAL && + set->desc.field_count <= 1) { interval_map_decompose(expr); list_splice_tail_init(&expr->expressions, &set->init->expressions); set->init->size += expr->size; @@ -1533,7 +1534,7 @@ static int do_add_setelems(struct netlink_ctx *ctx, struct cmd *cmd, table = table_lookup(h, &ctx->nft->cache); set = set_lookup(table, h->set.name); - if (set->flags & NFT_SET_INTERVAL && + if (set_is_non_concat_range(set) && set_to_intervals(ctx->msgs, set, init, true, ctx->nft->debug_mask, set->automerge, &ctx->nft->output) < 0) @@ -1548,7 +1549,7 @@ static int do_add_set(struct netlink_ctx *ctx, const struct cmd *cmd, struct set *set = cmd->set; if (set->init != NULL) { - if (set->flags & NFT_SET_INTERVAL && + if (set_is_non_concat_range(set) && set_to_intervals(ctx->msgs, set, set->init, true, ctx->nft->debug_mask, set->automerge, &ctx->nft->output) < 0) @@ -1634,7 +1635,7 @@ static int do_delete_setelems(struct netlink_ctx *ctx, struct cmd *cmd) table = table_lookup(h, &ctx->nft->cache); set = set_lookup(table, h->set.name); - if (set->flags & NFT_SET_INTERVAL && + if (set_is_non_concat_range(set) && set_to_intervals(ctx->msgs, set, expr, false, ctx->nft->debug_mask, set->automerge, &ctx->nft->output) < 0) @@ -2488,7 +2489,7 @@ static int do_get_setelems(struct netlink_ctx *ctx, struct cmd *cmd, set = set_lookup(table, cmd->handle.set.name); /* Create a list of elements based of what we got from command line. */ - if (set->flags & NFT_SET_INTERVAL) + if (set_is_non_concat_range(set)) init = get_set_intervals(set, cmd->expr); else init = cmd->expr; @@ -2501,7 +2502,7 @@ static int do_get_setelems(struct netlink_ctx *ctx, struct cmd *cmd, if (err >= 0) __do_list_set(ctx, cmd, table, new_set); - if (set->flags & NFT_SET_INTERVAL) + if (set_is_non_concat_range(set)) expr_free(init); set_free(new_set); diff --git a/src/segtree.c b/src/segtree.c index 7217dbc..e859f84 100644 --- a/src/segtree.c +++ b/src/segtree.c @@ -652,6 +652,11 @@ struct expr *get_set_intervals(const struct set *set, const struct expr *init) set_elem_add(set, new_init, i->key->value, i->flags, i->byteorder); break; + case EXPR_CONCAT: + compound_expr_add(new_init, expr_clone(i)); + i->flags |= EXPR_F_INTERVAL_END; + compound_expr_add(new_init, expr_clone(i)); + break; default: range_expr_value_low(low, i); set_elem_add(set, new_init, low, 0, i->byteorder); @@ -823,6 +828,9 @@ static int expr_value_cmp(const void *p1, const void *p2) struct expr *e2 = *(void * const *)p2; int ret; + if (expr_value(e1)->etype == EXPR_CONCAT) + return -1; + ret = mpz_cmp(expr_value(e1)->value, expr_value(e2)->value); if (ret == 0) { if (e1->flags & EXPR_F_INTERVAL_END) @@ -834,6 +842,115 @@ static int expr_value_cmp(const void *p1, const void *p2) return ret; } +/* Given start and end elements of a range, check if it can be represented as + * a single netmask, and if so, how long, by returning zero or a positive value. + */ +static int range_mask_len(const mpz_t start, const mpz_t end, unsigned int len) +{ + mpz_t tmp_start, tmp_end; + int ret; + + mpz_init_set_ui(tmp_start, mpz_get_ui(start)); + mpz_init_set_ui(tmp_end, mpz_get_ui(end)); + + while (mpz_cmp(tmp_start, tmp_end) <= 0 && + !mpz_tstbit(tmp_start, 0) && mpz_tstbit(tmp_end, 0) && + len--) { + mpz_fdiv_q_2exp(tmp_start, tmp_start, 1); + mpz_fdiv_q_2exp(tmp_end, tmp_end, 1); + } + + ret = !mpz_cmp(tmp_start, tmp_end) ? (int)len : -1; + + mpz_clear(tmp_start); + mpz_clear(tmp_end); + + return ret; +} + +/* Given a set with two elements (start and end), transform them into a + * concatenation of ranges. That is, from a list of start expressions and a list + * of end expressions, form a list of start - end expressions. + */ +void concat_range_aggregate(struct expr *set) +{ + struct expr *i, *start = NULL, *end, *r1, *r2, *next, *r1_next, *tmp; + struct list_head *r2_next; + int prefix_len, free_r1; + mpz_t range, p; + + list_for_each_entry_safe(i, next, &set->expressions, list) { + if (!start) { + start = i; + continue; + } + end = i; + + /* Walk over r1 (start expression) and r2 (end) in parallel, + * form ranges between corresponding r1 and r2 expressions, + * store them by replacing r2 expressions, and free r1 + * expressions. + */ + r2 = list_first_entry(&expr_value(end)->expressions, + struct expr, list); + list_for_each_entry_safe(r1, r1_next, + &expr_value(start)->expressions, + list) { + mpz_init(range); + mpz_init(p); + + r2_next = r2->list.next; + free_r1 = 0; + + if (!mpz_cmp(r1->value, r2->value)) { + free_r1 = 1; + goto next; + } + + mpz_sub(range, r2->value, r1->value); + mpz_sub_ui(range, range, 1); + mpz_and(p, r1->value, range); + + /* Check if we are forced, or if it's anyway preferable, + * to express the range as two points instead of a + * netmask. + */ + prefix_len = range_mask_len(r1->value, r2->value, + r1->len); + if (prefix_len < 0 || + !(r1->dtype->flags & DTYPE_F_PREFIX)) { + tmp = range_expr_alloc(&r1->location, r1, + r2); + + list_replace(&r2->list, &tmp->list); + r2_next = tmp->list.next; + } else { + tmp = prefix_expr_alloc(&r1->location, r1, + prefix_len); + tmp->len = r2->len; + + list_replace(&r2->list, &tmp->list); + r2_next = tmp->list.next; + expr_free(r2); + } + +next: + mpz_clear(p); + mpz_clear(range); + + r2 = list_entry(r2_next, typeof(*r2), list); + compound_expr_remove(start, r1); + + if (free_r1) + expr_free(r1); + } + + compound_expr_remove(set, start); + expr_free(start); + start = NULL; + } +} + void interval_map_decompose(struct expr *set) { struct expr **elements, **ranges;