Skip to content

Commit

Permalink
Fix abort task/yide (#386)
Browse files Browse the repository at this point in the history
* fix abort task

* add test case

* fix bug

add log
  • Loading branch information
shikimoe authored Jan 20, 2021
1 parent 4afdfb2 commit 1ead53e
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import javax.inject.Singleton;
import java.time.OffsetDateTime;
import java.time.temporal.ChronoUnit;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -198,16 +199,35 @@ private void notifyFinished(Long attemptId, TaskRunStatus status, OperatorReport
}

private boolean killTaskAttempt(Long attemptId) {
logger.debug("going to abort taskAttempt,attemptId = {}", attemptId);
if (!workerPool.containsKey(attemptId)) {
return false;
Optional<TaskAttempt> taskAttemptOptional = taskRunDao.fetchAttemptById(attemptId);
if (!taskAttemptOptional.isPresent()) {
logger.warn("taskAttempt not found,attemptId = {}", attemptId);
return false;
}
TaskAttempt taskAttempt = taskAttemptOptional.get();
if (taskAttempt.getStatus() == TaskRunStatus.QUEUED) {
logger.debug("taskAttempt to be abort has add to queue,attemptId = {}", attemptId);
Iterator<TaskAttempt> iterator = taskAttemptQueue.iterator();
while (iterator.hasNext()){
TaskAttempt queuedTaskAttempt = iterator.next();
if (queuedTaskAttempt.getId().equals(attemptId)) {
iterator.remove();
logger.debug("remove taskAttempt from queue,attemptId = {}", attemptId);
break;
}
}
}
miscService.changeTaskAttemptStatus(attemptId, TaskRunStatus.ABORTED);
} else {
HeartBeatMessage message = workerPool.get(attemptId);
Worker worker = workerFactory.getWorker(message);
worker.killTask(true);
Thread thread = new Thread(new WaitAbort(attemptId));
thread.start();
return true;
}
return true;
}

private ExecCommand buildExecCommand(TaskAttempt attempt) {
Expand Down Expand Up @@ -335,15 +355,22 @@ public void run() {
logger.error("taskAttemptId = {} acquire worker token failed", taskAttempt.getId());
throw ExceptionUtils.wrapIfChecked(e);
}
TaskAttempt taskAttemptToRun = taskRunDao.fetchAttemptById(taskAttempt.getId()).get();
if(taskAttemptToRun.getStatus().isFinished()){
logger.info("taskAttemptToRun is finished,attemptId = {},status = {}",taskAttemptToRun.getId(),taskAttemptToRun.getStatus().name());
workerToken.release();
logger.debug("taskAttemptId = {} release worker token, current size = {}", taskAttemptToRun.getId(), workerToken.availablePermits());
return;
}
try {
workerPool.put(taskAttempt.getId(), initHeartBeatByTaskAttempt(taskAttempt));
workerPool.put(taskAttemptToRun.getId(), initHeartBeatByTaskAttempt(taskAttemptToRun));
//taskAttempt 已经启动(重启恢复),则只加入workerPool监听心跳,正常入队和超时则重新启动
if (taskAttempt.getStatus().equals(TaskRunStatus.QUEUED) || taskAttempt.getStatus().equals(TaskRunStatus.ERROR)) {
ExecCommand command = buildExecCommand(taskAttempt);
if (taskAttemptToRun.getStatus().equals(TaskRunStatus.QUEUED) || taskAttemptToRun.getStatus().equals(TaskRunStatus.ERROR)) {
ExecCommand command = buildExecCommand(taskAttemptToRun);
startWorker(command);
}
} catch (Exception e) {
logger.error("taskAttemptId = {} could start worker ", taskAttempt.getId(), e);
logger.error("taskAttemptId = {} could start worker ", taskAttemptToRun.getId(), e);
workerToken.release();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -898,6 +899,48 @@ public void testStop_attempt_abort_throws_exception() throws IOException {
assertThat(workerToken.availablePermits(), is(8));
}

@Test
public void abortTaskAttemptInQueue(){
Reflect.on(executor).set("workerToken",new Semaphore(0));
//prepare
TaskAttempt taskAttempt = prepareAttempt(TestOperator1.class);

executor.submit(taskAttempt);

//verify
TaskAttempt saved = taskRunDao.fetchAttemptById(taskAttempt.getId()).get();
assertThat(saved.getStatus(),is(TaskRunStatus.QUEUED));
LinkedBlockingQueue<TaskAttempt> taskAttemptQueue = Reflect.on(executor).field("taskAttemptQueue").get();
assertThat(taskAttemptQueue,hasSize(1));
executor.cancel(taskAttempt.getId());
awaitUntilAttemptAbort(taskAttempt.getId());
// events
assertStatusProgress(taskAttempt.getId(),
TaskRunStatus.CREATED,
TaskRunStatus.QUEUED,
TaskRunStatus.ABORTED);
taskAttemptQueue = Reflect.on(executor).field("taskAttemptQueue").get();
assertThat(taskAttemptQueue,hasSize(0));


}

@Test
public void abortTaskAttemptCreated(){
//prepare
TaskAttempt taskAttempt = prepareAttempt(TestOperator1.class);
//verify
TaskAttempt saved = taskRunDao.fetchAttemptById(taskAttempt.getId()).get();
assertThat(saved.getStatus(),is(TaskRunStatus.CREATED));
executor.cancel(taskAttempt.getId());
awaitUntilAttemptAbort(taskAttempt.getId());
// events
assertStatusProgress(taskAttempt.getId(),
TaskRunStatus.CREATED,
TaskRunStatus.ABORTED);

}

private TaskAttempt prepareAttempt(Class<? extends KunOperator> operatorClass) {
return prepareAttempt(operatorClass, operatorClass.getSimpleName());
}
Expand Down

0 comments on commit 1ead53e

Please sign in to comment.