@@ -506,7 +506,7 @@ def llama_mlock_supported() -> bool:
506506_lib .llama_mlock_supported .restype = c_bool
507507
508508
509- # LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
509+ # LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
510510def llama_n_vocab (ctx : llama_context_p ) -> int :
511511 return _lib .llama_n_vocab (ctx )
512512
@@ -515,7 +515,7 @@ def llama_n_vocab(ctx: llama_context_p) -> int:
515515_lib .llama_n_vocab .restype = c_int
516516
517517
518- # LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
518+ # LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
519519def llama_n_ctx (ctx : llama_context_p ) -> int :
520520 return _lib .llama_n_ctx (ctx )
521521
@@ -524,7 +524,16 @@ def llama_n_ctx(ctx: llama_context_p) -> int:
524524_lib .llama_n_ctx .restype = c_int
525525
526526
527- # LLAMA_API int llama_n_embd (const struct llama_context * ctx);
527+ # LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
528+ def llama_n_ctx_train (ctx : llama_context_p ) -> int :
529+ return _lib .llama_n_ctx_train (ctx )
530+
531+
532+ _lib .llama_n_ctx_train .argtypes = [llama_context_p ]
533+ _lib .llama_n_ctx_train .restype = c_int
534+
535+
536+ # LLAMA_API int llama_n_embd (const struct llama_context * ctx);
528537def llama_n_embd (ctx : llama_context_p ) -> int :
529538 return _lib .llama_n_embd (ctx )
530539
@@ -542,7 +551,7 @@ def llama_vocab_type(ctx: llama_context_p) -> int:
542551_lib .llama_vocab_type .restype = c_int
543552
544553
545- # LLAMA_API int llama_model_n_vocab(const struct llama_model * model);
554+ # LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
546555def llama_model_n_vocab (model : llama_model_p ) -> int :
547556 return _lib .llama_model_n_vocab (model )
548557
@@ -551,7 +560,7 @@ def llama_model_n_vocab(model: llama_model_p) -> int:
551560_lib .llama_model_n_vocab .restype = c_int
552561
553562
554- # LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
563+ # LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
555564def llama_model_n_ctx (model : llama_model_p ) -> int :
556565 return _lib .llama_model_n_ctx (model )
557566
@@ -560,7 +569,16 @@ def llama_model_n_ctx(model: llama_model_p) -> int:
560569_lib .llama_model_n_ctx .restype = c_int
561570
562571
563- # LLAMA_API int llama_model_n_embd (const struct llama_model * model);
572+ # LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
573+ def llama_model_n_ctx_train (model : llama_model_p ) -> int :
574+ return _lib .llama_model_n_ctx_train (model )
575+
576+
577+ _lib .llama_model_n_ctx_train .argtypes = [llama_model_p ]
578+ _lib .llama_model_n_ctx_train .restype = c_int
579+
580+
581+ # LLAMA_API int llama_model_n_embd (const struct llama_model * model);
564582def llama_model_n_embd (model : llama_model_p ) -> int :
565583 return _lib .llama_model_n_embd (model )
566584
@@ -1046,74 +1064,14 @@ def llama_grammar_free(grammar: llama_grammar_p):
10461064_lib .llama_grammar_free .argtypes = [llama_grammar_p ]
10471065_lib .llama_grammar_free .restype = None
10481066
1049- # //
1050- # // Beam search
1051- # //
1052-
1053-
1054- # struct llama_beam_view {
1055- # const llama_token * tokens;
1056- # size_t n_tokens;
1057- # float p; // Cumulative beam probability (renormalized relative to all beams)
1058- # bool eob; // Callback should set this to true when a beam is at end-of-beam.
1059- # };
1060- class llama_beam_view (ctypes .Structure ):
1061- _fields_ = [
1062- ("tokens" , llama_token_p ),
1063- ("n_tokens" , c_size_t ),
1064- ("p" , c_float ),
1065- ("eob" , c_bool ),
1066- ]
10671067
1068+ # LLAMA_API struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar);
1069+ def llama_grammar_copy (grammar : llama_grammar_p ) -> llama_grammar_p :
1070+ return _lib .llama_grammar_copy (grammar )
10681071
1069- # // Passed to beam_search_callback function.
1070- # // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
1071- # // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
1072- # // These pointers are valid only during the synchronous callback, so should not be saved.
1073- # struct llama_beams_state {
1074- # struct llama_beam_view * beam_views;
1075- # size_t n_beams; // Number of elements in beam_views[].
1076- # size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
1077- # bool last_call; // True iff this is the last callback invocation.
1078- # };
1079- class llama_beams_state (ctypes .Structure ):
1080- _fields_ = [
1081- ("beam_views" , POINTER (llama_beam_view )),
1082- ("n_beams" , c_size_t ),
1083- ("common_prefix_length" , c_size_t ),
1084- ("last_call" , c_bool ),
1085- ]
1086-
1087-
1088- # // Type of pointer to the beam_search_callback function.
1089- # // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
1090- # // passed back to beam_search_callback. This avoids having to use global variables in the callback.
1091- # typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
1092- llama_beam_search_callback_fn_t = ctypes .CFUNCTYPE (None , c_void_p , llama_beams_state )
1093-
1094-
1095- # /// @details Deterministically returns entire sentence constructed by a beam search.
1096- # /// @param ctx Pointer to the llama_context.
1097- # /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
1098- # /// @param callback_data A pointer that is simply passed back to callback.
1099- # /// @param n_beams Number of beams to use.
1100- # /// @param n_past Number of tokens already evaluated.
1101- # /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
1102- # /// @param n_threads Number of threads as passed to llama_eval().
1103- # LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
1104- def llama_beam_search (
1105- ctx : llama_context_p ,
1106- callback : "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]" , # type: ignore
1107- callback_data : c_void_p ,
1108- n_beams : c_size_t ,
1109- n_past : c_int ,
1110- n_predict : c_int ,
1111- n_threads : c_int ,
1112- ):
1113- return _lib .llama_beam_search (
1114- ctx , callback , callback_data , n_beams , n_past , n_predict , n_threads
1115- )
11161072
1073+ _lib .llama_grammar_copy .argtypes = [llama_grammar_p ]
1074+ _lib .llama_grammar_copy .restype = llama_grammar_p
11171075
11181076# //
11191077# // Sampling functions
@@ -1436,6 +1394,74 @@ def llama_grammar_accept_token(
14361394 llama_token ,
14371395]
14381396_lib .llama_grammar_accept_token .restype = None
1397+ # //
1398+ # // Beam search
1399+ # //
1400+
1401+
1402+ # struct llama_beam_view {
1403+ # const llama_token * tokens;
1404+ # size_t n_tokens;
1405+ # float p; // Cumulative beam probability (renormalized relative to all beams)
1406+ # bool eob; // Callback should set this to true when a beam is at end-of-beam.
1407+ # };
1408+ class llama_beam_view (ctypes .Structure ):
1409+ _fields_ = [
1410+ ("tokens" , llama_token_p ),
1411+ ("n_tokens" , c_size_t ),
1412+ ("p" , c_float ),
1413+ ("eob" , c_bool ),
1414+ ]
1415+
1416+
1417+ # // Passed to beam_search_callback function.
1418+ # // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
1419+ # // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
1420+ # // These pointers are valid only during the synchronous callback, so should not be saved.
1421+ # struct llama_beams_state {
1422+ # struct llama_beam_view * beam_views;
1423+ # size_t n_beams; // Number of elements in beam_views[].
1424+ # size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
1425+ # bool last_call; // True iff this is the last callback invocation.
1426+ # };
1427+ class llama_beams_state (ctypes .Structure ):
1428+ _fields_ = [
1429+ ("beam_views" , POINTER (llama_beam_view )),
1430+ ("n_beams" , c_size_t ),
1431+ ("common_prefix_length" , c_size_t ),
1432+ ("last_call" , c_bool ),
1433+ ]
1434+
1435+
1436+ # // Type of pointer to the beam_search_callback function.
1437+ # // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
1438+ # // passed back to beam_search_callback. This avoids having to use global variables in the callback.
1439+ # typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, struct llama_beams_state);
1440+ llama_beam_search_callback_fn_t = ctypes .CFUNCTYPE (None , c_void_p , llama_beams_state )
1441+
1442+
1443+ # /// @details Deterministically returns entire sentence constructed by a beam search.
1444+ # /// @param ctx Pointer to the llama_context.
1445+ # /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
1446+ # /// @param callback_data A pointer that is simply passed back to callback.
1447+ # /// @param n_beams Number of beams to use.
1448+ # /// @param n_past Number of tokens already evaluated.
1449+ # /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
1450+ # /// @param n_threads Number of threads as passed to llama_eval().
1451+ # LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
1452+ def llama_beam_search (
1453+ ctx : llama_context_p ,
1454+ callback : "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]" , # type: ignore
1455+ callback_data : c_void_p ,
1456+ n_beams : c_size_t ,
1457+ n_past : c_int ,
1458+ n_predict : c_int ,
1459+ n_threads : c_int ,
1460+ ):
1461+ return _lib .llama_beam_search (
1462+ ctx , callback , callback_data , n_beams , n_past , n_predict , n_threads
1463+ )
1464+
14391465
14401466# Performance information
14411467
@@ -1494,6 +1520,7 @@ def llama_log_set(
14941520def llama_dump_timing_info_yaml (stream : ctypes .c_void_p , ctx : llama_context_p ):
14951521 return _lib .llama_dump_timing_info_yaml (stream , ctx )
14961522
1523+
14971524_lib .llama_dump_timing_info_yaml .argtypes = [ctypes .c_void_p , llama_context_p ]
14981525_lib .llama_dump_timing_info_yaml .restype = None
14991526
0 commit comments