diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c
index 229396172026a4ecc035c5fa7fffcbb9ddda24a3..5241ba671c5a3cadf6c92e617a8cb13e806be541 100644
--- a/kernel/bpf/helpers.c
+++ b/kernel/bpf/helpers.c
@@ -2734,7 +2734,7 @@ __bpf_kfunc int bpf_wq_start(struct bpf_wq *wq, unsigned int flags)
 }
 
 __bpf_kfunc int bpf_wq_set_callback_impl(struct bpf_wq *wq,
-					 int (callback_fn)(void *map, int *key, struct bpf_wq *wq),
+					 int (callback_fn)(void *map, int *key, void *value),
 					 unsigned int flags,
 					 void *aux__ign)
 {
diff --git a/tools/testing/selftests/bpf/bpf_experimental.h b/tools/testing/selftests/bpf/bpf_experimental.h
index eede6fc2ccb4043d78591046a0e2bc7c09c014d5..828556cdc2f0d461c994821655b31f817a114553 100644
--- a/tools/testing/selftests/bpf/bpf_experimental.h
+++ b/tools/testing/selftests/bpf/bpf_experimental.h
@@ -552,7 +552,7 @@ extern void bpf_iter_css_destroy(struct bpf_iter_css *it) __weak __ksym;
 extern int bpf_wq_init(struct bpf_wq *wq, void *p__map, unsigned int flags) __weak __ksym;
 extern int bpf_wq_start(struct bpf_wq *wq, unsigned int flags) __weak __ksym;
 extern int bpf_wq_set_callback_impl(struct bpf_wq *wq,
-		int (callback_fn)(void *map, int *key, struct bpf_wq *wq),
+		int (callback_fn)(void *map, int *key, void *value),
 		unsigned int flags__k, void *aux__ign) __ksym;
 #define bpf_wq_set_callback(timer, cb, flags) \
 	bpf_wq_set_callback_impl(timer, cb, flags, NULL)
diff --git a/tools/testing/selftests/bpf/progs/wq.c b/tools/testing/selftests/bpf/progs/wq.c
index 49e712acbf60042f91e2ff4f49cd7920d358426f..f8d3ae0c29aeb3da58c604f6a2eb6aeb3c90e365 100644
--- a/tools/testing/selftests/bpf/progs/wq.c
+++ b/tools/testing/selftests/bpf/progs/wq.c
@@ -32,6 +32,7 @@ struct {
 } hmap_malloc SEC(".maps");
 
 struct elem {
+	int ok_offset;
 	struct bpf_wq w;
 };
 
@@ -53,7 +54,7 @@ __u32 ok;
 __u32 ok_sleepable;
 
 static int test_elem_callback(void *map, int *key,
-		int (callback_fn)(void *map, int *key, struct bpf_wq *wq))
+		int (callback_fn)(void *map, int *key, void *value))
 {
 	struct elem init = {}, *val;
 	struct bpf_wq *wq;
@@ -70,6 +71,8 @@ static int test_elem_callback(void *map, int *key,
 	if (!val)
 		return -2;
 
+	val->ok_offset = *key;
+
 	wq = &val->w;
 	if (bpf_wq_init(wq, map, 0) != 0)
 		return -3;
@@ -84,7 +87,7 @@ static int test_elem_callback(void *map, int *key,
 }
 
 static int test_hmap_elem_callback(void *map, int *key,
-		int (callback_fn)(void *map, int *key, struct bpf_wq *wq))
+		int (callback_fn)(void *map, int *key, void *value))
 {
 	struct hmap_elem init = {}, *val;
 	struct bpf_wq *wq;
@@ -114,7 +117,7 @@ static int test_hmap_elem_callback(void *map, int *key,
 }
 
 /* callback for non sleepable workqueue */
-static int wq_callback(void *map, int *key, struct bpf_wq *work)
+static int wq_callback(void *map, int *key, void *value)
 {
 	bpf_kfunc_common_test();
 	ok |= (1 << *key);
@@ -122,10 +125,16 @@ static int wq_callback(void *map, int *key, struct bpf_wq *work)
 }
 
 /* callback for sleepable workqueue */
-static int wq_cb_sleepable(void *map, int *key, struct bpf_wq *work)
+static int wq_cb_sleepable(void *map, int *key, void *value)
 {
+	struct elem *data = (struct elem *)value;
+	int offset = data->ok_offset;
+
+	if (*key != offset)
+		return 0;
+
 	bpf_kfunc_call_test_sleepable();
-	ok_sleepable |= (1 << *key);
+	ok_sleepable |= (1 << offset);
 	return 0;
 }
 
diff --git a/tools/testing/selftests/bpf/progs/wq_failures.c b/tools/testing/selftests/bpf/progs/wq_failures.c
index 4cbdb425f223d3354619a612c486432b321e59ed..25b51a72fe0fe6d0f1253e22e204aaf6c1946fc6 100644
--- a/tools/testing/selftests/bpf/progs/wq_failures.c
+++ b/tools/testing/selftests/bpf/progs/wq_failures.c
@@ -28,14 +28,14 @@ struct {
 } lru SEC(".maps");
 
 /* callback for non sleepable workqueue */
-static int wq_callback(void *map, int *key, struct bpf_wq *work)
+static int wq_callback(void *map, int *key, void *value)
 {
 	bpf_kfunc_common_test();
 	return 0;
 }
 
 /* callback for sleepable workqueue */
-static int wq_cb_sleepable(void *map, int *key, struct bpf_wq *work)
+static int wq_cb_sleepable(void *map, int *key, void *value)
 {
 	bpf_kfunc_call_test_sleepable();
 	return 0;