diff --git a/packages/pyright-internal/src/analyzer/checker.ts b/packages/pyright-internal/src/analyzer/checker.ts index 9097cfb726b5..d74e323ae578 100644 --- a/packages/pyright-internal/src/analyzer/checker.ts +++ b/packages/pyright-internal/src/analyzer/checker.ts @@ -17,6 +17,7 @@ import { DiagnosticLevel } from '../common/configOptions'; import { assert } from '../common/debug'; import { Diagnostic, DiagnosticAddendum } from '../common/diagnostic'; import { DiagnosticRule } from '../common/diagnosticRules'; +import { getFileExtension } from '../common/pathUtils'; import { PythonVersion, versionToString } from '../common/pythonVersion'; import { TextRange } from '../common/textRange'; import { Localizer } from '../localization/localize'; @@ -3348,7 +3349,8 @@ export class Checker extends ParseTreeWalker { // If a non-protocol class explicitly inherits from a protocol class, this method // verifies that any class or instance variables declared but not assigned - // in the protocol class are implemented in the subclass. + // in the protocol class are implemented in the subclass. It also checks that any + // empty functions declared in the protocol are implemented in the subclass. private _validateProtocolCompatibility(classType: ClassType, errorNode: ClassNode) { if (ClassType.isProtocolClass(classType)) { return; @@ -3366,27 +3368,37 @@ export class Checker extends ParseTreeWalker { protocolSymbols.forEach((member, name) => { const decls = member.symbol.getDeclarations(); - // We care only about variables, not functions or other declaration types. - if (decls.length === 0 || decls[0].type !== DeclarationType.Variable) { + if (decls.length === 0 || !isClass(member.classType)) { return; } - if (!isClass(member.classType)) { - return; - } - - // If none of the declarations involve assignments, assume it's - // not implemented in the protocol. - if (!decls.some((decl) => decl.type === DeclarationType.Variable && !!decl.inferredTypeSource)) { - // This is a variable declaration that is not implemented in the - // protocol base class. Make sure it's implemented in the derived class. - if (!classType.details.fields.has(name)) { - diagAddendum.addMessage( - Localizer.DiagnosticAddendum.missingProtocolMember().format({ - name, - classType: member.classType.details.name, - }) - ); + if (decls[0].type === DeclarationType.Variable) { + // If none of the declarations involve assignments, assume it's + // not implemented in the protocol. + if (!decls.some((decl) => decl.type === DeclarationType.Variable && !!decl.inferredTypeSource)) { + // This is a variable declaration that is not implemented in the + // protocol base class. Make sure it's implemented in the derived class. + if (!classType.details.fields.has(name)) { + diagAddendum.addMessage( + Localizer.DiagnosticAddendum.missingProtocolMember().format({ + name, + classType: member.classType.details.name, + }) + ); + } + } + } else if (decls[0].type === DeclarationType.Function) { + if (ParseTreeUtils.isSuiteEmpty(decls[0].node.suite) && decls[0]) { + if (getFileExtension(decls[0].path).toLowerCase() !== '.pyi') { + if (!classType.details.fields.has(name)) { + diagAddendum.addMessage( + Localizer.DiagnosticAddendum.missingProtocolMember().format({ + name, + classType: member.classType.details.name, + }) + ); + } + } } } }); diff --git a/packages/pyright-internal/src/tests/samples/classes8.py b/packages/pyright-internal/src/tests/samples/classes8.py index 53c0b64591b6..2773d4a65556 100644 --- a/packages/pyright-internal/src/tests/samples/classes8.py +++ b/packages/pyright-internal/src/tests/samples/classes8.py @@ -3,18 +3,34 @@ # it does appear within the stdlib typeshed stubs (see os.scandir). from os import DirEntry -from typing import AnyStr, ContextManager, Iterator +from types import TracebackType +from typing import AnyStr, ContextManager, Iterator, Type +from typing_extensions import Self class _ScandirIterator( Iterator[DirEntry[AnyStr]], ContextManager["_ScandirIterator[AnyStr]"] ): + def __iter__(self) -> Self: + ... + def __next__(self) -> DirEntry[AnyStr]: ... def close(self) -> None: ... + def __enter__(self) -> Self: + ... + + def __exit__( + self, + __exc_type: Type[BaseException] | None, + __exc_value: BaseException | None, + __traceback: TracebackType | None, + ) -> bool | None: + ... + def scandir(path: AnyStr) -> _ScandirIterator[AnyStr]: ... diff --git a/packages/pyright-internal/src/tests/samples/genericTypes10.py b/packages/pyright-internal/src/tests/samples/genericTypes10.py index dda42af67a80..e95b49d15469 100644 --- a/packages/pyright-internal/src/tests/samples/genericTypes10.py +++ b/packages/pyright-internal/src/tests/samples/genericTypes10.py @@ -1,7 +1,7 @@ # This sample tests that a Generic base class overrides the type parameter # ordering of other type parameters. -from typing import Container, Generic, Iterable, Mapping, Protocol, TypeVar +from typing import Container, Generic, Iterable, Iterator, Mapping, Protocol, TypeVar _T1 = TypeVar("_T1") _T2 = TypeVar( @@ -16,6 +16,9 @@ def __init__(self, a: _T1, b: _T2): def foo(self, a: _T1, b: _T2) -> _T2: return b + def __iter__(self) -> Iterator[int]: + ... + a: Foo[int, str] = Foo(2, "") b: str = a.foo(4, "") diff --git a/packages/pyright-internal/src/tests/samples/genericTypes11.py b/packages/pyright-internal/src/tests/samples/genericTypes11.py index d1f35214d527..7f45d027bfd0 100644 --- a/packages/pyright-internal/src/tests/samples/genericTypes11.py +++ b/packages/pyright-internal/src/tests/samples/genericTypes11.py @@ -4,12 +4,12 @@ # pyright: strict -from typing import Callable, Iterator, TypeVar +from typing import Callable, Iterator, Protocol, TypeVar -_T = TypeVar("_T") +_T = TypeVar("_T", covariant=True) -class Foo(Iterator[_T]): +class Foo(Iterator[_T], Protocol): pass diff --git a/packages/pyright-internal/src/tests/samples/paramSpec24.py b/packages/pyright-internal/src/tests/samples/paramSpec24.py index 4725a0f2997c..938b8fd50b97 100644 --- a/packages/pyright-internal/src/tests/samples/paramSpec24.py +++ b/packages/pyright-internal/src/tests/samples/paramSpec24.py @@ -6,16 +6,20 @@ from typing_extensions import Self, Concatenate, ParamSpec from typing import Any, Callable, TypeVar, Protocol, Generic, overload -T = TypeVar("T", covariant=True) +T = TypeVar("T") O = TypeVar("O") P = ParamSpec("P") class _callable_cache(Protocol[P, T]): foo: int = 0 + val: T + + def __init__(self, val: T) -> None: + self.val = val def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: - ... + return self.val class _wrapped_cache(_callable_cache[P, T], Generic[O, P, T]): diff --git a/packages/pyright-internal/src/tests/samples/protocol27.py b/packages/pyright-internal/src/tests/samples/protocol27.py index de743d2c36a9..39bfdf6408a2 100644 --- a/packages/pyright-internal/src/tests/samples/protocol27.py +++ b/packages/pyright-internal/src/tests/samples/protocol27.py @@ -1,6 +1,6 @@ # This sample tests the logic that validates that a concrete class that # explicitly derives from a protocol class implements the variables -# defined in the protocol. +# and functions defined in the protocol. from typing import ClassVar, Protocol @@ -51,3 +51,14 @@ def __init__(self): self.im1 = 3 self.im10 = 10 self.cm11 = 3 + + +class Protocol5(Protocol): + def foo(self) -> int: + ... + + +# This should generate an error because "foo" is +# not implemented. +class Concrete5(Protocol5): + pass diff --git a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts index d3911f86b94a..df0a0b3f20cb 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts @@ -1009,7 +1009,7 @@ test('Protocol26', () => { test('Protocol27', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['protocol27.py']); - TestUtils.validateResults(analysisResults, 2); + TestUtils.validateResults(analysisResults, 3); }); test('TypedDict1', () => {