diff --git a/spec/compiler/type_inference/def_overload_spec.cr b/spec/compiler/type_inference/def_overload_spec.cr
index 4769b8c71ed5ad5593a5b236c61846dfb2def24d..973e37ec9c5b2ac5c2a25a71f32960f397fc2ffe 100755
--- a/spec/compiler/type_inference/def_overload_spec.cr
+++ b/spec/compiler/type_inference/def_overload_spec.cr
@@ -532,4 +532,24 @@ describe "Type inference: def overload" do
       ",
       "no overload matches"
   end
+
+  it "gets free variable from union restriction" do
+    assert_type("
+      def foo(x : Nil | U)
+        U
+      end
+
+      foo(1 || nil)
+      ") { int32.metaclass }
+  end
+
+  it "gets free variable from union restriction without a union" do
+    assert_type("
+      def foo(x : Nil | U)
+        U
+      end
+
+      foo(1)
+      ") { int32.metaclass }
+  end
 end
diff --git a/src/compiler/crystal/type_inference/restrictions.cr b/src/compiler/crystal/type_inference/restrictions.cr
index 9cb659169b1833236c327106fbe4532c19e9f0c3..17f50ef183c007f3a91f39344dfe47637ae9431f 100644
--- a/src/compiler/crystal/type_inference/restrictions.cr
+++ b/src/compiler/crystal/type_inference/restrictions.cr
@@ -201,6 +201,20 @@ module Crystal
       self == type || union_types.any? &.is_restriction_of?(type, owner)
     end
 
+    def restrict(other : Union, owner, type_lookup, free_vars)
+      types = [] of Type
+      other.types.each do |other_type|
+        self.union_types.each do |type|
+          restricted = type.restrict(other_type, owner, type_lookup, free_vars)
+          if restricted
+            types << restricted
+            break
+          end
+        end
+      end
+      program.type_merge_union_of(types)
+    end
+
     def restrict(other : Type | Generic, owner, type_lookup, free_vars)
       types = [] of Type
       union_types.each do |type|