diff --git a/include/linux/tnum.h b/include/linux/tnum.h
index 1c3948a1d6ad9b8b2fe4870c2d1321ea277adf66..3c13240077b87a70eb73559fcb8f17dd41a6da3f 100644
--- a/include/linux/tnum.h
+++ b/include/linux/tnum.h
@@ -106,6 +106,10 @@ int tnum_sbin(char *str, size_t size, struct tnum a);
 struct tnum tnum_subreg(struct tnum a);
 /* Returns the tnum with the lower 32-bit subreg cleared */
 struct tnum tnum_clear_subreg(struct tnum a);
+/* Returns the tnum with the lower 32-bit subreg in *reg* set to the lower
+ * 32-bit subreg in *subreg*
+ */
+struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg);
 /* Returns the tnum with the lower 32-bit subreg set to value */
 struct tnum tnum_const_subreg(struct tnum a, u32 value);
 /* Returns true if 32-bit subreg @a is a known constant*/
diff --git a/kernel/bpf/tnum.c b/kernel/bpf/tnum.c
index 3d7127f439a142848c9c392c75dcb967d005a0f9..f4c91c9b27d7f175daf30c11afb9ef3c650a547f 100644
--- a/kernel/bpf/tnum.c
+++ b/kernel/bpf/tnum.c
@@ -208,7 +208,12 @@ struct tnum tnum_clear_subreg(struct tnum a)
 	return tnum_lshift(tnum_rshift(a, 32), 32);
 }
 
+struct tnum tnum_with_subreg(struct tnum reg, struct tnum subreg)
+{
+	return tnum_or(tnum_clear_subreg(reg), tnum_subreg(subreg));
+}
+
 struct tnum tnum_const_subreg(struct tnum a, u32 value)
 {
-	return tnum_or(tnum_clear_subreg(a), tnum_const(value));
+	return tnum_with_subreg(a, tnum_const(value));
 }
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 9ae6eae1347164423583c200745c78d41d07d24a..39ce141c55d36b48d406a7dc11c280d48580adc5 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -14453,218 +14453,186 @@ static int is_branch_taken(struct bpf_reg_state *reg1, struct bpf_reg_state *reg
 	return is_scalar_branch_taken(reg1, reg2, opcode, is_jmp32);
 }
 
-/* Adjusts the register min/max values in the case that the dst_reg and
- * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
- * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
- * Technically we can do similar adjustments for pointers to the same object,
- * but we don't support that right now.
+/* Opcode that corresponds to a *false* branch condition.
+ * E.g., if r1 < r2, then reverse (false) condition is r1 >= r2
  */
-static void reg_set_min_max(struct bpf_reg_state *true_reg1,
-			    struct bpf_reg_state *true_reg2,
-			    struct bpf_reg_state *false_reg1,
-			    struct bpf_reg_state *false_reg2,
-			    u8 opcode, bool is_jmp32)
+static u8 rev_opcode(u8 opcode)
 {
-	struct tnum false_32off, false_64off;
-	struct tnum true_32off, true_64off;
-	u64 uval;
-	u32 uval32;
-	s64 sval;
-	s32 sval32;
-
-	/* If either register is a pointer, we can't learn anything about its
-	 * variable offset from the compare (unless they were a pointer into
-	 * the same object, but we don't bother with that).
+	switch (opcode) {
+	case BPF_JEQ:		return BPF_JNE;
+	case BPF_JNE:		return BPF_JEQ;
+	/* JSET doesn't have it's reverse opcode in BPF, so add
+	 * BPF_X flag to denote the reverse of that operation
 	 */
-	if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE)
-		return;
-
-	/* we expect right-hand registers (src ones) to be constants, for now */
-	if (!is_reg_const(false_reg2, is_jmp32)) {
-		opcode = flip_opcode(opcode);
-		swap(true_reg1, true_reg2);
-		swap(false_reg1, false_reg2);
+	case BPF_JSET:		return BPF_JSET | BPF_X;
+	case BPF_JSET | BPF_X:	return BPF_JSET;
+	case BPF_JGE:		return BPF_JLT;
+	case BPF_JGT:		return BPF_JLE;
+	case BPF_JLE:		return BPF_JGT;
+	case BPF_JLT:		return BPF_JGE;
+	case BPF_JSGE:		return BPF_JSLT;
+	case BPF_JSGT:		return BPF_JSLE;
+	case BPF_JSLE:		return BPF_JSGT;
+	case BPF_JSLT:		return BPF_JSGE;
+	default:		return 0;
 	}
