From 2f5046b54523e9b04707da25d00118700953ebd7 Mon Sep 17 00:00:00 2001 From: Eric Traut Date: Tue, 15 Mar 2022 16:38:40 -0600 Subject: [PATCH] Added support for the use of `Concatenate` as a type argument for a generic type alias that accepts a ParamSpec. --- .../src/analyzer/typeEvaluator.ts | 42 ++++++++++++++++--- .../src/tests/samples/paramSpec13.py | 40 +++++++++++++++++- .../src/tests/typeEvaluator4.test.ts | 2 +- 3 files changed, 77 insertions(+), 7 deletions(-) diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index cbf876ca8c0e..e52b70128e28 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -5776,6 +5776,8 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions const diag = new DiagnosticAddendum(); typeParameters.forEach((param, index) => { if (param.details.isParamSpec && index < typeArgs.length) { + const typeArgType = typeArgs[index].type; + if (typeArgs[index].typeList) { const functionType = FunctionType.createInstantiable('', '', '', FunctionTypeFlags.ParamSpecValue); TypeBase.setSpecialForm(functionType); @@ -5790,9 +5792,39 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions }); canAssignTypeToTypeVar(param, functionType, diag, typeVarMap); - } else if (isParamSpec(typeArgs[index].type)) { - canAssignTypeToTypeVar(param, convertToInstance(typeArgs[index].type), diag, typeVarMap); - } else if (isEllipsisType(typeArgs[index].type)) { + } else if (isParamSpec(typeArgType)) { + canAssignTypeToTypeVar(param, convertToInstance(typeArgType), diag, typeVarMap); + } else if (isInstantiableClass(typeArgType) && ClassType.isBuiltIn(typeArgType, 'Concatenate')) { + const concatTypeArgs = typeArgType.typeArguments; + const functionType = FunctionType.createInstance('', '', '', FunctionTypeFlags.None); + + if (concatTypeArgs && concatTypeArgs.length > 0) { + concatTypeArgs.forEach((typeArg, index) => { + if (index === concatTypeArgs.length - 1) { + // Add a position-only separator + FunctionType.addParameter(functionType, { + category: ParameterCategory.Simple, + isNameSynthesized: false, + type: UnknownType.create(), + }); + + if (isParamSpec(typeArg)) { + functionType.details.paramSpec = typeArg; + } + } else { + FunctionType.addParameter(functionType, { + category: ParameterCategory.Simple, + name: `__p${index}`, + isNameSynthesized: true, + hasDeclaredType: true, + type: typeArg, + }); + } + }); + } + + canAssignTypeToTypeVar(param, functionType, diag, typeVarMap); + } else if (isEllipsisType(typeArgType)) { const functionType = FunctionType.createInstantiable( '', '', @@ -20506,7 +20538,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions if (existingEntry) { // Verify that the existing entry matches the new entry. if ( - !existingEntry.paramSpec && + existingEntry.paramSpec === srcType.details.paramSpec && existingEntry.parameters.length === parameters.length && !existingEntry.parameters.some((existingParam, index) => { const newParam = parameters[index]; @@ -20533,7 +20565,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions typeVarScopeId: srcType.details.typeVarScopeId, flags: srcType.details.flags, docString: srcType.details.docString, - paramSpec: undefined, + paramSpec: srcType.details.paramSpec, }); } return true; diff --git a/packages/pyright-internal/src/tests/samples/paramSpec13.py b/packages/pyright-internal/src/tests/samples/paramSpec13.py index 60137aab6b21..e9647b7c9ccf 100644 --- a/packages/pyright-internal/src/tests/samples/paramSpec13.py +++ b/packages/pyright-internal/src/tests/samples/paramSpec13.py @@ -1,7 +1,17 @@ # This sample tests cases where a ParamSpec is used as a type parameter # for a generic type alias, a generic function, and a generic class. -from typing import Callable, Concatenate, Generic, List, ParamSpec, TypeVar +import asyncio +from typing import ( + Any, + Callable, + Concatenate, + Coroutine, + Generic, + List, + ParamSpec, + TypeVar, +) _P = ParamSpec("_P") @@ -72,3 +82,31 @@ def remote(func: Callable[_P, _R]) -> RemoteFunction[_P, _R]: v4 = remote(func2) reveal_type(v4, expected_text="RemoteFunction[(a: str, b: List[int]), str]") + + +Coro = Coroutine[Any, Any, _T] +CoroFunc = Callable[_P, Coro[_T]] + + +class ClassA: + ... + + +CheckFunc = CoroFunc[Concatenate[ClassA, _P], bool] + + +async def my_check_func(obj: ClassA, a: int, b: str) -> bool: + print(a, b) + return str(a) == b + + +async def takes_check_func( + check_func: CheckFunc[_P], *args: _P.args, **kwargs: _P.kwargs +): + await check_func(ClassA(), *args, **kwargs) + + +asyncio.run(takes_check_func(my_check_func, 1, "2")) + +# This should generate an error because the signature doesn't match. +asyncio.run(takes_check_func(my_check_func, 1, 2)) diff --git a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts index bb092b9a7851..48aed51eac2f 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator4.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator4.test.ts @@ -775,7 +775,7 @@ test('ParamSpec13', () => { configOptions.defaultPythonVersion = PythonVersion.V3_10; const results = TestUtils.typeAnalyzeSampleFiles(['paramSpec13.py'], configOptions); - TestUtils.validateResults(results, 5); + TestUtils.validateResults(results, 6); }); test('ParamSpec14', () => {