diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index efa1a96db9..47a8c5c376 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -160,6 +160,15 @@ #include "xla/hlo/ir/hlo_module.h" #include "xla/service/hlo_cost_analysis.h" +#include "xla/client/client_library.h" +#include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" +#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/service/compiler.h" +#include "xla/service/cpu/backend_config.pb.h" +#include "xla/service/cpu/cpu_executable.h" +#include "xla/service/local_service_utils.h" +#include "xla/service/service.h" + #if defined(REACTANT_CUDA) || defined(REACTANT_ROCM) #include "xla/service/gpu/model/gpu_hlo_cost_analysis.h" #include "xla/service/gpu/model/gpu_performance_model.h" @@ -1211,6 +1220,9 @@ GenerateCompileOptions(int64_t device_id, const int64_t *mesh_ids, debug_options->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir); debug_options->set_xla_enable_enzyme_comms_opt(true); + // TODO: make this an option + debug_options->set_xla_embed_ir_in_executable(true); + if (kernel_cache_enabled) { debug_options->set_xla_gpu_kernel_cache_file(kernel_cache_path); debug_options->set_xla_gpu_enable_llvm_module_compilation_parallelism(true); @@ -3505,3 +3517,100 @@ REACTANT_ABI void EstimateRunTimeForInstruction(void *gpu_performance_model, } #endif + +REACTANT_ABI const char *ConvertMLIRModuleToLLVMIR(MlirModule mod) { + auto cmod_op = cast(*unwrap(mod)); + + xla::HloProto hlo_proto; + mlir::MlirToHloConversionOptions options; + options.use_tuple_args = false; + options.return_tuple = false; + auto status = mlir::ConvertMlirHloToHlo(cmod_op, &hlo_proto, options); + if (!status.ok()) { + ReactantThrowError(status.ToString().c_str()); + } + + for (auto &computation : + *hlo_proto.mutable_hlo_module()->mutable_computations()) { + if (computation.id() != hlo_proto.hlo_module().entry_computation_id()) + continue; + // Assume root is the last instruction. + xla::HloInstructionProto &instruction = + *computation.mutable_instructions()->rbegin(); + xla::cpu::BackendConfig backend_config; + backend_config.ParseFromString(instruction.backend_config()); + backend_config.Clear(); + instruction.set_backend_config(backend_config.SerializeAsString()); + break; + } + + xla::XlaComputation xla_computation(hlo_proto.hlo_module()); + + // Extract and convert the shapes fro MHLO. + std::vector shapes; + mlir::SymbolTable symbol_table(cmod_op); + auto entry_point = symbol_table.lookup("main"); + shapes.reserve(entry_point.getNumArguments()); + for (mlir::Type type : entry_point.getArgumentTypes()) { + shapes.push_back(xla::TypeToShape(type)); + } + std::vector shape_pointers; + shape_pointers.reserve(shapes.size()); + for (xla::Shape &shape : shapes) { + shape_pointers.push_back(&shape); + } + + absl::StatusOr local_client_or_error = + xla::ClientLibrary::GetOrCreateLocalClient(); + if (!local_client_or_error.ok()) { + llvm::errs() << "failed to get local client\n"; + ReactantThrowError(local_client_or_error.status().ToString().c_str()); + } + xla::LocalClient *local_client = local_client_or_error.value(); + + xla::ExecutableBuildOptions build_options; + build_options.mutable_debug_options()->set_xla_embed_ir_in_executable(true); + + if (build_options.device_ordinal() == -1) { + build_options.set_device_ordinal(local_client->default_device_ordinal()); + } + + absl::StatusOr> module_config_or_error = + xla::GetHloModuleConfig( + xla_computation, shape_pointers, build_options, + // TODO: how to get the service options? + // /*(service) options=*/&local_client->local_service()->options_, + nullptr, local_client->mutable_backend()); + if (!module_config_or_error.ok()) { + llvm::errs() << "failed to get hlo module config\n"; + ReactantThrowError(module_config_or_error.status().ToString().c_str()); + } + + auto executor = local_client->mutable_backend()->stream_executor( + build_options.device_ordinal()); + if (!executor.ok()) { + llvm::errs() << "failed to get stream executor\n"; + ReactantThrowError(executor.status().ToString().c_str()); + } + + xla::Compiler::CompileOptions opts = { + build_options.device_allocator(), build_options.compile_thread_pool(), + build_options.layout_canonicalization_callback()}; + auto executable = local_client->local_service()->BuildExecutable( + xla_computation.proto(), std::move(module_config_or_error.value()), + local_client->mutable_backend(), executor.value(), opts, + build_options.run_backend_only()); + if (!executable.ok()) { + llvm::errs() << "failed to build executable\n"; + ReactantThrowError(executable.status().ToString().c_str()); + } + + auto local_executable = std::make_unique( + std::move(executable.value()), + local_client->local_service()->mutable_backend(), build_options); + + auto *cpu_executable = + static_cast(local_executable->executable()); + + return cstr_from_string(cpu_executable->ir_module_string()); +} diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index 7be2938210..ca365c493a 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -1059,6 +1059,7 @@ cc_library( "-Wl,-exported_symbol,_CreateGPUPerformanceModel", "-Wl,-exported_symbol,_RunAnalysisOnHloModule", "-Wl,-exported_symbol,_EstimateRunTimeForInstruction", + "-Wl,-exported_symbol,_ConvertMLIRModuleToLLVMIR", ], }), linkstatic = True, @@ -1194,6 +1195,12 @@ cc_library( "@xla//xla/tsl/platform:errors", "@xla//xla/service:hlo_proto_cc_impl", "@com_google_absl//absl/status:statusor", + "@xla//xla/client:client_library", + "@xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", + "@xla//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", + "@xla//xla/service:local_service_utils", + "@xla//xla/service", + "@xla//xla/service/cpu:backend_config_proto_cc", ] + if_cuda([ "@xla//xla/stream_executor/cuda:cuda_compute_capability_proto_cc_impl", "@xla//xla/service:gpu_plugin",