diff --git a/net/sched/sch_cake.c b/net/sched/sch_cake.c index 1bfdfacb6be8a5d693593c90c08e4fd027d135f3..2116d7d4f4c07879c9dd272abe802fa0f1fcdd29 100644 --- a/net/sched/sch_cake.c +++ b/net/sched/sch_cake.c @@ -497,6 +497,52 @@ static bool cobalt_queue_empty(struct cobalt_vars *vars, return down; } +static __be16 cake_skb_proto(const struct sk_buff *skb) +{ + unsigned int offset = skb_mac_offset(skb) + sizeof(struct ethhdr); + __be16 proto = skb->protocol; + struct vlan_hdr vhdr, *vh; + + while (proto == htons(ETH_P_8021Q) || proto == htons(ETH_P_8021AD)) { + vh = skb_header_pointer(skb, offset, sizeof(vhdr), &vhdr); + if (!vh) + break; + + proto = vh->h_vlan_encapsulated_proto; + offset += sizeof(vhdr); + } + + return proto; +} + +static int cake_set_ce(struct sk_buff *skb) +{ + int wlen = skb_network_offset(skb); + + switch (cake_skb_proto(skb)) { + case htons(ETH_P_IP): + wlen += sizeof(struct iphdr); + if (!pskb_may_pull(skb, wlen) || + skb_try_make_writable(skb, wlen)) + return 0; + + return IP_ECN_set_ce(ip_hdr(skb)); + + case htons(ETH_P_IPV6): + wlen += sizeof(struct ipv6hdr); + if (!pskb_may_pull(skb, wlen) || + skb_try_make_writable(skb, wlen)) + return 0; + + return IP6_ECN_set_ce(skb, ipv6_hdr(skb)); + + default: + return 0; + } + + return 0; +} + /* Call this with a freshly dequeued packet for possible congestion marking. * Returns true as an instruction to drop the packet, false for delivery. */ @@ -549,7 +595,7 @@ static bool cobalt_should_drop(struct cobalt_vars *vars, if (next_due && vars->dropping) { /* Use ECN mark if possible, otherwise drop */ - drop = !(vars->ecn_marked = INET_ECN_set_ce(skb)); + drop = !(vars->ecn_marked = cake_set_ce(skb)); vars->count++; if (!vars->count) @@ -592,7 +638,7 @@ static bool cake_update_flowkeys(struct flow_keys *keys, bool rev = !skb->_nfct, upd = false; __be32 ip; - if (skb_protocol(skb, true) != htons(ETH_P_IP)) + if (cake_skb_proto(skb) != htons(ETH_P_IP)) return false; if (!nf_ct_get_tuple_skb(&tuple, skb)) @@ -1557,7 +1603,7 @@ static u8 cake_handle_diffserv(struct sk_buff *skb, bool wash) u16 *buf, buf_; u8 dscp; - switch (skb_protocol(skb, true)) { + switch (cake_skb_proto(skb)) { case htons(ETH_P_IP): buf = skb_header_pointer(skb, offset, sizeof(buf_), &buf_); if (unlikely(!buf))