44 *
55 * This source code is licensed under the BSD-style license found in the
66 * LICENSE file in the root directory of this source tree.
7+ * @lint-ignore-every CLANGTIDY facebook-hte-Deprecated
78 */
89
910// A simple llama2 runner that includes preprocessing and post processing logic.
1011// The module takes in a string as input and emits a string as output.
1112
1213#include < executorch/examples/models/llama/runner/runner.h>
1314
14- #include < algorithm>
15- #include < ctime>
16-
1715#include < executorch/extension/llm/runner/util.h>
1816
1917#include < executorch/examples/models/llama/tokenizer/llama_tiktoken.h>
@@ -62,125 +60,162 @@ std::unique_ptr<::tokenizers::Tokenizer> load_tokenizer(
6260}
6361} // namespace
6462
65- Runner:: Runner (
63+ std::unique_ptr< Runner> Runner::create (
6664 const std::string& model_path,
6765 const std::string& tokenizer_path,
68- std::optional<const std::string> data_path)
69- // NOTE: we observed ~2x loading performance increase on iPhone 15
70- // and a ~5% improvement on Galaxy S22 by switching to
71- // FileDataLoader instead of MmapDataLoader + UseMlockIgnoreErrors.
72- : tokenizer_path_(tokenizer_path),
73- metadata_ ({
74- {kEnableDynamicShape , false },
75- {kMaxSeqLen , 128 },
76- {kMaxContextLen , 128 },
77- {kUseKVCache , true },
78- {kUseSDPAWithKVCache , false },
79- }) {
80- if (data_path.has_value ()) {
81- module_ = std::make_unique<Module>(
82- model_path, data_path.value (), Module::LoadMode::File);
83- } else {
84- module_ = std::make_unique<Module>(model_path, Module::LoadMode::File);
85- }
66+ std::optional<const std::string> data_path,
67+ float temperature) {
8668 ET_LOG (
8769 Info,
8870 " Creating LLaMa runner: model_path=%s, tokenizer_path=%s" ,
8971 model_path.c_str (),
9072 tokenizer_path.c_str ());
91- }
9273
93- [[deprecated(
94- " This constructor is deprecated. Use the constructor without temperature parameter instead." )]]
95- Runner::Runner (
96- const std::string& model_path,
97- const std::string& tokenizer_path,
98- const float temperature,
99- std::optional<const std::string> data_path)
100- : Runner(model_path, tokenizer_path, std::move(data_path)) {
101- temperature_ = temperature;
102- }
103-
104- bool Runner::is_loaded () const {
105- return module_->is_loaded () && tokenizer_ && text_decoder_runner_ &&
106- text_prefiller_ && text_token_generator_;
107- }
108-
109- Error Runner::load () {
110- if (is_loaded ()) {
111- return Error::Ok;
74+ // Create the Module
75+ std::unique_ptr<Module> module ;
76+ if (data_path.has_value ()) {
77+ module = std::make_unique<Module>(
78+ model_path, data_path.value (), Module::LoadMode::File);
79+ } else {
80+ module = std::make_unique<Module>(model_path, Module::LoadMode::File);
11281 }
113- ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (" forward" ));
11482
115- // Load tokenizer.
116- tokenizer_ = load_tokenizer (tokenizer_path_);
117- if (tokenizer_ == nullptr ) {
83+ // Initialize metadata with default values
84+ std::unordered_map<std::string, int64_t > metadata ({
85+ {kEnableDynamicShape , false },
86+ {kMaxSeqLen , 128 },
87+ {kMaxContextLen , 128 },
88+ {kUseKVCache , true },
89+ {kUseSDPAWithKVCache , false },
90+ });
91+
92+ // Create and load tokenizer
93+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer =
94+ load_tokenizer (tokenizer_path);
95+
96+ // Fallback to BPE tokenizer if tiktoken fails
97+ if (tokenizer == nullptr ) {
11898 ET_LOG (
11999 Info,
120- " Failed to load %s as a Tiktoken artifact, trying BPE tokenizer" ,
121- tokenizer_path_.c_str ());
122- tokenizer_.reset ();
123- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
124- tokenizer_ = std::make_unique<::tokenizers::Llama2cTokenizer>();
125- auto err = tokenizer_->load (tokenizer_path_);
126- ET_CHECK_TK_OK_OR_RETURN_ERROR (
127- err,
128- " Failed to load %s as a llama2.c tokenizer artifact" ,
129- tokenizer_path_.c_str ());
130- return ::executorch::runtime::Error::InvalidArgument;
100+ " Failed to load %s as a Tiktoken, Sentencepiece or Llama2.c tokenizer, make sure the artifact is one of these types" ,
101+ tokenizer_path.c_str ());
102+ return nullptr ;
131103 }
132104
133105 ET_LOG (Info, " Reading metadata from model" );
134106
135- metadata_[kBosId ] = tokenizer_->bos_tok ();
107+ // Set tokenizer-related metadata
108+ metadata[kBosId ] = tokenizer->bos_tok ();
136109 auto eos_ids = std::make_unique<std::unordered_set<uint64_t >>(
137- std::unordered_set<uint64_t >{tokenizer_->eos_tok ()});
138- metadata_[kVocabSize ] = tokenizer_->vocab_size ();
139-
140- const auto method_names =
141- ET_UNWRAP (module_->method_names (), " Failed reading method names" );
110+ std::unordered_set<uint64_t >{tokenizer->eos_tok ()});
111+ metadata[kVocabSize ] = tokenizer->vocab_size ();
112+
113+ // Read metadata from the model
114+ auto method_names_result = module ->method_names ();
115+ if (method_names_result.error () != Error::Ok) {
116+ ET_LOG (Error, " Failed reading method names" );
117+ return nullptr ;
118+ }
119+ const auto method_names = method_names_result.get ();
142120
143- for (auto & pair : metadata_ ) {
121+ for (auto & pair : metadata ) {
144122 const auto & method_name = pair.first ;
145123 auto & value = pair.second ;
146124
147125 if (method_names.count (method_name)) {
148- value = ET_UNWRAP (module_->get (method_name))
149- .toScalar ()
150- .to <decltype (metadata_)::mapped_type>();
126+ auto get_result = module ->get (method_name);
127+ value = get_result.get ().toScalar ().to <decltype (metadata)::mapped_type>();
151128 } else {
152129 ET_LOG (
153130 Info,
154- " Methond %s not found, using the default value %" PRId64,
131+ " Method %s not found, using the default value %" PRId64,
155132 method_name.c_str (),
156133 value);
157134 }
158135 ET_LOG (Info, " Metadata: %s = %" PRId64, method_name.c_str (), value);
159136 }
137+
138+ // Get EOS IDs if available
160139 if (method_names.count (kEosIds )) {
161140 eos_ids->clear ();
162- for (const auto & eos_id : ET_UNWRAP (module_->execute (kEosIds ))) {
141+ auto execute_result = module ->execute (kEosIds );
142+ if (execute_result.error () != Error::Ok) {
143+ ET_LOG (Error, " Failed to execute %s" , kEosIds );
144+ return nullptr ;
145+ }
146+ for (const auto & eos_id : execute_result.get ()) {
163147 auto value = eos_id.toScalar ().to <int64_t >();
164148 eos_ids->emplace (value);
165149 ET_LOG (Info, " eos_id = %" PRId64, value);
166150 }
167151 }
168- // @lint-ignore CLANGTIDY facebook-hte-Deprecated
169- text_decoder_runner_ = std::make_unique<llm::TextDecoderRunner>(
170- module_.get (), metadata_.at (kUseKVCache ));
171- text_prefiller_ = std::make_unique<llm::TextPrefiller>(
172- text_decoder_runner_.get (),
173- metadata_.at (kUseKVCache ),
174- metadata_.at (kEnableDynamicShape ),
175- metadata_.at (kMaxSeqLen ));
176-
177- text_token_generator_ = std::make_unique<llm::TextTokenGenerator>(
178- tokenizer_.get (),
179- text_decoder_runner_.get (),
180- metadata_.at (kUseKVCache ),
152+
153+ // Create text_decoder_runner. Use a shared_ptr so that it can be shared with
154+ // TextPrefiller and TextTokenGenerator
155+ auto text_decoder_runner = std::make_unique<llm::TextDecoderRunner>(
156+ module .get (), metadata.at (kUseKVCache ));
157+
158+ // Create text_prefiller
159+ auto text_prefiller = std::make_unique<llm::TextPrefiller>(
160+ text_decoder_runner.get (),
161+ metadata.at (kUseKVCache ),
162+ metadata.at (kEnableDynamicShape ),
163+ metadata.at (kMaxSeqLen ));
164+
165+ // Create text_token_generator with stats
166+ auto stats = std::make_unique<llm::Stats>();
167+ auto text_token_generator = std::make_unique<llm::TextTokenGenerator>(
168+ tokenizer.get (),
169+ text_decoder_runner.get (),
170+ metadata.at (kUseKVCache ),
181171 std::move (eos_ids),
182- &stats_);
172+ stats.get ());
173+
174+ // Create and return the Runner instance
175+ return std::make_unique<Runner>(
176+ std::move (metadata),
177+ std::move (tokenizer),
178+ std::move (module ),
179+ std::move (text_decoder_runner),
180+ std::move (text_prefiller),
181+ std::move (text_token_generator),
182+ std::move (stats),
183+ temperature);
184+ }
185+
186+ Runner::Runner (
187+ std::unordered_map<std::string, int64_t > metadata,
188+ std::unique_ptr<::tokenizers::Tokenizer> tokenizer,
189+ std::unique_ptr<::executorch::extension::Module> module ,
190+ std::unique_ptr<::executorch::extension::llm::TextDecoderRunner>
191+ text_decoder_runner,
192+ std::unique_ptr<::executorch::extension::llm::TextPrefiller> text_prefiller,
193+ std::unique_ptr<::executorch::extension::llm::TextTokenGenerator>
194+ text_token_generator,
195+ std::unique_ptr<::executorch::extension::llm::Stats> stats,
196+ float temperature)
197+ : tokenizer_(std::move(tokenizer)),
198+ metadata_ (std::move(metadata)),
199+ module_(std::move(module )),
200+ text_decoder_runner_(std::move(text_decoder_runner)),
201+ text_prefiller_(std::move(text_prefiller)),
202+ text_token_generator_(std::move(text_token_generator)),
203+ stats_(std::move(stats)),
204+ temperature_(temperature) {
205+ // Note: This constructor assumes that text_prefiller and text_token_generator
206+ // already have references to the Module and TextDecoderRunner they need
207+ }
208+
209+ bool Runner::is_loaded () const {
210+ return text_prefiller_->is_loaded () && text_token_generator_->is_loaded ();
211+ }
183212
213+ Error Runner::load () {
214+ if (is_loaded ()) {
215+ return Error::Ok;
216+ }
217+ ET_CHECK_OK_OR_RETURN_ERROR (text_prefiller_->load ());
218+ ET_CHECK_OK_OR_RETURN_ERROR (text_token_generator_->load ());
184219 return Error::Ok;
185220}
186221
@@ -201,9 +236,9 @@ Error Runner::generate(
201236 // Use ones-initialized inputs.
202237 ET_CHECK_MSG (!prompt.empty (), " Prompt cannot be null" );
203238 if (!is_loaded ()) {
204- stats_. model_load_start_ms = llm::time_in_ms ();
239+ stats_-> model_load_start_ms = llm::time_in_ms ();
205240 ET_CHECK_OK_OR_RETURN_ERROR (load ());
206- stats_. model_load_end_ms = llm::time_in_ms ();
241+ stats_-> model_load_end_ms = llm::time_in_ms ();
207242 }
208243
209244 if (config.warming ) {
@@ -229,7 +264,7 @@ Error Runner::generate(
229264 // First token time only measures the time it takes to encode the prompt and
230265 // return a response token.
231266
232- stats_. inference_start_ms = llm::time_in_ms ();
267+ stats_-> inference_start_ms = llm::time_in_ms ();
233268 shouldStop_ = false ;
234269
235270 ::tokenizers::Result<std::vector<uint64_t >> encode_res = tokenizer_->encode (
@@ -270,8 +305,8 @@ Error Runner::generate(
270305 auto prefill_res = text_prefiller_->prefill (prompt_tokens, pos);
271306 ET_CHECK_OK_OR_RETURN_ERROR (prefill_res.error ());
272307 uint64_t cur_token = prefill_res.get ();
273- stats_. first_token_ms = llm::time_in_ms ();
274- stats_. prompt_eval_end_ms = llm::time_in_ms ();
308+ stats_-> first_token_ms = llm::time_in_ms ();
309+ stats_-> prompt_eval_end_ms = llm::time_in_ms ();
275310
276311 // print the first token from prefill. No prev_token so use cur_token for it.
277312 wrapped_callback (
@@ -292,7 +327,7 @@ Error Runner::generate(
292327 temperature_ == -1 .0f ? config.temperature : temperature_,
293328 wrapped_callback));
294329
295- stats_. inference_end_ms = llm::time_in_ms ();
330+ stats_-> inference_end_ms = llm::time_in_ms ();
296331 if (!config.warming ) {
297332 printf (" \n " );
298333 }
@@ -305,17 +340,17 @@ Error Runner::generate(
305340 RUNNER_ET_LOG (config.warming , " Max new tokens %i reached!" , max_new_tokens);
306341 }
307342
308- stats_. num_prompt_tokens = num_prompt_tokens;
309- stats_. num_generated_tokens = num_generated_tokens;
343+ stats_-> num_prompt_tokens = num_prompt_tokens;
344+ stats_-> num_generated_tokens = num_generated_tokens;
310345
311346 if (config.warming ) {
312347 ET_LOG (Info, " Warmup run finished!" );
313348 } else {
314349 // Do not print report during warmup
315- ::executorch::llm::print_report (stats_);
350+ ::executorch::llm::print_report (* stats_);
316351 }
317352 if (stats_callback) {
318- stats_callback (stats_);
353+ stats_callback (* stats_);
319354 }
320355
321356 return Error::Ok;
@@ -329,8 +364,8 @@ Error Runner::warmup(const std::string& prompt, int32_t max_new_tokens) {
329364 // Call generate with the warmup config
330365 Error err = generate (prompt, config);
331366
332- // Reset stats after warmup
333- stats_. reset ();
367+ // Reset stats after warmup, not resetting the std::unique_ptr!
368+ stats_-> reset ();
334369 return err;
335370}
336371
0 commit comments