diff --git a/spec/compiler/codegen/proc_spec.cr b/spec/compiler/codegen/proc_spec.cr index d53ab169dad9..2c4998d116ce 100644 --- a/spec/compiler/codegen/proc_spec.cr +++ b/spec/compiler/codegen/proc_spec.cr @@ -9,6 +9,14 @@ describe "Code gen: proc" do run("f = ->(x : Int32) { x &+ 1 }; f.call(41)").to_i.should eq(42) end + it "call proc literal with return type" do + run(<<-CR).to_b.should be_true + f = -> : Int32 | Float64 { 1 } + x = f.call + x.is_a?(Int32) && x == 1 + CR + end + it "call proc pointer" do run("def foo; 1; end; x = ->foo; x.call").to_i.should eq(1) end diff --git a/spec/compiler/formatter/formatter_spec.cr b/spec/compiler/formatter/formatter_spec.cr index 90f347238de3..6520b56478d1 100644 --- a/spec/compiler/formatter/formatter_spec.cr +++ b/spec/compiler/formatter/formatter_spec.cr @@ -780,6 +780,7 @@ describe Crystal::Formatter do assert_format "foo = 1\n->foo.bar=(Int32)" assert_format "foo = 1\n->foo.[](Int32)" assert_format "foo = 1\n->foo.[]=(Int32)" + assert_format "->{ x }" assert_format "->{\nx\n}", "->{\n x\n}" assert_format "->do\nx\nend", "->do\n x\nend" @@ -787,6 +788,12 @@ describe Crystal::Formatter do assert_format "->() do x end", "->do x end" assert_format "->( x , y ) { x }", "->(x, y) { x }" assert_format "->( x : Int32 , y ) { x }", "->(x : Int32, y) { x }" + assert_format "->{}" + + assert_format "-> : Int32 {}" + assert_format "->\n:\nInt32\n{\n}", "-> : Int32 {\n}" + assert_format "->( x )\n:\nInt32 { }", "->(x) : Int32 {}" + assert_format "->: Int32 do\nx\nend", "-> : Int32 do\n x\nend" {:+, :-, :*, :/, :^, :>>, :<<, :|, :&, :&+, :&-, :&*, :&**}.each do |sym| assert_format ":#{sym}" diff --git a/spec/compiler/macro/macro_methods_spec.cr b/spec/compiler/macro/macro_methods_spec.cr index b52ae9e3b3f3..f7d450ce26e7 100644 --- a/spec/compiler/macro/macro_methods_spec.cr +++ b/spec/compiler/macro/macro_methods_spec.cr @@ -2081,6 +2081,11 @@ module Crystal it "executes args" do assert_macro %({{x.args}}), "[z]", {x: ProcLiteral.new(Def.new("->", [Arg.new("z")]))} end + + it "executes return_type" do + assert_macro %({{x.return_type}}), "Int32", {x: ProcLiteral.new(Def.new("->", return_type: "Int32".path))} + assert_macro %({{x.return_type}}), "", {x: ProcLiteral.new(Def.new("->"))} + end end describe "proc pointer methods" do diff --git a/spec/compiler/parser/parser_spec.cr b/spec/compiler/parser/parser_spec.cr index 8e4216e52ccb..2b14541c2e43 100644 --- a/spec/compiler/parser/parser_spec.cr +++ b/spec/compiler/parser/parser_spec.cr @@ -1318,6 +1318,15 @@ module Crystal it_parses "x = 1; ->{ x }", [Assign.new("x".var, 1.int32), ProcLiteral.new(Def.new("->", body: "x".var))] it_parses "f ->{ a do\n end\n }", Call.new(nil, "f", ProcLiteral.new(Def.new("->", body: Call.new(nil, "a", block: Block.new)))) + it_parses "-> : Int32 { }", ProcLiteral.new(Def.new("->", return_type: "Int32".path)) + it_parses "->\n:\nInt32\n{\n}", ProcLiteral.new(Def.new("->", return_type: "Int32".path)) + it_parses "->() : Int32 { }", ProcLiteral.new(Def.new("->", return_type: "Int32".path)) + it_parses "->() : Int32 do end", ProcLiteral.new(Def.new("->", return_type: "Int32".path)) + it_parses "->(x : Int32) : Int32 { }", ProcLiteral.new(Def.new("->", [Arg.new("x", restriction: "Int32".path)], return_type: "Int32".path)) + + assert_syntax_error "-> :Int32 { }", "a space is mandatory between ':' and return type" + assert_syntax_error "->() :Int32 { }", "a space is mandatory between ':' and return type" + %w(foo foo= foo? foo!).each do |method| it_parses "->#{method}", ProcPointer.new(nil, method) it_parses "foo = 1; ->foo.#{method}", [Assign.new("foo".var, 1.int32), ProcPointer.new("foo".var, method)] diff --git a/spec/compiler/parser/to_s_spec.cr b/spec/compiler/parser/to_s_spec.cr index 9b740a6ece96..cb35814d6e10 100644 --- a/spec/compiler/parser/to_s_spec.cr +++ b/spec/compiler/parser/to_s_spec.cr @@ -188,4 +188,6 @@ describe "ASTNode#to_s" do expect_to_s %[他.说("你好")] expect_to_s %[他.说 = "你好"] expect_to_s %[あ.い, う.え.お = 1, 2] + expect_to_s "-> : Int32 do\nend" + expect_to_s "->(x : Int32, y : Bool) : Char do\n 'a'\nend" end diff --git a/spec/compiler/semantic/instance_var_spec.cr b/spec/compiler/semantic/instance_var_spec.cr index e0fb8c83de5c..4300972205a9 100644 --- a/spec/compiler/semantic/instance_var_spec.cr +++ b/spec/compiler/semantic/instance_var_spec.cr @@ -775,6 +775,22 @@ describe "Semantic: instance var" do )) { named_tuple_of({"x": int32, "y": string}) } end + it "infers type from proc literal with return type" do + assert_type(<<-CR) { proc_of([int32, bool, string]) } + class Foo + def initialize + @x = ->(x : Int32, y : Bool) : String { "" } + end + + def x + @x + end + end + + Foo.new.x + CR + end + it "infers type from new expression" do assert_type(%( class Bar diff --git a/spec/compiler/semantic/proc_spec.cr b/spec/compiler/semantic/proc_spec.cr index 0d522a95ba48..8886d00cda62 100644 --- a/spec/compiler/semantic/proc_spec.cr +++ b/spec/compiler/semantic/proc_spec.cr @@ -21,6 +21,18 @@ describe "Semantic: proc" do assert_type("f = ->(x : Int32) { x }; f.call(1)", inject_primitives: true) { int32 } end + it "types proc literal with return type (1)" do + assert_type("->(x : Int32) : Int32 { x }") { proc_of(int32, int32) } + end + + it "types proc literal with return type (2)" do + assert_type("-> : Int32 | String { 1 }") { proc_of(union_of int32, string) } + end + + it "types proc call with return type" do + assert_type("x = -> : Int32 | String { 1 }; x.call()", inject_primitives: true) { union_of int32, string } + end + it "types proc pointer" do assert_type("def foo; 1; end; ->foo") { proc_of(int32) } end @@ -227,6 +239,14 @@ describe "Semantic: proc" do "can't cast Proc(Int32, Float64) to Proc(Float64, Float64)", inject_primitives: true end + it "errors if inferred return type doesn't match return type restriction (1)" do + assert_error "-> : Int32 { true }", "expected Proc to return Int32, not Bool" + end + + it "errors if inferred return type doesn't match return type restriction (2)" do + assert_error "->(x : Int32) : Int32 { x || 'a' }", "expected Proc to return Int32, not (Char | Int32)" + end + it "types proc literal hard type inference (1)" do assert_type(%( require "prelude" @@ -763,6 +783,13 @@ describe "Semantic: proc" do "can't use #{type} as a Proc argument type" end + it "disallows #{type} in proc return types" do + assert_error %( + -> : #{type} { } + ), + "can't use #{type} as a Proc argument type" + end + it "disallows #{type} in captured block" do assert_error %( def foo(&block : #{type} ->) @@ -795,6 +822,13 @@ describe "Semantic: proc" do "can't use Object as a Proc argument type" end + it "disallows Class in proc return types" do + assert_error %( + -> : Class { } + ), + "can't use Class as a Proc argument type" + end + it "disallows Class in captured block" do assert_error %( def foo(&block : Class ->) diff --git a/src/compiler/crystal/macros.cr b/src/compiler/crystal/macros.cr index 56fc0126f280..5f4ef695acea 100644 --- a/src/compiler/crystal/macros.cr +++ b/src/compiler/crystal/macros.cr @@ -1480,6 +1480,10 @@ module Crystal::Macros # Returns the body of this proc. def body : ASTNode end + + # Returns the return type of this proc, if specified. + def return_type : ASTNode | Nop + end end # A proc pointer, like `->my_var.some_method(String)` diff --git a/src/compiler/crystal/macros/methods.cr b/src/compiler/crystal/macros/methods.cr index 2d98f4a64264..80535fdc5575 100644 --- a/src/compiler/crystal/macros/methods.cr +++ b/src/compiler/crystal/macros/methods.cr @@ -1188,7 +1188,7 @@ module Crystal class ProcLiteral def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?) case method - when "args", "body" + when "args", "body", "return_type" @def.interpret(method, args, named_args, block, interpreter, location) else super diff --git a/src/compiler/crystal/semantic/bindings.cr b/src/compiler/crystal/semantic/bindings.cr index 83a9f344787e..e745fe5109f4 100644 --- a/src/compiler/crystal/semantic/bindings.cr +++ b/src/compiler/crystal/semantic/bindings.cr @@ -465,6 +465,7 @@ module Crystal class ProcLiteral property? force_nil = false property expected_return_type : Type? + property? from_block = false def update(from = nil) return unless self.def.args.all? &.type? @@ -475,7 +476,7 @@ module Crystal expected_return_type = @expected_return_type if expected_return_type && !expected_return_type.nil_type? && !return_type.implements?(expected_return_type) - raise "expected block to return #{expected_return_type.devirtualize}, not #{return_type}" + raise "expected #{from_block? ? "block" : "Proc"} to return #{expected_return_type.devirtualize}, not #{return_type}" end types << (expected_return_type || return_type) diff --git a/src/compiler/crystal/semantic/call.cr b/src/compiler/crystal/semantic/call.cr index db6fda0bea70..2739eae08611 100644 --- a/src/compiler/crystal/semantic/call.cr +++ b/src/compiler/crystal/semantic/call.cr @@ -929,6 +929,7 @@ class Crystal::Call fun_literal = ProcLiteral.new(a_def).at(self) fun_literal.expected_return_type = output_type if output_type + fun_literal.from_block = true fun_literal.force_nil = true unless output fun_literal.accept parent_visitor end diff --git a/src/compiler/crystal/semantic/main_visitor.cr b/src/compiler/crystal/semantic/main_visitor.cr index e694eb2560ef..2d2fbafadaab 100644 --- a/src/compiler/crystal/semantic/main_visitor.cr +++ b/src/compiler/crystal/semantic/main_visitor.cr @@ -1148,6 +1148,17 @@ module Crystal meta_vars[arg.name] = meta_var end + if return_type = node.def.return_type + @in_type_args += 1 + return_type.accept self + @in_type_args -= 1 + check_not_a_constant(return_type) + + def_type = return_type.type + MainVisitor.check_type_allowed_as_proc_argument(node, def_type) + node.expected_return_type = def_type.virtual_type + end + node.bind_to node.def node.def.bind_to node.def.body node.def.vars = meta_vars diff --git a/src/compiler/crystal/semantic/type_guess_visitor.cr b/src/compiler/crystal/semantic/type_guess_visitor.cr index 2d4c954142af..6974073f2333 100644 --- a/src/compiler/crystal/semantic/type_guess_visitor.cr +++ b/src/compiler/crystal/semantic/type_guess_visitor.cr @@ -601,6 +601,32 @@ module Crystal end end + def guess_type(node : ProcLiteral) + output = node.def.return_type + return nil unless output + + types = nil + + node.def.args.each do |input| + restriction = input.restriction + return nil unless restriction + + input_type = lookup_type?(restriction) + return nil unless input_type + + types ||= [] of Type + types << input_type.virtual_type + end + + output_type = lookup_type?(output) + return nil unless output_type + + types ||= [] of Type + types << output_type.virtual_type + + program.proc_of(types) + end + def guess_type(node : Call) if expanded = node.expanded return guess_type(expanded) diff --git a/src/compiler/crystal/syntax/parser.cr b/src/compiler/crystal/syntax/parser.cr index f63e324f4383..289cd95bb699 100644 --- a/src/compiler/crystal/syntax/parser.cr +++ b/src/compiler/crystal/syntax/parser.cr @@ -1798,8 +1798,14 @@ module Crystal next_token_skip_space_or_newline - unless @token.type == :"{" || @token.type == :"(" || @token.keyword?(:do) - return parse_fun_pointer + case @token.type + when :SYMBOL + # -> :T { } + raise "a space is mandatory between ':' and return type", @token + when :"{", :"(", :":" + # do nothing + else + return parse_fun_pointer unless @token.keyword?(:do) end args = [] of Arg @@ -1817,6 +1823,16 @@ module Crystal next_token_skip_space_or_newline end + case @token.type + when :SYMBOL + # ->() :T { } + raise "a space is mandatory between ':' and return type", @token + when :":" + next_token_skip_space_or_newline + return_type = parse_bare_proc_type + skip_space_or_newline + end + with_lexical_var_scope do push_vars args @@ -1838,7 +1854,7 @@ module Crystal unexpected_token end - a_def = Def.new("->", args, body).at(location).at_end(end_location) + a_def = Def.new("->", args, body, return_type: return_type).at(location).at_end(end_location) ProcLiteral.new(a_def).at(location).at_end(end_location) end end diff --git a/src/compiler/crystal/syntax/to_s.cr b/src/compiler/crystal/syntax/to_s.cr index c5e03a2375b5..d576ebb359f6 100644 --- a/src/compiler/crystal/syntax/to_s.cr +++ b/src/compiler/crystal/syntax/to_s.cr @@ -616,6 +616,10 @@ module Crystal node.def.args.join(@str, ", ", &.accept self) @str << ')' end + if return_type = node.def.return_type + @str << " : " + return_type.accept self + end @str << ' ' @str << keyword("do") newline diff --git a/src/compiler/crystal/tools/formatter.cr b/src/compiler/crystal/tools/formatter.cr index b849449fb839..99e83a94f916 100644 --- a/src/compiler/crystal/tools/formatter.cr +++ b/src/compiler/crystal/tools/formatter.cr @@ -4131,7 +4131,15 @@ module Crystal next_token_skip_space_or_newline end - write " " unless a_def.args.empty? + if return_type = a_def.return_type + check :":" + write " : " + next_token_skip_space_or_newline + accept return_type + next_token_skip_space_or_newline + end + + write " " unless a_def.args.empty? && !return_type is_do = false if @token.keyword?(:do)