diff --git a/net/ipv4/tcp_input.c b/net/ipv4/tcp_input.c
index e0f54b9be85021c76790697329d3c4b0a4d04b97..ff9ab3d01ced89570903d3a9f649a637c5e07a90 100644
--- a/net/ipv4/tcp_input.c
+++ b/net/ipv4/tcp_input.c
@@ -5998,6 +5998,11 @@ static bool tcp_validate_incoming(struct sock *sk, struct sk_buff *skb,
 	 * RFC 5961 4.2 : Send a challenge ack
 	 */
 	if (th->syn) {
+		if (sk->sk_state == TCP_SYN_RECV && sk->sk_socket && th->ack &&
+		    TCP_SKB_CB(skb)->seq + 1 == TCP_SKB_CB(skb)->end_seq &&
+		    TCP_SKB_CB(skb)->seq + 1 == tp->rcv_nxt &&
+		    TCP_SKB_CB(skb)->ack_seq == tp->snd_nxt)
+			goto pass;
 syn_challenge:
 		if (syn_inerr)
 			TCP_INC_STATS(sock_net(sk), TCP_MIB_INERRS);
@@ -6007,6 +6012,7 @@ static bool tcp_validate_incoming(struct sock *sk, struct sk_buff *skb,
 		goto discard;
 	}
 
+pass:
 	bpf_skops_parse_hdr(sk, skb);
 
 	return true;
@@ -6813,6 +6819,9 @@ tcp_rcv_state_process(struct sock *sk, struct sk_buff *skb)
 		tcp_fast_path_on(tp);
 		if (sk->sk_shutdown & SEND_SHUTDOWN)
 			tcp_shutdown(sk, SEND_SHUTDOWN);
+
+		if (sk->sk_socket)
+			goto consume;
 		break;
 
 	case TCP_FIN_WAIT1: {
diff --git a/tools/testing/selftests/net/tcp_ao/self-connect.c b/tools/testing/selftests/net/tcp_ao/self-connect.c
index e154d9e198a993c45d16a71133d106046f83d3f9..a5698b0a37188ece1490242f8f7d71f81c9b6b85 100644
--- a/tools/testing/selftests/net/tcp_ao/self-connect.c
+++ b/tools/testing/selftests/net/tcp_ao/self-connect.c
@@ -30,8 +30,6 @@ static void setup_lo_intf(const char *lo_intf)
 static void tcp_self_connect(const char *tst, unsigned int port,
 			     bool different_keyids, bool check_restore)
 {
-	uint64_t before_challenge_ack, after_challenge_ack;
-	uint64_t before_syn_challenge, after_syn_challenge;
 	struct tcp_ao_counters before_ao, after_ao;
 	uint64_t before_aogood, after_aogood;
 	struct netstat *ns_before, *ns_after;
@@ -62,8 +60,6 @@ static void tcp_self_connect(const char *tst, unsigned int port,
 
 	ns_before = netstat_read();
 	before_aogood = netstat_get(ns_before, "TCPAOGood", NULL);
-	before_challenge_ack = netstat_get(ns_before, "TCPChallengeACK", NULL);
-	before_syn_challenge = netstat_get(ns_before, "TCPSYNChallenge", NULL);
 	if (test_get_tcp_ao_counters(sk, &before_ao))
 		test_error("test_get_tcp_ao_counters()");
 
@@ -82,8 +78,6 @@ static void tcp_self_connect(const char *tst, unsigned int port,
 
 	ns_after = netstat_read();
 	after_aogood = netstat_get(ns_after, "TCPAOGood", NULL);
-	after_challenge_ack = netstat_get(ns_after, "TCPChallengeACK", NULL);
-	after_syn_challenge = netstat_get(ns_after, "TCPSYNChallenge", NULL);
 	if (test_get_tcp_ao_counters(sk, &after_ao))
 		test_error("test_get_tcp_ao_counters()");
 	if (!check_restore) {
@@ -98,18 +92,6 @@ static void tcp_self_connect(const char *tst, unsigned int port,
 		close(sk);
 		return;
 	}
-	if (after_challenge_ack <= before_challenge_ack ||
-	    after_syn_challenge <= before_syn_challenge) {
-		/*
-		 * It's also meant to test simultaneous open, so check
-		 * these counters as well.
-		 */
-		test_fail("%s: Didn't challenge SYN or ACK: %zu <= %zu OR %zu <= %zu",
-			  tst, after_challenge_ack, before_challenge_ack,
-			  after_syn_challenge, before_syn_challenge);
-		close(sk);
-		return;
-	}
 
 	if (test_tcp_ao_counters_cmp(tst, &before_ao, &after_ao, TEST_CNT_GOOD)) {
 		close(sk);