Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 168 additions & 3 deletions src/backends/onnx.js
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,57 @@ const IS_WEB_ENV = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV;
*/
let webInitChain = Promise.resolve();

/**
* Promise that resolves when WASM binary has been loaded (if caching is enabled).
* This ensures we only attempt to load the WASM binary once.
* @type {Promise<void>|null}
*/
let wasmBinaryLoadPromise = null;

/**
* Ensures the WASM binary is loaded and cached before creating an inference session.
* Only runs once, even if called multiple times.
*
* Note: This caches the WASM binary via the wasmBinary option, which allows it to work offline.
* However, the MJS loader file still needs to be fetched from the network (or bundled).
* For full offline support including the MJS file, you need to either:
* 1. Use a Service Worker to cache all network requests
* 2. Bundle onnxruntime-web with your application
*
* @returns {Promise<void>}
*/
async function ensureWasmBinaryLoaded() {
// If already loading or loaded, return the existing promise
if (wasmBinaryLoadPromise) {
return wasmBinaryLoadPromise;
}

// Check if we should load the WASM binary
if (!env.useWasmCache || !ONNX_ENV?.wasm?.wasmPaths || !IS_WEB_ENV) {
wasmBinaryLoadPromise = Promise.resolve();
return wasmBinaryLoadPromise;
}

// Start loading the WASM binary
wasmBinaryLoadPromise = (async () => {
const urls = getUrlsFromPaths(ONNX_ENV.wasm.wasmPaths);

// Load and cache the WASM binary
if (urls.wasm) {
try {
const wasmBinary = await loadWasmBinary(urls.wasm);
if (wasmBinary) {
ONNX_ENV.wasm.wasmBinary = wasmBinary;
}
} catch (err) {
console.warn('Failed to pre-load WASM binary:', err);
}
}
})();

return wasmBinaryLoadPromise;
}

