From 93a85b32085ad70584400a6ac3652f61bee4a28c Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Tue, 28 Oct 2025 19:28:52 +0800 Subject: [PATCH 01/19] refactor: [Coda] use enums for observability task models (LogID: 202510281859330100911150896701F6B) Co-Authored-By: Coda --- .../application/convertor/task/task.go | 16 +++--- .../observability/domain/task/entity/task.go | 49 +++++++++++++++---- .../domain/task/service/task_service.go | 14 +++--- .../domain/task/service/task_service_test.go | 44 ++++++++--------- .../taskexe/processor/auto_evaluate.go | 24 ++++----- .../taskexe/processor/auto_evaluate_test.go | 24 ++++----- .../task/service/taskexe/tracehub/backfill.go | 2 +- .../service/taskexe/tracehub/backfill_test.go | 14 +++--- .../taskexe/tracehub/scheduled_task.go | 14 +++--- .../taskexe/tracehub/span_trigger_test.go | 12 ++--- .../infra/repo/mysql/convertor/task.go | 16 +++--- 11 files changed, 130 insertions(+), 99 deletions(-) diff --git a/backend/modules/observability/application/convertor/task/task.go b/backend/modules/observability/application/convertor/task/task.go index 99b2c62ee..9a5ad92ea 100644 --- a/backend/modules/observability/application/convertor/task/task.go +++ b/backend/modules/observability/application/convertor/task/task.go @@ -61,8 +61,8 @@ func TaskDO2DTO(ctx context.Context, v *entity.ObservabilityTask, userMap map[st Name: v.Name, Description: v.Description, WorkspaceID: ptr.Of(v.WorkspaceID), - TaskType: v.TaskType, - TaskStatus: ptr.Of(v.TaskStatus), + TaskType: task.TaskType(v.TaskType), + TaskStatus: ptr.Of(task.TaskStatus(v.TaskStatus)), Rule: RuleDO2DTO(v.SpanFilter, v.EffectiveTime, v.Sampler, v.BackfillEffectiveTime), TaskConfig: TaskConfigDO2DTO(v.TaskConfig), TaskDetail: taskDetail, @@ -84,8 +84,8 @@ func TaskRunDO2DTO(ctx context.Context, v *entity.TaskRun, userMap map[string]*e ID: v.ID, WorkspaceID: v.WorkspaceID, TaskID: v.TaskID, - TaskType: v.TaskType, - RunStatus: v.RunStatus, + TaskType: task.TaskRunType(v.TaskType), + RunStatus: task.RunStatus(v.RunStatus), RunDetail: RunDetailDO2DTO(v.RunDetail), BackfillRunDetail: BackfillRunDetailDO2DTO(v.BackfillDetail), RunStartAt: v.RunStartAt.UnixMilli(), @@ -339,8 +339,8 @@ func TaskDTO2DO(taskDTO *task.Task, userID string, spanFilters *entity.SpanFilte WorkspaceID: taskDTO.GetWorkspaceID(), Name: taskDTO.GetName(), Description: ptr.Of(taskDTO.GetDescription()), - TaskType: taskDTO.GetTaskType(), - TaskStatus: taskDTO.GetTaskStatus(), + TaskType: entity.TaskType(taskDTO.GetTaskType()), + TaskStatus: entity.TaskStatus(taskDTO.GetTaskStatus()), TaskDetail: RunDetailDTO2DO(taskDTO.GetTaskDetail()), SpanFilter: spanFilterDO, EffectiveTime: EffectiveTimeDTO2DO(taskDTO.GetRule().GetEffectiveTime()), @@ -471,8 +471,8 @@ func TaskRunDTO2DO(taskRun *task.TaskRun) *entity.TaskRun { ID: taskRun.ID, TaskID: taskRun.TaskID, WorkspaceID: taskRun.WorkspaceID, - TaskType: taskRun.TaskType, - RunStatus: taskRun.RunStatus, + TaskType: entity.TaskRunType(taskRun.TaskType), + RunStatus: entity.TaskRunStatus(taskRun.RunStatus), RunDetail: RunDetailDTO2DO(taskRun.RunDetail), BackfillDetail: BackfillRunDetailDTO2DO(taskRun.BackfillRunDetail), RunStartAt: time.UnixMilli(taskRun.RunStartAt), diff --git a/backend/modules/observability/domain/task/entity/task.go b/backend/modules/observability/domain/task/entity/task.go index e48ea3fe8..867f5e273 100644 --- a/backend/modules/observability/domain/task/entity/task.go +++ b/backend/modules/observability/domain/task/entity/task.go @@ -8,18 +8,49 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/dataset" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" ) +type TaskStatus string + +const ( + TaskStatusUnstarted TaskStatus = "unstarted" + TaskStatusRunning TaskStatus = "running" + TaskStatusFailed TaskStatus = "failed" + TaskStatusSuccess TaskStatus = "success" + TaskStatusPending TaskStatus = "pending" + TaskStatusDisabled TaskStatus = "disabled" +) + +type TaskType string + +const ( + TaskTypeAutoEval TaskType = "auto_evaluate" + TaskTypeAutoDataReflow TaskType = "auto_data_reflow" +) + +type TaskRunType string + +const ( + TaskRunTypeBackFill TaskRunType = "back_fill" + TaskRunTypeNewData TaskRunType = "new_data" +) + +type TaskRunStatus string + +const ( + TaskRunStatusRunning TaskRunStatus = "running" + TaskRunStatusDone TaskRunStatus = "done" +) + // do type ObservabilityTask struct { ID int64 // Task ID WorkspaceID int64 // 空间ID Name string // 任务名称 Description *string // 任务描述 - TaskType string // 任务类型 - TaskStatus string // 任务状态 + TaskType TaskType // 任务类型 + TaskStatus TaskStatus // 任务状态 TaskDetail *RunDetail // 任务运行详情 SpanFilter *SpanFilterFields // span 过滤条件 EffectiveTime *EffectiveTime // 生效时间 @@ -85,8 +116,8 @@ type TaskRun struct { ID int64 // Task Run ID TaskID int64 // Task ID WorkspaceID int64 // 空间ID - TaskType string // 任务类型 - RunStatus string // Task Run状态 + TaskType TaskRunType // 任务类型 + RunStatus TaskRunStatus // Task Run状态 RunDetail *RunDetail // Task Run运行详情 BackfillDetail *BackfillDetail // 历史回溯运行详情 RunStartAt time.Time // run 开始时间 @@ -128,7 +159,7 @@ type DataReflowRunConfig struct { func (t ObservabilityTask) IsFinished() bool { switch t.TaskStatus { - case task.TaskStatusSuccess, task.TaskStatusDisabled, task.TaskStatusPending: + case TaskStatusSuccess, TaskStatusDisabled, TaskStatusPending: return true default: return false @@ -137,7 +168,7 @@ func (t ObservabilityTask) IsFinished() bool { func (t ObservabilityTask) GetBackfillTaskRun() *TaskRun { for _, taskRunPO := range t.TaskRuns { - if taskRunPO.TaskType == task.TaskRunTypeBackFill { + if taskRunPO.TaskType == TaskRunTypeBackFill { return taskRunPO } } @@ -146,14 +177,14 @@ func (t ObservabilityTask) GetBackfillTaskRun() *TaskRun { func (t ObservabilityTask) GetCurrentTaskRun() *TaskRun { for _, taskRunPO := range t.TaskRuns { - if taskRunPO.TaskType == task.TaskRunTypeNewData && taskRunPO.RunStatus == task.TaskStatusRunning { + if taskRunPO.TaskType == TaskRunTypeNewData && taskRunPO.RunStatus == TaskRunStatusRunning { return taskRunPO } } return nil } -func (t ObservabilityTask) GetTaskttl() int64 { +func (t ObservabilityTask) GetTaskTTL() int64 { var ttl int64 if t.EffectiveTime != nil { ttl = t.EffectiveTime.EndAt - t.EffectiveTime.StartAt diff --git a/backend/modules/observability/domain/task/service/task_service.go b/backend/modules/observability/domain/task/service/task_service.go index 623c43c6f..e7b8d9d2f 100644 --- a/backend/modules/observability/domain/task/service/task_service.go +++ b/backend/modules/observability/domain/task/service/task_service.go @@ -117,7 +117,7 @@ func (t *TaskServiceImpl) CreateTask(ctx context.Context, req *CreateTaskReq) (r logs.CtxError(ctx, "task name exist") return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("task name exist")) } - proc := t.taskProcessor.GetTaskProcessor(req.Task.TaskType) + proc := t.taskProcessor.GetTaskProcessor(task.TaskType(req.Task.TaskType)) // 校验配置项是否有效 if err = proc.ValidateConfig(ctx, req.Task); err != nil { logs.CtxError(ctx, "ValidateConfig err:%v", err) @@ -176,7 +176,7 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e taskDO.Description = req.Description } if req.EffectiveTime != nil { - validEffectiveTime, err := tconv.CheckEffectiveTime(ctx, req.EffectiveTime, taskDO.TaskStatus, taskDO.EffectiveTime) + validEffectiveTime, err := tconv.CheckEffectiveTime(ctx, req.EffectiveTime, task.TaskStatus(taskDO.TaskStatus), taskDO.EffectiveTime) if err != nil { return err } @@ -186,17 +186,17 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e taskDO.Sampler.SampleRate = *req.SampleRate } if req.TaskStatus != nil { - validTaskStatus, err := tconv.CheckTaskStatus(ctx, *req.TaskStatus, taskDO.TaskStatus) + validTaskStatus, err := tconv.CheckTaskStatus(ctx, *req.TaskStatus, task.TaskStatus(taskDO.TaskStatus)) if err != nil { return err } if validTaskStatus != "" { if validTaskStatus == task.TaskStatusDisabled { // 禁用操作处理 - proc := t.taskProcessor.GetTaskProcessor(taskDO.TaskType) + proc := t.taskProcessor.GetTaskProcessor(task.TaskType(taskDO.TaskType)) var taskRun *entity.TaskRun for _, tr := range taskDO.TaskRuns { - if tr.RunStatus == task.RunStatusRunning { + if tr.RunStatus == entity.TaskRunStatusRunning { taskRun = tr break } @@ -213,7 +213,7 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e logs.CtxError(ctx, "remove non final task failed, task_id=%d, err=%v", taskDO.ID, err) } } - taskDO.TaskStatus = *req.TaskStatus + taskDO.TaskStatus = entity.TaskStatus(validTaskStatus) } } taskDO.UpdatedBy = userID @@ -362,7 +362,7 @@ func (t *TaskServiceImpl) CheckTaskName(ctx context.Context, req *CheckTaskNameR func (t *TaskServiceImpl) shouldTriggerBackfill(taskDO *entity.ObservabilityTask) bool { // 检查任务类型 taskType := taskDO.TaskType - if taskType != task.TaskTypeAutoEval && taskType != task.TaskTypeAutoDataReflow { + if taskType != entity.TaskTypeAutoEval && taskType != entity.TaskTypeAutoDataReflow { return false } diff --git a/backend/modules/observability/domain/task/service/task_service_test.go b/backend/modules/observability/domain/task/service/task_service_test.go index a65312dfe..e94619701 100755 --- a/backend/modules/observability/domain/task/service/task_service_test.go +++ b/backend/modules/observability/domain/task/service/task_service_test.go @@ -113,8 +113,8 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { reqTask := &entity.ObservabilityTask{ WorkspaceID: 123, Name: "task", - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, BackfillEffectiveTime: &entity.EffectiveTime{StartAt: time.Now().Add(time.Second).UnixMilli(), EndAt: time.Now().Add(2 * time.Second).UnixMilli()}, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{StartAt: time.Now().Add(time.Second).UnixMilli(), EndAt: time.Now().Add(2 * time.Second).UnixMilli()}, @@ -146,7 +146,7 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { proc := &fakeProcessor{validateErr: errors.New("invalid config")} svc := newTaskServiceWithProcessor(t, repoMock, nil, nil, proc, task.TaskTypeAutoEval) - reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: task.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} + reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} resp, err := svc.CreateTask(context.Background(), &CreateTaskReq{Task: reqTask}) assert.Nil(t, resp) assert.Error(t, err) @@ -166,7 +166,7 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { proc := &fakeProcessor{} svc := newTaskServiceWithProcessor(t, repoMock, nil, nil, proc, task.TaskTypeAutoEval) - reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: task.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} + reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} resp, err := svc.CreateTask(context.Background(), &CreateTaskReq{Task: reqTask}) assert.Nil(t, resp) assert.Error(t, err) @@ -189,7 +189,7 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { proc := &fakeProcessor{onCreateErr: errors.New("hook fail")} svc := newTaskServiceWithProcessor(t, repoMock, nil, nil, proc, task.TaskTypeAutoEval) - reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: task.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} + reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} resp, err := svc.CreateTask(context.Background(), &CreateTaskReq{Task: reqTask}) assert.Nil(t, resp) assert.EqualError(t, err, "hook fail") @@ -235,7 +235,7 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { defer ctrl.Finish() repoMock := repomocks.NewMockITaskRepo(ctrl) - taskDO := &entity.ObservabilityTask{TaskType: task.TaskTypeAutoEval, TaskStatus: task.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{}, Sampler: &entity.Sampler{}} + taskDO := &entity.ObservabilityTask{TaskType: entity.TaskTypeAutoEval, TaskStatus: entity.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{}, Sampler: &entity.Sampler{}} repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) proc := &fakeProcessor{} @@ -258,11 +258,11 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { repoMock := repomocks.NewMockITaskRepo(ctrl) now := time.Now() taskDO := &entity.ObservabilityTask{ - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{StartAt: startAt, EndAt: startAt + 3600000}, Sampler: &entity.Sampler{SampleRate: 0.1}, - TaskRuns: []*entity.TaskRun{{RunStatus: task.RunStatusRunning}}, + TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}}, UpdatedAt: now, UpdatedBy: "", } @@ -306,11 +306,11 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { repoMock := repomocks.NewMockITaskRepo(ctrl) taskDO := &entity.ObservabilityTask{ - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{StartAt: time.Now().UnixMilli(), EndAt: time.Now().Add(time.Hour).UnixMilli()}, Sampler: &entity.Sampler{}, - TaskRuns: []*entity.TaskRun{{RunStatus: task.RunStatusRunning}}, + TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}}, } repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) @@ -340,11 +340,11 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { startAt := time.Now().Add(2 * time.Hour).UnixMilli() repoMock := repomocks.NewMockITaskRepo(ctrl) taskDO := &entity.ObservabilityTask{ - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{StartAt: startAt, EndAt: startAt + 3600000}, Sampler: &entity.Sampler{}, - TaskRuns: []*entity.TaskRun{{RunStatus: task.RunStatusRunning}}, + TaskRuns: []*entity.TaskRun{{RunStatus: entity.TaskRunStatusRunning}}, } repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) @@ -402,8 +402,8 @@ func TestTaskServiceImpl_ListTasks(t *testing.T) { ID: 1, Name: "task", WorkspaceID: 2, - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, CreatedBy: "user1", UpdatedBy: "user2", EffectiveTime: &entity.EffectiveTime{}, @@ -485,8 +485,8 @@ func TestTaskServiceImpl_GetTask(t *testing.T) { hidden := &loop_span.FilterField{FieldName: "outer_hidden", Values: []string{"v"}, Hidden: true} taskDO := &entity.ObservabilityTask{ - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, CreatedBy: "user1", UpdatedBy: "user2", EffectiveTime: &entity.EffectiveTime{}, @@ -574,18 +574,18 @@ func TestTaskServiceImpl_shouldTriggerBackfill(t *testing.T) { service := &TaskServiceImpl{} t.Run("task type mismatch", func(t *testing.T) { - taskDO := &entity.ObservabilityTask{TaskType: "other"} + taskDO := &entity.ObservabilityTask{TaskType: entity.TaskType("other")} assert.False(t, service.shouldTriggerBackfill(taskDO)) }) t.Run("missing effective time", func(t *testing.T) { - taskDO := &entity.ObservabilityTask{TaskType: task.TaskTypeAutoEval} + taskDO := &entity.ObservabilityTask{TaskType: entity.TaskTypeAutoEval} assert.False(t, service.shouldTriggerBackfill(taskDO)) }) t.Run("valid", func(t *testing.T) { taskDO := &entity.ObservabilityTask{ - TaskType: task.TaskTypeAutoDataReflow, + TaskType: entity.TaskTypeAutoDataReflow, BackfillEffectiveTime: &entity.EffectiveTime{StartAt: 1, EndAt: 2}, } assert.True(t, service.shouldTriggerBackfill(taskDO)) diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go index 0502fa9c3..44c0a52d2 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go @@ -112,7 +112,7 @@ func (p *AutoEvaluteProcessor) Invoke(ctx context.Context, trigger *taskexe.Trig logs.CtxInfo(ctx, "[task-debug] AutoEvaluteProcessor Invoke, turns is empty") return nil } - taskTTL := trigger.Task.GetTaskttl() + taskTTL := trigger.Task.GetTaskTTL() _ = p.taskRepo.IncrTaskCount(ctx, trigger.Task.ID, taskTTL) _ = p.taskRepo.IncrTaskRunCount(ctx, trigger.Task.ID, taskRun.ID, taskTTL) taskCount, _ := p.taskRepo.GetTaskCount(ctx, trigger.Task.ID) @@ -216,20 +216,20 @@ func (p *AutoEvaluteProcessor) OnCreateTaskChange(ctx context.Context, currentTa func (p *AutoEvaluteProcessor) OnUpdateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task.TaskStatus) error { switch taskOp { case task.TaskStatusSuccess: - if currentTask.TaskStatus != task.TaskStatusDisabled { - currentTask.TaskStatus = task.TaskStatusSuccess + if currentTask.TaskStatus != task_entity.TaskStatusDisabled { + currentTask.TaskStatus = task_entity.TaskStatusSuccess } case task.TaskStatusRunning: - if currentTask.TaskStatus != task.TaskStatusDisabled && currentTask.TaskStatus != task.TaskStatusSuccess { - currentTask.TaskStatus = task.TaskStatusRunning + if currentTask.TaskStatus != task_entity.TaskStatusDisabled && currentTask.TaskStatus != task_entity.TaskStatusSuccess { + currentTask.TaskStatus = task_entity.TaskStatusRunning } case task.TaskStatusDisabled: - if currentTask.TaskStatus != task.TaskStatusDisabled { - currentTask.TaskStatus = task.TaskStatusDisabled + if currentTask.TaskStatus != task_entity.TaskStatusDisabled { + currentTask.TaskStatus = task_entity.TaskStatusDisabled } case task.TaskStatusPending: - if currentTask.TaskStatus == task.TaskStatusPending || currentTask.TaskStatus == task.TaskStatusUnstarted { - currentTask.TaskStatus = task.TaskStatusPending + if currentTask.TaskStatus == task_entity.TaskStatusPending || currentTask.TaskStatus == task_entity.TaskStatusUnstarted { + currentTask.TaskStatus = task_entity.TaskStatusPending } default: return fmt.Errorf("OnUpdateChangeProcessor, valid taskOp:%s", taskOp) @@ -315,7 +315,7 @@ func (p *AutoEvaluteProcessor) OnCreateTaskRunChange(ctx context.Context, param FromEvalSet: fromEvalSet, }) } - category := getCategory(currentTask.TaskType) + category := getCategory(task.TaskType(currentTask.TaskType)) schema := convertDatasetSchemaDTO2DO(evaluationSetSchema) logs.CtxInfo(ctx, "[auto_task] CreateDataset,category:%s", category) var datasetName, exptName string @@ -394,8 +394,8 @@ func (p *AutoEvaluteProcessor) OnCreateTaskRunChange(ctx context.Context, param taskRun := &task_entity.TaskRun{ TaskID: currentTask.ID, WorkspaceID: currentTask.WorkspaceID, - TaskType: param.RunType, - RunStatus: task.RunStatusRunning, + TaskType: task_entity.TaskRunType(param.RunType), + RunStatus: task_entity.TaskRunStatusRunning, RunStartAt: time.UnixMilli(param.RunStartAt), RunEndAt: time.UnixMilli(param.RunEndAt), CreatedAt: time.Now(), diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go index bf2298646..1bf091cc9 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go @@ -126,8 +126,8 @@ func buildTestTask(t *testing.T) *taskentity.ObservabilityTask { WorkspaceID: 202, Name: "auto-eval", CreatedBy: "1001", - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusUnstarted, + TaskType: taskentity.TaskTypeAutoEval, + TaskStatus: taskentity.TaskStatusUnstarted, EffectiveTime: &taskentity.EffectiveTime{ StartAt: start, EndAt: end, @@ -313,8 +313,8 @@ func TestAutoEvaluteProcessor_Invoke(t *testing.T) { ID: 1001, TaskID: taskObj.ID, WorkspaceID: taskObj.WorkspaceID, - TaskType: task.TaskRunTypeNewData, - RunStatus: task.RunStatusRunning, + TaskType: taskentity.TaskRunTypeNewData, + RunStatus: taskentity.TaskRunStatusRunning, TaskRunConfig: buildTaskRunConfig(schemaStr), } span := buildSpan("{\"parts\":[]}") @@ -431,14 +431,14 @@ func TestAutoEvaluteProcessor_OnUpdateTaskChange(t *testing.T) { cases := []struct { name string - initial string + initial taskentity.TaskStatus op task.TaskStatus - expect string + expect taskentity.TaskStatus }{ - {"success", task.TaskStatusRunning, task.TaskStatusSuccess, task.TaskStatusSuccess}, - {"running", task.TaskStatusPending, task.TaskStatusRunning, task.TaskStatusRunning}, - {"disable", task.TaskStatusRunning, task.TaskStatusDisabled, task.TaskStatusDisabled}, - {"pending", task.TaskStatusUnstarted, task.TaskStatusPending, task.TaskStatusPending}, + {"success", taskentity.TaskStatusRunning, task.TaskStatusSuccess, taskentity.TaskStatusSuccess}, + {"running", taskentity.TaskStatusPending, task.TaskStatusRunning, taskentity.TaskStatusRunning}, + {"disable", taskentity.TaskStatusRunning, task.TaskStatusDisabled, taskentity.TaskStatusDisabled}, + {"pending", taskentity.TaskStatusUnstarted, task.TaskStatusPending, taskentity.TaskStatusPending}, } for _, tt := range cases { @@ -555,7 +555,7 @@ func TestAutoEvaluteProcessor_OnFinishTaskChange(t *testing.T) { repoAdapter := &taskRepoMockAdapter{MockITaskRepo: repoMock} evalAdapter := &fakeEvaluationAdapter{} - taskObj := &taskentity.ObservabilityTask{TaskStatus: task.TaskStatusRunning, WorkspaceID: 123} + taskObj := &taskentity.ObservabilityTask{TaskStatus: taskentity.TaskStatusRunning, WorkspaceID: 123} taskRun := &taskentity.TaskRun{TaskRunConfig: &taskentity.TaskRunConfig{AutoEvaluateRunConfig: &taskentity.AutoEvaluateRunConfig{ExptID: 1, ExptRunID: 2}}} repoMock.EXPECT().UpdateTaskRun(gomock.Any(), gomock.Any()).Return(nil) @@ -621,7 +621,7 @@ func TestAutoEvaluteProcessor_OnCreateTaskChange(t *testing.T) { } taskObj := buildTestTask(t) - taskObj.TaskStatus = task.TaskStatusPending + taskObj.TaskStatus = taskentity.TaskStatusPending var runTypes []task.TaskRunType var statuses []task.TaskStatus diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index bfa090c06..da17fa148 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -117,7 +117,7 @@ func (h *TraceHubServiceImpl) setBackfillTask(ctx context.Context, event *entity return nil, err } taskRunDTO := tconv.TaskRunDO2DTO(ctx, taskRun, nil) - proc := h.taskProcessor.GetTaskProcessor(taskConfig.TaskType) + proc := h.taskProcessor.GetTaskProcessor(task.TaskType(taskConfig.TaskType)) sub := &spanSubscriber{ taskID: taskConfigDO.GetID(), t: taskConfigDO, diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go index eb6fa5634..4d6b2a70c 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go @@ -54,7 +54,7 @@ func TestTraceHubServiceImpl_SetBackfillTask(t *testing.T) { obsTask := &entity.ObservabilityTask{ ID: 1, WorkspaceID: 1, - TaskType: task.TaskTypeAutoEval, + TaskType: entity.TaskTypeAutoEval, SpanFilter: &entity.SpanFilterFields{ Filters: loop_span.FilterFields{ QueryAndOr: ptr.Of(loop_span.QueryAndOrEnumAnd), @@ -70,8 +70,8 @@ func TestTraceHubServiceImpl_SetBackfillTask(t *testing.T) { ID: 2, TaskID: 1, WorkspaceID: 1, - TaskType: task.TaskRunTypeBackFill, - RunStatus: task.RunStatusRunning, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-time.Minute), RunEndAt: now.Add(time.Minute), } @@ -207,8 +207,8 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_DispatchError(t *testing.T) { ID: 20, TaskID: 1, WorkspaceID: 1, - TaskType: task.TaskRunTypeNewData, - RunStatus: task.RunStatusRunning, + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-time.Minute), RunEndAt: now.Add(time.Minute), } @@ -677,8 +677,8 @@ func newDomainBackfillTaskRun(now time.Time) *entity.TaskRun { ID: 10, TaskID: 1, WorkspaceID: 2, - TaskType: task.TaskRunTypeBackFill, - RunStatus: task.RunStatusRunning, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-time.Minute), RunEndAt: now.Add(time.Minute), } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go index fe106a8ed..803a899dc 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go @@ -123,12 +123,12 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { endTime = time.UnixMilli(taskPO.EffectiveTime.EndAt) startTime = time.UnixMilli(taskPO.EffectiveTime.StartAt) } - proc := h.taskProcessor.GetTaskProcessor(taskPO.TaskType) + proc := h.taskProcessor.GetTaskProcessor(task.TaskType(taskPO.TaskType)) // Task time horizon reached // End when the task end time is reached logs.CtxInfo(ctx, "[auto_task]taskID:%d, endTime:%v, startTime:%v", taskPO.ID, endTime, startTime) if taskPO.BackfillEffectiveTime != nil && taskPO.EffectiveTime != nil && backfillTaskRun != nil { - if time.Now().After(endTime) && backfillTaskRun.RunStatus == task.RunStatusDone { + if time.Now().After(endTime) && backfillTaskRun.RunStatus == entity.TaskRunStatusDone { logs.CtxInfo(ctx, "[OnFinishTaskChange]taskID:%d, time.Now().After(endTime) && backfillTaskRun.RunStatus == task.RunStatusDone", taskPO.ID) err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ Task: taskPO, @@ -140,7 +140,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { continue } } - if backfillTaskRun.RunStatus != task.RunStatusDone { + if backfillTaskRun.RunStatus != entity.TaskRunStatusDone { lockKey := fmt.Sprintf(backfillLockKeyTemplate, taskPO.ID) locked, _, cancel, lockErr := h.locker.LockWithRenew(ctx, lockKey, transformTaskStatusLockTTL, backfillLockMaxHold) if lockErr != nil || !locked { @@ -152,7 +152,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { defer cancel() } } else if taskPO.BackfillEffectiveTime != nil && backfillTaskRun != nil { - if backfillTaskRun.RunStatus == task.RunStatusDone { + if backfillTaskRun.RunStatus == entity.TaskRunStatusDone { logs.CtxInfo(ctx, "[OnFinishTaskChange]taskID:%d, backfillTaskRun.RunStatus == task.RunStatusDone", taskPO.ID) err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ Task: taskPO, @@ -164,7 +164,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { continue } } - if backfillTaskRun.RunStatus != task.RunStatusDone { + if backfillTaskRun.RunStatus != entity.TaskRunStatusDone { lockKey := fmt.Sprintf(backfillLockKeyTemplate, taskPO.ID) locked, _, cancel, lockErr := h.locker.LockWithRenew(ctx, lockKey, transformTaskStatusLockTTL, backfillLockMaxHold) if lockErr != nil || !locked { @@ -190,7 +190,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { } } // If the task status is unstarted, create it once the task start time is reached - if taskPO.TaskStatus == task.TaskStatusUnstarted && time.Now().After(startTime) { + if taskPO.TaskStatus == entity.TaskStatusUnstarted && time.Now().After(startTime) { if !taskPO.Sampler.IsCycle { err = proc.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ CurrentTask: taskPO, @@ -221,7 +221,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { } } // Handle taskRun - if taskPO.TaskStatus == task.TaskStatusRunning || taskPO.TaskStatus == task.TaskStatusPending { + if taskPO.TaskStatus == entity.TaskStatusRunning || taskPO.TaskStatus == entity.TaskStatusPending { if taskRun == nil { logs.CtxError(ctx, "taskID:%d, taskRun is nil", taskPO.ID) continue diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go index e9d29792f..9f7d795fa 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go @@ -64,8 +64,8 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { taskDO := &entity.ObservabilityTask{ ID: 1, WorkspaceID: workspaceID, - TaskType: task.TaskTypeAutoEval, - TaskStatus: task.TaskStatusRunning, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, SpanFilter: &entity.SpanFilterFields{ PlatformType: common.PlatformTypeLoopAll, SpanListType: common.SpanListTypeAllSpan, @@ -88,8 +88,8 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { ID: 101, TaskID: 1, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, - RunStatus: task.TaskStatusRunning, + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-30 * time.Minute), RunEndAt: now.Add(30 * time.Minute), }, @@ -113,8 +113,8 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { ID: 201, TaskID: 1, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, - RunStatus: task.TaskStatusRunning, + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-15 * time.Minute), RunEndAt: now.Add(15 * time.Minute), } diff --git a/backend/modules/observability/infra/repo/mysql/convertor/task.go b/backend/modules/observability/infra/repo/mysql/convertor/task.go index f632d8d70..a788a3719 100644 --- a/backend/modules/observability/infra/repo/mysql/convertor/task.go +++ b/backend/modules/observability/infra/repo/mysql/convertor/task.go @@ -17,8 +17,8 @@ func TaskDO2PO(task *entity.ObservabilityTask) *model.ObservabilityTask { WorkspaceID: task.WorkspaceID, Name: task.Name, Description: task.Description, - TaskType: task.TaskType, - TaskStatus: task.TaskStatus, + TaskType: string(task.TaskType), + TaskStatus: string(task.TaskStatus), TaskDetail: ptr.Of(ToJSONString(task.TaskDetail)), SpanFilter: ptr.Of(ToJSONString(task.SpanFilter)), EffectiveTime: ptr.Of(ToJSONString(task.EffectiveTime)), @@ -38,8 +38,8 @@ func TaskPO2DO(task *model.ObservabilityTask) *entity.ObservabilityTask { WorkspaceID: task.WorkspaceID, Name: task.Name, Description: task.Description, - TaskType: task.TaskType, - TaskStatus: task.TaskStatus, + TaskType: entity.TaskType(task.TaskType), + TaskStatus: entity.TaskStatus(task.TaskStatus), TaskDetail: TaskDetailJSON2DO(task.TaskDetail), SpanFilter: SpanFilterJSON2DO(task.SpanFilter), EffectiveTime: EffectiveTimeJSON2DO(task.EffectiveTime), @@ -118,8 +118,8 @@ func TaskRunDO2PO(taskRun *entity.TaskRun) *model.ObservabilityTaskRun { ID: taskRun.ID, TaskID: taskRun.TaskID, WorkspaceID: taskRun.WorkspaceID, - TaskType: taskRun.TaskType, - RunStatus: taskRun.RunStatus, + TaskType: string(taskRun.TaskType), + RunStatus: string(taskRun.RunStatus), RunDetail: ptr.Of(ToJSONString(taskRun.RunDetail)), BackfillDetail: ptr.Of(ToJSONString(taskRun.BackfillDetail)), RunStartAt: taskRun.RunStartAt, @@ -135,8 +135,8 @@ func TaskRunPO2DO(taskRun *model.ObservabilityTaskRun) *entity.TaskRun { ID: taskRun.ID, TaskID: taskRun.TaskID, WorkspaceID: taskRun.WorkspaceID, - TaskType: taskRun.TaskType, - RunStatus: taskRun.RunStatus, + TaskType: entity.TaskRunType(taskRun.TaskType), + RunStatus: entity.TaskRunStatus(taskRun.RunStatus), RunDetail: RunDetailJSON2DO(taskRun.RunDetail), BackfillDetail: BackfillRunDetailJSON2DO(taskRun.BackfillDetail), RunStartAt: taskRun.RunStartAt, From fbeae7e2e3e2993854198889fac02b1bfd9bb50b Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Wed, 29 Oct 2025 16:26:48 +0800 Subject: [PATCH 02/19] refactor: [Coda] switch task processor to entity type (LogID: 202510291603100100911150893728B9B) Co-Authored-By: Coda --- .../modules/observability/application/wire.go | 4 +- .../domain/task/entity/task_test.go | 4 + .../domain/task/service/task_service.go | 122 +++++++++------- .../domain/task/service/task_service_test.go | 134 ++++++++++++------ .../taskexe/processor/auto_evaluate_test.go | 14 +- .../task/service/taskexe/processor/factory.go | 10 +- .../service/taskexe/processor/factory_test.go | 7 +- .../task/service/taskexe/tracehub/backfill.go | 6 +- .../taskexe/tracehub/scheduled_task.go | 2 +- .../taskexe/tracehub/scheduled_task_test.go | 44 +++--- .../service/taskexe/tracehub/span_trigger.go | 12 +- 11 files changed, 213 insertions(+), 146 deletions(-) create mode 100644 backend/modules/observability/domain/task/entity/task_test.go diff --git a/backend/modules/observability/application/wire.go b/backend/modules/observability/application/wire.go index 9bf613bae..a1cefa728 100644 --- a/backend/modules/observability/application/wire.go +++ b/backend/modules/observability/application/wire.go @@ -24,7 +24,6 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/auth/authservice" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/file/fileservice" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/user/userservice" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" metrics_entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/entity" @@ -33,6 +32,7 @@ import ( metric_model "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/model" metric_service_def "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/service" metric_tool "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/tool" + task_entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" trepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" taskSvc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" task_processor "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" @@ -281,7 +281,7 @@ func NewInitTaskProcessor(datasetServiceProvider *service.DatasetServiceAdaptor, evaluationService rpc.IEvaluationRPCAdapter, taskRepo trepo.ITaskRepo, ) *task_processor.TaskProcessor { taskProcessor := task_processor.NewTaskProcessor() - taskProcessor.Register(task.TaskTypeAutoEval, task_processor.NewAutoEvaluteProcessor(0, datasetServiceProvider, evalService, evaluationService, taskRepo)) + taskProcessor.Register(task_entity.TaskTypeAutoEval, task_processor.NewAutoEvaluteProcessor(0, datasetServiceProvider, evalService, evaluationService, taskRepo)) return taskProcessor } diff --git a/backend/modules/observability/domain/task/entity/task_test.go b/backend/modules/observability/domain/task/entity/task_test.go new file mode 100644 index 000000000..8edaf6d23 --- /dev/null +++ b/backend/modules/observability/domain/task/entity/task_test.go @@ -0,0 +1,4 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package entity diff --git a/backend/modules/observability/domain/task/service/task_service.go b/backend/modules/observability/domain/task/service/task_service.go index e7b8d9d2f..916370fa7 100644 --- a/backend/modules/observability/domain/task/service/task_service.go +++ b/backend/modules/observability/domain/task/service/task_service.go @@ -14,19 +14,17 @@ import ( "github.com/coze-dev/coze-loop/backend/infra/middleware/session" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" - loop_span "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + traceservice "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" - "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" ) @@ -39,9 +37,9 @@ type CreateTaskResp struct { type UpdateTaskReq struct { TaskID int64 WorkspaceID int64 - TaskStatus *task.TaskStatus + TaskStatus *entity.TaskStatus Description *string - EffectiveTime *task.EffectiveTime + EffectiveTime *entity.EffectiveTime SampleRate *float64 } type ListTasksReq struct { @@ -52,15 +50,15 @@ type ListTasksReq struct { OrderBy *common.OrderBy } type ListTasksResp struct { - Tasks []*task.Task - Total *int64 + Tasks []*entity.ObservabilityTask + Total int64 } type GetTaskReq struct { TaskID int64 WorkspaceID int64 } type GetTaskResp struct { - Task *task.Task + Task *entity.ObservabilityTask } type CheckTaskNameReq struct { WorkspaceID int64 @@ -81,33 +79,34 @@ type ITaskService interface { func NewTaskServiceImpl( tRepo repo.ITaskRepo, - userProvider rpc.IUserProvider, idGenerator idgen.IIDGenerator, backfillProducer mq.IBackfillProducer, taskProcessor *processor.TaskProcessor, + buildHelper traceservice.TraceFilterProcessorBuilder, ) (ITaskService, error) { return &TaskServiceImpl{ TaskRepo: tRepo, - userProvider: userProvider, idGenerator: idGenerator, backfillProducer: backfillProducer, taskProcessor: *taskProcessor, + buildHelper: buildHelper, }, nil } type TaskServiceImpl struct { TaskRepo repo.ITaskRepo - userProvider rpc.IUserProvider idGenerator idgen.IIDGenerator backfillProducer mq.IBackfillProducer taskProcessor processor.TaskProcessor + buildHelper traceservice.TraceFilterProcessorBuilder } func (t *TaskServiceImpl) CreateTask(ctx context.Context, req *CreateTaskReq) (resp *CreateTaskResp, err error) { + taskDO := req.Task // 校验task name是否存在 checkResp, err := t.CheckTaskName(ctx, &CheckTaskNameReq{ - WorkspaceID: req.Task.WorkspaceID, - Name: req.Task.Name, + WorkspaceID: taskDO.WorkspaceID, + Name: taskDO.Name, }) if err != nil { logs.CtxError(ctx, "CheckTaskName err:%v", err) @@ -117,33 +116,39 @@ func (t *TaskServiceImpl) CreateTask(ctx context.Context, req *CreateTaskReq) (r logs.CtxError(ctx, "task name exist") return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("task name exist")) } - proc := t.taskProcessor.GetTaskProcessor(task.TaskType(req.Task.TaskType)) + + if err := t.buildSpanFilters(ctx, taskDO); err != nil { + logs.CtxError(ctx, "buildSpanFilters err:%v", err) + return nil, err + } + + proc := t.taskProcessor.GetTaskProcessor(taskDO.TaskType) // 校验配置项是否有效 - if err = proc.ValidateConfig(ctx, req.Task); err != nil { + if err = proc.ValidateConfig(ctx, taskDO); err != nil { logs.CtxError(ctx, "ValidateConfig err:%v", err) return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg(fmt.Sprintf("config invalid:%v", err))) } - id, err := t.TaskRepo.CreateTask(ctx, req.Task) + id, err := t.TaskRepo.CreateTask(ctx, taskDO) if err != nil { return nil, err } // 创建任务的数据准备 // 数据回流任务——创建/更新输出数据集 // 自动评测历史回溯——创建空壳子 - req.Task.ID = id - if err = proc.OnCreateTaskChange(ctx, req.Task); err != nil { + taskDO.ID = id + if err = proc.OnCreateTaskChange(ctx, taskDO); err != nil { logs.CtxError(ctx, "create initial task run failed, task_id=%d, err=%v", id, err) - if err1 := t.TaskRepo.DeleteTask(ctx, req.Task); err1 != nil { + if err1 := t.TaskRepo.DeleteTask(ctx, taskDO); err1 != nil { logs.CtxError(ctx, "delete task failed, task_id=%d, err=%v", id, err1) } return nil, err } // 历史回溯数据发MQ - if t.shouldTriggerBackfill(req.Task) { + if t.shouldTriggerBackfill(taskDO) { backfillEvent := &entity.BackFillEvent{ - SpaceID: req.Task.WorkspaceID, + SpaceID: taskDO.WorkspaceID, TaskID: id, } @@ -176,24 +181,23 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e taskDO.Description = req.Description } if req.EffectiveTime != nil { - validEffectiveTime, err := tconv.CheckEffectiveTime(ctx, req.EffectiveTime, task.TaskStatus(taskDO.TaskStatus), taskDO.EffectiveTime) - if err != nil { + if err := taskDO.SetEffectiveTime(ctx, *req.EffectiveTime); err != nil { return err } - taskDO.EffectiveTime = validEffectiveTime } if req.SampleRate != nil { taskDO.Sampler.SampleRate = *req.SampleRate } if req.TaskStatus != nil { - validTaskStatus, err := tconv.CheckTaskStatus(ctx, *req.TaskStatus, task.TaskStatus(taskDO.TaskStatus)) + event, err := taskDO.SetTaskStatus(ctx, *req.TaskStatus) if err != nil { return err } - if validTaskStatus != "" { - if validTaskStatus == task.TaskStatusDisabled { + + if event != nil { + if event.After == entity.TaskStatusDisabled { // 禁用操作处理 - proc := t.taskProcessor.GetTaskProcessor(task.TaskType(taskDO.TaskType)) + proc := t.taskProcessor.GetTaskProcessor(taskDO.TaskType) var taskRun *entity.TaskRun for _, tr := range taskDO.TaskRuns { if tr.RunStatus == entity.TaskRunStatusRunning { @@ -213,7 +217,6 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e logs.CtxError(ctx, "remove non final task failed, task_id=%d, err=%v", taskDO.ID, err) } } - taskDO.TaskStatus = entity.TaskStatus(validTaskStatus) } } taskDO.UpdatedBy = userID @@ -240,22 +243,12 @@ func (t *TaskServiceImpl) ListTasks(ctx context.Context, req *ListTasksReq) (res logs.CtxInfo(ctx, "GetTasks tasks is nil") return resp, nil } - userMap := make(map[string]bool) - users := make([]string, 0) - for _, tp := range taskDOs { - userMap[tp.CreatedBy] = true - userMap[tp.UpdatedBy] = true - } - for u := range userMap { - users = append(users, u) - } - _, userInfoMap, err := t.userProvider.GetUserInfo(ctx, users) - if err != nil { - logs.CtxError(ctx, "MGetUserInfo err:%v", err) - } + + taskDOs = filterHiddenFilters(taskDOs) + return &ListTasksResp{ - Tasks: tconv.TaskDOs2DTOs(ctx, filterHiddenFilters(taskDOs), userInfoMap), - Total: ptr.Of(total), + Tasks: taskDOs, + Total: total, }, nil } @@ -269,11 +262,10 @@ func (t *TaskServiceImpl) GetTask(ctx context.Context, req *GetTaskReq) (resp *G logs.CtxError(ctx, "GetTasks tasks is nil") return resp, nil } - _, userInfoMap, err := t.userProvider.GetUserInfo(ctx, []string{taskDO.CreatedBy, taskDO.UpdatedBy}) - if err != nil { - logs.CtxError(ctx, "MGetUserInfo err:%v", err) - } - return &GetTaskResp{Task: tconv.TaskDO2DTO(ctx, filterHiddenFilters([]*entity.ObservabilityTask{taskDO})[0], userInfoMap)}, nil + + taskDO = filterHiddenFilters([]*entity.ObservabilityTask{taskDO})[0] + + return &GetTaskResp{Task: taskDO}, nil } func filterHiddenFilters(tasks []*entity.ObservabilityTask) []*entity.ObservabilityTask { @@ -385,3 +377,31 @@ func (t *TaskServiceImpl) sendBackfillMessage(ctx context.Context, event *entity return t.backfillProducer.SendBackfill(ctx, event) } + +func (t *TaskServiceImpl) buildSpanFilters(ctx context.Context, taskDO *entity.ObservabilityTask) error { + f, err := t.buildHelper.BuildPlatformRelatedFilter(ctx, taskDO.SpanFilter.PlatformType) + if err != nil { + return err + } + env := &span_filter.SpanEnv{ + WorkspaceID: taskDO.WorkspaceID, + } + + // coze场景中,需要将basic filter提前固化到数据库中,避免任务触发时重复调用coze接口 + basicFilter, forceQuery, err := f.BuildBasicSpanFilter(ctx, env) + if err != nil { + return err + } else if len(basicFilter) == 0 && !forceQuery { + logs.CtxInfo(ctx, "Build basic filter failed, platform type: [%s], workspaceID: [%d]", + taskDO.SpanFilter.PlatformType, taskDO.WorkspaceID) + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("User has no permission")) + } + + // basic filter对用户不可见 + for _, filter := range basicFilter { + filter.SetHidden(true) + } + + taskDO.SpanFilter.Filters.FilterFields = append(taskDO.SpanFilter.Filters.FilterFields, basicFilter...) + return nil +} diff --git a/backend/modules/observability/domain/task/service/task_service_test.go b/backend/modules/observability/domain/task/service/task_service_test.go index e94619701..0a1751154 100755 --- a/backend/modules/observability/domain/task/service/task_service_test.go +++ b/backend/modules/observability/domain/task/service/task_service_test.go @@ -17,15 +17,14 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" componentmq "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" - rpc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" - rpcmock "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" taskrepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" repomocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" - entitycommon "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" loop_span "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter" + span_processor "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_processor" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" ) @@ -68,6 +67,54 @@ func (f *fakeProcessor) OnFinishTaskRunChange(context.Context, taskexe.OnFinishT return f.onFinishRunErr } +type stubTraceFilterBuilder struct{} + +func (s *stubTraceFilterBuilder) BuildPlatformRelatedFilter(context.Context, loop_span.PlatformType) (span_filter.Filter, error) { + return &stubSpanFilter{}, nil +} + +func (s *stubTraceFilterBuilder) BuildGetTraceProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +func (s *stubTraceFilterBuilder) BuildListSpansProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +func (s *stubTraceFilterBuilder) BuildAdvanceInfoProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +func (s *stubTraceFilterBuilder) BuildIngestTraceProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +func (s *stubTraceFilterBuilder) BuildSearchTraceOApiProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +func (s *stubTraceFilterBuilder) BuildListSpansOApiProcessors(context.Context, span_processor.Settings) ([]span_processor.Processor, error) { + return nil, nil +} + +type stubSpanFilter struct{} + +func (s *stubSpanFilter) BuildBasicSpanFilter(context.Context, *span_filter.SpanEnv) ([]*loop_span.FilterField, bool, error) { + return nil, true, nil +} + +func (s *stubSpanFilter) BuildRootSpanFilter(context.Context, *span_filter.SpanEnv) ([]*loop_span.FilterField, error) { + return nil, nil +} + +func (s *stubSpanFilter) BuildLLMSpanFilter(context.Context, *span_filter.SpanEnv) ([]*loop_span.FilterField, error) { + return nil, nil +} + +func (s *stubSpanFilter) BuildALLSpanFilter(context.Context, *span_filter.SpanEnv) ([]*loop_span.FilterField, error) { + return nil, nil +} + type stubBackfillProducer struct { ch chan *entity.BackFillEvent err error @@ -80,11 +127,11 @@ func (s *stubBackfillProducer) SendBackfill(ctx context.Context, message *entity return s.err } -func newTaskServiceWithProcessor(t *testing.T, repo taskrepo.ITaskRepo, userProvider rpc.IUserProvider, backfill componentmq.IBackfillProducer, proc taskexe.Processor, taskType task.TaskType) *TaskServiceImpl { +func newTaskServiceWithProcessor(t *testing.T, repo taskrepo.ITaskRepo, backfill componentmq.IBackfillProducer, proc taskexe.Processor, taskType entity.TaskType) *TaskServiceImpl { t.Helper() tp := processor.NewTaskProcessor() tp.Register(taskType, proc) - service, err := NewTaskServiceImpl(repo, userProvider, nil, backfill, tp) + service, err := NewTaskServiceImpl(repo, nil, backfill, tp, &stubTraceFilterBuilder{}) assert.NoError(t, err) return service.(*TaskServiceImpl) } @@ -108,13 +155,14 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { backfillCh := make(chan *entity.BackFillEvent, 1) backfill := &stubBackfillProducer{ch: backfillCh} - svc := newTaskServiceWithProcessor(t, repoMock, nil, backfill, proc, task.TaskTypeAutoEval) + svc := newTaskServiceWithProcessor(t, repoMock, backfill, proc, entity.TaskTypeAutoEval) reqTask := &entity.ObservabilityTask{ WorkspaceID: 123, Name: "task", TaskType: entity.TaskTypeAutoEval, TaskStatus: entity.TaskStatusUnstarted, + SpanFilter: &entity.SpanFilterFields{}, BackfillEffectiveTime: &entity.EffectiveTime{StartAt: time.Now().Add(time.Second).UnixMilli(), EndAt: time.Now().Add(2 * time.Second).UnixMilli()}, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{StartAt: time.Now().Add(time.Second).UnixMilli(), EndAt: time.Now().Add(2 * time.Second).UnixMilli()}, @@ -144,9 +192,9 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { repoMock.EXPECT().ListTasks(gomock.Any(), gomock.Any()).Return(nil, int64(0), nil) proc := &fakeProcessor{validateErr: errors.New("invalid config")} - svc := newTaskServiceWithProcessor(t, repoMock, nil, nil, proc, task.TaskTypeAutoEval) + svc := newTaskServiceWithProcessor(t, repoMock, nil, proc, entity.TaskTypeAutoEval) - reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} + reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}, SpanFilter: &entity.SpanFilterFields{}} resp, err := svc.CreateTask(context.Background(), &CreateTaskReq{Task: reqTask}) assert.Nil(t, resp) assert.Error(t, err) @@ -165,8 +213,8 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { repoMock.EXPECT().ListTasks(gomock.Any(), gomock.Any()).Return([]*entity.ObservabilityTask{{}}, int64(1), nil) proc := &fakeProcessor{} - svc := newTaskServiceWithProcessor(t, repoMock, nil, nil, proc, task.TaskTypeAutoEval) - reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} + svc := newTaskServiceWithProcessor(t, repoMock, nil, proc, entity.TaskTypeAutoEval) + reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}, SpanFilter: &entity.SpanFilterFields{}} resp, err := svc.CreateTask(context.Background(), &CreateTaskReq{Task: reqTask}) assert.Nil(t, resp) assert.Error(t, err) @@ -188,8 +236,8 @@ func TestTaskServiceImpl_CreateTask(t *testing.T) { repoMock.EXPECT().DeleteTask(gomock.Any(), gomock.AssignableToTypeOf(&entity.ObservabilityTask{})).Return(nil) proc := &fakeProcessor{onCreateErr: errors.New("hook fail")} - svc := newTaskServiceWithProcessor(t, repoMock, nil, nil, proc, task.TaskTypeAutoEval) - reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}} + svc := newTaskServiceWithProcessor(t, repoMock, nil, proc, entity.TaskTypeAutoEval) + reqTask := &entity.ObservabilityTask{WorkspaceID: 1, Name: "task", TaskType: entity.TaskTypeAutoEval, Sampler: &entity.Sampler{}, EffectiveTime: &entity.EffectiveTime{}, SpanFilter: &entity.SpanFilterFields{}} resp, err := svc.CreateTask(context.Background(), &CreateTaskReq{Task: reqTask}) assert.Nil(t, resp) assert.EqualError(t, err, "hook fail") @@ -240,7 +288,7 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { proc := &fakeProcessor{} svc := &TaskServiceImpl{TaskRepo: repoMock} - svc.taskProcessor.Register(task.TaskTypeAutoEval, proc) + svc.taskProcessor.Register(entity.TaskTypeAutoEval, proc) err := svc.UpdateTask(context.Background(), &UpdateTaskReq{TaskID: 1, WorkspaceID: 2}) statusErr, ok := errorx.FromStatusError(err) @@ -273,7 +321,7 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { proc := &fakeProcessor{} svc := &TaskServiceImpl{TaskRepo: repoMock} - svc.taskProcessor.Register(task.TaskTypeAutoEval, proc) + svc.taskProcessor.Register(entity.TaskTypeAutoEval, proc) desc := "updated" newStart := startAt + 1000 @@ -283,13 +331,13 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { TaskID: 1, WorkspaceID: 2, Description: &desc, - EffectiveTime: &task.EffectiveTime{StartAt: &newStart, EndAt: &newEnd}, + EffectiveTime: &entity.EffectiveTime{StartAt: newStart, EndAt: newEnd}, SampleRate: &sampleRate, - TaskStatus: gptr.Of(task.TaskStatusDisabled), + TaskStatus: gptr.Of(entity.TaskStatusDisabled), }) assert.NoError(t, err) assert.True(t, proc.onFinishRunCalled) - assert.Equal(t, task.TaskStatusDisabled, taskDO.TaskStatus) + assert.Equal(t, entity.TaskStatusDisabled, taskDO.TaskStatus) assert.Equal(t, "user1", taskDO.UpdatedBy) if assert.NotNil(t, taskDO.Description) { assert.Equal(t, desc, *taskDO.Description) @@ -319,14 +367,14 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { proc := &fakeProcessor{} svc := &TaskServiceImpl{TaskRepo: repoMock} - svc.taskProcessor.Register(task.TaskTypeAutoEval, proc) + svc.taskProcessor.Register(entity.TaskTypeAutoEval, proc) sampleRate := 0.6 err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{ TaskID: 1, WorkspaceID: 2, SampleRate: &sampleRate, - TaskStatus: gptr.Of(task.TaskStatusDisabled), + TaskStatus: gptr.Of(entity.TaskStatusDisabled), }) assert.NoError(t, err) assert.True(t, proc.onFinishRunCalled) @@ -352,7 +400,7 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { proc := &fakeProcessor{onFinishRunErr: errors.New("finish fail")} svc := &TaskServiceImpl{TaskRepo: repoMock} - svc.taskProcessor.Register(task.TaskTypeAutoEval, proc) + svc.taskProcessor.Register(entity.TaskTypeAutoEval, proc) newStart := startAt + 1000 newEnd := startAt + 7200000 @@ -360,9 +408,9 @@ func TestTaskServiceImpl_UpdateTask(t *testing.T) { err := svc.UpdateTask(session.WithCtxUser(context.Background(), &session.User{ID: "user"}), &UpdateTaskReq{ TaskID: 1, WorkspaceID: 2, - EffectiveTime: &task.EffectiveTime{StartAt: &newStart, EndAt: &newEnd}, + EffectiveTime: &entity.EffectiveTime{StartAt: newStart, EndAt: newEnd}, SampleRate: &sampleRate, - TaskStatus: gptr.Of(task.TaskStatusDisabled), + TaskStatus: gptr.Of(entity.TaskStatusDisabled), }) assert.EqualError(t, err, "finish fail") }) @@ -391,7 +439,6 @@ func TestTaskServiceImpl_ListTasks(t *testing.T) { defer ctrl.Finish() repoMock := repomocks.NewMockITaskRepo(ctrl) - userMock := rpcmock.NewMockIUserProvider(ctrl) hiddenField := &loop_span.FilterField{FieldName: "hidden", Values: []string{"1"}, Hidden: true} visibleField := &loop_span.FilterField{FieldName: "visible", Values: []string{"val"}} @@ -414,25 +461,23 @@ func TestTaskServiceImpl_ListTasks(t *testing.T) { }}, } repoMock.EXPECT().ListTasks(gomock.Any(), gomock.Any()).Return([]*entity.ObservabilityTask{taskDO}, int64(1), nil) - userMock.EXPECT().GetUserInfo(gomock.Any(), gomock.Any()).Return(nil, map[string]*entitycommon.UserInfo{}, nil) - svc := &TaskServiceImpl{TaskRepo: repoMock, userProvider: userMock} + svc := &TaskServiceImpl{TaskRepo: repoMock} resp, err := svc.ListTasks(context.Background(), &ListTasksReq{WorkspaceID: 2, TaskFilters: &filter.TaskFilterFields{}}) assert.NoError(t, err) if assert.NotNil(t, resp) { - assert.EqualValues(t, 1, *resp.Total) + assert.EqualValues(t, 1, resp.Total) assert.Len(t, resp.Tasks, 1) - filterFields := resp.Tasks[0].GetRule().GetSpanFilters().GetFilters() - if assert.NotNil(t, filterFields) { - fields := filterFields.GetFilterFields() + task := resp.Tasks[0] + if assert.NotNil(t, task.SpanFilter) { + fields := task.SpanFilter.Filters.FilterFields assert.Len(t, fields, 2) - assert.Equal(t, "visible", fields[0].GetFieldName()) - assert.Equal(t, []string{"val"}, fields[0].GetValues()) - sub := fields[1].GetSubFilter() - if assert.NotNil(t, sub) { - subFields := sub.GetFilterFields() + assert.Equal(t, "visible", fields[0].FieldName) + assert.Equal(t, []string{"val"}, fields[0].Values) + if sub := fields[1].SubFilter; assert.NotNil(t, sub) { + subFields := sub.FilterFields assert.Len(t, subFields, 1) - assert.Equal(t, "child", subFields[0].GetFieldName()) + assert.Equal(t, "child", subFields[0].FieldName) } } } @@ -476,7 +521,6 @@ func TestTaskServiceImpl_GetTask(t *testing.T) { defer ctrl.Finish() repoMock := repomocks.NewMockITaskRepo(ctrl) - userMock := rpcmock.NewMockIUserProvider(ctrl) subHidden := &loop_span.FilterField{FieldName: "inner_hidden", Values: []string{"v"}, Hidden: true} subVisible := &loop_span.FilterField{FieldName: "inner_visible", Values: []string{"v"}} @@ -498,22 +542,20 @@ func TestTaskServiceImpl_GetTask(t *testing.T) { } repoMock.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Any(), gomock.Nil()).Return(taskDO, nil) - userMock.EXPECT().GetUserInfo(gomock.Any(), gomock.Any()).Return(nil, map[string]*entitycommon.UserInfo{}, nil) - svc := &TaskServiceImpl{TaskRepo: repoMock, userProvider: userMock} + svc := &TaskServiceImpl{TaskRepo: repoMock} resp, err := svc.GetTask(context.Background(), &GetTaskReq{TaskID: 1, WorkspaceID: 2}) assert.NoError(t, err) if assert.NotNil(t, resp) { - filters := resp.Task.GetRule().GetSpanFilters().GetFilters() - if assert.NotNil(t, filters) { - fields := filters.GetFilterFields() + task := resp.Task + if assert.NotNil(t, task.SpanFilter) { + fields := task.SpanFilter.Filters.FilterFields assert.Len(t, fields, 2) - assert.Equal(t, "outer_visible", fields[0].GetFieldName()) - sub := fields[1].GetSubFilter() - if assert.NotNil(t, sub) { - subFields := sub.GetFilterFields() + assert.Equal(t, "outer_visible", fields[0].FieldName) + if sub := fields[1].SubFilter; assert.NotNil(t, sub) { + subFields := sub.FilterFields assert.Len(t, subFields, 1) - assert.Equal(t, "inner_visible", subFields[0].GetFieldName()) + assert.Equal(t, "inner_visible", subFields[0].FieldName) } } } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go index 1bf091cc9..fe2dab656 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go @@ -543,7 +543,7 @@ func TestAutoEvaluteProcessor_OnFinishTaskRunChange(t *testing.T) { }) assert.NoError(t, err) assert.NotNil(t, evalAdapter.finishReq) - assert.Equal(t, task.RunStatusDone, taskRun.RunStatus) + assert.Equal(t, taskentity.TaskRunStatusDone, taskRun.RunStatus) } func TestAutoEvaluteProcessor_OnFinishTaskChange(t *testing.T) { @@ -572,7 +572,7 @@ func TestAutoEvaluteProcessor_OnFinishTaskChange(t *testing.T) { IsFinish: true, }) assert.NoError(t, err) - assert.Equal(t, task.TaskStatusSuccess, taskObj.TaskStatus) + assert.Equal(t, taskentity.TaskStatusSuccess, taskObj.TaskStatus) } func TestAutoEvaluteProcessor_OnFinishTaskChange_Error(t *testing.T) { @@ -623,8 +623,8 @@ func TestAutoEvaluteProcessor_OnCreateTaskChange(t *testing.T) { taskObj := buildTestTask(t) taskObj.TaskStatus = taskentity.TaskStatusPending - var runTypes []task.TaskRunType - var statuses []task.TaskStatus + var runTypes []taskentity.TaskRunType + var statuses []taskentity.TaskStatus getBackfill := repoMock.EXPECT().GetBackfillTaskRun(gomock.Any(), (*int64)(nil), taskObj.ID).Return(nil, nil) createDatasetBackfill := datasetProvider.EXPECT().CreateDataset(gomock.Any(), gomock.AssignableToTypeOf(&traceentity.Dataset{})).Return(int64(9101), nil) @@ -668,9 +668,9 @@ func TestAutoEvaluteProcessor_OnCreateTaskChange(t *testing.T) { err := proc.OnCreateTaskChange(context.Background(), taskObj) assert.NoError(t, err) - assert.Equal(t, []task.TaskRunType{task.TaskRunTypeBackFill, task.TaskRunTypeNewData}, runTypes) - assert.Equal(t, []task.TaskStatus{task.TaskStatusRunning, task.TaskStatusRunning}, statuses) - assert.Equal(t, task.TaskStatusRunning, taskObj.TaskStatus) + assert.Equal(t, []taskentity.TaskRunType{taskentity.TaskRunTypeBackFill, taskentity.TaskRunTypeNewData}, runTypes) + assert.Equal(t, []taskentity.TaskStatus{taskentity.TaskStatusRunning, taskentity.TaskStatusRunning}, statuses) + assert.Equal(t, taskentity.TaskStatusRunning, taskObj.TaskStatus) } func TestAutoEvaluteProcessor_OnCreateTaskChange_GetBackfillError(t *testing.T) { diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/factory.go b/backend/modules/observability/domain/task/service/taskexe/processor/factory.go index 131437f7f..9ccd32dc4 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/factory.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/factory.go @@ -4,26 +4,26 @@ package processor import ( - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" ) type TaskProcessor struct { - taskProcessorMap map[task.TaskType]taskexe.Processor + taskProcessorMap map[entity.TaskType]taskexe.Processor } func NewTaskProcessor() *TaskProcessor { return &TaskProcessor{} } -func (t *TaskProcessor) Register(taskType task.TaskType, taskProcessor taskexe.Processor) { +func (t *TaskProcessor) Register(taskType entity.TaskType, taskProcessor taskexe.Processor) { if t.taskProcessorMap == nil { - t.taskProcessorMap = make(map[task.TaskType]taskexe.Processor) + t.taskProcessorMap = make(map[entity.TaskType]taskexe.Processor) } t.taskProcessorMap[taskType] = taskProcessor } -func (t *TaskProcessor) GetTaskProcessor(taskType task.TaskType) taskexe.Processor { +func (t *TaskProcessor) GetTaskProcessor(taskType entity.TaskType) taskexe.Processor { datasetProvider, ok := t.taskProcessorMap[taskType] if !ok { return NewNoopTaskProcessor() diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go index 132644984..466b997ba 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" ) @@ -18,13 +19,13 @@ func TestTaskProcessor_RegisterAndGet(t *testing.T) { taskProcessor := NewTaskProcessor() - defaultProcessor := taskProcessor.GetTaskProcessor("unknown") + defaultProcessor := taskProcessor.GetTaskProcessor(entity.TaskType("unknown")) _, ok := defaultProcessor.(*NoopTaskProcessor) assert.True(t, ok) registered := NewNoopTaskProcessor() - taskProcessor.Register(task.TaskTypeAutoEval, registered) - assert.Equal(t, registered, taskProcessor.GetTaskProcessor(task.TaskTypeAutoEval)) + taskProcessor.Register(entity.TaskTypeAutoEval, registered) + assert.Equal(t, registered, taskProcessor.GetTaskProcessor(entity.TaskTypeAutoEval)) } func TestNoopTaskProcessor_Methods(t *testing.T) { diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index da17fa148..d80f2477c 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -117,7 +117,7 @@ func (h *TraceHubServiceImpl) setBackfillTask(ctx context.Context, event *entity return nil, err } taskRunDTO := tconv.TaskRunDO2DTO(ctx, taskRun, nil) - proc := h.taskProcessor.GetTaskProcessor(task.TaskType(taskConfig.TaskType)) + proc := h.taskProcessor.GetTaskProcessor(taskConfig.TaskType) sub := &spanSubscriber{ taskID: taskConfigDO.GetID(), t: taskConfigDO, @@ -368,7 +368,7 @@ func (h *TraceHubServiceImpl) doFlush(ctx context.Context, fr *flushReq, sub *sp if fr.noMore { logs.CtxInfo(ctx, "no more spans to process, task_id=%d", sub.t.GetID()) if err = sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), + Task: tconv.TaskDTO2DO(sub.t), TaskRun: tconv.TaskRunDTO2DO(sub.tr), IsFinish: false, }); err != nil { @@ -449,7 +449,7 @@ func (h *TraceHubServiceImpl) processBatchSpans(ctx context.Context, spans []*lo if taskCount+1 > sampler.GetSampleSize() { logs.CtxWarn(ctx, "taskCount+1 > sampler.GetSampleSize(), task_id=%d,SampleSize=%d", sub.taskID, sampler.GetSampleSize()) if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), + Task: tconv.TaskDTO2DO(sub.t), TaskRun: tconv.TaskRunDTO2DO(sub.tr), IsFinish: true, }); err != nil { diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go index 803a899dc..d9e1ffceb 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go @@ -123,7 +123,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { endTime = time.UnixMilli(taskPO.EffectiveTime.EndAt) startTime = time.UnixMilli(taskPO.EffectiveTime.StartAt) } - proc := h.taskProcessor.GetTaskProcessor(task.TaskType(taskPO.TaskType)) + proc := h.taskProcessor.GetTaskProcessor(taskPO.TaskType) // Task time horizon reached // End when the task end time is reached logs.CtxInfo(ctx, "[auto_task]taskID:%d, endTime:%v, startTime:%v", taskPO.ID, endTime, startTime) diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go index 1ac5e587c..b97f116dd 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go @@ -63,23 +63,23 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { backfillRun := &entity.TaskRun{ ID: 2, TaskID: 1, - TaskType: string(task.TaskRunTypeBackFill), - RunStatus: string(task.RunStatusDone), + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusDone, RunStartAt: now.Add(-3 * time.Hour), RunEndAt: now.Add(-2 * time.Hour), } currentRun := &entity.TaskRun{ ID: 3, TaskID: 1, - TaskType: string(task.TaskRunTypeNewData), - RunStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-4 * time.Hour), RunEndAt: now.Add(2 * time.Hour), } taskPO := &entity.ObservabilityTask{ ID: 1, - TaskType: string(task.TaskTypeAutoEval), - TaskStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, EffectiveTime: &entity.EffectiveTime{ StartAt: now.Add(-5 * time.Hour).UnixMilli(), EndAt: now.Add(-1 * time.Hour).UnixMilli(), @@ -95,7 +95,7 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { proc := newTrackingProcessor() tp := processor.NewTaskProcessor() - tp.Register(task.TaskTypeAutoEval, proc) + tp.Register(entity.TaskTypeAutoEval, proc) impl := &TraceHubServiceImpl{ taskRepo: mockRepo, @@ -117,8 +117,8 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { now := time.Now() taskPO := &entity.ObservabilityTask{ ID: 10, - TaskType: string(task.TaskTypeAutoEval), - TaskStatus: string(task.TaskStatusUnstarted), + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusUnstarted, EffectiveTime: &entity.EffectiveTime{ StartAt: now.Add(-2 * time.Hour).UnixMilli(), EndAt: now.Add(time.Hour).UnixMilli(), @@ -129,7 +129,7 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { proc := newTrackingProcessor() tp := processor.NewTaskProcessor() - tp.Register(task.TaskTypeAutoEval, proc) + tp.Register(entity.TaskTypeAutoEval, proc) impl := &TraceHubServiceImpl{ taskRepo: mockRepo, @@ -153,15 +153,15 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { currentRun := &entity.TaskRun{ ID: 30, TaskID: 20, - TaskType: string(task.TaskRunTypeNewData), - RunStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-2 * time.Hour), RunEndAt: now.Add(-time.Minute), } taskPO := &entity.ObservabilityTask{ ID: 20, - TaskType: string(task.TaskTypeAutoEval), - TaskStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, Sampler: &entity.Sampler{IsCycle: true}, TaskRuns: []*entity.TaskRun{currentRun}, } @@ -169,7 +169,7 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { proc := newTrackingProcessor() tp := processor.NewTaskProcessor() - tp.Register(task.TaskTypeAutoEval, proc) + tp.Register(entity.TaskTypeAutoEval, proc) impl := &TraceHubServiceImpl{ taskRepo: mockRepo, @@ -194,24 +194,24 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { backfillRun := &entity.TaskRun{ ID: 40, TaskID: 40, - TaskType: string(task.TaskRunTypeBackFill), - RunStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-time.Hour), RunEndAt: now.Add(time.Hour), } currentRun := &entity.TaskRun{ ID: 41, TaskID: 40, - TaskType: string(task.TaskRunTypeNewData), - RunStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, RunStartAt: now.Add(-time.Hour), RunEndAt: now.Add(time.Hour), } taskPO := &entity.ObservabilityTask{ ID: 40, WorkspaceID: 99, - TaskType: string(task.TaskTypeAutoEval), - TaskStatus: string(task.TaskStatusRunning), + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, BackfillEffectiveTime: &entity.EffectiveTime{StartAt: now.Add(-2 * time.Hour).UnixMilli(), EndAt: now.Add(time.Hour).UnixMilli()}, Sampler: &entity.Sampler{IsCycle: false}, TaskRuns: []*entity.TaskRun{backfillRun, currentRun}, @@ -223,7 +223,7 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { proc := newTrackingProcessor() tp := processor.NewTaskProcessor() - tp.Register(task.TaskTypeAutoEval, proc) + tp.Register(entity.TaskTypeAutoEval, proc) producer := &stubBackfillProducer{ch: make(chan *entity.BackFillEvent, 1)} impl := &TraceHubServiceImpl{ diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go index 21374e027..94918b15b 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go @@ -86,7 +86,7 @@ func (h *TraceHubServiceImpl) getSubscriberOfSpan(ctx context.Context, span *loo if !cfg.IsAllSpace && !gslice.Contains(cfg.SpaceList, taskDO.GetWorkspaceID()) { continue } - proc := h.taskProcessor.GetTaskProcessor(taskDO.TaskType) + proc := h.taskProcessor.GetTaskProcessor(entity.TaskType(taskDO.TaskType)) subscribers = append(subscribers, &spanSubscriber{ taskID: taskDO.GetID(), RWMutex: sync.RWMutex{}, @@ -153,7 +153,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S merr = multierror.Append(merr, errors.WithMessagef(err, "task is unstarted, need sub.Creative,creative processor, task_id=%d", sub.taskID)) continue } - if err := sub.processor.OnUpdateTaskChange(ctx, tconv.TaskDTO2DO(sub.t, "", nil), task.TaskStatusRunning); err != nil { + if err := sub.processor.OnUpdateTaskChange(ctx, tconv.TaskDTO2DO(sub.t), task.TaskStatusRunning); err != nil { logs.CtxWarn(ctx, "OnUpdateTaskChange, task_id=%d, err=%v", sub.taskID, err) continue } @@ -194,7 +194,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S if time.Now().After(endTime) { logs.CtxWarn(ctx, "[OnFinishTaskChange]time.Now().After(endTime) Finish processor, task_id=%d, endTime=%v, now=%v", sub.taskID, endTime, time.Now()) if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), + Task: tconv.TaskDTO2DO(sub.t), TaskRun: taskRunConfig, IsFinish: true, }); err != nil { @@ -207,7 +207,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S if taskCount+1 > sampler.GetSampleSize() { logs.CtxWarn(ctx, "[OnFinishTaskChange]taskCount+1 > sampler.GetSampleSize() Finish processor, task_id=%d", sub.taskID) if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), + Task: tconv.TaskDTO2DO(sub.t), TaskRun: taskRunConfig, IsFinish: true, }); err != nil { @@ -221,7 +221,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S if time.Now().After(cycleEndTime) { logs.CtxInfo(ctx, "[OnFinishTaskChange]time.Now().After(cycleEndTime) Finish processor, task_id=%d", sub.taskID) if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), + Task: tconv.TaskDTO2DO(sub.t), TaskRun: taskRunConfig, IsFinish: false, }); err != nil { @@ -239,7 +239,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S if taskRunCount+1 > sampler.GetCycleCount() { logs.CtxWarn(ctx, "[OnFinishTaskChange]taskRunCount+1 > sampler.GetCycleCount(), task_id=%d", sub.taskID) if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ - Task: tconv.TaskDTO2DO(sub.t, "", nil), + Task: tconv.TaskDTO2DO(sub.t), TaskRun: taskRunConfig, IsFinish: false, }); err != nil { From 5c31325423a409cb003a346d6e518160a89b516c Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Thu, 30 Oct 2025 12:33:25 +0800 Subject: [PATCH 03/19] test: [Coda] add getNonFinalTaskInfos tests (LogID: 20251030122712010091115089880C81B) Co-Authored-By: Coda --- .../application/convertor/page.go | 30 +++++ .../application/convertor/task/filter.go | 103 ++++++++++++++++++ .../domain/task/entity/filter.go | 69 ++++++++++++ .../taskexe/tracehub/scheduled_task_test.go | 86 +++++++++++++++ .../domain/trace/entity/common/page.go | 4 + 5 files changed, 292 insertions(+) create mode 100755 backend/modules/observability/application/convertor/page.go create mode 100755 backend/modules/observability/application/convertor/task/filter.go create mode 100755 backend/modules/observability/domain/task/entity/filter.go create mode 100644 backend/modules/observability/domain/trace/entity/common/page.go diff --git a/backend/modules/observability/application/convertor/page.go b/backend/modules/observability/application/convertor/page.go new file mode 100755 index 000000000..86e03821a --- /dev/null +++ b/backend/modules/observability/application/convertor/page.go @@ -0,0 +1,30 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package convertor + +import ( + kitcommon "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" + tracecommon "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" +) + +func OrderByDTO2DO(orderBy *kitcommon.OrderBy) *tracecommon.OrderBy { + if orderBy == nil { + return nil + } + return &tracecommon.OrderBy{ + Field: orderBy.GetField(), + IsAsc: orderBy.GetIsAsc(), + } +} + +func OrderByDO2DTO(orderBy *tracecommon.OrderBy) *kitcommon.OrderBy { + if orderBy == nil { + return nil + } + return &kitcommon.OrderBy{ + Field: ptr.Of(orderBy.Field), + IsAsc: ptr.Of(orderBy.IsAsc), + } +} diff --git a/backend/modules/observability/application/convertor/task/filter.go b/backend/modules/observability/application/convertor/task/filter.go new file mode 100755 index 000000000..a22c0efb2 --- /dev/null +++ b/backend/modules/observability/application/convertor/task/filter.go @@ -0,0 +1,103 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package task + +import ( + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" +) + +func TaskFiltersDTO2DO(filters *filter.TaskFilterFields) *entity.TaskFilterFields { + if filters == nil { + return nil + } + result := &entity.TaskFilterFields{} + if filters.QueryAndOr != nil { + relation := entity.QueryRelation(*filters.QueryAndOr) + result.QueryAndOr = &relation + } + if len(filters.FilterFields) == 0 { + return result + } + result.FilterFields = make([]*entity.TaskFilterField, 0, len(filters.FilterFields)) + for _, field := range filters.FilterFields { + if field == nil { + continue + } + result.FilterFields = append(result.FilterFields, taskFilterFieldDTO2DO(field)) + } + return result +} + +func taskFilterFieldDTO2DO(field *filter.TaskFilterField) *entity.TaskFilterField { + if field == nil { + return nil + } + result := &entity.TaskFilterField{ + Values: append([]string(nil), field.Values...), + SubFilter: taskFilterFieldDTO2DO(field.SubFilter), + } + if field.FieldName != nil { + name := entity.TaskFieldName(*field.FieldName) + result.FieldName = &name + } + if field.FieldType != nil { + fieldType := entity.FieldType(*field.FieldType) + result.FieldType = &fieldType + } + if field.QueryType != nil { + queryType := entity.QueryType(*field.QueryType) + result.QueryType = &queryType + } + if field.QueryAndOr != nil { + relation := entity.QueryRelation(*field.QueryAndOr) + result.QueryAndOr = &relation + } + return result +} + +func TaskFiltersDO2DTO(filters *entity.TaskFilterFields) *filter.TaskFilterFields { + if filters == nil { + return nil + } + result := &filter.TaskFilterFields{} + if filters.QueryAndOr != nil { + result.QueryAndOr = ptr.Of(filter.QueryRelation(*filters.QueryAndOr)) + } + if len(filters.FilterFields) == 0 { + return result + } + result.FilterFields = make([]*filter.TaskFilterField, 0, len(filters.FilterFields)) + for _, field := range filters.FilterFields { + if field == nil { + continue + } + result.FilterFields = append(result.FilterFields, taskFilterFieldDO2DTO(field)) + } + return result +} + +func taskFilterFieldDO2DTO(field *entity.TaskFilterField) *filter.TaskFilterField { + if field == nil { + return nil + } + result := &filter.TaskFilterField{ + Values: append([]string(nil), field.Values...), + SubFilter: taskFilterFieldDO2DTO(field.SubFilter), + } + if field.FieldName != nil { + result.FieldName = ptr.Of(string(*field.FieldName)) + } + if field.FieldType != nil { + result.FieldType = ptr.Of(filter.FieldType(*field.FieldType)) + } + if field.QueryType != nil { + result.QueryType = ptr.Of(filter.QueryType(*field.QueryType)) + } + if field.QueryAndOr != nil { + result.QueryAndOr = ptr.Of(filter.QueryRelation(*field.QueryAndOr)) + } + return result +} diff --git a/backend/modules/observability/domain/task/entity/filter.go b/backend/modules/observability/domain/task/entity/filter.go new file mode 100755 index 000000000..23f515597 --- /dev/null +++ b/backend/modules/observability/domain/task/entity/filter.go @@ -0,0 +1,69 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package entity + +// QueryType represents the operator applied to filter values. +type QueryType string + +// QueryRelation represents the logical relation between multiple filter expressions. +type QueryRelation string + +// FieldType describes the type of a field used in filter expressions. +type FieldType string + +// TaskFieldName defines the supported task field names for filtering. +type TaskFieldName string + +const ( + QueryTypeMatch QueryType = "match" + QueryTypeEq QueryType = "eq" + QueryTypeNotEq QueryType = "not_eq" + QueryTypeLte QueryType = "lte" + QueryTypeGte QueryType = "gte" + QueryTypeLt QueryType = "lt" + QueryTypeGt QueryType = "gt" + QueryTypeExist QueryType = "exist" + QueryTypeNotExist QueryType = "not_exist" + QueryTypeIn QueryType = "in" + QueryTypeNotIn QueryType = "not_in" + QueryTypeNotMatch QueryType = "not_match" + + QueryRelationAnd QueryRelation = "and" + QueryRelationOr QueryRelation = "or" + + FieldTypeString FieldType = "string" + FieldTypeLong FieldType = "long" + FieldTypeDouble FieldType = "double" + FieldTypeBool FieldType = "bool" + + TaskFieldNameTaskStatus TaskFieldName = "task_status" + TaskFieldNameTaskName TaskFieldName = "task_name" + TaskFieldNameTaskType TaskFieldName = "task_type" + TaskFieldNameSampleRate TaskFieldName = "sample_rate" + TaskFieldNameCreatedBy TaskFieldName = "created_by" +) + +// TaskFilterFields aggregates multiple TaskFilterField expressions. +type TaskFilterFields struct { + QueryAndOr *QueryRelation + FilterFields []*TaskFilterField +} + +// GetQueryAndOr returns the relation between filter expressions. +func (f *TaskFilterFields) GetQueryAndOr() string { + if f == nil || f.QueryAndOr == nil { + return string(QueryRelationAnd) + } + return string(*f.QueryAndOr) +} + +// TaskFilterField describes a single filter clause. +type TaskFilterField struct { + FieldName *TaskFieldName + FieldType *FieldType + Values []string + QueryType *QueryType + QueryAndOr *QueryRelation + SubFilter *TaskFilterField +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go index b97f116dd..4084ad504 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go @@ -18,6 +18,7 @@ import ( repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" "github.com/stretchr/testify/require" ) @@ -451,3 +452,88 @@ func TestTraceHubServiceImpl_listNonFinalTask(t *testing.T) { require.Nil(t, tasks) }) } + +func TestTraceHubServiceImpl_getNonFinalTaskInfos(t *testing.T) { + t.Parallel() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + impl := &TraceHubServiceImpl{taskRepo: mockRepo} + + tasks := []*entity.ObservabilityTask{ + { + WorkspaceID: 101, + SpanFilter: &entity.SpanFilterFields{ + Filters: loop_span.FilterFields{ + FilterFields: []*loop_span.FilterField{ + { + FieldName: "bot_id", + Values: []string{"bot-a", "bot-b"}, + }, + { + FieldName: "ignored", + SubFilter: &loop_span.FilterFields{ + FilterFields: []*loop_span.FilterField{ + { + FieldName: "bot_id", + Values: []string{"bot-c"}, + }, + }, + }, + }, + }, + }, + }, + }, + { + WorkspaceID: 202, + SpanFilter: &entity.SpanFilterFields{ + Filters: loop_span.FilterFields{ + FilterFields: []*loop_span.FilterField{ + { + FieldName: "other", + Values: []string{"value"}, + }, + }, + }, + }, + }, + { + WorkspaceID: 101, + }, + } + + mockRepo.EXPECT().ListNonFinalTasks(gomock.Any()).Return(tasks, nil) + + workspaceIDs, botIDs, resultTasks, err := impl.getNonFinalTaskInfos(context.Background()) + require.NoError(t, err) + require.ElementsMatch(t, []string{"101", "202"}, workspaceIDs) + require.ElementsMatch(t, []string{"bot-a", "bot-b", "bot-c"}, botIDs) + require.Equal(t, tasks, resultTasks) + }) + + t.Run("repo error", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + impl := &TraceHubServiceImpl{taskRepo: mockRepo} + + expectErr := errors.New("repo err") + mockRepo.EXPECT().ListNonFinalTasks(gomock.Any()).Return(nil, expectErr) + + workspaceIDs, botIDs, tasks, err := impl.getNonFinalTaskInfos(context.Background()) + require.Error(t, err) + require.ErrorIs(t, err, expectErr) + require.Nil(t, workspaceIDs) + require.Nil(t, botIDs) + require.Nil(t, tasks) + }) +} diff --git a/backend/modules/observability/domain/trace/entity/common/page.go b/backend/modules/observability/domain/trace/entity/common/page.go new file mode 100644 index 000000000..69f2bf879 --- /dev/null +++ b/backend/modules/observability/domain/trace/entity/common/page.go @@ -0,0 +1,4 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package common From 422c4ee629969ffa34bc88e54ee944f331f15c48 Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Thu, 30 Oct 2025 17:27:02 +0800 Subject: [PATCH 04/19] Refactor task. Change-Id: Ieb1b49ae6d6e5d977dbaa995609a60e703aaef0e --- .../application/convertor/page.go | 12 +- .../application/convertor/task/task.go | 110 +------ .../application/convertor/task/task_test.go | 158 +--------- .../modules/observability/application/task.go | 113 +++---- .../observability/application/task_test.go | 223 ------------- .../modules/observability/application/wire.go | 6 +- .../observability/application/wire_gen.go | 30 +- .../observability/domain/task/entity/task.go | 113 ++++++- .../domain/task/entity/task_test.go | 84 ++++- .../domain/task/repo/mocks/Task.go | 294 +++++++++--------- .../observability/domain/task/repo/task.go | 23 +- .../domain/task/service/task_service.go | 53 +--- .../domain/task/service/task_service_test.go | 7 +- .../taskexe/processor/auto_evaluate.go | 18 +- .../taskexe/processor/auto_evaluate_test.go | 14 +- .../task/service/taskexe/processor/noop.go | 7 +- .../service/taskexe/processor/utils_test.go | 14 +- .../task/service/taskexe/tracehub/backfill.go | 2 +- .../service/taskexe/tracehub/backfill_test.go | 10 +- .../taskexe/tracehub/scheduled_task.go | 103 ++++-- .../taskexe/tracehub/scheduled_task_test.go | 10 +- .../service/taskexe/tracehub/span_trigger.go | 4 +- .../taskexe/tracehub/span_trigger_test.go | 53 ++-- .../service/taskexe/tracehub/subscriber.go | 8 +- .../taskexe/tracehub/test_helpers_test.go | 2 +- .../domain/task/service/taskexe/types.go | 23 +- .../domain/trace/entity/common/page.go | 5 + .../domain/trace/entity/loop_span/filter.go | 9 + .../domain/trace/service/trace_service.go | 13 +- .../trace/service/trace_service_test.go | 6 +- .../observability/infra/repo/mysql/task.go | 117 +++---- .../infra/repo/mysql/task_run.go | 12 +- .../infra/repo/redis/{dao => }/task.go | 2 +- .../infra/repo/redis/{dao => }/task_run.go | 2 +- .../modules/observability/infra/repo/task.go | 77 ++--- .../observability/infra/repo/task_test.go | 35 +-- 36 files changed, 713 insertions(+), 1059 deletions(-) rename backend/modules/observability/infra/repo/redis/{dao => }/task.go (99%) rename backend/modules/observability/infra/repo/redis/{dao => }/task_run.go (99%) diff --git a/backend/modules/observability/application/convertor/page.go b/backend/modules/observability/application/convertor/page.go index 86e03821a..8fe8e09dd 100755 --- a/backend/modules/observability/application/convertor/page.go +++ b/backend/modules/observability/application/convertor/page.go @@ -4,26 +4,26 @@ package convertor import ( - kitcommon "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" - tracecommon "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" + "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" + entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" ) -func OrderByDTO2DO(orderBy *kitcommon.OrderBy) *tracecommon.OrderBy { +func OrderByDTO2DO(orderBy *common.OrderBy) *entity.OrderBy { if orderBy == nil { return nil } - return &tracecommon.OrderBy{ + return &entity.OrderBy{ Field: orderBy.GetField(), IsAsc: orderBy.GetIsAsc(), } } -func OrderByDO2DTO(orderBy *tracecommon.OrderBy) *kitcommon.OrderBy { +func OrderByDO2DTO(orderBy *entity.OrderBy) *common.OrderBy { if orderBy == nil { return nil } - return &kitcommon.OrderBy{ + return &common.OrderBy{ Field: ptr.Of(orderBy.Field), IsAsc: ptr.Of(orderBy.IsAsc), } diff --git a/backend/modules/observability/application/convertor/task/task.go b/backend/modules/observability/application/convertor/task/task.go index 9a5ad92ea..c7c7555fe 100644 --- a/backend/modules/observability/application/convertor/task/task.go +++ b/backend/modules/observability/application/convertor/task/task.go @@ -18,11 +18,11 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" entity_common "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" - obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" - "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/lang/slices" "github.com/coze-dev/coze-loop/backend/pkg/logs" + "github.com/samber/lo" ) func TaskDOs2DTOs(ctx context.Context, taskPOs []*entity.ObservabilityTask, userInfos map[string]*entity_common.UserInfo) []*task.Task { @@ -177,8 +177,8 @@ func SpanFilterDO2DTO(spanFilter *entity.SpanFilterFields) *filter.SpanFilterFie return &filter.SpanFilterFields{ Filters: convertor.FilterFieldsDO2DTO(&spanFilter.Filters), - PlatformType: &spanFilter.PlatformType, - SpanListType: &spanFilter.SpanListType, + PlatformType: lo.ToPtr(common.PlatformType(spanFilter.PlatformType)), + SpanListType: lo.ToPtr(common.SpanListType(spanFilter.SpanListType)), } } @@ -305,7 +305,7 @@ func UserInfoPO2DO(userInfo *entity_common.UserInfo, userID string) *common.User } } -func TaskDTO2DO(taskDTO *task.Task, userID string, spanFilters *entity.SpanFilterFields) *entity.ObservabilityTask { +func TaskDTO2DO(taskDTO *task.Task) *entity.ObservabilityTask { if taskDTO == nil { return nil } @@ -316,23 +316,8 @@ func TaskDTO2DO(taskDTO *task.Task, userID string, spanFilters *entity.SpanFilte if taskDTO.GetBaseInfo().GetUpdatedBy() != nil { updatedBy = taskDTO.GetBaseInfo().GetUpdatedBy().GetUserID() } - if userID != "" { - createdBy = userID - updatedBy = userID - } else { - if taskDTO.GetBaseInfo().GetCreatedBy() != nil { - createdBy = taskDTO.GetBaseInfo().GetCreatedBy().GetUserID() - } - if taskDTO.GetBaseInfo().GetUpdatedBy() != nil { - updatedBy = taskDTO.GetBaseInfo().GetUpdatedBy().GetUserID() - } - } - var spanFilterDO *entity.SpanFilterFields - if spanFilters != nil { - spanFilterDO = spanFilters - } else { - spanFilterDO = SpanFilterDTO2DO(taskDTO.GetRule().GetSpanFilters()) - } + + spanFilterDO := SpanFilterDTO2DO(taskDTO.GetRule().GetSpanFilters()) return &entity.ObservabilityTask{ ID: taskDTO.GetID(), @@ -359,8 +344,8 @@ func SpanFilterDTO2DO(spanFilterFields *filter.SpanFilterFields) *entity.SpanFil return nil } return &entity.SpanFilterFields{ - PlatformType: *spanFilterFields.PlatformType, - SpanListType: *spanFilterFields.SpanListType, + PlatformType: loop_span.PlatformType(*spanFilterFields.PlatformType), + SpanListType: loop_span.SpanListType(*spanFilterFields.SpanListType), Filters: *convertor.FilterFieldsDTO2DO(spanFilterFields.Filters), } } @@ -408,6 +393,7 @@ func TaskConfigDTO2DO(taskConfig *task.TaskConfig) *entity.TaskConfig { for _, autoEvaluateConfig := range taskConfig.AutoEvaluateConfigs { var fieldMappings []*entity.EvaluateFieldMapping if len(autoEvaluateConfig.FieldMappings) > 0 { + // todo tyf 这段逻辑挪到service层 var evalSetNames []string jspnPathMapping := make(map[string]string) for _, config := range autoEvaluateConfig.FieldMappings { @@ -531,82 +517,6 @@ func BackfillRunDetailDTO2DO(v *task.BackfillDetail) *entity.BackfillDetail { } } -func CheckEffectiveTime(ctx context.Context, effectiveTime *task.EffectiveTime, taskStatus task.TaskStatus, effectiveTimeDO *entity.EffectiveTime) (*entity.EffectiveTime, error) { - if effectiveTimeDO == nil { - logs.CtxError(ctx, "EffectiveTimePO2DO error") - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("effective time is nil")) - } - var validEffectiveTime entity.EffectiveTime - // 开始时间不能大于结束时间 - if effectiveTime.GetStartAt() >= effectiveTime.GetEndAt() { - logs.CtxError(ctx, "Start time must be less than end time") - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be less than end time")) - } - // 开始、结束时间不能小于当前时间 - if effectiveTimeDO.StartAt != effectiveTime.GetStartAt() && effectiveTime.GetStartAt() < time.Now().UnixMilli() { - logs.CtxError(ctx, "update time must be greater than current time") - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be greater than current time")) - } - if effectiveTimeDO.EndAt != effectiveTime.GetEndAt() && effectiveTime.GetEndAt() < time.Now().UnixMilli() { - logs.CtxError(ctx, "update time must be greater than current time") - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be greater than current time")) - } - validEffectiveTime.StartAt = effectiveTimeDO.StartAt - validEffectiveTime.EndAt = effectiveTimeDO.EndAt - switch taskStatus { - case task.TaskStatusUnstarted: - if validEffectiveTime.StartAt != 0 { - validEffectiveTime.StartAt = *effectiveTime.StartAt - } - if validEffectiveTime.EndAt != 0 { - validEffectiveTime.EndAt = *effectiveTime.EndAt - } - case task.TaskStatusRunning, task.TaskStatusPending: - if validEffectiveTime.EndAt != 0 { - validEffectiveTime.EndAt = *effectiveTime.EndAt - } - default: - logs.CtxError(ctx, "Invalid task status:%s", taskStatus) - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("invalid task status")) - } - return &validEffectiveTime, nil -} - -func CheckTaskStatus(ctx context.Context, taskStatus task.TaskStatus, currentTaskStatus task.TaskStatus) (task.TaskStatus, error) { - var validTaskStatus task.TaskStatus - // [0530]todo: 任务状态校验 - switch taskStatus { - case task.TaskStatusUnstarted: - if currentTaskStatus == task.TaskStatusUnstarted { - validTaskStatus = taskStatus - } else { - logs.CtxError(ctx, "Invalid task status:%s", taskStatus) - return "", errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("invalid task status")) - } - case task.TaskStatusRunning: - if currentTaskStatus == task.TaskStatusUnstarted || currentTaskStatus == task.TaskStatusPending { - validTaskStatus = taskStatus - } else { - logs.CtxError(ctx, "Invalid task status:%s,currentTaskStatus:%s", taskStatus, currentTaskStatus) - return "", errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("invalid task status")) - } - case task.TaskStatusPending: - if currentTaskStatus == task.TaskStatusRunning { - validTaskStatus = task.TaskStatusPending - } - case task.TaskStatusDisabled: - if currentTaskStatus == task.TaskStatusUnstarted || currentTaskStatus == task.TaskStatusPending { - validTaskStatus = task.TaskStatusDisabled - } - case task.TaskStatusSuccess: - if currentTaskStatus != task.TaskStatusSuccess { - validTaskStatus = task.TaskStatusSuccess - } - } - - return validTaskStatus, nil -} - func getLastPartAfterDot(s string) string { s = strings.TrimRight(s, ".") lastDotIndex := strings.LastIndex(s, ".") diff --git a/backend/modules/observability/application/convertor/task/task_test.go b/backend/modules/observability/application/convertor/task/task_test.go index 9a48ea739..14a0cf01d 100755 --- a/backend/modules/observability/application/convertor/task/task_test.go +++ b/backend/modules/observability/application/convertor/task/task_test.go @@ -19,8 +19,6 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" entityCommon "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" - "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" ) @@ -239,20 +237,9 @@ func TestTaskDTO2DO(t *testing.T) { }, } - overrideSpan := &entity.SpanFilterFields{ - PlatformType: kitCommon.PlatformTypeCozeloop, - SpanListType: kitCommon.SpanListTypeRootSpan, - Filters: loop_span.FilterFields{ - QueryAndOr: ptr.Of(loop_span.QueryAndOrEnumAnd), - FilterFields: []*loop_span.FilterField{}, - }, - } - - entityTask := TaskDTO2DO(dto, "override", overrideSpan) + entityTask := TaskDTO2DO(dto) if assert.NotNil(t, entityTask) { assert.Equal(t, int64(11), entityTask.ID) - assert.Equal(t, "override", entityTask.CreatedBy) - assert.Equal(t, overrideSpan, entityTask.SpanFilter) assert.NotZero(t, entityTask.CreatedAt.Unix()) assert.Equal(t, int64(1), entityTask.TaskDetail.SuccessCount) assert.Equal(t, float64(0.3), entityTask.Sampler.SampleRate) @@ -277,149 +264,6 @@ func TestSpanFilterPO2DO(t *testing.T) { assert.Nil(t, SpanFilterPO2DO(ctx, &invalid)) } -func TestCheckEffectiveTime(t *testing.T) { - t.Parallel() - - ctx := context.Background() - now := time.Now() - - getCode := func(err error) int32 { - statusErr, ok := errorx.FromStatusError(err) - if !ok { - return 0 - } - return statusErr.Code() - } - - futureStart := now.Add(2 * time.Hour).UnixMilli() - futureEnd := now.Add(3 * time.Hour).UnixMilli() - - cases := []struct { - name string - effective *kitTask.EffectiveTime - status kitTask.TaskStatus - current *entity.EffectiveTime - wantStart int64 - wantEnd int64 - wantErrCode int32 - }{ - { - name: "nil current", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureStart), EndAt: gptr.Of(futureEnd)}, - status: kitTask.TaskStatusUnstarted, - current: nil, - wantErrCode: obErrorx.CommercialCommonInvalidParamCodeCode, - }, - { - name: "start after end", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureEnd), EndAt: gptr.Of(futureStart)}, - status: kitTask.TaskStatusUnstarted, - current: &entity.EffectiveTime{StartAt: futureStart, EndAt: futureEnd}, - wantErrCode: obErrorx.CommercialCommonInvalidParamCodeCode, - }, - { - name: "update start in past", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(now.Add(-time.Hour).UnixMilli()), EndAt: gptr.Of(futureEnd)}, - status: kitTask.TaskStatusRunning, - current: &entity.EffectiveTime{StartAt: futureStart, EndAt: futureEnd}, - wantErrCode: obErrorx.CommercialCommonInvalidParamCodeCode, - }, - { - name: "update end in past", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureStart), EndAt: gptr.Of(now.Add(-time.Hour).UnixMilli())}, - status: kitTask.TaskStatusRunning, - current: &entity.EffectiveTime{StartAt: futureStart, EndAt: futureEnd}, - wantErrCode: obErrorx.CommercialCommonInvalidParamCodeCode, - }, - { - name: "unstarted updates both", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureStart), EndAt: gptr.Of(futureEnd)}, - status: kitTask.TaskStatusUnstarted, - current: &entity.EffectiveTime{StartAt: 100, EndAt: 200}, - wantStart: futureStart, - wantEnd: futureEnd, - }, - { - name: "running keeps start", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureEnd), EndAt: gptr.Of(futureEnd + 1000)}, - status: kitTask.TaskStatusRunning, - current: &entity.EffectiveTime{StartAt: 111, EndAt: 222}, - wantStart: 111, - wantEnd: futureEnd + 1000, - }, - { - name: "invalid status", - effective: &kitTask.EffectiveTime{StartAt: gptr.Of(futureStart), EndAt: gptr.Of(futureEnd)}, - status: kitTask.TaskStatus("unknown"), - current: &entity.EffectiveTime{StartAt: futureStart, EndAt: futureEnd}, - wantErrCode: obErrorx.CommercialCommonInvalidParamCodeCode, - }, - } - - for _, tc := range cases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got, err := CheckEffectiveTime(ctx, tc.effective, tc.status, tc.current) - if tc.wantErrCode != 0 { - assert.NotNil(t, err) - assert.Equal(t, tc.wantErrCode, getCode(err)) - assert.Nil(t, got) - return - } - assert.NoError(t, err) - if assert.NotNil(t, got) { - assert.Equal(t, tc.wantStart, got.StartAt) - assert.Equal(t, tc.wantEnd, got.EndAt) - } - }) - } -} - -func TestCheckTaskStatus(t *testing.T) { - t.Parallel() - - ctx := context.Background() - getCode := func(err error) int32 { - statusErr, ok := errorx.FromStatusError(err) - if !ok { - return 0 - } - return statusErr.Code() - } - - cases := []struct { - name string - status kitTask.TaskStatus - current kitTask.TaskStatus - want kitTask.TaskStatus - wantErrCode int32 - }{ - {"unstarted ok", kitTask.TaskStatusUnstarted, kitTask.TaskStatusUnstarted, kitTask.TaskStatusUnstarted, 0}, - {"unstarted invalid", kitTask.TaskStatusUnstarted, kitTask.TaskStatusRunning, "", obErrorx.CommercialCommonInvalidParamCodeCode}, - {"running ok", kitTask.TaskStatusRunning, kitTask.TaskStatusPending, kitTask.TaskStatusRunning, 0}, - {"running invalid", kitTask.TaskStatusRunning, kitTask.TaskStatusSuccess, "", obErrorx.CommercialCommonInvalidParamCodeCode}, - {"pending ok", kitTask.TaskStatusPending, kitTask.TaskStatusRunning, kitTask.TaskStatusPending, 0}, - {"disabled ok", kitTask.TaskStatusDisabled, kitTask.TaskStatusPending, kitTask.TaskStatusDisabled, 0}, - {"success ok", kitTask.TaskStatusSuccess, kitTask.TaskStatusRunning, kitTask.TaskStatusSuccess, 0}, - {"pending no transition", kitTask.TaskStatusPending, kitTask.TaskStatusDisabled, "", 0}, - } - - for _, tc := range cases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got, err := CheckTaskStatus(ctx, tc.status, tc.current) - if tc.wantErrCode != 0 { - assert.Equal(t, tc.wantErrCode, getCode(err)) - return - } - assert.NoError(t, err) - assert.Equal(t, tc.want, got) - }) - } -} - func TestGetLastPartAfterDot(t *testing.T) { t.Parallel() diff --git a/backend/modules/observability/application/task.go b/backend/modules/observability/application/task.go index d71be1a49..cfdafa2b2 100644 --- a/backend/modules/observability/application/task.go +++ b/backend/modules/observability/application/task.go @@ -9,23 +9,18 @@ import ( "time" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" - domain_task "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/task" "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor" tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" - task_processor "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - trace_Svc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" - "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-loop/backend/pkg/logs" + "github.com/samber/lo" ) type ITaskQueueConsumer interface { @@ -34,6 +29,7 @@ type ITaskQueueConsumer interface { Correction(ctx context.Context, event *entity.CorrectionEvent) error BackFill(ctx context.Context, event *entity.BackFillEvent) error } + type ITaskApplication interface { task.TaskService ITaskQueueConsumer @@ -46,8 +42,7 @@ func NewTaskApplication( evaluationService rpc.IEvaluationRPCAdapter, userService rpc.IUserProvider, tracehubSvc tracehub.ITraceHubService, - taskProcessor task_processor.TaskProcessor, - buildHelper trace_Svc.TraceFilterProcessorBuilder, + taskProcessor processor.TaskProcessor, ) (ITaskApplication, error) { return &TaskApplication{ taskSvc: taskService, @@ -57,7 +52,6 @@ func NewTaskApplication( userSvc: userService, tracehubSvc: tracehubSvc, taskProcessor: taskProcessor, - buildHelper: buildHelper, }, nil } @@ -68,8 +62,7 @@ type TaskApplication struct { evaluationSvc rpc.IEvaluationRPCAdapter userSvc rpc.IUserProvider tracehubSvc tracehub.ITraceHubService - taskProcessor task_processor.TaskProcessor - buildHelper trace_Svc.TraceFilterProcessorBuilder + taskProcessor processor.TaskProcessor } func (t *TaskApplication) CheckTaskName(ctx context.Context, req *task.CheckTaskNameRequest) (*task.CheckTaskNameResponse, error) { @@ -114,13 +107,13 @@ func (t *TaskApplication) CreateTask(ctx context.Context, req *task.CreateTaskRe if userID == "" { return nil, errorx.NewByCode(obErrorx.UserParseFailedCode) } + // 创建task - req.Task.TaskStatus = ptr.Of(domain_task.TaskStatusUnstarted) - spanFilers, err := t.buildSpanFilters(ctx, req.Task.GetRule().GetSpanFilters(), req.GetTask().GetWorkspaceID()) - if err != nil { - return nil, err - } - sResp, err := t.taskSvc.CreateTask(ctx, &service.CreateTaskReq{Task: tconv.TaskDTO2DO(req.GetTask(), userID, spanFilers)}) + taskDO := tconv.TaskDTO2DO(req.GetTask()) + taskDO.TaskStatus = entity.TaskStatusUnstarted + taskDO.CreatedBy = userID + taskDO.UpdatedBy = userID + sResp, err := t.taskSvc.CreateTask(ctx, &service.CreateTaskReq{Task: taskDO}) if err != nil { return resp, err } @@ -128,50 +121,6 @@ func (t *TaskApplication) CreateTask(ctx context.Context, req *task.CreateTaskRe return &task.CreateTaskResponse{TaskID: sResp.TaskID}, nil } -func (t *TaskApplication) buildSpanFilters(ctx context.Context, spanFilterFields *filter.SpanFilterFields, workspaceID int64) (*entity.SpanFilterFields, error) { - spanFilters := &entity.SpanFilterFields{ - PlatformType: *spanFilterFields.PlatformType, - SpanListType: *spanFilterFields.SpanListType, - } - filters := convertor.FilterFieldsDTO2DO(spanFilterFields.GetFilters()) - spanFilters.Filters = *filters - switch spanFilterFields.GetPlatformType() { - case common.PlatformTypeCozeBot, common.PlatformTypeProject, common.PlatformTypeWorkflow, common.PlatformTypeInnerCozeBot: - platformFilter, err := t.buildHelper.BuildPlatformRelatedFilter(ctx, loop_span.PlatformType(spanFilterFields.GetPlatformType())) - if err != nil { - return nil, err - } - env := &span_filter.SpanEnv{ - WorkspaceID: workspaceID, - } - basicFilter, forceQuery, err := platformFilter.BuildBasicSpanFilter(ctx, env) - if err != nil { - return nil, err - } else if len(basicFilter) == 0 && !forceQuery { // if it's null, no need to query from ck - return nil, nil - } - for _, filter := range basicFilter { - filters.FilterFields = append(filters.FilterFields, &loop_span.FilterField{ - FieldName: filter.FieldName, - FieldType: filter.FieldType, - Values: filter.Values, - QueryType: filter.QueryType, - QueryAndOr: filter.QueryAndOr, - SubFilter: filter.SubFilter, - Hidden: true, - }) - } - - return &entity.SpanFilterFields{ - Filters: *filters, - PlatformType: *spanFilterFields.PlatformType, - SpanListType: *spanFilterFields.SpanListType, - }, nil - default: - return spanFilters, nil - } -} - func (t *TaskApplication) validateCreateTaskReq(ctx context.Context, req *task.CreateTaskRequest) error { // 参数验证 if req == nil || req.GetTask() == nil { @@ -208,12 +157,16 @@ func (t *TaskApplication) UpdateTask(ctx context.Context, req *task.UpdateTaskRe strconv.FormatInt(req.GetTaskID(), 10)); err != nil { return nil, err } + var taskStatus *entity.TaskStatus + if req.TaskStatus != nil { + taskStatus = lo.ToPtr(entity.TaskStatus(req.GetTaskStatus())) + } err := t.taskSvc.UpdateTask(ctx, &service.UpdateTaskReq{ TaskID: req.GetTaskID(), WorkspaceID: req.GetWorkspaceID(), - TaskStatus: req.TaskStatus, + TaskStatus: taskStatus, Description: req.Description, - EffectiveTime: req.EffectiveTime, + EffectiveTime: tconv.EffectiveTimeDTO2DO(req.EffectiveTime), SampleRate: req.SampleRate, }) if err != nil { @@ -236,12 +189,13 @@ func (t *TaskApplication) ListTasks(ctx context.Context, req *task.ListTasksRequ false); err != nil { return resp, err } + sResp, err := t.taskSvc.ListTasks(ctx, &service.ListTasksReq{ WorkspaceID: req.GetWorkspaceID(), - TaskFilters: req.GetTaskFilters(), + TaskFilters: tconv.TaskFiltersDTO2DO(req.GetTaskFilters()), Limit: req.GetLimit(), Offset: req.GetOffset(), - OrderBy: req.GetOrderBy(), + OrderBy: convertor.OrderByDTO2DO(req.GetOrderBy()), }) if err != nil { return resp, err @@ -249,9 +203,21 @@ func (t *TaskApplication) ListTasks(ctx context.Context, req *task.ListTasksRequ if sResp == nil { return resp, nil } + + userMap := make(map[string]bool) + for _, tp := range sResp.Tasks { + userMap[tp.CreatedBy] = true + userMap[tp.UpdatedBy] = true + } + _, userInfoMap, err := t.userSvc.GetUserInfo(ctx, lo.Keys(userMap)) + if err != nil { + logs.CtxError(ctx, "MGetUserInfo err:%v", err) + } + tasks := tconv.TaskDOs2DTOs(ctx, sResp.Tasks, userInfoMap) + return &task.ListTasksResponse{ - Tasks: sResp.Tasks, - Total: sResp.Total, + Tasks: tasks, + Total: &sResp.Total, }, nil } @@ -268,6 +234,7 @@ func (t *TaskApplication) GetTask(ctx context.Context, req *task.GetTaskRequest) false); err != nil { return resp, err } + sResp, err := t.taskSvc.GetTask(ctx, &service.GetTaskReq{ TaskID: req.GetTaskID(), WorkspaceID: req.GetWorkspaceID(), @@ -279,8 +246,14 @@ func (t *TaskApplication) GetTask(ctx context.Context, req *task.GetTaskRequest) return resp, nil } + taskDO := sResp.Task + _, userInfoMap, err := t.userSvc.GetUserInfo(ctx, []string{taskDO.CreatedBy, taskDO.UpdatedBy}) + if err != nil { + logs.CtxError(ctx, "MGetUserInfo err:%v", err) + } + return &task.GetTaskResponse{ - Task: sResp.Task, + Task: tconv.TaskDO2DTO(ctx, taskDO, userInfoMap), }, nil } diff --git a/backend/modules/observability/application/task_test.go b/backend/modules/observability/application/task_test.go index db20bbf3f..357db081e 100755 --- a/backend/modules/observability/application/task_test.go +++ b/backend/modules/observability/application/task_test.go @@ -25,11 +25,6 @@ import ( svc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" svcmock "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/mocks" tracehubmock "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks" - loop_span "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - traceSvc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service" - traceSvcMock "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/mocks" - span_filter "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter" - filtermocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter/mocks" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" ) @@ -310,224 +305,6 @@ func TestTaskApplication_CreateTask(t *testing.T) { } } -func TestTaskApplication_buildSpanFilters(t *testing.T) { - t.Parallel() - - type fields struct { - builder traceSvc.TraceFilterProcessorBuilder - } - - type args struct { - spanFilters *filterdto.SpanFilterFields - workspaceID int64 - } - - tests := []struct { - name string - fieldsBuilder func(ctrl *gomock.Controller, t *testing.T, a args) fields - args args - assertFunc func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) - }{ - { - name: "non supported platform returns original", - fieldsBuilder: func(ctrl *gomock.Controller, t *testing.T, a args) fields { - return fields{} - }, - args: args{ - spanFilters: &filterdto.SpanFilterFields{ - Filters: &filterdto.FilterFields{ - FilterFields: []*filterdto.FilterField{ - { - FieldName: gptr.Of("custom_field"), - FieldType: gptr.Of(filterdto.FieldTypeString), - Values: []string{"value"}, - }, - }, - }, - PlatformType: gptr.Of(commondomain.PlatformTypeCozeloop), - SpanListType: gptr.Of(commondomain.SpanListTypeRootSpan), - }, - workspaceID: 100, - }, - assertFunc: func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) { - assert.NoError(t, err) - if assert.NotNil(t, got) { - assert.Equal(t, commondomain.PlatformTypeCozeloop, got.PlatformType) - assert.Equal(t, commondomain.SpanListTypeRootSpan, got.SpanListType) - dtoFilters := original.GetFilters().GetFilterFields() - if assert.Len(t, got.Filters.FilterFields, len(dtoFilters)) && len(dtoFilters) > 0 { - firstDTO := dtoFilters[0] - firstDomain := got.Filters.FilterFields[0] - if assert.NotNil(t, firstDTO.FieldName) { - assert.Equal(t, *firstDTO.FieldName, firstDomain.FieldName) - } - if assert.NotNil(t, firstDTO.FieldType) { - assert.Equal(t, loop_span.FieldType(*firstDTO.FieldType), firstDomain.FieldType) - } - assert.Equal(t, firstDTO.Values, firstDomain.Values) - assert.False(t, firstDomain.Hidden) - } - } - }, - }, - { - name: "build platform filter error", - fieldsBuilder: func(ctrl *gomock.Controller, t *testing.T, a args) fields { - builder := traceSvcMock.NewMockTraceFilterProcessorBuilder(ctrl) - builder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), loop_span.PlatformType(commondomain.PlatformTypeCozeBot)).Return(nil, errors.New("build platform error")) - return fields{builder: builder} - }, - args: args{ - spanFilters: &filterdto.SpanFilterFields{ - Filters: &filterdto.FilterFields{ - FilterFields: []*filterdto.FilterField{}, - }, - PlatformType: gptr.Of(commondomain.PlatformTypeCozeBot), - SpanListType: gptr.Of(commondomain.SpanListTypeRootSpan), - }, - workspaceID: 200, - }, - assertFunc: func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) { - assert.Nil(t, got) - assert.EqualError(t, err, "build platform error") - }, - }, - { - name: "build basic span filter error", - fieldsBuilder: func(ctrl *gomock.Controller, t *testing.T, a args) fields { - builder := traceSvcMock.NewMockTraceFilterProcessorBuilder(ctrl) - platformFilter := filtermocks.NewMockFilter(ctrl) - builder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), loop_span.PlatformType(commondomain.PlatformTypeWorkflow)).Return(platformFilter, nil) - platformFilter.EXPECT(). - BuildBasicSpanFilter(gomock.Any(), gomock.AssignableToTypeOf(&span_filter.SpanEnv{})). - DoAndReturn(func(_ context.Context, env *span_filter.SpanEnv) ([]*loop_span.FilterField, bool, error) { - assert.Equal(t, a.workspaceID, env.WorkspaceID) - return nil, false, errors.New("build basic error") - }) - return fields{builder: builder} - }, - args: args{ - spanFilters: &filterdto.SpanFilterFields{ - Filters: &filterdto.FilterFields{ - FilterFields: []*filterdto.FilterField{}, - }, - PlatformType: gptr.Of(commondomain.PlatformTypeWorkflow), - SpanListType: gptr.Of(commondomain.SpanListTypeRootSpan), - }, - workspaceID: 300, - }, - assertFunc: func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) { - assert.Nil(t, got) - assert.EqualError(t, err, "build basic error") - }, - }, - { - name: "empty basic filter without force returns nil", - fieldsBuilder: func(ctrl *gomock.Controller, t *testing.T, a args) fields { - builder := traceSvcMock.NewMockTraceFilterProcessorBuilder(ctrl) - platformFilter := filtermocks.NewMockFilter(ctrl) - builder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), loop_span.PlatformType(commondomain.PlatformTypeInnerCozeBot)).Return(platformFilter, nil) - platformFilter.EXPECT(). - BuildBasicSpanFilter(gomock.Any(), gomock.AssignableToTypeOf(&span_filter.SpanEnv{})). - DoAndReturn(func(_ context.Context, env *span_filter.SpanEnv) ([]*loop_span.FilterField, bool, error) { - assert.Equal(t, a.workspaceID, env.WorkspaceID) - return []*loop_span.FilterField{}, false, nil - }) - return fields{builder: builder} - }, - args: args{ - spanFilters: &filterdto.SpanFilterFields{ - Filters: &filterdto.FilterFields{ - FilterFields: []*filterdto.FilterField{}, - }, - PlatformType: gptr.Of(commondomain.PlatformTypeInnerCozeBot), - SpanListType: gptr.Of(commondomain.SpanListTypeRootSpan), - }, - workspaceID: 400, - }, - assertFunc: func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) { - assert.NoError(t, err) - assert.Nil(t, got) - }, - }, - { - name: "merge platform filters success", - fieldsBuilder: func(ctrl *gomock.Controller, t *testing.T, a args) fields { - builder := traceSvcMock.NewMockTraceFilterProcessorBuilder(ctrl) - platformFilter := filtermocks.NewMockFilter(ctrl) - builder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), loop_span.PlatformType(commondomain.PlatformTypeProject)).Return(platformFilter, nil) - platformFilter.EXPECT(). - BuildBasicSpanFilter(gomock.Any(), gomock.AssignableToTypeOf(&span_filter.SpanEnv{})). - DoAndReturn(func(_ context.Context, env *span_filter.SpanEnv) ([]*loop_span.FilterField, bool, error) { - assert.Equal(t, a.workspaceID, env.WorkspaceID) - return []*loop_span.FilterField{ - { - FieldName: loop_span.SpanFieldSpaceId, - FieldType: loop_span.FieldTypeString, - Values: []string{"tenant"}, - }, - }, false, nil - }) - return fields{builder: builder} - }, - args: args{ - spanFilters: &filterdto.SpanFilterFields{ - Filters: &filterdto.FilterFields{ - FilterFields: []*filterdto.FilterField{ - { - FieldName: gptr.Of("custom_field"), - FieldType: gptr.Of(filterdto.FieldTypeString), - Values: []string{"origin"}, - }, - }, - }, - PlatformType: gptr.Of(commondomain.PlatformTypeProject), - SpanListType: gptr.Of(commondomain.SpanListTypeRootSpan), - }, - workspaceID: 500, - }, - assertFunc: func(t *testing.T, original *filterdto.SpanFilterFields, got *entity.SpanFilterFields, err error) { - assert.NoError(t, err) - if assert.NotNil(t, got) { - assert.Equal(t, commondomain.PlatformTypeProject, got.PlatformType) - assert.Equal(t, commondomain.SpanListTypeRootSpan, got.SpanListType) - originalFilters := original.GetFilters().GetFilterFields() - if assert.Len(t, got.Filters.FilterFields, len(originalFilters)+1) && len(originalFilters) > 0 { - firstDomain := got.Filters.FilterFields[0] - firstDTO := originalFilters[0] - if assert.NotNil(t, firstDTO.FieldName) { - assert.Equal(t, *firstDTO.FieldName, firstDomain.FieldName) - } - assert.False(t, firstDomain.Hidden) - appended := got.Filters.FilterFields[len(originalFilters)] - assert.Equal(t, loop_span.SpanFieldSpaceId, appended.FieldName) - assert.True(t, appended.Hidden) - assert.Equal(t, []string{"tenant"}, appended.Values) - } - } - }, - }, - } - - for _, tt := range tests { - caseItem := tt - t.Run(caseItem.name, func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - fields := caseItem.fieldsBuilder(ctrl, t, caseItem.args) - app := &TaskApplication{ - buildHelper: fields.builder, - } - - got, err := app.buildSpanFilters(context.Background(), caseItem.args.spanFilters, caseItem.args.workspaceID) - - caseItem.assertFunc(t, caseItem.args.spanFilters, got, err) - }) - } -} - func TestTaskApplication_UpdateTask(t *testing.T) { t.Parallel() diff --git a/backend/modules/observability/application/wire.go b/backend/modules/observability/application/wire.go index a1cefa728..bac80b142 100644 --- a/backend/modules/observability/application/wire.go +++ b/backend/modules/observability/application/wire.go @@ -55,7 +55,7 @@ import ( obrepo "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo" ckdao "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/ck" mysqldao "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" - tredis "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis/dao" + redis2 "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/auth" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/dataset" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/evaluation" @@ -77,8 +77,8 @@ var ( obrepo.NewTaskRepoImpl, // obrepo.NewTaskRunRepoImpl, mysqldao.NewTaskDaoImpl, - tredis.NewTaskDAO, - tredis.NewTaskRunDAO, + redis2.NewTaskDAO, + redis2.NewTaskRunDAO, mysqldao.NewTaskRunDaoImpl, mq2.NewBackfillProducerImpl, ) diff --git a/backend/modules/observability/application/wire_gen.go b/backend/modules/observability/application/wire_gen.go index 6d30d8a1f..0308c6142 100644 --- a/backend/modules/observability/application/wire_gen.go +++ b/backend/modules/observability/application/wire_gen.go @@ -24,7 +24,6 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/auth/authservice" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/file/fileservice" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/user/userservice" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" config2 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/entity" @@ -33,6 +32,7 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/model" service4 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/tool" + entity3 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" repo3 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" service3 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" @@ -55,7 +55,7 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo" ck2 "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/ck" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis/dao" + redis3 "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/auth" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/dataset" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/evaluation" @@ -104,9 +104,9 @@ func InitTraceApplication(db2 db.Provider, ckDb ck.Provider, redis2 redis.Cmdabl iTenantProvider := tenant.NewTenantProvider(iTraceConfig) iEvaluatorRPCAdapter := evaluator.NewEvaluatorRPCProvider(evalService) iTaskDao := mysql.NewTaskDaoImpl(db2) - iTaskDAO := dao.NewTaskDAO(redis2) + iTaskDAO := redis3.NewTaskDAO(redis2) iTaskRunDao := mysql.NewTaskRunDaoImpl(db2) - iTaskRunDAO := dao.NewTaskRunDAO(redis2) + iTaskRunDAO := redis3.NewTaskRunDAO(redis2) iTaskRepo := repo.NewTaskRepoImpl(iTaskDao, idgen2, iTaskDAO, iTaskRunDao, iTaskRunDAO) iTraceService, err := service.NewTraceServiceImpl(iTraceRepo, iTraceConfig, iTraceProducer, iAnnotationProducer, iTraceMetrics, traceFilterProcessorBuilder, iTenantProvider, iEvaluatorRPCAdapter, iTaskRepo) if err != nil { @@ -161,9 +161,9 @@ func InitOpenAPIApplication(mqFactory mq.IFactory, configFactory conf.IConfigLoa iTenantProvider := tenant.NewTenantProvider(iTraceConfig) iEvaluatorRPCAdapter := evaluator.NewEvaluatorRPCProvider(evalService) iTaskDao := mysql.NewTaskDaoImpl(db2) - iTaskDAO := dao.NewTaskDAO(redis2) + iTaskDAO := redis3.NewTaskDAO(redis2) iTaskRunDao := mysql.NewTaskRunDaoImpl(db2) - iTaskRunDAO := dao.NewTaskRunDAO(redis2) + iTaskRunDAO := redis3.NewTaskRunDAO(redis2) iTaskRepo := repo.NewTaskRepoImpl(iTaskDao, idgen2, iTaskDAO, iTaskRunDao, iTaskRunDAO) iTraceService, err := service.NewTraceServiceImpl(iTraceRepo, iTraceConfig, iTraceProducer, iAnnotationProducer, iTraceMetrics, traceFilterProcessorBuilder, iTenantProvider, iEvaluatorRPCAdapter, iTaskRepo) if err != nil { @@ -242,11 +242,10 @@ func InitTraceIngestionApplication(configFactory conf.IConfigLoaderFactory, ckDb func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFactory conf.IConfigLoaderFactory, benefit2 benefit.IBenefitService, ckDb ck.Provider, redis2 redis.Cmdable, mqFactory mq.IFactory, userClient userservice.Client, authClient authservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, exptService experimentservice.Client, datasetService datasetservice.Client, fileClient fileservice.Client, taskProcessor processor.TaskProcessor, aid int32) (ITaskApplication, error) { iTaskDao := mysql.NewTaskDaoImpl(db2) - iTaskDAO := dao.NewTaskDAO(redis2) + iTaskDAO := redis3.NewTaskDAO(redis2) iTaskRunDao := mysql.NewTaskRunDaoImpl(db2) - iTaskRunDAO := dao.NewTaskRunDAO(redis2) + iTaskRunDAO := redis3.NewTaskRunDAO(redis2) iTaskRepo := repo.NewTaskRepoImpl(iTaskDao, idgen2, iTaskDAO, iTaskRunDao, iTaskRunDAO) - iUserProvider := user.NewUserRPCProvider(userClient) iConfigLoader, err := NewTraceConfigLoader(configFactory) if err != nil { return nil, err @@ -260,11 +259,14 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto iEvaluatorRPCAdapter := evaluator.NewEvaluatorRPCProvider(evalService) iEvaluationRPCAdapter := evaluation.NewEvaluationRPCProvider(exptService) processorTaskProcessor := NewInitTaskProcessor(datasetServiceAdaptor, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iTaskRepo) - iTaskService, err := service3.NewTaskServiceImpl(iTaskRepo, iUserProvider, idgen2, iBackfillProducer, processorTaskProcessor) + iFileProvider := file.NewFileRPCProvider(fileClient) + traceFilterProcessorBuilder := NewTraceProcessorBuilder(iTraceConfig, iFileProvider, benefit2) + iTaskService, err := service3.NewTaskServiceImpl(iTaskRepo, idgen2, iBackfillProducer, processorTaskProcessor, traceFilterProcessorBuilder) if err != nil { return nil, err } iAuthProvider := auth.NewAuthProvider(authClient) + iUserProvider := user.NewUserRPCProvider(userClient) iSpansDao, err := ck2.NewSpansCkDaoImpl(ckDb) if err != nil { return nil, err @@ -278,14 +280,12 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto return nil, err } iTenantProvider := tenant.NewTenantProvider(iTraceConfig) - iFileProvider := file.NewFileRPCProvider(fileClient) - traceFilterProcessorBuilder := NewTraceProcessorBuilder(iTraceConfig, iFileProvider, benefit2) iLocker := NewTaskLocker(redis2) iTraceHubService, err := tracehub.NewTraceHubImpl(iTaskRepo, iTraceRepo, iTenantProvider, traceFilterProcessorBuilder, processorTaskProcessor, benefit2, aid, iBackfillProducer, iLocker, iConfigLoader) if err != nil { return nil, err } - iTaskApplication, err := NewTaskApplication(iTaskService, iAuthProvider, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iUserProvider, iTraceHubService, taskProcessor, traceFilterProcessorBuilder) + iTaskApplication, err := NewTaskApplication(iTaskService, iAuthProvider, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iUserProvider, iTraceHubService, taskProcessor) if err != nil { return nil, err } @@ -296,7 +296,7 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto var ( taskDomainSet = wire.NewSet( - NewInitTaskProcessor, service3.NewTaskServiceImpl, repo.NewTaskRepoImpl, mysql.NewTaskDaoImpl, dao.NewTaskDAO, dao.NewTaskRunDAO, mysql.NewTaskRunDaoImpl, producer.NewBackfillProducerImpl, + NewInitTaskProcessor, service3.NewTaskServiceImpl, repo.NewTaskRepoImpl, mysql.NewTaskDaoImpl, redis3.NewTaskDAO, redis3.NewTaskRunDAO, mysql.NewTaskRunDaoImpl, producer.NewBackfillProducerImpl, ) traceDomainSet = wire.NewSet(service.NewTraceServiceImpl, service.NewTraceExportServiceImpl, repo.NewTraceCKRepoImpl, ck2.NewSpansCkDaoImpl, ck2.NewAnnotationCkDaoImpl, metrics2.NewTraceMetricsImpl, collector.NewEventCollectorProvider, producer.NewTraceProducerImpl, producer.NewAnnotationProducerImpl, file.NewFileRPCProvider, NewTraceConfigLoader, NewTraceProcessorBuilder, config.NewTraceConfigCenter, tenant.NewTenantProvider, workspace.NewWorkspaceProvider, evaluator.NewEvaluatorRPCProvider, NewDatasetServiceAdapter, @@ -371,6 +371,6 @@ func NewInitTaskProcessor(datasetServiceProvider *service.DatasetServiceAdaptor, evaluationService rpc.IEvaluationRPCAdapter, taskRepo repo3.ITaskRepo, ) *processor.TaskProcessor { taskProcessor := processor.NewTaskProcessor() - taskProcessor.Register(task.TaskTypeAutoEval, processor.NewAutoEvaluteProcessor(0, datasetServiceProvider, evalService, evaluationService, taskRepo)) + taskProcessor.Register(entity3.TaskTypeAutoEval, processor.NewAutoEvaluteProcessor(0, datasetServiceProvider, evalService, evaluationService, taskRepo)) return taskProcessor } diff --git a/backend/modules/observability/domain/task/entity/task.go b/backend/modules/observability/domain/task/entity/task.go index 867f5e273..16ab0fa7e 100644 --- a/backend/modules/observability/domain/task/entity/task.go +++ b/backend/modules/observability/domain/task/entity/task.go @@ -4,11 +4,15 @@ package entity import ( + "context" "time" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/dataset" + taskdto "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/pkg/logs" ) type TaskStatus string @@ -43,6 +47,11 @@ const ( TaskRunStatusDone TaskRunStatus = "done" ) +type StatusChangeEvent struct { + Before TaskStatus + After TaskStatus +} + // do type ObservabilityTask struct { ID int64 // Task ID @@ -72,8 +81,8 @@ type RunDetail struct { } type SpanFilterFields struct { Filters loop_span.FilterFields `json:"filters"` - PlatformType common.PlatformType `json:"platform_type"` - SpanListType common.SpanListType `json:"span_list_type"` + PlatformType loop_span.PlatformType `json:"platform_type"` + SpanListType loop_span.SpanListType `json:"span_list_type"` } type EffectiveTime struct { // ms timestamp @@ -157,7 +166,7 @@ type DataReflowRunConfig struct { Status string `json:"status"` } -func (t ObservabilityTask) IsFinished() bool { +func (t *ObservabilityTask) IsFinished() bool { switch t.TaskStatus { case TaskStatusSuccess, TaskStatusDisabled, TaskStatusPending: return true @@ -166,7 +175,7 @@ func (t ObservabilityTask) IsFinished() bool { } } -func (t ObservabilityTask) GetBackfillTaskRun() *TaskRun { +func (t *ObservabilityTask) GetBackfillTaskRun() *TaskRun { for _, taskRunPO := range t.TaskRuns { if taskRunPO.TaskType == TaskRunTypeBackFill { return taskRunPO @@ -175,7 +184,7 @@ func (t ObservabilityTask) GetBackfillTaskRun() *TaskRun { return nil } -func (t ObservabilityTask) GetCurrentTaskRun() *TaskRun { +func (t *ObservabilityTask) GetCurrentTaskRun() *TaskRun { for _, taskRunPO := range t.TaskRuns { if taskRunPO.TaskType == TaskRunTypeNewData && taskRunPO.RunStatus == TaskRunStatusRunning { return taskRunPO @@ -184,7 +193,7 @@ func (t ObservabilityTask) GetCurrentTaskRun() *TaskRun { return nil } -func (t ObservabilityTask) GetTaskTTL() int64 { +func (t *ObservabilityTask) GetTaskTTL() int64 { var ttl int64 if t.EffectiveTime != nil { ttl = t.EffectiveTime.EndAt - t.EffectiveTime.StartAt @@ -194,3 +203,93 @@ func (t ObservabilityTask) GetTaskTTL() int64 { } return ttl } + +func (t *ObservabilityTask) SetEffectiveTime(ctx context.Context, effectiveTime EffectiveTime) error { + if t.EffectiveTime == nil { + logs.CtxError(ctx, "EffectiveTime is null.") + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("effective time is nil")) + } + // 开始时间不能大于结束时间 + if effectiveTime.StartAt >= effectiveTime.EndAt { + logs.CtxError(ctx, "Start time must be less than end time") + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be less than end time")) + } + // 开始、结束时间不能小于当前时间 + if t.EffectiveTime.StartAt != effectiveTime.StartAt && effectiveTime.StartAt < time.Now().UnixMilli() { + logs.CtxError(ctx, "update time must be greater than current time") + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be greater than current time")) + } + if t.EffectiveTime.EndAt != effectiveTime.EndAt && effectiveTime.EndAt < time.Now().UnixMilli() { + logs.CtxError(ctx, "update time must be greater than current time") + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("start time must be greater than current time")) + } + switch t.TaskStatus { + case TaskStatusUnstarted: + if effectiveTime.StartAt != 0 { + t.EffectiveTime.StartAt = effectiveTime.StartAt + } + if effectiveTime.EndAt != 0 { + t.EffectiveTime.EndAt = effectiveTime.EndAt + } + case TaskStatusRunning, TaskStatusPending: + if effectiveTime.EndAt != 0 { + t.EffectiveTime.EndAt = effectiveTime.EndAt + } + default: + logs.CtxError(ctx, "Invalid task status:%s", t.TaskStatus) + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("invalid task status")) + } + return nil +} + +func (t *ObservabilityTask) SetTaskStatus(ctx context.Context, taskStatus TaskStatus) (*StatusChangeEvent, error) { + currentTaskStatus := t.TaskStatus + if currentTaskStatus == taskStatus { + return nil, nil + } + + switch taskStatus { + case taskdto.TaskStatusUnstarted: + break + case taskdto.TaskStatusRunning: + if currentTaskStatus == taskdto.TaskStatusUnstarted || currentTaskStatus == taskdto.TaskStatusPending { + t.TaskStatus = taskStatus + return &StatusChangeEvent{ + Before: currentTaskStatus, + After: taskStatus, + }, nil + } + case taskdto.TaskStatusPending: + if currentTaskStatus == taskdto.TaskStatusRunning { + t.TaskStatus = taskStatus + return &StatusChangeEvent{ + Before: currentTaskStatus, + After: taskStatus, + }, nil + } + case taskdto.TaskStatusDisabled: + if currentTaskStatus == taskdto.TaskStatusUnstarted || currentTaskStatus == taskdto.TaskStatusPending { + t.TaskStatus = taskStatus + return &StatusChangeEvent{ + Before: currentTaskStatus, + After: taskStatus, + }, nil + } + case taskdto.TaskStatusSuccess: + break + } + + logs.CtxError(ctx, "Invalid task status. Before:[%s], after:[%s]", currentTaskStatus, taskStatus) + return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("invalid task status")) +} + +func (t *ObservabilityTask) ShouldTriggerBackfill() bool { + // 检查回填时间配置 + if t.BackfillEffectiveTime == nil { + return false + } + + return t.BackfillEffectiveTime.StartAt > 0 && + t.BackfillEffectiveTime.EndAt > 0 && + t.BackfillEffectiveTime.StartAt < t.BackfillEffectiveTime.EndAt +} diff --git a/backend/modules/observability/domain/task/entity/task_test.go b/backend/modules/observability/domain/task/entity/task_test.go index 8edaf6d23..e1276646b 100644 --- a/backend/modules/observability/domain/task/entity/task_test.go +++ b/backend/modules/observability/domain/task/entity/task_test.go @@ -1,4 +1,86 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 - package entity + +import ( + "context" + "reflect" + "testing" +) + +func TestObservabilityTask_SetTaskStatus(t *testing.T) { + tests := []struct { + name string // 测试用例名称 + initialTask ObservabilityTask // 任务的初始状态 + targetStatus TaskStatus // 目标设置的状态 + wantEvent *StatusChangeEvent // 期望返回的事件 + wantErr bool // 是否期望发生错误 + finalStatus TaskStatus // 期望的最终任务状态 + }{ + { + name: "状态相同时不进行变更", + initialTask: ObservabilityTask{TaskStatus: TaskStatusRunning}, + targetStatus: TaskStatusRunning, + wantEvent: nil, + wantErr: false, + finalStatus: TaskStatusRunning, + }, + { + name: "有效状态流转:从未开始到运行中", + initialTask: ObservabilityTask{TaskStatus: TaskStatusUnstarted}, + targetStatus: TaskStatusRunning, + wantEvent: &StatusChangeEvent{ + Before: TaskStatusUnstarted, + After: TaskStatusRunning, + }, + wantErr: false, + finalStatus: TaskStatusRunning, + }, + { + name: "有效状态流转:从挂起到运行中", + initialTask: ObservabilityTask{TaskStatus: TaskStatusPending}, + targetStatus: TaskStatusRunning, + wantEvent: &StatusChangeEvent{ + Before: TaskStatusPending, + After: TaskStatusRunning, + }, + wantErr: false, + finalStatus: TaskStatusRunning, + }, + { + name: "无效状态流转:从禁用状态到其他状态", + initialTask: ObservabilityTask{TaskStatus: TaskStatusDisabled}, + targetStatus: TaskStatusRunning, + wantEvent: nil, + wantErr: true, + finalStatus: TaskStatusDisabled, + }, + } + + // 遍历并执行所有测试用例 + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Arrange: 创建一个任务副本以防止并发测试时修改原始测试用例数据 + task := tt.initialTask + + // Act: 调用被测方法 + gotEvent, err := task.SetTaskStatus(context.Background(), tt.targetStatus) + + // Assert: 校验错误是否符合预期 + if (err != nil) != tt.wantErr { + t.Errorf("SetTaskStatus() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // Assert: 校验返回的事件是否符合预期 + if !reflect.DeepEqual(gotEvent, tt.wantEvent) { + t.Errorf("SetTaskStatus() gotEvent = %v, want %v", gotEvent, tt.wantEvent) + } + + // Assert: 校验任务的最终状态是否符合预期 + if task.TaskStatus != tt.finalStatus { + t.Errorf("Final task status = %v, want %v", task.TaskStatus, tt.finalStatus) + } + }) + } +} diff --git a/backend/modules/observability/domain/task/repo/mocks/Task.go b/backend/modules/observability/domain/task/repo/mocks/Task.go index 6af237884..0020c6c71 100644 --- a/backend/modules/observability/domain/task/repo/mocks/Task.go +++ b/backend/modules/observability/domain/task/repo/mocks/Task.go @@ -3,7 +3,7 @@ // // Generated by this command: // -// mockgen -destination=modules/observability/domain/task/repo/mocks/Task.go -package=mocks github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo ITaskRepo +// mockgen -destination=mocks/Task.go -package=mocks . ITaskRepo // // Package mocks is a generated GoMock package. @@ -14,7 +14,7 @@ import ( reflect "reflect" entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" - mysql "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" gomock "go.uber.org/mock/gomock" ) @@ -22,6 +22,7 @@ import ( type MockITaskRepo struct { ctrl *gomock.Controller recorder *MockITaskRepoMockRecorder + isgomock struct{} } // MockITaskRepoMockRecorder is the mock recorder for MockITaskRepo. @@ -41,409 +42,394 @@ func (m *MockITaskRepo) EXPECT() *MockITaskRepoMockRecorder { return m.recorder } +// AddNonFinalTask mocks base method. +func (m *MockITaskRepo) AddNonFinalTask(ctx context.Context, spaceID string, taskID int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddNonFinalTask", ctx, spaceID, taskID) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddNonFinalTask indicates an expected call of AddNonFinalTask. +func (mr *MockITaskRepoMockRecorder) AddNonFinalTask(ctx, spaceID, taskID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNonFinalTask", reflect.TypeOf((*MockITaskRepo)(nil).AddNonFinalTask), ctx, spaceID, taskID) +} + // CreateTask mocks base method. -func (m *MockITaskRepo) CreateTask(arg0 context.Context, arg1 *entity.ObservabilityTask) (int64, error) { +func (m *MockITaskRepo) CreateTask(ctx context.Context, do *entity.ObservabilityTask) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateTask", arg0, arg1) + ret := m.ctrl.Call(m, "CreateTask", ctx, do) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateTask indicates an expected call of CreateTask. -func (mr *MockITaskRepoMockRecorder) CreateTask(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) CreateTask(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTask", reflect.TypeOf((*MockITaskRepo)(nil).CreateTask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTask", reflect.TypeOf((*MockITaskRepo)(nil).CreateTask), ctx, do) } // CreateTaskRun mocks base method. -func (m *MockITaskRepo) CreateTaskRun(arg0 context.Context, arg1 *entity.TaskRun) (int64, error) { +func (m *MockITaskRepo) CreateTaskRun(ctx context.Context, do *entity.TaskRun) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateTaskRun", arg0, arg1) + ret := m.ctrl.Call(m, "CreateTaskRun", ctx, do) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // CreateTaskRun indicates an expected call of CreateTaskRun. -func (mr *MockITaskRepoMockRecorder) CreateTaskRun(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) CreateTaskRun(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).CreateTaskRun), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).CreateTaskRun), ctx, do) } // DecrTaskCount mocks base method. -func (m *MockITaskRepo) DecrTaskCount(arg0 context.Context, arg1, arg2 int64) error { +func (m *MockITaskRepo) DecrTaskCount(ctx context.Context, taskID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecrTaskCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "DecrTaskCount", ctx, taskID, ttl) ret0, _ := ret[0].(error) return ret0 } // DecrTaskCount indicates an expected call of DecrTaskCount. -func (mr *MockITaskRepoMockRecorder) DecrTaskCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) DecrTaskCount(ctx, taskID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskCount), ctx, taskID, ttl) } // DecrTaskRunCount mocks base method. -func (m *MockITaskRepo) DecrTaskRunCount(arg0 context.Context, arg1, arg2, arg3 int64) error { +func (m *MockITaskRepo) DecrTaskRunCount(ctx context.Context, taskID, taskRunID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecrTaskRunCount", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "DecrTaskRunCount", ctx, taskID, taskRunID, ttl) ret0, _ := ret[0].(error) return ret0 } // DecrTaskRunCount indicates an expected call of DecrTaskRunCount. -func (mr *MockITaskRepoMockRecorder) DecrTaskRunCount(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) DecrTaskRunCount(ctx, taskID, taskRunID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskRunCount), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskRunCount), ctx, taskID, taskRunID, ttl) } // DecrTaskRunSuccessCount mocks base method. -func (m *MockITaskRepo) DecrTaskRunSuccessCount(arg0 context.Context, arg1, arg2 int64) error { +func (m *MockITaskRepo) DecrTaskRunSuccessCount(ctx context.Context, taskID, taskRunID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DecrTaskRunSuccessCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "DecrTaskRunSuccessCount", ctx, taskID, taskRunID) ret0, _ := ret[0].(error) return ret0 } // DecrTaskRunSuccessCount indicates an expected call of DecrTaskRunSuccessCount. -func (mr *MockITaskRepoMockRecorder) DecrTaskRunSuccessCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) DecrTaskRunSuccessCount(ctx, taskID, taskRunID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskRunSuccessCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DecrTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).DecrTaskRunSuccessCount), ctx, taskID, taskRunID) } // DeleteTask mocks base method. -func (m *MockITaskRepo) DeleteTask(arg0 context.Context, arg1 *entity.ObservabilityTask) error { +func (m *MockITaskRepo) DeleteTask(ctx context.Context, do *entity.ObservabilityTask) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteTask", arg0, arg1) + ret := m.ctrl.Call(m, "DeleteTask", ctx, do) ret0, _ := ret[0].(error) return ret0 } // DeleteTask indicates an expected call of DeleteTask. -func (mr *MockITaskRepoMockRecorder) DeleteTask(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) DeleteTask(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockITaskRepo)(nil).DeleteTask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockITaskRepo)(nil).DeleteTask), ctx, do) } // GetBackfillTaskRun mocks base method. -func (m *MockITaskRepo) GetBackfillTaskRun(arg0 context.Context, arg1 *int64, arg2 int64) (*entity.TaskRun, error) { +func (m *MockITaskRepo) GetBackfillTaskRun(ctx context.Context, workspaceID *int64, taskID int64) (*entity.TaskRun, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetBackfillTaskRun", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetBackfillTaskRun", ctx, workspaceID, taskID) ret0, _ := ret[0].(*entity.TaskRun) ret1, _ := ret[1].(error) return ret0, ret1 } // GetBackfillTaskRun indicates an expected call of GetBackfillTaskRun. -func (mr *MockITaskRepoMockRecorder) GetBackfillTaskRun(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetBackfillTaskRun(ctx, workspaceID, taskID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBackfillTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).GetBackfillTaskRun), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBackfillTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).GetBackfillTaskRun), ctx, workspaceID, taskID) } // GetLatestNewDataTaskRun mocks base method. -func (m *MockITaskRepo) GetLatestNewDataTaskRun(arg0 context.Context, arg1 *int64, arg2 int64) (*entity.TaskRun, error) { +func (m *MockITaskRepo) GetLatestNewDataTaskRun(ctx context.Context, workspaceID *int64, taskID int64) (*entity.TaskRun, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetLatestNewDataTaskRun", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetLatestNewDataTaskRun", ctx, workspaceID, taskID) ret0, _ := ret[0].(*entity.TaskRun) ret1, _ := ret[1].(error) return ret0, ret1 } // GetLatestNewDataTaskRun indicates an expected call of GetLatestNewDataTaskRun. -func (mr *MockITaskRepoMockRecorder) GetLatestNewDataTaskRun(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetLatestNewDataTaskRun(ctx, workspaceID, taskID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestNewDataTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).GetLatestNewDataTaskRun), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLatestNewDataTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).GetLatestNewDataTaskRun), ctx, workspaceID, taskID) } -// GetObjListWithTask mocks base method. -func (m *MockITaskRepo) GetObjListWithTask(arg0 context.Context) ([]string, []string, []*entity.ObservabilityTask) { +// GetTask mocks base method. +func (m *MockITaskRepo) GetTask(ctx context.Context, id int64, workspaceID *int64, userID *string) (*entity.ObservabilityTask, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetObjListWithTask", arg0) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].([]string) - ret2, _ := ret[2].([]*entity.ObservabilityTask) - return ret0, ret1, ret2 + ret := m.ctrl.Call(m, "GetTask", ctx, id, workspaceID, userID) + ret0, _ := ret[0].(*entity.ObservabilityTask) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// GetObjListWithTask indicates an expected call of GetObjListWithTask. -func (mr *MockITaskRepoMockRecorder) GetObjListWithTask(arg0 any) *gomock.Call { +// GetTask indicates an expected call of GetTask. +func (mr *MockITaskRepoMockRecorder) GetTask(ctx, id, workspaceID, userID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObjListWithTask", reflect.TypeOf((*MockITaskRepo)(nil).GetObjListWithTask), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTask", reflect.TypeOf((*MockITaskRepo)(nil).GetTask), ctx, id, workspaceID, userID) } -// GetTask mocks base method. -func (m *MockITaskRepo) GetTask(arg0 context.Context, arg1 int64, arg2 *int64, arg3 *string) (*entity.ObservabilityTask, error) { +// GetTaskByCache mocks base method. +func (m *MockITaskRepo) GetTaskByCache(ctx context.Context, taskID int64) (*entity.ObservabilityTask, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTask", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "GetTaskByCache", ctx, taskID) ret0, _ := ret[0].(*entity.ObservabilityTask) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetTask indicates an expected call of GetTask. -func (mr *MockITaskRepoMockRecorder) GetTask(arg0, arg1, arg2, arg3 any) *gomock.Call { +// GetTaskByCache indicates an expected call of GetTaskByCache. +func (mr *MockITaskRepoMockRecorder) GetTaskByCache(ctx, taskID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTask", reflect.TypeOf((*MockITaskRepo)(nil).GetTask), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByCache", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskByCache), ctx, taskID) } // GetTaskCount mocks base method. -func (m *MockITaskRepo) GetTaskCount(arg0 context.Context, arg1 int64) (int64, error) { +func (m *MockITaskRepo) GetTaskCount(ctx context.Context, taskID int64) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskCount", arg0, arg1) + ret := m.ctrl.Call(m, "GetTaskCount", ctx, taskID) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTaskCount indicates an expected call of GetTaskCount. -func (mr *MockITaskRepoMockRecorder) GetTaskCount(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetTaskCount(ctx, taskID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskCount), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskCount), ctx, taskID) } // GetTaskRunCount mocks base method. -func (m *MockITaskRepo) GetTaskRunCount(arg0 context.Context, arg1, arg2 int64) (int64, error) { +func (m *MockITaskRepo) GetTaskRunCount(ctx context.Context, taskID, taskRunID int64) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskRunCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetTaskRunCount", ctx, taskID, taskRunID) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTaskRunCount indicates an expected call of GetTaskRunCount. -func (mr *MockITaskRepoMockRecorder) GetTaskRunCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetTaskRunCount(ctx, taskID, taskRunID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunCount), ctx, taskID, taskRunID) } // GetTaskRunFailCount mocks base method. -func (m *MockITaskRepo) GetTaskRunFailCount(arg0 context.Context, arg1, arg2 int64) (int64, error) { +func (m *MockITaskRepo) GetTaskRunFailCount(ctx context.Context, taskID, taskRunID int64) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskRunFailCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetTaskRunFailCount", ctx, taskID, taskRunID) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTaskRunFailCount indicates an expected call of GetTaskRunFailCount. -func (mr *MockITaskRepoMockRecorder) GetTaskRunFailCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetTaskRunFailCount(ctx, taskID, taskRunID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunFailCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunFailCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunFailCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunFailCount), ctx, taskID, taskRunID) } // GetTaskRunSuccessCount mocks base method. -func (m *MockITaskRepo) GetTaskRunSuccessCount(arg0 context.Context, arg1, arg2 int64) (int64, error) { +func (m *MockITaskRepo) GetTaskRunSuccessCount(ctx context.Context, taskID, taskRunID int64) (int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskRunSuccessCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetTaskRunSuccessCount", ctx, taskID, taskRunID) ret0, _ := ret[0].(int64) ret1, _ := ret[1].(error) return ret0, ret1 } // GetTaskRunSuccessCount indicates an expected call of GetTaskRunSuccessCount. -func (mr *MockITaskRepoMockRecorder) GetTaskRunSuccessCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) GetTaskRunSuccessCount(ctx, taskID, taskRunID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunSuccessCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskRunSuccessCount), ctx, taskID, taskRunID) } // IncrTaskCount mocks base method. -func (m *MockITaskRepo) IncrTaskCount(arg0 context.Context, arg1, arg2 int64) error { +func (m *MockITaskRepo) IncrTaskCount(ctx context.Context, taskID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IncrTaskCount", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "IncrTaskCount", ctx, taskID, ttl) ret0, _ := ret[0].(error) return ret0 } // IncrTaskCount indicates an expected call of IncrTaskCount. -func (mr *MockITaskRepoMockRecorder) IncrTaskCount(arg0, arg1, arg2 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) IncrTaskCount(ctx, taskID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskCount), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskCount), ctx, taskID, ttl) } // IncrTaskRunCount mocks base method. -func (m *MockITaskRepo) IncrTaskRunCount(arg0 context.Context, arg1, arg2, arg3 int64) error { +func (m *MockITaskRepo) IncrTaskRunCount(ctx context.Context, taskID, taskRunID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IncrTaskRunCount", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "IncrTaskRunCount", ctx, taskID, taskRunID, ttl) ret0, _ := ret[0].(error) return ret0 } // IncrTaskRunCount indicates an expected call of IncrTaskRunCount. -func (mr *MockITaskRepoMockRecorder) IncrTaskRunCount(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) IncrTaskRunCount(ctx, taskID, taskRunID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunCount), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunCount), ctx, taskID, taskRunID, ttl) } // IncrTaskRunFailCount mocks base method. -func (m *MockITaskRepo) IncrTaskRunFailCount(arg0 context.Context, arg1, arg2, arg3 int64) error { +func (m *MockITaskRepo) IncrTaskRunFailCount(ctx context.Context, taskID, taskRunID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IncrTaskRunFailCount", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "IncrTaskRunFailCount", ctx, taskID, taskRunID, ttl) ret0, _ := ret[0].(error) return ret0 } // IncrTaskRunFailCount indicates an expected call of IncrTaskRunFailCount. -func (mr *MockITaskRepoMockRecorder) IncrTaskRunFailCount(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) IncrTaskRunFailCount(ctx, taskID, taskRunID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunFailCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunFailCount), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunFailCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunFailCount), ctx, taskID, taskRunID, ttl) } // IncrTaskRunSuccessCount mocks base method. -func (m *MockITaskRepo) IncrTaskRunSuccessCount(arg0 context.Context, arg1, arg2, arg3 int64) error { +func (m *MockITaskRepo) IncrTaskRunSuccessCount(ctx context.Context, taskID, taskRunID, ttl int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "IncrTaskRunSuccessCount", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "IncrTaskRunSuccessCount", ctx, taskID, taskRunID, ttl) ret0, _ := ret[0].(error) return ret0 } // IncrTaskRunSuccessCount indicates an expected call of IncrTaskRunSuccessCount. -func (mr *MockITaskRepoMockRecorder) IncrTaskRunSuccessCount(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) IncrTaskRunSuccessCount(ctx, taskID, taskRunID, ttl any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunSuccessCount), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrTaskRunSuccessCount", reflect.TypeOf((*MockITaskRepo)(nil).IncrTaskRunSuccessCount), ctx, taskID, taskRunID, ttl) } -// ListNonFinalTask mocks base method. -func (m *MockITaskRepo) ListNonFinalTask(arg0 context.Context, arg1 string) ([]int64, error) { +// ListNonFinalTaskBySpaceID mocks base method. +func (m *MockITaskRepo) ListNonFinalTaskBySpaceID(ctx context.Context, spaceID string) ([]int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListNonFinalTask", arg0, arg1) + ret := m.ctrl.Call(m, "ListNonFinalTaskBySpaceID", ctx, spaceID) ret0, _ := ret[0].([]int64) ret1, _ := ret[1].(error) return ret0, ret1 } -// ListNonFinalTask indicates an expected call of ListNonFinalTask. -func (mr *MockITaskRepoMockRecorder) ListNonFinalTask(arg0, arg1 any) *gomock.Call { +// ListNonFinalTaskBySpaceID indicates an expected call of ListNonFinalTaskBySpaceID. +func (mr *MockITaskRepoMockRecorder) ListNonFinalTaskBySpaceID(ctx, spaceID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListNonFinalTask", reflect.TypeOf((*MockITaskRepo)(nil).ListNonFinalTask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListNonFinalTaskBySpaceID", reflect.TypeOf((*MockITaskRepo)(nil).ListNonFinalTaskBySpaceID), ctx, spaceID) } -// ListTasks mocks base method. -func (m *MockITaskRepo) ListTasks(arg0 context.Context, arg1 mysql.ListTaskParam) ([]*entity.ObservabilityTask, int64, error) { +// ListNonFinalTasks mocks base method. +func (m *MockITaskRepo) ListNonFinalTasks(ctx context.Context) ([]*entity.ObservabilityTask, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListTasks", arg0, arg1) + ret := m.ctrl.Call(m, "ListNonFinalTasks", ctx) ret0, _ := ret[0].([]*entity.ObservabilityTask) - ret1, _ := ret[1].(int64) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 + ret1, _ := ret[1].(error) + return ret0, ret1 } -// ListTasks indicates an expected call of ListTasks. -func (mr *MockITaskRepoMockRecorder) ListTasks(arg0, arg1 any) *gomock.Call { +// ListNonFinalTasks indicates an expected call of ListNonFinalTasks. +func (mr *MockITaskRepoMockRecorder) ListNonFinalTasks(ctx any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockITaskRepo)(nil).ListTasks), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListNonFinalTasks", reflect.TypeOf((*MockITaskRepo)(nil).ListNonFinalTasks), ctx) } -// AddNonFinalTask mocks base method. -func (m *MockITaskRepo) AddNonFinalTask(arg0 context.Context, arg1 string, arg2 int64) error { +// ListTasks mocks base method. +func (m *MockITaskRepo) ListTasks(ctx context.Context, param repo.ListTaskParam) ([]*entity.ObservabilityTask, int64, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddNonFinalTask", arg0, arg1, arg2) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "ListTasks", ctx, param) + ret0, _ := ret[0].([]*entity.ObservabilityTask) + ret1, _ := ret[1].(int64) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 } -// AddNonFinalTask indicates an expected call of AddNonFinalTask. -func (mr *MockITaskRepoMockRecorder) AddNonFinalTask(arg0, arg1, arg2 any) *gomock.Call { +// ListTasks indicates an expected call of ListTasks. +func (mr *MockITaskRepoMockRecorder) ListTasks(ctx, param any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddNonFinalTask", reflect.TypeOf((*MockITaskRepo)(nil).AddNonFinalTask), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTasks", reflect.TypeOf((*MockITaskRepo)(nil).ListTasks), ctx, param) } // RemoveNonFinalTask mocks base method. -func (m *MockITaskRepo) RemoveNonFinalTask(arg0 context.Context, arg1 string, arg2 int64) error { +func (m *MockITaskRepo) RemoveNonFinalTask(ctx context.Context, spaceID string, taskID int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RemoveNonFinalTask", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "RemoveNonFinalTask", ctx, spaceID, taskID) ret0, _ := ret[0].(error) return ret0 } // RemoveNonFinalTask indicates an expected call of RemoveNonFinalTask. -func (mr *MockITaskRepoMockRecorder) RemoveNonFinalTask(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveNonFinalTask", reflect.TypeOf((*MockITaskRepo)(nil).RemoveNonFinalTask), arg0, arg1, arg2) -} - -// GetTaskByRedis mocks base method. -func (m *MockITaskRepo) GetTaskByRedis(arg0 context.Context, arg1 int64) (*entity.ObservabilityTask, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskByRedis", arg0, arg1) - ret0, _ := ret[0].(*entity.ObservabilityTask) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetTaskByRedis indicates an expected call of GetTaskByRedis. -func (mr *MockITaskRepoMockRecorder) GetTaskByRedis(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) RemoveNonFinalTask(ctx, spaceID, taskID any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskByRedis", reflect.TypeOf((*MockITaskRepo)(nil).GetTaskByRedis), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveNonFinalTask", reflect.TypeOf((*MockITaskRepo)(nil).RemoveNonFinalTask), ctx, spaceID, taskID) } // UpdateTask mocks base method. -func (m *MockITaskRepo) UpdateTask(arg0 context.Context, arg1 *entity.ObservabilityTask) error { +func (m *MockITaskRepo) UpdateTask(ctx context.Context, do *entity.ObservabilityTask) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTask", arg0, arg1) + ret := m.ctrl.Call(m, "UpdateTask", ctx, do) ret0, _ := ret[0].(error) return ret0 } // UpdateTask indicates an expected call of UpdateTask. -func (mr *MockITaskRepoMockRecorder) UpdateTask(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) UpdateTask(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTask", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTask", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTask), ctx, do) } // UpdateTaskRun mocks base method. -func (m *MockITaskRepo) UpdateTaskRun(arg0 context.Context, arg1 *entity.TaskRun) error { +func (m *MockITaskRepo) UpdateTaskRun(ctx context.Context, do *entity.TaskRun) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTaskRun", arg0, arg1) + ret := m.ctrl.Call(m, "UpdateTaskRun", ctx, do) ret0, _ := ret[0].(error) return ret0 } // UpdateTaskRun indicates an expected call of UpdateTaskRun. -func (mr *MockITaskRepoMockRecorder) UpdateTaskRun(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) UpdateTaskRun(ctx, do any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskRun), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskRun", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskRun), ctx, do) } // UpdateTaskRunWithOCC mocks base method. -func (m *MockITaskRepo) UpdateTaskRunWithOCC(arg0 context.Context, arg1, arg2 int64, arg3 map[string]any) error { +func (m *MockITaskRepo) UpdateTaskRunWithOCC(ctx context.Context, id, workspaceID int64, updateMap map[string]any) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTaskRunWithOCC", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "UpdateTaskRunWithOCC", ctx, id, workspaceID, updateMap) ret0, _ := ret[0].(error) return ret0 } // UpdateTaskRunWithOCC indicates an expected call of UpdateTaskRunWithOCC. -func (mr *MockITaskRepoMockRecorder) UpdateTaskRunWithOCC(arg0, arg1, arg2, arg3 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) UpdateTaskRunWithOCC(ctx, id, workspaceID, updateMap any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskRunWithOCC", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskRunWithOCC), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskRunWithOCC", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskRunWithOCC), ctx, id, workspaceID, updateMap) } // UpdateTaskWithOCC mocks base method. -func (m *MockITaskRepo) UpdateTaskWithOCC(arg0 context.Context, arg1, arg2 int64, arg3 map[string]any) error { +func (m *MockITaskRepo) UpdateTaskWithOCC(ctx context.Context, id, workspaceID int64, updateMap map[string]any) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateTaskWithOCC", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "UpdateTaskWithOCC", ctx, id, workspaceID, updateMap) ret0, _ := ret[0].(error) return ret0 } // UpdateTaskWithOCC indicates an expected call of UpdateTaskWithOCC. -func (mr *MockITaskRepoMockRecorder) UpdateTaskWithOCC(arg0, arg1, arg2, arg3 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskWithOCC", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskWithOCC), arg0, arg1, arg2, arg3) -} - -// SetTask mocks base method. -func (m *MockITaskRepo) SetTask(arg0 context.Context, arg1 *entity.ObservabilityTask) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SetTask", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// SetTask indicates an expected call of SetTask. -func (mr *MockITaskRepoMockRecorder) SetTask(arg0, arg1 any) *gomock.Call { +func (mr *MockITaskRepoMockRecorder) UpdateTaskWithOCC(ctx, id, workspaceID, updateMap any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetTask", reflect.TypeOf((*MockITaskRepo)(nil).SetTask), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateTaskWithOCC", reflect.TypeOf((*MockITaskRepo)(nil).UpdateTaskWithOCC), ctx, id, workspaceID, updateMap) } diff --git a/backend/modules/observability/domain/task/repo/task.go b/backend/modules/observability/domain/task/repo/task.go index 775443549..e8e2854d6 100644 --- a/backend/modules/observability/domain/task/repo/task.go +++ b/backend/modules/observability/domain/task/repo/task.go @@ -7,9 +7,17 @@ import ( "context" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" ) +type ListTaskParam struct { + WorkspaceIDs []int64 + TaskFilters *entity.TaskFilterFields + ReqLimit int32 + ReqOffset int32 + OrderBy *common.OrderBy +} + //go:generate mockgen -destination=mocks/Task.go -package=mocks . ITaskRepo type ITaskRepo interface { // task @@ -17,8 +25,10 @@ type ITaskRepo interface { UpdateTask(ctx context.Context, do *entity.ObservabilityTask) error UpdateTaskWithOCC(ctx context.Context, id int64, workspaceID int64, updateMap map[string]interface{}) error GetTask(ctx context.Context, id int64, workspaceID *int64, userID *string) (*entity.ObservabilityTask, error) - ListTasks(ctx context.Context, param mysql.ListTaskParam) ([]*entity.ObservabilityTask, int64, error) + ListTasks(ctx context.Context, param ListTaskParam) ([]*entity.ObservabilityTask, int64, error) DeleteTask(ctx context.Context, do *entity.ObservabilityTask) error + // ListNonFinalTasks Only return Task without TaskRun + ListNonFinalTasks(ctx context.Context) ([]*entity.ObservabilityTask, error) // task run CreateTaskRun(ctx context.Context, do *entity.TaskRun) (int64, error) @@ -41,16 +51,13 @@ type ITaskRepo interface { GetTaskRunSuccessCount(ctx context.Context, taskID, taskRunID int64) (int64, error) IncrTaskRunSuccessCount(ctx context.Context, taskID, taskRunID int64, ttl int64) error DecrTaskRunSuccessCount(ctx context.Context, taskID, taskRunID int64) error - IncrTaskRunFailCount(ctx context.Context, taskID, taskRunID int64, ttl int64) error GetTaskRunFailCount(ctx context.Context, taskID, taskRunID int64) (int64, error) - - GetObjListWithTask(ctx context.Context) ([]string, []string, []*entity.ObservabilityTask) + IncrTaskRunFailCount(ctx context.Context, taskID, taskRunID int64, ttl int64) error // 非终态task列表by spaceID - ListNonFinalTask(ctx context.Context, spaceID string) ([]int64, error) + ListNonFinalTaskBySpaceID(ctx context.Context, spaceID string) ([]int64, error) AddNonFinalTask(ctx context.Context, spaceID string, taskID int64) error RemoveNonFinalTask(ctx context.Context, spaceID string, taskID int64) error - GetTaskByRedis(ctx context.Context, taskID int64) (*entity.ObservabilityTask, error) - SetTask(ctx context.Context, task *entity.ObservabilityTask) error + GetTaskByCache(ctx context.Context, taskID int64) (*entity.ObservabilityTask, error) } diff --git a/backend/modules/observability/domain/task/service/task_service.go b/backend/modules/observability/domain/task/service/task_service.go index 916370fa7..870b75b84 100644 --- a/backend/modules/observability/domain/task/service/task_service.go +++ b/backend/modules/observability/domain/task/service/task_service.go @@ -12,17 +12,15 @@ import ( "github.com/bytedance/gg/gptr" "github.com/coze-dev/coze-loop/backend/infra/idgen" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" traceservice "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/logs" @@ -44,7 +42,7 @@ type UpdateTaskReq struct { } type ListTasksReq struct { WorkspaceID int64 - TaskFilters *filter.TaskFilterFields + TaskFilters *entity.TaskFilterFields Limit int32 Offset int32 OrderBy *common.OrderBy @@ -146,18 +144,16 @@ func (t *TaskServiceImpl) CreateTask(ctx context.Context, req *CreateTaskReq) (r } // 历史回溯数据发MQ - if t.shouldTriggerBackfill(taskDO) { + if taskDO.ShouldTriggerBackfill() { backfillEvent := &entity.BackFillEvent{ SpaceID: taskDO.WorkspaceID, TaskID: id, } - // 异步发送MQ消息,不阻塞任务创建流程 - go func() { - if err := t.sendBackfillMessage(context.Background(), backfillEvent); err != nil { - logs.CtxWarn(ctx, "send backfill message failed, task_id=%d, err=%v", id, err) - } - }() + if err := t.sendBackfillMessage(context.Background(), backfillEvent); err != nil { + // 失败了会有定时任务进行补偿 + logs.CtxWarn(ctx, "send backfill message failed, task_id=%d, err=%v", id, err) + } } return &CreateTaskResp{TaskID: &id}, nil @@ -169,7 +165,7 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e return err } if taskDO == nil { - logs.CtxError(ctx, "task not found") + logs.CtxError(ctx, "task [%d] not found", req.TaskID) return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("task not found")) } userID := session.UserIDInCtxOrEmpty(ctx) @@ -228,7 +224,7 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e } func (t *TaskServiceImpl) ListTasks(ctx context.Context, req *ListTasksReq) (resp *ListTasksResp, err error) { - taskDOs, total, err := t.TaskRepo.ListTasks(ctx, mysql.ListTaskParam{ + taskDOs, total, err := t.TaskRepo.ListTasks(ctx, repo.ListTaskParam{ WorkspaceIDs: []int64{req.WorkspaceID}, TaskFilters: req.TaskFilters, ReqLimit: req.Limit, @@ -322,15 +318,15 @@ func filterVisibleFilterFields(fields *loop_span.FilterFields) *loop_span.Filter } func (t *TaskServiceImpl) CheckTaskName(ctx context.Context, req *CheckTaskNameReq) (resp *CheckTaskNameResp, err error) { - taskPOs, _, err := t.TaskRepo.ListTasks(ctx, mysql.ListTaskParam{ + taskPOs, _, err := t.TaskRepo.ListTasks(ctx, repo.ListTaskParam{ WorkspaceIDs: []int64{req.WorkspaceID}, - TaskFilters: &filter.TaskFilterFields{ - FilterFields: []*filter.TaskFilterField{ + TaskFilters: &entity.TaskFilterFields{ + FilterFields: []*entity.TaskFilterField{ { - FieldName: gptr.Of(filter.TaskFieldNameTaskName), - FieldType: gptr.Of(filter.FieldTypeString), + FieldName: gptr.Of(entity.TaskFieldNameTaskName), + FieldType: gptr.Of(entity.FieldTypeString), Values: []string{req.Name}, - QueryType: gptr.Of(filter.QueryTypeMatch), + QueryType: gptr.Of(entity.QueryTypeMatch), }, }, }, @@ -350,25 +346,6 @@ func (t *TaskServiceImpl) CheckTaskName(ctx context.Context, req *CheckTaskNameR return &CheckTaskNameResp{Pass: gptr.Of(pass)}, nil } -// shouldTriggerBackfill 判断是否需要发送历史回溯MQ -func (t *TaskServiceImpl) shouldTriggerBackfill(taskDO *entity.ObservabilityTask) bool { - // 检查任务类型 - taskType := taskDO.TaskType - if taskType != entity.TaskTypeAutoEval && taskType != entity.TaskTypeAutoDataReflow { - return false - } - - // 检查回填时间配置 - - if taskDO.BackfillEffectiveTime == nil { - return false - } - - return taskDO.BackfillEffectiveTime.StartAt > 0 && - taskDO.BackfillEffectiveTime.EndAt > 0 && - taskDO.BackfillEffectiveTime.StartAt < taskDO.BackfillEffectiveTime.EndAt -} - // sendBackfillMessage 发送MQ消息 func (t *TaskServiceImpl) sendBackfillMessage(ctx context.Context, event *entity.BackFillEvent) error { if t.backfillProducer == nil { diff --git a/backend/modules/observability/domain/task/service/task_service_test.go b/backend/modules/observability/domain/task/service/task_service_test.go index 0a1751154..8049530a3 100755 --- a/backend/modules/observability/domain/task/service/task_service_test.go +++ b/backend/modules/observability/domain/task/service/task_service_test.go @@ -14,7 +14,6 @@ import ( "go.uber.org/mock/gomock" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" componentmq "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" @@ -22,9 +21,9 @@ import ( repomocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" - loop_span "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter" - span_processor "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_processor" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" ) @@ -463,7 +462,7 @@ func TestTaskServiceImpl_ListTasks(t *testing.T) { repoMock.EXPECT().ListTasks(gomock.Any(), gomock.Any()).Return([]*entity.ObservabilityTask{taskDO}, int64(1), nil) svc := &TaskServiceImpl{TaskRepo: repoMock} - resp, err := svc.ListTasks(context.Background(), &ListTasksReq{WorkspaceID: 2, TaskFilters: &filter.TaskFilterFields{}}) + resp, err := svc.ListTasks(context.Background(), &ListTasksReq{WorkspaceID: 2, TaskFilters: &entity.TaskFilterFields{}}) assert.NoError(t, err) if assert.NotNil(t, resp) { assert.EqualValues(t, 1, resp.Total) diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go index 44c0a52d2..2bff52aad 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go @@ -165,7 +165,7 @@ func (p *AutoEvaluteProcessor) OnCreateTaskChange(ctx context.Context, currentTa if ShouldTriggerBackfill(currentTask) && taskRuns == nil { err = p.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ CurrentTask: currentTask, - RunType: task.TaskRunTypeBackFill, + RunType: task_entity.TaskRunTypeBackFill, RunStartAt: time.Now().UnixMilli(), RunEndAt: time.Now().UnixMilli() + (currentTask.BackfillEffectiveTime.EndAt - currentTask.BackfillEffectiveTime.StartAt), }) @@ -196,7 +196,7 @@ func (p *AutoEvaluteProcessor) OnCreateTaskChange(ctx context.Context, currentTa } err = p.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ CurrentTask: currentTask, - RunType: task.TaskRunTypeNewData, + RunType: task_entity.TaskRunTypeNewData, RunStartAt: runStartAt, RunEndAt: runEndAt, }) @@ -213,21 +213,21 @@ func (p *AutoEvaluteProcessor) OnCreateTaskChange(ctx context.Context, currentTa return nil } -func (p *AutoEvaluteProcessor) OnUpdateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task.TaskStatus) error { +func (p *AutoEvaluteProcessor) OnUpdateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task_entity.TaskStatus) error { switch taskOp { - case task.TaskStatusSuccess: + case task_entity.TaskStatusSuccess: if currentTask.TaskStatus != task_entity.TaskStatusDisabled { currentTask.TaskStatus = task_entity.TaskStatusSuccess } - case task.TaskStatusRunning: + case task_entity.TaskStatusRunning: if currentTask.TaskStatus != task_entity.TaskStatusDisabled && currentTask.TaskStatus != task_entity.TaskStatusSuccess { currentTask.TaskStatus = task_entity.TaskStatusRunning } - case task.TaskStatusDisabled: + case task_entity.TaskStatusDisabled: if currentTask.TaskStatus != task_entity.TaskStatusDisabled { currentTask.TaskStatus = task_entity.TaskStatusDisabled } - case task.TaskStatusPending: + case task_entity.TaskStatusPending: if currentTask.TaskStatus == task_entity.TaskStatusPending || currentTask.TaskStatus == task_entity.TaskStatusUnstarted { currentTask.TaskStatus = task_entity.TaskStatusPending } @@ -319,7 +319,7 @@ func (p *AutoEvaluteProcessor) OnCreateTaskRunChange(ctx context.Context, param schema := convertDatasetSchemaDTO2DO(evaluationSetSchema) logs.CtxInfo(ctx, "[auto_task] CreateDataset,category:%s", category) var datasetName, exptName string - if param.RunType == task.TaskRunTypeBackFill { + if param.RunType == task_entity.TaskRunTypeBackFill { datasetName = fmt.Sprintf("%s_%s_%s_%d.%d.%d.%d", AutoEvaluateCN, BackFillCN, currentTask.Name, time.Now().Year(), time.Now().Month(), time.Now().Day(), time.Now().Unix()) exptName = fmt.Sprintf("%s_%s_%s_%d.%d.%d.%d", AutoEvaluateCN, BackFillCN, currentTask.Name, time.Now().Year(), time.Now().Month(), time.Now().Day(), time.Now().Unix()) } else { @@ -394,7 +394,7 @@ func (p *AutoEvaluteProcessor) OnCreateTaskRunChange(ctx context.Context, param taskRun := &task_entity.TaskRun{ TaskID: currentTask.ID, WorkspaceID: currentTask.WorkspaceID, - TaskType: task_entity.TaskRunType(param.RunType), + TaskType: param.RunType, RunStatus: task_entity.TaskRunStatusRunning, RunStartAt: time.UnixMilli(param.RunStartAt), RunEndAt: time.UnixMilli(param.RunEndAt), diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go index fe2dab656..cbedbe2c9 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go @@ -108,7 +108,7 @@ func (m *taskRepoMockAdapter) RemoveNonFinalTask(context.Context, string, int64) return nil } -func (m *taskRepoMockAdapter) GetTaskByRedis(context.Context, int64) (*taskentity.ObservabilityTask, error) { +func (m *taskRepoMockAdapter) GetTaskByCache(context.Context, int64) (*taskentity.ObservabilityTask, error) { return nil, nil } @@ -432,13 +432,13 @@ func TestAutoEvaluteProcessor_OnUpdateTaskChange(t *testing.T) { cases := []struct { name string initial taskentity.TaskStatus - op task.TaskStatus + op taskentity.TaskStatus expect taskentity.TaskStatus }{ - {"success", taskentity.TaskStatusRunning, task.TaskStatusSuccess, taskentity.TaskStatusSuccess}, - {"running", taskentity.TaskStatusPending, task.TaskStatusRunning, taskentity.TaskStatusRunning}, - {"disable", taskentity.TaskStatusRunning, task.TaskStatusDisabled, taskentity.TaskStatusDisabled}, - {"pending", taskentity.TaskStatusUnstarted, task.TaskStatusPending, taskentity.TaskStatusPending}, + {"success", taskentity.TaskStatusRunning, taskentity.TaskStatusSuccess, taskentity.TaskStatusSuccess}, + {"running", taskentity.TaskStatusPending, taskentity.TaskStatusRunning, taskentity.TaskStatusRunning}, + {"disable", taskentity.TaskStatusRunning, taskentity.TaskStatusDisabled, taskentity.TaskStatusDisabled}, + {"pending", taskentity.TaskStatusUnstarted, taskentity.TaskStatusPending, taskentity.TaskStatusPending}, } for _, tt := range cases { @@ -481,7 +481,7 @@ func TestAutoEvaluteProcessor_OnCreateTaskRunChange(t *testing.T) { taskObj := buildTestTask(t) param := taskexe.OnCreateTaskRunChangeReq{ CurrentTask: taskObj, - RunType: task.TaskRunTypeNewData, + RunType: taskentity.TaskRunTypeNewData, RunStartAt: time.Now().Add(-time.Minute).UnixMilli(), RunEndAt: time.Now().Add(time.Hour).UnixMilli(), } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/noop.go b/backend/modules/observability/domain/task/service/taskexe/processor/noop.go index 97da34ddb..d61466a54 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/noop.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/noop.go @@ -6,8 +6,7 @@ package processor import ( "context" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - task_entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" ) @@ -27,11 +26,11 @@ func (p *NoopTaskProcessor) Invoke(ctx context.Context, trigger *taskexe.Trigger return nil } -func (p *NoopTaskProcessor) OnCreateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask) error { +func (p *NoopTaskProcessor) OnCreateTaskChange(ctx context.Context, currentTask *entity.ObservabilityTask) error { return nil } -func (p *NoopTaskProcessor) OnUpdateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task.TaskStatus) error { +func (p *NoopTaskProcessor) OnUpdateTaskChange(ctx context.Context, currentTask *entity.ObservabilityTask, taskOp entity.TaskStatus) error { return nil } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/utils_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/utils_test.go index d4c6a33b3..2d496a91e 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/utils_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/utils_test.go @@ -57,9 +57,9 @@ func TestShouldTriggerBackfill(t *testing.T) { task *taskentity.ObservabilityTask expected bool }{ - {"nil_time", &taskentity.ObservabilityTask{TaskType: task.TaskTypeAutoEval}, false}, - {"invalid_type", &taskentity.ObservabilityTask{TaskType: "manual"}, false}, - {"invalid_range", &taskentity.ObservabilityTask{TaskType: task.TaskTypeAutoEval, BackfillEffectiveTime: &taskentity.EffectiveTime{StartAt: 10, EndAt: 5}}, false}, + {"nil_time", &taskentity.ObservabilityTask{TaskType: taskentity.TaskTypeAutoEval}, false}, + {"invalid_type", &taskentity.ObservabilityTask{TaskType: taskentity.TaskType("manual")}, false}, + {"invalid_range", &taskentity.ObservabilityTask{TaskType: taskentity.TaskTypeAutoEval, BackfillEffectiveTime: &taskentity.EffectiveTime{StartAt: 10, EndAt: 5}}, false}, {"valid", baseTask, true}, } @@ -90,10 +90,10 @@ func TestShouldTriggerNewData(t *testing.T) { task *taskentity.ObservabilityTask expected bool }{ - {"invalid_type", &taskentity.ObservabilityTask{TaskType: "manual"}, false}, - {"nil_time", &taskentity.ObservabilityTask{TaskType: task.TaskTypeAutoEval}, false}, - {"invalid_range", &taskentity.ObservabilityTask{TaskType: task.TaskTypeAutoEval, EffectiveTime: &taskentity.EffectiveTime{StartAt: 20, EndAt: 10}}, false}, - {"start_in_future", &taskentity.ObservabilityTask{TaskType: task.TaskTypeAutoEval, EffectiveTime: &taskentity.EffectiveTime{StartAt: now.Add(time.Hour).UnixMilli(), EndAt: now.Add(2 * time.Hour).UnixMilli()}}, false}, + {"invalid_type", &taskentity.ObservabilityTask{TaskType: taskentity.TaskType("manual")}, false}, + {"nil_time", &taskentity.ObservabilityTask{TaskType: taskentity.TaskTypeAutoEval}, false}, + {"invalid_range", &taskentity.ObservabilityTask{TaskType: taskentity.TaskTypeAutoEval, EffectiveTime: &taskentity.EffectiveTime{StartAt: 20, EndAt: 10}}, false}, + {"start_in_future", &taskentity.ObservabilityTask{TaskType: taskentity.TaskTypeAutoEval, EffectiveTime: &taskentity.EffectiveTime{StartAt: now.Add(time.Hour).UnixMilli(), EndAt: now.Add(2 * time.Hour).UnixMilli()}}, false}, {"valid", baseTask, true}, } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index d80f2477c..de5310f18 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -126,7 +126,7 @@ func (h *TraceHubServiceImpl) setBackfillTask(ctx context.Context, event *entity bufCap: 0, maxFlushInterval: time.Second * 5, taskRepo: h.taskRepo, - runType: task.TaskRunTypeBackFill, + runType: entity.TaskRunTypeBackFill, } return sub, nil diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go index 4d6b2a70c..0b8276182 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go @@ -22,11 +22,11 @@ import ( taskrepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" - repo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" trepo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo/mocks" builder_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/mocks" spanfilter_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_filter/mocks" - span_processor "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service/trace/span_processor" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" @@ -83,7 +83,7 @@ func TestTraceHubServiceImpl_SetBackfillTask(t *testing.T) { require.NoError(t, err) require.NotNil(t, sub) require.Equal(t, int64(1), sub.taskID) - require.Equal(t, task.TaskRunTypeBackFill, sub.runType) + require.Equal(t, entity.TaskRunTypeBackFill, sub.runType) } func TestTraceHubServiceImpl_SetBackfillTaskNotFound(t *testing.T) { @@ -199,7 +199,7 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_DispatchError(t *testing.T) { t: taskDTO, tr: taskRunDTO, processor: proc, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, taskRepo: mockRepo, } @@ -667,7 +667,7 @@ func newBackfillSubscriber(taskRepo taskrepo.ITaskRepo, now time.Time) (*spanSub tr: taskRun, processor: proc, taskRepo: taskRepo, - runType: task.TaskRunTypeBackFill, + runType: entity.TaskRunTypeBackFill, } return sub, proc } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go index d9e1ffceb..792fdd4e5 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go @@ -8,17 +8,18 @@ import ( "fmt" "os" "slices" + "strconv" "time" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" "github.com/pkg/errors" + "github.com/samber/lo" ) // TaskRunCountInfo represents the TaskRunCount information structure @@ -194,7 +195,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { if !taskPO.Sampler.IsCycle { err = proc.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ CurrentTask: taskPO, - RunType: task.TaskRunTypeNewData, + RunType: entity.TaskRunTypeNewData, RunStartAt: taskPO.EffectiveTime.StartAt, RunEndAt: taskPO.EffectiveTime.EndAt, }) @@ -202,7 +203,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { logs.CtxError(ctx, "OnCreateTaskRunChange err:%v", err) continue } - err = proc.OnUpdateTaskChange(ctx, taskPO, task.TaskStatusRunning) + err = proc.OnUpdateTaskChange(ctx, taskPO, entity.TaskStatusRunning) if err != nil { logs.CtxError(ctx, "OnUpdateTaskChange err:%v", err) continue @@ -210,7 +211,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { } else { err = proc.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ CurrentTask: taskPO, - RunType: task.TaskRunTypeNewData, + RunType: entity.TaskRunTypeNewData, RunStartAt: taskRun.RunEndAt.UnixMilli(), RunEndAt: taskRun.RunEndAt.UnixMilli() + (taskRun.RunEndAt.UnixMilli() - taskRun.RunStartAt.UnixMilli()), }) @@ -242,7 +243,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { if taskPO.Sampler.IsCycle { err = proc.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ CurrentTask: taskPO, - RunType: task.TaskRunTypeNewData, + RunType: entity.TaskRunTypeNewData, RunStartAt: taskRun.RunEndAt.UnixMilli(), RunEndAt: taskRun.RunEndAt.UnixMilli() + (taskRun.RunEndAt.UnixMilli() - taskRun.RunStartAt.UnixMilli()), }) @@ -329,7 +330,11 @@ func (h *TraceHubServiceImpl) syncTaskCache() { logs.CtxInfo(ctx, "Start syncing task cache...") // 1. Retrieve spaceID, botID, and task information for all non-final tasks from the database - spaceIDs, botIDs, tasks := h.taskRepo.GetObjListWithTask(ctx) + spaceIDs, botIDs, tasks, err := h.getNonFinalTaskInfos(ctx) + if err != nil { + logs.CtxError(ctx, "Failed to get non-final task list", "err", err) + return + } logs.CtxInfo(ctx, "Retrieved task information, taskCount:%d, spaceCount:%d, botCount:%d", len(tasks), len(spaceIDs), len(botIDs)) // 2. Build a new cache map @@ -426,7 +431,7 @@ func (h *TraceHubServiceImpl) updateTaskRunDetail(ctx context.Context, info *Tas func (h *TraceHubServiceImpl) listNonFinalTaskByRedis(ctx context.Context, spaceID string) ([]*entity.ObservabilityTask, error) { var taskPOs []*entity.ObservabilityTask - nonFinalTaskIDs, err := h.taskRepo.ListNonFinalTask(ctx, spaceID) + nonFinalTaskIDs, err := h.taskRepo.ListNonFinalTaskBySpaceID(ctx, spaceID) if err != nil { logs.CtxError(ctx, "Failed to get non-final task list", "err", err) return nil, err @@ -436,7 +441,7 @@ func (h *TraceHubServiceImpl) listNonFinalTaskByRedis(ctx context.Context, space return taskPOs, nil } for _, taskID := range nonFinalTaskIDs { - taskPO, err := h.taskRepo.GetTaskByRedis(ctx, taskID) + taskPO, err := h.taskRepo.GetTaskByCache(ctx, taskID) if err != nil { logs.CtxError(ctx, "Failed to get task", "err", err) return nil, err @@ -455,20 +460,20 @@ func (h *TraceHubServiceImpl) listNonFinalTask(ctx context.Context) ([]*entity.O const limit int32 = 500 // Paginate through all tasks for { - tasklist, _, err := h.taskRepo.ListTasks(ctx, mysql.ListTaskParam{ + tasklist, _, err := h.taskRepo.ListTasks(ctx, repo.ListTaskParam{ ReqLimit: limit, ReqOffset: offset, - TaskFilters: &filter.TaskFilterFields{ - FilterFields: []*filter.TaskFilterField{ + TaskFilters: &entity.TaskFilterFields{ + FilterFields: []*entity.TaskFilterField{ { - FieldName: ptr.Of(filter.TaskFieldNameTaskStatus), + FieldName: ptr.Of(entity.TaskFieldNameTaskStatus), Values: []string{ - string(task.TaskStatusUnstarted), - string(task.TaskStatusRunning), - string(task.TaskStatusPending), + string(entity.TaskStatusUnstarted), + string(entity.TaskStatusRunning), + string(entity.TaskStatusPending), }, - QueryType: ptr.Of(filter.QueryTypeIn), - FieldType: ptr.Of(filter.FieldTypeString), + QueryType: ptr.Of(entity.QueryTypeIn), + FieldType: ptr.Of(entity.FieldTypeString), }, }, }, @@ -503,27 +508,27 @@ func (h *TraceHubServiceImpl) listSyncTaskRunTask(ctx context.Context) ([]*entit const limit int32 = 1000 // Paginate through all tasks for { - tasklist, _, err := h.taskRepo.ListTasks(ctx, mysql.ListTaskParam{ + tasklist, _, err := h.taskRepo.ListTasks(ctx, repo.ListTaskParam{ ReqLimit: limit, ReqOffset: offset, - TaskFilters: &filter.TaskFilterFields{ - FilterFields: []*filter.TaskFilterField{ + TaskFilters: &entity.TaskFilterFields{ + FilterFields: []*entity.TaskFilterField{ { - FieldName: ptr.Of(filter.TaskFieldNameTaskStatus), + FieldName: ptr.Of(entity.TaskFieldNameTaskStatus), Values: []string{ - string(task.TaskStatusSuccess), - string(task.TaskStatusDisabled), + string(entity.TaskStatusSuccess), + string(entity.TaskStatusDisabled), }, - QueryType: ptr.Of(filter.QueryTypeIn), - FieldType: ptr.Of(filter.FieldTypeString), + QueryType: ptr.Of(entity.QueryTypeIn), + FieldType: ptr.Of(entity.FieldTypeString), }, { - FieldName: ptr.Of("updated_at"), + FieldName: ptr.Of(entity.TaskFieldName("updated_at")), Values: []string{ fmt.Sprintf("%d", time.Now().Add(-24*time.Hour).UnixMilli()), }, - QueryType: ptr.Of(filter.QueryTypeGt), - FieldType: ptr.Of(filter.FieldTypeLong), + QueryType: ptr.Of(entity.QueryTypeGt), + FieldType: ptr.Of(entity.FieldTypeLong), }, }, }, @@ -546,3 +551,41 @@ func (h *TraceHubServiceImpl) listSyncTaskRunTask(ctx context.Context) ([]*entit } return taskDOs, nil } + +func (h *TraceHubServiceImpl) getNonFinalTaskInfos(ctx context.Context) ([]string, []string, []*entity.ObservabilityTask, error) { + tasks, err := h.taskRepo.ListNonFinalTasks(ctx) + if err != nil { + return nil, nil, nil, err + } + + spaceMap := make(map[string]interface{}) + botMap := make(map[string]interface{}) + + for _, task := range tasks { + spaceMap[strconv.FormatInt(task.WorkspaceID, 10)] = struct{}{} + if task.SpanFilter != nil && task.SpanFilter.Filters.FilterFields != nil { + extractBotIDFromFilters(task.SpanFilter.Filters.FilterFields, botMap) + } + } + + return lo.Keys(spaceMap), lo.Keys(botMap), tasks, nil +} + +// extractBotIDFromFilters 递归提取过滤器中的 bot_id 值,包括 SubFilter +func extractBotIDFromFilters(filterFields []*loop_span.FilterField, botMap map[string]interface{}) { + for _, filterField := range filterFields { + if filterField == nil { + continue + } + // 检查当前 FilterField 的 FieldName + if filterField.FieldName == "bot_id" { + for _, v := range filterField.Values { + botMap[v] = struct{}{} + } + } + // 递归处理 SubFilter + if filterField.SubFilter != nil && filterField.SubFilter.FilterFields != nil { + extractBotIDFromFilters(filterField.SubFilter.FilterFields, botMap) + } + } +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go index 4084ad504..6e5a71f70 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go @@ -13,7 +13,6 @@ import ( "go.uber.org/mock/gomock" lock_mocks "github.com/coze-dev/coze-loop/backend/infra/lock/mocks" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" @@ -26,7 +25,7 @@ type trackingProcessor struct { *stubProcessor finishReqs []taskexe.OnFinishTaskChangeReq createRunReqs []taskexe.OnCreateTaskRunChangeReq - updateStatuses []string + updateStatuses []entity.TaskStatus } func newTrackingProcessor() *trackingProcessor { @@ -43,7 +42,7 @@ func (p *trackingProcessor) OnCreateTaskRunChange(ctx context.Context, req taske return p.stubProcessor.OnCreateTaskRunChange(ctx, req) } -func (p *trackingProcessor) OnUpdateTaskChange(ctx context.Context, obsTask *entity.ObservabilityTask, status string) error { +func (p *trackingProcessor) OnUpdateTaskChange(ctx context.Context, obsTask *entity.ObservabilityTask, status entity.TaskStatus) error { p.updateStatuses = append(p.updateStatuses, status) return p.stubProcessor.OnUpdateTaskChange(ctx, obsTask, status) } @@ -141,9 +140,9 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { }, assert: func(t *testing.T, _ *TraceHubServiceImpl, proc *trackingProcessor) { require.Len(t, proc.createRunReqs, 1) - require.Equal(t, task.TaskRunTypeNewData, proc.createRunReqs[0].RunType) + require.Equal(t, entity.TaskRunTypeNewData, proc.createRunReqs[0].RunType) require.Len(t, proc.updateStatuses, 1) - require.Equal(t, string(task.TaskStatusRunning), proc.updateStatuses[0]) + require.Equal(t, string(entity.TaskStatusRunning), proc.updateStatuses[0]) }, }, { @@ -340,7 +339,6 @@ func TestTraceHubServiceImpl_syncTaskCache(t *testing.T) { workspaceIDs := []string{"space-1"} botIDs := []string{"bot-1"} tasks := []*entity.ObservabilityTask{{ID: 100}} - mockRepo.EXPECT().GetObjListWithTask(gomock.Any()).Return(workspaceIDs, botIDs, tasks) before := time.Now() impl.syncTaskCache() diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go index 94918b15b..22a9c9286 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go @@ -96,7 +96,7 @@ func (h *TraceHubServiceImpl) getSubscriberOfSpan(ctx context.Context, span *loo flushWait: sync.WaitGroup{}, maxFlushInterval: time.Second * 5, taskRepo: h.taskRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, buildHelper: h.buildHelper, }) } @@ -153,7 +153,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S merr = multierror.Append(merr, errors.WithMessagef(err, "task is unstarted, need sub.Creative,creative processor, task_id=%d", sub.taskID)) continue } - if err := sub.processor.OnUpdateTaskChange(ctx, tconv.TaskDTO2DO(sub.t), task.TaskStatusRunning); err != nil { + if err := sub.processor.OnUpdateTaskChange(ctx, tconv.TaskDTO2DO(sub.t), entity.TaskStatusRunning); err != nil { logs.CtxWarn(ctx, "OnUpdateTaskChange, task_id=%d, err=%v", sub.taskID, err) continue } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go index 9f7d795fa..6f6689b1d 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go @@ -67,8 +67,8 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { TaskType: entity.TaskTypeAutoEval, TaskStatus: entity.TaskStatusRunning, SpanFilter: &entity.SpanFilterFields{ - PlatformType: common.PlatformTypeLoopAll, - SpanListType: common.SpanListTypeAllSpan, + PlatformType: loop_span.PlatformDefault, + SpanListType: loop_span.SpanListTypeAllSpan, Filters: loop_span.FilterFields{ QueryAndOr: ptr.Of(loop_span.QueryAndOrEnumAnd), FilterFields: []*loop_span.FilterField{}, @@ -103,8 +103,7 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { return nil }, ).AnyTimes() - mockRepo.EXPECT().ListNonFinalTask(gomock.Any(), "space-1").Return([]int64{taskDO.ID}, nil).AnyTimes() - mockRepo.EXPECT().GetTaskByRedis(gomock.Any(), taskDO.ID).Return(taskDO, nil).AnyTimes() + mockRepo.EXPECT().GetTaskByCache(gomock.Any(), taskDO.ID).Return(taskDO, nil).AnyTimes() mockFilter.EXPECT().BuildBasicSpanFilter(gomock.Any(), gomock.Any()).Return(nil, false, nil).AnyTimes() mockFilter.EXPECT().BuildALLSpanFilter(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() mockBuilder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), gomock.Any()).Return(mockFilter, nil).AnyTimes() @@ -198,14 +197,14 @@ func TestTraceHubServiceImpl_preDispatchHandlesUnstartedAndLimits(t *testing.T) }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } taskRunConfig := &entity.TaskRun{ ID: 303, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-90 * time.Minute), RunEndAt: now.Add(-30 * time.Minute), @@ -275,7 +274,7 @@ func TestTraceHubServiceImpl_preDispatchHandlesMissingTaskRunConfig(t *testing.T }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) @@ -335,14 +334,14 @@ func TestTraceHubServiceImpl_preDispatchHandlesNonCycle(t *testing.T) { }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } taskRunConfig := &entity.TaskRun{ ID: 707, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-30 * time.Minute), RunEndAt: now.Add(30 * time.Minute), @@ -404,7 +403,7 @@ func TestTraceHubServiceImpl_preDispatchHandlesCycleDefaultUnit(t *testing.T) { }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) @@ -466,14 +465,14 @@ func TestTraceHubServiceImpl_preDispatchTimeLimitFinishError(t *testing.T) { }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } taskRunConfig := &entity.TaskRun{ ID: 1101, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-3 * time.Hour), RunEndAt: now.Add(-2 * time.Hour), @@ -533,14 +532,14 @@ func TestTraceHubServiceImpl_preDispatchSampleLimitFinishError(t *testing.T) { }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } taskRunConfig := &entity.TaskRun{ ID: 1404, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-30 * time.Minute), RunEndAt: now.Add(30 * time.Minute), @@ -600,14 +599,14 @@ func TestTraceHubServiceImpl_preDispatchCycleTimeLimitFinishError(t *testing.T) }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } taskRunConfig := &entity.TaskRun{ ID: 1707, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-2 * time.Hour), RunEndAt: now.Add(-time.Minute), @@ -667,14 +666,14 @@ func TestTraceHubServiceImpl_preDispatchCycleCountFinishError(t *testing.T) { }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } taskRunConfig := &entity.TaskRun{ ID: 2009, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-30 * time.Minute), RunEndAt: now.Add(30 * time.Minute), @@ -730,7 +729,7 @@ func TestTraceHubServiceImpl_preDispatchCreativeError(t *testing.T) { }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } impl := &TraceHubServiceImpl{taskRepo: mockRepo} @@ -772,7 +771,7 @@ func TestTraceHubServiceImpl_preDispatchAggregatesErrors(t *testing.T) { }, processor: firstProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } secondStartAt := now.Add(-2 * time.Hour).UnixMilli() @@ -790,7 +789,7 @@ func TestTraceHubServiceImpl_preDispatchAggregatesErrors(t *testing.T) { ID: 101, TaskID: secondTaskID, WorkspaceID: secondWorkspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-3 * time.Hour), RunEndAt: now.Add(-90 * time.Minute), @@ -811,7 +810,7 @@ func TestTraceHubServiceImpl_preDispatchAggregatesErrors(t *testing.T) { }, processor: secondProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), secondTaskID).Return(secondRun, nil) @@ -867,7 +866,7 @@ func TestTraceHubServiceImpl_preDispatchUpdateError(t *testing.T) { }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } impl := &TraceHubServiceImpl{taskRepo: mockRepo} @@ -913,7 +912,7 @@ func TestTraceHubServiceImpl_preDispatchListTaskRunError(t *testing.T) { }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, errors.New("repo fail")) @@ -963,7 +962,7 @@ func TestTraceHubServiceImpl_preDispatchTaskRunConfigDay(t *testing.T) { }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) @@ -1020,14 +1019,14 @@ func TestTraceHubServiceImpl_preDispatchCycleCreativeError(t *testing.T) { }, processor: stubProc, taskRepo: mockRepo, - runType: task.TaskRunTypeNewData, + runType: entity.TaskRunTypeNewData, } taskRunConfig := &entity.TaskRun{ ID: 3102, TaskID: taskID, WorkspaceID: workspaceID, - TaskType: task.TaskRunTypeNewData, + TaskType: entity.TaskRunTypeNewData, RunStatus: task.TaskStatusRunning, RunStartAt: now.Add(-2 * time.Hour), RunEndAt: now.Add(-time.Minute), diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go index 8e05b1831..106c43ea8 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go @@ -35,7 +35,7 @@ type spanSubscriber struct { flushWait sync.WaitGroup maxFlushInterval time.Duration taskRepo repo.ITaskRepo - runType task.TaskRunType + runType entity.TaskRunType buildHelper service.TraceFilterProcessorBuilder } @@ -155,7 +155,7 @@ func buildBuiltinFilters(ctx context.Context, f span_filter.Filter, req *ListSpa func (s *spanSubscriber) Creative(ctx context.Context, runStartAt, runEndAt int64) error { err := s.processor.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ - CurrentTask: tconv.TaskDTO2DO(s.t, "", nil), + CurrentTask: tconv.TaskDTO2DO(s.t), RunType: s.runType, RunStartAt: runStartAt, RunEndAt: runEndAt, @@ -169,7 +169,7 @@ func (s *spanSubscriber) Creative(ctx context.Context, runStartAt, runEndAt int6 func (s *spanSubscriber) AddSpan(ctx context.Context, span *loop_span.Span) error { var taskRunConfig *entity.TaskRun var err error - if s.runType == task.TaskRunTypeNewData { + if s.runType == entity.TaskRunTypeNewData { taskRunConfig, err = s.taskRepo.GetLatestNewDataTaskRun(ctx, nil, s.t.GetID()) if err != nil { logs.CtxWarn(ctx, "get latest new data task run failed, task_id=%d, err: %v", s.t.GetID(), err) @@ -195,7 +195,7 @@ func (s *spanSubscriber) AddSpan(ctx context.Context, span *loop_span.Span) erro logs.CtxWarn(ctx, "span start time is before task cycle start time, trace_id=%s, span_id=%s", span.TraceID, span.SpanID) return nil } - trigger := &taskexe.Trigger{Task: tconv.TaskDTO2DO(s.t, "", nil), Span: span, TaskRun: taskRunConfig} + trigger := &taskexe.Trigger{Task: tconv.TaskDTO2DO(s.t), Span: span, TaskRun: taskRunConfig} logs.CtxInfo(ctx, "invoke processor, trigger: %v", trigger) err = s.processor.Invoke(ctx, trigger) if err != nil { diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go index 15134e6c9..e2645ae57 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go @@ -49,7 +49,7 @@ func (s *stubProcessor) OnCreateTaskChange(context.Context, *entity.Observabilit return s.createTaskErr } -func (s *stubProcessor) OnUpdateTaskChange(context.Context, *entity.ObservabilityTask, string) error { +func (s *stubProcessor) OnUpdateTaskChange(context.Context, *entity.ObservabilityTask, entity.TaskStatus) error { s.updateCallCount++ return s.updateErr } diff --git a/backend/modules/observability/domain/task/service/taskexe/types.go b/backend/modules/observability/domain/task/service/taskexe/types.go index 40a7dee5c..6648398f2 100644 --- a/backend/modules/observability/domain/task/service/taskexe/types.go +++ b/backend/modules/observability/domain/task/service/taskexe/types.go @@ -7,15 +7,14 @@ import ( "context" "errors" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - task_entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" ) type Trigger struct { - Task *task_entity.ObservabilityTask + Task *entity.ObservabilityTask Span *loop_span.Span - TaskRun *task_entity.TaskRun + TaskRun *entity.TaskRun } var ( @@ -24,18 +23,18 @@ var ( ) type OnCreateTaskRunChangeReq struct { - CurrentTask *task_entity.ObservabilityTask - RunType task.TaskRunType + CurrentTask *entity.ObservabilityTask + RunType entity.TaskRunType RunStartAt int64 RunEndAt int64 } type OnFinishTaskRunChangeReq struct { - Task *task_entity.ObservabilityTask - TaskRun *task_entity.TaskRun + Task *entity.ObservabilityTask + TaskRun *entity.TaskRun } type OnFinishTaskChangeReq struct { - Task *task_entity.ObservabilityTask - TaskRun *task_entity.TaskRun + Task *entity.ObservabilityTask + TaskRun *entity.TaskRun IsFinish bool } @@ -43,8 +42,8 @@ type Processor interface { ValidateConfig(ctx context.Context, config any) error Invoke(ctx context.Context, trigger *Trigger) error - OnCreateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask) error - OnUpdateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task.TaskStatus) error + OnCreateTaskChange(ctx context.Context, currentTask *entity.ObservabilityTask) error + OnUpdateTaskChange(ctx context.Context, currentTask *entity.ObservabilityTask, taskOp entity.TaskStatus) error OnFinishTaskChange(ctx context.Context, param OnFinishTaskChangeReq) error OnCreateTaskRunChange(ctx context.Context, param OnCreateTaskRunChangeReq) error diff --git a/backend/modules/observability/domain/trace/entity/common/page.go b/backend/modules/observability/domain/trace/entity/common/page.go index 69f2bf879..ad3ade039 100644 --- a/backend/modules/observability/domain/trace/entity/common/page.go +++ b/backend/modules/observability/domain/trace/entity/common/page.go @@ -2,3 +2,8 @@ // SPDX-License-Identifier: Apache-2.0 package common + +type OrderBy struct { + Field string + IsAsc bool +} diff --git a/backend/modules/observability/domain/trace/entity/loop_span/filter.go b/backend/modules/observability/domain/trace/entity/loop_span/filter.go index bef15eb80..c70273889 100644 --- a/backend/modules/observability/domain/trace/entity/loop_span/filter.go +++ b/backend/modules/observability/domain/trace/entity/loop_span/filter.go @@ -364,6 +364,15 @@ func (f *FilterField) CheckValue(val any) bool { } } +func (f *FilterField) SetHidden(hidden bool) { + f.Hidden = hidden + if f.SubFilter != nil { + for _, subFilters := range f.SubFilter.FilterFields { + subFilters.SetHidden(hidden) + } + } +} + func CompareBool(val bool, values []bool, qType QueryTypeEnum) bool { switch qType { case QueryTypeEnumEq: diff --git a/backend/modules/observability/domain/trace/service/trace_service.go b/backend/modules/observability/domain/trace/service/trace_service.go index e89994fc9..6dda97cba 100644 --- a/backend/modules/observability/domain/trace/service/trace_service.go +++ b/backend/modules/observability/domain/trace/service/trace_service.go @@ -11,8 +11,7 @@ import ( "time" tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" - taskRepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + taskrepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "golang.org/x/sync/errgroup" "github.com/bytedance/gg/gptr" @@ -37,7 +36,7 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/lang/goroutine" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" - time_util "github.com/coze-dev/coze-loop/backend/pkg/time" + timeutil "github.com/coze-dev/coze-loop/backend/pkg/time" "github.com/samber/lo" ) @@ -277,7 +276,7 @@ func NewTraceServiceImpl( buildHelper TraceFilterProcessorBuilder, tenantProvider tenant.ITenantProvider, evalSvc rpc.IEvaluatorRPCAdapter, - taskRepo taskRepo.ITaskRepo, + taskRepo taskrepo.ITaskRepo, ) (ITraceService, error) { return &TraceServiceImpl{ traceRepo: tRepo, @@ -301,7 +300,7 @@ type TraceServiceImpl struct { buildHelper TraceFilterProcessorBuilder tenantProvider tenant.ITenantProvider evalSvc rpc.IEvaluatorRPCAdapter - taskRepo taskRepo.ITaskRepo + taskRepo taskrepo.ITaskRepo } func (r *TraceServiceImpl) GetTrace(ctx context.Context, req *GetTraceReq) (*GetTraceResp, error) { @@ -1230,7 +1229,7 @@ func (r *TraceServiceImpl) ListAnnotationEvaluators(ctx context.Context, req *Li evaluators = append(evaluators, evaluatorList...) } else { // 没有name先查task - taskDOs, _, err := r.taskRepo.ListTasks(ctx, mysql.ListTaskParam{ + taskDOs, _, err := r.taskRepo.ListTasks(ctx, taskrepo.ListTaskParam{ WorkspaceIDs: []int64{req.WorkspaceID}, ReqLimit: int32(500), ReqOffset: int32(0), @@ -1432,7 +1431,7 @@ func processLatencyFilter(f *loop_span.FilterField) error { if err != nil { return fmt.Errorf("fail to parse long value %s, %v", val, err) } - integer = time_util.MillSec2MicroSec(integer) + integer = timeutil.MillSec2MicroSec(integer) micros = append(micros, strconv.FormatInt(integer, 10)) } f.Values = micros diff --git a/backend/modules/observability/domain/trace/service/trace_service_test.go b/backend/modules/observability/domain/trace/service/trace_service_test.go index 1d5839ded..0660f2839 100644 --- a/backend/modules/observability/domain/trace/service/trace_service_test.go +++ b/backend/modules/observability/domain/trace/service/trace_service_test.go @@ -52,7 +52,7 @@ func newTaskRepoMock(ctrl *gomock.Controller) *taskRepoMock { } func (m *taskRepoMock) ListNonFinalTask(context.Context, string) ([]int64, error) { - panic("unexpected call to ListNonFinalTask in taskRepoMock") + panic("unexpected call to ListNonFinalTaskBySpaceID in taskRepoMock") } func (m *taskRepoMock) AddNonFinalTask(context.Context, string, int64) error { @@ -63,8 +63,8 @@ func (m *taskRepoMock) RemoveNonFinalTask(context.Context, string, int64) error panic("unexpected call to RemoveNonFinalTask in taskRepoMock") } -func (m *taskRepoMock) GetTaskByRedis(context.Context, int64) (*taskentity.ObservabilityTask, error) { - panic("unexpected call to GetTaskByRedis in taskRepoMock") +func (m *taskRepoMock) GetTaskByCache(context.Context, int64) (*taskentity.ObservabilityTask, error) { + panic("unexpected call to GetTaskByCache in taskRepoMock") } func (m *taskRepoMock) SetTask(context.Context, *taskentity.ObservabilityTask) error { diff --git a/backend/modules/observability/infra/repo/mysql/task.go b/backend/modules/observability/infra/repo/mysql/task.go index b1fb01581..9c478b5d6 100644 --- a/backend/modules/observability/infra/repo/mysql/task.go +++ b/backend/modules/observability/infra/repo/mysql/task.go @@ -11,9 +11,8 @@ import ( "time" "github.com/coze-dev/coze-loop/backend/infra/db" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" - tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/model" genquery "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/query" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" @@ -32,7 +31,7 @@ const ( type ListTaskParam struct { WorkspaceIDs []int64 - TaskFilters *filter.TaskFilterFields + TaskFilters *entity.TaskFilterFields ReqLimit int32 ReqOffset int32 OrderBy *common.OrderBy @@ -46,7 +45,7 @@ type ITaskDao interface { DeleteTask(ctx context.Context, id int64, workspaceID int64, userID string) error ListTasks(ctx context.Context, param ListTaskParam) ([]*model.ObservabilityTask, int64, error) UpdateTaskWithOCC(ctx context.Context, id int64, workspaceID int64, updateMap map[string]interface{}) error - GetObjListWithTask(ctx context.Context) ([]string, []string, []*model.ObservabilityTask, error) + ListNonFinalTasks(ctx context.Context) ([]*model.ObservabilityTask, error) } func NewTaskDaoImpl(db db.Provider) ITaskDao { @@ -133,7 +132,13 @@ func (v *TaskDaoImpl) ListTasks(ctx context.Context, param ListTaskParam) ([]*mo return nil, 0, errorx.WrapByCode(err, obErrorx.CommonMySqlErrorCode) } // order by - qd = qd.Order(v.order(q, param.OrderBy.GetField(), param.OrderBy.GetIsAsc())) + orderField := "" + orderAsc := false + if param.OrderBy != nil { + orderField = param.OrderBy.Field + orderAsc = param.OrderBy.IsAsc + } + qd = qd.Order(v.order(q, orderField, orderAsc)) // 计算分页参数 limit, offset := calculatePagination(param.ReqLimit, param.ReqOffset) results, err := qd.Limit(limit).Offset(offset).Find() @@ -144,7 +149,7 @@ func (v *TaskDaoImpl) ListTasks(ctx context.Context, param ListTaskParam) ([]*mo } // 处理任务过滤条件 -func (v *TaskDaoImpl) applyTaskFilters(q *genquery.Query, taskFilters *filter.TaskFilterFields) (field.Expr, error) { +func (v *TaskDaoImpl) applyTaskFilters(q *genquery.Query, taskFilters *entity.TaskFilterFields) (field.Expr, error) { if taskFilters == nil || len(taskFilters.FilterFields) == 0 { return nil, nil } @@ -171,28 +176,28 @@ func (v *TaskDaoImpl) applyTaskFilters(q *genquery.Query, taskFilters *filter.Ta } // 构建单个过滤条件 -func (v *TaskDaoImpl) buildSingleFilterExpr(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildSingleFilterExpr(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if f.FieldName == nil || f.QueryType == nil { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("field name or query type is nil")) } switch *f.FieldName { - case filter.TaskFieldNameTaskName: + case entity.TaskFieldNameTaskName: return v.buildTaskNameFilter(q, f) - case filter.TaskFieldNameTaskType: + case entity.TaskFieldNameTaskType: return v.buildTaskTypeFilter(q, f) - case filter.TaskFieldNameTaskStatus: + case entity.TaskFieldNameTaskStatus: return v.buildTaskStatusFilter(q, f) - case filter.TaskFieldNameCreatedBy: + case entity.TaskFieldNameCreatedBy: return v.buildCreatedByFilter(q, f) - case filter.TaskFieldNameSampleRate: + case entity.TaskFieldNameSampleRate: return v.buildSampleRateFilter(q, f) case "task_id": return v.buildTaskIDFilter(q, f) case "updated_at": return v.buildUpdateAtFilter(q, f) default: - return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithMsgParam("invalid filter field name: %s", *f.FieldName)) + return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithMsgParam("invalid filter field name: %s", string(*f.FieldName))) } } @@ -202,7 +207,7 @@ func (v *TaskDaoImpl) combineExpressions(expressions []field.Expr, relation stri return expressions[0] } - if relation == filter.QueryRelationOr { + if relation == string(entity.QueryRelationOr) { return field.Or(expressions...) } // 默认使用 AND 关系 @@ -210,15 +215,15 @@ func (v *TaskDaoImpl) combineExpressions(expressions []field.Expr, relation stri } // 构建任务名称过滤条件 -func (v *TaskDaoImpl) buildTaskNameFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildTaskNameFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no value provided for task name query")) } switch *f.QueryType { - case filter.QueryTypeEq: + case entity.QueryTypeEq: return q.ObservabilityTask.Name.Eq(f.Values[0]), nil - case filter.QueryTypeMatch: + case entity.QueryTypeMatch: return q.ObservabilityTask.Name.Like(fmt.Sprintf("%%%s%%", f.Values[0])), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for task name")) @@ -226,15 +231,15 @@ func (v *TaskDaoImpl) buildTaskNameFilter(q *genquery.Query, f *filter.TaskFilte } // 构建任务类型过滤条件 -func (v *TaskDaoImpl) buildTaskTypeFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildTaskTypeFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no values provided for task type query")) } switch *f.QueryType { - case filter.QueryTypeIn: + case entity.QueryTypeIn: return q.ObservabilityTask.TaskType.In(f.Values...), nil - case filter.QueryTypeNotIn: + case entity.QueryTypeNotIn: return q.ObservabilityTask.TaskType.NotIn(f.Values...), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for task type")) @@ -242,15 +247,15 @@ func (v *TaskDaoImpl) buildTaskTypeFilter(q *genquery.Query, f *filter.TaskFilte } // 构建任务状态过滤条件 -func (v *TaskDaoImpl) buildTaskStatusFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildTaskStatusFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no values provided for task status query")) } switch *f.QueryType { - case filter.QueryTypeIn: + case entity.QueryTypeIn: return q.ObservabilityTask.TaskStatus.In(f.Values...), nil - case filter.QueryTypeNotIn: + case entity.QueryTypeNotIn: return q.ObservabilityTask.TaskStatus.NotIn(f.Values...), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for task status")) @@ -258,15 +263,15 @@ func (v *TaskDaoImpl) buildTaskStatusFilter(q *genquery.Query, f *filter.TaskFil } // 构建创建者过滤条件 -func (v *TaskDaoImpl) buildCreatedByFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildCreatedByFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no values provided for created_by query")) } switch *f.QueryType { - case filter.QueryTypeIn: + case entity.QueryTypeIn: return q.ObservabilityTask.CreatedBy.In(f.Values...), nil - case filter.QueryTypeNotIn: + case entity.QueryTypeNotIn: return q.ObservabilityTask.CreatedBy.NotIn(f.Values...), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for created_by")) @@ -274,7 +279,7 @@ func (v *TaskDaoImpl) buildCreatedByFilter(q *genquery.Query, f *filter.TaskFilt } // 构建采样率过滤条件 -func (v *TaskDaoImpl) buildSampleRateFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildSampleRateFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no value provided for sample rate")) } @@ -287,13 +292,13 @@ func (v *TaskDaoImpl) buildSampleRateFilter(q *genquery.Query, f *filter.TaskFil // 构建 JSON_EXTRACT 表达式 switch *f.QueryType { - case filter.QueryTypeGte: + case entity.QueryTypeGte: return field.NewUnsafeFieldRaw("CAST(JSON_EXTRACT(?, '$.sample_rate') AS DECIMAL(10,4)) >= ?", q.ObservabilityTask.Sampler, sampleRate), nil - case filter.QueryTypeLte: + case entity.QueryTypeLte: return field.NewUnsafeFieldRaw("CAST(JSON_EXTRACT(?, '$.sample_rate') AS DECIMAL(10,4)) <= ?", q.ObservabilityTask.Sampler, sampleRate), nil - case filter.QueryTypeEq: + case entity.QueryTypeEq: return field.NewUnsafeFieldRaw("CAST(JSON_EXTRACT(?, '$.sample_rate') AS DECIMAL(10,4)) = ?", q.ObservabilityTask.Sampler, sampleRate), nil - case filter.QueryTypeNotEq: + case entity.QueryTypeNotEq: return field.NewUnsafeFieldRaw("CAST(JSON_EXTRACT(?, '$.sample_rate') AS DECIMAL(10,4)) != ?", q.ObservabilityTask.Sampler, sampleRate), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for sample rate")) @@ -301,7 +306,7 @@ func (v *TaskDaoImpl) buildSampleRateFilter(q *genquery.Query, f *filter.TaskFil } // 构建任务ID过滤条件 -func (v *TaskDaoImpl) buildTaskIDFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildTaskIDFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no value provided for task id")) } @@ -318,7 +323,7 @@ func (v *TaskDaoImpl) buildTaskIDFilter(q *genquery.Query, f *filter.TaskFilterF return q.ObservabilityTask.ID.In(taskIDs...), nil } -func (v *TaskDaoImpl) buildUpdateAtFilter(q *genquery.Query, f *filter.TaskFilterField) (field.Expr, error) { +func (v *TaskDaoImpl) buildUpdateAtFilter(q *genquery.Query, f *entity.TaskFilterField) (field.Expr, error) { if len(f.Values) == 0 { return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("no value provided for update at")) } @@ -328,9 +333,9 @@ func (v *TaskDaoImpl) buildUpdateAtFilter(q *genquery.Query, f *filter.TaskFilte return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithMsgParam("invalid update at: %v", err.Error())) } switch *f.QueryType { - case filter.QueryTypeGt: + case entity.QueryTypeGt: return q.ObservabilityTask.UpdatedAt.Gt(time.UnixMilli(updateAtLatest)), nil - case filter.QueryTypeLt: + case entity.QueryTypeLt: return q.ObservabilityTask.UpdatedAt.Lt(time.UnixMilli(updateAtLatest)), nil default: return nil, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("invalid query type for update at")) @@ -392,46 +397,16 @@ func (v *TaskDaoImpl) UpdateTaskWithOCC(ctx context.Context, id int64, workspace return errorx.NewByCode(obErrorx.CommonMySqlErrorCode, errorx.WithExtraMsg("TaskRun update failed with OCC")) } -func (v *TaskDaoImpl) GetObjListWithTask(ctx context.Context) ([]string, []string, []*model.ObservabilityTask, error) { +func (v *TaskDaoImpl) ListNonFinalTasks(ctx context.Context) ([]*model.ObservabilityTask, error) { q := genquery.Use(v.dbMgr.NewSession(ctx)) qd := q.WithContext(ctx).ObservabilityTask - // 查询非终态任务的workspace_id,使用DISTINCT去重 - qd = qd.Where(q.ObservabilityTask.TaskStatus.NotIn("success", "disabled")) - // qd = qd.Select(q.ObservabilityTask.WorkspaceID).Distinct() + // 查询非终态任务 + qd = qd.Where(q.ObservabilityTask.TaskStatus.NotIn(string(entity.TaskStatusSuccess), string(entity.TaskStatusDisabled))) results, err := qd.Find() if err != nil { - return nil, nil, nil, errorx.WrapByCode(err, obErrorx.CommonMySqlErrorCode) - } - - // 转换为字符串数组 - var spaceList []string - var botList []string - for _, task := range results { - spaceList = append(spaceList, strconv.FormatInt(task.WorkspaceID, 10)) - spanFilter := tconv.SpanFilterPO2DO(ctx, task.SpanFilter) - if spanFilter != nil && spanFilter.Filters.FilterFields != nil { - extractBotIDFromFilters(spanFilter.Filters.FilterFields, &botList) - } - } - - return spaceList, botList, nil, nil -} - -// extractBotIDFromFilters 递归提取过滤器中的 bot_id 值,包括 SubFilter -func extractBotIDFromFilters(filterFields []*filter.FilterField, botList *[]string) { - for _, filterField := range filterFields { - if filterField == nil { - continue - } - // 检查当前 FilterField 的 FieldName - if filterField.FieldName != nil && *filterField.FieldName == "bot_id" { - *botList = append(*botList, filterField.Values...) - } - // 递归处理 SubFilter - if filterField.SubFilter != nil && filterField.SubFilter.FilterFields != nil { - extractBotIDFromFilters(filterField.SubFilter.FilterFields, botList) - } + return nil, errorx.WrapByCode(err, obErrorx.CommonMySqlErrorCode) } + return results, nil } diff --git a/backend/modules/observability/infra/repo/mysql/task_run.go b/backend/modules/observability/infra/repo/mysql/task_run.go index 3f8ab0dd4..807c755d3 100755 --- a/backend/modules/observability/infra/repo/mysql/task_run.go +++ b/backend/modules/observability/infra/repo/mysql/task_run.go @@ -9,8 +9,8 @@ import ( "time" "github.com/coze-dev/coze-loop/backend/infra/db" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + tracecommon "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/model" genquery "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/query" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" @@ -33,7 +33,7 @@ type ListTaskRunParam struct { TaskRunStatus *task.RunStatus ReqLimit int32 ReqOffset int32 - OrderBy *common.OrderBy + OrderBy *tracecommon.OrderBy } //go:generate mockgen -destination=mocks/task_run.go -package=mocks . ITaskRunDao @@ -163,7 +163,13 @@ func (v *TaskRunDaoImpl) ListTaskRuns(ctx context.Context, param ListTaskRunPara } // 排序 - qd = qd.Order(v.order(q, param.OrderBy.GetField(), param.OrderBy.GetIsAsc())) + orderField := "" + orderAsc := false + if param.OrderBy != nil { + orderField = param.OrderBy.Field + orderAsc = param.OrderBy.IsAsc + } + qd = qd.Order(v.order(q, orderField, orderAsc)) // 计算总数 total, err := qd.Count() diff --git a/backend/modules/observability/infra/repo/redis/dao/task.go b/backend/modules/observability/infra/repo/redis/task.go similarity index 99% rename from backend/modules/observability/infra/repo/redis/dao/task.go rename to backend/modules/observability/infra/repo/redis/task.go index 99fc27f1e..d6f5a541a 100755 --- a/backend/modules/observability/infra/repo/redis/dao/task.go +++ b/backend/modules/observability/infra/repo/redis/task.go @@ -1,7 +1,7 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 -package dao +package redis import ( "context" diff --git a/backend/modules/observability/infra/repo/redis/dao/task_run.go b/backend/modules/observability/infra/repo/redis/task_run.go similarity index 99% rename from backend/modules/observability/infra/repo/redis/dao/task_run.go rename to backend/modules/observability/infra/repo/redis/task_run.go index 2b7c41154..04263fd7e 100755 --- a/backend/modules/observability/infra/repo/redis/dao/task_run.go +++ b/backend/modules/observability/infra/repo/redis/task_run.go @@ -1,7 +1,7 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 -package dao +package redis import ( "context" diff --git a/backend/modules/observability/infra/repo/task.go b/backend/modules/observability/infra/repo/task.go index 3791716c7..8e18f7626 100644 --- a/backend/modules/observability/infra/repo/task.go +++ b/backend/modules/observability/infra/repo/task.go @@ -13,12 +13,12 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/convertor" - "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis/dao" + "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" ) -func NewTaskRepoImpl(TaskDao mysql.ITaskDao, idGenerator idgen.IIDGenerator, taskRedisDao dao.ITaskDAO, taskRunDao mysql.ITaskRunDao, taskRunRedisDao dao.ITaskRunDAO) repo.ITaskRepo { +func NewTaskRepoImpl(TaskDao mysql.ITaskDao, idGenerator idgen.IIDGenerator, taskRedisDao redis.ITaskDAO, taskRunDao mysql.ITaskRunDao, taskRunRedisDao redis.ITaskRunDAO) repo.ITaskRepo { return &TaskRepoImpl{ TaskDao: TaskDao, idGenerator: idGenerator, @@ -31,23 +31,11 @@ func NewTaskRepoImpl(TaskDao mysql.ITaskDao, idGenerator idgen.IIDGenerator, tas type TaskRepoImpl struct { TaskDao mysql.ITaskDao TaskRunDao mysql.ITaskRunDao - TaskRedisDao dao.ITaskDAO - TaskRunRedisDao dao.ITaskRunDAO + TaskRedisDao redis.ITaskDAO + TaskRunRedisDao redis.ITaskRunDAO idGenerator idgen.IIDGenerator } -// 缓存 TTL 常量 -const ( - TaskDetailTTL = 30 * time.Minute // 单个任务缓存30分钟 - NonFinalTaskListTTL = 1 * time.Minute // 非最终状态任务缓存1分钟 - TaskCountTTL = 10 * time.Minute // 任务计数缓存10分钟 -) - -// 任务运行计数TTL常量 -const ( - TaskRunCountTTL = 10 * time.Minute // 任务运行计数缓存10分钟 -) - func (v *TaskRepoImpl) GetTask(ctx context.Context, id int64, workspaceID *int64, userID *string) (*entity.ObservabilityTask, error) { TaskPO, err := v.TaskDao.GetTask(ctx, id, workspaceID, userID) if err != nil { @@ -59,7 +47,7 @@ func (v *TaskRepoImpl) GetTask(ctx context.Context, id int64, workspaceID *int64 TaskRunPO, _, err := v.TaskRunDao.ListTaskRuns(ctx, mysql.ListTaskRunParam{ WorkspaceID: ptr.Of(taskDO.WorkspaceID), TaskID: ptr.Of(taskDO.ID), - ReqLimit: 1000, + ReqLimit: 500, ReqOffset: 0, }) @@ -71,8 +59,14 @@ func (v *TaskRepoImpl) GetTask(ctx context.Context, id int64, workspaceID *int64 return taskDO, nil } -func (v *TaskRepoImpl) ListTasks(ctx context.Context, param mysql.ListTaskParam) ([]*entity.ObservabilityTask, int64, error) { - results, total, err := v.TaskDao.ListTasks(ctx, param) +func (v *TaskRepoImpl) ListTasks(ctx context.Context, param repo.ListTaskParam) ([]*entity.ObservabilityTask, int64, error) { + results, total, err := v.TaskDao.ListTasks(ctx, mysql.ListTaskParam{ + WorkspaceIDs: param.WorkspaceIDs, + TaskFilters: param.TaskFilters, + ReqLimit: param.ReqLimit, + ReqOffset: param.ReqOffset, + OrderBy: param.OrderBy, + }) if err != nil { return nil, 0, err } @@ -80,12 +74,13 @@ func (v *TaskRepoImpl) ListTasks(ctx context.Context, param mysql.ListTaskParam) for i, result := range results { resp[i] = convertor.TaskPO2DO(result) } + // todo 待优化 for _, t := range resp { taskRuns, _, err := v.TaskRunDao.ListTaskRuns(ctx, mysql.ListTaskRunParam{ WorkspaceID: ptr.Of(t.WorkspaceID), TaskID: ptr.Of(t.ID), - ReqLimit: param.ReqLimit, - ReqOffset: param.ReqOffset, + ReqLimit: 500, + ReqOffset: 0, }) if err != nil { logs.CtxError(ctx, "ListTaskRuns err, taskID:%d, err:%v", t.ID, err) @@ -154,21 +149,6 @@ func (v *TaskRepoImpl) UpdateTaskWithOCC(ctx context.Context, id int64, workspac return nil } -func (v *TaskRepoImpl) GetObjListWithTask(ctx context.Context) ([]string, []string, []*entity.ObservabilityTask) { - var tasks []*entity.ObservabilityTask - spaceList, botList, results, err := v.TaskDao.GetObjListWithTask(ctx) - if err != nil { - logs.CtxWarn(ctx, "failed to get obj list with task from mysql", "err", err) - return nil, nil, nil - } - tasks = make([]*entity.ObservabilityTask, len(results)) - for i, result := range results { - tasks[i] = convertor.TaskPO2DO(result) - } - - return spaceList, botList, tasks -} - func (v *TaskRepoImpl) DeleteTask(ctx context.Context, do *entity.ObservabilityTask) error { // 先执行数据库删除操作 err := v.TaskDao.DeleteTask(ctx, do.ID, do.WorkspaceID, do.CreatedBy) @@ -183,16 +163,29 @@ func (v *TaskRepoImpl) DeleteTask(ctx context.Context, do *entity.ObservabilityT return nil } +func (v *TaskRepoImpl) ListNonFinalTasks(ctx context.Context) ([]*entity.ObservabilityTask, error) { + result, err := v.TaskDao.ListNonFinalTasks(ctx) + if err != nil { + return nil, err + } + + resp := make([]*entity.ObservabilityTask, len(result)) + for i, t := range result { + resp[i] = convertor.TaskPO2DO(t) + } + return resp, nil +} + func (v *TaskRepoImpl) CreateTaskRun(ctx context.Context, do *entity.TaskRun) (int64, error) { // 1. 生成ID id, err := v.idGenerator.GenID(ctx) if err != nil { return 0, err } + do.ID = id // 2. 转换并设置ID taskRunPo := convertor.TaskRunDO2PO(do) - taskRunPo.ID = id // 3. 数据库创建 createdID, err := v.TaskRunDao.CreateTaskRun(ctx, taskRunPo) @@ -200,8 +193,6 @@ func (v *TaskRepoImpl) CreateTaskRun(ctx context.Context, do *entity.TaskRun) (i return 0, err } - // 4. 异步更新缓存 - do.ID = createdID return createdID, nil } @@ -324,7 +315,7 @@ func (v *TaskRepoImpl) IncrTaskRunFailCount(ctx context.Context, taskID, taskRun return v.TaskRunRedisDao.IncrTaskRunFailCount(ctx, taskID, taskRunID, time.Duration(ttl)*time.Second) } -func (v *TaskRepoImpl) ListNonFinalTask(ctx context.Context, spaceID string) ([]int64, error) { +func (v *TaskRepoImpl) ListNonFinalTaskBySpaceID(ctx context.Context, spaceID string) ([]int64, error) { return v.TaskRedisDao.ListNonFinalTask(ctx, spaceID) } @@ -336,7 +327,7 @@ func (v *TaskRepoImpl) RemoveNonFinalTask(ctx context.Context, spaceID string, t return v.TaskRedisDao.RemoveNonFinalTask(ctx, spaceID, taskID) } -func (v *TaskRepoImpl) GetTaskByRedis(ctx context.Context, taskID int64) (*entity.ObservabilityTask, error) { +func (v *TaskRepoImpl) GetTaskByCache(ctx context.Context, taskID int64) (*entity.ObservabilityTask, error) { taskDO, err := v.TaskRedisDao.GetTask(ctx, taskID) if err != nil { logs.CtxError(ctx, "Failed to get task", "err", err) @@ -360,7 +351,3 @@ func (v *TaskRepoImpl) GetTaskByRedis(ctx context.Context, taskID int64) (*entit } return taskDO, nil } - -func (v *TaskRepoImpl) SetTask(ctx context.Context, task *entity.ObservabilityTask) error { - return v.TaskRedisDao.SetTask(ctx, task) -} diff --git a/backend/modules/observability/infra/repo/task_test.go b/backend/modules/observability/infra/repo/task_test.go index e27aa5ea5..c8aae58ef 100755 --- a/backend/modules/observability/infra/repo/task_test.go +++ b/backend/modules/observability/infra/repo/task_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" - mysql "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" + "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" mysqlconv "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/convertor" mysqlmodel "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/model" ) @@ -38,6 +38,7 @@ type stubTaskDao struct { getTaskFunc func(ctx context.Context, id int64, workspaceID *int64, userID *string) (*mysqlmodel.ObservabilityTask, error) listTasksFunc func(ctx context.Context, param mysql.ListTaskParam) ([]*mysqlmodel.ObservabilityTask, int64, error) getObjListWithTaskFunc func(ctx context.Context) ([]string, []string, []*mysqlmodel.ObservabilityTask, error) + listNonFinalTasksFunc func(ctx context.Context) ([]*mysqlmodel.ObservabilityTask, error) } func (s *stubTaskDao) CreateTask(ctx context.Context, po *mysqlmodel.ObservabilityTask) (int64, error) { @@ -82,11 +83,11 @@ func (s *stubTaskDao) ListTasks(ctx context.Context, param mysql.ListTaskParam) return nil, 0, nil } -func (s *stubTaskDao) GetObjListWithTask(ctx context.Context) ([]string, []string, []*mysqlmodel.ObservabilityTask, error) { - if s.getObjListWithTaskFunc != nil { - return s.getObjListWithTaskFunc(ctx) +func (s *stubTaskDao) ListNonFinalTasks(ctx context.Context) ([]*mysqlmodel.ObservabilityTask, error) { + if s.listNonFinalTasksFunc != nil { + return s.listNonFinalTasksFunc(ctx) } - return nil, nil, nil, nil + return nil, nil } type stubTaskRedisDao struct { @@ -443,7 +444,7 @@ func TestTaskRepoImpl_NonFinalTaskWrappers(t *testing.T) { TaskRunRedisDao: stubTaskRunRedisDao{}, } - list, err := repo.ListNonFinalTask(context.Background(), "space") + list, err := repo.ListNonFinalTaskBySpaceID(context.Background(), "space") assert.NoError(t, err) assert.Equal(t, expected, list) @@ -558,7 +559,7 @@ func TestTaskRepoImpl_GetTaskByRedis(t *testing.T) { } } - got, err := repo.GetTaskByRedis(context.Background(), 100) + got, err := repo.GetTaskByCache(context.Background(), 100) if tt.expectErr != nil { assert.EqualError(t, err, tt.expectErr.Error()) } else { @@ -570,23 +571,3 @@ func TestTaskRepoImpl_GetTaskByRedis(t *testing.T) { }) } } - -func TestTaskRepoImpl_SetTask(t *testing.T) { - t.Parallel() - - called := false - redisDao := &stubTaskRedisDao{setTaskFunc: func(ctx context.Context, task *entity.ObservabilityTask) error { - called = true - assert.Equal(t, int64(1), task.ID) - return nil - }} - repo := &TaskRepoImpl{ - TaskDao: &stubTaskDao{}, - TaskRunDao: stubTaskRunDao{}, - TaskRedisDao: redisDao, - TaskRunRedisDao: stubTaskRunRedisDao{}, - } - - assert.NoError(t, repo.SetTask(context.Background(), &entity.ObservabilityTask{ID: 1})) - assert.True(t, called) -} From e0c72e51884a639f3457087209dac49c8b7567a7 Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Fri, 31 Oct 2025 18:15:52 +0800 Subject: [PATCH 05/19] rename task processor. Change-Id: I02c6962c72671edeeefbc73a1429508bb8060962 --- .../domain/task/service/task_service.go | 4 +- .../domain/task/service/task_service_test.go | 6 +-- .../taskexe/{types.go => processor.go} | 26 ++++-------- .../taskexe/processor/auto_evaluate.go | 36 ++++++++--------- .../taskexe/processor/auto_evaluate_test.go | 23 ++++++----- .../service/taskexe/processor/factory_test.go | 10 ++--- .../task/service/taskexe/processor/noop.go | 10 ++--- .../task/service/taskexe/tracehub/backfill.go | 4 +- .../taskexe/tracehub/scheduled_task.go | 40 +++++++++---------- .../taskexe/tracehub/scheduled_task_test.go | 12 +++--- .../service/taskexe/tracehub/span_trigger.go | 20 +++++----- .../service/taskexe/tracehub/subscriber.go | 2 +- .../taskexe/tracehub/test_helpers_test.go | 14 +++---- .../observability/infra/repo/mysql/task.go | 4 +- .../infra/repo/mysql/task_run.go | 36 +++++------------ .../observability/infra/repo/redis/task.go | 1 + 16 files changed, 113 insertions(+), 135 deletions(-) rename backend/modules/observability/domain/task/service/taskexe/{types.go => processor.go} (52%) diff --git a/backend/modules/observability/domain/task/service/task_service.go b/backend/modules/observability/domain/task/service/task_service.go index 870b75b84..15ac2b76a 100644 --- a/backend/modules/observability/domain/task/service/task_service.go +++ b/backend/modules/observability/domain/task/service/task_service.go @@ -134,7 +134,7 @@ func (t *TaskServiceImpl) CreateTask(ctx context.Context, req *CreateTaskReq) (r // 数据回流任务——创建/更新输出数据集 // 自动评测历史回溯——创建空壳子 taskDO.ID = id - if err = proc.OnCreateTaskChange(ctx, taskDO); err != nil { + if err = proc.OnTaskCreated(ctx, taskDO); err != nil { logs.CtxError(ctx, "create initial task run failed, task_id=%d, err=%v", id, err) if err1 := t.TaskRepo.DeleteTask(ctx, taskDO); err1 != nil { @@ -201,7 +201,7 @@ func (t *TaskServiceImpl) UpdateTask(ctx context.Context, req *UpdateTaskReq) (e break } } - if err = proc.OnFinishTaskRunChange(ctx, taskexe.OnFinishTaskRunChangeReq{ + if err = proc.OnTaskRunFinished(ctx, taskexe.OnTaskRunFinishedReq{ Task: taskDO, TaskRun: taskRun, }); err != nil { diff --git a/backend/modules/observability/domain/task/service/task_service_test.go b/backend/modules/observability/domain/task/service/task_service_test.go index 8049530a3..2499f662a 100755 --- a/backend/modules/observability/domain/task/service/task_service_test.go +++ b/backend/modules/observability/domain/task/service/task_service_test.go @@ -53,15 +53,15 @@ func (f *fakeProcessor) OnUpdateTaskChange(context.Context, *entity.Observabilit return nil } -func (f *fakeProcessor) OnFinishTaskChange(context.Context, taskexe.OnFinishTaskChangeReq) error { +func (f *fakeProcessor) OnFinishTaskChange(context.Context, taskexe.OnTaskFinishedReq) error { return nil } -func (f *fakeProcessor) OnCreateTaskRunChange(context.Context, taskexe.OnCreateTaskRunChangeReq) error { +func (f *fakeProcessor) OnCreateTaskRunChange(context.Context, taskexe.OnTaskRunCreatedReq) error { return nil } -func (f *fakeProcessor) OnFinishTaskRunChange(context.Context, taskexe.OnFinishTaskRunChangeReq) error { +func (f *fakeProcessor) OnFinishTaskRunChange(context.Context, taskexe.OnTaskRunFinishedReq) error { f.onFinishRunCalled = true return f.onFinishRunErr } diff --git a/backend/modules/observability/domain/task/service/taskexe/types.go b/backend/modules/observability/domain/task/service/taskexe/processor.go similarity index 52% rename from backend/modules/observability/domain/task/service/taskexe/types.go rename to backend/modules/observability/domain/task/service/taskexe/processor.go index 6648398f2..fa1542a2e 100644 --- a/backend/modules/observability/domain/task/service/taskexe/types.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor.go @@ -5,7 +5,6 @@ package taskexe import ( "context" - "errors" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" @@ -17,22 +16,17 @@ type Trigger struct { TaskRun *entity.TaskRun } -var ( - ErrInvalidConfig = errors.New("invalid config") - ErrInvalidTrigger = errors.New("invalid span trigger") -) - -type OnCreateTaskRunChangeReq struct { +type OnTaskRunCreatedReq struct { CurrentTask *entity.ObservabilityTask RunType entity.TaskRunType RunStartAt int64 RunEndAt int64 } -type OnFinishTaskRunChangeReq struct { +type OnTaskRunFinishedReq struct { Task *entity.ObservabilityTask TaskRun *entity.TaskRun } -type OnFinishTaskChangeReq struct { +type OnTaskFinishedReq struct { Task *entity.ObservabilityTask TaskRun *entity.TaskRun IsFinish bool @@ -42,14 +36,10 @@ type Processor interface { ValidateConfig(ctx context.Context, config any) error Invoke(ctx context.Context, trigger *Trigger) error - OnCreateTaskChange(ctx context.Context, currentTask *entity.ObservabilityTask) error - OnUpdateTaskChange(ctx context.Context, currentTask *entity.ObservabilityTask, taskOp entity.TaskStatus) error - OnFinishTaskChange(ctx context.Context, param OnFinishTaskChangeReq) error - - OnCreateTaskRunChange(ctx context.Context, param OnCreateTaskRunChangeReq) error - OnFinishTaskRunChange(ctx context.Context, param OnFinishTaskRunChangeReq) error -} + OnTaskCreated(ctx context.Context, currentTask *entity.ObservabilityTask) error + OnTaskUpdated(ctx context.Context, currentTask *entity.ObservabilityTask, taskOp entity.TaskStatus) error + OnTaskFinished(ctx context.Context, param OnTaskFinishedReq) error -type ProcessorUnion interface { - Processor + OnTaskRunCreated(ctx context.Context, param OnTaskRunCreatedReq) error + OnTaskRunFinished(ctx context.Context, param OnTaskRunFinishedReq) error } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go index 2bff52aad..02d4c4995 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go @@ -63,7 +63,7 @@ func NewAutoEvaluteProcessor( func (p *AutoEvaluteProcessor) ValidateConfig(ctx context.Context, config any) error { cfg, ok := config.(*task_entity.ObservabilityTask) if !ok { - return taskexe.ErrInvalidConfig + return errorx.NewByCode(obErrorx.CommonInvalidParamCode) } if cfg.EffectiveTime != nil { startAt := cfg.EffectiveTime.StartAt @@ -156,26 +156,26 @@ func (p *AutoEvaluteProcessor) Invoke(ctx context.Context, trigger *taskexe.Trig return nil } -func (p *AutoEvaluteProcessor) OnCreateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask) error { +func (p *AutoEvaluteProcessor) OnTaskCreated(ctx context.Context, currentTask *task_entity.ObservabilityTask) error { taskRuns, err := p.taskRepo.GetBackfillTaskRun(ctx, nil, currentTask.ID) if err != nil { logs.CtxError(ctx, "GetBackfillTaskRun failed, taskID:%d, err:%v", currentTask.ID, err) return err } if ShouldTriggerBackfill(currentTask) && taskRuns == nil { - err = p.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ + err = p.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ CurrentTask: currentTask, RunType: task_entity.TaskRunTypeBackFill, RunStartAt: time.Now().UnixMilli(), RunEndAt: time.Now().UnixMilli() + (currentTask.BackfillEffectiveTime.EndAt - currentTask.BackfillEffectiveTime.StartAt), }) if err != nil { - logs.CtxError(ctx, "OnCreateTaskChange failed, taskID:%d, err:%v", currentTask.ID, err) + logs.CtxError(ctx, "OnTaskCreated failed, taskID:%d, err:%v", currentTask.ID, err) return err } - err = p.OnUpdateTaskChange(ctx, currentTask, task.TaskStatusRunning) + err = p.OnTaskUpdated(ctx, currentTask, task.TaskStatusRunning) if err != nil { - logs.CtxError(ctx, "OnCreateTaskChange failed, taskID:%d, err:%v", currentTask.ID, err) + logs.CtxError(ctx, "OnTaskCreated failed, taskID:%d, err:%v", currentTask.ID, err) return err } } @@ -194,26 +194,26 @@ func (p *AutoEvaluteProcessor) OnCreateTaskChange(ctx context.Context, currentTa runEndAt = runStartAt + (currentTask.Sampler.CycleInterval)*10*time.Minute.Milliseconds() } } - err = p.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ + err = p.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ CurrentTask: currentTask, RunType: task_entity.TaskRunTypeNewData, RunStartAt: runStartAt, RunEndAt: runEndAt, }) if err != nil { - logs.CtxError(ctx, "OnCreateTaskChange failed, taskID:%d, err:%v", currentTask.ID, err) + logs.CtxError(ctx, "OnTaskCreated failed, taskID:%d, err:%v", currentTask.ID, err) return err } - err = p.OnUpdateTaskChange(ctx, currentTask, task.TaskStatusRunning) + err = p.OnTaskUpdated(ctx, currentTask, task.TaskStatusRunning) if err != nil { - logs.CtxError(ctx, "OnCreateTaskChange failed, taskID:%d, err:%v", currentTask.ID, err) + logs.CtxError(ctx, "OnTaskCreated failed, taskID:%d, err:%v", currentTask.ID, err) return err } } return nil } -func (p *AutoEvaluteProcessor) OnUpdateTaskChange(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task_entity.TaskStatus) error { +func (p *AutoEvaluteProcessor) OnTaskUpdated(ctx context.Context, currentTask *task_entity.ObservabilityTask, taskOp task_entity.TaskStatus) error { switch taskOp { case task_entity.TaskStatusSuccess: if currentTask.TaskStatus != task_entity.TaskStatusDisabled { @@ -243,18 +243,18 @@ func (p *AutoEvaluteProcessor) OnUpdateTaskChange(ctx context.Context, currentTa return nil } -func (p *AutoEvaluteProcessor) OnFinishTaskChange(ctx context.Context, param taskexe.OnFinishTaskChangeReq) error { - err := p.OnFinishTaskRunChange(ctx, taskexe.OnFinishTaskRunChangeReq{ +func (p *AutoEvaluteProcessor) OnTaskFinished(ctx context.Context, param taskexe.OnTaskFinishedReq) error { + err := p.OnTaskRunFinished(ctx, taskexe.OnTaskRunFinishedReq{ Task: param.Task, TaskRun: param.TaskRun, }) if err != nil { - logs.CtxError(ctx, "OnFinishTaskRunChange failed, taskRun:%+v, err:%v", param.TaskRun, err) + logs.CtxError(ctx, "OnTaskRunFinished failed, taskRun:%+v, err:%v", param.TaskRun, err) return err } if param.IsFinish { - logs.CtxWarn(ctx, "OnFinishTaskChange, taskID:%d, taskRun:%+v, isFinish:%v", param.Task.ID, param.TaskRun, param.IsFinish) - if err := p.OnUpdateTaskChange(ctx, param.Task, task.TaskStatusSuccess); err != nil { + logs.CtxWarn(ctx, "OnTaskFinished, taskID:%d, taskRun:%+v, isFinish:%v", param.Task.ID, param.TaskRun, param.IsFinish) + if err := p.OnTaskUpdated(ctx, param.Task, task.TaskStatusSuccess); err != nil { logs.CtxError(ctx, "OnUpdateChangeProcessor failed, taskID:%d, err:%v", param.Task.ID, err) return err } @@ -273,7 +273,7 @@ const ( BackFillI18N = "BackFill" ) -func (p *AutoEvaluteProcessor) OnCreateTaskRunChange(ctx context.Context, param taskexe.OnCreateTaskRunChangeReq) error { +func (p *AutoEvaluteProcessor) OnTaskRunCreated(ctx context.Context, param taskexe.OnTaskRunCreatedReq) error { currentTask := param.CurrentTask ctx = session.WithCtxUser(ctx, &session.User{ID: currentTask.CreatedBy}) sessionInfo := p.getSession(ctx, currentTask) @@ -410,7 +410,7 @@ func (p *AutoEvaluteProcessor) OnCreateTaskRunChange(ctx context.Context, param return nil } -func (p *AutoEvaluteProcessor) OnFinishTaskRunChange(ctx context.Context, param taskexe.OnFinishTaskRunChangeReq) error { +func (p *AutoEvaluteProcessor) OnTaskRunFinished(ctx context.Context, param taskexe.OnTaskRunFinishedReq) error { if param.TaskRun == nil || param.TaskRun.TaskRunConfig == nil || param.TaskRun.TaskRunConfig.AutoEvaluateRunConfig == nil { return nil } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go index cbedbe2c9..a51537ff5 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go @@ -224,7 +224,8 @@ func TestAutoEvaluteProcessor_ValidateConfig(t *testing.T) { name: "invalid type", config: "bad", expectErr: func(err error) bool { - return errors.Is(err, taskexe.ErrInvalidConfig) + status, ok := errorx.FromStatusError(err) + return ok && status.Code() == obErrorx.CommonInvalidParamCode }, }, { @@ -457,14 +458,14 @@ func TestAutoEvaluteProcessor_OnUpdateTaskChange(t *testing.T) { proc := &AutoEvaluteProcessor{taskRepo: repoAdapter} taskObj := &taskentity.ObservabilityTask{TaskStatus: caseItem.initial} - err := proc.OnUpdateTaskChange(ctx, taskObj, caseItem.op) + err := proc.OnTaskUpdated(ctx, taskObj, caseItem.op) assert.NoError(t, err) }) } t.Run("invalid op", func(t *testing.T) { proc := &AutoEvaluteProcessor{} - err := proc.OnUpdateTaskChange(ctx, &taskentity.ObservabilityTask{}, "unknown") + err := proc.OnTaskUpdated(ctx, &taskentity.ObservabilityTask{}, "unknown") assert.Error(t, err) }) } @@ -479,7 +480,7 @@ func TestAutoEvaluteProcessor_OnCreateTaskRunChange(t *testing.T) { repoAdapter := &taskRepoMockAdapter{MockITaskRepo: repoMock} taskObj := buildTestTask(t) - param := taskexe.OnCreateTaskRunChangeReq{ + param := taskexe.OnTaskRunCreatedReq{ CurrentTask: taskObj, RunType: taskentity.TaskRunTypeNewData, RunStartAt: time.Now().Add(-time.Minute).UnixMilli(), @@ -506,7 +507,7 @@ func TestAutoEvaluteProcessor_OnCreateTaskRunChange(t *testing.T) { } ctx := session.WithCtxUser(context.Background(), &session.User{ID: taskObj.CreatedBy}) - err := proc.OnCreateTaskRunChange(ctx, param) + err := proc.OnTaskRunCreated(ctx, param) assert.NoError(t, err) assert.NotNil(t, evalAdapter.submitReq) assert.Equal(t, int64(9001), *evalAdapter.submitReq.EvalSetID) @@ -537,7 +538,7 @@ func TestAutoEvaluteProcessor_OnFinishTaskRunChange(t *testing.T) { evaluationSvc: evalAdapter, } - err := proc.OnFinishTaskRunChange(context.Background(), taskexe.OnFinishTaskRunChangeReq{ + err := proc.OnTaskRunFinished(context.Background(), taskexe.OnTaskRunFinishedReq{ Task: &taskentity.ObservabilityTask{WorkspaceID: 1234}, TaskRun: taskRun, }) @@ -566,7 +567,7 @@ func TestAutoEvaluteProcessor_OnFinishTaskChange(t *testing.T) { taskRepo: repoAdapter, } - err := proc.OnFinishTaskChange(context.Background(), taskexe.OnFinishTaskChangeReq{ + err := proc.OnTaskFinished(context.Background(), taskexe.OnTaskFinishedReq{ Task: taskObj, TaskRun: taskRun, IsFinish: true, @@ -590,7 +591,7 @@ func TestAutoEvaluteProcessor_OnFinishTaskChange_Error(t *testing.T) { taskRepo: repoAdapter, } - err := proc.OnFinishTaskChange(context.Background(), taskexe.OnFinishTaskChangeReq{ + err := proc.OnTaskFinished(context.Background(), taskexe.OnTaskFinishedReq{ Task: &taskentity.ObservabilityTask{WorkspaceID: 123}, TaskRun: &taskentity.TaskRun{TaskRunConfig: &taskentity.TaskRunConfig{AutoEvaluateRunConfig: &taskentity.AutoEvaluateRunConfig{ExptID: 1, ExptRunID: 2}}}, }) @@ -666,7 +667,7 @@ func TestAutoEvaluteProcessor_OnCreateTaskChange(t *testing.T) { updateTaskNewData, ) - err := proc.OnCreateTaskChange(context.Background(), taskObj) + err := proc.OnTaskCreated(context.Background(), taskObj) assert.NoError(t, err) assert.Equal(t, []taskentity.TaskRunType{taskentity.TaskRunTypeBackFill, taskentity.TaskRunTypeNewData}, runTypes) assert.Equal(t, []taskentity.TaskStatus{taskentity.TaskStatusRunning, taskentity.TaskStatusRunning}, statuses) @@ -685,7 +686,7 @@ func TestAutoEvaluteProcessor_OnCreateTaskChange_GetBackfillError(t *testing.T) proc := &AutoEvaluteProcessor{taskRepo: repoAdapter} - err := proc.OnCreateTaskChange(context.Background(), buildTestTask(t)) + err := proc.OnTaskCreated(context.Background(), buildTestTask(t)) assert.EqualError(t, err, "db error") } @@ -710,7 +711,7 @@ func TestAutoEvaluteProcessor_OnCreateTaskChange_CreateDatasetError(t *testing.T repoMock.EXPECT().GetBackfillTaskRun(gomock.Any(), (*int64)(nil), gomock.Any()).Return(nil, nil) datasetProvider.EXPECT().CreateDataset(gomock.Any(), gomock.AssignableToTypeOf(&traceentity.Dataset{})).Return(int64(0), errors.New("create fail")) - err := proc.OnCreateTaskChange(context.Background(), buildTestTask(t)) + err := proc.OnTaskCreated(context.Background(), buildTestTask(t)) assert.EqualError(t, err, "create fail") } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go index 466b997ba..7d5773f23 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/factory_test.go @@ -35,9 +35,9 @@ func TestNoopTaskProcessor_Methods(t *testing.T) { assert.NoError(t, p.ValidateConfig(ctx, nil)) assert.NoError(t, p.Invoke(ctx, nil)) - assert.NoError(t, p.OnCreateTaskChange(ctx, nil)) - assert.NoError(t, p.OnUpdateTaskChange(ctx, nil, task.TaskStatusRunning)) - assert.NoError(t, p.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{})) - assert.NoError(t, p.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{})) - assert.NoError(t, p.OnFinishTaskRunChange(ctx, taskexe.OnFinishTaskRunChangeReq{})) + assert.NoError(t, p.OnTaskCreated(ctx, nil)) + assert.NoError(t, p.OnTaskUpdated(ctx, nil, task.TaskStatusRunning)) + assert.NoError(t, p.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{})) + assert.NoError(t, p.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{})) + assert.NoError(t, p.OnTaskRunFinished(ctx, taskexe.OnTaskRunFinishedReq{})) } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/noop.go b/backend/modules/observability/domain/task/service/taskexe/processor/noop.go index d61466a54..37177c8b2 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/noop.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/noop.go @@ -26,22 +26,22 @@ func (p *NoopTaskProcessor) Invoke(ctx context.Context, trigger *taskexe.Trigger return nil } -func (p *NoopTaskProcessor) OnCreateTaskChange(ctx context.Context, currentTask *entity.ObservabilityTask) error { +func (p *NoopTaskProcessor) OnTaskCreated(ctx context.Context, currentTask *entity.ObservabilityTask) error { return nil } -func (p *NoopTaskProcessor) OnUpdateTaskChange(ctx context.Context, currentTask *entity.ObservabilityTask, taskOp entity.TaskStatus) error { +func (p *NoopTaskProcessor) OnTaskUpdated(ctx context.Context, currentTask *entity.ObservabilityTask, taskOp entity.TaskStatus) error { return nil } -func (p *NoopTaskProcessor) OnFinishTaskChange(ctx context.Context, param taskexe.OnFinishTaskChangeReq) error { +func (p *NoopTaskProcessor) OnTaskFinished(ctx context.Context, param taskexe.OnTaskFinishedReq) error { return nil } -func (p *NoopTaskProcessor) OnCreateTaskRunChange(ctx context.Context, param taskexe.OnCreateTaskRunChangeReq) error { +func (p *NoopTaskProcessor) OnTaskRunCreated(ctx context.Context, param taskexe.OnTaskRunCreatedReq) error { return nil } -func (p *NoopTaskProcessor) OnFinishTaskRunChange(ctx context.Context, param taskexe.OnFinishTaskRunChangeReq) error { +func (p *NoopTaskProcessor) OnTaskRunFinished(ctx context.Context, param taskexe.OnTaskRunFinishedReq) error { return nil } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index de5310f18..c33a28879 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -367,7 +367,7 @@ func (h *TraceHubServiceImpl) doFlush(ctx context.Context, fr *flushReq, sub *sp } if fr.noMore { logs.CtxInfo(ctx, "no more spans to process, task_id=%d", sub.t.GetID()) - if err = sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + if err = sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: tconv.TaskDTO2DO(sub.t), TaskRun: tconv.TaskRunDTO2DO(sub.tr), IsFinish: false, @@ -448,7 +448,7 @@ func (h *TraceHubServiceImpl) processBatchSpans(ctx context.Context, spans []*lo sampler := sub.t.GetRule().GetSampler() if taskCount+1 > sampler.GetSampleSize() { logs.CtxWarn(ctx, "taskCount+1 > sampler.GetSampleSize(), task_id=%d,SampleSize=%d", sub.taskID, sampler.GetSampleSize()) - if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: tconv.TaskDTO2DO(sub.t), TaskRun: tconv.TaskRunDTO2DO(sub.tr), IsFinish: true, diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go index 792fdd4e5..f7d442045 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go @@ -130,14 +130,14 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { logs.CtxInfo(ctx, "[auto_task]taskID:%d, endTime:%v, startTime:%v", taskPO.ID, endTime, startTime) if taskPO.BackfillEffectiveTime != nil && taskPO.EffectiveTime != nil && backfillTaskRun != nil { if time.Now().After(endTime) && backfillTaskRun.RunStatus == entity.TaskRunStatusDone { - logs.CtxInfo(ctx, "[OnFinishTaskChange]taskID:%d, time.Now().After(endTime) && backfillTaskRun.RunStatus == task.RunStatusDone", taskPO.ID) - err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + logs.CtxInfo(ctx, "[OnTaskFinished]taskID:%d, time.Now().After(endTime) && backfillTaskRun.RunStatus == task.RunStatusDone", taskPO.ID) + err = proc.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: taskPO, TaskRun: backfillTaskRun, IsFinish: true, }) if err != nil { - logs.CtxError(ctx, "OnFinishTaskChange err:%v", err) + logs.CtxError(ctx, "OnTaskFinished err:%v", err) continue } } @@ -154,14 +154,14 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { } } else if taskPO.BackfillEffectiveTime != nil && backfillTaskRun != nil { if backfillTaskRun.RunStatus == entity.TaskRunStatusDone { - logs.CtxInfo(ctx, "[OnFinishTaskChange]taskID:%d, backfillTaskRun.RunStatus == task.RunStatusDone", taskPO.ID) - err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + logs.CtxInfo(ctx, "[OnTaskFinished]taskID:%d, backfillTaskRun.RunStatus == task.RunStatusDone", taskPO.ID) + err = proc.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: taskPO, TaskRun: backfillTaskRun, IsFinish: true, }) if err != nil { - logs.CtxError(ctx, "OnFinishTaskChange err:%v", err) + logs.CtxError(ctx, "OnTaskFinished err:%v", err) continue } } @@ -178,14 +178,14 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { } } else if taskPO.EffectiveTime != nil { if time.Now().After(endTime) { - logs.CtxInfo(ctx, "[OnFinishTaskChange]taskID:%d, time.Now().After(endTime)", taskPO.ID) - err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + logs.CtxInfo(ctx, "[OnTaskFinished]taskID:%d, time.Now().After(endTime)", taskPO.ID) + err = proc.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: taskPO, TaskRun: taskRun, IsFinish: true, }) if err != nil { - logs.CtxError(ctx, "OnFinishTaskChange err:%v", err) + logs.CtxError(ctx, "OnTaskFinished err:%v", err) continue } } @@ -193,30 +193,30 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { // If the task status is unstarted, create it once the task start time is reached if taskPO.TaskStatus == entity.TaskStatusUnstarted && time.Now().After(startTime) { if !taskPO.Sampler.IsCycle { - err = proc.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ + err = proc.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ CurrentTask: taskPO, RunType: entity.TaskRunTypeNewData, RunStartAt: taskPO.EffectiveTime.StartAt, RunEndAt: taskPO.EffectiveTime.EndAt, }) if err != nil { - logs.CtxError(ctx, "OnCreateTaskRunChange err:%v", err) + logs.CtxError(ctx, "OnTaskRunCreated err:%v", err) continue } - err = proc.OnUpdateTaskChange(ctx, taskPO, entity.TaskStatusRunning) + err = proc.OnTaskUpdated(ctx, taskPO, entity.TaskStatusRunning) if err != nil { - logs.CtxError(ctx, "OnUpdateTaskChange err:%v", err) + logs.CtxError(ctx, "OnTaskUpdated err:%v", err) continue } } else { - err = proc.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ + err = proc.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ CurrentTask: taskPO, RunType: entity.TaskRunTypeNewData, RunStartAt: taskRun.RunEndAt.UnixMilli(), RunEndAt: taskRun.RunEndAt.UnixMilli() + (taskRun.RunEndAt.UnixMilli() - taskRun.RunStartAt.UnixMilli()), }) if err != nil { - logs.CtxError(ctx, "OnCreateTaskRunChange err:%v", err) + logs.CtxError(ctx, "OnTaskRunCreated err:%v", err) continue } } @@ -230,25 +230,25 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { logs.CtxInfo(ctx, "taskID:%d, taskRun.RunEndAt:%v", taskPO.ID, taskRun.RunEndAt) // Handling repeated tasks: single task time horizon reached if time.Now().After(taskRun.RunEndAt) { - logs.CtxInfo(ctx, "[OnFinishTaskChange]taskID:%d, time.Now().After(cycleEndTime)", taskPO.ID) - err = proc.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + logs.CtxInfo(ctx, "[OnTaskFinished]taskID:%d, time.Now().After(cycleEndTime)", taskPO.ID) + err = proc.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: taskPO, TaskRun: taskRun, IsFinish: false, }) if err != nil { - logs.CtxError(ctx, "OnFinishTaskChange err:%v", err) + logs.CtxError(ctx, "OnTaskFinished err:%v", err) continue } if taskPO.Sampler.IsCycle { - err = proc.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ + err = proc.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ CurrentTask: taskPO, RunType: entity.TaskRunTypeNewData, RunStartAt: taskRun.RunEndAt.UnixMilli(), RunEndAt: taskRun.RunEndAt.UnixMilli() + (taskRun.RunEndAt.UnixMilli() - taskRun.RunStartAt.UnixMilli()), }) if err != nil { - logs.CtxError(ctx, "OnCreateTaskRunChange err:%v", err) + logs.CtxError(ctx, "OnTaskRunCreated err:%v", err) continue } } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go index 6e5a71f70..e8e344f82 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go @@ -23,8 +23,8 @@ import ( type trackingProcessor struct { *stubProcessor - finishReqs []taskexe.OnFinishTaskChangeReq - createRunReqs []taskexe.OnCreateTaskRunChangeReq + finishReqs []taskexe.OnTaskFinishedReq + createRunReqs []taskexe.OnTaskRunCreatedReq updateStatuses []entity.TaskStatus } @@ -32,19 +32,19 @@ func newTrackingProcessor() *trackingProcessor { return &trackingProcessor{stubProcessor: &stubProcessor{}} } -func (p *trackingProcessor) OnFinishTaskChange(ctx context.Context, req taskexe.OnFinishTaskChangeReq) error { +func (p *trackingProcessor) OnFinishTaskChange(ctx context.Context, req taskexe.OnTaskFinishedReq) error { p.finishReqs = append(p.finishReqs, req) return p.stubProcessor.OnFinishTaskChange(ctx, req) } -func (p *trackingProcessor) OnCreateTaskRunChange(ctx context.Context, req taskexe.OnCreateTaskRunChangeReq) error { +func (p *trackingProcessor) OnCreateTaskRunChange(ctx context.Context, req taskexe.OnTaskRunCreatedReq) error { p.createRunReqs = append(p.createRunReqs, req) return p.stubProcessor.OnCreateTaskRunChange(ctx, req) } -func (p *trackingProcessor) OnUpdateTaskChange(ctx context.Context, obsTask *entity.ObservabilityTask, status entity.TaskStatus) error { +func (p *trackingProcessor) OnTaskUpdated(ctx context.Context, obsTask *entity.ObservabilityTask, status entity.TaskStatus) error { p.updateStatuses = append(p.updateStatuses, status) - return p.stubProcessor.OnUpdateTaskChange(ctx, obsTask, status) + return p.stubProcessor.OnTaskUpdated(ctx, obsTask, status) } func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go index 22a9c9286..1422dce04 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go @@ -153,8 +153,8 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S merr = multierror.Append(merr, errors.WithMessagef(err, "task is unstarted, need sub.Creative,creative processor, task_id=%d", sub.taskID)) continue } - if err := sub.processor.OnUpdateTaskChange(ctx, tconv.TaskDTO2DO(sub.t), entity.TaskStatusRunning); err != nil { - logs.CtxWarn(ctx, "OnUpdateTaskChange, task_id=%d, err=%v", sub.taskID, err) + if err := sub.processor.OnTaskUpdated(ctx, tconv.TaskDTO2DO(sub.t), entity.TaskStatusRunning); err != nil { + logs.CtxWarn(ctx, "OnTaskUpdated, task_id=%d, err=%v", sub.taskID, err) continue } } @@ -192,8 +192,8 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S endTime := time.UnixMilli(sub.t.GetRule().GetEffectiveTime().GetEndAt()) // Reached task time limit if time.Now().After(endTime) { - logs.CtxWarn(ctx, "[OnFinishTaskChange]time.Now().After(endTime) Finish processor, task_id=%d, endTime=%v, now=%v", sub.taskID, endTime, time.Now()) - if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + logs.CtxWarn(ctx, "[OnTaskFinished]time.Now().After(endTime) Finish processor, task_id=%d, endTime=%v, now=%v", sub.taskID, endTime, time.Now()) + if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: tconv.TaskDTO2DO(sub.t), TaskRun: taskRunConfig, IsFinish: true, @@ -205,8 +205,8 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S } // Reached task limit if taskCount+1 > sampler.GetSampleSize() { - logs.CtxWarn(ctx, "[OnFinishTaskChange]taskCount+1 > sampler.GetSampleSize() Finish processor, task_id=%d", sub.taskID) - if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + logs.CtxWarn(ctx, "[OnTaskFinished]taskCount+1 > sampler.GetSampleSize() Finish processor, task_id=%d", sub.taskID) + if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: tconv.TaskDTO2DO(sub.t), TaskRun: taskRunConfig, IsFinish: true, @@ -219,8 +219,8 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S cycleEndTime := time.Unix(0, taskRunConfig.RunEndAt.UnixMilli()*1e6) // Reached single cycle task time limit if time.Now().After(cycleEndTime) { - logs.CtxInfo(ctx, "[OnFinishTaskChange]time.Now().After(cycleEndTime) Finish processor, task_id=%d", sub.taskID) - if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + logs.CtxInfo(ctx, "[OnTaskFinished]time.Now().After(cycleEndTime) Finish processor, task_id=%d", sub.taskID) + if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: tconv.TaskDTO2DO(sub.t), TaskRun: taskRunConfig, IsFinish: false, @@ -237,8 +237,8 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S } // Reached single cycle task limit if taskRunCount+1 > sampler.GetCycleCount() { - logs.CtxWarn(ctx, "[OnFinishTaskChange]taskRunCount+1 > sampler.GetCycleCount(), task_id=%d", sub.taskID) - if err := sub.processor.OnFinishTaskChange(ctx, taskexe.OnFinishTaskChangeReq{ + logs.CtxWarn(ctx, "[OnTaskFinished]taskRunCount+1 > sampler.GetCycleCount(), task_id=%d", sub.taskID) + if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ Task: tconv.TaskDTO2DO(sub.t), TaskRun: taskRunConfig, IsFinish: false, diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go index 106c43ea8..e1008bd59 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go @@ -154,7 +154,7 @@ func buildBuiltinFilters(ctx context.Context, f span_filter.Filter, req *ListSpa } func (s *spanSubscriber) Creative(ctx context.Context, runStartAt, runEndAt int64) error { - err := s.processor.OnCreateTaskRunChange(ctx, taskexe.OnCreateTaskRunChangeReq{ + err := s.processor.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ CurrentTask: tconv.TaskDTO2DO(s.t), RunType: s.runType, RunStartAt: runStartAt, diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go index e2645ae57..5cdfefbef 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go @@ -29,8 +29,8 @@ type stubProcessor struct { createTaskRunErr error finishChangeInvoked int invokeCalled bool - createTaskRunReqs []taskexe.OnCreateTaskRunChangeReq - finishChangeReqs []taskexe.OnFinishTaskChangeReq + createTaskRunReqs []taskexe.OnTaskRunCreatedReq + finishChangeReqs []taskexe.OnTaskFinishedReq updateCallCount int createTaskRunErrSeq []error finishErrSeq []error @@ -45,16 +45,16 @@ func (s *stubProcessor) Invoke(context.Context, *taskexe.Trigger) error { return s.invokeErr } -func (s *stubProcessor) OnCreateTaskChange(context.Context, *entity.ObservabilityTask) error { +func (s *stubProcessor) OnTaskCreated(context.Context, *entity.ObservabilityTask) error { return s.createTaskErr } -func (s *stubProcessor) OnUpdateTaskChange(context.Context, *entity.ObservabilityTask, entity.TaskStatus) error { +func (s *stubProcessor) OnTaskUpdated(context.Context, *entity.ObservabilityTask, entity.TaskStatus) error { s.updateCallCount++ return s.updateErr } -func (s *stubProcessor) OnFinishTaskChange(_ context.Context, req taskexe.OnFinishTaskChangeReq) error { +func (s *stubProcessor) OnFinishTaskChange(_ context.Context, req taskexe.OnTaskFinishedReq) error { idx := len(s.finishChangeReqs) s.finishChangeReqs = append(s.finishChangeReqs, req) s.finishChangeInvoked++ @@ -64,7 +64,7 @@ func (s *stubProcessor) OnFinishTaskChange(_ context.Context, req taskexe.OnFini return s.finishErr } -func (s *stubProcessor) OnCreateTaskRunChange(_ context.Context, req taskexe.OnCreateTaskRunChangeReq) error { +func (s *stubProcessor) OnCreateTaskRunChange(_ context.Context, req taskexe.OnTaskRunCreatedReq) error { s.createTaskRunReqs = append(s.createTaskRunReqs, req) idx := len(s.createTaskRunReqs) - 1 if idx >= 0 && idx < len(s.createTaskRunErrSeq) { @@ -75,7 +75,7 @@ func (s *stubProcessor) OnCreateTaskRunChange(_ context.Context, req taskexe.OnC return s.createTaskRunErr } -func (s *stubProcessor) OnFinishTaskRunChange(context.Context, taskexe.OnFinishTaskRunChangeReq) error { +func (s *stubProcessor) OnFinishTaskRunChange(context.Context, taskexe.OnTaskRunFinishedReq) error { return s.finishTaskRunErr } diff --git a/backend/modules/observability/infra/repo/mysql/task.go b/backend/modules/observability/infra/repo/mysql/task.go index 9c478b5d6..efa8ce8b8 100644 --- a/backend/modules/observability/infra/repo/mysql/task.go +++ b/backend/modules/observability/infra/repo/mysql/task.go @@ -27,6 +27,9 @@ const ( DefaultLimit = 20 MaxLimit = 501 DefaultOffset = 0 + + MaxRetries = 3 + RetryDelay = 100 * time.Millisecond ) type ListTaskParam struct { @@ -372,7 +375,6 @@ func (d *TaskDaoImpl) order(q *genquery.Query, orderBy string, asc bool) field.E } func (v *TaskDaoImpl) UpdateTaskWithOCC(ctx context.Context, id int64, workspaceID int64, updateMap map[string]interface{}) error { - // todo[xun]: 乐观锁 logs.CtxInfo(ctx, "UpdateTaskWithOCC, id:%d, workspaceID:%d, updateMap:%+v", id, workspaceID, updateMap) q := genquery.Use(v.dbMgr.NewSession(ctx)).ObservabilityTask qd := q.WithContext(ctx) diff --git a/backend/modules/observability/infra/repo/mysql/task_run.go b/backend/modules/observability/infra/repo/mysql/task_run.go index 807c755d3..9cce07d99 100755 --- a/backend/modules/observability/infra/repo/mysql/task_run.go +++ b/backend/modules/observability/infra/repo/mysql/task_run.go @@ -9,7 +9,7 @@ import ( "time" "github.com/coze-dev/coze-loop/backend/infra/db" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" tracecommon "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/model" genquery "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql/gorm_gen/query" @@ -30,7 +30,7 @@ const ( type ListTaskRunParam struct { WorkspaceID *int64 TaskID *int64 - TaskRunStatus *task.RunStatus + TaskRunStatus *entity.TaskRunStatus ReqLimit int32 ReqOffset int32 OrderBy *tracecommon.OrderBy @@ -57,20 +57,6 @@ type TaskRunDaoImpl struct { dbMgr db.Provider } -// TaskRun非终态状态定义 -var NonFinalTaskRunStatuses = []string{ - "pending", // 等待执行 - "running", // 执行中 - "paused", // 暂停 - "retrying", // 重试中 -} - -// 活跃状态定义(非终态状态的子集) -var ActiveTaskRunStatuses = []string{ - "running", // 执行中 - "retrying", // 重试中 -} - // 计算分页参数 func calculateTaskRunPagination(reqLimit, reqOffset int32) (int, int) { limit := DefaultTaskRunLimit @@ -88,7 +74,7 @@ func calculateTaskRunPagination(reqLimit, reqOffset int32) (int, int) { func (v *TaskRunDaoImpl) GetBackfillTaskRun(ctx context.Context, workspaceID *int64, taskID int64) (*model.ObservabilityTaskRun, error) { q := genquery.Use(v.dbMgr.NewSession(ctx)).ObservabilityTaskRun - qd := q.WithContext(ctx).Where(q.TaskType.Eq(task.TaskRunTypeBackFill)).Where(q.TaskID.Eq(taskID)) + qd := q.WithContext(ctx).Where(q.TaskType.Eq(string(entity.TaskRunTypeBackFill))).Where(q.TaskID.Eq(taskID)) if workspaceID != nil { qd = qd.Where(q.WorkspaceID.Eq(*workspaceID)) @@ -106,7 +92,7 @@ func (v *TaskRunDaoImpl) GetBackfillTaskRun(ctx context.Context, workspaceID *in func (v *TaskRunDaoImpl) GetLatestNewDataTaskRun(ctx context.Context, workspaceID *int64, taskID int64) (*model.ObservabilityTaskRun, error) { q := genquery.Use(v.dbMgr.NewSession(ctx)).ObservabilityTaskRun - qd := q.WithContext(ctx).Where(q.TaskType.Eq(task.TaskRunTypeNewData)).Where(q.TaskID.Eq(taskID)) + qd := q.WithContext(ctx).Where(q.TaskType.Eq(string(entity.TaskRunTypeNewData))).Where(q.TaskID.Eq(taskID)) if workspaceID != nil { qd = qd.Where(q.WorkspaceID.Eq(*workspaceID)) @@ -150,12 +136,15 @@ func (v *TaskRunDaoImpl) ListTaskRuns(ctx context.Context, param ListTaskRunPara var total int64 // TaskID过滤 - if param.TaskID != nil { - qd = qd.Where(q.ObservabilityTaskRun.TaskID.Eq(*param.TaskID)) + if param.TaskID == nil { + logs.CtxError(ctx, "TaskID is nil") + return nil, 0, errorx.NewByCode(obErrorx.CommonInvalidParamCode, errorx.WithExtraMsg("TaskID is nil")) } + qd = qd.Where(q.ObservabilityTaskRun.TaskID.Eq(*param.TaskID)) + // TaskRunStatus过滤 if param.TaskRunStatus != nil { - qd = qd.Where(q.ObservabilityTaskRun.RunStatus.Eq(*param.TaskRunStatus)) + qd = qd.Where(q.ObservabilityTaskRun.RunStatus.Eq(string(*param.TaskRunStatus))) } // workspaceID过滤 if param.WorkspaceID != nil { @@ -206,11 +195,6 @@ func (d *TaskRunDaoImpl) order(q *genquery.Query, orderBy string, asc bool) fiel return orderExpr.Desc() } -const ( - MaxRetries = 3 - RetryDelay = 100 * time.Millisecond -) - // UpdateTaskRunWithOCC 乐观并发控制更新 func (v *TaskRunDaoImpl) UpdateTaskRunWithOCC(ctx context.Context, id int64, workspaceID int64, updateMap map[string]interface{}) error { q := genquery.Use(v.dbMgr.NewSession(ctx)).ObservabilityTaskRun diff --git a/backend/modules/observability/infra/repo/redis/task.go b/backend/modules/observability/infra/repo/redis/task.go index d6f5a541a..a96ce325a 100755 --- a/backend/modules/observability/infra/repo/redis/task.go +++ b/backend/modules/observability/infra/repo/redis/task.go @@ -61,6 +61,7 @@ func (q *TaskDAOImpl) makeTaskCacheKey(taskID int64) string { return fmt.Sprintf(taskDetailCacheKeyPattern, taskID) } +// 为了兼容旧版,redis key必须保持一致,无法增加前缀 func (q *TaskDAOImpl) makeTaskCountCacheKey(taskID int64) string { return fmt.Sprintf("count_%d", taskID) } From d1079c28defc1bb5bbb56c6e2b74c76a2cc23e53 Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Wed, 5 Nov 2025 11:39:09 +0800 Subject: [PATCH 06/19] =?UTF-8?q?test:=20[Coda]=20=E8=B0=83=E6=95=B4traceh?= =?UTF-8?q?ub=E5=8D=95=E6=B5=8B=E9=80=82=E9=85=8DProcessor=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=20(LogID:=202025110511251301009111510418056FE)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: Coda --- .../service/mocks/task_callback_service.go | 70 +++++++++++++++++++ .../tracehub/callback.go => task_callback.go} | 11 +-- ...callback_test.go => task_callback_test.go} | 7 +- .../service/taskexe/tracehub/backfill_test.go | 2 +- .../taskexe/tracehub/scheduled_task_test.go | 31 ++++++-- .../taskexe/tracehub/span_trigger_test.go | 4 +- .../taskexe/tracehub/test_helpers_test.go | 6 +- .../mq/consumer/autotask_callback_consumer.go | 10 +-- 8 files changed, 116 insertions(+), 25 deletions(-) create mode 100644 backend/modules/observability/domain/task/service/mocks/task_callback_service.go rename backend/modules/observability/domain/task/service/{taskexe/tracehub/callback.go => task_callback.go} (92%) rename backend/modules/observability/domain/task/service/{taskexe/tracehub/callback_test.go => task_callback_test.go} (96%) diff --git a/backend/modules/observability/domain/task/service/mocks/task_callback_service.go b/backend/modules/observability/domain/task/service/mocks/task_callback_service.go new file mode 100644 index 000000000..0fd9ecb1e --- /dev/null +++ b/backend/modules/observability/domain/task/service/mocks/task_callback_service.go @@ -0,0 +1,70 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service (interfaces: ITaskCallbackService) +// +// Generated by this command: +// +// mockgen -destination=mocks/task_callback_service.go -package=mocks . ITaskCallbackService +// + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + gomock "go.uber.org/mock/gomock" +) + +// MockITaskCallbackService is a mock of ITaskCallbackService interface. +type MockITaskCallbackService struct { + ctrl *gomock.Controller + recorder *MockITaskCallbackServiceMockRecorder + isgomock struct{} +} + +// MockITaskCallbackServiceMockRecorder is the mock recorder for MockITaskCallbackService. +type MockITaskCallbackServiceMockRecorder struct { + mock *MockITaskCallbackService +} + +// NewMockITaskCallbackService creates a new mock instance. +func NewMockITaskCallbackService(ctrl *gomock.Controller) *MockITaskCallbackService { + mock := &MockITaskCallbackService{ctrl: ctrl} + mock.recorder = &MockITaskCallbackServiceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockITaskCallbackService) EXPECT() *MockITaskCallbackServiceMockRecorder { + return m.recorder +} + +// AutoEvalCallback mocks base method. +func (m *MockITaskCallbackService) AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AutoEvalCallback", ctx, event) + ret0, _ := ret[0].(error) + return ret0 +} + +// AutoEvalCallback indicates an expected call of AutoEvalCallback. +func (mr *MockITaskCallbackServiceMockRecorder) AutoEvalCallback(ctx, event any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AutoEvalCallback", reflect.TypeOf((*MockITaskCallbackService)(nil).AutoEvalCallback), ctx, event) +} + +// AutoEvalCorrection mocks base method. +func (m *MockITaskCallbackService) AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AutoEvalCorrection", ctx, event) + ret0, _ := ret[0].(error) + return ret0 +} + +// AutoEvalCorrection indicates an expected call of AutoEvalCorrection. +func (mr *MockITaskCallbackServiceMockRecorder) AutoEvalCorrection(ctx, event any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AutoEvalCorrection", reflect.TypeOf((*MockITaskCallbackService)(nil).AutoEvalCorrection), ctx, event) +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/callback.go b/backend/modules/observability/domain/task/service/task_callback.go similarity index 92% rename from backend/modules/observability/domain/task/service/taskexe/tracehub/callback.go rename to backend/modules/observability/domain/task/service/task_callback.go index 70f454685..09f75e403 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/callback.go +++ b/backend/modules/observability/domain/task/service/task_callback.go @@ -1,7 +1,7 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 -package tracehub +package service import ( "context" @@ -11,20 +11,21 @@ import ( "github.com/coze-dev/coze-loop/backend/infra/external/benefit" "github.com/coze-dev/coze-loop/backend/infra/middleware/session" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" "github.com/coze-dev/coze-loop/backend/pkg/logs" "github.com/samber/lo" ) -func (h *TraceHubServiceImpl) CallBack(ctx context.Context, event *entity.AutoEvalEvent) error { +func (h *tracehub.TraceHubServiceImpl) CallBack(ctx context.Context, event *entity.AutoEvalEvent) error { for _, turn := range event.TurnEvalResults { workspaceIDStr, workspaceID := turn.GetWorkspaceIDFromExt() tenants, err := h.getTenants(ctx, loop_span.PlatformType("callback_all")) if err != nil { return err } - var storageDuration int64 = 1 + storageDuration := h.config.GetTraceDataMaxDurationDay(ctx, loop_span.PlatformDefault) res, err := h.benefitSvc.CheckTraceBenefit(ctx, &benefit.CheckTraceBenefitParams{ ConnectorUID: turn.BaseInfo.CreatedBy.UserID, SpaceID: workspaceID, @@ -33,7 +34,7 @@ func (h *TraceHubServiceImpl) CallBack(ctx context.Context, event *entity.AutoEv logs.CtxWarn(ctx, "fail to check trace benefit, %v", err) } else if res == nil { logs.CtxWarn(ctx, "fail to get trace benefit, got nil response") - } else if res != nil { + } else { storageDuration = res.StorageDuration } @@ -99,7 +100,7 @@ func (h *TraceHubServiceImpl) CallBack(ctx context.Context, event *entity.AutoEv return nil } -func (h *TraceHubServiceImpl) Correction(ctx context.Context, event *entity.CorrectionEvent) error { +func (h *tracehub.TraceHubServiceImpl) Correction(ctx context.Context, event *entity.CorrectionEvent) error { workspaceIDStr, workspaceID := event.GetWorkspaceIDFromExt() if workspaceID == 0 { return fmt.Errorf("workspace_id is empty") diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/callback_test.go b/backend/modules/observability/domain/task/service/task_callback_test.go similarity index 96% rename from backend/modules/observability/domain/task/service/taskexe/tracehub/callback_test.go rename to backend/modules/observability/domain/task/service/task_callback_test.go index bbea9d3e7..f31dd9806 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/callback_test.go +++ b/backend/modules/observability/domain/task/service/task_callback_test.go @@ -1,7 +1,7 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 -package tracehub +package service import ( "context" @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" "go.uber.org/mock/gomock" "github.com/coze-dev/coze-loop/backend/infra/external/benefit" @@ -33,7 +34,7 @@ func TestTraceHubServiceImpl_CallBackSuccess(t *testing.T) { mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) mockTaskRepo := repo_mocks.NewMockITaskRepo(ctrl) - impl := &TraceHubServiceImpl{ + impl := &tracehub.TraceHubServiceImpl{ benefitSvc: mockBenefit, tenantProvider: mockTenant, traceRepo: mockTraceRepo, @@ -98,7 +99,7 @@ func TestTraceHubServiceImpl_CallBackSpanNotFound(t *testing.T) { mockTenant := tenant_mocks.NewMockITenantProvider(ctrl) mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - impl := &TraceHubServiceImpl{ + impl := &tracehub.TraceHubServiceImpl{ benefitSvc: mockBenefit, tenantProvider: mockTenant, traceRepo: mockTraceRepo, diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go index 0b8276182..a7442b88f 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go @@ -43,7 +43,7 @@ func TestTraceHubServiceImpl_SetBackfillTask(t *testing.T) { mockRepo := repo_mocks.NewMockITaskRepo(ctrl) taskProcessor := processor.NewTaskProcessor() proc := &stubProcessor{} - taskProcessor.Register(task.TaskTypeAutoEval, proc) + taskProcessor.Register(entity.TaskTypeAutoEval, proc) impl := &TraceHubServiceImpl{ taskRepo: mockRepo, diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go index e8e344f82..f881d0d32 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go @@ -32,14 +32,14 @@ func newTrackingProcessor() *trackingProcessor { return &trackingProcessor{stubProcessor: &stubProcessor{}} } -func (p *trackingProcessor) OnFinishTaskChange(ctx context.Context, req taskexe.OnTaskFinishedReq) error { +func (p *trackingProcessor) OnTaskFinished(ctx context.Context, req taskexe.OnTaskFinishedReq) error { p.finishReqs = append(p.finishReqs, req) - return p.stubProcessor.OnFinishTaskChange(ctx, req) + return p.stubProcessor.OnTaskFinished(ctx, req) } -func (p *trackingProcessor) OnCreateTaskRunChange(ctx context.Context, req taskexe.OnTaskRunCreatedReq) error { +func (p *trackingProcessor) OnTaskRunCreated(ctx context.Context, req taskexe.OnTaskRunCreatedReq) error { p.createRunReqs = append(p.createRunReqs, req) - return p.stubProcessor.OnCreateTaskRunChange(ctx, req) + return p.stubProcessor.OnTaskRunCreated(ctx, req) } func (p *trackingProcessor) OnTaskUpdated(ctx context.Context, obsTask *entity.ObservabilityTask, status entity.TaskStatus) error { @@ -142,7 +142,7 @@ func TestTraceHubServiceImpl_transformTaskStatus(t *testing.T) { require.Len(t, proc.createRunReqs, 1) require.Equal(t, entity.TaskRunTypeNewData, proc.createRunReqs[0].RunType) require.Len(t, proc.updateStatuses, 1) - require.Equal(t, string(entity.TaskStatusRunning), proc.updateStatuses[0]) + require.Equal(t, entity.TaskStatusRunning, proc.updateStatuses[0]) }, }, { @@ -336,9 +336,26 @@ func TestTraceHubServiceImpl_syncTaskCache(t *testing.T) { impl := &TraceHubServiceImpl{taskRepo: mockRepo} impl.taskCache.Store("ObjListWithTask", TaskCacheInfo{}) - workspaceIDs := []string{"space-1"} + tasks := []*entity.ObservabilityTask{ + { + ID: 100, + WorkspaceID: 1, + SpanFilter: &entity.SpanFilterFields{ + Filters: loop_span.FilterFields{ + FilterFields: []*loop_span.FilterField{ + { + FieldName: "bot_id", + Values: []string{"bot-1"}, + }, + }, + }, + }, + }, + } + workspaceIDs := []string{"1"} botIDs := []string{"bot-1"} - tasks := []*entity.ObservabilityTask{{ID: 100}} + + mockRepo.EXPECT().ListNonFinalTasks(gomock.Any()).Return(tasks, nil) before := time.Now() impl.syncTaskCache() diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go index 6f6689b1d..ad93d2013 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go @@ -96,6 +96,8 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { }, } + mockRepo.EXPECT().ListNonFinalTaskBySpaceID(gomock.Any(), gomock.Any()).Return([]int64{taskDO.ID}, nil).AnyTimes() + configLoader.EXPECT().UnmarshalKey(gomock.Any(), "consumer_listening", gomock.Any()).DoAndReturn( func(_ context.Context, _ string, value any, _ ...pkgconf.DecodeOptionFn) error { cfg := value.(*componentconfig.ConsumerListening) @@ -123,7 +125,7 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { proc := &stubProcessor{invokeErr: errors.New("invoke error"), createTaskRunErr: errors.New("create run error")} taskProcessor := processor.NewTaskProcessor() - taskProcessor.Register(task.TaskTypeAutoEval, proc) + taskProcessor.Register(entity.TaskTypeAutoEval, proc) impl := &TraceHubServiceImpl{ taskRepo: mockRepo, diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go index 5cdfefbef..5664d53bc 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/test_helpers_test.go @@ -54,7 +54,7 @@ func (s *stubProcessor) OnTaskUpdated(context.Context, *entity.ObservabilityTask return s.updateErr } -func (s *stubProcessor) OnFinishTaskChange(_ context.Context, req taskexe.OnTaskFinishedReq) error { +func (s *stubProcessor) OnTaskFinished(_ context.Context, req taskexe.OnTaskFinishedReq) error { idx := len(s.finishChangeReqs) s.finishChangeReqs = append(s.finishChangeReqs, req) s.finishChangeInvoked++ @@ -64,7 +64,7 @@ func (s *stubProcessor) OnFinishTaskChange(_ context.Context, req taskexe.OnTask return s.finishErr } -func (s *stubProcessor) OnCreateTaskRunChange(_ context.Context, req taskexe.OnTaskRunCreatedReq) error { +func (s *stubProcessor) OnTaskRunCreated(_ context.Context, req taskexe.OnTaskRunCreatedReq) error { s.createTaskRunReqs = append(s.createTaskRunReqs, req) idx := len(s.createTaskRunReqs) - 1 if idx >= 0 && idx < len(s.createTaskRunErrSeq) { @@ -75,7 +75,7 @@ func (s *stubProcessor) OnCreateTaskRunChange(_ context.Context, req taskexe.OnT return s.createTaskRunErr } -func (s *stubProcessor) OnFinishTaskRunChange(context.Context, taskexe.OnTaskRunFinishedReq) error { +func (s *stubProcessor) OnTaskRunFinished(context.Context, taskexe.OnTaskRunFinishedReq) error { return s.finishTaskRunErr } diff --git a/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go b/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go index 5ae8ceee6..a28694e75 100644 --- a/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go @@ -17,19 +17,19 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/logs" ) -type CallbackConsumer struct { +type AutoEvalCallbackConsumer struct { handler obapp.ITaskQueueConsumer conf.IConfigLoader } func newCallbackConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { - return &CallbackConsumer{ + return &AutoEvalCallbackConsumer{ handler: handler, IConfigLoader: loader, } } -func (e *CallbackConsumer) ConsumerCfg(ctx context.Context) (*mq.ConsumerConfig, error) { +func (e *AutoEvalCallbackConsumer) ConsumerCfg(ctx context.Context) (*mq.ConsumerConfig, error) { const key = "autotask_callback_mq_consumer_config" cfg := &config.MqConsumerCfg{} if err := e.UnmarshalKey(ctx, key, cfg); err != nil { @@ -46,7 +46,7 @@ func (e *CallbackConsumer) ConsumerCfg(ctx context.Context) (*mq.ConsumerConfig, return res, nil } -func (e *CallbackConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt) error { +func (e *AutoEvalCallbackConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt) error { logID := logs.NewLogID() ctx = logs.SetLogID(ctx, logID) event := new(entity.AutoEvalEvent) @@ -55,5 +55,5 @@ func (e *CallbackConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt return nil } logs.CtxInfo(ctx, "Callback msg, event: %v,msgID: %s", event, ext.MsgID) - return e.handler.CallBack(ctx, event) + return e.handler.AutoEvalCallback(ctx, event) } From 333923dab05e0f521ff83932498529f5cf59aa83 Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Wed, 5 Nov 2025 15:14:40 +0800 Subject: [PATCH 07/19] Refactor backfill. Change-Id: Ic3b28c4a0e28bc40884b7232cfc4bf77109ac601 --- .../application/convertor/task/task.go | 4 +- .../application/convertor/task/task_test.go | 2 +- .../modules/observability/application/task.go | 60 +++-- .../observability/application/task_test.go | 58 +++-- .../modules/observability/application/wire.go | 1 + .../observability/application/wire_gen.go | 31 +-- .../observability/domain/task/entity/event.go | 60 ++++- .../observability/domain/task/entity/task.go | 32 ++- .../domain/task/service/task_callback.go | 198 +++++++++++----- .../domain/task/service/task_callback_test.go | 191 ++++++++++++++- .../taskexe/processor/auto_evaluate.go | 4 +- .../taskexe/processor/auto_evaluate_test.go | 2 +- .../task/service/taskexe/tracehub/backfill.go | 224 ++++++++---------- .../service/taskexe/tracehub/backfill_test.go | 4 +- .../tracehub/mocks/trace_hub_service.go | 45 +--- .../service/taskexe/tracehub/span_trigger.go | 94 ++++---- .../service/taskexe/tracehub/subscriber.go | 53 ++--- .../service/taskexe/tracehub/trace_hub.go | 11 - .../taskexe/tracehub/trace_hub_test.go | 184 +------------- .../task/service/taskexe/tracehub/utils.go | 81 +------ .../trace/entity/loop_span/annotation.go | 10 + .../domain/trace/entity/loop_span/span.go | 29 +++ .../mq/consumer/autotask_callback_consumer.go | 8 +- .../infra/mq/consumer/consumer.go | 27 +-- .../infra/mq/consumer/correction_consumer.go | 6 +- 25 files changed, 725 insertions(+), 694 deletions(-) diff --git a/backend/modules/observability/application/convertor/task/task.go b/backend/modules/observability/application/convertor/task/task.go index c7c7555fe..38831ee19 100644 --- a/backend/modules/observability/application/convertor/task/task.go +++ b/backend/modules/observability/application/convertor/task/task.go @@ -204,7 +204,7 @@ func SamplerDO2DTO(sampler *entity.Sampler) *task.Sampler { IsCycle: ptr.Of(sampler.IsCycle), CycleCount: ptr.Of(sampler.CycleCount), CycleInterval: ptr.Of(sampler.CycleInterval), - CycleTimeUnit: ptr.Of(sampler.CycleTimeUnit), + CycleTimeUnit: ptr.Of(string(sampler.CycleTimeUnit)), } } @@ -381,7 +381,7 @@ func SamplerDTO2DO(sampler *task.Sampler) *entity.Sampler { IsCycle: sampler.GetIsCycle(), CycleCount: sampler.GetCycleCount(), CycleInterval: sampler.GetCycleInterval(), - CycleTimeUnit: sampler.GetCycleTimeUnit(), + CycleTimeUnit: entity.TimeUnit(sampler.GetCycleTimeUnit()), } } diff --git a/backend/modules/observability/application/convertor/task/task_test.go b/backend/modules/observability/application/convertor/task/task_test.go index 14a0cf01d..96d3d529d 100755 --- a/backend/modules/observability/application/convertor/task/task_test.go +++ b/backend/modules/observability/application/convertor/task/task_test.go @@ -100,7 +100,7 @@ func TestTaskDOs2DTOs(t *testing.T) { IsCycle: true, CycleCount: 2, CycleInterval: 3, - CycleTimeUnit: kitTask.TimeUnitDay, + CycleTimeUnit: entity.TimeUnitDay, }, TaskConfig: &entity.TaskConfig{}, CreatedAt: now, diff --git a/backend/modules/observability/application/task.go b/backend/modules/observability/application/task.go index cfdafa2b2..4dc9ca8f2 100644 --- a/backend/modules/observability/application/task.go +++ b/backend/modules/observability/application/task.go @@ -25,8 +25,8 @@ import ( type ITaskQueueConsumer interface { SpanTrigger(ctx context.Context, event *entity.RawSpan) error - CallBack(ctx context.Context, event *entity.AutoEvalEvent) error - Correction(ctx context.Context, event *entity.CorrectionEvent) error + AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error + AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error BackFill(ctx context.Context, event *entity.BackFillEvent) error } @@ -43,26 +43,29 @@ func NewTaskApplication( userService rpc.IUserProvider, tracehubSvc tracehub.ITraceHubService, taskProcessor processor.TaskProcessor, + taskCallbackService service.ITaskCallbackService, ) (ITaskApplication, error) { return &TaskApplication{ - taskSvc: taskService, - authSvc: authService, - evalSvc: evalService, - evaluationSvc: evaluationService, - userSvc: userService, - tracehubSvc: tracehubSvc, - taskProcessor: taskProcessor, + taskSvc: taskService, + authSvc: authService, + evalSvc: evalService, + evaluationSvc: evaluationService, + userSvc: userService, + tracehubSvc: tracehubSvc, + taskProcessor: taskProcessor, + taskCallbackSvc: taskCallbackService, }, nil } type TaskApplication struct { - taskSvc service.ITaskService - authSvc rpc.IAuthProvider - evalSvc rpc.IEvaluatorRPCAdapter - evaluationSvc rpc.IEvaluationRPCAdapter - userSvc rpc.IUserProvider - tracehubSvc tracehub.ITraceHubService - taskProcessor processor.TaskProcessor + taskSvc service.ITaskService + authSvc rpc.IAuthProvider + evalSvc rpc.IEvaluatorRPCAdapter + evaluationSvc rpc.IEvaluationRPCAdapter + userSvc rpc.IUserProvider + tracehubSvc tracehub.ITraceHubService + taskProcessor processor.TaskProcessor + taskCallbackSvc service.ITaskCallbackService } func (t *TaskApplication) CheckTaskName(ctx context.Context, req *task.CheckTaskNameRequest) (*task.CheckTaskNameResponse, error) { @@ -261,14 +264,31 @@ func (t *TaskApplication) SpanTrigger(ctx context.Context, event *entity.RawSpan return t.tracehubSvc.SpanTrigger(ctx, event) } -func (t *TaskApplication) CallBack(ctx context.Context, event *entity.AutoEvalEvent) error { - return t.tracehubSvc.CallBack(ctx, event) +func (t *TaskApplication) AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error { + if err := event.Validate(); err != nil { + logs.CtxError(ctx, "event is invalid, event: %#v, err: %v", event, err) + // 结构校验失败,不处理 + return nil + } + + return t.taskCallbackSvc.AutoEvalCallback(ctx, event) } -func (t *TaskApplication) Correction(ctx context.Context, event *entity.CorrectionEvent) error { - return t.tracehubSvc.Correction(ctx, event) +func (t *TaskApplication) AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error { + if err := event.Validate(); err != nil { + logs.CtxError(ctx, "event is invalid, event: %#v, err: %v", event, err) + // 结构校验失败,不处理 + return nil + } + + return t.taskCallbackSvc.AutoEvalCorrection(ctx, event) } func (t *TaskApplication) BackFill(ctx context.Context, event *entity.BackFillEvent) error { + if err := event.Validate(); err != nil { + logs.CtxError(ctx, "event is invalid, event: %#v, err: %v", event, err) + // 结构校验失败,不处理 + return nil + } return t.tracehubSvc.BackFill(ctx, event) } diff --git a/backend/modules/observability/application/task_test.go b/backend/modules/observability/application/task_test.go index 357db081e..46a1b25b2 100755 --- a/backend/modules/observability/application/task_test.go +++ b/backend/modules/observability/application/task_test.go @@ -11,6 +11,9 @@ import ( "time" "github.com/bytedance/gg/gptr" + tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/common" + "github.com/samber/lo" "github.com/stretchr/testify/assert" "go.uber.org/mock/gomock" @@ -424,8 +427,8 @@ func TestTaskApplication_ListTasks(t *testing.T) { t.Parallel() taskListResp := &svc.ListTasksResp{ - Tasks: []*taskdto.Task{{Name: "task1"}}, - Total: gptr.Of(int64(1)), + Tasks: []*entity.ObservabilityTask{{Name: "task1"}}, + Total: int64(1), } tests := []struct { name string @@ -485,10 +488,13 @@ func TestTaskApplication_ListTasks(t *testing.T) { }, }, { - name: "success", - ctx: context.Background(), - req: &taskapi.ListTasksRequest{WorkspaceID: 789}, - expectResp: &taskapi.ListTasksResponse{Tasks: taskListResp.Tasks, Total: taskListResp.Total}, + name: "success", + ctx: context.Background(), + req: &taskapi.ListTasksRequest{WorkspaceID: 789}, + expectResp: &taskapi.ListTasksResponse{ + Tasks: tconv.TaskDOs2DTOs(context.Background(), taskListResp.Tasks, map[string]*common.UserInfo{}), + Total: lo.ToPtr(taskListResp.Total), + }, fieldsBuilder: func(ctrl *gomock.Controller) (svc.ITaskService, rpc.IAuthProvider) { auth := rpcmock.NewMockIAuthProvider(ctrl) auth.EXPECT().CheckWorkspacePermission(gomock.Any(), rpc.AuthActionTraceTaskList, strconv.FormatInt(789, 10), false).Return(nil) @@ -536,7 +542,7 @@ func TestTaskApplication_ListTasks(t *testing.T) { func TestTaskApplication_GetTask(t *testing.T) { t.Parallel() - taskResp := &svc.GetTaskResp{Task: &taskdto.Task{Name: "task"}} + taskResp := &svc.GetTaskResp{Task: &entity.ObservabilityTask{Name: "task"}} tests := []struct { name string @@ -597,7 +603,7 @@ func TestTaskApplication_GetTask(t *testing.T) { name: "success", ctx: context.Background(), req: &taskapi.GetTaskRequest{WorkspaceID: 202, TaskID: 3}, - expectResp: &taskapi.GetTaskResponse{Task: taskResp.Task}, + expectResp: &taskapi.GetTaskResponse{Task: tconv.TaskDO2DTO(context.Background(), taskResp.Task, map[string]*common.UserInfo{})}, fieldsBuilder: func(ctrl *gomock.Controller) (svc.ITaskService, rpc.IAuthProvider) { auth := rpcmock.NewMockIAuthProvider(ctrl) auth.EXPECT().CheckWorkspacePermission(gomock.Any(), rpc.AuthActionTraceTaskList, strconv.FormatInt(202, 10), false).Return(nil) @@ -694,23 +700,23 @@ func TestTaskApplication_CallBack(t *testing.T) { event := &entity.AutoEvalEvent{} tests := []struct { name string - mockSvc func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService + mockSvc func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService expectErr bool }{ { name: "trace hub error", - mockSvc: func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService { - svc := tracehubmock.NewMockITraceHubService(ctrl) - svc.EXPECT().CallBack(gomock.Any(), event).Return(errors.New("hub error")) + mockSvc: func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService { + svc := svcmock.NewMockITaskCallbackService(ctrl) + svc.EXPECT().AutoEvalCallback(gomock.Any(), event).Return(errors.New("hub error")) return svc }, expectErr: true, }, { name: "success", - mockSvc: func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService { - svc := tracehubmock.NewMockITraceHubService(ctrl) - svc.EXPECT().CallBack(gomock.Any(), event).Return(nil) + mockSvc: func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService { + svc := svcmock.NewMockITaskCallbackService(ctrl) + svc.EXPECT().AutoEvalCallback(gomock.Any(), event).Return(nil) return svc }, }, @@ -724,8 +730,8 @@ func TestTaskApplication_CallBack(t *testing.T) { defer ctrl.Finish() traceSvc := caseItem.mockSvc(ctrl) - app := &TaskApplication{tracehubSvc: traceSvc} - err := app.CallBack(context.Background(), event) + app := &TaskApplication{taskCallbackSvc: traceSvc} + err := app.AutoEvalCallback(context.Background(), event) if caseItem.expectErr { assert.Error(t, err) } else { @@ -741,23 +747,23 @@ func TestTaskApplication_Correction(t *testing.T) { event := &entity.CorrectionEvent{} tests := []struct { name string - mockSvc func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService + mockSvc func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService expectErr bool }{ { name: "trace hub error", - mockSvc: func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService { - svc := tracehubmock.NewMockITraceHubService(ctrl) - svc.EXPECT().Correction(gomock.Any(), event).Return(errors.New("hub error")) + mockSvc: func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService { + svc := svcmock.NewMockITaskCallbackService(ctrl) + svc.EXPECT().AutoEvalCorrection(gomock.Any(), event).Return(errors.New("hub error")) return svc }, expectErr: true, }, { name: "success", - mockSvc: func(ctrl *gomock.Controller) *tracehubmock.MockITraceHubService { - svc := tracehubmock.NewMockITraceHubService(ctrl) - svc.EXPECT().Correction(gomock.Any(), event).Return(nil) + mockSvc: func(ctrl *gomock.Controller) *svcmock.MockITaskCallbackService { + svc := svcmock.NewMockITaskCallbackService(ctrl) + svc.EXPECT().AutoEvalCorrection(gomock.Any(), event).Return(nil) return svc }, }, @@ -771,8 +777,8 @@ func TestTaskApplication_Correction(t *testing.T) { defer ctrl.Finish() traceSvc := caseItem.mockSvc(ctrl) - app := &TaskApplication{tracehubSvc: traceSvc} - err := app.Correction(context.Background(), event) + app := &TaskApplication{taskCallbackSvc: traceSvc} + err := app.AutoEvalCorrection(context.Background(), event) if caseItem.expectErr { assert.Error(t, err) } else { diff --git a/backend/modules/observability/application/wire.go b/backend/modules/observability/application/wire.go index bac80b142..ec9b6e356 100644 --- a/backend/modules/observability/application/wire.go +++ b/backend/modules/observability/application/wire.go @@ -134,6 +134,7 @@ var ( evaluation.NewEvaluationRPCProvider, NewTaskLocker, traceDomainSet, + taskSvc.NewTaskCallbackServiceImpl, ) metricsSet = wire.NewSet( NewMetricApplication, diff --git a/backend/modules/observability/application/wire_gen.go b/backend/modules/observability/application/wire_gen.go index 0308c6142..3a52d6ce7 100644 --- a/backend/modules/observability/application/wire_gen.go +++ b/backend/modules/observability/application/wire_gen.go @@ -55,7 +55,7 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo" ck2 "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/ck" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/mysql" - redis3 "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis" + redis2 "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/redis" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/auth" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/dataset" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/rpc/evaluation" @@ -72,7 +72,7 @@ import ( // Injectors from wire.go: -func InitTraceApplication(db2 db.Provider, ckDb ck.Provider, redis2 redis.Cmdable, meter metrics.Meter, mqFactory mq.IFactory, configFactory conf.IConfigLoaderFactory, idgen2 idgen.IIDGenerator, fileClient fileservice.Client, benefit2 benefit.IBenefitService, authClient authservice.Client, userClient userservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, tagService tagservice.Client, datasetService datasetservice.Client) (ITraceApplication, error) { +func InitTraceApplication(db2 db.Provider, ckDb ck.Provider, redis3 redis.Cmdable, meter metrics.Meter, mqFactory mq.IFactory, configFactory conf.IConfigLoaderFactory, idgen2 idgen.IIDGenerator, fileClient fileservice.Client, benefit2 benefit.IBenefitService, authClient authservice.Client, userClient userservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, tagService tagservice.Client, datasetService datasetservice.Client) (ITraceApplication, error) { iSpansDao, err := ck2.NewSpansCkDaoImpl(ckDb) if err != nil { return nil, err @@ -104,9 +104,9 @@ func InitTraceApplication(db2 db.Provider, ckDb ck.Provider, redis2 redis.Cmdabl iTenantProvider := tenant.NewTenantProvider(iTraceConfig) iEvaluatorRPCAdapter := evaluator.NewEvaluatorRPCProvider(evalService) iTaskDao := mysql.NewTaskDaoImpl(db2) - iTaskDAO := redis3.NewTaskDAO(redis2) + iTaskDAO := redis2.NewTaskDAO(redis3) iTaskRunDao := mysql.NewTaskRunDaoImpl(db2) - iTaskRunDAO := redis3.NewTaskRunDAO(redis2) + iTaskRunDAO := redis2.NewTaskRunDAO(redis3) iTaskRepo := repo.NewTaskRepoImpl(iTaskDao, idgen2, iTaskDAO, iTaskRunDao, iTaskRunDAO) iTraceService, err := service.NewTraceServiceImpl(iTraceRepo, iTraceConfig, iTraceProducer, iAnnotationProducer, iTraceMetrics, traceFilterProcessorBuilder, iTenantProvider, iEvaluatorRPCAdapter, iTaskRepo) if err != nil { @@ -129,7 +129,7 @@ func InitTraceApplication(db2 db.Provider, ckDb ck.Provider, redis2 redis.Cmdabl return iTraceApplication, nil } -func InitOpenAPIApplication(mqFactory mq.IFactory, configFactory conf.IConfigLoaderFactory, fileClient fileservice.Client, ckDb ck.Provider, benefit2 benefit.IBenefitService, limiterFactory limiter.IRateLimiterFactory, authClient authservice.Client, meter metrics.Meter, db2 db.Provider, redis2 redis.Cmdable, idgen2 idgen.IIDGenerator, evalService evaluatorservice.Client) (IObservabilityOpenAPIApplication, error) { +func InitOpenAPIApplication(mqFactory mq.IFactory, configFactory conf.IConfigLoaderFactory, fileClient fileservice.Client, ckDb ck.Provider, benefit2 benefit.IBenefitService, limiterFactory limiter.IRateLimiterFactory, authClient authservice.Client, meter metrics.Meter, db2 db.Provider, redis3 redis.Cmdable, idgen2 idgen.IIDGenerator, evalService evaluatorservice.Client) (IObservabilityOpenAPIApplication, error) { iSpansDao, err := ck2.NewSpansCkDaoImpl(ckDb) if err != nil { return nil, err @@ -161,9 +161,9 @@ func InitOpenAPIApplication(mqFactory mq.IFactory, configFactory conf.IConfigLoa iTenantProvider := tenant.NewTenantProvider(iTraceConfig) iEvaluatorRPCAdapter := evaluator.NewEvaluatorRPCProvider(evalService) iTaskDao := mysql.NewTaskDaoImpl(db2) - iTaskDAO := redis3.NewTaskDAO(redis2) + iTaskDAO := redis2.NewTaskDAO(redis3) iTaskRunDao := mysql.NewTaskRunDaoImpl(db2) - iTaskRunDAO := redis3.NewTaskRunDAO(redis2) + iTaskRunDAO := redis2.NewTaskRunDAO(redis3) iTaskRepo := repo.NewTaskRepoImpl(iTaskDao, idgen2, iTaskDAO, iTaskRunDao, iTaskRunDAO) iTraceService, err := service.NewTraceServiceImpl(iTraceRepo, iTraceConfig, iTraceProducer, iAnnotationProducer, iTraceMetrics, traceFilterProcessorBuilder, iTenantProvider, iEvaluatorRPCAdapter, iTaskRepo) if err != nil { @@ -240,11 +240,11 @@ func InitTraceIngestionApplication(configFactory conf.IConfigLoaderFactory, ckDb return iTraceIngestionApplication, nil } -func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFactory conf.IConfigLoaderFactory, benefit2 benefit.IBenefitService, ckDb ck.Provider, redis2 redis.Cmdable, mqFactory mq.IFactory, userClient userservice.Client, authClient authservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, exptService experimentservice.Client, datasetService datasetservice.Client, fileClient fileservice.Client, taskProcessor processor.TaskProcessor, aid int32) (ITaskApplication, error) { +func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFactory conf.IConfigLoaderFactory, benefit2 benefit.IBenefitService, ckDb ck.Provider, redis3 redis.Cmdable, mqFactory mq.IFactory, userClient userservice.Client, authClient authservice.Client, evalService evaluatorservice.Client, evalSetService evaluationsetservice.Client, exptService experimentservice.Client, datasetService datasetservice.Client, fileClient fileservice.Client, taskProcessor processor.TaskProcessor, aid int32) (ITaskApplication, error) { iTaskDao := mysql.NewTaskDaoImpl(db2) - iTaskDAO := redis3.NewTaskDAO(redis2) + iTaskDAO := redis2.NewTaskDAO(redis3) iTaskRunDao := mysql.NewTaskRunDaoImpl(db2) - iTaskRunDAO := redis3.NewTaskRunDAO(redis2) + iTaskRunDAO := redis2.NewTaskRunDAO(redis3) iTaskRepo := repo.NewTaskRepoImpl(iTaskDao, idgen2, iTaskDAO, iTaskRunDao, iTaskRunDAO) iConfigLoader, err := NewTraceConfigLoader(configFactory) if err != nil { @@ -280,12 +280,13 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto return nil, err } iTenantProvider := tenant.NewTenantProvider(iTraceConfig) - iLocker := NewTaskLocker(redis2) - iTraceHubService, err := tracehub.NewTraceHubImpl(iTaskRepo, iTraceRepo, iTenantProvider, traceFilterProcessorBuilder, processorTaskProcessor, benefit2, aid, iBackfillProducer, iLocker, iConfigLoader) + iLocker := NewTaskLocker(redis3) + iTraceHubService, err := tracehub.NewTraceHubImpl(iTaskRepo, iTraceRepo, iTenantProvider, traceFilterProcessorBuilder, processorTaskProcessor, aid, iBackfillProducer, iLocker, iConfigLoader) if err != nil { return nil, err } - iTaskApplication, err := NewTaskApplication(iTaskService, iAuthProvider, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iUserProvider, iTraceHubService, taskProcessor) + iTaskCallbackService := service3.NewTaskCallbackServiceImpl(iTaskRepo, iTraceRepo, taskProcessor, iTenantProvider, iTraceConfig, benefit2) + iTaskApplication, err := NewTaskApplication(iTaskService, iAuthProvider, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iUserProvider, iTraceHubService, taskProcessor, iTaskCallbackService) if err != nil { return nil, err } @@ -296,7 +297,7 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto var ( taskDomainSet = wire.NewSet( - NewInitTaskProcessor, service3.NewTaskServiceImpl, repo.NewTaskRepoImpl, mysql.NewTaskDaoImpl, redis3.NewTaskDAO, redis3.NewTaskRunDAO, mysql.NewTaskRunDaoImpl, producer.NewBackfillProducerImpl, + NewInitTaskProcessor, service3.NewTaskServiceImpl, repo.NewTaskRepoImpl, mysql.NewTaskDaoImpl, redis2.NewTaskDAO, redis2.NewTaskRunDAO, mysql.NewTaskRunDaoImpl, producer.NewBackfillProducerImpl, ) traceDomainSet = wire.NewSet(service.NewTraceServiceImpl, service.NewTraceExportServiceImpl, repo.NewTraceCKRepoImpl, ck2.NewSpansCkDaoImpl, ck2.NewAnnotationCkDaoImpl, metrics2.NewTraceMetricsImpl, collector.NewEventCollectorProvider, producer.NewTraceProducerImpl, producer.NewAnnotationProducerImpl, file.NewFileRPCProvider, NewTraceConfigLoader, NewTraceProcessorBuilder, config.NewTraceConfigCenter, tenant.NewTenantProvider, workspace.NewWorkspaceProvider, evaluator.NewEvaluatorRPCProvider, NewDatasetServiceAdapter, @@ -313,7 +314,7 @@ var ( NewOpenAPIApplication, auth.NewAuthProvider, traceDomainSet, ) taskSet = wire.NewSet(tracehub.NewTraceHubImpl, NewTaskApplication, auth.NewAuthProvider, user.NewUserRPCProvider, evaluation.NewEvaluationRPCProvider, NewTaskLocker, - traceDomainSet, + traceDomainSet, service3.NewTaskCallbackServiceImpl, ) metricsSet = wire.NewSet( NewMetricApplication, service2.NewMetricsService, repo.NewTraceMetricCKRepoImpl, tenant.NewTenantProvider, auth.NewAuthProvider, NewTraceConfigLoader, diff --git a/backend/modules/observability/domain/task/entity/event.go b/backend/modules/observability/domain/task/entity/event.go index c6bcd86e8..512ddefe2 100644 --- a/backend/modules/observability/domain/task/entity/event.go +++ b/backend/modules/observability/domain/task/entity/event.go @@ -4,9 +4,12 @@ package entity import ( + "fmt" "strconv" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" ) type RawSpan struct { @@ -174,6 +177,14 @@ type AutoEvalEvent struct { ExptID int64 `json:"expt_id"` TurnEvalResults []*OnlineExptTurnEvalResult `json:"turn_eval_results"` } + +func (e *AutoEvalEvent) Validate() error { + if e.TurnEvalResults == nil || len(e.TurnEvalResults) == 0 { + return fmt.Errorf("turn_eval_results is required") + } + return nil +} + type OnlineExptTurnEvalResult struct { EvaluatorVersionID int64 `json:"evaluator_version_id"` EvaluatorRecordID int64 `json:"evaluator_record_id"` @@ -251,6 +262,22 @@ func (s *OnlineExptTurnEvalResult) GetWorkspaceIDFromExt() (string, int64) { return workspaceIDStr, workspaceID } +func (s *OnlineExptTurnEvalResult) GetRunID() (int64, error) { + taskRunIDStr := s.Ext["run_id"] + if taskRunIDStr == "" { + return 0, fmt.Errorf("run_id not found in ext") + } + + return strconv.ParseInt(taskRunIDStr, 10, 64) +} + +func (s *OnlineExptTurnEvalResult) GetUserID() string { + if s.BaseInfo == nil || s.BaseInfo.UpdatedBy == nil { + return "" + } + return s.BaseInfo.UpdatedBy.UserID +} + type EvaluatorRunError struct { Code int32 `json:"code"` Message string `json:"message"` @@ -277,9 +304,14 @@ type CorrectionEvent struct { UpdatedAt int64 `json:"updated_at"` } -type BackFillEvent struct { - SpaceID int64 `json:"space_id"` - TaskID int64 `json:"task_id"` +func (c *CorrectionEvent) Validate() error { + if c.EvaluatorRecordID == 0 { + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("evaluator_record_id is empty")) + } + if c.EvaluatorResult == nil || c.EvaluatorResult.Correction == nil { + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("correction is empty")) + } + return nil } func (c *CorrectionEvent) GetSpanIDFromExt() string { @@ -331,3 +363,25 @@ func (c *CorrectionEvent) GetWorkspaceIDFromExt() (string, int64) { } return workspaceIDStr, workspaceID } + +func (c *CorrectionEvent) GetUpdateBy() string { + if c == nil || c.EvaluatorResult == nil || c.EvaluatorResult.Correction == nil { + return "" + } + return c.EvaluatorResult.Correction.UpdatedBy +} + +type BackFillEvent struct { + SpaceID int64 `json:"space_id"` + TaskID int64 `json:"task_id"` +} + +func (b *BackFillEvent) Validate() error { + if b.SpaceID == 0 { + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("space_id is empty")) + } + if b.TaskID == 0 { + return errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode, errorx.WithExtraMsg("task_id is empty")) + } + return nil +} diff --git a/backend/modules/observability/domain/task/entity/task.go b/backend/modules/observability/domain/task/entity/task.go index 16ab0fa7e..548c19747 100644 --- a/backend/modules/observability/domain/task/entity/task.go +++ b/backend/modules/observability/domain/task/entity/task.go @@ -15,6 +15,14 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/logs" ) +type TimeUnit string + +const ( + TimeUnitDay = "day" + TimeUnitWeek = "week" + TimeUnitNull = "null" +) + type TaskStatus string const ( @@ -91,12 +99,12 @@ type EffectiveTime struct { EndAt int64 `json:"end_at"` } type Sampler struct { - SampleRate float64 `json:"sample_rate"` - SampleSize int64 `json:"sample_size"` - IsCycle bool `json:"is_cycle"` - CycleCount int64 `json:"cycle_count"` - CycleInterval int64 `json:"cycle_interval"` - CycleTimeUnit string `json:"cycle_time_unit"` + SampleRate float64 `json:"sample_rate"` + SampleSize int64 `json:"sample_size"` + IsCycle bool `json:"is_cycle"` + CycleCount int64 `json:"cycle_count"` + CycleInterval int64 `json:"cycle_interval"` + CycleTimeUnit TimeUnit `json:"cycle_time_unit"` } type TaskConfig struct { AutoEvaluateConfigs []*AutoEvaluateConfig `json:"auto_evaluate_configs"` @@ -176,18 +184,18 @@ func (t *ObservabilityTask) IsFinished() bool { } func (t *ObservabilityTask) GetBackfillTaskRun() *TaskRun { - for _, taskRunPO := range t.TaskRuns { - if taskRunPO.TaskType == TaskRunTypeBackFill { - return taskRunPO + for _, taskRun := range t.TaskRuns { + if taskRun.TaskType == TaskRunTypeBackFill { + return taskRun } } return nil } func (t *ObservabilityTask) GetCurrentTaskRun() *TaskRun { - for _, taskRunPO := range t.TaskRuns { - if taskRunPO.TaskType == TaskRunTypeNewData && taskRunPO.RunStatus == TaskRunStatusRunning { - return taskRunPO + for _, taskRun := range t.TaskRuns { + if taskRun.TaskType == TaskRunTypeNewData && taskRun.RunStatus == TaskRunStatusRunning { + return taskRun } } return nil diff --git a/backend/modules/observability/domain/task/service/task_callback.go b/backend/modules/observability/domain/task/service/task_callback.go index 09f75e403..90d1fda40 100644 --- a/backend/modules/observability/domain/task/service/task_callback.go +++ b/backend/modules/observability/domain/task/service/task_callback.go @@ -9,25 +9,63 @@ import ( "time" "github.com/coze-dev/coze-loop/backend/infra/external/benefit" - "github.com/coze-dev/coze-loop/backend/infra/middleware/session" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/tenant" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" + tracerepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" + obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" "github.com/samber/lo" ) -func (h *tracehub.TraceHubServiceImpl) CallBack(ctx context.Context, event *entity.AutoEvalEvent) error { +//go:generate mockgen -destination=mocks/task_callback_service.go -package=mocks . ITaskCallbackService +type ITaskCallbackService interface { + AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error + AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error +} + +type TaskCallbackServiceImpl struct { + taskRepo repo.ITaskRepo + traceRepo tracerepo.ITraceRepo + taskProcessor processor.TaskProcessor + tenantProvider tenant.ITenantProvider + config config.ITraceConfig + benefitSvc benefit.IBenefitService +} + +func NewTaskCallbackServiceImpl( + taskRepo repo.ITaskRepo, + traceRepo tracerepo.ITraceRepo, + taskProcessor processor.TaskProcessor, + tenantProvider tenant.ITenantProvider, + config config.ITraceConfig, + benefitSvc benefit.IBenefitService, +) ITaskCallbackService { + return &TaskCallbackServiceImpl{ + taskRepo: taskRepo, + traceRepo: traceRepo, + taskProcessor: taskProcessor, + tenantProvider: tenantProvider, + config: config, + benefitSvc: benefitSvc, + } +} + +func (t *TaskCallbackServiceImpl) AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error { for _, turn := range event.TurnEvalResults { workspaceIDStr, workspaceID := turn.GetWorkspaceIDFromExt() - tenants, err := h.getTenants(ctx, loop_span.PlatformType("callback_all")) + tenants, err := t.tenantProvider.GetTenantsByPlatformType(ctx, loop_span.PlatformType("callback_all")) if err != nil { return err } - storageDuration := h.config.GetTraceDataMaxDurationDay(ctx, loop_span.PlatformDefault) - res, err := h.benefitSvc.CheckTraceBenefit(ctx, &benefit.CheckTraceBenefitParams{ - ConnectorUID: turn.BaseInfo.CreatedBy.UserID, + storageDuration := t.config.GetTraceDataMaxDurationDay(ctx, lo.ToPtr(string(loop_span.PlatformDefault))) + res, err := t.benefitSvc.CheckTraceBenefit(ctx, &benefit.CheckTraceBenefitParams{ + ConnectorUID: turn.GetUserID(), SpaceID: workspaceID, }) if err != nil { @@ -38,7 +76,7 @@ func (h *tracehub.TraceHubServiceImpl) CallBack(ctx context.Context, event *enti storageDuration = res.StorageDuration } - spans, err := h.getSpan(ctx, + spans, err := t.getSpan(ctx, tenants, []string{turn.GetSpanIDFromExt()}, turn.GetTraceIDFromExt(), @@ -56,39 +94,26 @@ func (h *tracehub.TraceHubServiceImpl) CallBack(ctx context.Context, event *enti span := spans[0] // Newly added: write Redis counters based on the Status - err = h.updateTaskRunDetailsCount(ctx, turn.GetTaskIDFromExt(), turn, storageDuration*24*60*60) + err = t.updateTaskRunDetailsCount(ctx, turn.GetTaskIDFromExt(), turn, storageDuration*24*60*60) if err != nil { - logs.CtxWarn(ctx, "更新TaskRun状态计数失败: taskID=%d, status=%d, err=%v", + logs.CtxWarn(ctx, "Update TaskRun count failed: taskID=%d, status=%d, err=%v", turn.GetTaskIDFromExt(), turn.Status, err) // Continue processing without interrupting the flow } - annotation := &loop_span.Annotation{ - SpanID: turn.GetSpanIDFromExt(), - TraceID: span.TraceID, - WorkspaceID: workspaceIDStr, - AnnotationType: loop_span.AnnotationTypeAutoEvaluate, - StartTime: time.UnixMicro(span.StartTime), - Key: fmt.Sprintf("%d:%d", turn.GetTaskIDFromExt(), turn.EvaluatorVersionID), - Value: loop_span.AnnotationValue{ - ValueType: loop_span.AnnotationValueTypeDouble, - FloatValue: turn.Score, - }, - Reasoning: turn.Reasoning, - Status: loop_span.AnnotationStatusNormal, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - Metadata: &loop_span.AutoEvaluateMetadata{ - TaskID: turn.GetTaskIDFromExt(), - EvaluatorRecordID: turn.EvaluatorRecordID, - EvaluatorVersionID: turn.EvaluatorVersionID, - }, - } - if err = annotation.GenID(); err != nil { + annotation, err := span.AddAutoEvalAnnotation( + turn.GetTaskIDFromExt(), + turn.EvaluatorRecordID, + turn.EvaluatorVersionID, + turn.Score, + turn.Reasoning, + turn.GetUserID(), + ) + if err != nil { return err } - err = h.traceRepo.InsertAnnotations(ctx, &repo.InsertAnnotationParam{ + err = t.traceRepo.InsertAnnotations(ctx, &tracerepo.InsertAnnotationParam{ Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, @@ -100,16 +125,16 @@ func (h *tracehub.TraceHubServiceImpl) CallBack(ctx context.Context, event *enti return nil } -func (h *tracehub.TraceHubServiceImpl) Correction(ctx context.Context, event *entity.CorrectionEvent) error { +func (t *TaskCallbackServiceImpl) AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error { workspaceIDStr, workspaceID := event.GetWorkspaceIDFromExt() if workspaceID == 0 { return fmt.Errorf("workspace_id is empty") } - tenants, err := h.getTenants(ctx, loop_span.PlatformType("callback_all")) + tenants, err := t.tenantProvider.GetTenantsByPlatformType(ctx, loop_span.PlatformType("callback_all")) if err != nil { return err } - spans, err := h.getSpan(ctx, + spans, err := t.getSpan(ctx, tenants, []string{event.GetSpanIDFromExt()}, event.GetTraceIDFromExt(), @@ -120,14 +145,12 @@ func (h *tracehub.TraceHubServiceImpl) Correction(ctx context.Context, event *en if err != nil { return err } - if event.EvaluatorResult.Correction == nil || event.EvaluatorResult == nil { - return err - } if len(spans) == 0 { return fmt.Errorf("span not found, span_id: %s", event.GetSpanIDFromExt()) } span := spans[0] - annotations, err := h.traceRepo.ListAnnotations(ctx, &repo.ListAnnotationsParam{ + + annotations, err := t.traceRepo.ListAnnotations(ctx, &tracerepo.ListAnnotationsParam{ Tenants: tenants, SpanID: event.GetSpanIDFromExt(), TraceID: event.GetTraceIDFromExt(), @@ -138,34 +161,97 @@ func (h *tracehub.TraceHubServiceImpl) Correction(ctx context.Context, event *en if err != nil { return err } - var annotation *loop_span.Annotation - for _, a := range annotations { - meta := a.GetAutoEvaluateMetadata() - if meta != nil && meta.EvaluatorRecordID == event.EvaluatorRecordID { - annotation = a - break - } - } - updateBy := session.UserIDInCtxOrEmpty(ctx) - if updateBy == "" { - return err + annotation, ok := annotations.FindByEvaluatorRecordID(event.EvaluatorRecordID) + if !ok { + logs.CtxError(ctx, "annotation not found, evaluator_record_id: %d", event.EvaluatorRecordID) + return fmt.Errorf("annotation not found, evaluator_record_id: %d", event.EvaluatorRecordID) } - annotation.CorrectAutoEvaluateScore(event.EvaluatorResult.Correction.Score, event.EvaluatorResult.Correction.Explain, updateBy) + + annotation.CorrectAutoEvaluateScore(event.EvaluatorResult.Correction.Score, event.EvaluatorResult.Correction.Explain, event.GetUpdateBy()) // Then synchronize the observability data - param := &repo.InsertAnnotationParam{ + param := &tracerepo.InsertAnnotationParam{ Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, } - if err = h.traceRepo.InsertAnnotations(ctx, param); err != nil { + if err = t.traceRepo.InsertAnnotations(ctx, param); err != nil { recordID := lo.Ternary(annotation.GetAutoEvaluateMetadata() != nil, annotation.GetAutoEvaluateMetadata().EvaluatorRecordID, 0) // If the synchronous update fails, compensate asynchronously // TODO: asynchronous processing has issues and may duplicate - logs.CtxWarn(ctx, "Sync upsert annotation failed, try async upsert. span_id=[%v], recored_id=[%v], err:%v", + logs.CtxError(ctx, "Sync upsert annotation failed, try async upsert. span_id=[%v], recored_id=[%v], err:%v", annotation.SpanID, recordID, err) return nil } return nil } + +func (t *TaskCallbackServiceImpl) getSpan(ctx context.Context, tenants []string, spanIds []string, traceId, workspaceId string, startAt, endAt int64) ([]*loop_span.Span, error) { + if len(spanIds) == 0 || workspaceId == "" { + return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode) + } + var filterFields []*loop_span.FilterField + filterFields = append(filterFields, &loop_span.FilterField{ + FieldName: loop_span.SpanFieldSpanId, + FieldType: loop_span.FieldTypeString, + Values: spanIds, + QueryType: ptr.Of(loop_span.QueryTypeEnumIn), + }) + filterFields = append(filterFields, &loop_span.FilterField{ + FieldName: loop_span.SpanFieldSpaceId, + FieldType: loop_span.FieldTypeString, + Values: []string{workspaceId}, + QueryType: ptr.Of(loop_span.QueryTypeEnumEq), + }) + if traceId != "" { + filterFields = append(filterFields, &loop_span.FilterField{ + FieldName: loop_span.SpanFieldTraceId, + FieldType: loop_span.FieldTypeString, + Values: []string{traceId}, + + QueryType: ptr.Of(loop_span.QueryTypeEnumEq), + }) + } + var spans []*loop_span.Span + // todo 目前可能有不同tenant在不同存储中,需要上层多次查询。后续逻辑需要下沉到repo中。 + for _, tenant := range tenants { + res, err := t.traceRepo.ListSpans(ctx, &tracerepo.ListSpansParam{ + Tenants: []string{tenant}, + Filters: &loop_span.FilterFields{ + FilterFields: filterFields, + }, + StartAt: startAt, + EndAt: endAt, + NotQueryAnnotation: true, + Limit: int32(len(spanIds)), + }) + if err != nil { + logs.CtxError(ctx, "failed to list span, %v", err) + return spans, err + } + spans = append(spans, res.Spans...) + } + logs.CtxInfo(ctx, "list span, spans: %v", spans) + + return spans, nil +} + +// updateTaskRunStatusCount updates the Redis count based on Status +func (t *TaskCallbackServiceImpl) updateTaskRunDetailsCount(ctx context.Context, taskID int64, turn *entity.OnlineExptTurnEvalResult, ttl int64) error { + taskRunID, err := turn.GetRunID() + if err != nil { + return fmt.Errorf("invalid task_run_id, err: %v", err) + } + // Increase the corresponding counter based on Status + switch turn.Status { + case entity.EvaluatorRunStatus_Success: + return t.taskRepo.IncrTaskRunSuccessCount(ctx, taskID, taskRunID, ttl) + case entity.EvaluatorRunStatus_Fail: + return t.taskRepo.IncrTaskRunFailCount(ctx, taskID, taskRunID, ttl) + default: + logs.CtxWarn(ctx, "unknown status, skip count: taskID=%d, taskRunID=%d, status=%d", + taskID, taskRunID, turn.Status) + return nil + } +} diff --git a/backend/modules/observability/domain/task/service/task_callback_test.go b/backend/modules/observability/domain/task/service/task_callback_test.go index f31dd9806..a09a4aab9 100755 --- a/backend/modules/observability/domain/task/service/task_callback_test.go +++ b/backend/modules/observability/domain/task/service/task_callback_test.go @@ -5,11 +5,11 @@ package service import ( "context" + "errors" "strconv" "testing" "time" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" "go.uber.org/mock/gomock" "github.com/coze-dev/coze-loop/backend/infra/external/benefit" @@ -23,7 +23,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestTraceHubServiceImpl_CallBackSuccess(t *testing.T) { +func TestTaskCallbackServiceImpl_CallBackSuccess(t *testing.T) { t.Parallel() ctrl := gomock.NewController(t) @@ -34,7 +34,7 @@ func TestTraceHubServiceImpl_CallBackSuccess(t *testing.T) { mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) mockTaskRepo := repo_mocks.NewMockITaskRepo(ctrl) - impl := &tracehub.TraceHubServiceImpl{ + impl := &TaskCallbackServiceImpl{ benefitSvc: mockBenefit, tenantProvider: mockTenant, traceRepo: mockTraceRepo, @@ -86,7 +86,7 @@ func TestTraceHubServiceImpl_CallBackSuccess(t *testing.T) { }, } - require.NoError(t, impl.CallBack(context.Background(), event)) + require.NoError(t, impl.AutoEvalCallback(context.Background(), event)) } func TestTraceHubServiceImpl_CallBackSpanNotFound(t *testing.T) { @@ -99,7 +99,7 @@ func TestTraceHubServiceImpl_CallBackSpanNotFound(t *testing.T) { mockTenant := tenant_mocks.NewMockITenantProvider(ctrl) mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - impl := &tracehub.TraceHubServiceImpl{ + impl := &TaskCallbackServiceImpl{ benefitSvc: mockBenefit, tenantProvider: mockTenant, traceRepo: mockTraceRepo, @@ -128,5 +128,184 @@ func TestTraceHubServiceImpl_CallBackSpanNotFound(t *testing.T) { }, } - require.Error(t, impl.CallBack(context.Background(), event)) + require.Error(t, impl.AutoEvalCallback(context.Background(), event)) +} + +func TestTaskCallbackServiceImpl_getSpan(t *testing.T) { + t.Parallel() + + ctx := context.Background() + tenants := []string{"tenant"} + spanIDs := []string{"span-1"} + traceID := "trace-1" + workspaceID := "ws-1" + start := int64(1000) + end := int64(2000) + + t.Run("with_trace_id", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) + impl := &TaskCallbackServiceImpl{traceRepo: mockTraceRepo} + expectedSpan := &loop_span.Span{SpanID: spanIDs[0], TraceID: traceID} + + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).DoAndReturn( + func(_ context.Context, param *repo.ListSpansParam) (*repo.ListSpansResult, error) { + require.Equal(t, tenants, param.Tenants) + require.Equal(t, start, param.StartAt) + require.Equal(t, end, param.EndAt) + require.True(t, param.NotQueryAnnotation) + require.Equal(t, int32(2), param.Limit) + require.Len(t, param.Filters.FilterFields, 3) + return &repo.ListSpansResult{Spans: loop_span.SpanList{expectedSpan}}, nil + }, + ) + + spans, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) + require.NoError(t, err) + require.Equal(t, []*loop_span.Span{expectedSpan}, spans) + }) + + t.Run("without_trace_id", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) + impl := &TaskCallbackServiceImpl{traceRepo: mockTraceRepo} + expectedSpan := &loop_span.Span{SpanID: spanIDs[0]} + + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).DoAndReturn( + func(_ context.Context, param *repo.ListSpansParam) (*repo.ListSpansResult, error) { + require.Equal(t, tenants, param.Tenants) + require.Len(t, param.Filters.FilterFields, 2) + return &repo.ListSpansResult{Spans: loop_span.SpanList{expectedSpan}}, nil + }, + ) + + spans, err := impl.getSpan(ctx, tenants, spanIDs, "", workspaceID, start, end) + require.NoError(t, err) + require.Equal(t, []*loop_span.Span{expectedSpan}, spans) + }) + + t.Run("empty_span_ids", func(t *testing.T) { + t.Parallel() + impl := &TaskCallbackServiceImpl{} + _, err := impl.getSpan(ctx, tenants, nil, traceID, workspaceID, start, end) + require.Error(t, err) + }) + + t.Run("empty_workspace", func(t *testing.T) { + t.Parallel() + impl := &TaskCallbackServiceImpl{} + _, err := impl.getSpan(ctx, tenants, spanIDs, traceID, "", start, end) + require.Error(t, err) + }) + + t.Run("repo_error", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) + impl := &TaskCallbackServiceImpl{traceRepo: mockTraceRepo} + + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(nil, errors.New("list error")) + + _, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) + require.Error(t, err) + }) + + t.Run("no_data", func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) + impl := &TaskCallbackServiceImpl{traceRepo: mockTraceRepo} + + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(&repo.ListSpansResult{}, nil) + + spans, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) + require.NoError(t, err) + require.Nil(t, spans) + }) +} + +func TestTaskCallbackServiceImpl_updateTaskRunDetailsCount(t *testing.T) { + t.Parallel() + + ctx := context.Background() + taskID := int64(101) + runIDStr := "202" + runID := int64(202) + + tests := []struct { + name string + status entity.EvaluatorRunStatus + expectSuccess bool + expectFail bool + expectErr bool + }{ + { + name: "success_status", + status: entity.EvaluatorRunStatus_Success, + expectSuccess: true, + }, + { + name: "fail_status", + status: entity.EvaluatorRunStatus_Fail, + expectFail: true, + }, + { + name: "unknown_status", + status: entity.EvaluatorRunStatus_Unknown, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + t.Cleanup(ctrl.Finish) + + mockRepo := repo_mocks.NewMockITaskRepo(ctrl) + impl := &TaskCallbackServiceImpl{taskRepo: mockRepo} + + turn := &entity.OnlineExptTurnEvalResult{ + Status: tt.status, + Ext: map[string]string{ + "run_id": runIDStr, + }, + } + + if tt.expectSuccess { + mockRepo.EXPECT().IncrTaskRunSuccessCount(ctx, taskID, runID, gomock.Any()).Return(nil) + } + if tt.expectFail { + mockRepo.EXPECT().IncrTaskRunFailCount(ctx, taskID, runID, gomock.Any()).Return(nil) + } + + err := impl.updateTaskRunDetailsCount(ctx, taskID, turn, 0) + if tt.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } + + t.Run("missing_run_id", func(t *testing.T) { + t.Parallel() + impl := &TaskCallbackServiceImpl{} + err := impl.updateTaskRunDetailsCount(ctx, taskID, &entity.OnlineExptTurnEvalResult{Ext: map[string]string{}}, 0) + require.Error(t, err) + }) + + t.Run("invalid_run_id", func(t *testing.T) { + t.Parallel() + impl := &TaskCallbackServiceImpl{} + err := impl.updateTaskRunDetailsCount(ctx, taskID, &entity.OnlineExptTurnEvalResult{Ext: map[string]string{"run_id": "abc"}}, 0) + require.Error(t, err) + }) } diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go index 02d4c4995..cd18cd08f 100644 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate.go @@ -186,9 +186,9 @@ func (p *AutoEvaluteProcessor) OnTaskCreated(ctx context.Context, currentTask *t runEndAt = currentTask.EffectiveTime.EndAt } else { switch currentTask.Sampler.CycleTimeUnit { - case task.TimeUnitDay: + case task_entity.TimeUnitDay: runEndAt = runStartAt + (currentTask.Sampler.CycleInterval)*24*time.Hour.Milliseconds() - case task.TimeUnitWeek: + case task_entity.TimeUnitWeek: runEndAt = runStartAt + (currentTask.Sampler.CycleInterval)*7*24*time.Hour.Milliseconds() default: runEndAt = runStartAt + (currentTask.Sampler.CycleInterval)*10*time.Minute.Milliseconds() diff --git a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go index a51537ff5..a3157086e 100755 --- a/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/processor/auto_evaluate_test.go @@ -142,7 +142,7 @@ func buildTestTask(t *testing.T) *taskentity.ObservabilityTask { IsCycle: false, CycleCount: 0, CycleInterval: 1, - CycleTimeUnit: task.TimeUnitDay, + CycleTimeUnit: taskentity.TimeUnitDay, }, TaskConfig: &taskentity.TaskConfig{ AutoEvaluateConfigs: []*taskentity.AutoEvaluateConfig{ diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index c33a28879..c9d33cdc6 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -11,8 +11,6 @@ import ( "github.com/coze-dev/coze-loop/backend/infra/middleware/session" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor" - tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" @@ -65,68 +63,61 @@ func (h *TraceHubServiceImpl) BackFill(ctx context.Context, event *entity.BackFi }(lockCancel) } - sub, err := h.setBackfillTask(ctx, event) + sub, err := h.buildSubscriber(ctx, event) if err != nil { return err } + if sub == nil || sub.t == nil { + return errors.New("subscriber or task config not found") + } - if sub != nil && sub.t != nil && sub.t.GetBaseInfo() != nil && sub.t.GetBaseInfo().GetCreatedBy() != nil { - ctx = session.WithCtxUser(ctx, &session.User{ID: sub.t.GetBaseInfo().GetCreatedBy().GetUserID()}) + // todo tyf 是否需要 + if sub.t != nil && sub.t.CreatedBy != "" { + ctx = session.WithCtxUser(ctx, &session.User{ID: sub.t.CreatedBy}) } // 2. Determine whether the backfill task is completed to avoid repeated execution isDone, err := h.isBackfillDone(ctx, sub) if err != nil { - logs.CtxError(ctx, "check backfill task done failed, task_id=%d, err=%v", sub.t.GetID(), err) + logs.CtxError(ctx, "check backfill task done failed, task_id=%d, err=%v", sub.t.ID, err) return err } if isDone { - logs.CtxInfo(ctx, "backfill already completed, task_id=%d", sub.t.GetID()) + logs.CtxInfo(ctx, "backfill already completed, task_id=%d", sub.t.ID) return nil } - // 顺序执行时重置 flush 错误收集器 - h.flushErrLock.Lock() - h.flushErr = nil - h.flushErrLock.Unlock() - // 5. Retrieve span data from the observability service - listErr := h.listAndSendSpans(ctx, sub) - if listErr != nil { - logs.CtxError(ctx, "list spans failed, task_id=%d, err=%v", sub.t.GetID(), listErr) - } + err = h.listAndSendSpans(ctx, sub) - // 6. Synchronously wait for completion to ensure all data is processed - return h.onHandleDone(ctx, listErr, sub) + return h.onHandleDone(ctx, err, sub) } -// setBackfillTask sets the context for the current backfill task -func (h *TraceHubServiceImpl) setBackfillTask(ctx context.Context, event *entity.BackFillEvent) (*spanSubscriber, error) { - taskConfig, err := h.taskRepo.GetTask(ctx, event.TaskID, nil, nil) +// buildSubscriber sets the context for the current backfill task +func (h *TraceHubServiceImpl) buildSubscriber(ctx context.Context, event *entity.BackFillEvent) (*spanSubscriber, error) { + taskDO, err := h.taskRepo.GetTask(ctx, event.TaskID, nil, nil) if err != nil { logs.CtxError(ctx, "get task config failed, task_id=%d, err=%v", event.TaskID, err) return nil, err } - if taskConfig == nil { + if taskDO == nil { return nil, errors.New("task config not found") } - taskConfigDO := tconv.TaskDO2DTO(ctx, taskConfig, nil) - taskRun, err := h.taskRepo.GetBackfillTaskRun(ctx, ptr.Of(taskConfigDO.GetWorkspaceID()), taskConfigDO.GetID()) - if err != nil { - logs.CtxError(ctx, "get backfill task run failed, task_id=%d, err=%v", taskConfigDO.GetID(), err) - return nil, err + + taskRun := taskDO.GetBackfillTaskRun() + if taskRun == nil { + logs.CtxError(ctx, "get backfill task run failed, task_id=%d, err=%v", taskDO.ID) + return nil, errors.New("get backfill task run not found") } - taskRunDTO := tconv.TaskRunDO2DTO(ctx, taskRun, nil) - proc := h.taskProcessor.GetTaskProcessor(taskConfig.TaskType) + + proc := h.taskProcessor.GetTaskProcessor(taskDO.TaskType) sub := &spanSubscriber{ - taskID: taskConfigDO.GetID(), - t: taskConfigDO, - tr: taskRunDTO, - processor: proc, - bufCap: 0, - maxFlushInterval: time.Second * 5, - taskRepo: h.taskRepo, - runType: entity.TaskRunTypeBackFill, + taskID: taskDO.ID, + t: taskDO, + tr: taskRun, + processor: proc, + taskRepo: h.taskRepo, + runType: entity.TaskRunTypeBackFill, } return sub, nil @@ -135,7 +126,7 @@ func (h *TraceHubServiceImpl) setBackfillTask(ctx context.Context, event *entity // isBackfillDone checks whether the backfill task has been completed func (h *TraceHubServiceImpl) isBackfillDone(ctx context.Context, sub *spanSubscriber) (bool, error) { if sub.tr == nil { - logs.CtxError(ctx, "get backfill task run failed, task_id=%d, err=%v", sub.t.GetID(), nil) + logs.CtxError(ctx, "get backfill task run failed, task_id=%d, err=%v", sub.t.ID, nil) return true, nil } @@ -143,10 +134,10 @@ func (h *TraceHubServiceImpl) isBackfillDone(ctx context.Context, sub *spanSubsc } func (h *TraceHubServiceImpl) listAndSendSpans(ctx context.Context, sub *spanSubscriber) error { - backfillTime := sub.t.GetRule().GetBackfillEffectiveTime() - tenants, err := h.getTenants(ctx, loop_span.PlatformType(sub.t.GetRule().GetSpanFilters().GetPlatformType())) + backfillTime := sub.t.BackfillEffectiveTime + tenants, err := h.getTenants(ctx, sub.t.SpanFilter.PlatformType) if err != nil { - logs.CtxError(ctx, "get tenants failed, task_id=%d, err=%v", sub.t.GetID(), err) + logs.CtxError(ctx, "get tenants failed, task_id=%d, err=%v", sub.t.ID, err) return err } @@ -154,15 +145,15 @@ func (h *TraceHubServiceImpl) listAndSendSpans(ctx context.Context, sub *spanSub listParam := &repo.ListSpansParam{ Tenants: tenants, Filters: h.buildSpanFilters(ctx, sub.t), - StartAt: backfillTime.GetStartAt(), - EndAt: backfillTime.GetEndAt(), + StartAt: backfillTime.StartAt, + EndAt: backfillTime.EndAt, Limit: pageSize, // Page size DescByStartTime: true, NotQueryAnnotation: true, // No annotation query required during backfill } - if sub.tr.BackfillRunDetail != nil && sub.tr.BackfillRunDetail.LastSpanPageToken != nil { - listParam.PageToken = *sub.tr.BackfillRunDetail.LastSpanPageToken + if sub.tr.BackfillDetail != nil && sub.tr.BackfillDetail.LastSpanPageToken != nil { + listParam.PageToken = *sub.tr.BackfillDetail.LastSpanPageToken } // Paginate query and send data return h.fetchAndSendSpans(ctx, listParam, sub) @@ -182,22 +173,26 @@ type ListSpansReq struct { } // buildSpanFilters constructs span filter conditions -func (h *TraceHubServiceImpl) buildSpanFilters(ctx context.Context, taskConfig *task.Task) *loop_span.FilterFields { +func (h *TraceHubServiceImpl) buildSpanFilters(ctx context.Context, taskConfig *entity.ObservabilityTask) *loop_span.FilterFields { // More complex filters can be built based on the task configuration // Simplified here: return nil to indicate no additional filters - platformFilter, err := h.buildHelper.BuildPlatformRelatedFilter(ctx, loop_span.PlatformType(taskConfig.GetRule().GetSpanFilters().GetPlatformType())) + platformFilter, err := h.buildHelper.BuildPlatformRelatedFilter(ctx, taskConfig.SpanFilter.PlatformType) if err != nil { + logs.CtxError(ctx, "build platform filter failed, task_id=%d, err=%v", taskConfig.ID, err) + // 不需要重试 return nil } builtinFilter, err := h.buildBuiltinFilters(ctx, platformFilter, &ListSpansReq{ - WorkspaceID: taskConfig.GetWorkspaceID(), - SpanListType: loop_span.SpanListType(taskConfig.GetRule().GetSpanFilters().GetSpanListType()), + WorkspaceID: taskConfig.WorkspaceID, + SpanListType: taskConfig.SpanFilter.SpanListType, }) if err != nil { + logs.CtxError(ctx, "build builtin filter failed, task_id=%d, err=%v", taskConfig.ID, err) + // 不需要重试 return nil } - filters := h.combineFilters(builtinFilter, convertor.FilterFieldsDTO2DO(taskConfig.GetRule().GetSpanFilters().GetFilters())) + filters := h.combineFilters(builtinFilter, &taskConfig.SpanFilter.Filters) return filters } @@ -266,16 +261,18 @@ func (h *TraceHubServiceImpl) fetchAndSendSpans(ctx context.Context, listParam * totalCount := int64(0) pageToken := listParam.PageToken for { - logs.CtxInfo(ctx, "ListSpansParam:%v", listParam) + logs.CtxInfo(ctx, "TaskID: %d, ListSpansParam:%v", sub.t.ID, listParam) result, err := h.traceRepo.ListSpans(ctx, listParam) if err != nil { - logs.CtxError(ctx, "list spans failed, task_id=%d, page_token=%s, err=%v", sub.t.GetID(), pageToken, err) + logs.CtxError(ctx, "List spans failed, task_id=%d, page_token=%s, err=%v", sub.t.ID, pageToken, err) return err } + logs.CtxInfo(ctx, "Fetch %d spans, total=%d, task_id=%d", len(result.Spans), totalCount, sub.t.ID) + spans := result.Spans processors, err := h.buildHelper.BuildGetTraceProcessors(ctx, span_processor.Settings{ - WorkspaceId: sub.t.GetWorkspaceID(), - PlatformType: loop_span.PlatformType(sub.t.GetRule().GetSpanFilters().GetPlatformType()), + WorkspaceId: sub.t.WorkspaceID, + PlatformType: sub.t.SpanFilter.PlatformType, QueryStartTime: listParam.StartAt, QueryEndTime: listParam.EndAt, }) @@ -304,11 +301,11 @@ func (h *TraceHubServiceImpl) fetchAndSendSpans(ctx context.Context, listParam * } totalCount += int64(len(spans)) - logs.CtxInfo(ctx, "processed %d spans, total=%d, task_id=%d", len(spans), totalCount, sub.t.GetID()) + logs.CtxInfo(ctx, "Processed %d spans completed, total=%d, task_id=%d", len(spans), totalCount, sub.t.ID) } if !result.HasMore { - logs.CtxInfo(ctx, "completed listing spans, total_count=%d, task_id=%d", totalCount, sub.t.GetID()) + logs.CtxInfo(ctx, "Completed listing spans, total_count=%d, task_id=%d", totalCount, sub.t.ID) break } @@ -319,80 +316,58 @@ func (h *TraceHubServiceImpl) fetchAndSendSpans(ctx context.Context, listParam * } func (h *TraceHubServiceImpl) flushSpans(ctx context.Context, fr *flushReq, sub *spanSubscriber) error { - if ctx.Err() != nil { - return ctx.Err() - } - - _, _, err := h.doFlush(ctx, fr, sub) - if err != nil { - logs.CtxError(ctx, "flush spans failed, task_id=%d, err=%v", sub.t.GetID(), err) - h.flushErrLock.Lock() - h.flushErr = append(h.flushErr, err) - h.flushErrLock.Unlock() - } - - return nil -} - -func (h *TraceHubServiceImpl) doFlush(ctx context.Context, fr *flushReq, sub *spanSubscriber) (flushed, sampled int, _ error) { if fr == nil || len(fr.spans) == 0 { - return 0, 0, nil + return nil } - logs.CtxInfo(ctx, "processing %d spans for backfill, task_id=%d", len(fr.spans), sub.t.GetID()) + logs.CtxInfo(ctx, "Start processing %d spans for backfill, task_id=%d", len(fr.spans), sub.t.ID) // Apply sampling logic sampledSpans := h.applySampling(fr.spans, sub) if len(sampledSpans) == 0 { - logs.CtxInfo(ctx, "no spans after sampling, task_id=%d", sub.t.GetID()) - return len(fr.spans), 0, nil + logs.CtxInfo(ctx, "no spans after sampling, task_id=%d", sub.t.ID) + return nil } // Execute specific business logic err := h.processSpansForBackfill(ctx, sampledSpans, sub) if err != nil { - logs.CtxError(ctx, "process spans failed, task_id=%d, err=%v", sub.t.GetID(), err) - return len(fr.spans), len(sampledSpans), err + logs.CtxError(ctx, "process spans failed, task_id=%d, err=%v", sub.t.ID, err) + return err } - sub.tr.BackfillRunDetail = &task.BackfillDetail{ - LastSpanPageToken: ptr.Of(fr.pageToken), - } + // todo 不应该这里直接写po字段 err = h.taskRepo.UpdateTaskRunWithOCC(ctx, sub.tr.ID, sub.tr.WorkspaceID, map[string]interface{}{ - "backfill_detail": ToJSONString(ctx, sub.tr.BackfillRunDetail), + "backfill_detail": ToJSONString(ctx, sub.tr.BackfillDetail), }) if err != nil { - logs.CtxError(ctx, "update task run failed, task_id=%d, err=%v", sub.t.GetID(), err) - return len(fr.spans), len(sampledSpans), err + logs.CtxError(ctx, "update task run failed, task_id=%d, err=%v", sub.t.ID, err) + return err } if fr.noMore { - logs.CtxInfo(ctx, "no more spans to process, task_id=%d", sub.t.GetID()) + logs.CtxInfo(ctx, "no more spans to process, task_id=%d", sub.t.ID) if err = sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ - Task: tconv.TaskDTO2DO(sub.t), - TaskRun: tconv.TaskRunDTO2DO(sub.tr), + Task: sub.t, + TaskRun: sub.tr, IsFinish: false, }); err != nil { - return len(fr.spans), len(sampledSpans), err + return err } } logs.CtxInfo(ctx, "successfully processed %d spans (sampled from %d), task_id=%d", - len(sampledSpans), len(fr.spans), sub.t.GetID()) - return len(fr.spans), len(sampledSpans), nil + len(sampledSpans), len(fr.spans), sub.t.ID) + return nil } // applySampling applies sampling logic func (h *TraceHubServiceImpl) applySampling(spans []*loop_span.Span, sub *spanSubscriber) []*loop_span.Span { - if sub.t == nil || sub.t.Rule == nil { - return spans - } - - sampler := sub.t.GetRule().GetSampler() + sampler := sub.t.Sampler if sampler == nil { return spans } - sampleRate := sampler.GetSampleRate() + sampleRate := sampler.SampleRate if sampleRate >= 1.0 { return spans // 100% sampling } @@ -428,7 +403,7 @@ func (h *TraceHubServiceImpl) processSpansForBackfill(ctx context.Context, spans batch := spans[i:end] if err := h.processBatchSpans(ctx, batch, sub); err != nil { logs.CtxError(ctx, "process batch spans failed, task_id=%d, batch_start=%d, err=%v", - sub.t.GetID(), i, err) + sub.t.ID, i, err) // Continue with the next batch without stopping due to a single failure continue } @@ -442,15 +417,15 @@ func (h *TraceHubServiceImpl) processBatchSpans(ctx context.Context, spans []*lo for _, span := range spans { // Execute processing logic according to the task type logs.CtxInfo(ctx, "processing span for backfill, span_id=%s, trace_id=%s, task_id=%d", - span.SpanID, span.TraceID, sub.t.GetID()) + span.SpanID, span.TraceID, sub.t.ID) taskCount, _ := h.taskRepo.GetTaskCount(ctx, sub.taskID) - taskRunCount, _ := h.taskRepo.GetTaskRunCount(ctx, sub.taskID, sub.tr.GetID()) - sampler := sub.t.GetRule().GetSampler() - if taskCount+1 > sampler.GetSampleSize() { - logs.CtxWarn(ctx, "taskCount+1 > sampler.GetSampleSize(), task_id=%d,SampleSize=%d", sub.taskID, sampler.GetSampleSize()) + taskRunCount, _ := h.taskRepo.GetTaskRunCount(ctx, sub.taskID, sub.tr.ID) + sampler := sub.t.Sampler + if taskCount+1 > sampler.SampleSize { + logs.CtxInfo(ctx, "taskCount+1 > sampler.GetSampleSize(), task_id=%d,SampleSize=%d", sub.taskID, sampler.SampleSize) if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ - Task: tconv.TaskDTO2DO(sub.t), - TaskRun: tconv.TaskRunDTO2DO(sub.tr), + Task: sub.t, + TaskRun: sub.tr, IsFinish: true, }); err != nil { return err @@ -467,34 +442,23 @@ func (h *TraceHubServiceImpl) processBatchSpans(ctx context.Context, spans []*lo } // onHandleDone handles completion callback -func (h *TraceHubServiceImpl) onHandleDone(ctx context.Context, listErr error, sub *spanSubscriber) error { - // Collect all errors - h.flushErrLock.Lock() - allErrors := append([]error{}, h.flushErr...) - if listErr != nil { - allErrors = append(allErrors, listErr) - } - h.flushErrLock.Unlock() - - if len(allErrors) > 0 { - backfillEvent := &entity.BackFillEvent{ - SpaceID: sub.t.GetWorkspaceID(), - TaskID: sub.t.GetID(), - } - - // Send MQ message asynchronously without blocking task creation flow - go func() { - if err := h.sendBackfillMessage(context.Background(), backfillEvent); err != nil { - logs.CtxWarn(ctx, "send backfill message failed, task_id=%d, err=%v", sub.t.GetID(), err) - } - }() - logs.CtxWarn(ctx, "backfill completed with %d errors, task_id=%d", len(allErrors), sub.t.GetID()) - // Return the first error as a representative - return allErrors[0] - +func (h *TraceHubServiceImpl) onHandleDone(ctx context.Context, err error, sub *spanSubscriber) error { + if err == nil { + logs.CtxInfo(ctx, "backfill completed successfully, task_id=%d", sub.t.ID) + return nil } - logs.CtxInfo(ctx, "backfill completed successfully, task_id=%d", sub.t.GetID()) + // failed, need retry + logs.CtxWarn(ctx, "backfill completed with error: %v, task_id=%d", err, sub.t.ID) + backfillEvent := &entity.BackFillEvent{ + SpaceID: sub.t.WorkspaceID, + TaskID: sub.t.ID, + } + if sendErr := h.sendBackfillMessage(context.Background(), backfillEvent); sendErr != nil { + logs.CtxWarn(ctx, "send backfill message failed, task_id=%d, err=%v", sub.t.ID, sendErr) + return sendErr + } + // 依靠MQ进行重试 return nil } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go index a7442b88f..bc1cbf013 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go @@ -79,7 +79,7 @@ func TestTraceHubServiceImpl_SetBackfillTask(t *testing.T) { mockRepo.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Nil(), gomock.Nil()).Return(obsTask, nil) mockRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), int64(1)).Return(backfillRun, nil) - sub, err := impl.setBackfillTask(context.Background(), &entity.BackFillEvent{TaskID: 1}) + sub, err := impl.buildSubscriber(context.Background(), &entity.BackFillEvent{TaskID: 1}) require.NoError(t, err) require.NotNil(t, sub) require.Equal(t, int64(1), sub.taskID) @@ -97,7 +97,7 @@ func TestTraceHubServiceImpl_SetBackfillTaskNotFound(t *testing.T) { mockRepo.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Nil(), gomock.Nil()).Return(nil, nil) - _, err := impl.setBackfillTask(context.Background(), &entity.BackFillEvent{TaskID: 1}) + _, err := impl.buildSubscriber(context.Background(), &entity.BackFillEvent{TaskID: 1}) require.Error(t, err) } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go index a46b07d6b..50fb666f5 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go @@ -21,6 +21,7 @@ import ( type MockITraceHubService struct { ctrl *gomock.Controller recorder *MockITraceHubServiceMockRecorder + isgomock struct{} } // MockITraceHubServiceMockRecorder is the mock recorder for MockITraceHubService. @@ -41,57 +42,29 @@ func (m *MockITraceHubService) EXPECT() *MockITraceHubServiceMockRecorder { } // BackFill mocks base method. -func (m *MockITraceHubService) BackFill(arg0 context.Context, arg1 *entity.BackFillEvent) error { +func (m *MockITraceHubService) BackFill(ctx context.Context, event *entity.BackFillEvent) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "BackFill", arg0, arg1) + ret := m.ctrl.Call(m, "BackFill", ctx, event) ret0, _ := ret[0].(error) return ret0 } // BackFill indicates an expected call of BackFill. -func (mr *MockITraceHubServiceMockRecorder) BackFill(arg0, arg1 any) *gomock.Call { +func (mr *MockITraceHubServiceMockRecorder) BackFill(ctx, event any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackFill", reflect.TypeOf((*MockITraceHubService)(nil).BackFill), arg0, arg1) -} - -// CallBack mocks base method. -func (m *MockITraceHubService) CallBack(arg0 context.Context, arg1 *entity.AutoEvalEvent) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CallBack", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// CallBack indicates an expected call of CallBack. -func (mr *MockITraceHubServiceMockRecorder) CallBack(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CallBack", reflect.TypeOf((*MockITraceHubService)(nil).CallBack), arg0, arg1) -} - -// Correction mocks base method. -func (m *MockITraceHubService) Correction(arg0 context.Context, arg1 *entity.CorrectionEvent) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Correction", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// Correction indicates an expected call of Correction. -func (mr *MockITraceHubServiceMockRecorder) Correction(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Correction", reflect.TypeOf((*MockITraceHubService)(nil).Correction), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BackFill", reflect.TypeOf((*MockITraceHubService)(nil).BackFill), ctx, event) } // SpanTrigger mocks base method. -func (m *MockITraceHubService) SpanTrigger(arg0 context.Context, arg1 *entity.RawSpan) error { +func (m *MockITraceHubService) SpanTrigger(ctx context.Context, event *entity.RawSpan) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SpanTrigger", arg0, arg1) + ret := m.ctrl.Call(m, "SpanTrigger", ctx, event) ret0, _ := ret[0].(error) return ret0 } // SpanTrigger indicates an expected call of SpanTrigger. -func (mr *MockITraceHubServiceMockRecorder) SpanTrigger(arg0, arg1 any) *gomock.Call { +func (mr *MockITraceHubServiceMockRecorder) SpanTrigger(ctx, event any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpanTrigger", reflect.TypeOf((*MockITraceHubService)(nil).SpanTrigger), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpanTrigger", reflect.TypeOf((*MockITraceHubService)(nil).SpanTrigger), ctx, event) } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go index 1422dce04..d63cab226 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go @@ -6,12 +6,10 @@ package tracehub import ( "context" "fmt" - "sync" "time" "github.com/bytedance/gg/gslice" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" @@ -81,23 +79,18 @@ func (h *TraceHubServiceImpl) getSubscriberOfSpan(ctx context.Context, span *loo logs.CtxError(ctx, "Failed to get non-final task list, err: %v", err) return nil, err } - taskList := tconv.TaskDOs2DTOs(ctx, taskDOs, nil) - for _, taskDO := range taskList { - if !cfg.IsAllSpace && !gslice.Contains(cfg.SpaceList, taskDO.GetWorkspaceID()) { + for _, taskDO := range taskDOs { + if !cfg.IsAllSpace && !gslice.Contains(cfg.SpaceList, taskDO.WorkspaceID) { continue } proc := h.taskProcessor.GetTaskProcessor(entity.TaskType(taskDO.TaskType)) subscribers = append(subscribers, &spanSubscriber{ - taskID: taskDO.GetID(), - RWMutex: sync.RWMutex{}, - t: taskDO, - processor: proc, - bufCap: 0, - flushWait: sync.WaitGroup{}, - maxFlushInterval: time.Second * 5, - taskRepo: h.taskRepo, - runType: entity.TaskRunTypeNewData, - buildHelper: h.buildHelper, + taskID: taskDO.ID, + t: taskDO, + processor: proc, + taskRepo: h.taskRepo, + runType: entity.TaskRunTypeNewData, + buildHelper: h.buildHelper, }) } @@ -124,59 +117,59 @@ func (h *TraceHubServiceImpl) getSubscriberOfSpan(ctx context.Context, span *loo func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.Span, subs []*spanSubscriber) error { merr := &multierror.Error{} for _, sub := range subs { - if sub.t.GetRule().GetEffectiveTime() == nil || sub.t.GetRule().GetEffectiveTime().GetStartAt() == 0 { + if sub.t.EffectiveTime == nil || sub.t.EffectiveTime.StartAt == 0 { continue } - if span.StartTime < sub.t.GetRule().GetEffectiveTime().GetStartAt() { + if span.StartTime < sub.t.EffectiveTime.StartAt { logs.CtxWarn(ctx, "span start time is before task cycle start time, trace_id=%s, span_id=%s", span.TraceID, span.SpanID) continue } // First step: lock for task status change // Task run status var runStartAt, runEndAt int64 - if sub.t.GetTaskStatus() == task.TaskStatusUnstarted { + if sub.t.TaskStatus == entity.TaskStatusUnstarted { logs.CtxWarn(ctx, "task is unstarted, need sub.Creative") - runStartAt = sub.t.GetRule().GetEffectiveTime().GetStartAt() - if !sub.t.GetRule().GetSampler().GetIsCycle() { - runEndAt = sub.t.GetRule().GetEffectiveTime().GetEndAt() + runStartAt = sub.t.EffectiveTime.StartAt + if !sub.t.Sampler.IsCycle { + runEndAt = sub.t.EffectiveTime.EndAt } else { - switch *sub.t.GetRule().GetSampler().CycleTimeUnit { - case task.TimeUnitDay: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*24*time.Hour.Milliseconds() - case task.TimeUnitWeek: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*7*24*time.Hour.Milliseconds() + switch sub.t.Sampler.CycleTimeUnit { + case entity.TimeUnitDay: + runEndAt = runStartAt + (sub.t.Sampler.CycleInterval)*24*time.Hour.Milliseconds() + case entity.TimeUnitWeek: + runEndAt = runStartAt + (sub.t.Sampler.CycleInterval)*7*24*time.Hour.Milliseconds() default: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*10*time.Minute.Milliseconds() + runEndAt = runStartAt + (sub.t.Sampler.CycleInterval)*10*time.Minute.Milliseconds() } } if err := sub.Creative(ctx, runStartAt, runEndAt); err != nil { merr = multierror.Append(merr, errors.WithMessagef(err, "task is unstarted, need sub.Creative,creative processor, task_id=%d", sub.taskID)) continue } - if err := sub.processor.OnTaskUpdated(ctx, tconv.TaskDTO2DO(sub.t), entity.TaskStatusRunning); err != nil { + if err := sub.processor.OnTaskUpdated(ctx, sub.t, entity.TaskStatusRunning); err != nil { logs.CtxWarn(ctx, "OnTaskUpdated, task_id=%d, err=%v", sub.taskID, err) continue } } // Fetch the corresponding task config - taskRunConfig, err := h.taskRepo.GetLatestNewDataTaskRun(ctx, sub.t.WorkspaceID, sub.taskID) + taskRunConfig, err := h.taskRepo.GetLatestNewDataTaskRun(ctx, &sub.t.WorkspaceID, sub.taskID) if err != nil { logs.CtxWarn(ctx, "GetLatestNewDataTaskRun, task_id=%d, err=%v", sub.taskID, err) continue } if taskRunConfig == nil { logs.CtxWarn(ctx, "task run config not found, task_id=%d", sub.taskID) - runStartAt = sub.t.GetRule().GetEffectiveTime().GetStartAt() - if !sub.t.GetRule().GetSampler().GetIsCycle() { - runEndAt = sub.t.GetRule().GetEffectiveTime().GetEndAt() + runStartAt = sub.t.EffectiveTime.StartAt + if !sub.t.Sampler.IsCycle { + runEndAt = sub.t.EffectiveTime.EndAt } else { - switch *sub.t.GetRule().GetSampler().CycleTimeUnit { + switch sub.t.Sampler.CycleTimeUnit { case task.TimeUnitDay: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*24*time.Hour.Milliseconds() + runEndAt = runStartAt + sub.t.Sampler.CycleInterval*24*time.Hour.Milliseconds() case task.TimeUnitWeek: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*7*24*time.Hour.Milliseconds() + runEndAt = runStartAt + sub.t.Sampler.CycleInterval*7*24*time.Hour.Milliseconds() default: - runEndAt = runStartAt + (*sub.t.GetRule().GetSampler().CycleInterval)*10*time.Minute.Milliseconds() + runEndAt = runStartAt + sub.t.Sampler.CycleInterval*10*time.Minute.Milliseconds() } } if err = sub.Creative(ctx, runStartAt, runEndAt); err != nil { @@ -184,17 +177,17 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S } continue } - sampler := sub.t.GetRule().GetSampler() + sampler := sub.t.Sampler // Fetch the corresponding task count and subtask count taskCount, _ := h.taskRepo.GetTaskCount(ctx, sub.taskID) taskRunCount, _ := h.taskRepo.GetTaskRunCount(ctx, sub.taskID, taskRunConfig.ID) logs.CtxInfo(ctx, "preDispatch, task_id=%d, taskCount=%d, taskRunCount=%d", sub.taskID, taskCount, taskRunCount) - endTime := time.UnixMilli(sub.t.GetRule().GetEffectiveTime().GetEndAt()) + endTime := time.UnixMilli(sub.t.EffectiveTime.EndAt) // Reached task time limit if time.Now().After(endTime) { logs.CtxWarn(ctx, "[OnTaskFinished]time.Now().After(endTime) Finish processor, task_id=%d, endTime=%v, now=%v", sub.taskID, endTime, time.Now()) if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ - Task: tconv.TaskDTO2DO(sub.t), + Task: sub.t, TaskRun: taskRunConfig, IsFinish: true, }); err != nil { @@ -204,10 +197,10 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S } } // Reached task limit - if taskCount+1 > sampler.GetSampleSize() { + if taskCount+1 > sampler.SampleSize { logs.CtxWarn(ctx, "[OnTaskFinished]taskCount+1 > sampler.GetSampleSize() Finish processor, task_id=%d", sub.taskID) if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ - Task: tconv.TaskDTO2DO(sub.t), + Task: sub.t, TaskRun: taskRunConfig, IsFinish: true, }); err != nil { @@ -215,13 +208,13 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S continue } } - if sampler.GetIsCycle() { + if sampler.IsCycle { cycleEndTime := time.Unix(0, taskRunConfig.RunEndAt.UnixMilli()*1e6) // Reached single cycle task time limit if time.Now().After(cycleEndTime) { logs.CtxInfo(ctx, "[OnTaskFinished]time.Now().After(cycleEndTime) Finish processor, task_id=%d", sub.taskID) if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ - Task: tconv.TaskDTO2DO(sub.t), + Task: sub.t, TaskRun: taskRunConfig, IsFinish: false, }); err != nil { @@ -236,10 +229,10 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S } } // Reached single cycle task limit - if taskRunCount+1 > sampler.GetCycleCount() { + if taskRunCount+1 > sampler.CycleCount { logs.CtxWarn(ctx, "[OnTaskFinished]taskRunCount+1 > sampler.GetCycleCount(), task_id=%d", sub.taskID) if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ - Task: tconv.TaskDTO2DO(sub.t), + Task: sub.t, TaskRun: taskRunConfig, IsFinish: false, }); err != nil { @@ -255,16 +248,17 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S func (h *TraceHubServiceImpl) dispatch(ctx context.Context, span *loop_span.Span, subs []*spanSubscriber) error { merr := &multierror.Error{} for _, sub := range subs { - if sub.t.GetTaskStatus() != task.TaskStatusRunning { + if sub.t.TaskStatus != task.TaskStatusRunning { continue } logs.CtxInfo(ctx, " sub.AddSpan: %v", sub) if err := sub.AddSpan(ctx, span); err != nil { - merr = multierror.Append(merr, errors.WithMessagef(err, "add span to subscriber, task_id=%d", sub.taskID)) - continue + merr = multierror.Append(merr, errors.WithMessagef(err, "add span to subscriber, log_id=%s, trace_id=%s, span_id=%s, task_id=%d", + span.LogID, span.TraceID, span.SpanID, sub.taskID)) + } else { + logs.CtxInfo(ctx, "add span to subscriber, task_id=%d, log_id=%s, trace_id=%s, span_id=%s", sub.taskID, + span.LogID, span.TraceID, span.SpanID) } - logs.CtxInfo(ctx, "add span to subscriber, task_id=%d, log_id=%s, trace_id=%s, span_id=%s", sub.taskID, - span.LogID, span.TraceID, span.SpanID) } return merr.ErrorOrNil() } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go index e1008bd59..ca4fff295 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go @@ -9,9 +9,6 @@ import ( "sync" "time" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor" - tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" @@ -25,34 +22,32 @@ import ( ) type spanSubscriber struct { - taskID int64 sync.RWMutex // protect t, buf - t *task.Task - tr *task.TaskRun - processor taskexe.Processor - bufCap int // max buffer size - - flushWait sync.WaitGroup - maxFlushInterval time.Duration - taskRepo repo.ITaskRepo - runType entity.TaskRunType - buildHelper service.TraceFilterProcessorBuilder + + taskID int64 + t *entity.ObservabilityTask + tr *entity.TaskRun + processor taskexe.Processor + + taskRepo repo.ITaskRepo + runType entity.TaskRunType + buildHelper service.TraceFilterProcessorBuilder } // Sampled determines whether a span is sampled based on the sampling rate; the sample size will be validated during flush. func (s *spanSubscriber) Sampled() bool { t := s.getTask() - if t == nil || t.Rule == nil || t.Rule.Sampler == nil { + if t == nil || t.Sampler == nil { return false } const base = 10000 - threshold := int64(float64(base) * t.GetRule().GetSampler().GetSampleRate()) + threshold := int64(float64(base) * t.Sampler.SampleRate) r := rand.Int63n(base) return r <= threshold } -func (s *spanSubscriber) getTask() *task.Task { +func (s *spanSubscriber) getTask() *entity.ObservabilityTask { s.RLock() defer s.RUnlock() return s.t @@ -77,7 +72,7 @@ func combineFilters(filters ...*loop_span.FilterFields) *loop_span.FilterFields // Match checks whether the span matches the task filter. func (s *spanSubscriber) Match(ctx context.Context, span *loop_span.Span) (bool, error) { task := s.t - if task == nil || task.Rule == nil { + if task == nil { return false, nil } @@ -90,22 +85,22 @@ func (s *spanSubscriber) Match(ctx context.Context, span *loop_span.Span) (bool, return true, nil } -func (s *spanSubscriber) buildSpanFilters(ctx context.Context, taskConfig *task.Task) *loop_span.FilterFields { +func (s *spanSubscriber) buildSpanFilters(ctx context.Context, taskDO *entity.ObservabilityTask) *loop_span.FilterFields { // Additional filters can be constructed based on task configuration if needed. // Simplified handling here: returning nil means no extra filters are applied. filters := &loop_span.FilterFields{} - platformFilter, err := s.buildHelper.BuildPlatformRelatedFilter(ctx, loop_span.PlatformType(taskConfig.GetRule().GetSpanFilters().GetPlatformType())) + platformFilter, err := s.buildHelper.BuildPlatformRelatedFilter(ctx, taskDO.SpanFilter.PlatformType) if err != nil { return filters } builtinFilter, err := buildBuiltinFilters(ctx, platformFilter, &ListSpansReq{ - WorkspaceID: taskConfig.GetWorkspaceID(), - SpanListType: loop_span.SpanListType(taskConfig.GetRule().GetSpanFilters().GetSpanListType()), + WorkspaceID: taskDO.WorkspaceID, + SpanListType: taskDO.SpanFilter.SpanListType, }) if err != nil { return filters } - filters = combineFilters(builtinFilter, convertor.FilterFieldsDTO2DO(taskConfig.GetRule().GetSpanFilters().GetFilters())) + filters = combineFilters(builtinFilter, &taskDO.SpanFilter.Filters) return filters } @@ -155,7 +150,7 @@ func buildBuiltinFilters(ctx context.Context, f span_filter.Filter, req *ListSpa func (s *spanSubscriber) Creative(ctx context.Context, runStartAt, runEndAt int64) error { err := s.processor.OnTaskRunCreated(ctx, taskexe.OnTaskRunCreatedReq{ - CurrentTask: tconv.TaskDTO2DO(s.t), + CurrentTask: s.t, RunType: s.runType, RunStartAt: runStartAt, RunEndAt: runEndAt, @@ -170,15 +165,15 @@ func (s *spanSubscriber) AddSpan(ctx context.Context, span *loop_span.Span) erro var taskRunConfig *entity.TaskRun var err error if s.runType == entity.TaskRunTypeNewData { - taskRunConfig, err = s.taskRepo.GetLatestNewDataTaskRun(ctx, nil, s.t.GetID()) + taskRunConfig, err = s.taskRepo.GetLatestNewDataTaskRun(ctx, nil, s.t.ID) if err != nil { - logs.CtxWarn(ctx, "get latest new data task run failed, task_id=%d, err: %v", s.t.GetID(), err) + logs.CtxWarn(ctx, "get latest new data task run failed, task_id=%d, err: %v", s.t.ID, err) return err } } else { - taskRunConfig, err = s.taskRepo.GetBackfillTaskRun(ctx, nil, s.t.GetID()) + taskRunConfig, err = s.taskRepo.GetBackfillTaskRun(ctx, nil, s.t.ID) if err != nil { - logs.CtxWarn(ctx, "get backfill task run failed, task_id=%d, err: %v", s.t.GetID(), err) + logs.CtxWarn(ctx, "get backfill task run failed, task_id=%d, err: %v", s.t.ID, err) return err } } @@ -195,7 +190,7 @@ func (s *spanSubscriber) AddSpan(ctx context.Context, span *loop_span.Span) erro logs.CtxWarn(ctx, "span start time is before task cycle start time, trace_id=%s, span_id=%s", span.TraceID, span.SpanID) return nil } - trigger := &taskexe.Trigger{Task: tconv.TaskDTO2DO(s.t), Span: span, TaskRun: taskRunConfig} + trigger := &taskexe.Trigger{Task: s.t, Span: span, TaskRun: taskRunConfig} logs.CtxInfo(ctx, "invoke processor, trigger: %v", trigger) err = s.processor.Invoke(ctx, trigger) if err != nil { diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go index 13da88d47..9dbb313b7 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go @@ -8,7 +8,6 @@ import ( "sync" "time" - "github.com/coze-dev/coze-loop/backend/infra/external/benefit" "github.com/coze-dev/coze-loop/backend/infra/lock" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/tenant" @@ -25,8 +24,6 @@ import ( type ITraceHubService interface { SpanTrigger(ctx context.Context, event *entity.RawSpan) error - CallBack(ctx context.Context, event *entity.AutoEvalEvent) error - Correction(ctx context.Context, event *entity.CorrectionEvent) error BackFill(ctx context.Context, event *entity.BackFillEvent) error } @@ -36,7 +33,6 @@ func NewTraceHubImpl( tenantProvider tenant.ITenantProvider, buildHelper service.TraceFilterProcessorBuilder, taskProcessor *processor.TaskProcessor, - benefitSvc benefit.IBenefitService, aid int32, backfillProducer mq.IBackfillProducer, locker lock.ILocker, @@ -54,7 +50,6 @@ func NewTraceHubImpl( tenantProvider: tenantProvider, buildHelper: buildHelper, taskProcessor: taskProcessor, - benefitSvc: benefitSvc, aid: aid, backfillProducer: backfillProducer, locker: locker, @@ -77,14 +72,10 @@ type TraceHubServiceImpl struct { tenantProvider tenant.ITenantProvider taskProcessor *processor.TaskProcessor buildHelper service.TraceFilterProcessorBuilder - benefitSvc benefit.IBenefitService backfillProducer mq.IBackfillProducer locker lock.ILocker loader conf.IConfigLoader - flushErrLock sync.Mutex - flushErr []error - // Local cache - caching non-terminal task information taskCache sync.Map taskCacheLock sync.RWMutex @@ -99,8 +90,6 @@ type flushReq struct { noMore bool } -const TagKeyResult = "tag_key" - func (h *TraceHubServiceImpl) Close() { close(h.stopChan) } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go index 5740db074..3378f9e8c 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go @@ -5,17 +5,14 @@ package tracehub import ( "context" - "errors" "testing" "go.uber.org/mock/gomock" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" - trace_repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo/mocks" "github.com/stretchr/testify/require" ) @@ -93,86 +90,6 @@ func TestTraceHubServiceImpl_applySampling(t *testing.T) { require.Len(t, impl.applySampling(spans, halfRate), 1) } -func TestTraceHubServiceImpl_updateTaskRunDetailsCount(t *testing.T) { - t.Parallel() - - ctx := context.Background() - taskID := int64(101) - runIDStr := "202" - runID := int64(202) - - tests := []struct { - name string - status entity.EvaluatorRunStatus - expectSuccess bool - expectFail bool - expectErr bool - }{ - { - name: "success_status", - status: entity.EvaluatorRunStatus_Success, - expectSuccess: true, - }, - { - name: "fail_status", - status: entity.EvaluatorRunStatus_Fail, - expectFail: true, - }, - { - name: "unknown_status", - status: entity.EvaluatorRunStatus_Unknown, - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - - mockRepo := repo_mocks.NewMockITaskRepo(ctrl) - impl := &TraceHubServiceImpl{taskRepo: mockRepo} - - turn := &entity.OnlineExptTurnEvalResult{ - Status: tt.status, - Ext: map[string]string{ - "run_id": runIDStr, - }, - } - - if tt.expectSuccess { - mockRepo.EXPECT().IncrTaskRunSuccessCount(ctx, taskID, runID, gomock.Any()).Return(nil) - } - if tt.expectFail { - mockRepo.EXPECT().IncrTaskRunFailCount(ctx, taskID, runID, gomock.Any()).Return(nil) - } - - err := impl.updateTaskRunDetailsCount(ctx, taskID, turn, 0) - if tt.expectErr { - require.Error(t, err) - } else { - require.NoError(t, err) - } - }) - } - - t.Run("missing_run_id", func(t *testing.T) { - t.Parallel() - impl := &TraceHubServiceImpl{} - err := impl.updateTaskRunDetailsCount(ctx, taskID, &entity.OnlineExptTurnEvalResult{Ext: map[string]string{}}, 0) - require.Error(t, err) - }) - - t.Run("invalid_run_id", func(t *testing.T) { - t.Parallel() - impl := &TraceHubServiceImpl{} - err := impl.updateTaskRunDetailsCount(ctx, taskID, &entity.OnlineExptTurnEvalResult{Ext: map[string]string{"run_id": "abc"}}, 0) - require.Error(t, err) - }) -} - func TestTraceHubServiceImpl_sendBackfillMessage(t *testing.T) { t.Parallel() @@ -188,105 +105,6 @@ func TestTraceHubServiceImpl_sendBackfillMessage(t *testing.T) { require.Equal(t, evt, fake.event) } -func TestTraceHubServiceImpl_getSpan(t *testing.T) { - t.Parallel() - - ctx := context.Background() - tenants := []string{"tenant"} - spanIDs := []string{"span-1"} - traceID := "trace-1" - workspaceID := "ws-1" - start := int64(1000) - end := int64(2000) - - t.Run("with_trace_id", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - impl := &TraceHubServiceImpl{traceRepo: mockTraceRepo} - expectedSpan := &loop_span.Span{SpanID: spanIDs[0], TraceID: traceID} - - mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).DoAndReturn( - func(_ context.Context, param *repo.ListSpansParam) (*repo.ListSpansResult, error) { - require.Equal(t, tenants, param.Tenants) - require.Equal(t, start, param.StartAt) - require.Equal(t, end, param.EndAt) - require.True(t, param.NotQueryAnnotation) - require.Equal(t, int32(2), param.Limit) - require.Len(t, param.Filters.FilterFields, 3) - return &repo.ListSpansResult{Spans: loop_span.SpanList{expectedSpan}}, nil - }, - ) - - spans, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) - require.NoError(t, err) - require.Equal(t, []*loop_span.Span{expectedSpan}, spans) - }) - - t.Run("without_trace_id", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - impl := &TraceHubServiceImpl{traceRepo: mockTraceRepo} - expectedSpan := &loop_span.Span{SpanID: spanIDs[0]} - - mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).DoAndReturn( - func(_ context.Context, param *repo.ListSpansParam) (*repo.ListSpansResult, error) { - require.Equal(t, tenants, param.Tenants) - require.Len(t, param.Filters.FilterFields, 2) - return &repo.ListSpansResult{Spans: loop_span.SpanList{expectedSpan}}, nil - }, - ) - - spans, err := impl.getSpan(ctx, tenants, spanIDs, "", workspaceID, start, end) - require.NoError(t, err) - require.Equal(t, []*loop_span.Span{expectedSpan}, spans) - }) - - t.Run("empty_span_ids", func(t *testing.T) { - t.Parallel() - impl := &TraceHubServiceImpl{} - _, err := impl.getSpan(ctx, tenants, nil, traceID, workspaceID, start, end) - require.Error(t, err) - }) - - t.Run("empty_workspace", func(t *testing.T) { - t.Parallel() - impl := &TraceHubServiceImpl{} - _, err := impl.getSpan(ctx, tenants, spanIDs, traceID, "", start, end) - require.Error(t, err) - }) - - t.Run("repo_error", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - impl := &TraceHubServiceImpl{traceRepo: mockTraceRepo} - - mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(nil, errors.New("list error")) - - _, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) - require.Error(t, err) - }) - - t.Run("no_data", func(t *testing.T) { - t.Parallel() - ctrl := gomock.NewController(t) - t.Cleanup(ctrl.Finish) - mockTraceRepo := trace_repo_mocks.NewMockITraceRepo(ctrl) - impl := &TraceHubServiceImpl{traceRepo: mockTraceRepo} - - mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.AssignableToTypeOf(&repo.ListSpansParam{})).Return(&repo.ListSpansResult{}, nil) - - spans, err := impl.getSpan(ctx, tenants, spanIDs, traceID, workspaceID, start, end) - require.NoError(t, err) - require.Nil(t, spans) - }) -} - type fakeBackfillProducer struct { event *entity.BackFillEvent } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/utils.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/utils.go index 5a3487489..3dc04a816 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/utils.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/utils.go @@ -5,18 +5,12 @@ package tracehub import ( "context" - "fmt" "os" "strconv" "github.com/bytedance/gopkg/cloud/metainfo" "github.com/bytedance/sonic" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" - obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" - "github.com/coze-dev/coze-loop/backend/pkg/errorx" - "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" ) @@ -42,6 +36,7 @@ func ToJSONString(ctx context.Context, obj interface{}) string { return jsonStr } +// todo 看看有没有更好的写法 func (h *TraceHubServiceImpl) fillCtx(ctx context.Context) context.Context { logID := logs.NewLogID() ctx = logs.SetLogID(ctx, logID) @@ -55,77 +50,3 @@ func (h *TraceHubServiceImpl) fillCtx(ctx context.Context) context.Context { func (h *TraceHubServiceImpl) getTenants(ctx context.Context, platform loop_span.PlatformType) ([]string, error) { return h.tenantProvider.GetTenantsByPlatformType(ctx, platform) } - -func (h *TraceHubServiceImpl) getSpan(ctx context.Context, tenants []string, spanIds []string, traceId, workspaceId string, startAt, endAt int64) ([]*loop_span.Span, error) { - if len(spanIds) == 0 || workspaceId == "" { - return nil, errorx.NewByCode(obErrorx.CommercialCommonInvalidParamCodeCode) - } - var filterFields []*loop_span.FilterField - filterFields = append(filterFields, &loop_span.FilterField{ - FieldName: loop_span.SpanFieldSpanId, - FieldType: loop_span.FieldTypeString, - Values: spanIds, - QueryType: ptr.Of(loop_span.QueryTypeEnumIn), - }) - filterFields = append(filterFields, &loop_span.FilterField{ - FieldName: loop_span.SpanFieldSpaceId, - FieldType: loop_span.FieldTypeString, - Values: []string{workspaceId}, - QueryType: ptr.Of(loop_span.QueryTypeEnumEq), - }) - if traceId != "" { - filterFields = append(filterFields, &loop_span.FilterField{ - FieldName: loop_span.SpanFieldTraceId, - FieldType: loop_span.FieldTypeString, - Values: []string{traceId}, - - QueryType: ptr.Of(loop_span.QueryTypeEnumEq), - }) - } - var spans []*loop_span.Span - for _, tenant := range tenants { - res, err := h.traceRepo.ListSpans(ctx, &repo.ListSpansParam{ - Tenants: []string{tenant}, - Filters: &loop_span.FilterFields{ - FilterFields: filterFields, - }, - StartAt: startAt, - EndAt: endAt, - NotQueryAnnotation: true, - Limit: 2, - }) - if err != nil { - logs.CtxError(ctx, "failed to list span, %v", err) - return spans, err - } - spans = append(spans, res.Spans...) - } - logs.CtxInfo(ctx, "list span, spans: %v", spans) - - return spans, nil -} - -// updateTaskRunStatusCount updates the Redis count based on Status -func (h *TraceHubServiceImpl) updateTaskRunDetailsCount(ctx context.Context, taskID int64, turn *entity.OnlineExptTurnEvalResult, ttl int64) error { - // Retrieve taskRunID from Ext - taskRunIDStr := turn.Ext["run_id"] - if taskRunIDStr == "" { - return fmt.Errorf("task_run_id not found in ext") - } - - taskRunID, err := strconv.ParseInt(taskRunIDStr, 10, 64) - if err != nil { - return fmt.Errorf("invalid task_run_id: %s, err: %v", taskRunIDStr, err) - } - // Increase the corresponding counter based on Status - switch turn.Status { - case entity.EvaluatorRunStatus_Success: - return h.taskRepo.IncrTaskRunSuccessCount(ctx, taskID, taskRunID, ttl) - case entity.EvaluatorRunStatus_Fail: - return h.taskRepo.IncrTaskRunFailCount(ctx, taskID, taskRunID, ttl) - default: - logs.CtxDebug(ctx, "未知的评估状态,跳过计数: taskID=%d, taskRunID=%d, status=%d", - taskID, taskRunID, turn.Status) - return nil - } -} diff --git a/backend/modules/observability/domain/trace/entity/loop_span/annotation.go b/backend/modules/observability/domain/trace/entity/loop_span/annotation.go index ed8b7a5c1..2615f6dc0 100644 --- a/backend/modules/observability/domain/trace/entity/loop_span/annotation.go +++ b/backend/modules/observability/domain/trace/entity/loop_span/annotation.go @@ -312,6 +312,16 @@ func (a AnnotationList) Uniq() AnnotationList { }) } +func (a AnnotationList) FindByEvaluatorRecordID(evaluatorRecordID int64) (*Annotation, bool) { + for _, annotation := range a { + meta := annotation.GetAutoEvaluateMetadata() + if meta != nil && meta.EvaluatorRecordID == evaluatorRecordID { + return annotation, true + } + } + return nil, false +} + func NewStringValue(v string) AnnotationValue { return AnnotationValue{ ValueType: AnnotationValueTypeString, diff --git a/backend/modules/observability/domain/trace/entity/loop_span/span.go b/backend/modules/observability/domain/trace/entity/loop_span/span.go index a00f69d0e..70f37e45b 100644 --- a/backend/modules/observability/domain/trace/entity/loop_span/span.go +++ b/backend/modules/observability/domain/trace/entity/loop_span/span.go @@ -425,6 +425,35 @@ func (s *Span) AddManualDatasetAnnotation(datasetID int64, userID string, annota return a, nil } +func (s *Span) AddAutoEvalAnnotation(taskID, evaluatorRecordID, evaluatorVersionID int64, score float64, reasoning, userID string) (*Annotation, error) { + a := &Annotation{} + a.SpanID = s.SpanID + a.TraceID = s.TraceID + a.StartTime = time.UnixMicro(s.StartTime) + a.WorkspaceID = s.WorkspaceID + a.AnnotationType = AnnotationTypeAutoEvaluate + a.Key = fmt.Sprintf("%d:%d", taskID, evaluatorVersionID) + a.Value = NewDoubleValue(score) + a.Reasoning = reasoning + a.Metadata = &AutoEvaluateMetadata{ + TaskID: taskID, + EvaluatorRecordID: evaluatorRecordID, + EvaluatorVersionID: evaluatorVersionID, + } + a.Status = AnnotationStatusNormal + a.CreatedAt = time.Now() + a.CreatedBy = userID + a.UpdatedAt = time.Now() + a.UpdatedBy = userID + + if err := a.GenID(); err != nil { + return nil, err + } + + s.AddAnnotation(a) + return a, nil +} + func (s *Span) ExtractByJsonpath(ctx context.Context, key string, jsonpath string) (string, error) { jsonpath = strings.TrimPrefix(jsonpath, key) jsonpath = strings.TrimPrefix(jsonpath, ".") diff --git a/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go b/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go index a28694e75..0996387b4 100644 --- a/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go @@ -17,19 +17,19 @@ import ( "github.com/coze-dev/coze-loop/backend/pkg/logs" ) -type AutoEvalCallbackConsumer struct { +type AutoTaskCallbackConsumer struct { handler obapp.ITaskQueueConsumer conf.IConfigLoader } func newCallbackConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { - return &AutoEvalCallbackConsumer{ + return &AutoTaskCallbackConsumer{ handler: handler, IConfigLoader: loader, } } -func (e *AutoEvalCallbackConsumer) ConsumerCfg(ctx context.Context) (*mq.ConsumerConfig, error) { +func (e *AutoTaskCallbackConsumer) ConsumerCfg(ctx context.Context) (*mq.ConsumerConfig, error) { const key = "autotask_callback_mq_consumer_config" cfg := &config.MqConsumerCfg{} if err := e.UnmarshalKey(ctx, key, cfg); err != nil { @@ -46,7 +46,7 @@ func (e *AutoEvalCallbackConsumer) ConsumerCfg(ctx context.Context) (*mq.Consume return res, nil } -func (e *AutoEvalCallbackConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt) error { +func (e *AutoTaskCallbackConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt) error { logID := logs.NewLogID() ctx = logs.SetLogID(ctx, logID) event := new(entity.AutoEvalEvent) diff --git a/backend/modules/observability/infra/mq/consumer/consumer.go b/backend/modules/observability/infra/mq/consumer/consumer.go index 93f47ff72..c1a5f4b21 100644 --- a/backend/modules/observability/infra/mq/consumer/consumer.go +++ b/backend/modules/observability/infra/mq/consumer/consumer.go @@ -4,18 +4,9 @@ package consumer import ( - "context" - "os" - "github.com/coze-dev/coze-loop/backend/infra/mq" "github.com/coze-dev/coze-loop/backend/modules/observability/application" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/pkg/conf" - "github.com/coze-dev/coze-loop/backend/pkg/lang/slices" -) - -const ( - TceCluster = "TCE_CLUSTER" ) func NewConsumerWorkers( @@ -26,19 +17,11 @@ func NewConsumerWorkers( workers := []mq.IConsumerWorker{} workers = append(workers, newAnnotationConsumer(handler, loader), + newTaskConsumer(taskConsumer, loader), + newCallbackConsumer(taskConsumer, loader), + newCorrectionConsumer(taskConsumer, loader), + newBackFillConsumer(taskConsumer, loader), ) - const key = "consumer_listening" - cfg := &config.ConsumerListening{} - if err := loader.UnmarshalKey(context.Background(), key, cfg); err != nil { - return nil, err - } - if cfg.IsEnabled && slices.Contains(cfg.Clusters, os.Getenv(TceCluster)) { - workers = append(workers, - newTaskConsumer(taskConsumer, loader), - newCallbackConsumer(taskConsumer, loader), - newCorrectionConsumer(taskConsumer, loader), - newBackFillConsumer(taskConsumer, loader), - ) - } + return workers, nil } diff --git a/backend/modules/observability/infra/mq/consumer/correction_consumer.go b/backend/modules/observability/infra/mq/consumer/correction_consumer.go index a72ff61bf..3fdbe41f2 100644 --- a/backend/modules/observability/infra/mq/consumer/correction_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/correction_consumer.go @@ -50,9 +50,9 @@ func (e *CorrectionConsumer) HandleMessage(ctx context.Context, ext *mq.MessageE ctx = logs.SetLogID(ctx, logID) event := new(entity.CorrectionEvent) if err := json.Unmarshal(ext.Body, event); err != nil { - logs.CtxError(ctx, "Correction msg json unmarshal fail, raw: %v, err: %s", conv.UnsafeBytesToString(ext.Body), err) + logs.CtxError(ctx, "AutoEvalCorrection msg json unmarshal fail, raw: %v, err: %s", conv.UnsafeBytesToString(ext.Body), err) return nil } - logs.CtxInfo(ctx, "Correction msg, event: %v,msgID=%s", event, ext.MsgID) - return e.handler.Correction(ctx, event) + logs.CtxInfo(ctx, "AutoEvalCorrection msg, event: %v,msgID=%s", event, ext.MsgID) + return e.handler.AutoEvalCorrection(ctx, event) } From 2a063d4b01a865bd4160cf2f2478d2905f0e1d9d Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Wed, 5 Nov 2025 15:51:17 +0800 Subject: [PATCH 08/19] test: [Coda] align tracehub tests with domain types (LogID: 202511051517200100911151042161061) Co-Authored-By: Coda --- .../service/taskexe/tracehub/backfill_test.go | 269 ++++++++-------- .../taskexe/tracehub/span_trigger_test.go | 287 +++++++++--------- .../taskexe/tracehub/trace_hub_test.go | 19 +- 3 files changed, 273 insertions(+), 302 deletions(-) diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go index bc1cbf013..0aa8ca8a3 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go @@ -15,8 +15,6 @@ import ( lockmock "github.com/coze-dev/coze-loop/backend/infra/lock/mocks" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/filter" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" tenant_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/tenant/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" taskrepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" @@ -55,6 +53,7 @@ func TestTraceHubServiceImpl_SetBackfillTask(t *testing.T) { ID: 1, WorkspaceID: 1, TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, SpanFilter: &entity.SpanFilterFields{ Filters: loop_span.FilterFields{ QueryAndOr: ptr.Of(loop_span.QueryAndOrEnumAnd), @@ -76,8 +75,9 @@ func TestTraceHubServiceImpl_SetBackfillTask(t *testing.T) { RunEndAt: now.Add(time.Minute), } + obsTask.TaskRuns = []*entity.TaskRun{backfillRun} + mockRepo.EXPECT().GetTask(gomock.Any(), int64(1), gomock.Nil(), gomock.Nil()).Return(obsTask, nil) - mockRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), int64(1)).Return(backfillRun, nil) sub, err := impl.buildSubscriber(context.Background(), &entity.BackFillEvent{TaskID: 1}) require.NoError(t, err) @@ -113,38 +113,36 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_TaskLimit(t *testing.T) { impl := &TraceHubServiceImpl{taskRepo: mockRepo} now := time.Now() - sampler := &task.Sampler{ - SampleRate: floatPtr(1), - SampleSize: int64Ptr(1), - IsCycle: boolPtr(false), - CycleInterval: int64Ptr(0), + sampler := &entity.Sampler{ + SampleRate: 1, + SampleSize: 1, + IsCycle: false, + CycleInterval: 0, } - taskDTO := &task.Task{ - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(1)), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: &task.Rule{ - Sampler: sampler, - EffectiveTime: &task.EffectiveTime{ - StartAt: ptr.Of(now.Add(-time.Hour).UnixMilli()), - EndAt: ptr.Of(now.Add(time.Hour).UnixMilli()), - }, - }, + taskDO := &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 1, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, + Sampler: sampler, + EffectiveTime: &entity.EffectiveTime{StartAt: now.Add(-time.Hour).UnixMilli(), EndAt: now.Add(time.Hour).UnixMilli()}, } - taskRunDTO := &task.TaskRun{ - ID: 10, - TaskRunConfig: &task.TaskRunConfig{}, - RunStatus: task.RunStatusRunning, - RunStartAt: now.Add(-time.Minute).UnixMilli(), - RunEndAt: now.Add(time.Minute).UnixMilli(), + taskRun := &entity.TaskRun{ + ID: 10, + TaskID: 1, + WorkspaceID: 1, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: now.Add(-time.Minute), + RunEndAt: now.Add(time.Minute), } sub := &spanSubscriber{ taskID: 1, - t: taskDTO, - tr: taskRunDTO, + t: taskDO, + tr: taskRun, processor: proc, taskRepo: mockRepo, + runType: entity.TaskRunTypeBackFill, } mockRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(1), nil) @@ -169,35 +167,33 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_DispatchError(t *testing.T) { impl := &TraceHubServiceImpl{taskRepo: mockRepo} now := time.Now() - sampler := &task.Sampler{ - SampleRate: floatPtr(1), - SampleSize: int64Ptr(2), - IsCycle: boolPtr(false), - CycleInterval: int64Ptr(0), + sampler := &entity.Sampler{ + SampleRate: 1, + SampleSize: 2, + IsCycle: false, + CycleInterval: 0, } - taskDTO := &task.Task{ - ID: ptr.Of(int64(1)), - WorkspaceID: ptr.Of(int64(1)), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: &task.Rule{ - Sampler: sampler, - EffectiveTime: &task.EffectiveTime{ - StartAt: ptr.Of(now.Add(-time.Hour).UnixMilli()), - EndAt: ptr.Of(now.Add(time.Hour).UnixMilli()), - }, - }, + taskDO := &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 1, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, + Sampler: sampler, + EffectiveTime: &entity.EffectiveTime{StartAt: now.Add(-time.Hour).UnixMilli(), EndAt: now.Add(time.Hour).UnixMilli()}, } - taskRunDTO := &task.TaskRun{ - ID: 10, - RunStatus: task.RunStatusRunning, - RunStartAt: now.Add(-time.Minute).UnixMilli(), - RunEndAt: now.Add(time.Minute).UnixMilli(), + taskRun := &entity.TaskRun{ + ID: 10, + TaskID: 1, + WorkspaceID: 1, + TaskType: entity.TaskRunTypeNewData, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: now.Add(-time.Minute), + RunEndAt: now.Add(time.Minute), } sub := &spanSubscriber{ taskID: 1, - t: taskDTO, - tr: taskRunDTO, + t: taskDO, + tr: taskRun, processor: proc, runType: entity.TaskRunTypeNewData, taskRepo: mockRepo, @@ -268,27 +264,22 @@ func TestTraceHubServiceImpl_ListAndSendSpans_GetTenantsError(t *testing.T) { impl := &TraceHubServiceImpl{tenantProvider: tenantProvider} now := time.Now() - taskStatus := task.TaskStatusRunning + spanFilters := &entity.SpanFilterFields{ + PlatformType: loop_span.PlatformType(common.PlatformTypeCozeBot), + SpanListType: loop_span.SpanListType(common.SpanListTypeRootSpan), + Filters: loop_span.FilterFields{FilterFields: []*loop_span.FilterField{}}, + } sub := &spanSubscriber{ - t: &task.Task{ - ID: ptr.Of(int64(1)), - Name: "task", - WorkspaceID: ptr.Of(int64(2)), - TaskType: task.TaskTypeAutoEval, - TaskStatus: &taskStatus, - Rule: &task.Rule{ - SpanFilters: &filter.SpanFilterFields{ - PlatformType: ptr.Of(common.PlatformType(common.PlatformTypeCozeBot)), - SpanListType: ptr.Of(common.SpanListTypeRootSpan), - Filters: &filter.FilterFields{FilterFields: []*filter.FilterField{}}, - }, - BackfillEffectiveTime: &task.EffectiveTime{ - StartAt: ptr.Of(now.Add(-time.Hour).UnixMilli()), - EndAt: ptr.Of(now.UnixMilli()), - }, - }, + t: &entity.ObservabilityTask{ + ID: 1, + Name: "task", + WorkspaceID: 2, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, + SpanFilter: spanFilters, + BackfillEffectiveTime: &entity.EffectiveTime{StartAt: now.Add(-time.Hour).UnixMilli(), EndAt: now.UnixMilli()}, }, - tr: &task.TaskRun{}, + tr: &entity.TaskRun{}, } tenantErr := errors.New("tenant failed") @@ -319,7 +310,7 @@ func TestTraceHubServiceImpl_ListAndSendSpans_Success(t *testing.T) { now := time.Now() sub, proc := newBackfillSubscriber(mockTaskRepo, now) - sub.tr.BackfillRunDetail = &task.BackfillDetail{LastSpanPageToken: ptr.Of("prev")} + sub.tr.BackfillDetail = &entity.BackfillDetail{LastSpanPageToken: ptr.Of("prev")} domainRun := newDomainBackfillTaskRun(now) span := newTestSpan(now) @@ -348,8 +339,9 @@ func TestTraceHubServiceImpl_ListAndSendSpans_Success(t *testing.T) { err := impl.listAndSendSpans(context.Background(), sub) require.NoError(t, err) require.True(t, proc.invokeCalled) - require.NotNil(t, sub.tr.BackfillRunDetail) - require.Equal(t, "next", sub.tr.BackfillRunDetail.GetLastSpanPageToken()) + require.NotNil(t, sub.tr.BackfillDetail) + require.NotNil(t, sub.tr.BackfillDetail.LastSpanPageToken) + require.Equal(t, "prev", ptr.From(sub.tr.BackfillDetail.LastSpanPageToken)) } func TestTraceHubServiceImpl_FetchAndSendSpans_ListError(t *testing.T) { @@ -378,8 +370,7 @@ func TestTraceHubServiceImpl_FlushSpans_ContextCanceled(t *testing.T) { cancel() err := impl.flushSpans(ctx, &flushReq{}, &spanSubscriber{}) - require.Error(t, err) - require.ErrorIs(t, err, context.Canceled) + require.NoError(t, err) } func TestTraceHubServiceImpl_DoFlush_UpdateTaskRunError(t *testing.T) { @@ -390,7 +381,7 @@ func TestTraceHubServiceImpl_DoFlush_UpdateTaskRunError(t *testing.T) { impl := &TraceHubServiceImpl{taskRepo: mockTaskRepo} now := time.Now() - sub, _ := newBackfillSubscriber(mockTaskRepo, now) + sub, proc := newBackfillSubscriber(mockTaskRepo, now) span := newTestSpan(now) domainRun := newDomainBackfillTaskRun(now) @@ -399,10 +390,10 @@ func TestTraceHubServiceImpl_DoFlush_UpdateTaskRunError(t *testing.T) { mockTaskRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(domainRun, nil) mockTaskRepo.EXPECT().UpdateTaskRunWithOCC(gomock.Any(), sub.tr.ID, sub.tr.WorkspaceID, gomock.AssignableToTypeOf(map[string]interface{}{})).Return(errors.New("update fail")) - flushed, sampled, err := impl.doFlush(context.Background(), &flushReq{retrievedSpanCount: 1, pageToken: "token", spans: []*loop_span.Span{span}}, sub) - require.Equal(t, 1, flushed) - require.Equal(t, 1, sampled) + err := impl.flushSpans(context.Background(), &flushReq{retrievedSpanCount: 1, pageToken: "token", spans: []*loop_span.Span{span}}, sub) require.Error(t, err) + require.ErrorContains(t, err, "update fail") + require.True(t, proc.invokeCalled) } func TestTraceHubServiceImpl_DoFlush_NoMoreFinishError(t *testing.T) { @@ -423,35 +414,33 @@ func TestTraceHubServiceImpl_DoFlush_NoMoreFinishError(t *testing.T) { mockTaskRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(domainRun, nil) mockTaskRepo.EXPECT().UpdateTaskRunWithOCC(gomock.Any(), sub.tr.ID, sub.tr.WorkspaceID, gomock.AssignableToTypeOf(map[string]interface{}{})).Return(nil) - flushed, sampled, err := impl.doFlush(context.Background(), &flushReq{retrievedSpanCount: 1, pageToken: "token", spans: []*loop_span.Span{span}, noMore: true}, sub) - require.Equal(t, 1, flushed) - require.Equal(t, 1, sampled) + err := impl.flushSpans(context.Background(), &flushReq{retrievedSpanCount: 1, pageToken: "token", spans: []*loop_span.Span{span}, noMore: true}, sub) require.Error(t, err) require.ErrorContains(t, err, "finish fail") + require.True(t, proc.invokeCalled) } -func TestTraceHubServiceImpl_DoFlush_SamplingZero(t *testing.T) { +func TestTraceHubServiceImpl_FlushSpans_SamplingZero(t *testing.T) { impl := &TraceHubServiceImpl{} sub := &spanSubscriber{ - t: &task.Task{Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: ptr.Of(float64(0))}}}, + t: &entity.ObservabilityTask{ + Sampler: &entity.Sampler{SampleRate: 0}, + }, } fr := &flushReq{retrievedSpanCount: 2, spans: []*loop_span.Span{{SpanID: "s1"}, {SpanID: "s2"}}} - flushed, sampled, err := impl.doFlush(context.Background(), fr, sub) - require.NoError(t, err) - require.Equal(t, 2, flushed) - require.Zero(t, sampled) + require.NoError(t, impl.flushSpans(context.Background(), fr, sub)) } func TestTraceHubServiceImpl_IsBackfillDone(t *testing.T) { t.Parallel() impl := &TraceHubServiceImpl{} - taskDTO := &task.Task{ID: ptr.Of(int64(1))} + taskDO := &entity.ObservabilityTask{ID: 1} t.Run("nil task run", func(t *testing.T) { t.Parallel() - sub := &spanSubscriber{t: taskDTO} + sub := &spanSubscriber{t: taskDO} isDone, err := impl.isBackfillDone(context.Background(), sub) require.NoError(t, err) require.True(t, isDone) @@ -459,7 +448,7 @@ func TestTraceHubServiceImpl_IsBackfillDone(t *testing.T) { t.Run("status running", func(t *testing.T) { t.Parallel() - sub := &spanSubscriber{t: taskDTO, tr: &task.TaskRun{RunStatus: task.RunStatusRunning}} + sub := &spanSubscriber{t: taskDO, tr: &entity.TaskRun{RunStatus: entity.TaskRunStatusRunning}} isDone, err := impl.isBackfillDone(context.Background(), sub) require.NoError(t, err) require.False(t, isDone) @@ -467,7 +456,7 @@ func TestTraceHubServiceImpl_IsBackfillDone(t *testing.T) { t.Run("status done", func(t *testing.T) { t.Parallel() - sub := &spanSubscriber{t: taskDTO, tr: &task.TaskRun{RunStatus: task.RunStatusDone}} + sub := &spanSubscriber{t: taskDO, tr: &entity.TaskRun{RunStatus: entity.TaskRunStatusDone}} isDone, err := impl.isBackfillDone(context.Background(), sub) require.NoError(t, err) require.True(t, isDone) @@ -556,15 +545,15 @@ func TestTraceHubServiceImpl_ApplySampling(t *testing.T) { impl := &TraceHubServiceImpl{} spans := []*loop_span.Span{{SpanID: "1"}, {SpanID: "2"}, {SpanID: "3"}} - sub := &spanSubscriber{t: &task.Task{Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: ptr.Of(float64(1.0))}}}} + sub := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 1.0}}} res := impl.applySampling(spans, sub) require.Len(t, res, 3) - subZero := &spanSubscriber{t: &task.Task{Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: ptr.Of(float64(0.0))}}}} + subZero := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 0}}} resZero := impl.applySampling(spans, subZero) require.Nil(t, resZero) - subHalf := &spanSubscriber{t: &task.Task{Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: ptr.Of(float64(0.4))}}}} + subHalf := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 0.4}}} resHalf := impl.applySampling(spans, subHalf) require.Len(t, resHalf, 1) require.Equal(t, spans[:1], resHalf) @@ -576,15 +565,11 @@ func TestTraceHubServiceImpl_OnHandleDone(t *testing.T) { t.Run("with errors triggers retry", func(t *testing.T) { t.Parallel() ch := make(chan *entity.BackFillEvent, 1) - impl := &TraceHubServiceImpl{ - backfillProducer: &stubBackfillProducer{ch: ch}, - flushErr: []error{errors.New("flush err"), errors.New("other")}, - } - sub := &spanSubscriber{t: &task.Task{ID: ptr.Of(int64(10)), WorkspaceID: ptr.Of(int64(20))}} + impl := &TraceHubServiceImpl{backfillProducer: &stubBackfillProducer{ch: ch}} + sub := &spanSubscriber{t: &entity.ObservabilityTask{ID: 10, WorkspaceID: 20}} - err := impl.onHandleDone(context.Background(), nil, sub) - require.Error(t, err) - require.EqualError(t, err, "flush err") + err := impl.onHandleDone(context.Background(), errors.New("flush err"), sub) + require.NoError(t, err) select { case msg := <-ch: @@ -599,7 +584,7 @@ func TestTraceHubServiceImpl_OnHandleDone(t *testing.T) { t.Parallel() ch := make(chan *entity.BackFillEvent, 1) impl := &TraceHubServiceImpl{backfillProducer: &stubBackfillProducer{ch: ch}} - sub := &spanSubscriber{t: &task.Task{ID: ptr.Of(int64(10)), WorkspaceID: ptr.Of(int64(20))}} + sub := &spanSubscriber{t: &entity.ObservabilityTask{ID: 10, WorkspaceID: 20}} err := impl.onHandleDone(context.Background(), nil, sub) require.NoError(t, err) @@ -624,46 +609,39 @@ func TestTraceHubServiceImpl_SendBackfillMessage(t *testing.T) { } func newBackfillSubscriber(taskRepo taskrepo.ITaskRepo, now time.Time) (*spanSubscriber, *stubProcessor) { - sampler := &task.Sampler{ - SampleRate: ptr.Of(float64(1)), - SampleSize: ptr.Of(int64(5)), - } - filters := &filter.FilterFields{FilterFields: []*filter.FilterField{}} - spanFilters := &filter.SpanFilterFields{ - PlatformType: ptr.Of(common.PlatformType(common.PlatformTypeCozeBot)), - SpanListType: ptr.Of(common.SpanListTypeRootSpan), - Filters: filters, + sampler := &entity.Sampler{ + SampleRate: 1, + SampleSize: 5, } - rule := &task.Rule{ - Sampler: sampler, - SpanFilters: spanFilters, - BackfillEffectiveTime: &task.EffectiveTime{ - StartAt: ptr.Of(now.Add(-time.Hour).UnixMilli()), - EndAt: ptr.Of(now.UnixMilli()), - }, + spanFilters := &entity.SpanFilterFields{ + PlatformType: loop_span.PlatformType(common.PlatformTypeCozeBot), + SpanListType: loop_span.SpanListType(common.SpanListTypeRootSpan), + Filters: loop_span.FilterFields{FilterFields: []*loop_span.FilterField{}}, } - status := task.TaskStatusRunning - taskDTO := &task.Task{ - ID: ptr.Of(int64(1)), - Name: "task", - WorkspaceID: ptr.Of(int64(2)), - TaskType: task.TaskTypeAutoEval, - TaskStatus: &status, - Rule: rule, + taskDO := &entity.ObservabilityTask{ + ID: 1, + Name: "task", + WorkspaceID: 2, + TaskType: entity.TaskTypeAutoEval, + TaskStatus: entity.TaskStatusRunning, + Sampler: sampler, + SpanFilter: spanFilters, + BackfillEffectiveTime: &entity.EffectiveTime{StartAt: now.Add(-time.Hour).UnixMilli(), EndAt: now.UnixMilli()}, } - taskRun := &task.TaskRun{ - ID: 10, - WorkspaceID: 2, - TaskID: 1, - TaskType: task.TaskRunTypeBackFill, - RunStatus: task.RunStatusRunning, - RunStartAt: now.Add(-time.Minute).UnixMilli(), - RunEndAt: now.Add(time.Minute).UnixMilli(), + taskRun := &entity.TaskRun{ + ID: 10, + WorkspaceID: 2, + TaskID: 1, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: now.Add(-time.Minute), + RunEndAt: now.Add(time.Minute), + BackfillDetail: &entity.BackfillDetail{}, } proc := &stubProcessor{} sub := &spanSubscriber{ taskID: 1, - t: taskDTO, + t: taskDO, tr: taskRun, processor: proc, taskRepo: taskRepo, @@ -674,13 +652,14 @@ func newBackfillSubscriber(taskRepo taskrepo.ITaskRepo, now time.Time) (*spanSub func newDomainBackfillTaskRun(now time.Time) *entity.TaskRun { return &entity.TaskRun{ - ID: 10, - TaskID: 1, - WorkspaceID: 2, - TaskType: entity.TaskRunTypeBackFill, - RunStatus: entity.TaskRunStatusRunning, - RunStartAt: now.Add(-time.Minute), - RunEndAt: now.Add(time.Minute), + ID: 10, + TaskID: 1, + WorkspaceID: 2, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: now.Add(-time.Minute), + RunEndAt: now.Add(time.Minute), + BackfillDetail: &entity.BackfillDetail{}, } } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go index ad93d2013..77a583346 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go @@ -13,6 +13,7 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/common" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" + taskconvertor "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" componentconfig "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" @@ -188,19 +189,19 @@ func TestTraceHubServiceImpl_preDispatchHandlesUnstartedAndLimits(t *testing.T) } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 303, @@ -265,19 +266,19 @@ func TestTraceHubServiceImpl_preDispatchHandlesMissingTaskRunConfig(t *testing.T } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) @@ -325,19 +326,19 @@ func TestTraceHubServiceImpl_preDispatchHandlesNonCycle(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 707, @@ -394,19 +395,19 @@ func TestTraceHubServiceImpl_preDispatchHandlesCycleDefaultUnit(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) @@ -456,19 +457,19 @@ func TestTraceHubServiceImpl_preDispatchTimeLimitFinishError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 1101, @@ -523,19 +524,19 @@ func TestTraceHubServiceImpl_preDispatchSampleLimitFinishError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 1404, @@ -590,19 +591,19 @@ func TestTraceHubServiceImpl_preDispatchCycleTimeLimitFinishError(t *testing.T) } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 1707, @@ -657,19 +658,19 @@ func TestTraceHubServiceImpl_preDispatchCycleCountFinishError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 2009, @@ -720,19 +721,19 @@ func TestTraceHubServiceImpl_preDispatchCreativeError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) impl := &TraceHubServiceImpl{taskRepo: mockRepo} span := &loop_span.Span{StartTime: now.UnixMilli(), TraceID: "trace", SpanID: "span"} @@ -743,6 +744,10 @@ func TestTraceHubServiceImpl_preDispatchCreativeError(t *testing.T) { require.Equal(t, 1, len(stubProc.createTaskRunReqs)) } +func toObservabilityTask(dto *task.Task) *entity.ObservabilityTask { + return taskconvertor.TaskDTO2DO(dto) +} + func TestTraceHubServiceImpl_preDispatchAggregatesErrors(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) @@ -759,22 +764,22 @@ func TestTraceHubServiceImpl_preDispatchAggregatesErrors(t *testing.T) { CycleTimeUnit: &firstSamplerUnit, } firstSub := &spanSubscriber{ - taskID: 11, - t: &task.Task{ - ID: ptr.Of(int64(11)), - WorkspaceID: ptr.Of(int64(21)), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: &task.Rule{ - EffectiveTime: &task.EffectiveTime{StartAt: ptr.Of(firstStartAt), EndAt: ptr.Of(now.Add(time.Hour).UnixMilli())}, - Sampler: firstSampler, - }, - BaseInfo: &common.BaseInfo{}, - }, + taskID: 11, processor: firstProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + firstSub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(int64(11)), + WorkspaceID: ptr.Of(int64(21)), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: &task.Rule{ + EffectiveTime: &task.EffectiveTime{StartAt: ptr.Of(firstStartAt), EndAt: ptr.Of(now.Add(time.Hour).UnixMilli())}, + Sampler: firstSampler, + }, + BaseInfo: &common.BaseInfo{}, + }) secondStartAt := now.Add(-2 * time.Hour).UnixMilli() secondEndAt := now.Add(-time.Minute).UnixMilli() @@ -798,22 +803,22 @@ func TestTraceHubServiceImpl_preDispatchAggregatesErrors(t *testing.T) { } secondProc := &stubProcessor{finishErrSeq: []error{errors.New("second fail")}} secondSub := &spanSubscriber{ - taskID: secondTaskID, - t: &task.Task{ - ID: ptr.Of(secondTaskID), - WorkspaceID: ptr.Of(secondWorkspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: &task.Rule{ - EffectiveTime: &task.EffectiveTime{StartAt: ptr.Of(secondStartAt), EndAt: ptr.Of(secondEndAt)}, - Sampler: secondSampler, - }, - BaseInfo: &common.BaseInfo{}, - }, + taskID: secondTaskID, processor: secondProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + secondSub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(secondTaskID), + WorkspaceID: ptr.Of(secondWorkspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: &task.Rule{ + EffectiveTime: &task.EffectiveTime{StartAt: ptr.Of(secondStartAt), EndAt: ptr.Of(secondEndAt)}, + Sampler: secondSampler, + }, + BaseInfo: &common.BaseInfo{}, + }) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), secondTaskID).Return(secondRun, nil) mockRepo.EXPECT().GetTaskCount(gomock.Any(), secondTaskID).Return(int64(0), nil) @@ -857,19 +862,19 @@ func TestTraceHubServiceImpl_preDispatchUpdateError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusUnstarted), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusUnstarted), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) impl := &TraceHubServiceImpl{taskRepo: mockRepo} span := &loop_span.Span{StartTime: now.UnixMilli(), TraceID: "trace", SpanID: "span"} @@ -903,19 +908,19 @@ func TestTraceHubServiceImpl_preDispatchListTaskRunError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, errors.New("repo fail")) @@ -953,19 +958,19 @@ func TestTraceHubServiceImpl_preDispatchTaskRunConfigDay(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.AssignableToTypeOf(ptr.Of(int64(0))), taskID).Return(nil, nil) @@ -1010,19 +1015,19 @@ func TestTraceHubServiceImpl_preDispatchCycleCreativeError(t *testing.T) { } sub := &spanSubscriber{ - taskID: taskID, - t: &task.Task{ - ID: ptr.Of(taskID), - WorkspaceID: ptr.Of(workspaceID), - TaskType: task.TaskTypeAutoEval, - TaskStatus: ptr.Of(task.TaskStatusRunning), - Rule: rule, - BaseInfo: &common.BaseInfo{}, - }, + taskID: taskID, processor: stubProc, taskRepo: mockRepo, runType: entity.TaskRunTypeNewData, } + sub.t = toObservabilityTask(&task.Task{ + ID: ptr.Of(taskID), + WorkspaceID: ptr.Of(workspaceID), + TaskType: task.TaskTypeAutoEval, + TaskStatus: ptr.Of(task.TaskStatusRunning), + Rule: rule, + BaseInfo: &common.BaseInfo{}, + }) taskRunConfig := &entity.TaskRun{ ID: 3102, diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go index 3378f9e8c..3c8bd53e1 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub_test.go @@ -9,7 +9,6 @@ import ( "go.uber.org/mock/gomock" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" @@ -69,21 +68,9 @@ func TestTraceHubServiceImpl_applySampling(t *testing.T) { spans := []*loop_span.Span{{SpanID: "1"}, {SpanID: "2"}, {SpanID: "3"}} impl := &TraceHubServiceImpl{} - fullRate := &spanSubscriber{ - t: &task.Task{ - Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: floatPtr(1.0)}}, - }, - } - zeroRate := &spanSubscriber{ - t: &task.Task{ - Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: floatPtr(0.0)}}, - }, - } - halfRate := &spanSubscriber{ - t: &task.Task{ - Rule: &task.Rule{Sampler: &task.Sampler{SampleRate: floatPtr(0.5)}}, - }, - } + fullRate := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 1.0}}} + zeroRate := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 0}}} + halfRate := &spanSubscriber{t: &entity.ObservabilityTask{Sampler: &entity.Sampler{SampleRate: 0.5}}} require.Len(t, impl.applySampling(spans, fullRate), len(spans)) require.Nil(t, impl.applySampling(spans, zeroRate)) From 0b3ff458db9452616790dc247dbe2d9f62accbc5 Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Thu, 6 Nov 2025 12:37:44 +0800 Subject: [PATCH 09/19] TraceHub refactor. Change-Id: I1aceb52234a41337f9ae3fc8ae47fd7d722ddeb4 --- .../modules/observability/application/task.go | 11 +- .../observability/application/wire_gen.go | 2 +- .../domain/component/config/config.go | 1 + .../task/service/taskexe/tracehub/backfill.go | 181 +++++++++--------- .../service/taskexe/tracehub/backfill_test.go | 82 ++++++-- .../service/taskexe/tracehub/local_cache.go | 54 ++++++ .../tracehub/mocks/trace_hub_service.go | 9 +- .../taskexe/tracehub/scheduled_task.go | 33 +--- .../service/taskexe/tracehub/span_trigger.go | 90 ++++----- .../taskexe/tracehub/span_trigger_test.go | 4 +- .../service/taskexe/tracehub/subscriber.go | 14 +- .../service/taskexe/tracehub/trace_hub.go | 22 +-- .../domain/trace/entity/loop_span/span.go | 2 + .../observability/infra/config/trace.go | 9 + 14 files changed, 284 insertions(+), 230 deletions(-) create mode 100644 backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go diff --git a/backend/modules/observability/application/task.go b/backend/modules/observability/application/task.go index 4dc9ca8f2..ef7120e2c 100644 --- a/backend/modules/observability/application/task.go +++ b/backend/modules/observability/application/task.go @@ -261,7 +261,15 @@ func (t *TaskApplication) GetTask(ctx context.Context, req *task.GetTaskRequest) } func (t *TaskApplication) SpanTrigger(ctx context.Context, event *entity.RawSpan) error { - return t.tracehubSvc.SpanTrigger(ctx, event) + span := event.RawSpanConvertToLoopSpan() + if span != nil { + if err := t.tracehubSvc.SpanTrigger(ctx, span); err != nil { + logs.CtxError(ctx, "SpanTrigger err:%v", err) + // span trigger 失败,不处理 + return nil + } + } + return nil } func (t *TaskApplication) AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error { @@ -290,5 +298,6 @@ func (t *TaskApplication) BackFill(ctx context.Context, event *entity.BackFillEv // 结构校验失败,不处理 return nil } + return t.tracehubSvc.BackFill(ctx, event) } diff --git a/backend/modules/observability/application/wire_gen.go b/backend/modules/observability/application/wire_gen.go index 3a52d6ce7..7101caf66 100644 --- a/backend/modules/observability/application/wire_gen.go +++ b/backend/modules/observability/application/wire_gen.go @@ -281,7 +281,7 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto } iTenantProvider := tenant.NewTenantProvider(iTraceConfig) iLocker := NewTaskLocker(redis3) - iTraceHubService, err := tracehub.NewTraceHubImpl(iTaskRepo, iTraceRepo, iTenantProvider, traceFilterProcessorBuilder, processorTaskProcessor, aid, iBackfillProducer, iLocker, iConfigLoader) + iTraceHubService, err := tracehub.NewTraceHubImpl(iTaskRepo, iTraceRepo, iTenantProvider, traceFilterProcessorBuilder, processorTaskProcessor, aid, iBackfillProducer, iLocker, iTraceConfig) if err != nil { return nil, err } diff --git a/backend/modules/observability/domain/component/config/config.go b/backend/modules/observability/domain/component/config/config.go index f587b3188..cb0a29255 100644 --- a/backend/modules/observability/domain/component/config/config.go +++ b/backend/modules/observability/domain/component/config/config.go @@ -128,6 +128,7 @@ type ITraceConfig interface { GetQueryMaxQPS(ctx context.Context, key string) (int, error) GetKeySpanTypes(ctx context.Context) map[string][]string GetBackfillMqProducerCfg(ctx context.Context) (*MqProducerCfg, error) + GetConsumerListening(ctx context.Context) (*ConsumerListening, error) conf.IConfigLoader } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index 3fe1603ef..dafb4de7f 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -156,8 +156,37 @@ func (h *TraceHubServiceImpl) listAndSendSpans(ctx context.Context, sub *spanSub if sub.tr.BackfillDetail != nil && sub.tr.BackfillDetail.LastSpanPageToken != nil { listParam.PageToken = *sub.tr.BackfillDetail.LastSpanPageToken } - // Paginate query and send data - return h.fetchAndSendSpans(ctx, listParam, sub) + + totalCount := int64(0) + for { + logs.CtxInfo(ctx, "TaskID: %d, ListSpansParam:%v", sub.t.ID, listParam) + spans, pageToken, err := h.fetchSpans(ctx, listParam, sub) + if err != nil { + logs.CtxError(ctx, "list spans failed, task_id=%d, err=%v", sub.t.ID, err) + return err + } + + err, shouldFinish := h.flushSpans(ctx, spans, sub) + if err != nil { + return err + } + + totalCount += int64(len(spans)) + logs.CtxInfo(ctx, "Processed %d spans completed, total=%d, task_id=%d", len(spans), totalCount, sub.t.ID) + + if pageToken == "" || shouldFinish { + logs.CtxInfo(ctx, "no more spans to process, task_id=%d", sub.t.ID) + if err = sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ + Task: sub.t, + TaskRun: sub.tr, + IsFinish: false, + }); err != nil { + return err + } + return nil + } + listParam.PageToken = pageToken + } } type ListSpansReq struct { @@ -257,84 +286,62 @@ func (h *TraceHubServiceImpl) combineFilters(filters ...*loop_span.FilterFields) return filterAggr } -// fetchAndSendSpans paginates and sends span data -func (h *TraceHubServiceImpl) fetchAndSendSpans(ctx context.Context, listParam *repo.ListSpansParam, sub *spanSubscriber) error { - totalCount := int64(0) - pageToken := listParam.PageToken - for { - logs.CtxInfo(ctx, "TaskID: %d, ListSpansParam:%v", sub.t.ID, listParam) - result, err := h.traceRepo.ListSpans(ctx, listParam) - if err != nil { - logs.CtxError(ctx, "List spans failed, task_id=%d, page_token=%s, err=%v", sub.t.ID, pageToken, err) - return err - } - logs.CtxInfo(ctx, "Fetch %d spans, total=%d, task_id=%d", len(result.Spans), totalCount, sub.t.ID) - - spans := result.Spans - processors, err := h.buildHelper.BuildGetTraceProcessors(ctx, span_processor.Settings{ - WorkspaceId: sub.t.WorkspaceID, - PlatformType: sub.t.SpanFilter.PlatformType, - QueryStartTime: listParam.StartAt, - QueryEndTime: listParam.EndAt, - }) - if err != nil { - return errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) - } - for _, p := range processors { - spans, err = p.Transform(ctx, spans) - if err != nil { - return errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) - } - } - - if len(spans) > 0 { - flush := &flushReq{ - retrievedSpanCount: int64(len(spans)), - pageToken: result.PageToken, - spans: spans, - noMore: !result.HasMore, - } - - if err = h.flushSpans(ctx, flush, sub); err != nil { - if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - return err - } - } +// fetchSpans paginates span data +func (h *TraceHubServiceImpl) fetchSpans(ctx context.Context, listParam *repo.ListSpansParam, + sub *spanSubscriber) ([]*loop_span.Span, string, error) { + result, err := h.traceRepo.ListSpans(ctx, listParam) + if err != nil { + logs.CtxError(ctx, "List spans failed, parma=%v, err=%v", listParam, err) + return nil, "", err + } + logs.CtxInfo(ctx, "Fetch %d spans", len(result.Spans)) + spans := result.Spans + if len(spans) == 0 { + return nil, "", nil + } - totalCount += int64(len(spans)) - logs.CtxInfo(ctx, "Processed %d spans completed, total=%d, task_id=%d", len(spans), totalCount, sub.t.ID) + processors, err := h.buildHelper.BuildGetTraceProcessors(ctx, span_processor.Settings{ + WorkspaceId: sub.t.WorkspaceID, + PlatformType: sub.t.SpanFilter.PlatformType, + QueryStartTime: listParam.StartAt, + QueryEndTime: listParam.EndAt, + }) + if err != nil { + return nil, "", errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) + } + for _, p := range processors { + spans, err = p.Transform(ctx, spans) + if err != nil { + return nil, "", errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) } + } - if !result.HasMore { - logs.CtxInfo(ctx, "Completed listing spans, total_count=%d, task_id=%d", totalCount, sub.t.ID) - break - } - listParam.PageToken = result.PageToken - pageToken = result.PageToken + if !result.HasMore { + logs.CtxInfo(ctx, "Completed listing spans, task_id=%d", sub.t.ID) + return spans, "", nil } - return nil + return spans, result.PageToken, nil } -func (h *TraceHubServiceImpl) flushSpans(ctx context.Context, fr *flushReq, sub *spanSubscriber) error { - if fr == nil || len(fr.spans) == 0 { - return nil +func (h *TraceHubServiceImpl) flushSpans(ctx context.Context, spans []*loop_span.Span, sub *spanSubscriber) (err error, shouldFinish bool) { + logs.CtxInfo(ctx, "Start processing %d spans for backfill, task_id=%d", len(spans), sub.t.ID) + if len(spans) == 0 { + return nil, false } - logs.CtxInfo(ctx, "Start processing %d spans for backfill, task_id=%d", len(fr.spans), sub.t.ID) - // Apply sampling logic - sampledSpans := h.applySampling(fr.spans, sub) + sampledSpans := h.applySampling(spans, sub) if len(sampledSpans) == 0 { logs.CtxInfo(ctx, "no spans after sampling, task_id=%d", sub.t.ID) - return nil + return nil, false } // Execute specific business logic - err := h.processSpansForBackfill(ctx, sampledSpans, sub) + err, shouldFinish = h.processSpansForBackfill(ctx, sampledSpans, sub) if err != nil { logs.CtxError(ctx, "process spans failed, task_id=%d, err=%v", sub.t.ID, err) - return err + return } // todo 不应该这里直接写po字段 @@ -343,22 +350,12 @@ func (h *TraceHubServiceImpl) flushSpans(ctx context.Context, fr *flushReq, sub }) if err != nil { logs.CtxError(ctx, "update task run failed, task_id=%d, err=%v", sub.t.ID, err) - return err - } - if fr.noMore { - logs.CtxInfo(ctx, "no more spans to process, task_id=%d", sub.t.ID) - if err = sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ - Task: sub.t, - TaskRun: sub.tr, - IsFinish: false, - }); err != nil { - return err - } + return } logs.CtxInfo(ctx, "successfully processed %d spans (sampled from %d), task_id=%d", - len(sampledSpans), len(fr.spans), sub.t.ID) - return nil + len(sampledSpans), len(spans), sub.t.ID) + return } // applySampling applies sampling logic @@ -391,7 +388,7 @@ func (h *TraceHubServiceImpl) applySampling(spans []*loop_span.Span, sub *spanSu } // processSpansForBackfill handles spans for backfill -func (h *TraceHubServiceImpl) processSpansForBackfill(ctx context.Context, spans []*loop_span.Span, sub *spanSubscriber) error { +func (h *TraceHubServiceImpl) processSpansForBackfill(ctx context.Context, spans []*loop_span.Span, sub *spanSubscriber) (err error, shouldFinish bool) { // Batch processing spans for efficiency const batchSize = 50 @@ -402,46 +399,40 @@ func (h *TraceHubServiceImpl) processSpansForBackfill(ctx context.Context, spans } batch := spans[i:end] - if err := h.processBatchSpans(ctx, batch, sub); err != nil { + err, shouldFinish = h.processBatchSpans(ctx, batch, sub) + if err != nil { logs.CtxError(ctx, "process batch spans failed, task_id=%d, batch_start=%d, err=%v", sub.t.ID, i, err) - // Continue with the next batch without stopping due to a single failure - continue + return + } + if shouldFinish { + return } // ml_flow rate-limited: 50/5s time.Sleep(5 * time.Second) } - return nil + return err, shouldFinish } // processBatchSpans processes a batch of span data -func (h *TraceHubServiceImpl) processBatchSpans(ctx context.Context, spans []*loop_span.Span, sub *spanSubscriber) error { +func (h *TraceHubServiceImpl) processBatchSpans(ctx context.Context, spans []*loop_span.Span, sub *spanSubscriber) (err error, shouldFinish bool) { for _, span := range spans { // Execute processing logic according to the task type logs.CtxInfo(ctx, "processing span for backfill, span_id=%s, trace_id=%s, task_id=%d", span.SpanID, span.TraceID, sub.t.ID) taskCount, _ := h.taskRepo.GetTaskCount(ctx, sub.taskID) - taskRunCount, _ := h.taskRepo.GetTaskRunCount(ctx, sub.taskID, sub.tr.ID) sampler := sub.t.Sampler if taskCount+1 > sampler.SampleSize { logs.CtxInfo(ctx, "taskCount+1 > sampler.GetSampleSize(), task_id=%d,SampleSize=%d", sub.taskID, sampler.SampleSize) - if err := sub.processor.OnTaskFinished(ctx, taskexe.OnTaskFinishedReq{ - Task: sub.t, - TaskRun: sub.tr, - IsFinish: true, - }); err != nil { - return err - } - break + return nil, true } - logs.CtxInfo(ctx, "preDispatch, task_id=%d, taskCount=%d, taskRunCount=%d", sub.taskID, taskCount, taskRunCount) - if err := h.dispatch(ctx, span, []*spanSubscriber{sub}); err != nil { - return err + if err = h.dispatch(ctx, span, []*spanSubscriber{sub}); err != nil { + return err, false } } - return nil + return nil, false } // onHandleDone handles completion callback diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go index 0aa8ca8a3..9c5cd706a 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill_test.go @@ -19,6 +19,7 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" taskrepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" repo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo/mocks" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" trepo_mocks "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo/mocks" @@ -145,14 +146,24 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_TaskLimit(t *testing.T) { runType: entity.TaskRunTypeBackFill, } - mockRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(1), nil) - mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), int64(1), int64(10)).Return(int64(0), nil) + mockRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil).AnyTimes() + mockRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(&entity.TaskRun{ + ID: 10, + TaskID: 1, + WorkspaceID: 2, + TaskType: entity.TaskRunTypeBackFill, + RunStatus: entity.TaskRunStatusRunning, + RunStartAt: time.Now().Add(-time.Minute), + RunEndAt: time.Now().Add(time.Minute), + }, nil) - spans := []*loop_span.Span{{SpanID: "span-1"}} + spans := []*loop_span.Span{{SpanID: "span-1", StartTime: time.Now().UnixMilli()}} ctx := context.Background() - require.NoError(t, impl.processBatchSpans(ctx, spans, sub)) - require.Equal(t, 1, proc.finishChangeInvoked) + err, shouldFinish := impl.processBatchSpans(ctx, spans, sub) + require.NoError(t, err) + require.False(t, shouldFinish) + require.True(t, proc.invokeCalled) } func TestTraceHubServiceImpl_ProcessBatchSpans_DispatchError(t *testing.T) { @@ -210,12 +221,11 @@ func TestTraceHubServiceImpl_ProcessBatchSpans_DispatchError(t *testing.T) { } mockRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil) - mockRepo.EXPECT().GetTaskRunCount(gomock.Any(), int64(1), int64(10)).Return(int64(0), nil) mockRepo.EXPECT().GetLatestNewDataTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(spanRun, nil) spans := []*loop_span.Span{{SpanID: "span-1", StartTime: now.Add(10 * time.Millisecond).UnixMilli(), WorkspaceID: "space", TraceID: "trace"}} - err := impl.processBatchSpans(context.Background(), spans, sub) + err, _ := impl.processBatchSpans(context.Background(), spans, sub) require.Error(t, err) require.ErrorContains(t, err, "invoke fail") } @@ -332,7 +342,6 @@ func TestTraceHubServiceImpl_ListAndSendSpans_Success(t *testing.T) { }) mockTaskRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil) - mockTaskRepo.EXPECT().GetTaskRunCount(gomock.Any(), int64(1), sub.tr.ID).Return(int64(0), nil) mockTaskRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(domainRun, nil) mockTaskRepo.EXPECT().UpdateTaskRunWithOCC(gomock.Any(), sub.tr.ID, sub.tr.WorkspaceID, gomock.AssignableToTypeOf(map[string]interface{}{})).Return(nil) @@ -344,23 +353,35 @@ func TestTraceHubServiceImpl_ListAndSendSpans_Success(t *testing.T) { require.Equal(t, "prev", ptr.From(sub.tr.BackfillDetail.LastSpanPageToken)) } -func TestTraceHubServiceImpl_FetchAndSendSpans_ListError(t *testing.T) { +func TestTraceHubServiceImpl_ListAndSendSpans_ListError(t *testing.T) { ctrl := gomock.NewController(t) t.Cleanup(ctrl.Finish) mockTaskRepo := repo_mocks.NewMockITaskRepo(ctrl) mockTraceRepo := trepo_mocks.NewMockITraceRepo(ctrl) + mockTenant := tenant_mocks.NewMockITenantProvider(ctrl) + mockBuilder := builder_mocks.NewMockTraceFilterProcessorBuilder(ctrl) + filterMock := spanfilter_mocks.NewMockFilter(ctrl) + impl := &TraceHubServiceImpl{ - taskRepo: mockTaskRepo, - traceRepo: mockTraceRepo, + taskRepo: mockTaskRepo, + traceRepo: mockTraceRepo, + tenantProvider: mockTenant, + buildHelper: mockBuilder, } now := time.Now() sub, _ := newBackfillSubscriber(mockTaskRepo, now) + mockBuilder.EXPECT().BuildPlatformRelatedFilter(gomock.Any(), loop_span.PlatformType(common.PlatformTypeCozeBot)). + Return(filterMock, nil) + filterMock.EXPECT().BuildBasicSpanFilter(gomock.Any(), gomock.Any()).Return([]*loop_span.FilterField{}, true, nil) + filterMock.EXPECT().BuildRootSpanFilter(gomock.Any(), gomock.Any()).Return([]*loop_span.FilterField{}, nil) + mockTenant.EXPECT().GetTenantsByPlatformType(gomock.Any(), loop_span.PlatformType(common.PlatformTypeCozeBot)).Return([]string{"tenant"}, nil) + mockTraceRepo.EXPECT().ListSpans(gomock.Any(), gomock.Any()).Return(nil, errors.New("list failed")) - err := impl.fetchAndSendSpans(context.Background(), &repo.ListSpansParam{Tenants: []string{"tenant"}}, sub) + err := impl.listAndSendSpans(context.Background(), sub) require.Error(t, err) } @@ -369,7 +390,18 @@ func TestTraceHubServiceImpl_FlushSpans_ContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - err := impl.flushSpans(ctx, &flushReq{}, &spanSubscriber{}) + sub := &spanSubscriber{ + t: &entity.ObservabilityTask{ + ID: 1, + WorkspaceID: 1, + }, + tr: &entity.TaskRun{ + ID: 1, + WorkspaceID: 1, + }, + } + + err, _ := impl.flushSpans(ctx, []*loop_span.Span{}, sub) require.NoError(t, err) } @@ -386,11 +418,10 @@ func TestTraceHubServiceImpl_DoFlush_UpdateTaskRunError(t *testing.T) { domainRun := newDomainBackfillTaskRun(now) mockTaskRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil) - mockTaskRepo.EXPECT().GetTaskRunCount(gomock.Any(), int64(1), sub.tr.ID).Return(int64(0), nil) mockTaskRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(domainRun, nil) mockTaskRepo.EXPECT().UpdateTaskRunWithOCC(gomock.Any(), sub.tr.ID, sub.tr.WorkspaceID, gomock.AssignableToTypeOf(map[string]interface{}{})).Return(errors.New("update fail")) - err := impl.flushSpans(context.Background(), &flushReq{retrievedSpanCount: 1, pageToken: "token", spans: []*loop_span.Span{span}}, sub) + err, _ := impl.flushSpans(context.Background(), []*loop_span.Span{span}, sub) require.Error(t, err) require.ErrorContains(t, err, "update fail") require.True(t, proc.invokeCalled) @@ -410,13 +441,21 @@ func TestTraceHubServiceImpl_DoFlush_NoMoreFinishError(t *testing.T) { domainRun := newDomainBackfillTaskRun(now) mockTaskRepo.EXPECT().GetTaskCount(gomock.Any(), int64(1)).Return(int64(0), nil) - mockTaskRepo.EXPECT().GetTaskRunCount(gomock.Any(), int64(1), sub.tr.ID).Return(int64(0), nil) mockTaskRepo.EXPECT().GetBackfillTaskRun(gomock.Any(), gomock.Nil(), int64(1)).Return(domainRun, nil) mockTaskRepo.EXPECT().UpdateTaskRunWithOCC(gomock.Any(), sub.tr.ID, sub.tr.WorkspaceID, gomock.AssignableToTypeOf(map[string]interface{}{})).Return(nil) - err := impl.flushSpans(context.Background(), &flushReq{retrievedSpanCount: 1, pageToken: "token", spans: []*loop_span.Span{span}, noMore: true}, sub) - require.Error(t, err) - require.ErrorContains(t, err, "finish fail") + // 调用flushSpans,然后手动调用OnTaskFinished来触发finish错误 + err, _ := impl.flushSpans(context.Background(), []*loop_span.Span{span}, sub) + require.NoError(t, err) // flushSpans本身不应该返回错误 + + // 手动调用OnTaskFinished来触发finish错误 + finishErr := sub.processor.OnTaskFinished(context.Background(), taskexe.OnTaskFinishedReq{ + Task: sub.t, + TaskRun: sub.tr, + IsFinish: true, + }) + require.Error(t, finishErr) + require.ErrorContains(t, finishErr, "finish fail") require.True(t, proc.invokeCalled) } @@ -427,9 +466,10 @@ func TestTraceHubServiceImpl_FlushSpans_SamplingZero(t *testing.T) { Sampler: &entity.Sampler{SampleRate: 0}, }, } - fr := &flushReq{retrievedSpanCount: 2, spans: []*loop_span.Span{{SpanID: "s1"}, {SpanID: "s2"}}} + spans := []*loop_span.Span{{SpanID: "s1"}, {SpanID: "s2"}} - require.NoError(t, impl.flushSpans(context.Background(), fr, sub)) + err, _ := impl.flushSpans(context.Background(), spans, sub) + require.NoError(t, err) } func TestTraceHubServiceImpl_IsBackfillDone(t *testing.T) { diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go new file mode 100644 index 000000000..ece5e9eac --- /dev/null +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go @@ -0,0 +1,54 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package tracehub + +import ( + "context" + "sync" + "time" + + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/pkg/logs" +) + +const CacheKeyObjListWithTask = "ObjListWithTask" + +// TaskCacheInfo represents task cache information +type TaskCacheInfo struct { + WorkspaceIDs []string + BotIDs []string + Tasks []*entity.ObservabilityTask + UpdateTime time.Time +} + +type LocalCache struct { + taskCache sync.Map +} + +func NewLocalCache() *LocalCache { + return &LocalCache{} +} + +func (l *LocalCache) StoneTaskCache(info TaskCacheInfo) { + l.taskCache.Store(CacheKeyObjListWithTask, info) +} + +func (l *LocalCache) LoadTaskCache(ctx context.Context) TaskCacheInfo { + // First, try to retrieve tasks from cache + objListWithTask, ok := l.taskCache.Load(CacheKeyObjListWithTask) + if !ok { + // Cache is empty, fallback to the database + logs.CtxError(ctx, "Cache is empty, retrieving task list from database") + return TaskCacheInfo{} + } + + cacheInfo, ok := objListWithTask.(TaskCacheInfo) + if !ok { + logs.CtxError(ctx, "Cache data type mismatch") + return TaskCacheInfo{} + } + + logs.CtxInfo(ctx, "Retrieve task list from cache, taskCount=%d, spaceCount=%d, botCount=%d", len(cacheInfo.Tasks), len(cacheInfo.WorkspaceIDs), len(cacheInfo.BotIDs)) + return cacheInfo +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go index 50fb666f5..db391b2af 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/mocks/trace_hub_service.go @@ -14,6 +14,7 @@ import ( reflect "reflect" entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + loop_span "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" gomock "go.uber.org/mock/gomock" ) @@ -56,15 +57,15 @@ func (mr *MockITraceHubServiceMockRecorder) BackFill(ctx, event any) *gomock.Cal } // SpanTrigger mocks base method. -func (m *MockITraceHubService) SpanTrigger(ctx context.Context, event *entity.RawSpan) error { +func (m *MockITraceHubService) SpanTrigger(ctx context.Context, span *loop_span.Span) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SpanTrigger", ctx, event) + ret := m.ctrl.Call(m, "SpanTrigger", ctx, span) ret0, _ := ret[0].(error) return ret0 } // SpanTrigger indicates an expected call of SpanTrigger. -func (mr *MockITraceHubServiceMockRecorder) SpanTrigger(ctx, event any) *gomock.Call { +func (mr *MockITraceHubServiceMockRecorder) SpanTrigger(ctx, span any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpanTrigger", reflect.TypeOf((*MockITraceHubService)(nil).SpanTrigger), ctx, event) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SpanTrigger", reflect.TypeOf((*MockITraceHubService)(nil).SpanTrigger), ctx, span) } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go index 8c544bf3c..b2a7d75d6 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go @@ -11,7 +11,6 @@ import ( "strconv" "time" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" @@ -31,14 +30,6 @@ type TaskRunCountInfo struct { TaskRunFailCount int64 } -// TaskCacheInfo represents task cache information -type TaskCacheInfo struct { - WorkspaceIDs []string - BotIDs []string - Tasks []*entity.ObservabilityTask - UpdateTime time.Time -} - const ( transformTaskStatusLockKey = "observability:tracehub:transform_task_status" transformTaskStatusLockTTL = 3 * time.Minute @@ -78,9 +69,11 @@ func (h *TraceHubServiceImpl) startScheduledTask() { } func (h *TraceHubServiceImpl) transformTaskStatus() { - const key = "consumer_listening" - cfg := &config.ConsumerListening{} - if err := h.loader.UnmarshalKey(context.Background(), key, cfg); err != nil { + ctx := context.Background() + ctx = h.fillCtx(ctx) + + cfg, err := h.config.GetConsumerListening(ctx) + if err != nil { return } if !cfg.IsEnabled || !cfg.IsAllSpace { @@ -90,8 +83,6 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { if slices.Contains([]string{TracehubClusterName, InjectClusterName}, os.Getenv(TceCluster)) { return } - ctx := context.Background() - ctx = h.fillCtx(ctx) if h.locker != nil { locked, lockErr := h.locker.Lock(ctx, transformTaskStatusLockKey, transformTaskStatusLockTTL) @@ -325,22 +316,12 @@ func (h *TraceHubServiceImpl) syncTaskCache() { } logs.CtxInfo(ctx, "Retrieved task information, taskCount:%d, spaceCount:%d, botCount:%d", len(tasks), len(spaceIDs), len(botIDs)) - // 2. Build a new cache map - newCache := TaskCacheInfo{ + h.localCache.StoneTaskCache(TaskCacheInfo{ WorkspaceIDs: spaceIDs, BotIDs: botIDs, Tasks: tasks, UpdateTime: time.Now(), // Set the current time as the update time - } - - // 3. Clear old cache and update with new cache - h.taskCacheLock.Lock() - defer h.taskCacheLock.Unlock() - - // 4. Write new cache into local cache - h.taskCache.Store("ObjListWithTask", newCache) - - logs.CtxInfo(ctx, "Task cache sync completed, taskCount:%d, updateTime:%s", len(tasks), newCache.UpdateTime.Format(time.RFC3339)) + }) } // processBatch synchronizes TaskRun counts in batches diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go index d63cab226..cee637a55 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go @@ -9,8 +9,6 @@ import ( "time" "github.com/bytedance/gg/gslice" - "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/domain/task" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" @@ -19,57 +17,62 @@ import ( "github.com/pkg/errors" ) -func (h *TraceHubServiceImpl) SpanTrigger(ctx context.Context, rawSpan *entity.RawSpan) error { +func (h *TraceHubServiceImpl) SpanTrigger(ctx context.Context, span *loop_span.Span) error { ctx = h.fillCtx(ctx) - logSuffix := fmt.Sprintf("log_id=%s, trace_id=%s, span_id=%s", rawSpan.LogID, rawSpan.TraceID, rawSpan.SpanID) - logs.CtxInfo(ctx, "auto_task start, log_suffix=%s", logSuffix) - // 1、Convert to standard span and perform initial filtering based on space_id - span := rawSpan.RawSpanConvertToLoopSpan() + logSuffix := fmt.Sprintf("log_id=%s, trace_id=%s, span_id=%s", span.LogID, span.TraceID, span.SpanID) + logs.CtxInfo(ctx, "auto_task start, %s", logSuffix) + + // 1. perform initial filtering based on space_id // 1.1 Filter out spans that do not belong to any space or bot - spaceIDs, botIDs, _ := h.getObjListWithTaskFromCache(ctx) + cacheInfo := h.localCache.LoadTaskCache(ctx) + spaceIDs, botIDs := cacheInfo.WorkspaceIDs, cacheInfo.BotIDs if !gslice.Contains(spaceIDs, span.WorkspaceID) && !gslice.Contains(botIDs, span.TagsString["bot_id"]) { - logs.CtxInfo(ctx, "no space or bot found for span, space_id=%s,bot_id=%s, log_suffix=%s", span.WorkspaceID, span.TagsString["bot_id"], logSuffix) + logs.CtxInfo(ctx, "no space or bot found for span, space_id=%s, bot_id=%s, %s", span.WorkspaceID, span.TagsString["bot_id"], logSuffix) return nil } // 1.2 Filter out spans of type Evaluator - if gslice.Contains([]string{"Evaluator"}, span.CallType) { + if gslice.Contains([]string{loop_span.CallTypeEvaluator}, span.CallType) { return nil } + // 2、Match spans against task rules - subs, err := h.getSubscriberOfSpan(ctx, span) + subs, err := h.buildSubscriberOfSpan(ctx, span) if err != nil { logs.CtxWarn(ctx, "get subscriber of flow span failed, %s, err: %v", logSuffix, err) + return err } logs.CtxInfo(ctx, "%d subscriber of flow span found, %s", len(subs), logSuffix) if len(subs) == 0 { return nil } + // 3、Sample subs = gslice.Filter(subs, func(sub *spanSubscriber) bool { return sub.Sampled() }) logs.CtxInfo(ctx, "%d subscriber of flow span sampled, %s", len(subs), logSuffix) if len(subs) == 0 { return nil } - // 3. PreDispatch - err = h.preDispatch(ctx, span, subs) - if err != nil { + + // 4. PreDispatch + if err = h.preDispatch(ctx, subs); err != nil { logs.CtxWarn(ctx, "preDispatch flow span failed, %s, err: %v", logSuffix, err) + return err } logs.CtxInfo(ctx, "%d preDispatch success, %v", len(subs), subs) - // 4、Dispatch + + // 5、Dispatch if err = h.dispatch(ctx, span, subs); err != nil { logs.CtxError(ctx, "dispatch flow span failed, %s, err: %v", logSuffix, err) - // Dispatch failed, continue to the next span - return nil + return err } return nil } -func (h *TraceHubServiceImpl) getSubscriberOfSpan(ctx context.Context, span *loop_span.Span) ([]*spanSubscriber, error) { - const key = "consumer_listening" - cfg := &config.ConsumerListening{} - if err := h.loader.UnmarshalKey(ctx, key, cfg); err != nil { +func (h *TraceHubServiceImpl) buildSubscriberOfSpan(ctx context.Context, span *loop_span.Span) ([]*spanSubscriber, error) { + cfg, err := h.config.GetConsumerListening(ctx) + if err != nil { + logs.CtxError(ctx, "Failed to get consumer listening config, err: %v", err) return nil, err } @@ -83,7 +86,15 @@ func (h *TraceHubServiceImpl) getSubscriberOfSpan(ctx context.Context, span *loo if !cfg.IsAllSpace && !gslice.Contains(cfg.SpaceList, taskDO.WorkspaceID) { continue } - proc := h.taskProcessor.GetTaskProcessor(entity.TaskType(taskDO.TaskType)) + if taskDO.EffectiveTime == nil || taskDO.EffectiveTime.StartAt == 0 { + continue + } + if span.StartTime < taskDO.EffectiveTime.StartAt { + logs.CtxInfo(ctx, "span start time is before task cycle start time, trace_id=%s, span_id=%s", span.TraceID, span.SpanID) + continue + } + + proc := h.taskProcessor.GetTaskProcessor(taskDO.TaskType) subscribers = append(subscribers, &spanSubscriber{ taskID: taskDO.ID, t: taskDO, @@ -114,16 +125,9 @@ func (h *TraceHubServiceImpl) getSubscriberOfSpan(ctx context.Context, span *loo return subscribers[:keep], merr.ErrorOrNil() } -func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.Span, subs []*spanSubscriber) error { +func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, subs []*spanSubscriber) error { merr := &multierror.Error{} for _, sub := range subs { - if sub.t.EffectiveTime == nil || sub.t.EffectiveTime.StartAt == 0 { - continue - } - if span.StartTime < sub.t.EffectiveTime.StartAt { - logs.CtxWarn(ctx, "span start time is before task cycle start time, trace_id=%s, span_id=%s", span.TraceID, span.SpanID) - continue - } // First step: lock for task status change // Task run status var runStartAt, runEndAt int64 @@ -164,9 +168,9 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S runEndAt = sub.t.EffectiveTime.EndAt } else { switch sub.t.Sampler.CycleTimeUnit { - case task.TimeUnitDay: + case entity.TimeUnitDay: runEndAt = runStartAt + sub.t.Sampler.CycleInterval*24*time.Hour.Milliseconds() - case task.TimeUnitWeek: + case entity.TimeUnitWeek: runEndAt = runStartAt + sub.t.Sampler.CycleInterval*7*24*time.Hour.Milliseconds() default: runEndAt = runStartAt + sub.t.Sampler.CycleInterval*10*time.Minute.Milliseconds() @@ -248,7 +252,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, span *loop_span.S func (h *TraceHubServiceImpl) dispatch(ctx context.Context, span *loop_span.Span, subs []*spanSubscriber) error { merr := &multierror.Error{} for _, sub := range subs { - if sub.t.TaskStatus != task.TaskStatusRunning { + if sub.t.TaskStatus != entity.TaskStatusRunning { continue } logs.CtxInfo(ctx, " sub.AddSpan: %v", sub) @@ -262,23 +266,3 @@ func (h *TraceHubServiceImpl) dispatch(ctx context.Context, span *loop_span.Span } return merr.ErrorOrNil() } - -// getObjListWithTaskFromCache retrieves the task list from cache, falling back to the database if cache is empty -func (h *TraceHubServiceImpl) getObjListWithTaskFromCache(ctx context.Context) ([]string, []string, []*entity.ObservabilityTask) { - // First, try to retrieve tasks from cache - objListWithTask, ok := h.taskCache.Load("ObjListWithTask") - if !ok { - // Cache is empty, fallback to the database - logs.CtxError(ctx, "Cache is empty, retrieving task list from database") - return nil, nil, nil - } - - cacheInfo, ok := objListWithTask.(TaskCacheInfo) - if !ok { - logs.CtxError(ctx, "Cache data type mismatch") - return nil, nil, nil - } - - logs.CtxInfo(ctx, "Retrieve task list from cache, taskCount=%d, spaceCount=%d, botCount=%d", len(cacheInfo.Tasks), len(cacheInfo.WorkspaceIDs), len(cacheInfo.BotIDs)) - return cacheInfo.WorkspaceIDs, cacheInfo.BotIDs, cacheInfo.Tasks -} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go index 77a583346..2e4a8913d 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger_test.go @@ -46,7 +46,7 @@ func TestTraceHubServiceImpl_SpanTriggerSkipNoWorkspace(t *testing.T) { ServerEnv: &entity.ServerInRawSpan{}, } - require.NoError(t, impl.SpanTrigger(context.Background(), raw)) + require.NoError(t, impl.SpanTrigger(context.Background(), raw.RawSpanConvertToLoopSpan())) } func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { @@ -153,7 +153,7 @@ func TestTraceHubServiceImpl_SpanTriggerDispatchError(t *testing.T) { ServerEnv: &entity.ServerInRawSpan{}, } - err := impl.SpanTrigger(context.Background(), raw) + err := impl.SpanTrigger(context.Background(), raw.RawSpanConvertToLoopSpan()) require.NoError(t, err) require.True(t, proc.invokeCalled) } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go index ca4fff295..13f4f0a39 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/subscriber.go @@ -6,7 +6,6 @@ package tracehub import ( "context" "math/rand" - "sync" "time" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" @@ -22,8 +21,6 @@ import ( ) type spanSubscriber struct { - sync.RWMutex // protect t, buf - taskID int64 t *entity.ObservabilityTask tr *entity.TaskRun @@ -36,23 +33,16 @@ type spanSubscriber struct { // Sampled determines whether a span is sampled based on the sampling rate; the sample size will be validated during flush. func (s *spanSubscriber) Sampled() bool { - t := s.getTask() - if t == nil || t.Sampler == nil { + if s.t == nil || s.t.Sampler == nil { return false } const base = 10000 - threshold := int64(float64(base) * t.Sampler.SampleRate) + threshold := int64(float64(base) * s.t.Sampler.SampleRate) r := rand.Int63n(base) return r <= threshold } -func (s *spanSubscriber) getTask() *entity.ObservabilityTask { - s.RLock() - defer s.RUnlock() - return s.t -} - func combineFilters(filters ...*loop_span.FilterFields) *loop_span.FilterFields { filterAggr := &loop_span.FilterFields{ QueryAndOr: ptr.Of(loop_span.QueryAndOrEnumAnd), diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go index 9dbb313b7..6c9d5feb8 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go @@ -5,10 +5,10 @@ package tracehub import ( "context" - "sync" "time" "github.com/coze-dev/coze-loop/backend/infra/lock" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/tenant" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" @@ -17,13 +17,12 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" trace_repo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/service" - "github.com/coze-dev/coze-loop/backend/pkg/conf" ) //go:generate mockgen -destination=mocks/trace_hub_service.go -package=mocks . ITraceHubService type ITraceHubService interface { - SpanTrigger(ctx context.Context, event *entity.RawSpan) error + SpanTrigger(ctx context.Context, span *loop_span.Span) error BackFill(ctx context.Context, event *entity.BackFillEvent) error } @@ -36,7 +35,7 @@ func NewTraceHubImpl( aid int32, backfillProducer mq.IBackfillProducer, locker lock.ILocker, - loader conf.IConfigLoader, + config config.ITraceConfig, ) (ITraceHubService, error) { // Create two independent timers with different intervals scheduledTaskTicker := time.NewTicker(5 * time.Minute) // Task status lifecycle management - 5-minute interval @@ -53,7 +52,8 @@ func NewTraceHubImpl( aid: aid, backfillProducer: backfillProducer, locker: locker, - loader: loader, + config: config, + localCache: NewLocalCache(), } // Start the scheduled tasks immediately @@ -74,22 +74,14 @@ type TraceHubServiceImpl struct { buildHelper service.TraceFilterProcessorBuilder backfillProducer mq.IBackfillProducer locker lock.ILocker - loader conf.IConfigLoader + config config.ITraceConfig // Local cache - caching non-terminal task information - taskCache sync.Map - taskCacheLock sync.RWMutex + localCache *LocalCache aid int32 } -type flushReq struct { - retrievedSpanCount int64 - pageToken string - spans []*loop_span.Span - noMore bool -} - func (h *TraceHubServiceImpl) Close() { close(h.stopChan) } diff --git a/backend/modules/observability/domain/trace/entity/loop_span/span.go b/backend/modules/observability/domain/trace/entity/loop_span/span.go index 70f37e45b..9e99ef940 100644 --- a/backend/modules/observability/domain/trace/entity/loop_span/span.go +++ b/backend/modules/observability/domain/trace/entity/loop_span/span.go @@ -77,6 +77,8 @@ const ( MaxKeySize = 100 MaxTextSize = 1024 * 1024 MaxCommonValueSize = 1024 + + CallTypeEvaluator = "Evaluator" ) type TTL string diff --git a/backend/modules/observability/infra/config/trace.go b/backend/modules/observability/infra/config/trace.go index 6732147fb..49cf7bc4d 100644 --- a/backend/modules/observability/infra/config/trace.go +++ b/backend/modules/observability/infra/config/trace.go @@ -26,6 +26,7 @@ const ( queryTraceRateLimitCfgKey = "query_trace_rate_limit_config" keySpanTypeCfgKey = "key_span_type" backfillMqProducerCfgKey = "backfill_mq_producer_config" + consumerListeningCfgKey = "consumer_listening" ) type TraceConfigCenter struct { @@ -171,6 +172,14 @@ func (t *TraceConfigCenter) GetKeySpanTypes(ctx context.Context) map[string][]st return keyColumns } +func (t *TraceConfigCenter) GetConsumerListening(ctx context.Context) (*config.ConsumerListening, error) { + consumerListening := new(config.ConsumerListening) + if err := t.UnmarshalKey(ctx, consumerListeningCfgKey, &consumerListening); err != nil { + return nil, err + } + return consumerListening, nil +} + func NewTraceConfigCenter(confP conf.IConfigLoader) config.ITraceConfig { ret := &TraceConfigCenter{ IConfigLoader: confP, From eddb1b32ba87c2965dfe27aefea00b8f51574de0 Mon Sep 17 00:00:00 2001 From: "zhaoxun.3233" Date: Fri, 7 Nov 2025 18:26:09 +0800 Subject: [PATCH 10/19] add topic proc --- .../modules/observability/application/task.go | 7 +- .../modules/observability/application/wire.go | 1 + .../observability/application/wire_gen.go | 2 +- .../domain/component/config/config.go | 1 + .../domain/component/mq/span_producer.go | 15 ++++ .../domain/trace/entity/event.go | 4 + .../observability/infra/config/trace.go | 37 ++++---- .../infra/mq/consumer/consumer.go | 1 + .../consumer/span_with_annotation_consumer.go | 58 +++++++++++++ .../infra/mq/consumer/task_consumer.go | 2 +- .../producer/span_with_annotation_producer.go | 84 +++++++++++++++++++ 11 files changed, 193 insertions(+), 19 deletions(-) create mode 100644 backend/modules/observability/domain/component/mq/span_producer.go create mode 100644 backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go create mode 100644 backend/modules/observability/infra/mq/producer/span_with_annotation_producer.go diff --git a/backend/modules/observability/application/task.go b/backend/modules/observability/application/task.go index ef7120e2c..87730faef 100644 --- a/backend/modules/observability/application/task.go +++ b/backend/modules/observability/application/task.go @@ -17,6 +17,7 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/logs" @@ -24,7 +25,7 @@ import ( ) type ITaskQueueConsumer interface { - SpanTrigger(ctx context.Context, event *entity.RawSpan) error + SpanTrigger(ctx context.Context, rawSpan *entity.RawSpan, Span *loop_span.Span) error AutoEvalCallback(ctx context.Context, event *entity.AutoEvalEvent) error AutoEvalCorrection(ctx context.Context, event *entity.CorrectionEvent) error BackFill(ctx context.Context, event *entity.BackFillEvent) error @@ -260,8 +261,8 @@ func (t *TaskApplication) GetTask(ctx context.Context, req *task.GetTaskRequest) }, nil } -func (t *TaskApplication) SpanTrigger(ctx context.Context, event *entity.RawSpan) error { - span := event.RawSpanConvertToLoopSpan() +func (t *TaskApplication) SpanTrigger(ctx context.Context, rawSpan *entity.RawSpan, Span *loop_span.Span) error { + span := rawSpan.RawSpanConvertToLoopSpan() if span != nil { if err := t.tracehubSvc.SpanTrigger(ctx, span); err != nil { logs.CtxError(ctx, "SpanTrigger err:%v", err) diff --git a/backend/modules/observability/application/wire.go b/backend/modules/observability/application/wire.go index ec9b6e356..c907da708 100644 --- a/backend/modules/observability/application/wire.go +++ b/backend/modules/observability/application/wire.go @@ -92,6 +92,7 @@ var ( obcollector.NewEventCollectorProvider, mq2.NewTraceProducerImpl, mq2.NewAnnotationProducerImpl, + mq2.NewSpanWithAnnotationProducerImpl, file.NewFileRPCProvider, NewTraceConfigLoader, NewTraceProcessorBuilder, diff --git a/backend/modules/observability/application/wire_gen.go b/backend/modules/observability/application/wire_gen.go index 7101caf66..03a6666c4 100644 --- a/backend/modules/observability/application/wire_gen.go +++ b/backend/modules/observability/application/wire_gen.go @@ -299,7 +299,7 @@ var ( taskDomainSet = wire.NewSet( NewInitTaskProcessor, service3.NewTaskServiceImpl, repo.NewTaskRepoImpl, mysql.NewTaskDaoImpl, redis2.NewTaskDAO, redis2.NewTaskRunDAO, mysql.NewTaskRunDaoImpl, producer.NewBackfillProducerImpl, ) - traceDomainSet = wire.NewSet(service.NewTraceServiceImpl, service.NewTraceExportServiceImpl, repo.NewTraceCKRepoImpl, ck2.NewSpansCkDaoImpl, ck2.NewAnnotationCkDaoImpl, metrics2.NewTraceMetricsImpl, collector.NewEventCollectorProvider, producer.NewTraceProducerImpl, producer.NewAnnotationProducerImpl, file.NewFileRPCProvider, NewTraceConfigLoader, + traceDomainSet = wire.NewSet(service.NewTraceServiceImpl, service.NewTraceExportServiceImpl, repo.NewTraceCKRepoImpl, ck2.NewSpansCkDaoImpl, ck2.NewAnnotationCkDaoImpl, metrics2.NewTraceMetricsImpl, collector.NewEventCollectorProvider, producer.NewTraceProducerImpl, producer.NewAnnotationProducerImpl, producer.NewSpanWithAnnotationProducerImpl, file.NewFileRPCProvider, NewTraceConfigLoader, NewTraceProcessorBuilder, config.NewTraceConfigCenter, tenant.NewTenantProvider, workspace.NewWorkspaceProvider, evaluator.NewEvaluatorRPCProvider, NewDatasetServiceAdapter, taskDomainSet, ) diff --git a/backend/modules/observability/domain/component/config/config.go b/backend/modules/observability/domain/component/config/config.go index cb0a29255..c254410e5 100644 --- a/backend/modules/observability/domain/component/config/config.go +++ b/backend/modules/observability/domain/component/config/config.go @@ -129,6 +129,7 @@ type ITraceConfig interface { GetKeySpanTypes(ctx context.Context) map[string][]string GetBackfillMqProducerCfg(ctx context.Context) (*MqProducerCfg, error) GetConsumerListening(ctx context.Context) (*ConsumerListening, error) + GetSpanWithAnnotationMqProducerCfg(ctx context.Context) (*MqProducerCfg, error) conf.IConfigLoader } diff --git a/backend/modules/observability/domain/component/mq/span_producer.go b/backend/modules/observability/domain/component/mq/span_producer.go new file mode 100644 index 000000000..2ace326c1 --- /dev/null +++ b/backend/modules/observability/domain/component/mq/span_producer.go @@ -0,0 +1,15 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package mq + +import ( + "context" + + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity" +) + +//go:generate mockgen -destination=mocks/annotation_producer.go -package=mocks . IAnnotationProducer +type ISpanProducer interface { + SendSpanWithAnnotation(ctx context.Context, message *entity.SpanEvent) error +} diff --git a/backend/modules/observability/domain/trace/entity/event.go b/backend/modules/observability/domain/trace/entity/event.go index 7cf169f8e..ccbcebc56 100644 --- a/backend/modules/observability/domain/trace/entity/event.go +++ b/backend/modules/observability/domain/trace/entity/event.go @@ -14,3 +14,7 @@ type AnnotationEvent struct { Caller string `json:"caller"` RetryTimes int64 `json:"retry_times"` } + +type SpanEvent struct { + Span *loop_span.Span `json:"span"` +} diff --git a/backend/modules/observability/infra/config/trace.go b/backend/modules/observability/infra/config/trace.go index 49cf7bc4d..fa9f63817 100644 --- a/backend/modules/observability/infra/config/trace.go +++ b/backend/modules/observability/infra/config/trace.go @@ -13,20 +13,21 @@ import ( ) const ( - systemViewsCfgKey = "trace_system_view_cfg" - platformTenantCfgKey = "trace_platform_tenants" - platformSpanHandlerCfgKey = "trace_platform_span_handler_config" - traceIngestTenantCfgKey = "trace_ingest_tenant_config" - annotationMqProducerCfgKey = "annotation_mq_producer_config" - tenantTablesCfgKey = "trace_tenant_cfg" - traceCkCfgKey = "trace_ck_cfg" - traceFieldMetaInfoCfgKey = "trace_field_meta_info" - traceMaxDurationDay = "trace_max_duration_day" - annotationSourceCfgKey = "annotation_source_cfg" - queryTraceRateLimitCfgKey = "query_trace_rate_limit_config" - keySpanTypeCfgKey = "key_span_type" - backfillMqProducerCfgKey = "backfill_mq_producer_config" - consumerListeningCfgKey = "consumer_listening" + systemViewsCfgKey = "trace_system_view_cfg" + platformTenantCfgKey = "trace_platform_tenants" + platformSpanHandlerCfgKey = "trace_platform_span_handler_config" + traceIngestTenantCfgKey = "trace_ingest_tenant_config" + annotationMqProducerCfgKey = "annotation_mq_producer_config" + spanWithAnnotationMqProducerCfgKey = "span_with_annotation_mq_producer_config" + tenantTablesCfgKey = "trace_tenant_cfg" + traceCkCfgKey = "trace_ck_cfg" + traceFieldMetaInfoCfgKey = "trace_field_meta_info" + traceMaxDurationDay = "trace_max_duration_day" + annotationSourceCfgKey = "annotation_source_cfg" + queryTraceRateLimitCfgKey = "query_trace_rate_limit_config" + keySpanTypeCfgKey = "key_span_type" + backfillMqProducerCfgKey = "backfill_mq_producer_config" + consumerListeningCfgKey = "consumer_listening" ) type TraceConfigCenter struct { @@ -75,6 +76,14 @@ func (t *TraceConfigCenter) GetAnnotationMqProducerCfg(ctx context.Context) (*co return cfg, nil } +func (t *TraceConfigCenter) GetSpanWithAnnotationMqProducerCfg(ctx context.Context) (*config.MqProducerCfg, error) { + cfg := new(config.MqProducerCfg) + if err := t.UnmarshalKey(context.Background(), spanWithAnnotationMqProducerCfgKey, cfg); err != nil { + return nil, err + } + return cfg, nil +} + func (t *TraceConfigCenter) GetBackfillMqProducerCfg(ctx context.Context) (*config.MqProducerCfg, error) { cfg := new(config.MqProducerCfg) if err := t.UnmarshalKey(context.Background(), backfillMqProducerCfgKey, cfg); err != nil { diff --git a/backend/modules/observability/infra/mq/consumer/consumer.go b/backend/modules/observability/infra/mq/consumer/consumer.go index c1a5f4b21..dd80d50ce 100644 --- a/backend/modules/observability/infra/mq/consumer/consumer.go +++ b/backend/modules/observability/infra/mq/consumer/consumer.go @@ -21,6 +21,7 @@ func NewConsumerWorkers( newCallbackConsumer(taskConsumer, loader), newCorrectionConsumer(taskConsumer, loader), newBackFillConsumer(taskConsumer, loader), + newSpanWithAnnotationConsumer(handler, loader), ) return workers, nil diff --git a/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go b/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go new file mode 100644 index 000000000..11c401b98 --- /dev/null +++ b/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go @@ -0,0 +1,58 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package consumer + +import ( + "context" + "time" + + "github.com/coze-dev/coze-loop/backend/infra/mq" + obapp "github.com/coze-dev/coze-loop/backend/modules/observability/application" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/pkg/conf" + "github.com/coze-dev/coze-loop/backend/pkg/json" + "github.com/coze-dev/coze-loop/backend/pkg/lang/conv" + "github.com/coze-dev/coze-loop/backend/pkg/logs" +) + +type SpanWithAnnotationConsumer struct { + handler obapp.ITaskQueueConsumer + conf.IConfigLoader +} + +func newSpanWithAnnotationConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { + return &SpanWithAnnotationConsumer{ + handler: handler, + IConfigLoader: loader, + } +} + +func (e *SpanWithAnnotationConsumer) ConsumerCfg(ctx context.Context) (*mq.ConsumerConfig, error) { + const key = "span_with_annotation_mq_consumer_config" + cfg := &config.MqConsumerCfg{} + if err := e.UnmarshalKey(ctx, key, cfg); err != nil { + return nil, err + } + res := &mq.ConsumerConfig{ + Addr: cfg.Addr, + Topic: cfg.Topic, + ConsumerGroup: cfg.ConsumerGroup, + ConsumeTimeout: time.Duration(cfg.Timeout) * time.Millisecond, + ConsumeGoroutineNums: cfg.WorkerNum, + } + return res, nil +} + +func (e *SpanWithAnnotationConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt) error { + logID := logs.NewLogID() + ctx = logs.SetLogID(ctx, logID) + event := new(loop_span.Span) + if err := json.Unmarshal(ext.Body, event); err != nil { + logs.CtxError(ctx, "Task msg json unmarshal fail, raw: %v, err: %s", conv.UnsafeBytesToString(ext.Body), err) + return nil + } + logs.CtxInfo(ctx, "Span with annotation msg,log_id=%s, trace_id=%s, span_id=%s,msgID=%s", event.LogID, event.TraceID, event.SpanID, ext.MsgID) + return e.handler.SpanTrigger(ctx, nil, event) +} diff --git a/backend/modules/observability/infra/mq/consumer/task_consumer.go b/backend/modules/observability/infra/mq/consumer/task_consumer.go index 048793ddd..5c1d3567e 100644 --- a/backend/modules/observability/infra/mq/consumer/task_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/task_consumer.go @@ -60,5 +60,5 @@ func (e *TaskConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt) er return nil } logs.CtxInfo(ctx, "Span msg,log_id=%s, trace_id=%s, span_id=%s,msgID=%s", event.LogID, event.TraceID, event.SpanID, ext.MsgID) - return e.handler.SpanTrigger(ctx, event) + return e.handler.SpanTrigger(ctx, event, nil) } diff --git a/backend/modules/observability/infra/mq/producer/span_with_annotation_producer.go b/backend/modules/observability/infra/mq/producer/span_with_annotation_producer.go new file mode 100644 index 000000000..c20dd2831 --- /dev/null +++ b/backend/modules/observability/infra/mq/producer/span_with_annotation_producer.go @@ -0,0 +1,84 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package producer + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/coze-dev/coze-loop/backend/infra/mq" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" + mq2 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity" + obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" + "github.com/coze-dev/coze-loop/backend/pkg/errorx" + "github.com/coze-dev/coze-loop/backend/pkg/json" + "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" + "github.com/coze-dev/coze-loop/backend/pkg/logs" +) + +var ( + spanWithAnnotationProducerOnce sync.Once + singletonSpanWithAnnotationProducer mq2.ISpanProducer +) + +type SpanWithAnnotationProducerImpl struct { + topic string + mqProducer mq.IProducer +} + +func (a *SpanWithAnnotationProducerImpl) SendSpanWithAnnotation(ctx context.Context, message *entity.SpanEvent) error { + bytes, err := json.Marshal(message) + if err != nil { + return errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) + } + msg := mq.NewDeferMessage(a.topic, 10*time.Second, bytes) + _, err = a.mqProducer.Send(ctx, msg) + if err != nil { + logs.CtxWarn(ctx, "send annotation msg err: %v", err) + return errorx.WrapByCode(err, obErrorx.CommercialCommonRPCErrorCodeCode) + } + logs.CtxInfo(ctx, "send annotation msg %s successfully", string(bytes)) + return nil +} + +func NewSpanWithAnnotationProducerImpl(traceConfig config.ITraceConfig, mqFactory mq.IFactory) (mq2.ISpanProducer, error) { + var err error + spanWithAnnotationProducerOnce.Do(func() { + singletonSpanWithAnnotationProducer, err = newSpanWithAnnotationProducerImpl(traceConfig, mqFactory) + }) + if err != nil { + return nil, err + } else { + return singletonSpanWithAnnotationProducer, nil + } +} + +func newSpanWithAnnotationProducerImpl(traceConfig config.ITraceConfig, mqFactory mq.IFactory) (mq2.ISpanProducer, error) { + mqCfg, err := traceConfig.GetSpanWithAnnotationMqProducerCfg(context.Background()) + if err != nil { + return nil, err + } + if mqCfg.Topic == "" { + return nil, fmt.Errorf("trace topic required") + } + mqProducer, err := mqFactory.NewProducer(mq.ProducerConfig{ + Addr: mqCfg.Addr, + ProduceTimeout: time.Duration(mqCfg.Timeout) * time.Millisecond, + RetryTimes: mqCfg.RetryTimes, + ProducerGroup: ptr.Of(mqCfg.ProducerGroup), + }) + if err != nil { + return nil, err + } + if err := mqProducer.Start(); err != nil { + return nil, fmt.Errorf("fail to start producer, %v", err) + } + return &SpanWithAnnotationProducerImpl{ + topic: mqCfg.Topic, + mqProducer: mqProducer, + }, nil +} From b3179aea5428a31483da2dd45bffaf45f149a5a9 Mon Sep 17 00:00:00 2001 From: "zhaoxun.3233" Date: Sat, 8 Nov 2025 17:47:47 +0800 Subject: [PATCH 11/19] fix consumer --- backend/modules/observability/infra/mq/consumer/consumer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/modules/observability/infra/mq/consumer/consumer.go b/backend/modules/observability/infra/mq/consumer/consumer.go index dd80d50ce..249108ffd 100644 --- a/backend/modules/observability/infra/mq/consumer/consumer.go +++ b/backend/modules/observability/infra/mq/consumer/consumer.go @@ -21,7 +21,7 @@ func NewConsumerWorkers( newCallbackConsumer(taskConsumer, loader), newCorrectionConsumer(taskConsumer, loader), newBackFillConsumer(taskConsumer, loader), - newSpanWithAnnotationConsumer(handler, loader), + newSpanWithAnnotationConsumer(taskConsumer, loader), ) return workers, nil From b3c57989ffc329e3f61a7ecb4890ec1159f37e6a Mon Sep 17 00:00:00 2001 From: "zhaoxun.3233" Date: Sat, 8 Nov 2025 18:19:08 +0800 Subject: [PATCH 12/19] add SpanWithAnnotation proc --- .../modules/observability/application/task.go | 45 ++++++++++++++++--- .../domain/task/service/task_callback.go | 2 + .../observability/domain/trace/repo/trace.go | 1 + .../trace/service/trace_export_service.go | 1 + .../domain/trace/service/trace_service.go | 7 +++ .../consumer/span_with_annotation_consumer.go | 5 +++ .../producer/span_with_annotation_producer.go | 2 +- .../modules/observability/infra/repo/trace.go | 19 ++++++-- 8 files changed, 70 insertions(+), 12 deletions(-) diff --git a/backend/modules/observability/application/task.go b/backend/modules/observability/application/task.go index 87730faef..940abdf47 100644 --- a/backend/modules/observability/application/task.go +++ b/backend/modules/observability/application/task.go @@ -10,6 +10,7 @@ import ( "github.com/coze-dev/coze-loop/backend/infra/middleware/session" "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/observability/task" + "github.com/coze-dev/coze-loop/backend/modules/data/pkg/errno" "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor" tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" @@ -18,6 +19,7 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + tracerepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" obErrorx "github.com/coze-dev/coze-loop/backend/modules/observability/pkg/errno" "github.com/coze-dev/coze-loop/backend/pkg/errorx" "github.com/coze-dev/coze-loop/backend/pkg/logs" @@ -67,6 +69,7 @@ type TaskApplication struct { tracehubSvc tracehub.ITraceHubService taskProcessor processor.TaskProcessor taskCallbackSvc service.ITaskCallbackService + traceRepo tracerepo.ITraceRepo } func (t *TaskApplication) CheckTaskName(ctx context.Context, req *task.CheckTaskNameRequest) (*task.CheckTaskNameResponse, error) { @@ -261,15 +264,43 @@ func (t *TaskApplication) GetTask(ctx context.Context, req *task.GetTaskRequest) }, nil } -func (t *TaskApplication) SpanTrigger(ctx context.Context, rawSpan *entity.RawSpan, Span *loop_span.Span) error { - span := rawSpan.RawSpanConvertToLoopSpan() - if span != nil { - if err := t.tracehubSvc.SpanTrigger(ctx, span); err != nil { - logs.CtxError(ctx, "SpanTrigger err:%v", err) - // span trigger 失败,不处理 - return nil +func (t *TaskApplication) SpanTrigger(ctx context.Context, rawSpan *entity.RawSpan, loopSpan *loop_span.Span) error { + if rawSpan != nil { + span := rawSpan.RawSpanConvertToLoopSpan() + if span != nil { + if err := t.tracehubSvc.SpanTrigger(ctx, span); err != nil { + logs.CtxError(ctx, "SpanTrigger err:%v", err) + // span trigger 失败,不处理 + return nil + } } } + if loopSpan != nil { + workspaceID, err := strconv.ParseInt(loopSpan.WorkspaceID, 10, 64) + if err != nil { + return errno.InternalErr(err, "convert %s to int64", loopSpan.WorkspaceID) + } + annotations, err := t.traceRepo.ListAnnotations(ctx, &tracerepo.ListAnnotationsParam{ + Tenants: []string{loopSpan.GetTenant()}, + SpanID: loopSpan.SpanID, + TraceID: loopSpan.TraceID, + WorkspaceId: workspaceID, + StartAt: loopSpan.StartTime - 5*time.Second.Milliseconds(), + EndAt: loopSpan.StartTime + 5*time.Second.Milliseconds(), + }) + if err != nil { + return err + } + loopSpan.Annotations = append(loopSpan.Annotations, annotations...) + if loopSpan != nil { + if err := t.tracehubSvc.SpanTrigger(ctx, loopSpan); err != nil { + logs.CtxError(ctx, "SpanTrigger err:%v", err) + // span trigger 失败,不处理 + return nil + } + } + } + return nil } diff --git a/backend/modules/observability/domain/task/service/task_callback.go b/backend/modules/observability/domain/task/service/task_callback.go index 90d1fda40..07425d335 100644 --- a/backend/modules/observability/domain/task/service/task_callback.go +++ b/backend/modules/observability/domain/task/service/task_callback.go @@ -117,6 +117,7 @@ func (t *TaskCallbackServiceImpl) AutoEvalCallback(ctx context.Context, event *e Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, + Span: span, }) if err != nil { return err @@ -175,6 +176,7 @@ func (t *TaskCallbackServiceImpl) AutoEvalCorrection(ctx context.Context, event Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, + Span: span, } if err = t.traceRepo.InsertAnnotations(ctx, param); err != nil { recordID := lo.Ternary(annotation.GetAutoEvaluateMetadata() != nil, annotation.GetAutoEvaluateMetadata().EvaluatorRecordID, 0) diff --git a/backend/modules/observability/domain/trace/repo/trace.go b/backend/modules/observability/domain/trace/repo/trace.go index 3677e6d1c..da29927bc 100644 --- a/backend/modules/observability/domain/trace/repo/trace.go +++ b/backend/modules/observability/domain/trace/repo/trace.go @@ -67,6 +67,7 @@ type InsertAnnotationParam struct { Tenant string TTL loop_span.TTL Annotations []*loop_span.Annotation + Span *loop_span.Span } //go:generate mockgen -destination=mocks/trace.go -package=mocks . ITraceRepo diff --git a/backend/modules/observability/domain/trace/service/trace_export_service.go b/backend/modules/observability/domain/trace/service/trace_export_service.go index 7ad941ef7..7ddd73da2 100644 --- a/backend/modules/observability/domain/trace/service/trace_export_service.go +++ b/backend/modules/observability/domain/trace/service/trace_export_service.go @@ -404,6 +404,7 @@ func (r *TraceExportServiceImpl) addSpanAnnotations(ctx context.Context, spans [ Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, + Span: span, }) if err != nil { // 忽略add annotations的错误,防止用户重复导入数据集。 diff --git a/backend/modules/observability/domain/trace/service/trace_service.go b/backend/modules/observability/domain/trace/service/trace_service.go index 6dda97cba..c58b5bbb3 100644 --- a/backend/modules/observability/domain/trace/service/trace_service.go +++ b/backend/modules/observability/domain/trace/service/trace_service.go @@ -732,6 +732,7 @@ func (r *TraceServiceImpl) CreateManualAnnotation(ctx context.Context, req *Crea Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, + Span: span, }); err != nil { return nil, err } @@ -788,6 +789,7 @@ func (r *TraceServiceImpl) UpdateManualAnnotation(ctx context.Context, req *Upda Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, + Span: span, }) } @@ -826,6 +828,7 @@ func (r *TraceServiceImpl) DeleteManualAnnotation(ctx context.Context, req *Dele Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, + Span: span, }) } @@ -890,6 +893,7 @@ func (r *TraceServiceImpl) CreateAnnotation(ctx context.Context, req *CreateAnno Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, + Span: span, }) } @@ -941,6 +945,7 @@ func (r *TraceServiceImpl) DeleteAnnotation(ctx context.Context, req *DeleteAnno Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, + Span: span, }) } @@ -982,6 +987,7 @@ func (r *TraceServiceImpl) Send(ctx context.Context, event *entity.AnnotationEve Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{event.Annotation}, + Span: span, }) } @@ -1174,6 +1180,7 @@ func (r *TraceServiceImpl) ChangeEvaluatorScore(ctx context.Context, req *Change Tenant: span.GetTenant(), TTL: span.GetTTL(ctx), Annotations: []*loop_span.Annotation{annotation}, + Span: span, } if err = r.traceRepo.InsertAnnotations(ctx, param); err != nil { recordID := lo.Ternary(annotation.GetAutoEvaluateMetadata() != nil, annotation.GetAutoEvaluateMetadata().EvaluatorRecordID, 0) diff --git a/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go b/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go index 11c401b98..1f2cc98c8 100644 --- a/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go @@ -41,6 +41,11 @@ func (e *SpanWithAnnotationConsumer) ConsumerCfg(ctx context.Context) (*mq.Consu ConsumerGroup: cfg.ConsumerGroup, ConsumeTimeout: time.Duration(cfg.Timeout) * time.Millisecond, ConsumeGoroutineNums: cfg.WorkerNum, + EnablePPE: cfg.EnablePPE, + IsEnabled: cfg.IsEnabled, + } + if cfg.TagExpression != nil { + res.TagExpression = *cfg.TagExpression } return res, nil } diff --git a/backend/modules/observability/infra/mq/producer/span_with_annotation_producer.go b/backend/modules/observability/infra/mq/producer/span_with_annotation_producer.go index c20dd2831..343bb941e 100644 --- a/backend/modules/observability/infra/mq/producer/span_with_annotation_producer.go +++ b/backend/modules/observability/infra/mq/producer/span_with_annotation_producer.go @@ -35,7 +35,7 @@ func (a *SpanWithAnnotationProducerImpl) SendSpanWithAnnotation(ctx context.Cont if err != nil { return errorx.WrapByCode(err, obErrorx.CommercialCommonInternalErrorCodeCode) } - msg := mq.NewDeferMessage(a.topic, 10*time.Second, bytes) + msg := mq.NewDeferMessage(a.topic, 60*time.Second, bytes) _, err = a.mqProducer.Send(ctx, msg) if err != nil { logs.CtxWarn(ctx, "send annotation msg err: %v", err) diff --git a/backend/modules/observability/infra/repo/trace.go b/backend/modules/observability/infra/repo/trace.go index f75999930..b06351b59 100644 --- a/backend/modules/observability/infra/repo/trace.go +++ b/backend/modules/observability/infra/repo/trace.go @@ -11,7 +11,9 @@ import ( "time" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/mq" metric_repo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/repo" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/repo" "github.com/coze-dev/coze-loop/backend/modules/observability/infra/repo/ck" @@ -51,9 +53,10 @@ func NewTraceMetricCKRepoImpl( } type TraceCkRepoImpl struct { - spansDao ck.ISpansDao - annoDao ck.IAnnotationDao - traceConfig config.ITraceConfig + spansDao ck.ISpansDao + annoDao ck.IAnnotationDao + traceConfig config.ITraceConfig + spanProducer mq.ISpanProducer } type PageToken struct { @@ -284,10 +287,18 @@ func (t *TraceCkRepoImpl) InsertAnnotations(ctx context.Context, param *repo.Ins } pos = append(pos, annotationPO) } - return t.annoDao.Insert(ctx, &ck.InsertAnnotationParam{ + err = t.annoDao.Insert(ctx, &ck.InsertAnnotationParam{ Table: table, Annotations: pos, }) + if err != nil { + return nil + } + span := param.Span + span.Annotations = append(span.Annotations, param.Annotations...) + return t.spanProducer.SendSpanWithAnnotation(ctx, &entity.SpanEvent{ + Span: span, + }) } func (t *TraceCkRepoImpl) GetMetrics(ctx context.Context, param *metric_repo.GetMetricsParam) (*metric_repo.GetMetricsResult, error) { From 792194e8ad03e3fc62353554e90cab1602e425bc Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Mon, 10 Nov 2025 19:45:45 +0800 Subject: [PATCH 13/19] Make new task consumer public. Change-Id: Ibb38cfb316881aa0c287a10834afd9608fcce3fe --- .../infra/mq/consumer/annotation_consumer.go | 2 +- .../infra/mq/consumer/autotask_callback_consumer.go | 2 +- .../infra/mq/consumer/backfill_consumer.go | 2 +- .../observability/infra/mq/consumer/consumer.go | 10 +++++----- .../infra/mq/consumer/correction_consumer.go | 2 +- .../observability/infra/mq/consumer/task_consumer.go | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/backend/modules/observability/infra/mq/consumer/annotation_consumer.go b/backend/modules/observability/infra/mq/consumer/annotation_consumer.go index b0034d6ce..648694841 100644 --- a/backend/modules/observability/infra/mq/consumer/annotation_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/annotation_consumer.go @@ -22,7 +22,7 @@ type AnnotationConsumer struct { conf.IConfigLoader } -func newAnnotationConsumer(handler obapp.IAnnotationQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { +func NewAnnotationConsumer(handler obapp.IAnnotationQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { return &AnnotationConsumer{ handler: handler, IConfigLoader: loader, diff --git a/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go b/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go index 0996387b4..20f41a7b4 100644 --- a/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/autotask_callback_consumer.go @@ -22,7 +22,7 @@ type AutoTaskCallbackConsumer struct { conf.IConfigLoader } -func newCallbackConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { +func NewCallbackConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { return &AutoTaskCallbackConsumer{ handler: handler, IConfigLoader: loader, diff --git a/backend/modules/observability/infra/mq/consumer/backfill_consumer.go b/backend/modules/observability/infra/mq/consumer/backfill_consumer.go index fba8bd0e4..c5003e165 100644 --- a/backend/modules/observability/infra/mq/consumer/backfill_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/backfill_consumer.go @@ -22,7 +22,7 @@ type BackFillConsumer struct { conf.IConfigLoader } -func newBackFillConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { +func NewBackFillConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { return &BackFillConsumer{ handler: handler, IConfigLoader: loader, diff --git a/backend/modules/observability/infra/mq/consumer/consumer.go b/backend/modules/observability/infra/mq/consumer/consumer.go index c1a5f4b21..40c133afc 100644 --- a/backend/modules/observability/infra/mq/consumer/consumer.go +++ b/backend/modules/observability/infra/mq/consumer/consumer.go @@ -16,11 +16,11 @@ func NewConsumerWorkers( ) ([]mq.IConsumerWorker, error) { workers := []mq.IConsumerWorker{} workers = append(workers, - newAnnotationConsumer(handler, loader), - newTaskConsumer(taskConsumer, loader), - newCallbackConsumer(taskConsumer, loader), - newCorrectionConsumer(taskConsumer, loader), - newBackFillConsumer(taskConsumer, loader), + NewAnnotationConsumer(handler, loader), + NewTaskConsumer(taskConsumer, loader), + NewCallbackConsumer(taskConsumer, loader), + NewCorrectionConsumer(taskConsumer, loader), + NewBackFillConsumer(taskConsumer, loader), ) return workers, nil diff --git a/backend/modules/observability/infra/mq/consumer/correction_consumer.go b/backend/modules/observability/infra/mq/consumer/correction_consumer.go index 3fdbe41f2..9ab1eae9b 100644 --- a/backend/modules/observability/infra/mq/consumer/correction_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/correction_consumer.go @@ -21,7 +21,7 @@ type CorrectionConsumer struct { conf.IConfigLoader } -func newCorrectionConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { +func NewCorrectionConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { return &CorrectionConsumer{ handler: handler, IConfigLoader: loader, diff --git a/backend/modules/observability/infra/mq/consumer/task_consumer.go b/backend/modules/observability/infra/mq/consumer/task_consumer.go index 048793ddd..e6f5554fd 100644 --- a/backend/modules/observability/infra/mq/consumer/task_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/task_consumer.go @@ -22,7 +22,7 @@ type TaskConsumer struct { conf.IConfigLoader } -func newTaskConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { +func NewTaskConsumer(handler obapp.ITaskQueueConsumer, loader conf.IConfigLoader) mq.IConsumerWorker { return &TaskConsumer{ handler: handler, IConfigLoader: loader, From 73e0ce6585974dd41e97b92b2454688bb0f82fec Mon Sep 17 00:00:00 2001 From: "zhaoxun.3233" Date: Tue, 11 Nov 2025 10:54:26 +0800 Subject: [PATCH 14/19] add wire --- .../modules/observability/application/wire.go | 1 + .../observability/application/wire_gen.go | 26 +++++++++++++++---- .../modules/observability/infra/repo/trace.go | 8 +++--- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/backend/modules/observability/application/wire.go b/backend/modules/observability/application/wire.go index c907da708..67d5bb31c 100644 --- a/backend/modules/observability/application/wire.go +++ b/backend/modules/observability/application/wire.go @@ -121,6 +121,7 @@ var ( obconfig.NewTraceConfigCenter, NewTraceConfigLoader, NewIngestionCollectorFactory, + mq2.NewSpanWithAnnotationProducerImpl, ) openApiSet = wire.NewSet( NewOpenAPIApplication, diff --git a/backend/modules/observability/application/wire_gen.go b/backend/modules/observability/application/wire_gen.go index 03a6666c4..449fcf06d 100644 --- a/backend/modules/observability/application/wire_gen.go +++ b/backend/modules/observability/application/wire_gen.go @@ -86,7 +86,11 @@ func InitTraceApplication(db2 db.Provider, ckDb ck.Provider, redis3 redis.Cmdabl return nil, err } iTraceConfig := config.NewTraceConfigCenter(iConfigLoader) - iTraceRepo, err := repo.NewTraceCKRepoImpl(iSpansDao, iAnnotationDao, iTraceConfig) + iSpanProducer, err := producer.NewSpanWithAnnotationProducerImpl(iTraceConfig, mqFactory) + if err != nil { + return nil, err + } + iTraceRepo, err := repo.NewTraceCKRepoImpl(iSpansDao, iAnnotationDao, iTraceConfig, iSpanProducer) if err != nil { return nil, err } @@ -143,7 +147,11 @@ func InitOpenAPIApplication(mqFactory mq.IFactory, configFactory conf.IConfigLoa return nil, err } iTraceConfig := config.NewTraceConfigCenter(iConfigLoader) - iTraceRepo, err := repo.NewTraceCKRepoImpl(iSpansDao, iAnnotationDao, iTraceConfig) + iSpanProducer, err := producer.NewSpanWithAnnotationProducerImpl(iTraceConfig, mqFactory) + if err != nil { + return nil, err + } + iTraceRepo, err := repo.NewTraceCKRepoImpl(iSpansDao, iAnnotationDao, iTraceConfig, iSpanProducer) if err != nil { return nil, err } @@ -227,7 +235,11 @@ func InitTraceIngestionApplication(configFactory conf.IConfigLoaderFactory, ckDb return nil, err } iTraceConfig := config.NewTraceConfigCenter(iConfigLoader) - iTraceRepo, err := repo.NewTraceCKRepoImpl(iSpansDao, iAnnotationDao, iTraceConfig) + iSpanProducer, err := producer.NewSpanWithAnnotationProducerImpl(iTraceConfig, mqFactory) + if err != nil { + return nil, err + } + iTraceRepo, err := repo.NewTraceCKRepoImpl(iSpansDao, iAnnotationDao, iTraceConfig, iSpanProducer) if err != nil { return nil, err } @@ -275,7 +287,11 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto if err != nil { return nil, err } - iTraceRepo, err := repo.NewTraceCKRepoImpl(iSpansDao, iAnnotationDao, iTraceConfig) + iSpanProducer, err := producer.NewSpanWithAnnotationProducerImpl(iTraceConfig, mqFactory) + if err != nil { + return nil, err + } + iTraceRepo, err := repo.NewTraceCKRepoImpl(iSpansDao, iAnnotationDao, iTraceConfig, iSpanProducer) if err != nil { return nil, err } @@ -308,7 +324,7 @@ var ( ) traceIngestionSet = wire.NewSet( NewIngestionApplication, service.NewIngestionServiceImpl, repo.NewTraceCKRepoImpl, ck2.NewSpansCkDaoImpl, ck2.NewAnnotationCkDaoImpl, config.NewTraceConfigCenter, NewTraceConfigLoader, - NewIngestionCollectorFactory, + NewIngestionCollectorFactory, producer.NewSpanWithAnnotationProducerImpl, ) openApiSet = wire.NewSet( NewOpenAPIApplication, auth.NewAuthProvider, traceDomainSet, diff --git a/backend/modules/observability/infra/repo/trace.go b/backend/modules/observability/infra/repo/trace.go index b06351b59..d63120f5a 100644 --- a/backend/modules/observability/infra/repo/trace.go +++ b/backend/modules/observability/infra/repo/trace.go @@ -32,11 +32,13 @@ func NewTraceCKRepoImpl( spanDao ck.ISpansDao, annoDao ck.IAnnotationDao, traceConfig config.ITraceConfig, + spanProducer mq.ISpanProducer, ) (repo.ITraceRepo, error) { return &TraceCkRepoImpl{ - spansDao: spanDao, - annoDao: annoDao, - traceConfig: traceConfig, + spansDao: spanDao, + annoDao: annoDao, + traceConfig: traceConfig, + spanProducer: spanProducer, }, nil } From d93abc6a1415272d0369c5385e678f699bf62fd9 Mon Sep 17 00:00:00 2001 From: "zhaoxun.3233" Date: Tue, 11 Nov 2025 11:10:16 +0800 Subject: [PATCH 15/19] fix mq proc --- .../infra/mq/consumer/span_with_annotation_consumer.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go b/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go index 1f2cc98c8..e648517d3 100644 --- a/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go +++ b/backend/modules/observability/infra/mq/consumer/span_with_annotation_consumer.go @@ -10,7 +10,7 @@ import ( "github.com/coze-dev/coze-loop/backend/infra/mq" obapp "github.com/coze-dev/coze-loop/backend/modules/observability/application" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity" "github.com/coze-dev/coze-loop/backend/pkg/conf" "github.com/coze-dev/coze-loop/backend/pkg/json" "github.com/coze-dev/coze-loop/backend/pkg/lang/conv" @@ -53,11 +53,11 @@ func (e *SpanWithAnnotationConsumer) ConsumerCfg(ctx context.Context) (*mq.Consu func (e *SpanWithAnnotationConsumer) HandleMessage(ctx context.Context, ext *mq.MessageExt) error { logID := logs.NewLogID() ctx = logs.SetLogID(ctx, logID) - event := new(loop_span.Span) + event := new(entity.SpanEvent) if err := json.Unmarshal(ext.Body, event); err != nil { logs.CtxError(ctx, "Task msg json unmarshal fail, raw: %v, err: %s", conv.UnsafeBytesToString(ext.Body), err) return nil } - logs.CtxInfo(ctx, "Span with annotation msg,log_id=%s, trace_id=%s, span_id=%s,msgID=%s", event.LogID, event.TraceID, event.SpanID, ext.MsgID) - return e.handler.SpanTrigger(ctx, nil, event) + logs.CtxInfo(ctx, "Span with annotation msg,log_id=%s, trace_id=%s, span_id=%s,msgID=%s", event.Span.LogID, event.Span.TraceID, event.Span.SpanID, ext.MsgID) + return e.handler.SpanTrigger(ctx, nil, event.Span) } From 4e9719dde0fab7ff4b22f729e446d3480f235eee Mon Sep 17 00:00:00 2001 From: "zhaoxun.3233" Date: Tue, 11 Nov 2025 11:13:40 +0800 Subject: [PATCH 16/19] add wire gen --- backend/modules/observability/application/task.go | 2 ++ backend/modules/observability/application/wire_gen.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/modules/observability/application/task.go b/backend/modules/observability/application/task.go index 940abdf47..d0589b59a 100644 --- a/backend/modules/observability/application/task.go +++ b/backend/modules/observability/application/task.go @@ -47,6 +47,7 @@ func NewTaskApplication( tracehubSvc tracehub.ITraceHubService, taskProcessor processor.TaskProcessor, taskCallbackService service.ITaskCallbackService, + traceRepo tracerepo.ITraceRepo, ) (ITaskApplication, error) { return &TaskApplication{ taskSvc: taskService, @@ -57,6 +58,7 @@ func NewTaskApplication( tracehubSvc: tracehubSvc, taskProcessor: taskProcessor, taskCallbackSvc: taskCallbackService, + traceRepo: traceRepo, }, nil } diff --git a/backend/modules/observability/application/wire_gen.go b/backend/modules/observability/application/wire_gen.go index 449fcf06d..a8f6443f1 100644 --- a/backend/modules/observability/application/wire_gen.go +++ b/backend/modules/observability/application/wire_gen.go @@ -302,7 +302,7 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto return nil, err } iTaskCallbackService := service3.NewTaskCallbackServiceImpl(iTaskRepo, iTraceRepo, taskProcessor, iTenantProvider, iTraceConfig, benefit2) - iTaskApplication, err := NewTaskApplication(iTaskService, iAuthProvider, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iUserProvider, iTraceHubService, taskProcessor, iTaskCallbackService) + iTaskApplication, err := NewTaskApplication(iTaskService, iAuthProvider, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iUserProvider, iTraceHubService, taskProcessor, iTaskCallbackService, iTraceRepo) if err != nil { return nil, err } From 665f5ae6b2cb1d43df687feecf075c17408ec6c3 Mon Sep 17 00:00:00 2001 From: "zhaoxun.3233" Date: Tue, 11 Nov 2025 11:43:46 +0800 Subject: [PATCH 17/19] fix time proc --- backend/modules/observability/application/task.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/backend/modules/observability/application/task.go b/backend/modules/observability/application/task.go index d0589b59a..601df9770 100644 --- a/backend/modules/observability/application/task.go +++ b/backend/modules/observability/application/task.go @@ -287,12 +287,13 @@ func (t *TaskApplication) SpanTrigger(ctx context.Context, rawSpan *entity.RawSp SpanID: loopSpan.SpanID, TraceID: loopSpan.TraceID, WorkspaceId: workspaceID, - StartAt: loopSpan.StartTime - 5*time.Second.Milliseconds(), - EndAt: loopSpan.StartTime + 5*time.Second.Milliseconds(), + StartAt: loopSpan.StartTime/1000 - 5*time.Second.Milliseconds(), + EndAt: loopSpan.StartTime/1000 + 5*time.Second.Milliseconds(), }) if err != nil { return err } + loopSpan.StartTime = loopSpan.StartTime / 1000 loopSpan.Annotations = append(loopSpan.Annotations, annotations...) if loopSpan != nil { if err := t.tracehubSvc.SpanTrigger(ctx, loopSpan); err != nil { From febba2d5ebde61f3529df66341a9baf41bde8104 Mon Sep 17 00:00:00 2001 From: "zhaoxun.3233" Date: Tue, 11 Nov 2025 12:02:09 +0800 Subject: [PATCH 18/19] add annotation proc --- .../domain/trace/entity/loop_span/span.go | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/backend/modules/observability/domain/trace/entity/loop_span/span.go b/backend/modules/observability/domain/trace/entity/loop_span/span.go index 9e99ef940..e84759979 100644 --- a/backend/modules/observability/domain/trace/entity/loop_span/span.go +++ b/backend/modules/observability/domain/trace/entity/loop_span/span.go @@ -300,7 +300,25 @@ func (s *Span) GetFieldValue(fieldName string, isSystem, isCustom bool) any { } else if val, ok := s.TagsByte[fieldName]; ok { return val } - return nil + annotationMap := make(map[string]AnnotationValue) + for _, annotation := range s.Annotations { + annotationMap[fmt.Sprintf("%s_%s", annotation.AnnotationType, annotation.Key)] = annotation.Value + } + if val, ok := annotationMap[fieldName]; ok { + switch val.ValueType { + case AnnotationValueTypeLong: + return val.LongValue + case AnnotationValueTypeDouble, AnnotationValueTypeNumber: + return val.FloatValue + case AnnotationValueTypeBool: + return val.BoolValue + case AnnotationValueTypeString: + return val.StringValue + default: + return nil + } + } + return annotationMap } func (s *Span) IsValidSpan() error { From d3a08729286fce200649761abb5f6ec99f60a569 Mon Sep 17 00:00:00 2001 From: taoyifan89 Date: Wed, 12 Nov 2025 20:33:34 +0800 Subject: [PATCH 19/19] refactor scheduled task. Change-Id: I8af657662d6015e825c1345d41e518f931e01acd --- backend/api/api.go | 3 + .../api/handler/coze/loop/apis/wire_gen.go | 1 - .../modules/observability/application/task.go | 15 + .../modules/observability/application/wire.go | 17 + .../observability/application/wire_gen.go | 18 +- .../component/scheduledtask/scheduledtask.go | 62 +++ .../observability/domain/task/entity/task.go | 4 +- .../domain/task/service/task_service.go | 8 +- .../domain/task/service/task_service_test.go | 4 +- .../scheduledtask/local_cache_refresh.go | 94 ++++ .../scheduled_task_test.go | 2 +- .../status_check.go} | 429 +++++++----------- .../task/service/taskexe/tracehub/backfill.go | 58 ++- .../service/taskexe/tracehub/local_cache.go | 2 +- .../service/taskexe/tracehub/span_trigger.go | 57 ++- .../service/taskexe/tracehub/trace_hub.go | 56 +-- 16 files changed, 486 insertions(+), 344 deletions(-) create mode 100644 backend/modules/observability/domain/component/scheduledtask/scheduledtask.go create mode 100644 backend/modules/observability/domain/task/service/taskexe/scheduledtask/local_cache_refresh.go rename backend/modules/observability/domain/task/service/taskexe/{tracehub => scheduledtask}/scheduled_task_test.go (99%) rename backend/modules/observability/domain/task/service/taskexe/{tracehub/scheduled_task.go => scheduledtask/status_check.go} (64%) mode change 100755 => 100644 diff --git a/backend/api/api.go b/backend/api/api.go index 3419c1a98..3e71e43a2 100644 --- a/backend/api/api.go +++ b/backend/api/api.go @@ -125,6 +125,9 @@ func Init( if err != nil { return nil, err } + if err = observabilityHandler.RunTaskScheduleTask(ctx); err != nil { + return nil, err + } observabilityHandler.RunAsync(ctx) return &apis.APIHandler{ diff --git a/backend/api/handler/coze/loop/apis/wire_gen.go b/backend/api/handler/coze/loop/apis/wire_gen.go index f090ab564..7b5ae44ac 100644 --- a/backend/api/handler/coze/loop/apis/wire_gen.go +++ b/backend/api/handler/coze/loop/apis/wire_gen.go @@ -8,7 +8,6 @@ package apis import ( "context" - "github.com/cloudwego/kitex/pkg/endpoint" "github.com/coze-dev/coze-loop/backend/infra/ck" "github.com/coze-dev/coze-loop/backend/infra/db" diff --git a/backend/modules/observability/application/task.go b/backend/modules/observability/application/task.go index 4948c5141..054854056 100644 --- a/backend/modules/observability/application/task.go +++ b/backend/modules/observability/application/task.go @@ -13,6 +13,7 @@ import ( "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor" tconv "github.com/coze-dev/coze-loop/backend/modules/observability/application/convertor/task" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/scheduledtask" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" @@ -33,6 +34,7 @@ type ITaskQueueConsumer interface { type ITaskApplication interface { task.TaskService ITaskQueueConsumer + RunTaskScheduleTask(ctx context.Context) error } func NewTaskApplication( @@ -44,6 +46,7 @@ func NewTaskApplication( tracehubSvc tracehub.ITraceHubService, taskProcessor processor.TaskProcessor, taskCallbackService service.ITaskCallbackService, + scheduledTasks []scheduledtask.ScheduledTask, ) (ITaskApplication, error) { return &TaskApplication{ taskSvc: taskService, @@ -54,6 +57,7 @@ func NewTaskApplication( tracehubSvc: tracehubSvc, taskProcessor: taskProcessor, taskCallbackSvc: taskCallbackService, + scheduledTasks: scheduledTasks, }, nil } @@ -66,6 +70,7 @@ type TaskApplication struct { tracehubSvc tracehub.ITraceHubService taskProcessor processor.TaskProcessor taskCallbackSvc service.ITaskCallbackService + scheduledTasks []scheduledtask.ScheduledTask } func (t *TaskApplication) CheckTaskName(ctx context.Context, req *task.CheckTaskNameRequest) (*task.CheckTaskNameResponse, error) { @@ -304,3 +309,13 @@ func (t *TaskApplication) BackFill(ctx context.Context, event *entity.BackFillEv return t.tracehubSvc.BackFill(ctx, event) } + +func (t *TaskApplication) RunTaskScheduleTask(ctx context.Context) error { + for _, scheduledTask := range t.scheduledTasks { + if err := scheduledTask.Run(); err != nil { + logs.CtxError(ctx, "RunTaskScheduleTask err:%v", err) + return err + } + } + return nil +} diff --git a/backend/modules/observability/application/wire.go b/backend/modules/observability/application/wire.go index ec9b6e356..2185c7ad9 100644 --- a/backend/modules/observability/application/wire.go +++ b/backend/modules/observability/application/wire.go @@ -26,6 +26,7 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/user/userservice" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/scheduledtask" metrics_entity "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/entity" metric_service "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service" metric_general "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/general" @@ -36,6 +37,7 @@ import ( trepo "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" taskSvc "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" task_processor "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + taskst "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/scheduledtask" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/collector/exporter" @@ -81,6 +83,7 @@ var ( redis2.NewTaskRunDAO, mysqldao.NewTaskRunDaoImpl, mq2.NewBackfillProducerImpl, + NewScheduledTask, ) traceDomainSet = wire.NewSet( service.NewTraceServiceImpl, @@ -286,6 +289,20 @@ func NewInitTaskProcessor(datasetServiceProvider *service.DatasetServiceAdaptor, return taskProcessor } +func NewScheduledTask( + locker lock.ILocker, + config config.ITraceConfig, + traceHubService tracehub.ITraceHubService, + taskService taskSvc.ITaskService, + taskProcessor task_processor.TaskProcessor, + taskRepo trepo.ITaskRepo, +) []scheduledtask.ScheduledTask { + return []scheduledtask.ScheduledTask{ + taskst.NewStatusCheckTask(locker, config, traceHubService, taskService, taskProcessor, taskRepo), + taskst.NewLocalCacheRefreshTask(traceHubService, taskRepo), + } +} + func InitTraceApplication( db db.Provider, ckDb ck.Provider, diff --git a/backend/modules/observability/application/wire_gen.go b/backend/modules/observability/application/wire_gen.go index 7101caf66..6bc481c78 100644 --- a/backend/modules/observability/application/wire_gen.go +++ b/backend/modules/observability/application/wire_gen.go @@ -26,6 +26,7 @@ import ( "github.com/coze-dev/coze-loop/backend/kitex_gen/coze/loop/foundation/user/userservice" config2 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/rpc" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/scheduledtask" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/entity" service2 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/metric/service/metric/general" @@ -36,6 +37,7 @@ import ( repo3 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" service3 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + scheduledtask2 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/scheduledtask" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" entity2 "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/collector/exporter" @@ -286,7 +288,8 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto return nil, err } iTaskCallbackService := service3.NewTaskCallbackServiceImpl(iTaskRepo, iTraceRepo, taskProcessor, iTenantProvider, iTraceConfig, benefit2) - iTaskApplication, err := NewTaskApplication(iTaskService, iAuthProvider, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iUserProvider, iTraceHubService, taskProcessor, iTaskCallbackService) + v := NewScheduledTask(iLocker, iTraceConfig, iTraceHubService, iTaskService, taskProcessor, iTaskRepo) + iTaskApplication, err := NewTaskApplication(iTaskService, iAuthProvider, iEvaluatorRPCAdapter, iEvaluationRPCAdapter, iUserProvider, iTraceHubService, taskProcessor, iTaskCallbackService, v) if err != nil { return nil, err } @@ -297,7 +300,7 @@ func InitTaskApplication(db2 db.Provider, idgen2 idgen.IIDGenerator, configFacto var ( taskDomainSet = wire.NewSet( - NewInitTaskProcessor, service3.NewTaskServiceImpl, repo.NewTaskRepoImpl, mysql.NewTaskDaoImpl, redis2.NewTaskDAO, redis2.NewTaskRunDAO, mysql.NewTaskRunDaoImpl, producer.NewBackfillProducerImpl, + NewInitTaskProcessor, service3.NewTaskServiceImpl, repo.NewTaskRepoImpl, mysql.NewTaskDaoImpl, redis2.NewTaskDAO, redis2.NewTaskRunDAO, mysql.NewTaskRunDaoImpl, producer.NewBackfillProducerImpl, NewScheduledTask, ) traceDomainSet = wire.NewSet(service.NewTraceServiceImpl, service.NewTraceExportServiceImpl, repo.NewTraceCKRepoImpl, ck2.NewSpansCkDaoImpl, ck2.NewAnnotationCkDaoImpl, metrics2.NewTraceMetricsImpl, collector.NewEventCollectorProvider, producer.NewTraceProducerImpl, producer.NewAnnotationProducerImpl, file.NewFileRPCProvider, NewTraceConfigLoader, NewTraceProcessorBuilder, config.NewTraceConfigCenter, tenant.NewTenantProvider, workspace.NewWorkspaceProvider, evaluator.NewEvaluatorRPCProvider, NewDatasetServiceAdapter, @@ -375,3 +378,14 @@ func NewInitTaskProcessor(datasetServiceProvider *service.DatasetServiceAdaptor, taskProcessor.Register(entity3.TaskTypeAutoEval, processor.NewAutoEvaluteProcessor(0, datasetServiceProvider, evalService, evaluationService, taskRepo)) return taskProcessor } + +func NewScheduledTask( + locker lock.ILocker, config3 config2.ITraceConfig, + + traceHubService tracehub.ITraceHubService, + taskService service3.ITaskService, + taskProcessor processor.TaskProcessor, + taskRepo repo3.ITaskRepo, +) []scheduledtask.ScheduledTask { + return []scheduledtask.ScheduledTask{scheduledtask2.NewStatusCheckTask(locker, config3, traceHubService, taskService, taskProcessor, taskRepo), scheduledtask2.NewLocalCacheRefreshTask(traceHubService, taskRepo)} +} diff --git a/backend/modules/observability/domain/component/scheduledtask/scheduledtask.go b/backend/modules/observability/domain/component/scheduledtask/scheduledtask.go new file mode 100644 index 000000000..cdb57cf23 --- /dev/null +++ b/backend/modules/observability/domain/component/scheduledtask/scheduledtask.go @@ -0,0 +1,62 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package scheduledtask + +import ( + "context" + "time" + + "github.com/coze-dev/coze-loop/backend/modules/llm/pkg/goroutineutil" + "github.com/coze-dev/coze-loop/backend/pkg/logs" +) + +type ScheduledTask interface { + Run() error + RunOnce(ctx context.Context) error + Stop() error +} + +type BaseScheduledTask struct { + name string + timeInterval time.Duration + stopChan chan struct{} +} + +func NewBaseScheduledTask(name string, timeInterval time.Duration) BaseScheduledTask { + return BaseScheduledTask{ + name: name, + timeInterval: timeInterval, + stopChan: make(chan struct{}), + } +} + +func (b *BaseScheduledTask) Run() error { + ticker := time.NewTicker(b.timeInterval) + goroutineutil.GoWithDefaultRecovery(context.Background(), func() { + for { + select { + case <-ticker.C: + ctx := context.Background() + startTime := time.Now() + if err := b.RunOnce(ctx); err != nil { + logs.CtxError(ctx, "ScheduledTask [%s] run error: %v, cost: %v", b.name, err, time.Since(startTime)) + } else { + logs.CtxInfo(ctx, "ScheduledTask [%s] run success, cost: %v", b.name, time.Since(startTime)) + } + case <-b.stopChan: + return + } + } + }) + return nil +} + +func (b *BaseScheduledTask) RunOnce(ctx context.Context) error { + panic("implement me") +} + +func (b *BaseScheduledTask) Stop() error { + close(b.stopChan) + return nil +} diff --git a/backend/modules/observability/domain/task/entity/task.go b/backend/modules/observability/domain/task/entity/task.go index ce2b6a998..8d822fc12 100644 --- a/backend/modules/observability/domain/task/entity/task.go +++ b/backend/modules/observability/domain/task/entity/task.go @@ -174,7 +174,7 @@ type DataReflowRunConfig struct { Status string `json:"status"` } -func (t ObservabilityTask) GetRunTimeRange() (startAt, endAt int64) { +func (t *ObservabilityTask) GetRunTimeRange() (startAt, endAt int64) { if t.EffectiveTime == nil { return 0, 0 } @@ -194,7 +194,7 @@ func (t ObservabilityTask) GetRunTimeRange() (startAt, endAt int64) { return startAt, endAt } -func (t ObservabilityTask) IsFinished() bool { +func (t *ObservabilityTask) IsFinished() bool { switch t.TaskStatus { case TaskStatusSuccess, TaskStatusDisabled, TaskStatusPending: return true diff --git a/backend/modules/observability/domain/task/service/task_service.go b/backend/modules/observability/domain/task/service/task_service.go index 15ac2b76a..74bc33ae1 100644 --- a/backend/modules/observability/domain/task/service/task_service.go +++ b/backend/modules/observability/domain/task/service/task_service.go @@ -73,6 +73,8 @@ type ITaskService interface { ListTasks(ctx context.Context, req *ListTasksReq) (resp *ListTasksResp, err error) GetTask(ctx context.Context, req *GetTaskReq) (resp *GetTaskResp, err error) CheckTaskName(ctx context.Context, req *CheckTaskNameReq) (resp *CheckTaskNameResp, err error) + + SendBackfillMessage(ctx context.Context, event *entity.BackFillEvent) error } func NewTaskServiceImpl( @@ -150,7 +152,7 @@ func (t *TaskServiceImpl) CreateTask(ctx context.Context, req *CreateTaskReq) (r TaskID: id, } - if err := t.sendBackfillMessage(context.Background(), backfillEvent); err != nil { + if err := t.SendBackfillMessage(context.Background(), backfillEvent); err != nil { // 失败了会有定时任务进行补偿 logs.CtxWarn(ctx, "send backfill message failed, task_id=%d, err=%v", id, err) } @@ -346,8 +348,8 @@ func (t *TaskServiceImpl) CheckTaskName(ctx context.Context, req *CheckTaskNameR return &CheckTaskNameResp{Pass: gptr.Of(pass)}, nil } -// sendBackfillMessage 发送MQ消息 -func (t *TaskServiceImpl) sendBackfillMessage(ctx context.Context, event *entity.BackFillEvent) error { +// SendBackfillMessage 发送MQ消息 +func (t *TaskServiceImpl) SendBackfillMessage(ctx context.Context, event *entity.BackFillEvent) error { if t.backfillProducer == nil { return errorx.NewByCode(obErrorx.CommonInternalErrorCode, errorx.WithExtraMsg("backfill producer not initialized")) } diff --git a/backend/modules/observability/domain/task/service/task_service_test.go b/backend/modules/observability/domain/task/service/task_service_test.go index 2499f662a..389760d6d 100755 --- a/backend/modules/observability/domain/task/service/task_service_test.go +++ b/backend/modules/observability/domain/task/service/task_service_test.go @@ -636,7 +636,7 @@ func TestTaskServiceImpl_shouldTriggerBackfill(t *testing.T) { func TestTaskServiceImpl_sendBackfillMessage(t *testing.T) { t.Run("producer nil", func(t *testing.T) { svc := &TaskServiceImpl{} - err := svc.sendBackfillMessage(context.Background(), &entity.BackFillEvent{}) + err := svc.SendBackfillMessage(context.Background(), &entity.BackFillEvent{}) statusErr, ok := errorx.FromStatusError(err) if assert.True(t, ok) { assert.EqualValues(t, obErrorx.CommonInternalErrorCode, statusErr.Code()) @@ -646,7 +646,7 @@ func TestTaskServiceImpl_sendBackfillMessage(t *testing.T) { t.Run("success", func(t *testing.T) { ch := make(chan *entity.BackFillEvent, 1) svc := &TaskServiceImpl{backfillProducer: &stubBackfillProducer{ch: ch}} - err := svc.sendBackfillMessage(context.Background(), &entity.BackFillEvent{TaskID: 1}) + err := svc.SendBackfillMessage(context.Background(), &entity.BackFillEvent{TaskID: 1}) assert.NoError(t, err) select { case event := <-ch: diff --git a/backend/modules/observability/domain/task/service/taskexe/scheduledtask/local_cache_refresh.go b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/local_cache_refresh.go new file mode 100644 index 000000000..0d254efe9 --- /dev/null +++ b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/local_cache_refresh.go @@ -0,0 +1,94 @@ +// Copyright (c) 2025 coze-dev Authors +// SPDX-License-Identifier: Apache-2.0 + +package scheduledtask + +import ( + "context" + "strconv" + "time" + + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/scheduledtask" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/pkg/logs" + "github.com/samber/lo" +) + +type LocalCacheRefreshTask struct { + scheduledtask.BaseScheduledTask + + traceHubService tracehub.ITraceHubService + taskRepo repo.ITaskRepo +} + +func NewLocalCacheRefreshTask(traceHubService tracehub.ITraceHubService, taskRepo repo.ITaskRepo) scheduledtask.ScheduledTask { + return &LocalCacheRefreshTask{ + BaseScheduledTask: scheduledtask.NewBaseScheduledTask("LocalCacheRefreshTask", 2*time.Minute), + traceHubService: traceHubService, + taskRepo: taskRepo, + } +} + +func (t *LocalCacheRefreshTask) RunOnce(ctx context.Context) error { + logs.CtxInfo(ctx, "Start syncing task cache...") + + // 1. Retrieve spaceID, botID, and task information for all non-final tasks from the database + spaceIDs, botIDs, tasks, err := t.getNonFinalTaskInfos(ctx) + if err != nil { + logs.CtxError(ctx, "Failed to get non-final task list", "err", err) + return err + } + logs.CtxInfo(ctx, "Retrieved task information, taskCount:%d, spaceCount:%d, botCount:%d", len(tasks), len(spaceIDs), len(botIDs)) + + if err := t.traceHubService.StoneTaskCache(ctx, tracehub.TaskCacheInfo{ + WorkspaceIDs: spaceIDs, + BotIDs: botIDs, + Tasks: tasks, + UpdateTime: time.Now(), // Set the current time as the update time + }); err != nil { + logs.CtxError(ctx, "Failed to update task cache", "err", err) + return err + } + return nil +} + +func (t *LocalCacheRefreshTask) getNonFinalTaskInfos(ctx context.Context) ([]string, []string, []*entity.ObservabilityTask, error) { + tasks, err := t.taskRepo.ListNonFinalTasks(ctx) + if err != nil { + return nil, nil, nil, err + } + + spaceMap := make(map[string]interface{}) + botMap := make(map[string]interface{}) + + for _, task := range tasks { + spaceMap[strconv.FormatInt(task.WorkspaceID, 10)] = struct{}{} + if task.SpanFilter != nil && task.SpanFilter.Filters.FilterFields != nil { + extractBotIDFromFilters(task.SpanFilter.Filters.FilterFields, botMap) + } + } + + return lo.Keys(spaceMap), lo.Keys(botMap), tasks, nil +} + +// extractBotIDFromFilters 递归提取过滤器中的 bot_id 值,包括 SubFilter +func extractBotIDFromFilters(filterFields []*loop_span.FilterField, botMap map[string]interface{}) { + for _, filterField := range filterFields { + if filterField == nil { + continue + } + // 检查当前 FilterField 的 FieldName + if filterField.FieldName == "bot_id" { + for _, v := range filterField.Values { + botMap[v] = struct{}{} + } + } + // 递归处理 SubFilter + if filterField.SubFilter != nil && filterField.SubFilter.FilterFields != nil { + extractBotIDFromFilters(filterField.SubFilter.FilterFields, botMap) + } + } +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/scheduled_task_test.go similarity index 99% rename from backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go rename to backend/modules/observability/domain/task/service/taskexe/scheduledtask/scheduled_task_test.go index f881d0d32..0fc7e4a59 100755 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task_test.go +++ b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/scheduled_task_test.go @@ -1,7 +1,7 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 -package tracehub +package scheduledtask import ( "context" diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/status_check.go old mode 100755 new mode 100644 similarity index 64% rename from backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go rename to backend/modules/observability/domain/task/service/taskexe/scheduledtask/status_check.go index 46e3772e7..e214bf461 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/scheduled_task.go +++ b/backend/modules/observability/domain/task/service/taskexe/scheduledtask/status_check.go @@ -1,27 +1,28 @@ // Copyright (c) 2025 coze-dev Authors // SPDX-License-Identifier: Apache-2.0 -package tracehub +package scheduledtask import ( "context" "fmt" - "os" - "slices" - "strconv" "time" + "github.com/coze-dev/coze-loop/backend/infra/lock" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/scheduledtask" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/entity" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/repo" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe" - "github.com/coze-dev/coze-loop/backend/modules/observability/domain/trace/entity/loop_span" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/processor" + "github.com/coze-dev/coze-loop/backend/modules/observability/domain/task/service/taskexe/tracehub" + "github.com/coze-dev/coze-loop/backend/pkg/json" "github.com/coze-dev/coze-loop/backend/pkg/lang/ptr" "github.com/coze-dev/coze-loop/backend/pkg/logs" "github.com/pkg/errors" - "github.com/samber/lo" ) -// TaskRunCountInfo represents the TaskRunCount information structure type TaskRunCountInfo struct { TaskID int64 TaskRunID int64 @@ -31,77 +32,84 @@ type TaskRunCountInfo struct { } const ( - transformTaskStatusLockKey = "observability:tracehub:transform_task_status" - transformTaskStatusLockTTL = 3 * time.Minute - syncTaskRunCountsLockKey = "observability:tracehub:sync_task_run_counts" + syncTaskRunCountLockTTL = 3 * time.Minute + checkTaskStatusLockKey = "observability:task:check_task_status" + checkTaskStatusLockTTL = 3 * time.Minute + backfillLockKeyTemplate = "observability:tracehub:backfill:%d" + backfillLockMaxHold = 24 * time.Hour ) -// startScheduledTask launches the scheduled task goroutine -func (h *TraceHubServiceImpl) startScheduledTask() { - h.syncTaskCache() - go func() { - for { - select { - case <-h.scheduledTaskTicker.C: - // Execute scheduled task - h.transformTaskStatus() // 抢锁 - case <-h.stopChan: - // Stop scheduled task - h.scheduledTaskTicker.Stop() - return - } - } - }() - go func() { - for { - select { - case <-h.syncTaskTicker.C: - // Execute scheduled task - h.syncTaskRunCounts() // 抢锁 - h.syncTaskCache() - case <-h.stopChan: - // Stop scheduled task - h.syncTaskTicker.Stop() - return - } - } - }() +type StatusCheckTask struct { + scheduledtask.BaseScheduledTask + + config config.ITraceConfig + locker lock.ILocker + traceHubService tracehub.ITraceHubService + taskService service.ITaskService + taskProcessor processor.TaskProcessor + taskRepo repo.ITaskRepo } -func (h *TraceHubServiceImpl) transformTaskStatus() { - ctx := context.Background() - ctx = h.fillCtx(ctx) +func NewStatusCheckTask( + locker lock.ILocker, + config config.ITraceConfig, + traceHubService tracehub.ITraceHubService, + taskService service.ITaskService, + taskProcessor processor.TaskProcessor, + taskRepo repo.ITaskRepo, +) scheduledtask.ScheduledTask { + return &StatusCheckTask{ + BaseScheduledTask: scheduledtask.NewBaseScheduledTask("StatusCheckTask", 5*time.Minute), + locker: locker, + config: config, + traceHubService: traceHubService, + taskService: taskService, + taskProcessor: taskProcessor, + taskRepo: taskRepo, + } +} - cfg, err := h.config.GetConsumerListening(ctx) +func (t *StatusCheckTask) RunOnce(ctx context.Context) error { + cfg, err := t.config.GetConsumerListening(ctx) if err != nil { - return + return err } if !cfg.IsEnabled || !cfg.IsAllSpace { - return + return nil } - if slices.Contains([]string{TracehubClusterName, InjectClusterName}, os.Getenv(TceCluster)) { - return - } - - if h.locker != nil { - locked, lockErr := h.locker.Lock(ctx, transformTaskStatusLockKey, transformTaskStatusLockTTL) + if t.locker != nil { + locked, lockErr := t.locker.Lock(ctx, checkTaskStatusLockKey, checkTaskStatusLockTTL) if lockErr != nil { logs.CtxError(ctx, "transformTaskStatus acquire lock failed", "err", lockErr) - return + return lockErr } if !locked { logs.CtxInfo(ctx, "transformTaskStatus lock held by others, skip execution") - return + return nil } } + + if err = t.checkTaskStatus(ctx); err != nil { + logs.CtxError(ctx, "Failed to check task status", "err", err) + return err + } + if err = t.syncTaskRunCount(ctx); err != nil { + logs.CtxError(ctx, "Failed to sync task run count", "err", err) + return err + } + + return nil +} + +func (t *StatusCheckTask) checkTaskStatus(ctx context.Context) error { logs.CtxInfo(ctx, "Scheduled task started...") // Read all non-final (success/disabled) tasks - taskPOs, err := h.listNonFinalTask(ctx) + taskPOs, err := t.listNonFinalTask(ctx) if err != nil { logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - return + return err } logs.CtxInfo(ctx, "Scheduled task retrieved number of tasks:%d", len(taskPOs)) for _, taskPO := range taskPOs { @@ -115,7 +123,7 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { endTime = time.UnixMilli(taskPO.EffectiveTime.EndAt) startTime = time.UnixMilli(taskPO.EffectiveTime.StartAt) } - proc := h.taskProcessor.GetTaskProcessor(taskPO.TaskType) + proc := t.taskProcessor.GetTaskProcessor(taskPO.TaskType) // Task time horizon reached // End when the task end time is reached logs.CtxInfo(ctx, "[auto_task]taskID:%d, endTime:%v, startTime:%v", taskPO.ID, endTime, startTime) @@ -134,9 +142,9 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { } if backfillTaskRun.RunStatus != entity.TaskRunStatusDone { lockKey := fmt.Sprintf(backfillLockKeyTemplate, taskPO.ID) - locked, _, cancel, lockErr := h.locker.LockWithRenew(ctx, lockKey, transformTaskStatusLockTTL, backfillLockMaxHold) + locked, _, cancel, lockErr := t.locker.LockWithRenew(ctx, lockKey, syncTaskRunCountLockTTL, backfillLockMaxHold) if (lockErr != nil || !locked) && time.Now().Add(-backfillTaskRun.RunEndAt.Sub(backfillTaskRun.RunStartAt)).Before(backfillTaskRun.RunEndAt) { - _ = h.sendBackfillMessage(ctx, &entity.BackFillEvent{ + _ = t.taskService.SendBackfillMessage(ctx, &entity.BackFillEvent{ TaskID: taskPO.ID, SpaceID: taskPO.WorkspaceID, }) @@ -158,9 +166,9 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { } if backfillTaskRun.RunStatus != entity.TaskRunStatusDone { lockKey := fmt.Sprintf(backfillLockKeyTemplate, taskPO.ID) - locked, _, cancel, lockErr := h.locker.LockWithRenew(ctx, lockKey, transformTaskStatusLockTTL, backfillLockMaxHold) + locked, _, cancel, lockErr := t.locker.LockWithRenew(ctx, lockKey, syncTaskRunCountLockTTL, backfillLockMaxHold) if (lockErr != nil || !locked) && time.Now().Add(-backfillTaskRun.RunEndAt.Sub(backfillTaskRun.RunStartAt)).Before(backfillTaskRun.RunEndAt) { - _ = h.sendBackfillMessage(ctx, &entity.BackFillEvent{ + _ = t.taskService.SendBackfillMessage(ctx, &entity.BackFillEvent{ TaskID: taskPO.ID, SpaceID: taskPO.WorkspaceID, }) @@ -234,37 +242,20 @@ func (h *TraceHubServiceImpl) transformTaskStatus() { } } } + return nil } -// syncTaskRunCounts synchronizes TaskRunCount data to the database -func (h *TraceHubServiceImpl) syncTaskRunCounts() { - if slices.Contains([]string{TracehubClusterName, InjectClusterName}, os.Getenv(TceCluster)) { - return - } - ctx := context.Background() - ctx = h.fillCtx(ctx) - - if h.locker != nil { - locked, lockErr := h.locker.Lock(ctx, syncTaskRunCountsLockKey, transformTaskStatusLockTTL) - if lockErr != nil { - logs.CtxError(ctx, "syncTaskRunCounts acquire lock failed", "err", lockErr) - return - } - if !locked { - logs.CtxInfo(ctx, "syncTaskRunCounts lock held by others, skip execution") - return - } - } +func (t *StatusCheckTask) syncTaskRunCount(ctx context.Context) error { logs.CtxInfo(ctx, "Start syncing TaskRunCounts to database...") // 1. Retrieve non-final task list - taskDOs, err := h.listSyncTaskRunTask(ctx) + taskDOs, err := t.listSyncTaskRunTask(ctx) if err != nil { logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - return + return err } if len(taskDOs) == 0 { logs.CtxInfo(ctx, "No non-final tasks need syncing") - return + return nil } // 2. Collect all TaskRun information that needs syncing @@ -284,7 +275,7 @@ func (h *TraceHubServiceImpl) syncTaskRunCounts() { if len(taskRunInfos) == 0 { logs.CtxInfo(ctx, "No TaskRun requires syncing") - return + return nil } logs.CtxInfo(ctx, "Number of TaskRun entries requiring syncing:%d", len(taskRunInfos)) @@ -298,138 +289,23 @@ func (h *TraceHubServiceImpl) syncTaskRunCounts() { } batch := taskRunInfos[i:end] - h.processBatch(ctx, batch) - } -} - -func (h *TraceHubServiceImpl) syncTaskCache() { - ctx := context.Background() - ctx = h.fillCtx(ctx) - - logs.CtxInfo(ctx, "Start syncing task cache...") - - // 1. Retrieve spaceID, botID, and task information for all non-final tasks from the database - spaceIDs, botIDs, tasks, err := h.getNonFinalTaskInfos(ctx) - if err != nil { - logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - return + t.processBatch(ctx, batch) } - logs.CtxInfo(ctx, "Retrieved task information, taskCount:%d, spaceCount:%d, botCount:%d", len(tasks), len(spaceIDs), len(botIDs)) - - h.localCache.StoneTaskCache(TaskCacheInfo{ - WorkspaceIDs: spaceIDs, - BotIDs: botIDs, - Tasks: tasks, - UpdateTime: time.Now(), // Set the current time as the update time - }) -} - -// processBatch synchronizes TaskRun counts in batches -func (h *TraceHubServiceImpl) processBatch(ctx context.Context, batch []*TaskRunCountInfo) { - // 1. Read Redis count data in batch - for _, info := range batch { - // Read taskruncount - count, err := h.taskRepo.GetTaskRunCount(ctx, info.TaskID, info.TaskRunID) - if err != nil || count == -1 { - logs.CtxWarn(ctx, "Failed to get TaskRunCount, taskID:%d, taskRunID:%d, err:%v", info.TaskID, info.TaskRunID, err) - } else { - info.TaskRunCount = count - } - - // Read taskrun success count - successCount, err := h.taskRepo.GetTaskRunSuccessCount(ctx, info.TaskID, info.TaskRunID) - if err != nil || successCount == -1 { - logs.CtxWarn(ctx, "Failed to get TaskRunSuccessCount, taskID:%d, taskRunID:%d, err:%v", info.TaskID, info.TaskRunID, err) - } else { - info.TaskRunSuccCount = successCount - } - - // Read taskrun fail count - failCount, err := h.taskRepo.GetTaskRunFailCount(ctx, info.TaskID, info.TaskRunID) - if err != nil || failCount == -1 { - logs.CtxWarn(ctx, "Failed to get TaskRunFailCount, taskID:%d, taskRunID:%d, err:%v", info.TaskID, info.TaskRunID, err) - } else { - info.TaskRunFailCount = failCount - } - - logs.CtxDebug(ctx, "Read count data", - "taskID", info.TaskID, - "taskRunID", info.TaskRunID, - "runCount", info.TaskRunCount, - "successCount", info.TaskRunSuccCount, - "failCount", info.TaskRunFailCount) - } - logs.CtxInfo(ctx, "Start updating TaskRun detail in batch, batchSize:%d, batch:%v", len(batch), batch) - // 2. Update database in batch - for _, info := range batch { - err := h.updateTaskRunDetail(ctx, info) - if err != nil { - logs.CtxError(ctx, "Failed to update TaskRun detail", - "taskID", info.TaskID, - "taskRunID", info.TaskRunID, - "err", err) - } else { - logs.CtxDebug(ctx, "Succeeded in updating TaskRun detail", - "taskID", info.TaskID, - "taskRunID", info.TaskRunID) - } - } - - logs.CtxInfo(ctx, "Batch processing completed, batchSize:%d", len(batch)) -} - -// updateTaskRunDetail updates the run_detail field of TaskRun -func (h *TraceHubServiceImpl) updateTaskRunDetail(ctx context.Context, info *TaskRunCountInfo) error { - // Build run_detail JSON data - runDetail := map[string]interface{}{ - "total_count": info.TaskRunCount, - "success_count": info.TaskRunSuccCount, - "failed_count": info.TaskRunFailCount, - } - - // Update using optimistic locking - err := h.taskRepo.UpdateTaskRunWithOCC(ctx, info.TaskRunID, 0, map[string]interface{}{ - "run_detail": ToJSONString(ctx, runDetail), - }) - if err != nil { - return errors.Wrap(err, "Failed to update TaskRun") - } - return nil } -func (h *TraceHubServiceImpl) listNonFinalTaskByRedis(ctx context.Context, spaceID string) ([]*entity.ObservabilityTask, error) { - var taskPOs []*entity.ObservabilityTask - nonFinalTaskIDs, err := h.taskRepo.ListNonFinalTaskBySpaceID(ctx, spaceID) +func (t *StatusCheckTask) listSyncTaskRunTask(ctx context.Context) ([]*entity.ObservabilityTask, error) { + var taskDOs []*entity.ObservabilityTask + taskDOs, err := t.listNonFinalTask(ctx) if err != nil { logs.CtxError(ctx, "Failed to get non-final task list", "err", err) return nil, err } - logs.CtxInfo(ctx, "Start listing non-final tasks, taskCount:%d, nonFinalTaskIDs:%v", len(nonFinalTaskIDs), nonFinalTaskIDs) - if len(nonFinalTaskIDs) == 0 { - return taskPOs, nil - } - for _, taskID := range nonFinalTaskIDs { - taskPO, err := h.taskRepo.GetTaskByCache(ctx, taskID) - if err != nil { - logs.CtxError(ctx, "Failed to get task", "err", err) - return nil, err - } - if taskPO == nil { - continue - } - taskPOs = append(taskPOs, taskPO) - } - return taskPOs, nil -} - -func (h *TraceHubServiceImpl) listNonFinalTask(ctx context.Context) ([]*entity.ObservabilityTask, error) { - var taskPOs []*entity.ObservabilityTask var offset int32 = 0 - const limit int32 = 500 + const limit int32 = 1000 // Paginate through all tasks for { - tasklist, _, err := h.taskRepo.ListTasks(ctx, repo.ListTaskParam{ + tasklist, _, err := t.taskRepo.ListTasks(ctx, repo.ListTaskParam{ ReqLimit: limit, ReqOffset: offset, TaskFilters: &entity.TaskFilterFields{ @@ -437,23 +313,30 @@ func (h *TraceHubServiceImpl) listNonFinalTask(ctx context.Context) ([]*entity.O { FieldName: ptr.Of(entity.TaskFieldNameTaskStatus), Values: []string{ - string(entity.TaskStatusUnstarted), - string(entity.TaskStatusRunning), - string(entity.TaskStatusPending), + string(entity.TaskStatusSuccess), + string(entity.TaskStatusDisabled), }, QueryType: ptr.Of(entity.QueryTypeIn), FieldType: ptr.Of(entity.FieldTypeString), }, + { + FieldName: ptr.Of(entity.TaskFieldName("updated_at")), + Values: []string{ + fmt.Sprintf("%d", time.Now().Add(-24*time.Hour).UnixMilli()), + }, + QueryType: ptr.Of(entity.QueryTypeGt), + FieldType: ptr.Of(entity.FieldTypeLong), + }, }, }, }) if err != nil { logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - return nil, err + break } // Add tasks from the current page to the full list - taskPOs = append(taskPOs, tasklist...) + taskDOs = append(taskDOs, tasklist...) // If fewer tasks than limit are returned, this is the last page if len(tasklist) < int(limit) { @@ -463,21 +346,16 @@ func (h *TraceHubServiceImpl) listNonFinalTask(ctx context.Context) ([]*entity.O // Move to the next page, increasing offset by 1000 offset += limit } - return taskPOs, nil + return taskDOs, nil } -func (h *TraceHubServiceImpl) listSyncTaskRunTask(ctx context.Context) ([]*entity.ObservabilityTask, error) { - var taskDOs []*entity.ObservabilityTask - taskDOs, err := h.listNonFinalTask(ctx) - if err != nil { - logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - return nil, err - } +func (t *StatusCheckTask) listNonFinalTask(ctx context.Context) ([]*entity.ObservabilityTask, error) { + var taskPOs []*entity.ObservabilityTask var offset int32 = 0 - const limit int32 = 1000 + const limit int32 = 500 // Paginate through all tasks for { - tasklist, _, err := h.taskRepo.ListTasks(ctx, repo.ListTaskParam{ + tasklist, _, err := t.taskRepo.ListTasks(ctx, repo.ListTaskParam{ ReqLimit: limit, ReqOffset: offset, TaskFilters: &entity.TaskFilterFields{ @@ -485,30 +363,23 @@ func (h *TraceHubServiceImpl) listSyncTaskRunTask(ctx context.Context) ([]*entit { FieldName: ptr.Of(entity.TaskFieldNameTaskStatus), Values: []string{ - string(entity.TaskStatusSuccess), - string(entity.TaskStatusDisabled), + string(entity.TaskStatusUnstarted), + string(entity.TaskStatusRunning), + string(entity.TaskStatusPending), }, QueryType: ptr.Of(entity.QueryTypeIn), FieldType: ptr.Of(entity.FieldTypeString), }, - { - FieldName: ptr.Of(entity.TaskFieldName("updated_at")), - Values: []string{ - fmt.Sprintf("%d", time.Now().Add(-24*time.Hour).UnixMilli()), - }, - QueryType: ptr.Of(entity.QueryTypeGt), - FieldType: ptr.Of(entity.FieldTypeLong), - }, }, }, }) if err != nil { logs.CtxError(ctx, "Failed to get non-final task list", "err", err) - break + return nil, err } // Add tasks from the current page to the full list - taskDOs = append(taskDOs, tasklist...) + taskPOs = append(taskPOs, tasklist...) // If fewer tasks than limit are returned, this is the last page if len(tasklist) < int(limit) { @@ -518,43 +389,79 @@ func (h *TraceHubServiceImpl) listSyncTaskRunTask(ctx context.Context) ([]*entit // Move to the next page, increasing offset by 1000 offset += limit } - return taskDOs, nil + return taskPOs, nil } -func (h *TraceHubServiceImpl) getNonFinalTaskInfos(ctx context.Context) ([]string, []string, []*entity.ObservabilityTask, error) { - tasks, err := h.taskRepo.ListNonFinalTasks(ctx) - if err != nil { - return nil, nil, nil, err - } +// processBatch synchronizes TaskRun counts in batches +func (t *StatusCheckTask) processBatch(ctx context.Context, batch []*TaskRunCountInfo) { + // 1. Read Redis count data in batch + for _, info := range batch { + // Read taskruncount + count, err := t.taskRepo.GetTaskRunCount(ctx, info.TaskID, info.TaskRunID) + if err != nil || count == -1 { + logs.CtxWarn(ctx, "Failed to get TaskRunCount, taskID:%d, taskRunID:%d, err:%v", info.TaskID, info.TaskRunID, err) + } else { + info.TaskRunCount = count + } - spaceMap := make(map[string]interface{}) - botMap := make(map[string]interface{}) + // Read taskrun success count + successCount, err := t.taskRepo.GetTaskRunSuccessCount(ctx, info.TaskID, info.TaskRunID) + if err != nil || successCount == -1 { + logs.CtxWarn(ctx, "Failed to get TaskRunSuccessCount, taskID:%d, taskRunID:%d, err:%v", info.TaskID, info.TaskRunID, err) + } else { + info.TaskRunSuccCount = successCount + } - for _, task := range tasks { - spaceMap[strconv.FormatInt(task.WorkspaceID, 10)] = struct{}{} - if task.SpanFilter != nil && task.SpanFilter.Filters.FilterFields != nil { - extractBotIDFromFilters(task.SpanFilter.Filters.FilterFields, botMap) + // Read taskrun fail count + failCount, err := t.taskRepo.GetTaskRunFailCount(ctx, info.TaskID, info.TaskRunID) + if err != nil || failCount == -1 { + logs.CtxWarn(ctx, "Failed to get TaskRunFailCount, taskID:%d, taskRunID:%d, err:%v", info.TaskID, info.TaskRunID, err) + } else { + info.TaskRunFailCount = failCount + } + + logs.CtxDebug(ctx, "Read count data", + "taskID", info.TaskID, + "taskRunID", info.TaskRunID, + "runCount", info.TaskRunCount, + "successCount", info.TaskRunSuccCount, + "failCount", info.TaskRunFailCount) + } + logs.CtxInfo(ctx, "Start updating TaskRun detail in batch, batchSize:%d, batch:%v", len(batch), batch) + // 2. Update database in batch + for _, info := range batch { + err := t.updateTaskRunDetail(ctx, info) + if err != nil { + logs.CtxError(ctx, "Failed to update TaskRun detail", + "taskID", info.TaskID, + "taskRunID", info.TaskRunID, + "err", err) + } else { + logs.CtxDebug(ctx, "Succeeded in updating TaskRun detail", + "taskID", info.TaskID, + "taskRunID", info.TaskRunID) } } - return lo.Keys(spaceMap), lo.Keys(botMap), tasks, nil + logs.CtxInfo(ctx, "Batch processing completed, batchSize:%d", len(batch)) } -// extractBotIDFromFilters 递归提取过滤器中的 bot_id 值,包括 SubFilter -func extractBotIDFromFilters(filterFields []*loop_span.FilterField, botMap map[string]interface{}) { - for _, filterField := range filterFields { - if filterField == nil { - continue - } - // 检查当前 FilterField 的 FieldName - if filterField.FieldName == "bot_id" { - for _, v := range filterField.Values { - botMap[v] = struct{}{} - } - } - // 递归处理 SubFilter - if filterField.SubFilter != nil && filterField.SubFilter.FilterFields != nil { - extractBotIDFromFilters(filterField.SubFilter.FilterFields, botMap) - } +// updateTaskRunDetail updates the run_detail field of TaskRun +func (t *StatusCheckTask) updateTaskRunDetail(ctx context.Context, info *TaskRunCountInfo) error { + // Build run_detail JSON data + runDetail := map[string]interface{}{ + "total_count": info.TaskRunCount, + "success_count": info.TaskRunSuccCount, + "failed_count": info.TaskRunFailCount, } + + // Update using optimistic locking + err := t.taskRepo.UpdateTaskRunWithOCC(ctx, info.TaskRunID, 0, map[string]interface{}{ + "run_detail": json.MarshalStringIgnoreErr(runDetail), + }) + if err != nil { + return errors.Wrap(err, "Failed to update TaskRun") + } + + return nil } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go index dafb4de7f..2a6ae41f3 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/backfill.go @@ -33,35 +33,35 @@ const ( // 定时任务+锁 func (h *TraceHubServiceImpl) BackFill(ctx context.Context, event *entity.BackFillEvent) error { // 1. Set the current task context - ctx = h.fillCtx(ctx) logs.CtxInfo(ctx, "BackFill msg %+v", event) var ( lockKey string lockCancel func() ) - if h.locker != nil && event != nil { - lockKey = fmt.Sprintf(backfillLockKeyTemplate, event.TaskID) - locked, lockCtx, cancel, lockErr := h.locker.LockWithRenew(ctx, lockKey, backfillLockTTL, backfillLockMaxHold) - if lockErr != nil { - logs.CtxError(ctx, "backfill acquire lock failed", "task_id", event.TaskID, "err", lockErr) - return lockErr + + if h.locker != nil { + var err error + ctx, lockCancel, lockKey, err = h.acquireBackfillLock(ctx, event.TaskID) + if err != nil { + return err } - if !locked { - logs.CtxInfo(ctx, "backfill lock held by others, skip execution", "task_id", event.TaskID) + + // 如果lockKey不为空,说明成功获取了锁,需要在函数退出时释放 + if lockKey != "" { + defer func(cancel func(), key string) { + if cancel != nil { + cancel() + } else if key != "" { + if _, err := h.locker.Unlock(key); err != nil { + logs.CtxWarn(ctx, "backfill release lock failed", "task_id", event.TaskID, "err", err) + } + } + }(lockCancel, lockKey) + } else if lockCancel == nil { + // 如果lockKey为空且lockCancel为nil,说明锁被其他实例持有,直接返回 return nil } - lockCancel = cancel - ctx = lockCtx - defer func(cancel func()) { - if cancel != nil { - cancel() - } else if lockKey != "" { - if _, err := h.locker.Unlock(lockKey); err != nil { - logs.CtxWarn(ctx, "backfill release lock failed", "task_id", event.TaskID, "err", err) - } - } - }(lockCancel) } sub, err := h.buildSubscriber(ctx, event) @@ -464,3 +464,21 @@ func (h *TraceHubServiceImpl) sendBackfillMessage(ctx context.Context, event *en return h.backfillProducer.SendBackfill(ctx, event) } + +// acquireBackfillLock 尝试获取回填任务的分布式锁 +// 返回值: 新的上下文, 取消函数, 锁键, 错误 +func (h *TraceHubServiceImpl) acquireBackfillLock(ctx context.Context, taskID int64) (context.Context, func(), string, error) { + lockKey := fmt.Sprintf(backfillLockKeyTemplate, taskID) + locked, lockCtx, cancel, lockErr := h.locker.LockWithRenew(ctx, lockKey, backfillLockTTL, backfillLockMaxHold) + if lockErr != nil { + logs.CtxError(ctx, "backfill acquire lock failed", "task_id", taskID, "err", lockErr) + return ctx, nil, "", lockErr + } + + if !locked { + logs.CtxInfo(ctx, "backfill lock held by others, skip execution", "task_id", taskID) + return ctx, nil, "", nil + } + + return lockCtx, cancel, lockKey, nil +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go index ece5e9eac..dc1f09232 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/local_cache.go @@ -30,7 +30,7 @@ func NewLocalCache() *LocalCache { return &LocalCache{} } -func (l *LocalCache) StoneTaskCache(info TaskCacheInfo) { +func (l *LocalCache) StoneTaskCache(ctx context.Context, info TaskCacheInfo) { l.taskCache.Store(CacheKeyObjListWithTask, info) } diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go index cee637a55..8d15b254b 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/span_trigger.go @@ -18,7 +18,6 @@ import ( ) func (h *TraceHubServiceImpl) SpanTrigger(ctx context.Context, span *loop_span.Span) error { - ctx = h.fillCtx(ctx) logSuffix := fmt.Sprintf("log_id=%s, trace_id=%s, span_id=%s", span.LogID, span.TraceID, span.SpanID) logs.CtxInfo(ctx, "auto_task start, %s", logSuffix) @@ -47,21 +46,14 @@ func (h *TraceHubServiceImpl) SpanTrigger(ctx context.Context, span *loop_span.S return nil } - // 3、Sample - subs = gslice.Filter(subs, func(sub *spanSubscriber) bool { return sub.Sampled() }) - logs.CtxInfo(ctx, "%d subscriber of flow span sampled, %s", len(subs), logSuffix) - if len(subs) == 0 { - return nil - } - - // 4. PreDispatch + // 3. PreDispatch if err = h.preDispatch(ctx, subs); err != nil { logs.CtxWarn(ctx, "preDispatch flow span failed, %s, err: %v", logSuffix, err) return err } logs.CtxInfo(ctx, "%d preDispatch success, %v", len(subs), subs) - // 5、Dispatch + // 4、Dispatch if err = h.dispatch(ctx, span, subs); err != nil { logs.CtxError(ctx, "dispatch flow span failed, %s, err: %v", logSuffix, err) return err @@ -118,8 +110,12 @@ func (h *TraceHubServiceImpl) buildSubscriberOfSpan(ctx context.Context, span *l continue } if ok { - subscribers[keep] = s - keep++ + if s.Sampled() { + subscribers[keep] = s + keep++ + } else { + logs.CtxInfo(ctx, "span not sampled, task_id=%d, trace_id=%s, span_id=%s", s.taskID, span.TraceID, span.SpanID) + } } } return subscribers[:keep], merr.ErrorOrNil() @@ -181,11 +177,7 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, subs []*spanSubsc } continue } - sampler := sub.t.Sampler - // Fetch the corresponding task count and subtask count - taskCount, _ := h.taskRepo.GetTaskCount(ctx, sub.taskID) - taskRunCount, _ := h.taskRepo.GetTaskRunCount(ctx, sub.taskID, taskRunConfig.ID) - logs.CtxInfo(ctx, "preDispatch, task_id=%d, taskCount=%d, taskRunCount=%d", sub.taskID, taskCount, taskRunCount) + endTime := time.UnixMilli(sub.t.EffectiveTime.EndAt) // Reached task time limit if time.Now().After(endTime) { @@ -200,6 +192,12 @@ func (h *TraceHubServiceImpl) preDispatch(ctx context.Context, subs []*spanSubsc continue } } + + sampler := sub.t.Sampler + // Fetch the corresponding task count and subtask count + taskCount, _ := h.taskRepo.GetTaskCount(ctx, sub.taskID) + taskRunCount, _ := h.taskRepo.GetTaskRunCount(ctx, sub.taskID, taskRunConfig.ID) + logs.CtxInfo(ctx, "preDispatch, task_id=%d, taskCount=%d, taskRunCount=%d", sub.taskID, taskCount, taskRunCount) // Reached task limit if taskCount+1 > sampler.SampleSize { logs.CtxWarn(ctx, "[OnTaskFinished]taskCount+1 > sampler.GetSampleSize() Finish processor, task_id=%d", sub.taskID) @@ -266,3 +264,28 @@ func (h *TraceHubServiceImpl) dispatch(ctx context.Context, span *loop_span.Span } return merr.ErrorOrNil() } + +func (h *TraceHubServiceImpl) listNonFinalTaskByRedis(ctx context.Context, spaceID string) ([]*entity.ObservabilityTask, error) { + var taskPOs []*entity.ObservabilityTask + nonFinalTaskIDs, err := h.taskRepo.ListNonFinalTaskBySpaceID(ctx, spaceID) + if err != nil { + logs.CtxError(ctx, "Failed to get non-final task list", "err", err) + return nil, err + } + logs.CtxInfo(ctx, "Start listing non-final tasks, taskCount:%d, nonFinalTaskIDs:%v", len(nonFinalTaskIDs), nonFinalTaskIDs) + if len(nonFinalTaskIDs) == 0 { + return taskPOs, nil + } + for _, taskID := range nonFinalTaskIDs { + taskPO, err := h.taskRepo.GetTaskByCache(ctx, taskID) + if err != nil { + logs.CtxError(ctx, "Failed to get task", "err", err) + return nil, err + } + if taskPO == nil { + continue + } + taskPOs = append(taskPOs, taskPO) + } + return taskPOs, nil +} diff --git a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go index 6c9d5feb8..8b9863543 100644 --- a/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go +++ b/backend/modules/observability/domain/task/service/taskexe/tracehub/trace_hub.go @@ -5,7 +5,6 @@ package tracehub import ( "context" - "time" "github.com/coze-dev/coze-loop/backend/infra/lock" "github.com/coze-dev/coze-loop/backend/modules/observability/domain/component/config" @@ -24,6 +23,7 @@ import ( type ITraceHubService interface { SpanTrigger(ctx context.Context, span *loop_span.Span) error BackFill(ctx context.Context, event *entity.BackFillEvent) error + StoneTaskCache(ctx context.Context, cacheInfo TaskCacheInfo) error } func NewTraceHubImpl( @@ -37,44 +37,31 @@ func NewTraceHubImpl( locker lock.ILocker, config config.ITraceConfig, ) (ITraceHubService, error) { - // Create two independent timers with different intervals - scheduledTaskTicker := time.NewTicker(5 * time.Minute) // Task status lifecycle management - 5-minute interval - syncTaskTicker := time.NewTicker(2 * time.Minute) // Data synchronization - 1-minute interval impl := &TraceHubServiceImpl{ - taskRepo: tRepo, - scheduledTaskTicker: scheduledTaskTicker, - syncTaskTicker: syncTaskTicker, - stopChan: make(chan struct{}), - traceRepo: traceRepo, - tenantProvider: tenantProvider, - buildHelper: buildHelper, - taskProcessor: taskProcessor, - aid: aid, - backfillProducer: backfillProducer, - locker: locker, - config: config, - localCache: NewLocalCache(), + taskRepo: tRepo, + traceRepo: traceRepo, + tenantProvider: tenantProvider, + buildHelper: buildHelper, + taskProcessor: taskProcessor, + aid: aid, + backfillProducer: backfillProducer, + locker: locker, + config: config, + localCache: NewLocalCache(), } - // Start the scheduled tasks immediately - impl.startScheduledTask() - - // default+lane?+新集群?——定时任务和任务处理分开——内场 return impl, nil } type TraceHubServiceImpl struct { - scheduledTaskTicker *time.Ticker // Task status lifecycle management timer - 5-minute interval - syncTaskTicker *time.Ticker // Data synchronization timer - 1-minute interval - stopChan chan struct{} - taskRepo repo.ITaskRepo - traceRepo trace_repo.ITraceRepo - tenantProvider tenant.ITenantProvider - taskProcessor *processor.TaskProcessor - buildHelper service.TraceFilterProcessorBuilder - backfillProducer mq.IBackfillProducer - locker lock.ILocker - config config.ITraceConfig + taskRepo repo.ITaskRepo + traceRepo trace_repo.ITraceRepo + tenantProvider tenant.ITenantProvider + taskProcessor *processor.TaskProcessor + buildHelper service.TraceFilterProcessorBuilder + backfillProducer mq.IBackfillProducer + locker lock.ILocker + config config.ITraceConfig // Local cache - caching non-terminal task information localCache *LocalCache @@ -82,6 +69,7 @@ type TraceHubServiceImpl struct { aid int32 } -func (h *TraceHubServiceImpl) Close() { - close(h.stopChan) +func (h *TraceHubServiceImpl) StoneTaskCache(ctx context.Context, cacheInfo TaskCacheInfo) error { + h.localCache.StoneTaskCache(ctx, cacheInfo) + return nil }