diff --git a/net/sched/sch_cake.c b/net/sched/sch_cake.c
index 1bfdfacb6be8a5d693593c90c08e4fd027d135f3..2116d7d4f4c07879c9dd272abe802fa0f1fcdd29 100644
--- a/net/sched/sch_cake.c
+++ b/net/sched/sch_cake.c
@@ -497,6 +497,52 @@ static bool cobalt_queue_empty(struct cobalt_vars *vars,
 	return down;
 }
 
+static __be16 cake_skb_proto(const struct sk_buff *skb)
+{
+	unsigned int offset = skb_mac_offset(skb) + sizeof(struct ethhdr);
+	__be16 proto = skb->protocol;
+	struct vlan_hdr vhdr, *vh;
+
+	while (proto == htons(ETH_P_8021Q) || proto == htons(ETH_P_8021AD)) {
+		vh = skb_header_pointer(skb, offset, sizeof(vhdr), &vhdr);
+		if (!vh)
+			break;
+
+		proto = vh->h_vlan_encapsulated_proto;
+		offset += sizeof(vhdr);
+	}
+
+	return proto;
+}
+
+static int cake_set_ce(struct sk_buff *skb)
+{
+	int wlen = skb_network_offset(skb);
+
+	switch (cake_skb_proto(skb)) {
+	case htons(ETH_P_IP):
+		wlen += sizeof(struct iphdr);
+		if (!pskb_may_pull(skb, wlen) ||
+		    skb_try_make_writable(skb, wlen))
+			return 0;
+
+		return IP_ECN_set_ce(ip_hdr(skb));
+
+	case htons(ETH_P_IPV6):
+		wlen += sizeof(struct ipv6hdr);
+		if (!pskb_may_pull(skb, wlen) ||
+		    skb_try_make_writable(skb, wlen))
+			return 0;
+
+		return IP6_ECN_set_ce(skb, ipv6_hdr(skb));
+
+	default:
+		return 0;
+	}
+
+	return 0;
+}
+
 /* Call this with a freshly dequeued packet for possible congestion marking.
  * Returns true as an instruction to drop the packet, false for delivery.
  */
@@ -549,7 +595,7 @@ static bool cobalt_should_drop(struct cobalt_vars *vars,
 
 	if (next_due && vars->dropping) {
 		/* Use ECN mark if possible, otherwise drop */
-		drop = !(vars->ecn_marked = INET_ECN_set_ce(skb));
+		drop = !(vars->ecn_marked = cake_set_ce(skb));
 
 		vars->count++;
 		if (!vars->count)
@@ -592,7 +638,7 @@ static bool cake_update_flowkeys(struct flow_keys *keys,
 	bool rev = !skb->_nfct, upd = false;
 	__be32 ip;
 
-	if (skb_protocol(skb, true) != htons(ETH_P_IP))
+	if (cake_skb_proto(skb) != htons(ETH_P_IP))
 		return false;
 
 	if (!nf_ct_get_tuple_skb(&tuple, skb))
@@ -1557,7 +1603,7 @@ static u8 cake_handle_diffserv(struct sk_buff *skb, bool wash)
 	u16 *buf, buf_;
 	u8 dscp;
 
-	switch (skb_protocol(skb, true)) {
+	switch (cake_skb_proto(skb)) {
 	case htons(ETH_P_IP):
 		buf = skb_header_pointer(skb, offset, sizeof(buf_), &buf_);
 		if (unlikely(!buf))