@@ -119,7 +119,7 @@ protected virtual Sampler GetSampler(O options, bool isBeamSerach)
119119 /// <param name="options">The options.</param>
120120 /// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
121121 /// <returns>A Task<Sequence> representing the asynchronous operation.</returns>
122- protected virtual async Task < Sequence > GreedySearchAsync ( O options , CancellationToken cancellationToken = default )
122+ protected virtual async Task < Sequence > GreedySearchAsync ( O options , IProgress < GenerateProgress > progressCallback = null , CancellationToken cancellationToken = default )
123123 {
124124 var sampler = GetSampler ( options , false ) ;
125125 var logitsProcessors = GetLogitsProcessor ( options ) ;
@@ -142,6 +142,9 @@ protected virtual async Task<Sequence> GreedySearchAsync(O options, Cancellation
142142 sequence . Tokens . Add ( sample . TokenId ) ;
143143 sequence . Score += sample . Score ;
144144
145+ // Notify
146+ NotifyProgress ( progressCallback , sequence ) ;
147+
145148 // Check Completion
146149 if ( tokenProcessors . Any ( x => x . Process ( sequence ) ) )
147150 break ;
@@ -158,13 +161,14 @@ protected virtual async Task<Sequence> GreedySearchAsync(O options, Cancellation
158161 /// <param name="options">The options.</param>
159162 /// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
160163 /// <returns>A Task<Sequence[]> representing the asynchronous operation.</returns>
161- protected virtual async Task < Sequence [ ] > BeamSearchAsync ( O options , CancellationToken cancellationToken = default )
162- {
164+ protected virtual async Task < Sequence [ ] > BeamSearchAsync ( O options , IProgress < GenerateProgress > progressCallback = null , CancellationToken cancellationToken = default )
165+ {
163166 var sampler = GetSampler ( options , true ) ;
164167 var logitsProcessors = GetLogitsProcessor ( options ) ;
165168 var tokenProcessors = GetTokenProcessors ( options ) ;
166169
167170 var initialPass = true ;
171+ var progressTokens = new List < long > ( ) ;
168172 var sequence = await InitializeAsync ( options ) ;
169173 var activeBeams = new SequenceCollection ( sequence , options . Beams ) ;
170174 while ( ! cancellationToken . IsCancellationRequested )
@@ -221,9 +225,9 @@ protected virtual async Task<Sequence[]> BeamSearchAsync(O options, Cancellation
221225
222226
223227 // Process Beams
228+ var bestBeam = activeBeams [ 0 ] ;
224229 foreach ( var beam in activeBeams )
225230 {
226- Console . WriteLine ( Tokenizer . Decode ( beam . Tokens ) ) ;
227231 if ( beam . IsComplete )
228232 continue ;
229233
@@ -233,6 +237,9 @@ protected virtual async Task<Sequence[]> BeamSearchAsync(O options, Cancellation
233237 }
234238 }
235239
240+ // Notify
241+ NotifyProgress ( progressCallback , bestBeam , progressTokens ) ;
242+
236243 // Check Completion
237244 if ( activeBeams . All ( x => x . IsComplete ) )
238245 break ;
@@ -325,7 +332,7 @@ protected virtual Sequence[] NormalizeAndSort(SequenceCollection sequences, O op
325332 . Where ( x => x . IsComplete )
326333 . OrderByDescending ( s => s . PenaltyScore )
327334 . ToArray ( ) ;
328-
335+
329336 sequences . Remove ( resultSequences ) ;
330337 sequences . Clear ( ) ;
331338 return resultSequences ;
@@ -355,6 +362,52 @@ protected virtual Tensor<long> GetPositionIds(ModelMetadata metadata, int startP
355362 }
356363
357364
365+ /// <summary>
366+ /// Notify token progress.
367+ /// </summary>
368+ /// <param name="progressCallback">The progress callback.</param>
369+ /// <param name="sequence">The sequence.</param>
370+ /// <param name="previousTokens">The previous tokens.</param>
371+ protected void NotifyProgress ( IProgress < GenerateProgress > progressCallback , Sequence sequence , List < long > previousTokens = null )
372+ {
373+ if ( progressCallback == null )
374+ return ;
375+
376+ string result ;
377+ var hasBeamChanged = false ;
378+ var newToken = sequence . Tokens [ ^ 1 ] ;
379+ if ( previousTokens == null )
380+ {
381+ result = Tokenizer . Decode ( newToken ) ;
382+ }
383+ else
384+ {
385+ var newTokens = sequence . Tokens [ ..^ 1 ] ;
386+ if ( sequence . Length == previousTokens . Count )
387+ return ;
388+
389+ hasBeamChanged = ! previousTokens . SequenceEqual ( newTokens ) ;
390+ if ( hasBeamChanged )
391+ {
392+ previousTokens . Clear ( ) ;
393+ previousTokens . AddRange ( sequence . Tokens ) ;
394+ result = Tokenizer . Decode ( previousTokens ) ;
395+ }
396+ else
397+ {
398+ previousTokens . Add ( newToken ) ;
399+ result = Tokenizer . Decode ( newToken ) ;
400+ }
401+ }
402+
403+ progressCallback . Report ( new GenerateProgress
404+ {
405+ Result = result ,
406+ IsReset = hasBeamChanged
407+ } ) ;
408+ }
409+
410+
358411 /// <summary>
359412 /// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
360413 /// </summary>
0 commit comments