11import ast
22
33
4- class AsyncTransformer (ast . NodeTransformer ):
4+ class AsyncTransformer ():
55 """Converts all async nodes into their synchronous counterparts."""
66
77 def visit_Await (self , node ):
@@ -16,3 +16,53 @@ def visit_AsyncFor(self, node):
1616
1717 def visit_AsyncWith (self , node ):
1818 return self .visit (ast .With (** node .__dict__ ))
19+
20+
21+ class ChainedFunctionTransformer ():
22+ def visit_chain (self , node , depth = 1 ):
23+ if (
24+ isinstance (node .value , ast .Call ) and
25+ isinstance (node .value .func , ast .Attribute ) and
26+ isinstance (node .value .func .value , ast .Call )
27+ ):
28+ # Node is assignment or return with value like `b.c().d()`
29+ call_node = node .value
30+ # If we want to handle nested functions in future, depth needs fixing
31+ temp_var_id = '__chain_tmp_{}' .format (depth )
32+ # AST tree is from right to left, so d() is the outer Call and b.c() is the inner Call
33+ unvisited_inner_call = ast .Assign (
34+ targets = [ast .Name (id = temp_var_id , ctx = ast .Store ())],
35+ value = call_node .func .value ,
36+ )
37+ ast .copy_location (unvisited_inner_call , node )
38+ inner_calls = self .visit_chain (unvisited_inner_call , depth + 1 )
39+ for inner_call_node in inner_calls :
40+ ast .copy_location (inner_call_node , node )
41+ outer_call = self .generic_visit (type (node )(
42+ value = ast .Call (
43+ func = ast .Attribute (
44+ value = ast .Name (id = temp_var_id , ctx = ast .Load ()),
45+ attr = call_node .func .attr ,
46+ ctx = ast .Load (),
47+ ),
48+ args = call_node .args ,
49+ keywords = call_node .keywords ,
50+ ),
51+ ** {field : value for field , value in ast .iter_fields (node ) if field != 'value' } # e.g. targets
52+ ))
53+ ast .copy_location (outer_call , node )
54+ ast .copy_location (outer_call .value , node )
55+ ast .copy_location (outer_call .value .func , node )
56+ return [* inner_calls , outer_call ]
57+ else :
58+ return [self .generic_visit (node )]
59+
60+ def visit_Assign (self , node ):
61+ return self .visit_chain (node )
62+
63+ def visit_Return (self , node ):
64+ return self .visit_chain (node )
65+
66+
67+ class PytTransformer (AsyncTransformer , ChainedFunctionTransformer , ast .NodeTransformer ):
68+ pass
0 commit comments