diff --git a/scheduler.go b/scheduler.go index a5ae6b7e..ad2f8db7 100644 --- a/scheduler.go +++ b/scheduler.go @@ -535,6 +535,57 @@ func (s *scheduler) NewJob(jobDefinition JobDefinition, task Task, options ...Jo return s.addOrUpdateJob(uuid.Nil, jobDefinition, task, options) } +func (s *scheduler) verifyVariadic(taskFunc reflect.Value, tsk task, variadicStart int) error { + if err := s.verifyNonVariadic(taskFunc, tsk, variadicStart); err != nil { + return err + } + parameterType := taskFunc.Type().In(variadicStart).Elem().Kind() + if parameterType == reflect.Interface || parameterType == reflect.Pointer { + parameterType = reflect.Indirect(reflect.ValueOf(taskFunc.Type().In(variadicStart))).Kind() + } + + for i := variadicStart; i < len(tsk.parameters); i++ { + argumentType := reflect.TypeOf(tsk.parameters[i]).Kind() + if argumentType == reflect.Interface || argumentType == reflect.Pointer { + argumentType = reflect.TypeOf(tsk.parameters[i]).Elem().Kind() + } + if argumentType != parameterType { + return ErrNewJobWrongTypeOfParameters + } + } + return nil +} + +func (s *scheduler) verifyNonVariadic(taskFunc reflect.Value, tsk task, length int) error { + for i := 0; i < length; i++ { + t1 := reflect.TypeOf(tsk.parameters[i]).Kind() + if t1 == reflect.Interface || t1 == reflect.Pointer { + t1 = reflect.TypeOf(tsk.parameters[i]).Elem().Kind() + } + t2 := reflect.New(taskFunc.Type().In(i)).Elem().Kind() + if t2 == reflect.Interface || t2 == reflect.Pointer { + t2 = reflect.Indirect(reflect.ValueOf(taskFunc.Type().In(i))).Kind() + } + if t1 != t2 { + return ErrNewJobWrongTypeOfParameters + } + } + return nil +} + +func (s *scheduler) verifyParameterType(taskFunc reflect.Value, tsk task) error { + isVariadic := taskFunc.Type().IsVariadic() + if isVariadic { + variadicStart := taskFunc.Type().NumIn() - 1 + return s.verifyVariadic(taskFunc, tsk, variadicStart) + } + expectedParameterLength := taskFunc.Type().NumIn() + if len(tsk.parameters) != expectedParameterLength { + return ErrNewJobWrongNumberOfParameters + } + return s.verifyNonVariadic(taskFunc, tsk, expectedParameterLength) +} + func (s *scheduler) addOrUpdateJob(id uuid.UUID, definition JobDefinition, taskWrapper Task, options []JobOption) (Job, error) { j := internalJob{} if id == uuid.Nil { @@ -569,23 +620,8 @@ func (s *scheduler) addOrUpdateJob(id uuid.UUID, definition JobDefinition, taskW return nil, ErrNewJobTaskNotFunc } - expectedParameterLength := taskFunc.Type().NumIn() - if len(tsk.parameters) != expectedParameterLength { - return nil, ErrNewJobWrongNumberOfParameters - } - - for i := 0; i < expectedParameterLength; i++ { - t1 := reflect.TypeOf(tsk.parameters[i]).Kind() - if t1 == reflect.Interface || t1 == reflect.Pointer { - t1 = reflect.TypeOf(tsk.parameters[i]).Elem().Kind() - } - t2 := reflect.New(taskFunc.Type().In(i)).Elem().Kind() - if t2 == reflect.Interface || t2 == reflect.Pointer { - t2 = reflect.Indirect(reflect.ValueOf(taskFunc.Type().In(i))).Kind() - } - if t1 != t2 { - return nil, ErrNewJobWrongTypeOfParameters - } + if err := s.verifyParameterType(taskFunc, tsk); err != nil { + return nil, err } j.name = runtime.FuncForPC(taskFunc.Pointer()).Name() diff --git a/scheduler_test.go b/scheduler_test.go index ceacc55c..52fb8e2f 100644 --- a/scheduler_test.go +++ b/scheduler_test.go @@ -899,6 +899,31 @@ func TestScheduler_NewJobTask(t *testing.T) { NewTask(&testFuncWithParams, "one", "two"), nil, }, + { + "parameter type does not match - different argument types against variadic parameters", + NewTask(func(args ...string) {}, "one", 2), + ErrNewJobWrongTypeOfParameters, + }, + { + "all good string - variadic", + NewTask(func(args ...string) {}, "one", "two"), + nil, + }, + { + "all good mixed variadic", + NewTask(func(arg int, args ...string) {}, 1, "one", "two"), + nil, + }, + { + "all good struct - variadic", + NewTask(func(args ...interface{}) {}, struct{}{}), + nil, + }, + { + "all good no arguments passed in - variadic", + NewTask(func(args ...interface{}) {}), + nil, + }, } for _, tt := range tests {