Skip to content

Commit

Permalink
Refactor PythonVersion#parse
Browse files Browse the repository at this point in the history
Instead of passing in a granular list of allowed versions, just have two separate methods for target versions versus sources versions, since that's how it's used anyway.

For simplicity in the PythonVersion class, don't intercept the IllegalArgumentException or allow the argument to be null. That can be the caller's responsibility.

Also do automated formatting fixes required by presubmit.

Work toward #6583.

RELNOTES: None
PiperOrigin-RevId: 223533178
  • Loading branch information
brandjon authored and Copybara-Service committed Nov 30, 2018
1 parent f195d55 commit 1f2298c
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,10 @@ public PyCommon(RuleContext ruleContext) {
}

public void initCommon(PythonVersion defaultVersion) {
this.sourcesVersion =
getPythonVersionAttr(ruleContext, "srcs_version", PythonVersion.getAllValues());

this.sourcesVersion = getSrcsVersionAttr(ruleContext);
this.version = ruleContext.getFragment(PythonConfiguration.class)
.getPythonVersion(defaultVersion);

transitivePythonSources = collectTransitivePythonSources();

this.transitivePythonSources = collectTransitivePythonSources();
checkSourceIsCompatible(this.version, this.sourcesVersion, ruleContext.getLabel());
}

Expand Down Expand Up @@ -199,25 +195,41 @@ public static StructImpl createSourceProvider(
}

public PythonVersion getDefaultPythonVersion() {
return ruleContext.getRule()
.isAttrDefined("default_python_version", Type.STRING)
? getPythonVersionAttr(
ruleContext, "default_python_version", PythonVersion.PY2, PythonVersion.PY3)
: null;
}

