From 4fb32b775bf5f71abf46b6efe936c019aab100b4 Mon Sep 17 00:00:00 2001
From: Ary Borenszweig <aborenszweig@manas.com.ar>
Date: Sun, 8 Jun 2014 09:07:29 -0300
Subject: [PATCH] Allow passing wrapper struct and pointer of wrapper struct to
 C

---
 spec/compiler/codegen/def_spec.cr           |  4 +--
 spec/compiler/codegen/lib_spec.cr           | 33 +++++++++++++++++++++
 spec/compiler/codegen/while_spec.cr         |  5 +---
 spec/compiler/type_inference/lib_spec.cr    | 33 +++++++++++++++++++++
 spec/spec_helper.cr                         |  6 ++++
 src/compiler/crystal/codegen.cr             |  4 +++
 src/compiler/crystal/type_inference/call.cr |  2 ++
 src/compiler/crystal/types.cr               | 27 +++++++++++++++++
 8 files changed, 107 insertions(+), 7 deletions(-)

diff --git a/spec/compiler/codegen/def_spec.cr b/spec/compiler/codegen/def_spec.cr
index 5a85363a24..8be14b2644 100755
--- a/spec/compiler/codegen/def_spec.cr
+++ b/spec/compiler/codegen/def_spec.cr
@@ -52,9 +52,7 @@ describe "Code gen: def" do
   end
 
   it "builds infinite recursive function" do
-    node = parse "def foo; foo; end; foo"
-    result = infer_type node
-    result.program.build result.node, Program::BuildOptions.single_module
+    build "def foo; foo; end; foo"
   end
 
   it "unifies all calls to same def" do
diff --git a/spec/compiler/codegen/lib_spec.cr b/spec/compiler/codegen/lib_spec.cr
index b3addb997f..09885b50a9 100644
--- a/spec/compiler/codegen/lib_spec.cr
+++ b/spec/compiler/codegen/lib_spec.cr
@@ -26,4 +26,37 @@ describe "Code gen: lib" do
       foo
     ")
   end