-	if (!is_reg_const(false_reg2, is_jmp32))
-		return;
+}
 
-	false_32off = tnum_subreg(false_reg1->var_off);
-	false_64off = false_reg1->var_off;
-	true_32off = tnum_subreg(true_reg1->var_off);
-	true_64off = true_reg1->var_off;
-	uval = false_reg2->var_off.value;
-	uval32 = (u32)tnum_subreg(false_reg2->var_off).value;
-	sval = (s64)uval;
-	sval32 = (s32)uval32;
+/* Refine range knowledge for <reg1> <op> <reg>2 conditional operation. */
+static void regs_refine_cond_op(struct bpf_reg_state *reg1, struct bpf_reg_state *reg2,
+				u8 opcode, bool is_jmp32)
+{
+	struct tnum t;
+	u64 val;
 
+again:
 	switch (opcode) {
-	/* JEQ/JNE comparison doesn't change the register equivalence.
-	 *
-	 * r1 = r2;
-	 * if (r1 == 42) goto label;
-	 * ...
-	 * label: // here both r1 and r2 are known to be 42.
-	 *
-	 * Hence when marking register as known preserve it's ID.
-	 */
 	case BPF_JEQ:
 		if (is_jmp32) {
-			__mark_reg32_known(true_reg1, uval32);
-			true_32off = tnum_subreg(true_reg1->var_off);
+			reg1->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
+			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
+			reg1->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
+			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
+			reg2->u32_min_value = reg1->u32_min_value;
+			reg2->u32_max_value = reg1->u32_max_value;
+			reg2->s32_min_value = reg1->s32_min_value;
+			reg2->s32_max_value = reg1->s32_max_value;
+
+			t = tnum_intersect(tnum_subreg(reg1->var_off), tnum_subreg(reg2->var_off));
+			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
+			reg2->var_off = tnum_with_subreg(reg2->var_off, t);
 		} else {
-			___mark_reg_known(true_reg1, uval);
-			true_64off = true_reg1->var_off;
+			reg1->umin_value = max(reg1->umin_value, reg2->umin_value);
+			reg1->umax_value = min(reg1->umax_value, reg2->umax_value);
+			reg1->smin_value = max(reg1->smin_value, reg2->smin_value);
+			reg1->smax_value = min(reg1->smax_value, reg2->smax_value);
+			reg2->umin_value = reg1->umin_value;
+			reg2->umax_value = reg1->umax_value;
+			reg2->smin_value = reg1->smin_value;
+			reg2->smax_value = reg1->smax_value;
+
+			reg1->var_off = tnum_intersect(reg1->var_off, reg2->var_off);
+			reg2->var_off = reg1->var_off;
 		}
 		break;
 	case BPF_JNE:
-		if (is_jmp32) {
-			__mark_reg32_known(false_reg1, uval32);
-			false_32off = tnum_subreg(false_reg1->var_off);
-		} else {
-			___mark_reg_known(false_reg1, uval);
-			false_64off = false_reg1->var_off;
-		}
+		/* we don't derive any new information for inequality yet */
 		break;
 	case BPF_JSET:
+		if (!is_reg_const(reg2, is_jmp32))
+			swap(reg1, reg2);
+		if (!is_reg_const(reg2, is_jmp32))
+			break;
+		val = reg_const_value(reg2, is_jmp32);
+		/* BPF_JSET (i.e., TRUE branch, *not* BPF_JSET | BPF_X)
+		 * requires single bit to learn something useful. E.g., if we
+		 * know that `r1 & 0x3` is true, then which bits (0, 1, or both)
+		 * are actually set? We can learn something definite only if
+		 * it's a single-bit value to begin with.
+		 *
+		 * BPF_JSET | BPF_X (i.e., negation of BPF_JSET) doesn't have
+		 * this restriction. I.e., !(r1 & 0x3) means neither bit 0 nor
+		 * bit 1 is set, which we can readily use in adjustments.
+		 */
+		if (!is_power_of_2(val))
+			break;
 		if (is_jmp32) {
-			false_32off = tnum_and(false_32off, tnum_const(~uval32));
-			if (is_power_of_2(uval32))
-				true_32off = tnum_or(true_32off,
-						     tnum_const(uval32));
+			t = tnum_or(tnum_subreg(reg1->var_off), tnum_const(val));
+			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
 		} else {
-			false_64off = tnum_and(false_64off, tnum_const(~uval));
-			if (is_power_of_2(uval))
-				true_64off = tnum_or(true_64off,
-						     tnum_const(uval));
+			reg1->var_off = tnum_or(reg1->var_off, tnum_const(val));
 		}
 		break;
-	case BPF_JGE:
-	case BPF_JGT:
-	{
+	case BPF_JSET | BPF_X: /* reverse of BPF_JSET, see rev_opcode() */
+		if (!is_reg_const(reg2, is_jmp32))
+			swap(reg1, reg2);
+		if (!is_reg_const(reg2, is_jmp32))
+			break;
+		val = reg_const_value(reg2, is_jmp32);
 		if (is_jmp32) {
-			u32 false_umax = opcode == BPF_JGT ? uval32  : uval32 - 1;
-			u32 true_umin = opcode == BPF_JGT ? uval32 + 1 : uval32;
-
-			false_reg1->u32_max_value = min(false_reg1->u32_max_value,
-						       false_umax);
-			true_reg1->u32_min_value = max(true_reg1->u32_min_value,
-						      true_umin);
+			t = tnum_and(tnum_subreg(reg1->var_off), tnum_const(~val));
+			reg1->var_off = tnum_with_subreg(reg1->var_off, t);
 		} else {
-			u64 false_umax = opcode == BPF_JGT ? uval    : uval - 1;
-			u64 true_umin = opcode == BPF_JGT ? uval + 1 : uval;
-
-			false_reg1->umax_value = min(false_reg1->umax_value, false_umax);
-			true_reg1->umin_value = max(true_reg1->umin_value, true_umin);
+			reg1->var_off = tnum_and(reg1->var_off, tnum_const(~val));
 		}
 		break;
-	}
-	case BPF_JSGE:
-	case BPF_JSGT:
-	{
+	case BPF_JLE:
 		if (is_jmp32) {
-			s32 false_smax = opcode == BPF_JSGT ? sval32    : sval32 - 1;
-			s32 true_smin = opcode == BPF_JSGT ? sval32 + 1 : sval32;
-
-			false_reg1->s32_max_value = min(false_reg1->s32_max_value, false_smax);
-			true_reg1->s32_min_value = max(true_reg1->s32_min_value, true_smin);
+			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value);
+			reg2->u32_min_value = max(reg1->u32_min_value, reg2->u32_min_value);
 		} else {
-			s64 false_smax = opcode == BPF_JSGT ? sval    : sval - 1;
-			s64 true_smin = opcode == BPF_JSGT ? sval + 1 : sval;
-
-			false_reg1->smax_value = min(false_reg1->smax_value, false_smax);
-			true_reg1->smin_value = max(true_reg1->smin_value, true_smin);
+			reg1->umax_value = min(reg1->umax_value, reg2->umax_value);
+			reg2->umin_value = max(reg1->umin_value, reg2->umin_value);
 		}
 		break;
-	}
-	case BPF_JLE:
 	case BPF_JLT:
-	{
 		if (is_jmp32) {
-			u32 false_umin = opcode == BPF_JLT ? uval32  : uval32 + 1;
-			u32 true_umax = opcode == BPF_JLT ? uval32 - 1 : uval32;
-
-			false_reg1->u32_min_value = max(false_reg1->u32_min_value,
-						       false_umin);
-			true_reg1->u32_max_value = min(true_reg1->u32_max_value,
-						      true_umax);
+			reg1->u32_max_value = min(reg1->u32_max_value, reg2->u32_max_value - 1);
+			reg2->u32_min_value = max(reg1->u32_min_value + 1, reg2->u32_min_value);
 		} else {
-			u64 false_umin = opcode == BPF_JLT ? uval    : uval + 1;
-			u64 true_umax = opcode == BPF_JLT ? uval - 1 : uval;
-
-			false_reg1->umin_value = max(false_reg1->umin_value, false_umin);
-			true_reg1->umax_value = min(true_reg1->umax_value, true_umax);
+			reg1->umax_value = min(reg1->umax_value, reg2->umax_value - 1);
+			reg2->umin_value = max(reg1->umin_value + 1, reg2->umin_value);
 		}
 		break;
-	}
 	case BPF_JSLE:
+		if (is_jmp32) {
+			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value);
+			reg2->s32_min_value = max(reg1->s32_min_value, reg2->s32_min_value);
+		} else {
+			reg1->smax_value = min(reg1->smax_value, reg2->smax_value);
+			reg2->smin_value = max(reg1->smin_value, reg2->smin_value);
+		}
+		break;
 	case BPF_JSLT:
