Skip to content

Commit

Permalink
Add generic type alias type parameter variance annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
soutaro committed Nov 11, 2021
1 parent f5d0c11 commit 62285a3
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 27 deletions.
25 changes: 22 additions & 3 deletions lib/rbs/validator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ module RBS
class Validator
attr_reader :env
attr_reader :resolver
attr_reader :definition_builder

def initialize(env:, resolver:)
@env = env
@resolver = resolver
@definition_builder = DefinitionBuilder.new(env: env)
end

def absolute_type(type, context:)
Expand Down Expand Up @@ -57,15 +59,32 @@ def validate_type(type, context:)
end

def validate_type_alias(entry:)
if type_alias_dependency.circular_definition?(entry.decl.name)
type_name = entry.decl.name

if type_alias_dependency.circular_definition?(type_name)
location = entry.decl.location or raise
raise RecursiveTypeAliasError.new(alias_names: [entry.decl.name], location: location)
raise RecursiveTypeAliasError.new(alias_names: [type_name], location: location)
end

if diagnostic = type_alias_regularity.nonregular?(entry.decl.name)
if diagnostic = type_alias_regularity.nonregular?(type_name)
location = entry.decl.location or raise
raise NonregularTypeAliasError.new(diagnostic: diagnostic, location: location)
end

unless entry.decl.type_params.empty?
calculator = VarianceCalculator.new(builder: definition_builder)
result = calculator.in_type_alias(name: type_name)
if set = result.incompatible?(entry.decl.type_params)
set.each do |param_name|
param = entry.decl.type_params[param_name] or raise
raise InvalidVarianceAnnotationError.new(
type_name: type_name,
param: param,
location: entry.decl.type.location
)
end
end
end
end

def type_alias_dependency
Expand Down
76 changes: 52 additions & 24 deletions lib/rbs/variance_calculator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@ def compatible?(var, with_annotation:)
false
end
end

def incompatible?(params)
# @type set: Hash[Symbol]
set = Set[]

params.each do |param|
unless compatible?(param.name, with_annotation: param.variance)
set << param.name
end
end

unless set.empty?
set
end
end
end

attr_reader :builder
Expand All @@ -69,19 +84,12 @@ def env
def in_method_type(method_type:, variables:)
result = Result.new(variables: variables)

method_type.type.each_param do |param|
type(param.type, result: result, context: :contravariant)
end
function(method_type.type, result: result, context: :covariant)

if block = method_type.block
block.type.each_param do |param|
type(param.type, result: result, context: :covariant)
end
type(block.type.return_type, result: result, context: :contravariant)
function(block.type, result: result, context: :contravariant)
end

type(method_type.type.return_type, result: result, context: :covariant)

result
end

Expand All @@ -97,6 +105,14 @@ def in_inherit(name:, args:, variables:)
end
end

def in_type_alias(name:)
decl = env.alias_decls[name].decl or raise
variables = decl.type_params.each.map(&:name)
Result.new(variables: variables).tap do |result|
type(decl.type, result: result, context: :covariant)
end
end

def type(type, result:, context:)
case type
when Types::Variable
Expand All @@ -110,7 +126,7 @@ def type(type, result:, context:)
result.invariant(type.name)
end
end
when Types::ClassInstance, Types::Interface
when Types::ClassInstance, Types::Interface, Types::Alias
NoTypeFoundError.check!(type.name,
env: env,
location: type.location)
Expand All @@ -120,6 +136,8 @@ def type(type, result:, context:)
env.class_decls[type.name].type_params
when Types::Interface
env.interface_decls[type.name].decl.type_params
when Types::Alias
env.alias_decls[type.name].decl.type_params
end

type.args.each.with_index do |ty, i|
Expand All @@ -130,26 +148,36 @@ def type(type, result:, context:)
when :covariant
type(ty, result: result, context: context)
when :contravariant
# @type var con: variance
con = case context
when :invariant
:invariant
when :covariant
:contravariant
when :contravariant
:covariant
else
raise
end
type(ty, result: result, context: con)
type(ty, result: result, context: negate(context))
end
end
when Types::Tuple, Types::Record, Types::Union, Types::Intersection
# Covariant types
when Types::Proc
function(type.type, result: result, context: context)
else
type.each_type do |ty|
type(ty, result: result, context: context)
end
end
end

def function(type, result:, context:)
type.each_param do |param|
type(param.type, result: result, context: negate(context))
end
type(type.return_type, result: result, context: context)
end

def negate(variance)
case variance
when :invariant
:invariant
when :covariant
:contravariant
when :contravariant
:covariant
else
raise
end
end
end
end
2 changes: 2 additions & 0 deletions sig/validator.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ module RBS

attr_reader resolver: TypeNameResolver

attr_reader definition_builder: DefinitionBuilder

attr_reader type_alias_dependency: TypeAliasDependency

