33from __future__ import annotations
44
55import ssl
6- from typing import TYPE_CHECKING
6+ from typing import TYPE_CHECKING , Any
77
88import requests .exceptions
99import tiktoken
2323 (1_000 , "k" ),
2424]
2525
26+ # cache tiktoken encoding for performance
27+ _TIKTOKEN_ENCODING : Any | None = None
28+
29+ def _get_tiktoken_encoding () -> Any :
30+ """Get cached tiktoken encoding, initializing only once."""
31+ global _TIKTOKEN_ENCODING
32+ if _TIKTOKEN_ENCODING is None :
33+ _TIKTOKEN_ENCODING = tiktoken .get_encoding ("o200k_base" )
34+ return _TIKTOKEN_ENCODING
35+
36+
37+ def _estimate_tokens (text : str ) -> int :
38+ """Estimate token count for a given text.
39+
40+ Parameters
41+ ----------
42+ text : str
43+ The text string for which the token count is to be estimated.
44+
45+ Returns
46+ -------
47+ int
48+ The number of tokens, or 0 if an error occurs.
49+
50+ """
51+ if not text :
52+ return 0
53+ try :
54+ encoding = _get_tiktoken_encoding ()
55+ return len (encoding .encode (text , disallowed_special = ()))
56+ except (ValueError , UnicodeEncodeError ) as exc :
57+ logger .warning ("Failed to estimate token size" , extra = {"error" : str (exc )})
58+ return 0
59+ except (requests .exceptions .RequestException , ssl .SSLError ) as exc :
60+ # if network errors, skip token count estimation instead of erroring out
61+ logger .warning ("Failed to download tiktoken model" , extra = {"error" : str (exc )})
62+ return 0
63+
64+
65+ def _format_token_number (count : int ) -> str :
66+ """Return a human-readable token-count string (e.g. 1.2k, 1.2M).
67+
68+ Parameters
69+ ----------
70+ count : int
71+ The token count to format.
72+
73+ Returns
74+ -------
75+ str
76+ The formatted number of tokens as a string (e.g., ``"1.2k"``, ``"1.2M"``), or empty string if count is 0.
77+
78+ """
79+ if count == 0 :
80+ return ""
81+ for threshold , suffix in _TOKEN_THRESHOLDS :
82+ if count >= threshold :
83+ return f"{ count / threshold :.1f} { suffix } "
84+ return str (count )
85+
2686
2787def format_node (node : FileSystemNode , query : IngestionQuery ) -> tuple [str , str , str ]:
2888 """Generate a summary, directory structure, and file contents for a given file system node.
@@ -51,9 +111,17 @@ def format_node(node: FileSystemNode, query: IngestionQuery) -> tuple[str, str,
51111 summary += f"File: { node .name } \n "
52112 summary += f"Lines: { len (node .content .splitlines ()):,} \n "
53113
114+ content = _gather_file_contents (node )
115+
54116 tree = "Directory structure:\n " + _create_tree_structure (query , node = node )
55117
56- content = _gather_file_contents (node )
118+ # calculate total tokens for entire digest (tree + content) - what users download/copy
119+ total_tokens = _estimate_tokens (tree + content )
120+
121+ # set root node token count to match the total exactly
122+ node .token_count = total_tokens
123+
124+ tree = "Directory structure:\n " + _create_tree_structure (query , node = node )
57125
58126 token_estimate = _format_token_count (tree + content )
59127 if token_estimate :
@@ -107,6 +175,7 @@ def _gather_file_contents(node: FileSystemNode) -> str:
107175
108176 This function recursively processes a directory node and gathers the contents of all files
109177 under that node. It returns the concatenated content of all files as a single string.
178+ Also calculates and aggregates token counts during traversal.
110179
111180 Parameters
112181 ----------
@@ -120,10 +189,17 @@ def _gather_file_contents(node: FileSystemNode) -> str:
120189
121190 """
122191 if node .type != FileSystemNodeType .DIRECTORY :
192+ node .token_count = _estimate_tokens (node .content )
123193 return node .content_string
124194
125- # Recursively gather contents of all files under the current directory
126- return "\n " .join (_gather_file_contents (child ) for child in node .children )
195+ # recursively gather contents and aggregate token counts
196+ node .token_count = 0
197+ contents = []
198+ for child in node .children :
199+ contents .append (_gather_file_contents (child ))
200+ node .token_count += child .token_count
201+
202+ return "\n " .join (contents )
127203
128204
129205def _create_tree_structure (
@@ -169,6 +245,10 @@ def _create_tree_structure(
169245 elif node .type == FileSystemNodeType .SYMLINK :
170246 display_name += " -> " + readlink (node .path ).name
171247
248+ if node .token_count > 0 :
249+ formatted_tokens = _format_token_number (node .token_count )
250+ display_name += f" ({ formatted_tokens } tokens)"
251+
172252 tree_str += f"{ prefix } { current_prefix } { display_name } \n "
173253
174254 if node .type == FileSystemNodeType .DIRECTORY and node .children :
@@ -192,15 +272,8 @@ def _format_token_count(text: str) -> str | None:
192272 The formatted number of tokens as a string (e.g., ``"1.2k"``, ``"1.2M"``), or ``None`` if an error occurs.
193273
194274 """
195- try :
196- encoding = tiktoken .get_encoding ("o200k_base" ) # gpt-4o, gpt-4o-mini
197- total_tokens = len (encoding .encode (text , disallowed_special = ()))
198- except (ValueError , UnicodeEncodeError ) as exc :
199- logger .warning ("Failed to estimate token size" , extra = {"error" : str (exc )})
200- return None
201- except (requests .exceptions .RequestException , ssl .SSLError ) as exc :
202- # If network errors, skip token count estimation instead of erroring out
203- logger .warning ("Failed to download tiktoken model" , extra = {"error" : str (exc )})
275+ total_tokens = _estimate_tokens (text )
276+ if total_tokens == 0 :
204277 return None
205278
206279 for threshold , suffix in _TOKEN_THRESHOLDS :
0 commit comments