-	{
 		if (is_jmp32) {
-			s32 false_smin = opcode == BPF_JSLT ? sval32    : sval32 + 1;
-			s32 true_smax = opcode == BPF_JSLT ? sval32 - 1 : sval32;
-
-			false_reg1->s32_min_value = max(false_reg1->s32_min_value, false_smin);
-			true_reg1->s32_max_value = min(true_reg1->s32_max_value, true_smax);
+			reg1->s32_max_value = min(reg1->s32_max_value, reg2->s32_max_value - 1);
+			reg2->s32_min_value = max(reg1->s32_min_value + 1, reg2->s32_min_value);
 		} else {
-			s64 false_smin = opcode == BPF_JSLT ? sval    : sval + 1;
-			s64 true_smax = opcode == BPF_JSLT ? sval - 1 : sval;
-
-			false_reg1->smin_value = max(false_reg1->smin_value, false_smin);
-			true_reg1->smax_value = min(true_reg1->smax_value, true_smax);
+			reg1->smax_value = min(reg1->smax_value, reg2->smax_value - 1);
+			reg2->smin_value = max(reg1->smin_value + 1, reg2->smin_value);
 		}
 		break;
-	}
+	case BPF_JGE:
+	case BPF_JGT:
+	case BPF_JSGE:
+	case BPF_JSGT:
+		/* just reuse LE/LT logic above */
+		opcode = flip_opcode(opcode);
+		swap(reg1, reg2);
+		goto again;
 	default:
 		return;
 	}
