Skip to content

Commit

Permalink
Rename Py.Union to Py.UnionType
Browse files Browse the repository at this point in the history
This is more consistent with the naming we have in `J` and also avoids conflict with `typing.Union`.
  • Loading branch information
knutwannheden committed Oct 16, 2024
1 parent d99cb39 commit 4e2d5da
Show file tree
Hide file tree
Showing 14 changed files with 79 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -284,13 +284,13 @@ public Py.YieldFrom visitYieldFrom(Py.YieldFrom yieldFrom, ReceiverContext ctx)
}

@Override
public Py.Union visitUnion(Py.Union union, ReceiverContext ctx) {
union = union.withId(ctx.receiveNonNullValue(union.getId(), UUID.class));
union = union.withPrefix(ctx.receiveNonNullNode(union.getPrefix(), PythonReceiver::receiveSpace));
union = union.withMarkers(ctx.receiveNonNullNode(union.getMarkers(), ctx::receiveMarkers));
union = union.getPadding().withTypes(ctx.receiveNonNullNodes(union.getPadding().getTypes(), PythonReceiver::receiveRightPaddedTree));
union = union.withType(ctx.receiveValue(union.getType(), JavaType.class));
return union;
public Py.UnionType visitUnionType(Py.UnionType unionType, ReceiverContext ctx) {
unionType = unionType.withId(ctx.receiveNonNullValue(unionType.getId(), UUID.class));
unionType = unionType.withPrefix(ctx.receiveNonNullNode(unionType.getPrefix(), PythonReceiver::receiveSpace));
unionType = unionType.withMarkers(ctx.receiveNonNullNode(unionType.getMarkers(), ctx::receiveMarkers));
unionType = unionType.getPadding().withTypes(ctx.receiveNonNullNodes(unionType.getPadding().getTypes(), PythonReceiver::receiveRightPaddedTree));
unionType = unionType.withType(ctx.receiveValue(unionType.getType(), JavaType.class));
return unionType;
}

@Override
Expand Down Expand Up @@ -1310,8 +1310,8 @@ public <T> T create(Class<T> type, ReceiverContext ctx) {
);
}

