diff --git a/pkg/globalmanager/controllers/lifelonglearning/lifelonglearningjob.go b/pkg/globalmanager/controllers/lifelonglearning/lifelonglearningjob.go index 174a3c768..f17b1a04e 100644 --- a/pkg/globalmanager/controllers/lifelonglearning/lifelonglearningjob.go +++ b/pkg/globalmanager/controllers/lifelonglearning/lifelonglearningjob.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "fmt" + "strconv" "strings" "time" @@ -542,6 +543,37 @@ func IsJobFinished(j *sednav1.LifelongLearningJob) bool { return false } +// isCompletedInitialTraining checks whether job has completed initial train task. +func (c *Controller) hasCompletedInitialTraining(jobConditions []sednav1.LLJobCondition) bool { + for i := 0; i < len(jobConditions); i++ { + jobCond := jobConditions[i] + if jobCond.Stage == sednav1.LLJobTrain && jobCond.Type == sednav1.LLJobStageCondCompleted { + return true + } + } + return false +} + +func (c *Controller) getCloudKBIndex(jobConditions []sednav1.LLJobCondition) string { + for i := len(jobConditions) - 1; i >= 0; i-- { + jobCond := jobConditions[i] + var cond ConditionData + if jobCond.Stage == sednav1.LLJobTrain && jobCond.Type == sednav1.LLJobStageCondCompleted { + if err := (&cond).Unmarshal([]byte(jobCond.Data)); err != nil { + continue + } + + if cond.Output == nil || len(cond.Output.Models) == 0 { + continue + } + + model := cond.Output.Models[0] + return model.GetURL() + } + } + return "" +} + func (c *Controller) createPod(job *sednav1.LifelongLearningJob, podtype sednav1.LLJobStage) (err error) { ctx := context.Background() var podTemplate *v1.PodTemplateSpec @@ -571,8 +603,10 @@ func (c *Controller) createPod(job *sednav1.LifelongLearningJob, podtype sednav1 return err } + jobConditions := job.Status.Conditions + // get all url for train and eval from data in condition - condDataStr := job.Status.Conditions[len(job.Status.Conditions)-1].Data + condDataStr := jobConditions[len(job.Status.Conditions)-1].Data klog.V(2).Infof("lifelonglearning job %v/%v data condition:%s", job.Namespace, job.Name, condDataStr) var cond ConditionData (&cond).Unmarshal([]byte(condDataStr)) @@ -598,13 +632,19 @@ func (c *Controller) createPod(job *sednav1.LifelongLearningJob, podtype sednav1 podTemplate = &job.Spec.TrainSpec.Template // Env parameters for train + hasCompletedInitialTraining := c.hasCompletedInitialTraining(jobConditions) + workerParam.Env = map[string]string{ - "NAMESPACE": job.Namespace, - "JOB_NAME": job.Name, - "WORKER_NAME": "train-worker-" + utilrand.String(5), + "NAMESPACE": job.Namespace, + "JOB_NAME": job.Name, + "WORKER_NAME": "train-worker-" + utilrand.String(5), + "HAS_COMPLETED_INITIAL_TRAINING": strconv.FormatBool(hasCompletedInitialTraining), + "LC_SERVER": c.cfg.LC.Server, + "KB_SERVER": c.cfg.KB.Server, + } - "LC_SERVER": c.cfg.LC.Server, - "KB_SERVER": c.cfg.KB.Server, + if hasCompletedInitialTraining { + workerParam.Env["CLOUD_KB_INDEX"] = c.getCloudKBIndex(jobConditions) } workerParam.Mounts = append(workerParam.Mounts,