Skip to content

Commit 67529d1

Browse files
authored
Merge pull request #511 from guardrails-ai/two-word-fix
Better Fix Logic for TwoWord Validator
2 parents f805ce7 + c7d8435 commit 67529d1

File tree

12 files changed

+145
-61
lines changed

12 files changed

+145
-61
lines changed

docs/examples/extracting_entities.ipynb

Lines changed: 38 additions & 38 deletions
Large diffs are not rendered by default.

guardrails/classes/history/call.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pydantic import Field
44
from rich.panel import Panel
5+
from rich.pretty import pretty_repr
56
from rich.tree import Tree
67

78
from guardrails.classes.generic.stack import Stack
@@ -319,4 +320,23 @@ def tree(self) -> Tree:
319320
tree = Tree("Logs")
320321
for i, iteration in enumerate(self.iterations):
321322
tree.add(Panel(iteration.rich_group, title=f"Step {i}"))
323+
324+
# Replace the last Validated Output panel if we applied fixes
325+
if self.failed_validations.length > 0 and self.status == pass_status:
326+
previous_panels = tree.children[ # type: ignore
327+
-1
328+
].label.renderable._renderables[ # type: ignore
329+
:-1
330+
]
331+
validated_outcome_panel = Panel(
332+
pretty_repr(self.validated_output),
333+
title="Validated Output",
334+
style="on #F0FFF0",
335+
)
336+
tree.children[
337+
-1
338+
].label.renderable._renderables = previous_panels + ( # type: ignore
339+
validated_outcome_panel,
340+
)
341+
322342
return tree

guardrails/utils/reask_utils.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,15 @@ def sub_reasks_with_fixed_values(value: Any) -> Any:
213213
Returns:
214214
The value with ReAsk objects replaced with their fixed values.
215215
"""
216-
if isinstance(value, list):
217-
for index, item in enumerate(value):
218-
value[index] = sub_reasks_with_fixed_values(item)
219-
elif isinstance(value, dict):
216+
copy = deepcopy(value)
217+
if isinstance(copy, list):
218+
for index, item in enumerate(copy):
219+
copy[index] = sub_reasks_with_fixed_values(item)
220+
elif isinstance(copy, dict):
220221
for dict_key, dict_value in value.items():
221-
value[dict_key] = sub_reasks_with_fixed_values(dict_value)
222-
elif isinstance(value, FieldReAsk):
222+
copy[dict_key] = sub_reasks_with_fixed_values(dict_value)
223+
elif isinstance(copy, FieldReAsk):
223224
# TODO handle multiple fail results
224-
value = value.fail_results[0].fix_value
225+
copy = copy.fail_results[0].fix_value
225226

226-
return value
227+
return copy

guardrails/validators/two_words.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any, Dict
22

3+
from pydash.strings import words as _words
4+
35
from guardrails.logger import logger
46
from guardrails.validator_base import (
57
FailResult,
@@ -23,13 +25,24 @@ class TwoWords(Validator):
2325
| Programmatic fix | Pick the first two words. |
2426
"""
2527

28+
def _get_fix_value(self, value: str) -> str:
29+
words = value.split()
30+
if len(words) == 1:
31+
words = _words(value)
32+
33+
if len(words) == 1:
34+
value = f"{value} {value}"
35+
words = value.split()
36+
37+
return " ".join(words[:2])
38+
2639
def validate(self, value: Any, metadata: Dict) -> ValidationResult:
2740
logger.debug(f"Validating {value} is two words...")
2841

2942
if len(value.split()) != 2:
3043
return FailResult(
3144
error_message="must be exactly two words",
32-
fix_value=" ".join(value.split()[:2]),
45+
fix_value=self._get_fix_value(str(value)),
3346
)
3447

3548
return PassResult()

poetry.lock

Lines changed: 18 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ anthropic = {version = "^0.7.2", optional = true}
4747
torch = {version = "^2.1.1", optional = true}
4848
guardrails-ai-unbabel-comet = {version = "^2.2.1", optional = true}
4949
huggingface_hub = {version = "^0.16.4", optional = true}
50+
pydash = "^7.0.6"
5051

5152

5253
[tool.poetry.extras]

tests/integration_tests/test_assets/entity_extraction/validated_output_fix.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
},
4040
{
4141
"index": 7,
42-
"name": "over-the-credit-limit",
42+
"name": "over the",
4343
"explanation": "Over-the-Credit-Limit None",
4444
"value": 0,
4545
},

tests/integration_tests/test_assets/entity_extraction/validated_output_reask_1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
fail_results=[
5757
FailResult(
5858
error_message="must be exactly two words",
59-
fix_value="over-the-credit-limit",
59+
fix_value="over the",
6060
)
6161
],
6262
path=["fees", 6, "name"],

tests/integration_tests/test_assets/entity_extraction/validated_output_reask_2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
},
4040
{
4141
"index": 7,
42-
"name": "over-the-credit-limit",
42+
"name": "over the",
4343
"explanation": "Over-the-Credit-Limit None",
4444
"value": 0,
4545
},

tests/integration_tests/test_assets/entity_extraction/validated_output_skeleton_reask_2.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
# flake8: noqa: E501
22
VALIDATED_OUTPUT_SKELETON_REASK_2 = {
33
"fees": [
4-
{"name": "annual_membership_fee", "explanation": "", "value": 0.0},
5-
{"name": "my_chase_plan_fee", "explanation": "", "value": 1.72},
6-
{"name": "balance_transfers", "explanation": "", "value": 5.0},
7-
{"name": "cash_advances", "explanation": "", "value": 5.0},
8-
{"name": "foreign_transactions", "explanation": "", "value": 3.0},
9-
{"name": "late_payment", "explanation": "", "value": 0.0},
10-
{"name": "over-the-credit-limit", "explanation": "", "value": 0.0},
11-
{"name": "return_payment", "explanation": "", "value": 0.0},
12-
{"name": "return_check", "explanation": "", "value": 0.0},
4+
{"name": "annual membership", "explanation": "", "value": 0.0},
5+
{"name": "my chase", "explanation": "", "value": 1.72},
6+
{"name": "balance transfers", "explanation": "", "value": 5.0},
7+
{"name": "cash advances", "explanation": "", "value": 5.0},
8+
{"name": "foreign transactions", "explanation": "", "value": 3.0},
9+
{"name": "late payment", "explanation": "", "value": 0.0},
10+
{"name": "over the", "explanation": "", "value": 0.0},
11+
{"name": "return payment", "explanation": "", "value": 0.0},
12+
{"name": "return check", "explanation": "", "value": 0.0},
1313
],
1414
"interest_rates": {
1515
"purchase": {

0 commit comments

Comments
 (0)