diff --git a/libs/llvm/lib_llvm.cr b/libs/llvm/lib_llvm.cr index 228be640076a716e3f1f8b2855727be02cae45fd..7511862d4bef57530812ae1aab8606d1117b5daf 100644 --- a/libs/llvm/lib_llvm.cr +++ b/libs/llvm/lib_llvm.cr @@ -153,6 +153,7 @@ lib LibLLVM("`llvm-config --libs --ldflags`") end fun add_attribute = LLVMAddAttribute(arg : ValueRef, attr : Int32) + fun add_instr_attribute = LLVMAddInstrAttribute(instr : ValueRef, index : UInt32, attr : Attribute) fun add_clause = LLVMAddClause(lpad : ValueRef, clause_val : ValueRef) fun add_function = LLVMAddFunction(module : ModuleRef, name : UInt8*, type : TypeRef) : ValueRef fun add_function_attr = LLVMAddFunctionAttr(fn : ValueRef, pa : Int32); diff --git a/spec/compiler/codegen/fun_spec.cr b/spec/compiler/codegen/fun_spec.cr index 143763568457a22704c45b0e0bac803f29eaba4c..211bb86d21b2f199f5c0a9a5a29fa03fab70e6fc 100644 --- a/spec/compiler/codegen/fun_spec.cr +++ b/spec/compiler/codegen/fun_spec.cr @@ -155,4 +155,15 @@ describe "Code gen: fun" do g.call(10, 20) ").to_i.should eq(11) end + + it "calls fun pointer with union (passed by value) arg" do + run(" + struct Number + def abs; self; end + end + + f = ->(x : Int32 | Float64) { x.abs } + f.call(1 || 1.5).to_i + ").to_i.should eq(1) + end end diff --git a/src/compiler/crystal/codegen.cr b/src/compiler/crystal/codegen.cr index 4ab396c7021d047fe581164f91c6a26c30f1828c..20a9cebf4e35aa79ab9adca829b4331cdbd14616 100644 --- a/src/compiler/crystal/codegen.cr +++ b/src/compiler/crystal/codegen.cr @@ -662,7 +662,7 @@ module Crystal end def codegen_primitive_fun_call(node, target_def, call_args) - codegen_call_or_invoke(node, call_args[0], call_args[1 .. -1], true, target_def.type) + codegen_call_or_invoke(node, target_def, nil, call_args[0], call_args[1 .. -1], true, target_def.type) end def codegen_primitive_pointer_diff(node, target_def, call_args) @@ -1574,7 +1574,7 @@ module Crystal end raise_fun = main_fun(RAISE_NAME) - codegen_call_or_invoke(node, raise_fun, [bit_cast(unwind_ex_obj, type_of(raise_fun.get_param(0)))], true, @mod.no_return) + codegen_call_or_invoke(node, nil, nil, raise_fun, [bit_cast(unwind_ex_obj, type_of(raise_fun.get_param(0)))], true, @mod.no_return) end if node_ensure @@ -1713,12 +1713,12 @@ module Crystal accept arg def_arg = target_def.args[i]? - if def_arg - call_args << downcast(@last, def_arg.type, arg.type, true) - else - # Def argument might be missing if it's a variadic call - call_args << @last - end + call_arg = @last + + # Def argument might be missing if it's a variadic call + call_arg = downcast(call_arg, def_arg.type, arg.type, true) if def_arg + + call_args << call_arg end end @@ -1851,10 +1851,10 @@ module Crystal end func = target_def_fun(target_def, self_type) - codegen_call_or_invoke(node, func, call_args, target_def.raises, target_def.type) + codegen_call_or_invoke(node, target_def, self_type, func, call_args, target_def.raises, target_def.type) end - def codegen_call_or_invoke(node, func, call_args, raises, type) + def codegen_call_or_invoke(node, target_def, self_type, func, call_args, raises, type) if raises && (handler = @exception_handlers.try &.last?) invoke_out_block = new_block "invoke_out" @last = @builder.invoke func, call_args, invoke_out_block, handler.catch_block @@ -1863,6 +1863,7 @@ module Crystal @last = call func, call_args end + set_call_by_val_attributes node, target_def, self_type emit_debug_metadata node, @last if @debug case type @@ -1881,6 +1882,25 @@ module Crystal @last end + def set_call_by_val_attributes(node, target_def, self_type) + # We don't want by_val in C functions + return if target_def.is_a?(External) + + arg_offset = 1 + if node.is_a?(Call) + args = node.args + arg_offset += 1 if node.obj.try(&.type.passed_as_self?) || self_type.try(&.passed_as_self?) + else + args = target_def.try(&.args) + end + + args.try &.each_with_index do |arg, i| + if arg.type.passed_by_value? + LibLLVM.add_instr_attribute(@last, (i + arg_offset).to_u32, LibLLVM::Attribute::ByVal) + end + end + end + def emit_debug_metadata(node, value) # if value.is_a?(LibLLVM::ValueRef) && !LLVM.constant?(value) && !value.is_a?(LibLLVM::BasicBlockRef) if md = dbg_metadata(node)