diff --git a/app/com/linkedin/drelephant/spark/fetchers/SparkRestClient.scala b/app/com/linkedin/drelephant/spark/fetchers/SparkRestClient.scala index 48adb9d78..4a6112248 100644 --- a/app/com/linkedin/drelephant/spark/fetchers/SparkRestClient.scala +++ b/app/com/linkedin/drelephant/spark/fetchers/SparkRestClient.scala @@ -68,22 +68,17 @@ class SparkRestClient(sparkConf: SparkConf) { // Limit scope of async. async { - val lastAttemptId = applicationInfo.attempts.maxBy { _.startTime }.attemptId - lastAttemptId match { - case Some(attemptId) => { - val attemptTarget = appTarget.path(attemptId) - val futureJobDatas = async { getJobDatas(attemptTarget) } - val futureStageDatas = async { getStageDatas(attemptTarget) } - val futureExecutorSummaries = async { getExecutorSummaries(attemptTarget) } - SparkRestDerivedData( - applicationInfo, - await(futureJobDatas), - await(futureStageDatas), - await(futureExecutorSummaries) - ) - } - case None => throw new IllegalArgumentException("Spark REST API has no attempt information") - } + val lastAttemptId = applicationInfo.attempts.maxBy {_.startTime}.attemptId + val attemptTarget = lastAttemptId.map(appTarget.path).getOrElse(appTarget) + val futureJobDatas = async { getJobDatas(attemptTarget) } + val futureStageDatas = async { getStageDatas(attemptTarget) } + val futureExecutorSummaries = async { getExecutorSummaries(attemptTarget) } + SparkRestDerivedData( + applicationInfo, + await(futureJobDatas), + await(futureStageDatas), + await(futureExecutorSummaries) + ) } } diff --git a/test/com/linkedin/drelephant/spark/fetchers/SparkRestClientTest.scala b/test/com/linkedin/drelephant/spark/fetchers/SparkRestClientTest.scala index 7f325739d..a346f9225 100644 --- a/test/com/linkedin/drelephant/spark/fetchers/SparkRestClientTest.scala +++ b/test/com/linkedin/drelephant/spark/fetchers/SparkRestClientTest.scala @@ -45,17 +45,17 @@ class SparkRestClientTest extends AsyncFunSpec with Matchers { an[IllegalArgumentException] should be thrownBy(new SparkRestClient(new SparkConf())) } - it("returns the desired data from the Spark REST API") { + it("returns the desired data from the Spark REST API for cluster mode application") { import ExecutionContext.Implicits.global val fakeJerseyServer = new FakeJerseyServer() { override def configure(): Application = super.configure() match { case resourceConfig: ResourceConfig => resourceConfig - .register(classOf[FetchDataFixtures.ApiResource]) - .register(classOf[FetchDataFixtures.ApplicationResource]) - .register(classOf[FetchDataFixtures.JobsResource]) - .register(classOf[FetchDataFixtures.StagesResource]) - .register(classOf[FetchDataFixtures.ExecutorsResource]) + .register(classOf[FetchClusterModeDataFixtures.ApiResource]) + .register(classOf[FetchClusterModeDataFixtures.ApplicationResource]) + .register(classOf[FetchClusterModeDataFixtures.JobsResource]) + .register(classOf[FetchClusterModeDataFixtures.StagesResource]) + .register(classOf[FetchClusterModeDataFixtures.ExecutorsResource]) case config => config } } @@ -67,9 +67,9 @@ class SparkRestClientTest extends AsyncFunSpec with Matchers { val sparkConf = new SparkConf().set("spark.yarn.historyServer.address", s"${historyServerUri.getHost}:${historyServerUri.getPort}") val sparkRestClient = new SparkRestClient(sparkConf) - sparkRestClient.fetchData(FetchDataFixtures.APP_ID) map { restDerivedData => - restDerivedData.applicationInfo.id should be(FetchDataFixtures.APP_ID) - restDerivedData.applicationInfo.name should be(FetchDataFixtures.APP_NAME) + sparkRestClient.fetchData(FetchClusterModeDataFixtures.APP_ID) map { restDerivedData => + restDerivedData.applicationInfo.id should be(FetchClusterModeDataFixtures.APP_ID) + restDerivedData.applicationInfo.name should be(FetchClusterModeDataFixtures.APP_NAME) restDerivedData.jobDatas should not be(None) restDerivedData.stageDatas should not be(None) restDerivedData.executorSummaries should not be(None) @@ -78,6 +78,40 @@ class SparkRestClientTest extends AsyncFunSpec with Matchers { assertion } } + + it("returns the desired data from the Spark REST API for client mode application") { + import ExecutionContext.Implicits.global + val fakeJerseyServer = new FakeJerseyServer() { + override def configure(): Application = super.configure() match { + case resourceConfig: ResourceConfig => + resourceConfig + .register(classOf[FetchClientModeDataFixtures.ApiResource]) + .register(classOf[FetchClientModeDataFixtures.ApplicationResource]) + .register(classOf[FetchClientModeDataFixtures.JobsResource]) + .register(classOf[FetchClientModeDataFixtures.StagesResource]) + .register(classOf[FetchClientModeDataFixtures.ExecutorsResource]) + case config => config + } + } + + fakeJerseyServer.setUp() + + val historyServerUri = fakeJerseyServer.target.getUri + + val sparkConf = new SparkConf().set("spark.yarn.historyServer.address", s"${historyServerUri.getHost}:${historyServerUri.getPort}") + val sparkRestClient = new SparkRestClient(sparkConf) + + sparkRestClient.fetchData(FetchClusterModeDataFixtures.APP_ID) map { restDerivedData => + restDerivedData.applicationInfo.id should be(FetchClusterModeDataFixtures.APP_ID) + restDerivedData.applicationInfo.name should be(FetchClusterModeDataFixtures.APP_NAME) + restDerivedData.jobDatas should not be(None) + restDerivedData.stageDatas should not be(None) + restDerivedData.executorSummaries should not be(None) + } andThen { case assertion: Try[Assertion] => + fakeJerseyServer.tearDown() + assertion + } + } } } @@ -115,7 +149,7 @@ object SparkRestClientTest { override def getContext(cls: Class[_]): ObjectMapper = objectMapper } - object FetchDataFixtures { + object FetchClusterModeDataFixtures { val APP_ID = "application_1" val APP_NAME = "app" @@ -174,6 +208,65 @@ object SparkRestClientTest { } } + object FetchClientModeDataFixtures { + val APP_ID = "application_1" + val APP_NAME = "app" + + @Path("/api/v1") + class ApiResource { + @Path("applications/{appId}") + def getApplication(): ApplicationResource = new ApplicationResource() + + @Path("applications/{appId}/jobs") + def getJobs(): JobsResource = new JobsResource() + + @Path("applications/{appId}/stages") + def getStages(): StagesResource = new StagesResource() + + @Path("applications/{appId}/executors") + def getExecutors(): ExecutorsResource = new ExecutorsResource() + } + + @Produces(Array(MediaType.APPLICATION_JSON)) + class ApplicationResource { + @GET + def getApplication(@PathParam("appId") appId: String): ApplicationInfo = { + val t2 = System.currentTimeMillis + val t1 = t2 - 1 + val duration = 8000000L + new ApplicationInfo( + APP_ID, + APP_NAME, + Seq( + newFakeApplicationAttemptInfo(None, startTime = new Date(t2 - duration), endTime = new Date(t2)), + newFakeApplicationAttemptInfo(None, startTime = new Date(t1 - duration), endTime = new Date(t1)) + ) + ) + } + } + + @Produces(Array(MediaType.APPLICATION_JSON)) + class JobsResource { + @GET + def getJobs(@PathParam("appId") appId: String): Seq[JobData] = + Seq.empty + } + + @Produces(Array(MediaType.APPLICATION_JSON)) + class StagesResource { + @GET + def getStages(@PathParam("appId") appId: String): Seq[StageData] = + Seq.empty + } + + @Produces(Array(MediaType.APPLICATION_JSON)) + class ExecutorsResource { + @GET + def getExecutors(@PathParam("appId") appId: String): Seq[ExecutorSummary] = + Seq.empty + } + } + def newFakeApplicationAttemptInfo( attemptId: Option[String], startTime: Date,