Skip to content

Commit 60af354

Browse files
committed
feat: support modelObj.params type ParamObject
1 parent cf9e895 commit 60af354

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

packages/paddlejs-core/src/commons/interface.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,13 @@ export interface FeedShape {
6363
fh: number;
6464
};
6565

66+
67+
export interface ParamObject {
68+
[key: string]: number;
69+
}
6670
interface ModelObj {
6771
model: Model;
68-
params: Float32Array
72+
params: Float32Array | ParamObject
6973
}
7074
export interface RunnerConfig {
7175
modelPath?: string;

packages/paddlejs-core/src/loader.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
*/
44

55
import env from './env';
6-
import { Model } from './commons/interface';
6+
import { Model, ParamObject } from './commons/interface';
77
import { traverseVars } from './commons/utils';
88

99
interface UrlConf {
@@ -126,14 +126,15 @@ export default class ModelLoader {
126126
});
127127
}
128128

129-
static allocateParamsVar(vars, allChunksData: Float32Array) {
129+
static allocateParamsVar(vars, allChunksData: Float32Array | ParamObject) {
130130
let marker = 0; // 读到哪个位置了
131131
let len; // 当前op长度
132+
const chunkData: number[] = Array.isArray(allChunksData) ? allChunksData : Object.values(allChunksData);
132133
traverseVars(vars, item => {
133134
len = item.shape.reduce((a, b) => a * b); // 长度为shape的乘积
134135
// 为了减少模型体积,模型转换工具不会导出非persistable的数据,这里只需要读取persistable的数据
135136
if (item.persistable) {
136-
item.data = allChunksData.slice(marker, marker + len);
137+
item.data = chunkData.slice(marker, marker + len);
137138
marker += len;
138139
}
139140
});

0 commit comments

Comments
 (0)