diff --git a/src/backends/onnx.js b/src/backends/onnx.js index b5923a596..662702d61 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -141,6 +141,57 @@ const IS_WEB_ENV = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV; */ let webInitChain = Promise.resolve(); +/** + * Promise that resolves when WASM binary has been loaded (if caching is enabled). + * This ensures we only attempt to load the WASM binary once. + * @type {Promise|null} + */ +let wasmBinaryLoadPromise = null; + +/** + * Ensures the WASM binary is loaded and cached before creating an inference session. + * Only runs once, even if called multiple times. + * + * Note: This caches the WASM binary via the wasmBinary option, which allows it to work offline. + * However, the MJS loader file still needs to be fetched from the network (or bundled). + * For full offline support including the MJS file, you need to either: + * 1. Use a Service Worker to cache all network requests + * 2. Bundle onnxruntime-web with your application + * + * @returns {Promise} + */ +async function ensureWasmBinaryLoaded() { + // If already loading or loaded, return the existing promise + if (wasmBinaryLoadPromise) { + return wasmBinaryLoadPromise; + } + + // Check if we should load the WASM binary + if (!env.useWasmCache || !ONNX_ENV?.wasm?.wasmPaths || !IS_WEB_ENV) { + wasmBinaryLoadPromise = Promise.resolve(); + return wasmBinaryLoadPromise; + } + + // Start loading the WASM binary + wasmBinaryLoadPromise = (async () => { + const urls = getUrlsFromPaths(ONNX_ENV.wasm.wasmPaths); + + // Load and cache the WASM binary + if (urls.wasm) { + try { + const wasmBinary = await loadWasmBinary(urls.wasm); + if (wasmBinary) { + ONNX_ENV.wasm.wasmBinary = wasmBinary; + } + } catch (err) { + console.warn('Failed to pre-load WASM binary:', err); + } + } + })(); + + return wasmBinaryLoadPromise; +} + /** * Create an ONNX inference session. * @param {Uint8Array|string} buffer_or_path The ONNX model buffer or path. @@ -149,6 +200,8 @@ let webInitChain = Promise.resolve(); * @returns {Promise} The ONNX inference session. */ export async function createInferenceSession(buffer_or_path, session_options, session_config) { + await ensureWasmBinaryLoaded(); + const load = () => InferenceSession.create(buffer_or_path, session_options); const session = await (IS_WEB_ENV ? (webInitChain = webInitChain.then(load)) : load()); session.config = session_config; @@ -183,6 +236,121 @@ export function isONNXTensor(x) { return x instanceof ONNX.Tensor; } +/** + * Get the appropriate cache instance based on environment and settings. + * @returns {Promise} The cache instance or null if caching is disabled. + */ +async function getWasmCache() { + if (!env.useWasmCache || !IS_WEB_ENV) { + return null; + } + + // Try custom cache first + if (env.useCustomCache && env.customCache) { + return env.customCache; + } + + // Try browser cache (only relevant for web environments) + if (env.useBrowserCache && typeof caches !== 'undefined') { + try { + return await caches.open(env.cacheKey); + } catch (e) { + console.warn('Failed to open cache:', e); + } + } + + return null; +} + +/** + * Extracts the WASM and MJS file URLs from the wasmPaths configuration. + * @param {any} wasmPaths The wasmPaths configuration. + * @returns {{wasm: string|null, mjs: string|null}} Object containing both URLs, or null values if not found. + */ +function getUrlsFromPaths(wasmPaths) { + if (!wasmPaths) return { wasm: null, mjs: null }; + + // If wasmPaths is an object (Safari case), use the wasm and mjs properties + if (typeof wasmPaths === 'object') { + return { + wasm: wasmPaths.wasm ? (wasmPaths.wasm instanceof URL ? wasmPaths.wasm.href : String(wasmPaths.wasm)) : null, + mjs: wasmPaths.mjs ? (wasmPaths.mjs instanceof URL ? wasmPaths.mjs.href : String(wasmPaths.mjs)) : null, + }; + } + + // If wasmPaths is a string (prefix), append the appropriate file names + if (typeof wasmPaths === 'string') { + // For non-Safari, use asyncify version + return { + wasm: `${wasmPaths}ort-wasm-simd-threaded.asyncify.wasm`, + mjs: `${wasmPaths}ort-wasm-simd-threaded.asyncify.mjs`, + }; + } + + return { wasm: null, mjs: null }; +} + +/** + * Loads and caches a file from the given URL. + * @param {string} url The URL of the file to load. + * @returns {Promise} The response object, or null if loading failed. + */ +async function loadAndCacheWasmFile(url) { + try { + const cache = await getWasmCache(); + let response; + + // Try to get from cache first + if (cache) { + try { + response = await cache.match(url); + } catch (e) { + console.warn(`Error reading wasm file from cache:`, e); + } + } + + // If not in cache, fetch it + if (!response) { + response = await fetch(url); + + if (!response.ok) { + throw new Error(`Failed to fetch wasm file: ${response.status} ${response.statusText}`); + } + + // Cache the response for future use + if (cache) { + try { + await cache.put(url, response.clone()); + } catch (e) { + console.warn(`Failed to cache wasm file:`, e); + } + } + } + + return response; + } catch (error) { + console.warn(`Failed to load wasm file:`, error); + return null; + } +} + +/** + * Loads and caches the WASM binary for ONNX Runtime. + * @param {string} wasmURL The URL of the WASM file to load. + * @returns {Promise} The WASM binary as an ArrayBuffer, or null if loading failed. + */ +async function loadWasmBinary(wasmURL) { + const response = await loadAndCacheWasmFile(wasmURL); + if (!response) return null; + + try { + return await response.arrayBuffer(); + } catch (error) { + console.warn('Failed to read WASM binary:', error); + return null; + } +} + /** @type {import('onnxruntime-common').Env} */ // @ts-ignore const ONNX_ENV = ONNX?.env; @@ -207,9 +375,6 @@ if (ONNX_ENV?.wasm) { : wasmPathPrefix; } - // TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0 - // https://github.com/microsoft/onnxruntime/pull/21534 - // Users may wish to proxy the WASM backend to prevent the UI from freezing, // However, this is not necessary when using WebGPU, so we default to false. ONNX_ENV.wasm.proxy = false; diff --git a/src/env.js b/src/env.js index 7b8fc9d03..c279f7286 100644 --- a/src/env.js +++ b/src/env.js @@ -155,6 +155,10 @@ const localModelPath = RUNNING_LOCALLY ? path.join(dirname__, DEFAULT_LOCAL_MODE * @property {Object|null} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which * implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache. * If you wish, you may also return a `Promise` from the `match` function if you'd like to use a file path instead of `Promise`. + * @property {boolean} useWasmCache Whether to pre-load and cache WASM binaries for ONNX Runtime. Defaults to `true` when cache is available. + * This can improve performance by avoiding repeated downloads of WASM files. Note: Only the WASM binary is cached. + * The MJS loader file still requires network access unless you use a Service Worker. + * @property {string} cacheKey The cache key to use for storing models and WASM binaries. Defaults to 'transformers-cache'. */ /** @type {TransformersEnvironment} */ @@ -185,6 +189,9 @@ export const env = { useCustomCache: false, customCache: null, + + useWasmCache: IS_WEB_CACHE_AVAILABLE || IS_FS_AVAILABLE, + cacheKey: 'transformers-cache', ////////////////////////////////////////////////////// }; diff --git a/src/utils/hub.js b/src/utils/hub.js index ca5967faf..7d75e7df3 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -448,7 +448,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti // incognito mode, the following error is thrown: `DOMException: Failed to execute 'open' on 'CacheStorage': // An attempt was made to break through the security policy of the user agent.` // So, instead of crashing, we just ignore the error and continue without using the cache. - cache = await caches.open('transformers-cache'); + cache = await caches.open(env.cacheKey); } catch (e) { console.warn('An error occurred while opening the browser cache:', e); }