Skip to content

Commit

Permalink
use new GetTypesByMetadataName in ImmutabilityContext.Factory
Browse files Browse the repository at this point in the history
Closes: #851
  • Loading branch information
omsmith committed Jul 4, 2022
1 parent 0765985 commit 91e04d1
Showing 1 changed file with 20 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,11 @@ internal static ImmutabilityContext Create(
additionalImmutableTypes = ImmutableHashSet<string>.Empty;
}

ImmutableDictionary<string, IAssemblySymbol> compilationAssemblies = GetCompilationAssemblies( compilation );

// Generate a dictionary of types that we have specifically determined
// should be considered Immutable by the Analyzer.
var extraImmutableTypesBuilder = ImmutableDictionary.CreateBuilder<INamedTypeSymbol, ImmutableTypeInfo>( SymbolEqualityComparer.Default );
foreach( ( string typeName, string qualifiedAssembly ) in DefaultExtraTypes ) {
INamedTypeSymbol type = GetTypeSymbol( compilationAssemblies, compilation, qualifiedAssembly, typeName );
INamedTypeSymbol type = GetTypeSymbol( compilation, qualifiedAssembly, typeName );

if( type == null ) {
continue;
Expand All @@ -119,7 +117,7 @@ internal static ImmutabilityContext Create(
}

foreach( string typeName in additionalImmutableTypes ) {
INamedTypeSymbol type = GetTypeSymbol( compilationAssemblies, compilation, qualifiedAssembly: default, typeName );
INamedTypeSymbol type = GetTypeSymbol( compilation, qualifiedAssembly: default, typeName );

if( type == null ) {
continue;
Expand All @@ -141,7 +139,7 @@ internal static ImmutabilityContext Create(
// have a return value which should be considered Immutable by the Analyzer.
var knownImmutableReturnsBuilder = ImmutableHashSet.CreateBuilder<IMethodSymbol>( SymbolEqualityComparer.Default );
foreach( ( string typeName, string methodName, string qualifiedAssembly ) in KnownImmutableReturningMethods ) {
INamedTypeSymbol type = GetTypeSymbol( compilationAssemblies, compilation, qualifiedAssembly, typeName );
INamedTypeSymbol type = GetTypeSymbol( compilation, qualifiedAssembly, typeName );

if( type == null ) {
continue;
Expand All @@ -168,45 +166,35 @@ internal static ImmutabilityContext Create(
);
}

private static ImmutableDictionary<string, IAssemblySymbol> GetCompilationAssemblies( Compilation compilation ) {
var builder = ImmutableDictionary.CreateBuilder<string, IAssemblySymbol>();

IAssemblySymbol compilationAssmebly = compilation.Assembly;

builder.Add( compilationAssmebly.Name, compilationAssmebly );

foreach( IModuleSymbol module in compilationAssmebly.Modules ) {
foreach( IAssemblySymbol assembly in module.ReferencedAssemblySymbols ) {
builder.Add( assembly.Name, assembly );
}
}

return builder.ToImmutable();
}

private static INamedTypeSymbol GetTypeSymbol(
ImmutableDictionary<string, IAssemblySymbol> compilationAssemblies,
Compilation compilation,
string qualifiedAssembly,
string typeName
) {
INamedTypeSymbol type;

if( string.IsNullOrEmpty( qualifiedAssembly ) ) {
type = compilation.GetTypeByMetadataName( typeName );
} else {
if( !compilationAssemblies.TryGetValue( qualifiedAssembly, out IAssemblySymbol assembly ) ) {
return null;
ImmutableArray<INamedTypeSymbol> types = compilation.GetTypesByMetadataName( typeName );

if( types.IsEmpty ) {
return null;
}

if( qualifiedAssembly == default ) {
if( types.Length > 1 ) {
throw new InvalidOperationException(
$"Found multiple {typeName} with no {nameof( qualifiedAssembly )} specified when building ImmutabilityContext."
);
}

type = assembly.GetTypeByMetadataName( typeName );
return types[ 0 ];
}

if( type == null || type.Kind == SymbolKind.ErrorType ) {
return null;
foreach( INamedTypeSymbol type in types ) {
if( type.ContainingAssembly.Name.Equals( qualifiedAssembly, StringComparison.Ordinal ) ) {
return type;
}
}

return type;
return null;
}

}
Expand Down

0 comments on commit 91e04d1

Please sign in to comment.