-
Notifications
You must be signed in to change notification settings - Fork 3
Add class that treats Codex as a backup #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
3acb048
74755d2
f7f8156
d20d31c
a223b04
f15424c
23eeb58
9e8690c
d0ad8df
b4bff54
038a475
5fbb48e
3892b52
22253e9
e4bdf2c
e5a6164
807d7fa
00def49
d8a6e86
2630a2c
3286674
0ebd4fe
4eca7d3
c59cec5
6026179
a94ffb5
2510255
b439113
38666de
a5d655b
e776dfe
7866f0c
26adbf1
36f80e9
febbfd0
dc1d003
739ffc6
3e4864a
81cc934
9e91e9b
c5843c9
49f9a9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| # SPDX-License-Identifier: MIT | ||
| from cleanlab_codex.codex import Codex | ||
| from cleanlab_codex.codex_backup import CodexBackup | ||
| from cleanlab_codex.codex_tool import CodexTool | ||
|
|
||
| __all__ = ["Codex", "CodexTool"] | ||
| __all__ = ["Codex", "CodexTool", "CodexBackup"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,112 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from functools import wraps | ||
| from typing import Any, Callable, Optional | ||
|
|
||
| from cleanlab_codex.codex import Codex | ||
| from cleanlab_codex.validation import is_bad_response | ||
|
|
||
|
|
||
| def handle_backup_default(backup_response: str, decorated_instance: Any) -> None: # noqa: ARG001 | ||
| """Default implementation is a no-op.""" | ||
| return None | ||
|
|
||
|
|
||
| class CodexBackup: | ||
| """A backup decorator that connects to a Codex project to answer questions that | ||
| cannot be adequately answered by the existing agent. | ||
| """ | ||
|
|
||
| DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." | ||
|
|
||
| def __init__( | ||
| self, | ||
| codex_client: Codex, | ||
| *, | ||
| project_id: Optional[str] = None, | ||
| fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, | ||
|
||
| backup_handler: Callable[[str, Any], None] = handle_backup_default, | ||
|
||
| ): | ||
| self._codex_client = codex_client | ||
| self._project_id = project_id | ||
| self._fallback_answer = fallback_answer | ||
| self._backup_handler = backup_handler | ||
|
|
||
| @classmethod | ||
| def from_access_key( | ||
| cls, | ||
| access_key: str, | ||
| *, | ||
| project_id: Optional[str] = None, | ||
| fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, | ||
| backup_handler: Callable[[str, Any], None] = handle_backup_default, | ||
| ) -> CodexBackup: | ||
| """Creates a CodexBackup from an access key. The project ID that the CodexBackup will use is the one that is associated with the access key.""" | ||
| return cls( | ||
| codex_client=Codex(key=access_key), | ||
| project_id=project_id, | ||
| fallback_answer=fallback_answer, | ||
| backup_handler=backup_handler, | ||
| ) | ||
|
|
||
| @classmethod | ||
| def from_client( | ||
| cls, | ||
| codex_client: Codex, | ||
| *, | ||
| project_id: Optional[str] = None, | ||
| fallback_answer: Optional[str] = DEFAULT_FALLBACK_ANSWER, | ||
| backup_handler: Callable[[str, Any], None] = handle_backup_default, | ||
| ) -> CodexBackup: | ||
| """Creates a CodexBackup from a Codex client. | ||
| If the Codex client is initialized with a project access key, the CodexBackup will use the project ID that is associated with the access key. | ||
| If the Codex client is initialized with a user API key, a project ID must be provided. | ||
| """ | ||
| return cls( | ||
| codex_client=codex_client, | ||
| project_id=project_id, | ||
| fallback_answer=fallback_answer, | ||
| backup_handler=backup_handler, | ||
| ) | ||
|
|
||
| def to_decorator(self): | ||
|
||
| """Factory that creates a backup decorator using the provided Codex client""" | ||
|
|
||
| def decorator(chat_method): | ||
| """ | ||
| Decorator for RAG chat methods that adds backup response handling. | ||
|
|
||
| If the original chat method returns an inadequate response, attempts to get | ||
| a backup response from Codex. Returns the backup response if available, | ||
| otherwise returns the original response. | ||
|
|
||
| Args: | ||
| chat_method: Method with signature (self, user_message: str) -> str | ||
|
||
| where 'self' refers to the instance being decorated, not an instance of CodexBackup. | ||
| """ | ||
|
|
||
| @wraps(chat_method) | ||
| def wrapper(decorated_instance, user_message): | ||
| # Call the original chat method | ||
| assistant_response = chat_method(decorated_instance, user_message) | ||
|
|
||
| # Return original response if it's adequate | ||
| # TODO: Update usage of is_bad_response | ||
| if not is_bad_response(assistant_response, self._fallback_answer): | ||
| return assistant_response | ||
|
|
||
| # Query Codex for a backup response | ||
| cache_result = self._codex_client.query(user_message)[0] | ||
| if not cache_result: | ||
| return assistant_response | ||
|
|
||
| # Handle backup response if handler exists | ||
| self._backup_handler( | ||
| backup_response=cache_result, | ||
| decorated_instance=decorated_instance, | ||
| ) | ||
| return cache_result | ||
|
|
||
| return wrapper | ||
|
|
||
| return decorator | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| """ | ||
| Utility functions for RAG (Retrieval Augmented Generation) operations. | ||
| """ | ||
|
|
||
|
|
||
| def default_format_prompt(query: str, context: str) -> str: | ||
| """Default function for formatting RAG prompts. | ||
|
|
||
| Args: | ||
| query: The user's question | ||
| context: The context/documents to use for answering | ||
|
|
||
| Returns: | ||
| str: A formatted prompt combining the query and context | ||
| """ | ||
| template = ( | ||
| "Using only information from the following Context, answer the following Query.\n\n" | ||
| "Context:\n{context}\n\n" | ||
| "Query: {query}" | ||
| ) | ||
| return template.format(context=context, query=query) |
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,190 @@ | ||
| """ | ||
elisno marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| This module provides validation functions for checking if an LLM response is unhelpful. | ||
elisno marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Callable, Optional | ||
|
|
||
| from cleanlab_codex.utils.prompt import default_format_prompt | ||
|
|
||
| if TYPE_CHECKING: | ||
| from cleanlab_studio.studio.trustworthy_language_model import TLM # type: ignore | ||
|
|
||
|
|
||
| DEFAULT_FALLBACK_ANSWER = "Based on the available information, I cannot provide a complete answer to this question." | ||
| DEFAULT_PARTIAL_RATIO_THRESHOLD = 70 | ||
| DEFAULT_TRUSTWORTHINESS_THRESHOLD = 0.5 | ||
|
|
||
|
|
||
| def is_bad_response( | ||
elisno marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| response: str, | ||
| context: str, | ||
| tlm: TLM, # TODO: Make this optional | ||
| query: Optional[str] = None, | ||
| # is_fallback_response args | ||
| fallback_answer: str = DEFAULT_FALLBACK_ANSWER, | ||
elisno marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| partial_ratio_threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD, | ||
elisno marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # is_untrustworthy_response args | ||
| trustworthiness_threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, | ||
| # is_unhelpful_response args | ||
| unhelpful_trustworthiness_threshold: Optional[float] = None, | ||
elisno marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) -> bool: | ||
| """Run a series of checks to determine if a response is bad. If any of the checks pass, return True. | ||
elisno marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Checks: | ||
| - Is the response too similar to a known fallback answer? | ||
| - Is the response untrustworthy? | ||
| - Is the response unhelpful? | ||
|
|
||
| Args: | ||
| response: The response to check. See `is_fallback_response`, `is_untrustworthy_response`, and `is_unhelpful_response`. | ||
| context: The context/documents to use for answering. See `is_untrustworthy_response`. | ||
| tlm: The TLM model to use for evaluation. See `is_untrustworthy_response` and `is_unhelpful_response`. | ||
| query: The user's question (optional). See `is_untrustworthy_response` and `is_unhelpful_response`. | ||
| fallback_answer: The fallback answer to compare against. See `is_fallback_response`. | ||
| partial_ratio_threshold: The threshold for detecting fallback responses. See `is_fallback_response`. | ||
| trustworthiness_threshold: The threshold for detecting untrustworthy responses. See `is_untrustworthy_response`. | ||
| unhelpful_trustworthiness_threshold: The threshold for detecting unhelpful responses. See `is_unhelpful_response`. | ||
| """ | ||
| validation_checks = [ | ||
| lambda: is_fallback_response(response, fallback_answer, threshold=partial_ratio_threshold), | ||
| lambda: ( | ||
| is_untrustworthy_response(response, context, query, tlm, threshold=trustworthiness_threshold) | ||
| if query is not None | ||
| else False | ||
| ), | ||
| lambda: is_unhelpful_response( | ||
| response, tlm, query, trustworthiness_score_threshold=unhelpful_trustworthiness_threshold | ||
| ), | ||
| ] | ||
|
|
||
| return any(check() for check in validation_checks) | ||
|
|
||
|
|
||
| def is_fallback_response( | ||
| response: str, fallback_answer: str = DEFAULT_FALLBACK_ANSWER, threshold: int = DEFAULT_PARTIAL_RATIO_THRESHOLD | ||
| ) -> bool: | ||
| """Check if a response is too similar to a known fallback answer. | ||
|
|
||
| Uses fuzzy string matching to compare the response against a known fallback answer. | ||
| Returns True if the response is similar enough to be considered unhelpful. | ||
|
|
||
| Args: | ||
| response: The response to check | ||
| fallback_answer: A known unhelpful/fallback response to compare against | ||
| threshold: Similarity threshold (0-100). Higher values require more similarity. | ||
| Default 70 means responses that are 70% or more similar are considered bad. | ||
|
|
||
| Returns: | ||
| bool: True if the response is too similar to the fallback answer, False otherwise | ||
| """ | ||
| try: | ||
| from thefuzz import fuzz # type: ignore | ||
| except ImportError as e: | ||
| error_msg = "The 'thefuzz' library is required. Please install it with `pip install thefuzz`." | ||
| raise ImportError(error_msg) from e | ||
|
|
||
| partial_ratio: int = fuzz.partial_ratio(fallback_answer.lower(), response.lower()) | ||
jwmueller marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return bool(partial_ratio >= threshold) | ||
|
|
||
|
|
||
| def is_untrustworthy_response( | ||
| response: str, | ||
| context: str, | ||
| query: str, | ||
| tlm: TLM, | ||
| threshold: float = DEFAULT_TRUSTWORTHINESS_THRESHOLD, | ||
elisno marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| format_prompt: Callable[[str, str], str] = default_format_prompt, | ||
| ) -> bool: | ||
| """Check if a response is untrustworthy based on TLM's evaluation. | ||
|
|
||
| Uses TLM to evaluate whether a response is trustworthy given the context and query. | ||
| Returns True if TLM's trustworthiness score falls below the threshold, indicating | ||
| the response may be incorrect or unreliable. | ||
|
|
||
| Args: | ||
| response: The response to check from the assistant | ||
| context: The context information available for answering the query | ||
| query: The user's question or request | ||
| tlm: The TLM model to use for evaluation | ||
| threshold: Score threshold (0.0-1.0). Lower values allow less trustworthy responses. | ||
| Default 0.6, meaning responses with scores less than 0.6 are considered untrustworthy. | ||
| format_prompt: Function that takes (query, context) and returns a formatted prompt string. | ||
| Users should provide their RAG app's own prompt formatting function here | ||
| to match how their LLM is prompted. | ||
|
|
||
| Returns: | ||
| bool: True if the response is deemed untrustworthy by TLM, False otherwise | ||
| """ | ||
| try: | ||
| from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 | ||
| except ImportError as e: | ||
| error_msg = "The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`." | ||
| raise ImportError(error_msg) from e | ||
|
|
||
| prompt = format_prompt(query, context) | ||
| resp = tlm.get_trustworthiness_score(prompt, response) | ||
| score: float = resp["trustworthiness_score"] | ||
elisno marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return score < threshold | ||
|
|
||
|
|
||
| def is_unhelpful_response( | ||
| response: str, tlm: TLM, query: Optional[str] = None, trustworthiness_score_threshold: Optional[float] = None | ||
| ) -> bool: | ||
elisno marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """Check if a response is unhelpful by asking TLM to evaluate it. | ||
|
|
||
| Uses TLM to evaluate whether a response is helpful by asking it to make a Yes/No judgment. | ||
| The evaluation considers both the TLM's binary classification of helpfulness and its | ||
| confidence score. Returns True only if TLM classifies the response as unhelpful AND | ||
| is sufficiently confident in that assessment (if a threshold is provided). | ||
|
|
||
| Args: | ||
| response: The response to check from the assistant | ||
| tlm: The TLM model to use for evaluation | ||
| query: Optional user query to provide context for evaluating helpfulness. | ||
| If provided, TLM will assess if the response helpfully answers this query. | ||
| trustworthiness_score_threshold: Optional confidence threshold (0.0-1.0). | ||
| If provided, responses are only marked unhelpful if TLM's | ||
| confidence score exceeds this threshold. | ||
|
|
||
| Returns: | ||
| bool: True if TLM determines the response is unhelpful with sufficient confidence, | ||
| False otherwise | ||
| """ | ||
| try: | ||
| from cleanlab_studio.studio.trustworthy_language_model import TLM # noqa: F401 | ||
| except ImportError as e: | ||
| error_msg = "The 'cleanlab_studio' library is required. Please install it with `pip install cleanlab-studio`." | ||
| raise ImportError(error_msg) from e | ||
|
|
||
| # The question and expected "unhelpful" response are linked: | ||
| # - When asking "is helpful?" -> "no" means unhelpful | ||
| # - When asking "is unhelpful?" -> "yes" means unhelpful | ||
| question = ( | ||
| "Is the AI Assistant Response unhelpful? " | ||
| "Unhelpful responses include answers that:\n" | ||
| "- Are not useful, incomplete, incorrect, uncertain or unclear.\n" | ||
| "- Abstain or refuse to answer the question\n" | ||
| "- Leave the original question unresolved\n" | ||
| "- Are irrelevant to the question\n" | ||
| "Answer Yes/No only." | ||
| ) | ||
| expected_unhelpful_response = "yes" | ||
|
|
||
| prompt = ( | ||
| "Consider the following" | ||
| + ( | ||
| f" User Query and AI Assistant Response.\n\nUser Query: {query}\n\n" | ||
| if query | ||
| else " AI Assistant Response.\n\n" | ||
| ) | ||
| + f"AI Assistant Response: {response}\n\n{question}" | ||
| ) | ||
|
|
||
| output = tlm.prompt(prompt, constrain_outputs=["Yes", "No"]) | ||
| response_marked_unhelpful = output["response"].lower() == expected_unhelpful_response | ||
| is_trustworthy = trustworthiness_score_threshold is None or ( | ||
| output["trustworthiness_score"] > trustworthiness_score_threshold | ||
| ) | ||
| return response_marked_unhelpful and is_trustworthy | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Delete this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll hold off on removing this until we've finalized the code in "validation.py".
The intention was to pass the fallback answer from the backup object to the relevant
is_fallback_responsehelper function before deciding to call Codex as Backup.