diff --git a/spec/compiler/codegen/class_spec.cr b/spec/compiler/codegen/class_spec.cr index 64102c63c7665dcff6b73c463d6c6c1b74a31a06..7ff697d028f5296e42c62b357cc26c316c72a2c2 100755 --- a/spec/compiler/codegen/class_spec.cr +++ b/spec/compiler/codegen/class_spec.cr @@ -235,4 +235,21 @@ describe "Code gen: class" do a ").to_i.should eq(2) end + + it "assigns type to reference union type" do + run(" + class Foo + def initialize(@x) + end + def x=(@x); end + end + + class Bar; end + class Baz; end + + f = Foo.new(Bar.new) + f.x = Baz.new + 1 + ").to_i.should eq(1) + end end diff --git a/spec/compiler/codegen/def_spec.cr b/spec/compiler/codegen/def_spec.cr index 4e2d9487976ef9ff6e99db14b8bd98aafcc6b8c9..ce4677f517693a63e7928e9e386598c5f8a6c7ac 100755 --- a/spec/compiler/codegen/def_spec.cr +++ b/spec/compiler/codegen/def_spec.cr @@ -338,4 +338,15 @@ describe "Code gen: def" do foo.nil? ").to_b.should be_true end + + it "codegens dispatch with nilable reference union type" do + run(" + struct Nil; def object_id; 0_u64; end; end + class Foo; end + class Bar; end + + f = 1 == 1 ? nil : (Foo.new || Bar.new) + f.object_id + ").to_i.should eq(0) + end end diff --git a/spec/compiler/codegen/is_a_spec.cr b/spec/compiler/codegen/is_a_spec.cr index 0677dbd3b210b75fbc71a68b40a1dab68216f447..59df33adfb4585b93b7a6c758bcabe4d6b422aa4 100755 --- a/spec/compiler/codegen/is_a_spec.cr +++ b/spec/compiler/codegen/is_a_spec.cr @@ -377,4 +377,40 @@ describe "Codegen: is_a?" do end ").to_i.should eq(1) end + + it "codegens is_a? from nilable reference union type to nil" do + run(" + class Foo + end + + class Bar + end + + a = Foo.new || Bar.new || nil + if a.is_a?(Nil) + b = a + 1 + else + 2 + end + ").to_i.should eq(2) + end + + it "codegens is_a? from nilable reference union type to type" do + run(" + class Foo + end + + class Bar + end + + a = Foo.new || Bar.new || nil + if a.is_a?(Foo) + b = a + 1 + else + 2 + end + ").to_i.should eq(1) + end end diff --git a/src/compiler/crystal/codegen.cr b/src/compiler/crystal/codegen.cr index 7e573055b09ebda15c5273fbbcf584b2ef328b43..64c000ed6e3b3221a9b79650d976c725a8fbd285 100644 --- a/src/compiler/crystal/codegen.cr +++ b/src/compiler/crystal/codegen.cr @@ -846,7 +846,9 @@ module Crystal case break_type when NilableType context.break_table.not_nil!.add insert_block, to_nilable(@last, break_type, control_expression_type(node)) - when UnionType + when NilableReferenceUnionType + context.break_table.not_nil!.add insert_block, to_nilable(@last, break_type, control_expression_type(node)) + when MixedUnionType break_union = context.break_union.not_nil! assign(break_union, break_type, control_expression_type(node), @last) else @@ -915,6 +917,10 @@ module Crystal case node.type when NilableType value = @codegen.to_nilable(value, node.type, type) + when NilableReferenceUnionType + value = @codegen.to_nilable(value, node.type, type) + when ReferenceUnionType + value = cast_to value, node.type when HierarchyType value = cast_to value, node.type else @@ -1349,7 +1355,7 @@ module Crystal context.vars = context.vars.dup get_exception_fun = main_fun(GET_EXCEPTION_NAME) exception_ptr = call get_exception_fun, [bit_cast(unwind_ex_obj, type_of(get_exception_fun.get_param(0)))] - exception = int2ptr exception_ptr, LLVMTyper::HIERARCHY_LLVM_TYPE + exception = int2ptr exception_ptr, LLVMTyper::TYPE_ID_POINTER context.vars[a_rescue_name] = LLVMVar.new(exception, a_rescue.type, true) end @@ -1807,7 +1813,25 @@ module Crystal case type when NilableType @builder.select null_pointer?(value), int(@mod.nil.type_id), int(type.not_nil_type.type_id) - when UnionType + when ReferenceUnionType + load(value) + when NilableReferenceUnionType + nil_block, not_nil_block, exit_block = new_blocks ["nil", "not_nil", "exit"] + phi_table = LLVM::PhiTable.new + + cond null_pointer?(value), nil_block, not_nil_block + + position_at_end nil_block + phi_table.add insert_block, int(@mod.nil.type_id) + br exit_block + + position_at_end not_nil_block + phi_table.add insert_block, load(value) + br exit_block + + position_at_end exit_block + phi LLVM::Int32, phi_table + when MixedUnionType load(union_type_id(value)) when HierarchyType load(value) @@ -1833,11 +1857,11 @@ module Crystal llvm_false when BoolType @last - when NilableType, PointerInstanceType - not_null_pointer? @last when TypeDefType codegen_cond type.typedef - when UnionType + when NilableType, NilableReferenceUnionType, PointerInstanceType + not_null_pointer? @last + when MixedUnionType has_nil = type.union_types.any? &.nil_type? has_bool = type.union_types.any? &.bool_type? @@ -1888,32 +1912,44 @@ module Crystal end end - def assign_distinct(target_pointer, target_type : HierarchyTypeMetaclass, value_type : Metaclass, value) + def assign_distinct(target_pointer, target_type : NilableType, value_type : Type, value) + store to_nilable(value, target_type, value_type), target_pointer + end + + def assign_distinct(target_pointer, target_type : ReferenceUnionType, value_type : ReferenceUnionType, value) store value, target_pointer end - def assign_distinct(target_pointer, target_type : NilableType, value_type : Type, value) + def assign_distinct(target_pointer, target_type : ReferenceUnionType, value_type : HierarchyType, value) + store value, target_pointer + end + + def assign_distinct(target_pointer, target_type : ReferenceUnionType, value_type : Type, value) + store cast_to(value, target_type), target_pointer + end + + def assign_distinct(target_pointer, target_type : NilableReferenceUnionType, value_type : Type, value) store to_nilable(value, target_type, value_type), target_pointer end - def assign_distinct(target_pointer, target_type : UnionType, value_type : UnionType, value) + def assign_distinct(target_pointer, target_type : MixedUnionType, value_type : MixedUnionType, value) casted_value = cast_to_pointer value, target_type store load(casted_value), target_pointer end - def assign_distinct(target_pointer, target_type : UnionType, value_type : NilableType, value) + def assign_distinct(target_pointer, target_type : MixedUnionType, value_type : NilableType, value) store_in_union target_pointer, value_type, value end - def assign_distinct(target_pointer, target_type : UnionType, value_type : VoidType, value) + def assign_distinct(target_pointer, target_type : MixedUnionType, value_type : VoidType, value) store int(value_type.type_id), union_type_id(target_pointer) end - def assign_distinct(target_pointer, target_type : UnionType, value_type : Type, value) + def assign_distinct(target_pointer, target_type : MixedUnionType, value_type : Type, value) store_in_union target_pointer, value_type, to_rhs(value, value_type) end - def assign_distinct(target_pointer, target_type : HierarchyType, value_type : UnionType, value) + def assign_distinct(target_pointer, target_type : HierarchyType, value_type : MixedUnionType, value) casted_value = cast_to_pointer(union_value(value), target_type) store load(casted_value), target_pointer end @@ -1922,8 +1958,12 @@ module Crystal store cast_to(value, target_type), target_pointer end + def assign_distinct(target_pointer, target_type : HierarchyTypeMetaclass, value_type : Metaclass, value) + store value, target_pointer + end + def assign_distinct(target_pointer, target_type : Type, value_type : Type, value) - raise "Bug: trying to assign #{target_type} = #{value_type}" + raise "Bug: trying to assign #{target_type} <- #{value_type}" end def cast_value(value, to_type, from_type, already_loaded = false) @@ -1936,14 +1976,6 @@ module Crystal end end - def cast_value_distinct(value, to_type, from_type : NilableType, already_loaded) - if to_type.nil_type? - value = llvm_nil - else - already_loaded ? value : load value - end - end - def cast_value_distinct(value, to_type, from_type : Metaclass | GenericClassInstanceMetaclass | HierarchyTypeMetaclass, already_loaded) value end @@ -1952,30 +1984,76 @@ module Crystal already_loaded ? value : load value end - def cast_value_distinct(value, to_type : UnionType, from_type : HierarchyType, already_loaded) - # This happens if the restriction is a UnionType: + def cast_value_distinct(value, to_type : MixedUnionType, from_type : HierarchyType, already_loaded) + # This happens if the restriction is a union: # we keep each of the union types as the result, we don't fully merge union_ptr = alloca llvm_type(to_type) store_in_union union_ptr, from_type, (already_loaded ? value : load value) union_ptr end - def cast_value_distinct(value, to_type : UnionType, from_type : UnionType, already_loaded) + def cast_value_distinct(value, to_type : ReferenceUnionType, from_type : HierarchyType, already_loaded) + # This happens if the restriction is a union: + # we keep each of the union types as the result, we don't fully merge + already_loaded ? value : load value + end + + def cast_value_distinct(value, to_type : NilType, from_type : NilableType, already_loaded) + llvm_nil + end + + def cast_value_distinct(value, to_type : Type, from_type : NilableType, already_loaded) + already_loaded ? value : load value + end + + def cast_value_distinct(value, to_type : ReferenceUnionType, from_type : ReferenceUnionType, already_loaded) + already_loaded ? value : load value + end + + def cast_value_distinct(value, to_type : HierarchyType, from_type : ReferenceUnionType, already_loaded) + already_loaded ? value : load value + end + + def cast_value_distinct(value, to_type : Type, from_type : ReferenceUnionType, already_loaded) + cast_to (already_loaded ? value : load value), to_type + end + + def cast_value_distinct(value, to_type : HierarchyType, from_type : NilableReferenceUnionType, already_loaded) + already_loaded ? value : load value + end + + def cast_value_distinct(value, to_type : ReferenceUnionType, from_type : NilableReferenceUnionType, already_loaded) + already_loaded ? value : load value + end + + def cast_value_distinct(value, to_type : NilableType, from_type : NilableReferenceUnionType, already_loaded) + cast_to (already_loaded ? value : load value), to_type + end + + def cast_value_distinct(value, to_type : NilType, from_type : NilableReferenceUnionType, already_loaded) + llvm_nil + end + + def cast_value_distinct(value, to_type : Type, from_type : NilableReferenceUnionType, already_loaded) + cast_to (already_loaded ? value : load value), to_type + end + + def cast_value_distinct(value, to_type : MixedUnionType, from_type : MixedUnionType, already_loaded) cast_to_pointer value, to_type end - def cast_value_distinct(value, to_type : NilableType, from_type : UnionType, already_loaded) + def cast_value_distinct(value, to_type : NilableType, from_type : MixedUnionType, already_loaded) load cast_to_pointer(union_value(value), to_type) end - def cast_value_distinct(value, to_type : Type, from_type : UnionType, already_loaded) + def cast_value_distinct(value, to_type : Type, from_type : MixedUnionType, already_loaded) value_ptr = union_value(value) value = cast_to_pointer(value_ptr, to_type) to_lhs value, to_type end def cast_value_distinct(value, to_type : Type, from_type : Type, already_loaded) - raise "Bug: trying to cast #{to_type} = #{from_type}" + raise "Bug: trying to cast #{to_type} <- #{from_type}" end def match_any_type_id(type, type_id) @@ -2033,6 +2111,15 @@ module Crystal end end + def llvm_self_ptr + type = context.type + if type.is_a?(HierarchyType) + cast_to llvm_self, type.base_type + else + llvm_self + end + end + def codegen_return(exp_type, fun_type) case fun_type when VoidType @@ -2045,6 +2132,10 @@ module Crystal yield load(return_union) when NilableType yield to_nilable(@last, fun_type, exp_type) + when NilableReferenceUnionType + yield to_nilable(@last, fun_type, exp_type) + when ReferenceUnionType + yield cast_to @last, fun_type when HierarchyType yield cast_to @last, fun_type else @@ -2052,15 +2143,6 @@ module Crystal end end - def llvm_self_ptr - type = context.type - if type.is_a?(HierarchyType) - cast_to llvm_self, type.base_type - else - llvm_self - end - end - def type_module(type) return @main_mod if @single_module @@ -2222,8 +2304,24 @@ module Crystal type.passed_by_value? ? load ptr : ptr end - def to_nilable(ptr, to_type, from_type) - (from_type || @mod.nil).nil_type? ? LLVM.null(llvm_type(to_type)) : ptr + def to_nilable(ptr, to_type : NilableType, from_type : NilType | Nil) + LLVM.null(llvm_type(to_type)) + end + + def to_nilable(ptr, to_type : NilableType, from_type : Type) + ptr + end + + def to_nilable(ptr, to_type : NilableReferenceUnionType, from_type : NilType | Nil) + LLVM.null(llvm_type(to_type)) + end + + def to_nilable(ptr, to_type : NilableReferenceUnionType, from_type : Type) + cast_to ptr, to_type + end + + def to_nilable(ptr, to_type : Type, from_type : Type) + raise "Bug: to_nilable with #{to_type} <- #{from_type}" end def union_type_id(union_pointer) diff --git a/src/compiler/crystal/codegen/llvm_typer.cr b/src/compiler/crystal/codegen/llvm_typer.cr index 81ade86ed526a809e5386622f0b038038948dbed..20630e2707282532c99f8f56cddb68818cbce324 100644 --- a/src/compiler/crystal/codegen/llvm_typer.cr +++ b/src/compiler/crystal/codegen/llvm_typer.cr @@ -3,7 +3,7 @@ require "llvm" module Crystal class LLVMTyper - HIERARCHY_LLVM_TYPE = LLVM.pointer_type(LLVM::Int32) + TYPE_ID_POINTER = LLVM.pointer_type(LLVM::Int32) getter landing_pad_type @@ -87,7 +87,19 @@ module Crystal LLVM.array_type(pointed_type, (type.size as NumberLiteral).value.to_i) end - def create_llvm_type(type : UnionType) + def create_llvm_type(type : NilableType) + llvm_type type.not_nil_type + end + + def create_llvm_type(type : ReferenceUnionType) + TYPE_ID_POINTER + end + + def create_llvm_type(type : NilableReferenceUnionType) + TYPE_ID_POINTER + end + + def create_llvm_type(type : MixedUnionType) LLVM.struct_type(type.llvm_name) do |a_struct| @cache[type] = a_struct @@ -112,9 +124,6 @@ module Crystal end end - def create_llvm_type(type : NilableType) - llvm_type type.not_nil_type - end def create_llvm_type(type : CStructType) llvm_struct_type(type) @@ -129,7 +138,7 @@ module Crystal end def create_llvm_type(type : HierarchyType) - HIERARCHY_LLVM_TYPE + TYPE_ID_POINTER end def create_llvm_type(type : FunType) @@ -223,10 +232,6 @@ module Crystal end end - def create_llvm_arg_type(type : UnionType) - LLVM.pointer_type llvm_type(type) - end - def create_llvm_arg_type(type : CStructType) LLVM.pointer_type llvm_type(type) end @@ -235,8 +240,8 @@ module Crystal LLVM.pointer_type llvm_type(type) end - def create_llvm_arg_type(type : NilableType) - llvm_type(type) + def create_llvm_arg_type(type : MixedUnionType) + LLVM.pointer_type llvm_type(type) end def create_llvm_arg_type(type : AliasType) diff --git a/src/compiler/crystal/program.cr b/src/compiler/crystal/program.cr index 4da4fdd414acb9717e35b103905c2302e6580809..edd3660cd268c7c365af34a7ab30a9cdfd6eb4ea 100644 --- a/src/compiler/crystal/program.cr +++ b/src/compiler/crystal/program.cr @@ -185,18 +185,35 @@ module Crystal types.first else types_ids = types.map(&.type_id).sort! + @unions[types_ids] ||= make_union_type(types, types_ids) + end + end - if types_ids.length == 2 && types_ids[0] == 0 # NilType has type_id == 0 + def make_union_type(types, types_ids) + # NilType has type_id == 0 + if types_ids.first == 0 + # Check if it's a Nilable type + if types.length == 2 nil_index = types.index(&.nil_type?).not_nil! other_index = 1 - nil_index other_type = types[other_index] - if other_type.class? && !other_type.struct? - return @unions[types_ids] ||= NilableType.new(self, other_type) + if other_type.reference_like? && !other_type.hierarchy? + return NilableType.new(self, other_type) end end - @unions[types_ids] ||= UnionType.new(self, types) + if types.all? &.reference_like? + return NilableReferenceUnionType.new(self, types) + else + return MixedUnionType.new(self, types) + end end + + if types.all? &.reference_like? + return ReferenceUnionType.new(self, types) + end + + MixedUnionType.new(self, types) end def fun_of(types : Array) diff --git a/src/compiler/crystal/types.cr b/src/compiler/crystal/types.cr index 46f76e2a7542c8a23a3878967250672616564d97..c3e7ae616aa95327a8de4b94f1d45f1dc51fd028 100644 --- a/src/compiler/crystal/types.cr +++ b/src/compiler/crystal/types.cr @@ -122,6 +122,10 @@ module Crystal false end + def reference_like? + false + end + def hierarchy_type self end @@ -881,6 +885,10 @@ module Crystal true end + def reference_like? + !struct? + end + def declare_instance_var(name, type) ivar = Var.new(name, type) ivar.bind_to ivar @@ -984,6 +992,10 @@ module Crystal def nil_type? true end + + def reference_like? + true + end end class VoidType < PrimitiveType @@ -1177,6 +1189,10 @@ module Crystal true end + def reference_like? + !struct? + end + def metaclass @metaclass ||= GenericClassInstanceMetaclass.new(program, self) end @@ -1264,6 +1280,10 @@ module Crystal true end + def reference_like? + false + end + def allocated true end @@ -1308,6 +1328,10 @@ module Crystal var.type.primitive_like? end + def reference_like? + false + end + def to_s "#{var.type}[#{size}]" end @@ -1789,7 +1813,8 @@ module Crystal end end - class UnionType < Type + # Base class for union types. + abstract class UnionType < Type include MultiType getter :program @@ -1812,10 +1837,6 @@ module Crystal nil end - def passed_by_value? - true - end - def includes_type?(other_type) union_types.any? &.includes_type?(other_type) end @@ -1898,23 +1919,45 @@ module Crystal end end + # A union type that has two types: Nil and another Reference type. + # Can be represented as a maybe-null pointer where the type id + # of the type that is not nil is known at compile time. class NilableType < UnionType - getter :not_nil_type - - def initialize(@program, @not_nil_type) - super(@program, [@program.nil, @not_nil_type] of Type) + def initialize(@program, not_nil_type) + super(@program, [@program.nil, not_nil_type] of Type) end def nilable? true end - def passed_by_value? - false + def not_nil_type + @union_types.last end def to_s - "#{@not_nil_type}?" + "#{not_nil_type}?" + end + end + + # A union type that has Nil and other reference-like types. + # Can be represented as a maybe-null pointer but the type id is + # not known at compile time. + class NilableReferenceUnionType < UnionType + end + + # A union type that doesn't have nil, and all types are reference-like. + # Can be represented as a never-null pointer. + class ReferenceUnionType < UnionType + end + + # A union type that doesn't match any of the previous definitions, + # so it can contain Nil with primitive types, or Reference types with + # primitives types. + # Must be represented as a union. + class MixedUnionType < UnionType + def passed_by_value? + true end end @@ -2068,6 +2111,10 @@ module Crystal true end + def reference_like? + true + end + def cover if base_type.abstract cover = [] of Type