From d8f8d9d558d04ae402713d0f62a6c53faa3f9d25 Mon Sep 17 00:00:00 2001
From: Ary Borenszweig <aborenszweig@manas.com.ar>
Date: Sun, 29 Jun 2014 11:05:10 -0300
Subject: [PATCH] Added inherited, included and extended hook macros

---
 spec/compiler/codegen/macro_spec.cr        | 47 ++++++++++++++++++
 spec/compiler/type_inference/macro_spec.cr | 53 ++++++++++++++++++++
 src/compiler/crystal/codegen.cr            |  6 +++
 src/compiler/crystal/macros.cr             | 12 +++++
 src/compiler/crystal/program.cr            | 12 -----
 src/compiler/crystal/type_inference.cr     | 56 +++++++++++++++++-----
 src/compiler/crystal/type_inference/ast.cr | 21 ++++++++
 src/compiler/crystal/types.cr              | 24 ++++++++++
 8 files changed, 207 insertions(+), 24 deletions(-)

diff --git a/spec/compiler/codegen/macro_spec.cr b/spec/compiler/codegen/macro_spec.cr
index a7aaf027fe..db12b833fd 100755
--- a/spec/compiler/codegen/macro_spec.cr
+++ b/spec/compiler/codegen/macro_spec.cr
@@ -288,4 +288,51 @@ describe "Code gen: macro" do
       foo(1)
       )).to_i.should eq(3)
   end
