Skip to content

Commit 5e9bb75

Browse files
authored
add bindings for llm/generate-embeddings and related types (#53)
Signed-off-by: Joel Dice <joel.dice@fermyon.com>
1 parent 55377a2 commit 5e9bb75

File tree

3 files changed

+64
-4
lines changed

3 files changed

+64
-4
lines changed

crates/spin-python-engine/src/lib.rs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -545,13 +545,65 @@ fn llm_infer(
545545
.map(LLMInferencingResult::from)
546546
}
547547

548+
#[pyo3::pyfunction]
549+
fn generate_embeddings(model: &str, text: Vec<String>) -> Result<LLMEmbeddingsResult, Anyhow> {
550+
let model = match model {
551+
"all-minilm-l6-v2" => llm::EmbeddingModel::AllMiniLmL6V2,
552+
_ => llm::EmbeddingModel::Other(model),
553+
};
554+
555+
let text = text.iter().map(|s| s.as_str()).collect::<Vec<_>>();
556+
557+
llm::generate_embeddings(model, &text)
558+
.map_err(Anyhow::from)
559+
.map(LLMEmbeddingsResult::from)
560+
}
561+
562+
#[derive(Clone)]
563+
#[pyo3::pyclass]
564+
#[pyo3(name = "LLMEmbeddingsUsage")]
565+
struct LLMEmbeddingsUsage {
566+
#[pyo3(get)]
567+
prompt_token_count: u32,
568+
}
569+
570+
impl From<llm::EmbeddingsUsage> for LLMEmbeddingsUsage {
571+
fn from(result: llm::EmbeddingsUsage) -> Self {
572+
LLMEmbeddingsUsage {
573+
prompt_token_count: result.prompt_token_count,
574+
}
575+
}
576+
}
577+
578+
#[derive(Clone)]
579+
#[pyo3::pyclass]
580+
#[pyo3(name = "LLMEmbeddingResult")]
581+
struct LLMEmbeddingsResult {
582+
#[pyo3(get)]
583+
embeddings: Vec<Vec<f32>>,
584+
#[pyo3(get)]
585+
usage: LLMEmbeddingsUsage,
586+
}
587+
588+
impl From<llm::EmbeddingsResult> for LLMEmbeddingsResult {
589+
fn from(result: llm::EmbeddingsResult) -> Self {
590+
LLMEmbeddingsResult {
591+
embeddings: result.embeddings,
592+
usage: LLMEmbeddingsUsage::from(result.usage),
593+
}
594+
}
595+
}
596+
548597
#[pyo3::pymodule]
549598
#[pyo3(name = "spin_llm")]
550599
fn spin_llm_module(_py: Python<'_>, module: &PyModule) -> PyResult<()> {
551600
module.add_function(pyo3::wrap_pyfunction!(llm_infer, module)?)?;
601+
module.add_function(pyo3::wrap_pyfunction!(generate_embeddings, module)?)?;
552602
module.add_class::<LLMInferencingUsage>()?;
553603
module.add_class::<LLMInferencingParams>()?;
554-
module.add_class::<LLMInferencingResult>()
604+
module.add_class::<LLMInferencingResult>()?;
605+
module.add_class::<LLMEmbeddingsUsage>()?;
606+
module.add_class::<LLMEmbeddingsResult>()
555607
}
556608

557609
pub fn run_ctors() {

examples/llm/app.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1+
import json
12
from spin_http import Response
2-
from spin_llm import llm_infer
3+
from spin_llm import llm_infer, generate_embeddings
34

45

56
def handle_request(request):
67
prompt="You are a stand up comedy writer. Tell me a joke."
78
result=llm_infer("llama2-chat", prompt)
9+
10+
embeddings = generate_embeddings("all-minilm-l6-v2", ["hat", "cat", "bat"])
11+
12+
body = (f"joke: {result.text}\n\n"
13+
f"embeddings: {json.dumps(embeddings.embeddings)}\n"
14+
f"prompt token count: {embeddings.usage.prompt_token_count}")
15+
816
return Response(200,
917
{"content-type": "text/plain"},
10-
bytes(result.text, "utf-8"))
18+
bytes(body, "utf-8"))

examples/llm/spin.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ version = "0.1.0"
88
[[component]]
99
id = "python-sdk-example"
1010
source = "app.wasm"
11-
ai_models = ["llama2-chat"]
11+
ai_models = ["llama2-chat", "all-minilm-l6-v2"]
1212
[component.trigger]
1313
route = "/..."
1414
[component.build]

0 commit comments

Comments
 (0)