diff --git a/include/net/netlink.h b/include/net/netlink.h index e658d18afa67..4418b1981e31 100644 --- a/include/net/netlink.h +++ b/include/net/netlink.h @@ -325,6 +325,7 @@ struct nla_policy { struct netlink_range_validation_signed *range_signed; struct { s16 min, max; + u8 network_byte_order:1; }; int (*validate)(const struct nlattr *attr, struct netlink_ext_ack *extack); @@ -418,6 +419,14 @@ struct nla_policy { .type = NLA_ENSURE_INT_OR_BINARY_TYPE(tp), \ .validation_type = NLA_VALIDATE_MAX, \ .max = _max, \ + .network_byte_order = 0, \ +} + +#define NLA_POLICY_MAX_BE(tp, _max) { \ + .type = NLA_ENSURE_UINT_TYPE(tp), \ + .validation_type = NLA_VALIDATE_MAX, \ + .max = _max, \ + .network_byte_order = 1, \ } #define NLA_POLICY_MASK(tp, _mask) { \ diff --git a/lib/nlattr.c b/lib/nlattr.c index 86029ad5ead4..40f22b177d69 100644 --- a/lib/nlattr.c +++ b/lib/nlattr.c @@ -159,6 +159,31 @@ void nla_get_range_unsigned(const struct nla_policy *pt, } } +static u64 nla_get_attr_bo(const struct nla_policy *pt, + const struct nlattr *nla) +{ + switch (pt->type) { + case NLA_U16: + if (pt->network_byte_order) + return ntohs(nla_get_be16(nla)); + + return nla_get_u16(nla); + case NLA_U32: + if (pt->network_byte_order) + return ntohl(nla_get_be32(nla)); + + return nla_get_u32(nla); + case NLA_U64: + if (pt->network_byte_order) + return be64_to_cpu(nla_get_be64(nla)); + + return nla_get_u64(nla); + } + + WARN_ON_ONCE(1); + return 0; +} + static int nla_validate_range_unsigned(const struct nla_policy *pt, const struct nlattr *nla, struct netlink_ext_ack *extack, @@ -172,12 +197,10 @@ static int nla_validate_range_unsigned(const struct nla_policy *pt, value = nla_get_u8(nla); break; case NLA_U16: - value = nla_get_u16(nla); - break; case NLA_U32: - value = nla_get_u32(nla); - break; case NLA_U64: + value = nla_get_attr_bo(pt, nla); + break; case NLA_MSECS: value = nla_get_u64(nla); break;