Skip to content

Commit 9b261b7

Browse files
authored
Merge pull request #414 from JingyuanZhang/master
[webgl op]: add slice stack exp op
2 parents c757296 + 2df6248 commit 9b261b7

File tree

13 files changed

+451
-8
lines changed

13 files changed

+451
-8
lines changed

packages/paddlejs-backend-webgl/src/ops/atom/common_func.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,16 @@ float tanh_func(float x, float y, float z) {
6969
return tanh_calc(x);
7070
}`;
7171

72+
const exp = `
73+
float exp(float x, float y, float z) {
74+
float result = exp(x);
75+
return result;
76+
}`;
77+
7278
export {
7379
prelu,
7480
relu6,
81+
exp,
7582
leakyRelu,
7683
scale,
7784
sigmoid,
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/**
2+
* @file common utils
3+
*/
4+
5+
import { env } from '@paddlejs/paddlejs-core';
6+
7+
8+
enum ArrTypeEnum {
9+
INT_TYPE = 'int',
10+
FLOAT_TYPE = 'float'
11+
}
12+
13+
// GLSL ES 3.00 支持 arr => int arr = int[](x, x, x,... x);
14+
// GLSL ES 1.0 (1.0) 不支持 array constructor
15+
// '[]' : array constructor supported in GLSL ES 3.00 and above only
16+
const initializeGLSLArr = (arr: Array<Number>, type: ArrTypeEnum) => {
17+
if (env.get('webglVersion') !== 2) {
18+
return arr.reduce((acc, cur, index) => {
19+
const tmp = index < arr.length - 1 ? `${cur}, ` : `${cur});`;
20+
return acc + tmp;
21+
}, `${type} arr[] = ${type}[](`);
22+
}
23+
24+
const arr_value = arr.reduce((acc, cur, index) => {
25+
return acc + `
26+
arr[${index}] = ${cur};`;
27+
}, '');
28+
29+
return `
30+
${type} arr[${arr.length}];
31+
${arr_value}
32+
`;
33+
};
34+
35+
export {
36+
initializeGLSLArr,
37+
ArrTypeEnum
38+
};

packages/paddlejs-backend-webgl/src/ops/index.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ import pool2d_avg from './shader/pool2d_avg';
5757
import density_prior_box from './shader/density_prior_box';
5858
import box_coder from './shader/box_coder';
5959
import prior_box from './shader/prior_box';
60+
import stack from './shader/stack';
61+
import slice from './shader/slice';
6062

6163
import {
6264
imgFeed, pack_out, nhwc_2_nchw, unpacked_2_packed,
@@ -125,6 +127,7 @@ const ops = {
125127
pow: dynamic('pow'),
126128
sqrt: dynamic('sqrt'),
127129
tanh: dynamic('tanh'),
130+
exp: dynamic('exp'),
128131
squeeze2,
129132
pad3d,
130133
bilinear_interp_v2,
@@ -135,7 +138,9 @@ const ops = {
135138
imgFeed,
136139
box_coder,
137140
density_prior_box,
138-
prior_box
141+
prior_box,
142+
stack,
143+
slice
139144
};
140145
export {
141146
ops

packages/paddlejs-backend-webgl/src/ops/shader/concat.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @file concat dynamic inputs
3-
* @description concat inputs X supports no more than 4 tensors, eg. [a1, a2, a3, a4]
3+
* @description concat inputs X supports no more than 15 tensors, eg. [a1, a2, a3, a4, ... , a15]
44
*/
55

66
/* eslint-disable max-lines */

packages/paddlejs-backend-webgl/src/ops/shader/concat_mul.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
/**
22
* @file concat_mul dynamic inputs
3-
* @description concat inputs X supports no more than 4 tensors, eg. [a1, a2, a3, a4]
3+
* @description concat inputs X supports no more than 15 tensors, eg. [a1, a2, a3, a4, ... , a15]
44
*/
55

66
/* eslint-disable max-lines */

packages/paddlejs-backend-webgl/src/ops/shader/dynamic.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const commonFuncBehaviors = {
1111
sigmoid: ['transToSigmoid'],
1212
hard_sigmoid: ['transToHardSigmoid'],
1313
pow: ['transToPow'],
14+
exp: ['transToExp'],
1415
sqrt: ['transToSqrt'],
1516
tanh: ['transToTanh']
1617
};
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/**
2+
* @file slice
3+
* @description slice op,目前 case 很单一,暂时仅支持遇到的 case:slice final dim
4+
* @example x = [[1,2,3,4],[5,6,7,8]] axes=[1] starts=[2] ends=[3] => out [3,7]
5+
*/
6+
7+
import { env } from '@paddlejs/paddlejs-core';
8+
import { initializeGLSLArr, ArrTypeEnum } from '../atom/common_utils';
9+
10+
11+
function mainFunc(
12+
{
13+
out, origin
14+
},
15+
{
16+
axes,
17+
starts,
18+
ends,
19+
decrease_axis
20+
}
21+
) {
22+
23+
if (
24+
axes.length > 1
25+
|| starts.length > 1
26+
|| ends.length > 1
27+
|| (decrease_axis && decrease_axis.length === 0)
28+
) {
29+
throw Error('[slice op feature]: current support one dim, support decrease_axis');
30+
}
31+
const {
32+
width_shape,
33+
height_shape,
34+
channel,
35+
total_shape,
36+
length_unformatted_shape
37+
} = origin;
38+
const batch = total_shape / (width_shape * height_shape * channel);
39+
const tensor_shape = [batch, channel, height_shape, width_shape];
40+
41+
let axis = axes[0];
42+
43+
if (axis < 0) {
44+
axis = axis + length_unformatted_shape + 1;
45+
}
46+
axis = 4 - length_unformatted_shape + axis;
47+
48+
if (axis !== 4) {
49+
throw Error('[slice op feature]: unsupport axis value');
50+
}
51+
52+
const start = starts[0];
53+
const end = ends[0];
54+
55+
const [
56+
batch_num,
57+
channel_num,
58+
height_num,
59+
width_num
60+
] = tensor_shape;
61+
62+
// 计算 output tensor value 对应的 origin index
63+
const res_pos = [];
64+
for (let index = start; index < end; index++) {
65+
for (let batch = 0; batch < batch_num; batch++) {
66+
for (let channel = 0; channel < channel_num; channel++) {
67+
for (let height = 0; height < height_num; height++) {
68+
res_pos.push(
69+
batch * channel_num * height_num * width_num
70+
+ channel * height_num * width_num
71+
+ height * width_num + index
72+
);
73+
}
74+
}
75+
}
76+
}
77+
78+
const glslIndexArr = initializeGLSLArr(res_pos, ArrTypeEnum.INT_TYPE);
79+
80+
const ifConditions = res_pos.reduce((acc, _, idx) => {
81+
const ifCondition = idx === 0
82+
? `
83+
int index = 0;
84+
if (sumVal == ${idx}) {
85+
index = arr[${idx}];
86+
}`
87+
: `
88+
else if (sumVal == ${idx}) {
89+
index = arr[${idx}];
90+
}
91+
`;
92+
return acc + ifCondition;
93+
}, '');
94+
95+
const getValueFromArrIndex = env.get('webglVersion') === 2
96+
? 'int index = arr[sumVal];'
97+
: ifConditions;
98+
return `
99+
void main(void) {
100+
ivec4 oPos = getOutputTensorPos();
101+
${glslIndexArr}
102+
103+
// 输出坐标转换为输入坐标
104+
int sumVal = oPos.a
105+
+ oPos.b * ${out.width_shape}
106+
+ oPos.g * ${out.height_shape} * ${out.width_shape}
107+
+ oPos.r * ${out.channel} * ${out.width_shape} * ${out.height_shape};
108+
109+
${getValueFromArrIndex}
110+
111+
float res = 0.0;
112+
ivec4 co = getTensorPosFromArrayIndex_origin(index);
113+
res = getValueFromTensorPos_origin(co.r, co.g, co.b, co.a);
114+
setOutput(float(res));
115+
}
116+
`;
117+
}
118+
export default {
119+
mainFunc,
120+
textureFuncConf: {
121+
origin: ['getValueFromTensorPos', 'getTensorPosFromArrayIndex']
122+
}
123+
};
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/**
2+
* @file stack dynamic inputs
3+
* @description stack inputs X supports no more than 15 tensors, eg. [a1, a2, a3, a4, ... , a15]
4+
* @detail https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/stack_op.h#L56
5+
* @detail https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/stack_cn.html#stack
6+
*/
7+
8+
9+
function mainFunc(
10+
{ out, ...inputs },
11+
attrs
12+
) {
13+
const origin_tensor = inputs['origin'];
14+
const {
15+
width_shape,
16+
height_shape,
17+
channel,
18+
total_shape,
19+
length_unformatted_shape
20+
} = origin_tensor;
21+
const batch = total_shape / (width_shape * height_shape * channel);
22+
23+
const tensor_shape = [batch, channel, height_shape, width_shape];
24+
const origin_shape = tensor_shape.slice(4 - length_unformatted_shape);
25+
26+
const inputs_num = Object.keys(inputs).length;
27+
28+
const axis = attrs.axis < 0 ? attrs.axis + origin_shape.length + 1 : attrs.axis;
29+
30+
let pre = 1;
31+
let post = 1;
32+
for (let index = 0; index < axis; index++) {
33+
pre *= origin_shape[index];
34+
}
35+
for (let index = axis; index < origin_shape.length; index++) {
36+
post *= origin_shape[index];
37+
}
38+
39+
const out_total_shape = out.total_shape;
40+
41+
const pre_every_num = out_total_shape / pre;
42+
43+
let getMultiInputsValue = '';
44+
getMultiInputsValue = Array.from(Array(inputs_num).keys()).reduce((acc, cur) => {
45+
return acc + (cur === 0
46+
? `
47+
if (i == 0) {
48+
ivec4 co = getTensorPosFromArrayIndex_origin(j);
49+
o = getValueFromTensorPos_origin(co.r, co.g, co.b, co.a);
50+
}`
51+
: `
52+
else if (i == ${cur}) {
53+
ivec4 co = getTensorPosFromArrayIndex_origin_${cur}(j);
54+
o = getValueFromTensorPos_origin_${cur}(co.r, co.g, co.b, co.a);
55+
}`);
56+
}, getMultiInputsValue);
57+
58+
59+
return `
60+
// start函数
61+
void main(void) {
62+
ivec4 oPos = getOutputTensorPos();
63+
// 输出坐标转换为输入坐标
64+
int sumVal = oPos.a
65+
+ oPos.b * ${out.width_shape}
66+
+ oPos.g * ${out.height_shape} * ${out.width_shape}
67+
+ oPos.r * ${out.channel} * ${out.width_shape} * ${out.height_shape};
68+
69+
int index = calMod(sumVal, ${pre_every_num});
70+
71+
int layer = sumVal / ${pre_every_num};
72+
73+
int i = index / ${post};
74+
int j = calMod(index, ${post}) + layer * ${post};
75+
76+
float o = 0.0;
77+
${getMultiInputsValue}
78+
setOutput(float(o));
79+
}
80+
`;
81+
}
82+
export default {
83+
mainFunc,
84+
textureFuncConf: {
85+
'@all': ['getValueFromTensorPos', 'getTensorPosFromArrayIndex']
86+
}
87+
};

0 commit comments

Comments
 (0)