public static PythonVersion getPythonVersionAttr(RuleContext ruleContext,
String attrName, PythonVersion... allowed) {
String stringAttr = ruleContext.attributes().get(attrName, Type.STRING);
PythonVersion version = PythonVersion.parse(stringAttr, allowed);
if (version != null) {
return version;
return ruleContext.getRule().isAttrDefined("default_python_version", Type.STRING)
? getPythonVersionAttr(ruleContext)
: null;
}

/** Returns the parsed value of the "srcs_version" attribute. */
private static PythonVersion getSrcsVersionAttr(RuleContext ruleContext) {
String attrValue = ruleContext.attributes().get("srcs_version", Type.STRING);
try {
return PythonVersion.parseSrcsValue(attrValue);
} catch (IllegalArgumentException ex) {
// Should already have been disallowed in the rule.
ruleContext.attributeError(
"srcs_version",
String.format(
"'%s' is not a valid value. Expected one of: %s",
attrValue, Joiner.on(", ").join(PythonVersion.getAllStrings())));
return PythonVersion.getDefaultSrcsValue();
}
}

/** Returns the parsed value of the "default_python_version" attribute. */
private static PythonVersion getPythonVersionAttr(RuleContext ruleContext) {
String attrValue = ruleContext.attributes().get("default_python_version", Type.STRING);
try {
return PythonVersion.parseTargetValue(attrValue);
} catch (IllegalArgumentException ex) {
// Should already have been disallowed in the rule.
ruleContext.attributeError(
"default_python_version",
String.format(
"'%s' is not a valid value. Expected one of: %s",
attrValue, Joiner.on(", ").join(PythonVersion.getTargetStrings())));
return PythonVersion.getDefaultTargetValue();
}
// Should already have been disallowed in the rule.
ruleContext.attributeError(attrName,
"'" + stringAttr + "' is not a valid value. Expected one of: " + Joiner.on(", ")
.join(allowed));
return PythonVersion.getDefaultTargetValue();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import com.google.devtools.build.lib.syntax.Type;
import com.google.devtools.build.lib.util.FileType;

/**
* Rule definitions for Python rules.
*/
/** Rule definitions for Python rules. */
public class PyRuleClasses {

public static final FileType PYTHON_SOURCE = FileType.of(".py", ".py3");
Expand All @@ -34,12 +32,21 @@ public class PyRuleClasses {
* <p>Since this is a configuration transition, this propagates to the rules' transitive deps.
*/
public static final RuleTransitionFactory DEFAULT_PYTHON_VERSION_TRANSITION =
(rule) ->
new PythonVersionTransition(
// In case of a parse error, this will return null, which means that the transition
// would use the hard-coded default (PythonVersion#getDefaultTargetValue). But the
// attribute is already validated to allow only PythonVersion#getTargetStrings anyway.
PythonVersion.parse(
RawAttributeMapper.of(rule).get("default_python_version", Type.STRING),
PythonVersion.getAllValues()));
(rule) -> {
String attrDefault = RawAttributeMapper.of(rule).get("default_python_version", Type.STRING);
// It should be a target value ("PY2" or "PY3"), and if not that should be caught by
// attribute validation. But just in case, we'll treat an invalid value as null (which means
// "use the hard-coded default version") rather than propagate an unchecked exception in
// this context.
PythonVersion version = null;
// Should be non-null because this transition shouldn't be used on rules without the attr.
if (attrDefault != null) {
try {
version = PythonVersion.parseTargetValue(attrDefault);
} catch (IllegalArgumentException ex) {
// Parsing error.
}
}
return new PythonVersionTransition(version);
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

package com.google.devtools.build.lib.rules.python;

import static com.google.common.collect.Iterables.transform;

import com.google.common.base.Functions;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
Expand Down Expand Up @@ -71,23 +69,25 @@ public enum PythonVersion {
*/
PY3ONLY;

private static Iterable<String> convertToStrings(PythonVersion[] values) {
return transform(ImmutableList.copyOf(values), Functions.toStringFunction());
private static ImmutableList<String> convertToStrings(PythonVersion[] values) {
return Arrays.stream(values)
.map(Functions.toStringFunction())
.collect(ImmutableList.toImmutableList());
}

private static final PythonVersion[] allValues =
new PythonVersion[] {PY2, PY3, PY2AND3, PY2ONLY, PY3ONLY};

private static final Iterable<String> allStrings = convertToStrings(allValues);
private static final ImmutableList<String> ALL_STRINGS = convertToStrings(allValues);

private static final PythonVersion[] targetValues = new PythonVersion[] {PY2, PY3};

private static final Iterable<String> targetStrings = convertToStrings(targetValues);
private static final ImmutableList<String> TARGET_STRINGS = convertToStrings(targetValues);

private static final PythonVersion[] nonConversionValues =
new PythonVersion[] {PY2AND3, PY2ONLY, PY3ONLY};

private static final Iterable<String> nonConversionStrings =
private static final ImmutableList<String> NON_CONVERSION_STRINGS =
convertToStrings(nonConversionValues);

private static final PythonVersion DEFAULT_TARGET_VALUE = PY2;
Expand All @@ -100,8 +100,8 @@ public static PythonVersion[] getAllValues() {
}

/** Returns an iterable of all values as strings. */
public static Iterable<String> getAllStrings() {
return allStrings;
public static ImmutableList<String> getAllStrings() {
return ALL_STRINGS;
}

/** Returns all values representing a specific version, as a new array. */
Expand All @@ -110,8 +110,8 @@ public static PythonVersion[] getTargetValues() {
}

/** Returns an iterable of all values representing a specific version, as strings. */
public static Iterable<String> getTargetStrings() {
return targetStrings;
public static ImmutableList<String> getTargetStrings() {
return TARGET_STRINGS;
}

/**
Expand All @@ -126,8 +126,8 @@ public static PythonVersion[] getNonConversionValues() {
* Returns all values that do not imply running a transpiler to convert between versions, as
* strings.
*/
public static Iterable<String> getNonConversionStrings() {
return nonConversionStrings;
public static ImmutableList<String> getNonConversionStrings() {
return NON_CONVERSION_STRINGS;
}

/** Returns the Python version to use if not otherwise specified by a flag or attribute. */
Expand All @@ -142,26 +142,26 @@ public static PythonVersion getDefaultSrcsValue() {
return DEFAULT_SRCS_VALUE;
}

// TODO(brandjon): Refactor this into parseTargetValue and parseSourcesValue methods. Throw
// IllegalArgumentException on bad values instead of returning null, and modify callers to
// tolerate the exception.
/**
* Converts the string to PythonVersion, if it is one of the allowed values. Returns null if the
* input is not valid.
* Converts the string to a target {@code PythonVersion} value (case-sensitive).
*
* @throws IllegalArgumentException if the string is not "PY2" or "PY3".
*/
public static PythonVersion parse(String str, PythonVersion... allowed) {
if (str == null) {
return null;
}
try {
PythonVersion version = PythonVersion.valueOf(str);
if (Arrays.asList(allowed).contains(version)) {
return version;
}
return null;
} catch (IllegalArgumentException e) {
return null;
public static PythonVersion parseTargetValue(String str) {
if (!TARGET_STRINGS.contains(str)) {
throw new IllegalArgumentException(
String.format("'%s' is not a valid Python major version. Expected 'PY2' or 'PY3'.", str));
}
return PythonVersion.valueOf(str);
}

/**
* Converts the string to a sources {@code PythonVersion} value (case-sensitive).
*
* @throws IllegalArgumentException if the string is not an enum name.
*/
public static PythonVersion parseSrcsValue(String str) {
return PythonVersion.valueOf(str);
}
}

12 changes: 12 additions & 0 deletions src/test/java/com/google/devtools/build/lib/rules/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ test_suite(
":PyLibraryConfiguredTargetTest",
":PyTestConfiguredTargetTest",
":PythonConfigurationTest",
":PythonVersionTest",
],
)

Expand Down Expand Up @@ -105,3 +106,14 @@ java_test(
"//third_party:truth",
],
)

java_test(
name = "PythonVersionTest",
srcs = ["PythonVersionTest.java"],
deps = [
"//src/main/java/com/google/devtools/build/lib:python-rules",
"//src/test/java/com/google/devtools/build/lib:testutil",
"//third_party:junit4",
"//third_party:truth",
],
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright 2018 The Bazel Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.google.devtools.build.lib.rules.python;

import static com.google.common.truth.Truth.assertThat;
import static com.google.devtools.build.lib.testutil.MoreAsserts.assertThrows;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link PythonVersion}. */
@RunWith(JUnit4.class)
public class PythonVersionTest {

@Test
public void parseTargetValue() {
assertThat(PythonVersion.parseTargetValue("PY2")).isEqualTo(PythonVersion.PY2);

IllegalArgumentException expected =
assertThrows(
IllegalArgumentException.class, () -> PythonVersion.parseTargetValue("PY2AND3"));
assertThat(expected).hasMessageThat().contains("not a valid Python major version");

expected =
assertThrows(
IllegalArgumentException.class,
() -> PythonVersion.parseTargetValue("not an enum value"));
assertThat(expected).hasMessageThat().contains("not a valid Python major version");

expected =
assertThrows(IllegalArgumentException.class, () -> PythonVersion.parseTargetValue("py2"));
assertThat(expected).hasMessageThat().contains("not a valid Python major version");
}

@Test
public void parseSrcsValue() {
assertThat(PythonVersion.parseSrcsValue("PY2")).isEqualTo(PythonVersion.PY2);

assertThat(PythonVersion.parseSrcsValue("PY2AND3")).isEqualTo(PythonVersion.PY2AND3);

IllegalArgumentException expected =
assertThrows(
IllegalArgumentException.class,
() -> PythonVersion.parseSrcsValue("not an enum value"));
assertThat(expected).hasMessageThat().contains("No enum constant");

expected =
assertThrows(IllegalArgumentException.class, () -> PythonVersion.parseSrcsValue("py2"));
assertThat(expected).hasMessageThat().contains("No enum constant");
}
}

0 comments on commit 1f2298c

Please sign in to comment.