Skip to content

Commit 2df6248

Browse files
committed
fix(webgl): fix slice stack op run in webgl1.0 GLSL ES 1.0
1 parent eeefdf4 commit 2df6248

File tree

4 files changed

+71
-12
lines changed

4 files changed

+71
-12
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/**
2+
* @file common utils
3+
*/
4+
5+
import { env } from '@paddlejs/paddlejs-core';
6+
7+
8+
enum ArrTypeEnum {
9+
INT_TYPE = 'int',
10+
FLOAT_TYPE = 'float'
11+
}
12+
13+
// GLSL ES 3.00 支持 arr => int arr = int[](x, x, x,... x);
14+
// GLSL ES 1.0 (1.0) 不支持 array constructor
15+
// '[]' : array constructor supported in GLSL ES 3.00 and above only
16+
const initializeGLSLArr = (arr: Array<Number>, type: ArrTypeEnum) => {
17+
if (env.get('webglVersion') !== 2) {
18+
return arr.reduce((acc, cur, index) => {
19+
const tmp = index < arr.length - 1 ? `${cur}, ` : `${cur});`;
20+
return acc + tmp;
21+
}, `${type} arr[] = ${type}[](`);
22+
}
23+
24+
const arr_value = arr.reduce((acc, cur, index) => {
25+
return acc + `
26+
arr[${index}] = ${cur};`;
27+
}, '');
28+
29+
return `
30+
${type} arr[${arr.length}];
31+
${arr_value}
32+
`;
33+
};
34+
35+
export {
36+
initializeGLSLArr,
37+
ArrTypeEnum
38+
};

packages/paddlejs-backend-webgl/src/ops/shader/slice.ts

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
* @example x = [[1,2,3,4],[5,6,7,8]] axes=[1] starts=[2] ends=[3] => out [3,7]
55
*/
66

7+
import { env } from '@paddlejs/paddlejs-core';
8+
import { initializeGLSLArr, ArrTypeEnum } from '../atom/common_utils';
9+
10+
711
function mainFunc(
812
{
913
out, origin
@@ -70,23 +74,39 @@ function mainFunc(
7074
}
7175
}
7276
}
73-
// 生成 glsl arr => int arr = int[](x, x, x,... x);
74-
const glslIndexArr = res_pos.reduce((acc, cur, index) => {
75-
const tmp = index < res_pos.length - 1 ? `${cur}, ` : `${cur});`;
76-
return acc + tmp;
77-
}, 'int arr[] = int[](');
7877

78+
const glslIndexArr = initializeGLSLArr(res_pos, ArrTypeEnum.INT_TYPE);
79+
80+
const ifConditions = res_pos.reduce((acc, _, idx) => {
81+
const ifCondition = idx === 0
82+
? `
83+
int index = 0;
84+
if (sumVal == ${idx}) {
85+
index = arr[${idx}];
86+
}`
87+
: `
88+
else if (sumVal == ${idx}) {
89+
index = arr[${idx}];
90+
}
91+
`;
92+
return acc + ifCondition;
93+
}, '');
94+
95+
const getValueFromArrIndex = env.get('webglVersion') === 2
96+
? 'int index = arr[sumVal];'
97+
: ifConditions;
7998
return `
8099
void main(void) {
81100
ivec4 oPos = getOutputTensorPos();
82101
${glslIndexArr}
102+
83103
// 输出坐标转换为输入坐标
84104
int sumVal = oPos.a
85105
+ oPos.b * ${out.width_shape}
86106
+ oPos.g * ${out.height_shape} * ${out.width_shape}
87107
+ oPos.r * ${out.channel} * ${out.width_shape} * ${out.height_shape};
88108
89-
int index = arr[sumVal];
109+
${getValueFromArrIndex}
90110
91111
float res = 0.0;
92112
ivec4 co = getTensorPosFromArrayIndex_origin(index);

packages/paddlejs-backend-webgl/src/ops/shader/stack.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,13 @@ function mainFunc(
6666
+ oPos.g * ${out.height_shape} * ${out.width_shape}
6767
+ oPos.r * ${out.channel} * ${out.width_shape} * ${out.height_shape};
6868
69-
int index = sumVal % ${pre_every_num};
69+
int index = calMod(sumVal, ${pre_every_num});
7070
7171
int layer = sumVal / ${pre_every_num};
7272
7373
int i = index / ${post};
74-
int j = index % ${post} + layer * ${post};
74+
int j = calMod(index, ${post}) + layer * ${post};
7575
76-
7776
float o = 0.0;
7877
${getMultiInputsValue}
7978
setOutput(float(o));

packages/paddlejs-backend-webgl/test/op/data/stack.json

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,11 @@
8686
"name": "stack_out",
8787
"shape": [3, 3, 3],
8888
"data": [
89-
1, 2, 3, 4, 5, 6, 7, 8, 9,
90-
11, 12, 13, 14, 15, 16, 17, 18, 19,
91-
21, 22, 23, 24, 25, 26, 27, 28, 29
89+
1, 11, 21, 2, 12, 22,
90+
3, 13, 23, 4, 14, 24,
91+
5, 15, 25, 6, 16, 26,
92+
7, 17, 27, 8, 18, 28,
93+
9, 19, 29
9294
]
9395
}
9496
]

0 commit comments

Comments
 (0)