@@ -14,10 +14,12 @@ def __init__(self, backend: str, n_cpu: int, block_size: int) -> None:
1414 self .block_dim = block_size
1515 ti .init (arch = getattr (ti , backend .split ("-" )[- 1 ]))
1616 self .N = 0
17+ self .fb : ti .FieldsBuilder
18+ self .fbst : ti ._snode .snode_tree .SNodeTree
19+ self .terr = ti .field (ti .f32 , (3 ,))
1720 self .tA = ti .field (ti .i32 )
1821 self .tB = ti .field (ti .f32 )
1922 self .tX = ti .field (ti .f32 )
20- self .terr = ti .field (ti .f32 , (3 ,))
2123 self .tmp = ti .field (ti .f32 )
2224
2325 def partition (self , mask : np .ndarray ) -> np .ndarray :
@@ -29,10 +31,19 @@ def reset(self, N: int, A: np.ndarray, X: np.ndarray, B: np.ndarray) -> None:
2931 self .A = A
3032 self .B = B
3133 self .X = X
32- ti .root .dense (ti .ij , A .shape ).place (self .tA )
33- ti .root .dense (ti .ij , B .shape ).place (self .tB )
34- ti .root .dense (ti .ij , X .shape ).place (self .tX )
35- ti .root .dense (ti .ij , X .shape ).place (self .tmp )
34+ if hasattr (self , "fbst" ):
35+ self .fbst .destroy ()
36+ self .tA = ti .field (ti .i32 )
37+ self .tB = ti .field (ti .f32 )
38+ self .tX = ti .field (ti .f32 )
39+ self .tmp = ti .field (ti .f32 )
40+ self .fb = ti .FieldsBuilder ()
41+ layout = self .fb .dense (ti .i , N )
42+ layout .dense (ti .j , 4 ).place (self .tA )
43+ layout .dense (ti .j , 3 ).place (self .tB )
44+ layout .dense (ti .j , 3 ).place (self .tX )
45+ layout .dense (ti .j , 3 ).place (self .tmp )
46+ self .fbst = self .fb .finalize ()
3647 self .tA .from_numpy (A )
3748 self .tB .from_numpy (B )
3849 self .tX .from_numpy (X )
@@ -112,11 +123,13 @@ def __init__(
112123 self .parallelize = n_cpu
113124 self .block_dim = block_size
114125 ti .init (arch = getattr (ti , backend .split ("-" )[- 1 ]))
126+ self .fb : ti .FieldsBuilder
127+ self .fbst : ti ._snode .snode_tree .SNodeTree
128+ self .terr = ti .field (ti .f32 , (3 ,))
115129 self .tmask = ti .field (ti .i32 )
116130 self .ttgt = ti .field (ti .f32 )
117131 self .tgrad = ti .field (ti .f32 )
118132 self .tmp = ti .field (ti .f32 )
119- self .terr = ti .field (ti .f32 , (3 ,))
120133
121134 def reset (
122135 self , N : int , mask : np .ndarray , tgt : np .ndarray , grad : np .ndarray
@@ -132,15 +145,23 @@ def reset(
132145
133146 self .N , self .M = N , M = mask .shape
134147 bx , by = N // gx , M // gy
135- layout = ti .root .dense (ti .ij , (bx , by )).dense (ti .ij , (gx , gy ))
136148 self .mask = mask
137149 self .tgt = tgt
138150 self .grad = grad
139151
152+ if hasattr (self , "fbst" ):
153+ self .fbst .destroy ()
154+ self .tmask = ti .field (ti .i32 )
155+ self .ttgt = ti .field (ti .f32 )
156+ self .tgrad = ti .field (ti .f32 )
157+ self .tmp = ti .field (ti .f32 )
158+ self .fb = ti .FieldsBuilder ()
159+ layout = self .fb .dense (ti .ij , (bx , by )).dense (ti .ij , (gx , gy ))
140160 layout .place (self .tmask )
141161 layout .dense (ti .k , 3 ).place (self .ttgt )
142162 layout .dense (ti .k , 3 ).place (self .tgrad )
143163 layout .dense (ti .k , 3 ).place (self .tmp )
164+ self .fbst = self .fb .finalize ()
144165 self .tmask .from_numpy (mask )
145166 self .ttgt .from_numpy (tgt )
146167 self .tgrad .from_numpy (grad )
0 commit comments