diff --git a/drivers/vdpa/vdpa.c b/drivers/vdpa/vdpa.c
index 9846c9de4bfa2007f1d78ef3348bdac0004607e3..1ea525433a5ca17781c2f8180a2d6ae6c77842ce 100644
--- a/drivers/vdpa/vdpa.c
+++ b/drivers/vdpa/vdpa.c
@@ -393,7 +393,7 @@ static void vdpa_get_config_unlocked(struct vdpa_device *vdev,
 	 * If it does happen we assume a legacy guest.
 	 */
 	if (!vdev->features_valid)
-		vdpa_set_features(vdev, 0, true);
+		vdpa_set_features_unlocked(vdev, 0);
 	ops->get_config(vdev, offset, buf, len);
 }
 
diff --git a/drivers/vhost/vdpa.c b/drivers/vhost/vdpa.c
index 851539807bc9b53f5d18943fdba81608aa1cba79..ec5249e8c32d9d31efd707163632d8253200aa59 100644
--- a/drivers/vhost/vdpa.c
+++ b/drivers/vhost/vdpa.c
@@ -286,7 +286,7 @@ static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
 	if (copy_from_user(&features, featurep, sizeof(features)))
 		return -EFAULT;
 
-	if (vdpa_set_features(vdpa, features, false))
+	if (vdpa_set_features(vdpa, features))
 		return -EINVAL;
 
 	return 0;
diff --git a/drivers/virtio/virtio_vdpa.c b/drivers/virtio/virtio_vdpa.c
index 7767a7f0119b25196a3cf37ac52f542b100d2e88..76504559bc25f9c8d3fd73f35a23c99c1fca5983 100644
--- a/drivers/virtio/virtio_vdpa.c
+++ b/drivers/virtio/virtio_vdpa.c
@@ -317,7 +317,7 @@ static int virtio_vdpa_finalize_features(struct virtio_device *vdev)
 	/* Give virtio_ring a chance to accept features. */
 	vring_transport_features(vdev);
 
-	return vdpa_set_features(vdpa, vdev->features, false);
+	return vdpa_set_features(vdpa, vdev->features);
 }
 
 static const char *virtio_vdpa_bus_name(struct virtio_device *vdev)
diff --git a/include/linux/vdpa.h b/include/linux/vdpa.h
index 2de442ececae47d543f1aa03e636b667900c3b16..721089bb4c8490b685ec356f35cf66e7b60f93e0 100644
--- a/include/linux/vdpa.h
+++ b/include/linux/vdpa.h
@@ -401,18 +401,24 @@ static inline int vdpa_reset(struct vdpa_device *vdev)
 	return ret;
 }
 
-static inline int vdpa_set_features(struct vdpa_device *vdev, u64 features, bool locked)
+static inline int vdpa_set_features_unlocked(struct vdpa_device *vdev, u64 features)
 {
 	const struct vdpa_config_ops *ops = vdev->config;
 	int ret;
 
-	if (!locked)
-		mutex_lock(&vdev->cf_mutex);
-
 	vdev->features_valid = true;
 	ret = ops->set_driver_features(vdev, features);
-	if (!locked)
-		mutex_unlock(&vdev->cf_mutex);
+
+	return ret;
+}
+
+static inline int vdpa_set_features(struct vdpa_device *vdev, u64 features)
+{
+	int ret;
+
+	mutex_lock(&vdev->cf_mutex);
+	ret = vdpa_set_features_unlocked(vdev, features);
+	mutex_unlock(&vdev->cf_mutex);
 
 	return ret;
 }