diff --git a/src/databricks/labs/ucx/source_code/linters/directfs.py b/src/databricks/labs/ucx/source_code/linters/directfs.py index 7de99cd938..8e9ca9137a 100644 --- a/src/databricks/labs/ucx/source_code/linters/directfs.py +++ b/src/databricks/labs/ucx/source_code/linters/directfs.py @@ -1,10 +1,12 @@ import logging from abc import ABC from collections.abc import Iterable +from typing import Any from astroid import Call, InferenceError, NodeNG # type: ignore from sqlglot.expressions import Alter, Create, Delete, Drop, Expression, Identifier, Insert, Literal, Select +from databricks.labs.ucx.hive_metastore import TablesCrawler from databricks.labs.ucx.source_code.base import ( Advice, Deprecation, @@ -14,6 +16,7 @@ DirectFsAccess, ) from databricks.labs.ucx.source_code.linters.base import SqlLinter, PythonLinter, DfsaPyCollector +from databricks.labs.ucx.source_code.directfs_access import DirectFsAccessCrawler from databricks.labs.ucx.source_code.python.python_ast import ( Tree, TreeVisitor, @@ -205,3 +208,77 @@ def _walk_up(cls, expression: Expression | None) -> Expression | None: if isinstance(expression, (Create, Alter, Drop, Insert, Delete, Select)): return expression return cls._walk_up(expression.parent) + +class DirectFsAccessPyFixer(DirectFsAccessPyLinter): + def __init__(self, + session_state: CurrentSessionState, + directfs_crawler: DirectFsAccessCrawler, + tables_crawler: TablesCrawler, + prevent_spark_duplicates=True, + ): + super().__init__(session_state, prevent_spark_duplicates) + self.directfs_crawler = directfs_crawler + self.tables_crawler = tables_crawler + self.direct_fs_table_list:[Any, [dict[str,str], Any]] = [] + + def fix_tree(self, tree: Tree) -> Tree: + for directfs_node in self.collect_dfsas_from_tree(tree): + self._fix_node(directfs_node) + return tree + + def _fix_node(self, directfs_node: DirectFsAccessNode) -> None: + dfsa = directfs_node.dfsa + if dfsa.is_read: + self._replace_read(directfs_node) + elif dfsa.is_write: + self._replace_write(directfs_node) + + def _replace_read(self, directfs_node: DirectFsAccessNode) -> None: + dfsa = directfs_node.dfsa + dfsa_details = self.direct_fs_table_list[dfsa.path] + + # TODO: Actual code replacement + logger.info(f"Replacing read of {dfsa.path} with table {dfsa_details.dst_schema}.{dfsa_details.dst_table}") + + def _replace_write(self, directfs_node): + dfsa = directfs_node.dfsa + logger.info(f"Replacing read of {dfsa.path} with table") + + def populate_directfs_table_list( + self, + directfs_crawlers: list[DirectFsAccessCrawler], + tables_crawler: TablesCrawler, + workspace_name: str, + catalog_name: str, + ) -> None: + """ + List all direct filesystem access records. + """ + directfs_snapshot = [] + for crawler in directfs_crawlers: + for directfs_access in crawler.snapshot(): + directfs_snapshot.append(directfs_access) + tables_snapshot = list(tables_crawler.snapshot()) + if not tables_snapshot: + msg = "No tables found. Please run: databricks labs ucx ensure-assessment-run" + raise ValueError(msg) + if not directfs_snapshot: + msg = "No directfs references found in code" + raise ValueError(msg) + + # TODO: very inefficient search, just for initial testing + # + for table in tables_snapshot: + for directfs_record in directfs_snapshot: + if table.location: + if directfs_record.path in table.location: + self.direct_fs_table_list.append({ + directfs_record.path:{ + "workspace_name":workspace_name, + "is_read":directfs_record.is_read, + "is_write":directfs_record.is_write, + "catalog_name":catalog_name, + "dst_schema":table.database, + "dst_table":table.name, + } + }) diff --git a/tests/integration/source_code/test_directfs_access.py b/tests/integration/source_code/test_directfs_access.py index a4e2a911bf..fa3d3fd517 100644 --- a/tests/integration/source_code/test_directfs_access.py +++ b/tests/integration/source_code/test_directfs_access.py @@ -1,8 +1,11 @@ import pytest from databricks.labs.ucx.hive_metastore.table_migration_status import TableMigrationIndex -from databricks.labs.ucx.source_code.base import DirectFsAccess, LineageAtom +from databricks.labs.ucx.source_code.base import DirectFsAccess, LineageAtom, CurrentSessionState from databricks.labs.ucx.source_code.linters.jobs import WorkflowLinter +from databricks.labs.ucx.source_code.linters.directfs import DirectFsAccessPyFixer +from databricks.labs.ucx.source_code.python.python_ast import Tree +from integration.conftest import runtime_ctx def test_legacy_query_dfsa_ownership(runtime_ctx) -> None: @@ -110,3 +113,48 @@ def test_path_dfsa_ownership( # Verify ownership can be made. owner = runtime_ctx.directfs_access_ownership.owner_of(path_record) assert owner == runtime_ctx.workspace_client.current_user.me().user_name + +def test_path_dfsa_replacement( + runtime_ctx, + make_directory, + make_mounted_location, + inventory_schema, + sql_backend, +) -> None: + """Verify that the direct-fs access in python notebook is replaced with Unity catalog table""" + + mounted_location = '/mnt/things/e/f/g' + external_table = runtime_ctx.make_table(external_csv=mounted_location, + ) + notebook_content = f"display(spark.read.csv('{mounted_location}'))" + notebook = runtime_ctx.make_notebook(path=f"{make_directory()}/notebook.py", + content=notebook_content.encode("ASCII")) + job = runtime_ctx.make_job(notebook_path=notebook) + + # # Produce a DFSA record for the job. + linter = WorkflowLinter( + runtime_ctx.workspace_client, + runtime_ctx.dependency_resolver, + runtime_ctx.path_lookup, + TableMigrationIndex([]), + runtime_ctx.directfs_access_crawler_for_paths, + runtime_ctx.used_tables_crawler_for_paths, + include_job_ids=[job.job_id], + ) + linter.refresh_report(sql_backend, inventory_schema) + + runtime_ctx.tables_crawler.snapshot() + runtime_ctx.directfs_access_crawler_for_paths.snapshot() + + session_state = CurrentSessionState() + directfs_py_fixer = DirectFsAccessPyFixer(session_state, + runtime_ctx.directfs_access_crawler_for_paths, + runtime_ctx.tables_crawler) + directfs_py_fixer.populate_directfs_table_list([runtime_ctx.directfs_access_crawler_for_paths], + runtime_ctx.tables_crawler, + "workspace_name", + "catalog_name") + + assert True + directfs_py_fixer.fix_tree(Tree.maybe_normalized_parse(notebook_content).tree) + assert True