11import itertools
22from abc import ABC , abstractmethod
3- from typing import Generic , TypeVar
3+ from collections .abc import Sequence
4+ from typing import Generic
45
56from guidellm .backend .response import ResponseSummary
7+ from guidellm .config import settings
8+ from guidellm .preprocess .item import Item , ItemList
69from guidellm .request .request import GenerationRequest
10+ from guidellm .request .types import RequestT , ResponseT
711
812__all__ = ["GenerativeRequestSession" , "RequestSession" ]
913
10- RequestT = TypeVar ("RequestT" )
11- ResponseT = TypeVar ("ResponseT" )
12-
1314
1415class RequestSession (ABC , Generic [RequestT , ResponseT ]):
1516 """
@@ -35,44 +36,60 @@ def complete(self) -> bool: ...
3536
3637
3738class GenerativeRequestSession (RequestSession [GenerationRequest , ResponseSummary ]):
38- def __init__ (self , prompts : list [ GenerationRequest ] ) -> None :
39- if not prompts :
39+ def __init__ (self , items : ItemList ) -> None :
40+ if len ( items ) < 1 :
4041 raise ValueError ("Prompts cannot be empty" )
4142
42- self .prompts = prompts
43- self .responses : list [str ] = []
43+ self .prompts : Sequence [ Item ] = items
44+ self .responses : list [Item ] = []
4445
4546 def __len__ (self ) -> int :
4647 return len (self .prompts )
4748
4849 def get_next_request (self ) -> GenerationRequest :
4950 completed_responses = len (self .responses )
50- base_request = self .prompts [completed_responses ].model_copy (deep = True )
51- base_request .content = "" .join (
51+
52+ # FIXME: Can only handle string requests
53+ content = "" .join (
5254 itertools .chain .from_iterable (
53- zip ((x .content for x in self .prompts ), self .responses + ["" ])
55+ (x .value , y .value )
56+ for x , y in zip (self .prompts , self .responses + [Item (value = "" )])
5457 )
5558 )
56- base_request .stats ["prompt_tokens" ] = sum (
57- x .stats ["prompt_tokens" ] for x in self .prompts [: completed_responses + 1 ]
59+
60+ prev_prompt_tokens = sum (
61+ (x .prompt_tokens or 0 ) + (x .output_tokens or 0 ) for x in self .responses
5862 )
59- base_request .constraints ["output_tokens" ] = sum (
60- x .constraints ["output_tokens" ]
61- for x in self .prompts [: completed_responses + 1 ]
63+ prompt_tokens = (
64+ self .prompts [completed_responses ].prompt_tokens or 0
65+ ) + prev_prompt_tokens
66+
67+ output_tokens = self .prompts [completed_responses ].output_tokens
68+
69+ return GenerationRequest (
70+ request_type = settings .preferred_route ,
71+ content = content ,
72+ stats = (
73+ {"prompt_tokens" : prompt_tokens } if prompt_tokens is not None else {}
74+ ),
75+ constraints = (
76+ {"output_tokens" : output_tokens } if output_tokens is not None else {}
77+ ),
6278 )
6379
64- return base_request
65-
6680 def get_next_delay (self ) -> float :
6781 return 0.0
6882
6983 def push_response (self , response : ResponseSummary ) -> None :
7084 if len (self .responses ) < len (self .prompts ):
71- if response .response_output_tokens is not None :
72- self .prompts [len (self .responses )].constraints ["output_tokens" ] = (
73- response .response_output_tokens
74- )
75- self .responses .append (response .value )
85+ resp = Item (
86+ value = response .value ,
87+ prompt_tokens = response .response_prompt_tokens
88+ or response .request_prompt_tokens ,
89+ output_tokens = response .response_output_tokens
90+ or response .request_output_tokens ,
91+ )
92+ self .responses .append (resp )
7693 else :
7794 raise ValueError ("Response list full" )
7895
0 commit comments