Skip to content

Commit cc5e5c9

Browse files
committed
feat: add token counting utils with caching
1 parent ffb035b commit cc5e5c9

File tree

1 file changed

+86
-13
lines changed

1 file changed

+86
-13
lines changed

src/gitingest/output_formatter.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import ssl
6-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, Any
77

88
import requests.exceptions
99
import tiktoken
@@ -23,6 +23,66 @@
2323
(1_000, "k"),
2424
]
2525

26+
# cache tiktoken encoding for performance
27+
_TIKTOKEN_ENCODING: Any | None = None
28+
29+
def _get_tiktoken_encoding() -> Any:
30+
"""Get cached tiktoken encoding, initializing only once."""
31+
global _TIKTOKEN_ENCODING
32+
if _TIKTOKEN_ENCODING is None:
33+
_TIKTOKEN_ENCODING = tiktoken.get_encoding("o200k_base")
34+
return _TIKTOKEN_ENCODING
35+
36+
37+
def _estimate_tokens(text: str) -> int:
38+
"""Estimate token count for a given text.
39+
40+
Parameters
41+
----------
42+
text : str
43+
The text string for which the token count is to be estimated.
44+
45+
Returns
46+
-------
47+
int
48+
The number of tokens, or 0 if an error occurs.
49+
50+
"""
51+
if not text:
52+
return 0
53+
try:
54+
encoding = _get_tiktoken_encoding()
55+
return len(encoding.encode(text, disallowed_special=()))
56+
except (ValueError, UnicodeEncodeError) as exc:
57+
logger.warning("Failed to estimate token size", extra={"error": str(exc)})
58+
return 0
59+
except (requests.exceptions.RequestException, ssl.SSLError) as exc:
60+
# if network errors, skip token count estimation instead of erroring out
61+
logger.warning("Failed to download tiktoken model", extra={"error": str(exc)})
62+
return 0
63+
64+
65+
def _format_token_number(count: int) -> str:
66+
"""Return a human-readable token-count string (e.g. 1.2k, 1.2M).
67+
68+
Parameters
69+
----------
70+
count : int
71+
The token count to format.
72+
73+
Returns
74+
-------
75+
str
76+
The formatted number of tokens as a string (e.g., ``"1.2k"``, ``"1.2M"``), or empty string if count is 0.
77+
78+
"""
79+
if count == 0:
80+
return ""
81+
for threshold, suffix in _TOKEN_THRESHOLDS:
82+
if count >= threshold:
83+
return f"{count / threshold:.1f}{suffix}"
84+
return str(count)
85+
2686

2787
def format_node(node: FileSystemNode, query: IngestionQuery) -> tuple[str, str, str]:
2888
"""Generate a summary, directory structure, and file contents for a given file system node.
@@ -51,9 +111,17 @@ def format_node(node: FileSystemNode, query: IngestionQuery) -> tuple[str, str,
51111
summary += f"File: {node.name}\n"
52112
summary += f"Lines: {len(node.content.splitlines()):,}\n"
53113

114+
content = _gather_file_contents(node)
115+
54116
tree = "Directory structure:\n" + _create_tree_structure(query, node=node)
55117

56-
content = _gather_file_contents(node)
118+
# calculate total tokens for entire digest (tree + content) - what users download/copy
119+
total_tokens = _estimate_tokens(tree + content)
120+
121+
# set root node token count to match the total exactly
122+
node.token_count = total_tokens
123+
124+
tree = "Directory structure:\n" + _create_tree_structure(query, node=node)
57125

58126
token_estimate = _format_token_count(tree + content)
59127
if token_estimate:
@@ -107,6 +175,7 @@ def _gather_file_contents(node: FileSystemNode) -> str:
107175
108176
This function recursively processes a directory node and gathers the contents of all files
109177
under that node. It returns the concatenated content of all files as a single string.
178+
Also calculates and aggregates token counts during traversal.
110179
111180
Parameters
112181
----------
@@ -120,10 +189,17 @@ def _gather_file_contents(node: FileSystemNode) -> str:
120189
121190
"""
122191
if node.type != FileSystemNodeType.DIRECTORY:
192+
node.token_count = _estimate_tokens(node.content)
123193
return node.content_string
124194

125-
# Recursively gather contents of all files under the current directory
126-
return "\n".join(_gather_file_contents(child) for child in node.children)
195+
# recursively gather contents and aggregate token counts
196+
node.token_count = 0
197+
contents = []
198+
for child in node.children:
199+
contents.append(_gather_file_contents(child))
200+
node.token_count += child.token_count
201+
202+
return "\n".join(contents)
127203

128204

129205
def _create_tree_structure(
@@ -169,6 +245,10 @@ def _create_tree_structure(
169245
elif node.type == FileSystemNodeType.SYMLINK:
170246
display_name += " -> " + readlink(node.path).name
171247

248+
if node.token_count > 0:
249+
formatted_tokens = _format_token_number(node.token_count)
250+
display_name += f" ({formatted_tokens} tokens)"
251+
172252
tree_str += f"{prefix}{current_prefix}{display_name}\n"
173253

174254
if node.type == FileSystemNodeType.DIRECTORY and node.children:
@@ -192,15 +272,8 @@ def _format_token_count(text: str) -> str | None:
192272
The formatted number of tokens as a string (e.g., ``"1.2k"``, ``"1.2M"``), or ``None`` if an error occurs.
193273
194274
"""
195-
try:
196-
encoding = tiktoken.get_encoding("o200k_base") # gpt-4o, gpt-4o-mini
197-
total_tokens = len(encoding.encode(text, disallowed_special=()))
198-
except (ValueError, UnicodeEncodeError) as exc:
199-
logger.warning("Failed to estimate token size", extra={"error": str(exc)})
200-
return None
201-
except (requests.exceptions.RequestException, ssl.SSLError) as exc:
202-
# If network errors, skip token count estimation instead of erroring out
203-
logger.warning("Failed to download tiktoken model", extra={"error": str(exc)})
275+
total_tokens = _estimate_tokens(text)
276+
if total_tokens == 0:
204277
return None
205278

206279
for threshold, suffix in _TOKEN_THRESHOLDS:

0 commit comments

Comments
 (0)