From 049dd0cf8b9de48c9825c0f6aea370ceac5e43be Mon Sep 17 00:00:00 2001 From: Alexandre Terrasa Date: Wed, 31 Jul 2024 11:26:19 -0400 Subject: [PATCH] Model accept either String or Type when a RBI type is expected Signed-off-by: Alexandre Terrasa Co-authored-by: Ufuk Kayserilioglu --- lib/rbi/model.rb | 38 ++++++++++++++---------- lib/rbi/parser.rb | 2 +- lib/rbi/printer.rb | 14 ++++----- lib/rbi/rewriters/attr_to_methods.rb | 4 +-- test/rbi/model_test.rb | 43 ++++++++++++++++++++++++++++ 5 files changed, 76 insertions(+), 25 deletions(-) diff --git a/lib/rbi/model.rb b/lib/rbi/model.rb index f44cbbad..0c530019 100644 --- a/lib/rbi/model.rb +++ b/lib/rbi/model.rb @@ -574,7 +574,7 @@ def add_block_param(name) sig do params( params: T::Array[SigParam], - return_type: T.nilable(String), + return_type: T.any(String, Type), is_abstract: T::Boolean, is_override: T::Boolean, is_overridable: T::Boolean, @@ -586,7 +586,7 @@ def add_block_param(name) end def add_sig( params: [], - return_type: nil, + return_type: "void", is_abstract: false, is_override: false, is_overridable: false, @@ -928,8 +928,10 @@ def initialize(visibility, loc: nil, comments: []) @visibility = visibility end - sig { params(other: Visibility).returns(T::Boolean) } + sig { params(other: T.nilable(Object)).returns(T::Boolean) } def ==(other) + return false unless other.is_a?(Visibility) + visibility == other.visibility end @@ -1105,7 +1107,7 @@ class Sig < Node sig { returns(T::Array[SigParam]) } attr_reader :params - sig { returns(T.nilable(String)) } + sig { returns(T.any(Type, String)) } attr_accessor :return_type sig { returns(T::Boolean) } @@ -1120,7 +1122,7 @@ class Sig < Node sig do params( params: T::Array[SigParam], - return_type: T.nilable(String), + return_type: T.any(Type, String), is_abstract: T::Boolean, is_override: T::Boolean, is_overridable: T::Boolean, @@ -1133,7 +1135,7 @@ class Sig < Node end def initialize( params: [], - return_type: nil, + return_type: Type.void, is_abstract: false, is_override: false, is_overridable: false, @@ -1160,7 +1162,7 @@ def <<(param) @params << param end - sig { params(name: String, type: String).void } + sig { params(name: String, type: T.any(Type, String)).void } def add_param(name, type) @params << SigParam.new(name, type) end @@ -1169,7 +1171,7 @@ def add_param(name, type) def ==(other) return false unless other.is_a?(Sig) - params == other.params && return_type == other.return_type && is_abstract == other.is_abstract && + params == other.params && return_type.to_s == other.return_type.to_s && is_abstract == other.is_abstract && is_override == other.is_override && is_overridable == other.is_overridable && is_final == other.is_final && type_params == other.type_params && checked == other.checked end @@ -1179,12 +1181,15 @@ class SigParam < NodeWithComments extend T::Sig sig { returns(String) } - attr_reader :name, :type + attr_reader :name + + sig { returns(T.any(Type, String)) } + attr_reader :type sig do params( name: String, - type: String, + type: T.any(Type, String), loc: T.nilable(Loc), comments: T::Array[Comment], block: T.nilable(T.proc.params(node: SigParam).void), @@ -1199,7 +1204,7 @@ def initialize(name, type, loc: nil, comments: [], &block) sig { params(other: Object).returns(T::Boolean) } def ==(other) - other.is_a?(SigParam) && name == other.name && type == other.type + other.is_a?(SigParam) && name == other.name && type.to_s == other.type.to_s end end @@ -1229,7 +1234,10 @@ class TStructField < NodeWithComments abstract! sig { returns(String) } - attr_accessor :name, :type + attr_accessor :name + + sig { returns(T.any(Type, String)) } + attr_accessor :type sig { returns(T.nilable(String)) } attr_accessor :default @@ -1237,7 +1245,7 @@ class TStructField < NodeWithComments sig do params( name: String, - type: String, + type: T.any(Type, String), default: T.nilable(String), loc: T.nilable(Loc), comments: T::Array[Comment], @@ -1260,7 +1268,7 @@ class TStructConst < TStructField sig do params( name: String, - type: String, + type: T.any(Type, String), default: T.nilable(String), loc: T.nilable(Loc), comments: T::Array[Comment], @@ -1290,7 +1298,7 @@ class TStructProp < TStructField sig do params( name: String, - type: String, + type: T.any(Type, String), default: T.nilable(String), loc: T.nilable(Loc), comments: T::Array[Comment], diff --git a/lib/rbi/parser.rb b/lib/rbi/parser.rb index 541ed78e..1d311d40 100644 --- a/lib/rbi/parser.rb +++ b/lib/rbi/parser.rb @@ -839,7 +839,7 @@ def visit_call_node(node) end end when "void" - @current.return_type = nil + @current.return_type = "void" end visit(node.receiver) diff --git a/lib/rbi/printer.rb b/lib/rbi/printer.rb index 21e505cc..b41c78e1 100644 --- a/lib/rbi/printer.rb +++ b/lib/rbi/printer.rb @@ -611,7 +611,7 @@ def print_param_comment_leading_space(node, last:) def print_sig_param_comment_leading_space(node, last:) printn printt - print(" " * (node.name.size + node.type.size + 3)) + print(" " * (node.name.size + node.type.to_s.size + 3)) print(" ") unless last end @@ -654,10 +654,10 @@ def print_sig_as_line(node) print(").") end return_type = node.return_type - if node.return_type && node.return_type != "void" - print("returns(#{return_type})") - else + if node.return_type.to_s == "void" print("void") + else + print("returns(#{return_type})") end printn(" }") end @@ -707,10 +707,10 @@ def print_sig_as_block(node) print(".") if modifiers.any? || params.any? return_type = node.return_type - if return_type && return_type != "void" - print("returns(#{return_type})") - else + if return_type.to_s == "void" print("void") + else + print("returns(#{return_type})") end printn dedent diff --git a/lib/rbi/rewriters/attr_to_methods.rb b/lib/rbi/rewriters/attr_to_methods.rb index 0bf4257a..02ced33e 100644 --- a/lib/rbi/rewriters/attr_to_methods.rb +++ b/lib/rbi/rewriters/attr_to_methods.rb @@ -62,7 +62,7 @@ def convert_to_methods; end private - sig(:final) { returns([T.nilable(Sig), T.nilable(String)]) } + sig(:final) { returns([T.nilable(Sig), T.nilable(T.any(Type, String))]) } def parse_sig raise UnexpectedMultipleSigsError, self if 1 < sigs.count @@ -101,7 +101,7 @@ def create_getter_method(name, sig, visibility, loc, comments) params( name: String, sig: T.nilable(Sig), - attribute_type: T.nilable(String), + attribute_type: T.nilable(T.any(Type, String)), visibility: Visibility, loc: T.nilable(Loc), comments: T::Array[Comment], diff --git a/test/rbi/model_test.rb b/test/rbi/model_test.rb index ff8430f3..9d75752a 100644 --- a/test/rbi/model_test.rb +++ b/test/rbi/model_test.rb @@ -420,5 +420,48 @@ def test_model_nodes_as_strings mod << helper assert_equal("::Foo.foo!", helper.to_s) end + + # types + + def test_model_sig_builder_with_types + rbi = Tree.new do |tree| + tree << Method.new("foo") do |node| + node.add_param("x") + + node.add_sig do |sig| + sig.add_param("x", Type.untyped) + sig.return_type = Type.void + end + end + end + + assert_equal(<<~RBI, rbi.string) + sig { params(x: T.untyped).void } + def foo(x); end + RBI + end + + def test_model_sig_with_types + node = Sig.new + node << SigParam.new("x", Type.untyped) + node.return_type = Type.simple("Integer") + + assert_equal(<<~RBI, node.string) + sig { params(x: T.untyped).returns(Integer) } + RBI + end + + def test_t_struct_with_types + node = TStruct.new("MyStruct") + node << TStructConst.new("foo", Type.simple("Foo")) + node << TStructProp.new("bar", Type.simple("Bar")) + + assert_equal(<<~RBI, node.string) + class MyStruct < ::T::Struct + const :foo, Foo + prop :bar, Bar + end + RBI + end end end