Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions third_party/cuquantum/cuquantum_configure.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Setup cuQuantum as external dependency"""
_CUQUANTUM_ROOT = "CUQUANTUM_ROOT"


def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out:
out = tpl
Expand All @@ -25,14 +24,12 @@ def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
substitutions,
)


def _fail(msg):
"""Output failure message when auto configuration fails."""
red = "\033[0;31m"
no_color = "\033[0m"
fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))


def _execute(
repository_ctx,
cmdline,
Expand Down Expand Up @@ -60,7 +57,6 @@ def _execute(
]))
return result


def _read_dir(repository_ctx, src_dir):
"""Returns a string with all files in a directory.

Expand All @@ -76,18 +72,17 @@ def _read_dir(repository_ctx, src_dir):
result = find_result.stdout
return result


def _find_file(repository_ctx, filename):
"""Returns a string with a directory path including the filename.

The returned string contains the parent path of the filename.
"""
result = repository_ctx.execute(
["timeout", "5", "find", "/", "-name", filename, "-print", "-quit", "-not", "-path", "'*/.*'", "-quit"]).stdout
result = result[:result.find(filename)+len(filename)]
["timeout", "5", "find", "/", "-path", "*/.*", "-prune", "-o", "-name", filename, "-print", "-quit"],
).stdout
result = result[:result.find(filename) + len(filename)]
return result


def _genrule(genrule_name, command, outs):
"""Returns a string with a genrule.

Expand Down Expand Up @@ -121,7 +116,6 @@ def _norm_path(path):
path = path[:-1]
return path


def _symlink_genrule_for_dir(
repository_ctx,
src_dir,
Expand Down Expand Up @@ -170,9 +164,9 @@ def _symlink_genrule_for_dir(
"""
if is_empty_genrule:
if dest_dir != "":
target_path = "%s/%s.h" % (dest_dir, genrule_name)
target_path = "%s/%s.h" % (dest_dir, genrule_name)
else:
target_path = genrule_name
target_path = genrule_name
genrule = _genrule(
genrule_name,
"touch $(OUTS)",
Expand Down Expand Up @@ -208,21 +202,20 @@ def _symlink_genrule_for_dir(
)
return genrule


def _cuquantum_pip_impl(repository_ctx):
if _CUQUANTUM_ROOT in repository_ctx.os.environ:
cuquantum_root = repository_ctx.os.environ[_CUQUANTUM_ROOT]
cuquantum_root = repository_ctx.os.environ[_CUQUANTUM_ROOT]
else:
repository_ctx.os.environ[_CUQUANTUM_ROOT] = ""
cuquantum_root = ""
repository_ctx.os.environ[_CUQUANTUM_ROOT] = ""
cuquantum_root = ""

if cuquantum_root == "":
cuquantum_header_path = _find_file(repository_ctx, "custatevec.h")
cuquantum_header_path = cuquantum_header_path[:cuquantum_header_path.find("/custatevec.h")]
custatevec_shared_library_path = _find_file(repository_ctx, "libcustatevec.so")
cuquantum_header_path = _find_file(repository_ctx, "custatevec.h")
cuquantum_header_path = cuquantum_header_path[:cuquantum_header_path.find("/custatevec.h")]
custatevec_shared_library_path = _find_file(repository_ctx, "libcustatevec.so")
else:
cuquantum_header_path = "%s/include" % cuquantum_root
custatevec_shared_library_path = "%s/lib/libcustatevec.so" % (cuquantum_root)
cuquantum_header_path = "%s/include" % cuquantum_root
custatevec_shared_library_path = "%s/lib/libcustatevec.so" % (cuquantum_root)

is_empty_genrule = cuquantum_header_path == "" or custatevec_shared_library_path == ""

Expand All @@ -231,7 +224,7 @@ def _cuquantum_pip_impl(repository_ctx):
cuquantum_header_path,
"include",
"cuquantum_header_include",
is_empty_genrule=is_empty_genrule,
is_empty_genrule = is_empty_genrule,
)

custatevec_shared_library_rule = _symlink_genrule_for_dir(
Expand All @@ -241,15 +234,14 @@ def _cuquantum_pip_impl(repository_ctx):
"libcustatevec.so",
[custatevec_shared_library_path],
["libcustatevec.so"],
is_empty_genrule=is_empty_genrule,
is_empty_genrule = is_empty_genrule,
)

_tpl(repository_ctx, "BUILD", {
"%{CUQUANTUM_HEADER_GENRULE}": cuquantum_header_rule,
"%{CUSTATEVEC_SHARED_LIBRARY_GENRULE}": custatevec_shared_library_rule,
})


cuquantum_configure = repository_rule(
implementation = _cuquantum_pip_impl,
environ = [
Expand Down
Loading