@@ -545,13 +545,65 @@ fn llm_infer(
545545 . map ( LLMInferencingResult :: from)
546546}
547547
548+ #[ pyo3:: pyfunction]
549+ fn generate_embeddings ( model : & str , text : Vec < String > ) -> Result < LLMEmbeddingsResult , Anyhow > {
550+ let model = match model {
551+ "all-minilm-l6-v2" => llm:: EmbeddingModel :: AllMiniLmL6V2 ,
552+ _ => llm:: EmbeddingModel :: Other ( model) ,
553+ } ;
554+
555+ let text = text. iter ( ) . map ( |s| s. as_str ( ) ) . collect :: < Vec < _ > > ( ) ;
556+
557+ llm:: generate_embeddings ( model, & text)
558+ . map_err ( Anyhow :: from)
559+ . map ( LLMEmbeddingsResult :: from)
560+ }
561+
562+ #[ derive( Clone ) ]
563+ #[ pyo3:: pyclass]
564+ #[ pyo3( name = "LLMEmbeddingsUsage" ) ]
565+ struct LLMEmbeddingsUsage {
566+ #[ pyo3( get) ]
567+ prompt_token_count : u32 ,
568+ }
569+
570+ impl From < llm:: EmbeddingsUsage > for LLMEmbeddingsUsage {
571+ fn from ( result : llm:: EmbeddingsUsage ) -> Self {
572+ LLMEmbeddingsUsage {
573+ prompt_token_count : result. prompt_token_count ,
574+ }
575+ }
576+ }
577+
578+ #[ derive( Clone ) ]
579+ #[ pyo3:: pyclass]
580+ #[ pyo3( name = "LLMEmbeddingResult" ) ]
581+ struct LLMEmbeddingsResult {
582+ #[ pyo3( get) ]
583+ embeddings : Vec < Vec < f32 > > ,
584+ #[ pyo3( get) ]
585+ usage : LLMEmbeddingsUsage ,
586+ }
587+
588+ impl From < llm:: EmbeddingsResult > for LLMEmbeddingsResult {
589+ fn from ( result : llm:: EmbeddingsResult ) -> Self {
590+ LLMEmbeddingsResult {
591+ embeddings : result. embeddings ,
592+ usage : LLMEmbeddingsUsage :: from ( result. usage ) ,
593+ }
594+ }
595+ }
596+
548597#[ pyo3:: pymodule]
549598#[ pyo3( name = "spin_llm" ) ]
550599fn spin_llm_module ( _py : Python < ' _ > , module : & PyModule ) -> PyResult < ( ) > {
551600 module. add_function ( pyo3:: wrap_pyfunction!( llm_infer, module) ?) ?;
601+ module. add_function ( pyo3:: wrap_pyfunction!( generate_embeddings, module) ?) ?;
552602 module. add_class :: < LLMInferencingUsage > ( ) ?;
553603 module. add_class :: < LLMInferencingParams > ( ) ?;
554- module. add_class :: < LLMInferencingResult > ( )
604+ module. add_class :: < LLMInferencingResult > ( ) ?;
605+ module. add_class :: < LLMEmbeddingsUsage > ( ) ?;
606+ module. add_class :: < LLMEmbeddingsResult > ( )
555607}
556608
557609pub fn run_ctors ( ) {
0 commit comments