Skip to content

Commit e682a4d

Browse files
authored
Merge pull request #181 from yueshuangyan/master
fix(converter): fix name attr visit of fetch_targets
2 parents 7658965 + 58305aa commit e682a4d

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

packages/paddlejs-converter/convertModel.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,18 @@
4242
# 在转换过程中新生成的、需要添加到vars中的variable
4343
appendedVarList = []
4444

45+
class ObjDict(dict):
46+
"""
47+
Makes a dictionary behave like an object,with attribute-style access.
48+
"""
49+
def __getattr__(self,name):
50+
try:
51+
return self[name]
52+
except:
53+
raise AttributeError(name)
54+
def __setattr__(self,name,value):
55+
self[name]=value
56+
4557
def validateShape(shape, name):
4658
"""检验shape长度,超过4则截断"""
4759
if len(shape) > 4:
@@ -323,7 +335,7 @@ def appendConnectOp(fetch_targets):
323335

324336
# 从fetch_targets中提取输出算子信息
325337
for target in fetch_targets:
326-
name = target['name']
338+
name = target.name
327339
curVar = fluid.global_scope().find_var(name)
328340
curTensor = np.array(curVar.get_tensor())
329341
shape = list(curTensor.shape)
@@ -409,7 +421,8 @@ def convertToPaddleJSModel():
409421
for input, value in op['inputs'].items():
410422
if len(value) <= 0:
411423
continue
412-
cur = {'name': value[0]}
424+
cur = ObjDict()
425+
cur.name = value[0]
413426
inputNames.append(cur)
414427
targets = appendConnectOp(inputNames)
415428
# op['inputs'] = targets

0 commit comments

Comments
 (0)