Skip to content

Commit c8b02fc

Browse files
authored
Merge pull request #387 from JaredConover/fix/responses-api-type-parsing
2 parents 2e7849e + 381c0f1 commit c8b02fc

File tree

3 files changed

+103
-13
lines changed

3 files changed

+103
-13
lines changed

Sources/OpenAI/Private/Streaming/ModelResponseEventsStreamInterpreter.swift

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,25 +55,24 @@ final class ModelResponseEventsStreamInterpreter: @unchecked Sendable, StreamInt
5555
}
5656

5757
private func processEvent(_ event: ServerSentEventsStreamParser.Event) throws {
58-
var finalEvent = event
59-
if event.eventType == "response.output_text.annotation.added" {
60-
// Remove when they have fixed (unified)!
61-
//
62-
// By looking at [API Reference](https://platform.openai.com/docs/api-reference/responses-streaming/response/output_text_annotation/added)
63-
// and generated type `Schemas.ResponseOutputTextAnnotationAddedEvent`
64-
// We can see that "output_text.annotation" is incorrect, whereas output_text_annotation is the correct one
65-
let fixedDataString = event.decodedData.replacingOccurrences(of: "response.output_text.annotation.added", with: "response.output_text_annotation.added")
66-
finalEvent = .init(id: event.id, data: fixedDataString.data(using: .utf8) ?? event.data, decodedData: fixedDataString, eventType: "response.output_text_annotation.added", retry: event.retry)
58+
let finalEvent = event.fixMappingError()
59+
var eventType = finalEvent.eventType
60+
61+
/// If the SSE `event` property is not specified by the provider service, our parser defaults it to "message" which is not a valid model response type.
62+
/// In this case we check the `data.type` property for a valid model response type.
63+
if eventType == "message" || eventType.isEmpty,
64+
let payloadEventType = finalEvent.getPayloadType() {
65+
eventType = payloadEventType
6766
}
68-
69-
guard let modelResponseEventType = ModelResponseStreamEventType(rawValue: finalEvent.eventType) else {
70-
throw InterpreterError.unknownEventType(finalEvent.eventType)
67+
68+
guard let modelResponseEventType = ModelResponseStreamEventType(rawValue: eventType) else {
69+
throw InterpreterError.unknownEventType(eventType)
7170
}
7271

7372
let responseStreamEvent = try responseStreamEvent(modelResponseEventType: modelResponseEventType, data: finalEvent.data)
7473
onEventDispatched?(responseStreamEvent)
7574
}
76-
75+
7776
private func processError(_ error: Error) {
7877
onError?(error)
7978
}
@@ -210,3 +209,35 @@ final class ModelResponseEventsStreamInterpreter: @unchecked Sendable, StreamInt
210209
try decoder.decode(T.self, from: data)
211210
}
212211
}
212+
213+
private extension ServerSentEventsStreamParser.Event {
214+
215+
// Remove when they have fixed (unified)!
216+
//
217+
// By looking at [API Reference](https://platform.openai.com/docs/api-reference/responses-streaming/response/output_text_annotation/added)
218+
// and generated type `Schemas.ResponseOutputTextAnnotationAddedEvent`
219+
// We can see that "output_text.annotation" is incorrect, whereas output_text_annotation is the correct one
220+
func fixMappingError() -> Self {
221+
let incorrectEventType = "response.output_text.annotation.added"
222+
let correctEventType = "response.output_text_annotation.added"
223+
224+
guard self.eventType == incorrectEventType || self.getPayloadType() == incorrectEventType else {
225+
return self
226+
}
227+
228+
let fixedDataString = self.decodedData.replacingOccurrences(of: incorrectEventType, with: correctEventType)
229+
return .init(
230+
id: self.id,
231+
data: fixedDataString.data(using: .utf8) ?? self.data,
232+
decodedData: fixedDataString,
233+
eventType: correctEventType,
234+
retry: self.retry
235+
)
236+
}
237+
238+
struct TypeEnvelope: Decodable { let type: String }
239+
240+
func getPayloadType() -> String? {
241+
try? JSONDecoder().decode(TypeEnvelope.self, from: self.data).type
242+
}
243+
}

Tests/OpenAITests/MockServerSentEvent.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,18 @@ struct MockServerSentEvent {
2020
static func chatCompletionError() -> Data {
2121
"{\n \"error\": {\n \"message\": \"The model `o3-mini` does not exist or you do not have access to it.\",\n \"type\": \"invalid_request_error\",\n \"param\": null,\n \"code\": \"model_not_found\"\n }\n}\n".data(using: .utf8)!
2222
}
23+
24+
static func responseStreamEvent(
25+
itemId: String = "msg_1",
26+
payloadType: String,
27+
outputIndex: Int = 0,
28+
contentIndex: Int = 0,
29+
delta: String = "",
30+
sequenceNumber: Int = 1
31+
) -> Data {
32+
let json = """
33+
{"type":"\(payloadType)","output_index":\(outputIndex),"item_id":"\(itemId)","content_index":\(contentIndex),"delta":"\(delta)","sequence_number":\(sequenceNumber)}
34+
"""
35+
return "data: \(json)\n\n".data(using: .utf8)!
36+
}
2337
}

Tests/OpenAITests/ModelResponseEventsStreamInterpreterTests.swift

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,49 @@ final class ModelResponseEventsStreamInterpreterTests: XCTestCase {
3939
XCTAssertNotNil(receivedError, "Expected an error to be received, but got nil.")
4040
XCTAssertTrue(receivedError is APIErrorResponse, "Expected received error to be of type APIErrorResponse.")
4141
}
42+
43+
func testParsesOutputTextDeltaUsingPayloadType() async throws {
44+
let expectation = XCTestExpectation(description: "OutputText delta event received")
45+
var receivedEvent: ResponseStreamEvent?
46+
47+
interpreter.setCallbackClosures { event in
48+
Task {
49+
await MainActor.run {
50+
receivedEvent = event
51+
expectation.fulfill()
52+
}
53+
}
54+
} onError: { error in
55+
XCTFail("Unexpected error received: \(error)")
56+
}
57+
58+
interpreter.processData(
59+
MockServerSentEvent.responseStreamEvent(
60+
itemId: "msg_1",
61+
payloadType: "response.output_text.delta",
62+
outputIndex: 0,
63+
contentIndex: 0,
64+
delta: "Hi",
65+
sequenceNumber: 1
66+
)
67+
)
68+
69+
await fulfillment(of: [expectation], timeout: 1.0)
70+
71+
guard let receivedEvent else {
72+
XCTFail("No event received")
73+
return
74+
}
75+
76+
switch receivedEvent {
77+
case .outputText(.delta(let deltaEvent)):
78+
XCTAssertEqual(deltaEvent.itemId, "msg_1")
79+
XCTAssertEqual(deltaEvent.outputIndex, 0)
80+
XCTAssertEqual(deltaEvent.contentIndex, 0)
81+
XCTAssertEqual(deltaEvent.delta, "Hi")
82+
XCTAssertEqual(deltaEvent.sequenceNumber, 1)
83+
default:
84+
XCTFail("Expected .outputText(.delta), got \(receivedEvent)")
85+
}
86+
}
4287
}

0 commit comments

Comments
 (0)