|
5 | 5 | import asyncio |
6 | 6 | import hashlib |
7 | 7 | import json |
8 | | -from typing import Any, List, Optional, Union |
| 8 | +from typing import Any, List, Literal, Optional, Union |
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | from langchain_core.caches import RETURN_VAL_TYPE, BaseCache |
@@ -701,3 +701,168 @@ async def aupdate( |
701 | 701 | async def aclear(self, **kwargs: Any) -> None: |
702 | 702 | """Async clear cache that can take additional keyword arguments.""" |
703 | 703 | await self.cache.aclear() |
| 704 | + |
| 705 | + |
| 706 | +class LangCacheSemanticCache(BaseCache): |
| 707 | + """Semantic cache backed by RedisVL's LangCacheSemanticCache. |
| 708 | +
|
| 709 | + This uses redisvl.extensions.cache.llm.LangCacheSemanticCache (a wrapper over the |
| 710 | + managed LangCache API). The optional dependency ``langcache`` must be installed |
| 711 | + at runtime when this class is actually used. |
| 712 | +
|
| 713 | + Install with either ``pip install 'langchain-redis[langcache]'`` or |
| 714 | + ``pip install 'langcache>=0.10.0'``. |
| 715 | +
|
| 716 | + Parameters mirror ``RedisSemanticCache`` where possible. ``name``/``prefix`` |
| 717 | + are combined to derive a human-friendly cache name. The LangCache |
| 718 | + ``cache_id`` must be provided explicitly -- obtain it from the LangCache |
| 719 | + service before instantiating this class. |
| 720 | + """ |
| 721 | + |
| 722 | + def __init__( |
| 723 | + self, |
| 724 | + embeddings: Embeddings, |
| 725 | + distance_threshold: float = 0.2, |
| 726 | + ttl: Optional[int] = None, |
| 727 | + name: Optional[str] = "llmcache", |
| 728 | + prefix: Optional[str] = "llmcache", |
| 729 | + *, |
| 730 | + server_url: Optional[str] = None, |
| 731 | + api_key: Optional[str] = None, |
| 732 | + cache_id: Optional[str] = None, |
| 733 | + use_exact_search: bool = True, |
| 734 | + use_semantic_search: bool = True, |
| 735 | + distance_scale: Literal["normalized", "redis"] = "normalized", |
| 736 | + **kwargs: Any, |
| 737 | + ): |
| 738 | + if not cache_id: |
| 739 | + raise ValueError("cache_id is required for LangCacheSemanticCache") |
| 740 | + if not api_key: |
| 741 | + raise ValueError("api_key is required for LangCacheSemanticCache") |
| 742 | + |
| 743 | + # RedisVL's SemanticCache uses 'name' as the prefix for keys. |
| 744 | + # To support the 'prefix' parameter for multi-tenant isolation, |
| 745 | + # we need to map it appropriately: |
| 746 | + # - If both name and prefix are provided and different, combine them |
| 747 | + # - If only prefix is provided (and differs from default), use it |
| 748 | + # - Otherwise use name (maintains backward compatibility) |
| 749 | + cache_name = name |
| 750 | + if prefix and prefix != "llmcache": |
| 751 | + if name and name != "llmcache" and name != prefix: |
| 752 | + cache_name = f"{name}:{prefix}" |
| 753 | + else: |
| 754 | + cache_name = prefix |
| 755 | + self._cache_name = cache_name or "llmcache" |
| 756 | + |
| 757 | + self.ttl = ttl |
| 758 | + self._distance_threshold = distance_threshold |
| 759 | + # Store embeddings for optional future vectorization; avoid constructing |
| 760 | + # a BaseVectorizer eagerly (tests may pass MagicMocks without real dims) |
| 761 | + self.embeddings = embeddings |
| 762 | + |
| 763 | + try: |
| 764 | + from redisvl.extensions.cache.llm import ( |
| 765 | + LangCacheSemanticCache as RVLLangCacheSemanticCache, |
| 766 | + ) |
| 767 | + except ImportError as e: |
| 768 | + # Check if this is a missing langcache dependency or outdated redisvl |
| 769 | + error_msg = str(e).lower() |
| 770 | + if "langcache" in error_msg: |
| 771 | + raise ImportError( |
| 772 | + "LangCacheSemanticCache requires the langcache package. " |
| 773 | + "Install it with: pip install langcache " |
| 774 | + "or pip install 'redisvl[langcache]'" |
| 775 | + ) from e |
| 776 | + else: |
| 777 | + raise ImportError( |
| 778 | + "LangCacheSemanticCache requires redisvl>=0.11.0. " |
| 779 | + "Update redisvl with: pip install --upgrade redisvl" |
| 780 | + ) from e |
| 781 | + |
| 782 | + # Instantiate the LangCache wrapper; it will validate cache_id/api_key |
| 783 | + self.cache: Any = RVLLangCacheSemanticCache( |
| 784 | + name=self._cache_name, |
| 785 | + server_url=server_url or "https://aws-us-east-1.langcache.redis.io", |
| 786 | + cache_id=cache_id or "", |
| 787 | + api_key=api_key or "", |
| 788 | + ttl=ttl, |
| 789 | + use_exact_search=use_exact_search, |
| 790 | + use_semantic_search=use_semantic_search, |
| 791 | + distance_scale=distance_scale, |
| 792 | + **kwargs, |
| 793 | + ) |
| 794 | + |
| 795 | + def _vectorize(self, prompt: str) -> List[float]: |
| 796 | + """Vectorize via embeddings directly. Kept for future use. |
| 797 | + Note: We do not depend on BaseVectorizer here to keep initialization light. |
| 798 | + """ |
| 799 | + # Embeddings.embed_query returns a list[float] in real usage; mocks are fine. |
| 800 | + return self.embeddings.embed_query(prompt) # type: ignore[no-any-return] |
| 801 | + |
| 802 | + def lookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: |
| 803 | + """Lookup using RedisVL's LangCacheSemanticCache check API.""" |
| 804 | + results = self.cache.check( |
| 805 | + prompt=prompt, |
| 806 | + num_results=1, |
| 807 | + distance_threshold=self._distance_threshold, |
| 808 | + attributes={"llm_string": llm_string}, |
| 809 | + ) |
| 810 | + for result in results: |
| 811 | + metadata = result.get("metadata", {}) or {} |
| 812 | + if metadata.get("llm_string") == llm_string: |
| 813 | + try: |
| 814 | + return [loads(s) for s in json.loads(result.get("response", "[]"))] |
| 815 | + except (json.JSONDecodeError, TypeError): |
| 816 | + return None |
| 817 | + return None |
| 818 | + |
| 819 | + def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: |
| 820 | + """Store using RedisVL's LangCacheSemanticCache store API.""" |
| 821 | + serialized_response = json.dumps([dumps(gen) for gen in return_val]) |
| 822 | + # LangCacheSemanticCache ignores per-entry TTL; it uses cache-level TTL if set |
| 823 | + self.cache.store( |
| 824 | + prompt=prompt, |
| 825 | + response=serialized_response, |
| 826 | + metadata={"llm_string": llm_string}, |
| 827 | + ttl=self.ttl, |
| 828 | + ) |
| 829 | + |
| 830 | + def clear(self, **kwargs: Any) -> None: |
| 831 | + """Clear all entries via the wrapper's clear API.""" |
| 832 | + self.cache.clear() |
| 833 | + |
| 834 | + def name(self) -> str: |
| 835 | + return self._cache_name |
| 836 | + |
| 837 | + async def alookup(self, prompt: str, llm_string: str) -> Optional[RETURN_VAL_TYPE]: |
| 838 | + """Async lookup through RedisVL's LangCacheSemanticCache.""" |
| 839 | + results = await self.cache.acheck( |
| 840 | + prompt=prompt, |
| 841 | + num_results=1, |
| 842 | + distance_threshold=self._distance_threshold, |
| 843 | + attributes={"llm_string": llm_string}, |
| 844 | + ) |
| 845 | + for result in results: |
| 846 | + metadata = result.get("metadata", {}) or {} |
| 847 | + if metadata.get("llm_string") == llm_string: |
| 848 | + try: |
| 849 | + return [loads(s) for s in json.loads(result.get("response", "[]"))] |
| 850 | + except (json.JSONDecodeError, TypeError): |
| 851 | + return None |
| 852 | + return None |
| 853 | + |
| 854 | + async def aupdate( |
| 855 | + self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE |
| 856 | + ) -> None: |
| 857 | + """Async store using RedisVL's LangCacheSemanticCache.""" |
| 858 | + serialized_response = json.dumps([dumps(gen) for gen in return_val]) |
| 859 | + await self.cache.astore( |
| 860 | + prompt=prompt, |
| 861 | + response=serialized_response, |
| 862 | + metadata={"llm_string": llm_string}, |
| 863 | + ttl=self.ttl, |
| 864 | + ) |
| 865 | + |
| 866 | + async def aclear(self, **kwargs: Any) -> None: |
| 867 | + """Async clear via the wrapper's aclear API.""" |
| 868 | + await self.cache.aclear() |
0 commit comments