-
-	if (is_jmp32) {
-		false_reg1->var_off = tnum_or(tnum_clear_subreg(false_64off),
-					     tnum_subreg(false_32off));
-		true_reg1->var_off = tnum_or(tnum_clear_subreg(true_64off),
-					    tnum_subreg(true_32off));
-		reg_bounds_sync(false_reg1);
-		reg_bounds_sync(true_reg1);
-	} else {
-		false_reg1->var_off = false_64off;
-		true_reg1->var_off = true_64off;
-		reg_bounds_sync(false_reg1);
-		reg_bounds_sync(true_reg1);
-	}
-}
-
-/* Regs are known to be equal, so intersect their min/max/var_off */
-static void __reg_combine_min_max(struct bpf_reg_state *src_reg,
-				  struct bpf_reg_state *dst_reg)
-{
-	src_reg->umin_value = dst_reg->umin_value = max(src_reg->umin_value,
-							dst_reg->umin_value);
-	src_reg->umax_value = dst_reg->umax_value = min(src_reg->umax_value,
-							dst_reg->umax_value);
-	src_reg->smin_value = dst_reg->smin_value = max(src_reg->smin_value,
-							dst_reg->smin_value);
-	src_reg->smax_value = dst_reg->smax_value = min(src_reg->smax_value,
-							dst_reg->smax_value);
-	src_reg->var_off = dst_reg->var_off = tnum_intersect(src_reg->var_off,
-							     dst_reg->var_off);
-	reg_bounds_sync(src_reg);
-	reg_bounds_sync(dst_reg);
 }
 
