Skip to content

Commit 0b2c7ed

Browse files
Stream output
1 parent 0e6c6e0 commit 0b2c7ed

File tree

2 files changed

+43
-44
lines changed

2 files changed

+43
-44
lines changed

src/cpp/src/speculative_decoding/speculative_decoding_stateful_eagle3.cpp

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ overloaded(Ts...) -> overloaded<Ts...>;
2929

3030
namespace {
3131

32+
// Stream generated tokens to output
33+
ov::genai::StreamingStatus stream_generated_tokens(std::shared_ptr<ov::genai::StreamerBase> streamer_ptr,
34+
const std::vector<int64_t>& tokens) {
35+
if (streamer_ptr) {
36+
return streamer_ptr->write(tokens);
37+
}
38+
return ov::genai::StreamingStatus{};
39+
}
40+
3241
// Format microseconds for logging
3342
std::string format_duration_us(uint64_t microseconds) {
3443
if (microseconds < 1000) {
@@ -869,6 +878,9 @@ EncodedResults StatefulEagle3LLMPipeline::generate(const EncodedInputs& inputs,
869878

870879
auto config = resolve_generation_config(generation_config);
871880

881+
// Create streamer for streaming output
882+
std::shared_ptr<StreamerBase> streamer_ptr = ov::genai::utils::create_streamer(streamer, m_tokenizer);
883+
872884
log_info("Starting Eagle3 generation with max_new_tokens=" + std::to_string(config.max_new_tokens) +
873885
", draft_iterations=" + std::to_string(m_draft_iterations));
874886

@@ -889,12 +901,9 @@ EncodedResults StatefulEagle3LLMPipeline::generate(const EncodedInputs& inputs,
889901

890902
std::size_t prompt_len = prompt_shape[1];
891903

892-
// Initialize Pipeline's verified token storage with prompt
893-
const int64_t* prompt_data = input_ids.data<const int64_t>();
894-
m_verified_tokens.assign(prompt_data, prompt_data + prompt_len);
895904
m_prompt_length = prompt_len;
896905

897-
log_debug("Initialized verified_tokens with prompt (" + std::to_string(prompt_len) + " tokens)");
906+
log_debug("Prompt length: " + std::to_string(prompt_len) + " tokens");
898907

899908
// Initialize position IDs
900909
ov::Tensor position_ids{ov::element::i64, input_ids.get_shape()};
@@ -920,11 +929,10 @@ EncodedResults StatefulEagle3LLMPipeline::generate(const EncodedInputs& inputs,
920929
m_main_model->append_tokens({initial_token});
921930
m_draft_model->append_tokens({initial_token});
922931

923-
// Write initial verified token to Pipeline storage
924-
m_verified_tokens.push_back(initial_token);
932+
// Stream the initial token
933+
auto streaming_status = stream_generated_tokens(streamer_ptr, std::vector<int64_t>{initial_token});
925934

926-
log_debug("Initial token generated: " + std::to_string(initial_token) +
927-
", verified_tokens now: " + std::to_string(m_verified_tokens.size()));
935+
log_debug("Initial token generated: " + std::to_string(initial_token));
928936
log_sequence_state("after initial token generation");
929937

930938
// Main generation loop using speculative decoding
@@ -940,13 +948,17 @@ EncodedResults StatefulEagle3LLMPipeline::generate(const EncodedInputs& inputs,
940948
std::size_t total_iterations = 0; // Number of speculative iterations
941949

942950
while (!eos_reached && generated_tokens < max_new_tokens &&
943-
m_main_model->get_sequence_length() < prompt_len + max_new_tokens) {
951+
m_main_model->get_sequence_length() < prompt_len + max_new_tokens &&
952+
(streaming_status == ov::genai::StreamingStatus::RUNNING)) {
944953
log_generation_step("Speculative Decoding Iteration", generated_tokens);
945954
log_sequence_state("iteration start");
946955

947956
auto result =
948957
run_speculative_iteration(target_hidden_states, token_count, static_cast<int64_t>(config.eos_token_id));
949958

959+
// Stream validated tokens
960+
streaming_status = stream_generated_tokens(streamer_ptr, result.validated_tokens);
961+
950962
// Update iteration counter
951963
total_iterations++;
952964

@@ -986,23 +998,28 @@ EncodedResults StatefulEagle3LLMPipeline::generate(const EncodedInputs& inputs,
986998
log_sequence_state("iteration end");
987999
}
9881000

989-
// Convert all verified tokens to text and display
990-
if (!m_verified_tokens.empty()) {
991-
try {
992-
std::string decoded_text = m_tokenizer.decode(m_verified_tokens);
993-
std::cout << "[EAGLE3-FINAL] All verified tokens decoded (" << m_verified_tokens.size() << " tokens): \""
994-
<< decoded_text << "\"" << std::endl;
995-
} catch (const std::exception& e) {
996-
std::cout << "[EAGLE3-FINAL] Failed to decode verified tokens: " << e.what() << std::endl;
997-
}
1001+
m_streaming_was_cancelled = (streaming_status == ov::genai::StreamingStatus::CANCEL);
1002+
if (streamer_ptr) { // push streamer's cache
1003+
streamer_ptr->end();
9981004
}
9991005

1000-
// Prepare results using Pipeline's verified tokens (source of truth)
1006+
// Prepare results using main model's tokens as source of truth
10011007
EncodedResults results;
1002-
results.tokens = {m_verified_tokens};
1008+
results.tokens = {m_main_model->get_tokens()};
10031009
results.scores.resize(1);
10041010
results.scores[0] = 0.0f; // Greedy decoding, no scores
10051011

1012+
// Display final tokens if verbose
1013+
if (is_verbose() && !results.tokens[0].empty()) {
1014+
try {
1015+
std::string decoded_text = m_tokenizer.decode(results.tokens[0]);
1016+
std::cout << "[EAGLE3-FINAL] All tokens decoded (" << results.tokens[0].size() << " tokens): \""
1017+
<< decoded_text << "\"" << std::endl;
1018+
} catch (const std::exception& e) {
1019+
std::cout << "[EAGLE3-FINAL] Failed to decode tokens: " << e.what() << std::endl;
1020+
}
1021+
}
1022+
10061023
// Update performance metrics following the standard stateful speculative decoding pattern
10071024
generate_timer.end();
10081025

@@ -1151,11 +1168,6 @@ StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_spec
11511168
m_main_model->append_tokens(tokens_to_append);
11521169
m_draft_model->append_tokens(tokens_to_append);
11531170

1154-
// Write accepted tokens to Pipeline's verified storage
1155-
for (const auto& token : tokens_to_append) {
1156-
m_verified_tokens.push_back(token);
1157-
}
1158-
11591171
// Trim KV cache for rejected draft tokens
11601172
if (tokens_to_remove_from_draft > 0) {
11611173
m_main_model->trim_kv_cache(tokens_to_remove_from_draft);
@@ -1164,7 +1176,7 @@ StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_spec
11641176

11651177
log_debug("Accepted total " + std::to_string(total_accepted_tokens) + " tokens (" + std::to_string(accepted_count) +
11661178
" draft + 1 main prediction), rejected " + std::to_string(tokens_to_remove_from_draft) +
1167-
" draft tokens. " + "Verified tokens now: " + std::to_string(m_verified_tokens.size()));
1179+
" draft tokens.");
11681180

11691181
// Build next hidden window for next iteration
11701182
ov::Tensor next_hidden;
@@ -1208,6 +1220,7 @@ StatefulEagle3LLMPipeline::SpeculativeResult StatefulEagle3LLMPipeline::run_spec
12081220
result.next_window_size = total_accepted_tokens; // Total tokens for next iteration (draft + main)
12091221
result.new_token = main_predicted_token; // Main model's prediction
12101222
result.next_hidden_window = next_hidden;
1223+
result.validated_tokens = tokens_to_append; // Return validated tokens for streaming
12111224

12121225
log_debug("Speculative iteration completed - accepted " + std::to_string(accepted_count) +
12131226
" draft tokens + 1 main prediction = " + std::to_string(total_accepted_tokens) +
@@ -1291,23 +1304,11 @@ void StatefulEagle3LLMPipeline::log_sequence_state(const std::string& context) c
12911304
return;
12921305

12931306
std::cout << "[EAGLE3-PIPELINE] Sequence state (" << context << "):" << std::endl;
1294-
std::cout << " Pipeline verified tokens: " << m_verified_tokens.size() << " tokens (prompt=" << m_prompt_length
1295-
<< ", generated=" << (m_verified_tokens.size() - m_prompt_length) << ")" << std::endl;
1296-
std::cout << " Main model tokens (speculative): " << m_main_model->get_sequence_length() << " tokens" << std::endl;
1297-
std::cout << " Draft model tokens (speculative): " << m_draft_model->get_sequence_length() << " tokens"
1307+
std::cout << " Prompt length: " << m_prompt_length << " tokens" << std::endl;
1308+
std::cout << " Main model tokens: " << m_main_model->get_sequence_length() << " tokens" << std::endl;
1309+
std::cout << " Draft model tokens: " << m_draft_model->get_sequence_length() << " tokens"
12981310
<< std::endl;
12991311

1300-
// Show verified tokens from Pipeline
1301-
if (!m_verified_tokens.empty()) {
1302-
std::cout << " Pipeline verified tokens: ";
1303-
for (std::size_t i = 0; i < m_verified_tokens.size(); ++i) {
1304-
std::cout << m_verified_tokens[i];
1305-
if (i + 1 < m_verified_tokens.size())
1306-
std::cout << ", ";
1307-
}
1308-
std::cout << std::endl;
1309-
}
1310-
13111312
// Show all tokens and positions from main model
13121313
const auto& main_tokens = m_main_model->get_tokens();
13131314
const auto& main_positions = m_main_model->get_positions();

src/cpp/src/speculative_decoding/speculative_decoding_stateful_eagle3.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ class Eagle3InferWrapperBase {
123123
std::size_t m_kv_cache_capacity = 0;
124124

125125
// Speculative token sequences (may be rolled back)
126-
// Verified tokens stored in Pipeline's m_verified_tokens
127126
std::vector<int64_t> m_tokens;
128127
std::vector<int64_t> m_positions;
129128

@@ -220,6 +219,7 @@ class StatefulEagle3LLMPipeline : public ov::genai::LLMPipelineImplBase {
220219
std::size_t next_window_size = 0;
221220
int64_t new_token = -1;
222221
bool eos_reached = false;
222+
std::vector<int64_t> validated_tokens; // Tokens accepted by main model (draft + main prediction)
223223
};
224224

225225
// Core algorithm
@@ -250,8 +250,6 @@ class StatefulEagle3LLMPipeline : public ov::genai::LLMPipelineImplBase {
250250
ov::Tensor m_draft_target_mapping;
251251
std::vector<int> m_hidden_layers_to_abstract;
252252

253-
// Verified tokens (source of truth)
254-
std::vector<int64_t> m_verified_tokens;
255253
std::size_t m_prompt_length = 0;
256254

257255
// Metrics

0 commit comments

Comments
 (0)