diff --git a/spec/compiler/codegen/macro_spec.cr b/spec/compiler/codegen/macro_spec.cr index a7aaf027fee610d948d0bd8a4a13328c5f6da099..db12b833fdbad15f473c97fd37fb36c83180a623 100755 --- a/spec/compiler/codegen/macro_spec.cr +++ b/spec/compiler/codegen/macro_spec.cr @@ -288,4 +288,51 @@ describe "Code gen: macro" do foo(1) )).to_i.should eq(3) end + + it "does inherited macro" do + run(" + class Foo + macro inherited + $x = 1 + end + end + + class Bar < Foo + end + + $x + ").to_i.should eq(1) + end + + it "does included macro" do + run(" + module Foo + macro included + $x = 1 + end + end + + class Bar + include Foo + end + + $x + ").to_i.should eq(1) + end + + it "does extended macro" do + run(" + module Foo + macro extended + $x = 1 + end + end + + class Bar + extend Foo + end + + $x + ").to_i.should eq(1) + end end diff --git a/spec/compiler/type_inference/macro_spec.cr b/spec/compiler/type_inference/macro_spec.cr index df0b192b2824d7fc32302be687edd46240e0b713..90b15fc85342430b8ac6d514272d13e40b631b41 100755 --- a/spec/compiler/type_inference/macro_spec.cr +++ b/spec/compiler/type_inference/macro_spec.cr @@ -133,4 +133,57 @@ describe "Type inference: macro" do foo(1) ), "wrong number of arguments for macro 'foo' (1 for 0)" end + + it "does inherited macro" do + assert_type(" + class Foo + macro inherited + def self.{{@name.downcase.id}} + 1 + end + end + end + + class Bar < Foo + end + + Bar.bar + ") { int32 } + end + + it "does included macro" do + assert_type(" + module Foo + macro included + def self.{{@name.downcase.id}} + 1 + end + end + end + + class Bar + include Foo + end + + Bar.bar + ") { int32 } + end + + it "does extended macro" do + assert_type(" + module Foo + macro extended + def self.{{@name.downcase.id}} + 1 + end + end + end + + class Bar + extend Foo + end + + Bar.bar + ") { int32 } + end end diff --git a/src/compiler/crystal/codegen.cr b/src/compiler/crystal/codegen.cr index 6e3cfd0b003bd1733b49203f5d917aff5d7ae9dc..475896149a69b2f65aebf69aad7adc62ae8ce2e2 100644 --- a/src/compiler/crystal/codegen.cr +++ b/src/compiler/crystal/codegen.cr @@ -394,6 +394,8 @@ module Crystal end def visit(node : ClassDef) + node.runtime_initializers.try &.each &.accept self + accept node.body @last = llvm_nil false @@ -431,11 +433,15 @@ module Crystal end def visit(node : Include) + node.runtime_initializers.try &.each &.accept self + @last = llvm_nil false end def visit(node : Extend) + node.runtime_initializers.try &.each &.accept self + @last = llvm_nil false end diff --git a/src/compiler/crystal/macros.cr b/src/compiler/crystal/macros.cr index ee82fa7dd5fe7ce06d748b5e7cb8a97194a1052b..b5881ff356f46d5bee69dc0aad771a94dc43eb5c 100644 --- a/src/compiler/crystal/macros.cr +++ b/src/compiler/crystal/macros.cr @@ -1,5 +1,17 @@ module Crystal class Program + def push_def_macro(def) + @def_macros << def + end + + def expand_macro(scope : Type, a_macro, call) + @macro_expander.expand scope, a_macro, call + end + + def expand_macro(scope : Type, node) + @macro_expander.expand scope, node + end + def expand_def_macros until @def_macros.empty? def_macro = @def_macros.pop diff --git a/src/compiler/crystal/program.cr b/src/compiler/crystal/program.cr index 5bacb93e09f9f9c36c139a808620aa0df73b759c..02ca32a1b41137c2ecc4f8dad1e730461ed34daf 100644 --- a/src/compiler/crystal/program.cr +++ b/src/compiler/crystal/program.cr @@ -137,18 +137,6 @@ module Crystal flags end - def push_def_macro(def) - @def_macros << def - end - - def expand_macro(scope : Type, a_macro, call) - @macro_expander.expand scope, a_macro, call - end - - def expand_macro(scope : Type, node) - @macro_expander.expand scope, node - end - def self.exec(command) Pipe.open(command, "r") do |pipe| pipe.gets.try &.strip diff --git a/src/compiler/crystal/type_inference.cr b/src/compiler/crystal/type_inference.cr index 57cef96e5332eacac73765362108086cb82f2e55..f0db48df281bfdbdc7cf73acc1c89ac9449c1fda 100644 --- a/src/compiler/crystal/type_inference.cr +++ b/src/compiler/crystal/type_inference.cr @@ -434,7 +434,12 @@ module Crystal end def visit(node : Macro) - current_type.metaclass.add_macro node + begin + current_type.metaclass.add_macro node + rescue ex + node.raise ex.message + end + node.set_type @mod.nil false end @@ -894,6 +899,8 @@ module Crystal end end + created_new_type = false + if type unless type.is_a?(ClassType) node.raise "#{name} is not a #{node.struct ? "struct" : "class"}, it's a #{type.type_desc}" @@ -911,7 +918,7 @@ module Crystal node_superclass.not_nil!.raise "#{superclass} is not a class, it's a #{superclass.type_desc}" end - needs_force_add_subclass = true + created_new_type = true if type_vars = node.type_vars type = GenericClassType.new @mod, scope, name, superclass, type_vars, false else @@ -923,10 +930,15 @@ module Crystal end @types.push type + + if created_new_type + run_hooks(superclass.metaclass, type, :inherited, node) + end + node.body.accept self @types.pop - if needs_force_add_subclass + if created_new_type raise "Bug" unless type.is_a?(InheritableClass) type.force_add_subclass end @@ -936,6 +948,21 @@ module Crystal false end + def run_hooks(type_with_hooks, current_type, kind, node) + hooks = type_with_hooks.hooks + return unless hooks + + hooks.each do |hook| + next if hook.kind != kind + + expanded = expand_macro(hook.macro, node) do + @mod.expand_macro current_type, hook.macro.body + end + expanded.accept self + node.add_runtime_initializer(expanded) + end + end + def visit(node : ModuleDef) if node.name.names.length == 1 && !node.name.global scope = current_type @@ -980,7 +1007,7 @@ module Crystal end def visit(node : Include) - include_in current_type, node.name + include_in current_type, node, :included node.type = @mod.nil @@ -988,7 +1015,7 @@ module Crystal end def visit(node : Extend) - include_in current_type.metaclass, node.name + include_in current_type.metaclass, node, :extended node.type = @mod.nil @@ -1925,7 +1952,8 @@ module Crystal false end - def include_in(current_type, node_name) + def include_in(current_type, node, kind) + node_name = node.name if node_name.is_a?(Generic) type = lookup_path_type(node_name.name) else @@ -1946,7 +1974,7 @@ module Crystal end mapping = Hash.zip(type.type_vars, node_name.type_vars) - current_type.include IncludedGenericModule.new(@mod, type, current_type, mapping) + module_to_include = IncludedGenericModule.new(@mod, type, current_type, mapping) else if type.is_a?(GenericModuleType) if current_type.is_a?(GenericType) @@ -1959,17 +1987,21 @@ module Crystal type.type_vars.zip(current_type.type_vars) do |type_var, current_type_var| mapping[type_var] = Path.new([current_type_var]) end - - current_type.include IncludedGenericModule.new(@mod, type, current_type, mapping) + module_to_include = IncludedGenericModule.new(@mod, type, current_type, mapping) else node_name.raise "#{type} is a generic module" end else - current_type.include type + module_to_include = type end end - rescue ex - node_name.raise ex.message + + begin + current_type.include module_to_include + run_hooks type.metaclass, current_type, kind, node + rescue ex + node_name.raise ex.message + end end def process_struct_or_union_def(node, klass) diff --git a/src/compiler/crystal/type_inference/ast.cr b/src/compiler/crystal/type_inference/ast.cr index 4dd728f70af19d6b24e030c82d81d46955a33ca0..cc8353da14cfcacbe64c754bd9c157a88422dbdb 100644 --- a/src/compiler/crystal/type_inference/ast.cr +++ b/src/compiler/crystal/type_inference/ast.cr @@ -445,4 +445,25 @@ module Crystal include ExpandableNode end {% end %} + + module RuntimeInitializable + getter runtime_initializers + + def add_runtime_initializer(node) + initializers = @runtime_initializers ||= [] of ASTNode + initializers << node + end + end + + class ClassDef + include RuntimeInitializable + end + + class Include + include RuntimeInitializable + end + + class Extend + include RuntimeInitializable + end end diff --git a/src/compiler/crystal/types.cr b/src/compiler/crystal/types.cr index 24b19d880a9109df6e07755986a2e09ebb611ccb..75b96aab2e9388923eb5098600db88366d656847 100644 --- a/src/compiler/crystal/types.cr +++ b/src/compiler/crystal/types.cr @@ -297,6 +297,10 @@ module Crystal raise "Bug: #{self} doesn't implement macros" end + def hooks + nil + end + def add_macro(a_def) raise "Bug: #{self} doesn't implement add_macro" end @@ -619,10 +623,12 @@ module Crystal make_named_tuple DefKey, [restrictions, yields] make_named_tuple SortedDefKey, [name, length, yields] + make_named_tuple Hook, [kind, :macro] getter defs getter sorted_defs getter macros + getter hooks def add_def(a_def) a_def.owner = self @@ -677,6 +683,15 @@ module Crystal end def add_macro(a_def) + case a_def.name + when "inherited" + return add_hook :inherited, a_def + when "included" + return add_hook :included, a_def + when "extended" + return add_hook :extended, a_def + end + macros = (@macros ||= {} of String => Hash(Int32, Macro)) hash = (macros[a_def.name] ||= {} of Int32 => Macro) @@ -687,6 +702,15 @@ module Crystal end end + def add_hook(kind, a_def) + if a_def.args.length != 0 + raise "macro '#{kind}' cannot have arguments" + end + + hooks = @hooks ||= [] of Hook + hooks << Hook.new(kind, a_def) + end + def filter_by_responds_to(name) has_def?(name) ? self : nil end