@@ -717,10 +717,53 @@ def create_embedding(
717717 Returns:
718718 An embedding object.
719719 """
720- assert self ._ctx .ctx is not None
721720 assert self ._model .model is not None
722721 model_name : str = model if model is not None else self .model_path
723722
723+ # get numeric embeddings
724+ embeds : List [List [float ]]
725+ total_tokens : int
726+ embeds , total_tokens = self .embed (input , return_count = True ) # type: ignore
727+
728+ # convert to CreateEmbeddingResponse
729+ data : List [Embedding ] = [
730+ {
731+ "object" : "embedding" ,
732+ "embedding" : emb ,
733+ "index" : idx ,
734+ }
735+ for idx , emb in enumerate (embeds )
736+ ]
737+
738+ return {
739+ "object" : "list" ,
740+ "data" : data ,
741+ "model" : model_name ,
742+ "usage" : {
743+ "prompt_tokens" : total_tokens ,
744+ "total_tokens" : total_tokens ,
745+ },
746+ }
747+
748+ def embed (
749+ self ,
750+ input : Union [str , List [str ]],
751+ normalize : bool = True ,
752+ truncate : bool = True ,
753+ return_count : bool = False ,
754+ ):
755+ """Embed a string.
756+
757+ Args:
758+ input: The utf-8 encoded string to embed.
759+
760+ Returns:
761+ A list of embeddings
762+ """
763+ assert self ._ctx .ctx is not None
764+ n_embd = self .n_embd ()
765+ n_ctx = self .n_ctx ()
766+
724767 if self .context_params .embedding == False :
725768 raise RuntimeError (
726769 "Llama model must be created with embedding=True to call this method"
@@ -734,48 +777,72 @@ def create_embedding(
734777 else :
735778 inputs = input
736779
737- data : List [Embedding ] = []
780+ # reset batch
781+ self ._batch .reset ()
782+
783+ # decode and fetch embeddings
784+ data : List [List [float ]] = []
785+ def decode_batch (sizes : List [int ]):
786+ assert self ._ctx .ctx is not None
787+ llama_cpp .llama_kv_cache_clear (self ._ctx .ctx )
788+ self ._ctx .decode (self ._batch )
789+ self ._batch .reset ()
790+
791+ # store embeddings
792+ for i , s in enumerate (sizes ):
793+ embedding = llama_cpp .llama_get_embeddings_ith (self ._ctx .ctx , i )[
794+ :n_embd
795+ ]
796+ norm = np .linalg .norm (embedding ) if normalize else s
797+ embedding : List [float ] = [v / float (norm ) for v in embedding ]
798+ data .append (embedding )
799+
800+ # init state
738801 total_tokens = 0
739- for index , input in enumerate (inputs ):
740- tokens = self .tokenize (input .encode ("utf-8" ), special = True )
741- self .reset ()
742- self .eval (tokens )
802+ t_batch = 0
803+ s_sizes : List [int ] = []
804+
805+ # accumulate batches and encode
806+ for text in inputs :
807+ tokens = self .tokenize (text .encode ("utf-8" ))
808+ if truncate :
809+ tokens = tokens [:n_ctx ]
810+
743811 n_tokens = len (tokens )
744812 total_tokens += n_tokens
745- embedding = llama_cpp .llama_get_embeddings (self ._ctx .ctx )[
746- : llama_cpp .llama_n_embd (self ._model .model )
747- ]
748813
749- data .append (
750- {
751- "object" : "embedding" ,
752- "embedding" : embedding ,
753- "index" : index ,
754- }
755- )
814+ # check for overrun
815+ if n_tokens > n_ctx :
816+ raise ValueError (
817+ f"Requested tokens ({ n_tokens } ) exceed context window of { n_ctx } "
818+ )
819+
820+ # time to eval batch
821+ if t_batch + n_tokens > self ._n_ctx :
822+ decode_batch (s_sizes )
823+ t_batch = 0
824+ s_sizes = []
825+
826+ # add to batch
827+ self ._batch .add_sequence (tokens , len (s_sizes ), False )
828+ t_batch += n_tokens
829+ s_sizes .append (n_tokens )
830+
831+ # hanlde last batch
832+ decode_batch (s_sizes )
833+
756834 if self .verbose :
757835 llama_cpp .llama_print_timings (self ._ctx .ctx )
758836
759- return {
760- "object" : "list" ,
761- "data" : data ,
762- "model" : model_name ,
763- "usage" : {
764- "prompt_tokens" : total_tokens ,
765- "total_tokens" : total_tokens ,
766- },
767- }
768-
769- def embed (self , input : str ) -> List [float ]:
770- """Embed a string.
837+ output = data [0 ] if isinstance (input , str ) else data
771838
772- Args:
773- input: The utf-8 encoded string to embed.
839+ llama_cpp . llama_kv_cache_clear ( self . _ctx . ctx )
840+ self . reset ()
774841
775- Returns :
776- A list of embeddings
777- """
778- return list ( map ( float , self . create_embedding ( input )[ "data" ][ 0 ][ "embedding" ]))
842+ if return_count :
843+ return output , total_tokens
844+ else :
845+ return output
779846
780847 def _create_completion (
781848 self ,
0 commit comments