Skip to content

Commit

Permalink
Added check for a class that derives from a protocol class where the …
Browse files Browse the repository at this point in the history
…protocol declares a method with an empty implementation, and the subclass doesn't provide a concrete implementation of the same-named method.
  • Loading branch information
msfterictraut committed Jan 21, 2022
1 parent a45b8dc commit 44c98fe
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 28 deletions.
50 changes: 31 additions & 19 deletions packages/pyright-internal/src/analyzer/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import { DiagnosticLevel } from '../common/configOptions';
import { assert } from '../common/debug';
import { Diagnostic, DiagnosticAddendum } from '../common/diagnostic';
import { DiagnosticRule } from '../common/diagnosticRules';
import { getFileExtension } from '../common/pathUtils';
import { PythonVersion, versionToString } from '../common/pythonVersion';
import { TextRange } from '../common/textRange';
import { Localizer } from '../localization/localize';
Expand Down Expand Up @@ -3348,7 +3349,8 @@ export class Checker extends ParseTreeWalker {

// If a non-protocol class explicitly inherits from a protocol class, this method
// verifies that any class or instance variables declared but not assigned
// in the protocol class are implemented in the subclass.
// in the protocol class are implemented in the subclass. It also checks that any
// empty functions declared in the protocol are implemented in the subclass.
private _validateProtocolCompatibility(classType: ClassType, errorNode: ClassNode) {
if (ClassType.isProtocolClass(classType)) {
return;
Expand All @@ -3366,27 +3368,37 @@ export class Checker extends ParseTreeWalker {
protocolSymbols.forEach((member, name) => {
const decls = member.symbol.getDeclarations();

// We care only about variables, not functions or other declaration types.
if (decls.length === 0 || decls[0].type !== DeclarationType.Variable) {
if (decls.length === 0 || !isClass(member.classType)) {
return;
}

if (!isClass(member.classType)) {
return;
}

// If none of the declarations involve assignments, assume it's
// not implemented in the protocol.
if (!decls.some((decl) => decl.type === DeclarationType.Variable && !!decl.inferredTypeSource)) {
// This is a variable declaration that is not implemented in the
// protocol base class. Make sure it's implemented in the derived class.
if (!classType.details.fields.has(name)) {
diagAddendum.addMessage(
Localizer.DiagnosticAddendum.missingProtocolMember().format({
name,
classType: member.classType.details.name,
})
);
if (decls[0].type === DeclarationType.Variable) {
// If none of the declarations involve assignments, assume it's
// not implemented in the protocol.
if (!decls.some((decl) => decl.type === DeclarationType.Variable && !!decl.inferredTypeSource)) {
// This is a variable declaration that is not implemented in the
// protocol base class. Make sure it's implemented in the derived class.
if (!classType.details.fields.has(name)) {
diagAddendum.addMessage(
Localizer.DiagnosticAddendum.missingProtocolMember().format({
name,
classType: member.classType.details.name,
})
);
}
}
} else if (decls[0].type === DeclarationType.Function) {
if (ParseTreeUtils.isSuiteEmpty(decls[0].node.suite) && decls[0]) {
if (getFileExtension(decls[0].path).toLowerCase() !== '.pyi') {
if (!classType.details.fields.has(name)) {
diagAddendum.addMessage(
Localizer.DiagnosticAddendum.missingProtocolMember().format({
name,
classType: member.classType.details.name,
})
);
}
}
}
}
});
Expand Down
18 changes: 17 additions & 1 deletion packages/pyright-internal/src/tests/samples/classes8.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,34 @@
# it does appear within the stdlib typeshed stubs (see os.scandir).

from os import DirEntry
from typing import AnyStr, ContextManager, Iterator
from types import TracebackType
from typing import AnyStr, ContextManager, Iterator, Type
from typing_extensions import Self


class _ScandirIterator(
Iterator[DirEntry[AnyStr]], ContextManager["_ScandirIterator[AnyStr]"]
):
def __iter__(self) -> Self:
...

def __next__(self) -> DirEntry[AnyStr]:
...

def close(self) -> None:
...

def __enter__(self) -> Self:
...

def __exit__(
self,
__exc_type: Type[BaseException] | None,
__exc_value: BaseException | None,
__traceback: TracebackType | None,
) -> bool | None:
...


def scandir(path: AnyStr) -> _ScandirIterator[AnyStr]:
...
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# This sample tests that a Generic base class overrides the type parameter
# ordering of other type parameters.

from typing import Container, Generic, Iterable, Mapping, Protocol, TypeVar
from typing import Container, Generic, Iterable, Iterator, Mapping, Protocol, TypeVar

_T1 = TypeVar("_T1")
_T2 = TypeVar(
Expand All @@ -16,6 +16,9 @@ def __init__(self, a: _T1, b: _T2):
def foo(self, a: _T1, b: _T2) -> _T2:
return b

def __iter__(self) -> Iterator[int]:
...


a: Foo[int, str] = Foo(2, "")
b: str = a.foo(4, "")
Expand Down
6 changes: 3 additions & 3 deletions packages/pyright-internal/src/tests/samples/genericTypes11.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

# pyright: strict

from typing import Callable, Iterator, TypeVar
from typing import Callable, Iterator, Protocol, TypeVar

_T = TypeVar("_T")
_T = TypeVar("_T", covariant=True)


class Foo(Iterator[_T]):
class Foo(Iterator[_T], Protocol):
pass


Expand Down
8 changes: 6 additions & 2 deletions packages/pyright-internal/src/tests/samples/paramSpec24.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,20 @@
from typing_extensions import Self, Concatenate, ParamSpec
from typing import Any, Callable, TypeVar, Protocol, Generic, overload

T = TypeVar("T", covariant=True)
T = TypeVar("T")
O = TypeVar("O")
P = ParamSpec("P")


class _callable_cache(Protocol[P, T]):
foo: int = 0
val: T

def __init__(self, val: T) -> None:
self.val = val

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
...
return self.val


class _wrapped_cache(_callable_cache[P, T], Generic[O, P, T]):
Expand Down
13 changes: 12 additions & 1 deletion packages/pyright-internal/src/tests/samples/protocol27.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This sample tests the logic that validates that a concrete class that
# explicitly derives from a protocol class implements the variables
# defined in the protocol.
# and functions defined in the protocol.

from typing import ClassVar, Protocol

Expand Down Expand Up @@ -51,3 +51,14 @@ def __init__(self):
self.im1 = 3
self.im10 = 10
self.cm11 = 3


class Protocol5(Protocol):
def foo(self) -> int:
...


# This should generate an error because "foo" is
# not implemented.
class Concrete5(Protocol5):
pass
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/typeEvaluator2.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ test('Protocol26', () => {
test('Protocol27', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['protocol27.py']);

TestUtils.validateResults(analysisResults, 2);
TestUtils.validateResults(analysisResults, 3);
});

test('TypedDict1', () => {
Expand Down

0 comments on commit 44c98fe

Please sign in to comment.