+
+  it "does inherited macro" do
+    run("
+      class Foo
+        macro inherited
+          $x = 1
+        end
+      end
+
+      class Bar < Foo
+      end
+
+      $x
+      ").to_i.should eq(1)
+  end
+
+  it "does included macro" do
+    run("
+      module Foo
+        macro included
+          $x = 1
+        end
+      end
+
+      class Bar
+        include Foo
+      end
+
+      $x
+      ").to_i.should eq(1)
+  end
+
+  it "does extended macro" do
+    run("
+      module Foo
+        macro extended
+          $x = 1
+        end
+      end
+
+      class Bar
+        extend Foo
+      end
+
+      $x
+      ").to_i.should eq(1)
+  end
 end
diff --git a/spec/compiler/type_inference/macro_spec.cr b/spec/compiler/type_inference/macro_spec.cr
index df0b192b28..90b15fc853 100755
--- a/spec/compiler/type_inference/macro_spec.cr
+++ b/spec/compiler/type_inference/macro_spec.cr
@@ -133,4 +133,57 @@ describe "Type inference: macro" do
       foo(1)
       ), "wrong number of arguments for macro 'foo' (1 for 0)"
   end
+
+  it "does inherited macro" do
+    assert_type("
+      class Foo
+        macro inherited
+          def self.{{@name.downcase.id}}
+            1
+          end
+        end
+      end
+
+      class Bar < Foo
+      end
+
+      Bar.bar
+      ") { int32 }
+  end
+
+  it "does included macro" do
+    assert_type("
+      module Foo
+        macro included
+          def self.{{@name.downcase.id}}
+            1
+          end
+        end
+      end
+
+      class Bar
+        include Foo
+      end
+
+      Bar.bar
+      ") { int32 }
+  end
+
+  it "does extended macro" do
+    assert_type("
+      module Foo
+        macro extended
+          def self.{{@name.downcase.id}}
+            1
+          end
+        end
+      end
+
+      class Bar
+        extend Foo
+      end
+
+      Bar.bar
+      ") { int32 }
+  end
 end
diff --git a/src/compiler/crystal/codegen.cr b/src/compiler/crystal/codegen.cr
index 6e3cfd0b00..475896149a 100644
--- a/src/compiler/crystal/codegen.cr
+++ b/src/compiler/crystal/codegen.cr
@@ -394,6 +394,8 @@ module Crystal
     end
 
     def visit(node : ClassDef)
+      node.runtime_initializers.try &.each &.accept self
+
       accept node.body
       @last = llvm_nil
       false
@@ -431,11 +433,15 @@ module Crystal
     end
 
     def visit(node : Include)
+      node.runtime_initializers.try &.each &.accept self
+
       @last = llvm_nil
       false
     end
 
     def visit(node : Extend)
+      node.runtime_initializers.try &.each &.accept self
+
       @last = llvm_nil
       false
     end
diff --git a/src/compiler/crystal/macros.cr b/src/compiler/crystal/macros.cr
index ee82fa7dd5..b5881ff356 100644
--- a/src/compiler/crystal/macros.cr
+++ b/src/compiler/crystal/macros.cr
@@ -1,5 +1,17 @@
 module Crystal
   class Program
+    def push_def_macro(def)
+      @def_macros << def
+    end
+
+    def expand_macro(scope : Type, a_macro, call)
+      @macro_expander.expand scope, a_macro, call
+    end
+
+    def expand_macro(scope : Type, node)
+      @macro_expander.expand scope, node
+    end
+
     def expand_def_macros
       until @def_macros.empty?
         def_macro = @def_macros.pop
diff --git a/src/compiler/crystal/program.cr b/src/compiler/crystal/program.cr
index 5bacb93e09..02ca32a1b4 100644
--- a/src/compiler/crystal/program.cr
+++ b/src/compiler/crystal/program.cr
@@ -137,18 +137,6 @@ module Crystal
       flags
     end
 
-    def push_def_macro(def)
-      @def_macros << def
-    end
-
-    def expand_macro(scope : Type, a_macro, call)
-      @macro_expander.expand scope, a_macro, call
-    end
-
-    def expand_macro(scope : Type, node)
-      @macro_expander.expand scope, node
-    end
-
     def self.exec(command)
       Pipe.open(command, "r") do |pipe|
         pipe.gets.try &.strip
diff --git a/src/compiler/crystal/type_inference.cr b/src/compiler/crystal/type_inference.cr
index 57cef96e53..f0db48df28 100644
--- a/src/compiler/crystal/type_inference.cr
+++ b/src/compiler/crystal/type_inference.cr
@@ -434,7 +434,12 @@ module Crystal
     end
 
     def visit(node : Macro)
-      current_type.metaclass.add_macro node
+      begin
+        current_type.metaclass.add_macro node
+      rescue ex
+        node.raise ex.message
+      end
+
       node.set_type @mod.nil
       false
     end
@@ -894,6 +899,8 @@ module Crystal
         end
       end
 
+      created_new_type = false
+
       if type
         unless type.is_a?(ClassType)
           node.raise "#{name} is not a #{node.struct ? "struct" : "class"}, it's a #{type.type_desc}"
@@ -911,7 +918,7 @@ module Crystal
           node_superclass.not_nil!.raise "#{superclass} is not a class, it's a #{superclass.type_desc}"
         end
 
-        needs_force_add_subclass = true
+        created_new_type = true
         if type_vars = node.type_vars
           type = GenericClassType.new @mod, scope, name, superclass, type_vars, false
         else
@@ -923,10 +930,15 @@ module Crystal
       end
 
       @types.push type
+
+      if created_new_type
+        run_hooks(superclass.metaclass, type, :inherited, node)
+      end
+
       node.body.accept self
       @types.pop
 
-      if needs_force_add_subclass
+      if created_new_type
         raise "Bug" unless type.is_a?(InheritableClass)
         type.force_add_subclass
       end
@@ -936,6 +948,21 @@ module Crystal
       false
     end
 
+    def run_hooks(type_with_hooks, current_type, kind, node)
+      hooks = type_with_hooks.hooks
+      return unless hooks
+
+      hooks.each do |hook|
+        next if hook.kind != kind
+
+        expanded = expand_macro(hook.macro, node) do
+          @mod.expand_macro current_type, hook.macro.body
+        end
+        expanded.accept self
+        node.add_runtime_initializer(expanded)
+      end
+    end
+
     def visit(node : ModuleDef)
       if node.name.names.length == 1 && !node.name.global
         scope = current_type
@@ -980,7 +1007,7 @@ module Crystal
     end
 
     def visit(node : Include)
-      include_in current_type, node.name
+      include_in current_type, node, :included
 
       node.type = @mod.nil
 
@@ -988,7 +1015,7 @@ module Crystal
     end
 
     def visit(node : Extend)
-      include_in current_type.metaclass, node.name
+      include_in current_type.metaclass, node, :extended
 
       node.type = @mod.nil
 
@@ -1925,7 +1952,8 @@ module Crystal
       false
     end
 
-    def include_in(current_type, node_name)
+    def include_in(current_type, node, kind)
+      node_name = node.name
       if node_name.is_a?(Generic)
         type = lookup_path_type(node_name.name)
       else
@@ -1946,7 +1974,7 @@ module Crystal
         end
 
         mapping = Hash.zip(type.type_vars, node_name.type_vars)
-        current_type.include IncludedGenericModule.new(@mod, type, current_type, mapping)
+        module_to_include = IncludedGenericModule.new(@mod, type, current_type, mapping)
       else
         if type.is_a?(GenericModuleType)
           if current_type.is_a?(GenericType)
@@ -1959,17 +1987,21 @@ module Crystal
             type.type_vars.zip(current_type.type_vars) do |type_var, current_type_var|
               mapping[type_var] = Path.new([current_type_var])
             end
-
-            current_type.include IncludedGenericModule.new(@mod, type, current_type, mapping)
+            module_to_include = IncludedGenericModule.new(@mod, type, current_type, mapping)
           else
             node_name.raise "#{type} is a generic module"
           end
         else
-          current_type.include type
+          module_to_include = type
         end
       end
-    rescue ex
-      node_name.raise ex.message
+
+      begin
+        current_type.include module_to_include
+        run_hooks type.metaclass, current_type, kind, node
+      rescue ex
+        node_name.raise ex.message
+      end
     end
 
     def process_struct_or_union_def(node, klass)
diff --git a/src/compiler/crystal/type_inference/ast.cr b/src/compiler/crystal/type_inference/ast.cr
index 4dd728f70a..cc8353da14 100644
--- a/src/compiler/crystal/type_inference/ast.cr
+++ b/src/compiler/crystal/type_inference/ast.cr
@@ -445,4 +445,25 @@ module Crystal
       include ExpandableNode
     end
   {% end %}
+
+  module RuntimeInitializable
+    getter runtime_initializers
+
+    def add_runtime_initializer(node)
+      initializers = @runtime_initializers ||= [] of ASTNode
+      initializers << node
+    end
+  end
+
+  class ClassDef
+    include RuntimeInitializable
+  end
+
+  class Include
+    include RuntimeInitializable
+  end
+
+  class Extend
+    include RuntimeInitializable
+  end
 end
diff --git a/src/compiler/crystal/types.cr b/src/compiler/crystal/types.cr
index 24b19d880a..75b96aab2e 100644
--- a/src/compiler/crystal/types.cr
+++ b/src/compiler/crystal/types.cr
@@ -297,6 +297,10 @@ module Crystal
       raise "Bug: #{self} doesn't implement macros"
     end
 
+    def hooks
+      nil
+    end
+
     def add_macro(a_def)
       raise "Bug: #{self} doesn't implement add_macro"
     end
@@ -619,10 +623,12 @@ module Crystal
 
     make_named_tuple DefKey, [restrictions, yields]
     make_named_tuple SortedDefKey, [name, length, yields]
+    make_named_tuple Hook, [kind, :macro]
 
     getter defs
     getter sorted_defs
     getter macros
+    getter hooks
 
     def add_def(a_def)
       a_def.owner = self
@@ -677,6 +683,15 @@ module Crystal
     end
 
     def add_macro(a_def)
+      case a_def.name
+      when "inherited"
+        return add_hook :inherited, a_def
+      when "included"
+        return add_hook :included, a_def
+      when "extended"
+        return add_hook :extended, a_def
+      end
+
       macros = (@macros ||= {} of String => Hash(Int32, Macro))
       hash = (macros[a_def.name] ||= {} of Int32 => Macro)
 
@@ -687,6 +702,15 @@ module Crystal
       end
     end
 
+    def add_hook(kind, a_def)
+      if a_def.args.length != 0
+        raise "macro '#{kind}' cannot have arguments"
+      end
+
+      hooks = @hooks ||= [] of Hook
+      hooks << Hook.new(kind, a_def)
+    end
+
     def filter_by_responds_to(name)
       has_def?(name) ? self : nil
     end
-- 
GitLab