From 1a758a0b43c2b7228c4e74702a09daf6c9d5335a Mon Sep 17 00:00:00 2001 From: Nico Martin Date: Mon, 1 Dec 2025 15:00:53 +0100 Subject: [PATCH 1/4] added wasm cache --- src/backends/onnx.js | 171 ++++++++++++++++++++++++++++++++++++++++++- src/env.js | 7 ++ src/utils/hub.js | 2 +- 3 files changed, 176 insertions(+), 4 deletions(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index b5923a596..662702d61 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -141,6 +141,57 @@ const IS_WEB_ENV = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV; */ let webInitChain = Promise.resolve(); +/** + * Promise that resolves when WASM binary has been loaded (if caching is enabled). + * This ensures we only attempt to load the WASM binary once. + * @type {Promise|null} + */ +let wasmBinaryLoadPromise = null; + +/** + * Ensures the WASM binary is loaded and cached before creating an inference session. + * Only runs once, even if called multiple times. + * + * Note: This caches the WASM binary via the wasmBinary option, which allows it to work offline. + * However, the MJS loader file still needs to be fetched from the network (or bundled). + * For full offline support including the MJS file, you need to either: + * 1. Use a Service Worker to cache all network requests + * 2. Bundle onnxruntime-web with your application + * + * @returns {Promise} + */ +async function ensureWasmBinaryLoaded() { + // If already loading or loaded, return the existing promise + if (wasmBinaryLoadPromise) { + return wasmBinaryLoadPromise; + } + + // Check if we should load the WASM binary + if (!env.useWasmCache || !ONNX_ENV?.wasm?.wasmPaths || !IS_WEB_ENV) { + wasmBinaryLoadPromise = Promise.resolve(); + return wasmBinaryLoadPromise; + } + + // Start loading the WASM binary + wasmBinaryLoadPromise = (async () => { + const urls = getUrlsFromPaths(ONNX_ENV.wasm.wasmPaths); + + // Load and cache the WASM binary + if (urls.wasm) { + try { + const wasmBinary = await loadWasmBinary(urls.wasm); + if (wasmBinary) { + ONNX_ENV.wasm.wasmBinary = wasmBinary; + } + } catch (err) { + console.warn('Failed to pre-load WASM binary:', err); + } + } + })(); + + return wasmBinaryLoadPromise; +} + /** * Create an ONNX inference session. * @param {Uint8Array|string} buffer_or_path The ONNX model buffer or path. @@ -149,6 +200,8 @@ let webInitChain = Promise.resolve(); * @returns {Promise} The ONNX inference session. */ export async function createInferenceSession(buffer_or_path, session_options, session_config) { + await ensureWasmBinaryLoaded(); + const load = () => InferenceSession.create(buffer_or_path, session_options); const session = await (IS_WEB_ENV ? (webInitChain = webInitChain.then(load)) : load()); session.config = session_config; @@ -183,6 +236,121 @@ export function isONNXTensor(x) { return x instanceof ONNX.Tensor; } +/** + * Get the appropriate cache instance based on environment and settings. + * @returns {Promise} The cache instance or null if caching is disabled. + */ +async function getWasmCache() { + if (!env.useWasmCache || !IS_WEB_ENV) { + return null; + } + + // Try custom cache first + if (env.useCustomCache && env.customCache) { + return env.customCache; + } + + // Try browser cache (only relevant for web environments) + if (env.useBrowserCache && typeof caches !== 'undefined') { + try { + return await caches.open(env.cacheKey); + } catch (e) { + console.warn('Failed to open cache:', e); + } + } + + return null; +} + +/** + * Extracts the WASM and MJS file URLs from the wasmPaths configuration. + * @param {any} wasmPaths The wasmPaths configuration. + * @returns {{wasm: string|null, mjs: string|null}} Object containing both URLs, or null values if not found. + */ +function getUrlsFromPaths(wasmPaths) { + if (!wasmPaths) return { wasm: null, mjs: null }; + + // If wasmPaths is an object (Safari case), use the wasm and mjs properties + if (typeof wasmPaths === 'object') { + return { + wasm: wasmPaths.wasm ? (wasmPaths.wasm instanceof URL ? wasmPaths.wasm.href : String(wasmPaths.wasm)) : null, + mjs: wasmPaths.mjs ? (wasmPaths.mjs instanceof URL ? wasmPaths.mjs.href : String(wasmPaths.mjs)) : null, + }; + } + + // If wasmPaths is a string (prefix), append the appropriate file names + if (typeof wasmPaths === 'string') { + // For non-Safari, use asyncify version + return { + wasm: `${wasmPaths}ort-wasm-simd-threaded.asyncify.wasm`, + mjs: `${wasmPaths}ort-wasm-simd-threaded.asyncify.mjs`, + }; + } + + return { wasm: null, mjs: null }; +} + +/** + * Loads and caches a file from the given URL. + * @param {string} url The URL of the file to load. + * @returns {Promise} The response object, or null if loading failed. + */ +async function loadAndCacheWasmFile(url) { + try { + const cache = await getWasmCache(); + let response; + + // Try to get from cache first + if (cache) { + try { + response = await cache.match(url); + } catch (e) { + console.warn(`Error reading wasm file from cache:`, e); + } + } + + // If not in cache, fetch it + if (!response) { + response = await fetch(url); + + if (!response.ok) { + throw new Error(`Failed to fetch wasm file: ${response.status} ${response.statusText}`); + } + + // Cache the response for future use + if (cache) { + try { + await cache.put(url, response.clone()); + } catch (e) { + console.warn(`Failed to cache wasm file:`, e); + } + } + } + + return response; + } catch (error) { + console.warn(`Failed to load wasm file:`, error); + return null; + } +} + +/** + * Loads and caches the WASM binary for ONNX Runtime. + * @param {string} wasmURL The URL of the WASM file to load. + * @returns {Promise} The WASM binary as an ArrayBuffer, or null if loading failed. + */ +async function loadWasmBinary(wasmURL) { + const response = await loadAndCacheWasmFile(wasmURL); + if (!response) return null; + + try { + return await response.arrayBuffer(); + } catch (error) { + console.warn('Failed to read WASM binary:', error); + return null; + } +} + /** @type {import('onnxruntime-common').Env} */ // @ts-ignore const ONNX_ENV = ONNX?.env; @@ -207,9 +375,6 @@ if (ONNX_ENV?.wasm) { : wasmPathPrefix; } - // TODO: Add support for loading WASM files from cached buffer when we upgrade to onnxruntime-web@1.19.0 - // https://github.com/microsoft/onnxruntime/pull/21534 - // Users may wish to proxy the WASM backend to prevent the UI from freezing, // However, this is not necessary when using WebGPU, so we default to false. ONNX_ENV.wasm.proxy = false; diff --git a/src/env.js b/src/env.js index 7b8fc9d03..c279f7286 100644 --- a/src/env.js +++ b/src/env.js @@ -155,6 +155,10 @@ const localModelPath = RUNNING_LOCALLY ? path.join(dirname__, DEFAULT_LOCAL_MODE * @property {Object|null} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which * implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache. * If you wish, you may also return a `Promise` from the `match` function if you'd like to use a file path instead of `Promise`. + * @property {boolean} useWasmCache Whether to pre-load and cache WASM binaries for ONNX Runtime. Defaults to `true` when cache is available. + * This can improve performance by avoiding repeated downloads of WASM files. Note: Only the WASM binary is cached. + * The MJS loader file still requires network access unless you use a Service Worker. + * @property {string} cacheKey The cache key to use for storing models and WASM binaries. Defaults to 'transformers-cache'. */ /** @type {TransformersEnvironment} */ @@ -185,6 +189,9 @@ export const env = { useCustomCache: false, customCache: null, + + useWasmCache: IS_WEB_CACHE_AVAILABLE || IS_FS_AVAILABLE, + cacheKey: 'transformers-cache', ////////////////////////////////////////////////////// }; diff --git a/src/utils/hub.js b/src/utils/hub.js index ca5967faf..7d75e7df3 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -448,7 +448,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti // incognito mode, the following error is thrown: `DOMException: Failed to execute 'open' on 'CacheStorage': // An attempt was made to break through the security policy of the user agent.` // So, instead of crashing, we just ignore the error and continue without using the cache. - cache = await caches.open('transformers-cache'); + cache = await caches.open(env.cacheKey); } catch (e) { console.warn('An error occurred while opening the browser cache:', e); } From 1ce23bcc41cfc501e0b366caa5ab191057e799de Mon Sep 17 00:00:00 2001 From: Nico Martin Date: Thu, 4 Dec 2025 11:33:36 +0100 Subject: [PATCH 2/4] some refactoring of the hub.js and caching of the wasm factory --- src/backends/onnx.js | 199 +++++---------- src/backends/utils/cacheWasm.js | 83 ++++++ src/env.js | 3 +- src/utils/cache.js | 82 ++++++ src/utils/hub.js | 432 ++------------------------------ src/utils/hub/FileCache.js | 92 +++++++ src/utils/hub/FileResponse.js | 121 +++++++++ src/utils/hub/constants.js | 18 ++ src/utils/hub/utils.js | 128 ++++++++++ webpack.config.js | 4 +- 10 files changed, 601 insertions(+), 561 deletions(-) create mode 100644 src/backends/utils/cacheWasm.js create mode 100644 src/utils/cache.js create mode 100644 src/utils/hub/FileCache.js create mode 100644 src/utils/hub/FileResponse.js create mode 100644 src/utils/hub/constants.js create mode 100644 src/utils/hub/utils.js diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 662702d61..38f303403 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -22,6 +22,7 @@ import { env, apis } from '../env.js'; // In either case, we select the default export if it exists, otherwise we use the named export. import * as ONNX_NODE from 'onnxruntime-node'; import * as ONNX_WEB from 'onnxruntime-web/webgpu'; +import { loadWasmBinary, loadWasmFactory } from './utils/cacheWasm.js'; export { Tensor } from 'onnxruntime-common'; @@ -146,7 +147,7 @@ let webInitChain = Promise.resolve(); * This ensures we only attempt to load the WASM binary once. * @type {Promise|null} */ -let wasmBinaryLoadPromise = null; +let wasmLoadPromise = null; /** * Ensures the WASM binary is loaded and cached before creating an inference session. @@ -160,36 +161,64 @@ let wasmBinaryLoadPromise = null; * * @returns {Promise} */ -async function ensureWasmBinaryLoaded() { +async function ensureWasmLoaded() { // If already loading or loaded, return the existing promise - if (wasmBinaryLoadPromise) { - return wasmBinaryLoadPromise; + if (wasmLoadPromise) { + return wasmLoadPromise; } + const shouldUseWasmCache = + env.useWasmCache && + typeof ONNX_ENV?.wasm?.wasmPaths === 'object' && + ONNX_ENV?.wasm?.wasmPaths?.wasm && + ONNX_ENV?.wasm?.wasmPaths?.mjs; + // Check if we should load the WASM binary - if (!env.useWasmCache || !ONNX_ENV?.wasm?.wasmPaths || !IS_WEB_ENV) { - wasmBinaryLoadPromise = Promise.resolve(); - return wasmBinaryLoadPromise; + if (!shouldUseWasmCache) { + wasmLoadPromise = Promise.resolve(); + return wasmLoadPromise; } // Start loading the WASM binary - wasmBinaryLoadPromise = (async () => { - const urls = getUrlsFromPaths(ONNX_ENV.wasm.wasmPaths); - - // Load and cache the WASM binary - if (urls.wasm) { - try { - const wasmBinary = await loadWasmBinary(urls.wasm); - if (wasmBinary) { - ONNX_ENV.wasm.wasmBinary = wasmBinary; - } - } catch (err) { - console.warn('Failed to pre-load WASM binary:', err); - } - } + wasmLoadPromise = (async () => { + // At this point, we know wasmPaths is an object (not a string) because + // shouldUseWasmCache checks for wasmPaths.wasm and wasmPaths.mjs + const urls = /** @type {{ wasm: string, mjs: string }} */ (ONNX_ENV.wasm.wasmPaths); + + // Load and cache both the WASM binary and factory + await Promise.all([ + // Load and cache the WASM binary + urls.wasm + ? (async () => { + try { + const wasmBinary = await loadWasmBinary(urls.wasm); + if (wasmBinary) { + ONNX_ENV.wasm.wasmBinary = wasmBinary; + } + } catch (err) { + console.warn('Failed to pre-load WASM binary:', err); + } + })() + : Promise.resolve(), + + // Load and cache the WASM factory + urls.mjs + ? (async () => { + try { + const wasmFactoryBlob = await loadWasmFactory(urls.mjs); + if (wasmFactoryBlob) { + // @ts-ignore + ONNX_ENV.wasm.wasmPaths.mjs = wasmFactoryBlob; + } + } catch (err) { + console.warn('Failed to pre-load WASM factory:', err); + } + })() + : Promise.resolve(), + ]); })(); - return wasmBinaryLoadPromise; + return wasmLoadPromise; } /** @@ -200,7 +229,7 @@ async function ensureWasmBinaryLoaded() { * @returns {Promise} The ONNX inference session. */ export async function createInferenceSession(buffer_or_path, session_options, session_config) { - await ensureWasmBinaryLoaded(); + await ensureWasmLoaded(); const load = () => InferenceSession.create(buffer_or_path, session_options); const session = await (IS_WEB_ENV ? (webInitChain = webInitChain.then(load)) : load()); @@ -236,121 +265,6 @@ export function isONNXTensor(x) { return x instanceof ONNX.Tensor; } -/** - * Get the appropriate cache instance based on environment and settings. - * @returns {Promise} The cache instance or null if caching is disabled. - */ -async function getWasmCache() { - if (!env.useWasmCache || !IS_WEB_ENV) { - return null; - } - - // Try custom cache first - if (env.useCustomCache && env.customCache) { - return env.customCache; - } - - // Try browser cache (only relevant for web environments) - if (env.useBrowserCache && typeof caches !== 'undefined') { - try { - return await caches.open(env.cacheKey); - } catch (e) { - console.warn('Failed to open cache:', e); - } - } - - return null; -} - -/** - * Extracts the WASM and MJS file URLs from the wasmPaths configuration. - * @param {any} wasmPaths The wasmPaths configuration. - * @returns {{wasm: string|null, mjs: string|null}} Object containing both URLs, or null values if not found. - */ -function getUrlsFromPaths(wasmPaths) { - if (!wasmPaths) return { wasm: null, mjs: null }; - - // If wasmPaths is an object (Safari case), use the wasm and mjs properties - if (typeof wasmPaths === 'object') { - return { - wasm: wasmPaths.wasm ? (wasmPaths.wasm instanceof URL ? wasmPaths.wasm.href : String(wasmPaths.wasm)) : null, - mjs: wasmPaths.mjs ? (wasmPaths.mjs instanceof URL ? wasmPaths.mjs.href : String(wasmPaths.mjs)) : null, - }; - } - - // If wasmPaths is a string (prefix), append the appropriate file names - if (typeof wasmPaths === 'string') { - // For non-Safari, use asyncify version - return { - wasm: `${wasmPaths}ort-wasm-simd-threaded.asyncify.wasm`, - mjs: `${wasmPaths}ort-wasm-simd-threaded.asyncify.mjs`, - }; - } - - return { wasm: null, mjs: null }; -} - -/** - * Loads and caches a file from the given URL. - * @param {string} url The URL of the file to load. - * @returns {Promise} The response object, or null if loading failed. - */ -async function loadAndCacheWasmFile(url) { - try { - const cache = await getWasmCache(); - let response; - - // Try to get from cache first - if (cache) { - try { - response = await cache.match(url); - } catch (e) { - console.warn(`Error reading wasm file from cache:`, e); - } - } - - // If not in cache, fetch it - if (!response) { - response = await fetch(url); - - if (!response.ok) { - throw new Error(`Failed to fetch wasm file: ${response.status} ${response.statusText}`); - } - - // Cache the response for future use - if (cache) { - try { - await cache.put(url, response.clone()); - } catch (e) { - console.warn(`Failed to cache wasm file:`, e); - } - } - } - - return response; - } catch (error) { - console.warn(`Failed to load wasm file:`, error); - return null; - } -} - -/** - * Loads and caches the WASM binary for ONNX Runtime. - * @param {string} wasmURL The URL of the WASM file to load. - * @returns {Promise} The WASM binary as an ArrayBuffer, or null if loading failed. - */ -async function loadWasmBinary(wasmURL) { - const response = await loadAndCacheWasmFile(wasmURL); - if (!response) return null; - - try { - return await response.arrayBuffer(); - } catch (error) { - console.warn('Failed to read WASM binary:', error); - return null; - } -} - /** @type {import('onnxruntime-common').Env} */ // @ts-ignore const ONNX_ENV = ONNX?.env; @@ -369,10 +283,13 @@ if (ONNX_ENV?.wasm) { ONNX_ENV.wasm.wasmPaths = apis.IS_SAFARI ? { - mjs: `${wasmPathPrefix}/ort-wasm-simd-threaded.mjs`, - wasm: `${wasmPathPrefix}/ort-wasm-simd-threaded.wasm`, + mjs: `${wasmPathPrefix}ort-wasm-simd-threaded.mjs`, + wasm: `${wasmPathPrefix}ort-wasm-simd-threaded.wasm`, } - : wasmPathPrefix; + : { + mjs: `${wasmPathPrefix}ort-wasm-simd-threaded.asyncify.mjs`, + wasm: `${wasmPathPrefix}ort-wasm-simd-threaded.asyncify.wasm`, + }; } // Users may wish to proxy the WASM backend to prevent the UI from freezing, diff --git a/src/backends/utils/cacheWasm.js b/src/backends/utils/cacheWasm.js new file mode 100644 index 000000000..237a7e2b1 --- /dev/null +++ b/src/backends/utils/cacheWasm.js @@ -0,0 +1,83 @@ +import { getCache } from '../../utils/cache.js'; + +/** + * Loads and caches a file from the given URL. + * @param {string} url The URL of the file to load. + * @returns {Promise} The response object, or null if loading failed. + */ +async function loadAndCacheFile(url) { + const fileName = url.split('/').pop(); + try { + const cache = await getCache(); + + // Try to get from cache first + if (cache) { + try { + return await cache.match(url); + } catch (e) { + console.warn(`Error reading ${fileName} from cache:`, e); + } + } + + // If not in cache, fetch it + const response = await fetch(url); + + if (!response.ok) { + throw new Error(`Failed to fetch ${fileName}: ${response.status} ${response.statusText}`); + } + + // Cache the response for future use + if (cache) { + try { + await cache.put(url, response.clone()); + } catch (e) { + console.warn(`Failed to cache ${fileName}:`, e); + } + } + + return response; + } catch (error) { + console.warn(`Failed to load ${fileName}:`, error); + return null; + } +} + +/** + * Loads and caches the WASM binary for ONNX Runtime. + * @param {string} wasmURL The URL of the WASM file to load. + * @returns {Promise} The WASM binary as an ArrayBuffer, or null if loading failed. + */ + +export async function loadWasmBinary(wasmURL) { + const response = await loadAndCacheFile(wasmURL); + if (!response) return null; + + try { + return await response.arrayBuffer(); + } catch (error) { + console.warn('Failed to read WASM binary:', error); + return null; + } +} + +/** + * Loads and caches the WASM Factory for ONNX Runtime. + * @param {string} libURL The URL of the WASM Factory to load. + * @returns {Promise} The blob URL of the WASM Factory, or null if loading failed. + */ +export async function loadWasmFactory(libURL) { + const response = await loadAndCacheFile(libURL); + if (!response) return null; + + try { + let code = await response.text(); + // Fix relative paths when loading factory from blob, overwrite import.meta.url with actual baseURL + const baseUrl = libURL.split('/').slice(0, -1).join('/'); + code = code.replace(/import\.meta\.url/g, `"${baseUrl}"`); + const blob = new Blob([code], { type: 'text/javascript' }); + return URL.createObjectURL(blob); + } catch (error) { + console.warn('Failed to read WASM binary:', error); + return null; + } +} diff --git a/src/env.js b/src/env.js index c279f7286..bf19628cf 100644 --- a/src/env.js +++ b/src/env.js @@ -152,9 +152,8 @@ const localModelPath = RUNNING_LOCALLY ? path.join(dirname__, DEFAULT_LOCAL_MODE * @property {boolean} useFSCache Whether to use the file system to cache files. By default, it is `true` if available. * @property {string|null} cacheDir The directory to use for caching files with the file system. By default, it is `./.cache`. * @property {boolean} useCustomCache Whether to use a custom cache system (defined by `customCache`), defaults to `false`. - * @property {Object|null} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which + * @property {import('./utils/cache.js').CacheInterface|null} customCache The custom cache to use. Defaults to `null`. Note: this must be an object which * implements the `match` and `put` functions of the Web Cache API. For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache. - * If you wish, you may also return a `Promise` from the `match` function if you'd like to use a file path instead of `Promise`. * @property {boolean} useWasmCache Whether to pre-load and cache WASM binaries for ONNX Runtime. Defaults to `true` when cache is available. * This can improve performance by avoiding repeated downloads of WASM files. Note: Only the WASM binary is cached. * The MJS loader file still requires network access unless you use a Service Worker. diff --git a/src/utils/cache.js b/src/utils/cache.js new file mode 100644 index 000000000..0f8582c5b --- /dev/null +++ b/src/utils/cache.js @@ -0,0 +1,82 @@ +import { apis, env } from '../env.js'; +import FileCache from './hub/FileCache.js'; + +/** + * @typedef {Object} CacheInterface + * @property {(request: string) => Promise} match + * Checks if a request is in the cache and returns the cached response if found. + * @property {(request: string, response: Response, progress_callback?: (data: {progress: number, loaded: number, total: number}) => void) => Promise} put + * Adds a response to the cache. + */ + +/** + * Retrieves an appropriate caching backend based on the environment configuration. + * Attempts to use custom cache, browser cache, or file system cache in that order of priority. + * @returns {Promise} + * @param file_cache_dir {string|null} Path to a directory in which a downloaded pretrained model configuration should be cached if using the file system cache. + */ +export async function getCache(file_cache_dir = null) { + // First, check if the a caching backend is available + // If no caching mechanism available, will download the file every time + let cache = null; + if (env.useCustomCache) { + // Allow the user to specify a custom cache system. + if (!env.customCache) { + throw Error('`env.useCustomCache=true`, but `env.customCache` is not defined.'); + } + + // Check that the required methods are defined: + if (!env.customCache.match || !env.customCache.put) { + throw new Error( + '`env.customCache` must be an object which implements the `match` and `put` functions of the Web Cache API. ' + + 'For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache', + ); + } + cache = env.customCache; + } + + if (!cache && env.useBrowserCache) { + if (typeof caches === 'undefined') { + throw Error('Browser cache is not available in this environment.'); + } + try { + // In some cases, the browser cache may be visible, but not accessible due to security restrictions. + // For example, when running an application in an iframe, if a user attempts to load the page in + // incognito mode, the following error is thrown: `DOMException: Failed to execute 'open' on 'CacheStorage': + // An attempt was made to break through the security policy of the user agent.` + // So, instead of crashing, we just ignore the error and continue without using the cache. + cache = await caches.open(env.cacheKey); + } catch (e) { + console.warn('An error occurred while opening the browser cache:', e); + } + } + + if (!cache && env.useFSCache) { + if (!apis.IS_FS_AVAILABLE) { + throw Error('File System Cache is not available in this environment.'); + } + + // If `cache_dir` is not specified, use the default cache directory + cache = new FileCache(file_cache_dir ?? env.cacheDir); + } + + return cache; +} + +/** + * Searches the cache for any of the provided names and returns the first match found. + * @param {CacheInterface} cache The cache to search + * @param {...string} names The names of the items to search for + * @returns {Promise} The item from the cache, or undefined if not found. + */ +export async function tryCache(cache, ...names) { + for (let name of names) { + try { + let result = await cache.match(name); + if (result) return result; + } catch (e) { + continue; + } + } + return undefined; +} diff --git a/src/utils/hub.js b/src/utils/hub.js index 7d75e7df3..218b812a1 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -4,18 +4,22 @@ * @module utils/hub */ -import fs from 'node:fs'; -import path from 'node:path'; - import { apis, env } from '../env.js'; import { dispatchCallback } from './core.js'; +import FileResponse from './hub/FileResponse.js'; +import FileCache from './hub/FileCache.js'; +import { handleError, isValidUrl, pathJoin, isValidHfModelId, readResponse } from './hub/utils.js'; +import { getCache, tryCache } from './cache.js'; + +export { MAX_EXTERNAL_DATA_CHUNKS } from './hub/constants.js'; /** - * @typedef {boolean|number} ExternalData Whether to load the model using the external data format (used for models >= 2GB in size). - * If `true`, the model will be loaded using the external data format. - * If a number, this many chunks will be loaded using the external data format (of the form: "model.onnx_data[_{chunk_number}]"). + * @typedef {boolean|number} ExternalData + * Specifies whether to load the model using the external data format. + * - `false`: Do not use external data format + * - `true`: Use external data format with 1 chunk + * - `number`: Use external data format with the specified number of chunks */ -export const MAX_EXTERNAL_DATA_CHUNKS = 100; /** * @typedef {Object} PretrainedOptions Options for loading a pretrained model. @@ -45,165 +49,6 @@ export const MAX_EXTERNAL_DATA_CHUNKS = 100; * @typedef {PretrainedOptions & ModelSpecificPretrainedOptions} PretrainedModelOptions Options for loading a pretrained model. */ -/** - * Mapping from file extensions to MIME types. - */ -const CONTENT_TYPE_MAP = { - txt: 'text/plain', - html: 'text/html', - css: 'text/css', - js: 'text/javascript', - json: 'application/json', - png: 'image/png', - jpg: 'image/jpeg', - jpeg: 'image/jpeg', - gif: 'image/gif', -}; -class FileResponse { - /** - * Creates a new `FileResponse` object. - * @param {string} filePath - */ - constructor(filePath) { - this.filePath = filePath; - this.headers = new Headers(); - - this.exists = fs.existsSync(filePath); - if (this.exists) { - this.status = 200; - this.statusText = 'OK'; - - let stats = fs.statSync(filePath); - this.headers.set('content-length', stats.size.toString()); - - this.updateContentType(); - - const stream = fs.createReadStream(filePath); - this.body = new ReadableStream({ - start(controller) { - stream.on('data', (chunk) => controller.enqueue(chunk)); - stream.on('end', () => controller.close()); - stream.on('error', (err) => controller.error(err)); - }, - cancel() { - stream.destroy(); - }, - }); - } else { - this.status = 404; - this.statusText = 'Not Found'; - this.body = null; - } - } - - /** - * Updates the 'content-type' header property of the response based on the extension of - * the file specified by the filePath property of the current object. - * @returns {void} - */ - updateContentType() { - // Set content-type header based on file extension - const extension = this.filePath.toString().split('.').pop().toLowerCase(); - this.headers.set('content-type', CONTENT_TYPE_MAP[extension] ?? 'application/octet-stream'); - } - - /** - * Clone the current FileResponse object. - * @returns {FileResponse} A new FileResponse object with the same properties as the current object. - */ - clone() { - let response = new FileResponse(this.filePath); - response.exists = this.exists; - response.status = this.status; - response.statusText = this.statusText; - response.headers = new Headers(this.headers); - return response; - } - - /** - * Reads the contents of the file specified by the filePath property and returns a Promise that - * resolves with an ArrayBuffer containing the file's contents. - * @returns {Promise} A Promise that resolves with an ArrayBuffer containing the file's contents. - * @throws {Error} If the file cannot be read. - */ - async arrayBuffer() { - const data = await fs.promises.readFile(this.filePath); - return /** @type {ArrayBuffer} */ (data.buffer); - } - - /** - * Reads the contents of the file specified by the filePath property and returns a Promise that - * resolves with a Blob containing the file's contents. - * @returns {Promise} A Promise that resolves with a Blob containing the file's contents. - * @throws {Error} If the file cannot be read. - */ - async blob() { - const data = await fs.promises.readFile(this.filePath); - return new Blob([/** @type {any} */ (data)], { type: this.headers.get('content-type') }); - } - - /** - * Reads the contents of the file specified by the filePath property and returns a Promise that - * resolves with a string containing the file's contents. - * @returns {Promise} A Promise that resolves with a string containing the file's contents. - * @throws {Error} If the file cannot be read. - */ - async text() { - const data = await fs.promises.readFile(this.filePath, 'utf8'); - return data; - } - - /** - * Reads the contents of the file specified by the filePath property and returns a Promise that - * resolves with a parsed JavaScript object containing the file's contents. - * - * @returns {Promise} A Promise that resolves with a parsed JavaScript object containing the file's contents. - * @throws {Error} If the file cannot be read. - */ - async json() { - return JSON.parse(await this.text()); - } -} - -/** - * Determines whether the given string is a valid URL. - * @param {string|URL} string The string to test for validity as an URL. - * @param {string[]} [protocols=null] A list of valid protocols. If specified, the protocol must be in this list. - * @param {string[]} [validHosts=null] A list of valid hostnames. If specified, the URL's hostname must be in this list. - * @returns {boolean} True if the string is a valid URL, false otherwise. - */ -function isValidUrl(string, protocols = null, validHosts = null) { - let url; - try { - url = new URL(string); - } catch (_) { - return false; - } - if (protocols && !protocols.includes(url.protocol)) { - return false; - } - if (validHosts && !validHosts.includes(url.hostname)) { - return false; - } - return true; -} - -const REPO_ID_REGEX = /^(\b[\w\-.]+\b\/)?\b[\w\-.]{1,96}\b$/; - -/** - * Tests whether a string is a valid Hugging Face model ID or not. - * Adapted from https://github.com/huggingface/huggingface_hub/blob/6378820ebb03f071988a96c7f3268f5bdf8f9449/src/huggingface_hub/utils/_validators.py#L119-L170 - * - * @param {string} string The string to test - * @returns {boolean} True if the string is a valid model ID, false otherwise. - */ -function isValidHfModelId(string) { - if (!REPO_ID_REGEX.test(string)) return false; - if (string.includes('..') || string.includes('--')) return false; - if (string.endsWith('.git') || string.endsWith('.ipynb')) return false; - return true; -} - /** * Helper function to get a file, using either the Fetch API or FileSystem API. * @@ -246,142 +91,6 @@ export async function getFile(urlOrPath) { } } -const ERROR_MAPPING = { - // 4xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses) - 400: 'Bad request error occurred while trying to load file', - 401: 'Unauthorized access to file', - 403: 'Forbidden access to file', - 404: 'Could not locate file', - 408: 'Request timeout error occurred while trying to load file', - - // 5xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#server_error_responses) - 500: 'Internal server error error occurred while trying to load file', - 502: 'Bad gateway error occurred while trying to load file', - 503: 'Service unavailable error occurred while trying to load file', - 504: 'Gateway timeout error occurred while trying to load file', -}; -/** - * Helper method to handle fatal errors that occur while trying to load a file from the Hugging Face Hub. - * @param {number} status The HTTP status code of the error. - * @param {string} remoteURL The URL of the file that could not be loaded. - * @param {boolean} fatal Whether to raise an error if the file could not be loaded. - * @returns {null} Returns `null` if `fatal = true`. - * @throws {Error} If `fatal = false`. - */ -function handleError(status, remoteURL, fatal) { - if (!fatal) { - // File was not loaded correctly, but it is optional. - // TODO in future, cache the response? - return null; - } - - const message = ERROR_MAPPING[status] ?? `Error (${status}) occurred while trying to load file`; - throw Error(`${message}: "${remoteURL}".`); -} - -class FileCache { - /** - * Instantiate a `FileCache` object. - * @param {string} path - */ - constructor(path) { - this.path = path; - } - - /** - * Checks whether the given request is in the cache. - * @param {string} request - * @returns {Promise} - */ - async match(request) { - let filePath = path.join(this.path, request); - let file = new FileResponse(filePath); - - if (file.exists) { - return file; - } else { - return undefined; - } - } - - /** - * Adds the given response to the cache. - * @param {string} request - * @param {Response} response - * @param {(data: {progress: number, loaded: number, total: number}) => void} [progress_callback] Optional. - * The function to call with progress updates - * @returns {Promise} - */ - async put(request, response, progress_callback = undefined) { - let filePath = path.join(this.path, request); - - try { - const contentLength = response.headers.get('Content-Length'); - const total = parseInt(contentLength ?? '0'); - let loaded = 0; - - await fs.promises.mkdir(path.dirname(filePath), { recursive: true }); - const fileStream = fs.createWriteStream(filePath); - const reader = response.body.getReader(); - - while (true) { - const { done, value } = await reader.read(); - if (done) { - break; - } - - await new Promise((resolve, reject) => { - fileStream.write(value, (err) => { - if (err) { - reject(err); - return; - } - resolve(); - }); - }); - - loaded += value.length; - const progress = total ? (loaded / total) * 100 : 0; - - progress_callback?.({ progress, loaded, total }); - } - - fileStream.close(); - } catch (error) { - // Clean up the file if an error occurred during download - try { - await fs.promises.unlink(filePath); - } catch {} - throw error; - } - } - - // TODO add the rest? - // addAll(requests: RequestInfo[]): Promise; - // delete(request: RequestInfo | URL, options?: CacheQueryOptions): Promise; - // keys(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise>; - // match(request: RequestInfo | URL, options?: CacheQueryOptions): Promise; - // matchAll(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise>; -} - -/** - * - * @param {FileCache|Cache} cache The cache to search - * @param {string[]} names The names of the item to search for - * @returns {Promise} The item from the cache, or undefined if not found. - */ -async function tryCache(cache, ...names) { - for (let name of names) { - try { - let result = await cache.match(name); - if (result) return result; - } catch (e) { - continue; - } - } - return undefined; -} - /** * Retrieves a file from either a remote URL using the Fetch API or from the local file system using the FileSystem API. * If the filesystem is available and `env.useCache = true`, the file will be downloaded and cached. @@ -419,49 +128,8 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti file: filename, }); - // First, check if the a caching backend is available - // If no caching mechanism available, will download the file every time - let cache; - if (!cache && env.useCustomCache) { - // Allow the user to specify a custom cache system. - if (!env.customCache) { - throw Error('`env.useCustomCache=true`, but `env.customCache` is not defined.'); - } - - // Check that the required methods are defined: - if (!env.customCache.match || !env.customCache.put) { - throw new Error( - '`env.customCache` must be an object which implements the `match` and `put` functions of the Web Cache API. ' + - 'For more information, see https://developer.mozilla.org/en-US/docs/Web/API/Cache', - ); - } - cache = env.customCache; - } - - if (!cache && env.useBrowserCache) { - if (typeof caches === 'undefined') { - throw Error('Browser cache is not available in this environment.'); - } - try { - // In some cases, the browser cache may be visible, but not accessible due to security restrictions. - // For example, when running an application in an iframe, if a user attempts to load the page in - // incognito mode, the following error is thrown: `DOMException: Failed to execute 'open' on 'CacheStorage': - // An attempt was made to break through the security policy of the user agent.` - // So, instead of crashing, we just ignore the error and continue without using the cache. - cache = await caches.open(env.cacheKey); - } catch (e) { - console.warn('An error occurred while opening the browser cache:', e); - } - } - - if (!cache && env.useFSCache) { - if (!apis.IS_FS_AVAILABLE) { - throw Error('File System Cache is not available in this environment.'); - } - - // If `cache_dir` is not specified, use the default cache directory - cache = new FileCache(options.cache_dir ?? env.cacheDir); - } + /** @type {import('./cache.js').CacheInterface | null} */ + const cache = await getCache(options?.cache_dir); const revision = options.revision ?? 'main'; const requestURL = pathJoin(path_or_repo_id, filename); @@ -491,7 +159,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti // Whether to cache the final response in the end. let toCacheResponse = false; - /** @type {Response|FileResponse|undefined} */ + /** @type {Response|import('./hub/FileResponse.js').default|undefined} */ let response; if (cache) { @@ -503,7 +171,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti } const cacheHit = response !== undefined; - if (response === undefined) { + if (!cacheHit) { // Caching not available, or file is not cached, so we perform the request if (env.allowLocalModels) { @@ -726,73 +394,3 @@ export async function getModelJSON(modelPath, fileName, fatal = true, options = return JSON.parse(text); } -/** - * Read and track progress when reading a Response object - * - * @param {Response|FileResponse} response The Response object to read - * @param {(data: {progress: number, loaded: number, total: number}) => void} progress_callback The function to call with progress updates - * @returns {Promise} A Promise that resolves with the Uint8Array buffer - */ -async function readResponse(response, progress_callback) { - const contentLength = response.headers.get('Content-Length'); - if (contentLength === null) { - console.warn('Unable to determine content-length from response headers. Will expand buffer when needed.'); - } - let total = parseInt(contentLength ?? '0'); - let buffer = new Uint8Array(total); - let loaded = 0; - - const reader = response.body.getReader(); - async function read() { - const { done, value } = await reader.read(); - if (done) return; - - const newLoaded = loaded + value.length; - if (newLoaded > total) { - total = newLoaded; - - // Adding the new data will overflow buffer. - // In this case, we extend the buffer - const newBuffer = new Uint8Array(total); - - // copy contents - newBuffer.set(buffer); - - buffer = newBuffer; - } - buffer.set(value, loaded); - loaded = newLoaded; - - const progress = (loaded / total) * 100; - - // Call your function here - progress_callback({ progress, loaded, total }); - - return read(); - } - - // Actually read - await read(); - - return buffer; -} - -/** - * Joins multiple parts of a path into a single path, while handling leading and trailing slashes. - * - * @param {...string} parts Multiple parts of a path. - * @returns {string} A string representing the joined path. - */ -function pathJoin(...parts) { - // https://stackoverflow.com/a/55142565 - parts = parts.map((part, index) => { - if (index) { - part = part.replace(new RegExp('^/'), ''); - } - if (index !== parts.length - 1) { - part = part.replace(new RegExp('/$'), ''); - } - return part; - }); - return parts.join('/'); -} diff --git a/src/utils/hub/FileCache.js b/src/utils/hub/FileCache.js new file mode 100644 index 000000000..c227c8c8f --- /dev/null +++ b/src/utils/hub/FileCache.js @@ -0,0 +1,92 @@ +import fs from 'node:fs'; +import path from 'node:path'; +import FileResponse from './FileResponse.js'; + +/** + * File system cache implementation that implements the CacheInterface. + * Provides `match` and `put` methods compatible with the Web Cache API. + */ +export default class FileCache { + /** + * Instantiate a `FileCache` object. + * @param {string} path + */ + constructor(path) { + this.path = path; + } + + /** + * Checks whether the given request is in the cache. + * @param {string} request + * @returns {Promise} + */ + async match(request) { + let filePath = path.join(this.path, request); + let file = new FileResponse(filePath); + + if (file.exists) { + return file; + } else { + return undefined; + } + } + + /** + * Adds the given response to the cache. + * @param {string} request + * @param {Response} response + * @param {(data: {progress: number, loaded: number, total: number}) => void} [progress_callback] Optional. + * The function to call with progress updates + * @returns {Promise} + */ + async put(request, response, progress_callback = undefined) { + let filePath = path.join(this.path, request); + + try { + const contentLength = response.headers.get('Content-Length'); + const total = parseInt(contentLength ?? '0'); + let loaded = 0; + + await fs.promises.mkdir(path.dirname(filePath), { recursive: true }); + const fileStream = fs.createWriteStream(filePath); + const reader = response.body.getReader(); + + while (true) { + const { done, value } = await reader.read(); + if (done) { + break; + } + + await new Promise((resolve, reject) => { + fileStream.write(value, (err) => { + if (err) { + reject(err); + return; + } + resolve(); + }); + }); + + loaded += value.length; + const progress = total ? (loaded / total) * 100 : 0; + + progress_callback?.({ progress, loaded, total }); + } + + fileStream.close(); + } catch (error) { + // Clean up the file if an error occurred during download + try { + await fs.promises.unlink(filePath); + } catch {} + throw error; + } + } + + // TODO add the rest? + // addAll(requests: RequestInfo[]): Promise; + // delete(request: RequestInfo | URL, options?: CacheQueryOptions): Promise; + // keys(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise>; + // match(request: RequestInfo | URL, options?: CacheQueryOptions): Promise; + // matchAll(request?: RequestInfo | URL, options?: CacheQueryOptions): Promise>; +} diff --git a/src/utils/hub/FileResponse.js b/src/utils/hub/FileResponse.js new file mode 100644 index 000000000..2e421f6bc --- /dev/null +++ b/src/utils/hub/FileResponse.js @@ -0,0 +1,121 @@ +import fs from 'node:fs'; + +/** + * Mapping from file extensions to MIME types. + */ +const CONTENT_TYPE_MAP = { + txt: 'text/plain', + html: 'text/html', + css: 'text/css', + js: 'text/javascript', + json: 'application/json', + png: 'image/png', + jpg: 'image/jpeg', + jpeg: 'image/jpeg', + gif: 'image/gif', +}; + +export default class FileResponse { + /** + * Creates a new `FileResponse` object. + * @param {string} filePath + */ + constructor(filePath) { + this.filePath = filePath; + this.headers = new Headers(); + + this.exists = fs.existsSync(filePath); + if (this.exists) { + this.status = 200; + this.statusText = 'OK'; + + let stats = fs.statSync(filePath); + this.headers.set('content-length', stats.size.toString()); + + this.updateContentType(); + + const stream = fs.createReadStream(filePath); + this.body = new ReadableStream({ + start(controller) { + stream.on('data', (chunk) => controller.enqueue(chunk)); + stream.on('end', () => controller.close()); + stream.on('error', (err) => controller.error(err)); + }, + cancel() { + stream.destroy(); + }, + }); + } else { + this.status = 404; + this.statusText = 'Not Found'; + this.body = null; + } + } + + /** + * Updates the 'content-type' header property of the response based on the extension of + * the file specified by the filePath property of the current object. + * @returns {void} + */ + updateContentType() { + // Set content-type header based on file extension + const extension = this.filePath.toString().split('.').pop().toLowerCase(); + this.headers.set('content-type', CONTENT_TYPE_MAP[extension] ?? 'application/octet-stream'); + } + + /** + * Clone the current FileResponse object. + * @returns {FileResponse} A new FileResponse object with the same properties as the current object. + */ + clone() { + let response = new FileResponse(this.filePath); + response.exists = this.exists; + response.status = this.status; + response.statusText = this.statusText; + response.headers = new Headers(this.headers); + return response; + } + + /** + * Reads the contents of the file specified by the filePath property and returns a Promise that + * resolves with an ArrayBuffer containing the file's contents. + * @returns {Promise} A Promise that resolves with an ArrayBuffer containing the file's contents. + * @throws {Error} If the file cannot be read. + */ + async arrayBuffer() { + const data = await fs.promises.readFile(this.filePath); + return /** @type {ArrayBuffer} */ (data.buffer); + } + + /** + * Reads the contents of the file specified by the filePath property and returns a Promise that + * resolves with a Blob containing the file's contents. + * @returns {Promise} A Promise that resolves with a Blob containing the file's contents. + * @throws {Error} If the file cannot be read. + */ + async blob() { + const data = await fs.promises.readFile(this.filePath); + return new Blob([/** @type {any} */ (data)], { type: this.headers.get('content-type') }); + } + + /** + * Reads the contents of the file specified by the filePath property and returns a Promise that + * resolves with a string containing the file's contents. + * @returns {Promise} A Promise that resolves with a string containing the file's contents. + * @throws {Error} If the file cannot be read. + */ + async text() { + return await fs.promises.readFile(this.filePath, 'utf8'); + } + + /** + * Reads the contents of the file specified by the filePath property and returns a Promise that + * resolves with a parsed JavaScript object containing the file's contents. + * + * @returns {Promise} A Promise that resolves with a parsed JavaScript object containing the file's contents. + * @throws {Error} If the file cannot be read. + */ + async json() { + return JSON.parse(await this.text()); + } +} diff --git a/src/utils/hub/constants.js b/src/utils/hub/constants.js new file mode 100644 index 000000000..b87c7e55d --- /dev/null +++ b/src/utils/hub/constants.js @@ -0,0 +1,18 @@ +export const ERROR_MAPPING = { + // 4xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#client_error_responses) + 400: 'Bad request error occurred while trying to load file', + 401: 'Unauthorized access to file', + 403: 'Forbidden access to file', + 404: 'Could not locate file', + 408: 'Request timeout error occurred while trying to load file', + + // 5xx errors (https://developer.mozilla.org/en-US/docs/Web/HTTP/Status#server_error_responses) + 500: 'Internal server error error occurred while trying to load file', + 502: 'Bad gateway error occurred while trying to load file', + 503: 'Service unavailable error occurred while trying to load file', + 504: 'Gateway timeout error occurred while trying to load file', +}; + +export const MAX_EXTERNAL_DATA_CHUNKS = 100; + +export const REPO_ID_REGEX = /^(\b[\w\-.]+\b\/)?\b[\w\-.]{1,96}\b$/; diff --git a/src/utils/hub/utils.js b/src/utils/hub/utils.js new file mode 100644 index 000000000..3f8bc8354 --- /dev/null +++ b/src/utils/hub/utils.js @@ -0,0 +1,128 @@ +import { ERROR_MAPPING, REPO_ID_REGEX } from './constants.js'; + +/** + * Joins multiple parts of a path into a single path, while handling leading and trailing slashes. + * + * @param {...string} parts Multiple parts of a path. + * @returns {string} A string representing the joined path. + */ +export function pathJoin(...parts) { + // https://stackoverflow.com/a/55142565 + parts = parts.map((part, index) => { + if (index) { + part = part.replace(new RegExp('^/'), ''); + } + if (index !== parts.length - 1) { + part = part.replace(new RegExp('/$'), ''); + } + return part; + }); + return parts.join('/'); +} + +/** + * Determines whether the given string is a valid URL. + * @param {string|URL} string The string to test for validity as an URL. + * @param {string[]} [protocols=null] A list of valid protocols. If specified, the protocol must be in this list. + * @param {string[]} [validHosts=null] A list of valid hostnames. If specified, the URL's hostname must be in this list. + * @returns {boolean} True if the string is a valid URL, false otherwise. + */ +export function isValidUrl(string, protocols = null, validHosts = null) { + let url; + try { + url = new URL(string); + } catch (_) { + return false; + } + if (protocols && !protocols.includes(url.protocol)) { + return false; + } + if (validHosts && !validHosts.includes(url.hostname)) { + return false; + } + return true; +} + +/** + * Tests whether a string is a valid Hugging Face model ID or not. + * Adapted from https://github.com/huggingface/huggingface_hub/blob/6378820ebb03f071988a96c7f3268f5bdf8f9449/src/huggingface_hub/utils/_validators.py#L119-L170 + * + * @param {string} string The string to test + * @returns {boolean} True if the string is a valid model ID, false otherwise. + */ +export function isValidHfModelId(string) { + if (!REPO_ID_REGEX.test(string)) return false; + if (string.includes('..') || string.includes('--')) return false; + if (string.endsWith('.git') || string.endsWith('.ipynb')) return false; + return true; +} + +/** + * Helper method to handle fatal errors that occur while trying to load a file from the Hugging Face Hub. + * @param {number} status The HTTP status code of the error. + * @param {string} remoteURL The URL of the file that could not be loaded. + * @param {boolean} fatal Whether to raise an error if the file could not be loaded. + * @returns {null} Returns `null` if `fatal = true`. + * @throws {Error} If `fatal = false`. + */ +export function handleError(status, remoteURL, fatal) { + if (!fatal) { + // File was not loaded correctly, but it is optional. + // TODO in future, cache the response? + return null; + } + + const message = ERROR_MAPPING[status] ?? `Error (${status}) occurred while trying to load file`; + throw Error(`${message}: "${remoteURL}".`); +} + +/** + * Read and track progress when reading a Response object + * + * @param {Response|import('./FileResponse.js').default} response The Response object to read + * @param {(data: {progress: number, loaded: number, total: number}) => void} progress_callback The function to call with progress updates + * @returns {Promise} A Promise that resolves with the Uint8Array buffer + */ +export async function readResponse(response, progress_callback) { + const contentLength = response.headers.get('Content-Length'); + if (contentLength === null) { + console.warn('Unable to determine content-length from response headers. Will expand buffer when needed.'); + } + let total = parseInt(contentLength ?? '0'); + let buffer = new Uint8Array(total); + let loaded = 0; + + const reader = response.body.getReader(); + async function read() { + const { done, value } = await reader.read(); + if (done) return; + + const newLoaded = loaded + value.length; + if (newLoaded > total) { + total = newLoaded; + + // Adding the new data will overflow buffer. + // In this case, we extend the buffer + const newBuffer = new Uint8Array(total); + + // copy contents + newBuffer.set(buffer); + + buffer = newBuffer; + } + buffer.set(value, loaded); + loaded = newLoaded; + + const progress = (loaded / total) * 100; + + // Call your function here + progress_callback({ progress, loaded, total }); + + return read(); + } + + // Actually read + await read(); + + return buffer; +} diff --git a/webpack.config.js b/webpack.config.js index 4d8edb24f..d1e264ac2 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -171,12 +171,14 @@ const NODE_EXTERNAL_MODULES = [ "node:fs", "node:path", "node:url", + "node:stream", + "node:stream/promises", ]; // Do not bundle node-only packages when bundling for the web. // NOTE: We can exclude the "node:" prefix for built-in modules here, // since we apply the `StripNodePrefixPlugin` to strip it. -const WEB_IGNORE_MODULES = ["onnxruntime-node", "sharp", "fs", "path", "url"]; +const WEB_IGNORE_MODULES = ["onnxruntime-node", "sharp", "fs", "path", "url", "stream", "stream/promises"]; // Do not bundle the following modules with webpack (mark as external) const WEB_EXTERNAL_MODULES = ["onnxruntime-common", "onnxruntime-web"]; From e480fc36df9060cfb952d20e0cfde0b58d297a26 Mon Sep 17 00:00:00 2001 From: Nico Martin Date: Thu, 4 Dec 2025 17:37:38 +0100 Subject: [PATCH 3/4] fixed comment --- src/backends/onnx.js | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/backends/onnx.js b/src/backends/onnx.js index 38f303403..9ac28e8df 100644 --- a/src/backends/onnx.js +++ b/src/backends/onnx.js @@ -153,12 +153,6 @@ let wasmLoadPromise = null; * Ensures the WASM binary is loaded and cached before creating an inference session. * Only runs once, even if called multiple times. * - * Note: This caches the WASM binary via the wasmBinary option, which allows it to work offline. - * However, the MJS loader file still needs to be fetched from the network (or bundled). - * For full offline support including the MJS file, you need to either: - * 1. Use a Service Worker to cache all network requests - * 2. Bundle onnxruntime-web with your application - * * @returns {Promise} */ async function ensureWasmLoaded() { From 24596d9cdc88fcf7ef3709c78512a9e9205a3dfb Mon Sep 17 00:00:00 2001 From: Nico Martin Date: Fri, 5 Dec 2025 09:45:25 +0100 Subject: [PATCH 4/4] added string as cache return --- src/backends/utils/cacheWasm.js | 6 +++--- src/utils/cache.js | 4 ++-- src/utils/hub.js | 13 +++++++------ 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/backends/utils/cacheWasm.js b/src/backends/utils/cacheWasm.js index 237a7e2b1..0c8170596 100644 --- a/src/backends/utils/cacheWasm.js +++ b/src/backends/utils/cacheWasm.js @@ -3,7 +3,7 @@ import { getCache } from '../../utils/cache.js'; /** * Loads and caches a file from the given URL. * @param {string} url The URL of the file to load. - * @returns {Promise} The response object, or null if loading failed. + * @returns {Promise} The response object, or null if loading failed. */ async function loadAndCacheFile(url) { const fileName = url.split('/').pop(); @@ -50,7 +50,7 @@ async function loadAndCacheFile(url) { export async function loadWasmBinary(wasmURL) { const response = await loadAndCacheFile(wasmURL); - if (!response) return null; + if (!response || typeof response === 'string') return null; try { return await response.arrayBuffer(); @@ -67,7 +67,7 @@ export async function loadWasmBinary(wasmURL) { */ export async function loadWasmFactory(libURL) { const response = await loadAndCacheFile(libURL); - if (!response) return null; + if (!response || typeof response === 'string') return null; try { let code = await response.text(); diff --git a/src/utils/cache.js b/src/utils/cache.js index 0f8582c5b..c803546d5 100644 --- a/src/utils/cache.js +++ b/src/utils/cache.js @@ -3,7 +3,7 @@ import FileCache from './hub/FileCache.js'; /** * @typedef {Object} CacheInterface - * @property {(request: string) => Promise} match + * @property {(request: string) => Promise} match * Checks if a request is in the cache and returns the cached response if found. * @property {(request: string, response: Response, progress_callback?: (data: {progress: number, loaded: number, total: number}) => void) => Promise} put * Adds a response to the cache. @@ -67,7 +67,7 @@ export async function getCache(file_cache_dir = null) { * Searches the cache for any of the provided names and returns the first match found. * @param {CacheInterface} cache The cache to search * @param {...string} names The names of the items to search for - * @returns {Promise} The item from the cache, or undefined if not found. + * @returns {Promise} The item from the cache, or undefined if not found. */ export async function tryCache(cache, ...names) { for (let name of names) { diff --git a/src/utils/hub.js b/src/utils/hub.js index 218b812a1..9954b8d36 100755 --- a/src/utils/hub.js +++ b/src/utils/hub.js @@ -159,7 +159,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti // Whether to cache the final response in the end. let toCacheResponse = false; - /** @type {Response|import('./hub/FileResponse.js').default|undefined} */ + /** @type {Response|import('./hub/FileResponse.js').default|undefined|string} */ let response; if (cache) { @@ -196,7 +196,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti } } - if (response === undefined || response.status === 404) { + if (response === undefined || (typeof response !== 'string' && response.status === 404)) { // File not found locally. This means either: // - The user has disabled local file access (`env.allowLocalModels=false`) // - the path is a valid HTTP url (`response === undefined`) @@ -253,14 +253,15 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti /** @type {Uint8Array} */ let buffer; - if (!options.progress_callback) { + if (!options.progress_callback && typeof response !== 'string') { // If no progress callback is specified, we can use the `.arrayBuffer()` // method to read the response. buffer = new Uint8Array(await response.arrayBuffer()); } else if ( cacheHit && // The item is being read from the cache typeof navigator !== 'undefined' && - /firefox/i.test(navigator.userAgent) // We are in Firefox + /firefox/i.test(navigator.userAgent) && // We are in Firefox + typeof response !== 'string' ) { // Due to bug in Firefox, we cannot display progress when loading from cache. // Fortunately, since this should be instantaneous, this should not impact users too much. @@ -275,7 +276,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti loaded: buffer.length, total: buffer.length, }); - } else { + } else if (typeof response !== 'string') { buffer = await readResponse(response, (data) => { dispatchCallback(options.progress_callback, { status: 'progress', @@ -309,7 +310,7 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti }) : undefined; await cache.put(cacheKey, /** @type {Response} */ (response), wrapped_progress); - } else { + } else if (typeof response !== 'string') { // NOTE: We use `new Response(buffer, ...)` instead of `response.clone()` to handle LFS files await cache .put(