Skip to content

Commit

Permalink
Support explicit return types in ProcLiterals (#11402)
Browse files Browse the repository at this point in the history
A `ProcLiteral` with an explicit return type can now be used to type instance and class variables:

```crystal
class Foo
  getter foo = ->(x : Int32) : Int32 {
    return x // 2 if x.even? # okay
    x * 3 + 1
  }

  # current way:
  getter bar = Proc(Int32, Int32).new { |x|
    next x // 2 if x.even?
    x * 3 + 1
  }
end
```

The corresponding macro method accessor is `ProcLiteral#return_type`.
  • Loading branch information
HertzDevil authored Nov 15, 2021
1 parent e4c3c75 commit fe810b3
Show file tree
Hide file tree
Showing 16 changed files with 158 additions and 6 deletions.
8 changes: 8 additions & 0 deletions spec/compiler/codegen/proc_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,14 @@ describe "Code gen: proc" do
run("f = ->(x : Int32) { x &+ 1 }; f.call(41)").to_i.should eq(42)
end

it "call proc literal with return type" do
run(<<-CR).to_b.should be_true
f = -> : Int32 | Float64 { 1 }
x = f.call
x.is_a?(Int32) && x == 1
CR
end

it "call proc pointer" do
run("def foo; 1; end; x = ->foo; x.call").to_i.should eq(1)
end
Expand Down
7 changes: 7 additions & 0 deletions spec/compiler/formatter/formatter_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -780,13 +780,20 @@ describe Crystal::Formatter do
assert_format "foo = 1\n->foo.bar=(Int32)"
assert_format "foo = 1\n->foo.[](Int32)"
assert_format "foo = 1\n->foo.[]=(Int32)"

assert_format "->{ x }"
assert_format "->{\nx\n}", "->{\n x\n}"
assert_format "->do\nx\nend", "->do\n x\nend"
assert_format "->( ){ x }", "->{ x }"
assert_format "->() do x end", "->do x end"
assert_format "->( x , y ) { x }", "->(x, y) { x }"
assert_format "->( x : Int32 , y ) { x }", "->(x : Int32, y) { x }"
assert_format "->{}"

assert_format "-> : Int32 {}"
assert_format "->\n:\nInt32\n{\n}", "-> : Int32 {\n}"
assert_format "->( x )\n:\nInt32 { }", "->(x) : Int32 {}"
assert_format "->: Int32 do\nx\nend", "-> : Int32 do\n x\nend"

{:+, :-, :*, :/, :^, :>>, :<<, :|, :&, :&+, :&-, :&*, :&**}.each do |sym|
assert_format ":#{sym}"
Expand Down
5 changes: 5 additions & 0 deletions spec/compiler/macro/macro_methods_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -2081,6 +2081,11 @@ module Crystal
it "executes args" do
assert_macro %({{x.args}}), "[z]", {x: ProcLiteral.new(Def.new("->", [Arg.new("z")]))}
end

it "executes return_type" do
assert_macro %({{x.return_type}}), "Int32", {x: ProcLiteral.new(Def.new("->", return_type: "Int32".path))}
assert_macro %({{x.return_type}}), "", {x: ProcLiteral.new(Def.new("->"))}
end
end

describe "proc pointer methods" do
Expand Down
9 changes: 9 additions & 0 deletions spec/compiler/parser/parser_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,15 @@ module Crystal
it_parses "x = 1; ->{ x }", [Assign.new("x".var, 1.int32), ProcLiteral.new(Def.new("->", body: "x".var))]
it_parses "f ->{ a do\n end\n }", Call.new(nil, "f", ProcLiteral.new(Def.new("->", body: Call.new(nil, "a", block: Block.new))))

it_parses "-> : Int32 { }", ProcLiteral.new(Def.new("->", return_type: "Int32".path))
it_parses "->\n:\nInt32\n{\n}", ProcLiteral.new(Def.new("->", return_type: "Int32".path))
it_parses "->() : Int32 { }", ProcLiteral.new(Def.new("->", return_type: "Int32".path))
it_parses "->() : Int32 do end", ProcLiteral.new(Def.new("->", return_type: "Int32".path))
it_parses "->(x : Int32) : Int32 { }", ProcLiteral.new(Def.new("->", [Arg.new("x", restriction: "Int32".path)], return_type: "Int32".path))

assert_syntax_error "-> :Int32 { }", "a space is mandatory between ':' and return type"
assert_syntax_error "->() :Int32 { }", "a space is mandatory between ':' and return type"

%w(foo foo= foo? foo!).each do |method|
it_parses "->#{method}", ProcPointer.new(nil, method)
it_parses "foo = 1; ->foo.#{method}", [Assign.new("foo".var, 1.int32), ProcPointer.new("foo".var, method)]
Expand Down
2 changes: 2 additions & 0 deletions spec/compiler/parser/to_s_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,6 @@ describe "ASTNode#to_s" do
expect_to_s %[他.说("你好")]
expect_to_s %[他.说 = "你好"]
expect_to_s %[あ.い, う.え.お = 1, 2]
expect_to_s "-> : Int32 do\nend"
expect_to_s "->(x : Int32, y : Bool) : Char do\n 'a'\nend"
end
16 changes: 16 additions & 0 deletions spec/compiler/semantic/instance_var_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,22 @@ describe "Semantic: instance var" do
)) { named_tuple_of({"x": int32, "y": string}) }
end

