diff --git a/spec/compiler/codegen/class_spec.cr b/spec/compiler/codegen/class_spec.cr
index 64102c63c7665dcff6b73c463d6c6c1b74a31a06..7ff697d028f5296e42c62b357cc26c316c72a2c2 100755
--- a/spec/compiler/codegen/class_spec.cr
+++ b/spec/compiler/codegen/class_spec.cr
@@ -235,4 +235,21 @@ describe "Code gen: class" do
       a
       ").to_i.should eq(2)
   end
+
+  it "assigns type to reference union type" do
+    run("
+      class Foo
+        def initialize(@x)
+        end
+        def x=(@x); end
+      end
+
+      class Bar; end
+      class Baz; end
+
+      f = Foo.new(Bar.new)
+      f.x = Baz.new
+      1
+      ").to_i.should eq(1)
+  end
 end
diff --git a/spec/compiler/codegen/def_spec.cr b/spec/compiler/codegen/def_spec.cr
index 4e2d9487976ef9ff6e99db14b8bd98aafcc6b8c9..ce4677f517693a63e7928e9e386598c5f8a6c7ac 100755
--- a/spec/compiler/codegen/def_spec.cr
+++ b/spec/compiler/codegen/def_spec.cr
@@ -338,4 +338,15 @@ describe "Code gen: def" do
       foo.nil?
       ").to_b.should be_true
   end
+
+  it "codegens dispatch with nilable reference union type" do
+    run("
+      struct Nil; def object_id; 0_u64; end; end
+      class Foo; end
+      class Bar; end
+
+      f = 1 == 1 ? nil : (Foo.new || Bar.new)
+      f.object_id
+      ").to_i.should eq(0)
+  end
 end
diff --git a/spec/compiler/codegen/is_a_spec.cr b/spec/compiler/codegen/is_a_spec.cr
index 0677dbd3b210b75fbc71a68b40a1dab68216f447..59df33adfb4585b93b7a6c758bcabe4d6b422aa4 100755
--- a/spec/compiler/codegen/is_a_spec.cr
+++ b/spec/compiler/codegen/is_a_spec.cr
@@ -377,4 +377,40 @@ describe "Codegen: is_a?" do
       end
       ").to_i.should eq(1)
   end
+
+  it "codegens is_a? from nilable reference union type to nil" do
+    run("
+      class Foo
+      end
+
+      class Bar
+      end
+
+      a = Foo.new || Bar.new || nil
+      if a.is_a?(Nil)
+        b = a
+        1
+      else
+        2
+      end
+      ").to_i.should eq(2)
+  end
+
+  it "codegens is_a? from nilable reference union type to type" do
+    run("
+      class Foo
+      end
+
+      class Bar
+      end
+
+      a = Foo.new || Bar.new || nil
+      if a.is_a?(Foo)
+        b = a
+        1
+      else
+        2
+      end
+      ").to_i.should eq(1)
+  end
 end
diff --git a/src/compiler/crystal/codegen.cr b/src/compiler/crystal/codegen.cr
index 7e573055b09ebda15c5273fbbcf584b2ef328b43..64c000ed6e3b3221a9b79650d976c725a8fbd285 100644
--- a/src/compiler/crystal/codegen.cr
+++ b/src/compiler/crystal/codegen.cr
@@ -846,7 +846,9 @@ module Crystal
         case break_type
         when NilableType
           context.break_table.not_nil!.add insert_block, to_nilable(@last, break_type, control_expression_type(node))
-        when UnionType
+        when NilableReferenceUnionType
+          context.break_table.not_nil!.add insert_block, to_nilable(@last, break_type, control_expression_type(node))
+        when MixedUnionType
           break_union = context.break_union.not_nil!
           assign(break_union, break_type, control_expression_type(node), @last)
         else
@@ -915,6 +917,10 @@ module Crystal
         case node.type
         when NilableType
           value = @codegen.to_nilable(value, node.type, type)
+        when NilableReferenceUnionType
+          value = @codegen.to_nilable(value, node.type, type)
+        when ReferenceUnionType
+          value = cast_to value, node.type
         when HierarchyType
           value = cast_to value, node.type
         else
@@ -1349,7 +1355,7 @@ module Crystal
               context.vars = context.vars.dup
               get_exception_fun = main_fun(GET_EXCEPTION_NAME)
               exception_ptr = call get_exception_fun, [bit_cast(unwind_ex_obj, type_of(get_exception_fun.get_param(0)))]
-              exception = int2ptr exception_ptr, LLVMTyper::HIERARCHY_LLVM_TYPE
+              exception = int2ptr exception_ptr, LLVMTyper::TYPE_ID_POINTER
               context.vars[a_rescue_name] = LLVMVar.new(exception, a_rescue.type, true)
             end
 
@@ -1807,7 +1813,25 @@ module Crystal
       case type
       when NilableType
         @builder.select null_pointer?(value), int(@mod.nil.type_id), int(type.not_nil_type.type_id)
-      when UnionType
+      when ReferenceUnionType
+        load(value)
+      when NilableReferenceUnionType
+        nil_block, not_nil_block, exit_block = new_blocks ["nil", "not_nil", "exit"]
+        phi_table = LLVM::PhiTable.new
+
+        cond null_pointer?(value), nil_block, not_nil_block
+
+        position_at_end nil_block
+        phi_table.add insert_block, int(@mod.nil.type_id)
+        br exit_block
+
+        position_at_end not_nil_block
+        phi_table.add insert_block, load(value)
+        br exit_block
+
+        position_at_end exit_block
+        phi LLVM::Int32, phi_table
+      when MixedUnionType
         load(union_type_id(value))
       when HierarchyType
         load(value)
@@ -1833,11 +1857,11 @@ module Crystal
         llvm_false
       when BoolType
         @last
-      when NilableType, PointerInstanceType
-        not_null_pointer? @last
       when TypeDefType
         codegen_cond type.typedef
-      when UnionType
+      when NilableType, NilableReferenceUnionType, PointerInstanceType
+        not_null_pointer? @last
+      when MixedUnionType
         has_nil = type.union_types.any? &.nil_type?
         has_bool = type.union_types.any? &.bool_type?
 
@@ -1888,32 +1912,44 @@ module Crystal
       end
     end
 
-    def assign_distinct(target_pointer, target_type : HierarchyTypeMetaclass, value_type : Metaclass, value)
+    def assign_distinct(target_pointer, target_type : NilableType, value_type : Type, value)
+      store to_nilable(value, target_type, value_type), target_pointer
+    end
+
+    def assign_distinct(target_pointer, target_type : ReferenceUnionType, value_type : ReferenceUnionType, value)
       store value, target_pointer
     end
 
-    def assign_distinct(target_pointer, target_type : NilableType, value_type : Type, value)
+    def assign_distinct(target_pointer, target_type : ReferenceUnionType, value_type : HierarchyType, value)
+      store value, target_pointer
+    end
+
+    def assign_distinct(target_pointer, target_type : ReferenceUnionType, value_type : Type, value)
+      store cast_to(value, target_type), target_pointer
+    end
+
+    def assign_distinct(target_pointer, target_type : NilableReferenceUnionType, value_type : Type, value)
       store to_nilable(value, target_type, value_type), target_pointer
     end
 
-    def assign_distinct(target_pointer, target_type : UnionType, value_type : UnionType, value)
+    def assign_distinct(target_pointer, target_type : MixedUnionType, value_type : MixedUnionType, value)
       casted_value = cast_to_pointer value, target_type
       store load(casted_value), target_pointer
     end
 
-    def assign_distinct(target_pointer, target_type : UnionType, value_type : NilableType, value)
+    def assign_distinct(target_pointer, target_type : MixedUnionType, value_type : NilableType, value)
       store_in_union target_pointer, value_type, value
     end
 
-    def assign_distinct(target_pointer, target_type : UnionType, value_type : VoidType, value)
+    def assign_distinct(target_pointer, target_type : MixedUnionType, value_type : VoidType, value)
       store int(value_type.type_id), union_type_id(target_pointer)
     end
 
-    def assign_distinct(target_pointer, target_type : UnionType, value_type : Type, value)
+    def assign_distinct(target_pointer, target_type : MixedUnionType, value_type : Type, value)
       store_in_union target_pointer, value_type, to_rhs(value, value_type)
     end
 
-    def assign_distinct(target_pointer, target_type : HierarchyType, value_type : UnionType, value)
+    def assign_distinct(target_pointer, target_type : HierarchyType, value_type : MixedUnionType, value)
       casted_value = cast_to_pointer(union_value(value), target_type)
       store load(casted_value), target_pointer
     end
@@ -1922,8 +1958,12 @@ module Crystal
       store cast_to(value, target_type), target_pointer
     end
 
+    def assign_distinct(target_pointer, target_type : HierarchyTypeMetaclass, value_type : Metaclass, value)
+      store value, target_pointer
+    end
+
     def assign_distinct(target_pointer, target_type : Type, value_type : Type, value)
-      raise "Bug: trying to assign #{target_type} = #{value_type}"
+      raise "Bug: trying to assign #{target_type} <- #{value_type}"
     end
 
     def cast_value(value, to_type, from_type, already_loaded = false)
@@ -1936,14 +1976,6 @@ module Crystal
       end
     end
 
-    def cast_value_distinct(value, to_type, from_type : NilableType, already_loaded)
-      if to_type.nil_type?
-        value = llvm_nil
-      else
-        already_loaded ? value : load value
-      end
-    end
-
     def cast_value_distinct(value, to_type, from_type : Metaclass | GenericClassInstanceMetaclass | HierarchyTypeMetaclass, already_loaded)
       value
     end
@@ -1952,30 +1984,76 @@ module Crystal
       already_loaded ? value : load value
     end
 
-    def cast_value_distinct(value, to_type : UnionType, from_type : HierarchyType, already_loaded)
-      # This happens if the restriction is a UnionType:
+    def cast_value_distinct(value, to_type : MixedUnionType, from_type : HierarchyType, already_loaded)
+      # This happens if the restriction is a union:
       # we keep each of the union types as the result, we don't fully merge
       union_ptr = alloca llvm_type(to_type)
       store_in_union union_ptr, from_type, (already_loaded ? value : load value)
       union_ptr
     end
 
-    def cast_value_distinct(value, to_type : UnionType, from_type : UnionType, already_loaded)
+    def cast_value_distinct(value, to_type : ReferenceUnionType, from_type : HierarchyType, already_loaded)
+      # This happens if the restriction is a union:
+      # we keep each of the union types as the result, we don't fully merge
+      already_loaded ? value : load value
+    end
+
+    def cast_value_distinct(value, to_type : NilType, from_type : NilableType, already_loaded)
+      llvm_nil
+    end
+
+    def cast_value_distinct(value, to_type : Type, from_type : NilableType, already_loaded)
+      already_loaded ? value : load value
+    end
+
+    def cast_value_distinct(value, to_type : ReferenceUnionType, from_type : ReferenceUnionType, already_loaded)
+      already_loaded ? value : load value
+    end
+
+    def cast_value_distinct(value, to_type : HierarchyType, from_type : ReferenceUnionType, already_loaded)
+      already_loaded ? value : load value
+    end
+
+    def cast_value_distinct(value, to_type : Type, from_type : ReferenceUnionType, already_loaded)
+      cast_to (already_loaded ? value : load value), to_type
+    end
+
+    def cast_value_distinct(value, to_type : HierarchyType, from_type : NilableReferenceUnionType, already_loaded)
+      already_loaded ? value : load value
+    end
+
+    def cast_value_distinct(value, to_type : ReferenceUnionType, from_type : NilableReferenceUnionType, already_loaded)
+      already_loaded ? value : load value
+    end
+
+    def cast_value_distinct(value, to_type : NilableType, from_type : NilableReferenceUnionType, already_loaded)
+      cast_to (already_loaded ? value : load value), to_type
+    end
+
+    def cast_value_distinct(value, to_type : NilType, from_type : NilableReferenceUnionType, already_loaded)
+      llvm_nil
+    end
+
+    def cast_value_distinct(value, to_type : Type, from_type : NilableReferenceUnionType, already_loaded)
+      cast_to (already_loaded ? value : load value), to_type
+    end
+
+    def cast_value_distinct(value, to_type : MixedUnionType, from_type : MixedUnionType, already_loaded)
       cast_to_pointer value, to_type
     end
 
-    def cast_value_distinct(value, to_type : NilableType, from_type : UnionType, already_loaded)
+    def cast_value_distinct(value, to_type : NilableType, from_type : MixedUnionType, already_loaded)
       load cast_to_pointer(union_value(value), to_type)
     end
 
-    def cast_value_distinct(value, to_type : Type, from_type : UnionType, already_loaded)
+    def cast_value_distinct(value, to_type : Type, from_type : MixedUnionType, already_loaded)
       value_ptr = union_value(value)
       value = cast_to_pointer(value_ptr, to_type)
       to_lhs value, to_type
     end
 
     def cast_value_distinct(value, to_type : Type, from_type : Type, already_loaded)
-      raise "Bug: trying to cast #{to_type} = #{from_type}"
+      raise "Bug: trying to cast #{to_type} <- #{from_type}"
     end
 
     def match_any_type_id(type, type_id)
@@ -2033,6 +2111,15 @@ module Crystal
       end
     end
 
+    def llvm_self_ptr
+      type = context.type
+      if type.is_a?(HierarchyType)
+        cast_to llvm_self, type.base_type
+      else
+        llvm_self
+      end
+    end
+
     def codegen_return(exp_type, fun_type)
       case fun_type
       when VoidType
@@ -2045,6 +2132,10 @@ module Crystal
         yield load(return_union)
       when NilableType
         yield to_nilable(@last, fun_type, exp_type)
+      when NilableReferenceUnionType
+        yield to_nilable(@last, fun_type, exp_type)
+      when ReferenceUnionType
+        yield cast_to @last, fun_type
       when HierarchyType
         yield cast_to @last, fun_type
       else
@@ -2052,15 +2143,6 @@ module Crystal
       end
     end
 
-    def llvm_self_ptr
-      type = context.type
-      if type.is_a?(HierarchyType)
-        cast_to llvm_self, type.base_type
-      else
-        llvm_self
-      end
-    end
-
     def type_module(type)
       return @main_mod if @single_module
 
@@ -2222,8 +2304,24 @@ module Crystal
       type.passed_by_value? ? load ptr : ptr
     end
 
-    def to_nilable(ptr, to_type, from_type)
-      (from_type || @mod.nil).nil_type? ? LLVM.null(llvm_type(to_type)) : ptr
+    def to_nilable(ptr, to_type : NilableType, from_type : NilType | Nil)
+      LLVM.null(llvm_type(to_type))
+    end
+
+    def to_nilable(ptr, to_type : NilableType, from_type : Type)
+      ptr
+    end
+
+    def to_nilable(ptr, to_type : NilableReferenceUnionType, from_type : NilType | Nil)
+      LLVM.null(llvm_type(to_type))
+    end
+
+    def to_nilable(ptr, to_type : NilableReferenceUnionType, from_type : Type)
+      cast_to ptr, to_type
+    end
+
+    def to_nilable(ptr, to_type : Type, from_type : Type)
+      raise "Bug: to_nilable with #{to_type} <- #{from_type}"
     end
 
     def union_type_id(union_pointer)
diff --git a/src/compiler/crystal/codegen/llvm_typer.cr b/src/compiler/crystal/codegen/llvm_typer.cr
index 81ade86ed526a809e5386622f0b038038948dbed..20630e2707282532c99f8f56cddb68818cbce324 100644
--- a/src/compiler/crystal/codegen/llvm_typer.cr
+++ b/src/compiler/crystal/codegen/llvm_typer.cr
@@ -3,7 +3,7 @@ require "llvm"
 
 module Crystal
   class LLVMTyper
-    HIERARCHY_LLVM_TYPE = LLVM.pointer_type(LLVM::Int32)
+    TYPE_ID_POINTER = LLVM.pointer_type(LLVM::Int32)
 
     getter landing_pad_type
 
@@ -87,7 +87,19 @@ module Crystal
       LLVM.array_type(pointed_type, (type.size as NumberLiteral).value.to_i)
     end
 
-    def create_llvm_type(type : UnionType)
+    def create_llvm_type(type : NilableType)
+      llvm_type type.not_nil_type
+    end
+
+    def create_llvm_type(type : ReferenceUnionType)
+      TYPE_ID_POINTER
+    end
+
+    def create_llvm_type(type : NilableReferenceUnionType)
+      TYPE_ID_POINTER
+    end
+
+    def create_llvm_type(type : MixedUnionType)
       LLVM.struct_type(type.llvm_name) do |a_struct|
         @cache[type] = a_struct
 
@@ -112,9 +124,6 @@ module Crystal
       end
     end
 
-    def create_llvm_type(type : NilableType)
-      llvm_type type.not_nil_type
-    end
 
     def create_llvm_type(type : CStructType)
       llvm_struct_type(type)
@@ -129,7 +138,7 @@ module Crystal
     end
 
     def create_llvm_type(type : HierarchyType)
-      HIERARCHY_LLVM_TYPE
+      TYPE_ID_POINTER
     end
 
     def create_llvm_type(type : FunType)
@@ -223,10 +232,6 @@ module Crystal
       end
     end
 
-    def create_llvm_arg_type(type : UnionType)
-      LLVM.pointer_type llvm_type(type)
-    end
-
     def create_llvm_arg_type(type : CStructType)
       LLVM.pointer_type llvm_type(type)
     end
@@ -235,8 +240,8 @@ module Crystal
       LLVM.pointer_type llvm_type(type)
     end
 
-    def create_llvm_arg_type(type : NilableType)
-      llvm_type(type)
+    def create_llvm_arg_type(type : MixedUnionType)
+      LLVM.pointer_type llvm_type(type)
     end
 
     def create_llvm_arg_type(type : AliasType)
diff --git a/src/compiler/crystal/program.cr b/src/compiler/crystal/program.cr
index 4da4fdd414acb9717e35b103905c2302e6580809..edd3660cd268c7c365af34a7ab30a9cdfd6eb4ea 100644
--- a/src/compiler/crystal/program.cr
+++ b/src/compiler/crystal/program.cr
@@ -185,18 +185,35 @@ module Crystal
         types.first
       else
         types_ids = types.map(&.type_id).sort!
+        @unions[types_ids] ||= make_union_type(types, types_ids)
+      end
+    end
 
-        if types_ids.length == 2 && types_ids[0] == 0 # NilType has type_id == 0
+    def make_union_type(types, types_ids)
+      # NilType has type_id == 0
+      if types_ids.first == 0
+        # Check if it's a Nilable type
+        if types.length == 2
           nil_index = types.index(&.nil_type?).not_nil!
           other_index = 1 - nil_index
           other_type = types[other_index]
-          if other_type.class? && !other_type.struct?
-            return @unions[types_ids] ||= NilableType.new(self, other_type)
+          if other_type.reference_like? && !other_type.hierarchy?
+            return NilableType.new(self, other_type)
           end
         end
 
-        @unions[types_ids] ||= UnionType.new(self, types)
+        if types.all? &.reference_like?
+          return NilableReferenceUnionType.new(self, types)
+        else
+          return MixedUnionType.new(self, types)
+        end
       end
+
+      if types.all? &.reference_like?
+        return ReferenceUnionType.new(self, types)
+      end
+
+      MixedUnionType.new(self, types)
     end
 
     def fun_of(types : Array)
diff --git a/src/compiler/crystal/types.cr b/src/compiler/crystal/types.cr
index 46f76e2a7542c8a23a3878967250672616564d97..c3e7ae616aa95327a8de4b94f1d45f1dc51fd028 100644
--- a/src/compiler/crystal/types.cr
+++ b/src/compiler/crystal/types.cr
@@ -122,6 +122,10 @@ module Crystal
       false
     end
 
+    def reference_like?
+      false
+    end
+
     def hierarchy_type
       self
     end
@@ -881,6 +885,10 @@ module Crystal
       true
     end
 
+    def reference_like?
+      !struct?
+    end
+
     def declare_instance_var(name, type)
       ivar = Var.new(name, type)
       ivar.bind_to ivar
@@ -984,6 +992,10 @@ module Crystal
     def nil_type?
       true
     end
+
+    def reference_like?
+      true
+    end
   end
 
   class VoidType < PrimitiveType
@@ -1177,6 +1189,10 @@ module Crystal
       true
     end
 
+    def reference_like?
+      !struct?
+    end
+
     def metaclass
       @metaclass ||= GenericClassInstanceMetaclass.new(program, self)
     end
@@ -1264,6 +1280,10 @@ module Crystal
       true
     end
 
+    def reference_like?
+      false
+    end
+
     def allocated
       true
     end
@@ -1308,6 +1328,10 @@ module Crystal
       var.type.primitive_like?
     end
 
+    def reference_like?
+      false
+    end
+
     def to_s
       "#{var.type}[#{size}]"
     end
@@ -1789,7 +1813,8 @@ module Crystal
     end
   end
 
-  class UnionType < Type
+  # Base class for union types.
+  abstract class UnionType < Type
     include MultiType
 
     getter :program
@@ -1812,10 +1837,6 @@ module Crystal
       nil
     end
 
-    def passed_by_value?
-      true
-    end
-
     def includes_type?(other_type)
       union_types.any? &.includes_type?(other_type)
     end
@@ -1898,23 +1919,45 @@ module Crystal
     end
   end
 
+  # A union type that has two types: Nil and another Reference type.
+  # Can be represented as a maybe-null pointer where the type id
+  # of the type that is not nil is known at compile time.
   class NilableType < UnionType
-    getter :not_nil_type
-
-    def initialize(@program, @not_nil_type)
-      super(@program, [@program.nil, @not_nil_type] of Type)
+    def initialize(@program, not_nil_type)
+      super(@program, [@program.nil, not_nil_type] of Type)
     end
 
     def nilable?
       true
     end
 
-    def passed_by_value?
-      false
+    def not_nil_type
+      @union_types.last
     end
 
     def to_s
-      "#{@not_nil_type}?"
+      "#{not_nil_type}?"
+    end
+  end
+
+  # A union type that has Nil and other reference-like types.
+  # Can be represented as a maybe-null pointer but the type id is
+  # not known at compile time.
+  class NilableReferenceUnionType < UnionType
+  end
+
+  # A union type that doesn't have nil, and all types are reference-like.
+  # Can be represented as a never-null pointer.
+  class ReferenceUnionType < UnionType
+  end
+
+  # A union type that doesn't match any of the previous definitions,
+  # so it can contain Nil with primitive types, or Reference types with
+  # primitives types.
+  # Must be represented as a union.
+  class MixedUnionType < UnionType
+    def passed_by_value?
+      true
     end
   end
 
@@ -2068,6 +2111,10 @@ module Crystal
       true
     end
 
+    def reference_like?
+      true
+    end
+
     def cover
       if base_type.abstract
         cover = [] of Type