|
| 1 | +# /// script |
| 2 | +# requires-python = ">=3.11" |
| 3 | +# dependencies = [ |
| 4 | +# "llama-index", |
| 5 | +# "llama-index-program-openai", |
| 6 | +# "llama-parse", |
| 7 | +# "python-dotenv", |
| 8 | +# "rapidfuzz", |
| 9 | +# ] |
| 10 | +# /// |
| 11 | +import os |
| 12 | +from datetime import date |
| 13 | +from pathlib import Path |
| 14 | +from typing import List, Optional |
| 15 | + |
| 16 | +import pandas as pd |
| 17 | +from dotenv import load_dotenv |
| 18 | +from llama_index.core import Settings |
| 19 | +from llama_index.llms.openai import OpenAI |
| 20 | +from llama_index.program.openai import OpenAIPydanticProgram |
| 21 | +from llama_parse import LlamaParse |
| 22 | +from PIL import Image |
| 23 | +from pydantic import BaseModel, Field |
| 24 | +from rapidfuzz import fuzz |
| 25 | + |
| 26 | + |
| 27 | +def configure_settings() -> None: |
| 28 | + """Load environment variables and configure the default LLM.""" |
| 29 | + load_dotenv(override=True) |
| 30 | + |
| 31 | + openai_key = os.environ.get("OPENAI_API_KEY") |
| 32 | + Settings.llm = OpenAI(api_key=openai_key, model="gpt-4o-mini", temperature=0) |
| 33 | + Settings.context_window = 8000 |
| 34 | + |
| 35 | + |
| 36 | +def scale_image(image_path: Path, output_dir: Path, scale_factor: int = 3) -> Path: |
| 37 | + """Scale up an image using high-quality resampling. |
| 38 | +
|
| 39 | + Args: |
| 40 | + image_path: Path to the original image |
| 41 | + output_dir: Directory to save the scaled image |
| 42 | + scale_factor: Factor to scale up the image (default: 3x) |
| 43 | +
|
| 44 | + Returns: |
| 45 | + Path to the scaled image |
| 46 | + """ |
| 47 | + # Load the image |
| 48 | + img = Image.open(image_path) |
| 49 | + |
| 50 | + # Scale up the image using high-quality resampling |
| 51 | + new_size = (img.width * scale_factor, img.height * scale_factor) |
| 52 | + img_resized = img.resize(new_size, Image.Resampling.LANCZOS) |
| 53 | + |
| 54 | + # Save to output directory with same filename |
| 55 | + output_dir.mkdir(parents=True, exist_ok=True) |
| 56 | + output_path = output_dir / image_path.name |
| 57 | + img_resized.save(output_path, quality=95) |
| 58 | + |
| 59 | + return output_path |
| 60 | + |
| 61 | + |
| 62 | +class ReceiptItem(BaseModel): |
| 63 | + """Line item extracted from a receipt.""" |
| 64 | + |
| 65 | + description: str = Field(description="Item name exactly as shown on the receipt") |
| 66 | + quantity: int = Field(default=1, ge=1, description="Integer quantity of the item") |
| 67 | + unit_price: Optional[float] = Field( |
| 68 | + default=None, ge=0, description="Price per unit in the receipt currency" |
| 69 | + ) |
| 70 | + discount_amount: float = Field( |
| 71 | + default=0.0, ge=0, description="Discount applied to this line item" |
| 72 | + ) |
| 73 | + |
| 74 | + |
| 75 | +class Receipt(BaseModel): |
| 76 | + """Structured receipt fields extracted from OCR.""" |
| 77 | + |
| 78 | + company: str = Field(description="Business or merchant name") |
| 79 | + purchase_date: Optional[date] = Field( |
| 80 | + default=None, description="Date in YYYY-MM-DD format" |
| 81 | + ) |
| 82 | + address: Optional[str] = Field(default=None, description="Address of the business") |
| 83 | + total: float = Field(description="Final charged amount") |
| 84 | + items: List[ReceiptItem] = Field(default_factory=list) |
| 85 | + |
| 86 | + |
| 87 | +def extract_documents(paths: List[str], prompt: str, id_column: str) -> List[dict]: |
| 88 | + """Extract structured data from documents using LlamaParse and LLM. |
| 89 | +
|
| 90 | + Args: |
| 91 | + paths: List of document file paths |
| 92 | + prompt: Extraction prompt template |
| 93 | + id_column: ID column to identify the document |
| 94 | +
|
| 95 | + Returns: |
| 96 | + List of dictionaries with document_id and extracted data |
| 97 | + """ |
| 98 | + results: List[dict] = [] |
| 99 | + |
| 100 | + parser = LlamaParse( |
| 101 | + api_key=os.environ["LLAMA_CLOUD_API_KEY"], |
| 102 | + result_type="markdown", |
| 103 | + num_workers=4, |
| 104 | + language="en", |
| 105 | + skip_diagonal_text=True, |
| 106 | + ) |
| 107 | + |
| 108 | + documents = parser.load_data(paths) |
| 109 | + |
| 110 | + program = OpenAIPydanticProgram.from_defaults( |
| 111 | + output_cls=Receipt, |
| 112 | + llm=Settings.llm, |
| 113 | + prompt_template_str=prompt, |
| 114 | + ) |
| 115 | + |
| 116 | + for path, doc in zip(paths, documents, strict=False): |
| 117 | + document_id = Path(path).stem |
| 118 | + parsed_document = program(context_str=doc.text) |
| 119 | + results.append( |
| 120 | + { |
| 121 | + id_column: document_id, |
| 122 | + "data": parsed_document, |
| 123 | + } |
| 124 | + ) |
| 125 | + return results |
| 126 | + |
| 127 | + |
| 128 | +def transform_receipt_columns(df: pd.DataFrame) -> pd.DataFrame: |
| 129 | + """Apply standard transformations to receipt DataFrame columns. |
| 130 | +
|
| 131 | + Transforms: |
| 132 | + - company: Convert to uppercase |
| 133 | + - total: Convert to numeric |
| 134 | + - purchase_date: Convert to date |
| 135 | + """ |
| 136 | + df = df.copy() |
| 137 | + |
| 138 | + df["company"] = df["company"].str.upper() |
| 139 | + |
| 140 | + df["total"] = pd.to_numeric(df["total"], errors="coerce") |
| 141 | + |
| 142 | + df["purchase_date"] = pd.to_datetime( |
| 143 | + df["purchase_date"], errors="coerce", dayfirst=True |
| 144 | + ).dt.date |
| 145 | + |
| 146 | + return df |
| 147 | + |
| 148 | + |
| 149 | +def create_extracted_df(records: List[dict], id_column: str) -> pd.DataFrame: |
| 150 | + df = pd.DataFrame( |
| 151 | + [ |
| 152 | + { |
| 153 | + id_column: record[id_column], |
| 154 | + "company": record["data"].company, |
| 155 | + "total": record["data"].total, |
| 156 | + "purchase_date": record["data"].purchase_date, |
| 157 | + } |
| 158 | + for record in records |
| 159 | + ] |
| 160 | + ) |
| 161 | + return transform_receipt_columns(df) |
| 162 | + |
| 163 | + |
| 164 | +def normalize_date(value: str) -> str: |
| 165 | + value = (value or "").strip() |
| 166 | + if not value: |
| 167 | + return value |
| 168 | + value = value.replace("-", "/") |
| 169 | + parts = value.split("/") |
| 170 | + if len(parts[-1]) == 2: |
| 171 | + parts[-1] = f"20{parts[-1]}" |
| 172 | + return "/".join(parts) |
| 173 | + |
| 174 | + |
| 175 | +def create_ground_truth_df(label_paths: List[str], id_column: str) -> pd.DataFrame: |
| 176 | + """Create ground truth DataFrame from label JSON files.""" |
| 177 | + records = [] |
| 178 | + for path in label_paths: |
| 179 | + payload = pd.read_json(Path(path), typ="series").to_dict() |
| 180 | + records.append( |
| 181 | + { |
| 182 | + id_column: Path(path).stem, |
| 183 | + "company": payload.get("company", ""), |
| 184 | + "total": payload.get("total", ""), |
| 185 | + "purchase_date": normalize_date(payload.get("date", "")), |
| 186 | + } |
| 187 | + ) |
| 188 | + |
| 189 | + df = pd.DataFrame(records) |
| 190 | + return transform_receipt_columns(df) |
| 191 | + |
| 192 | + |
| 193 | +def fuzzy_match_score(text1: str, text2: str) -> int: |
| 194 | + """Calculate fuzzy match score between two strings. |
| 195 | +
|
| 196 | + Args: |
| 197 | + text1: First string to compare |
| 198 | + text2: Second string to compare |
| 199 | +
|
| 200 | + Returns: |
| 201 | + Similarity score between 0 and 100 |
| 202 | + """ |
| 203 | + return fuzz.token_set_ratio(str(text1), str(text2)) |
| 204 | + |
| 205 | + |
| 206 | +def compare_receipts( |
| 207 | + extracted_df: pd.DataFrame, |
| 208 | + ground_truth_df: pd.DataFrame, |
| 209 | + id_column: str, |
| 210 | + fuzzy_match_cols: List[str], |
| 211 | + exact_match_cols: List[str], |
| 212 | + fuzzy_threshold: int = 80, |
| 213 | +) -> pd.DataFrame: |
| 214 | + """Compare extracted and ground truth data with explicit column specifications. |
| 215 | +
|
| 216 | + Args: |
| 217 | + extracted_df: DataFrame with extracted data |
| 218 | + ground_truth_df: DataFrame with ground truth data |
| 219 | + id_column: Column to join on |
| 220 | + fuzzy_match_cols: Columns to compare using fuzzy matching |
| 221 | + exact_match_cols: Columns to compare using exact matching |
| 222 | + fuzzy_threshold: Similarity threshold for fuzzy matching (default: 80) |
| 223 | + """ |
| 224 | + comparison_df = extracted_df.merge( |
| 225 | + ground_truth_df, |
| 226 | + on=id_column, |
| 227 | + how="inner", |
| 228 | + suffixes=("_extracted", "_truth"), |
| 229 | + ) |
| 230 | + |
| 231 | + # Fuzzy matching |
| 232 | + for col in fuzzy_match_cols: |
| 233 | + extracted_col = f"{col}_extracted" |
| 234 | + truth_col = f"{col}_truth" |
| 235 | + comparison_df[f"{col}_score"] = comparison_df.apply( |
| 236 | + lambda row, ec=extracted_col, tc=truth_col: fuzzy_match_score(row[ec], row[tc]), |
| 237 | + axis=1, |
| 238 | + ) |
| 239 | + comparison_df[f"{col}_match"] = comparison_df[f"{col}_score"] >= fuzzy_threshold |
| 240 | + |
| 241 | + # Exact matching |
| 242 | + for col in exact_match_cols: |
| 243 | + extracted_col = f"{col}_extracted" |
| 244 | + truth_col = f"{col}_truth" |
| 245 | + comparison_df[f"{col}_match"] = ( |
| 246 | + comparison_df[extracted_col] == comparison_df[truth_col] |
| 247 | + ) |
| 248 | + |
| 249 | + return comparison_df |
| 250 | + |
| 251 | + |
| 252 | +def get_mismatch_rows(comparison_df: pd.DataFrame) -> pd.DataFrame: |
| 253 | + """Get mismatched rows, excluding match indicator columns.""" |
| 254 | + # Extract match columns and data columns |
| 255 | + match_columns = [col for col in comparison_df.columns if col.endswith("_match")] |
| 256 | + data_columns = [col for col in comparison_df.columns if col not in match_columns] |
| 257 | + |
| 258 | + # Check for rows where not all matches are True |
| 259 | + has_mismatch = comparison_df[match_columns].all(axis=1).eq(False) |
| 260 | + |
| 261 | + return comparison_df[has_mismatch][data_columns] |
| 262 | + |
| 263 | + |
| 264 | +def main( |
| 265 | + receipt_paths: List[str], |
| 266 | + label_paths: List[str], |
| 267 | + preprocess: bool = True, |
| 268 | + output_dir: Path = Path("data/SROIE2019/train/img_adjusted"), |
| 269 | + id_column: str = "receipt_id", |
| 270 | +) -> None: |
| 271 | + configure_settings() |
| 272 | + |
| 273 | + # Preprocess images if requested |
| 274 | + if preprocess: |
| 275 | + print("Preprocessing receipt images...") |
| 276 | + receipt_paths_to_parse = [ |
| 277 | + scale_image(Path(p), output_dir, scale_factor=3) for p in receipt_paths |
| 278 | + ] |
| 279 | + else: |
| 280 | + receipt_paths_to_parse = receipt_paths |
| 281 | + |
| 282 | + prompt = """ |
| 283 | + You are extracting structured data from a receipt. |
| 284 | + Use the provided text to populate the Receipt model. |
| 285 | + Interpret every receipt date as day-first. |
| 286 | + If a field is missing, return null. |
| 287 | +
|
| 288 | + {context_str} |
| 289 | + """ |
| 290 | + |
| 291 | + structured_receipts = extract_documents(receipt_paths_to_parse, prompt, id_column) |
| 292 | + |
| 293 | + extracted_df = create_extracted_df(structured_receipts, id_column) |
| 294 | + ground_truth_df = create_ground_truth_df(label_paths, id_column) |
| 295 | + |
| 296 | + comparison_df = compare_receipts( |
| 297 | + extracted_df, |
| 298 | + ground_truth_df, |
| 299 | + id_column, |
| 300 | + fuzzy_match_cols=["company"], |
| 301 | + exact_match_cols=["total", "purchase_date"], |
| 302 | + ) |
| 303 | + mismatch_df = get_mismatch_rows(comparison_df) |
| 304 | + |
| 305 | + if mismatch_df.empty: |
| 306 | + print("All receipts matched the ground truth.") |
| 307 | + else: |
| 308 | + print("Mismatched receipts:") |
| 309 | + print(mismatch_df) |
| 310 | + |
| 311 | + |
| 312 | +if __name__ == "__main__": |
| 313 | + # Default paths |
| 314 | + receipt_dir = Path("data/SROIE2019/train/img") |
| 315 | + label_dir = Path("data/SROIE2019/train/entities") |
| 316 | + adjusted_receipt_dir = Path("data/SROIE2019/train/img_adjusted") |
| 317 | + |
| 318 | + # Default number of receipts |
| 319 | + num_receipts = 10 |
| 320 | + receipt_paths = sorted(receipt_dir.glob("*.jpg"))[:num_receipts] |
| 321 | + label_paths = sorted(label_dir.glob("*.txt"))[:num_receipts] |
| 322 | + |
| 323 | + # Run the pipeline |
| 324 | + main(receipt_paths, label_paths, preprocess=True, output_dir=adjusted_receipt_dir) |
0 commit comments