Skip to content

Commit 28872e3

Browse files
Copilotgkorland
andcommitted
Add Import support for Python - initial implementation
Co-authored-by: gkorland <753206+gkorland@users.noreply.github.com>
1 parent 07b7eac commit 28872e3

File tree

5 files changed

+149
-0
lines changed

5 files changed

+149
-0
lines changed

api/analyzers/analyzer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,32 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_
143143

144144
pass
145145

146+
@abstractmethod
147+
def add_file_imports(self, file: File) -> None:
148+
"""
149+
Add import statements to the file.
150+
151+
Args:
152+
file (File): The file to add imports to.
153+
"""
154+
155+
pass
156+
157+
@abstractmethod
158+
def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]:
159+
"""
160+
Resolve an import statement to entities.
161+
162+
Args:
163+
files (dict[Path, File]): All files in the project.
164+
lsp (SyncLanguageServer): The language server.
165+
file_path (Path): The path to the file containing the import.
166+
path (Path): The path to the project root.
167+
import_node (Node): The import statement node.
168+
169+
Returns:
170+
list[Entity]: List of resolved entities.
171+
"""
172+
173+
pass
174+

api/analyzers/java/analyzer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,19 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_
132132
return self.resolve_method(files, lsp, file_path, path, symbol)
133133
else:
134134
raise ValueError(f"Unknown key {key}")
135+
136+
def add_file_imports(self, file: File) -> None:
137+
"""
138+
Extract and add import statements from the file.
139+
Java imports are not yet implemented.
140+
"""
141+
# TODO: Implement Java import tracking
142+
pass
143+
144+
def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]:
145+
"""
146+
Resolve an import statement to the entities it imports.
147+
Java imports are not yet implemented.
148+
"""
149+
# TODO: Implement Java import resolution
150+
return []

api/analyzers/python/analyzer.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,3 +121,73 @@ def resolve_symbol(self, files: dict[Path, File], lsp: SyncLanguageServer, file_
121121
return self.resolve_method(files, lsp, file_path, path, symbol)
122122
else:
123123
raise ValueError(f"Unknown key {key}")
124+
125+
def add_file_imports(self, file: File) -> None:
126+
"""
127+
Extract and add import statements from the file.
128+
"""
129+
import warnings
130+
with warnings.catch_warnings():
131+
warnings.simplefilter("ignore")
132+
# Query for both import types
133+
import_query = self.language.query("""
134+
(import_statement) @import
135+
(import_from_statement) @import_from
136+
""")
137+
138+
captures = import_query.captures(file.tree.root_node)
139+
140+
# Add all import statement nodes to the file
141+
if 'import' in captures:
142+
for import_node in captures['import']:
143+
file.add_import(import_node)
144+
145+
if 'import_from' in captures:
146+
for import_node in captures['import_from']:
147+
file.add_import(import_node)
148+
149+
def resolve_import(self, files: dict[Path, File], lsp: SyncLanguageServer, file_path: Path, path: Path, import_node: Node) -> list[Entity]:
150+
"""
151+
Resolve an import statement to the entities it imports.
152+
"""
153+
res = []
154+
155+
# For import statements like "import os" or "from pathlib import Path"
156+
# We need to find the dotted_name nodes that represent the imported modules/names
157+
import warnings
158+
with warnings.catch_warnings():
159+
warnings.simplefilter("ignore")
160+
if import_node.type == 'import_statement':
161+
# Handle "import module" or "import module as alias"
162+
# Look for dotted_name or aliased_import
163+
query = self.language.query("(dotted_name) @module (aliased_import) @aliased")
164+
else: # import_from_statement
165+
# Handle "from module import name"
166+
# Get the imported names (after the 'import' keyword)
167+
query = self.language.query("""
168+
(import_from_statement
169+
(dotted_name) @imported_name)
170+
""")
171+
172+
captures = query.captures(import_node)
173+
174+
# Try to resolve each imported name
175+
if 'module' in captures:
176+
for module_node in captures['module']:
177+
resolved = self.resolve_type(files, lsp, file_path, path, module_node)
178+
res.extend(resolved)
179+
180+
if 'aliased' in captures:
181+
for aliased_node in captures['aliased']:
182+
# Get the actual module name from the aliased import
183+
if aliased_node.child_count > 0:
184+
module_name_node = aliased_node.children[0]
185+
resolved = self.resolve_type(files, lsp, file_path, path, module_name_node)
186+
res.extend(resolved)
187+
188+
if 'imported_name' in captures:
189+
for name_node in captures['imported_name']:
190+
resolved = self.resolve_type(files, lsp, file_path, path, name_node)
191+
res.extend(resolved)
192+
193+
return res

api/analyzers/source_analyzer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def first_pass(self, path: Path, files: list[Path], ignore: list[str], graph: Gr
112112
# Walk thought the AST
113113
graph.add_file(file)
114114
self.create_hierarchy(file, analyzer, graph)
115+
116+
# Extract import statements
117+
if not analyzer.is_dependency(str(file_path)):
118+
analyzer.add_file_imports(file)
115119

116120
def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None:
117121
"""
@@ -141,6 +145,8 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None:
141145
for i, file_path in enumerate(files):
142146
file = self.files[file_path]
143147
logging.info(f'Processing file ({i + 1}/{files_len}): {file_path}')
148+
149+
# Resolve entity symbols
144150
for _, entity in file.entities.items():
145151
entity.resolved_symbol(lambda key, symbol: analyzers[file_path.suffix].resolve_symbol(self.files, lsps[file_path.suffix], file_path, path, key, symbol))
146152
for key, symbols in entity.resolved_symbols.items():
@@ -157,6 +163,13 @@ def second_pass(self, graph: Graph, files: list[Path], path: Path) -> None:
157163
graph.connect_entities("RETURNS", entity.id, symbol.id)
158164
elif key == "parameters":
159165
graph.connect_entities("PARAMETERS", entity.id, symbol.id)
166+
167+
# Resolve file imports
168+
for import_node in file.imports:
169+
resolved_entities = analyzers[file_path.suffix].resolve_import(self.files, lsps[file_path.suffix], file_path, path, import_node)
170+
for resolved_entity in resolved_entities:
171+
file.add_resolved_import(resolved_entity)
172+
graph.connect_entities("IMPORTS", file.id, resolved_entity.id)
160173

161174
def analyze_files(self, files: list[Path], path: Path, graph: Graph) -> None:
162175
self.first_pass(path, files, [], graph)

api/entities/file.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22
from tree_sitter import Node, Tree
3+
from typing import Self
34

45
from api.entities.entity import Entity
56

@@ -21,10 +22,30 @@ def __init__(self, path: Path, tree: Tree) -> None:
2122
self.path = path
2223
self.tree = tree
2324
self.entities: dict[Node, Entity] = {}
25+
self.imports: list[Node] = []
26+
self.resolved_imports: set[Self] = set()
2427

2528
def add_entity(self, entity: Entity):
2629
entity.parent = self
2730
self.entities[entity.node] = entity
31+
32+
def add_import(self, import_node: Node):
33+
"""
34+
Add an import statement node to track.
35+
36+
Args:
37+
import_node (Node): The import statement node.
38+
"""
39+
self.imports.append(import_node)
40+
41+
def add_resolved_import(self, resolved_entity: Self):
42+
"""
43+
Add a resolved import entity.
44+
45+
Args:
46+
resolved_entity (Self): The resolved entity that is imported.
47+
"""
48+
self.resolved_imports.add(resolved_entity)
2849

2950
def __str__(self) -> str:
3051
return f"path: {self.path}"

0 commit comments

Comments
 (0)