@@ -139,7 +139,6 @@ def get_insertion_sort_ptr(dtype: str) -> int:
139139 return _materialize (dtype , "insertion" )
140140
141141def _build_insertion_sort_ir (dtype : str ) -> str :
142- """Generate LLVM IR for insertion sort for the given dtype."""
143142 if dtype not in _SUPPORTED :
144143 raise ValueError (f"Unsupported dtype '{ dtype } '. Supported: { list (_SUPPORTED )} " )
145144
@@ -156,10 +155,13 @@ def _build_insertion_sort_ir(dtype: str) -> str:
156155 arr , n = fn .args
157156 arr .name , n .name = "arr" , "n"
158157
158+ # Basic blocks
159159 b_entry = fn .append_basic_block ("entry" )
160160 b_outer = fn .append_basic_block ("outer" )
161- b_inner = fn .append_basic_block ("inner" )
162- b_latch = fn .append_basic_block ("latch" )
161+ b_inner_cond = fn .append_basic_block ("inner.cond" )
162+ b_inner_body = fn .append_basic_block ("inner.body" )
163+ b_inner_latch = fn .append_basic_block ("inner.latch" )
164+ b_outer_latch = fn .append_basic_block ("outer.latch" )
163165 b_exit = fn .append_basic_block ("exit" )
164166
165167 b = ir .IRBuilder (b_entry )
@@ -169,69 +171,46 @@ def _build_insertion_sort_ir(dtype: str) -> str:
169171
170172 b .position_at_end (b_outer )
171173 i_phi = b .phi (i32 , name = "i" )
172- i_phi .add_incoming (ir .Constant (i32 , 1 ), b_entry )
174+ i_phi .add_incoming (ir .Constant (i32 , 1 ), b_entry ) # start from 1 for insertion sort
173175
174176 cond_outer = b .icmp_signed ("<" , i_phi , n )
175- b .cbranch (cond_outer , b_inner , b_exit )
176-
177- b .position_at_end (b_inner )
178- i64_cast = b .sext (i_phi , i64 )
177+ b .cbranch (cond_outer , b_inner_cond , b_exit )
179178
180- # key = arr[i]
181- ptr_i = b .gep (arr , [i64_cast ])
182- key = b .load (ptr_i , name = "key " )
179+ b . position_at_end ( b_inner_cond )
180+ key_ptr = b .gep (arr , [i_phi ])
181+ key_val = b .load (key_ptr , name = "key_val " )
183182
184- # j = i - 1
185- j = b .sub (i_phi , ir .Constant (i32 , 1 ), name = "j" )
186-
187- # while j >= 0 and arr[j] > key
188- loop_cond = fn .append_basic_block ("loop_cond" )
189- loop_body = fn .append_basic_block ("loop_body" )
190- loop_exit = fn .append_basic_block ("loop_exit" )
183+ j_phi = b .phi (i32 , name = "j" )
184+ j_phi .add_incoming (i_phi , b_outer ) # j = i
191185
192- b . branch ( loop_cond )
193- b .position_at_end ( loop_cond )
186+ inner_cond = b . icmp_signed ( ">" , j_phi , ir . Constant ( i32 , 0 ) )
187+ b .cbranch ( inner_cond , b_inner_body , b_outer_latch )
194188
195- j64 = b .sext (j , i64 )
196- ptr_j = b .gep (arr , [j64 ])
197- val_j = b .load (ptr_j , name = "val_j" )
189+ b .position_at_end (b_inner_body )
190+ j_minus_1 = b .sub (j_phi , ir .Constant (i32 , 1 ))
191+ ptr_j_minus_1 = b .gep (arr , [j_minus_1 ])
192+ val_j_minus_1 = b .load (ptr_j_minus_1 )
198193
199- if isinstance (T , ir .IntType ):
200- cmp1 = b .icmp_signed (">=" , j , ir .Constant (i32 , 0 ))
201- cmp2 = b .icmp_signed (">" , val_j , key )
202- else :
203- cmp1 = b .icmp_signed (">=" , j , ir .Constant (i32 , 0 ))
204- cmp2 = b .fcmp_ordered (">" , val_j , key )
194+ cmp = b .icmp_signed (">" , val_j_minus_1 , key_val ) if isinstance (T , ir .IntType ) else b .fcmp_ordered (">" , val_j_minus_1 , key_val )
195+ b .cbranch (cmp , b_inner_latch , b_outer_latch )
205196
206- cond = b .and_ (cmp1 , cmp2 )
207- b .cbranch (cond , loop_body , loop_exit )
208-
209- # loop body
210- b .position_at_end (loop_body )
211- jp1 = b .add (j , ir .Constant (i32 , 1 ))
212- jp1_64 = b .sext (jp1 , i64 )
213- ptr_jp1 = b .gep (arr , [jp1_64 ])
214- b .store (val_j , ptr_jp1 )
215-
216- j_next = b .sub (j , ir .Constant (i32 , 1 ))
217- j = j_next
218- b .branch (loop_cond )
197+ b .position_at_end (b_inner_latch )
198+ ptr_j = b .gep (arr , [j_phi ])
199+ b .store (val_j_minus_1 , ptr_j )
219200
220- # after loop: arr[j + 1] = key
221- b .position_at_end (loop_exit )
222- jp1_final = b .add (j , ir .Constant (i32 , 1 ))
223- jp1_final_64 = b .sext (jp1_final , i64 )
224- ptr_jp1_final = b .gep (arr , [jp1_final_64 ])
225- b .store (key , ptr_jp1_final )
201+ j_next = b .sub (j_phi , ir .Constant (i32 , 1 ))
202+ j_phi .add_incoming (j_next , b_inner_latch )
203+ b .branch (b_inner_cond )
226204
227- b .branch (b_latch )
205+ b .position_at_end (b_outer_latch )
206+ ptr_j = b .gep (arr , [j_phi ])
207+ b .store (key_val , ptr_j )
228208
229- # outer latch
230- b .position_at_end (b_latch )
231209 i_next = b .add (i_phi , ir .Constant (i32 , 1 ))
232- i_phi .add_incoming (i_next , b_latch )
210+ i_phi .add_incoming (i_next , b_outer_latch )
233211 b .branch (b_outer )
234212
213+ # Exit
235214 b .position_at_end (b_exit )
236215 b .ret_void ()
237216
0 commit comments