|
5 | 5 | uv run -m pydantic_ai_examples.bank_support |
6 | 6 | """ |
7 | 7 |
|
| 8 | +import sqlite3 |
8 | 9 | from dataclasses import dataclass |
9 | 10 |
|
10 | 11 | from pydantic import BaseModel |
11 | 12 |
|
12 | 13 | from pydantic_ai import Agent, RunContext |
13 | 14 |
|
14 | 15 |
|
| 16 | +@dataclass |
15 | 17 | class DatabaseConn: |
16 | | - """This is a fake database for example purposes. |
17 | | -
|
18 | | - In reality, you'd be connecting to an external database |
19 | | - (e.g. PostgreSQL) to get information about customers. |
20 | | - """ |
21 | | - |
22 | | - @classmethod |
23 | | - async def customer_name(cls, *, id: int) -> str | None: |
24 | | - if id == 123: |
25 | | - return 'John' |
26 | | - |
27 | | - @classmethod |
28 | | - async def customer_balance(cls, *, id: int, include_pending: bool) -> float: |
29 | | - if id == 123: |
30 | | - if include_pending: |
31 | | - return 123.45 |
32 | | - else: |
33 | | - return 100.00 |
| 18 | + """A wrapper over the SQLite connection.""" |
| 19 | + |
| 20 | + sqlite_conn: sqlite3.Connection |
| 21 | + |
| 22 | + async def customer_name(self, *, id: int) -> str | None: |
| 23 | + res = cur.execute('SELECT name FROM customers WHERE id=?', (id,)) |
| 24 | + row = res.fetchone() |
| 25 | + if row: |
| 26 | + return row[0] |
| 27 | + return None |
| 28 | + |
| 29 | + async def customer_balance(self, *, id: int) -> float: |
| 30 | + res = cur.execute('SELECT balance FROM customers WHERE id=?', (id,)) |
| 31 | + row = res.fetchone() |
| 32 | + if row: |
| 33 | + return row[0] |
34 | 34 | else: |
35 | 35 | raise ValueError('Customer not found') |
36 | 36 |
|
37 | 37 |
|
| 38 | + |
38 | 39 | @dataclass |
39 | 40 | class SupportDependencies: |
40 | 41 | customer_id: int |
@@ -69,27 +70,33 @@ async def add_customer_name(ctx: RunContext[SupportDependencies]) -> str: |
69 | 70 |
|
70 | 71 |
|
71 | 72 | @support_agent.tool |
72 | | -async def customer_balance( |
73 | | - ctx: RunContext[SupportDependencies], include_pending: bool |
74 | | -) -> str: |
| 73 | +async def customer_balance(ctx: RunContext[SupportDependencies]) -> str: |
75 | 74 | """Returns the customer's current account balance.""" |
76 | 75 | balance = await ctx.deps.db.customer_balance( |
77 | 76 | id=ctx.deps.customer_id, |
78 | | - include_pending=include_pending, |
79 | 77 | ) |
80 | 78 | return f'${balance:.2f}' |
81 | 79 |
|
82 | 80 |
|
83 | 81 | if __name__ == '__main__': |
84 | | - deps = SupportDependencies(customer_id=123, db=DatabaseConn()) |
85 | | - result = support_agent.run_sync('What is my balance?', deps=deps) |
86 | | - print(result.output) |
87 | | - """ |
88 | | - support_advice='Hello John, your current account balance, including pending transactions, is $123.45.' block_card=False risk=1 |
89 | | - """ |
90 | | - |
91 | | - result = support_agent.run_sync('I just lost my card!', deps=deps) |
92 | | - print(result.output) |
93 | | - """ |
94 | | - support_advice="I'm sorry to hear that, John. We are temporarily blocking your card to prevent unauthorized transactions." block_card=True risk=8 |
95 | | - """ |
| 82 | + with sqlite3.connect(':memory:') as con: |
| 83 | + cur = con.cursor() |
| 84 | + cur.execute('CREATE TABLE customers(id, name, balance)') |
| 85 | + cur.execute(""" |
| 86 | + INSERT INTO customers VALUES |
| 87 | + (123, 'John', 123.45) |
| 88 | + """) |
| 89 | + con.commit() |
| 90 | + |
| 91 | + deps = SupportDependencies(customer_id=123, db=DatabaseConn(sqlite_conn=con)) |
| 92 | + result = support_agent.run_sync('What is my balance?', deps=deps) |
| 93 | + print(result.output) |
| 94 | + """ |
| 95 | + support_advice='Hello John, your current account balance, including pending transactions, is $123.45.' block_card=False risk=1 |
| 96 | + """ |
| 97 | + |
| 98 | + result = support_agent.run_sync('I just lost my card!', deps=deps) |
| 99 | + print(result.output) |
| 100 | + """ |
| 101 | + support_advice="I'm sorry to hear that, John. We are temporarily blocking your card to prevent unauthorized transactions." block_card=True risk=8 |
| 102 | + """ |
0 commit comments