diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index 628cd60c9d0f93ac79d6f0b415cc08cbfae8ad44..667e153e6e240f588bc0f2cac68148c0728f0fc5 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -3404,9 +3404,20 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
 	struct mptcp_sock *msk = mptcp_sk(sock->sk);
 	struct mptcp_subflow_context *subflow;
 	struct socket *ssock;
-	int err;
+	int err = -EINVAL;
 
 	lock_sock(sock->sk);
+	if (uaddr) {
+		if (addr_len < sizeof(uaddr->sa_family))
+			goto unlock;
+
+		if (uaddr->sa_family == AF_UNSPEC) {
+			err = mptcp_disconnect(sock->sk, flags);
+			sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED;
+			goto unlock;
+		}
+	}
+
 	if (sock->state != SS_UNCONNECTED && msk->subflow) {
 		/* pending connection or invalid state, let existing subflow
 		 * cope with that
@@ -3416,10 +3427,8 @@ static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
 	}
 
 	ssock = __mptcp_nmpc_socket(msk);
-	if (!ssock) {
-		err = -EINVAL;
+	if (!ssock)
 		goto unlock;
-	}
 
 	mptcp_token_destroy(msk);
 	inet_sk_state_store(sock->sk, TCP_SYN_SENT);