+
+  it "allows passing wrapper struct to c" do
+    build("
+      lib C
+        fun foo(x : Void*) : Int32
+      end
+
+      struct Wrapper
+        def initialize(@x)
+        end
+      end
+
+      w = Wrapper.new(Pointer(Void).null)
+      C.foo(w)
+      ")
+  end
+
+  it "allows passing pointer wrapper struct to c" do
+    build("
+      lib C
+        fun foo(x : Void**) : Int32
+      end
+
+      struct Wrapper
+        def initialize(@x)
+        end
+      end
+
+      w = Wrapper.new(Pointer(Void).null)
+      p = Pointer(Wrapper).null
+      C.foo(p)
+      ")
+  end
 end
diff --git a/spec/compiler/codegen/while_spec.cr b/spec/compiler/codegen/while_spec.cr
index f980233ce5..eb7f123a9d 100755
--- a/spec/compiler/codegen/while_spec.cr
+++ b/spec/compiler/codegen/while_spec.cr
@@ -27,10 +27,7 @@ describe "Codegen: while" do
   end
 
   it "codegens endless while" do
-    program = Program.new
-    node = parse "while true; end"
-    node = program.infer_type node
-    program.build node, Program::BuildOptions.single_module
+    build "while true; end"
   end
 
   it "codegens while with declared var 1" do
diff --git a/spec/compiler/type_inference/lib_spec.cr b/spec/compiler/type_inference/lib_spec.cr
index fe2c5ed752..0039500b0f 100755
--- a/spec/compiler/type_inference/lib_spec.cr
+++ b/spec/compiler/type_inference/lib_spec.cr
@@ -133,4 +133,37 @@ describe "Type inference: lib" do
     foo = lib_type.lookup_first_def("foo", false) as External
     foo.real_name.should eq("bar")
   end
+
+  it "allows passing wrapper struct to c" do
+    assert_type("
+      lib C
+        fun foo(x : Void*) : Int32
+      end
+
+      struct Wrapper
+        def initialize(@x)
+        end
+      end
+
+      w = Wrapper.new(Pointer(Void).null)
+      C.foo(w)
+      ") { int32 }
+  end
+
+  it "allows passing pointer of wrapper struct to c" do
+    assert_type("
+      lib C
+        fun foo(x : Void**) : Int32
+      end
+
+      struct Wrapper
+        def initialize(@x)
+        end
+      end
+
+      w = Wrapper.new(Pointer(Void).null)
+      p = Pointer(Wrapper).null
+      C.foo(p)
+      ") { int32 }
+  end
 end
diff --git a/spec/spec_helper.cr b/spec/spec_helper.cr
index ba4d6a42ee..297fe2e129 100644
--- a/spec/spec_helper.cr
+++ b/spec/spec_helper.cr
@@ -117,6 +117,12 @@ def parse(string)
   Parser.parse string
 end
 
+def build(code)
+  node = parse code
+  result = infer_type node
+  result.program.build result.node, Program::BuildOptions.single_module
+end
+
 def run(code)
   Program.new.run(code)
 end
diff --git a/src/compiler/crystal/codegen.cr b/src/compiler/crystal/codegen.cr
index b776d617c3..eee9fba319 100644
--- a/src/compiler/crystal/codegen.cr
+++ b/src/compiler/crystal/codegen.cr
@@ -1618,6 +1618,10 @@ module Crystal
           elsif is_external && def_arg && arg.type.nil_type? && (def_arg.type.pointer? || def_arg.type.fun?)
             # Nil to pointer
             call_arg = LLVM.null(llvm_c_type(def_arg.type))
+          elsif is_external && def_arg && arg.type.struct_wrapper_of?(def_arg.type)
+            call_arg = @builder.extract_value load(call_arg), 0
+          elsif is_external && def_arg && arg.type.pointer_struct_wrapper_of?(def_arg.type)
+            call_arg = bit_cast call_arg, llvm_type(def_arg.type)
           else
             # Def argument might be missing if it's a variadic call
             call_arg = downcast(call_arg, def_arg.type, arg.type, true) if def_arg
diff --git a/src/compiler/crystal/type_inference/call.cr b/src/compiler/crystal/type_inference/call.cr
index 264e7cd0be..fdcefafc9e 100644
--- a/src/compiler/crystal/type_inference/call.cr
+++ b/src/compiler/crystal/type_inference/call.cr
@@ -433,6 +433,8 @@ module Crystal
             # OK: string will be sent as UInt8
           elsif expected_type.is_a?(FunType) && actual_type.is_a?(FunType) && expected_type.return_type == mod.void && expected_type.arg_types == actual_type.arg_types
             # OK: fun will be cast to return void
+          elsif actual_type.struct_wrapper_of?(expected_type) || actual_type.pointer_struct_wrapper_of?(expected_type)
+            # OK: same memory layout
           else
             arg_name = typed_def_arg.name.length > 0 ? "'#{typed_def_arg.name}'" : "##{i + 1}"
             self_arg.raise "argument #{arg_name} of '#{full_name(obj_type)}' must be #{expected_type}, not #{actual_type}"
diff --git a/src/compiler/crystal/types.cr b/src/compiler/crystal/types.cr
index 41c6c70559..e84c62b2e6 100644
--- a/src/compiler/crystal/types.cr
+++ b/src/compiler/crystal/types.cr
@@ -187,6 +187,14 @@ module Crystal
       1
     end
 
+    def struct_wrapper_of?(type)
+      false
+    end
+
+    def pointer_struct_wrapper_of?(type)
+      false
+    end
+
     def lookup_def_instance(def_object_id, arg_types, block_type)
       raise "Bug: #{self} doesn't implement lookup_def_instance"
     end
@@ -920,6 +928,15 @@ module Crystal
       struct?
     end
 
+    def struct_wrapper_of?(type)
+      return false unless struct?
+
+      ivars = all_instance_vars
+      return false unless ivars.length == 1
+
+      ivars.first_value.type? == type
+    end
+
     def type_desc
       struct? ? "struct" : "class"
     end
@@ -1087,6 +1104,10 @@ module Crystal
       false
     end
 
+    def struct_wrapper_of?(type)
+      false
+    end
+
     def allocated
       true
     end
@@ -1493,6 +1514,12 @@ module Crystal
       false
     end
 
+    def pointer_struct_wrapper_of?(type)
+      return false unless type.is_a?(PointerInstanceType)
+
+      element_type.struct_wrapper_of?(type.element_type)
+    end
+
     def allocated
       true
     end
-- 
GitLab