if (type == Py.Union.class) {
return (T) new Py.Union(
if (type == Py.UnionType.class) {
return (T) new Py.UnionType(
ctx.receiveNonNullValue(null, UUID.class),
ctx.receiveNonNullNode(null, PythonReceiver::receiveSpace),
ctx.receiveNonNullNode(null, ctx::receiveMarkers),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,13 @@ public Py.YieldFrom visitYieldFrom(Py.YieldFrom yieldFrom, SenderContext ctx) {
}

@Override
public Py.Union visitUnion(Py.Union union, SenderContext ctx) {
ctx.sendValue(union, Py.Union::getId);
ctx.sendNode(union, Py.Union::getPrefix, PythonSender::sendSpace);
ctx.sendNode(union, Py.Union::getMarkers, ctx::sendMarkers);
ctx.sendNodes(union, e -> e.getPadding().getTypes(), PythonSender::sendRightPadded, e -> e.getElement().getId());
ctx.sendTypedValue(union, Py.Union::getType);
return union;
public Py.UnionType visitUnionType(Py.UnionType unionType, SenderContext ctx) {
ctx.sendValue(unionType, Py.UnionType::getId);
ctx.sendNode(unionType, Py.UnionType::getPrefix, PythonSender::sendSpace);
ctx.sendNode(unionType, Py.UnionType::getMarkers, ctx::sendMarkers);
ctx.sendNodes(unionType, e -> e.getPadding().getTypes(), PythonSender::sendRightPadded, e -> e.getElement().getId());
ctx.sendTypedValue(unionType, Py.UnionType::getType);
return unionType;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ public Py.TypeHintedExpression visitTypeHintedExpression(Py.TypeHintedExpression
}

@Override
public Py.Union visitUnion(Py.Union union, P p) {
return (Py.Union) super.visitUnion(union, p);
public Py.UnionType visitUnionType(Py.UnionType unionType, P p) {
return (Py.UnionType) super.visitUnionType(unionType, p);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -393,17 +393,17 @@ public J visitTrailingElseWrapper(Py.TrailingElseWrapper ogWrapper, P p) {
return wrapper;
}

public J visitUnion(Py.Union union, P p) {
Py.Union u = union;
u = u.withPrefix(visitSpace(u.getPrefix(), PySpace.Location.UNION_PREFIX, p));
public J visitUnionType(Py.UnionType unionType, P p) {
Py.UnionType u = unionType;
u = u.withPrefix(visitSpace(u.getPrefix(), PySpace.Location.UNION_TYPE_PREFIX, p));
u = u.withMarkers(visitMarkers(u.getMarkers(), p));
Expression temp = (Expression) visitExpression(u, p);
if (!(temp instanceof Py.Union)) {
if (!(temp instanceof Py.UnionType)) {
return temp;
} else {
u = (Py.Union) temp;
u = (Py.UnionType) temp;
}
u = u.getPadding().withTypes(ListUtils.map(u.getPadding().getTypes(), e -> visitRightPadded(e, PyRightPadded.Location.UNION_TYPE, p)));
u = u.getPadding().withTypes(ListUtils.map(u.getPadding().getTypes(), e -> visitRightPadded(e, PyRightPadded.Location.UNION_TYPE_TYPE, p)));
return u;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -640,11 +640,11 @@ public J visitTypeHintedExpression(Py.TypeHintedExpression expr, PrintOutputCapt
}

@Override
public J visitUnion(Py.Union union, PrintOutputCapture<P> p) {
beforeSyntax(union, PySpace.Location.UNION_PREFIX, p);
visitRightPadded(union.getPadding().getTypes(), PyRightPadded.Location.UNION_TYPE, "|", p);
afterSyntax(union, p);
return union;
public J visitUnionType(Py.UnionType unionType, PrintOutputCapture<P> p) {
beforeSyntax(unionType, PySpace.Location.UNION_TYPE_PREFIX, p);
visitRightPadded(unionType.getPadding().getTypes(), PyRightPadded.Location.UNION_TYPE_TYPE, "|", p);
afterSyntax(unionType, p);
return unionType;
}

@SuppressWarnings("SameParameterValue")
Expand Down
12 changes: 6 additions & 6 deletions rewrite-python/src/main/java/org/openrewrite/python/tree/Py.java
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ public CoordinateBuilder.Expression getCoordinates() {
@RequiredArgsConstructor
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@Data
final class Union implements Py, Expression, TypeTree {
final class UnionType implements Py, Expression, TypeTree {

@Nullable
@NonFinal
Expand All @@ -1381,7 +1381,7 @@ public List<Expression> getTypes() {
return JRightPadded.getElements(types);
}

public Union withTypes(List<Expression> types) {
public UnionType withTypes(List<Expression> types) {
return getPadding().withTypes(JRightPadded.withElements(this.types, types));
}

Expand All @@ -1391,7 +1391,7 @@ public Union withTypes(List<Expression> types) {

@Override
public <P> J acceptPython(PythonVisitor<P> v, P p) {
return v.visitUnion(this, p);
return v.visitUnionType(this, p);
}

@Transient
Expand All @@ -1417,14 +1417,14 @@ public Padding getPadding() {

@RequiredArgsConstructor
public static class Padding {
private final Union t;
private final UnionType t;

public List<JRightPadded<Expression>> getTypes() {
return t.types;
}

public Union withTypes(List<JRightPadded<Expression>> types) {
return t.types == types ? t : new Union(t.id, t.prefix, t.markers, types, t.type);
public UnionType withTypes(List<JRightPadded<Expression>> types) {
return t.types == types ? t : new UnionType(t.id, t.prefix, t.markers, types, t.type);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public enum Location {
SLICE_EXPRESSION_STEP(PySpace.Location.SLICE_STEP_SUFFIX),
SLICE_EXPRESSION_STOP(PySpace.Location.SLICE_STOP_SUFFIX),
TOP_LEVEL_STATEMENT_SUFFIX(PySpace.Location.TOP_LEVEL_STATEMENT),
UNION_TYPE(PySpace.Location.UNION_ELEMENT_SUFFIX),
UNION_TYPE_TYPE(PySpace.Location.UNION_ELEMENT_SUFFIX),
VARIABLE_SCOPE_ELEMENT(PySpace.Location.VARIABLE_SCOPE_NAME_SUFFIX),
;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ public enum Location {
TYPE_HINTED_EXPRESSION_PREFIX,
TYPE_HINT_PREFIX,
UNION_ELEMENT_SUFFIX,
UNION_PREFIX,
UNION_TYPE_PREFIX,
VARIABLE_SCOPE_NAME_SUFFIX,
VARIABLE_SCOPE_PREFIX,
YIELD_FROM_PREFIX,
Expand Down
2 changes: 1 addition & 1 deletion rewrite/rewrite/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
'TrailingElseWrapper',
'TypeHint',
'TypeHintedExpression',
'Union',
'UnionType',
'VariableScope',
'YieldFrom',
]
2 changes: 1 addition & 1 deletion rewrite/rewrite/python/_parser_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1714,7 +1714,7 @@ def __convert_type_hint(self, node) -> Optional[TypeTree]:
# FIXME consider flattening nested unions
left = self.__pad_right(self.__convert_internal(node.left, self.__convert_type_hint), self.__source_before('|'))
right = self.__pad_right(self.__convert_internal(node.right, self.__convert_type_hint), Space.EMPTY)
return py.Union(
return py.UnionType(
random_id(),
prefix,
Markers.EMPTY,
Expand Down
18 changes: 9 additions & 9 deletions rewrite/rewrite/python/remote/receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,13 +202,13 @@ def visit_yield_from(self, yield_from: YieldFrom, ctx: ReceiverContext) -> J:
yield_from = yield_from.with_type(ctx.receive_value(yield_from.type, JavaType))
return yield_from

def visit_union(self, union: Union, ctx: ReceiverContext) -> J:
union = union.with_id(ctx.receive_value(union.id, UUID))
union = union.with_prefix(ctx.receive_node(union.prefix, PythonReceiver.receive_space))
union = union.with_markers(ctx.receive_node(union.markers, ctx.receive_markers))
union = union.padding.with_types(ctx.receive_nodes(union.padding.types, PythonReceiver.receive_right_padded_tree))
union = union.with_type(ctx.receive_value(union.type, JavaType))
return union
def visit_union_type(self, union_type: UnionType, ctx: ReceiverContext) -> J:
union_type = union_type.with_id(ctx.receive_value(union_type.id, UUID))
union_type = union_type.with_prefix(ctx.receive_node(union_type.prefix, PythonReceiver.receive_space))
union_type = union_type.with_markers(ctx.receive_node(union_type.markers, ctx.receive_markers))
union_type = union_type.padding.with_types(ctx.receive_nodes(union_type.padding.types, PythonReceiver.receive_right_padded_tree))
union_type = union_type.with_type(ctx.receive_value(union_type.type, JavaType))
return union_type

def visit_variable_scope(self, variable_scope: VariableScope, ctx: ReceiverContext) -> J:
variable_scope = variable_scope.with_id(ctx.receive_value(variable_scope.id, UUID))
Expand Down Expand Up @@ -1088,8 +1088,8 @@ def create(self, type: str, ctx: ReceiverContext) -> Tree:
ctx.receive_value(None, JavaType)
)

if type in ["rewrite.python.tree.Union", "org.openrewrite.python.tree.Py$Union"]:
return Union(
if type in ["rewrite.python.tree.UnionType", "org.openrewrite.python.tree.Py$UnionType"]:
return UnionType(
ctx.receive_value(None, UUID),
ctx.receive_node(None, PythonReceiver.receive_space),
ctx.receive_node(None, ctx.receive_markers),
Expand Down
14 changes: 7 additions & 7 deletions rewrite/rewrite/python/remote/sender.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@ def visit_yield_from(self, yield_from: YieldFrom, ctx: SenderContext) -> J:
ctx.send_typed_value(yield_from, attrgetter('_type'))
return yield_from

def visit_union(self, union: Union, ctx: SenderContext) -> J:
ctx.send_value(union, attrgetter('_id'))
ctx.send_node(union, attrgetter('_prefix'), PythonSender.send_space)
ctx.send_node(union, attrgetter('_markers'), ctx.send_markers)
ctx.send_nodes(union, attrgetter('_types'), PythonSender.send_right_padded, lambda t: t.element.id)
ctx.send_typed_value(union, attrgetter('_type'))
return union
def visit_union_type(self, union_type: UnionType, ctx: SenderContext) -> J:
ctx.send_value(union_type, attrgetter('_id'))
ctx.send_node(union_type, attrgetter('_prefix'), PythonSender.send_space)
ctx.send_node(union_type, attrgetter('_markers'), ctx.send_markers)
ctx.send_nodes(union_type, attrgetter('_types'), PythonSender.send_right_padded, lambda t: t.element.id)
ctx.send_typed_value(union_type, attrgetter('_type'))
return union_type

def visit_variable_scope(self, variable_scope: VariableScope, ctx: SenderContext) -> J:
ctx.send_value(variable_scope, attrgetter('_id'))
Expand Down
24 changes: 12 additions & 12 deletions rewrite/rewrite/python/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,14 +1419,14 @@ def accept_python(self, v: PythonVisitor[P], p: P) -> J:

# noinspection PyShadowingBuiltins,PyShadowingNames,DuplicatedCode
@dataclass(frozen=True, eq=False)
class Union(Py, Expression, TypeTree):
class UnionType(Py, Expression, TypeTree):
_id: UUID

@property
def id(self) -> UUID:
return self._id

def with_id(self, id: UUID) -> Union:
def with_id(self, id: UUID) -> UnionType:
return self if id is self._id else replace(self, _id=id)

_prefix: Space
Expand All @@ -1435,7 +1435,7 @@ def with_id(self, id: UUID) -> Union:
def prefix(self) -> Space:
return self._prefix

def with_prefix(self, prefix: Space) -> Union:
def with_prefix(self, prefix: Space) -> UnionType:
return self if prefix is self._prefix else replace(self, _prefix=prefix)

_markers: Markers
Expand All @@ -1444,7 +1444,7 @@ def with_prefix(self, prefix: Space) -> Union:
def markers(self) -> Markers:
return self._markers

def with_markers(self, markers: Markers) -> Union:
def with_markers(self, markers: Markers) -> UnionType:
return self if markers is self._markers else replace(self, _markers=markers)

_types: List[JRightPadded[Expression]]
Expand All @@ -1453,7 +1453,7 @@ def with_markers(self, markers: Markers) -> Union:
def types(self) -> List[Expression]:
return JRightPadded.get_elements(self._types)

def with_types(self, types: List[Expression]) -> Union:
def with_types(self, types: List[Expression]) -> UnionType:
return self.padding.with_types(JRightPadded.with_elements(self._types, types))

_type: Optional[JavaType]
Expand All @@ -1462,38 +1462,38 @@ def with_types(self, types: List[Expression]) -> Union:
def type(self) -> Optional[JavaType]:
return self._type

def with_type(self, type: Optional[JavaType]) -> Union:
def with_type(self, type: Optional[JavaType]) -> UnionType:
return self if type is self._type else replace(self, _type=type)

@dataclass
class PaddingHelper:
_t: Union
_t: UnionType

@property
def types(self) -> List[JRightPadded[Expression]]:
return self._t._types

def with_types(self, types: List[JRightPadded[Expression]]) -> Union:
def with_types(self, types: List[JRightPadded[Expression]]) -> UnionType:
return self._t if self._t._types is types else replace(self._t, _types=types)

_padding: weakref.ReferenceType[PaddingHelper] = None

@property
def padding(self) -> PaddingHelper:
p: Union.PaddingHelper
p: UnionType.PaddingHelper
if self._padding is None:
p = Union.PaddingHelper(self)
p = UnionType.PaddingHelper(self)
object.__setattr__(self, '_padding', weakref.ref(p))
else:
p = self._padding()
# noinspection PyProtectedMember
if p is None or p._t != self:
p = Union.PaddingHelper(self)
p = UnionType.PaddingHelper(self)
object.__setattr__(self, '_padding', weakref.ref(p))
return p

def accept_python(self, v: PythonVisitor[P], p: P) -> J:
return v.visit_union(self, p)
return v.visit_union_type(self, p)

# noinspection PyShadowingBuiltins,PyShadowingNames,DuplicatedCode
@dataclass(frozen=True, eq=False)
Expand Down
25 changes: 12 additions & 13 deletions rewrite/rewrite/python/visitor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from typing import cast, TypeVar, Union

from rewrite import SourceFile, TreeVisitor
from .extensions import *
from .support_types import *
from .tree import *
from typing import cast, TypeVar
import typing
from rewrite.java import JavaVisitor

# noinspection DuplicatedCode
Expand Down Expand Up @@ -193,15 +192,15 @@ def visit_yield_from(self, yield_from: YieldFrom, p: P) -> J:
yield_from = yield_from.with_expression(self.visit_and_cast(yield_from.expression, Expression, p))
return yield_from

def visit_union(self, union: Union, p: P) -> J:
union = union.with_prefix(self.visit_space(union.prefix, PySpace.Location.UNION_PREFIX, p))
temp_expression = cast(Expression, self.visit_expression(union, p))
if not isinstance(temp_expression, Union):
def visit_union_type(self, union_type: UnionType, p: P) -> J:
union_type = union_type.with_prefix(self.visit_space(union_type.prefix, PySpace.Location.UNION_TYPE_PREFIX, p))
temp_expression = cast(Expression, self.visit_expression(union_type, p))
if not isinstance(temp_expression, UnionType):
return temp_expression
union = cast(Union, temp_expression)
union = union.with_markers(self.visit_markers(union.markers, p))
union = union.padding.with_types([self.visit_right_padded(v, PyRightPadded.Location.UNION_TYPES, p) for v in union.padding.types])
return union
union_type = cast(UnionType, temp_expression)
union_type = union_type.with_markers(self.visit_markers(union_type.markers, p))
union_type = union_type.padding.with_types([self.visit_right_padded(v, PyRightPadded.Location.UNION_TYPE_TYPES, p) for v in union_type.padding.types])
return union_type

def visit_variable_scope(self, variable_scope: VariableScope, p: P) -> J:
variable_scope = variable_scope.with_prefix(self.visit_space(variable_scope.prefix, PySpace.Location.VARIABLE_SCOPE_PREFIX, p))
Expand Down Expand Up @@ -305,12 +304,12 @@ def visit_slice(self, slice: Slice, p: P) -> J:
slice = slice.padding.with_step(self.visit_right_padded(slice.padding.step, PyRightPadded.Location.SLICE_STEP, p))
return slice

def visit_container(self, container: Optional[JContainer[J2]], loc: typing.Union[PyContainer.Location, JContainer.Location], p: P) -> JContainer[J2]:
def visit_container(self, container: Optional[JContainer[J2]], loc: Union[PyContainer.Location, JContainer.Location], p: P) -> JContainer[J2]:
if isinstance(loc, JContainer.Location):
return super().visit_container(container, loc, p)
return extensions.visit_container(self, container, loc, p)

def visit_right_padded(self, right: Optional[JRightPadded[T]], loc: typing.Union[PyRightPadded.Location, JRightPadded.Location], p: P) -> Optional[JRightPadded[T]]:
def visit_right_padded(self, right: Optional[JRightPadded[T]], loc: Union[PyRightPadded.Location, JRightPadded.Location], p: P) -> Optional[JRightPadded[T]]:
if isinstance(loc, JRightPadded.Location):
return super().visit_right_padded(right, loc, p)
return extensions.visit_right_padded(self, right, loc, p)
Expand All @@ -320,7 +319,7 @@ def visit_left_padded(self, left: Optional[JLeftPadded[T]], loc: PyLeftPadded.Lo
return super().visit_left_padded(left, loc, p)
return extensions.visit_left_padded(self, left, loc, p)

def visit_space(self, space: Optional[Space], loc: Optional[typing.Union[PySpace.Location, Space.Location]], p: P) -> Space:
def visit_space(self, space: Optional[Space], loc: Optional[Union[PySpace.Location, Space.Location]], p: P) -> Space:
if isinstance(loc, Space.Location) or loc is None:
return super().visit_space(space, loc, p)
return extensions.visit_space(self, space, loc, p)

0 comments on commit 4e2d5da

Please sign in to comment.