/**
* Create an ONNX inference session.
* @param {Uint8Array|string} buffer_or_path The ONNX model buffer or path.
Expand All @@ -149,6 +200,8 @@ let webInitChain = Promise.resolve();
* @returns {Promise<import('onnxruntime-common').InferenceSession & { config: Object}>} The ONNX inference session.
*/
export async function createInferenceSession(buffer_or_path, session_options, session_config) {
await ensureWasmBinaryLoaded();

const load = () => InferenceSession.create(buffer_or_path, session_options);
const session = await (IS_WEB_ENV ? (webInitChain = webInitChain.then(load)) : load());
session.config = session_config;
Expand Down Expand Up @@ -183,6 +236,121 @@ export function isONNXTensor(x) {
return x instanceof ONNX.Tensor;
}

/**
* Get the appropriate cache instance based on environment and settings.
* @returns {Promise<Cache|null>} The cache instance or null if caching is disabled.
*/
async function getWasmCache() {
if (!env.useWasmCache || !IS_WEB_ENV) {
return null;
}

// Try custom cache first
if (env.useCustomCache && env.customCache) {
return env.customCache;
}

// Try browser cache (only relevant for web environments)
if (env.useBrowserCache && typeof caches !== 'undefined') {
try {
return await caches.open(env.cacheKey);
} catch (e) {
console.warn('Failed to open cache:', e);
}
}

return null;
}

/**
* Extracts the WASM and MJS file URLs from the wasmPaths configuration.
* @param {any} wasmPaths The wasmPaths configuration.
* @returns {{wasm: string|null, mjs: string|null}} Object containing both URLs, or null values if not found.
*/
function getUrlsFromPaths(wasmPaths) {
if (!wasmPaths) return { wasm: null, mjs: null };

// If wasmPaths is an object (Safari case), use the wasm and mjs properties
if (typeof wasmPaths === 'object') {
return {
wasm: wasmPaths.wasm ? (wasmPaths.wasm instanceof URL ? wasmPaths.wasm.href : String(wasmPaths.wasm)) : null,
mjs: wasmPaths.mjs ? (wasmPaths.mjs instanceof URL ? wasmPaths.mjs.href : String(wasmPaths.mjs)) : null,
};
}

// If wasmPaths is a string (prefix), append the appropriate file names
if (typeof wasmPaths === 'string') {
// For non-Safari, use asyncify version
return {
wasm: `${wasmPaths}ort-wasm-simd-threaded.asyncify.wasm`,
mjs: `${wasmPaths}ort-wasm-simd-threaded.asyncify.mjs`,
};
}

return { wasm: null, mjs: null };
}

/**
* Loads and caches a file from the given URL.
* @param {string} url The URL of the file to load.
* @returns {Promise<Response|null>} The response object, or null if loading failed.
*/
async function loadAndCacheWasmFile(url) {
try {
const cache = await getWasmCache();
let response;

// Try to get from cache first
if (cache) {
try {
response = await cache.match(url);
} catch (e) {
console.warn(`Error reading wasm file from cache:`, e);
}
}

// If not in cache, fetch it
if (!response) {
response = await fetch(url);

if (!response.ok) {
throw new Error(`Failed to fetch wasm file: ${response.status} ${response.statusText}`);
}

// Cache the response for future use
if (cache) {
try {
await cache.put(url, response.clone());
} catch (e) {
console.warn(`Failed to cache wasm file:`, e);
}
}
}

return response;
} catch (error) {
console.warn(`Failed to load wasm file:`, error);
return null;
}
}

/**
* Loads and caches the WASM binary for ONNX Runtime.
* @param {string} wasmURL The URL of the WASM file to load.
* @returns {Promise<ArrayBuffer|null>} The WASM binary as an ArrayBuffer, or null if loading failed.
*/
async function loadWasmBinary(wasmURL) {
const response = await loadAndCacheWasmFile(wasmURL);
if (!response) return null;

try {
return await response.arrayBuffer();
} catch (error) {
console.warn('Failed to read WASM binary:', error);
return null;
}
}

/** @type {import('onnxruntime-common').Env} */
// @ts-ignore
const ONNX_ENV = ONNX?.env;
Expand All @@ -207,9 +375,6 @@ if (ONNX_ENV?.wasm) {
: wasmPathPrefix;
}

// TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0
// https://github.com/microsoft/onnxruntime/pull/21534

// Users may wish to proxy the WASM backend to prevent the UI from freezing,
// However, this is not necessary when using WebGPU, so we default to false.
ONNX_ENV.wasm.proxy = false;
Expand Down
7 changes: 7 additions & 0 deletions src/env.js
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ const localModelPath = RUNNING_LOCALLY ? path.join(dirname__, DEFAULT_LOCAL_MODE
* @property {Object|null} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which
* implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache.
* If you wish, you may also return a `Promise<string>` from the `match` function if you'd like to use a file path instead of `Promise<Response>`.
* @property {boolean} useWasmCache Whether to pre-load and cache WASM binaries for ONNX Runtime. Defaults to `true` when cache is available.
* This can improve performance by avoiding repeated downloads of WASM files. Note: Only the WASM binary is cached.
* The MJS loader file still requires network access unless you use a Service Worker.
* @property {string} cacheKey The cache key to use for storing models and WASM binaries. Defaults to 'transformers-cache'.
*/

/** @type {TransformersEnvironment} */
Expand Down Expand Up @@ -185,6 +189,9 @@ export const env = {

useCustomCache: false,
customCache: null,

useWasmCache: IS_WEB_CACHE_AVAILABLE || IS_FS_AVAILABLE,
cacheKey: 'transformers-cache',
//////////////////////////////////////////////////////
};

Expand Down
2 changes: 1 addition & 1 deletion src/utils/hub.js
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
// incognito mode, the following error is thrown: `DOMException: Failed to execute 'open' on 'CacheStorage':
// An attempt was made to break through the security policy of the user agent.`
// So, instead of crashing, we just ignore the error and continue without using the cache.
cache = await caches.open('transformers-cache');
cache = await caches.open(env.cacheKey);
} catch (e) {
console.warn('An error occurred while opening the browser cache:', e);
}
Expand Down
Loading