Skip to content

Commit b41de2e

Browse files
committed
Add JSON example to skeleton reask
1 parent f805ce7 commit b41de2e

File tree

6 files changed

+98
-5
lines changed

6 files changed

+98
-5
lines changed

guardrails/constants.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@ ${output_schema}
5252
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
5353
</json_suffix_without_examples>
5454

55+
<json_suffix_with_structure_example>
56+
${gr.json_suffix_without_examples}
57+
Here's an example of the structure:
58+
${json_example}
59+
</json_suffix_with_structure_example>
60+
5561
<complete_json_suffix>
5662
Given below is XML that describes the information to extract from this document and the tags to extract it into.
5763

guardrails/datatypes.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def __init__(
6969
self.description = description
7070
self.optional = optional
7171

72+
def get_example(self):
73+
raise NotImplementedError
74+
7275
@property
7376
def validators(self) -> TypedList:
7477
return self.validators_attr.validators
@@ -186,6 +189,9 @@ class String(ScalarType):
186189

187190
tag = "string"
188191

192+
def get_example(self):
193+
return "string"
194+
189195
def from_str(self, s: str) -> Optional[str]:
190196
"""Create a String from a string."""
191197
return to_string(s)
@@ -212,6 +218,9 @@ class Integer(ScalarType):
212218

213219
tag = "integer"
214220

221+
def get_example(self):
222+
return 1
223+
215224
def from_str(self, s: str) -> Optional[int]:
216225
"""Create an Integer from a string."""
217226
return to_int(s)
@@ -223,6 +232,9 @@ class Float(ScalarType):
223232

224233
tag = "float"
225234

235+
def get_example(self):
236+
return 1.5
237+
226238
def from_str(self, s: str) -> Optional[float]:
227239
"""Create a Float from a string."""
228240
return to_float(s)
@@ -234,6 +246,9 @@ class Boolean(ScalarType):
234246

235247
tag = "bool"
236248

249+
def get_example(self):
250+
return True
251+
237252
def from_str(self, s: Union[str, bool]) -> Optional[bool]:
238253
"""Create a Boolean from a string."""
239254
if s is None:
@@ -271,6 +286,9 @@ def __init__(
271286
super().__init__(children, validators_attr, optional, name, description)
272287
self.date_format = None
273288

289+
def get_example(self):
290+
return datetime.date.today()
291+
274292
def from_str(self, s: str) -> Optional[datetime.date]:
275293
"""Create a Date from a string."""
276294
if s is None:
@@ -310,6 +328,9 @@ def __init__(
310328
self.time_format = "%H:%M:%S"
311329
super().__init__(children, validators_attr, optional, name, description)
312330

331+
def get_example(self):
332+
return datetime.time()
333+
313334
def from_str(self, s: str) -> Optional[datetime.time]:
314335
"""Create a Time from a string."""
315336
if s is None:
@@ -338,6 +359,9 @@ def __init__(self, *args, **kwargs):
338359
super().__init__(*args, **kwargs)
339360
deprecate_type(type(self))
340361

362+
def get_example(self):
363+
return "hello@example.com"
364+
341365

342366
@deprecate_type
343367
@register_type("url")
@@ -350,6 +374,9 @@ def __init__(self, *args, **kwargs):
350374
super().__init__(*args, **kwargs)
351375
deprecate_type(type(self))
352376

377+
def get_example(self):
378+
return "https://example.com"
379+
353380

354381
@deprecate_type
355382
@register_type("pythoncode")
@@ -362,6 +389,9 @@ def __init__(self, *args, **kwargs):
362389
super().__init__(*args, **kwargs)
363390
deprecate_type(type(self))
364391

392+
def get_example(self):
393+
return "print('hello world')"
394+
365395

366396
@deprecate_type
367397
@register_type("sql")
@@ -374,13 +404,19 @@ def __init__(self, *args, **kwargs):
374404
super().__init__(*args, **kwargs)
375405
deprecate_type(type(self))
376406

407+
def get_example(self):
408+
return "SELECT * FROM table"
409+
377410

378411
@register_type("percentage")
379412
class Percentage(ScalarType):
380413
"""Element tag: `<percentage>`"""
381414

382415
tag = "percentage"
383416

417+
def get_example(self):
418+
return "20%"
419+
384420

385421
@register_type("enum")
386422
class Enum(ScalarType):
@@ -400,6 +436,9 @@ def __init__(
400436
super().__init__(children, validators_attr, optional, name, description)
401437
self.enum_values = enum_values
402438

439+
def get_example(self):
440+
return self.enum_values[0]
441+
403442
def from_str(self, s: str) -> Optional[str]:
404443
"""Create an Enum from a string."""
405444
if s is None:
@@ -432,6 +471,9 @@ class List(NonScalarType):
432471

433472
tag = "list"
434473

474+
def get_example(self):
475+
return [e.get_example() for e in self._children.values()]
476+
435477
def collect_validation(
436478
self,
437479
key: str,
@@ -474,6 +516,9 @@ class Object(NonScalarType):
474516

475517
tag = "object"
476518

519+
def get_example(self):
520+
return {k: v.get_example() for k, v in self._children.items()}
521+
477522
def collect_validation(
478523
self,
479524
key: str,
@@ -544,6 +589,14 @@ def __init__(
544589
super().__init__(children, validators_attr, optional, name, description)
545590
self.discriminator_key = discriminator_key
546591

592+
def get_example(self):
593+
first_discriminator = list(self._children.keys())[0]
594+
first_child = list(self._children.values())[0]
595+
return {
596+
self.discriminator_key: first_discriminator,
597+
**first_child.get_example(),
598+
}
599+
547600
@classmethod
548601
def from_xml(cls, element: ET._Element, strict: bool = False, **kwargs) -> Self:
549602
# grab `discriminator` attribute
@@ -604,6 +657,9 @@ def __init__(
604657
) -> None:
605658
super().__init__(children, validators_attr, optional, name, description)
606659

660+
def get_example(self):
661+
return {k: v.get_example() for k, v in self._children.items()}
662+
607663
def collect_validation(
608664
self,
609665
key: str,

guardrails/schema.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def check_valid_reask_prompt(self, reask_prompt: Optional[str]) -> None:
219219

220220

221221
class JsonSchema(Schema):
222-
reask_prompt_vars = {"previous_response", "output_schema"}
222+
reask_prompt_vars = {"previous_response", "output_schema", "json_example"}
223223

224224
def __init__(
225225
self,
@@ -269,7 +269,7 @@ def get_reask_setup(
269269
if reask_prompt_template is None:
270270
reask_prompt_template = Prompt(
271271
constants["high_level_skeleton_reask_prompt"]
272-
+ constants["json_suffix_without_examples"]
272+
+ constants["json_suffix_with_structure_example"]
273273
)
274274

275275
# This is incorrect
@@ -300,6 +300,10 @@ def get_reask_setup(
300300
)
301301

302302
pruned_tree_string = pruned_tree_schema.transpile()
303+
json_example = json.dumps(
304+
pruned_tree_schema.root_datatype.get_example(),
305+
indent=2,
306+
)
303307

304308
def reask_decoder(obj):
305309
decoded = {}
@@ -317,6 +321,7 @@ def reask_decoder(obj):
317321
reask_value, indent=2, default=reask_decoder, ensure_ascii=False
318322
),
319323
output_schema=pruned_tree_string,
324+
json_example=json_example,
320325
**(prompt_params or {}),
321326
)
322327

tests/integration_tests/test_assets/entity_extraction/compiled_prompt_skeleton_reask_2.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ I was given the following JSON response, which had problems due to incorrect val
6969

7070
Help me correct the incorrect values based on the given error messages.
7171

72+
7273
Given below is XML that describes the information to extract from this document and the tags to extract it into.
7374

7475
<output>
@@ -85,6 +86,18 @@ Given below is XML that describes the information to extract from this document
8586

8687
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
8788

89+
Here's an example of the structure:
90+
{
91+
"fees": [
92+
{
93+
"name": "string",
94+
"explanation": "string",
95+
"value": 1.5
96+
}
97+
],
98+
"interest_rates": {}
99+
}
100+
88101

89102
Json Output:
90103

tests/integration_tests/test_assets/pydantic/msg_compiled_prompt_reask.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ I was given the following JSON response, which had problems due to incorrect val
1313

1414
Help me correct the incorrect values based on the given error messages.
1515

16+
1617
Given below is XML that describes the information to extract from this document and the tags to extract it into.
1718

1819
<output>
@@ -23,3 +24,10 @@ Given below is XML that describes the information to extract from this document
2324

2425

2526
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
27+
28+
Here's an example of the structure:
29+
{
30+
"name": "string",
31+
"director": "string",
32+
"release_year": 1
33+
}

tests/unit_tests/utils/test_reask_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pytest
44
from lxml import etree as ET
55

6-
from guardrails import Instructions, Prompt
76
from guardrails.classes.history.iteration import Iteration
87
from guardrails.datatypes import Object
98
from guardrails.schema import JsonSchema
@@ -443,10 +442,14 @@ def test_get_reask_prompt(
443442
444443
Help me correct the incorrect values based on the given error messages.
445444
445+
446446
Given below is XML that describes the information to extract from this document and the tags to extract it into.
447447
%s
448448
449449
ONLY return a valid JSON object (no other text is necessary), where the key of the field in JSON is the `name` attribute of the corresponding XML, and the value is of the type specified by the corresponding XML's tag. The JSON MUST conform to the XML format, including any types and format requests e.g. requests for lists, objects and specific types. Be correct and concise. If you are unsure anywhere, enter `null`.
450+
451+
Here's an example of the structure:
452+
%s
450453
""" # noqa: E501
451454
expected_instructions = """
452455
You are a helpful assistant only capable of communicating with valid JSON, and no other text.
@@ -467,13 +470,15 @@ def test_get_reask_prompt(
467470
result_prompt,
468471
instructions,
469472
) = output_schema.get_reask_setup(reasks, reask_json, False)
473+
json_example = output_schema.root_datatype.get_example()
470474

471-
assert result_prompt == Prompt(
475+
assert result_prompt.source == (
472476
expected_result_template
473477
% (
474478
json.dumps(reask_json, indent=2),
475479
expected_rail,
480+
json.dumps(json_example, indent=2),
476481
)
477482
)
478483

479-
assert instructions == Instructions(expected_instructions)
484+
assert instructions.source == expected_instructions

0 commit comments

Comments
 (0)