diff --git a/spec/compiler/macro/macro_expander_spec.cr b/spec/compiler/macro/macro_expander_spec.cr index 5017830b4e88eb5a22dde9d9b8bb1cb47efdd5ae..15213c7222335fc576d4a5843ebe26f1fe7e827d 100644 --- a/spec/compiler/macro/macro_expander_spec.cr +++ b/spec/compiler/macro/macro_expander_spec.cr @@ -54,6 +54,10 @@ describe "MacroExpander" do assert_macro "", %({{{1, 2, 3}}}), [] of ASTNode, %({1, 2, 3}) end + it "expands macro with string interpolation" do + assert_macro "", "{{ \"hello\#{1 == 1}world\" }}", [] of ASTNode, "hellotrueworld" + end + it "expands macro with var sustitution" do assert_macro "x", "{{x}}", [Var.new("hello")] of ASTNode, "hello" end @@ -202,6 +206,42 @@ describe "MacroExpander" do assert_macro "", %({{[1, 2, 3].empty?}}), [] of ASTNode, "false" end + it "executes array join" do + assert_macro "", %({{[1, 2, 3].join ", "}}), [] of ASTNode, "1, 2, 3" + end + + it "executes array join with strings" do + assert_macro "", %({{["a", "b"].join ", "}}), [] of ASTNode, "a, b" + end + + it "executes array map" do + assert_macro "", %({{[1, 2, 3].map { |e| e == 2 }}}), [] of ASTNode, "[false, true, false]" + end + + it "executes array map with arg" do + assert_macro "x", %({{x.map { |e| e }}}), [ArrayLiteral.new([Call.new(nil, "hello")] of ASTNode)] of ASTNode, "[hello]" + end + + it "executes array select" do + assert_macro "", %({{[1, 2, 3].select { |e| e == 1 }}}), [] of ASTNode, "[1]" + end + + it "executes array any? (true)" do + assert_macro "", %({{[1, 2, 3].any? { |e| e == 1 }}}), [] of ASTNode, "true" + end + + it "executes array any? (false)" do + assert_macro "", %({{[1, 2, 3].any? { |e| e == 4 }}}), [] of ASTNode, "false" + end + + it "executes array all? (true)" do + assert_macro "", %({{[1, 1, 1].all? { |e| e == 1 }}}), [] of ASTNode, "true" + end + + it "executes array all? (false)" do + assert_macro "", %({{[1, 2, 1].all? { |e| e == 1 }}}), [] of ASTNode, "false" + end + it "executes hash length" do assert_macro "", %({{{a: 1, b: 3}.length}}), [] of ASTNode, "2" end diff --git a/src/compiler/crystal/ast.cr b/src/compiler/crystal/ast.cr index ed118a4fe176e0db8862b82d31d647044a8d9ac1..42e87d0640da152b5e85eacb1f50f2d512744b94 100644 --- a/src/compiler/crystal/ast.cr +++ b/src/compiler/crystal/ast.cr @@ -1972,6 +1972,22 @@ module Crystal TupleIndexer.new(index) end end + + # Ficticious node to wrap a call inside a macro + class MacroCallWrapper < ASTNode + property call + + def initialize(@call) + end + + def to_macro_id + call.to_macro_id + end + + def clone_without_location + self + end + end end require "to_s" diff --git a/src/compiler/crystal/macro_expander.cr b/src/compiler/crystal/macro_expander.cr index ce90bf7569d7a36da123346feb0f77852f9f95b1..0e6d22e7d498ca3e12d98e405acfc3a0bb4bb62a 100644 --- a/src/compiler/crystal/macro_expander.cr +++ b/src/compiler/crystal/macro_expander.cr @@ -16,10 +16,12 @@ module Crystal end class MacroVisitor < Visitor + getter last + def self.new(mod, a_macro, call) vars = {} of String => ASTNode a_macro.args.zip(call.args) do |macro_arg, call_arg| - vars[macro_arg.name] = call_arg + vars[macro_arg.name] = call_arg.to_macro_var end new(mod, vars) @@ -30,6 +32,15 @@ module Crystal @last = Nop.new end + def define_var(name, value) + @vars[name] = value + end + + def accept(node) + node.accept self + @last + end + def visit(node : Expressions) node.expressions.each &.accept self false @@ -56,6 +67,20 @@ module Crystal end end + def visit(node : StringInterpolation) + @last = StringLiteral.new(String.build do |str| + node.expressions.each do |exp| + if exp.is_a?(StringLiteral) + str << exp.value + else + exp.accept self + str << @last.to_macro_id + end + end + end) + false + end + def visit(node : MacroIf) node.cond.accept self @@ -128,7 +153,7 @@ module Crystal @last end - @last = receiver.interpret(node.name, args) + @last = receiver.interpret(node.name, args, node.block, self) else # no receiver: special calls execute_special_call node @@ -215,6 +240,10 @@ module Crystal @last = node end + def visit(node : MacroCallWrapper) + @last = node + end + def visit(node : ASTNode) node.raise "can't execute this in a macro" end @@ -230,11 +259,15 @@ module Crystal to_s end + def to_macro_var + self + end + def truthy? true end - def interpret(method, args) + def interpret(method, args, block, interpreter) case method when "stringify" unless args.length == 0 @@ -260,14 +293,21 @@ module Crystal end end - def interpret_argumentless_method(method, args) - unless args.length == 0 - raise "wrong number of arguments for #{method} (#{args.length} for 0)" - end - + def interpret_argless_method(method, args) + interpret_check_args_length method, args, 0 yield end + def interpret_one_arg_method(method, args) + interpret_check_args_length method, args, 1 + yield args.first + end + + def interpret_check_args_length(method, args, length) + unless args.length == length + raise "wrong number of arguments for #{method} (#{args.length} for #{length})" + end + end end class NilLiteral @@ -291,7 +331,7 @@ module Crystal end class NumberLiteral - def interpret(method, args) + def interpret(method, args, block, interpreter) case method when ">" compare_to(args.first) { |me, other| me > other } @@ -324,16 +364,16 @@ module Crystal @value end - def interpret(method, args) + def interpret(method, args, block, interpreter) case method when "downcase" - interpret_argumentless_method(method, args) { StringLiteral.new(@value.downcase) } + interpret_argless_method(method, args) { StringLiteral.new(@value.downcase) } when "empty?" - interpret_argumentless_method(method, args) { BoolLiteral.new(@value.empty?) } + interpret_argless_method(method, args) { BoolLiteral.new(@value.empty?) } when "length" - interpret_argumentless_method(method, args) { NumberLiteral.new(@value.length, :i32) } + interpret_argless_method(method, args) { NumberLiteral.new(@value.length, :i32) } when "lines" - interpret_argumentless_method(method, args) { create_array_literal_from_values(@value.lines) } + interpret_argless_method(method, args) { create_array_literal_from_values(@value.lines) } when "split" case args.length when 0 @@ -349,9 +389,9 @@ module Crystal raise "wrong number of arguments for split (#{args.length} for 0, 1)" end when "strip" - interpret_argumentless_method(method, args) { StringLiteral.new(@value.strip) } + interpret_argless_method(method, args) { StringLiteral.new(@value.strip) } when "upcase" - interpret_argumentless_method(method, args) { StringLiteral.new(@value.upcase) } + interpret_argless_method(method, args) { StringLiteral.new(@value.upcase) } else super end @@ -363,12 +403,64 @@ module Crystal end class ArrayLiteral - def interpret(method, args) + def interpret(method, args, block, interpreter) case method + when "any?" + interpret_argless_method(method, args) do + raise "any? expects a block" unless block + + block_arg = block.args.first? + + BoolLiteral.new(elements.any? do |elem| + block_value = interpreter.accept elem.to_macro_var + interpreter.define_var(block_arg.name, block_value) if block_arg + interpreter.accept(block.body).truthy? + end) + end + when "all?" + interpret_argless_method(method, args) do + raise "all? expects a block" unless block + + block_arg = block.args.first? + + BoolLiteral.new(elements.all? do |elem| + block_value = interpreter.accept elem.to_macro_var + interpreter.define_var(block_arg.name, block_value) if block_arg + interpreter.accept(block.body).truthy? + end) + end when "empty?" - interpret_argumentless_method(method, args) { BoolLiteral.new(elements.empty?) } + interpret_argless_method(method, args) { BoolLiteral.new(elements.empty?) } + when "join" + interpret_one_arg_method(method, args) do |arg| + StringLiteral.new(elements.map(&.to_macro_id).join arg.to_macro_id) + end when "length" - interpret_argumentless_method(method, args) { NumberLiteral.new(elements.length, :i32) } + interpret_argless_method(method, args) { NumberLiteral.new(elements.length, :i32) } + when "map" + interpret_argless_method(method, args) do + raise "map expects a block" unless block + + block_arg = block.args.first? + + ArrayLiteral.new(elements.map do |elem| + block_value = interpreter.accept elem.to_macro_var + interpreter.define_var(block_arg.name, block_value) if block_arg + interpreter.accept block.body + end) + end + when "select" + interpret_argless_method(method, args) do + raise "select expects a block" unless block + + block_arg = block.args.first? + + ArrayLiteral.new(elements.select do |elem| + block_value = interpreter.accept elem.to_macro_var + interpreter.define_var(block_arg.name, block_value) if block_arg + interpreter.accept(block.body).truthy? + end) + end when "[]" case args.length when 1 @@ -394,12 +486,12 @@ module Crystal end class HashLiteral - def interpret(method, args) + def interpret(method, args, block, interpreter) case method when "empty?" - interpret_argumentless_method(method, args) { BoolLiteral.new(keys.empty?) } + interpret_argless_method(method, args) { BoolLiteral.new(keys.empty?) } when "length" - interpret_argumentless_method(method, args) { NumberLiteral.new(keys.length, :i32) } + interpret_argless_method(method, args) { NumberLiteral.new(keys.length, :i32) } when "[]" case args.length when 1 @@ -421,12 +513,12 @@ module Crystal end class TupleLiteral - def interpret(method, args) + def interpret(method, args, block, interpreter) case method when "empty?" - interpret_argumentless_method(method, args) { BoolLiteral.new(elements.empty?) } + interpret_argless_method(method, args) { BoolLiteral.new(elements.empty?) } when "length" - interpret_argumentless_method(method, args) { NumberLiteral.new(elements.length, :i32) } + interpret_argless_method(method, args) { NumberLiteral.new(elements.length, :i32) } when "[]" case args.length when 1 @@ -471,6 +563,10 @@ module Crystal to_s end end + + def to_macro_var + MacroCallWrapper.new(self) + end end class InstanceVar diff --git a/src/compiler/crystal/to_s.cr b/src/compiler/crystal/to_s.cr index 8c7179169f5618aeee6d6bc2b816fa871b208aff..4f869ba3ac8b11d5480bd323984826cd81ab6af8 100644 --- a/src/compiler/crystal/to_s.cr +++ b/src/compiler/crystal/to_s.cr @@ -301,6 +301,11 @@ module Crystal !(node.obj && node.args.empty?) || node.block_arg end + def visit(node : MacroCallWrapper) + @str << node.call.to_macro_id + false + end + def keyword(str) str end diff --git a/src/compiler/crystal/transformer.cr b/src/compiler/crystal/transformer.cr index 287150c98ea4e026b702c3dcd62e630f50fc0cf7..2f46c6218fc192ce6087d01fc2c5b93fadf54f28 100644 --- a/src/compiler/crystal/transformer.cr +++ b/src/compiler/crystal/transformer.cr @@ -525,6 +525,10 @@ module Crystal node end + def transform(node : MacroCallWrapper) + node + end + def transform_many(exps) exps.map! { |exp| exp.transform(self) } if exps end