Skip to content

Commit fcc1b71

Browse files
committed
Streaming token callback
1 parent a780ad9 commit fcc1b71

File tree

10 files changed

+161
-27
lines changed

10 files changed

+161
-27
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
// Copyright (c) TensorStack. All rights reserved.
2+
// Licensed under the Apache 2.0 License.
3+
4+
5+
// Copyright (c) TensorStack. All rights reserved.
6+
// Licensed under the Apache 2.0 License.
7+
8+
using TensorStack.Common.Pipeline;
9+
10+
namespace TensorStack.TextGeneration.Common
11+
{
12+
public record GenerateProgress : IRunProgress
13+
{
14+
public bool IsReset { get; set; }
15+
public string Result { get; set; }
16+
}
17+
}

TensorStack.TextGeneration/ITextGeneration.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
namespace TensorStack.TextGeneration
55
{
66
public interface ITextGeneration :
7-
IPipeline<GenerateResult, GenerateOptions>,
8-
IPipeline<GenerateResult[], SearchOptions>
7+
IPipeline<GenerateResult, GenerateOptions, GenerateProgress>,
8+
IPipeline<GenerateResult[], SearchOptions, GenerateProgress>
99
{
1010
}
1111
}

TensorStack.TextGeneration/Pipelines/DecoderPipeline.cs

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ protected virtual Sampler GetSampler(O options, bool isBeamSerach)
119119
/// <param name="options">The options.</param>
120120
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
121121
/// <returns>A Task&lt;Sequence&gt; representing the asynchronous operation.</returns>
122-
protected virtual async Task<Sequence> GreedySearchAsync(O options, CancellationToken cancellationToken = default)
122+
protected virtual async Task<Sequence> GreedySearchAsync(O options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
123123
{
124124
var sampler = GetSampler(options, false);
125125
var logitsProcessors = GetLogitsProcessor(options);
@@ -142,6 +142,9 @@ protected virtual async Task<Sequence> GreedySearchAsync(O options, Cancellation
142142
sequence.Tokens.Add(sample.TokenId);
143143
sequence.Score += sample.Score;
144144

145+
// Notify
146+
NotifyProgress(progressCallback, sequence);
147+
145148
// Check Completion
146149
if (tokenProcessors.Any(x => x.Process(sequence)))
147150
break;
@@ -158,13 +161,14 @@ protected virtual async Task<Sequence> GreedySearchAsync(O options, Cancellation
158161
/// <param name="options">The options.</param>
159162
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
160163
/// <returns>A Task&lt;Sequence[]&gt; representing the asynchronous operation.</returns>
161-
protected virtual async Task<Sequence[]> BeamSearchAsync(O options, CancellationToken cancellationToken = default)
162-
{
164+
protected virtual async Task<Sequence[]> BeamSearchAsync(O options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
165+
{
163166
var sampler = GetSampler(options, true);
164167
var logitsProcessors = GetLogitsProcessor(options);
165168
var tokenProcessors = GetTokenProcessors(options);
166169

167170
var initialPass = true;
171+
var progressTokens = new List<long>();
168172
var sequence = await InitializeAsync(options);
169173
var activeBeams = new SequenceCollection(sequence, options.Beams);
170174
while (!cancellationToken.IsCancellationRequested)
@@ -221,9 +225,9 @@ protected virtual async Task<Sequence[]> BeamSearchAsync(O options, Cancellation
221225

222226

223227
// Process Beams
228+
var bestBeam = activeBeams[0];
224229
foreach (var beam in activeBeams)
225230
{
226-
Console.WriteLine(Tokenizer.Decode(beam.Tokens));
227231
if (beam.IsComplete)
228232
continue;
229233

@@ -233,6 +237,9 @@ protected virtual async Task<Sequence[]> BeamSearchAsync(O options, Cancellation
233237
}
234238
}
235239

240+
// Notify
241+
NotifyProgress(progressCallback, bestBeam, progressTokens);
242+
236243
// Check Completion
237244
if (activeBeams.All(x => x.IsComplete))
238245
break;
@@ -325,7 +332,7 @@ protected virtual Sequence[] NormalizeAndSort(SequenceCollection sequences, O op
325332
.Where(x => x.IsComplete)
326333
.OrderByDescending(s => s.PenaltyScore)
327334
.ToArray();
328-
335+
329336
sequences.Remove(resultSequences);
330337
sequences.Clear();
331338
return resultSequences;
@@ -355,6 +362,52 @@ protected virtual Tensor<long> GetPositionIds(ModelMetadata metadata, int startP
355362
}
356363

357364

365+
/// <summary>
366+
/// Notify token progress.
367+
/// </summary>
368+
/// <param name="progressCallback">The progress callback.</param>
369+
/// <param name="sequence">The sequence.</param>
370+
/// <param name="previousTokens">The previous tokens.</param>
371+
protected void NotifyProgress(IProgress<GenerateProgress> progressCallback, Sequence sequence, List<long> previousTokens = null)
372+
{
373+
if (progressCallback == null)
374+
return;
375+
376+
string result;
377+
var hasBeamChanged = false;
378+
var newToken = sequence.Tokens[^1];
379+
if (previousTokens == null)
380+
{
381+
result = Tokenizer.Decode(newToken);
382+
}
383+
else
384+
{
385+
var newTokens = sequence.Tokens[..^1];
386+
if (sequence.Length == previousTokens.Count)
387+
return;
388+
389+
hasBeamChanged = !previousTokens.SequenceEqual(newTokens);
390+
if (hasBeamChanged)
391+
{
392+
previousTokens.Clear();
393+
previousTokens.AddRange(sequence.Tokens);
394+
result = Tokenizer.Decode(previousTokens);
395+
}
396+
else
397+
{
398+
previousTokens.Add(newToken);
399+
result = Tokenizer.Decode(newToken);
400+
}
401+
}
402+
403+
progressCallback.Report(new GenerateProgress
404+
{
405+
Result = result,
406+
IsReset = hasBeamChanged
407+
});
408+
}
409+
410+
358411
/// <summary>
359412
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
360413
/// </summary>

TensorStack.TextGeneration/Pipelines/Florence/FlorencePipeline.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
namespace TensorStack.TextGeneration.Pipelines.Florence
1919
{
2020
public class FlorencePipeline : EncoderDecoderPipeline<FlorenceOptions>,
21-
IPipeline<GenerateResult, FlorenceOptions>,
22-
IPipeline<GenerateResult[], FlorenceSearchOptions>
21+
IPipeline<GenerateResult, FlorenceOptions, GenerateProgress>,
22+
IPipeline<GenerateResult[], FlorenceSearchOptions, GenerateProgress>
2323
{
2424
private readonly FlorenceConfig _configuration;
2525
private readonly PreProcessor _preProcessor;
@@ -83,7 +83,7 @@ await Task.WhenAll
8383
/// <param name="progressCallback">The progress callback.</param>
8484
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
8585
/// <returns>A Task&lt;GenerateResult&gt; representing the asynchronous operation.</returns>
86-
public virtual async Task<GenerateResult> RunAsync(FlorenceOptions options, IProgress<RunProgress> progressCallback = null, CancellationToken cancellationToken = default)
86+
public virtual async Task<GenerateResult> RunAsync(FlorenceOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
8787
{
8888
var textPrompt = _preProcessor.ProcessPrompt(options);
8989
var imagePrompt = _preProcessor.ProcessImage(options);
@@ -93,7 +93,7 @@ public virtual async Task<GenerateResult> RunAsync(FlorenceOptions options, IPro
9393
_visionOutput = await RunVisionEncoderAsync(embedsOutput, imagePrompt);
9494
EncoderOutput = await RunEncoderAsync();
9595

96-
var sequence = await GreedySearchAsync(options, cancellationToken);
96+
var sequence = await GreedySearchAsync(options, progressCallback, cancellationToken);
9797
using (sequence)
9898
{
9999
var processedBeamOutput = _postProcessor.Process(options, sequence.Tokens);
@@ -107,7 +107,7 @@ public virtual async Task<GenerateResult> RunAsync(FlorenceOptions options, IPro
107107
}
108108

109109

110-
public virtual async Task<GenerateResult[]> RunAsync(FlorenceSearchOptions options, IProgress<RunProgress> progressCallback = null, CancellationToken cancellationToken = default)
110+
public virtual async Task<GenerateResult[]> RunAsync(FlorenceSearchOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
111111
{
112112
var textPrompt = _preProcessor.ProcessPrompt(options);
113113
var imagePrompt = _preProcessor.ProcessImage(options);
@@ -117,7 +117,7 @@ public virtual async Task<GenerateResult[]> RunAsync(FlorenceSearchOptions optio
117117
_visionOutput = await RunVisionEncoderAsync(embedsOutput, imagePrompt);
118118
EncoderOutput = await RunEncoderAsync();
119119

120-
var sequences = await BeamSearchAsync(options, cancellationToken);
120+
var sequences = await BeamSearchAsync(options, progressCallback, cancellationToken);
121121
var results = new GenerateResult[sequences.Length];
122122
for (int beam = 0; beam < sequences.Length; beam++)
123123
{

TensorStack.TextGeneration/Pipelines/Other/SummaryPipeline.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ public SummaryPipeline(SummaryConfig configuration)
2828
/// <param name="progressCallback">The progress callback.</param>
2929
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
3030
/// <returns>A Task&lt;GenerateResult&gt; representing the asynchronous operation.</returns>
31-
public async Task<GenerateResult> RunAsync(GenerateOptions options, IProgress<RunProgress> progressCallback = null, CancellationToken cancellationToken = default)
31+
public async Task<GenerateResult> RunAsync(GenerateOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
3232
{
3333
await TokenizePromptAsync(options);
3434

35-
var sequence = await GreedySearchAsync(options, cancellationToken);
35+
var sequence = await GreedySearchAsync(options, progressCallback, cancellationToken);
3636
using (sequence)
3737
{
3838
return new GenerateResult
@@ -51,11 +51,11 @@ public async Task<GenerateResult> RunAsync(GenerateOptions options, IProgress<Ru
5151
/// <param name="options">The options.</param>
5252
/// <param name="progressCallback">The progress callback.</param>
5353
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
54-
public async Task<GenerateResult[]> RunAsync(SearchOptions options, IProgress<RunProgress> progressCallback = null, CancellationToken cancellationToken = default)
54+
public async Task<GenerateResult[]> RunAsync(SearchOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
5555
{
5656
await TokenizePromptAsync(options);
5757

58-
var sequences = await BeamSearchAsync(options, cancellationToken);
58+
var sequences = await BeamSearchAsync(options, progressCallback, cancellationToken);
5959
var results = new GenerateResult[sequences.Length];
6060
for (int beam = 0; beam < sequences.Length; beam++)
6161
{

TensorStack.TextGeneration/Pipelines/Phi/Phi3Pipeline.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ public Phi3Pipeline(Phi3Config configuration)
3737
/// <param name="options">The options.</param>
3838
/// <param name="cancellationToken">The cancellation token.</param>
3939
/// <returns></returns>
40-
public virtual async Task<GenerateResult> RunAsync(GenerateOptions options, IProgress<RunProgress> progressCallback = null, CancellationToken cancellationToken = default)
40+
public virtual async Task<GenerateResult> RunAsync(GenerateOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
4141
{
4242
await TokenizePromptAsync(options);
43-
var sequence = await GreedySearchAsync(options, cancellationToken);
43+
var sequence = await GreedySearchAsync(options, progressCallback, cancellationToken);
4444
using (sequence)
4545
{
4646
return new GenerateResult
@@ -58,11 +58,11 @@ public virtual async Task<GenerateResult> RunAsync(GenerateOptions options, IPro
5858
/// <param name="options">The options.</param>
5959
/// <param name="progressCallback">The progress callback.</param>
6060
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
61-
public async Task<GenerateResult[]> RunAsync(SearchOptions options, IProgress<RunProgress> progressCallback = null, CancellationToken cancellationToken = default)
61+
public async Task<GenerateResult[]> RunAsync(SearchOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
6262
{
6363
await TokenizePromptAsync(options);
6464

65-
var sequences = await BeamSearchAsync(options, cancellationToken);
65+
var sequences = await BeamSearchAsync(options,progressCallback, cancellationToken);
6666
var results = new GenerateResult[sequences.Length];
6767
for (int beam = 0; beam < sequences.Length; beam++)
6868
{

TensorStack.TextGeneration/Pipelines/Whisper/WhisperPipeline.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
namespace TensorStack.TextGeneration.Pipelines.Whisper
1818
{
1919
public class WhisperPipeline : EncoderDecoderPipeline<WhisperOptions>,
20-
IPipeline<GenerateResult, WhisperOptions>,
21-
IPipeline<GenerateResult[], WhisperSearchOptions>
20+
IPipeline<GenerateResult, WhisperOptions, GenerateProgress>,
21+
IPipeline<GenerateResult[], WhisperSearchOptions, GenerateProgress>
2222
{
2323
private readonly PreProcessor _preProcessor;
2424
private Tensor<float> _currentAudioSample;
@@ -42,14 +42,14 @@ public WhisperPipeline(WhisperConfig configuration)
4242
/// <param name="progressCallback">The progress callback.</param>
4343
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
4444
/// <returns>A Task&lt;GenerateResult&gt; representing the asynchronous operation.</returns>
45-
public async Task<GenerateResult> RunAsync(WhisperOptions options, IProgress<RunProgress> progressCallback = null, CancellationToken cancellationToken = default)
45+
public async Task<GenerateResult> RunAsync(WhisperOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
4646
{
4747
var result = default(GenerateResult);
4848
var audioSamples = _preProcessor.ProcessInput(options.AudioInput);
4949
foreach (var sample in audioSamples)
5050
{
5151
await RunEncoderAsync(sample);
52-
var sequence = await GreedySearchAsync(options, cancellationToken);
52+
var sequence = await GreedySearchAsync(options, progressCallback, cancellationToken);
5353
using (sequence)
5454
{
5555
if (result != null)
@@ -79,14 +79,14 @@ public async Task<GenerateResult> RunAsync(WhisperOptions options, IProgress<Run
7979
/// <param name="options">The options.</param>
8080
/// <param name="progressCallback">The progress callback.</param>
8181
/// <param name="cancellationToken">The cancellation token that can be used by other objects or threads to receive notice of cancellation.</param>
82-
public async Task<GenerateResult[]> RunAsync(WhisperSearchOptions options, IProgress<RunProgress> progressCallback = null, CancellationToken cancellationToken = default)
82+
public async Task<GenerateResult[]> RunAsync(WhisperSearchOptions options, IProgress<GenerateProgress> progressCallback = null, CancellationToken cancellationToken = default)
8383
{
8484
var results = new List<GenerateResult>();
8585
var audioSamples = _preProcessor.ProcessInput(options.AudioInput);
8686
foreach (var sample in audioSamples)
8787
{
8888
await RunEncoderAsync(sample);
89-
var sequences = await BeamSearchAsync(options, cancellationToken);
89+
var sequences = await BeamSearchAsync(options, progressCallback, cancellationToken);
9090
for (int beam = 0; beam < sequences.Length; beam++)
9191
{
9292
var sequence = sequences[beam];

TensorStack.TextGeneration/Tokenizers/BPETokenizer.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,26 @@ public Task<string> DecodeAsync(IEnumerable<long> tokens, bool considerSpecialTo
100100
}
101101

102102

103+
/// <summary>
104+
/// Decodes the specified token.
105+
/// </summary>
106+
/// <param name="token">The token.</param>
107+
public string Decode(int token, bool considerSpecialTokens = false)
108+
{
109+
return Decode(token);
110+
}
111+
112+
113+
/// <summary>
114+
/// Decodes the specified token.
115+
/// </summary>
116+
/// <param name="token">The token.</param>
117+
public string Decode(long token, bool considerSpecialTokens = false)
118+
{
119+
return VocabularyMap[token];
120+
}
121+
122+
103123
/// <summary>
104124
/// TokenId to Token.
105125
/// </summary>

TensorStack.TextGeneration/Tokenizers/ITokenizer.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,18 @@ public interface ITokenizer : IDisposable
3030
/// <param name="considerSpecialTokens">if set to <c>true</c> decode special tokens.</param>
3131
Task<TokenizerResult> EncodeAsync(ReadOnlySpan<char> input);
3232

33+
/// <summary>
34+
/// Decodes the specified token.
35+
/// </summary>
36+
/// <param name="token">The token.</param>
37+
string Decode(int token, bool considerSpecialTokens = false);
38+
39+
/// <summary>
40+
/// Decodes the specified token.
41+
/// </summary>
42+
/// <param name="token">The token.</param>
43+
string Decode(long token, bool considerSpecialTokens = false);
44+
3345
/// <summary>
3446
/// Decodes the specified tokens to string.
3547
/// </summary>

TensorStack.TextGeneration/Tokenizers/T5Tokenizer.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,38 @@ public Task<string> DecodeAsync(IEnumerable<long> tokens, bool considerSpecialTo
102102
}
103103

104104

105+
/// <summary>
106+
/// Decodes the specified token.
107+
/// </summary>
108+
/// <param name="token">The token.</param>
109+
/// <param name="considerSpecialTokens">if set to <c>true</c> [consider special tokens].</param>
110+
public string Decode(int token, bool considerSpecialTokens = false)
111+
{
112+
return Decode(token);
113+
}
114+
115+
116+
/// <summary>
117+
/// Decodes the specified token.
118+
/// </summary>
119+
/// <param name="token">The token.</param>
120+
/// <param name="considerSpecialTokens">if set to <c>true</c> [consider special tokens].</param>
121+
public string Decode(long token, bool considerSpecialTokens = false)
122+
{
123+
var vocabResult = _tokenizer.Vocabulary.FirstOrDefault(x => x.Value == token);
124+
if (vocabResult.Key is not null)
125+
return vocabResult.Key.Replace('▁', ' ');
126+
127+
if (considerSpecialTokens)
128+
{
129+
var specialToken = _tokenizer.SpecialTokens.FirstOrDefault(x => x.Value == token);
130+
if (specialToken.Key is not null)
131+
return specialToken.Key.Replace('▁', ' ');
132+
}
133+
return string.Empty;
134+
}
135+
136+
105137
/// <summary>
106138
/// Creates the tokenizer.
107139
/// </summary>

0 commit comments

Comments
 (0)