Skip to content

Commit

Permalink
Rewrite attr_* into def methods
Browse files Browse the repository at this point in the history
  • Loading branch information
amomchilov committed Jun 7, 2024
1 parent 556e965 commit d321f86
Show file tree
Hide file tree
Showing 4 changed files with 495 additions and 0 deletions.
1 change: 1 addition & 0 deletions lib/rbi.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
require "rbi/rewriters/nest_non_public_methods"
require "rbi/rewriters/group_nodes"
require "rbi/rewriters/remove_known_definitions"
require "rbi/rewriters/replace_attributes_with_methods"
require "rbi/rewriters/sort_nodes"
require "rbi/parser"
require "rbi/printer"
Expand Down
56 changes: 56 additions & 0 deletions lib/rbi/model.rb
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def replace(node)
self.parent_tree = nil
end

sig { params(nodes: T::Enumerable[Node]).void }
def replace_with_multiple(nodes)
tree = parent_tree
raise unless tree

# Does this work?
nodes.each { |node| tree << node }
detach
end

sig { returns(T.nilable(Scope)) }
def parent_scope
parent = T.let(parent_tree, T.nilable(Tree))
Expand Down Expand Up @@ -1153,6 +1163,43 @@ def initialize(
block&.call(self)
end

sig do
params(
params: T::Array[SigParam],
return_type: T.nilable(String),
is_abstract: T::Boolean,
is_override: T::Boolean,
is_overridable: T::Boolean,
is_final: T::Boolean,
type_params: T::Array[String],
checked: T.nilable(Symbol),
loc: T.nilable(Loc),
).returns(Sig)
end
def new_with(
params: @params,
return_type: @return_type,
is_abstract: @is_abstract,
is_override: @is_override,
is_overridable: @is_overridable,
is_final: @is_final,
type_params: @type_params.dup,
checked: @checked,
loc: @loc.dup
)
Sig.new(
params: params,
return_type: return_type,
is_abstract: is_abstract,
is_override: is_override,
is_overridable: is_overridable,
is_final: is_final,
type_params: type_params,
checked: checked,
loc: loc,
)
end

sig { params(param: SigParam).void }
def <<(param)
@params << param
Expand All @@ -1171,6 +1218,15 @@ def ==(other)
is_override == other.is_override && is_overridable == other.is_overridable && is_final == other.is_final &&
type_params == other.type_params && checked == other.checked
end

sig { override.returns(String) }
def inspect
io = StringIO.new

Printer.new(out: io, indent: 0).visit(self)

io.string.chomp
end
end

class SigParam < NodeWithComments
Expand Down
146 changes: 146 additions & 0 deletions lib/rbi/rewriters/replace_attributes_with_methods.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
# typed: strict
# frozen_string_literal: true

module RBI
module TempHelpers
extend T::Sig
end

module Rewriters
class ReplaceAttributesWithMethods < Visitor
extend T::Sig
include TempHelpers

sig { override.params(node: T.nilable(Node)).void }
def visit(node)
return unless node

case node
when Tree
node.nodes.dup.each do |child|
visit(child)
next unless (attr = child).is_a?(Attr)

new_methods = convert_to_methods(attr)

child.replace_with_multiple(new_methods)
end
end
end

private

sig { params(attr: Attr).returns(T::Array[Method]) }
def convert_to_methods(attr)
sig, attribute_type = parse_sig_of(attr)

case attr
when AttrReader then convert_attr_reader_to_methods(attr, sig, attribute_type)
when AttrWriter then convert_attr_writer_to_methods(attr, sig, attribute_type)
when AttrAccessor then convert_attr_accessor_to_methods(attr, sig, attribute_type)
else raise NotImplementedError, "Unknown attribute type: #{attr.class}"
end
end

sig { params(attr: AttrReader, sig: T.nilable(Sig), attribute_type: T.nilable(String)).returns(T::Array[Method]) }
def convert_attr_reader_to_methods(attr, sig, attribute_type)
attr.names.map do |name|
create_getter_method(name.to_s, sig, attr.visibility, attr.loc, attr.comments)
end
end

sig { params(attr: AttrWriter, sig: T.nilable(Sig), attribute_type: T.nilable(String)).returns(T::Array[Method]) }
def convert_attr_writer_to_methods(attr, sig, attribute_type)
attr.names.map do |name|
create_setter_method(name.to_s, sig, attribute_type, attr.visibility, attr.loc, attr.comments)
end
end

sig do
params(attr: AttrAccessor, sig: T.nilable(Sig), attribute_type: T.nilable(String)).returns(T::Array[Method])
end
def convert_attr_accessor_to_methods(attr, sig, attribute_type)
readers = attr.names.flat_map do |name|
create_getter_method(name.to_s, sig, attr.visibility, attr.loc, attr.comments)
end

writers = attr.names.map do |name|
create_setter_method(name.to_s, sig, attribute_type, attr.visibility, attr.loc, attr.comments)
end

readers + writers
end

sig { params(attr: Attr).returns([T.nilable(Sig), T.nilable(String)]) }
def parse_sig_of(attr)
raise "Attributes cannot have more than 1 sig" if 1 < attr.sigs.count

sig = attr.sigs.first
return [nil, nil] unless sig

attribute_type = case attr
when AttrReader, AttrAccessor then sig.return_type
when AttrWriter then sig.params.first&.type
end

[sig, attribute_type]
end

sig do
params(
name: String,
sig: T.nilable(Sig),
visibility: Visibility,
loc: T.nilable(Loc),
comments: T::Array[Comment],
).returns(Method)
end
def create_getter_method(name, sig, visibility, loc, comments)
Method.new(
name,
params: [],
visibility: visibility,
sigs: sig ? [sig] : [],
loc: loc,
comments: comments,
)
end

sig do
params(
name: String,
sig: T.nilable(Sig),
attribute_type: T.nilable(String),
visibility: Visibility,
loc: T.nilable(Loc),
comments: T::Array[Comment],
).returns(Method)
end
def create_setter_method(name, sig, attribute_type, visibility, loc, comments) # rubocop:disable Metrics/ParameterLists
sig = if sig # Modify the original sig to correct the name, and remove the return type
params = attribute_type ? [SigParam.new(name, attribute_type)] : []
sig.new_with(params: params, return_type: "void")
end

Method.new(
"#{name}=",
params: [ReqParam.new(name)],
visibility: visibility,
sigs: sig ? [sig] : sig,
loc: loc,
comments: comments,
)
end
end
end

class Tree
extend T::Sig

sig { void }
def replace_attributes_with_methods!
visitor = Rewriters::ReplaceAttributesWithMethods.new
visitor.visit(self)
end
end
end
Loading

0 comments on commit d321f86

Please sign in to comment.