-static void reg_combine_min_max(struct bpf_reg_state *true_src,
-				struct bpf_reg_state *true_dst,
-				struct bpf_reg_state *false_src,
-				struct bpf_reg_state *false_dst,
-				u8 opcode)
+/* Adjusts the register min/max values in the case that the dst_reg and
+ * src_reg are both SCALAR_VALUE registers (or we are simply doing a BPF_K
+ * check, in which case we havea fake SCALAR_VALUE representing insn->imm).
+ * Technically we can do similar adjustments for pointers to the same object,
+ * but we don't support that right now.
+ */
+static void reg_set_min_max(struct bpf_reg_state *true_reg1,
+			    struct bpf_reg_state *true_reg2,
+			    struct bpf_reg_state *false_reg1,
+			    struct bpf_reg_state *false_reg2,
+			    u8 opcode, bool is_jmp32)
 {
-	switch (opcode) {
-	case BPF_JEQ:
-		__reg_combine_min_max(true_src, true_dst);
-		break;
-	case BPF_JNE:
-		__reg_combine_min_max(false_src, false_dst);
-		break;
-	}
+	/* If either register is a pointer, we can't learn anything about its
+	 * variable offset from the compare (unless they were a pointer into
+	 * the same object, but we don't bother with that).
+	 */
+	if (false_reg1->type != SCALAR_VALUE || false_reg2->type != SCALAR_VALUE)
+		return;
+
+	/* fallthrough (FALSE) branch */
+	regs_refine_cond_op(false_reg1, false_reg2, rev_opcode(opcode), is_jmp32);
+	reg_bounds_sync(false_reg1);
+	reg_bounds_sync(false_reg2);
+
+	/* jump (TRUE) branch */
+	regs_refine_cond_op(true_reg1, true_reg2, opcode, is_jmp32);
+	reg_bounds_sync(true_reg1);
+	reg_bounds_sync(true_reg2);
 }
 
 static void mark_ptr_or_null_reg(struct bpf_func_state *state,
@@ -14961,22 +14929,12 @@ static int check_cond_jmp_op(struct bpf_verifier_env *env,
 		reg_set_min_max(&other_branch_regs[insn->dst_reg],
 				&other_branch_regs[insn->src_reg],
 				dst_reg, src_reg, opcode, is_jmp32);
-
-		if (dst_reg->type == SCALAR_VALUE &&
-		    src_reg->type == SCALAR_VALUE &&
-		    !is_jmp32 && (opcode == BPF_JEQ || opcode == BPF_JNE)) {
-			/* Comparing for equality, we can combine knowledge */
-			reg_combine_min_max(&other_branch_regs[insn->src_reg],
-					    &other_branch_regs[insn->dst_reg],
-					    src_reg, dst_reg, opcode);
-		}
 	} else /* BPF_SRC(insn->code) == BPF_K */ {
 		reg_set_min_max(&other_branch_regs[insn->dst_reg],
 				src_reg /* fake one */,
 				dst_reg, src_reg /* same fake one */,
 				opcode, is_jmp32);
 	}
-
 	if (BPF_SRC(insn->code) == BPF_X &&
 	    src_reg->type == SCALAR_VALUE && src_reg->id &&
 	    !WARN_ON_ONCE(src_reg->id != other_branch_regs[insn->src_reg].id)) {