it "infers type from proc literal with return type" do
assert_type(<<-CR) { proc_of([int32, bool, string]) }
class Foo
def initialize
@x = ->(x : Int32, y : Bool) : String { "" }
end
def x
@x
end
end
Foo.new.x
CR
end

it "infers type from new expression" do
assert_type(%(
class Bar
Expand Down
34 changes: 34 additions & 0 deletions spec/compiler/semantic/proc_spec.cr
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ describe "Semantic: proc" do
assert_type("f = ->(x : Int32) { x }; f.call(1)", inject_primitives: true) { int32 }
end

it "types proc literal with return type (1)" do
assert_type("->(x : Int32) : Int32 { x }") { proc_of(int32, int32) }
end

it "types proc literal with return type (2)" do
assert_type("-> : Int32 | String { 1 }") { proc_of(union_of int32, string) }
end

it "types proc call with return type" do
assert_type("x = -> : Int32 | String { 1 }; x.call()", inject_primitives: true) { union_of int32, string }
end

it "types proc pointer" do
assert_type("def foo; 1; end; ->foo") { proc_of(int32) }
end
Expand Down Expand Up @@ -227,6 +239,14 @@ describe "Semantic: proc" do
"can't cast Proc(Int32, Float64) to Proc(Float64, Float64)", inject_primitives: true
end

it "errors if inferred return type doesn't match return type restriction (1)" do
assert_error "-> : Int32 { true }", "expected Proc to return Int32, not Bool"
end

it "errors if inferred return type doesn't match return type restriction (2)" do
assert_error "->(x : Int32) : Int32 { x || 'a' }", "expected Proc to return Int32, not (Char | Int32)"
end

it "types proc literal hard type inference (1)" do
assert_type(%(
require "prelude"
Expand Down Expand Up @@ -763,6 +783,13 @@ describe "Semantic: proc" do
"can't use #{type} as a Proc argument type"
end

it "disallows #{type} in proc return types" do
assert_error %(
-> : #{type} { }
),
"can't use #{type} as a Proc argument type"
end

it "disallows #{type} in captured block" do
assert_error %(
def foo(&block : #{type} ->)
Expand Down Expand Up @@ -795,6 +822,13 @@ describe "Semantic: proc" do
"can't use Object as a Proc argument type"
end

it "disallows Class in proc return types" do
assert_error %(
-> : Class { }
),
"can't use Class as a Proc argument type"
end

it "disallows Class in captured block" do
assert_error %(
def foo(&block : Class ->)
Expand Down
4 changes: 4 additions & 0 deletions src/compiler/crystal/macros.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,10 @@ module Crystal::Macros
# Returns the body of this proc.
def body : ASTNode
end

# Returns the return type of this proc, if specified.
def return_type : ASTNode | Nop
end
end

# A proc pointer, like `->my_var.some_method(String)`
Expand Down
2 changes: 1 addition & 1 deletion src/compiler/crystal/macros/methods.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ module Crystal
class ProcLiteral
def interpret(method : String, args : Array(ASTNode), named_args : Hash(String, ASTNode)?, block : Crystal::Block?, interpreter : Crystal::MacroInterpreter, name_loc : Location?)
case method
when "args", "body"
when "args", "body", "return_type"
@def.interpret(method, args, named_args, block, interpreter, location)
else
super
Expand Down
3 changes: 2 additions & 1 deletion src/compiler/crystal/semantic/bindings.cr
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ module Crystal
class ProcLiteral
property? force_nil = false
property expected_return_type : Type?
property? from_block = false

def update(from = nil)
return unless self.def.args.all? &.type?
Expand All @@ -475,7 +476,7 @@ module Crystal

expected_return_type = @expected_return_type
if expected_return_type && !expected_return_type.nil_type? && !return_type.implements?(expected_return_type)
raise "expected block to return #{expected_return_type.devirtualize}, not #{return_type}"
raise "expected #{from_block? ? "block" : "Proc"} to return #{expected_return_type.devirtualize}, not #{return_type}"
end

types << (expected_return_type || return_type)
Expand Down
1 change: 1 addition & 0 deletions src/compiler/crystal/semantic/call.cr
Original file line number Diff line number Diff line change
Expand Up @@ -929,6 +929,7 @@ class Crystal::Call

fun_literal = ProcLiteral.new(a_def).at(self)
fun_literal.expected_return_type = output_type if output_type
fun_literal.from_block = true
fun_literal.force_nil = true unless output
fun_literal.accept parent_visitor
end
Expand Down
11 changes: 11 additions & 0 deletions src/compiler/crystal/semantic/main_visitor.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,17 @@ module Crystal
meta_vars[arg.name] = meta_var
end

if return_type = node.def.return_type
@in_type_args += 1
return_type.accept self
@in_type_args -= 1
check_not_a_constant(return_type)

def_type = return_type.type
MainVisitor.check_type_allowed_as_proc_argument(node, def_type)
node.expected_return_type = def_type.virtual_type
end

node.bind_to node.def
node.def.bind_to node.def.body
node.def.vars = meta_vars
Expand Down
26 changes: 26 additions & 0 deletions src/compiler/crystal/semantic/type_guess_visitor.cr
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,32 @@ module Crystal
end
end

def guess_type(node : ProcLiteral)
output = node.def.return_type
return nil unless output

types = nil

node.def.args.each do |input|
restriction = input.restriction
return nil unless restriction

input_type = lookup_type?(restriction)
return nil unless input_type

types ||= [] of Type
types << input_type.virtual_type
end

output_type = lookup_type?(output)
return nil unless output_type

types ||= [] of Type
types << output_type.virtual_type

program.proc_of(types)
end

def guess_type(node : Call)
if expanded = node.expanded
return guess_type(expanded)
Expand Down
22 changes: 19 additions & 3 deletions src/compiler/crystal/syntax/parser.cr
Original file line number Diff line number Diff line change
Expand Up @@ -1798,8 +1798,14 @@ module Crystal

next_token_skip_space_or_newline

unless @token.type == :"{" || @token.type == :"(" || @token.keyword?(:do)
return parse_fun_pointer
case @token.type
when :SYMBOL
# -> :T { }
raise "a space is mandatory between ':' and return type", @token
when :"{", :"(", :":"
# do nothing
else
return parse_fun_pointer unless @token.keyword?(:do)
end

args = [] of Arg
Expand All @@ -1817,6 +1823,16 @@ module Crystal
next_token_skip_space_or_newline
end

case @token.type
when :SYMBOL
# ->() :T { }
raise "a space is mandatory between ':' and return type", @token
when :":"
next_token_skip_space_or_newline
return_type = parse_bare_proc_type
skip_space_or_newline
end

with_lexical_var_scope do
push_vars args

Expand All @@ -1838,7 +1854,7 @@ module Crystal
unexpected_token
end

a_def = Def.new("->", args, body).at(location).at_end(end_location)
a_def = Def.new("->", args, body, return_type: return_type).at(location).at_end(end_location)
ProcLiteral.new(a_def).at(location).at_end(end_location)
end
end
Expand Down
4 changes: 4 additions & 0 deletions src/compiler/crystal/syntax/to_s.cr
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ module Crystal
node.def.args.join(@str, ", ", &.accept self)
@str << ')'
end
if return_type = node.def.return_type
@str << " : "
return_type.accept self
end
@str << ' '
@str << keyword("do")
newline
Expand Down
10 changes: 9 additions & 1 deletion src/compiler/crystal/tools/formatter.cr
Original file line number Diff line number Diff line change
Expand Up @@ -4131,7 +4131,15 @@ module Crystal
next_token_skip_space_or_newline
end

write " " unless a_def.args.empty?
if return_type = a_def.return_type
check :":"
write " : "
next_token_skip_space_or_newline
accept return_type
next_token_skip_space_or_newline
end

write " " unless a_def.args.empty? && !return_type

is_do = false
if @token.keyword?(:do)
Expand Down

0 comments on commit fe810b3

Please sign in to comment.