diff --git a/.github/workflows/update-enzyme-jax.yml b/.github/workflows/update-enzyme-jax.yml new file mode 100644 index 0000000000..af20112d24 --- /dev/null +++ b/.github/workflows/update-enzyme-jax.yml @@ -0,0 +1,28 @@ +name: "Open PR to update Enzyme-JAX commit" + +on: + schedule: + - cron: '19 16 * * *' + workflow_dispatch: + inputs: + enzyme_jax_commit: + description: 'The Enzyme-JAX commit to update to (optional)' + default: '' + type: 'string' + +concurrency: + # Skip intermediate builds: always. + # Cancel intermediate builds: only if it is a pull request build. + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} + +jobs: + pr-latest-enzyme-jax: + name: 'Update Enzyme-JAX' + uses: EnzymeAD/Enzyme-JAX/.github/workflows/update-dependency.yml@main + with: + upstream_repo: 'EnzymeAD/Enzyme-JAX' + upstream_commit: ${{ inputs.enzyme_jax_commit }} + variable_name: 'ENZYMEXLA_COMMIT' + workspace_path: 'deps/ReactantExtra/WORKSPACE' + secrets: inherit diff --git a/Project.toml b/Project.toml index fbde73622a..ce091e34c7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal ", "Mosè Giordano "] -version = "0.2.171" +version = "0.2.172" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -105,7 +105,7 @@ PythonCall = "0.9.25" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.16" -Reactant_jll = "0.0.251" +Reactant_jll = "0.0.254" ScopedValues = "1.3.0" Scratch = "1.2" Sockets = "1.10" diff --git a/benchmark/aggregate.jl b/benchmark/aggregate.jl index aaf4e25fa5..7d71352255 100644 --- a/benchmark/aggregate.jl +++ b/benchmark/aggregate.jl @@ -15,5 +15,5 @@ for backend in BACKENDS end open(joinpath(dirname(@__FILE__), "results", "combinedbenchmarks.json"), "w") do io - JSON3.pretty(io, JSON3.write(all_results)) + return JSON3.pretty(io, JSON3.write(all_results)) end diff --git a/benchmark/runbenchmarks.jl b/benchmark/runbenchmarks.jl index 1a80c3a0d7..3c389e0ad2 100644 --- a/benchmark/runbenchmarks.jl +++ b/benchmark/runbenchmarks.jl @@ -44,7 +44,7 @@ for (i, (k, v)) in enumerate(results) end open(joinpath(filepath, filename), "w") do io - JSON3.pretty(io, JSON3.write(standardized_results)) + return JSON3.pretty(io, JSON3.write(standardized_results)) end @info "Saved results to $(joinpath(filepath, filename))" diff --git a/deps/ReactantExtra/.bazelrc b/deps/ReactantExtra/.bazelrc index 371c1cb70b..eef15c815a 100644 --- a/deps/ReactantExtra/.bazelrc +++ b/deps/ReactantExtra/.bazelrc @@ -29,9 +29,13 @@ build --repo_env=RULES_PYTHON_ENABLE_PYSTAR=0 build -c opt +common:macos --define ynn_enable_arm64_sme=false + common:cuda --repo_env TF_NEED_CUDA=1 common:cuda --repo_env TF_NVCC_CLANG=1 common:cuda --repo_env TF_NCCL_USE_STUB=1 +common:cuda_static --@rules_ml_toolchain//common:link_cuda_static_libs=true +common:cuda_static --@rules_ml_toolchain//common:link_nvrtc_static_libs=true common:cuda --@local_config_cuda//:enable_cuda common:cuda --crosstool_top="@local_config_cuda//crosstool:toolchain" # Default hermetic CUDA and CUDNN versions. @@ -40,19 +44,19 @@ common:cuda --@local_config_cuda//:cuda_compiler=nvcc # common:cuda --@local_config_nvshmem//:override_include_nvshmem_libs=true # common:cuda --@local_config_nvshmem//cuda:include_nvshmem_libs=true - common:cuda12 --config=cuda -common:cuda12 --repo_env=HERMETIC_CUDA_VERSION="12.8.1" -common:cuda12 --repo_env=HERMETIC_CUDNN_VERSION="9.8.0" -common:cuda12 --repo_env=HERMETIC_NVSHMEM_VERSION="3.2.5" +common:cuda12 --config=cuda_static +common:cuda12 --repo_env=HERMETIC_CUDA_VERSION="12.9.1" +common:cuda12 --repo_env=HERMETIC_CUDNN_VERSION="9.14.0" +common:cuda12 --repo_env=HERMETIC_NVSHMEM_VERSION="3.3.9" # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -common:cuda12 --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_60,sm_70,sm_80,compute_90" +common:cuda12 --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,sm_90,sm_100,compute_120" common:cuda13 --config=cuda -common:cuda13 --repo_env=HERMETIC_CUDA_VERSION="13.0.0" -common:cuda13 --repo_env=HERMETIC_CUDNN_VERSION="9.12.0" +common:cuda13 --repo_env=HERMETIC_CUDA_VERSION="13.0.2" +common:cuda13 --repo_env=HERMETIC_CUDNN_VERSION="9.14.0" common:cuda13 --repo_env=HERMETIC_NVSHMEM_VERSION="3.3.20" # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. diff --git a/deps/ReactantExtra/BUILD b/deps/ReactantExtra/BUILD index a9a7666e87..e014f71e80 100644 --- a/deps/ReactantExtra/BUILD +++ b/deps/ReactantExtra/BUILD @@ -928,6 +928,7 @@ cc_library( "//conditions:default": [], "@bazel_tools//src/conditions:darwin": [ "-Wl,-exported_symbol,_stablehlo*", + "-Wl,-exported_symbol,_enzymexla*", "-Wl,-exported_symbol,_mlir*", "-Wl,-exported_symbol,_sdy*", "-Wl,-exported_symbol,_EnzymeJaXMapSymbol", @@ -1078,6 +1079,7 @@ cc_library( "@llvm-project//llvm:X86CodeGen", "@enzyme_ad//src/enzyme_ad/jax:TransformOps", "@enzyme_ad//src/enzyme_ad/jax:XLADerivatives", + "@enzyme_ad//src/enzyme_ad/jax:CInterface", # "@enzyme_ad//src/enzyme_ad/jax:gpu", "@xla//xla/ffi/api:ffi", "@xla//xla/ffi:ffi_api", diff --git a/deps/ReactantExtra/WORKSPACE b/deps/ReactantExtra/WORKSPACE index a46b45c33c..d15a5f0726 100644 --- a/deps/ReactantExtra/WORKSPACE +++ b/deps/ReactantExtra/WORKSPACE @@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023" NSYNC_SHA256 = "" -ENZYMEXLA_COMMIT = "a7c38bce984c3adedafb8e03282c0e39640ab6f9" +ENZYMEXLA_COMMIT = "c5b0090d53998673b2f728b7590b97d7bc548d2b" ENZYMEXLA_SHA256 = "" diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index f84309fef1..ef14ab82f1 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -22,26 +22,26 @@ end src_dir = joinpath(dirname(dirname(@__DIR__)), "src") for file in [ - "Builtin.jl", - "Arith.jl", - "Affine.jl", - "Func.jl", - "Enzyme.jl", + # "Builtin.jl", + # "Arith.jl", + # "Affine.jl", + # "Func.jl", + # "Enzyme.jl", "EnzymeXLA.jl", - "StableHLO.jl", - "CHLO.jl", - "VHLO.jl", - "Llvm.jl", - "Nvvm.jl", - "Gpu.jl", - "Affine.jl", - "TPU.jl", - "MosaicGPU.jl", - "Triton.jl", - "Shardy.jl", - "MPI.jl", - "MemRef.jl", - "SparseTensor.jl", + # "StableHLO.jl", + # "CHLO.jl", + # "VHLO.jl", + # "Llvm.jl", + # "Nvvm.jl", + # "Gpu.jl", + # "Affine.jl", + # "TPU.jl", + # "MosaicGPU.jl", + # "Triton.jl", + # "Shardy.jl", + # "MPI.jl", + # "MemRef.jl", + # "SparseTensor.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/deps/build_local.jl b/deps/build_local.jl index b135800600..8d1a43821e 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -115,36 +115,50 @@ source_dir = joinpath(@__DIR__, "ReactantExtra") # --@local_config_cuda//:cuda_compiler=nvcc # --crosstool_top="@local_config_cuda//crosstool:toolchain" -build_kind = parsed_args["debug"] ? "dbg" : "opt" +abstract type AbstractBackend end +struct CPUBackend <: AbstractBackend end +struct CUDABackend <: AbstractBackend + version::VersionNumber + CUDABackend(ver::VersionNumber) = new(VersionNumber(ver.major)) +end -build_backend = parsed_args["backend"] +function parse_build_backend(str::String)::AbstractBackend + if str == "cpu" + return CPUBackend() + elseif str == "cuda12" + return CUDABackend(v"12") + elseif str == "cuda13" + return CUDABackend(v"13") + end -if build_backend == "auto" || build_backend == "cuda" - cuda_ver = get_cuda_version() - @show cuda_ver - if cuda_ver === nothing - if build_backend == "cuda" - throw( - AssertionError( - "Could not detect cuda version, but requested cuda with auto version build", - ), - ) - end - build_backend = "cpu" - else - if Int(get_cuda_version().major) == 13 - build_backend = "cuda13" + if str in ("auto", "cuda") + cuda_ver = get_cuda_version() + if isnothing(cuda_ver) + if str == "cuda" + throw( + AssertionError( + "Could not detect cuda version, but requested cuda with auto version build", + ), + ) + end + return CPUBackend() else - build_backend = "cuda12" + return CUDABackend(get_cuda_version()) end + else + error("Unknown backend '$(str)'") end end -arg = if build_backend == "cuda12" +build_kind = parsed_args["debug"] ? "dbg" : "opt" + +build_backend = parse_build_backend(parsed_args["backend"]) + +arg = if build_backend == CUDABackend(v"12") "--config=cuda12" -elseif build_backend == "cuda13" +elseif build_backend == CUDABackend(v"13") "--config=cuda13" -elseif build_backend == "cpu" +elseif build_backend == CPUBackend() "" else throw(AssertionError("Unknown backend `$build_backend`")) @@ -197,6 +211,8 @@ push!(build_cmd_list, "--jobs=$(parsed_args["jobs"])") push!(build_cmd_list, "--experimental_ui_max_stdouterr_bytes=-1") push!(build_cmd_list, "--sandbox_debug") +push!(build_cmd_list, "--linkopt=-fuse-ld=lld") + for opt in parsed_args["copt"] push!(build_cmd_list, "--copt=$(opt)") end @@ -231,17 +247,18 @@ push!(build_cmd_list, "--copt=-Wno-private-header") push!(build_cmd_list, "--color=$(parsed_args["color"])") push!(build_cmd_list, ":libReactantExtra.so") +@info "About to run Bazel" build_cmd_list run(Cmd(Cmd(build_cmd_list); dir=source_dir)) # Discover built libraries built_libs = filter(readdir(joinpath(source_dir, "bazel-bin"))) do file - endswith(file, "Extra.so") && startswith(file, "lib") + return endswith(file, "Extra.so") && startswith(file, "lib") end lib_path = joinpath(source_dir, "bazel-bin", only(built_libs)) isfile(lib_path) || error("Could not find library $lib_path in build directory") -if build_backend == "cuda" +if build_backend isa CUDABackend for path in ( joinpath("bin", "ptxas"), joinpath("bin", "fatbinary"), @@ -249,17 +266,27 @@ if build_backend == "cuda" ) full_path = joinpath(source_dir, "bazel-bin", "cuda", path) if !Base.Filesystem.ispath(full_path) - Base.Filesystem.mkpath(dirname(full_path)) - Base.Filesystem.symlink( - joinpath( - source_dir, - "bazel-bin", - "libReactantExtra.so.runfiles", - "cuda_nvcc", - path, - ), - full_path, + source = joinpath( + source_dir, + "bazel-bin", + "libReactantExtra.so.runfiles", + # libdevice's directory was moved in CUDA 13, before was in same + # dir as ptxas and fatbinary + if contains(basename(path), "libdevice") && build_backend.version >= v"13" + "cuda_nvvm" + else + "cuda_nvcc" + end, + path, ) + if !Base.Filesystem.ispath(source) + error( + "File $(source) does not exist, are you sure it is where you expect it to be?", + ) + end + Base.Filesystem.mkpath(dirname(full_path)) + @info "Symlinking $(full_path) -> $(source)" + Base.Filesystem.symlink(source, full_path) end end end diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl index bd9b55edde..68117d5050 100644 --- a/ext/ReactantCUDAExt.jl +++ b/ext/ReactantCUDAExt.jl @@ -1282,8 +1282,8 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})( @assert length(restys) == length(aliases) call = MLIR.Dialects.enzymexla.kernel_call( - blk_operands..., - mlir_args; + blk_operands...; + inputs=mlir_args, result_0=restys, fn=MLIR.IR.FlatSymbolRefAttribute(sym_name), output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases), diff --git a/ext/ReactantKernelAbstractionsExt.jl b/ext/ReactantKernelAbstractionsExt.jl index ee13c3bb3f..617075b086 100644 --- a/ext/ReactantKernelAbstractionsExt.jl +++ b/ext/ReactantKernelAbstractionsExt.jl @@ -109,15 +109,26 @@ function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsi return nothing end -Reactant.@reactant_overlay @noinline Base.@nospecializeinfer function ( - obj::KA.Kernel{ReactantBackend} -)( - args...; ndrange=nothing, workgroupsize=nothing -) - @nospecialize - return Reactant.call_with_reactant( - Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args... +@static if VERSION < v"1.12-" + Reactant.@reactant_overlay Base.@nospecializeinfer @noinline function ( + obj::KA.Kernel{ReactantBackend} + )( + @nospecialize args...; ndrange=nothing, workgroupsize=nothing ) + return Reactant.call_with_reactant( + Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args... + ) + end +else + Reactant.@reactant_overlay function (obj::KA.Kernel{ReactantBackend})( + args...; ndrange=nothing, workgroupsize=nothing + ) + Base.@_noinline_meta + Base.@_nospecializeinfer_meta + return Reactant.call_with_reactant( + Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args... + ) + end end end diff --git a/src/Compiler.jl b/src/Compiler.jl index a6be948b1c..4ce10365a4 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -910,6 +910,7 @@ function optimization_passes( "remove_no_ops_from_while_loop", "while_is_copy_simplify", "split_variadic_scatter_op", + "dynamic_slice_simplify", ] if !compile_options.disable_auto_batching_passes @@ -1124,9 +1125,11 @@ function optimization_passes( if AGGRESSIVE_PROPAGATION[] push!(transform_passes_list, "reshape_slice(0)") push!(transform_passes_list, "reshape_elementwise(0)") + push!(transform_passes_list, "reshape_dynamic_slice(0)") else push!(transform_passes_list, "reshape_slice(1)") push!(transform_passes_list, "reshape_elementwise(1)") + push!(transform_passes_list, "reshape_dynamic_slice(1)") end elseif compile_options.reshape_propagate === :down append!( @@ -1395,7 +1398,8 @@ function __get_compile_options_and_kwargs(; end function compile_mlir(f, args; client=nothing, kwargs...) - backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + client = client !== nothing ? client : XLA.default_backend() + backend = XLA.platform_name(client) if backend == "CUDA" backend = "GPU" @@ -1414,6 +1418,7 @@ function compile_mlir(f, args; client=nothing, kwargs...) compile_options; backend, runtime=XLA.runtime(client), + client, kwargs..., ) @@ -1430,11 +1435,9 @@ end const PartitionKA = Ref{Bool}(true) -const cubinChip = Ref{String}("sm_60") -const cubinFormat = Ref{String}("bin") const cuindexBitWidth = Ref{Int}(32) +const cubinFormat = Ref{String}("bin") const cuOptLevel = Ref{Int}(2) -const cuWarpSize = Ref{Int}(32) # Wgatever the relevant highest version from our LLVM is within NVPTX.td # Or more specifically looking at clang/lib/Driver/ToolChains/Cuda.cpp:684 @@ -1580,8 +1583,11 @@ function compile_mlir!( backend="gpu", runtime::Union{Val{:PJRT},Val{:IFRT}}, legalize_stablehlo_to_mhlo::Bool=false, + client=nothing, kwargs..., ) + client = client !== nothing ? client : XLA.default_backend() + # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues @@ -1655,25 +1661,27 @@ function compile_mlir!( else jit = "lower-jit{openmp=$(OpenMP[]) backend=cpu},symbol-dce" end - elseif DEBUG_KERNEL[] - curesulthandler = dlsym( - Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" - ) - @assert curesulthandler !== nothing - curesulthandler = Base.reinterpret(UInt, curesulthandler) + else kern = if is_raising "lower-kernel{backend=cpu},symbol-dce,canonicalize" else "lower-kernel,canonicalize" end - jit = "lower-jit{debug=true cuResultHandlerPtr=$curesulthandler cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" - else - kern = if is_raising - "lower-kernel{backend=cpu},symbol-dce,canonicalize" + + device_properties = XLA.device_properties(XLA.default_device(client)) + cubinChip = "sm_$(device_properties.major)$(device_properties.minor)" + + if DEBUG_KERNEL[] + curesulthandler = dlsym( + Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult" + ) + @assert curesulthandler !== nothing + curesulthandler = Base.reinterpret(UInt, curesulthandler) + extra_lowerjit_options = "debug=true cuResultHandlerPtr=$curesulthandler " else - "lower-kernel,canonicalize" + extra_lowerjit_options = "" end - jit = "lower-jit{cuOptLevel=$(cuOptLevel[]) indexBitWidth=$(cuindexBitWidth[]) cubinFormat=$(cubinFormat[]) cubinChip=$(cubinChip[]) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" + jit = "lower-jit{$(extra_lowerjit_options)cuOptLevel=$(cuOptLevel[]) cubinFormat=$(cubinFormat[]) indexBitWidth=$(cuindexBitWidth[]) cubinChip=$(cubinChip) cubinFeatures=$(cubinFeatures()) run_init=true toolkitPath=$toolkit},symbol-dce" end recognize_comms = true @@ -3477,7 +3485,8 @@ function compile_xla( context_gc_vector[ctx] = Vector{Union{TracedRArray,TracedRNumber}}(undef, 0) @ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid - backend = XLA.platform_name(client !== nothing ? client : XLA.default_backend()) + client = client !== nothing ? client : XLA.default_backend() + backend = XLA.platform_name(client) if backend == "CUDA" backend = "GPU" @@ -3498,6 +3507,7 @@ function compile_xla( compile_options; backend, runtime=XLA.runtime(client), + client, kwargs..., ) diff --git a/src/Interpreter.jl b/src/Interpreter.jl index 6ffcec0ab3..0e73cff527 100644 --- a/src/Interpreter.jl +++ b/src/Interpreter.jl @@ -41,18 +41,31 @@ function set_reactant_abi( if length(argtypes) != 1 @static if VERSION < v"1.11.0-" return CallMeta(Union{}, Effects(), NoCallInfo()) - else + elseif VERSION < v"1.12.0-" return CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + else + return Core.Compiler.Future{Core.Compiler.CallMeta}( + CallMeta(Union{}, Union{}, Effects(), NoCallInfo()) + ) end end @static if VERSION < v"1.11.0-" return CallMeta( Core.Const(true), Core.Compiler.EFFECTS_TOTAL, MethodResultPure() ) - else + elseif VERSION < v"1.12.0-" return CallMeta( Core.Const(true), Union{}, Core.Compiler.EFFECTS_TOTAL, MethodResultPure() ) + else + return Core.Compiler.Future{Core.Compiler.CallMeta}( + CallMeta( + Core.Const(true), + Union{}, + Core.Compiler.EFFECTS_TOTAL, + MethodResultPure(), + ), + ) end end diff --git a/src/Ops.jl b/src/Ops.jl index 252d87dc84..cbfbd9decd 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -934,6 +934,36 @@ end return TracedRArray{T,N}((), MLIR.IR.result(conv), result_size) end +@noinline function lapack_symm( + A::TracedRArray{T}, + B::TracedRArray{T}, + C::TracedRArray{T}, + alpha::TracedRNumber{T}, + beta::TracedRNumber{T}; + side::Symbol, + uplo::Symbol, + location=mlir_stacktrace("lapack_symm", @__FILE__, @__LINE__), +) where {T} + ctx = MLIR.IR.context() + ressize = size(C) + resT = mlir_type(TracedRArray{unwrapped_eltype(C),length(ressize)}, ressize) + + res = MLIR.IR.result( + enzymexla.lapack_symm( + A.mlir_data, + B.mlir_data, + C.mlir_data, + alpha.mlir_data, + beta.mlir_data; + output=resT, + side=MLIR.API.enzymexlaLapackSideAttrGet(ctx, side == :L ? 1 : 0), + uplo=MLIR.API.enzymexlaLapackUploAttrGet(ctx, uplo == :U ? 1 : 0), + location, + ), + ) + return TracedRArray{resT,length(ressize)}((), res, ressize) +end + Base.@nospecializeinfer @noinline function dot_general( @nospecialize(lhs::TracedRArray{T1}), @nospecialize(rhs::TracedRArray{T2}); diff --git a/src/accelerators/TPU.jl b/src/accelerators/TPU.jl index cf79fa4d0c..a0b5564914 100644 --- a/src/accelerators/TPU.jl +++ b/src/accelerators/TPU.jl @@ -10,6 +10,9 @@ using unzip_jll: unzip const libtpu_dir = Ref{Union{Nothing,String}}(nothing) const RUNNING_IN_CLOUD_TPU_VM = Ref(false) +const LIBTPU_VERSION = "0.0.28.dev20251027" +const LIBTPU_SO = "libtpu-$(replace(string(LIBTPU_VERSION), '.' => '_')).so" + function __init__() @static if !Sys.isapple() if !Reactant.precompiling() && has_tpu() @@ -32,18 +35,18 @@ end get_libtpu_dir() = libtpu_dir[] -get_libtpu_path() = joinpath(get_libtpu_dir(), "libtpu.so") +get_libtpu_path() = joinpath(get_libtpu_dir(), LIBTPU_SO) function download_libtpu_if_needed(path=nothing) path === nothing && (path = get_libtpu_dir()) @assert path !== nothing "libtpu_dir is not set!" - libtpu_path = joinpath(path, "libtpu.so") + libtpu_path = joinpath(path, LIBTPU_SO) if !isfile(libtpu_path) zip_file_path = joinpath(path, "tpu.zip") tmp_dir = joinpath(path, "tmp") Downloads.download( - "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20250811+nightly-py3-none-manylinux_2_31_x86_64.whl", + "https://storage.googleapis.com/libtpu-nightly-releases/wheels/libtpu/libtpu-0.0.28.dev20251027+nightly-cp314-cp314t-manylinux_2_31_x86_64.whl", zip_file_path, ) run(`$(unzip()) -qq $(zip_file_path) -d $(tmp_dir)`) diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index 42569f0784..3f0e47267e 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -259,6 +259,127 @@ function broadcast(input::Value; output::IR.Type, shape, location=Location()) ) end +""" +`cholesky_solve` + +Solves the linear system Ax = b for x using Cholesky decomposition. +Assuming A is symmetric positive definite! +""" +function cholesky_solve(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzyme.cholesky_solve", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`concat` + +Concat list of input arguments into a generic value +""" +function concat(inputs::Vector{Value}; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzyme.concat", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`dot` + +Computes the dot product of two 1D tensors (vectors). +""" +function dot(lhs::Value, rhs::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[lhs, rhs] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzyme.dot", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`dump` + +Debug operation that dumps a tensor value with a label. +""" +function dump(value::Value; output::IR.Type, label, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[value,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("label", label),] + + return create_operation( + "enzyme.dump", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`extract` + +Extract value from batched operand at index +""" +function extract(input::Value, index::Value; output::IR.Type, location=Location()) + op_ty_results = IR.Type[output,] + operands = Value[input, index] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzyme.extract", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function fwddiff( inputs::Vector{Value}; outputs::Vector{IR.Type}, @@ -372,6 +493,33 @@ function genericAdjoint( ) end +""" +`getFlattenedSamplesFromTrace` + +Get sampled values for multiple addresses from an execution trace and +flatten them into a single position vector for HMC. +""" +function getFlattenedSamplesFromTrace( + trace::Value; position::IR.Type, selection, location=Location() +) + op_ty_results = IR.Type[position,] + operands = Value[trace,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("selection", selection),] + + return create_operation( + "enzyme.getFlattenedSamplesFromTrace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function get(gradient::Value; result_0::IR.Type, location=Location()) op_ty_results = IR.Type[result_0,] operands = Value[gradient,] @@ -417,6 +565,32 @@ function getSampleFromConstraint( ) end +""" +`getSampleFromTrace` + +Get the sampled value for a given symbol from an execution trace. +""" +function getSampleFromTrace( + trace::Value; sample::Vector{IR.Type}, symbol, location=Location() +) + op_ty_results = IR.Type[sample...,] + operands = Value[trace,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("symbol", symbol),] + + return create_operation( + "enzyme.getSampleFromTrace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + """ `getSubconstraint` @@ -443,6 +617,54 @@ function getSubconstraint( ) end +""" +`getSubtrace` + +Get a subtrace from a trace for a given symbol. +""" +function getSubtrace(trace::Value; subtrace::IR.Type, symbol, location=Location()) + op_ty_results = IR.Type[subtrace,] + operands = Value[trace,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("symbol", symbol),] + + return create_operation( + "enzyme.getSubtrace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`getWeightFromTrace` + +Get the accumulated log-probability weight from an execution trace. +""" +function getWeightFromTrace(trace::Value; weight::IR.Type, location=Location()) + op_ty_results = IR.Type[weight,] + operands = Value[trace,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzyme.getWeightFromTrace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function ignore_derivatives(input::Value; output::IR.Type, location=Location()) op_ty_results = IR.Type[output,] operands = Value[input,] @@ -524,6 +746,150 @@ function load(cache::Value, indices::Vector{Value}; result::IR.Type, location=Lo ) end +""" +`loop` + +A counted loop operation that iterates from `lowerBound` to `upperBound` +by `step`, carrying `iter_args` through each iteration. The iteration +variable and iter_args are passed to the body region. +""" +function loop( + lowerBound::Value, + upperBound::Value, + step::Value, + initArgs::Vector{Value}; + results::Vector{IR.Type}, + region::Region, + location=Location(), +) + op_ty_results = IR.Type[results...,] + operands = Value[lowerBound, upperBound, step, initArgs...] + owned_regions = Region[region,] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzyme.loop", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mcmc` + +Perform an MCMC inference step (HMC, NUTS, etc.) on a probabilistic function. +This operation proposes a new trace using the specified algorithm, +computes the acceptance probability, and returns the updated trace. +By convention, the 0th operand in inputs is the initial RNG state +and the 0th operand in results is the updated RNG state. + +Optional HMC-specific parameters: +- mass: Mass matrix (identity assumed if not provided) +- step_size: Leapfrong integration step size +- num_steps: Number of leapfrog steps +- initial_momentum: deterministic initial momentum (debug) +""" +function mcmc( + inputs::Vector{Value}, + original_trace::Value, + mass=nothing::Union{Nothing,Value}; + step_size=nothing::Union{Nothing,Value}, + num_steps=nothing::Union{Nothing,Value}, + initial_momentum=nothing::Union{Nothing,Value}, + new_trace::IR.Type, + accepted::IR.Type, + output_rng_state::IR.Type, + alg, + fn, + selection, + name=nothing, + location=Location(), +) + op_ty_results = IR.Type[new_trace, accepted, output_rng_state] + operands = Value[inputs..., original_trace] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("alg", alg), + namedattribute("fn", fn), + namedattribute("selection", selection), + ] + !isnothing(mass) && push!(operands, mass) + !isnothing(step_size) && push!(operands, step_size) + !isnothing(num_steps) && push!(operands, num_steps) + !isnothing(initial_momentum) && push!(operands, initial_momentum) + push!( + attributes, + operandsegmentsizes([ + length(inputs), + 1, + (mass == nothing) ? 0 : 1, + (step_size == nothing) ? 0 : 1, + (num_steps == nothing) ? 0 : 1, + (initial_momentum == nothing) ? 0 : 1, + ]), + ) + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.mcmc", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`mh` + +Perform a Metropolis-Hastings step on a probabilistic function. +This operation proposes a new trace by regenerating selected addresses, +computes the acceptance probability, and returns the updated trace. +By convention, the 0th operand in inputs is the initial RNG state +and the 0th operand in results is the updated RNG state. +""" +function mh( + inputs::Vector{Value}, + original_trace::Value; + new_trace::IR.Type, + accepted::IR.Type, + output_rng_state::IR.Type, + fn, + selection, + name=nothing, + location=Location(), +) + op_ty_results = IR.Type[new_trace, accepted, output_rng_state] + operands = Value[inputs..., original_trace] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("fn", fn), namedattribute("selection", selection) + ] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.mh", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function placeholder(; output::IR.Type, location=Location()) op_ty_results = IR.Type[output,] operands = Value[] @@ -581,6 +947,94 @@ function push(cache::Value, value::Value; location=Location()) ) end +""" +`random` + +Generates random numbers using the rng_distribution algorithm and produces +a result tensor. + +If rng_distribution = UNIFORM, then the random numbers are generated following +the uniform distribution over the interval [a, b). If a >= b, the behavior is +undefined. + +If rng_distribution = NORMAL, then the random numbers are generated following +the normal distribution with mean = a and standard deviation = b. If b < 0, +the behavior is undefined. + +If rng_distribution = MULTINORMAL, then the random numbers are generated +following the multivariate normal distribution with mean = a (scalar or vector) +and covariance matrix = b. The parameter b should be a positive definite matrix. + +By convention, the 0th operand in inputs is the initial RNG state and the +0th operand in results is the updated RNG state. +""" +function random( + rng_state::Value, + a::Value, + b::Value; + output_rng_state::IR.Type, + result::IR.Type, + rng_distribution, + location=Location(), +) + op_ty_results = IR.Type[output_rng_state, result] + operands = Value[rng_state, a, b] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("rng_distribution", rng_distribution),] + + return create_operation( + "enzyme.random", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`regenerate` + +Regenerate selected addresses in a probabilistic function while keeping +other addresses fixed to their values in the given trace. +By convention, the 0th operand in inputs is the initial RNG state +and the 0th operand in results is the updated RNG state. +""" +function regenerate( + inputs::Vector{Value}, + original_trace::Value; + trace::IR.Type, + weight::IR.Type, + output_rng_state::IR.Type, + fn, + selection, + name=nothing, + location=Location(), +) + op_ty_results = IR.Type[trace, weight, output_rng_state] + operands = Value[inputs..., original_trace] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("fn", fn), namedattribute("selection", selection) + ] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.regenerate", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + """ `sample` @@ -617,6 +1071,36 @@ function sample( ) end +""" +`selectTrace` + +Selects between two !enzyme.Trace values (considered scalars here) based on a tensor condition. +""" +function selectTrace( + condition::Value, + true_value::Value, + false_value::Value; + result::IR.Type, + location=Location(), +) + op_ty_results = IR.Type[result,] + operands = Value[condition, true_value, false_value] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzyme.selectTrace", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function set(gradient::Value, value::Value; location=Location()) op_ty_results = IR.Type[] operands = Value[gradient, value] @@ -692,6 +1176,31 @@ function store(value::Value, cache::Value, indices::Vector{Value}; location=Loca ) end +""" +`unflatten_slice` + +Extract a slice from a 1D position vector starting at the given offset, +and reconstruct the original multi-dimensional tensor shape (implied by the type). +""" +function unflatten_slice(position::Value; result::IR.Type, offset, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[position,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("offset", offset),] + + return create_operation( + "enzyme.unflatten_slice", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + """ `untracedCall` @@ -720,6 +1229,47 @@ function untracedCall( ) end +""" +`update` + +Update selected addresses in a trace with new values from a position vector, +re-evaluate the probabilistic function, and return the updated trace with +the new weight (log probability) and updated RNG state. +By convention, the 0th operand in inputs is the initial RNG state. +""" +function update( + inputs::Vector{Value}, + original_trace::Value, + position::Value; + updated_trace::IR.Type, + weight::IR.Type, + output_rng_state::IR.Type, + fn, + selection, + name=nothing, + location=Location(), +) + op_ty_results = IR.Type[updated_trace, weight, output_rng_state] + operands = Value[inputs..., original_trace, position] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("fn", fn), namedattribute("selection", selection) + ] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.update", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function yield(operands::Vector{Value}; location=Location()) op_ty_results = IR.Type[] operands = Value[operands...,] diff --git a/src/mlir/Dialects/EnzymeXLA.jl b/src/mlir/Dialects/EnzymeXLA.jl index 79a5cd298a..49067fe710 100755 --- a/src/mlir/Dialects/EnzymeXLA.jl +++ b/src/mlir/Dialects/EnzymeXLA.jl @@ -1,238 +1,168 @@ module enzymexla using ...IR -import ...IR: - NamedAttribute, - Value, - Location, - Block, - Region, - Attribute, - create_operation, - context, - IndexType +import ...IR: NamedAttribute, Value, Location, Block, Region, Attribute, create_operation, context, IndexType import ..Dialects: namedattribute, operandsegmentsizes import ...API -function scope( - operands::Vector{Value}; results::Vector{IR.Type}, region::Region, location=Location() -) - op_ty_results = IR.Type[results...,] - operands = Value[operands...,] - owned_regions = Region[region,] + + +function scope(operands::Vector{Value}; results::Vector{IR.Type}, region::Region, location=Location()) + op_ty_results = IR.Type[results..., ] + operands = Value[operands..., ] + owned_regions = Region[region, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.scope", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.scope", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function alternatives(; regions::Vector{Region}, location=Location()) op_ty_results = IR.Type[] operands = Value[] - owned_regions = Region[regions...,] + owned_regions = Region[regions..., ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.alternatives", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.alternatives", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function barrier(indices::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[indices...,] + operands = Value[indices..., ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.barrier", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.barrier", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function comm_region(; result_0::Vector{IR.Type}, body::Region, location=Location()) - op_ty_results = IR.Type[result_0...,] + op_ty_results = IR.Type[result_0..., ] operands = Value[] - owned_regions = Region[body,] + owned_regions = Region[body, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.comm_region", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.comm_region", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function extend( - operand::Value; - result=nothing::Union{Nothing,IR.Type}, - lhs, - rhs, - dimension, - location=Location(), -) + +function extend(operand::Value; result=nothing::Union{Nothing, IR.Type}, lhs, rhs, dimension, location=Location()) op_ty_results = IR.Type[] - operands = Value[operand,] + operands = Value[operand, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("lhs", lhs), - namedattribute("rhs", rhs), - namedattribute("dimension", dimension), - ] + attributes = NamedAttribute[namedattribute("lhs", lhs), namedattribute("rhs", rhs), namedattribute("dimension", dimension), ] !isnothing(result) && push!(op_ty_results, result) - - return create_operation( - "enzymexla.extend", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.extend", location; + operands, owned_regions, successors, attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + result_inference=(length(op_ty_results) == 0 ? true : false) ) end -function gpu_block( - blockIndexX::Value, - blockIndexY::Value, - blockIndexZ::Value; - region::Region, - location=Location(), -) + +function gpu_block(blockIndexX::Value, blockIndexY::Value, blockIndexZ::Value; region::Region, location=Location()) op_ty_results = IR.Type[] - operands = Value[blockIndexX, blockIndexY, blockIndexZ] - owned_regions = Region[region,] + operands = Value[blockIndexX, blockIndexY, blockIndexZ, ] + owned_regions = Region[region, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.gpu_block", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.gpu_block", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function gpu_error(; result::IR.Type, region::Region, location=Location()) - op_ty_results = IR.Type[result,] + op_ty_results = IR.Type[result, ] operands = Value[] - owned_regions = Region[region,] + owned_regions = Region[region, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.gpu_error", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.gpu_error", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function gpu_kernel_address(; result::IR.Type, fn, location=Location()) - op_ty_results = IR.Type[result,] + op_ty_results = IR.Type[result, ] operands = Value[] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] - - return create_operation( - "enzymexla.gpu_kernel_address", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("fn", fn), ] + + create_operation( + "enzymexla.gpu_kernel_address", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function gpu_occupancy( - blockSize::Value, - dynamicSMemSize::Value, - flags::Value; - result::IR.Type, - fn, - location=Location(), -) - op_ty_results = IR.Type[result,] - operands = Value[blockSize, dynamicSMemSize, flags] + +function gpu_occupancy(blockSize::Value, dynamicSMemSize::Value, flags::Value; result::IR.Type, fn, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[blockSize, dynamicSMemSize, flags, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] - - return create_operation( - "enzymexla.gpu_occupancy", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("fn", fn), ] + + create_operation( + "enzymexla.gpu_occupancy", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function gpu_thread( - threadIndexX::Value, - threadIndexY::Value, - threadIndexZ::Value; - region::Region, - location=Location(), -) + +function gpu_thread(threadIndexX::Value, threadIndexY::Value, threadIndexZ::Value; region::Region, location=Location()) op_ty_results = IR.Type[] - operands = Value[threadIndexX, threadIndexY, threadIndexZ] - owned_regions = Region[region,] + operands = Value[threadIndexX, threadIndexY, threadIndexZ, ] + owned_regions = Region[region, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.gpu_thread", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.gpu_thread", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -243,49 +173,35 @@ The optional arguments to this operation are suggestions about what block dimensions this gpu kernel should have - usually taken from kernel launch params """ -function gpu_wrapper( - blockDims::Vector{Value}; result::IR.Type, region::Region, location=Location() -) - op_ty_results = IR.Type[result,] - operands = Value[blockDims...,] - owned_regions = Region[region,] +function gpu_wrapper(blockDims::Vector{Value}; result::IR.Type, region::Region, location=Location()) + op_ty_results = IR.Type[result, ] + operands = Value[blockDims..., ] + owned_regions = Region[region, ] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.gpu_wrapper", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.gpu_wrapper", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function ml_gelu( - input::Value; - result=nothing::Union{Nothing,IR.Type}, - gelu_approximation, - location=Location(), -) + +function ml_gelu(input::Value; result=nothing::Union{Nothing, IR.Type}, gelu_approximation, location=Location()) op_ty_results = IR.Type[] - operands = Value[input,] + operands = Value[input, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("gelu_approximation", gelu_approximation),] + attributes = NamedAttribute[namedattribute("gelu_approximation", gelu_approximation), ] !isnothing(result) && push!(op_ty_results, result) - - return create_operation( - "enzymexla.ml.gelu", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.ml.gelu", location; + operands, owned_regions, successors, attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + result_inference=(length(op_ty_results) == 0 ? true : false) ) end @@ -294,31 +210,19 @@ end This operation is modeled after LAPACK\'s *GEMQR routines. """ -function lapack_gemqrt( - V::Value, - T::Value, - C::Value; - output::IR.Type, - side, - transpose=nothing, - location=Location(), -) - op_ty_results = IR.Type[output,] - operands = Value[V, T, C] +function lapack_gemqrt(V::Value, T::Value, C::Value; output::IR.Type, side, transpose=nothing, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[V, T, C, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("side", side),] + attributes = NamedAttribute[namedattribute("side", side), ] !isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose)) - - return create_operation( - "enzymexla.lapack.gemqrt", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.lapack.gemqrt", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -332,24 +236,18 @@ orthogonal matrix Q and an upper triangular matrix R, such that A = QR. This operation is modeled after LAPACK\'s *GEQRF routines, which returns the result in the QR packed format. """ -function lapack_geqrf( - input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location() -) - op_ty_results = IR.Type[output, tau, info] - operands = Value[input,] +function lapack_geqrf(input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location()) + op_ty_results = IR.Type[output, tau, info, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.lapack.geqrf", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.lapack.geqrf", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -363,92 +261,58 @@ orthogonal matrix Q and an upper triangular matrix R, such that A = QR. This operation is modeled after LAPACK\'s *GEQRT routines, which returns the result in the QR CompactWY format. """ -function lapack_geqrt( - input::Value; - output::IR.Type, - T::IR.Type, - info::IR.Type, - blocksize=nothing, - location=Location(), -) - op_ty_results = IR.Type[output, T, info] - operands = Value[input,] +function lapack_geqrt(input::Value; output::IR.Type, T::IR.Type, info::IR.Type, blocksize=nothing, location=Location()) + op_ty_results = IR.Type[output, T, info, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(blocksize) && push!(attributes, namedattribute("blocksize", blocksize)) - - return create_operation( - "enzymexla.lapack.geqrt", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.lapack.geqrt", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function get_stream(; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result,] + op_ty_results = IR.Type[result, ] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.get_stream", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.get_stream", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function jit_call( - inputs::Vector{Value}; - result_0::Vector{IR.Type}, - fn, - backend_config=nothing, - operand_layouts=nothing, - result_layouts=nothing, - arg_attrs=nothing, - res_attrs=nothing, - output_operand_aliases=nothing, - xla_side_effect_free=nothing, - location=Location(), -) - op_ty_results = IR.Type[result_0...,] - operands = Value[inputs...,] + +function jit_call(inputs::Vector{Value}; result_0::Vector{IR.Type}, fn, backend_config=nothing, operand_layouts=nothing, result_layouts=nothing, arg_attrs=nothing, res_attrs=nothing, output_operand_aliases=nothing, xla_side_effect_free=nothing, location=Location()) + op_ty_results = IR.Type[result_0..., ] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] - !isnothing(backend_config) && - push!(attributes, namedattribute("backend_config", backend_config)) - !isnothing(operand_layouts) && - push!(attributes, namedattribute("operand_layouts", operand_layouts)) - !isnothing(result_layouts) && - push!(attributes, namedattribute("result_layouts", result_layouts)) + attributes = NamedAttribute[namedattribute("fn", fn), ] + !isnothing(backend_config) && push!(attributes, namedattribute("backend_config", backend_config)) + !isnothing(operand_layouts) && push!(attributes, namedattribute("operand_layouts", operand_layouts)) + !isnothing(result_layouts) && push!(attributes, namedattribute("result_layouts", result_layouts)) !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) - !isnothing(output_operand_aliases) && - push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) - !isnothing(xla_side_effect_free) && - push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) - - return create_operation( - "enzymexla.jit_call", - location; - operands, - owned_regions, - successors, - attributes, + !isnothing(output_operand_aliases) && push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + create_operation( + "enzymexla.jit_call", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -460,7 +324,10 @@ function kernel_call( blocky::Value, blockz::Value, shmem::Value, - inputs::Vector{Value}; + clusterx=nothing::Union{Nothing,Value}; + clustery=nothing::Union{Nothing,Value}, + clusterz=nothing::Union{Nothing,Value}, + inputs::Vector{Value}, result_0::Vector{IR.Type}, fn, backend_config=nothing, @@ -477,6 +344,25 @@ function kernel_call( owned_regions = Region[] successors = Block[] attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(clusterx) && push!(operands, clusterx) + !isnothing(clustery) && push!(operands, clustery) + !isnothing(clusterz) && push!(operands, clusterz) + push!( + attributes, + operandsegmentsizes([ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + (clusterx == nothing) ? 0 : 1, + (clustery == nothing) ? 0 : 1, + (clusterz == nothing) ? 0 : 1, + length(inputs), + ]), + ) !isnothing(backend_config) && push!(attributes, namedattribute("backend_config", backend_config)) !isnothing(operand_layouts) && @@ -485,46 +371,30 @@ function kernel_call( push!(attributes, namedattribute("result_layouts", result_layouts)) !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) - !isnothing(output_operand_aliases) && - push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) - !isnothing(xla_side_effect_free) && - push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) - - return create_operation( - "enzymexla.kernel_call", - location; - operands, - owned_regions, - successors, - attributes, + !isnothing(output_operand_aliases) && push!(attributes, namedattribute("output_operand_aliases", output_operand_aliases)) + !isnothing(xla_side_effect_free) && push!(attributes, namedattribute("xla_side_effect_free", xla_side_effect_free)) + + create_operation( + "enzymexla.kernel_call", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function linalg_lu( - input::Value; - output::IR.Type, - pivots::IR.Type, - permutation::IR.Type, - info::IR.Type, - location=Location(), -) - op_ty_results = IR.Type[output, pivots, permutation, info] - operands = Value[input,] + +function linalg_lu(input::Value; output::IR.Type, pivots::IR.Type, permutation::IR.Type, info::IR.Type, location=Location()) + op_ty_results = IR.Type[output, pivots, permutation, info, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.linalg.lu", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.linalg.lu", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -546,68 +416,51 @@ that case, it returns a !gpu.async.token. %token = gpu.memcpy async [%dep] %dst, %src : memref, memref ``` """ -function memcpy( - asyncDependencies::Vector{Value}, - target::Value, - source::Value, - size::Value; - asyncToken=nothing::Union{Nothing,IR.Type}, - location=Location(), -) +function memcpy(asyncDependencies::Vector{Value}, target::Value, source::Value, size::Value; asyncToken=nothing::Union{Nothing, IR.Type}, location=Location()) op_ty_results = IR.Type[] - operands = Value[asyncDependencies..., target, source, size] + operands = Value[asyncDependencies..., target, source, size, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(asyncToken) && push!(op_ty_results, asyncToken) - - return create_operation( - "enzymexla.memcpy", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.memcpy", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function memref2pointer(source::Value; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result,] - operands = Value[source,] + op_ty_results = IR.Type[result, ] + operands = Value[source, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.memref2pointer", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.memref2pointer", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function noop(blockDims::Vector{Value}; location=Location()) op_ty_results = IR.Type[] - operands = Value[blockDims...,] + operands = Value[blockDims..., ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.noop", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.noop", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -617,21 +470,17 @@ end This operation is modeled after LAPACK\'s *ORGQR/*UNGQR routines. """ function lapack_orgqr(input::Value, tau::Value; output::IR.Type, location=Location()) - op_ty_results = IR.Type[output,] - operands = Value[input, tau] + op_ty_results = IR.Type[output, ] + operands = Value[input, tau, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.lapack.orgqr", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.lapack.orgqr", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -640,69 +489,51 @@ end This operation is modeled after LAPACK\'s *ORMQR routines. """ -function lapack_ormqr( - A::Value, - tau::Value, - C::Value; - output::IR.Type, - side, - transpose=nothing, - location=Location(), -) - op_ty_results = IR.Type[output,] - operands = Value[A, tau, C] +function lapack_ormqr(A::Value, tau::Value, C::Value; output::IR.Type, side, transpose=nothing, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[A, tau, C, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("side", side),] + attributes = NamedAttribute[namedattribute("side", side), ] !isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose)) - - return create_operation( - "enzymexla.lapack.ormqr", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.lapack.ormqr", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function pointer2memref(source::Value; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result,] - operands = Value[source,] + op_ty_results = IR.Type[result, ] + operands = Value[source, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.pointer2memref", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.pointer2memref", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function polygeist_yield(; location=Location()) op_ty_results = IR.Type[] operands = Value[] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.polygeist_yield", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.polygeist_yield", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -720,120 +551,86 @@ will be a m x n trapezoidal matrix. This operation is modeled after the mathematical formulation of the QR factorization, and not after LAPACK\'s compact formats. """ -function linalg_qr( - input::Value; Q::IR.Type, R::IR.Type, algorithm=nothing, location=Location() -) - op_ty_results = IR.Type[Q, R] - operands = Value[input,] +function linalg_qr(input::Value; Q::IR.Type, R::IR.Type, algorithm=nothing, location=Location()) + op_ty_results = IR.Type[Q, R, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(algorithm) && push!(attributes, namedattribute("algorithm", algorithm)) - - return create_operation( - "enzymexla.linalg.qr", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.linalg.qr", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function ml_relu(input::Value; result=nothing::Union{Nothing,IR.Type}, location=Location()) + +function ml_relu(input::Value; result=nothing::Union{Nothing, IR.Type}, location=Location()) op_ty_results = IR.Type[] - operands = Value[input,] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(result) && push!(op_ty_results, result) - - return create_operation( - "enzymexla.ml.relu", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.ml.relu", location; + operands, owned_regions, successors, attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + result_inference=(length(op_ty_results) == 0 ? true : false) ) end -function rotate( - operand::Value; - result=nothing::Union{Nothing,IR.Type}, - amount, - dimension, - location=Location(), -) + +function rotate(operand::Value; result=nothing::Union{Nothing, IR.Type}, amount, dimension, location=Location()) op_ty_results = IR.Type[] - operands = Value[operand,] + operands = Value[operand, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("amount", amount), namedattribute("dimension", dimension) - ] + attributes = NamedAttribute[namedattribute("amount", amount), namedattribute("dimension", dimension), ] !isnothing(result) && push!(op_ty_results, result) - - return create_operation( - "enzymexla.rotate", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.rotate", location; + operands, owned_regions, successors, attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + result_inference=(length(op_ty_results) == 0 ? true : false) ) end -function linalg_svd( - input::Value; - U::IR.Type, - S::IR.Type, - Vt::IR.Type, - info::IR.Type, - full=nothing, - location=Location(), -) - op_ty_results = IR.Type[U, S, Vt, info] - operands = Value[input,] + +function linalg_svd(input::Value; U::IR.Type, S::IR.Type, Vt::IR.Type, info::IR.Type, full=nothing, location=Location()) + op_ty_results = IR.Type[U, S, Vt, info, ] + operands = Value[input, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] !isnothing(full) && push!(attributes, namedattribute("full", full)) - - return create_operation( - "enzymexla.linalg.svd", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.linalg.svd", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end + function stream2token(source::Value; result::IR.Type, location=Location()) - op_ty_results = IR.Type[result,] - operands = Value[source,] + op_ty_results = IR.Type[result, ] + operands = Value[source, ] owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] - - return create_operation( - "enzymexla.stream2token", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.stream2token", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end @@ -842,86 +639,53 @@ end C := alpha*A*B + beta*C, or C := alpha*B*A + beta*C, where alpha and beta are scalars, A is a symmetric matrix\" """ -function lapack_symm( - A::Value, - B::Value, - C::Value, - alpha::Value, - beta::Value; - output::IR.Type, - side, - uplo, - location=Location(), -) - op_ty_results = IR.Type[output,] - operands = Value[A, B, C, alpha, beta] +function lapack_symm(A::Value, B::Value, C::Value, alpha::Value, beta::Value; output::IR.Type, side, uplo, location=Location()) + op_ty_results = IR.Type[output, ] + operands = Value[A, B, C, alpha, beta, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("side", side), namedattribute("uplo", uplo)] - - return create_operation( - "enzymexla.lapack.symm", - location; - operands, - owned_regions, - successors, - attributes, + attributes = NamedAttribute[namedattribute("side", side), namedattribute("uplo", uplo), ] + + create_operation( + "enzymexla.lapack.symm", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end -function wrap( - operand::Value; - result=nothing::Union{Nothing,IR.Type}, - lhs, - rhs, - dimension, - location=Location(), -) + +function wrap(operand::Value; result=nothing::Union{Nothing, IR.Type}, lhs, rhs, dimension, location=Location()) op_ty_results = IR.Type[] - operands = Value[operand,] + operands = Value[operand, ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[ - namedattribute("lhs", lhs), - namedattribute("rhs", rhs), - namedattribute("dimension", dimension), - ] + attributes = NamedAttribute[namedattribute("lhs", lhs), namedattribute("rhs", rhs), namedattribute("dimension", dimension), ] !isnothing(result) && push!(op_ty_results, result) - - return create_operation( - "enzymexla.wrap", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.wrap", location; + operands, owned_regions, successors, attributes, results=(length(op_ty_results) == 0 ? nothing : op_ty_results), - result_inference=(length(op_ty_results) == 0 ? true : false), + result_inference=(length(op_ty_results) == 0 ? true : false) ) end -function xla_wrapper( - inputs::Vector{Value}; fn, arg_attrs=nothing, res_attrs=nothing, location=Location() -) + +function xla_wrapper(inputs::Vector{Value}; fn, arg_attrs=nothing, res_attrs=nothing, location=Location()) op_ty_results = IR.Type[] - operands = Value[inputs...,] + operands = Value[inputs..., ] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("fn", fn),] + attributes = NamedAttribute[namedattribute("fn", fn), ] !isnothing(arg_attrs) && push!(attributes, namedattribute("arg_attrs", arg_attrs)) !isnothing(res_attrs) && push!(attributes, namedattribute("res_attrs", res_attrs)) - - return create_operation( - "enzymexla.xla_wrapper", - location; - operands, - owned_regions, - successors, - attributes, + + create_operation( + "enzymexla.xla_wrapper", location; + operands, owned_regions, successors, attributes, results=op_ty_results, - result_inference=false, + result_inference=false ) end diff --git a/src/mlir/libMLIR_h.jl b/src/mlir/libMLIR_h.jl index ce9c4c230b..916391f436 100755 --- a/src/mlir/libMLIR_h.jl +++ b/src/mlir/libMLIR_h.jl @@ -39,6 +39,7 @@ elseif Sys.iswindows() && Sys.ARCH === :x86_64 const off_t = off32_t end + struct MlirDialectHandle ptr::Ptr{Cvoid} end @@ -233,9 +234,7 @@ end Allocates a type id that is valid for the lifetime of the allocator """ function mlirTypeIDAllocatorAllocateTypeID(allocator) - @ccall mlir_c.mlirTypeIDAllocatorAllocateTypeID( - allocator::MlirTypeIDAllocator - )::MlirTypeID + @ccall mlir_c.mlirTypeIDAllocatorAllocateTypeID(allocator::MlirTypeIDAllocator)::MlirTypeID end struct MlirAsmState @@ -342,9 +341,7 @@ end Creates an MLIR context, setting the multithreading setting explicitly and pre-loading the dialects from the provided DialectRegistry. """ function mlirContextCreateWithRegistry(registry, threadingEnabled) - @ccall mlir_c.mlirContextCreateWithRegistry( - registry::MlirDialectRegistry, threadingEnabled::Bool - )::MlirContext + @ccall mlir_c.mlirContextCreateWithRegistry(registry::MlirDialectRegistry, threadingEnabled::Bool)::MlirContext end """ @@ -380,9 +377,7 @@ end Sets whether unregistered dialects are allowed in this context. """ function mlirContextSetAllowUnregisteredDialects(context, allow) - @ccall mlir_c.mlirContextSetAllowUnregisteredDialects( - context::MlirContext, allow::Bool - )::Cvoid + @ccall mlir_c.mlirContextSetAllowUnregisteredDialects(context::MlirContext, allow::Bool)::Cvoid end """ @@ -409,9 +404,7 @@ end Append the contents of the given dialect registry to the registry associated with the context. """ function mlirContextAppendDialectRegistry(ctx, registry) - @ccall mlir_c.mlirContextAppendDialectRegistry( - ctx::MlirContext, registry::MlirDialectRegistry - )::Cvoid + @ccall mlir_c.mlirContextAppendDialectRegistry(ctx::MlirContext, registry::MlirDialectRegistry)::Cvoid end """ @@ -429,9 +422,7 @@ end Gets the dialect instance owned by the given context using the dialect namespace to identify it, loads (i.e., constructs the instance of) the dialect if necessary. If the dialect is not registered with the context, returns null. Use mlirContextLoadDialect to load an unregistered dialect. """ function mlirContextGetOrLoadDialect(context, name) - @ccall mlir_c.mlirContextGetOrLoadDialect( - context::MlirContext, name::MlirStringRef - )::MlirDialect + @ccall mlir_c.mlirContextGetOrLoadDialect(context::MlirContext, name::MlirStringRef)::MlirDialect end """ @@ -458,9 +449,7 @@ end Returns whether the given fully-qualified operation (i.e. 'dialect.operation') is registered with the context. This will return true if the dialect is loaded and the operation is registered within the dialect. """ function mlirContextIsRegisteredOperation(context, name) - @ccall mlir_c.mlirContextIsRegisteredOperation( - context::MlirContext, name::MlirStringRef - )::Bool + @ccall mlir_c.mlirContextIsRegisteredOperation(context::MlirContext, name::MlirStringRef)::Bool end """ @@ -469,9 +458,7 @@ end Sets the thread pool of the context explicitly, enabling multithreading in the process. This API should be used to avoid re-creating thread pools in long-running applications that perform multiple compilations, see the C++ documentation for MLIRContext for details. """ function mlirContextSetThreadPool(context, threadPool) - @ccall mlir_c.mlirContextSetThreadPool( - context::MlirContext, threadPool::MlirLlvmThreadPool - )::Cvoid + @ccall mlir_c.mlirContextSetThreadPool(context::MlirContext, threadPool::MlirLlvmThreadPool)::Cvoid end """ @@ -543,9 +530,7 @@ end Inserts the dialect associated with the provided dialect handle into the provided dialect registry """ function mlirDialectHandleInsertDialect(arg1, arg2) - @ccall mlir_c.mlirDialectHandleInsertDialect( - arg1::MlirDialectHandle, arg2::MlirDialectRegistry - )::Cvoid + @ccall mlir_c.mlirDialectHandleInsertDialect(arg1::MlirDialectHandle, arg2::MlirDialectRegistry)::Cvoid end """ @@ -554,9 +539,7 @@ end Registers the dialect associated with the provided dialect handle. """ function mlirDialectHandleRegisterDialect(arg1, arg2) - @ccall mlir_c.mlirDialectHandleRegisterDialect( - arg1::MlirDialectHandle, arg2::MlirContext - )::Cvoid + @ccall mlir_c.mlirDialectHandleRegisterDialect(arg1::MlirDialectHandle, arg2::MlirContext)::Cvoid end """ @@ -565,9 +548,7 @@ end Loads the dialect associated with the provided dialect handle. """ function mlirDialectHandleLoadDialect(arg1, arg2) - @ccall mlir_c.mlirDialectHandleLoadDialect( - arg1::MlirDialectHandle, arg2::MlirContext - )::MlirDialect + @ccall mlir_c.mlirDialectHandleLoadDialect(arg1::MlirDialectHandle, arg2::MlirContext)::MlirDialect end """ @@ -621,9 +602,7 @@ end Creates an File/Line/Column location owned by the given context. """ function mlirLocationFileLineColGet(context, filename, line, col) - @ccall mlir_c.mlirLocationFileLineColGet( - context::MlirContext, filename::MlirStringRef, line::Cuint, col::Cuint - )::MlirLocation + @ccall mlir_c.mlirLocationFileLineColGet(context::MlirContext, filename::MlirStringRef, line::Cuint, col::Cuint)::MlirLocation end """ @@ -631,17 +610,8 @@ end Creates an File/Line/Column range location owned by the given context. """ -function mlirLocationFileLineColRangeGet( - context, filename, start_line, start_col, end_line, end_col -) - @ccall mlir_c.mlirLocationFileLineColRangeGet( - context::MlirContext, - filename::MlirStringRef, - start_line::Cuint, - start_col::Cuint, - end_line::Cuint, - end_col::Cuint, - )::MlirLocation +function mlirLocationFileLineColRangeGet(context, filename, start_line, start_col, end_line, end_col) + @ccall mlir_c.mlirLocationFileLineColRangeGet(context::MlirContext, filename::MlirStringRef, start_line::Cuint, start_col::Cuint, end_line::Cuint, end_col::Cuint)::MlirLocation end """ @@ -650,9 +620,7 @@ end Getter for filename of FileLineColRange. """ function mlirLocationFileLineColRangeGetFilename(location) - @ccall mlir_c.mlirLocationFileLineColRangeGetFilename( - location::MlirLocation - )::MlirIdentifier + @ccall mlir_c.mlirLocationFileLineColRangeGetFilename(location::MlirLocation)::MlirIdentifier end """ @@ -715,9 +683,7 @@ end Creates a call site location with a callee and a caller. """ function mlirLocationCallSiteGet(callee, caller) - @ccall mlir_c.mlirLocationCallSiteGet( - callee::MlirLocation, caller::MlirLocation - )::MlirLocation + @ccall mlir_c.mlirLocationCallSiteGet(callee::MlirLocation, caller::MlirLocation)::MlirLocation end """ @@ -762,12 +728,7 @@ end Creates a fused location with an array of locations and metadata. """ function mlirLocationFusedGet(ctx, nLocations, locations, metadata) - @ccall mlir_c.mlirLocationFusedGet( - ctx::MlirContext, - nLocations::Cptrdiff_t, - locations::Ptr{MlirLocation}, - metadata::MlirAttribute, - )::MlirLocation + @ccall mlir_c.mlirLocationFusedGet(ctx::MlirContext, nLocations::Cptrdiff_t, locations::Ptr{MlirLocation}, metadata::MlirAttribute)::MlirLocation end """ @@ -785,9 +746,7 @@ end Getter for locations of Fused. Requires pre-allocated memory of #fusedLocations X sizeof([`MlirLocation`](@ref)). """ function mlirLocationFusedGetLocations(location, locationsCPtr) - @ccall mlir_c.mlirLocationFusedGetLocations( - location::MlirLocation, locationsCPtr::Ptr{MlirLocation} - )::Cvoid + @ccall mlir_c.mlirLocationFusedGetLocations(location::MlirLocation, locationsCPtr::Ptr{MlirLocation})::Cvoid end """ @@ -823,9 +782,7 @@ end Creates a name location owned by the given context. Providing null location for childLoc is allowed and if childLoc is null location, then the behavior is the same as having unknown child location. """ function mlirLocationNameGet(context, name, childLoc) - @ccall mlir_c.mlirLocationNameGet( - context::MlirContext, name::MlirStringRef, childLoc::MlirLocation - )::MlirLocation + @ccall mlir_c.mlirLocationNameGet(context::MlirContext, name::MlirStringRef, childLoc::MlirLocation)::MlirLocation end """ @@ -906,9 +863,7 @@ end Prints a location by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirLocationPrint(location, callback, userData) - @ccall mlir_c.mlirLocationPrint( - location::MlirLocation, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirLocationPrint(location::MlirLocation, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -926,9 +881,7 @@ end Parses a module from the string and transfers ownership to the caller. """ function mlirModuleCreateParse(context, _module) - @ccall mlir_c.mlirModuleCreateParse( - context::MlirContext, _module::MlirStringRef - )::MlirModule + @ccall mlir_c.mlirModuleCreateParse(context::MlirContext, _module::MlirStringRef)::MlirModule end """ @@ -937,9 +890,7 @@ end Parses a module from file and transfers ownership to the caller. """ function mlirModuleCreateParseFromFile(context, fileName) - @ccall mlir_c.mlirModuleCreateParseFromFile( - context::MlirContext, fileName::MlirStringRef - )::MlirModule + @ccall mlir_c.mlirModuleCreateParseFromFile(context::MlirContext, fileName::MlirStringRef)::MlirModule end """ @@ -1043,9 +994,7 @@ end Constructs an operation state from a name and a location. """ function mlirOperationStateGet(name, loc) - @ccall mlir_c.mlirOperationStateGet( - name::MlirStringRef, loc::MlirLocation - )::MlirOperationState + @ccall mlir_c.mlirOperationStateGet(name::MlirStringRef, loc::MlirLocation)::MlirOperationState end """ @@ -1054,33 +1003,23 @@ end Adds a list of components to the operation state. """ function mlirOperationStateAddResults(state, n, results) - @ccall mlir_c.mlirOperationStateAddResults( - state::Ptr{MlirOperationState}, n::Cptrdiff_t, results::Ptr{MlirType} - )::Cvoid + @ccall mlir_c.mlirOperationStateAddResults(state::Ptr{MlirOperationState}, n::Cptrdiff_t, results::Ptr{MlirType})::Cvoid end function mlirOperationStateAddOperands(state, n, operands) - @ccall mlir_c.mlirOperationStateAddOperands( - state::Ptr{MlirOperationState}, n::Cptrdiff_t, operands::Ptr{MlirValue} - )::Cvoid + @ccall mlir_c.mlirOperationStateAddOperands(state::Ptr{MlirOperationState}, n::Cptrdiff_t, operands::Ptr{MlirValue})::Cvoid end function mlirOperationStateAddOwnedRegions(state, n, regions) - @ccall mlir_c.mlirOperationStateAddOwnedRegions( - state::Ptr{MlirOperationState}, n::Cptrdiff_t, regions::Ptr{MlirRegion} - )::Cvoid + @ccall mlir_c.mlirOperationStateAddOwnedRegions(state::Ptr{MlirOperationState}, n::Cptrdiff_t, regions::Ptr{MlirRegion})::Cvoid end function mlirOperationStateAddSuccessors(state, n, successors) - @ccall mlir_c.mlirOperationStateAddSuccessors( - state::Ptr{MlirOperationState}, n::Cptrdiff_t, successors::Ptr{MlirBlock} - )::Cvoid + @ccall mlir_c.mlirOperationStateAddSuccessors(state::Ptr{MlirOperationState}, n::Cptrdiff_t, successors::Ptr{MlirBlock})::Cvoid end function mlirOperationStateAddAttributes(state, n, attributes) - @ccall mlir_c.mlirOperationStateAddAttributes( - state::Ptr{MlirOperationState}, n::Cptrdiff_t, attributes::Ptr{MlirNamedAttribute} - )::Cvoid + @ccall mlir_c.mlirOperationStateAddAttributes(state::Ptr{MlirOperationState}, n::Cptrdiff_t, attributes::Ptr{MlirNamedAttribute})::Cvoid end """ @@ -1089,9 +1028,7 @@ end Enables result type inference for the operation under construction. If enabled, then the caller must not have called [`mlirOperationStateAddResults`](@ref)(). Note that if enabled, the [`mlirOperationCreate`](@ref)() call is failable: it will return a null operation on inference failure and will emit diagnostics. """ function mlirOperationStateEnableResultTypeInference(state) - @ccall mlir_c.mlirOperationStateEnableResultTypeInference( - state::Ptr{MlirOperationState} - )::Cvoid + @ccall mlir_c.mlirOperationStateEnableResultTypeInference(state::Ptr{MlirOperationState})::Cvoid end """ @@ -1100,9 +1037,7 @@ end Creates new AsmState, as with AsmState the IR should not be mutated in-between using this state. Must be freed with a call to [`mlirAsmStateDestroy`](@ref)(). """ function mlirAsmStateCreateForOperation(op, flags) - @ccall mlir_c.mlirAsmStateCreateForOperation( - op::MlirOperation, flags::MlirOpPrintingFlags - )::MlirAsmState + @ccall mlir_c.mlirAsmStateCreateForOperation(op::MlirOperation, flags::MlirOpPrintingFlags)::MlirAsmState end """ @@ -1111,9 +1046,7 @@ end Creates new AsmState from value. Must be freed with a call to [`mlirAsmStateDestroy`](@ref)(). """ function mlirAsmStateCreateForValue(value, flags) - @ccall mlir_c.mlirAsmStateCreateForValue( - value::MlirValue, flags::MlirOpPrintingFlags - )::MlirAsmState + @ccall mlir_c.mlirAsmStateCreateForValue(value::MlirValue, flags::MlirOpPrintingFlags)::MlirAsmState end """ @@ -1149,9 +1082,7 @@ end Enables the elision of large elements attributes by printing a lexically valid but otherwise meaningless form instead of the element data. The `largeElementLimit` is used to configure what is considered to be a "large" ElementsAttr by providing an upper limit to the number of elements. """ function mlirOpPrintingFlagsElideLargeElementsAttrs(flags, largeElementLimit) - @ccall mlir_c.mlirOpPrintingFlagsElideLargeElementsAttrs( - flags::MlirOpPrintingFlags, largeElementLimit::Cptrdiff_t - )::Cvoid + @ccall mlir_c.mlirOpPrintingFlagsElideLargeElementsAttrs(flags::MlirOpPrintingFlags, largeElementLimit::Cptrdiff_t)::Cvoid end """ @@ -1160,9 +1091,7 @@ end Enables the elision of large resources strings by omitting them from the `dialect_resources` section. The `largeResourceLimit` is used to configure what is considered to be a "large" resource by providing an upper limit to the string size. """ function mlirOpPrintingFlagsElideLargeResourceString(flags, largeResourceLimit) - @ccall mlir_c.mlirOpPrintingFlagsElideLargeResourceString( - flags::MlirOpPrintingFlags, largeResourceLimit::Cptrdiff_t - )::Cvoid + @ccall mlir_c.mlirOpPrintingFlagsElideLargeResourceString(flags::MlirOpPrintingFlags, largeResourceLimit::Cptrdiff_t)::Cvoid end """ @@ -1171,9 +1100,7 @@ end Enable or disable printing of debug information (based on `enable`). If 'prettyForm' is set to true, debug information is printed in a more readable 'pretty' form. Note: The IR generated with 'prettyForm' is not parsable. """ function mlirOpPrintingFlagsEnableDebugInfo(flags, enable, prettyForm) - @ccall mlir_c.mlirOpPrintingFlagsEnableDebugInfo( - flags::MlirOpPrintingFlags, enable::Bool, prettyForm::Bool - )::Cvoid + @ccall mlir_c.mlirOpPrintingFlagsEnableDebugInfo(flags::MlirOpPrintingFlags, enable::Bool, prettyForm::Bool)::Cvoid end """ @@ -1245,9 +1172,7 @@ end Sets the version to emit in the writer config. """ function mlirBytecodeWriterConfigDesiredEmitVersion(flags, version) - @ccall mlir_c.mlirBytecodeWriterConfigDesiredEmitVersion( - flags::MlirBytecodeWriterConfig, version::Int64 - )::Cvoid + @ccall mlir_c.mlirBytecodeWriterConfigDesiredEmitVersion(flags::MlirBytecodeWriterConfig, version::Int64)::Cvoid end """ @@ -1269,9 +1194,7 @@ Parses an operation, giving ownership to the caller. If parsing fails a null ope `sourceStr` may be either the text assembly format, or binary bytecode format. `sourceName` is used as the file name of the source; any IR without locations will get a `FileLineColLoc` location with `sourceName` as the file name. """ function mlirOperationCreateParse(context, sourceStr, sourceName) - @ccall mlir_c.mlirOperationCreateParse( - context::MlirContext, sourceStr::MlirStringRef, sourceName::MlirStringRef - )::MlirOperation + @ccall mlir_c.mlirOperationCreateParse(context::MlirContext, sourceStr::MlirStringRef, sourceName::MlirStringRef)::MlirOperation end """ @@ -1442,9 +1365,7 @@ end Sets the `pos`-th operand of the operation. """ function mlirOperationSetOperand(op, pos, newValue) - @ccall mlir_c.mlirOperationSetOperand( - op::MlirOperation, pos::Cptrdiff_t, newValue::MlirValue - )::Cvoid + @ccall mlir_c.mlirOperationSetOperand(op::MlirOperation, pos::Cptrdiff_t, newValue::MlirValue)::Cvoid end """ @@ -1453,9 +1374,7 @@ end Replaces the operands of the operation. """ function mlirOperationSetOperands(op, nOperands, operands) - @ccall mlir_c.mlirOperationSetOperands( - op::MlirOperation, nOperands::Cptrdiff_t, operands::Ptr{MlirValue} - )::Cvoid + @ccall mlir_c.mlirOperationSetOperands(op::MlirOperation, nOperands::Cptrdiff_t, operands::Ptr{MlirValue})::Cvoid end """ @@ -1500,9 +1419,7 @@ end Set `pos`-th successor of the operation. """ function mlirOperationSetSuccessor(op, pos, block) - @ccall mlir_c.mlirOperationSetSuccessor( - op::MlirOperation, pos::Cptrdiff_t, block::MlirBlock - )::Cvoid + @ccall mlir_c.mlirOperationSetSuccessor(op::MlirOperation, pos::Cptrdiff_t, block::MlirBlock)::Cvoid end """ @@ -1511,9 +1428,7 @@ end Returns true if this operation defines an inherent attribute with this name. Note: the attribute can be optional, so [`mlirOperationGetInherentAttributeByName`](@ref) can still return a null attribute. """ function mlirOperationHasInherentAttributeByName(op, name) - @ccall mlir_c.mlirOperationHasInherentAttributeByName( - op::MlirOperation, name::MlirStringRef - )::Bool + @ccall mlir_c.mlirOperationHasInherentAttributeByName(op::MlirOperation, name::MlirStringRef)::Bool end """ @@ -1522,9 +1437,7 @@ end Returns an inherent attribute attached to the operation given its name. """ function mlirOperationGetInherentAttributeByName(op, name) - @ccall mlir_c.mlirOperationGetInherentAttributeByName( - op::MlirOperation, name::MlirStringRef - )::MlirAttribute + @ccall mlir_c.mlirOperationGetInherentAttributeByName(op::MlirOperation, name::MlirStringRef)::MlirAttribute end """ @@ -1533,9 +1446,7 @@ end Sets an inherent attribute by name, replacing the existing if it exists. This has no effect if "name" does not match an inherent attribute. """ function mlirOperationSetInherentAttributeByName(op, name, attr) - @ccall mlir_c.mlirOperationSetInherentAttributeByName( - op::MlirOperation, name::MlirStringRef, attr::MlirAttribute - )::Cvoid + @ccall mlir_c.mlirOperationSetInherentAttributeByName(op::MlirOperation, name::MlirStringRef, attr::MlirAttribute)::Cvoid end """ @@ -1553,9 +1464,7 @@ end Return `pos`-th discardable attribute of the operation. """ function mlirOperationGetDiscardableAttribute(op, pos) - @ccall mlir_c.mlirOperationGetDiscardableAttribute( - op::MlirOperation, pos::Cptrdiff_t - )::MlirNamedAttribute + @ccall mlir_c.mlirOperationGetDiscardableAttribute(op::MlirOperation, pos::Cptrdiff_t)::MlirNamedAttribute end """ @@ -1564,9 +1473,7 @@ end Returns a discardable attribute attached to the operation given its name. """ function mlirOperationGetDiscardableAttributeByName(op, name) - @ccall mlir_c.mlirOperationGetDiscardableAttributeByName( - op::MlirOperation, name::MlirStringRef - )::MlirAttribute + @ccall mlir_c.mlirOperationGetDiscardableAttributeByName(op::MlirOperation, name::MlirStringRef)::MlirAttribute end """ @@ -1575,9 +1482,7 @@ end Sets a discardable attribute by name, replacing the existing if it exists or adding a new one otherwise. The new `attr` Attribute is not allowed to be null, use [`mlirOperationRemoveDiscardableAttributeByName`](@ref) to remove an Attribute instead. """ function mlirOperationSetDiscardableAttributeByName(op, name, attr) - @ccall mlir_c.mlirOperationSetDiscardableAttributeByName( - op::MlirOperation, name::MlirStringRef, attr::MlirAttribute - )::Cvoid + @ccall mlir_c.mlirOperationSetDiscardableAttributeByName(op::MlirOperation, name::MlirStringRef, attr::MlirAttribute)::Cvoid end """ @@ -1586,9 +1491,7 @@ end Removes a discardable attribute by name. Returns false if the attribute was not found and true if removed. """ function mlirOperationRemoveDiscardableAttributeByName(op, name) - @ccall mlir_c.mlirOperationRemoveDiscardableAttributeByName( - op::MlirOperation, name::MlirStringRef - )::Bool + @ccall mlir_c.mlirOperationRemoveDiscardableAttributeByName(op::MlirOperation, name::MlirStringRef)::Bool end """ @@ -1606,9 +1509,7 @@ end Return `pos`-th attribute of the operation. Deprecated, please use `mlirOperationGetInherentAttribute` or [`mlirOperationGetDiscardableAttribute`](@ref). """ function mlirOperationGetAttribute(op, pos) - @ccall mlir_c.mlirOperationGetAttribute( - op::MlirOperation, pos::Cptrdiff_t - )::MlirNamedAttribute + @ccall mlir_c.mlirOperationGetAttribute(op::MlirOperation, pos::Cptrdiff_t)::MlirNamedAttribute end """ @@ -1617,9 +1518,7 @@ end Returns an attribute attached to the operation given its name. Deprecated, please use [`mlirOperationGetInherentAttributeByName`](@ref) or [`mlirOperationGetDiscardableAttributeByName`](@ref). """ function mlirOperationGetAttributeByName(op, name) - @ccall mlir_c.mlirOperationGetAttributeByName( - op::MlirOperation, name::MlirStringRef - )::MlirAttribute + @ccall mlir_c.mlirOperationGetAttributeByName(op::MlirOperation, name::MlirStringRef)::MlirAttribute end """ @@ -1628,9 +1527,7 @@ end Sets an attribute by name, replacing the existing if it exists or adding a new one otherwise. Deprecated, please use [`mlirOperationSetInherentAttributeByName`](@ref) or [`mlirOperationSetDiscardableAttributeByName`](@ref). """ function mlirOperationSetAttributeByName(op, name, attr) - @ccall mlir_c.mlirOperationSetAttributeByName( - op::MlirOperation, name::MlirStringRef, attr::MlirAttribute - )::Cvoid + @ccall mlir_c.mlirOperationSetAttributeByName(op::MlirOperation, name::MlirStringRef, attr::MlirAttribute)::Cvoid end """ @@ -1639,9 +1536,7 @@ end Removes an attribute by name. Returns false if the attribute was not found and true if removed. Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or [`mlirOperationRemoveDiscardableAttributeByName`](@ref). """ function mlirOperationRemoveAttributeByName(op, name) - @ccall mlir_c.mlirOperationRemoveAttributeByName( - op::MlirOperation, name::MlirStringRef - )::Bool + @ccall mlir_c.mlirOperationRemoveAttributeByName(op::MlirOperation, name::MlirStringRef)::Bool end """ @@ -1650,9 +1545,7 @@ end Prints an operation by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirOperationPrint(op, callback, userData) - @ccall mlir_c.mlirOperationPrint( - op::MlirOperation, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirOperationPrint(op::MlirOperation, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -1661,12 +1554,7 @@ end Same as [`mlirOperationPrint`](@ref) but accepts flags controlling the printing behavior. """ function mlirOperationPrintWithFlags(op, flags, callback, userData) - @ccall mlir_c.mlirOperationPrintWithFlags( - op::MlirOperation, - flags::MlirOpPrintingFlags, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - )::Cvoid + @ccall mlir_c.mlirOperationPrintWithFlags(op::MlirOperation, flags::MlirOpPrintingFlags, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -1675,12 +1563,7 @@ end Same as [`mlirOperationPrint`](@ref) but accepts AsmState controlling the printing behavior as well as caching computed names. """ function mlirOperationPrintWithState(op, state, callback, userData) - @ccall mlir_c.mlirOperationPrintWithState( - op::MlirOperation, - state::MlirAsmState, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - )::Cvoid + @ccall mlir_c.mlirOperationPrintWithState(op::MlirOperation, state::MlirAsmState, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -1689,9 +1572,7 @@ end Same as [`mlirOperationPrint`](@ref) but writing the bytecode format. """ function mlirOperationWriteBytecode(op, callback, userData) - @ccall mlir_c.mlirOperationWriteBytecode( - op::MlirOperation, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirOperationWriteBytecode(op::MlirOperation, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -1700,12 +1581,7 @@ end Same as [`mlirOperationWriteBytecode`](@ref) but with writer config and returns failure only if desired bytecode could not be honored. """ function mlirOperationWriteBytecodeWithConfig(op, config, callback, userData) - @ccall mlir_c.mlirOperationWriteBytecodeWithConfig( - op::MlirOperation, - config::MlirBytecodeWriterConfig, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - )::MlirLogicalResult + @ccall mlir_c.mlirOperationWriteBytecodeWithConfig(op::MlirOperation, config::MlirBytecodeWriterConfig, callback::MlirStringCallback, userData::Ptr{Cvoid})::MlirLogicalResult end """ @@ -1750,9 +1626,7 @@ end Given an operation 'other' that is within the same parent block, return whether the current operation is before 'other' in the operation list of the parent block. Note: This function has an average complexity of O(1), but worst case may take O(N) where N is the number of operations within the parent block. """ function mlirOperationIsBeforeInBlock(op, other) - @ccall mlir_c.mlirOperationIsBeforeInBlock( - op::MlirOperation, other::MlirOperation - )::Bool + @ccall mlir_c.mlirOperationIsBeforeInBlock(op::MlirOperation, other::MlirOperation)::Bool end """ @@ -1788,12 +1662,7 @@ const MlirOperationWalkCallback = Ptr{Cvoid} Walks operation `op` in `walkOrder` and calls `callback` on that operation. `*userData` is passed to the callback as well and can be used to tunnel some context or other data into the callback. """ function mlirOperationWalk(op, callback, userData, walkOrder) - @ccall mlir_c.mlirOperationWalk( - op::MlirOperation, - callback::MlirOperationWalkCallback, - userData::Ptr{Cvoid}, - walkOrder::MlirWalkOrder, - )::Cvoid + @ccall mlir_c.mlirOperationWalk(op::MlirOperation, callback::MlirOperationWalkCallback, userData::Ptr{Cvoid}, walkOrder::MlirWalkOrder)::Cvoid end """ @@ -1856,9 +1725,7 @@ end Takes a block owned by the caller and inserts it at `pos` to the given region. This is an expensive operation that linearly scans the region, prefer insertAfter/Before instead. """ function mlirRegionInsertOwnedBlock(region, pos, block) - @ccall mlir_c.mlirRegionInsertOwnedBlock( - region::MlirRegion, pos::Cptrdiff_t, block::MlirBlock - )::Cvoid + @ccall mlir_c.mlirRegionInsertOwnedBlock(region::MlirRegion, pos::Cptrdiff_t, block::MlirBlock)::Cvoid end """ @@ -1867,9 +1734,7 @@ end Takes a block owned by the caller and inserts it after the (non-owned) reference block in the given region. The reference block must belong to the region. If the reference block is null, prepends the block to the region. """ function mlirRegionInsertOwnedBlockAfter(region, reference, block) - @ccall mlir_c.mlirRegionInsertOwnedBlockAfter( - region::MlirRegion, reference::MlirBlock, block::MlirBlock - )::Cvoid + @ccall mlir_c.mlirRegionInsertOwnedBlockAfter(region::MlirRegion, reference::MlirBlock, block::MlirBlock)::Cvoid end """ @@ -1878,9 +1743,7 @@ end Takes a block owned by the caller and inserts it before the (non-owned) reference block in the given region. The reference block must belong to the region. If the reference block is null, appends the block to the region. """ function mlirRegionInsertOwnedBlockBefore(region, reference, block) - @ccall mlir_c.mlirRegionInsertOwnedBlockBefore( - region::MlirRegion, reference::MlirBlock, block::MlirBlock - )::Cvoid + @ccall mlir_c.mlirRegionInsertOwnedBlockBefore(region::MlirRegion, reference::MlirBlock, block::MlirBlock)::Cvoid end """ @@ -1916,9 +1779,7 @@ end Creates a new empty block with the given argument types and transfers ownership to the caller. """ function mlirBlockCreate(nArgs, args, locs) - @ccall mlir_c.mlirBlockCreate( - nArgs::Cptrdiff_t, args::Ptr{MlirType}, locs::Ptr{MlirLocation} - )::MlirBlock + @ccall mlir_c.mlirBlockCreate(nArgs::Cptrdiff_t, args::Ptr{MlirType}, locs::Ptr{MlirLocation})::MlirBlock end """ @@ -2008,9 +1869,7 @@ end Takes an operation owned by the caller and appends it to the block. """ function mlirBlockAppendOwnedOperation(block, operation) - @ccall mlir_c.mlirBlockAppendOwnedOperation( - block::MlirBlock, operation::MlirOperation - )::Cvoid + @ccall mlir_c.mlirBlockAppendOwnedOperation(block::MlirBlock, operation::MlirOperation)::Cvoid end """ @@ -2019,9 +1878,7 @@ end Takes an operation owned by the caller and inserts it as `pos` to the block. This is an expensive operation that scans the block linearly, prefer insertBefore/After instead. """ function mlirBlockInsertOwnedOperation(block, pos, operation) - @ccall mlir_c.mlirBlockInsertOwnedOperation( - block::MlirBlock, pos::Cptrdiff_t, operation::MlirOperation - )::Cvoid + @ccall mlir_c.mlirBlockInsertOwnedOperation(block::MlirBlock, pos::Cptrdiff_t, operation::MlirOperation)::Cvoid end """ @@ -2030,9 +1887,7 @@ end Takes an operation owned by the caller and inserts it after the (non-owned) reference operation in the given block. If the reference is null, prepends the operation. Otherwise, the reference must belong to the block. """ function mlirBlockInsertOwnedOperationAfter(block, reference, operation) - @ccall mlir_c.mlirBlockInsertOwnedOperationAfter( - block::MlirBlock, reference::MlirOperation, operation::MlirOperation - )::Cvoid + @ccall mlir_c.mlirBlockInsertOwnedOperationAfter(block::MlirBlock, reference::MlirOperation, operation::MlirOperation)::Cvoid end """ @@ -2041,9 +1896,7 @@ end Takes an operation owned by the caller and inserts it before the (non-owned) reference operation in the given block. If the reference is null, appends the operation. Otherwise, the reference must belong to the block. """ function mlirBlockInsertOwnedOperationBefore(block, reference, operation) - @ccall mlir_c.mlirBlockInsertOwnedOperationBefore( - block::MlirBlock, reference::MlirOperation, operation::MlirOperation - )::Cvoid + @ccall mlir_c.mlirBlockInsertOwnedOperationBefore(block::MlirBlock, reference::MlirOperation, operation::MlirOperation)::Cvoid end """ @@ -2061,9 +1914,7 @@ end Appends an argument of the specified type to the block. Returns the newly added argument. """ function mlirBlockAddArgument(block, type, loc) - @ccall mlir_c.mlirBlockAddArgument( - block::MlirBlock, type::MlirType, loc::MlirLocation - )::MlirValue + @ccall mlir_c.mlirBlockAddArgument(block::MlirBlock, type::MlirType, loc::MlirLocation)::MlirValue end """ @@ -2081,9 +1932,7 @@ end Inserts an argument of the specified type at a specified index to the block. Returns the newly added argument. """ function mlirBlockInsertArgument(block, pos, type, loc) - @ccall mlir_c.mlirBlockInsertArgument( - block::MlirBlock, pos::Cptrdiff_t, type::MlirType, loc::MlirLocation - )::MlirValue + @ccall mlir_c.mlirBlockInsertArgument(block::MlirBlock, pos::Cptrdiff_t, type::MlirType, loc::MlirLocation)::MlirValue end """ @@ -2101,9 +1950,7 @@ end Prints a block by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirBlockPrint(block, callback, userData) - @ccall mlir_c.mlirBlockPrint( - block::MlirBlock, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirBlockPrint(block::MlirBlock, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -2258,9 +2105,7 @@ end Prints a value by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirValuePrint(value, callback, userData) - @ccall mlir_c.mlirValuePrint( - value::MlirValue, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirValuePrint(value::MlirValue, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -2269,12 +2114,7 @@ end Prints a value as an operand (i.e., the ValueID). """ function mlirValuePrintAsOperand(value, state, callback, userData) - @ccall mlir_c.mlirValuePrintAsOperand( - value::MlirValue, - state::MlirAsmState, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - )::Cvoid + @ccall mlir_c.mlirValuePrintAsOperand(value::MlirValue, state::MlirAsmState, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -2301,12 +2141,7 @@ end Replace all uses of 'of' value with 'with' value, updating anything in the IR that uses 'of' to use 'with' instead, except if the user is listed in 'exceptions'. The 'exceptions' parameter is an array of [`MlirOperation`](@ref) pointers with a length of 'numExceptions'. """ function mlirValueReplaceAllUsesExcept(of, with, numExceptions, exceptions) - @ccall mlir_c.mlirValueReplaceAllUsesExcept( - of::MlirValue, - with::MlirValue, - numExceptions::Cptrdiff_t, - exceptions::Ptr{MlirOperation}, - )::Cvoid + @ccall mlir_c.mlirValueReplaceAllUsesExcept(of::MlirValue, with::MlirValue, numExceptions::Cptrdiff_t, exceptions::Ptr{MlirOperation})::Cvoid end """ @@ -2432,9 +2267,7 @@ end Prints a location by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirTypePrint(type, callback, userData) - @ccall mlir_c.mlirTypePrint( - type::MlirType, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirTypePrint(type::MlirType, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -2452,9 +2285,7 @@ end Parses an attribute. The attribute is owned by the context. """ function mlirAttributeParseGet(context, attr) - @ccall mlir_c.mlirAttributeParseGet( - context::MlirContext, attr::MlirStringRef - )::MlirAttribute + @ccall mlir_c.mlirAttributeParseGet(context::MlirContext, attr::MlirStringRef)::MlirAttribute end """ @@ -2517,9 +2348,7 @@ end Prints an attribute by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirAttributePrint(attr, callback, userData) - @ccall mlir_c.mlirAttributePrint( - attr::MlirAttribute, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirAttributePrint(attr::MlirAttribute, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -2537,9 +2366,7 @@ end Associates an attribute with the name. Takes ownership of neither. """ function mlirNamedAttributeGet(name, attr) - @ccall mlir_c.mlirNamedAttributeGet( - name::MlirIdentifier, attr::MlirAttribute - )::MlirNamedAttribute + @ccall mlir_c.mlirNamedAttributeGet(name::MlirIdentifier, attr::MlirAttribute)::MlirNamedAttribute end """ @@ -2548,9 +2375,7 @@ end Gets an identifier with the given string value. """ function mlirIdentifierGet(context, str) - @ccall mlir_c.mlirIdentifierGet( - context::MlirContext, str::MlirStringRef - )::MlirIdentifier + @ccall mlir_c.mlirIdentifierGet(context::MlirContext, str::MlirStringRef)::MlirIdentifier end """ @@ -2631,9 +2456,7 @@ end Looks up a symbol with the given name in the given symbol table and returns the operation that corresponds to the symbol. If the symbol cannot be found, returns a null operation. """ function mlirSymbolTableLookup(symbolTable, name) - @ccall mlir_c.mlirSymbolTableLookup( - symbolTable::MlirSymbolTable, name::MlirStringRef - )::MlirOperation + @ccall mlir_c.mlirSymbolTableLookup(symbolTable::MlirSymbolTable, name::MlirStringRef)::MlirOperation end """ @@ -2642,9 +2465,7 @@ end Inserts the given operation into the given symbol table. The operation must have the symbol trait. If the symbol table already has a symbol with the same name, renames the symbol being inserted to ensure name uniqueness. Note that this does not move the operation itself into the block of the symbol table operation, this should be done separately. Returns the name of the symbol after insertion. """ function mlirSymbolTableInsert(symbolTable, operation) - @ccall mlir_c.mlirSymbolTableInsert( - symbolTable::MlirSymbolTable, operation::MlirOperation - )::MlirAttribute + @ccall mlir_c.mlirSymbolTableInsert(symbolTable::MlirSymbolTable, operation::MlirOperation)::MlirAttribute end """ @@ -2653,9 +2474,7 @@ end Removes the given operation from the symbol table and erases it. """ function mlirSymbolTableErase(symbolTable, operation) - @ccall mlir_c.mlirSymbolTableErase( - symbolTable::MlirSymbolTable, operation::MlirOperation - )::Cvoid + @ccall mlir_c.mlirSymbolTableErase(symbolTable::MlirSymbolTable, operation::MlirOperation)::Cvoid end """ @@ -2664,9 +2483,7 @@ end Attempt to replace all uses that are nested within the given operation of the given symbol 'oldSymbol' with the provided 'newSymbol'. This does not traverse into nested symbol tables. Will fail atomically if there are any unknown operations that may be potential symbol tables. """ function mlirSymbolTableReplaceAllSymbolUses(oldSymbol, newSymbol, from) - @ccall mlir_c.mlirSymbolTableReplaceAllSymbolUses( - oldSymbol::MlirStringRef, newSymbol::MlirStringRef, from::MlirOperation - )::MlirLogicalResult + @ccall mlir_c.mlirSymbolTableReplaceAllSymbolUses(oldSymbol::MlirStringRef, newSymbol::MlirStringRef, from::MlirOperation)::MlirLogicalResult end """ @@ -2675,12 +2492,7 @@ end Walks all symbol table operations nested within, and including, `op`. For each symbol table operation, the provided callback is invoked with the op and a boolean signifying if the symbols within that symbol table can be treated as if all uses within the IR are visible to the caller. `allSymUsesVisible` identifies whether all of the symbol uses of symbols within `op` are visible. """ function mlirSymbolTableWalkSymbolTables(from, allSymUsesVisible, callback, userData) - @ccall mlir_c.mlirSymbolTableWalkSymbolTables( - from::MlirOperation, - allSymUsesVisible::Bool, - callback::Ptr{Cvoid}, - userData::Ptr{Cvoid}, - )::Cvoid + @ccall mlir_c.mlirSymbolTableWalkSymbolTables(from::MlirOperation, allSymUsesVisible::Bool, callback::Ptr{Cvoid}, userData::Ptr{Cvoid})::Cvoid end struct MlirAffineExpr @@ -2720,9 +2532,7 @@ end Prints an affine expression by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirAffineExprPrint(affineExpr, callback, userData) - @ccall mlir_c.mlirAffineExprPrint( - affineExpr::MlirAffineExpr, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirAffineExprPrint(affineExpr::MlirAffineExpr, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -2767,9 +2577,7 @@ end Checks whether the given affine expression is a multiple of 'factor'. """ function mlirAffineExprIsMultipleOf(affineExpr, factor) - @ccall mlir_c.mlirAffineExprIsMultipleOf( - affineExpr::MlirAffineExpr, factor::Int64 - )::Bool + @ccall mlir_c.mlirAffineExprIsMultipleOf(affineExpr::MlirAffineExpr, factor::Int64)::Bool end """ @@ -2778,9 +2586,7 @@ end Checks whether the given affine expression involves AffineDimExpr 'position'. """ function mlirAffineExprIsFunctionOfDim(affineExpr, position) - @ccall mlir_c.mlirAffineExprIsFunctionOfDim( - affineExpr::MlirAffineExpr, position::Cptrdiff_t - )::Bool + @ccall mlir_c.mlirAffineExprIsFunctionOfDim(affineExpr::MlirAffineExpr, position::Cptrdiff_t)::Bool end struct MlirAffineMap @@ -2793,9 +2599,7 @@ end Composes the given map with the given expression. """ function mlirAffineExprCompose(affineExpr, affineMap) - @ccall mlir_c.mlirAffineExprCompose( - affineExpr::MlirAffineExpr, affineMap::MlirAffineMap - )::MlirAffineExpr + @ccall mlir_c.mlirAffineExprCompose(affineExpr::MlirAffineExpr, affineMap::MlirAffineMap)::MlirAffineExpr end """ @@ -2804,9 +2608,7 @@ end Replace dims[offset ... numDims) by dims[offset + shift ... shift + numDims). """ function mlirAffineExprShiftDims(affineExpr, numDims, shift, offset) - @ccall mlir_c.mlirAffineExprShiftDims( - affineExpr::MlirAffineExpr, numDims::UInt32, shift::UInt32, offset::UInt32 - )::MlirAffineExpr + @ccall mlir_c.mlirAffineExprShiftDims(affineExpr::MlirAffineExpr, numDims::UInt32, shift::UInt32, offset::UInt32)::MlirAffineExpr end """ @@ -2815,9 +2617,7 @@ end Replace symbols[offset ... numSymbols) by symbols[offset + shift ... shift + numSymbols). """ function mlirAffineExprShiftSymbols(affineExpr, numSymbols, shift, offset) - @ccall mlir_c.mlirAffineExprShiftSymbols( - affineExpr::MlirAffineExpr, numSymbols::UInt32, shift::UInt32, offset::UInt32 - )::MlirAffineExpr + @ccall mlir_c.mlirAffineExprShiftSymbols(affineExpr::MlirAffineExpr, numSymbols::UInt32, shift::UInt32, offset::UInt32)::MlirAffineExpr end """ @@ -2826,9 +2626,7 @@ end Simplify an affine expression by flattening and some amount of simple analysis. This has complexity linear in the number of nodes in 'expr'. Returns the simplified expression, which is the same as the input expression if it can't be simplified. When `expr` is semi-affine, a simplified semi-affine expression is constructed in the sorted order of dimension and symbol positions. """ function mlirSimplifyAffineExpr(expr, numDims, numSymbols) - @ccall mlir_c.mlirSimplifyAffineExpr( - expr::MlirAffineExpr, numDims::UInt32, numSymbols::UInt32 - )::MlirAffineExpr + @ccall mlir_c.mlirSimplifyAffineExpr(expr::MlirAffineExpr, numDims::UInt32, numSymbols::UInt32)::MlirAffineExpr end """ @@ -2846,9 +2644,7 @@ end Creates an affine dimension expression with 'position' in the context. """ function mlirAffineDimExprGet(ctx, position) - @ccall mlir_c.mlirAffineDimExprGet( - ctx::MlirContext, position::Cptrdiff_t - )::MlirAffineExpr + @ccall mlir_c.mlirAffineDimExprGet(ctx::MlirContext, position::Cptrdiff_t)::MlirAffineExpr end """ @@ -2875,9 +2671,7 @@ end Creates an affine symbol expression with 'position' in the context. """ function mlirAffineSymbolExprGet(ctx, position) - @ccall mlir_c.mlirAffineSymbolExprGet( - ctx::MlirContext, position::Cptrdiff_t - )::MlirAffineExpr + @ccall mlir_c.mlirAffineSymbolExprGet(ctx::MlirContext, position::Cptrdiff_t)::MlirAffineExpr end """ @@ -2904,9 +2698,7 @@ end Creates an affine constant expression with 'constant' in the context. """ function mlirAffineConstantExprGet(ctx, constant) - @ccall mlir_c.mlirAffineConstantExprGet( - ctx::MlirContext, constant::Int64 - )::MlirAffineExpr + @ccall mlir_c.mlirAffineConstantExprGet(ctx::MlirContext, constant::Int64)::MlirAffineExpr end """ @@ -2933,9 +2725,7 @@ end Creates an affine add expression with 'lhs' and 'rhs'. """ function mlirAffineAddExprGet(lhs, rhs) - @ccall mlir_c.mlirAffineAddExprGet( - lhs::MlirAffineExpr, rhs::MlirAffineExpr - )::MlirAffineExpr + @ccall mlir_c.mlirAffineAddExprGet(lhs::MlirAffineExpr, rhs::MlirAffineExpr)::MlirAffineExpr end """ @@ -2953,9 +2743,7 @@ end Creates an affine mul expression with 'lhs' and 'rhs'. """ function mlirAffineMulExprGet(lhs, rhs) - @ccall mlir_c.mlirAffineMulExprGet( - lhs::MlirAffineExpr, rhs::MlirAffineExpr - )::MlirAffineExpr + @ccall mlir_c.mlirAffineMulExprGet(lhs::MlirAffineExpr, rhs::MlirAffineExpr)::MlirAffineExpr end """ @@ -2973,9 +2761,7 @@ end Creates an affine mod expression with 'lhs' and 'rhs'. """ function mlirAffineModExprGet(lhs, rhs) - @ccall mlir_c.mlirAffineModExprGet( - lhs::MlirAffineExpr, rhs::MlirAffineExpr - )::MlirAffineExpr + @ccall mlir_c.mlirAffineModExprGet(lhs::MlirAffineExpr, rhs::MlirAffineExpr)::MlirAffineExpr end """ @@ -2993,9 +2779,7 @@ end Creates an affine floordiv expression with 'lhs' and 'rhs'. """ function mlirAffineFloorDivExprGet(lhs, rhs) - @ccall mlir_c.mlirAffineFloorDivExprGet( - lhs::MlirAffineExpr, rhs::MlirAffineExpr - )::MlirAffineExpr + @ccall mlir_c.mlirAffineFloorDivExprGet(lhs::MlirAffineExpr, rhs::MlirAffineExpr)::MlirAffineExpr end """ @@ -3013,9 +2797,7 @@ end Creates an affine ceildiv expression with 'lhs' and 'rhs'. """ function mlirAffineCeilDivExprGet(lhs, rhs) - @ccall mlir_c.mlirAffineCeilDivExprGet( - lhs::MlirAffineExpr, rhs::MlirAffineExpr - )::MlirAffineExpr + @ccall mlir_c.mlirAffineCeilDivExprGet(lhs::MlirAffineExpr, rhs::MlirAffineExpr)::MlirAffineExpr end """ @@ -3078,9 +2860,7 @@ end Prints an affine map by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirAffineMapPrint(affineMap, callback, userData) - @ccall mlir_c.mlirAffineMapPrint( - affineMap::MlirAffineMap, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirAffineMapPrint(affineMap::MlirAffineMap, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -3107,9 +2887,7 @@ end Creates a zero result affine map of the given dimensions and symbols in the context. The affine map is owned by the context. """ function mlirAffineMapZeroResultGet(ctx, dimCount, symbolCount) - @ccall mlir_c.mlirAffineMapZeroResultGet( - ctx::MlirContext, dimCount::Cptrdiff_t, symbolCount::Cptrdiff_t - )::MlirAffineMap + @ccall mlir_c.mlirAffineMapZeroResultGet(ctx::MlirContext, dimCount::Cptrdiff_t, symbolCount::Cptrdiff_t)::MlirAffineMap end """ @@ -3118,13 +2896,7 @@ end Creates an affine map with results defined by the given list of affine expressions. The map resulting map also has the requested number of input dimensions and symbols, regardless of them being used in the results. """ function mlirAffineMapGet(ctx, dimCount, symbolCount, nAffineExprs, affineExprs) - @ccall mlir_c.mlirAffineMapGet( - ctx::MlirContext, - dimCount::Cptrdiff_t, - symbolCount::Cptrdiff_t, - nAffineExprs::Cptrdiff_t, - affineExprs::Ptr{MlirAffineExpr}, - )::MlirAffineMap + @ccall mlir_c.mlirAffineMapGet(ctx::MlirContext, dimCount::Cptrdiff_t, symbolCount::Cptrdiff_t, nAffineExprs::Cptrdiff_t, affineExprs::Ptr{MlirAffineExpr})::MlirAffineMap end """ @@ -3142,9 +2914,7 @@ end Creates an affine map with 'numDims' identity in the context. The affine map is owned by the context. """ function mlirAffineMapMultiDimIdentityGet(ctx, numDims) - @ccall mlir_c.mlirAffineMapMultiDimIdentityGet( - ctx::MlirContext, numDims::Cptrdiff_t - )::MlirAffineMap + @ccall mlir_c.mlirAffineMapMultiDimIdentityGet(ctx::MlirContext, numDims::Cptrdiff_t)::MlirAffineMap end """ @@ -3153,9 +2923,7 @@ end Creates an identity affine map on the most minor dimensions in the context. The affine map is owned by the context. The function asserts that the number of dimensions is greater or equal to the number of results. """ function mlirAffineMapMinorIdentityGet(ctx, dims, results) - @ccall mlir_c.mlirAffineMapMinorIdentityGet( - ctx::MlirContext, dims::Cptrdiff_t, results::Cptrdiff_t - )::MlirAffineMap + @ccall mlir_c.mlirAffineMapMinorIdentityGet(ctx::MlirContext, dims::Cptrdiff_t, results::Cptrdiff_t)::MlirAffineMap end """ @@ -3164,9 +2932,7 @@ end Creates an affine map with a permutation expression and its size in the context. The permutation expression is a non-empty vector of integers. The elements of the permutation vector must be continuous from 0 and cannot be repeated (i.e. `[1,2,0]` is a valid permutation. `[2,0]` or `[1,1,2]` is an invalid permutation.) The affine map is owned by the context. """ function mlirAffineMapPermutationGet(ctx, size, permutation) - @ccall mlir_c.mlirAffineMapPermutationGet( - ctx::MlirContext, size::Cptrdiff_t, permutation::Ptr{Cuint} - )::MlirAffineMap + @ccall mlir_c.mlirAffineMapPermutationGet(ctx::MlirContext, size::Cptrdiff_t, permutation::Ptr{Cuint})::MlirAffineMap end """ @@ -3247,9 +3013,7 @@ end Returns the result at the given position. """ function mlirAffineMapGetResult(affineMap, pos) - @ccall mlir_c.mlirAffineMapGetResult( - affineMap::MlirAffineMap, pos::Cptrdiff_t - )::MlirAffineExpr + @ccall mlir_c.mlirAffineMapGetResult(affineMap::MlirAffineMap, pos::Cptrdiff_t)::MlirAffineExpr end """ @@ -3285,9 +3049,7 @@ end Returns the affine map consisting of the `resultPos` subset. """ function mlirAffineMapGetSubMap(affineMap, size, resultPos) - @ccall mlir_c.mlirAffineMapGetSubMap( - affineMap::MlirAffineMap, size::Cptrdiff_t, resultPos::Ptr{Cptrdiff_t} - )::MlirAffineMap + @ccall mlir_c.mlirAffineMapGetSubMap(affineMap::MlirAffineMap, size::Cptrdiff_t, resultPos::Ptr{Cptrdiff_t})::MlirAffineMap end """ @@ -3296,9 +3058,7 @@ end Returns the affine map consisting of the most major `numResults` results. Returns the null AffineMap if the `numResults` is equal to zero. Returns the `affineMap` if `numResults` is greater or equals to number of results of the given affine map. """ function mlirAffineMapGetMajorSubMap(affineMap, numResults) - @ccall mlir_c.mlirAffineMapGetMajorSubMap( - affineMap::MlirAffineMap, numResults::Cptrdiff_t - )::MlirAffineMap + @ccall mlir_c.mlirAffineMapGetMajorSubMap(affineMap::MlirAffineMap, numResults::Cptrdiff_t)::MlirAffineMap end """ @@ -3307,9 +3067,7 @@ end Returns the affine map consisting of the most minor `numResults` results. Returns the null AffineMap if the `numResults` is equal to zero. Returns the `affineMap` if `numResults` is greater or equals to number of results of the given affine map. """ function mlirAffineMapGetMinorSubMap(affineMap, numResults) - @ccall mlir_c.mlirAffineMapGetMinorSubMap( - affineMap::MlirAffineMap, numResults::Cptrdiff_t - )::MlirAffineMap + @ccall mlir_c.mlirAffineMapGetMinorSubMap(affineMap::MlirAffineMap, numResults::Cptrdiff_t)::MlirAffineMap end """ @@ -3317,16 +3075,8 @@ end Apply AffineExpr::replace(`map`) to each of the results and return a new new AffineMap with the new results and the specified number of dims and symbols. """ -function mlirAffineMapReplace( - affineMap, expression, replacement, numResultDims, numResultSyms -) - @ccall mlir_c.mlirAffineMapReplace( - affineMap::MlirAffineMap, - expression::MlirAffineExpr, - replacement::MlirAffineExpr, - numResultDims::Cptrdiff_t, - numResultSyms::Cptrdiff_t, - )::MlirAffineMap +function mlirAffineMapReplace(affineMap, expression, replacement, numResultDims, numResultSyms) + @ccall mlir_c.mlirAffineMapReplace(affineMap::MlirAffineMap, expression::MlirAffineExpr, replacement::MlirAffineExpr, numResultDims::Cptrdiff_t, numResultSyms::Cptrdiff_t)::MlirAffineMap end """ @@ -3335,12 +3085,7 @@ end Returns the simplified affine map resulting from dropping the symbols that do not appear in any of the individual maps in `affineMaps`. Asserts that all maps in `affineMaps` are normalized to the same number of dims and symbols. Takes a callback `populateResult` to fill the `res` container with value `m` at entry `idx`. This allows returning without worrying about ownership considerations. """ function mlirAffineMapCompressUnusedSymbols(affineMaps, size, result, populateResult) - @ccall mlir_c.mlirAffineMapCompressUnusedSymbols( - affineMaps::Ptr{MlirAffineMap}, - size::Cptrdiff_t, - result::Ptr{Cvoid}, - populateResult::Ptr{Cvoid}, - )::Cvoid + @ccall mlir_c.mlirAffineMapCompressUnusedSymbols(affineMaps::Ptr{MlirAffineMap}, size::Cptrdiff_t, result::Ptr{Cvoid}, populateResult::Ptr{Cvoid})::Cvoid end struct MlirIntegerSet @@ -3380,9 +3125,7 @@ end Prints an integer set by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirIntegerSetPrint(set, callback, userData) - @ccall mlir_c.mlirIntegerSetPrint( - set::MlirIntegerSet, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirIntegerSetPrint(set::MlirIntegerSet, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -3400,9 +3143,7 @@ end Gets or creates a new canonically empty integer set with the give number of dimensions and symbols in the given context. """ function mlirIntegerSetEmptyGet(context, numDims, numSymbols) - @ccall mlir_c.mlirIntegerSetEmptyGet( - context::MlirContext, numDims::Cptrdiff_t, numSymbols::Cptrdiff_t - )::MlirIntegerSet + @ccall mlir_c.mlirIntegerSetEmptyGet(context::MlirContext, numDims::Cptrdiff_t, numSymbols::Cptrdiff_t)::MlirIntegerSet end """ @@ -3410,17 +3151,8 @@ end Gets or creates a new integer set in the given context. The set is defined by a list of affine constraints, with the given number of input dimensions and symbols, which are treated as either equalities (eqFlags is 1) or inequalities (eqFlags is 0). Both `constraints` and `eqFlags` are expected to point to at least `numConstraint` consecutive values. """ -function mlirIntegerSetGet( - context, numDims, numSymbols, numConstraints, constraints, eqFlags -) - @ccall mlir_c.mlirIntegerSetGet( - context::MlirContext, - numDims::Cptrdiff_t, - numSymbols::Cptrdiff_t, - numConstraints::Cptrdiff_t, - constraints::Ptr{MlirAffineExpr}, - eqFlags::Ptr{Bool}, - )::MlirIntegerSet +function mlirIntegerSetGet(context, numDims, numSymbols, numConstraints, constraints, eqFlags) + @ccall mlir_c.mlirIntegerSetGet(context::MlirContext, numDims::Cptrdiff_t, numSymbols::Cptrdiff_t, numConstraints::Cptrdiff_t, constraints::Ptr{MlirAffineExpr}, eqFlags::Ptr{Bool})::MlirIntegerSet end """ @@ -3428,16 +3160,8 @@ end Gets or creates a new integer set in which the values and dimensions of the given set are replaced with the given affine expressions. `dimReplacements` and `symbolReplacements` are expected to point to at least as many consecutive expressions as the given set has dimensions and symbols, respectively. The new set will have `numResultDims` and `numResultSymbols` dimensions and symbols, respectively. """ -function mlirIntegerSetReplaceGet( - set, dimReplacements, symbolReplacements, numResultDims, numResultSymbols -) - @ccall mlir_c.mlirIntegerSetReplaceGet( - set::MlirIntegerSet, - dimReplacements::Ptr{MlirAffineExpr}, - symbolReplacements::Ptr{MlirAffineExpr}, - numResultDims::Cptrdiff_t, - numResultSymbols::Cptrdiff_t, - )::MlirIntegerSet +function mlirIntegerSetReplaceGet(set, dimReplacements, symbolReplacements, numResultDims, numResultSymbols) + @ccall mlir_c.mlirIntegerSetReplaceGet(set::MlirIntegerSet, dimReplacements::Ptr{MlirAffineExpr}, symbolReplacements::Ptr{MlirAffineExpr}, numResultDims::Cptrdiff_t, numResultSymbols::Cptrdiff_t)::MlirIntegerSet end """ @@ -3509,9 +3233,7 @@ end Returns `pos`-th constraint of the set. """ function mlirIntegerSetGetConstraint(set, pos) - @ccall mlir_c.mlirIntegerSetGetConstraint( - set::MlirIntegerSet, pos::Cptrdiff_t - )::MlirAffineExpr + @ccall mlir_c.mlirIntegerSetGetConstraint(set::MlirIntegerSet, pos::Cptrdiff_t)::MlirAffineExpr end """ @@ -3587,9 +3309,7 @@ end Creates an array element containing the given list of elements in the given context. """ function mlirArrayAttrGet(ctx, numElements, elements) - @ccall mlir_c.mlirArrayAttrGet( - ctx::MlirContext, numElements::Cptrdiff_t, elements::Ptr{MlirAttribute} - )::MlirAttribute + @ccall mlir_c.mlirArrayAttrGet(ctx::MlirContext, numElements::Cptrdiff_t, elements::Ptr{MlirAttribute})::MlirAttribute end """ @@ -3607,9 +3327,7 @@ end Returns pos-th element stored in the given array attribute. """ function mlirArrayAttrGetElement(attr, pos) - @ccall mlir_c.mlirArrayAttrGetElement( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.mlirArrayAttrGetElement(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end """ @@ -3636,9 +3354,7 @@ end Creates a dictionary attribute containing the given list of elements in the provided context. """ function mlirDictionaryAttrGet(ctx, numElements, elements) - @ccall mlir_c.mlirDictionaryAttrGet( - ctx::MlirContext, numElements::Cptrdiff_t, elements::Ptr{MlirNamedAttribute} - )::MlirAttribute + @ccall mlir_c.mlirDictionaryAttrGet(ctx::MlirContext, numElements::Cptrdiff_t, elements::Ptr{MlirNamedAttribute})::MlirAttribute end """ @@ -3656,9 +3372,7 @@ end Returns pos-th element of the given dictionary attribute. """ function mlirDictionaryAttrGetElement(attr, pos) - @ccall mlir_c.mlirDictionaryAttrGetElement( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirNamedAttribute + @ccall mlir_c.mlirDictionaryAttrGetElement(attr::MlirAttribute, pos::Cptrdiff_t)::MlirNamedAttribute end """ @@ -3667,9 +3381,7 @@ end Returns the dictionary attribute element with the given name or NULL if the given name does not exist in the dictionary. """ function mlirDictionaryAttrGetElementByName(attr, name) - @ccall mlir_c.mlirDictionaryAttrGetElementByName( - attr::MlirAttribute, name::MlirStringRef - )::MlirAttribute + @ccall mlir_c.mlirDictionaryAttrGetElementByName(attr::MlirAttribute, name::MlirStringRef)::MlirAttribute end """ @@ -3696,9 +3408,7 @@ end Creates a floating point attribute in the given context with the given double value and double-precision FP semantics. """ function mlirFloatAttrDoubleGet(ctx, type, value) - @ccall mlir_c.mlirFloatAttrDoubleGet( - ctx::MlirContext, type::MlirType, value::Cdouble - )::MlirAttribute + @ccall mlir_c.mlirFloatAttrDoubleGet(ctx::MlirContext, type::MlirType, value::Cdouble)::MlirAttribute end """ @@ -3707,9 +3417,7 @@ end Same as "[`mlirFloatAttrDoubleGet`](@ref)", but if the type is not valid for a construction of a FloatAttr, returns a null [`MlirAttribute`](@ref). """ function mlirFloatAttrDoubleGetChecked(loc, type, value) - @ccall mlir_c.mlirFloatAttrDoubleGetChecked( - loc::MlirLocation, type::MlirType, value::Cdouble - )::MlirAttribute + @ccall mlir_c.mlirFloatAttrDoubleGetChecked(loc::MlirLocation, type::MlirType, value::Cdouble)::MlirAttribute end """ @@ -3862,13 +3570,7 @@ end Creates an opaque attribute in the given context associated with the dialect identified by its namespace. The attribute contains opaque byte data of the specified length (data need not be null-terminated). """ function mlirOpaqueAttrGet(ctx, dialectNamespace, dataLength, data, type) - @ccall mlir_c.mlirOpaqueAttrGet( - ctx::MlirContext, - dialectNamespace::MlirStringRef, - dataLength::Cptrdiff_t, - data::Cstring, - type::MlirType, - )::MlirAttribute + @ccall mlir_c.mlirOpaqueAttrGet(ctx::MlirContext, dialectNamespace::MlirStringRef, dataLength::Cptrdiff_t, data::Cstring, type::MlirType)::MlirAttribute end """ @@ -3958,12 +3660,7 @@ end Creates a symbol reference attribute in the given context referencing a symbol identified by the given string inside a list of nested references. Each of the references in the list must not be nested. """ function mlirSymbolRefAttrGet(ctx, symbol, numReferences, references) - @ccall mlir_c.mlirSymbolRefAttrGet( - ctx::MlirContext, - symbol::MlirStringRef, - numReferences::Cptrdiff_t, - references::Ptr{MlirAttribute}, - )::MlirAttribute + @ccall mlir_c.mlirSymbolRefAttrGet(ctx::MlirContext, symbol::MlirStringRef, numReferences::Cptrdiff_t, references::Ptr{MlirAttribute})::MlirAttribute end """ @@ -3999,9 +3696,7 @@ end Returns pos-th reference nested in the given symbol reference attribute. """ function mlirSymbolRefAttrGetNestedReference(attr, pos) - @ccall mlir_c.mlirSymbolRefAttrGetNestedReference( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.mlirSymbolRefAttrGetNestedReference(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end """ @@ -4037,9 +3732,7 @@ end Creates a flat symbol reference attribute in the given context referencing a symbol identified by the given string. """ function mlirFlatSymbolRefAttrGet(ctx, symbol) - @ccall mlir_c.mlirFlatSymbolRefAttrGet( - ctx::MlirContext, symbol::MlirStringRef - )::MlirAttribute + @ccall mlir_c.mlirFlatSymbolRefAttrGet(ctx::MlirContext, symbol::MlirStringRef)::MlirAttribute end """ @@ -4129,9 +3822,7 @@ end Returns the element at the given rank-dimensional index. """ function mlirElementsAttrGetValue(attr, rank, idxs) - @ccall mlir_c.mlirElementsAttrGetValue( - attr::MlirAttribute, rank::Cptrdiff_t, idxs::Ptr{UInt64} - )::MlirAttribute + @ccall mlir_c.mlirElementsAttrGetValue(attr::MlirAttribute, rank::Cptrdiff_t, idxs::Ptr{UInt64})::MlirAttribute end """ @@ -4140,9 +3831,7 @@ end Checks whether the given rank-dimensional index is valid in the given elements attribute. """ function mlirElementsAttrIsValidIndex(attr, rank, idxs) - @ccall mlir_c.mlirElementsAttrIsValidIndex( - attr::MlirAttribute, rank::Cptrdiff_t, idxs::Ptr{UInt64} - )::Bool + @ccall mlir_c.mlirElementsAttrIsValidIndex(attr::MlirAttribute, rank::Cptrdiff_t, idxs::Ptr{UInt64})::Bool end """ @@ -4197,45 +3886,31 @@ end Create a dense array attribute with the given elements. """ function mlirDenseBoolArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseBoolArrayGet( - ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Cint} - )::MlirAttribute + @ccall mlir_c.mlirDenseBoolArrayGet(ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Cint})::MlirAttribute end function mlirDenseI8ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseI8ArrayGet( - ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Int8} - )::MlirAttribute + @ccall mlir_c.mlirDenseI8ArrayGet(ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Int8})::MlirAttribute end function mlirDenseI16ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseI16ArrayGet( - ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Int16} - )::MlirAttribute + @ccall mlir_c.mlirDenseI16ArrayGet(ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Int16})::MlirAttribute end function mlirDenseI32ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseI32ArrayGet( - ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Int32} - )::MlirAttribute + @ccall mlir_c.mlirDenseI32ArrayGet(ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Int32})::MlirAttribute end function mlirDenseI64ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseI64ArrayGet( - ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Int64} - )::MlirAttribute + @ccall mlir_c.mlirDenseI64ArrayGet(ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Int64})::MlirAttribute end function mlirDenseF32ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseF32ArrayGet( - ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Cfloat} - )::MlirAttribute + @ccall mlir_c.mlirDenseF32ArrayGet(ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Cfloat})::MlirAttribute end function mlirDenseF64ArrayGet(ctx, size, values) - @ccall mlir_c.mlirDenseF64ArrayGet( - ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Cdouble} - )::MlirAttribute + @ccall mlir_c.mlirDenseF64ArrayGet(ctx::MlirContext, size::Cptrdiff_t, values::Ptr{Cdouble})::MlirAttribute end """ @@ -4312,9 +3987,7 @@ end Creates a dense elements attribute with the given Shaped type and elements in the same context as the type. """ function mlirDenseElementsAttrGet(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrGet( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{MlirAttribute} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrGet(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{MlirAttribute})::MlirAttribute end """ @@ -4327,9 +4000,7 @@ The format of the raw buffer is a densely packed array of values that can be bit A raw buffer of a single element (or for 1-bit, a byte of value 0 or 255) will be interpreted as a splat. User code should be prepared for additional, conformant patterns to be identified as splats in the future. """ function mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, rawBuffer) - @ccall mlir_c.mlirDenseElementsAttrRawBufferGet( - shapedType::MlirType, rawBufferSize::Csize_t, rawBuffer::Ptr{Cvoid} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrRawBufferGet(shapedType::MlirType, rawBufferSize::Csize_t, rawBuffer::Ptr{Cvoid})::MlirAttribute end """ @@ -4338,63 +4009,43 @@ end Creates a dense elements attribute with the given Shaped type containing a single replicated element (splat). """ function mlirDenseElementsAttrSplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrSplatGet( - shapedType::MlirType, element::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrSplatGet(shapedType::MlirType, element::MlirAttribute)::MlirAttribute end function mlirDenseElementsAttrBoolSplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrBoolSplatGet( - shapedType::MlirType, element::Bool - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrBoolSplatGet(shapedType::MlirType, element::Bool)::MlirAttribute end function mlirDenseElementsAttrUInt8SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrUInt8SplatGet( - shapedType::MlirType, element::UInt8 - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt8SplatGet(shapedType::MlirType, element::UInt8)::MlirAttribute end function mlirDenseElementsAttrInt8SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrInt8SplatGet( - shapedType::MlirType, element::Int8 - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt8SplatGet(shapedType::MlirType, element::Int8)::MlirAttribute end function mlirDenseElementsAttrUInt32SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrUInt32SplatGet( - shapedType::MlirType, element::UInt32 - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt32SplatGet(shapedType::MlirType, element::UInt32)::MlirAttribute end function mlirDenseElementsAttrInt32SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrInt32SplatGet( - shapedType::MlirType, element::Int32 - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt32SplatGet(shapedType::MlirType, element::Int32)::MlirAttribute end function mlirDenseElementsAttrUInt64SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrUInt64SplatGet( - shapedType::MlirType, element::UInt64 - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt64SplatGet(shapedType::MlirType, element::UInt64)::MlirAttribute end function mlirDenseElementsAttrInt64SplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrInt64SplatGet( - shapedType::MlirType, element::Int64 - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt64SplatGet(shapedType::MlirType, element::Int64)::MlirAttribute end function mlirDenseElementsAttrFloatSplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrFloatSplatGet( - shapedType::MlirType, element::Cfloat - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrFloatSplatGet(shapedType::MlirType, element::Cfloat)::MlirAttribute end function mlirDenseElementsAttrDoubleSplatGet(shapedType, element) - @ccall mlir_c.mlirDenseElementsAttrDoubleSplatGet( - shapedType::MlirType, element::Cdouble - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrDoubleSplatGet(shapedType::MlirType, element::Cdouble)::MlirAttribute end """ @@ -4403,81 +4054,55 @@ end Creates a dense elements attribute with the given shaped type from elements of a specific type. Expects the element type of the shaped type to match the data element type. """ function mlirDenseElementsAttrBoolGet(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrBoolGet( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Cint} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrBoolGet(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Cint})::MlirAttribute end function mlirDenseElementsAttrUInt8Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrUInt8Get( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt8} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt8Get(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt8})::MlirAttribute end function mlirDenseElementsAttrInt8Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrInt8Get( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Int8} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt8Get(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Int8})::MlirAttribute end function mlirDenseElementsAttrUInt16Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrUInt16Get( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt16} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt16Get(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt16})::MlirAttribute end function mlirDenseElementsAttrInt16Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrInt16Get( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Int16} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt16Get(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Int16})::MlirAttribute end function mlirDenseElementsAttrUInt32Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrUInt32Get( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt32} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt32Get(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt32})::MlirAttribute end function mlirDenseElementsAttrInt32Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrInt32Get( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Int32} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt32Get(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Int32})::MlirAttribute end function mlirDenseElementsAttrUInt64Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrUInt64Get( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt64} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrUInt64Get(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt64})::MlirAttribute end function mlirDenseElementsAttrInt64Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrInt64Get( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Int64} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrInt64Get(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Int64})::MlirAttribute end function mlirDenseElementsAttrFloatGet(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrFloatGet( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Cfloat} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrFloatGet(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Cfloat})::MlirAttribute end function mlirDenseElementsAttrDoubleGet(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrDoubleGet( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Cdouble} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrDoubleGet(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{Cdouble})::MlirAttribute end function mlirDenseElementsAttrBFloat16Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrBFloat16Get( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt16} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrBFloat16Get(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt16})::MlirAttribute end function mlirDenseElementsAttrFloat16Get(shapedType, numElements, elements) - @ccall mlir_c.mlirDenseElementsAttrFloat16Get( - shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt16} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrFloat16Get(shapedType::MlirType, numElements::Cptrdiff_t, elements::Ptr{UInt16})::MlirAttribute end """ @@ -4486,9 +4111,7 @@ end Creates a dense elements attribute with the given shaped type from string elements. """ function mlirDenseElementsAttrStringGet(shapedType, numElements, strs) - @ccall mlir_c.mlirDenseElementsAttrStringGet( - shapedType::MlirType, numElements::Cptrdiff_t, strs::Ptr{MlirStringRef} - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrStringGet(shapedType::MlirType, numElements::Cptrdiff_t, strs::Ptr{MlirStringRef})::MlirAttribute end """ @@ -4497,9 +4120,7 @@ end Creates a dense elements attribute that has the same data as the given dense elements attribute and a different shaped type. The new type must have the same total number of elements. """ function mlirDenseElementsAttrReshapeGet(attr, shapedType) - @ccall mlir_c.mlirDenseElementsAttrReshapeGet( - attr::MlirAttribute, shapedType::MlirType - )::MlirAttribute + @ccall mlir_c.mlirDenseElementsAttrReshapeGet(attr::MlirAttribute, shapedType::MlirType)::MlirAttribute end """ @@ -4557,9 +4178,7 @@ function mlirDenseElementsAttrGetDoubleSplatValue(attr) end function mlirDenseElementsAttrGetStringSplatValue(attr) - @ccall mlir_c.mlirDenseElementsAttrGetStringSplatValue( - attr::MlirAttribute - )::MlirStringRef + @ccall mlir_c.mlirDenseElementsAttrGetStringSplatValue(attr::MlirAttribute)::MlirStringRef end """ @@ -4568,81 +4187,55 @@ end Returns the pos-th value (flat contiguous indexing) of a specific type contained by the given dense elements attribute. """ function mlirDenseElementsAttrGetBoolValue(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetBoolValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::Bool + @ccall mlir_c.mlirDenseElementsAttrGetBoolValue(attr::MlirAttribute, pos::Cptrdiff_t)::Bool end function mlirDenseElementsAttrGetInt8Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetInt8Value( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int8 + @ccall mlir_c.mlirDenseElementsAttrGetInt8Value(attr::MlirAttribute, pos::Cptrdiff_t)::Int8 end function mlirDenseElementsAttrGetUInt8Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetUInt8Value( - attr::MlirAttribute, pos::Cptrdiff_t - )::UInt8 + @ccall mlir_c.mlirDenseElementsAttrGetUInt8Value(attr::MlirAttribute, pos::Cptrdiff_t)::UInt8 end function mlirDenseElementsAttrGetInt16Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetInt16Value( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int16 + @ccall mlir_c.mlirDenseElementsAttrGetInt16Value(attr::MlirAttribute, pos::Cptrdiff_t)::Int16 end function mlirDenseElementsAttrGetUInt16Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetUInt16Value( - attr::MlirAttribute, pos::Cptrdiff_t - )::UInt16 + @ccall mlir_c.mlirDenseElementsAttrGetUInt16Value(attr::MlirAttribute, pos::Cptrdiff_t)::UInt16 end function mlirDenseElementsAttrGetInt32Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetInt32Value( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int32 + @ccall mlir_c.mlirDenseElementsAttrGetInt32Value(attr::MlirAttribute, pos::Cptrdiff_t)::Int32 end function mlirDenseElementsAttrGetUInt32Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetUInt32Value( - attr::MlirAttribute, pos::Cptrdiff_t - )::UInt32 + @ccall mlir_c.mlirDenseElementsAttrGetUInt32Value(attr::MlirAttribute, pos::Cptrdiff_t)::UInt32 end function mlirDenseElementsAttrGetInt64Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetInt64Value( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.mlirDenseElementsAttrGetInt64Value(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function mlirDenseElementsAttrGetUInt64Value(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetUInt64Value( - attr::MlirAttribute, pos::Cptrdiff_t - )::UInt64 + @ccall mlir_c.mlirDenseElementsAttrGetUInt64Value(attr::MlirAttribute, pos::Cptrdiff_t)::UInt64 end function mlirDenseElementsAttrGetIndexValue(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetIndexValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::UInt64 + @ccall mlir_c.mlirDenseElementsAttrGetIndexValue(attr::MlirAttribute, pos::Cptrdiff_t)::UInt64 end function mlirDenseElementsAttrGetFloatValue(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetFloatValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::Cfloat + @ccall mlir_c.mlirDenseElementsAttrGetFloatValue(attr::MlirAttribute, pos::Cptrdiff_t)::Cfloat end function mlirDenseElementsAttrGetDoubleValue(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetDoubleValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::Cdouble + @ccall mlir_c.mlirDenseElementsAttrGetDoubleValue(attr::MlirAttribute, pos::Cptrdiff_t)::Cdouble end function mlirDenseElementsAttrGetStringValue(attr, pos) - @ccall mlir_c.mlirDenseElementsAttrGetStringValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirStringRef + @ccall mlir_c.mlirDenseElementsAttrGetStringValue(attr::MlirAttribute, pos::Cptrdiff_t)::MlirStringRef end """ @@ -4663,140 +4256,52 @@ end Unlike the typed accessors below, constructs the attribute with a raw data buffer and no type/alignment checking. Use a more strongly typed accessor if possible. If dataIsMutable is false, then an immutable AsmResourceBlob will be created and that passed data contents will be treated as const. If the deleter is non NULL, then it will be called when the data buffer can no longer be accessed (passing userData to it). """ -function mlirUnmanagedDenseResourceElementsAttrGet( - shapedType, name, data, dataLength, dataAlignment, dataIsMutable, deleter, userData -) - @ccall mlir_c.mlirUnmanagedDenseResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - data::Ptr{Cvoid}, - dataLength::Csize_t, - dataAlignment::Csize_t, - dataIsMutable::Bool, - deleter::Ptr{Cvoid}, - userData::Ptr{Cvoid}, - )::MlirAttribute +function mlirUnmanagedDenseResourceElementsAttrGet(shapedType, name, data, dataLength, dataAlignment, dataIsMutable, deleter, userData) + @ccall mlir_c.mlirUnmanagedDenseResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, data::Ptr{Cvoid}, dataLength::Csize_t, dataAlignment::Csize_t, dataIsMutable::Bool, deleter::Ptr{Cvoid}, userData::Ptr{Cvoid})::MlirAttribute end -function mlirUnmanagedDenseBoolResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseBoolResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{Cint}, - )::MlirAttribute +function mlirUnmanagedDenseBoolResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseBoolResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{Cint})::MlirAttribute end -function mlirUnmanagedDenseUInt8ResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseUInt8ResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{UInt8}, - )::MlirAttribute +function mlirUnmanagedDenseUInt8ResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseUInt8ResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{UInt8})::MlirAttribute end -function mlirUnmanagedDenseInt8ResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseInt8ResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{Int8}, - )::MlirAttribute +function mlirUnmanagedDenseInt8ResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseInt8ResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{Int8})::MlirAttribute end -function mlirUnmanagedDenseUInt16ResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseUInt16ResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{UInt16}, - )::MlirAttribute +function mlirUnmanagedDenseUInt16ResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseUInt16ResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{UInt16})::MlirAttribute end -function mlirUnmanagedDenseInt16ResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseInt16ResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{Int16}, - )::MlirAttribute +function mlirUnmanagedDenseInt16ResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseInt16ResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{Int16})::MlirAttribute end -function mlirUnmanagedDenseUInt32ResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseUInt32ResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{UInt32}, - )::MlirAttribute +function mlirUnmanagedDenseUInt32ResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseUInt32ResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{UInt32})::MlirAttribute end -function mlirUnmanagedDenseInt32ResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseInt32ResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{Int32}, - )::MlirAttribute +function mlirUnmanagedDenseInt32ResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseInt32ResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{Int32})::MlirAttribute end -function mlirUnmanagedDenseUInt64ResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseUInt64ResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{UInt64}, - )::MlirAttribute +function mlirUnmanagedDenseUInt64ResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseUInt64ResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{UInt64})::MlirAttribute end -function mlirUnmanagedDenseInt64ResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseInt64ResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{Int64}, - )::MlirAttribute +function mlirUnmanagedDenseInt64ResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseInt64ResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{Int64})::MlirAttribute end -function mlirUnmanagedDenseFloatResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseFloatResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{Cfloat}, - )::MlirAttribute +function mlirUnmanagedDenseFloatResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseFloatResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{Cfloat})::MlirAttribute end -function mlirUnmanagedDenseDoubleResourceElementsAttrGet( - shapedType, name, numElements, elements -) - @ccall mlir_c.mlirUnmanagedDenseDoubleResourceElementsAttrGet( - shapedType::MlirType, - name::MlirStringRef, - numElements::Cptrdiff_t, - elements::Ptr{Cdouble}, - )::MlirAttribute +function mlirUnmanagedDenseDoubleResourceElementsAttrGet(shapedType, name, numElements, elements) + @ccall mlir_c.mlirUnmanagedDenseDoubleResourceElementsAttrGet(shapedType::MlirType, name::MlirStringRef, numElements::Cptrdiff_t, elements::Ptr{Cdouble})::MlirAttribute end """ @@ -4805,69 +4310,47 @@ end Returns the pos-th value (flat contiguous indexing) of a specific type contained by the given dense resource elements attribute. """ function mlirDenseBoolResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseBoolResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::Bool + @ccall mlir_c.mlirDenseBoolResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::Bool end function mlirDenseInt8ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseInt8ResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int8 + @ccall mlir_c.mlirDenseInt8ResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::Int8 end function mlirDenseUInt8ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseUInt8ResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::UInt8 + @ccall mlir_c.mlirDenseUInt8ResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::UInt8 end function mlirDenseInt16ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseInt16ResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int16 + @ccall mlir_c.mlirDenseInt16ResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::Int16 end function mlirDenseUInt16ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseUInt16ResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::UInt16 + @ccall mlir_c.mlirDenseUInt16ResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::UInt16 end function mlirDenseInt32ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseInt32ResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int32 + @ccall mlir_c.mlirDenseInt32ResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::Int32 end function mlirDenseUInt32ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseUInt32ResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::UInt32 + @ccall mlir_c.mlirDenseUInt32ResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::UInt32 end function mlirDenseInt64ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseInt64ResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.mlirDenseInt64ResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function mlirDenseUInt64ResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseUInt64ResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::UInt64 + @ccall mlir_c.mlirDenseUInt64ResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::UInt64 end function mlirDenseFloatResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseFloatResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::Cfloat + @ccall mlir_c.mlirDenseFloatResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::Cfloat end function mlirDenseDoubleResourceElementsAttrGetValue(attr, pos) - @ccall mlir_c.mlirDenseDoubleResourceElementsAttrGetValue( - attr::MlirAttribute, pos::Cptrdiff_t - )::Cdouble + @ccall mlir_c.mlirDenseDoubleResourceElementsAttrGetValue(attr::MlirAttribute, pos::Cptrdiff_t)::Cdouble end """ @@ -4885,9 +4368,7 @@ end Creates a sparse elements attribute of the given shape from a list of indices and a list of associated values. Both lists are expected to be dense elements attributes with the same number of elements. The list of indices is expected to contain 64-bit integers. The attribute is created in the same context as the type. """ function mlirSparseElementsAttribute(shapedType, denseIndices, denseValues) - @ccall mlir_c.mlirSparseElementsAttribute( - shapedType::MlirType, denseIndices::MlirAttribute, denseValues::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirSparseElementsAttribute(shapedType::MlirType, denseIndices::MlirAttribute, denseValues::MlirAttribute)::MlirAttribute end """ @@ -4922,9 +4403,7 @@ function mlirAttributeIsAStridedLayout(attr) end function mlirStridedLayoutAttrGet(ctx, offset, numStrides, strides) - @ccall mlir_c.mlirStridedLayoutAttrGet( - ctx::MlirContext, offset::Int64, numStrides::Cptrdiff_t, strides::Ptr{Int64} - )::MlirAttribute + @ccall mlir_c.mlirStridedLayoutAttrGet(ctx::MlirContext, offset::Int64, numStrides::Cptrdiff_t, strides::Ptr{Int64})::MlirAttribute end function mlirStridedLayoutAttrGetOffset(attr) @@ -4936,9 +4415,7 @@ function mlirStridedLayoutAttrGetNumStrides(attr) end function mlirStridedLayoutAttrGetStride(attr, pos) - @ccall mlir_c.mlirStridedLayoutAttrGetStride( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.mlirStridedLayoutAttrGetStride(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end """ @@ -5721,9 +5198,7 @@ end Creates a vector type of the shape identified by its rank and dimensions, with the given element type in the same context as the element type. The type is owned by the context. """ function mlirVectorTypeGet(rank, shape, elementType) - @ccall mlir_c.mlirVectorTypeGet( - rank::Cptrdiff_t, shape::Ptr{Int64}, elementType::MlirType - )::MlirType + @ccall mlir_c.mlirVectorTypeGet(rank::Cptrdiff_t, shape::Ptr{Int64}, elementType::MlirType)::MlirType end """ @@ -5732,9 +5207,7 @@ end Same as "[`mlirVectorTypeGet`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirVectorTypeGetChecked(loc, rank, shape, elementType) - @ccall mlir_c.mlirVectorTypeGetChecked( - loc::MlirLocation, rank::Cptrdiff_t, shape::Ptr{Int64}, elementType::MlirType - )::MlirType + @ccall mlir_c.mlirVectorTypeGetChecked(loc::MlirLocation, rank::Cptrdiff_t, shape::Ptr{Int64}, elementType::MlirType)::MlirType end """ @@ -5743,9 +5216,7 @@ end Creates a scalable vector type with the shape identified by its rank and dimensions. A subset of dimensions may be marked as scalable via the corresponding flag list, which is expected to have as many entries as the rank of the vector. The vector is created in the same context as the element type. """ function mlirVectorTypeGetScalable(rank, shape, scalable, elementType) - @ccall mlir_c.mlirVectorTypeGetScalable( - rank::Cptrdiff_t, shape::Ptr{Int64}, scalable::Ptr{Bool}, elementType::MlirType - )::MlirType + @ccall mlir_c.mlirVectorTypeGetScalable(rank::Cptrdiff_t, shape::Ptr{Int64}, scalable::Ptr{Bool}, elementType::MlirType)::MlirType end """ @@ -5754,13 +5225,7 @@ end Same as "[`mlirVectorTypeGetScalable`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirVectorTypeGetScalableChecked(loc, rank, shape, scalable, elementType) - @ccall mlir_c.mlirVectorTypeGetScalableChecked( - loc::MlirLocation, - rank::Cptrdiff_t, - shape::Ptr{Int64}, - scalable::Ptr{Bool}, - elementType::MlirType, - )::MlirType + @ccall mlir_c.mlirVectorTypeGetScalableChecked(loc::MlirLocation, rank::Cptrdiff_t, shape::Ptr{Int64}, scalable::Ptr{Bool}, elementType::MlirType)::MlirType end """ @@ -5832,9 +5297,7 @@ end Creates a tensor type of a fixed rank with the given shape, element type, and optional encoding in the same context as the element type. The type is owned by the context. Tensor types without any specific encoding field should assign [`mlirAttributeGetNull`](@ref)() to this parameter. """ function mlirRankedTensorTypeGet(rank, shape, elementType, encoding) - @ccall mlir_c.mlirRankedTensorTypeGet( - rank::Cptrdiff_t, shape::Ptr{Int64}, elementType::MlirType, encoding::MlirAttribute - )::MlirType + @ccall mlir_c.mlirRankedTensorTypeGet(rank::Cptrdiff_t, shape::Ptr{Int64}, elementType::MlirType, encoding::MlirAttribute)::MlirType end """ @@ -5843,13 +5306,7 @@ end Same as "[`mlirRankedTensorTypeGet`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirRankedTensorTypeGetChecked(loc, rank, shape, elementType, encoding) - @ccall mlir_c.mlirRankedTensorTypeGetChecked( - loc::MlirLocation, - rank::Cptrdiff_t, - shape::Ptr{Int64}, - elementType::MlirType, - encoding::MlirAttribute, - )::MlirType + @ccall mlir_c.mlirRankedTensorTypeGetChecked(loc::MlirLocation, rank::Cptrdiff_t, shape::Ptr{Int64}, elementType::MlirType, encoding::MlirAttribute)::MlirType end """ @@ -5876,9 +5333,7 @@ end Same as "[`mlirUnrankedTensorTypeGet`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirUnrankedTensorTypeGetChecked(loc, elementType) - @ccall mlir_c.mlirUnrankedTensorTypeGetChecked( - loc::MlirLocation, elementType::MlirType - )::MlirType + @ccall mlir_c.mlirUnrankedTensorTypeGetChecked(loc::MlirLocation, elementType::MlirType)::MlirType end """ @@ -5923,13 +5378,7 @@ end Creates a MemRef type with the given rank and shape, a potentially empty list of affine layout maps, the given memory space and element type, in the same context as element type. The type is owned by the context. """ function mlirMemRefTypeGet(elementType, rank, shape, layout, memorySpace) - @ccall mlir_c.mlirMemRefTypeGet( - elementType::MlirType, - rank::Cptrdiff_t, - shape::Ptr{Int64}, - layout::MlirAttribute, - memorySpace::MlirAttribute, - )::MlirType + @ccall mlir_c.mlirMemRefTypeGet(elementType::MlirType, rank::Cptrdiff_t, shape::Ptr{Int64}, layout::MlirAttribute, memorySpace::MlirAttribute)::MlirType end """ @@ -5938,14 +5387,7 @@ end Same as "[`mlirMemRefTypeGet`](@ref)" but returns a nullptr-wrapping [`MlirType`](@ref) o illegal arguments, emitting appropriate diagnostics. """ function mlirMemRefTypeGetChecked(loc, elementType, rank, shape, layout, memorySpace) - @ccall mlir_c.mlirMemRefTypeGetChecked( - loc::MlirLocation, - elementType::MlirType, - rank::Cptrdiff_t, - shape::Ptr{Int64}, - layout::MlirAttribute, - memorySpace::MlirAttribute, - )::MlirType + @ccall mlir_c.mlirMemRefTypeGetChecked(loc::MlirLocation, elementType::MlirType, rank::Cptrdiff_t, shape::Ptr{Int64}, layout::MlirAttribute, memorySpace::MlirAttribute)::MlirType end """ @@ -5954,12 +5396,7 @@ end Creates a MemRef type with the given rank, shape, memory space and element type in the same context as the element type. The type has no affine maps, i.e. represents a default row-major contiguous memref. The type is owned by the context. """ function mlirMemRefTypeContiguousGet(elementType, rank, shape, memorySpace) - @ccall mlir_c.mlirMemRefTypeContiguousGet( - elementType::MlirType, - rank::Cptrdiff_t, - shape::Ptr{Int64}, - memorySpace::MlirAttribute, - )::MlirType + @ccall mlir_c.mlirMemRefTypeContiguousGet(elementType::MlirType, rank::Cptrdiff_t, shape::Ptr{Int64}, memorySpace::MlirAttribute)::MlirType end """ @@ -5968,13 +5405,7 @@ end Same as "[`mlirMemRefTypeContiguousGet`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirMemRefTypeContiguousGetChecked(loc, elementType, rank, shape, memorySpace) - @ccall mlir_c.mlirMemRefTypeContiguousGetChecked( - loc::MlirLocation, - elementType::MlirType, - rank::Cptrdiff_t, - shape::Ptr{Int64}, - memorySpace::MlirAttribute, - )::MlirType + @ccall mlir_c.mlirMemRefTypeContiguousGetChecked(loc::MlirLocation, elementType::MlirType, rank::Cptrdiff_t, shape::Ptr{Int64}, memorySpace::MlirAttribute)::MlirType end """ @@ -5983,9 +5414,7 @@ end Creates an Unranked MemRef type with the given element type and in the given memory space. The type is owned by the context of element type. """ function mlirUnrankedMemRefTypeGet(elementType, memorySpace) - @ccall mlir_c.mlirUnrankedMemRefTypeGet( - elementType::MlirType, memorySpace::MlirAttribute - )::MlirType + @ccall mlir_c.mlirUnrankedMemRefTypeGet(elementType::MlirType, memorySpace::MlirAttribute)::MlirType end """ @@ -5994,9 +5423,7 @@ end Same as "[`mlirUnrankedMemRefTypeGet`](@ref)" but returns a nullptr wrapping [`MlirType`](@ref) on illegal arguments, emitting appropriate diagnostics. """ function mlirUnrankedMemRefTypeGetChecked(loc, elementType, memorySpace) - @ccall mlir_c.mlirUnrankedMemRefTypeGetChecked( - loc::MlirLocation, elementType::MlirType, memorySpace::MlirAttribute - )::MlirType + @ccall mlir_c.mlirUnrankedMemRefTypeGetChecked(loc::MlirLocation, elementType::MlirType, memorySpace::MlirAttribute)::MlirType end """ @@ -6032,9 +5459,7 @@ end Returns the strides of the MemRef if the layout map is in strided form. Both strides and offset are out params. strides must point to pre-allocated memory of length equal to the rank of the memref. """ function mlirMemRefTypeGetStridesAndOffset(type, strides, offset) - @ccall mlir_c.mlirMemRefTypeGetStridesAndOffset( - type::MlirType, strides::Ptr{Int64}, offset::Ptr{Int64} - )::MlirLogicalResult + @ccall mlir_c.mlirMemRefTypeGetStridesAndOffset(type::MlirType, strides::Ptr{Int64}, offset::Ptr{Int64})::MlirLogicalResult end """ @@ -6070,9 +5495,7 @@ end Creates a tuple type that consists of the given list of elemental types. The type is owned by the context. """ function mlirTupleTypeGet(ctx, numElements, elements) - @ccall mlir_c.mlirTupleTypeGet( - ctx::MlirContext, numElements::Cptrdiff_t, elements::Ptr{MlirType} - )::MlirType + @ccall mlir_c.mlirTupleTypeGet(ctx::MlirContext, numElements::Cptrdiff_t, elements::Ptr{MlirType})::MlirType end """ @@ -6117,13 +5540,7 @@ end Creates a function type, mapping a list of input types to result types. """ function mlirFunctionTypeGet(ctx, numInputs, inputs, numResults, results) - @ccall mlir_c.mlirFunctionTypeGet( - ctx::MlirContext, - numInputs::Cptrdiff_t, - inputs::Ptr{MlirType}, - numResults::Cptrdiff_t, - results::Ptr{MlirType}, - )::MlirType + @ccall mlir_c.mlirFunctionTypeGet(ctx::MlirContext, numInputs::Cptrdiff_t, inputs::Ptr{MlirType}, numResults::Cptrdiff_t, results::Ptr{MlirType})::MlirType end """ @@ -6186,9 +5603,7 @@ end Creates an opaque type in the given context associated with the dialect identified by its namespace. The type contains opaque byte data of the specified length (data need not be null-terminated). """ function mlirOpaqueTypeGet(ctx, dialectNamespace, typeData) - @ccall mlir_c.mlirOpaqueTypeGet( - ctx::MlirContext, dialectNamespace::MlirStringRef, typeData::MlirStringRef - )::MlirType + @ccall mlir_c.mlirOpaqueTypeGet(ctx::MlirContext, dialectNamespace::MlirStringRef, typeData::MlirStringRef)::MlirType end """ @@ -6292,9 +5707,7 @@ const MlirDiagnosticHandler = Ptr{Cvoid} Prints a diagnostic using the provided callback. """ function mlirDiagnosticPrint(diagnostic, callback, userData) - @ccall mlir_c.mlirDiagnosticPrint( - diagnostic::MlirDiagnostic, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirDiagnosticPrint(diagnostic::MlirDiagnostic, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -6312,9 +5725,7 @@ end Returns the severity of the diagnostic. """ function mlirDiagnosticGetSeverity(diagnostic) - @ccall mlir_c.mlirDiagnosticGetSeverity( - diagnostic::MlirDiagnostic - )::MlirDiagnosticSeverity + @ccall mlir_c.mlirDiagnosticGetSeverity(diagnostic::MlirDiagnostic)::MlirDiagnosticSeverity end """ @@ -6332,9 +5743,7 @@ end Returns `pos`-th note attached to the diagnostic. Expects `pos` to be a valid zero-based index into the list of notes. """ function mlirDiagnosticGetNote(diagnostic, pos) - @ccall mlir_c.mlirDiagnosticGetNote( - diagnostic::MlirDiagnostic, pos::Cptrdiff_t - )::MlirDiagnostic + @ccall mlir_c.mlirDiagnosticGetNote(diagnostic::MlirDiagnostic, pos::Cptrdiff_t)::MlirDiagnostic end """ @@ -6343,12 +5752,7 @@ end Attaches the diagnostic handler to the context. Handlers are invoked in the reverse order of attachment until one of them processes the diagnostic completely. When a handler is invoked it is passed the `userData` that was provided when it was attached. If non-NULL, `deleteUserData` is called once the system no longer needs to call the handler (for instance after the handler is detached or the context is destroyed). Returns an identifier that can be used to detach the handler. """ function mlirContextAttachDiagnosticHandler(context, handler, userData, deleteUserData) - @ccall mlir_c.mlirContextAttachDiagnosticHandler( - context::MlirContext, - handler::MlirDiagnosticHandler, - userData::Ptr{Cvoid}, - deleteUserData::Ptr{Cvoid}, - )::MlirDiagnosticHandlerID + @ccall mlir_c.mlirContextAttachDiagnosticHandler(context::MlirContext, handler::MlirDiagnosticHandler, userData::Ptr{Cvoid}, deleteUserData::Ptr{Cvoid})::MlirDiagnosticHandlerID end """ @@ -6357,9 +5761,7 @@ end Detaches an attached diagnostic handler from the context given its identifier. """ function mlirContextDetachDiagnosticHandler(context, id) - @ccall mlir_c.mlirContextDetachDiagnosticHandler( - context::MlirContext, id::MlirDiagnosticHandlerID - )::Cvoid + @ccall mlir_c.mlirContextDetachDiagnosticHandler(context::MlirContext, id::MlirDiagnosticHandlerID)::Cvoid end """ @@ -6410,9 +5812,7 @@ function mlirEmitCArrayTypeGetTypeID() end function mlirEmitCArrayTypeGet(nDims, shape, elementType) - @ccall mlir_c.mlirEmitCArrayTypeGet( - nDims::Cptrdiff_t, shape::Ptr{Int64}, elementType::MlirType - )::MlirType + @ccall mlir_c.mlirEmitCArrayTypeGet(nDims::Cptrdiff_t, shape::Ptr{Int64}, elementType::MlirType)::MlirType end function mlirTypeIsAEmitCLValueType(type) @@ -6492,15 +5892,11 @@ function mlirAttributeIsAEmitCCmpPredicate(attr) end function mlirEmitCCmpPredicateAttrGet(ctx, val) - @ccall mlir_c.mlirEmitCCmpPredicateAttrGet( - ctx::MlirContext, val::MlirEmitCCmpPredicate - )::MlirAttribute + @ccall mlir_c.mlirEmitCCmpPredicateAttrGet(ctx::MlirContext, val::MlirEmitCCmpPredicate)::MlirAttribute end function mlirEmitCCmpPredicateAttrGetValue(attr) - @ccall mlir_c.mlirEmitCCmpPredicateAttrGetValue( - attr::MlirAttribute - )::MlirEmitCCmpPredicate + @ccall mlir_c.mlirEmitCCmpPredicateAttrGetValue(attr::MlirAttribute)::MlirEmitCCmpPredicate end function mlirEmitCCmpPredicateAttrGetTypeID() @@ -6512,9 +5908,7 @@ function mlirAttributeIsAEmitCOpaque(attr) end function mlirEmitCOpaqueAttrGet(ctx, value) - @ccall mlir_c.mlirEmitCOpaqueAttrGet( - ctx::MlirContext, value::MlirStringRef - )::MlirAttribute + @ccall mlir_c.mlirEmitCOpaqueAttrGet(ctx::MlirContext, value::MlirStringRef)::MlirAttribute end function mlirEmitCOpaqueAttrGetValue(attr) @@ -6535,15 +5929,11 @@ end Sets the argument attribute 'name' of an argument at index 'pos'. Asserts that the operation is a FuncOp. """ function mlirFuncSetArgAttr(op, pos, name, attr) - @ccall mlir_c.mlirFuncSetArgAttr( - op::MlirOperation, pos::Cptrdiff_t, name::MlirStringRef, attr::MlirAttribute - )::Cvoid + @ccall mlir_c.mlirFuncSetArgAttr(op::MlirOperation, pos::Cptrdiff_t, name::MlirStringRef, attr::MlirAttribute)::Cvoid end function mlirFuncSetResultAttr(op, pos, name, attr) - @ccall mlir_c.mlirFuncSetResultAttr( - op::MlirOperation, pos::Cptrdiff_t, name::MlirStringRef, attr::MlirAttribute - )::Cvoid + @ccall mlir_c.mlirFuncSetResultAttr(op::MlirOperation, pos::Cptrdiff_t, name::MlirStringRef, attr::MlirAttribute)::Cvoid end function mlirGetDialectHandle__gpu__() @@ -6563,26 +5953,11 @@ function mlirAttributeIsAGPUObjectAttr(attr) end function mlirGPUObjectAttrGet(mlirCtx, target, format, objectStrRef, mlirObjectProps) - @ccall mlir_c.mlirGPUObjectAttrGet( - mlirCtx::MlirContext, - target::MlirAttribute, - format::UInt32, - objectStrRef::MlirStringRef, - mlirObjectProps::MlirAttribute, - )::MlirAttribute + @ccall mlir_c.mlirGPUObjectAttrGet(mlirCtx::MlirContext, target::MlirAttribute, format::UInt32, objectStrRef::MlirStringRef, mlirObjectProps::MlirAttribute)::MlirAttribute end -function mlirGPUObjectAttrGetWithKernels( - mlirCtx, target, format, objectStrRef, mlirObjectProps, mlirKernelsAttr -) - @ccall mlir_c.mlirGPUObjectAttrGetWithKernels( - mlirCtx::MlirContext, - target::MlirAttribute, - format::UInt32, - objectStrRef::MlirStringRef, - mlirObjectProps::MlirAttribute, - mlirKernelsAttr::MlirAttribute, - )::MlirAttribute +function mlirGPUObjectAttrGetWithKernels(mlirCtx, target, format, objectStrRef, mlirObjectProps, mlirKernelsAttr) + @ccall mlir_c.mlirGPUObjectAttrGetWithKernels(mlirCtx::MlirContext, target::MlirAttribute, format::UInt32, objectStrRef::MlirStringRef, mlirObjectProps::MlirAttribute, mlirKernelsAttr::MlirAttribute)::MlirAttribute end function mlirGPUObjectAttrGetTarget(mlirObjectAttr) @@ -6602,9 +5977,7 @@ function mlirGPUObjectAttrHasProperties(mlirObjectAttr) end function mlirGPUObjectAttrGetProperties(mlirObjectAttr) - @ccall mlir_c.mlirGPUObjectAttrGetProperties( - mlirObjectAttr::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirGPUObjectAttrGetProperties(mlirObjectAttr::MlirAttribute)::MlirAttribute end function mlirGPUObjectAttrHasKernels(mlirObjectAttr) @@ -6696,12 +6069,7 @@ end Creates an llvm.func type. """ function mlirLLVMFunctionTypeGet(resultType, nArgumentTypes, argumentTypes, isVarArg) - @ccall mlir_c.mlirLLVMFunctionTypeGet( - resultType::MlirType, - nArgumentTypes::Cptrdiff_t, - argumentTypes::Ptr{MlirType}, - isVarArg::Bool, - )::MlirType + @ccall mlir_c.mlirLLVMFunctionTypeGet(resultType::MlirType, nArgumentTypes::Cptrdiff_t, argumentTypes::Ptr{MlirType}, isVarArg::Bool)::MlirType end """ @@ -6764,9 +6132,7 @@ end Returns the `positions`-th field of the struct. Asserts if the struct is opaque, not yet initialized or if the position is out of range. """ function mlirLLVMStructTypeGetElementType(type, position) - @ccall mlir_c.mlirLLVMStructTypeGetElementType( - type::MlirType, position::Cptrdiff_t - )::MlirType + @ccall mlir_c.mlirLLVMStructTypeGetElementType(type::MlirType, position::Cptrdiff_t)::MlirType end """ @@ -6802,9 +6168,7 @@ end Creates an LLVM literal (unnamed) struct type. This may assert if the fields have types not compatible with the LLVM dialect. For a graceful failure, use the checked version. """ function mlirLLVMStructTypeLiteralGet(ctx, nFieldTypes, fieldTypes, isPacked) - @ccall mlir_c.mlirLLVMStructTypeLiteralGet( - ctx::MlirContext, nFieldTypes::Cptrdiff_t, fieldTypes::Ptr{MlirType}, isPacked::Bool - )::MlirType + @ccall mlir_c.mlirLLVMStructTypeLiteralGet(ctx::MlirContext, nFieldTypes::Cptrdiff_t, fieldTypes::Ptr{MlirType}, isPacked::Bool)::MlirType end """ @@ -6813,12 +6177,7 @@ end Creates an LLVM literal (unnamed) struct type if possible. Emits a diagnostic at the given location and returns null otherwise. """ function mlirLLVMStructTypeLiteralGetChecked(loc, nFieldTypes, fieldTypes, isPacked) - @ccall mlir_c.mlirLLVMStructTypeLiteralGetChecked( - loc::MlirLocation, - nFieldTypes::Cptrdiff_t, - fieldTypes::Ptr{MlirType}, - isPacked::Bool, - )::MlirType + @ccall mlir_c.mlirLLVMStructTypeLiteralGetChecked(loc::MlirLocation, nFieldTypes::Cptrdiff_t, fieldTypes::Ptr{MlirType}, isPacked::Bool)::MlirType end """ @@ -6827,9 +6186,7 @@ end Creates an LLVM identified struct type with no body. If a struct type with this name already exists in the context, returns that type. Use [`mlirLLVMStructTypeIdentifiedNewGet`](@ref) to create a fresh struct type, potentially renaming it. The body should be set separatelty by calling [`mlirLLVMStructTypeSetBody`](@ref), if it isn't set already. """ function mlirLLVMStructTypeIdentifiedGet(ctx, name) - @ccall mlir_c.mlirLLVMStructTypeIdentifiedGet( - ctx::MlirContext, name::MlirStringRef - )::MlirType + @ccall mlir_c.mlirLLVMStructTypeIdentifiedGet(ctx::MlirContext, name::MlirStringRef)::MlirType end """ @@ -6838,19 +6195,11 @@ end Creates an LLVM identified struct type with no body and a name starting with the given prefix. If a struct with the exact name as the given prefix already exists, appends an unspecified suffix to the name so that the name is unique in context. """ function mlirLLVMStructTypeIdentifiedNewGet(ctx, name, nFieldTypes, fieldTypes, isPacked) - @ccall mlir_c.mlirLLVMStructTypeIdentifiedNewGet( - ctx::MlirContext, - name::MlirStringRef, - nFieldTypes::Cptrdiff_t, - fieldTypes::Ptr{MlirType}, - isPacked::Bool, - )::MlirType + @ccall mlir_c.mlirLLVMStructTypeIdentifiedNewGet(ctx::MlirContext, name::MlirStringRef, nFieldTypes::Cptrdiff_t, fieldTypes::Ptr{MlirType}, isPacked::Bool)::MlirType end function mlirLLVMStructTypeOpaqueGet(ctx, name) - @ccall mlir_c.mlirLLVMStructTypeOpaqueGet( - ctx::MlirContext, name::MlirStringRef - )::MlirType + @ccall mlir_c.mlirLLVMStructTypeOpaqueGet(ctx::MlirContext, name::MlirStringRef)::MlirType end """ @@ -6859,12 +6208,7 @@ end Sets the body of the identified struct if it hasn't been set yet. Returns whether the operation was successful. """ function mlirLLVMStructTypeSetBody(structType, nFieldTypes, fieldTypes, isPacked) - @ccall mlir_c.mlirLLVMStructTypeSetBody( - structType::MlirType, - nFieldTypes::Cptrdiff_t, - fieldTypes::Ptr{MlirType}, - isPacked::Bool, - )::MlirLogicalResult + @ccall mlir_c.mlirLLVMStructTypeSetBody(structType::MlirType, nFieldTypes::Cptrdiff_t, fieldTypes::Ptr{MlirType}, isPacked::Bool)::MlirLogicalResult end @cenum MlirLLVMCConv::UInt32 begin @@ -6923,9 +6267,7 @@ end Creates a LLVM CConv attribute. """ function mlirLLVMCConvAttrGet(ctx, cconv) - @ccall mlir_c.mlirLLVMCConvAttrGet( - ctx::MlirContext, cconv::MlirLLVMCConv - )::MlirAttribute + @ccall mlir_c.mlirLLVMCConvAttrGet(ctx::MlirContext, cconv::MlirLLVMCConv)::MlirAttribute end @cenum MlirLLVMComdat::UInt32 begin @@ -6942,9 +6284,7 @@ end Creates a LLVM Comdat attribute. """ function mlirLLVMComdatAttrGet(ctx, comdat) - @ccall mlir_c.mlirLLVMComdatAttrGet( - ctx::MlirContext, comdat::MlirLLVMComdat - )::MlirAttribute + @ccall mlir_c.mlirLLVMComdatAttrGet(ctx::MlirContext, comdat::MlirLLVMComdat)::MlirAttribute end @cenum MlirLLVMLinkage::UInt32 begin @@ -6967,9 +6307,7 @@ end Creates a LLVM Linkage attribute. """ function mlirLLVMLinkageAttrGet(ctx, linkage) - @ccall mlir_c.mlirLLVMLinkageAttrGet( - ctx::MlirContext, linkage::MlirLLVMLinkage - )::MlirAttribute + @ccall mlir_c.mlirLLVMLinkageAttrGet(ctx::MlirContext, linkage::MlirLLVMLinkage)::MlirAttribute end """ @@ -6987,9 +6325,7 @@ end Creates a LLVM DIExpressionElem attribute. """ function mlirLLVMDIExpressionElemAttrGet(ctx, opcode, nArguments, arguments) - @ccall mlir_c.mlirLLVMDIExpressionElemAttrGet( - ctx::MlirContext, opcode::Cuint, nArguments::Cptrdiff_t, arguments::Ptr{UInt64} - )::MlirAttribute + @ccall mlir_c.mlirLLVMDIExpressionElemAttrGet(ctx::MlirContext, opcode::Cuint, nArguments::Cptrdiff_t, arguments::Ptr{UInt64})::MlirAttribute end """ @@ -6998,9 +6334,7 @@ end Creates a LLVM DIExpression attribute. """ function mlirLLVMDIExpressionAttrGet(ctx, nOperations, operations) - @ccall mlir_c.mlirLLVMDIExpressionAttrGet( - ctx::MlirContext, nOperations::Cptrdiff_t, operations::Ptr{MlirAttribute} - )::MlirAttribute + @ccall mlir_c.mlirLLVMDIExpressionAttrGet(ctx::MlirContext, nOperations::Cptrdiff_t, operations::Ptr{MlirAttribute})::MlirAttribute end @cenum MlirLLVMTypeEncoding::UInt32 begin @@ -7032,13 +6366,7 @@ end Creates a LLVM DIBasicType attribute. """ function mlirLLVMDIBasicTypeAttrGet(ctx, tag, name, sizeInBits, encoding) - @ccall mlir_c.mlirLLVMDIBasicTypeAttrGet( - ctx::MlirContext, - tag::Cuint, - name::MlirAttribute, - sizeInBits::UInt64, - encoding::MlirLLVMTypeEncoding, - )::MlirAttribute + @ccall mlir_c.mlirLLVMDIBasicTypeAttrGet(ctx::MlirContext, tag::Cuint, name::MlirAttribute, sizeInBits::UInt64, encoding::MlirLLVMTypeEncoding)::MlirAttribute end """ @@ -7055,46 +6383,8 @@ end Creates a LLVM DICompositeType attribute. """ -function mlirLLVMDICompositeTypeAttrGet( - ctx, - recId, - isRecSelf, - tag, - name, - file, - line, - scope, - baseType, - flags, - sizeInBits, - alignInBits, - nElements, - elements, - dataLocation, - rank, - allocated, - associated, -) - @ccall mlir_c.mlirLLVMDICompositeTypeAttrGet( - ctx::MlirContext, - recId::MlirAttribute, - isRecSelf::Bool, - tag::Cuint, - name::MlirAttribute, - file::MlirAttribute, - line::UInt32, - scope::MlirAttribute, - baseType::MlirAttribute, - flags::Int64, - sizeInBits::UInt64, - alignInBits::UInt64, - nElements::Cptrdiff_t, - elements::Ptr{MlirAttribute}, - dataLocation::MlirAttribute, - rank::MlirAttribute, - allocated::MlirAttribute, - associated::MlirAttribute, - )::MlirAttribute +function mlirLLVMDICompositeTypeAttrGet(ctx, recId, isRecSelf, tag, name, file, line, scope, baseType, flags, sizeInBits, alignInBits, nElements, elements, dataLocation, rank, allocated, associated) + @ccall mlir_c.mlirLLVMDICompositeTypeAttrGet(ctx::MlirContext, recId::MlirAttribute, isRecSelf::Bool, tag::Cuint, name::MlirAttribute, file::MlirAttribute, line::UInt32, scope::MlirAttribute, baseType::MlirAttribute, flags::Int64, sizeInBits::UInt64, alignInBits::UInt64, nElements::Cptrdiff_t, elements::Ptr{MlirAttribute}, dataLocation::MlirAttribute, rank::MlirAttribute, allocated::MlirAttribute, associated::MlirAttribute)::MlirAttribute end """ @@ -7102,52 +6392,12 @@ end Creates a LLVM DIDerivedType attribute. Note that `dwarfAddressSpace` is an optional field, where [`MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL`](@ref) indicates null and non-negative values indicate a value present. """ -function mlirLLVMDIDerivedTypeAttrGet( - ctx, - tag, - name, - baseType, - sizeInBits, - alignInBits, - offsetInBits, - dwarfAddressSpace, - extraData, -) - @ccall mlir_c.mlirLLVMDIDerivedTypeAttrGet( - ctx::MlirContext, - tag::Cuint, - name::MlirAttribute, - baseType::MlirAttribute, - sizeInBits::UInt64, - alignInBits::UInt32, - offsetInBits::UInt64, - dwarfAddressSpace::Int64, - extraData::MlirAttribute, - )::MlirAttribute +function mlirLLVMDIDerivedTypeAttrGet(ctx, tag, name, baseType, sizeInBits, alignInBits, offsetInBits, dwarfAddressSpace, extraData) + @ccall mlir_c.mlirLLVMDIDerivedTypeAttrGet(ctx::MlirContext, tag::Cuint, name::MlirAttribute, baseType::MlirAttribute, sizeInBits::UInt64, alignInBits::UInt32, offsetInBits::UInt64, dwarfAddressSpace::Int64, extraData::MlirAttribute)::MlirAttribute end -function mlirLLVMDIStringTypeAttrGet( - ctx, - tag, - name, - sizeInBits, - alignInBits, - stringLength, - stringLengthExp, - stringLocationExp, - encoding, -) - @ccall mlir_c.mlirLLVMDIStringTypeAttrGet( - ctx::MlirContext, - tag::Cuint, - name::MlirAttribute, - sizeInBits::UInt64, - alignInBits::UInt32, - stringLength::MlirAttribute, - stringLengthExp::MlirAttribute, - stringLocationExp::MlirAttribute, - encoding::MlirLLVMTypeEncoding, - )::MlirAttribute +function mlirLLVMDIStringTypeAttrGet(ctx, tag, name, sizeInBits, alignInBits, stringLength, stringLengthExp, stringLocationExp, encoding) + @ccall mlir_c.mlirLLVMDIStringTypeAttrGet(ctx::MlirContext, tag::Cuint, name::MlirAttribute, sizeInBits::UInt64, alignInBits::UInt32, stringLength::MlirAttribute, stringLengthExp::MlirAttribute, stringLocationExp::MlirAttribute, encoding::MlirLLVMTypeEncoding)::MlirAttribute end """ @@ -7156,9 +6406,7 @@ end Gets the base type from a LLVM DIDerivedType attribute. """ function mlirLLVMDIDerivedTypeAttrGetBaseType(diDerivedType) - @ccall mlir_c.mlirLLVMDIDerivedTypeAttrGetBaseType( - diDerivedType::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirLLVMDIDerivedTypeAttrGetBaseType(diDerivedType::MlirAttribute)::MlirAttribute end """ @@ -7167,9 +6415,7 @@ end Creates a LLVM DIFileAttr attribute. """ function mlirLLVMDIFileAttrGet(ctx, name, directory) - @ccall mlir_c.mlirLLVMDIFileAttrGet( - ctx::MlirContext, name::MlirAttribute, directory::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirLLVMDIFileAttrGet(ctx::MlirContext, name::MlirAttribute, directory::MlirAttribute)::MlirAttribute end @cenum MlirLLVMDIEmissionKind::UInt32 begin @@ -7230,13 +6476,7 @@ end Creates a LLVM DILexicalBlock attribute. """ function mlirLLVMDILexicalBlockAttrGet(ctx, scope, file, line, column) - @ccall mlir_c.mlirLLVMDILexicalBlockAttrGet( - ctx::MlirContext, - scope::MlirAttribute, - file::MlirAttribute, - line::Cuint, - column::Cuint, - )::MlirAttribute + @ccall mlir_c.mlirLLVMDILexicalBlockAttrGet(ctx::MlirContext, scope::MlirAttribute, file::MlirAttribute, line::Cuint, column::Cuint)::MlirAttribute end """ @@ -7245,9 +6485,7 @@ end Creates a LLVM DILexicalBlockFile attribute. """ function mlirLLVMDILexicalBlockFileAttrGet(ctx, scope, file, discriminator) - @ccall mlir_c.mlirLLVMDILexicalBlockFileAttrGet( - ctx::MlirContext, scope::MlirAttribute, file::MlirAttribute, discriminator::Cuint - )::MlirAttribute + @ccall mlir_c.mlirLLVMDILexicalBlockFileAttrGet(ctx::MlirContext, scope::MlirAttribute, file::MlirAttribute, discriminator::Cuint)::MlirAttribute end """ @@ -7255,20 +6493,8 @@ end Creates a LLVM DILocalVariableAttr attribute. """ -function mlirLLVMDILocalVariableAttrGet( - ctx, scope, name, diFile, line, arg, alignInBits, diType, flags -) - @ccall mlir_c.mlirLLVMDILocalVariableAttrGet( - ctx::MlirContext, - scope::MlirAttribute, - name::MlirAttribute, - diFile::MlirAttribute, - line::Cuint, - arg::Cuint, - alignInBits::Cuint, - diType::MlirAttribute, - flags::Int64, - )::MlirAttribute +function mlirLLVMDILocalVariableAttrGet(ctx, scope, name, diFile, line, arg, alignInBits, diType, flags) + @ccall mlir_c.mlirLLVMDILocalVariableAttrGet(ctx::MlirContext, scope::MlirAttribute, name::MlirAttribute, diFile::MlirAttribute, line::Cuint, arg::Cuint, alignInBits::Cuint, diType::MlirAttribute, flags::Int64)::MlirAttribute end """ @@ -7285,44 +6511,8 @@ end Creates a LLVM DISubprogramAttr attribute. """ -function mlirLLVMDISubprogramAttrGet( - ctx, - recId, - isRecSelf, - id, - compileUnit, - scope, - name, - linkageName, - file, - line, - scopeLine, - subprogramFlags, - type, - nRetainedNodes, - retainedNodes, - nAnnotations, - annotations, -) - @ccall mlir_c.mlirLLVMDISubprogramAttrGet( - ctx::MlirContext, - recId::MlirAttribute, - isRecSelf::Bool, - id::MlirAttribute, - compileUnit::MlirAttribute, - scope::MlirAttribute, - name::MlirAttribute, - linkageName::MlirAttribute, - file::MlirAttribute, - line::Cuint, - scopeLine::Cuint, - subprogramFlags::UInt64, - type::MlirAttribute, - nRetainedNodes::Cptrdiff_t, - retainedNodes::Ptr{MlirAttribute}, - nAnnotations::Cptrdiff_t, - annotations::Ptr{MlirAttribute}, - )::MlirAttribute +function mlirLLVMDISubprogramAttrGet(ctx, recId, isRecSelf, id, compileUnit, scope, name, linkageName, file, line, scopeLine, subprogramFlags, type, nRetainedNodes, retainedNodes, nAnnotations, annotations) + @ccall mlir_c.mlirLLVMDISubprogramAttrGet(ctx::MlirContext, recId::MlirAttribute, isRecSelf::Bool, id::MlirAttribute, compileUnit::MlirAttribute, scope::MlirAttribute, name::MlirAttribute, linkageName::MlirAttribute, file::MlirAttribute, line::Cuint, scopeLine::Cuint, subprogramFlags::UInt64, type::MlirAttribute, nRetainedNodes::Cptrdiff_t, retainedNodes::Ptr{MlirAttribute}, nAnnotations::Cptrdiff_t, annotations::Ptr{MlirAttribute})::MlirAttribute end """ @@ -7331,9 +6521,7 @@ end Creates a LLVM DIAnnotation attribute. """ function mlirLLVMDIAnnotationAttrGet(ctx, name, value) - @ccall mlir_c.mlirLLVMDIAnnotationAttrGet( - ctx::MlirContext, name::MlirAttribute, value::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirLLVMDIAnnotationAttrGet(ctx::MlirContext, name::MlirAttribute, value::MlirAttribute)::MlirAttribute end """ @@ -7342,9 +6530,7 @@ end Gets the scope from this DISubprogramAttr. """ function mlirLLVMDISubprogramAttrGetScope(diSubprogram) - @ccall mlir_c.mlirLLVMDISubprogramAttrGetScope( - diSubprogram::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirLLVMDISubprogramAttrGetScope(diSubprogram::MlirAttribute)::MlirAttribute end """ @@ -7371,9 +6557,7 @@ end Gets the compile unit from this DISubprogram. """ function mlirLLVMDISubprogramAttrGetCompileUnit(diSubprogram) - @ccall mlir_c.mlirLLVMDISubprogramAttrGetCompileUnit( - diSubprogram::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirLLVMDISubprogramAttrGetCompileUnit(diSubprogram::MlirAttribute)::MlirAttribute end """ @@ -7382,9 +6566,7 @@ end Gets the file from this DISubprogramAttr. """ function mlirLLVMDISubprogramAttrGetFile(diSubprogram) - @ccall mlir_c.mlirLLVMDISubprogramAttrGetFile( - diSubprogram::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirLLVMDISubprogramAttrGetFile(diSubprogram::MlirAttribute)::MlirAttribute end """ @@ -7393,9 +6575,7 @@ end Gets the type from this DISubprogramAttr. """ function mlirLLVMDISubprogramAttrGetType(diSubprogram) - @ccall mlir_c.mlirLLVMDISubprogramAttrGetType( - diSubprogram::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirLLVMDISubprogramAttrGetType(diSubprogram::MlirAttribute)::MlirAttribute end """ @@ -7404,12 +6584,7 @@ end Creates a LLVM DISubroutineTypeAttr attribute. """ function mlirLLVMDISubroutineTypeAttrGet(ctx, callingConvention, nTypes, types) - @ccall mlir_c.mlirLLVMDISubroutineTypeAttrGet( - ctx::MlirContext, - callingConvention::Cuint, - nTypes::Cptrdiff_t, - types::Ptr{MlirAttribute}, - )::MlirAttribute + @ccall mlir_c.mlirLLVMDISubroutineTypeAttrGet(ctx::MlirContext, callingConvention::Cuint, nTypes::Cptrdiff_t, types::Ptr{MlirAttribute})::MlirAttribute end """ @@ -7417,20 +6592,8 @@ end Creates a LLVM DIModuleAttr attribute. """ -function mlirLLVMDIModuleAttrGet( - ctx, file, scope, name, configMacros, includePath, apinotes, line, isDecl -) - @ccall mlir_c.mlirLLVMDIModuleAttrGet( - ctx::MlirContext, - file::MlirAttribute, - scope::MlirAttribute, - name::MlirAttribute, - configMacros::MlirAttribute, - includePath::MlirAttribute, - apinotes::MlirAttribute, - line::Cuint, - isDecl::Bool, - )::MlirAttribute +function mlirLLVMDIModuleAttrGet(ctx, file, scope, name, configMacros, includePath, apinotes, line, isDecl) + @ccall mlir_c.mlirLLVMDIModuleAttrGet(ctx::MlirContext, file::MlirAttribute, scope::MlirAttribute, name::MlirAttribute, configMacros::MlirAttribute, includePath::MlirAttribute, apinotes::MlirAttribute, line::Cuint, isDecl::Bool)::MlirAttribute end """ @@ -7438,20 +6601,8 @@ end Creates a LLVM DIImportedEntityAttr attribute. """ -function mlirLLVMDIImportedEntityAttrGet( - ctx, tag, scope, entity, file, line, name, nElements, elements -) - @ccall mlir_c.mlirLLVMDIImportedEntityAttrGet( - ctx::MlirContext, - tag::Cuint, - scope::MlirAttribute, - entity::MlirAttribute, - file::MlirAttribute, - line::Cuint, - name::MlirAttribute, - nElements::Cptrdiff_t, - elements::Ptr{MlirAttribute}, - )::MlirAttribute +function mlirLLVMDIImportedEntityAttrGet(ctx, tag, scope, entity, file, line, name, nElements, elements) + @ccall mlir_c.mlirLLVMDIImportedEntityAttrGet(ctx::MlirContext, tag::Cuint, scope::MlirAttribute, entity::MlirAttribute, file::MlirAttribute, line::Cuint, name::MlirAttribute, nElements::Cptrdiff_t, elements::Ptr{MlirAttribute})::MlirAttribute end """ @@ -7484,9 +6635,7 @@ struct MlirLinalgContractionDimensions end function mlirLinalgInferContractionDimensions(op) - @ccall mlir_c.mlirLinalgInferContractionDimensions( - op::MlirOperation - )::MlirLinalgContractionDimensions + @ccall mlir_c.mlirLinalgInferContractionDimensions(op::MlirOperation)::MlirLinalgContractionDimensions end function mlirLinalgIsAConvolutionOp(op) @@ -7505,9 +6654,7 @@ struct MlirLinalgConvolutionDimensions end function mlirLinalgInferConvolutionDimensions(op) - @ccall mlir_c.mlirLinalgInferConvolutionDimensions( - op::MlirOperation - )::MlirLinalgConvolutionDimensions + @ccall mlir_c.mlirLinalgInferConvolutionDimensions(op::MlirOperation)::MlirLinalgConvolutionDimensions end function mlirLinalgGetIndexingMapsAttribute(op) @@ -7538,17 +6685,8 @@ function mlirTypeIsANVGPUTensorMapDescriptorType(type) @ccall mlir_c.mlirTypeIsANVGPUTensorMapDescriptorType(type::MlirType)::Bool end -function mlirNVGPUTensorMapDescriptorTypeGet( - ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave -) - @ccall mlir_c.mlirNVGPUTensorMapDescriptorTypeGet( - ctx::MlirContext, - tensorMemrefType::MlirType, - swizzle::Cint, - l2promo::Cint, - oobFill::Cint, - interleave::Cint, - )::MlirType +function mlirNVGPUTensorMapDescriptorTypeGet(ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave) + @ccall mlir_c.mlirNVGPUTensorMapDescriptorTypeGet(ctx::MlirContext, tensorMemrefType::MlirType, swizzle::Cint, l2promo::Cint, oobFill::Cint, interleave::Cint)::MlirType end function mlirGetDialectHandle__nvvm__() @@ -7639,9 +6777,7 @@ end Returns the minimum possible value stored by a quantized type. """ function mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, integralWidth) - @ccall mlir_c.mlirQuantizedTypeGetDefaultMinimumForInteger( - isSigned::Bool, integralWidth::Cuint - )::Int64 + @ccall mlir_c.mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned::Bool, integralWidth::Cuint)::Int64 end """ @@ -7650,9 +6786,7 @@ end Returns the maximum possible value stored by a quantized type. """ function mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, integralWidth) - @ccall mlir_c.mlirQuantizedTypeGetDefaultMaximumForInteger( - isSigned::Bool, integralWidth::Cuint - )::Int64 + @ccall mlir_c.mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned::Bool, integralWidth::Cuint)::Int64 end """ @@ -7724,9 +6858,7 @@ end Returns `true` if the `candidate` type is compatible with the given quantized `type`. """ function mlirQuantizedTypeIsCompatibleExpressedType(type, candidate) - @ccall mlir_c.mlirQuantizedTypeIsCompatibleExpressedType( - type::MlirType, candidate::MlirType - )::Bool + @ccall mlir_c.mlirQuantizedTypeIsCompatibleExpressedType(type::MlirType, candidate::MlirType)::Bool end """ @@ -7744,9 +6876,7 @@ end Casts from a type based on the storage type of the given type to a corresponding type based on the given type. Returns a null type if the cast is not valid. """ function mlirQuantizedTypeCastFromStorageType(type, candidate) - @ccall mlir_c.mlirQuantizedTypeCastFromStorageType( - type::MlirType, candidate::MlirType - )::MlirType + @ccall mlir_c.mlirQuantizedTypeCastFromStorageType(type::MlirType, candidate::MlirType)::MlirType end """ @@ -7764,9 +6894,7 @@ end Casts from a type based on the expressed type of the given type to a corresponding type based on the given type. Returns a null type if the cast is not valid. """ function mlirQuantizedTypeCastFromExpressedType(type, candidate) - @ccall mlir_c.mlirQuantizedTypeCastFromExpressedType( - type::MlirType, candidate::MlirType - )::MlirType + @ccall mlir_c.mlirQuantizedTypeCastFromExpressedType(type::MlirType, candidate::MlirType)::MlirType end """ @@ -7784,9 +6912,7 @@ end Casts from a type based on the expressed type of the given quantized type to equivalent type based on storage type of the same quantized type. """ function mlirQuantizedTypeCastExpressedToStorageType(type, candidate) - @ccall mlir_c.mlirQuantizedTypeCastExpressedToStorageType( - type::MlirType, candidate::MlirType - )::MlirType + @ccall mlir_c.mlirQuantizedTypeCastExpressedToStorageType(type::MlirType, candidate::MlirType)::MlirType end """ @@ -7803,16 +6929,8 @@ end Creates an instance of AnyQuantizedType with the given parameters in the same context as `storageType` and returns it. The instance is owned by the context. """ -function mlirAnyQuantizedTypeGet( - flags, storageType, expressedType, storageTypeMin, storageTypeMax -) - @ccall mlir_c.mlirAnyQuantizedTypeGet( - flags::Cuint, - storageType::MlirType, - expressedType::MlirType, - storageTypeMin::Int64, - storageTypeMax::Int64, - )::MlirType +function mlirAnyQuantizedTypeGet(flags, storageType, expressedType, storageTypeMin, storageTypeMax) + @ccall mlir_c.mlirAnyQuantizedTypeGet(flags::Cuint, storageType::MlirType, expressedType::MlirType, storageTypeMin::Int64, storageTypeMax::Int64)::MlirType end """ @@ -7829,18 +6947,8 @@ end Creates an instance of UniformQuantizedType with the given parameters in the same context as `storageType` and returns it. The instance is owned by the context. """ -function mlirUniformQuantizedTypeGet( - flags, storageType, expressedType, scale, zeroPoint, storageTypeMin, storageTypeMax -) - @ccall mlir_c.mlirUniformQuantizedTypeGet( - flags::Cuint, - storageType::MlirType, - expressedType::MlirType, - scale::Cdouble, - zeroPoint::Int64, - storageTypeMin::Int64, - storageTypeMax::Int64, - )::MlirType +function mlirUniformQuantizedTypeGet(flags, storageType, expressedType, scale, zeroPoint, storageTypeMin, storageTypeMax) + @ccall mlir_c.mlirUniformQuantizedTypeGet(flags::Cuint, storageType::MlirType, expressedType::MlirType, scale::Cdouble, zeroPoint::Int64, storageTypeMin::Int64, storageTypeMax::Int64)::MlirType end """ @@ -7884,28 +6992,8 @@ end Creates an instance of UniformQuantizedPerAxisType with the given parameters in the same context as `storageType` and returns it. `scales` and `zeroPoints` point to `nDims` number of elements. The instance is owned by the context. """ -function mlirUniformQuantizedPerAxisTypeGet( - flags, - storageType, - expressedType, - nDims, - scales, - zeroPoints, - quantizedDimension, - storageTypeMin, - storageTypeMax, -) - @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGet( - flags::Cuint, - storageType::MlirType, - expressedType::MlirType, - nDims::Cptrdiff_t, - scales::Ptr{Cdouble}, - zeroPoints::Ptr{Int64}, - quantizedDimension::Int32, - storageTypeMin::Int64, - storageTypeMax::Int64, - )::MlirType +function mlirUniformQuantizedPerAxisTypeGet(flags, storageType, expressedType, nDims, scales, zeroPoints, quantizedDimension, storageTypeMin, storageTypeMax) + @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGet(flags::Cuint, storageType::MlirType, expressedType::MlirType, nDims::Cptrdiff_t, scales::Ptr{Cdouble}, zeroPoints::Ptr{Int64}, quantizedDimension::Int32, storageTypeMin::Int64, storageTypeMax::Int64)::MlirType end """ @@ -7923,9 +7011,7 @@ end Returns `pos`-th scale of the given quantized per-axis type. """ function mlirUniformQuantizedPerAxisTypeGetScale(type, pos) - @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetScale( - type::MlirType, pos::Cptrdiff_t - )::Cdouble + @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetScale(type::MlirType, pos::Cptrdiff_t)::Cdouble end """ @@ -7934,9 +7020,7 @@ end Returns `pos`-th zero point of the given quantized per-axis type. """ function mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, pos) - @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetZeroPoint( - type::MlirType, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetZeroPoint(type::MlirType, pos::Cptrdiff_t)::Int64 end """ @@ -7945,9 +7029,7 @@ end Returns the index of the quantized dimension in the given quantized per-axis type. """ function mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type) - @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetQuantizedDimension( - type::MlirType - )::Int32 + @ccall mlir_c.mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type::MlirType)::Int32 end """ @@ -7975,30 +7057,8 @@ Creates a UniformQuantizedSubChannelType with the given parameters. The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be DenseElementsAttrs. `quantizedDimensions` and `blockSizes` point to `blockSizeInfoLength` number of elements, describing respectively the quantization axis and corresponding block size. """ -function mlirUniformQuantizedSubChannelTypeGet( - flags, - storageType, - expressedType, - scalesAttr, - zeroPointsAttr, - blockSizeInfoLength, - quantizedDimensions, - blockSizes, - storageTypeMin, - storageTypeMax, -) - @ccall mlir_c.mlirUniformQuantizedSubChannelTypeGet( - flags::Cuint, - storageType::MlirType, - expressedType::MlirType, - scalesAttr::MlirAttribute, - zeroPointsAttr::MlirAttribute, - blockSizeInfoLength::Cptrdiff_t, - quantizedDimensions::Ptr{Int32}, - blockSizes::Ptr{Int64}, - storageTypeMin::Int64, - storageTypeMax::Int64, - )::MlirType +function mlirUniformQuantizedSubChannelTypeGet(flags, storageType, expressedType, scalesAttr, zeroPointsAttr, blockSizeInfoLength, quantizedDimensions, blockSizes, storageTypeMin, storageTypeMax) + @ccall mlir_c.mlirUniformQuantizedSubChannelTypeGet(flags::Cuint, storageType::MlirType, expressedType::MlirType, scalesAttr::MlirAttribute, zeroPointsAttr::MlirAttribute, blockSizeInfoLength::Cptrdiff_t, quantizedDimensions::Ptr{Int32}, blockSizes::Ptr{Int64}, storageTypeMin::Int64, storageTypeMax::Int64)::MlirType end """ @@ -8007,9 +7067,7 @@ end Returns the number of block sizes provided in type. """ function mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type) - @ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetNumBlockSizes( - type::MlirType - )::Cptrdiff_t + @ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type::MlirType)::Cptrdiff_t end """ @@ -8018,9 +7076,7 @@ end Returns the quantized dimension at the given position. """ function mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, pos) - @ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetQuantizedDimension( - type::MlirType, pos::Cptrdiff_t - )::Int32 + @ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type::MlirType, pos::Cptrdiff_t)::Int32 end """ @@ -8029,9 +7085,7 @@ end Returns the block size at the given position. """ function mlirUniformQuantizedSubChannelTypeGetBlockSize(type, pos) - @ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetBlockSize( - type::MlirType, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetBlockSize(type::MlirType, pos::Cptrdiff_t)::Int64 end """ @@ -8049,9 +7103,7 @@ end Returns the zero-points of the quantized type. """ function mlirUniformQuantizedSubChannelTypeGetZeroPoints(type) - @ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetZeroPoints( - type::MlirType - )::MlirAttribute + @ccall mlir_c.mlirUniformQuantizedSubChannelTypeGetZeroPoints(type::MlirType)::MlirAttribute end """ @@ -8069,9 +7121,7 @@ end Creates an instance of CalibratedQuantizedType with the given parameters in the same context as `expressedType` and returns it. The instance is owned by the context. """ function mlirCalibratedQuantizedTypeGet(expressedType, min, max) - @ccall mlir_c.mlirCalibratedQuantizedTypeGet( - expressedType::MlirType, min::Cdouble, max::Cdouble - )::MlirType + @ccall mlir_c.mlirCalibratedQuantizedTypeGet(expressedType::MlirType, min::Cdouble, max::Cdouble)::MlirType end """ @@ -8137,9 +7187,7 @@ end Creates an array type with the given domain and range types. """ function mlirSMTTypeGetArray(ctx, domainType, rangeType) - @ccall mlir_c.mlirSMTTypeGetArray( - ctx::MlirContext, domainType::MlirType, rangeType::MlirType - )::MlirType + @ccall mlir_c.mlirSMTTypeGetArray(ctx::MlirContext, domainType::MlirType, rangeType::MlirType)::MlirType end """ @@ -8211,12 +7259,7 @@ end Creates a smt::FuncType with the given domain and range types. """ function mlirSMTTypeGetSMTFunc(ctx, numberOfDomainTypes, domainTypes, rangeType) - @ccall mlir_c.mlirSMTTypeGetSMTFunc( - ctx::MlirContext, - numberOfDomainTypes::Csize_t, - domainTypes::Ptr{MlirType}, - rangeType::MlirType, - )::MlirType + @ccall mlir_c.mlirSMTTypeGetSMTFunc(ctx::MlirContext, numberOfDomainTypes::Csize_t, domainTypes::Ptr{MlirType}, rangeType::MlirType)::MlirType end """ @@ -8234,12 +7277,7 @@ end Creates a smt::SortType with the given identifier and sort parameters. """ function mlirSMTTypeGetSort(ctx, identifier, numberOfSortParams, sortParams) - @ccall mlir_c.mlirSMTTypeGetSort( - ctx::MlirContext, - identifier::MlirIdentifier, - numberOfSortParams::Csize_t, - sortParams::Ptr{MlirType}, - )::MlirType + @ccall mlir_c.mlirSMTTypeGetSort(ctx::MlirContext, identifier::MlirIdentifier, numberOfSortParams::Csize_t, sortParams::Ptr{MlirType})::MlirType end """ @@ -8275,9 +7313,7 @@ end Creates a smt::BitVectorAttr with the given value and width. """ function mlirSMTAttrGetBitVector(ctx, value, width) - @ccall mlir_c.mlirSMTAttrGetBitVector( - ctx::MlirContext, value::UInt64, width::Cuint - )::MlirAttribute + @ccall mlir_c.mlirSMTAttrGetBitVector(ctx::MlirContext, value::UInt64, width::Cuint)::MlirAttribute end """ @@ -8286,9 +7322,7 @@ end Creates a smt::BVCmpPredicateAttr with the given string. """ function mlirSMTAttrGetBVCmpPredicate(ctx, str) - @ccall mlir_c.mlirSMTAttrGetBVCmpPredicate( - ctx::MlirContext, str::MlirStringRef - )::MlirAttribute + @ccall mlir_c.mlirSMTAttrGetBVCmpPredicate(ctx::MlirContext, str::MlirStringRef)::MlirAttribute end """ @@ -8297,9 +7331,7 @@ end Creates a smt::IntPredicateAttr with the given string. """ function mlirSMTAttrGetIntPredicate(ctx, str) - @ccall mlir_c.mlirSMTAttrGetIntPredicate( - ctx::MlirContext, str::MlirStringRef - )::MlirAttribute + @ccall mlir_c.mlirSMTAttrGetIntPredicate(ctx::MlirContext, str::MlirStringRef)::MlirAttribute end function mlirGetDialectHandle__spirv__() @@ -8350,20 +7382,8 @@ end Creates a `sparse\\_tensor.encoding` attribute with the given parameters. """ -function mlirSparseTensorEncodingAttrGet( - ctx, lvlRank, lvlTypes, dimToLvl, lvlTodim, posWidth, crdWidth, explicitVal, implicitVal -) - @ccall mlir_c.mlirSparseTensorEncodingAttrGet( - ctx::MlirContext, - lvlRank::Cptrdiff_t, - lvlTypes::Ptr{MlirSparseTensorLevelType}, - dimToLvl::MlirAffineMap, - lvlTodim::MlirAffineMap, - posWidth::Cint, - crdWidth::Cint, - explicitVal::MlirAttribute, - implicitVal::MlirAttribute, - )::MlirAttribute +function mlirSparseTensorEncodingAttrGet(ctx, lvlRank, lvlTypes, dimToLvl, lvlTodim, posWidth, crdWidth, explicitVal, implicitVal) + @ccall mlir_c.mlirSparseTensorEncodingAttrGet(ctx::MlirContext, lvlRank::Cptrdiff_t, lvlTypes::Ptr{MlirSparseTensorLevelType}, dimToLvl::MlirAffineMap, lvlTodim::MlirAffineMap, posWidth::Cint, crdWidth::Cint, explicitVal::MlirAttribute, implicitVal::MlirAttribute)::MlirAttribute end """ @@ -8381,9 +7401,7 @@ end Returns a specified level-type of the `sparse\\_tensor.encoding` attribute. """ function mlirSparseTensorEncodingAttrGetLvlType(attr, lvl) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlType( - attr::MlirAttribute, lvl::Cptrdiff_t - )::MlirSparseTensorLevelType + @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlType(attr::MlirAttribute, lvl::Cptrdiff_t)::MlirSparseTensorLevelType end """ @@ -8392,9 +7410,7 @@ end Returns a specified level-format of the `sparse\\_tensor.encoding` attribute. """ function mlirSparseTensorEncodingAttrGetLvlFmt(attr, lvl) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlFmt( - attr::MlirAttribute, lvl::Cptrdiff_t - )::MlirSparseTensorLevelFormat + @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlFmt(attr::MlirAttribute, lvl::Cptrdiff_t)::MlirSparseTensorLevelFormat end """ @@ -8403,9 +7419,7 @@ end Returns the dimension-to-level mapping of the `sparse\\_tensor.encoding` attribute. """ function mlirSparseTensorEncodingAttrGetDimToLvl(attr) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetDimToLvl( - attr::MlirAttribute - )::MlirAffineMap + @ccall mlir_c.mlirSparseTensorEncodingAttrGetDimToLvl(attr::MlirAttribute)::MlirAffineMap end """ @@ -8414,9 +7428,7 @@ end Returns the level-to-dimension mapping of the `sparse\\_tensor.encoding` attribute. """ function mlirSparseTensorEncodingAttrGetLvlToDim(attr) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlToDim( - attr::MlirAttribute - )::MlirAffineMap + @ccall mlir_c.mlirSparseTensorEncodingAttrGetLvlToDim(attr::MlirAttribute)::MlirAffineMap end """ @@ -8443,9 +7455,7 @@ end Returns the explicit value of the `sparse\\_tensor.encoding` attribute. """ function mlirSparseTensorEncodingAttrGetExplicitVal(attr) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetExplicitVal( - attr::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirSparseTensorEncodingAttrGetExplicitVal(attr::MlirAttribute)::MlirAttribute end """ @@ -8454,31 +7464,19 @@ end Returns the implicit value of the `sparse\\_tensor.encoding` attribute. """ function mlirSparseTensorEncodingAttrGetImplicitVal(attr) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetImplicitVal( - attr::MlirAttribute - )::MlirAttribute + @ccall mlir_c.mlirSparseTensorEncodingAttrGetImplicitVal(attr::MlirAttribute)::MlirAttribute end function mlirSparseTensorEncodingAttrGetStructuredN(lvlType) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetStructuredN( - lvlType::MlirSparseTensorLevelType - )::Cuint + @ccall mlir_c.mlirSparseTensorEncodingAttrGetStructuredN(lvlType::MlirSparseTensorLevelType)::Cuint end function mlirSparseTensorEncodingAttrGetStructuredM(lvlType) - @ccall mlir_c.mlirSparseTensorEncodingAttrGetStructuredM( - lvlType::MlirSparseTensorLevelType - )::Cuint + @ccall mlir_c.mlirSparseTensorEncodingAttrGetStructuredM(lvlType::MlirSparseTensorLevelType)::Cuint end function mlirSparseTensorEncodingAttrBuildLvlType(lvlFmt, properties, propSize, n, m) - @ccall mlir_c.mlirSparseTensorEncodingAttrBuildLvlType( - lvlFmt::MlirSparseTensorLevelFormat, - properties::Ptr{MlirSparseTensorLevelPropertyNondefault}, - propSize::Cuint, - n::Cuint, - m::Cuint, - )::MlirSparseTensorLevelType + @ccall mlir_c.mlirSparseTensorEncodingAttrBuildLvlType(lvlFmt::MlirSparseTensorLevelFormat, properties::Ptr{MlirSparseTensorLevelPropertyNondefault}, propSize::Cuint, n::Cuint, m::Cuint)::MlirSparseTensorLevelType end function mlirGetDialectHandle__tensor__() @@ -8534,9 +7532,7 @@ function mlirTransformOperationTypeGetTypeID() end function mlirTransformOperationTypeGet(ctx, operationName) - @ccall mlir_c.mlirTransformOperationTypeGet( - ctx::MlirContext, operationName::MlirStringRef - )::MlirType + @ccall mlir_c.mlirTransformOperationTypeGet(ctx::MlirContext, operationName::MlirStringRef)::MlirType end function mlirTransformOperationTypeGetOperationName(type) @@ -8578,9 +7574,7 @@ end Enables or disables expensive checks in transform options. """ function mlirTransformOptionsEnableExpensiveChecks(transformOptions, enable) - @ccall mlir_c.mlirTransformOptionsEnableExpensiveChecks( - transformOptions::MlirTransformOptions, enable::Bool - )::Cvoid + @ccall mlir_c.mlirTransformOptionsEnableExpensiveChecks(transformOptions::MlirTransformOptions, enable::Bool)::Cvoid end """ @@ -8589,9 +7583,7 @@ end Returns true if expensive checks are enabled in transform options. """ function mlirTransformOptionsGetExpensiveChecksEnabled(transformOptions) - @ccall mlir_c.mlirTransformOptionsGetExpensiveChecksEnabled( - transformOptions::MlirTransformOptions - )::Bool + @ccall mlir_c.mlirTransformOptionsGetExpensiveChecksEnabled(transformOptions::MlirTransformOptions)::Bool end """ @@ -8600,9 +7592,7 @@ end Enables or disables the enforcement of the top-level transform op being single in transform options. """ function mlirTransformOptionsEnforceSingleTopLevelTransformOp(transformOptions, enable) - @ccall mlir_c.mlirTransformOptionsEnforceSingleTopLevelTransformOp( - transformOptions::MlirTransformOptions, enable::Bool - )::Cvoid + @ccall mlir_c.mlirTransformOptionsEnforceSingleTopLevelTransformOp(transformOptions::MlirTransformOptions, enable::Bool)::Cvoid end """ @@ -8611,9 +7601,7 @@ end Returns true if the enforcement of the top-level transform op being single is enabled in transform options. """ function mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(transformOptions) - @ccall mlir_c.mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( - transformOptions::MlirTransformOptions - )::Bool + @ccall mlir_c.mlirTransformOptionsGetEnforceSingleTopLevelTransformOp(transformOptions::MlirTransformOptions)::Bool end """ @@ -8630,15 +7618,8 @@ end Applies the transformation script starting at the given transform root operation to the given payload operation. The module containing the transform root as well as the transform options should be provided. The transform operation must implement TransformOpInterface and the module must be a ModuleOp. Returns the status of the application. """ -function mlirTransformApplyNamedSequence( - payload, transformRoot, transformModule, transformOptions -) - @ccall mlir_c.mlirTransformApplyNamedSequence( - payload::MlirOperation, - transformRoot::MlirOperation, - transformModule::MlirOperation, - transformOptions::MlirTransformOptions, - )::MlirLogicalResult +function mlirTransformApplyNamedSequence(payload, transformRoot, transformModule, transformOptions) + @ccall mlir_c.mlirTransformApplyNamedSequence(payload::MlirOperation, transformRoot::MlirOperation, transformModule::MlirOperation, transformOptions::MlirTransformOptions)::MlirLogicalResult end """ @@ -8649,9 +7630,7 @@ Merge the symbols from `other` into `target`, potentially renaming them to avoid Note that this clones the `other` operation unlike the C++ counterpart that takes ownership. """ function mlirMergeSymbolsIntoFromClone(target, other) - @ccall mlir_c.mlirMergeSymbolsIntoFromClone( - target::MlirOperation, other::MlirOperation - )::MlirLogicalResult + @ccall mlir_c.mlirMergeSymbolsIntoFromClone(target::MlirOperation, other::MlirOperation)::MlirLogicalResult end function mlirGetDialectHandle__vector__() @@ -8668,13 +7647,7 @@ end Creates an ExecutionEngine for the provided ModuleOp. The ModuleOp is expected to be "translatable" to LLVM IR (only contains operations in dialects that implement the `LLVMTranslationDialectInterface`). The module ownership stays with the client and can be destroyed as soon as the call returns. `optLevel` is the optimization level to be used for transformation and code generation. LLVM passes at `optLevel` are run before code generation. The number and array of paths corresponding to shared libraries that will be loaded are specified via `numPaths` and `sharedLibPaths` respectively. TODO: figure out other options. """ function mlirExecutionEngineCreate(op, optLevel, numPaths, sharedLibPaths, enableObjectDump) - @ccall mlir_c.mlirExecutionEngineCreate( - op::MlirModule, - optLevel::Cint, - numPaths::Cint, - sharedLibPaths::Ptr{MlirStringRef}, - enableObjectDump::Bool, - )::MlirExecutionEngine + @ccall mlir_c.mlirExecutionEngineCreate(op::MlirModule, optLevel::Cint, numPaths::Cint, sharedLibPaths::Ptr{MlirStringRef}, enableObjectDump::Bool)::MlirExecutionEngine end """ @@ -8710,9 +7683,7 @@ end Invoke a native function in the execution engine by name with the arguments and result of the invoked function passed as an array of pointers. The function must have been tagged with the `llvm.emit\\_c\\_interface` attribute. Returns a failure if the execution fails for any reason (the function name can't be resolved for instance). """ function mlirExecutionEngineInvokePacked(jit, name, arguments) - @ccall mlir_c.mlirExecutionEngineInvokePacked( - jit::MlirExecutionEngine, name::MlirStringRef, arguments::Ptr{Ptr{Cvoid}} - )::MlirLogicalResult + @ccall mlir_c.mlirExecutionEngineInvokePacked(jit::MlirExecutionEngine, name::MlirStringRef, arguments::Ptr{Ptr{Cvoid}})::MlirLogicalResult end """ @@ -8721,9 +7692,7 @@ end Lookup the wrapper of the native function in the execution engine with the given name, returns nullptr if the function can't be looked-up. """ function mlirExecutionEngineLookupPacked(jit, name) - @ccall mlir_c.mlirExecutionEngineLookupPacked( - jit::MlirExecutionEngine, name::MlirStringRef - )::Ptr{Cvoid} + @ccall mlir_c.mlirExecutionEngineLookupPacked(jit::MlirExecutionEngine, name::MlirStringRef)::Ptr{Cvoid} end """ @@ -8732,9 +7701,7 @@ end Lookup a native function in the execution engine by name, returns nullptr if the name can't be looked-up. """ function mlirExecutionEngineLookup(jit, name) - @ccall mlir_c.mlirExecutionEngineLookup( - jit::MlirExecutionEngine, name::MlirStringRef - )::Ptr{Cvoid} + @ccall mlir_c.mlirExecutionEngineLookup(jit::MlirExecutionEngine, name::MlirStringRef)::Ptr{Cvoid} end """ @@ -8743,9 +7710,7 @@ end Register a symbol with the jit: this symbol will be accessible to the jitted code. """ function mlirExecutionEngineRegisterSymbol(jit, name, sym) - @ccall mlir_c.mlirExecutionEngineRegisterSymbol( - jit::MlirExecutionEngine, name::MlirStringRef, sym::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirExecutionEngineRegisterSymbol(jit::MlirExecutionEngine, name::MlirStringRef, sym::Ptr{Cvoid})::Cvoid end """ @@ -8754,9 +7719,7 @@ end Dump as an object in `fileName`. """ function mlirExecutionEngineDumpToObjectFile(jit, fileName) - @ccall mlir_c.mlirExecutionEngineDumpToObjectFile( - jit::MlirExecutionEngine, fileName::MlirStringRef - )::Cvoid + @ccall mlir_c.mlirExecutionEngineDumpToObjectFile(jit::MlirExecutionEngine, fileName::MlirStringRef)::Cvoid end """ @@ -8765,9 +7728,7 @@ end Returns `true` if the given operation implements an interface identified by its TypeID. """ function mlirOperationImplementsInterface(operation, interfaceTypeID) - @ccall mlir_c.mlirOperationImplementsInterface( - operation::MlirOperation, interfaceTypeID::MlirTypeID - )::Bool + @ccall mlir_c.mlirOperationImplementsInterface(operation::MlirOperation, interfaceTypeID::MlirTypeID)::Bool end """ @@ -8776,9 +7737,7 @@ end Returns `true` if the operation identified by its canonical string name implements the interface identified by its TypeID in the given context. Note that interfaces may be attached to operations in some contexts and not others. """ function mlirOperationImplementsInterfaceStatic(operationName, context, interfaceTypeID) - @ccall mlir_c.mlirOperationImplementsInterfaceStatic( - operationName::MlirStringRef, context::MlirContext, interfaceTypeID::MlirTypeID - )::Bool + @ccall mlir_c.mlirOperationImplementsInterfaceStatic(operationName::MlirStringRef, context::MlirContext, interfaceTypeID::MlirTypeID)::Bool end """ @@ -8801,32 +7760,8 @@ const MlirTypesCallback = Ptr{Cvoid} Infers the return types of the operation identified by its canonical given the arguments that will be supplied to its generic builder. Calls `callback` with the types of inferred arguments, potentially several times, on success. Returns failure otherwise. """ -function mlirInferTypeOpInterfaceInferReturnTypes( - opName, - context, - location, - nOperands, - operands, - attributes, - properties, - nRegions, - regions, - callback, - userData, -) - @ccall mlir_c.mlirInferTypeOpInterfaceInferReturnTypes( - opName::MlirStringRef, - context::MlirContext, - location::MlirLocation, - nOperands::Cptrdiff_t, - operands::Ptr{MlirValue}, - attributes::MlirAttribute, - properties::Ptr{Cvoid}, - nRegions::Cptrdiff_t, - regions::Ptr{MlirRegion}, - callback::MlirTypesCallback, - userData::Ptr{Cvoid}, - )::MlirLogicalResult +function mlirInferTypeOpInterfaceInferReturnTypes(opName, context, location, nOperands, operands, attributes, properties, nRegions, regions, callback, userData) + @ccall mlir_c.mlirInferTypeOpInterfaceInferReturnTypes(opName::MlirStringRef, context::MlirContext, location::MlirLocation, nOperands::Cptrdiff_t, operands::Ptr{MlirValue}, attributes::MlirAttribute, properties::Ptr{Cvoid}, nRegions::Cptrdiff_t, regions::Ptr{MlirRegion}, callback::MlirTypesCallback, userData::Ptr{Cvoid})::MlirLogicalResult end """ @@ -8849,32 +7784,8 @@ const MlirShapedTypeComponentsCallback = Ptr{Cvoid} Infers the return shaped type components of the operation. Calls `callback` with the types of inferred arguments on success. Returns failure otherwise. """ -function mlirInferShapedTypeOpInterfaceInferReturnTypes( - opName, - context, - location, - nOperands, - operands, - attributes, - properties, - nRegions, - regions, - callback, - userData, -) - @ccall mlir_c.mlirInferShapedTypeOpInterfaceInferReturnTypes( - opName::MlirStringRef, - context::MlirContext, - location::MlirLocation, - nOperands::Cptrdiff_t, - operands::Ptr{MlirValue}, - attributes::MlirAttribute, - properties::Ptr{Cvoid}, - nRegions::Cptrdiff_t, - regions::Ptr{MlirRegion}, - callback::MlirShapedTypeComponentsCallback, - userData::Ptr{Cvoid}, - )::MlirLogicalResult +function mlirInferShapedTypeOpInterfaceInferReturnTypes(opName, context, location, nOperands, operands, attributes, properties, nRegions, regions, callback, userData) + @ccall mlir_c.mlirInferShapedTypeOpInterfaceInferReturnTypes(opName::MlirStringRef, context::MlirContext, location::MlirLocation, nOperands::Cptrdiff_t, operands::Ptr{MlirValue}, attributes::MlirAttribute, properties::Ptr{Cvoid}, nRegions::Cptrdiff_t, regions::Ptr{MlirRegion}, callback::MlirShapedTypeComponentsCallback, userData::Ptr{Cvoid})::MlirLogicalResult end struct MlirPass @@ -8908,9 +7819,7 @@ end Create a new top-level PassManager anchored on `anchorOp`. """ function mlirPassManagerCreateOnOperation(ctx, anchorOp) - @ccall mlir_c.mlirPassManagerCreateOnOperation( - ctx::MlirContext, anchorOp::MlirStringRef - )::MlirPassManager + @ccall mlir_c.mlirPassManagerCreateOnOperation(ctx::MlirContext, anchorOp::MlirStringRef)::MlirPassManager end """ @@ -8937,9 +7846,7 @@ end Cast a top-level PassManager to a generic OpPassManager. """ function mlirPassManagerGetAsOpPassManager(passManager) - @ccall mlir_c.mlirPassManagerGetAsOpPassManager( - passManager::MlirPassManager - )::MlirOpPassManager + @ccall mlir_c.mlirPassManagerGetAsOpPassManager(passManager::MlirPassManager)::MlirOpPassManager end """ @@ -8948,9 +7855,7 @@ end Run the provided `passManager` on the given `op`. """ function mlirPassManagerRunOnOp(passManager, op) - @ccall mlir_c.mlirPassManagerRunOnOp( - passManager::MlirPassManager, op::MlirOperation - )::MlirLogicalResult + @ccall mlir_c.mlirPassManagerRunOnOp(passManager::MlirPassManager, op::MlirOperation)::MlirLogicalResult end """ @@ -8958,26 +7863,8 @@ end Enable IR printing. The treePrintingPath argument is an optional path to a directory where the dumps will be produced. If it isn't provided then dumps are produced to stderr. """ -function mlirPassManagerEnableIRPrinting( - passManager, - printBeforeAll, - printAfterAll, - printModuleScope, - printAfterOnlyOnChange, - printAfterOnlyOnFailure, - flags, - treePrintingPath, -) - @ccall mlir_c.mlirPassManagerEnableIRPrinting( - passManager::MlirPassManager, - printBeforeAll::Bool, - printAfterAll::Bool, - printModuleScope::Bool, - printAfterOnlyOnChange::Bool, - printAfterOnlyOnFailure::Bool, - flags::MlirOpPrintingFlags, - treePrintingPath::MlirStringRef, - )::Cvoid +function mlirPassManagerEnableIRPrinting(passManager, printBeforeAll, printAfterAll, printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, flags, treePrintingPath) + @ccall mlir_c.mlirPassManagerEnableIRPrinting(passManager::MlirPassManager, printBeforeAll::Bool, printAfterAll::Bool, printModuleScope::Bool, printAfterOnlyOnChange::Bool, printAfterOnlyOnFailure::Bool, flags::MlirOpPrintingFlags, treePrintingPath::MlirStringRef)::Cvoid end """ @@ -8986,9 +7873,7 @@ end Enable / disable verify-each. """ function mlirPassManagerEnableVerifier(passManager, enable) - @ccall mlir_c.mlirPassManagerEnableVerifier( - passManager::MlirPassManager, enable::Bool - )::Cvoid + @ccall mlir_c.mlirPassManagerEnableVerifier(passManager::MlirPassManager, enable::Bool)::Cvoid end """ @@ -9027,9 +7912,7 @@ end Nest an OpPassManager under the top-level PassManager, the nested passmanager will only run on operations matching the provided name. The returned OpPassManager will be destroyed when the parent is destroyed. To further nest more OpPassManager under the newly returned one, see `mlirOpPassManagerNest` below. """ function mlirPassManagerGetNestedUnder(passManager, operationName) - @ccall mlir_c.mlirPassManagerGetNestedUnder( - passManager::MlirPassManager, operationName::MlirStringRef - )::MlirOpPassManager + @ccall mlir_c.mlirPassManagerGetNestedUnder(passManager::MlirPassManager, operationName::MlirStringRef)::MlirOpPassManager end """ @@ -9038,9 +7921,7 @@ end Nest an OpPassManager under the provided OpPassManager, the nested passmanager will only run on operations matching the provided name. The returned OpPassManager will be destroyed when the parent is destroyed. """ function mlirOpPassManagerGetNestedUnder(passManager, operationName) - @ccall mlir_c.mlirOpPassManagerGetNestedUnder( - passManager::MlirOpPassManager, operationName::MlirStringRef - )::MlirOpPassManager + @ccall mlir_c.mlirOpPassManagerGetNestedUnder(passManager::MlirOpPassManager, operationName::MlirStringRef)::MlirOpPassManager end """ @@ -9049,9 +7930,7 @@ end Add a pass and transfer ownership to the provided top-level mlirPassManager. If the pass is not a generic operation pass or a ModulePass, a new OpPassManager is implicitly nested under the provided PassManager. """ function mlirPassManagerAddOwnedPass(passManager, pass) - @ccall mlir_c.mlirPassManagerAddOwnedPass( - passManager::MlirPassManager, pass::MlirPass - )::Cvoid + @ccall mlir_c.mlirPassManagerAddOwnedPass(passManager::MlirPassManager, pass::MlirPass)::Cvoid end """ @@ -9060,9 +7939,7 @@ end Add a pass and transfer ownership to the provided mlirOpPassManager. If the pass is not a generic operation pass or matching the type of the provided PassManager, a new OpPassManager is implicitly nested under the provided PassManager. """ function mlirOpPassManagerAddOwnedPass(passManager, pass) - @ccall mlir_c.mlirOpPassManagerAddOwnedPass( - passManager::MlirOpPassManager, pass::MlirPass - )::Cvoid + @ccall mlir_c.mlirOpPassManagerAddOwnedPass(passManager::MlirOpPassManager, pass::MlirPass)::Cvoid end """ @@ -9071,12 +7948,7 @@ end Parse a sequence of textual MLIR pass pipeline elements and add them to the provided OpPassManager. If parsing fails an error message is reported using the provided callback. """ function mlirOpPassManagerAddPipeline(passManager, pipelineElements, callback, userData) - @ccall mlir_c.mlirOpPassManagerAddPipeline( - passManager::MlirOpPassManager, - pipelineElements::MlirStringRef, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - )::MlirLogicalResult + @ccall mlir_c.mlirOpPassManagerAddPipeline(passManager::MlirOpPassManager, pipelineElements::MlirStringRef, callback::MlirStringCallback, userData::Ptr{Cvoid})::MlirLogicalResult end """ @@ -9085,9 +7957,7 @@ end Print a textual MLIR pass pipeline by sending chunks of the string representation and forwarding `userData to `callback`. Note that the callback may be called several times with consecutive chunks of the string. """ function mlirPrintPassPipeline(passManager, callback, userData) - @ccall mlir_c.mlirPrintPassPipeline( - passManager::MlirOpPassManager, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirPrintPassPipeline(passManager::MlirOpPassManager, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end """ @@ -9096,12 +7966,7 @@ end Parse a textual MLIR pass pipeline and assign it to the provided OpPassManager. If parsing fails an error message is reported using the provided callback. """ function mlirParsePassPipeline(passManager, pipeline, callback, userData) - @ccall mlir_c.mlirParsePassPipeline( - passManager::MlirOpPassManager, - pipeline::MlirStringRef, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - )::MlirLogicalResult + @ccall mlir_c.mlirParsePassPipeline(passManager::MlirOpPassManager, pipeline::MlirStringRef, callback::MlirStringCallback, userData::Ptr{Cvoid})::MlirLogicalResult end """ @@ -9130,28 +7995,8 @@ end Creates an external [`MlirPass`](@ref) that calls the supplied `callbacks` using the supplied `userData`. If `opName` is empty, the pass is a generic operation pass. Otherwise it is an operation pass specific to the specified pass name. """ -function mlirCreateExternalPass( - passID, - name, - argument, - description, - opName, - nDependentDialects, - dependentDialects, - callbacks, - userData, -) - @ccall mlir_c.mlirCreateExternalPass( - passID::MlirTypeID, - name::MlirStringRef, - argument::MlirStringRef, - description::MlirStringRef, - opName::MlirStringRef, - nDependentDialects::Cptrdiff_t, - dependentDialects::Ptr{MlirDialectHandle}, - callbacks::MlirExternalPassCallbacks, - userData::Ptr{Cvoid}, - )::MlirPass +function mlirCreateExternalPass(passID, name, argument, description, opName, nDependentDialects, dependentDialects, callbacks, userData) + @ccall mlir_c.mlirCreateExternalPass(passID::MlirTypeID, name::MlirStringRef, argument::MlirStringRef, description::MlirStringRef, opName::MlirStringRef, nDependentDialects::Cptrdiff_t, dependentDialects::Ptr{MlirDialectHandle}, callbacks::MlirExternalPassCallbacks, userData::Ptr{Cvoid})::MlirPass end """ @@ -9238,9 +8083,7 @@ end Sets the insertion point to the specified operation, which will cause subsequent insertions to go right before it. """ function mlirRewriterBaseSetInsertionPointBefore(rewriter, op) - @ccall mlir_c.mlirRewriterBaseSetInsertionPointBefore( - rewriter::MlirRewriterBase, op::MlirOperation - )::Cvoid + @ccall mlir_c.mlirRewriterBaseSetInsertionPointBefore(rewriter::MlirRewriterBase, op::MlirOperation)::Cvoid end """ @@ -9249,9 +8092,7 @@ end Sets the insertion point to the node after the specified operation, which will cause subsequent insertions to go right after it. """ function mlirRewriterBaseSetInsertionPointAfter(rewriter, op) - @ccall mlir_c.mlirRewriterBaseSetInsertionPointAfter( - rewriter::MlirRewriterBase, op::MlirOperation - )::Cvoid + @ccall mlir_c.mlirRewriterBaseSetInsertionPointAfter(rewriter::MlirRewriterBase, op::MlirOperation)::Cvoid end """ @@ -9260,9 +8101,7 @@ end Sets the insertion point to the node after the specified value. If value has a defining operation, sets the insertion point to the node after such defining operation. This will cause subsequent insertions to go right after it. Otherwise, value is a BlockArgument. Sets the insertion point to the start of its block. """ function mlirRewriterBaseSetInsertionPointAfterValue(rewriter, value) - @ccall mlir_c.mlirRewriterBaseSetInsertionPointAfterValue( - rewriter::MlirRewriterBase, value::MlirValue - )::Cvoid + @ccall mlir_c.mlirRewriterBaseSetInsertionPointAfterValue(rewriter::MlirRewriterBase, value::MlirValue)::Cvoid end """ @@ -9271,9 +8110,7 @@ end Sets the insertion point to the start of the specified block. """ function mlirRewriterBaseSetInsertionPointToStart(rewriter, block) - @ccall mlir_c.mlirRewriterBaseSetInsertionPointToStart( - rewriter::MlirRewriterBase, block::MlirBlock - )::Cvoid + @ccall mlir_c.mlirRewriterBaseSetInsertionPointToStart(rewriter::MlirRewriterBase, block::MlirBlock)::Cvoid end """ @@ -9282,9 +8119,7 @@ end Sets the insertion point to the end of the specified block. """ function mlirRewriterBaseSetInsertionPointToEnd(rewriter, block) - @ccall mlir_c.mlirRewriterBaseSetInsertionPointToEnd( - rewriter::MlirRewriterBase, block::MlirBlock - )::Cvoid + @ccall mlir_c.mlirRewriterBaseSetInsertionPointToEnd(rewriter::MlirRewriterBase, block::MlirBlock)::Cvoid end """ @@ -9321,16 +8156,8 @@ end Add new block with 'argTypes' arguments and set the insertion point to the end of it. The block is placed before 'insertBefore'. `locs` contains the locations of the inserted arguments, and should match the size of `argTypes`. """ -function mlirRewriterBaseCreateBlockBefore( - rewriter, insertBefore, nArgTypes, argTypes, locations -) - @ccall mlir_c.mlirRewriterBaseCreateBlockBefore( - rewriter::MlirRewriterBase, - insertBefore::MlirBlock, - nArgTypes::Cptrdiff_t, - argTypes::Ptr{MlirType}, - locations::Ptr{MlirLocation}, - )::MlirBlock +function mlirRewriterBaseCreateBlockBefore(rewriter, insertBefore, nArgTypes, argTypes, locations) + @ccall mlir_c.mlirRewriterBaseCreateBlockBefore(rewriter::MlirRewriterBase, insertBefore::MlirBlock, nArgTypes::Cptrdiff_t, argTypes::Ptr{MlirType}, locations::Ptr{MlirLocation})::MlirBlock end """ @@ -9339,9 +8166,7 @@ end Insert the given operation at the current insertion point and return it. """ function mlirRewriterBaseInsert(rewriter, op) - @ccall mlir_c.mlirRewriterBaseInsert( - rewriter::MlirRewriterBase, op::MlirOperation - )::MlirOperation + @ccall mlir_c.mlirRewriterBaseInsert(rewriter::MlirRewriterBase, op::MlirOperation)::MlirOperation end """ @@ -9350,9 +8175,7 @@ end Creates a deep copy of the specified operation. """ function mlirRewriterBaseClone(rewriter, op) - @ccall mlir_c.mlirRewriterBaseClone( - rewriter::MlirRewriterBase, op::MlirOperation - )::MlirOperation + @ccall mlir_c.mlirRewriterBaseClone(rewriter::MlirRewriterBase, op::MlirOperation)::MlirOperation end """ @@ -9361,9 +8184,7 @@ end Creates a deep copy of this operation but keep the operation regions empty. """ function mlirRewriterBaseCloneWithoutRegions(rewriter, op) - @ccall mlir_c.mlirRewriterBaseCloneWithoutRegions( - rewriter::MlirRewriterBase, op::MlirOperation - )::MlirOperation + @ccall mlir_c.mlirRewriterBaseCloneWithoutRegions(rewriter::MlirRewriterBase, op::MlirOperation)::MlirOperation end """ @@ -9372,9 +8193,7 @@ end Clone the blocks that belong to "region" before the given position in another region "parent". """ function mlirRewriterBaseCloneRegionBefore(rewriter, region, before) - @ccall mlir_c.mlirRewriterBaseCloneRegionBefore( - rewriter::MlirRewriterBase, region::MlirRegion, before::MlirBlock - )::Cvoid + @ccall mlir_c.mlirRewriterBaseCloneRegionBefore(rewriter::MlirRewriterBase, region::MlirRegion, before::MlirBlock)::Cvoid end """ @@ -9383,9 +8202,7 @@ end Move the blocks that belong to "region" before the given position in another region "parent". The two regions must be different. The caller is responsible for creating or updating the operation transferring flow of control to the region and passing it the correct block arguments. """ function mlirRewriterBaseInlineRegionBefore(rewriter, region, before) - @ccall mlir_c.mlirRewriterBaseInlineRegionBefore( - rewriter::MlirRewriterBase, region::MlirRegion, before::MlirBlock - )::Cvoid + @ccall mlir_c.mlirRewriterBaseInlineRegionBefore(rewriter::MlirRewriterBase, region::MlirRegion, before::MlirBlock)::Cvoid end """ @@ -9394,12 +8211,7 @@ end Replace the results of the given (original) operation with the specified list of values (replacements). The result types of the given op and the replacements must match. The original op is erased. """ function mlirRewriterBaseReplaceOpWithValues(rewriter, op, nValues, values) - @ccall mlir_c.mlirRewriterBaseReplaceOpWithValues( - rewriter::MlirRewriterBase, - op::MlirOperation, - nValues::Cptrdiff_t, - values::Ptr{MlirValue}, - )::Cvoid + @ccall mlir_c.mlirRewriterBaseReplaceOpWithValues(rewriter::MlirRewriterBase, op::MlirOperation, nValues::Cptrdiff_t, values::Ptr{MlirValue})::Cvoid end """ @@ -9408,9 +8220,7 @@ end Replace the results of the given (original) operation with the specified new op (replacement). The result types of the two ops must match. The original op is erased. """ function mlirRewriterBaseReplaceOpWithOperation(rewriter, op, newOp) - @ccall mlir_c.mlirRewriterBaseReplaceOpWithOperation( - rewriter::MlirRewriterBase, op::MlirOperation, newOp::MlirOperation - )::Cvoid + @ccall mlir_c.mlirRewriterBaseReplaceOpWithOperation(rewriter::MlirRewriterBase, op::MlirOperation, newOp::MlirOperation)::Cvoid end """ @@ -9419,9 +8229,7 @@ end Erases an operation that is known to have no uses. """ function mlirRewriterBaseEraseOp(rewriter, op) - @ccall mlir_c.mlirRewriterBaseEraseOp( - rewriter::MlirRewriterBase, op::MlirOperation - )::Cvoid + @ccall mlir_c.mlirRewriterBaseEraseOp(rewriter::MlirRewriterBase, op::MlirOperation)::Cvoid end """ @@ -9430,9 +8238,7 @@ end Erases a block along with all operations inside it. """ function mlirRewriterBaseEraseBlock(rewriter, block) - @ccall mlir_c.mlirRewriterBaseEraseBlock( - rewriter::MlirRewriterBase, block::MlirBlock - )::Cvoid + @ccall mlir_c.mlirRewriterBaseEraseBlock(rewriter::MlirRewriterBase, block::MlirBlock)::Cvoid end """ @@ -9443,13 +8249,7 @@ Inline the operations of block 'source' before the operation 'op'. The source bl The source block must have no successors. Otherwise, the resulting IR would have unreachable operations. """ function mlirRewriterBaseInlineBlockBefore(rewriter, source, op, nArgValues, argValues) - @ccall mlir_c.mlirRewriterBaseInlineBlockBefore( - rewriter::MlirRewriterBase, - source::MlirBlock, - op::MlirOperation, - nArgValues::Cptrdiff_t, - argValues::Ptr{MlirValue}, - )::Cvoid + @ccall mlir_c.mlirRewriterBaseInlineBlockBefore(rewriter::MlirRewriterBase, source::MlirBlock, op::MlirOperation, nArgValues::Cptrdiff_t, argValues::Ptr{MlirValue})::Cvoid end """ @@ -9460,13 +8260,7 @@ Inline the operations of block 'source' into the end of block 'dest'. The source The dest block must have no successors. Otherwise, the resulting IR would have unreachable operation. """ function mlirRewriterBaseMergeBlocks(rewriter, source, dest, nArgValues, argValues) - @ccall mlir_c.mlirRewriterBaseMergeBlocks( - rewriter::MlirRewriterBase, - source::MlirBlock, - dest::MlirBlock, - nArgValues::Cptrdiff_t, - argValues::Ptr{MlirValue}, - )::Cvoid + @ccall mlir_c.mlirRewriterBaseMergeBlocks(rewriter::MlirRewriterBase, source::MlirBlock, dest::MlirBlock, nArgValues::Cptrdiff_t, argValues::Ptr{MlirValue})::Cvoid end """ @@ -9475,9 +8269,7 @@ end Unlink this operation from its current block and insert it right before `existingOp` which may be in the same or another block in the same function. """ function mlirRewriterBaseMoveOpBefore(rewriter, op, existingOp) - @ccall mlir_c.mlirRewriterBaseMoveOpBefore( - rewriter::MlirRewriterBase, op::MlirOperation, existingOp::MlirOperation - )::Cvoid + @ccall mlir_c.mlirRewriterBaseMoveOpBefore(rewriter::MlirRewriterBase, op::MlirOperation, existingOp::MlirOperation)::Cvoid end """ @@ -9486,9 +8278,7 @@ end Unlink this operation from its current block and insert it right after `existingOp` which may be in the same or another block in the same function. """ function mlirRewriterBaseMoveOpAfter(rewriter, op, existingOp) - @ccall mlir_c.mlirRewriterBaseMoveOpAfter( - rewriter::MlirRewriterBase, op::MlirOperation, existingOp::MlirOperation - )::Cvoid + @ccall mlir_c.mlirRewriterBaseMoveOpAfter(rewriter::MlirRewriterBase, op::MlirOperation, existingOp::MlirOperation)::Cvoid end """ @@ -9497,9 +8287,7 @@ end Unlink this block and insert it right before `existingBlock`. """ function mlirRewriterBaseMoveBlockBefore(rewriter, block, existingBlock) - @ccall mlir_c.mlirRewriterBaseMoveBlockBefore( - rewriter::MlirRewriterBase, block::MlirBlock, existingBlock::MlirBlock - )::Cvoid + @ccall mlir_c.mlirRewriterBaseMoveBlockBefore(rewriter::MlirRewriterBase, block::MlirBlock, existingBlock::MlirBlock)::Cvoid end """ @@ -9508,9 +8296,7 @@ end This method is used to notify the rewriter that an in-place operation modification is about to happen. A call to this function *must* be followed by a call to either `finalizeOpModification` or `cancelOpModification`. This is a minor efficiency win (it avoids creating a new operation and removing the old one) but also often allows simpler code in the client. """ function mlirRewriterBaseStartOpModification(rewriter, op) - @ccall mlir_c.mlirRewriterBaseStartOpModification( - rewriter::MlirRewriterBase, op::MlirOperation - )::Cvoid + @ccall mlir_c.mlirRewriterBaseStartOpModification(rewriter::MlirRewriterBase, op::MlirOperation)::Cvoid end """ @@ -9519,9 +8305,7 @@ end This method is used to signal the end of an in-place modification of the given operation. This can only be called on operations that were provided to a call to `startOpModification`. """ function mlirRewriterBaseFinalizeOpModification(rewriter, op) - @ccall mlir_c.mlirRewriterBaseFinalizeOpModification( - rewriter::MlirRewriterBase, op::MlirOperation - )::Cvoid + @ccall mlir_c.mlirRewriterBaseFinalizeOpModification(rewriter::MlirRewriterBase, op::MlirOperation)::Cvoid end """ @@ -9530,9 +8314,7 @@ end This method cancels a pending in-place modification. This can only be called on operations that were provided to a call to `startOpModification`. """ function mlirRewriterBaseCancelOpModification(rewriter, op) - @ccall mlir_c.mlirRewriterBaseCancelOpModification( - rewriter::MlirRewriterBase, op::MlirOperation - )::Cvoid + @ccall mlir_c.mlirRewriterBaseCancelOpModification(rewriter::MlirRewriterBase, op::MlirOperation)::Cvoid end """ @@ -9541,9 +8323,7 @@ end Find uses of `from` and replace them with `to`. Also notify the listener about every in-place op modification (for every use that was replaced). """ function mlirRewriterBaseReplaceAllUsesWith(rewriter, from, to) - @ccall mlir_c.mlirRewriterBaseReplaceAllUsesWith( - rewriter::MlirRewriterBase, from::MlirValue, to::MlirValue - )::Cvoid + @ccall mlir_c.mlirRewriterBaseReplaceAllUsesWith(rewriter::MlirRewriterBase, from::MlirValue, to::MlirValue)::Cvoid end """ @@ -9552,12 +8332,7 @@ end Find uses of `from` and replace them with `to`. Also notify the listener about every in-place op modification (for every use that was replaced). """ function mlirRewriterBaseReplaceAllValueRangeUsesWith(rewriter, nValues, from, to) - @ccall mlir_c.mlirRewriterBaseReplaceAllValueRangeUsesWith( - rewriter::MlirRewriterBase, - nValues::Cptrdiff_t, - from::Ptr{MlirValue}, - to::Ptr{MlirValue}, - )::Cvoid + @ccall mlir_c.mlirRewriterBaseReplaceAllValueRangeUsesWith(rewriter::MlirRewriterBase, nValues::Cptrdiff_t, from::Ptr{MlirValue}, to::Ptr{MlirValue})::Cvoid end """ @@ -9566,9 +8341,7 @@ end Find uses of `from` and replace them with `to`. Also notify the listener about every in-place op modification (for every use that was replaced) and that the `from` operation is about to be replaced. """ function mlirRewriterBaseReplaceAllOpUsesWithValueRange(rewriter, from, nTo, to) - @ccall mlir_c.mlirRewriterBaseReplaceAllOpUsesWithValueRange( - rewriter::MlirRewriterBase, from::MlirOperation, nTo::Cptrdiff_t, to::Ptr{MlirValue} - )::Cvoid + @ccall mlir_c.mlirRewriterBaseReplaceAllOpUsesWithValueRange(rewriter::MlirRewriterBase, from::MlirOperation, nTo::Cptrdiff_t, to::Ptr{MlirValue})::Cvoid end """ @@ -9577,9 +8350,7 @@ end Find uses of `from` and replace them with `to`. Also notify the listener about every in-place op modification (for every use that was replaced) and that the `from` operation is about to be replaced. """ function mlirRewriterBaseReplaceAllOpUsesWithOperation(rewriter, from, to) - @ccall mlir_c.mlirRewriterBaseReplaceAllOpUsesWithOperation( - rewriter::MlirRewriterBase, from::MlirOperation, to::MlirOperation - )::Cvoid + @ccall mlir_c.mlirRewriterBaseReplaceAllOpUsesWithOperation(rewriter::MlirRewriterBase, from::MlirOperation, to::MlirOperation)::Cvoid end """ @@ -9587,16 +8358,8 @@ end Find uses of `from` within `block` and replace them with `to`. Also notify the listener about every in-place op modification (for every use that was replaced). The optional `allUsesReplaced` flag is set to "true" if all uses were replaced. """ -function mlirRewriterBaseReplaceOpUsesWithinBlock( - rewriter, op, nNewValues, newValues, block -) - @ccall mlir_c.mlirRewriterBaseReplaceOpUsesWithinBlock( - rewriter::MlirRewriterBase, - op::MlirOperation, - nNewValues::Cptrdiff_t, - newValues::Ptr{MlirValue}, - block::MlirBlock, - )::Cvoid +function mlirRewriterBaseReplaceOpUsesWithinBlock(rewriter, op, nNewValues, newValues, block) + @ccall mlir_c.mlirRewriterBaseReplaceOpUsesWithinBlock(rewriter::MlirRewriterBase, op::MlirOperation, nNewValues::Cptrdiff_t, newValues::Ptr{MlirValue}, block::MlirBlock)::Cvoid end """ @@ -9605,12 +8368,7 @@ end Find uses of `from` and replace them with `to` except if the user is `exceptedUser`. Also notify the listener about every in-place op modification (for every use that was replaced). """ function mlirRewriterBaseReplaceAllUsesExcept(rewriter, from, to, exceptedUser) - @ccall mlir_c.mlirRewriterBaseReplaceAllUsesExcept( - rewriter::MlirRewriterBase, - from::MlirValue, - to::MlirValue, - exceptedUser::MlirOperation, - )::Cvoid + @ccall mlir_c.mlirRewriterBaseReplaceAllUsesExcept(rewriter::MlirRewriterBase, from::MlirValue, to::MlirValue, exceptedUser::MlirOperation)::Cvoid end """ @@ -9663,19 +8421,11 @@ function mlirFrozenRewritePatternSetDestroy(set) end function mlirApplyPatternsAndFoldGreedilyWithOp(op, patterns, arg3) - @ccall mlir_c.mlirApplyPatternsAndFoldGreedilyWithOp( - op::MlirOperation, - patterns::MlirFrozenRewritePatternSet, - arg3::MlirGreedyRewriteDriverConfig, - )::MlirLogicalResult + @ccall mlir_c.mlirApplyPatternsAndFoldGreedilyWithOp(op::MlirOperation, patterns::MlirFrozenRewritePatternSet, arg3::MlirGreedyRewriteDriverConfig)::MlirLogicalResult end function mlirApplyPatternsAndFoldGreedily(op, patterns, arg3) - @ccall mlir_c.mlirApplyPatternsAndFoldGreedily( - op::MlirModule, - patterns::MlirFrozenRewritePatternSet, - arg3::MlirGreedyRewriteDriverConfig, - )::MlirLogicalResult + @ccall mlir_c.mlirApplyPatternsAndFoldGreedily(op::MlirModule, patterns::MlirFrozenRewritePatternSet, arg3::MlirGreedyRewriteDriverConfig)::MlirLogicalResult end """ @@ -9757,28 +8507,12 @@ end Emits SMTLIB for the specified module using the provided callback and user data """ -function mlirTranslateModuleToSMTLIB( - arg1, arg2, userData, inlineSingleUseValues, indentLetBody -) - @ccall mlir_c.mlirTranslateModuleToSMTLIB( - arg1::MlirModule, - arg2::MlirStringCallback, - userData::Ptr{Cvoid}, - inlineSingleUseValues::Bool, - indentLetBody::Bool, - )::MlirLogicalResult +function mlirTranslateModuleToSMTLIB(arg1, arg2, userData, inlineSingleUseValues, indentLetBody) + @ccall mlir_c.mlirTranslateModuleToSMTLIB(arg1::MlirModule, arg2::MlirStringCallback, userData::Ptr{Cvoid}, inlineSingleUseValues::Bool, indentLetBody::Bool)::MlirLogicalResult end -function mlirTranslateOperationToSMTLIB( - arg1, arg2, userData, inlineSingleUseValues, indentLetBody -) - @ccall mlir_c.mlirTranslateOperationToSMTLIB( - arg1::MlirOperation, - arg2::MlirStringCallback, - userData::Ptr{Cvoid}, - inlineSingleUseValues::Bool, - indentLetBody::Bool, - )::MlirLogicalResult +function mlirTranslateOperationToSMTLIB(arg1, arg2, userData, inlineSingleUseValues, indentLetBody) + @ccall mlir_c.mlirTranslateOperationToSMTLIB(arg1::MlirOperation, arg2::MlirStringCallback, userData::Ptr{Cvoid}, inlineSingleUseValues::Bool, indentLetBody::Bool)::MlirLogicalResult end """ @@ -9980,9 +8714,7 @@ llvm::DbgRecord const LLVMDbgRecordRef = Ptr{LLVMOpaqueDbgRecord} function LLVMParseCommandLineOptions(argc, argv, Overview) - @ccall mlir_c.LLVMParseCommandLineOptions( - argc::Cint, argv::Ptr{Cstring}, Overview::Cstring - )::Cint + @ccall mlir_c.LLVMParseCommandLineOptions(argc::Cint, argv::Ptr{Cstring}, Overview::Cstring)::Cint end function LLVMSearchForAddressOfSymbol(symbolName) @@ -10002,9 +8734,7 @@ Translate operation that satisfies LLVM dialect module requirements into an LLVM the generated LLVM IR Module from the translated MLIR module, it is owned by the caller. """ function mlirTranslateModuleToLLVMIR(_module, context) - @ccall mlir_c.mlirTranslateModuleToLLVMIR( - _module::MlirOperation, context::LLVMContextRef - )::LLVMModuleRef + @ccall mlir_c.mlirTranslateModuleToLLVMIR(_module::MlirOperation, context::LLVMContextRef)::LLVMModuleRef end struct MlirTypeFromLLVMIRTranslator @@ -10017,9 +8747,7 @@ end Create an LLVM::TypeFromLLVMIRTranslator and transfer ownership to the caller. """ function mlirTypeFromLLVMIRTranslatorCreate(ctx) - @ccall mlir_c.mlirTypeFromLLVMIRTranslatorCreate( - ctx::MlirContext - )::MlirTypeFromLLVMIRTranslator + @ccall mlir_c.mlirTypeFromLLVMIRTranslatorCreate(ctx::MlirContext)::MlirTypeFromLLVMIRTranslator end """ @@ -10028,9 +8756,7 @@ end Takes an LLVM::TypeFromLLVMIRTranslator owned by the caller and destroys it. It is the responsibility of the user to only pass an LLVM::TypeFromLLVMIRTranslator class. """ function mlirTypeFromLLVMIRTranslatorDestroy(translator) - @ccall mlir_c.mlirTypeFromLLVMIRTranslatorDestroy( - translator::MlirTypeFromLLVMIRTranslator - )::Cvoid + @ccall mlir_c.mlirTypeFromLLVMIRTranslatorDestroy(translator::MlirTypeFromLLVMIRTranslator)::Cvoid end """ @@ -10039,9 +8765,7 @@ end Translates the given LLVM IR type to the MLIR LLVM dialect. """ function mlirTypeFromLLVMIRTranslatorTranslateType(translator, llvmType) - @ccall mlir_c.mlirTypeFromLLVMIRTranslatorTranslateType( - translator::MlirTypeFromLLVMIRTranslator, llvmType::LLVMTypeRef - )::MlirType + @ccall mlir_c.mlirTypeFromLLVMIRTranslatorTranslateType(translator::MlirTypeFromLLVMIRTranslator, llvmType::LLVMTypeRef)::MlirType end struct MlirTypeToLLVMIRTranslator @@ -10054,9 +8778,7 @@ end Create an LLVM::TypeToLLVMIRTranslator and transfer ownership to the caller. """ function mlirTypeToLLVMIRTranslatorCreate(ctx) - @ccall mlir_c.mlirTypeToLLVMIRTranslatorCreate( - ctx::LLVMContextRef - )::MlirTypeToLLVMIRTranslator + @ccall mlir_c.mlirTypeToLLVMIRTranslatorCreate(ctx::LLVMContextRef)::MlirTypeToLLVMIRTranslator end """ @@ -10065,9 +8787,7 @@ end Takes an LLVM::TypeToLLVMIRTranslator owned by the caller and destroys it. It is the responsibility of the user to only pass an LLVM::TypeToLLVMIRTranslator class. """ function mlirTypeToLLVMIRTranslatorDestroy(translator) - @ccall mlir_c.mlirTypeToLLVMIRTranslatorDestroy( - translator::MlirTypeToLLVMIRTranslator - )::Cvoid + @ccall mlir_c.mlirTypeToLLVMIRTranslatorDestroy(translator::MlirTypeToLLVMIRTranslator)::Cvoid end """ @@ -10076,39 +8796,11 @@ end Translates the given MLIR LLVM dialect to the LLVM IR type. """ function mlirTypeToLLVMIRTranslatorTranslateType(translator, mlirType) - @ccall mlir_c.mlirTypeToLLVMIRTranslatorTranslateType( - translator::MlirTypeToLLVMIRTranslator, mlirType::MlirType - )::LLVMTypeRef + @ccall mlir_c.mlirTypeToLLVMIRTranslatorTranslateType(translator::MlirTypeToLLVMIRTranslator, mlirType::MlirType)::LLVMTypeRef end -function stablehloScatterDimensionNumbersGet( - ctx, - nUpdateWindowDims, - updateWindowDims, - nInsertedWindowDims, - insertedWindowDims, - nInputBatchingDims, - inputBatchingDims, - nScatterIndicesBatchingDims, - scatterIndicesBatchingDims, - nScatteredDimsToOperandDims, - scatteredDimsToOperandDims, - indexVectorDim, -) - @ccall mlir_c.stablehloScatterDimensionNumbersGet( - ctx::MlirContext, - nUpdateWindowDims::Cptrdiff_t, - updateWindowDims::Ptr{Int64}, - nInsertedWindowDims::Cptrdiff_t, - insertedWindowDims::Ptr{Int64}, - nInputBatchingDims::Cptrdiff_t, - inputBatchingDims::Ptr{Int64}, - nScatterIndicesBatchingDims::Cptrdiff_t, - scatterIndicesBatchingDims::Ptr{Int64}, - nScatteredDimsToOperandDims::Cptrdiff_t, - scatteredDimsToOperandDims::Ptr{Int64}, - indexVectorDim::Int64, - )::MlirAttribute +function stablehloScatterDimensionNumbersGet(ctx, nUpdateWindowDims, updateWindowDims, nInsertedWindowDims, insertedWindowDims, nInputBatchingDims, inputBatchingDims, nScatterIndicesBatchingDims, scatterIndicesBatchingDims, nScatteredDimsToOperandDims, scatteredDimsToOperandDims, indexVectorDim) + @ccall mlir_c.stablehloScatterDimensionNumbersGet(ctx::MlirContext, nUpdateWindowDims::Cptrdiff_t, updateWindowDims::Ptr{Int64}, nInsertedWindowDims::Cptrdiff_t, insertedWindowDims::Ptr{Int64}, nInputBatchingDims::Cptrdiff_t, inputBatchingDims::Ptr{Int64}, nScatterIndicesBatchingDims::Cptrdiff_t, scatterIndicesBatchingDims::Ptr{Int64}, nScatteredDimsToOperandDims::Cptrdiff_t, scatteredDimsToOperandDims::Ptr{Int64}, indexVectorDim::Int64)::MlirAttribute end function stablehloAttributeIsAScatterDimensionNumbers(attr) @@ -10116,97 +8808,51 @@ function stablehloAttributeIsAScatterDimensionNumbers(attr) end function stablehloScatterDimensionNumbersGetUpdateWindowDimsSize(attr) - @ccall mlir_c.stablehloScatterDimensionNumbersGetUpdateWindowDimsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloScatterDimensionNumbersGetUpdateWindowDimsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloScatterDimensionNumbersGetUpdateWindowDimsElem(attr, pos) - @ccall mlir_c.stablehloScatterDimensionNumbersGetUpdateWindowDimsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloScatterDimensionNumbersGetUpdateWindowDimsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloScatterDimensionNumbersGetInsertedWindowDimsSize(attr) - @ccall mlir_c.stablehloScatterDimensionNumbersGetInsertedWindowDimsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloScatterDimensionNumbersGetInsertedWindowDimsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloScatterDimensionNumbersGetInsertedWindowDimsElem(attr, pos) - @ccall mlir_c.stablehloScatterDimensionNumbersGetInsertedWindowDimsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloScatterDimensionNumbersGetInsertedWindowDimsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloScatterDimensionNumbersGetInputBatchingDimsSize(attr) - @ccall mlir_c.stablehloScatterDimensionNumbersGetInputBatchingDimsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloScatterDimensionNumbersGetInputBatchingDimsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloScatterDimensionNumbersGetInputBatchingDimsElem(attr, pos) - @ccall mlir_c.stablehloScatterDimensionNumbersGetInputBatchingDimsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloScatterDimensionNumbersGetInputBatchingDimsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsSize(attr) - @ccall mlir_c.stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsElem(attr, pos) - @ccall mlir_c.stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloScatterDimensionNumbersGetScatterIndicesBatchingDimsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize(attr) - @ccall mlir_c.stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem(attr, pos) - @ccall mlir_c.stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloScatterDimensionNumbersGetScatteredDimsToOperandDimsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloDimensionNumbersGetIndexVectorDim(attr) @ccall mlir_c.stablehloDimensionNumbersGetIndexVectorDim(attr::MlirAttribute)::Int64 end -function stablehloGatherDimensionNumbersGet( - ctx, - nOffsetDims, - offsetDims, - nCollapsedSliceDims, - collapsedSliceDims, - nOperandBatchingDims, - operandBatchingDims, - nStartIndicesBatchingDims, - startIndicesBatchingDims, - nStartIndexMap, - startIndexMap, - indexVectorDim, -) - @ccall mlir_c.stablehloGatherDimensionNumbersGet( - ctx::MlirContext, - nOffsetDims::Cptrdiff_t, - offsetDims::Ptr{Int64}, - nCollapsedSliceDims::Cptrdiff_t, - collapsedSliceDims::Ptr{Int64}, - nOperandBatchingDims::Cptrdiff_t, - operandBatchingDims::Ptr{Int64}, - nStartIndicesBatchingDims::Cptrdiff_t, - startIndicesBatchingDims::Ptr{Int64}, - nStartIndexMap::Cptrdiff_t, - startIndexMap::Ptr{Int64}, - indexVectorDim::Int64, - )::MlirAttribute +function stablehloGatherDimensionNumbersGet(ctx, nOffsetDims, offsetDims, nCollapsedSliceDims, collapsedSliceDims, nOperandBatchingDims, operandBatchingDims, nStartIndicesBatchingDims, startIndicesBatchingDims, nStartIndexMap, startIndexMap, indexVectorDim) + @ccall mlir_c.stablehloGatherDimensionNumbersGet(ctx::MlirContext, nOffsetDims::Cptrdiff_t, offsetDims::Ptr{Int64}, nCollapsedSliceDims::Cptrdiff_t, collapsedSliceDims::Ptr{Int64}, nOperandBatchingDims::Cptrdiff_t, operandBatchingDims::Ptr{Int64}, nStartIndicesBatchingDims::Cptrdiff_t, startIndicesBatchingDims::Ptr{Int64}, nStartIndexMap::Cptrdiff_t, startIndexMap::Ptr{Int64}, indexVectorDim::Int64)::MlirAttribute end function stablehloAttributeIsAGatherDimensionNumbers(attr) @@ -10214,91 +8860,51 @@ function stablehloAttributeIsAGatherDimensionNumbers(attr) end function stablehloGatherDimensionNumbersGetOffsetDimsSize(attr) - @ccall mlir_c.stablehloGatherDimensionNumbersGetOffsetDimsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloGatherDimensionNumbersGetOffsetDimsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloGatherDimensionNumbersGetOffsetDimsElem(attr, pos) - @ccall mlir_c.stablehloGatherDimensionNumbersGetOffsetDimsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloGatherDimensionNumbersGetOffsetDimsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(attr) - @ccall mlir_c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(attr, pos) - @ccall mlir_c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloGatherDimensionNumbersGetCollapsedSliceDimsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(attr) - @ccall mlir_c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloGatherDimensionNumbersGetOperandBatchingDimsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(attr, pos) - @ccall mlir_c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloGatherDimensionNumbersGetOperandBatchingDimsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(attr) - @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(attr, pos) - @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndicesBatchingDimsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloGatherDimensionNumbersGetStartIndexMapSize(attr) - @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndexMapSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndexMapSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloGatherDimensionNumbersGetStartIndexMapElem(attr, pos) - @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndexMapElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloGatherDimensionNumbersGetStartIndexMapElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloGatherDimensionNumbersGetIndexVectorDim(attr) - @ccall mlir_c.stablehloGatherDimensionNumbersGetIndexVectorDim( - attr::MlirAttribute - )::Int64 + @ccall mlir_c.stablehloGatherDimensionNumbersGetIndexVectorDim(attr::MlirAttribute)::Int64 end -function stablehloDotAlgorithmGet( - ctx, - lhsPrecisionType, - rhsPrecisionType, - accumulationType, - lhsComponentCount, - rhsComponentCount, - numPrimitiveOperations, - allowImpreciseAccumulation, -) - @ccall mlir_c.stablehloDotAlgorithmGet( - ctx::MlirContext, - lhsPrecisionType::MlirType, - rhsPrecisionType::MlirType, - accumulationType::MlirType, - lhsComponentCount::Int64, - rhsComponentCount::Int64, - numPrimitiveOperations::Int64, - allowImpreciseAccumulation::Bool, - )::MlirAttribute +function stablehloDotAlgorithmGet(ctx, lhsPrecisionType, rhsPrecisionType, accumulationType, lhsComponentCount, rhsComponentCount, numPrimitiveOperations, allowImpreciseAccumulation) + @ccall mlir_c.stablehloDotAlgorithmGet(ctx::MlirContext, lhsPrecisionType::MlirType, rhsPrecisionType::MlirType, accumulationType::MlirType, lhsComponentCount::Int64, rhsComponentCount::Int64, numPrimitiveOperations::Int64, allowImpreciseAccumulation::Bool)::MlirAttribute end function stablehloAttributeIsADotAlgorithm(attr) @@ -10330,33 +8936,11 @@ function stablehloDotAlgorithmGetNumPrimitiveOperations(attr) end function stablehloDotAlgorithmGetAllowImpreciseAccumulation(attr) - @ccall mlir_c.stablehloDotAlgorithmGetAllowImpreciseAccumulation( - attr::MlirAttribute - )::Bool + @ccall mlir_c.stablehloDotAlgorithmGetAllowImpreciseAccumulation(attr::MlirAttribute)::Bool end -function stablehloDotDimensionNumbersGet( - ctx, - nLhsBatchingDimensions, - lhsBatchingDimensions, - nRhsBatchingDimensions, - rhsBatchingDimensions, - nLhsContractingDimensions, - lhsContractingDimensions, - nRhsContractingDimensions, - rhsContractingDimensions, -) - @ccall mlir_c.stablehloDotDimensionNumbersGet( - ctx::MlirContext, - nLhsBatchingDimensions::Cptrdiff_t, - lhsBatchingDimensions::Ptr{Int64}, - nRhsBatchingDimensions::Cptrdiff_t, - rhsBatchingDimensions::Ptr{Int64}, - nLhsContractingDimensions::Cptrdiff_t, - lhsContractingDimensions::Ptr{Int64}, - nRhsContractingDimensions::Cptrdiff_t, - rhsContractingDimensions::Ptr{Int64}, - )::MlirAttribute +function stablehloDotDimensionNumbersGet(ctx, nLhsBatchingDimensions, lhsBatchingDimensions, nRhsBatchingDimensions, rhsBatchingDimensions, nLhsContractingDimensions, lhsContractingDimensions, nRhsContractingDimensions, rhsContractingDimensions) + @ccall mlir_c.stablehloDotDimensionNumbersGet(ctx::MlirContext, nLhsBatchingDimensions::Cptrdiff_t, lhsBatchingDimensions::Ptr{Int64}, nRhsBatchingDimensions::Cptrdiff_t, rhsBatchingDimensions::Ptr{Int64}, nLhsContractingDimensions::Cptrdiff_t, lhsContractingDimensions::Ptr{Int64}, nRhsContractingDimensions::Cptrdiff_t, rhsContractingDimensions::Ptr{Int64})::MlirAttribute end function stablehloAttributeIsADotDimensionNumbers(attr) @@ -10364,83 +8948,39 @@ function stablehloAttributeIsADotDimensionNumbers(attr) end function stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(attr) - @ccall mlir_c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(attr, pos) - @ccall mlir_c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloDotDimensionNumbersGetLhsBatchingDimensionsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(attr) - @ccall mlir_c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(attr, pos) - @ccall mlir_c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloDotDimensionNumbersGetRhsBatchingDimensionsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(attr) - @ccall mlir_c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloDotDimensionNumbersGetLhsContractingDimensionsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(attr, pos) - @ccall mlir_c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloDotDimensionNumbersGetLhsContractingDimensionsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(attr) - @ccall mlir_c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloDotDimensionNumbersGetRhsContractingDimensionsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(attr, pos) - @ccall mlir_c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloDotDimensionNumbersGetRhsContractingDimensionsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end -function stablehloConvDimensionNumbersGet( - ctx, - inputBatchDimension, - inputFeatureDimension, - nInputSpatialDimensions, - inputSpatialDimensions, - kernelInputFeatureDimension, - kernelOutputFeatureDimension, - nKernelSpatialDimensions, - kernelSpatialDimensions, - outputBatchDimension, - outputFeatureDimension, - nOutputSpatialDimensions, - outputSpatialDimensions, -) - @ccall mlir_c.stablehloConvDimensionNumbersGet( - ctx::MlirContext, - inputBatchDimension::Int64, - inputFeatureDimension::Int64, - nInputSpatialDimensions::Cptrdiff_t, - inputSpatialDimensions::Ptr{Int64}, - kernelInputFeatureDimension::Int64, - kernelOutputFeatureDimension::Int64, - nKernelSpatialDimensions::Cptrdiff_t, - kernelSpatialDimensions::Ptr{Int64}, - outputBatchDimension::Int64, - outputFeatureDimension::Int64, - nOutputSpatialDimensions::Cptrdiff_t, - outputSpatialDimensions::Ptr{Int64}, - )::MlirAttribute +function stablehloConvDimensionNumbersGet(ctx, inputBatchDimension, inputFeatureDimension, nInputSpatialDimensions, inputSpatialDimensions, kernelInputFeatureDimension, kernelOutputFeatureDimension, nKernelSpatialDimensions, kernelSpatialDimensions, outputBatchDimension, outputFeatureDimension, nOutputSpatialDimensions, outputSpatialDimensions) + @ccall mlir_c.stablehloConvDimensionNumbersGet(ctx::MlirContext, inputBatchDimension::Int64, inputFeatureDimension::Int64, nInputSpatialDimensions::Cptrdiff_t, inputSpatialDimensions::Ptr{Int64}, kernelInputFeatureDimension::Int64, kernelOutputFeatureDimension::Int64, nKernelSpatialDimensions::Cptrdiff_t, kernelSpatialDimensions::Ptr{Int64}, outputBatchDimension::Int64, outputFeatureDimension::Int64, nOutputSpatialDimensions::Cptrdiff_t, outputSpatialDimensions::Ptr{Int64})::MlirAttribute end function stablehloAttributeIsAConvDimensionNumbers(attr) @@ -10448,93 +8988,55 @@ function stablehloAttributeIsAConvDimensionNumbers(attr) end function stablehloConvDimensionNumbersGetInputBatchDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetInputBatchDimension( - attr::MlirAttribute - )::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetInputBatchDimension(attr::MlirAttribute)::Int64 end function stablehloConvDimensionNumbersGetInputFeatureDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetInputFeatureDimension( - attr::MlirAttribute - )::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetInputFeatureDimension(attr::MlirAttribute)::Int64 end function stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloConvDimensionNumbersGetInputSpatialDimensionsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(attr, pos) - @ccall mlir_c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetInputSpatialDimensionsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloConvDimensionNumbersGetKernelInputFeatureDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension( - attr::MlirAttribute - )::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetKernelInputFeatureDimension(attr::MlirAttribute)::Int64 end function stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension( - attr::MlirAttribute - )::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetKernelOutputFeatureDimension(attr::MlirAttribute)::Int64 end function stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(attr, pos) - @ccall mlir_c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetKernelSpatialDimensionsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloConvDimensionNumbersGetOutputBatchDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetOutputBatchDimension( - attr::MlirAttribute - )::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetOutputBatchDimension(attr::MlirAttribute)::Int64 end function stablehloConvDimensionNumbersGetOutputFeatureDimension(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetOutputFeatureDimension( - attr::MlirAttribute - )::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetOutputFeatureDimension(attr::MlirAttribute)::Int64 end function stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(attr) - @ccall mlir_c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(attr, pos) - @ccall mlir_c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloConvDimensionNumbersGetOutputSpatialDimensionsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end -function stablehloOutputOperandAliasGet( - ctx, - nOutputTupleIndices, - outputTupleIndices, - operandIndex, - nOperandTupleIndices, - operandTupleIndices, -) - @ccall mlir_c.stablehloOutputOperandAliasGet( - ctx::MlirContext, - nOutputTupleIndices::Cptrdiff_t, - outputTupleIndices::Ptr{Int64}, - operandIndex::Int64, - nOperandTupleIndices::Cptrdiff_t, - operandTupleIndices::Ptr{Int64}, - )::MlirAttribute +function stablehloOutputOperandAliasGet(ctx, nOutputTupleIndices, outputTupleIndices, operandIndex, nOperandTupleIndices, operandTupleIndices) + @ccall mlir_c.stablehloOutputOperandAliasGet(ctx::MlirContext, nOutputTupleIndices::Cptrdiff_t, outputTupleIndices::Ptr{Int64}, operandIndex::Int64, nOperandTupleIndices::Cptrdiff_t, operandTupleIndices::Ptr{Int64})::MlirAttribute end function stablehloAttributeIsAOutputOperandAlias(attr) @@ -10542,15 +9044,11 @@ function stablehloAttributeIsAOutputOperandAlias(attr) end function stablehloOutputOperandAliasGetOutputTupleIndicesSize(attr) - @ccall mlir_c.stablehloOutputOperandAliasGetOutputTupleIndicesSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloOutputOperandAliasGetOutputTupleIndicesSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloOutputOperandAliasGetOutputTupleIndicesElem(attr, pos) - @ccall mlir_c.stablehloOutputOperandAliasGetOutputTupleIndicesElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloOutputOperandAliasGetOutputTupleIndicesElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloOutputOperandAliasGetOperandIndex(attr) @@ -10558,21 +9056,15 @@ function stablehloOutputOperandAliasGetOperandIndex(attr) end function stablehloOutputOperandAliasGetOperandTupleIndicesSize(attr) - @ccall mlir_c.stablehloOutputOperandAliasGetOperandTupleIndicesSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.stablehloOutputOperandAliasGetOperandTupleIndicesSize(attr::MlirAttribute)::Cptrdiff_t end function stablehloOutputOperandAliasGetOperandTupleIndicesElem(attr, pos) - @ccall mlir_c.stablehloOutputOperandAliasGetOperandTupleIndicesElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloOutputOperandAliasGetOperandTupleIndicesElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloComparisonDirectionAttrGet(ctx, value) - @ccall mlir_c.stablehloComparisonDirectionAttrGet( - ctx::MlirContext, value::MlirStringRef - )::MlirAttribute + @ccall mlir_c.stablehloComparisonDirectionAttrGet(ctx::MlirContext, value::MlirStringRef)::MlirAttribute end function stablehloAttributeIsAComparisonDirectionAttr(attr) @@ -10580,15 +9072,11 @@ function stablehloAttributeIsAComparisonDirectionAttr(attr) end function stablehloComparisonDirectionAttrGetValue(attr) - @ccall mlir_c.stablehloComparisonDirectionAttrGetValue( - attr::MlirAttribute - )::MlirStringRef + @ccall mlir_c.stablehloComparisonDirectionAttrGetValue(attr::MlirAttribute)::MlirStringRef end function stablehloComparisonTypeAttrGet(ctx, value) - @ccall mlir_c.stablehloComparisonTypeAttrGet( - ctx::MlirContext, value::MlirStringRef - )::MlirAttribute + @ccall mlir_c.stablehloComparisonTypeAttrGet(ctx::MlirContext, value::MlirStringRef)::MlirAttribute end function stablehloAttributeIsAComparisonTypeAttr(attr) @@ -10600,9 +9088,7 @@ function stablehloComparisonTypeAttrGetValue(attr) end function stablehloPrecisionAttrGet(ctx, value) - @ccall mlir_c.stablehloPrecisionAttrGet( - ctx::MlirContext, value::MlirStringRef - )::MlirAttribute + @ccall mlir_c.stablehloPrecisionAttrGet(ctx::MlirContext, value::MlirStringRef)::MlirAttribute end function stablehloAttributeIsAPrecisionAttr(attr) @@ -10614,9 +9100,7 @@ function stablehloPrecisionAttrGetValue(attr) end function stablehloFftTypeAttrGet(ctx, value) - @ccall mlir_c.stablehloFftTypeAttrGet( - ctx::MlirContext, value::MlirStringRef - )::MlirAttribute + @ccall mlir_c.stablehloFftTypeAttrGet(ctx::MlirContext, value::MlirStringRef)::MlirAttribute end function stablehloAttributeIsAFftTypeAttr(attr) @@ -10628,9 +9112,7 @@ function stablehloFftTypeAttrGetValue(attr) end function stablehloTransposeAttrGet(ctx, value) - @ccall mlir_c.stablehloTransposeAttrGet( - ctx::MlirContext, value::MlirStringRef - )::MlirAttribute + @ccall mlir_c.stablehloTransposeAttrGet(ctx::MlirContext, value::MlirStringRef)::MlirAttribute end function stablehloAttributeIsATransposeAttr(attr) @@ -10642,9 +9124,7 @@ function stablehloTransposeAttrGetValue(attr) end function stablehloRngDistributionAttrGet(ctx, value) - @ccall mlir_c.stablehloRngDistributionAttrGet( - ctx::MlirContext, value::MlirStringRef - )::MlirAttribute + @ccall mlir_c.stablehloRngDistributionAttrGet(ctx::MlirContext, value::MlirStringRef)::MlirAttribute end function stablehloAttributeIsARngDistributionAttr(attr) @@ -10656,9 +9136,7 @@ function stablehloRngDistributionAttrGetValue(attr) end function stablehloRngAlgorithmAttrGet(ctx, value) - @ccall mlir_c.stablehloRngAlgorithmAttrGet( - ctx::MlirContext, value::MlirStringRef - )::MlirAttribute + @ccall mlir_c.stablehloRngAlgorithmAttrGet(ctx::MlirContext, value::MlirStringRef)::MlirAttribute end function stablehloAttributeIsARngAlgorithmAttr(attr) @@ -10670,9 +9148,7 @@ function stablehloRngAlgorithmAttrGetValue(attr) end function stablehloChannelHandleGet(ctx, handle, type) - @ccall mlir_c.stablehloChannelHandleGet( - ctx::MlirContext, handle::Int64, type::Int64 - )::MlirAttribute + @ccall mlir_c.stablehloChannelHandleGet(ctx::MlirContext, handle::Int64, type::Int64)::MlirAttribute end function stablehloAttributeIsChannelHandle(attr) @@ -10688,9 +9164,7 @@ function stablehloChannelHandleGetType(attr) end function stablehloTypeExtensionsGet(ctx, nBounds, bounds) - @ccall mlir_c.stablehloTypeExtensionsGet( - ctx::MlirContext, nBounds::Cptrdiff_t, bounds::Ptr{Int64} - )::MlirAttribute + @ccall mlir_c.stablehloTypeExtensionsGet(ctx::MlirContext, nBounds::Cptrdiff_t, bounds::Ptr{Int64})::MlirAttribute end function stablehloAttributeIsTypeExtensions(attr) @@ -10702,15 +9176,11 @@ function stablehloTypeExtensionsGetBoundsSize(attr) end function stablehloTypeExtensionsGetBoundsElem(attr, pos) - @ccall mlir_c.stablehloTypeExtensionsGetBoundsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.stablehloTypeExtensionsGetBoundsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function stablehloResultAccuracyModeAttrGet(ctx, value) - @ccall mlir_c.stablehloResultAccuracyModeAttrGet( - ctx::MlirContext, value::MlirStringRef - )::MlirAttribute + @ccall mlir_c.stablehloResultAccuracyModeAttrGet(ctx::MlirContext, value::MlirStringRef)::MlirAttribute end function stablehloAttributeIsAResultAccuracyModeAttr(attr) @@ -10718,15 +9188,11 @@ function stablehloAttributeIsAResultAccuracyModeAttr(attr) end function stablehloResultAccuracyModeAttrGetValue(attr) - @ccall mlir_c.stablehloResultAccuracyModeAttrGetValue( - attr::MlirAttribute - )::MlirStringRef + @ccall mlir_c.stablehloResultAccuracyModeAttrGetValue(attr::MlirAttribute)::MlirStringRef end function stablehloResultAccuracyAttrGet(ctx, atol, rtol, ulps, value) - @ccall mlir_c.stablehloResultAccuracyAttrGet( - ctx::MlirContext, atol::Cdouble, rtol::Cdouble, ulps::Int64, value::MlirStringRef - )::MlirAttribute + @ccall mlir_c.stablehloResultAccuracyAttrGet(ctx::MlirContext, atol::Cdouble, rtol::Cdouble, ulps::Int64, value::MlirStringRef)::MlirAttribute end function stablehloAttributeIsAResultAccuracyAttr(attr) @@ -10765,67 +9231,35 @@ end end function stablehloVersionFromCompatibilityRequirement(requirement, callback, userData) - @ccall mlir_c.stablehloVersionFromCompatibilityRequirement( - requirement::MlirStablehloCompatibilityRequirement, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - )::Cvoid + @ccall mlir_c.stablehloVersionFromCompatibilityRequirement(requirement::MlirStablehloCompatibilityRequirement, callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end function stablehloGetCurrentVersion(callback, userData) - @ccall mlir_c.stablehloGetCurrentVersion( - callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.stablehloGetCurrentVersion(callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end function stablehloGetMinimumVersion(callback, userData) - @ccall mlir_c.stablehloGetMinimumVersion( - callback::MlirStringCallback, userData::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.stablehloGetMinimumVersion(callback::MlirStringCallback, userData::Ptr{Cvoid})::Cvoid end function stablehloGetSmallerVersion(version1, version2, callback, userData) - @ccall mlir_c.stablehloGetSmallerVersion( - version1::MlirStringRef, - version2::MlirStringRef, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - )::MlirLogicalResult + @ccall mlir_c.stablehloGetSmallerVersion(version1::MlirStringRef, version2::MlirStringRef, callback::MlirStringCallback, userData::Ptr{Cvoid})::MlirLogicalResult end -function stablehloSerializePortableArtifactFromStringRef( - moduleStr, targetVersion, callback, userData -) - @ccall mlir_c.stablehloSerializePortableArtifactFromStringRef( - moduleStr::MlirStringRef, - targetVersion::MlirStringRef, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - )::MlirLogicalResult +function stablehloSerializePortableArtifactFromStringRef(moduleStr, targetVersion, callback, userData) + @ccall mlir_c.stablehloSerializePortableArtifactFromStringRef(moduleStr::MlirStringRef, targetVersion::MlirStringRef, callback::MlirStringCallback, userData::Ptr{Cvoid})::MlirLogicalResult end -function stablehloSerializePortableArtifactFromModule( - moduleStr, targetVersion, callback, userData, allowOtherDialects -) - @ccall mlir_c.stablehloSerializePortableArtifactFromModule( - moduleStr::MlirModule, - targetVersion::MlirStringRef, - callback::MlirStringCallback, - userData::Ptr{Cvoid}, - allowOtherDialects::Bool, - )::MlirLogicalResult +function stablehloSerializePortableArtifactFromModule(moduleStr, targetVersion, callback, userData, allowOtherDialects) + @ccall mlir_c.stablehloSerializePortableArtifactFromModule(moduleStr::MlirModule, targetVersion::MlirStringRef, callback::MlirStringCallback, userData::Ptr{Cvoid}, allowOtherDialects::Bool)::MlirLogicalResult end function stablehloDeserializePortableArtifact(artifactStr, callback, userData) - @ccall mlir_c.stablehloDeserializePortableArtifact( - artifactStr::MlirStringRef, callback::MlirStringCallback, userData::Ptr{Cvoid} - )::MlirLogicalResult + @ccall mlir_c.stablehloDeserializePortableArtifact(artifactStr::MlirStringRef, callback::MlirStringCallback, userData::Ptr{Cvoid})::MlirLogicalResult end function stablehloDeserializePortableArtifactNoError(artifactStr, ctx) - @ccall mlir_c.stablehloDeserializePortableArtifactNoError( - artifactStr::MlirStringRef, ctx::MlirContext - )::MlirModule + @ccall mlir_c.stablehloDeserializePortableArtifactNoError(artifactStr::MlirStringRef, ctx::MlirContext)::MlirModule end function stablehloTokenTypeGet(ctx) @@ -10841,9 +9275,7 @@ function sdyAttributeIsAMeshAxisAttr(attr) end function sdyMeshAxisAttrGet(ctx, name, size) - @ccall mlir_c.sdyMeshAxisAttrGet( - ctx::MlirContext, name::MlirStringRef, size::Int64 - )::MlirAttribute + @ccall mlir_c.sdyMeshAxisAttrGet(ctx::MlirContext, name::MlirStringRef, size::Int64)::MlirAttribute end function sdyMeshAxisAttrGetName(attr) @@ -10859,13 +9291,7 @@ function sdyAttributeIsAMeshAttr(attr) end function sdyMeshAttrGet(ctx, nAxes, axes, nDeviceIds, deviceIds) - @ccall mlir_c.sdyMeshAttrGet( - ctx::MlirContext, - nAxes::Cptrdiff_t, - axes::Ptr{MlirAttribute}, - nDeviceIds::Cptrdiff_t, - deviceIds::Ptr{Int64}, - )::MlirAttribute + @ccall mlir_c.sdyMeshAttrGet(ctx::MlirContext, nAxes::Cptrdiff_t, axes::Ptr{MlirAttribute}, nDeviceIds::Cptrdiff_t, deviceIds::Ptr{Int64})::MlirAttribute end function sdyMeshAttrGetDeviceIdsSize(attr) @@ -10881,9 +9307,7 @@ function sdyMeshAttrGetAxesSize(attr) end function sdyMeshAttrGetAxesElem(attr, pos) - @ccall mlir_c.sdyMeshAttrGetAxesElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.sdyMeshAttrGetAxesElem(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end function sdyAttributeIsASubAxisInfoAttr(attr) @@ -10891,9 +9315,7 @@ function sdyAttributeIsASubAxisInfoAttr(attr) end function sdySubAxisInfoAttrGet(ctx, preSize, size) - @ccall mlir_c.sdySubAxisInfoAttrGet( - ctx::MlirContext, preSize::Int64, size::Int64 - )::MlirAttribute + @ccall mlir_c.sdySubAxisInfoAttrGet(ctx::MlirContext, preSize::Int64, size::Int64)::MlirAttribute end function sdySubAxisInfoAttrGetPreSize(attr) @@ -10909,9 +9331,7 @@ function sdyAttributeIsAnAxisRefAttr(attr) end function sdyAxisRefAttrGet(ctx, name, subAxisInfo) - @ccall mlir_c.sdyAxisRefAttrGet( - ctx::MlirContext, name::MlirStringRef, subAxisInfo::MlirAttribute - )::MlirAttribute + @ccall mlir_c.sdyAxisRefAttrGet(ctx::MlirContext, name::MlirStringRef, subAxisInfo::MlirAttribute)::MlirAttribute end function sdyAxisRefAttrGetName(attr) @@ -10927,13 +9347,7 @@ function sdyAttributeIsADimensionShardingAttr(attr) end function sdyDimensionShardingAttrGet(ctx, nAxes, axes, isClosed, priority) - @ccall mlir_c.sdyDimensionShardingAttrGet( - ctx::MlirContext, - nAxes::Cptrdiff_t, - axes::Ptr{MlirAttribute}, - isClosed::Bool, - priority::Int64, - )::MlirAttribute + @ccall mlir_c.sdyDimensionShardingAttrGet(ctx::MlirContext, nAxes::Cptrdiff_t, axes::Ptr{MlirAttribute}, isClosed::Bool, priority::Int64)::MlirAttribute end function sdyDimensionShardingAttrGetAxesSize(attr) @@ -10941,9 +9355,7 @@ function sdyDimensionShardingAttrGetAxesSize(attr) end function sdyDimensionShardingAttrGetAxesElem(attr, pos) - @ccall mlir_c.sdyDimensionShardingAttrGetAxesElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.sdyDimensionShardingAttrGetAxesElem(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end function sdyDimensionShardingAttrGetIsClosed(attr) @@ -10958,26 +9370,8 @@ function sdyAttributeIsATensorShardingAttr(attr) @ccall mlir_c.sdyAttributeIsATensorShardingAttr(attr::MlirAttribute)::Bool end -function sdyTensorShardingAttrGet( - ctx, - meshOrRef, - nDimShardings, - dimShardings, - nReplicatedAxes, - replicatedAxes, - nUnreducedAxes, - unreducedAxes, -) - @ccall mlir_c.sdyTensorShardingAttrGet( - ctx::MlirContext, - meshOrRef::MlirAttribute, - nDimShardings::Cptrdiff_t, - dimShardings::Ptr{MlirAttribute}, - nReplicatedAxes::Cptrdiff_t, - replicatedAxes::Ptr{MlirAttribute}, - nUnreducedAxes::Cptrdiff_t, - unreducedAxes::Ptr{MlirAttribute}, - )::MlirAttribute +function sdyTensorShardingAttrGet(ctx, meshOrRef, nDimShardings, dimShardings, nReplicatedAxes, replicatedAxes, nUnreducedAxes, unreducedAxes) + @ccall mlir_c.sdyTensorShardingAttrGet(ctx::MlirContext, meshOrRef::MlirAttribute, nDimShardings::Cptrdiff_t, dimShardings::Ptr{MlirAttribute}, nReplicatedAxes::Cptrdiff_t, replicatedAxes::Ptr{MlirAttribute}, nUnreducedAxes::Cptrdiff_t, unreducedAxes::Ptr{MlirAttribute})::MlirAttribute end function sdyTensorShardingAttrGetMeshOrRef(attr) @@ -10989,21 +9383,15 @@ function sdyTensorShardingAttrGetDimShardingsSize(attr) end function sdyTensorShardingAttrGetDimShardingsElem(attr, pos) - @ccall mlir_c.sdyTensorShardingAttrGetDimShardingsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.sdyTensorShardingAttrGetDimShardingsElem(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end function sdyTensorShardingAttrGetReplicatedAxesSize(attr) - @ccall mlir_c.sdyTensorShardingAttrGetReplicatedAxesSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.sdyTensorShardingAttrGetReplicatedAxesSize(attr::MlirAttribute)::Cptrdiff_t end function sdyTensorShardingAttrGetReplicatedAxesElem(attr, pos) - @ccall mlir_c.sdyTensorShardingAttrGetReplicatedAxesElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.sdyTensorShardingAttrGetReplicatedAxesElem(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end function sdyTensorShardingAttrGetUnreducedAxesSize(attr) @@ -11011,9 +9399,7 @@ function sdyTensorShardingAttrGetUnreducedAxesSize(attr) end function sdyTensorShardingAttrGetUnreducedAxesElem(attr, pos) - @ccall mlir_c.sdyTensorShardingAttrGetUnreducedAxesElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.sdyTensorShardingAttrGetUnreducedAxesElem(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end function sdyAttributeIsATensorShardingPerValueAttr(attr) @@ -11021,21 +9407,15 @@ function sdyAttributeIsATensorShardingPerValueAttr(attr) end function sdyTensorShardingPerValueAttrGet(ctx, nShardings, shardings) - @ccall mlir_c.sdyTensorShardingPerValueAttrGet( - ctx::MlirContext, nShardings::Cptrdiff_t, shardings::Ptr{MlirAttribute} - )::MlirAttribute + @ccall mlir_c.sdyTensorShardingPerValueAttrGet(ctx::MlirContext, nShardings::Cptrdiff_t, shardings::Ptr{MlirAttribute})::MlirAttribute end function sdyTensorShardingPerValueAttrGetShardingsSize(attr) - @ccall mlir_c.sdyTensorShardingPerValueAttrGetShardingsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.sdyTensorShardingPerValueAttrGetShardingsSize(attr::MlirAttribute)::Cptrdiff_t end function sdyTensorShardingPerValueAttrGetShardingsElem(attr, pos) - @ccall mlir_c.sdyTensorShardingPerValueAttrGetShardingsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.sdyTensorShardingPerValueAttrGetShardingsElem(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end function sdyAttributeIsADimMappingAttr(attr) @@ -11043,9 +9423,7 @@ function sdyAttributeIsADimMappingAttr(attr) end function sdyDimMappingAttrGet(ctx, nFactorIndices, factorIndices) - @ccall mlir_c.sdyDimMappingAttrGet( - ctx::MlirContext, nFactorIndices::Cptrdiff_t, factorIndices::Ptr{Int64} - )::MlirAttribute + @ccall mlir_c.sdyDimMappingAttrGet(ctx::MlirContext, nFactorIndices::Cptrdiff_t, factorIndices::Ptr{Int64})::MlirAttribute end function sdyDimMappingAttrGetFactorIndicesSize(attr) @@ -11053,9 +9431,7 @@ function sdyDimMappingAttrGetFactorIndicesSize(attr) end function sdyDimMappingAttrGetFactorIndicesElem(attr, pos) - @ccall mlir_c.sdyDimMappingAttrGetFactorIndicesElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.sdyDimMappingAttrGetFactorIndicesElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function sdyAttributeIsATensorMappingAttr(attr) @@ -11063,9 +9439,7 @@ function sdyAttributeIsATensorMappingAttr(attr) end function sdyTensorMappingAttrGet(ctx, nMappings, mappings) - @ccall mlir_c.sdyTensorMappingAttrGet( - ctx::MlirContext, nMappings::Cptrdiff_t, mappings::Ptr{MlirAttribute} - )::MlirAttribute + @ccall mlir_c.sdyTensorMappingAttrGet(ctx::MlirContext, nMappings::Cptrdiff_t, mappings::Ptr{MlirAttribute})::MlirAttribute end function sdyTensorMappingAttrGetRank(attr) @@ -11077,51 +9451,15 @@ function sdyTensorMappingAttrGetDimMappingsSize(attr) end function sdyTensorMappingAttrGetDimMappingsElem(attr, pos) - @ccall mlir_c.sdyTensorMappingAttrGetDimMappingsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.sdyTensorMappingAttrGetDimMappingsElem(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end function sdyAttributeIsAOpShardingRuleAttr(attr) @ccall mlir_c.sdyAttributeIsAOpShardingRuleAttr(attr::MlirAttribute)::Bool end -function sdyOpShardingRuleAttrGet( - ctx, - nFactorSizes, - factorSizes, - nOperandMappings, - operandMappings, - nResultMappings, - resultMappings, - nReductionFactors, - reductionFactors, - nNeedReplicationFactors, - needReplicationFactors, - nPermutationFactors, - permutationFactors, - nBlockedPropagationFactors, - blockedPropagationFactors, - isCustomRule, -) - @ccall mlir_c.sdyOpShardingRuleAttrGet( - ctx::MlirContext, - nFactorSizes::Cptrdiff_t, - factorSizes::Ptr{Int64}, - nOperandMappings::Cptrdiff_t, - operandMappings::Ptr{MlirAttribute}, - nResultMappings::Cptrdiff_t, - resultMappings::Ptr{MlirAttribute}, - nReductionFactors::Cptrdiff_t, - reductionFactors::Ptr{Int64}, - nNeedReplicationFactors::Cptrdiff_t, - needReplicationFactors::Ptr{Int64}, - nPermutationFactors::Cptrdiff_t, - permutationFactors::Ptr{Int64}, - nBlockedPropagationFactors::Cptrdiff_t, - blockedPropagationFactors::Ptr{Int64}, - isCustomRule::Bool, - )::MlirAttribute +function sdyOpShardingRuleAttrGet(ctx, nFactorSizes, factorSizes, nOperandMappings, operandMappings, nResultMappings, resultMappings, nReductionFactors, reductionFactors, nNeedReplicationFactors, needReplicationFactors, nPermutationFactors, permutationFactors, nBlockedPropagationFactors, blockedPropagationFactors, isCustomRule) + @ccall mlir_c.sdyOpShardingRuleAttrGet(ctx::MlirContext, nFactorSizes::Cptrdiff_t, factorSizes::Ptr{Int64}, nOperandMappings::Cptrdiff_t, operandMappings::Ptr{MlirAttribute}, nResultMappings::Cptrdiff_t, resultMappings::Ptr{MlirAttribute}, nReductionFactors::Cptrdiff_t, reductionFactors::Ptr{Int64}, nNeedReplicationFactors::Cptrdiff_t, needReplicationFactors::Ptr{Int64}, nPermutationFactors::Cptrdiff_t, permutationFactors::Ptr{Int64}, nBlockedPropagationFactors::Cptrdiff_t, blockedPropagationFactors::Ptr{Int64}, isCustomRule::Bool)::MlirAttribute end function sdyOpShardingRuleAttrGetIsCustom(attr) @@ -11133,81 +9471,55 @@ function sdyOpShardingRuleAttrGetFactorSizesSize(attr) end function sdyOpShardingRuleAttrGetFactorSizesElem(attr, pos) - @ccall mlir_c.sdyOpShardingRuleAttrGetFactorSizesElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.sdyOpShardingRuleAttrGetFactorSizesElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function sdyOpShardingRuleAttrGetOperandMappingsSize(attr) - @ccall mlir_c.sdyOpShardingRuleAttrGetOperandMappingsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.sdyOpShardingRuleAttrGetOperandMappingsSize(attr::MlirAttribute)::Cptrdiff_t end function sdyOpShardingRuleAttrGetOperandMappingsElem(attr, pos) - @ccall mlir_c.sdyOpShardingRuleAttrGetOperandMappingsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.sdyOpShardingRuleAttrGetOperandMappingsElem(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end function sdyOpShardingRuleAttrGetResultMappingsSize(attr) - @ccall mlir_c.sdyOpShardingRuleAttrGetResultMappingsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.sdyOpShardingRuleAttrGetResultMappingsSize(attr::MlirAttribute)::Cptrdiff_t end function sdyOpShardingRuleAttrGetResultMappingsElem(attr, pos) - @ccall mlir_c.sdyOpShardingRuleAttrGetResultMappingsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirAttribute + @ccall mlir_c.sdyOpShardingRuleAttrGetResultMappingsElem(attr::MlirAttribute, pos::Cptrdiff_t)::MlirAttribute end function sdyOpShardingRuleAttrGetReductionFactorsSize(attr) - @ccall mlir_c.sdyOpShardingRuleAttrGetReductionFactorsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.sdyOpShardingRuleAttrGetReductionFactorsSize(attr::MlirAttribute)::Cptrdiff_t end function sdyOpShardingRuleAttrGetReductionFactorsElem(attr, pos) - @ccall mlir_c.sdyOpShardingRuleAttrGetReductionFactorsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.sdyOpShardingRuleAttrGetReductionFactorsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function sdyOpShardingRuleAttrGetNeedReplicationFactorsSize(attr) - @ccall mlir_c.sdyOpShardingRuleAttrGetNeedReplicationFactorsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.sdyOpShardingRuleAttrGetNeedReplicationFactorsSize(attr::MlirAttribute)::Cptrdiff_t end function sdyOpShardingRuleAttrGetNeedReplicationFactorsElem(attr, pos) - @ccall mlir_c.sdyOpShardingRuleAttrGetNeedReplicationFactorsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.sdyOpShardingRuleAttrGetNeedReplicationFactorsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function sdyOpShardingRuleAttrGetPermutationFactorsSize(attr) - @ccall mlir_c.sdyOpShardingRuleAttrGetPermutationFactorsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.sdyOpShardingRuleAttrGetPermutationFactorsSize(attr::MlirAttribute)::Cptrdiff_t end function sdyOpShardingRuleAttrGetPermutationFactorsElem(attr, pos) - @ccall mlir_c.sdyOpShardingRuleAttrGetPermutationFactorsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.sdyOpShardingRuleAttrGetPermutationFactorsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function sdyOpShardingRuleAttrGetBlockedPropagationFactorsSize(attr) - @ccall mlir_c.sdyOpShardingRuleAttrGetBlockedPropagationFactorsSize( - attr::MlirAttribute - )::Cptrdiff_t + @ccall mlir_c.sdyOpShardingRuleAttrGetBlockedPropagationFactorsSize(attr::MlirAttribute)::Cptrdiff_t end function sdyOpShardingRuleAttrGetBlockedPropagationFactorsElem(attr, pos) - @ccall mlir_c.sdyOpShardingRuleAttrGetBlockedPropagationFactorsElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::Int64 + @ccall mlir_c.sdyOpShardingRuleAttrGetBlockedPropagationFactorsElem(attr::MlirAttribute, pos::Cptrdiff_t)::Int64 end function sdyAttributeIsAManualAxesAttr(attr) @@ -11215,9 +9527,7 @@ function sdyAttributeIsAManualAxesAttr(attr) end function sdyManualAxesAttrGet(ctx, nAxes, axes) - @ccall mlir_c.sdyManualAxesAttrGet( - ctx::MlirContext, nAxes::Cptrdiff_t, axes::Ptr{MlirAttribute} - )::MlirAttribute + @ccall mlir_c.sdyManualAxesAttrGet(ctx::MlirContext, nAxes::Cptrdiff_t, axes::Ptr{MlirAttribute})::MlirAttribute end function sdyManualAxesAttrGetAxesSize(attr) @@ -11225,9 +9535,7 @@ function sdyManualAxesAttrGetAxesSize(attr) end function sdyManualAxesAttrGetAxesElem(attr, pos) - @ccall mlir_c.sdyManualAxesAttrGetAxesElem( - attr::MlirAttribute, pos::Cptrdiff_t - )::MlirStringRef + @ccall mlir_c.sdyManualAxesAttrGetAxesElem(attr::MlirAttribute, pos::Cptrdiff_t)::MlirStringRef end function mlirGetDialectHandle__triton__() @@ -11235,9 +9543,7 @@ function mlirGetDialectHandle__triton__() end function mlirTritonPointerTypeGet(pointeeType, addressSpace) - @ccall mlir_c.mlirTritonPointerTypeGet( - pointeeType::MlirType, addressSpace::Cint - )::MlirType + @ccall mlir_c.mlirTritonPointerTypeGet(pointeeType::MlirType, addressSpace::Cint)::MlirType end function mlirTritonIsAPointer(type) @@ -11253,9 +9559,7 @@ function mlirTritonPointerTypeGetAddressSpace(pointerType) end function mlirTritonInferReduceOpEncoding(operandEncoding, axis) - @ccall mlir_c.mlirTritonInferReduceOpEncoding( - operandEncoding::MlirAttribute, axis::Cint - )::MlirAttribute + @ccall mlir_c.mlirTritonInferReduceOpEncoding(operandEncoding::MlirAttribute, axis::Cint)::MlirAttribute end function mlirGetDialectHandle__tpu__() @@ -11276,9 +9580,7 @@ function mlirTPUTiledLayoutAttrGetTiles(attr) end function mlirTPUAnalyzePotentialCommunication(op, has_communication, has_custom_barrier) - @ccall mlir_c.mlirTPUAnalyzePotentialCommunication( - op::MlirOperation, has_communication::Ptr{Bool}, has_custom_barrier::Ptr{Bool} - )::Cvoid + @ccall mlir_c.mlirTPUAnalyzePotentialCommunication(op::MlirOperation, has_communication::Ptr{Bool}, has_custom_barrier::Ptr{Bool})::Cvoid end @cenum MlirTpuImplicitDim::UInt32 begin @@ -11346,12 +9648,7 @@ struct MlirTpuApplyVectorLayoutContext end function mlirTpuVectorLayoutCreate(bitwidth, offsets, tiling, implicit_dim) - @ccall mlir_c.mlirTpuVectorLayoutCreate( - bitwidth::Cint, - offsets::MlirTpuLayoutOffsets, - tiling::MlirTpuI64TargetTuple, - implicit_dim::MlirTpuImplicitDim, - )::MlirTpuVectorLayout + @ccall mlir_c.mlirTpuVectorLayoutCreate(bitwidth::Cint, offsets::MlirTpuLayoutOffsets, tiling::MlirTpuI64TargetTuple, implicit_dim::MlirTpuImplicitDim)::MlirTpuVectorLayout end function mlirTpuVectorLayoutDestroy(arg1) @@ -11363,21 +9660,15 @@ function mlirTpuVectorLayoutGetBitwidth(layout) end function mlirTpuVectorLayoutGetOffsets(layout) - @ccall mlir_c.mlirTpuVectorLayoutGetOffsets( - layout::MlirTpuVectorLayout - )::MlirTpuLayoutOffsets + @ccall mlir_c.mlirTpuVectorLayoutGetOffsets(layout::MlirTpuVectorLayout)::MlirTpuLayoutOffsets end function mlirTpuVectorLayoutGetTiling(layout) - @ccall mlir_c.mlirTpuVectorLayoutGetTiling( - layout::MlirTpuVectorLayout - )::MlirTpuI64TargetTuple + @ccall mlir_c.mlirTpuVectorLayoutGetTiling(layout::MlirTpuVectorLayout)::MlirTpuI64TargetTuple end function mlirTpuVectorLayoutGetImplicitDim(layout) - @ccall mlir_c.mlirTpuVectorLayoutGetImplicitDim( - layout::MlirTpuVectorLayout - )::MlirTpuImplicitDim + @ccall mlir_c.mlirTpuVectorLayoutGetImplicitDim(layout::MlirTpuVectorLayout)::MlirTpuImplicitDim end function mlirTpuVectorLayoutGetPacking(layout) @@ -11389,97 +9680,55 @@ function mlirTpuVectorLayoutGetLayoutRank(layout) end function mlirTpuVectorLayoutEquals(lhs, rhs) - @ccall mlir_c.mlirTpuVectorLayoutEquals( - lhs::MlirTpuVectorLayout, rhs::MlirTpuVectorLayout - )::Bool + @ccall mlir_c.mlirTpuVectorLayoutEquals(lhs::MlirTpuVectorLayout, rhs::MlirTpuVectorLayout)::Bool end function mlirTpuVectorLayoutTilesPerVreg(layout, target_shape) - @ccall mlir_c.mlirTpuVectorLayoutTilesPerVreg( - layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple - )::Int64 + @ccall mlir_c.mlirTpuVectorLayoutTilesPerVreg(layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple)::Int64 end function mlirTpuVectorLayoutSublanesPerTile(layout, target_shape) - @ccall mlir_c.mlirTpuVectorLayoutSublanesPerTile( - layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple - )::Int64 + @ccall mlir_c.mlirTpuVectorLayoutSublanesPerTile(layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple)::Int64 end function mlirTpuVectorLayoutVregSlice(layout, target_shape) - @ccall mlir_c.mlirTpuVectorLayoutVregSlice( - layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple - )::MlirTpuI64TargetTuple + @ccall mlir_c.mlirTpuVectorLayoutVregSlice(layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple)::MlirTpuI64TargetTuple end function mlirTpuVectorLayoutImplicitShape(layout, shape) - @ccall mlir_c.mlirTpuVectorLayoutImplicitShape( - layout::MlirTpuVectorLayout, shape::MlirTpuI64ArrayRef - )::MlirTpuI64ArrayRef + @ccall mlir_c.mlirTpuVectorLayoutImplicitShape(layout::MlirTpuVectorLayout, shape::MlirTpuI64ArrayRef)::MlirTpuI64ArrayRef end function mlirTpuVectorLayoutTileArrayShape(layout, shape, target_shape) - @ccall mlir_c.mlirTpuVectorLayoutTileArrayShape( - layout::MlirTpuVectorLayout, - shape::MlirTpuI64ArrayRef, - target_shape::MlirTpuI64TargetTuple, - )::MlirTpuI64ArrayRef + @ccall mlir_c.mlirTpuVectorLayoutTileArrayShape(layout::MlirTpuVectorLayout, shape::MlirTpuI64ArrayRef, target_shape::MlirTpuI64TargetTuple)::MlirTpuI64ArrayRef end -function mlirTpuVectorLayoutTileDataBounds( - layout, ctx, full_shape, idxs, size, target_shape, allow_replicated -) - @ccall mlir_c.mlirTpuVectorLayoutTileDataBounds( - layout::MlirTpuVectorLayout, - ctx::MlirContext, - full_shape::Ptr{Int64}, - idxs::Ptr{Int64}, - size::Csize_t, - target_shape::MlirTpuI64TargetTuple, - allow_replicated::MlirTpuBoolTargetTuple, - )::MlirTpuVregDataBounds +function mlirTpuVectorLayoutTileDataBounds(layout, ctx, full_shape, idxs, size, target_shape, allow_replicated) + @ccall mlir_c.mlirTpuVectorLayoutTileDataBounds(layout::MlirTpuVectorLayout, ctx::MlirContext, full_shape::Ptr{Int64}, idxs::Ptr{Int64}, size::Csize_t, target_shape::MlirTpuI64TargetTuple, allow_replicated::MlirTpuBoolTargetTuple)::MlirTpuVregDataBounds end function mlirTpuVectorLayoutHasNaturalTopology(layout, target_shape) - @ccall mlir_c.mlirTpuVectorLayoutHasNaturalTopology( - layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple - )::Bool + @ccall mlir_c.mlirTpuVectorLayoutHasNaturalTopology(layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple)::Bool end function mlirTpuVectorLayoutHasNativeTiling(layout, target_shape) - @ccall mlir_c.mlirTpuVectorLayoutHasNativeTiling( - layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple - )::Bool + @ccall mlir_c.mlirTpuVectorLayoutHasNativeTiling(layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple)::Bool end function mlirTpuVectorLayoutGeneralizes(layout, other, shape, target_shape) - @ccall mlir_c.mlirTpuVectorLayoutGeneralizes( - layout::MlirTpuVectorLayout, - other::MlirTpuVectorLayout, - shape::MlirTpuI64ArrayRef, - target_shape::MlirTpuI64TargetTuple, - )::Bool + @ccall mlir_c.mlirTpuVectorLayoutGeneralizes(layout::MlirTpuVectorLayout, other::MlirTpuVectorLayout, shape::MlirTpuI64ArrayRef, target_shape::MlirTpuI64TargetTuple)::Bool end function mlirTpuVectorLayoutEquivalentTo(layout, other, shape, target_shape) - @ccall mlir_c.mlirTpuVectorLayoutEquivalentTo( - layout::MlirTpuVectorLayout, - other::MlirTpuVectorLayout, - shape::MlirTpuI64ArrayRef, - target_shape::MlirTpuI64TargetTuple, - )::Bool + @ccall mlir_c.mlirTpuVectorLayoutEquivalentTo(layout::MlirTpuVectorLayout, other::MlirTpuVectorLayout, shape::MlirTpuI64ArrayRef, target_shape::MlirTpuI64TargetTuple)::Bool end function mlirTpuVectorLayoutPrint(layout, callback, user_data) - @ccall mlir_c.mlirTpuVectorLayoutPrint( - layout::MlirTpuVectorLayout, callback::MlirStringCallback, user_data::Ptr{Cvoid} - )::Cvoid + @ccall mlir_c.mlirTpuVectorLayoutPrint(layout::MlirTpuVectorLayout, callback::MlirStringCallback, user_data::Ptr{Cvoid})::Cvoid end function mlirTpuVectorLayoutIsValid(layout, target_shape) - @ccall mlir_c.mlirTpuVectorLayoutIsValid( - layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple - )::Bool + @ccall mlir_c.mlirTpuVectorLayoutIsValid(layout::MlirTpuVectorLayout, target_shape::MlirTpuI64TargetTuple)::Bool end function mlirTpuVregDataBoundsDestroy(data_bounds) @@ -11487,72 +9736,35 @@ function mlirTpuVregDataBoundsDestroy(data_bounds) end function mlirTpuVregDataBoundsMaskVariesAlong(data_bounds, direction, target_shape) - @ccall mlir_c.mlirTpuVregDataBoundsMaskVariesAlong( - data_bounds::MlirTpuVregDataBounds, - direction::MlirTpuDirection, - target_shape::MlirTpuI64TargetTuple, - )::Bool + @ccall mlir_c.mlirTpuVregDataBoundsMaskVariesAlong(data_bounds::MlirTpuVregDataBounds, direction::MlirTpuDirection, target_shape::MlirTpuI64TargetTuple)::Bool end function mlirTpuVregDataBoundsIsComplete(data_bounds, target_shape) - @ccall mlir_c.mlirTpuVregDataBoundsIsComplete( - data_bounds::MlirTpuVregDataBounds, target_shape::MlirTpuI64TargetTuple - )::Bool + @ccall mlir_c.mlirTpuVregDataBoundsIsComplete(data_bounds::MlirTpuVregDataBounds, target_shape::MlirTpuI64TargetTuple)::Bool end -function mlirTpuVregDataBoundsGetVectorMask( - data_bounds, insertion_point, location, generation, target_shape -) - @ccall mlir_c.mlirTpuVregDataBoundsGetVectorMask( - data_bounds::MlirTpuVregDataBounds, - insertion_point::MlirTpuInsertionPoint, - location::MlirLocation, - generation::Cint, - target_shape::MlirTpuI64TargetTuple, - )::MlirValue +function mlirTpuVregDataBoundsGetVectorMask(data_bounds, insertion_point, location, generation, target_shape) + @ccall mlir_c.mlirTpuVregDataBoundsGetVectorMask(data_bounds::MlirTpuVregDataBounds, insertion_point::MlirTpuInsertionPoint, location::MlirLocation, generation::Cint, target_shape::MlirTpuI64TargetTuple)::MlirValue end function mlirTpuVregDataBoundsGetSublaneMask(data_bounds, ctx, target_shape) - @ccall mlir_c.mlirTpuVregDataBoundsGetSublaneMask( - data_bounds::MlirTpuVregDataBounds, - ctx::MlirContext, - target_shape::MlirTpuI64TargetTuple, - )::MlirAttribute + @ccall mlir_c.mlirTpuVregDataBoundsGetSublaneMask(data_bounds::MlirTpuVregDataBounds, ctx::MlirContext, target_shape::MlirTpuI64TargetTuple)::MlirAttribute end function mlirTpuAssemble(insertion_point, vector_type, layout, vals, target_shape) - @ccall mlir_c.mlirTpuAssemble( - insertion_point::MlirTpuInsertionPoint, - vector_type::MlirType, - layout::MlirTpuVectorLayout, - vals::MlirTpuValueArray, - target_shape::MlirTpuI64TargetTuple, - )::MlirOperation + @ccall mlir_c.mlirTpuAssemble(insertion_point::MlirTpuInsertionPoint, vector_type::MlirType, layout::MlirTpuVectorLayout, vals::MlirTpuValueArray, target_shape::MlirTpuI64TargetTuple)::MlirOperation end function mlirTpuDisassemble(insertion_point, layout, val, target_shape) - @ccall mlir_c.mlirTpuDisassemble( - insertion_point::MlirTpuInsertionPoint, - layout::MlirTpuVectorLayout, - val::MlirValue, - target_shape::MlirTpuI64TargetTuple, - )::MlirTpuValueArray + @ccall mlir_c.mlirTpuDisassemble(insertion_point::MlirTpuInsertionPoint, layout::MlirTpuVectorLayout, val::MlirValue, target_shape::MlirTpuI64TargetTuple)::MlirTpuValueArray end function mlirTpuApplyLayoutOp(ctx, op) - @ccall mlir_c.mlirTpuApplyLayoutOp( - ctx::MlirTpuApplyVectorLayoutContext, op::MlirOperation - )::MlirLogicalResult + @ccall mlir_c.mlirTpuApplyLayoutOp(ctx::MlirTpuApplyVectorLayoutContext, op::MlirOperation)::MlirLogicalResult end function mlirTpuRelayout(insertion_point, val, src, dst, ctx) - @ccall mlir_c.mlirTpuRelayout( - insertion_point::MlirTpuInsertionPoint, - val::MlirValue, - src::MlirTpuVectorLayout, - dst::MlirTpuVectorLayout, - ctx::MlirTpuApplyVectorLayoutContext, - )::MlirValue + @ccall mlir_c.mlirTpuRelayout(insertion_point::MlirTpuInsertionPoint, val::MlirValue, src::MlirTpuVectorLayout, dst::MlirTpuVectorLayout, ctx::MlirTpuApplyVectorLayoutContext)::MlirValue end function mlirTpuRegisterMosaicSerdePass() @@ -11564,9 +9776,7 @@ function mlirMosaicGpuIsATileTransformAttr(attr) end function mlirMosaicGpuTileTransformAttrGet(ctx, tiling, tiling_size) - @ccall mlir_c.mlirMosaicGpuTileTransformAttrGet( - ctx::MlirContext, tiling::Ptr{Int32}, tiling_size::Int32 - )::MlirAttribute + @ccall mlir_c.mlirMosaicGpuTileTransformAttrGet(ctx::MlirContext, tiling::Ptr{Int32}, tiling_size::Int32)::MlirAttribute end function mlirMosaicGpuTileTransformAttrGetTilingSize(attr) @@ -11574,9 +9784,7 @@ function mlirMosaicGpuTileTransformAttrGetTilingSize(attr) end function mlirMosaicGpuTileTransformAttrGetTiling(attr, index) - @ccall mlir_c.mlirMosaicGpuTileTransformAttrGetTiling( - attr::MlirAttribute, index::Int32 - )::Int32 + @ccall mlir_c.mlirMosaicGpuTileTransformAttrGetTiling(attr::MlirAttribute, index::Int32)::Int32 end function mlirMosaicGpuIsATransposeTransformAttr(attr) @@ -11584,21 +9792,15 @@ function mlirMosaicGpuIsATransposeTransformAttr(attr) end function mlirMosaicGpuTransposeTransformAttrGet(ctx, permutation, permutation_size) - @ccall mlir_c.mlirMosaicGpuTransposeTransformAttrGet( - ctx::MlirContext, permutation::Ptr{Int32}, permutation_size::Int32 - )::MlirAttribute + @ccall mlir_c.mlirMosaicGpuTransposeTransformAttrGet(ctx::MlirContext, permutation::Ptr{Int32}, permutation_size::Int32)::MlirAttribute end function mlirMosaicGpuTransposeTransformAttrGetPermutationSize(attr) - @ccall mlir_c.mlirMosaicGpuTransposeTransformAttrGetPermutationSize( - attr::MlirAttribute - )::Int32 + @ccall mlir_c.mlirMosaicGpuTransposeTransformAttrGetPermutationSize(attr::MlirAttribute)::Int32 end function mlirMosaicGpuTransposeTransformAttrGetPermutation(attr, index) - @ccall mlir_c.mlirMosaicGpuTransposeTransformAttrGetPermutation( - attr::MlirAttribute, index::Int32 - )::Int32 + @ccall mlir_c.mlirMosaicGpuTransposeTransformAttrGetPermutation(attr::MlirAttribute, index::Int32)::Int32 end function mlirMosaicGpuIsASwizzleTransformAttr(attr) @@ -11606,9 +9808,7 @@ function mlirMosaicGpuIsASwizzleTransformAttr(attr) end function mlirMosaicGpuSwizzleTransformAttrGet(ctx, swizzle) - @ccall mlir_c.mlirMosaicGpuSwizzleTransformAttrGet( - ctx::MlirContext, swizzle::Int32 - )::MlirAttribute + @ccall mlir_c.mlirMosaicGpuSwizzleTransformAttrGet(ctx::MlirContext, swizzle::Int32)::MlirAttribute end function mlirMosaicGpuSwizzleTransformAttrGetSwizzle(attr) @@ -11620,21 +9820,15 @@ function mlirGetDialectHandle__mosaic_gpu__() end function enzymexlaLapackLayoutAttrGet(ctx, col_major) - @ccall mlir_c.enzymexlaLapackLayoutAttrGet( - ctx::MlirContext, col_major::UInt8 - )::MlirAttribute + @ccall mlir_c.enzymexlaLapackLayoutAttrGet(ctx::MlirContext, col_major::UInt8)::MlirAttribute end function enzymexlaLapackTransposeAttrGet(ctx, mode) - @ccall mlir_c.enzymexlaLapackTransposeAttrGet( - ctx::MlirContext, mode::Int32 - )::MlirAttribute + @ccall mlir_c.enzymexlaLapackTransposeAttrGet(ctx::MlirContext, mode::Int32)::MlirAttribute end function enzymexlaLapackSideAttrGet(ctx, left_side) - @ccall mlir_c.enzymexlaLapackSideAttrGet( - ctx::MlirContext, left_side::UInt8 - )::MlirAttribute + @ccall mlir_c.enzymexlaLapackSideAttrGet(ctx::MlirContext, left_side::UInt8)::MlirAttribute end function enzymexlaLapackUploAttrGet(ctx, up) @@ -11646,9 +9840,8 @@ function enzymexlaQRAlgorithmAttrGet(ctx, mode) end function enzymexlaGeluApproximationAttrGet(ctx, mode) - @ccall mlir_c.enzymexlaGeluApproximationAttrGet( - ctx::MlirContext, mode::Int32 - )::MlirAttribute + @ccall mlir_c.enzymexlaGeluApproximationAttrGet(ctx::MlirContext, mode::Int32)::MlirAttribute end const MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL = -1 + diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index d4820c9966..4549f8370a 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -273,6 +273,69 @@ function overloaded_mul!( return C end +function overloaded_mul!( + @nospecialize(C::TracedRArray{T,2} where {T}), + @nospecialize(A::Symmetric), + @nospecialize(B::AbstractMatrix), + α::Number=true, + β::Number=true, +) + # Promote to traced arrays + A = call_with_reactant(Reactant.promote_to, TracedRArray, parent(A)) + B = call_with_reactant(Reactant.promote_to, TracedRArray, B) + + # Dimension checks + if size(C) != (size(A, 1), size(B, 2)) + throw(DimensionMismatch("C=$(size(C)), A=$(size(A)), B=$(size(B))")) + end + + T = Reactant.unwrapped_eltype(C) + tmp = @opcall lapack_symm( + T.(materialize_traced_array(A)), + T.(materialize_traced_array(B)), + T.(materialize_traced_array(C)), + Reactant.promote_to(TracedRNumber{T}, α), + Reactant.promote_to(TracedRNumber{T}, β), + side=:L, + uplo=:U, + ) + + set_mlir_data!(C, get_mlir_data(tmp)) # TODO remove later, handling in place ops are weird + return C +end + +function overloaded_mul!( + @nospecialize(C::TracedRArray{T,2} where {T}), + @nospecialize(A::AbstractMatrix), + @nospecialize(B::Symmetric), + α::Number=true, + β::Number=true, +) + # Promote to traced arrays + A = call_with_reactant(Reactant.promote_to, TracedRArray, A) + B = call_with_reactant(Reactant.promote_to, TracedRArray, parent(B)) + + # Dimension checks + if size(C) != (size(A, 1), size(B, 2)) + throw(DimensionMismatch("C=$(size(C)), A=$(size(A)), B=$(size(B))")) + end + + T = Reactant.unwrapped_eltype(C) + tmp = @opcall lapack_symm( + T.(materialize_traced_array(A)), + T.(materialize_traced_array(B)), + T.(materialize_traced_array(C)), + Reactant.promote_to(TracedRNumber{T}, α), + Reactant.promote_to(TracedRNumber{T}, β), + side=:R, + uplo=:U, + ) + + set_mlir_data!(C, get_mlir_data(tmp)) # TODO remove later, handling in place ops are weird + return C +end + + function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = @opcall subtract( diff --git a/src/utils.jl b/src/utils.jl index c7cb254946..8688603ff1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -369,7 +369,11 @@ function rewrite_inst(inst, ir, interp, RT, guaranteed_error) end end if Meta.isexpr(inst, :invoke) - omi = inst.args[1]::Core.MethodInstance + omi = if inst.args[1] isa Core.MethodInstance + inst.args[1] + else + (inst.args[1]::Core.CodeInstance).def + end sig = omi.specTypes ft = sig.parameters[1] argsig = sig.parameters[2:end] @@ -518,22 +522,42 @@ function make_oc_ref( if Base.isassigned(oc_captures) return oc_captures[] else - ores = ccall( - :jl_new_opaque_closure_from_code_info, - Any, - (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), - sig, - rt, - rt, - @__MODULE__, - src, - 0, - nothing, - nargs, - isva, - f, - true, - )::Core.OpaqueClosure + ores = @static if VERSION < v"1.11" + ccall( + :jl_new_opaque_closure_from_code_info, + Any, + (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint), + sig, + rt, + rt, + @__MODULE__, + src, + 0, + nothing, + nargs, + isva, + f, + true, + )::Core.OpaqueClosure + else + ccall( + :jl_new_opaque_closure_from_code_info, + Any, + (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any, Cint, Cint), + sig, # jl_tupletype_t *argt + rt, # jl_value_t *rt_lb + rt, # jl_value_t *rt_ub + @__MODULE__, # jl_module_t *mod + src, # jl_code_info_t *ci + 0, # int lineno + nothing, # jl_value_t *file + nargs, # int nargs + isva, # int isva + f, # jl_value_t *env + true, # int do_compile + true, # int isinferred + )::Core.OpaqueClosure + end oc_captures[] = ores return ores end @@ -610,7 +634,10 @@ end # using a custom interpreter in type unstable code. # `redub_arguments` is `(typeof(original_function), map(typeof, original_args_tuple)...)` function call_with_reactant_generator( - world::UInt, source::LineNumberNode, self, @nospecialize(redub_arguments) + world::UInt, + source::Union{LineNumberNode,Core.Method}, + self, + @nospecialize(redub_arguments) ) @nospecialize args = redub_arguments @@ -722,7 +749,9 @@ function call_with_reactant_generator( src.slotnames = fill(:none, length(ir.argtypes) + 1) src.slotflags = fill(zero(UInt8), length(ir.argtypes)) src.slottypes = copy(ir.argtypes) - src.rettype = rt + @static if VERSION < v"1.12.0-" + src.rettype = rt + end src = CC.ir_to_codeinf!(src, ir) if DEBUG_INTERP[] @@ -744,6 +773,12 @@ function call_with_reactant_generator( # and the REDUB_ARGUMENTS_NAME tuple of input arguments code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] code_info.slotflags = UInt8[0x00, 0x00] + + if VERSION >= v"1.12-" + code_info.nargs = length(code_info.slotnames) + code_info.isva = true + end + n_prepended_slots = 2 overdub_args_slot = Core.SlotNumber(n_prepended_slots) @@ -751,10 +786,18 @@ function call_with_reactant_generator( # into these overdubbed equivalents instead of updating `code_info` in-place. Then, at # the end of the pass, we'll reset `code_info` fields accordingly. overdubbed_code = Any[] - overdubbed_codelocs = Int32[] + + overdubbed_codelocs = @static if isdefined(Core, :DebugInfo) + nothing + else + Int32[] + end + function push_inst!(inst) push!(overdubbed_code, inst) - push!(overdubbed_codelocs, code_info.codelocs[1]) + @static if !isdefined(Core, :DebugInfo) + push!(overdubbed_codelocs, code_info.codelocs[1]) + end return Core.SSAValue(length(overdubbed_code)) end # Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention @@ -778,6 +821,11 @@ function call_with_reactant_generator( iter_args = min(n_actual_args, n_method_args - 1) end + if VERSION >= v"1.12-" + src.nargs = length(src.slottypes) + src.isva = false + end + for i in 1:iter_args actual_argument = Expr( :call, Core.GlobalRef(Core, :getfield), overdub_args_slot, offset @@ -859,12 +907,9 @@ function call_with_reactant_generator( farg = nothing rep = Expr(:call, make_oc, dict, octup, rt, src, ocnargs, ocva, farg) push_inst!(rep) - Core.SSAValue(length(overdubbed_code)) end - push_inst!(Expr(:call, oc, fn_args[1:end]...)) - - ocres = Core.SSAValue(length(overdubbed_code)) + ocres = push_inst!(Expr(:call, oc, fn_args[1:end]...)) if DEBUG_INTERP[] push_inst!(Expr(:call, safe_print, "ocres", ocres)) @@ -879,7 +924,13 @@ function call_with_reactant_generator( end code_info.code = overdubbed_code - code_info.codelocs = overdubbed_codelocs + + @static if isdefined(Core, :DebugInfo) + code_info.debuginfo = Core.DebugInfo(:none) # Core.DebugInfoStream(overdubbed_codelocs), length(overdubbed_codelocs)) + else + code_info.codelocs = overdubbed_codelocs + end + code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code diff --git a/src/xla/Device.jl b/src/xla/Device.jl index 19e9ef737f..fd76bb6e3e 100644 --- a/src/xla/Device.jl +++ b/src/xla/Device.jl @@ -11,6 +11,7 @@ function device_kind end function default_memory end function memories end function is_addressable end +function get_local_hardware_id end """ device_ordinal(device::Device) @@ -29,3 +30,74 @@ end function is_addressable(device::AbstractDevice) return device ∈ addressable_devices(client(device)) end + +# Keep in sync with API.cpp +struct DeviceProperties + total_global_mem::Csize_t + shared_mem_per_block::Csize_t + regs_per_block::Cint + warp_size::Cint + max_threads_per_block::Cint + max_threads_dim::NTuple{3,Cint} + max_grid_size::NTuple{3,Cint} + total_const_mem::Csize_t + major::Cint + minor::Cint + multi_processor_count::Cint + can_map_host_memory::Cint + l2_cache_size::Cint + max_threads_per_multiprocessor::Cint +end + +const DEVICE_PROPERTIES_CACHE = Dict{Tuple{Int,String},DeviceProperties}() + +""" + device_properties(device::AbstractDevice) + +Get a struct containing device properties. Which exact fields are populated relies on the +underlying device implementation. +""" +function device_properties(device::AbstractDevice) + pname = platform_name(client(device)) + local_hardware_id = get_local_hardware_id(device) + + if haskey(DEVICE_PROPERTIES_CACHE, (local_hardware_id, pname)) + return DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] + end + + jldevprops = Ref{DeviceProperties}() + if pname == "cuda" + GC.@preserve jldevprops begin + @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetProperties( + jldevprops::Ptr{Cvoid}, local_hardware_id::Cint + )::Cvoid + end + else + @warn "`get_properties` not implemented for platform: $(pname)" maxlog = 1 + end + DEVICE_PROPERTIES_CACHE[(local_hardware_id, pname)] = jldevprops[] + return jldevprops[] +end + +function Base.show(io::IO, ::MIME"text/plain", props::DeviceProperties) + return print( + io, + """ + DeviceProperties + ---------------- + Total Global Mem: $(_format_bytes(props.total_global_mem)) + Shared Mem Per Block: $(_format_bytes(props.shared_mem_per_block)) + Regs Per Block: $(props.regs_per_block) + Warp Size: $(props.warp_size) + Max Threads Per Block: $(props.max_threads_per_block) + Max Threads Dim: $(props.max_threads_dim) + Max Grid Size: $(props.max_grid_size) + Total Const Mem: $(_format_bytes(props.total_const_mem)) + Version: $(VersionNumber(props.major, props.minor)) + Multi Processor Count: $(props.multi_processor_count) + Can Map Host Memory: $(props.can_map_host_memory) + L2 Cache Size: $(props.l2_cache_size) + Max Threads Per Multiprocessor: $(props.max_threads_per_multiprocessor) + """, + ) +end diff --git a/src/xla/IFRT/Device.jl b/src/xla/IFRT/Device.jl index 7d269e166c..672900454a 100644 --- a/src/xla/IFRT/Device.jl +++ b/src/xla/IFRT/Device.jl @@ -31,6 +31,14 @@ function XLA.get_local_device_id(::Device) return error("Not implemented for ifrt devices") end +function XLA.get_local_hardware_id(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.ifrt_DeviceGetLocalHardwareId( + device.device::Ptr{Cvoid} + )::Cint + end +end + function XLA.default_memory(device::Device) GC.@preserve device begin return Memory( diff --git a/src/xla/PJRT/Device.jl b/src/xla/PJRT/Device.jl index 2a29c6279b..4a4dd178e7 100644 --- a/src/xla/PJRT/Device.jl +++ b/src/xla/PJRT/Device.jl @@ -33,6 +33,14 @@ function XLA.get_local_device_id(device::Device) end end +function XLA.get_local_hardware_id(device::Device) + GC.@preserve device begin + return @ccall MLIR.API.mlir_c.PjRtDeviceGetLocalHardwareId( + device.device::Ptr{Cvoid} + )::Cint + end +end + function XLA.is_addressable(device::Device) GC.@preserve device begin return @ccall MLIR.API.mlir_c.pjrt_device_is_addressable( diff --git a/src/xla/PJRT/LoadedExecutable.jl b/src/xla/PJRT/LoadedExecutable.jl index 65aedbf6d9..02e884f6ae 100644 --- a/src/xla/PJRT/LoadedExecutable.jl +++ b/src/xla/PJRT/LoadedExecutable.jl @@ -105,7 +105,11 @@ function XLA.compile( end function execute_ir(N, M, n_outs, with_device::Bool, nmesh_ids::Int64) - ptr = sizeof(Int) == sizeof(Int64) ? "i64" : "i32" + ptr = @static if VERSION < v"1.12" + sizeof(Int) == sizeof(Int64) ? "i64" : "i32" + else + "ptr" + end cint = sizeof(Cint) == sizeof(Int64) ? "i64" : "i32" args = N > 0 ? ", [$N x $ptr] %inps, [$M x i8] %donated" : "" if with_device diff --git a/src/xla/Stats.jl b/src/xla/Stats.jl index bc66cc348a..59f62609c2 100644 --- a/src/xla/Stats.jl +++ b/src/xla/Stats.jl @@ -13,7 +13,7 @@ struct JLAllocatorStats peak_pool_bytes::Int64 end -_format_bytes(x) = Base.format_bytes(x) +_format_bytes(x) = x < 0 ? nothing : Base.format_bytes(x) _format_bytes(x::Nothing) = x """ diff --git a/src/xla/XLA.jl b/src/xla/XLA.jl index 1a7ffc17f2..f14139b890 100644 --- a/src/xla/XLA.jl +++ b/src/xla/XLA.jl @@ -234,15 +234,6 @@ for runtime in (:PJRT, :IFRT) ) state.clients["cuda"] = gpu state.default_client = gpu - - # set values for cuda. This is being done here since we need cuda - # to be initialized before we can use it. initializing the devices - # implicitly initializes cuda. - cc_major = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMajor()::Int32 - cc_minor = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetComputeCapalilityMinor()::Int32 - Reactant.Compiler.cubinChip[] = "sm_$(cc_major)$(cc_minor)" - - Reactant.Compiler.cuWarpSize[] = @ccall MLIR.API.mlir_c.ReactantCudaDeviceGetWarpSizeInThreads()::Int32 catch e println(stdout, e) end diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 5790bfc928..e6bc28913b 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -432,3 +432,34 @@ end 1e-2 end end + +@testset "Symmetric Multiplication" begin + @testset "F32" begin + A = Symmetric(rand(Float32,(10,10))) + B = rand(Float32,(10,10)) + C = rand(Float32,(10,10)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + C_ra = Reactant.to_rarray(C) + + alpha = rand(Float32) + beta = rand(Float32) + + @test @code_hlo optimize=false A_ra * B_ra * alpha + + end + @testset "F64" begin + A = Symmetric(rand(Float64,(10,10))) + B = rand(Float64,(10,10)) + C = rand(Float64,(10,10)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + C_ra = Reactant.to_rarray(C) + + alpha = rand(Float64) + beta = rand(Float64) + + @test @code_hlo optimize=false A_ra * B_ra * alpha + + end +end \ No newline at end of file diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index b4547400a6..693722e4c4 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -1,4 +1,5 @@ using NNlib, Reactant, Enzyme, Random, Statistics +using Test @testset "Activation Functions" begin sumabs2(f, x) = sum(abs2, f.(x)) diff --git a/test/ops.jl b/test/ops.jl index 104316bd21..b631d20397 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -791,12 +791,12 @@ end @testset "acos" begin x = Reactant.to_rarray(Float32[-1.0, 0.0, 1.0]) - @test acos.(Array(x)) ≈ @jit(Ops.acos(x)) broken = RunningOnTPU + @test acos.(Array(x)) ≈ @jit(Ops.acos(x)) end @testset "acosh" begin x = Reactant.to_rarray(Float32[1.0, 10.0]) - @test acosh.(Array(x)) ≈ @jit(Ops.acosh(x)) broken = RunningOnTPU + @test acosh.(Array(x)) ≈ @jit(Ops.acosh(x)) end @testset "asin" begin