@@ -568,13 +568,33 @@ def llama_model_n_embd(model: llama_model_p) -> int:
568568
569569
570570# // Get a string describing the model type
571- # LLAMA_API int llama_model_type (const struct llama_model * model, char * buf, size_t buf_size);
572- def llama_model_type (model : llama_model_p , buf : bytes , buf_size : c_size_t ) -> int :
573- return _lib .llama_model_type (model , buf , buf_size )
571+ # LLAMA_API int llama_model_desc (const struct llama_model * model, char * buf, size_t buf_size);
572+ def llama_model_desc (model : llama_model_p , buf : bytes , buf_size : c_size_t ) -> int :
573+ return _lib .llama_model_desc (model , buf , buf_size )
574574
575575
576- _lib .llama_model_type .argtypes = [llama_model_p , c_char_p , c_size_t ]
577- _lib .llama_model_type .restype = c_int
576+ _lib .llama_model_desc .argtypes = [llama_model_p , c_char_p , c_size_t ]
577+ _lib .llama_model_desc .restype = c_int
578+
579+
580+ # // Returns the total size of all the tensors in the model in bytes
581+ # LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
582+ def llama_model_size (model : llama_model_p ) -> int :
583+ return _lib .llama_model_size (model )
584+
585+
586+ _lib .llama_model_size .argtypes = [llama_model_p ]
587+ _lib .llama_model_size .restype = ctypes .c_uint64
588+
589+
590+ # // Returns the total number of parameters in the model
591+ # LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
592+ def llama_model_n_params (model : llama_model_p ) -> int :
593+ return _lib .llama_model_n_params (model )
594+
595+
596+ _lib .llama_model_n_params .argtypes = [llama_model_p ]
597+ _lib .llama_model_n_params .restype = ctypes .c_uint64
578598
579599
580600# // Returns 0 on success
@@ -1029,6 +1049,74 @@ def llama_grammar_free(grammar: llama_grammar_p):
10291049_lib .llama_grammar_free .argtypes = [llama_grammar_p ]
10301050_lib .llama_grammar_free .restype = None
10311051
1052+ # //
1053+ # // Beam search
1054+ # //
1055+
1056+
1057+ # struct llama_beam_view {
1058+ # const llama_token * tokens;
1059+ # size_t n_tokens;
1060+ # float p; // Cumulative beam probability (renormalized relative to all beams)
1061+ # bool eob; // Callback should set this to true when a beam is at end-of-beam.
1062+ # };
1063+ class llama_beam_view (ctypes .Structure ):
1064+ _fields_ = [
1065+ ("tokens" , llama_token_p ),
1066+ ("n_tokens" , c_size_t ),
1067+ ("p" , c_float ),
1068+ ("eob" , c_bool ),
1069+ ]
1070+
1071+
1072+ # // Passed to beam_search_callback function.
1073+ # // Whenever 0 < common_prefix_length, this number of tokens should be copied from any of the beams
1074+ # // (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
1075+ # // These pointers are valid only during the synchronous callback, so should not be saved.
1076+ # struct llama_beams_state {
1077+ # struct llama_beam_view * beam_views;
1078+ # size_t n_beams; // Number of elements in beam_views[].
1079+ # size_t common_prefix_length; // Current max length of prefix tokens shared by all beams.
1080+ # bool last_call; // True iff this is the last callback invocation.
1081+ # };
1082+ class llama_beams_state (ctypes .Structure ):
1083+ _fields_ = [
1084+ ("beam_views" , POINTER (llama_beam_view )),
1085+ ("n_beams" , c_size_t ),
1086+ ("common_prefix_length" , c_size_t ),
1087+ ("last_call" , c_bool ),
1088+ ]
1089+
1090+
1091+ # // Type of pointer to the beam_search_callback function.
1092+ # // void* callback_data is any custom data passed to llama_beam_search, that is subsequently
1093+ # // passed back to beam_search_callback. This avoids having to use global variables in the callback.
1094+ # typedef void (*llama_beam_search_callback_fn_t)(void * callback_data, llama_beams_state);
1095+ llama_beam_search_callback_fn_t = ctypes .CFUNCTYPE (None , c_void_p , llama_beams_state )
1096+
1097+
1098+ # /// @details Deterministically returns entire sentence constructed by a beam search.
1099+ # /// @param ctx Pointer to the llama_context.
1100+ # /// @param callback Invoked for each iteration of the beam_search loop, passing in beams_state.
1101+ # /// @param callback_data A pointer that is simply passed back to callback.
1102+ # /// @param n_beams Number of beams to use.
1103+ # /// @param n_past Number of tokens already evaluated.
1104+ # /// @param n_predict Maximum number of tokens to predict. EOS may occur earlier.
1105+ # /// @param n_threads Number of threads as passed to llama_eval().
1106+ # LLAMA_API void llama_beam_search(struct llama_context * ctx, llama_beam_search_callback_fn_t callback, void * callback_data, size_t n_beams, int n_past, int n_predict, int n_threads);
1107+ def llama_beam_search (
1108+ ctx : llama_context_p ,
1109+ callback : "ctypes._CFuncPtr[None, c_void_p, llama_beams_state]" , # type: ignore
1110+ callback_data : c_void_p ,
1111+ n_beams : c_size_t ,
1112+ n_past : c_int ,
1113+ n_predict : c_int ,
1114+ n_threads : c_int ,
1115+ ):
1116+ return _lib .llama_beam_search (
1117+ ctx , callback , callback_data , n_beams , n_past , n_predict , n_threads
1118+ )
1119+
10321120
10331121# //
10341122# // Sampling functions
0 commit comments