@@ -17,6 +17,7 @@ limitations under the License.
1717package prefix
1818
1919import (
20+ "bytes"
2021 "context"
2122 "encoding/binary"
2223 "encoding/json"
@@ -28,6 +29,7 @@ import (
2829 k8stypes "k8s.io/apimachinery/pkg/types"
2930 "sigs.k8s.io/controller-runtime/pkg/log"
3031
32+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
3133 backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3234 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
3335 "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
@@ -114,6 +116,10 @@ var _ plugins.StateData = &SchedulingContextState{}
114116type SchedulingContextState struct {
115117 // PrefixHashes is a list of prefix hashes of the request prompt broken into blocks.
116118 PrefixHashes []BlockHash
119+ // RestBytes is the trailing bytes that not able to fill in a full block and left over.
120+ // If not empty, this will be used as the starting block for the following response that will
121+ // be added to the response as well. This happens especially at the multi-turn scenario.
122+ RestBytes []byte
117123 // A map of server to its longest prefix cache match length.
118124 PrefixCacheServers map [ServerID ]int
119125}
@@ -190,9 +196,10 @@ func (p *Plugin) WithName(name string) *Plugin {
190196// Score returns the scoring result for the given list of pods based on context.
191197func (p * Plugin ) Score (ctx context.Context , cycleState * types.CycleState , request * types.LLMRequest , pods []types.Pod ) map [types.Pod ]float64 {
192198 // pre score step, hashing prompt and find longest prefix match.
193- hashes := hashPrompt (ctx , request , p .config .BlockSize , p .config .MaxPrefixBlocksToMatch )
199+ hashes , restBytes := hashPrompt (ctx , request , p .config .BlockSize , p .config .MaxPrefixBlocksToMatch )
194200 state := & SchedulingContextState {
195201 PrefixHashes : hashes ,
202+ RestBytes : restBytes ,
196203 PrefixCacheServers : p .matchLongestPrefix (ctx , hashes ),
197204 }
198205
@@ -223,7 +230,6 @@ func (p *Plugin) PreRequest(ctx context.Context, request *types.LLMRequest, sche
223230 targetPod := primaryProfileResult .TargetPods [0 ].GetPod () // get the first pod of the primary profile
224231
225232 state , err := plugins .ReadPluginStateKey [* SchedulingContextState ](p .pluginState , request .RequestId , plugins .StateKey (p .TypedName ().String ()))
226- p .pluginState .Delete (request .RequestId ) // delete the state explicitly after completing using it
227233 if err != nil {
228234 log .FromContext (ctx ).Error (err , "failed to read prefix plugin state" , "requestID" , request .RequestId )
229235 return
@@ -296,47 +302,58 @@ func (m *Plugin) CleanUpInactivePods(ctx context.Context, handle plugins.Handle)
296302// hashPrompt divides the prompt into blocks and calculate the prefix cache for each block.
297303// hash[0] is calculated including the model name and cache_salt(if provided), since different models generally don't share prefix cache.
298304// For block i, hash(i) = hash(block i content, hash(i-1)).
299- func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) []BlockHash {
305+ // Also return the extra string.
306+ func hashPrompt (ctx context.Context , request * types.LLMRequest , cacheBlockSize int , maxPrefixBlocks int ) ([]BlockHash , []byte ) {
300307 loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
301308 if request == nil || request .Body == nil {
302309 loggerDebug .Info ("Request or request data is nil, skipping hashing" )
303- return nil
310+ return nil , nil
304311 }
305-
306312 userInput , err := getUserInputBytes (request )
307313 if err != nil {
308314 loggerDebug .Error (err , "Failed to get user input bytes" )
309- return nil
315+ return nil , nil
310316 }
317+ prevBlockHash := defaultPrevBlock (request )
318+ return hashInputWithPrevBlockHash (ctx , prevBlockHash , 0 , userInput , cacheBlockSize , maxPrefixBlocks )
319+ }
311320
312- if len (userInput ) < cacheBlockSize {
313- loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (userInput ), "block size" , cacheBlockSize )
314- return nil
315- }
316- if len (userInput ) > cacheBlockSize * maxPrefixBlocks {
317- loggerDebug .Info ("Truncating input" , "size" , len (userInput ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
318- userInput = userInput [:maxPrefixBlocks * cacheBlockSize ]
319- }
320- // Split the body into blocks of size cacheBlockSize.
321- // If the last block is smaller than cacheBlockSize, it will be ignored.
322- res := make ([]BlockHash , 0 , len (userInput )/ cacheBlockSize )
323- // Add the model to the first block hash so that different models have different hashes even with the same body.
321+ func defaultPrevBlock (request * types.LLMRequest ) BlockHash {
324322 h := xxhash .New ()
323+ // Add the model to the first block hash so that different models have different hashes even with the same body.
325324 _ , _ = h .Write ([]byte (request .TargetModel ))
326325 if cacheSalt := request .Body .CacheSalt (); cacheSalt != "" {
327326 _ , _ = h .Write ([]byte (cacheSalt ))
328327 }
329328
330- prevBlockHash := BlockHash (h .Sum64 ())
331- for i := 0 ; i + cacheBlockSize <= len (userInput ); i += cacheBlockSize {
329+ return BlockHash (h .Sum64 ())
330+ }
331+
332+ func hashInputWithPrevBlockHash (ctx context.Context , prevBlockHash BlockHash , prevBlockLength int , input []byte , cacheBlockSize int , maxPrefixBlocks int ) ([]BlockHash , []byte ) {
333+ loggerDebug := log .FromContext (ctx ).V (logutil .DEBUG )
334+ if len (input )+ prevBlockLength < cacheBlockSize {
335+ loggerDebug .Info ("Request body too small for prefix cache" , "size" , len (input ), "block size" , cacheBlockSize )
336+ return nil , input
337+ }
338+ if len (input )+ prevBlockLength > cacheBlockSize * maxPrefixBlocks {
339+ loggerDebug .Info ("Truncating input" , "size" , len (input ), "max prefix blocks" , maxPrefixBlocks , "block size" , cacheBlockSize )
340+ input = input [:(maxPrefixBlocks * cacheBlockSize - prevBlockLength )]
341+ }
342+ // Split the body into blocks of size cacheBlockSize.
343+ // If the last block is smaller than cacheBlockSize, it will be ignored.
344+ res := make ([]BlockHash , 0 , len (input )/ cacheBlockSize )
345+ lastOffSet := 0
346+ h := xxhash .New ()
347+ for i := 0 ; i + cacheBlockSize <= len (input ); i += cacheBlockSize {
332348 h .Reset ()
333- _ , _ = h .Write (userInput [i : i + cacheBlockSize ])
349+ _ , _ = h .Write (input [i : i + cacheBlockSize ])
334350 _ , _ = h .Write (toBytes (prevBlockHash ))
335351 res = append (res , BlockHash (h .Sum64 ()))
336352
337353 prevBlockHash = res [len (res )- 1 ]
354+ lastOffSet = i + cacheBlockSize
338355 }
339- return res
356+ return res , input [ lastOffSet :]
340357}
341358
342359func toBytes (i BlockHash ) []byte {
@@ -353,3 +370,30 @@ func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
353370 // must be chat-completions request at this point, return bytes of entire messages
354371 return json .Marshal (request .Body .ChatCompletions .Messages )
355372}
373+
374+ func (p * Plugin ) ResponseComplete (ctx context.Context , request * types.LLMRequest , response * requestcontrol.Response , targetPod * backend.Pod ) {
375+ state , err := plugins .ReadPluginStateKey [* SchedulingContextState ](p .pluginState , request .RequestId , plugins .StateKey (p .TypedName ().String ()))
376+ if err != nil {
377+ log .FromContext (ctx ).Error (err , "failed to read prefix plugin state" , "requestID" , request .RequestId )
378+ return
379+ }
380+ p .pluginState .Delete (request .RequestId ) // delete the state explicitly after completing using it.
381+ var input bytes.Buffer
382+ input .Write (state .RestBytes )
383+ input .Write ([]byte (response .Body ))
384+
385+ server := ServerID (targetPod .NamespacedName )
386+ prevBlockHash := defaultPrevBlock (request )
387+ prevBlockHashLength := 0
388+ if len (state .PrefixHashes ) > 0 {
389+ prevBlockHash = state .PrefixHashes [len (state .PrefixHashes )- 1 ]
390+ prevBlockHashLength = len (state .PrefixHashes )
391+ }
392+ inputBytes := input .Bytes ()
393+ hashBlocks , _ := hashInputWithPrevBlockHash (ctx , prevBlockHash , prevBlockHashLength , inputBytes , p .config .BlockSize , p .config .MaxPrefixBlocksToMatch )
394+ p .wg .Add (1 )
395+ go func () {
396+ p .indexer .Add (hashBlocks , server )
397+ p .wg .Done ()
398+ }()
399+ }
0 commit comments