@@ -374,6 +374,178 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string& te
374374 return bpe_offsets;
375375}
376376
377+ // K2 system regex patterns (from tokenization_kimi.py):
378+ // [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+
379+ static std::vector<size_t > unicode_regex_split_custom_kimi_k2 (const std::string & text, const std::vector<size_t > & offsets) {
380+ std::vector<size_t > bpe_offsets;
381+ bpe_offsets.reserve (offsets.size ());
382+
383+ const auto cpts = unicode_cpts_from_utf8 (text);
384+
385+ size_t start = 0 ;
386+ for (auto offset : offsets) {
387+ const size_t offset_ini = start;
388+ const size_t offset_end = start + offset;
389+ assert (offset_end <= cpts.size ());
390+ start = offset_end;
391+
392+ static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF ;
393+ auto _get_cpt = [&] (const size_t pos) -> uint32_t {
394+ return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
395+ };
396+
397+ auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
398+ return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags (cpts[pos]) : codepoint_flags{};
399+ };
400+
401+ size_t _prev_end = offset_ini;
402+ auto _add_token = [&] (const size_t end) -> size_t {
403+ assert (_prev_end <= end && end <= offset_end);
404+ size_t len = end - _prev_end;
405+ if (len > 0 ) {
406+ bpe_offsets.push_back (len);
407+ }
408+ _prev_end = end;
409+ return len;
410+ };
411+
412+ for (size_t pos = offset_ini; pos < offset_end; /* pos++*/ ) {
413+ const uint32_t cpt = _get_cpt (pos);
414+ const auto flags = _get_flags (pos);
415+
416+ // Pattern 1: [\p{Han}]+ (Chinese characters)
417+ if (unicode_cpt_is_han (cpt)) {
418+ while (unicode_cpt_is_han (_get_cpt (pos))) {
419+ pos++;
420+ }
421+ _add_token (pos);
422+ continue ;
423+ }
424+
425+ // Pattern 2 & 3: Letter words excluding Han characters with optional contractions
426+ // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)?
427+ // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)?
428+ // Check if current char is a letter OR if current char could be a leading char and next char is a letter
429+ bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han (cpt)) ||
430+ (!(cpt == ' \r ' || cpt == ' \n ' || flags.is_letter || flags.is_number ) &&
431+ _get_flags (pos + 1 ).is_letter && !unicode_cpt_is_han (_get_cpt (pos + 1 )));
432+
433+ if (is_letter_pattern) {
434+ // Handle optional leading non-letter/non-number character
435+ bool has_leading_char = false ;
436+ if (!(cpt == ' \r ' || cpt == ' \n ' || flags.is_letter || flags.is_number )) {
437+ has_leading_char = true ;
438+ pos++;
439+ }
440+
441+ // Match letter sequence (excluding Han characters)
442+ bool has_letters = false ;
443+ while (_get_flags (pos).is_letter && !unicode_cpt_is_han (_get_cpt (pos))) {
444+ has_letters = true ;
445+ pos++;
446+ }
447+
448+ // Only proceed if we found letters (after potentially skipping leading char)
449+ if (has_letters || (!has_leading_char && _get_flags (pos).is_letter && !unicode_cpt_is_han (_get_cpt (pos)))) {
450+ if (!has_letters) pos++; // consume the first letter if we didn't already
451+
452+ // Continue consuming letters
453+ while (_get_flags (pos).is_letter && !unicode_cpt_is_han (_get_cpt (pos))) {
454+ pos++;
455+ }
456+
457+ // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d)
458+ if (_get_cpt (pos) == ' \' ' && pos + 1 < offset_end) {
459+ uint32_t cpt_next = unicode_tolower (_get_cpt (pos + 1 ));
460+ if (cpt_next == ' s' || cpt_next == ' t' || cpt_next == ' m' || cpt_next == ' d' ) {
461+ pos += 2 ;
462+ } else if (pos + 2 < offset_end) {
463+ uint32_t cpt_next_next = unicode_tolower (_get_cpt (pos + 2 ));
464+ if ((cpt_next == ' r' && cpt_next_next == ' e' ) ||
465+ (cpt_next == ' v' && cpt_next_next == ' e' ) ||
466+ (cpt_next == ' l' && cpt_next_next == ' l' )) {
467+ pos += 3 ;
468+ }
469+ }
470+ }
471+
472+ _add_token (pos);
473+ continue ;
474+ } else if (has_leading_char) {
475+ // We consumed a leading char but found no letters, backtrack
476+ pos--;
477+ }
478+ }
479+
480+ // Pattern 4: \p{N}{1,3} (numbers 1-3 digits)
481+ if (flags.is_number ) {
482+ size_t ini = pos;
483+ while (_get_flags (pos).is_number ) {
484+ if (++pos - ini >= 3 ) {
485+ _add_token (pos);
486+ ini = pos;
487+ }
488+ }
489+ _add_token (pos);
490+ continue ;
491+ }
492+
493+ // Pattern 5: ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines)
494+ auto flags2 = (cpt == ' ' ? _get_flags (pos + 1 ) : flags);
495+ if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number ) && flags2.as_uint ()) {
496+ pos += (cpt == ' ' );
497+ while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number ) && flags2.as_uint ()) {
498+ flags2 = _get_flags (++pos);
499+ }
500+ // Match optional [\r\n]*
501+ uint32_t cpt2 = _get_cpt (pos);
502+ while (cpt2 == ' \r ' || cpt2 == ' \n ' ) {
503+ cpt2 = _get_cpt (++pos);
504+ }
505+ _add_token (pos);
506+ continue ;
507+ }
508+
509+ // Count whitespace characters
510+ size_t num_whitespaces = 0 ;
511+ size_t last_end_r_or_n = 0 ;
512+ while (_get_flags (pos + num_whitespaces).is_whitespace ) {
513+ uint32_t cpt2 = _get_cpt (pos + num_whitespaces);
514+ if (cpt2 == ' \r ' || cpt2 == ' \n ' ) {
515+ last_end_r_or_n = pos + num_whitespaces + 1 ;
516+ }
517+ num_whitespaces++;
518+ }
519+
520+ // Pattern 6: \s*[\r\n]+ (whitespace with newlines)
521+ if (last_end_r_or_n > 0 ) {
522+ pos = last_end_r_or_n;
523+ _add_token (pos);
524+ continue ;
525+ }
526+
527+ // Pattern 7: \s+(?!\S) (trailing whitespace)
528+ if (num_whitespaces > 1 && _get_cpt (pos + num_whitespaces) != OUT_OF_RANGE) {
529+ pos += num_whitespaces - 1 ;
530+ _add_token (pos);
531+ continue ;
532+ }
533+
534+ // Pattern 8: \s+ (general whitespace)
535+ if (num_whitespaces > 0 ) {
536+ pos += num_whitespaces;
537+ _add_token (pos);
538+ continue ;
539+ }
540+
541+ // No matches - consume single character
542+ _add_token (++pos);
543+ }
544+ }
545+
546+ return bpe_offsets;
547+ }
548+
377549// LLAMA3 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"
378550static std::vector<size_t > unicode_regex_split_custom_llama3 (const std::string& text, const std::vector<size_t >& offsets) {
379551 std::vector<size_t > bpe_offsets; // store the offset of each word
@@ -587,6 +759,10 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string& text, c
587759
588760 bpe_offsets = unicode_regex_split_custom_llama3 (text, offsets);
589761 }
762+ else if (regex_expr == " \\ p{Han}+" ) {
763+ // K2's first pattern - handle all K2 patterns together
764+ bpe_offsets = unicode_regex_split_custom_kimi_k2 (text, offsets);
765+ }
590766
591767 return bpe_offsets;
592768}
@@ -662,6 +838,38 @@ codepoint_flags unicode_cpt_flags(const std::string& utf8) {
662838 return unicode_cpt_flags (unicode_cpt_from_utf8 (utf8, offset));
663839}
664840
841+ bool unicode_cpt_is_han (uint32_t cpt) {
842+ // Han character ranges (Chinese/CJK characters)
843+ // CJK Unified Ideographs (most common)
844+ if (cpt >= 0x4E00 && cpt <= 0x9FFF ) return true ;
845+
846+ // CJK Extension A
847+ if (cpt >= 0x3400 && cpt <= 0x4DBF ) return true ;
848+
849+ // CJK Extension B
850+ if (cpt >= 0x20000 && cpt <= 0x2A6DF ) return true ;
851+
852+ // CJK Extension C
853+ if (cpt >= 0x2A700 && cpt <= 0x2B73F ) return true ;
854+
855+ // CJK Extension D
856+ if (cpt >= 0x2B740 && cpt <= 0x2B81F ) return true ;
857+
858+ // CJK Extension E
859+ if (cpt >= 0x2B820 && cpt <= 0x2CEAF ) return true ;
860+
861+ // CJK Extension F
862+ if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF ) return true ;
863+
864+ // CJK Compatibility Ideographs
865+ if (cpt >= 0xF900 && cpt <= 0xFAFF ) return true ;
866+
867+ // CJK Compatibility Ideographs Supplement
868+ if (cpt >= 0x2F800 && cpt <= 0x2FA1F ) return true ;
869+
870+ return false ;
871+ }
872+
665873std::string unicode_byte_to_utf8 (uint8_t byte) {
666874 static std::unordered_map<uint8_t , std::string> map = unicode_byte_to_utf8_map ();
667875 return map.at (byte);
0 commit comments