diff --git a/src/main/java/org/openrewrite/java/testing/jmockit/JMockitBlockRewriter.java b/src/main/java/org/openrewrite/java/testing/jmockit/JMockitBlockRewriter.java index 0e9829b4d..5544bb5d3 100644 --- a/src/main/java/org/openrewrite/java/testing/jmockit/JMockitBlockRewriter.java +++ b/src/main/java/org/openrewrite/java/testing/jmockit/JMockitBlockRewriter.java @@ -75,7 +75,7 @@ boolean isRewriteFailed() { this.newExpectations = newExpectations; this.bodyStatementIndex = bodyStatementIndex; this.blockType = blockType; - nextStatementCoordinates = newExpectations.getCoordinates().replace(); + this.nextStatementCoordinates = newExpectations.getCoordinates().replace(); } J.Block rewriteMethodBody() { @@ -125,37 +125,30 @@ private void rewriteMethodInvocation(List statementsToRewrite) { final MockInvocationResults mockInvocationResults = buildMockInvocationResults(statementsToRewrite); if (mockInvocationResults == null) { // invalid block, cannot rewrite - rewriteFailed = true; + this.rewriteFailed = true; return; } J.MethodInvocation invocation = (J.MethodInvocation) statementsToRewrite.get(0); boolean hasResults = !mockInvocationResults.getResults().isEmpty(); + boolean hasTimes = mockInvocationResults.hasAnyTimes(); if (hasResults) { - rewriteResult(invocation, mockInvocationResults.getResults()); + rewriteResult(invocation, mockInvocationResults.getResults(), hasTimes); } - if (blockType == NonStrictExpectations) { - // no verify for NonStrictExpectations + if (!hasResults && !hasTimes && (this.blockType == JMockitBlockType.Expectations || this.blockType == Verifications)) { + rewriteVerify(invocation, null, ""); return; } - - boolean hasTimes = false; if (mockInvocationResults.getTimes() != null) { - hasTimes = true; rewriteVerify(invocation, mockInvocationResults.getTimes(), "times"); } if (mockInvocationResults.getMinTimes() != null) { - hasTimes = true; rewriteVerify(invocation, mockInvocationResults.getMinTimes(), "atLeast"); } if (mockInvocationResults.getMaxTimes() != null) { - hasTimes = true; rewriteVerify(invocation, mockInvocationResults.getMaxTimes(), "atMost"); } - if (!hasResults && !hasTimes) { - rewriteVerify(invocation, null, ""); - } } private void removeBlock() { @@ -163,18 +156,15 @@ private void removeBlock() { .javaParser(JavaParser.fromJavaVersion()) .build() .apply(new Cursor(visitor.getCursor(), methodBody), nextStatementCoordinates); - if (bodyStatementIndex == 0) { - nextStatementCoordinates = methodBody.getCoordinates().firstStatement(); - } else { - setNextStatementCoordinates(0); - } + setNextStatementCoordinates(0); } - private void rewriteResult(J.MethodInvocation invocation, List results) { - String template = getWhenTemplate(results); + private void rewriteResult(J.MethodInvocation invocation, List results, boolean hasTimes) { + boolean lenient = this.blockType == NonStrictExpectations && !hasTimes; + String template = getWhenTemplate(results, lenient); if (template == null) { // invalid template, cannot rewrite - rewriteFailed = true; + this.rewriteFailed = true; return; } @@ -182,17 +172,16 @@ private void rewriteResult(J.MethodInvocation invocation, List resul templateParams.add(invocation); templateParams.addAll(results); this.rewriteFailed = !rewriteTemplate(template, templateParams, nextStatementCoordinates); - if (!this.rewriteFailed) { - this.rewriteFailed = true; - setNextStatementCoordinates(++numStatementsAdded); - // do this last making sure rewrite worked and specify hasReference=false because framework cannot find static - // reference for when method invocation when lenient is added. - boolean hasReferencesForWhen = true; - if (this.blockType == NonStrictExpectations) { - visitor.maybeAddImport(MOCKITO_IMPORT_FQN_PREFX, "lenient"); - hasReferencesForWhen = false; - } - visitor.maybeAddImport(MOCKITO_IMPORT_FQN_PREFX, "when", hasReferencesForWhen); + if (this.rewriteFailed) { + return; + } + + setNextStatementCoordinates(++numStatementsAdded); + // do this last making sure rewrite worked and specify onlyifReferenced=false because framework cannot find static + // reference for when method invocation when another static mockit reference is added + visitor.maybeAddImport(MOCKITO_IMPORT_FQN_PREFX, "when", false); + if (lenient) { + visitor.maybeAddImport(MOCKITO_IMPORT_FQN_PREFX, "lenient"); } } @@ -218,29 +207,37 @@ private void rewriteVerify(J.MethodInvocation invocation, @Nullable Expression t verifyCoordinates = methodBody.getCoordinates().lastStatement(); } this.rewriteFailed = !rewriteTemplate(verifyTemplate, templateParams, verifyCoordinates); - if (!this.rewriteFailed) { - if (this.blockType == Verifications) { - setNextStatementCoordinates(++numStatementsAdded); // for Expectations, verify statements added to end of method - } + if (this.rewriteFailed) { + return; + } - // do this last making sure rewrite worked and specify hasReference=false because in verify case framework - // cannot find the static reference - visitor.maybeAddImport(MOCKITO_IMPORT_FQN_PREFX, "verify", false); - if (!verificationMode.isEmpty()) { - visitor.maybeAddImport(MOCKITO_IMPORT_FQN_PREFX, verificationMode); - } + if (this.blockType == Verifications) { + setNextStatementCoordinates(++numStatementsAdded); // for Expectations, verify statements added to end of method + } + + // do this last making sure rewrite worked and specify onlyifReferenced=false because framework cannot find the + // static reference to verify when another static mockit reference is added + visitor.maybeAddImport(MOCKITO_IMPORT_FQN_PREFX, "verify", false); + if (!verificationMode.isEmpty()) { + visitor.maybeAddImport(MOCKITO_IMPORT_FQN_PREFX, verificationMode); } } private void setNextStatementCoordinates(int numStatementsAdded) { + if (numStatementsAdded <= 0 && bodyStatementIndex == 0) { + nextStatementCoordinates = methodBody.getCoordinates().firstStatement(); + return; + } + // the next statement coordinates are directly after the most recently written statement, calculated by // subtracting the removed jmockit block - int nextStatementIdx = bodyStatementIndex + numStatementsAdded - 1; - if (nextStatementIdx >= this.methodBody.getStatements().size()) { - rewriteFailed = true; - } else { - this.nextStatementCoordinates = this.methodBody.getStatements().get(nextStatementIdx).getCoordinates().after(); + int lastStatementIdx = bodyStatementIndex + numStatementsAdded - 1; + if (lastStatementIdx >= this.methodBody.getStatements().size()) { + this.rewriteFailed = true; + return; } + + this.nextStatementCoordinates = this.methodBody.getStatements().get(lastStatementIdx).getCoordinates().after(); } private boolean rewriteTemplate(String template, List templateParams, JavaCoordinates @@ -258,10 +255,10 @@ private boolean rewriteTemplate(String template, List templateParams, Ja return methodBody.getStatements().size() > numStatementsBefore; } - private @Nullable String getWhenTemplate(List results) { + private @Nullable String getWhenTemplate(List results, boolean lenient) { boolean buildingResults = false; StringBuilder templateBuilder = new StringBuilder(); - if (this.blockType == NonStrictExpectations) { + if (lenient) { templateBuilder.append(LENIENT_TEMPLATE_PREFIX); } templateBuilder.append(WHEN_TEMPLATE_PREFIX); @@ -422,5 +419,9 @@ private static class MockInvocationResults { private void addResult(Expression result) { results.add(result); } + + private boolean hasAnyTimes() { + return times != null || minTimes != null || maxTimes != null; + } } } diff --git a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoTest.java b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoTest.java index fc03f7aef..fcc8bece2 100644 --- a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoTest.java +++ b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitExpectationsToMockitoTest.java @@ -30,6 +30,61 @@ public void defaults(RecipeSpec spec) { setDefaultParserSettings(spec); } + @DocumentExample + @Test + void whenTimesAndResult() { + //language=java + rewriteRun( + java( + """ + import mockit.Expectations; + import mockit.Mocked; + import mockit.integration.junit5.JMockitExtension; + import org.junit.jupiter.api.extension.ExtendWith; + + import static org.junit.jupiter.api.Assertions.assertEquals; + + @ExtendWith(JMockitExtension.class) + class MyTest { + @Mocked + Object myObject; + + void test() { + new Expectations() {{ + myObject.toString(); + result = "foo"; + times = 2; + }}; + assertEquals("foo", myObject.toString()); + assertEquals("foo", myObject.toString()); + } + } + """, + """ + import org.junit.jupiter.api.extension.ExtendWith; + import org.mockito.Mock; + import org.mockito.junit.jupiter.MockitoExtension; + + import static org.junit.jupiter.api.Assertions.assertEquals; + import static org.mockito.Mockito.*; + + @ExtendWith(MockitoExtension.class) + class MyTest { + @Mock + Object myObject; + + void test() { + when(myObject.toString()).thenReturn("foo"); + assertEquals("foo", myObject.toString()); + assertEquals("foo", myObject.toString()); + verify(myObject, times(2)).toString(); + } + } + """ + ) + ); + } + @DocumentExample @Test void whenNoResultNoTimes() { @@ -77,7 +132,6 @@ void test() { ); } - @DocumentExample @Test void whenNoResultNoTimesNoArgs() { //language=java @@ -124,6 +178,56 @@ void test() { ); } + @Test + void whenHasResultNoTimes() { + //language=java + rewriteRun( + java( + """ + import mockit.Expectations; + import mockit.Mocked; + import mockit.integration.junit5.JMockitExtension; + import org.junit.jupiter.api.extension.ExtendWith; + + import static org.junit.jupiter.api.Assertions.assertEquals; + + @ExtendWith(JMockitExtension.class) + class MyTest { + @Mocked + Object myObject; + + void test() { + new Expectations() {{ + myObject.toString(); + result = "foo"; + }}; + assertEquals("foo", myObject.toString()); + } + } + """, + """ + import org.junit.jupiter.api.extension.ExtendWith; + import org.mockito.Mock; + import org.mockito.junit.jupiter.MockitoExtension; + + import static org.junit.jupiter.api.Assertions.assertEquals; + import static org.mockito.Mockito.when; + + @ExtendWith(MockitoExtension.class) + class MyTest { + @Mock + Object myObject; + + void test() { + when(myObject.toString()).thenReturn("foo"); + assertEquals("foo", myObject.toString()); + } + } + """ + ) + ); + } + @Test void whenNullResult() { //language=java diff --git a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitNonStrictExpectationsToMockitoTest.java b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitNonStrictExpectationsToMockitoTest.java index 6abe677a7..bc186bb48 100644 --- a/src/test/java/org/openrewrite/java/testing/jmockit/JMockitNonStrictExpectationsToMockitoTest.java +++ b/src/test/java/org/openrewrite/java/testing/jmockit/JMockitNonStrictExpectationsToMockitoTest.java @@ -1083,4 +1083,106 @@ void test() { ) ); } + + @Test + void whenTimes() { + //language=java + rewriteRun( + java( + """ + import mockit.NonStrictExpectations; + import mockit.Mocked; + import mockit.integration.junit4.JMockit; + import org.junit.runner.RunWith; + + @RunWith(JMockit.class) + class MyTest { + @Mocked + Object myObject; + + void test() { + new NonStrictExpectations() {{ + myObject.wait(anyLong, anyInt); + times = 3; + }}; + myObject.wait(10L, 10); + myObject.wait(10L, 10); + myObject.wait(10L, 10); + } + } + """, + """ + import org.junit.runner.RunWith; + import org.mockito.Mock; + import org.mockito.junit.MockitoJUnitRunner; + + import static org.mockito.Mockito.*; + + @RunWith(MockitoJUnitRunner.class) + class MyTest { + @Mock + Object myObject; + + void test() { + myObject.wait(10L, 10); + myObject.wait(10L, 10); + myObject.wait(10L, 10); + verify(myObject, times(3)).wait(anyLong(), anyInt()); + } + } + """ + ) + ); + } + + @Test + void whenTimesAndResult() { + //language=java + rewriteRun( + java( + """ + import mockit.NonStrictExpectations; + import mockit.Mocked; + import mockit.integration.junit4.JMockit; + import org.junit.runner.RunWith; + + @RunWith(JMockit.class) + class MyTest { + @Mocked + Object myObject; + + void test() { + new NonStrictExpectations() {{ + myObject.toString(); + result = "foo"; + times = 2; + }}; + myObject.toString(); + myObject.toString(); + } + } + """, + """ + import org.junit.runner.RunWith; + import org.mockito.Mock; + import org.mockito.junit.MockitoJUnitRunner; + + import static org.mockito.Mockito.*; + + @RunWith(MockitoJUnitRunner.class) + class MyTest { + @Mock + Object myObject; + + void test() { + when(myObject.toString()).thenReturn("foo"); + myObject.toString(); + myObject.toString(); + verify(myObject, times(2)).toString(); + } + } + """ + ) + ); + } }