Skip to content

Commit 8f74aa2

Browse files
refactor(llm): improve code quality in receipt extraction
- Reorder imports alphabetically (PIL after llama_parse) - Add strict=False to zip() for Python 3.10+ compatibility - Fix lambda variable binding in fuzzy matching loop
1 parent 7501b3a commit 8f74aa2

File tree

2 files changed

+327
-1
lines changed

2 files changed

+327
-1
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,6 @@ marimo_notebooks
152152
__marimo__/
153153

154154
# Claude Code documentation
155-
CLAUDE.md
155+
CLAUDE.md
156+
157+
data
Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
# /// script
2+
# requires-python = ">=3.11"
3+
# dependencies = [
4+
# "llama-index",
5+
# "llama-index-program-openai",
6+
# "llama-parse",
7+
# "python-dotenv",
8+
# "rapidfuzz",
9+
# ]
10+
# ///
11+
import os
12+
from datetime import date
13+
from pathlib import Path
14+
from typing import List, Optional
15+
16+
import pandas as pd
17+
from dotenv import load_dotenv
18+
from llama_index.core import Settings
19+
from llama_index.llms.openai import OpenAI
20+
from llama_index.program.openai import OpenAIPydanticProgram
21+
from llama_parse import LlamaParse
22+
from PIL import Image
23+
from pydantic import BaseModel, Field
24+
from rapidfuzz import fuzz
25+
26+
27+
def configure_settings() -> None:
28+
"""Load environment variables and configure the default LLM."""
29+
load_dotenv(override=True)
30+
31+
openai_key = os.environ.get("OPENAI_API_KEY")
32+
Settings.llm = OpenAI(api_key=openai_key, model="gpt-4o-mini", temperature=0)
33+
Settings.context_window = 8000
34+
35+
36+
def scale_image(image_path: Path, output_dir: Path, scale_factor: int = 3) -> Path:
37+
"""Scale up an image using high-quality resampling.
38+
39+
Args:
40+
image_path: Path to the original image
41+
output_dir: Directory to save the scaled image
42+
scale_factor: Factor to scale up the image (default: 3x)
43+
44+
Returns:
45+
Path to the scaled image
46+
"""
47+
# Load the image
48+
img = Image.open(image_path)
49+
50+
# Scale up the image using high-quality resampling
51+
new_size = (img.width * scale_factor, img.height * scale_factor)
52+
img_resized = img.resize(new_size, Image.Resampling.LANCZOS)
53+
54+
# Save to output directory with same filename
55+
output_dir.mkdir(parents=True, exist_ok=True)
56+
output_path = output_dir / image_path.name
57+
img_resized.save(output_path, quality=95)
58+
59+
return output_path
60+
61+
62+
class ReceiptItem(BaseModel):
63+
"""Line item extracted from a receipt."""
64+
65+
description: str = Field(description="Item name exactly as shown on the receipt")
66+
quantity: int = Field(default=1, ge=1, description="Integer quantity of the item")
67+
unit_price: Optional[float] = Field(
68+
default=None, ge=0, description="Price per unit in the receipt currency"
69+
)
70+
discount_amount: float = Field(
71+
default=0.0, ge=0, description="Discount applied to this line item"
72+
)
73+
74+
75+
class Receipt(BaseModel):
76+
"""Structured receipt fields extracted from OCR."""
77+
78+
company: str = Field(description="Business or merchant name")
79+
purchase_date: Optional[date] = Field(
80+
default=None, description="Date in YYYY-MM-DD format"
81+
)
82+
address: Optional[str] = Field(default=None, description="Address of the business")
83+
total: float = Field(description="Final charged amount")
84+
items: List[ReceiptItem] = Field(default_factory=list)
85+
86+
87+
def extract_documents(paths: List[str], prompt: str, id_column: str) -> List[dict]:
88+
"""Extract structured data from documents using LlamaParse and LLM.
89+
90+
Args:
91+
paths: List of document file paths
92+
prompt: Extraction prompt template
93+
id_column: ID column to identify the document
94+
95+
Returns:
96+
List of dictionaries with document_id and extracted data
97+
"""
98+
results: List[dict] = []
99+
100+
parser = LlamaParse(
101+
api_key=os.environ["LLAMA_CLOUD_API_KEY"],
102+
result_type="markdown",
103+
num_workers=4,
104+
language="en",
105+
skip_diagonal_text=True,
106+
)
107+
108+
documents = parser.load_data(paths)
109+
110+
program = OpenAIPydanticProgram.from_defaults(
111+
output_cls=Receipt,
112+
llm=Settings.llm,
113+
prompt_template_str=prompt,
114+
)
115+
116+
for path, doc in zip(paths, documents, strict=False):
117+
document_id = Path(path).stem
118+
parsed_document = program(context_str=doc.text)
119+
results.append(
120+
{
121+
id_column: document_id,
122+
"data": parsed_document,
123+
}
124+
)
125+
return results
126+
127+
128+
def transform_receipt_columns(df: pd.DataFrame) -> pd.DataFrame:
129+
"""Apply standard transformations to receipt DataFrame columns.
130+
131+
Transforms:
132+
- company: Convert to uppercase
133+
- total: Convert to numeric
134+
- purchase_date: Convert to date
135+
"""
136+
df = df.copy()
137+
138+
df["company"] = df["company"].str.upper()
139+
140+
df["total"] = pd.to_numeric(df["total"], errors="coerce")
141+
142+
df["purchase_date"] = pd.to_datetime(
143+
df["purchase_date"], errors="coerce", dayfirst=True
144+
).dt.date
145+
146+
return df
147+
148+
149+
def create_extracted_df(records: List[dict], id_column: str) -> pd.DataFrame:
150+
df = pd.DataFrame(
151+
[
152+
{
153+
id_column: record[id_column],
154+
"company": record["data"].company,
155+
"total": record["data"].total,
156+
"purchase_date": record["data"].purchase_date,
157+
}
158+
for record in records
159+
]
160+
)
161+
return transform_receipt_columns(df)
162+
163+
164+
def normalize_date(value: str) -> str:
165+
value = (value or "").strip()
166+
if not value:
167+
return value
168+
value = value.replace("-", "/")
169+
parts = value.split("/")
170+
if len(parts[-1]) == 2:
171+
parts[-1] = f"20{parts[-1]}"
172+
return "/".join(parts)
173+
174+
175+
def create_ground_truth_df(label_paths: List[str], id_column: str) -> pd.DataFrame:
176+
"""Create ground truth DataFrame from label JSON files."""
177+
records = []
178+
for path in label_paths:
179+
payload = pd.read_json(Path(path), typ="series").to_dict()
180+
records.append(
181+
{
182+
id_column: Path(path).stem,
183+
"company": payload.get("company", ""),
184+
"total": payload.get("total", ""),
185+
"purchase_date": normalize_date(payload.get("date", "")),
186+
}
187+
)
188+
189+
df = pd.DataFrame(records)
190+
return transform_receipt_columns(df)
191+
192+
193+
def fuzzy_match_score(text1: str, text2: str) -> int:
194+
"""Calculate fuzzy match score between two strings.
195+
196+
Args:
197+
text1: First string to compare
198+
text2: Second string to compare
199+
200+
Returns:
201+
Similarity score between 0 and 100
202+
"""
203+
return fuzz.token_set_ratio(str(text1), str(text2))
204+
205+
206+
def compare_receipts(
207+
extracted_df: pd.DataFrame,
208+
ground_truth_df: pd.DataFrame,
209+
id_column: str,
210+
fuzzy_match_cols: List[str],
211+
exact_match_cols: List[str],
212+
fuzzy_threshold: int = 80,
213+
) -> pd.DataFrame:
214+
"""Compare extracted and ground truth data with explicit column specifications.
215+
216+
Args:
217+
extracted_df: DataFrame with extracted data
218+
ground_truth_df: DataFrame with ground truth data
219+
id_column: Column to join on
220+
fuzzy_match_cols: Columns to compare using fuzzy matching
221+
exact_match_cols: Columns to compare using exact matching
222+
fuzzy_threshold: Similarity threshold for fuzzy matching (default: 80)
223+
"""
224+
comparison_df = extracted_df.merge(
225+
ground_truth_df,
226+
on=id_column,
227+
how="inner",
228+
suffixes=("_extracted", "_truth"),
229+
)
230+
231+
# Fuzzy matching
232+
for col in fuzzy_match_cols:
233+
extracted_col = f"{col}_extracted"
234+
truth_col = f"{col}_truth"
235+
comparison_df[f"{col}_score"] = comparison_df.apply(
236+
lambda row, ec=extracted_col, tc=truth_col: fuzzy_match_score(row[ec], row[tc]),
237+
axis=1,
238+
)
239+
comparison_df[f"{col}_match"] = comparison_df[f"{col}_score"] >= fuzzy_threshold
240+
241+
# Exact matching
242+
for col in exact_match_cols:
243+
extracted_col = f"{col}_extracted"
244+
truth_col = f"{col}_truth"
245+
comparison_df[f"{col}_match"] = (
246+
comparison_df[extracted_col] == comparison_df[truth_col]
247+
)
248+
249+
return comparison_df
250+
251+
252+
def get_mismatch_rows(comparison_df: pd.DataFrame) -> pd.DataFrame:
253+
"""Get mismatched rows, excluding match indicator columns."""
254+
# Extract match columns and data columns
255+
match_columns = [col for col in comparison_df.columns if col.endswith("_match")]
256+
data_columns = [col for col in comparison_df.columns if col not in match_columns]
257+
258+
# Check for rows where not all matches are True
259+
has_mismatch = comparison_df[match_columns].all(axis=1).eq(False)
260+
261+
return comparison_df[has_mismatch][data_columns]
262+
263+
264+
def main(
265+
receipt_paths: List[str],
266+
label_paths: List[str],
267+
preprocess: bool = True,
268+
output_dir: Path = Path("data/SROIE2019/train/img_adjusted"),
269+
id_column: str = "receipt_id",
270+
) -> None:
271+
configure_settings()
272+
273+
# Preprocess images if requested
274+
if preprocess:
275+
print("Preprocessing receipt images...")
276+
receipt_paths_to_parse = [
277+
scale_image(Path(p), output_dir, scale_factor=3) for p in receipt_paths
278+
]
279+
else:
280+
receipt_paths_to_parse = receipt_paths
281+
282+
prompt = """
283+
You are extracting structured data from a receipt.
284+
Use the provided text to populate the Receipt model.
285+
Interpret every receipt date as day-first.
286+
If a field is missing, return null.
287+
288+
{context_str}
289+
"""
290+
291+
structured_receipts = extract_documents(receipt_paths_to_parse, prompt, id_column)
292+
293+
extracted_df = create_extracted_df(structured_receipts, id_column)
294+
ground_truth_df = create_ground_truth_df(label_paths, id_column)
295+
296+
comparison_df = compare_receipts(
297+
extracted_df,
298+
ground_truth_df,
299+
id_column,
300+
fuzzy_match_cols=["company"],
301+
exact_match_cols=["total", "purchase_date"],
302+
)
303+
mismatch_df = get_mismatch_rows(comparison_df)
304+
305+
if mismatch_df.empty:
306+
print("All receipts matched the ground truth.")
307+
else:
308+
print("Mismatched receipts:")
309+
print(mismatch_df)
310+
311+
312+
if __name__ == "__main__":
313+
# Default paths
314+
receipt_dir = Path("data/SROIE2019/train/img")
315+
label_dir = Path("data/SROIE2019/train/entities")
316+
adjusted_receipt_dir = Path("data/SROIE2019/train/img_adjusted")
317+
318+
# Default number of receipts
319+
num_receipts = 10
320+
receipt_paths = sorted(receipt_dir.glob("*.jpg"))[:num_receipts]
321+
label_paths = sorted(label_dir.glob("*.txt"))[:num_receipts]
322+
323+
# Run the pipeline
324+
main(receipt_paths, label_paths, preprocess=True, output_dir=adjusted_receipt_dir)

0 commit comments

Comments
 (0)