Skip to content

Commit 0de6836

Browse files
fix the implementation
1 parent 04bbdc0 commit 0de6836

File tree

1 file changed

+31
-52
lines changed

1 file changed

+31
-52
lines changed

pydatastructs/linear_data_structures/_backend/cpp/algorithms/llvm_algorithms.py

Lines changed: 31 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,6 @@ def get_insertion_sort_ptr(dtype: str) -> int:
139139
return _materialize(dtype, "insertion")
140140

141141
def _build_insertion_sort_ir(dtype: str) -> str:
142-
"""Generate LLVM IR for insertion sort for the given dtype."""
143142
if dtype not in _SUPPORTED:
144143
raise ValueError(f"Unsupported dtype '{dtype}'. Supported: {list(_SUPPORTED)}")
145144

@@ -156,10 +155,13 @@ def _build_insertion_sort_ir(dtype: str) -> str:
156155
arr, n = fn.args
157156
arr.name, n.name = "arr", "n"
158157

158+
# Basic blocks
159159
b_entry = fn.append_basic_block("entry")
160160
b_outer = fn.append_basic_block("outer")
161-
b_inner = fn.append_basic_block("inner")
162-
b_latch = fn.append_basic_block("latch")
161+
b_inner_cond = fn.append_basic_block("inner.cond")
162+
b_inner_body = fn.append_basic_block("inner.body")
163+
b_inner_latch = fn.append_basic_block("inner.latch")
164+
b_outer_latch = fn.append_basic_block("outer.latch")
163165
b_exit = fn.append_basic_block("exit")
164166

165167
b = ir.IRBuilder(b_entry)
@@ -169,69 +171,46 @@ def _build_insertion_sort_ir(dtype: str) -> str:
169171

170172
b.position_at_end(b_outer)
171173
i_phi = b.phi(i32, name="i")
172-
i_phi.add_incoming(ir.Constant(i32, 1), b_entry)
174+
i_phi.add_incoming(ir.Constant(i32, 1), b_entry) # start from 1 for insertion sort
173175

174176
cond_outer = b.icmp_signed("<", i_phi, n)
175-
b.cbranch(cond_outer, b_inner, b_exit)
176-
177-
b.position_at_end(b_inner)
178-
i64_cast = b.sext(i_phi, i64)
177+
b.cbranch(cond_outer, b_inner_cond, b_exit)
179178

180-
# key = arr[i]
181-
ptr_i = b.gep(arr, [i64_cast])
182-
key = b.load(ptr_i, name="key")
179+
b.position_at_end(b_inner_cond)
180+
key_ptr = b.gep(arr, [i_phi])
181+
key_val = b.load(key_ptr, name="key_val")
183182

184-
# j = i - 1
185-
j = b.sub(i_phi, ir.Constant(i32, 1), name="j")
186-
187-
# while j >= 0 and arr[j] > key
188-
loop_cond = fn.append_basic_block("loop_cond")
189-
loop_body = fn.append_basic_block("loop_body")
190-
loop_exit = fn.append_basic_block("loop_exit")
183+
j_phi = b.phi(i32, name="j")
184+
j_phi.add_incoming(i_phi, b_outer) # j = i
191185

192-
b.branch(loop_cond)
193-
b.position_at_end(loop_cond)
186+
inner_cond = b.icmp_signed(">", j_phi, ir.Constant(i32, 0))
187+
b.cbranch(inner_cond, b_inner_body, b_outer_latch)
194188

195-
j64 = b.sext(j, i64)
196-
ptr_j = b.gep(arr, [j64])
197-
val_j = b.load(ptr_j, name="val_j")
189+
b.position_at_end(b_inner_body)
190+
j_minus_1 = b.sub(j_phi, ir.Constant(i32, 1))
191+
ptr_j_minus_1 = b.gep(arr, [j_minus_1])
192+
val_j_minus_1 = b.load(ptr_j_minus_1)
198193

199-
if isinstance(T, ir.IntType):
200-
cmp1 = b.icmp_signed(">=", j, ir.Constant(i32, 0))
201-
cmp2 = b.icmp_signed(">", val_j, key)
202-
else:
203-
cmp1 = b.icmp_signed(">=", j, ir.Constant(i32, 0))
204-
cmp2 = b.fcmp_ordered(">", val_j, key)
194+
cmp = b.icmp_signed(">", val_j_minus_1, key_val) if isinstance(T, ir.IntType) else b.fcmp_ordered(">", val_j_minus_1, key_val)
195+
b.cbranch(cmp, b_inner_latch, b_outer_latch)
205196

206-
cond = b.and_(cmp1, cmp2)
207-
b.cbranch(cond, loop_body, loop_exit)
208-
209-
# loop body
210-
b.position_at_end(loop_body)
211-
jp1 = b.add(j, ir.Constant(i32, 1))
212-
jp1_64 = b.sext(jp1, i64)
213-
ptr_jp1 = b.gep(arr, [jp1_64])
214-
b.store(val_j, ptr_jp1)
215-
216-
j_next = b.sub(j, ir.Constant(i32, 1))
217-
j = j_next
218-
b.branch(loop_cond)
197+
b.position_at_end(b_inner_latch)
198+
ptr_j = b.gep(arr, [j_phi])
199+
b.store(val_j_minus_1, ptr_j)
219200

220-
# after loop: arr[j + 1] = key
221-
b.position_at_end(loop_exit)
222-
jp1_final = b.add(j, ir.Constant(i32, 1))
223-
jp1_final_64 = b.sext(jp1_final, i64)
224-
ptr_jp1_final = b.gep(arr, [jp1_final_64])
225-
b.store(key, ptr_jp1_final)
201+
j_next = b.sub(j_phi, ir.Constant(i32, 1))
202+
j_phi.add_incoming(j_next, b_inner_latch)
203+
b.branch(b_inner_cond)
226204

227-
b.branch(b_latch)
205+
b.position_at_end(b_outer_latch)
206+
ptr_j = b.gep(arr, [j_phi])
207+
b.store(key_val, ptr_j)
228208

229-
# outer latch
230-
b.position_at_end(b_latch)
231209
i_next = b.add(i_phi, ir.Constant(i32, 1))
232-
i_phi.add_incoming(i_next, b_latch)
210+
i_phi.add_incoming(i_next, b_outer_latch)
233211
b.branch(b_outer)
234212

213+
# Exit
235214
b.position_at_end(b_exit)
236215
b.ret_void()
237216

0 commit comments

Comments
 (0)