From 596efe8d6ea610e4a87ae4efcdff3d7d6120a5b3 Mon Sep 17 00:00:00 2001
From: Ary Borenszweig <aborenszweig@manas.com.ar>
Date: Sat, 3 May 2014 15:25:56 -0300
Subject: [PATCH] Add byval attribute to call arguments. Fixes #117.

---
 libs/llvm/lib_llvm.cr             |  1 +
 spec/compiler/codegen/fun_spec.cr | 11 +++++++++
 src/compiler/crystal/codegen.cr   | 40 +++++++++++++++++++++++--------
 3 files changed, 42 insertions(+), 10 deletions(-)

diff --git a/libs/llvm/lib_llvm.cr b/libs/llvm/lib_llvm.cr
index 228be64007..7511862d4b 100644
--- a/libs/llvm/lib_llvm.cr
+++ b/libs/llvm/lib_llvm.cr
@@ -153,6 +153,7 @@ lib LibLLVM("`llvm-config --libs --ldflags`")
   end
 
   fun add_attribute = LLVMAddAttribute(arg : ValueRef, attr : Int32)
+  fun add_instr_attribute = LLVMAddInstrAttribute(instr : ValueRef, index : UInt32, attr : Attribute)
   fun add_clause = LLVMAddClause(lpad : ValueRef, clause_val : ValueRef)
   fun add_function = LLVMAddFunction(module : ModuleRef, name : UInt8*, type : TypeRef) : ValueRef
   fun add_function_attr = LLVMAddFunctionAttr(fn : ValueRef, pa : Int32);
diff --git a/spec/compiler/codegen/fun_spec.cr b/spec/compiler/codegen/fun_spec.cr
index 1437635684..211bb86d21 100644
--- a/spec/compiler/codegen/fun_spec.cr
+++ b/spec/compiler/codegen/fun_spec.cr
@@ -155,4 +155,15 @@ describe "Code gen: fun" do
       g.call(10, 20)
       ").to_i.should eq(11)
   end
+
+  it "calls fun pointer with union (passed by value) arg" do
+    run("
+      struct Number
+        def abs; self; end
+      end
+
+      f = ->(x : Int32 | Float64) { x.abs }
+      f.call(1 || 1.5).to_i
+      ").to_i.should eq(1)
+  end
 end
diff --git a/src/compiler/crystal/codegen.cr b/src/compiler/crystal/codegen.cr
index 4ab396c702..20a9cebf4e 100644
--- a/src/compiler/crystal/codegen.cr
+++ b/src/compiler/crystal/codegen.cr
@@ -662,7 +662,7 @@ module Crystal
     end
 
     def codegen_primitive_fun_call(node, target_def, call_args)
-      codegen_call_or_invoke(node, call_args[0], call_args[1 .. -1], true, target_def.type)
+      codegen_call_or_invoke(node, target_def, nil, call_args[0], call_args[1 .. -1], true, target_def.type)
     end
 
     def codegen_primitive_pointer_diff(node, target_def, call_args)
@@ -1574,7 +1574,7 @@ module Crystal
         end
 
         raise_fun = main_fun(RAISE_NAME)
-        codegen_call_or_invoke(node, raise_fun, [bit_cast(unwind_ex_obj, type_of(raise_fun.get_param(0)))], true, @mod.no_return)
+        codegen_call_or_invoke(node, nil, nil, raise_fun, [bit_cast(unwind_ex_obj, type_of(raise_fun.get_param(0)))], true, @mod.no_return)
       end
 
       if node_ensure
@@ -1713,12 +1713,12 @@ module Crystal
           accept arg
 
           def_arg = target_def.args[i]?
-          if def_arg
-            call_args << downcast(@last, def_arg.type, arg.type, true)
-          else
-            # Def argument might be missing if it's a variadic call
-            call_args << @last
-          end
+          call_arg = @last
+
+          # Def argument might be missing if it's a variadic call
+          call_arg = downcast(call_arg, def_arg.type, arg.type, true) if def_arg
+
+          call_args << call_arg
         end
       end
 
@@ -1851,10 +1851,10 @@ module Crystal
       end
 
       func = target_def_fun(target_def, self_type)
-      codegen_call_or_invoke(node, func, call_args, target_def.raises, target_def.type)
+      codegen_call_or_invoke(node, target_def, self_type, func, call_args, target_def.raises, target_def.type)
     end
 
-    def codegen_call_or_invoke(node, func, call_args, raises, type)
+    def codegen_call_or_invoke(node, target_def, self_type, func, call_args, raises, type)
       if raises && (handler = @exception_handlers.try &.last?)
         invoke_out_block = new_block "invoke_out"
         @last = @builder.invoke func, call_args, invoke_out_block, handler.catch_block
@@ -1863,6 +1863,7 @@ module Crystal
         @last = call func, call_args
       end
 
+      set_call_by_val_attributes node, target_def, self_type
       emit_debug_metadata node, @last if @debug
 
       case type
@@ -1881,6 +1882,25 @@ module Crystal
       @last
     end
 
+    def set_call_by_val_attributes(node, target_def, self_type)
+      # We don't want by_val in C functions
+      return if target_def.is_a?(External)
+
+      arg_offset = 1
+      if node.is_a?(Call)
+        args = node.args
+        arg_offset += 1 if node.obj.try(&.type.passed_as_self?) || self_type.try(&.passed_as_self?)
+      else
+        args = target_def.try(&.args)
+      end
+
+      args.try &.each_with_index do |arg, i|
+        if arg.type.passed_by_value?
+          LibLLVM.add_instr_attribute(@last, (i + arg_offset).to_u32, LibLLVM::Attribute::ByVal)
+        end
+      end
+    end
+
     def emit_debug_metadata(node, value)
       # if value.is_a?(LibLLVM::ValueRef) && !LLVM.constant?(value) && !value.is_a?(LibLLVM::BasicBlockRef)
         if md = dbg_metadata(node)
-- 
GitLab