diff --git a/apis/python/src/tiledb/vector_search/embeddings/__init__.py b/apis/python/src/tiledb/vector_search/embeddings/__init__.py index 5d3538370..b9254f848 100644 --- a/apis/python/src/tiledb/vector_search/embeddings/__init__.py +++ b/apis/python/src/tiledb/vector_search/embeddings/__init__.py @@ -2,6 +2,7 @@ from .image_resnetv2_embedding import ImageResNetV2Embedding from .langchain_embedding import LangChainEmbedding from .object_embedding import ObjectEmbedding +from .ollama_embedding import OllamaEmbedding from .random_embedding import RandomEmbedding from .sentence_transformers_embedding import SentenceTransformersEmbedding from .soma_geneptw_embedding import SomaGenePTwEmbedding @@ -18,4 +19,5 @@ "LangChainEmbedding", "SomaScGPTEmbedding", "SomaSCVIEmbedding", + "OllamaEmbedding", ] diff --git a/apis/python/src/tiledb/vector_search/embeddings/ollama_embedding.py b/apis/python/src/tiledb/vector_search/embeddings/ollama_embedding.py new file mode 100644 index 000000000..dabefead1 --- /dev/null +++ b/apis/python/src/tiledb/vector_search/embeddings/ollama_embedding.py @@ -0,0 +1,49 @@ +from typing import Dict, Optional, OrderedDict, Sequence, Union + +import numpy as np + +# from tiledb.vector_search.embeddings import ObjectEmbedding + + +class OllamaEmbedding: + """ + Embedding functions from Ollama. + + This attempts to import the embedding_class from the ollama module. + """ + + def __init__( + self, + dimensions: int, + embedding_class: str = "embed", # really it's the method + embedding_kwargs: Optional[Dict] = None, + ): + self.dim_num = dimensions + self.embedding_class = embedding_class + self.embedding_kwargs = embedding_kwargs + + def init_kwargs(self) -> Dict: + return { + "dimensions": self.dim_num, + "embedding_class": self.embedding_class, + "embedding_kwargs": self.embedding_kwargs, + } + + def dimensions(self) -> int: + return self.dim_num + + def vector_type(self) -> np.dtype: + return np.float32 + + def load(self) -> None: + import importlib + + try: + embeddings_module = importlib.import_module("ollama") + embedding_method_ = getattr(embeddings_module, self.embedding_class) + self.embedding = embedding_method_(**self.embedding_kwargs) + except ImportError as e: + print(e) + + def embed(self, objects: Union[str, Sequence[str]]) -> np.ndarray: + return np.array(self.embedding(input=objects).embeddings, dtype=np.float32) diff --git a/apis/python/test/test_ingestion.py b/apis/python/test/test_ingestion.py index 935482162..172d6c039 100644 --- a/apis/python/test/test_ingestion.py +++ b/apis/python/test/test_ingestion.py @@ -2012,6 +2012,74 @@ def test_ivf_flat_taskgraph_query(tmp_path): assert accuracy(result, gt_i) > MINIMUM_ACCURACY +def test_ollama_embedding(): + """Test OllamaEmbedding class with mocked ollama library.""" + from unittest.mock import MagicMock + from unittest.mock import Mock + from unittest.mock import patch + + from tiledb.vector_search.embeddings import OllamaEmbedding + + # Test initialization + dimensions = 384 + embedding_class = "embed" + embedding_kwargs = {"model": "nomic-embed-text"} + + embedding = OllamaEmbedding( + dimensions=dimensions, + embedding_class=embedding_class, + embedding_kwargs=embedding_kwargs, + ) + + # Test dimensions() method + assert embedding.dimensions() == dimensions + + # Test vector_type() method + assert embedding.vector_type() == np.float32 + + # Test init_kwargs() method + init_kwargs = embedding.init_kwargs() + assert init_kwargs["dimensions"] == dimensions + assert init_kwargs["embedding_class"] == embedding_class + assert init_kwargs["embedding_kwargs"] == embedding_kwargs + + # Mock the ollama module + mock_ollama = MagicMock() + + # Create a mock embedding result with the expected structure + mock_embed_result = Mock() + mock_embed_result.embeddings = [ + [0.1] * dimensions, # 384 dimensions for first text + [0.2] * dimensions, # 384 dimensions for second text + ] + + # Create a mock callable that will be returned by embed(**kwargs) + mock_callable = Mock(return_value=mock_embed_result) + + # Mock the embed function to return our callable when called with **embedding_kwargs + mock_ollama.embed = Mock(return_value=mock_callable) + + # Patch the importlib.import_module to return our mock + with patch("importlib.import_module", return_value=mock_ollama): + # Test load() method + embedding.load() + + # Test embed() method with multiple texts + test_texts = ["hello world", "test document"] + result = embedding.embed(test_texts) + + # Verify the result + assert isinstance(result, np.ndarray) + assert result.dtype == np.float32 + assert result.shape == (2, dimensions) + + # Verify embed was called with correct kwargs during load + mock_ollama.embed.assert_called_once_with(model="nomic-embed-text") + + # Verify the callable was called with correct input parameter + mock_callable.assert_called_once_with(input=test_texts) + + def test_dimensions_parameter_override(tmp_path): """ Test the dimensions parameter functionality with TileDB array input.