diff --git a/firebase-ai/CHANGELOG.md b/firebase-ai/CHANGELOG.md index 053febea25a..abf0bf55c68 100644 --- a/firebase-ai/CHANGELOG.md +++ b/firebase-ai/CHANGELOG.md @@ -1,5 +1,7 @@ # Unreleased +- [changed] Added better scheduling and louder output for Live API. +- [changed] Added support for input and output transcription. (#7482) - [feature] Added support for sending realtime audio and video in a `LiveSession`. - [changed] Removed redundant internal exception types. (#7475) diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt index 6179c8b52e9..9f1bbd37260 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/util/android.kt @@ -17,10 +17,8 @@ package com.google.firebase.ai.common.util import android.media.AudioRecord -import kotlin.time.Duration.Companion.milliseconds import kotlinx.coroutines.delay import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.yield /** * The minimum buffer size for this instance. @@ -40,15 +38,17 @@ internal fun AudioRecord.readAsFlow() = flow { while (true) { if (recordingState != AudioRecord.RECORDSTATE_RECORDING) { - // TODO(vguthal): Investigate if both yield and delay are required. - delay(10.milliseconds) - yield() + // delay uses a different scheduler in the backend, so it's "stickier" in its enforcement when + // compared to yield. + delay(0) continue } val bytesRead = read(buffer, 0, buffer.size) if (bytesRead > 0) { emit(buffer.copyOf(bytesRead)) } - yield() + // delay uses a different scheduler in the backend, so it's "stickier" in its enforcement when + // compared to yield. + delay(0) } } diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt index 08e90fc8538..62fbb740e77 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/AudioHelper.kt @@ -159,10 +159,13 @@ internal class AudioHelper( * constructor. */ @RequiresPermission(Manifest.permission.RECORD_AUDIO) - fun build(): AudioHelper { + fun build(audioHandler: ((AudioRecord, AudioTrack) -> Unit)? = null): AudioHelper { val playbackTrack = AudioTrack( - AudioAttributes.Builder().setUsage(AudioAttributes.USAGE_VOICE_COMMUNICATION).build(), + AudioAttributes.Builder() + .setUsage(AudioAttributes.USAGE_MEDIA) + .setContentType(AudioAttributes.CONTENT_TYPE_SPEECH) + .build(), AudioFormat.Builder() .setSampleRate(24000) .setChannelMask(AudioFormat.CHANNEL_OUT_MONO) @@ -176,7 +179,6 @@ internal class AudioHelper( AudioTrack.MODE_STREAM, AudioManager.AUDIO_SESSION_ID_GENERATE ) - val bufferSize = AudioRecord.getMinBufferSize( 16000, @@ -205,6 +207,9 @@ internal class AudioHelper( if (AcousticEchoCanceler.isAvailable()) { AcousticEchoCanceler.create(recorder.audioSessionId)?.enabled = true } + if (audioHandler != null) { + audioHandler(recorder, playbackTrack) + } return AudioHelper(recorder, playbackTrack) } diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt index 9e8b7d7f683..1521d6dbcaf 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt @@ -17,12 +17,18 @@ package com.google.firebase.ai.type import android.Manifest.permission.RECORD_AUDIO +import android.annotation.SuppressLint import android.content.pm.PackageManager import android.media.AudioFormat +import android.media.AudioRecord import android.media.AudioTrack +import android.os.Process +import android.os.StrictMode +import android.os.StrictMode.ThreadPolicy import android.util.Log import androidx.annotation.RequiresPermission import androidx.core.content.ContextCompat +import com.google.firebase.BuildConfig import com.google.firebase.FirebaseApp import com.google.firebase.ai.common.JSON import com.google.firebase.ai.common.util.CancelledCoroutineScope @@ -34,21 +40,27 @@ import io.ktor.websocket.Frame import io.ktor.websocket.close import io.ktor.websocket.readBytes import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.Executors +import java.util.concurrent.ThreadFactory import java.util.concurrent.atomic.AtomicBoolean +import java.util.concurrent.atomic.AtomicLong import kotlin.coroutines.CoroutineContext +import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.asCoroutineDispatcher import kotlinx.coroutines.cancel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED +import kotlinx.coroutines.delay import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.buffer import kotlinx.coroutines.flow.catch import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn import kotlinx.coroutines.flow.launchIn import kotlinx.coroutines.flow.onCompletion import kotlinx.coroutines.flow.onEach import kotlinx.coroutines.isActive import kotlinx.coroutines.launch -import kotlinx.coroutines.yield import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.Serializable import kotlinx.serialization.encodeToString @@ -65,11 +77,21 @@ internal constructor( private val firebaseApp: FirebaseApp, ) { /** - * Coroutine scope that we batch data on for [startAudioConversation]. + * Coroutine scope that we batch data on for network related behavior. * * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope. */ - private var scope = CancelledCoroutineScope + private var networkScope = CancelledCoroutineScope + + /** + * Coroutine scope that we batch data on for audio recording and playback. + * + * Separate from [networkScope] to ensure interchanging of dispatchers doesn't cause any deadlocks + * or issues. + * + * Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope. + */ + private var audioScope = CancelledCoroutineScope /** * Playback audio data sent from the model. @@ -145,9 +167,47 @@ internal constructor( * consistently available. */ @RequiresPermission(RECORD_AUDIO) + public suspend fun startAudioConversation( + functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? = null, + transcriptHandler: ((Transcription?, Transcription?) -> Unit)?, + enableInterruptions: Boolean = false, + ) { + startAudioConversation( + functionCallHandler = functionCallHandler, + transcriptHandler = transcriptHandler, + audioHandler = null, + enableInterruptions = enableInterruptions + ) + } + + /** + * Starts an audio conversation with the model, which can only be stopped using + * [stopAudioConversation] or [close]. + * + * @param functionCallHandler A callback function that is invoked whenever the model receives a + * function call. The [FunctionResponsePart] that the callback function returns will be + * automatically sent to the model. + * + * @param transcriptHandler A callback function that is invoked whenever the model receives a + * transcript. The first [Transcription] object is the input transcription, and the second is the + * output transcription. + * + * @param audioHandler A callback function that is invoked immediately following the successful + * initialization of the associated [AudioRecord] and [AudioTrack] objects. This offers a final + * opportunity to apply custom configurations or modifications to these objects, which will remain + * valid and effective for the duration of the current audio session. + * + * @param enableInterruptions If enabled, allows the user to speak over or interrupt the model's + * ongoing reply. + * + * **WARNING**: The user interruption feature relies on device-specific support, and may not be + * consistently available. + */ + @RequiresPermission(RECORD_AUDIO) public suspend fun startAudioConversation( functionCallHandler: ((FunctionCallPart) -> FunctionResponsePart)? = null, transcriptHandler: ((Transcription?, Transcription?) -> Unit)? = null, + audioHandler: ((AudioRecord, AudioTrack) -> Unit)? = null, enableInterruptions: Boolean = false, ) { @@ -159,7 +219,7 @@ internal constructor( } FirebaseAIException.catchAsync { - if (scope.isActive) { + if (networkScope.isActive || audioScope.isActive) { Log.w( TAG, "startAudioConversation called after the recording has already started. " + @@ -167,9 +227,10 @@ internal constructor( ) return@catchAsync } - - scope = CoroutineScope(blockingDispatcher + childJob()) - audioHelper = AudioHelper.build() + networkScope = + CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Network")) + audioScope = CoroutineScope(audioDispatcher + childJob() + CoroutineName("LiveSession Audio")) + audioHelper = AudioHelper.build(audioHandler) recordUserAudio() processModelResponses(functionCallHandler, transcriptHandler) @@ -188,7 +249,8 @@ internal constructor( FirebaseAIException.catch { if (!startedReceiving.getAndSet(false)) return@catch - scope.cancel() + networkScope.cancel() + audioScope.cancel() playBackQueue.clear() audioHelper?.release() @@ -231,7 +293,9 @@ internal constructor( ) } ?.let { emit(it.toPublic()) } - yield() + // delay uses a different scheduler in the backend, so it's "stickier" in its + // enforcement when compared to yield. + delay(0) } } .onCompletion { stopAudioConversation() } @@ -258,7 +322,8 @@ internal constructor( FirebaseAIException.catch { if (!startedReceiving.getAndSet(false)) return@catch - scope.cancel() + networkScope.cancel() + audioScope.cancel() playBackQueue.clear() audioHelper?.release() @@ -403,10 +468,16 @@ internal constructor( audioHelper ?.listenToRecording() ?.buffer(UNLIMITED) + ?.flowOn(audioDispatcher) ?.accumulateUntil(MIN_BUFFER_SIZE) - ?.onEach { sendAudioRealtime(InlineData(it, "audio/pcm")) } + ?.onEach { + sendAudioRealtime(InlineData(it, "audio/pcm")) + // delay uses a different scheduler in the backend, so it's "stickier" in its enforcement + // when compared to yield. + delay(0) + } ?.catch { throw FirebaseAIException.from(it) } - ?.launchIn(scope) + ?.launchIn(networkScope) } /** @@ -414,7 +485,7 @@ internal constructor( * * Audio messages are added to [playBackQueue]. * - * Launched asynchronously on [scope]. + * Launched asynchronously on [networkScope]. * * @param functionCallHandler A callback function that is invoked whenever the server receives a * function call. @@ -471,7 +542,7 @@ internal constructor( } } } - .launchIn(scope) + .launchIn(networkScope) } /** @@ -479,10 +550,10 @@ internal constructor( * * Polls [playBackQueue] for data, and calls [AudioHelper.playAudio] when data is received. * - * Launched asynchronously on [scope]. + * Launched asynchronously on [networkScope]. */ private fun listenForModelPlayback(enableInterruptions: Boolean = false) { - scope.launch { + audioScope.launch { while (isActive) { val playbackData = playBackQueue.poll() if (playbackData == null) { @@ -491,14 +562,16 @@ internal constructor( if (!enableInterruptions) { audioHelper?.resumeRecording() } - yield() + // delay uses a different scheduler in the backend, so it's "stickier" in its enforcement + // when compared to yield. + delay(0) } else { /** * We pause the recording while the model is speaking to avoid interrupting it because of * no echo cancellation */ // TODO(b/408223520): Conditionally pause when param is added - if (enableInterruptions != true) { + if (!enableInterruptions) { audioHelper?.pauseRecording() } audioHelper?.playAudio(playbackData) @@ -583,5 +656,38 @@ internal constructor( AudioFormat.CHANNEL_OUT_MONO, AudioFormat.ENCODING_PCM_16BIT ) + @SuppressLint("ThreadPoolCreation") + val audioDispatcher = + Executors.newCachedThreadPool(AudioThreadFactory()).asCoroutineDispatcher() + } +} + +internal class AudioThreadFactory : ThreadFactory { + private val threadCount = AtomicLong() + private val policy: ThreadPolicy = audioPolicy() + + override fun newThread(task: Runnable?): Thread? { + val thread = + DEFAULT.newThread { + Process.setThreadPriority(Process.THREAD_PRIORITY_AUDIO) + StrictMode.setThreadPolicy(policy) + task?.run() + } + thread.name = "Firebase Audio Thread #${threadCount.andIncrement}" + return thread + } + + companion object { + val DEFAULT: ThreadFactory = Executors.defaultThreadFactory() + + private fun audioPolicy(): ThreadPolicy { + val builder = ThreadPolicy.Builder().detectNetwork() + + if (BuildConfig.DEBUG) { + builder.penaltyDeath() + } + + return builder.penaltyLog().build() + } } }