attr_reader type_alias_regularity: TypeAliasRegularity
Expand Down
50 changes: 50 additions & 0 deletions sig/variance_calculator.rbs
Original file line number Diff line number Diff line change
@@ -1,7 +1,47 @@
module RBS
# Calculate the use variances of type variables in declaration.
#
# ```rb
# calculator = VarianceCalculator.new(builder: builder)
#
# # Calculates variances in a method type
# result = calculator.in_method_type(method_type: method_type, variables: variables)
#
# # Calculates variances in a inheritance/mixin/...
# result = calculator.in_inherit(name: name, args: args, variables: variables)
#
# # Calculates variances in a type alias
# result = calculator.in_type_alias(name: name, args: args, variables: variables)
# ```
#
# See `RBS::VarianceCaluculator::Result` for information recorded in the `Result` object.
#
class VarianceCalculator
type variance = :unused | :covariant | :contravariant | :invariant

# Result contains the set of type variables and it's variance in a occurrence.
#
# ```rb
# # Enumerates recorded type variables
# result.each do |name, variance|
# # name is the name of a type variable
# # variance is one of :unused | :covariant | :contravariant | :invariant
# end
# ```
#
# You can test with `compatible?` method if the type variable occurrences are compatible with specified (annotated) variance.
#
# ```rb
# # When T is declared as `out T`
# result.compatible?(:T, with_annotation: :covariant)
#
# # When T is declared as `in T`
# result.compatible?(:T, with_annotation: :contravariant)
#
# # When T is declared as `T`
# result.compatible?(:T, with_annotation: :invariant)
# ```
#
class Result
attr_reader result: Hash[Symbol, variance]

Expand All @@ -18,6 +58,8 @@ module RBS
def include?: (Symbol) -> bool

def compatible?: (Symbol, with_annotation: variance) -> bool

def incompatible?: (AST::Declarations::ModuleTypeParams) -> Set[Symbol]?
end

attr_reader builder: DefinitionBuilder
Expand All @@ -30,6 +72,14 @@ module RBS

def in_inherit: (name: TypeName, args: Array[Types::t], variables: Array[Symbol]) -> Result

def in_type_alias: (name: TypeName) -> Result

private

def type: (Types::t, result: Result, context: variance) -> void

def function: (Types::Function, result: Result, context: variance) -> void

def negate: (variance) -> variance
end
end
30 changes: 30 additions & 0 deletions test/rbs/signature_parsing_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,36 @@ class Foo[A]
end
end

def test_type_alias_generic_variance
Parser.parse_signature(<<RBS).yield_self do |decls|
type x[T] = ^(T) -> void
type y[unchecked out T] = ^(T) -> void
RBS
assert_equal 2, decls.size

decls[0].tap do |type_decl|
assert_instance_of Declarations::Alias, type_decl

type_decl.type_params.params[0].tap do |param|
assert_equal :T, param.name
assert_equal :invariant, param.variance
refute_predicate param, :skip_validation
end
end

decls[1].tap do |type_decl|
assert_instance_of Declarations::Alias, type_decl

type_decl.type_params.params[0].tap do |param|
assert_equal :T, param.name
assert_equal :covariant, param.variance
assert_predicate param, :skip_validation
end
end
end
end

def test_constant
Parser.parse_signature("FOO: untyped").yield_self do |decls|
assert_equal 1, decls.size
Expand Down
37 changes: 37 additions & 0 deletions test/rbs/variance_calculator_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,43 @@ module Bar[out X, in Y, Z]
end
end

def test_alias_generics
SignatureManager.new do |manager|
manager.files[Pathname("foo.rbs")] = <<EOF
type a[T] = T
type b[T, S] = ^(T) -> S
type c[T, S] = Foo[T, S]
type d[T] = Foo[T, T]
class Foo[in T, out S]
end
EOF
manager.build do |env|
builder = DefinitionBuilder.new(env: env)
calculator = VarianceCalculator.new(builder: builder)

calculator.in_type_alias(name: TypeName("::a")).tap do |result|
assert_equal({ T: :covariant }, result.result)
end

calculator.in_type_alias(name: TypeName("::b")).tap do |result|
assert_equal({ T: :contravariant, S: :covariant }, result.result)
end

calculator.in_type_alias(name: TypeName("::c")).tap do |result|
assert_equal({ T: :contravariant, S: :covariant }, result.result)
end

calculator.in_type_alias(name: TypeName("::d")).tap do |result|
assert_equal({ T: :invariant }, result.result)
end
end
end
end

def test_result
result = VarianceCalculator::Result.new(variables: [:A, :B, :C])
result.covariant(:A)
Expand Down
6 changes: 6 additions & 0 deletions test/validator_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def test_generic_type_aliases
type foo[T] = [T, foo[T]]
type bar[T] = [bar[T?]]
type baz[out T] = ^(T) -> void
EOF

manager.build do |env|
Expand All @@ -133,6 +135,10 @@ def test_generic_type_aliases
assert_raises RBS::NonregularTypeAliasError do
validator.validate_type_alias(entry: env.alias_decls[type_name("::bar")])
end

assert_raises RBS::InvalidVarianceAnnotationError do
validator.validate_type_alias(entry: env.alias_decls[type_name("::baz")])
end
end
end
end
Expand Down

0 comments on commit 62285a3

Please sign in to comment.