11import itertools
22from abc import ABC , abstractmethod
3- from typing import Generic , TypeVar
3+ from collections .abc import Sequence
4+ from typing import Generic
45
56from guidellm .backend .response import ResponseSummary
7+ from guidellm .config import settings
8+ from guidellm .preprocess .item import Item , ItemList
69from guidellm .request .request import GenerationRequest
10+ from guidellm .request .types import RequestT , ResponseT
711
812__all__ = ["GenerativeRequestSession" , "RequestSession" ]
913
10- # TODO: Replace with specific types that implement needed features
11- RequestT = TypeVar ("RequestT" )
12- ResponseT = TypeVar ("ResponseT" )
13-
1414
1515class RequestSession (ABC , Generic [RequestT , ResponseT ]):
1616 @abstractmethod
@@ -30,46 +30,61 @@ def push_response(self, response: ResponseT) -> None: ...
3030 def complete (self ) -> bool : ...
3131
3232
33- # FIXME: Bad implementation. Can only handle string requests
3433class GenerativeRequestSession (RequestSession [GenerationRequest , ResponseSummary ]):
35- def __init__ (self , prompts : list [ GenerationRequest ] ) -> None :
36- if not prompts :
34+ def __init__ (self , items : ItemList ) -> None :
35+ if len ( items ) < 1 :
3736 raise ValueError ("Prompts cannot be empty" )
3837
39- self .prompts = prompts
40- self .responses : list [str ] = []
38+ self .prompts : Sequence [ Item ] = items
39+ self .responses : list [Item ] = []
4140
4241 def __len__ (self ) -> int :
4342 return len (self .prompts )
4443
4544 def get_next_request (self ) -> GenerationRequest :
4645 completed_responses = len (self .responses )
47- base_request = self .prompts [completed_responses ].model_copy (deep = True )
48- base_request .content = "" .join (
46+
47+ # FIXME: Can only handle string requests
48+ content = "" .join (
4949 itertools .chain .from_iterable (
50- zip ((x .content for x in self .prompts ), self .responses + ["" ])
50+ (x .value , y .value )
51+ for x , y in zip (self .prompts , self .responses + [Item (value = "" )])
5152 )
5253 )
53- base_request .stats ["prompt_tokens" ] = sum (
54- x .stats ["prompt_tokens" ] for x in self .prompts [: completed_responses + 1 ]
54+
55+ prev_prompt_tokens = sum (
56+ (x .prompt_tokens or 0 ) + (x .output_tokens or 0 ) for x in self .responses
5557 )
56- base_request .constraints ["output_tokens" ] = sum (
57- x .constraints ["output_tokens" ]
58- for x in self .prompts [: completed_responses + 1 ]
58+ prompt_tokens = (
59+ self .prompts [completed_responses ].prompt_tokens or 0
60+ ) + prev_prompt_tokens
61+
62+ output_tokens = self .prompts [completed_responses ].output_tokens
63+
64+ return GenerationRequest (
65+ request_type = settings .preferred_route ,
66+ content = content ,
67+ stats = (
68+ {"prompt_tokens" : prompt_tokens } if prompt_tokens is not None else {}
69+ ),
70+ constraints = (
71+ {"output_tokens" : output_tokens } if output_tokens is not None else {}
72+ ),
5973 )
6074
61- return base_request
62-
6375 def get_next_delay (self ) -> float :
6476 return 0.0
6577
6678 def push_response (self , response : ResponseSummary ) -> None :
6779 if len (self .responses ) < len (self .prompts ):
68- if response .response_output_tokens is not None :
69- self .prompts [len (self .responses )].constraints ["output_tokens" ] = (
70- response .response_output_tokens
71- )
72- self .responses .append (response .value )
80+ resp = Item (
81+ value = response .value ,
82+ prompt_tokens = response .response_prompt_tokens
83+ or response .request_prompt_tokens ,
84+ output_tokens = response .response_output_tokens
85+ or response .request_output_tokens ,
86+ )
87+ self .responses .append (resp )
7388 else :
7489 raise ValueError ("Response list full" )
7590
0 commit comments