diff --git a/sdks/community/go/pkg/core/events/activity_events.go b/sdks/community/go/pkg/core/events/activity_events.go new file mode 100644 index 000000000..24eae50fe --- /dev/null +++ b/sdks/community/go/pkg/core/events/activity_events.go @@ -0,0 +1,109 @@ +package events + +import ( + "encoding/json" + "fmt" +) + +// ActivitySnapshotEvent contains a snapshot of an activity message. +type ActivitySnapshotEvent struct { + *BaseEvent + MessageID string `json:"messageId"` + ActivityType string `json:"activityType"` + Content any `json:"content"` + Replace *bool `json:"replace,omitempty"` +} + +// NewActivitySnapshotEvent creates a new activity snapshot event. +func NewActivitySnapshotEvent(messageID, activityType string, content any) *ActivitySnapshotEvent { + replace := true + return &ActivitySnapshotEvent{ + BaseEvent: NewBaseEvent(EventTypeActivitySnapshot), + MessageID: messageID, + ActivityType: activityType, + Content: content, + Replace: &replace, + } +} + +// WithReplace sets the replace flag for the snapshot event. +func (e *ActivitySnapshotEvent) WithReplace(replace bool) *ActivitySnapshotEvent { + e.Replace = &replace + return e +} + +// Validate validates the activity snapshot event. +func (e *ActivitySnapshotEvent) Validate() error { + if err := e.BaseEvent.Validate(); err != nil { + return err + } + + if e.MessageID == "" { + return fmt.Errorf("ActivitySnapshotEvent validation failed: messageId field is required") + } + + if e.ActivityType == "" { + return fmt.Errorf("ActivitySnapshotEvent validation failed: activityType field is required") + } + + if e.Content == nil { + return fmt.Errorf("ActivitySnapshotEvent validation failed: content field is required") + } + + return nil +} + +// ToJSON serializes the event to JSON. +func (e *ActivitySnapshotEvent) ToJSON() ([]byte, error) { + return json.Marshal(e) +} + +// ActivityDeltaEvent contains incremental updates for an activity message. +type ActivityDeltaEvent struct { + *BaseEvent + MessageID string `json:"messageId"` + ActivityType string `json:"activityType"` + Patch []JSONPatchOperation `json:"patch"` +} + +// NewActivityDeltaEvent creates a new activity delta event. +func NewActivityDeltaEvent(messageID, activityType string, patch []JSONPatchOperation) *ActivityDeltaEvent { + return &ActivityDeltaEvent{ + BaseEvent: NewBaseEvent(EventTypeActivityDelta), + MessageID: messageID, + ActivityType: activityType, + Patch: patch, + } +} + +// Validate validates the activity delta event. +func (e *ActivityDeltaEvent) Validate() error { + if err := e.BaseEvent.Validate(); err != nil { + return err + } + + if e.MessageID == "" { + return fmt.Errorf("ActivityDeltaEvent validation failed: messageId field is required") + } + + if e.ActivityType == "" { + return fmt.Errorf("ActivityDeltaEvent validation failed: activityType field is required") + } + + if len(e.Patch) == 0 { + return fmt.Errorf("ActivityDeltaEvent validation failed: patch field must contain at least one operation") + } + + for i, op := range e.Patch { + if err := validateJSONPatchOperation(op); err != nil { + return fmt.Errorf("ActivityDeltaEvent validation failed: invalid patch operation at index %d: %w", i, err) + } + } + + return nil +} + +// ToJSON serializes the event to JSON. +func (e *ActivityDeltaEvent) ToJSON() ([]byte, error) { + return json.Marshal(e) +} diff --git a/sdks/community/go/pkg/core/events/activity_events_test.go b/sdks/community/go/pkg/core/events/activity_events_test.go new file mode 100644 index 000000000..deee35be4 --- /dev/null +++ b/sdks/community/go/pkg/core/events/activity_events_test.go @@ -0,0 +1,109 @@ +package events + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestActivitySnapshotEventBasics(t *testing.T) { + content := map[string]any{"status": "draft"} + + event := NewActivitySnapshotEvent("activity-1", "PLAN", content) + + assert.Equal(t, EventTypeActivitySnapshot, event.Type()) + assert.Equal(t, "activity-1", event.MessageID) + assert.Equal(t, "PLAN", event.ActivityType) + require.NotNil(t, event.Replace) + assert.True(t, *event.Replace) + assert.NoError(t, event.Validate()) + + event = event.WithReplace(false) + require.NotNil(t, event.Replace) + assert.False(t, *event.Replace) +} + +func TestActivitySnapshotEventValidationAndJSON(t *testing.T) { + event := NewActivitySnapshotEvent("activity-1", "PLAN", map[string]any{"status": "draft"}) + + data, err := event.ToJSON() + require.NoError(t, err) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(data, &decoded)) + + assert.Equal(t, string(EventTypeActivitySnapshot), decoded["type"]) + assert.Equal(t, "activity-1", decoded["messageId"]) + assert.Equal(t, "PLAN", decoded["activityType"]) + content, ok := decoded["content"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "draft", content["status"]) + + event.MessageID = "" + assert.Error(t, event.Validate()) + + event.MessageID = "activity-1" + event.ActivityType = "" + assert.Error(t, event.Validate()) + + event.ActivityType = "PLAN" + event.Content = nil + assert.Error(t, event.Validate()) + + event.Content = map[string]any{"status": "draft"} + event.BaseEvent.EventType = "" + assert.Error(t, event.Validate()) +} + +func TestActivitySnapshotEvent_MissingActivityType(t *testing.T) { + event := NewActivitySnapshotEvent("activity-1", "", map[string]any{"status": "draft"}) + err := event.Validate() + assert.Error(t, err) +} + +func TestActivityDeltaEventValidationAndJSON(t *testing.T) { + patch := []JSONPatchOperation{{Op: "replace", Path: "/status", Value: "done"}} + event := NewActivityDeltaEvent("activity-1", "PLAN", patch) + + assert.Equal(t, EventTypeActivityDelta, event.Type()) + assert.NoError(t, event.Validate()) + + data, err := event.ToJSON() + require.NoError(t, err) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(data, &decoded)) + + assert.Equal(t, string(EventTypeActivityDelta), decoded["type"]) + assert.Equal(t, "activity-1", decoded["messageId"]) + assert.Equal(t, "PLAN", decoded["activityType"]) + items, ok := decoded["patch"].([]any) + require.True(t, ok) + assert.Len(t, items, 1) + + event.MessageID = "" + assert.Error(t, event.Validate()) + + event.MessageID = "activity-1" + event.Patch = []JSONPatchOperation{} + assert.Error(t, event.Validate()) + + event.Patch = []JSONPatchOperation{{Op: "invalid", Path: "/status"}} + assert.Error(t, event.Validate()) + + event.Patch = []JSONPatchOperation{{Op: "replace", Path: "/status", Value: "ok"}} + event.ActivityType = "" + assert.Error(t, event.Validate()) + + event.ActivityType = "PLAN" + event.BaseEvent.EventType = "" + assert.Error(t, event.Validate()) +} + +func TestActivityDeltaEvent_MissingActivityType(t *testing.T) { + event := NewActivityDeltaEvent("activity-1", "", []JSONPatchOperation{{Op: "replace", Path: "/status", Value: "done"}}) + err := event.Validate() + assert.Error(t, err) +} diff --git a/sdks/community/go/pkg/core/events/decoder.go b/sdks/community/go/pkg/core/events/decoder.go index 942d1490d..6b68b8d39 100644 --- a/sdks/community/go/pkg/core/events/decoder.go +++ b/sdks/community/go/pkg/core/events/decoder.go @@ -130,6 +130,20 @@ func (ed *EventDecoder) DecodeEvent(eventName string, data []byte) (Event, error } return &evt, nil + case EventTypeActivitySnapshot: + var evt ActivitySnapshotEvent + if err := json.Unmarshal(data, &evt); err != nil { + return nil, fmt.Errorf("failed to decode ACTIVITY_SNAPSHOT: %w", err) + } + return &evt, nil + + case EventTypeActivityDelta: + var evt ActivityDeltaEvent + if err := json.Unmarshal(data, &evt); err != nil { + return nil, fmt.Errorf("failed to decode ACTIVITY_DELTA: %w", err) + } + return &evt, nil + case EventTypeStepStarted: var evt StepStartedEvent if err := json.Unmarshal(data, &evt); err != nil { diff --git a/sdks/community/go/pkg/core/events/decoder_test.go b/sdks/community/go/pkg/core/events/decoder_test.go index 058d3a6fb..958573815 100644 --- a/sdks/community/go/pkg/core/events/decoder_test.go +++ b/sdks/community/go/pkg/core/events/decoder_test.go @@ -204,6 +204,41 @@ func TestEventDecoder(t *testing.T) { assert.Equal(t, "msg-1", msgEvent.Messages[0].ID) }) + t.Run("DecodeEvent_ActivitySnapshot", func(t *testing.T) { + decoder := NewEventDecoder(nil) + data := []byte(`{"messageId": "activity-1", "activityType": "PLAN", "content": {"status": "draft"}, "replace": false}`) + + event, err := decoder.DecodeEvent("ACTIVITY_SNAPSHOT", data) + require.NoError(t, err) + require.NotNil(t, event) + + activityEvent, ok := event.(*ActivitySnapshotEvent) + require.True(t, ok) + assert.Equal(t, "activity-1", activityEvent.MessageID) + assert.Equal(t, "PLAN", activityEvent.ActivityType) + require.NotNil(t, activityEvent.Replace) + assert.False(t, *activityEvent.Replace) + content, ok := activityEvent.Content.(map[string]any) + require.True(t, ok) + assert.Equal(t, "draft", content["status"]) + }) + + t.Run("DecodeEvent_ActivityDelta", func(t *testing.T) { + decoder := NewEventDecoder(nil) + data := []byte(`{"messageId": "activity-1", "activityType": "PLAN", "patch": [{"op": "replace", "path": "/status", "value": "streaming"}]}`) + + event, err := decoder.DecodeEvent("ACTIVITY_DELTA", data) + require.NoError(t, err) + require.NotNil(t, event) + + activityEvent, ok := event.(*ActivityDeltaEvent) + require.True(t, ok) + assert.Equal(t, "activity-1", activityEvent.MessageID) + assert.Equal(t, "PLAN", activityEvent.ActivityType) + assert.Len(t, activityEvent.Patch, 1) + assert.Equal(t, "replace", activityEvent.Patch[0].Op) + }) + t.Run("DecodeEvent_StepStarted", func(t *testing.T) { decoder := NewEventDecoder(nil) data := []byte(`{"stepName": "step-1"}`) diff --git a/sdks/community/go/pkg/core/events/events.go b/sdks/community/go/pkg/core/events/events.go index b97d963db..04f1f62ed 100644 --- a/sdks/community/go/pkg/core/events/events.go +++ b/sdks/community/go/pkg/core/events/events.go @@ -24,6 +24,8 @@ const ( EventTypeStateSnapshot EventType = "STATE_SNAPSHOT" EventTypeStateDelta EventType = "STATE_DELTA" EventTypeMessagesSnapshot EventType = "MESSAGES_SNAPSHOT" + EventTypeActivitySnapshot EventType = "ACTIVITY_SNAPSHOT" + EventTypeActivityDelta EventType = "ACTIVITY_DELTA" EventTypeRaw EventType = "RAW" EventTypeCustom EventType = "CUSTOM" EventTypeRunStarted EventType = "RUN_STARTED" @@ -57,6 +59,8 @@ var validEventTypes = map[EventType]bool{ EventTypeStateSnapshot: true, EventTypeStateDelta: true, EventTypeMessagesSnapshot: true, + EventTypeActivitySnapshot: true, + EventTypeActivityDelta: true, EventTypeRaw: true, EventTypeCustom: true, EventTypeRunStarted: true, @@ -318,6 +322,14 @@ func ValidateSequence(events []Event) error { // They represent complete message state at any point in time // Additional validation could be added if needed (e.g., consistency checks) + case EventTypeActivitySnapshot: + // Activity snapshot events are always valid in sequence context + // They represent complete activity state at any point in time + + case EventTypeActivityDelta: + // Activity delta events are always valid in sequence context + // They represent incremental activity changes at any point in time + case EventTypeRaw: // Raw events are always valid in sequence context // They contain external data that should be passed through @@ -381,6 +393,10 @@ func EventFromJSON(data []byte) (Event, error) { event = &StateDeltaEvent{} case EventTypeMessagesSnapshot: event = &MessagesSnapshotEvent{} + case EventTypeActivitySnapshot: + event = &ActivitySnapshotEvent{} + case EventTypeActivityDelta: + event = &ActivityDeltaEvent{} case EventTypeRaw: event = &RawEvent{} case EventTypeCustom: diff --git a/sdks/community/go/pkg/core/events/events_test.go b/sdks/community/go/pkg/core/events/events_test.go index e11843c98..5eba91e65 100644 --- a/sdks/community/go/pkg/core/events/events_test.go +++ b/sdks/community/go/pkg/core/events/events_test.go @@ -1,6 +1,7 @@ package events import ( + "encoding/json" "testing" "time" @@ -323,6 +324,14 @@ func TestStateEvents(t *testing.T) { }, }, }, + { + ID: "activity-1", + Role: RoleActivity, + ActivityType: "PLAN", + ActivityContent: map[string]any{ + "status": "draft", + }, + }, } event := NewMessagesSnapshotEvent(messages) @@ -355,6 +364,78 @@ func TestStateEvents(t *testing.T) { } event.Messages = invalidMessages assert.Error(t, event.Validate()) + + invalidMessages = []Message{ + { + ID: "activity-1", + Role: RoleActivity, + // Missing activityType + ActivityContent: map[string]any{ + "status": "draft", + }, + }, + } + event.Messages = invalidMessages + assert.Error(t, event.Validate()) + + invalidMessages = []Message{ + { + ID: "activity-1", + Role: RoleActivity, + ActivityType: "PLAN", + // Missing content + }, + } + event.Messages = invalidMessages + assert.Error(t, event.Validate()) + }) +} + +func TestActivityEvents(t *testing.T) { + t.Run("ActivitySnapshotEvent", func(t *testing.T) { + content := map[string]any{ + "status": "draft", + } + + event := NewActivitySnapshotEvent("activity-1", "PLAN", content) + + assert.Equal(t, EventTypeActivitySnapshot, event.Type()) + assert.Equal(t, "activity-1", event.MessageID) + assert.Equal(t, "PLAN", event.ActivityType) + assert.NotNil(t, event.Replace) + assert.True(t, *event.Replace) + assert.NoError(t, event.Validate()) + + event.MessageID = "" + assert.Error(t, event.Validate()) + + event.MessageID = "activity-1" + event.ActivityType = "" + assert.Error(t, event.Validate()) + + event.ActivityType = "PLAN" + event.Content = nil + assert.Error(t, event.Validate()) + }) + + t.Run("ActivityDeltaEvent", func(t *testing.T) { + patch := []JSONPatchOperation{ + {Op: "replace", Path: "/status", Value: "done"}, + } + + event := NewActivityDeltaEvent("activity-1", "PLAN", patch) + + assert.Equal(t, EventTypeActivityDelta, event.Type()) + assert.Equal(t, "activity-1", event.MessageID) + assert.Equal(t, "PLAN", event.ActivityType) + assert.Len(t, event.Patch, 1) + assert.NoError(t, event.Validate()) + + event.Patch = []JSONPatchOperation{} + assert.Error(t, event.Validate()) + + event.Patch = []JSONPatchOperation{{Op: "replace", Path: ""}} + assert.Error(t, event.Validate()) }) } @@ -392,6 +473,50 @@ func TestCustomEvents(t *testing.T) { }) } +func TestMessageSerialization(t *testing.T) { + t.Run("MarshalAndUnmarshal_TextMessage", func(t *testing.T) { + msg := Message{ + ID: "msg-1", + Role: "user", + Content: strPtr("hello"), + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + assert.Equal(t, "msg-1", decoded.ID) + assert.Equal(t, "user", decoded.Role) + require.NotNil(t, decoded.Content) + assert.Equal(t, "hello", *decoded.Content) + assert.Nil(t, decoded.ActivityContent) + }) + + t.Run("MarshalAndUnmarshal_ActivityMessage", func(t *testing.T) { + msg := Message{ + ID: "activity-1", + Role: "activity", + ActivityType: "PLAN", + ActivityContent: map[string]any{"status": "draft"}, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + assert.Equal(t, "activity-1", decoded.ID) + assert.Equal(t, "activity", decoded.Role) + assert.Equal(t, "PLAN", decoded.ActivityType) + require.Nil(t, decoded.Content) + require.NotNil(t, decoded.ActivityContent) + assert.Equal(t, "draft", decoded.ActivityContent["status"]) + }) +} + func TestEventSequenceValidation(t *testing.T) { t.Run("ValidSequence", func(t *testing.T) { events := []Event{ @@ -479,6 +604,8 @@ func TestJSONSerialization(t *testing.T) { NewTextMessageContentEvent("msg-1", "Hello"), NewToolCallStartEvent("tool-1", "get_weather", WithParentMessageID("msg-1")), NewStateSnapshotEvent(map[string]any{"counter": 42}), + NewActivitySnapshotEvent("activity-1", "PLAN", map[string]any{"status": "draft"}), + NewActivityDeltaEvent("activity-1", "PLAN", []JSONPatchOperation{{Op: "replace", Path: "/status", Value: "done"}}), NewCustomEvent("test-event", WithValue("test-value")), } diff --git a/sdks/community/go/pkg/core/events/id_utils_test.go b/sdks/community/go/pkg/core/events/id_utils_test.go index 5f033e9a1..529019d47 100644 --- a/sdks/community/go/pkg/core/events/id_utils_test.go +++ b/sdks/community/go/pkg/core/events/id_utils_test.go @@ -101,7 +101,7 @@ func TestTimestampIDGenerator(t *testing.T) { require.GreaterOrEqual(t, len(parts), 3) timestamp := parts[1] _, err := time.Parse("", timestamp) // Just check it's a number - assert.NotNil(t, err) // We expect an error because timestamp is just a number + assert.NotNil(t, err) // We expect an error because timestamp is just a number // Test uniqueness id2 := gen.GenerateRunID() @@ -172,19 +172,19 @@ func TestTimestampIDGenerator(t *testing.T) { t.Run("Timestamp_Ordering", func(t *testing.T) { gen := NewTimestampIDGenerator("") - + // Generate IDs with slight delay id1 := gen.GenerateRunID() time.Sleep(2 * time.Millisecond) id2 := gen.GenerateRunID() - + // Extract timestamps parts1 := strings.Split(id1, "-") parts2 := strings.Split(id2, "-") - + require.GreaterOrEqual(t, len(parts1), 3) require.GreaterOrEqual(t, len(parts2), 3) - + // The timestamp in id2 should be >= timestamp in id1 // (We can't parse them as ints here but the string comparison should work for ordering) assert.True(t, parts2[1] >= parts1[1]) @@ -195,7 +195,7 @@ func TestGlobalIDGenerator(t *testing.T) { t.Run("GetDefaultIDGenerator", func(t *testing.T) { gen := GetDefaultIDGenerator() assert.NotNil(t, gen) - + // Should be a DefaultIDGenerator by default _, ok := gen.(*DefaultIDGenerator) assert.True(t, ok) @@ -310,4 +310,4 @@ func TestGlobalIDGenerator(t *testing.T) { assert.Equal(t, 100, len(ids)) }) -} \ No newline at end of file +} diff --git a/sdks/community/go/pkg/core/events/state_events.go b/sdks/community/go/pkg/core/events/state_events.go index 1f317794b..1740f4fdc 100644 --- a/sdks/community/go/pkg/core/events/state_events.go +++ b/sdks/community/go/pkg/core/events/state_events.go @@ -15,6 +15,9 @@ var validJSONPatchOps = map[string]bool{ "test": true, } +// RoleActivity is the role for activity messages +const RoleActivity = "activity" + // StateSnapshotEvent contains a complete snapshot of the state type StateSnapshotEvent struct { *BaseEvent @@ -121,12 +124,96 @@ func (e *StateDeltaEvent) ToJSON() ([]byte, error) { // Message represents a message in the conversation type Message struct { - ID string `json:"id"` - Role string `json:"role"` - Content *string `json:"content,omitempty"` - Name *string `json:"name,omitempty"` - ToolCalls []ToolCall `json:"toolCalls,omitempty"` - ToolCallID *string `json:"toolCallId,omitempty"` + ID string `json:"id"` + Role string `json:"role"` + Content *string `json:"-"` + ActivityContent map[string]any `json:"-"` + Name *string `json:"name,omitempty"` + ToolCalls []ToolCall `json:"toolCalls,omitempty"` + ToolCallID *string `json:"toolCallId,omitempty"` + ActivityType string `json:"activityType,omitempty"` +} + +// MarshalJSON ensures content is correctly serialized for text and activity messages. +func (m Message) MarshalJSON() ([]byte, error) { + payload := map[string]any{ + "id": m.ID, + "role": m.Role, + } + + if m.Name != nil { + payload["name"] = *m.Name + } + if len(m.ToolCalls) > 0 { + payload["toolCalls"] = m.ToolCalls + } + if m.ToolCallID != nil { + payload["toolCallId"] = *m.ToolCallID + } + + if m.Role == RoleActivity { + if m.ActivityType != "" { + payload["activityType"] = m.ActivityType + } + if m.ActivityContent != nil { + payload["content"] = m.ActivityContent + } + } else if m.Content != nil { + payload["content"] = *m.Content + } + + return json.Marshal(payload) +} + +// UnmarshalJSON hydrates content into the appropriate field depending on role. +func (m *Message) UnmarshalJSON(data []byte) error { + var raw struct { + ID string `json:"id"` + Role string `json:"role"` + Name *string `json:"name,omitempty"` + Content json.RawMessage `json:"content,omitempty"` + ToolCalls []ToolCall `json:"toolCalls,omitempty"` + ToolCallID *string `json:"toolCallId,omitempty"` + ActivityType string `json:"activityType,omitempty"` + } + + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + m.ID = raw.ID + m.Role = raw.Role + m.Name = raw.Name + m.ToolCalls = raw.ToolCalls + m.ToolCallID = raw.ToolCallID + m.ActivityType = raw.ActivityType + + if raw.Role == RoleActivity { + if len(raw.Content) > 0 { + var content map[string]any + if err := json.Unmarshal(raw.Content, &content); err != nil { + return fmt.Errorf("failed to decode activity content: %w", err) + } + m.ActivityContent = content + } else { + m.ActivityContent = nil + } + m.Content = nil + } else { + if len(raw.Content) > 0 { + var text string + if err := json.Unmarshal(raw.Content, &text); err != nil { + return fmt.Errorf("failed to decode message content: %w", err) + } + m.Content = &text + } else { + m.Content = nil + } + m.ActivityContent = nil + m.ActivityType = "" + } + + return nil } // ToolCall represents a tool call within a message @@ -182,6 +269,22 @@ func validateMessage(msg Message) error { return fmt.Errorf("message role field is required") } + if msg.Role == RoleActivity { + if msg.ActivityType == "" { + return fmt.Errorf("activityType field is required for activity messages") + } + if msg.ActivityContent == nil { + return fmt.Errorf("content field is required for activity messages") + } + } else { + if msg.ActivityContent != nil { + return fmt.Errorf("activity content is only valid for activity messages") + } + if msg.ActivityType != "" { + return fmt.Errorf("activityType is only valid for activity messages") + } + } + // Validate tool calls if present for i, toolCall := range msg.ToolCalls { if err := validateToolCall(toolCall); err != nil { diff --git a/sdks/community/go/pkg/core/events/state_events_test.go b/sdks/community/go/pkg/core/events/state_events_test.go new file mode 100644 index 000000000..f36f99597 --- /dev/null +++ b/sdks/community/go/pkg/core/events/state_events_test.go @@ -0,0 +1,207 @@ +package events + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMessageMarshalUnmarshal_Text(t *testing.T) { + msg := Message{ + ID: "msg-1", + Role: "user", + Content: strPtr("hello"), + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + assert.Equal(t, "msg-1", decoded.ID) + assert.Equal(t, "user", decoded.Role) + require.NotNil(t, decoded.Content) + assert.Equal(t, "hello", *decoded.Content) + assert.Nil(t, decoded.ActivityContent) + assert.Empty(t, decoded.ActivityType) +} + +func TestMessageMarshalUnmarshal_Activity(t *testing.T) { + msg := Message{ + ID: "activity-1", + Role: RoleActivity, + ActivityType: "PLAN", + ActivityContent: map[string]any{"status": "working"}, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded Message + require.NoError(t, json.Unmarshal(data, &decoded)) + + assert.Equal(t, "activity-1", decoded.ID) + assert.Equal(t, "activity", decoded.Role) + assert.Equal(t, "PLAN", decoded.ActivityType) + require.Nil(t, decoded.Content) + require.NotNil(t, decoded.ActivityContent) + assert.Equal(t, "working", decoded.ActivityContent["status"]) +} + +func TestValidateMessage_NonActivityRejectsActivityFields(t *testing.T) { + msg := Message{ + ID: "msg-1", + Role: "user", + Content: strPtr("hello"), + ActivityType: "PLAN", + ActivityContent: map[string]any{"status": "draft"}, + } + + err := validateMessage(msg) + assert.Error(t, err) +} + +func TestValidateMessage_ActivityRequiresFields(t *testing.T) { + msg := Message{ + ID: "activity-1", + Role: RoleActivity, + } + + err := validateMessage(msg) + assert.Error(t, err) + + msg.ActivityType = "PLAN" + err = validateMessage(msg) + assert.Error(t, err) + + msg.ActivityContent = map[string]any{"status": "draft"} + err = validateMessage(msg) + assert.NoError(t, err) + + msg = Message{ + ID: "msg-1", + Role: "user", + Content: strPtr("hello"), + ActivityType: "PLAN", + ActivityContent: map[string]any{"status": "oops"}, + } + err = validateMessage(msg) + assert.Error(t, err) +} + +func TestMessageMarshalJSON_IncludesOptionalFields(t *testing.T) { + name := "bob" + toolCallID := "tool-123" + msg := Message{ + ID: "msg-1", + Role: "assistant", + Content: strPtr("hello"), + Name: &name, + ToolCalls: []ToolCall{ + { + ID: "tool-1", + Type: "function", + Function: Function{ + Name: "f", + Arguments: "{}", + }, + }, + }, + ToolCallID: &toolCallID, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(data, &decoded)) + + assert.Equal(t, "msg-1", decoded["id"]) + assert.Equal(t, "assistant", decoded["role"]) + assert.Equal(t, "hello", decoded["content"]) + assert.Equal(t, "bob", decoded["name"]) + assert.Equal(t, "tool-123", decoded["toolCallId"]) + toolCalls, ok := decoded["toolCalls"].([]any) + require.True(t, ok) + assert.Len(t, toolCalls, 1) +} + +func TestMessageMarshalJSON_ActivityPrefersActivityContent(t *testing.T) { + msg := Message{ + ID: "activity-1", + Role: "activity", + Content: strPtr("should-be-ignored"), + ActivityType: "PLAN", + ActivityContent: map[string]any{"status": "draft"}, + } + + data, err := json.Marshal(msg) + require.NoError(t, err) + + var decoded map[string]any + require.NoError(t, json.Unmarshal(data, &decoded)) + + assert.Equal(t, "activity-1", decoded["id"]) + assert.Equal(t, "activity", decoded["role"]) + assert.Equal(t, "PLAN", decoded["activityType"]) + content, ok := decoded["content"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "draft", content["status"]) +} + +func TestMessageUnmarshalJSON_InvalidTextContent(t *testing.T) { + payload := []byte(`{"id":"msg-1","role":"user","content":123}`) + var msg Message + err := json.Unmarshal(payload, &msg) + assert.Error(t, err) +} + +func TestMessageUnmarshalJSON_InvalidActivityContent(t *testing.T) { + payload := []byte(`{"id":"activity-1","role":"activity","activityType":"PLAN","content":"not-an-object"}`) + var msg Message + err := json.Unmarshal(payload, &msg) + assert.Error(t, err) +} + +func TestMessageUnmarshalJSON_ResetsActivityFieldsForText(t *testing.T) { + payload := []byte(`{"id":"msg-1","role":"user","activityType":"PLAN","content":"hello"}`) + var msg Message + err := json.Unmarshal(payload, &msg) + require.NoError(t, err) + + assert.Equal(t, "msg-1", msg.ID) + assert.Equal(t, "user", msg.Role) + require.NotNil(t, msg.Content) + assert.Equal(t, "hello", *msg.Content) + assert.Empty(t, msg.ActivityType) + assert.Nil(t, msg.ActivityContent) +} + +func TestMessageUnmarshalJSON_TextWithNoContent(t *testing.T) { + payload := []byte(`{"id":"msg-1","role":"user"}`) + var msg Message + err := json.Unmarshal(payload, &msg) + require.NoError(t, err) + + assert.Equal(t, "msg-1", msg.ID) + assert.Equal(t, "user", msg.Role) + assert.Nil(t, msg.Content) + assert.Nil(t, msg.ActivityContent) + assert.Empty(t, msg.ActivityType) +} + +func TestMessageUnmarshalJSON_ActivityWithNoContent(t *testing.T) { + payload := []byte(`{"id":"activity-1","role":"activity","activityType":"PLAN"}`) + var msg Message + err := json.Unmarshal(payload, &msg) + require.NoError(t, err) + + assert.Equal(t, "activity-1", msg.ID) + assert.Equal(t, "activity", msg.Role) + assert.Equal(t, "PLAN", msg.ActivityType) + assert.Nil(t, msg.Content) + assert.Nil(t, msg.ActivityContent) +}