Skip to content

Commit 6ad8f80

Browse files
committed
fix: second part of the fix
Adds a test, fixes some comments and typings. This test (finally) test the optimizer itself!
1 parent b85cf65 commit 6ad8f80

File tree

7 files changed

+96
-205
lines changed

7 files changed

+96
-205
lines changed

src/arch/z80/optimizer/__init__.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414

1515

1616
def init():
17-
global LABELS
18-
global JUMP_LABELS
19-
2017
LABELS.clear()
2118
JUMP_LABELS.clear()
2219

@@ -172,7 +169,7 @@ def initialize_memory(basic_block):
172169
get_labels(basic_block)
173170

174171

175-
def optimize(initial_memory):
172+
def optimize(initial_memory: list[str]) -> str:
176173
"""This will remove useless instructions"""
177174
global BLOCKS
178175
global PROC_COUNTER

src/arch/z80/optimizer/basicblock.py

Lines changed: 34 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from src.arch.z80.optimizer import helpers
1111
from src.arch.z80.optimizer.common import JUMP_LABELS, LABELS
1212
from src.arch.z80.optimizer.cpustate import CPUState
13-
from src.arch.z80.optimizer.errors import OptimizerError
1413
from src.arch.z80.optimizer.helpers import ALL_REGS
1514
from src.arch.z80.optimizer.labelinfo import LabelInfo
1615
from src.arch.z80.optimizer.memcell import MemCell
@@ -237,108 +236,15 @@ def update_next_block(self):
237236
n_block = LABELS[last.opers[0]].basic_block
238237
self.add_goes_to(n_block)
239238

240-
def update_used_by_list(self):
241-
"""Every label has a set containing
242-
which blocks jumps (jp, jr, call) if any.
243-
A block can "use" (call/jump) only another block
244-
and only one"""
245-
246-
# Searches all labels and remove this block out
247-
# of their used_by set, since this might have changed
248-
for label in LABELS.values():
249-
label.used_by.remove(self) # Delete this bblock
250-
251-
def update_goes_and_comes(self):
252-
"""Once the block is a Basic one, check the last instruction and updates
253-
goes_to and comes_from set of the receivers.
254-
Note: jp, jr and ret are already done in update_next_block()
255-
"""
256-
if not len(self):
257-
return
258-
259-
last = self.mem[-1]
260-
inst = last.inst
261-
oper = last.opers
262-
cond = last.condition_flag
263-
264-
for blk in list(self.goes_to):
265-
self.delete_goes_to(blk)
266-
267-
if self.next:
268-
self.add_goes_to(self.next)
269-
270-
if not last.is_ender:
271-
return
272-
273-
if cond is None:
274-
self.delete_goes_to(self.next)
275-
276-
if last.inst in {"ret", "reti", "retn"} and cond is None:
277-
return # subroutine returns are updated from CALLer blocks
278-
279-
if oper and oper[0]:
280-
if oper[0] not in LABELS:
281-
__DEBUG__("INFO: %s is not defined. No optimization is done." % oper[0], 1)
282-
LABELS[oper[0]] = LabelInfo(oper[0], 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
283-
284-
LABELS[oper[0]].used_by.add(self)
285-
self.add_goes_to(LABELS[oper[0]].basic_block)
286-
287-
if inst in {"djnz", "jp", "jr"}:
288-
return
289-
290-
assert inst in ("call", "rst")
291-
292-
if self.next is None:
293-
raise OptimizerError("Unexpected NULL next block")
294-
295-
final_blk = self.next # The block all the final returns should go to
296-
stack = [LABELS[oper[0]].basic_block]
297-
bbset: set[BasicBlock] = set()
298-
299-
while stack:
300-
bb = stack.pop(0)
301-
while True:
302-
if bb is None:
303-
bb = DummyBasicBlock(ALL_REGS, ALL_REGS)
304-
305-
if bb in bbset:
306-
break
307-
308-
bbset.add(bb)
309-
310-
if isinstance(bb, DummyBasicBlock):
311-
bb.add_goes_to(final_blk)
312-
break
313-
314-
if bb:
315-
bb1 = bb[-1]
316-
if bb1.inst in {"ret", "reti", "retn"}:
317-
bb.add_goes_to(final_blk)
318-
if bb1.condition_flag is None: # 'ret'
319-
break
320-
elif bb1.inst in ("jp", "jr") and bb1.condition_flag is not None: # jp/jr nc/nz/.. LABEL
321-
if bb1.opers[0] in LABELS: # some labels does not exist (e.g. immediate numeric addresses)
322-
stack.append(LABELS[bb1.opers[0]].basic_block)
323-
else:
324-
raise OptimizerError("Unknown block label '{}'".format(bb1.opers[0]))
325-
326-
bb = bb.next # next contiguous block
327-
328-
def is_used(self, regs, i, top=None):
239+
def is_used(self, regs: list[str], i: int, top: int | None = None) -> bool:
329240
"""Checks whether any of the given regs are required from the given point
330241
to the end or not.
331242
"""
332-
if i < 0:
333-
i = 0
334-
335243
if self.lock:
336244
return True
337245

338-
if top is None:
339-
top = len(self)
340-
else:
341-
top -= 1
246+
i = max(i, 0)
247+
top = len(self) if top is None else top + 1
342248

343249
if regs and regs[0][0] == "(" and regs[0][-1] == ")": # A memory address
344250
r16 = helpers.single_registers(regs[0][1:-1]) if helpers.is_16bit_oper_register(regs[0][1:-1]) else []
@@ -384,33 +290,17 @@ def is_used(self, regs, i, top=None):
384290

385291
return result
386292

387-
def safe_to_write(self, regs, i=0, end_=0):
388-
"""Given a list of registers (8 or 16 bits) returns a list of them
389-
that are safe to modify from the given index until the position given
390-
which, if omitted, defaults to the end of the block.
391-
:param regs: register or iterable of registers (8 or 16 bit one)
392-
:param i: initial position of the block to examine
393-
:param end_: final position to examine
394-
:returns: registers safe to write
395-
"""
396-
if helpers.is_register(regs):
397-
regs = set(helpers.single_registers(regs))
398-
else:
399-
regs = set(helpers.single_registers(x) for x in regs)
400-
return not regs.intersection(self.requires(i, end_))
401-
402-
def requires(self, i=0, end_=None):
293+
def requires(self, i: int = 0, end_: int | None = None) -> set[str]:
403294
"""Returns a list of registers and variables this block requires.
404295
By default checks from the beginning (i = 0).
405296
:param i: initial position of the block to examine
406297
:param end_: final position to examine
407298
:returns: registers safe to write
408299
"""
409-
if i < 0:
410-
i = 0
300+
i = max(i, 0)
411301
end_ = len(self) if end_ is None or end_ > len(self) else end_
412302
regs = {"a", "b", "c", "d", "e", "f", "h", "l", "i", "ixh", "ixl", "iyh", "iyl", "sp"}
413-
result = set()
303+
result: set[str] = set()
414304

415305
for ii in range(i, end_):
416306
for r in self.mem[ii].requires:
@@ -429,13 +319,13 @@ def requires(self, i=0, end_=None):
429319

430320
return result
431321

432-
def destroys(self, i=0):
322+
def destroys(self, i: int = 0) -> list[str]:
433323
"""Returns a list of registers this block destroys
434324
By default checks from the beginning (i = 0).
435325
"""
436326
regs = {"a", "b", "c", "d", "e", "f", "h", "l", "i", "ixh", "ixl", "iyh", "iyl", "sp"}
437327
top = len(self)
438-
result = []
328+
result: list[str] = []
439329

440330
for ii in range(i, top):
441331
for r in self.mem[ii].destroys:
@@ -448,10 +338,6 @@ def destroys(self, i=0):
448338

449339
return result
450340

451-
def swap(self, a: int, b: int) -> None:
452-
"""Swaps mem positions a and b"""
453-
self.mem[a], self.mem[b] = self.mem[b], self.mem[a]
454-
455341
def goes_requires(self, regs):
456342
"""Returns whether any of the goes_to block requires any of
457343
the given registers.
@@ -462,16 +348,6 @@ def goes_requires(self, regs):
462348

463349
return False
464350

465-
def get_label_idx(self, label):
466-
"""Returns the index of a label.
467-
Returns None if not found.
468-
"""
469-
for i in range(len(self)):
470-
if self.mem[i].is_label and self.mem[i].inst == label:
471-
return i
472-
473-
return None
474-
475351
def get_first_non_label_instruction(self):
476352
"""Returns the memcell of the given block, which is
477353
not a LABEL.
@@ -563,7 +439,7 @@ def optimize(self, patterns_list):
563439
if not p.cond.eval(match):
564440
continue
565441

566-
# all patterns applied successfully. Apply this pattern
442+
# all patterns matched successfully. Apply this rule
567443
new_code = list(code)
568444
matched = new_code[i : i + len(p.patt)]
569445
new_code[i : i + len(p.patt)] = p.template.filter(match)
@@ -590,60 +466,22 @@ class DummyBasicBlock(BasicBlock):
590466
about what registers uses an destroys
591467
"""
592468

593-
def __init__(self, destroys, requires):
469+
def __init__(self, destroys: Iterable[str], requires: Iterable[str]):
594470
BasicBlock.__init__(self, [])
595-
self.__destroys = [x for x in destroys]
596-
self.__requires = [x for x in requires]
471+
self.__destroys = tuple(destroys)
472+
self.__requires = set(requires)
473+
self.code = ["ret"]
597474

598-
def destroys(self, i: int = 0):
599-
return [x for x in self.__destroys]
475+
def destroys(self, i: int = 0) -> list[str]:
476+
return list(self.__destroys)
600477

601-
def requires(self, i: int = 0, end_=None):
602-
return [x for x in self.__requires]
478+
def requires(self, i: int = 0, end_=None) -> set[str]:
479+
return set(self.__requires)
603480

604-
def is_used(self, regs, i, top=None):
481+
def is_used(self, regs: Iterable[str], i: int, top: int | None = None) -> bool:
605482
return len([x for x in regs if x in self.__requires]) > 0
606483

607484

608-
def block_partition(block, i):
609-
"""Returns two blocks, as a result of partitioning the given one at
610-
i-th instruction.
611-
"""
612-
i += 1
613-
new_block = BasicBlock([])
614-
new_block.mem = block.mem[i:]
615-
block.mem = block.mem[:i]
616-
617-
for label, lbl_info in LABELS.items():
618-
if lbl_info.basic_block != block or lbl_info.position < len(block):
619-
continue
620-
621-
lbl_info.basic_block = new_block
622-
lbl_info.position -= len(block)
623-
624-
for b_ in list(block.goes_to):
625-
block.delete_goes_to(b_)
626-
new_block.add_goes_to(b_)
627-
628-
new_block.label_goes = block.label_goes
629-
block.label_goes = []
630-
631-
new_block.next = block.next
632-
new_block.prev = block
633-
block.next = new_block
634-
new_block.add_comes_from(block)
635-
636-
if new_block.next is not None:
637-
new_block.next.prev = new_block
638-
if block in new_block.next.comes_from:
639-
new_block.next.delete_comes_from(block)
640-
new_block.next.add_comes_from(new_block)
641-
642-
block.update_next_block()
643-
644-
return block, new_block
645-
646-
647485
def split_block(block: BasicBlock, start_of_new_block: int) -> tuple[BasicBlock, BasicBlock]:
648486
assert 0 <= start_of_new_block < len(block), f"Invalid split pos: {start_of_new_block}"
649487
new_block = BasicBlock([])
@@ -680,13 +518,11 @@ def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None
680518

681519
# Compute which blocks use jump labels
682520
for bb in basic_blocks:
683-
if bb[-1].is_ender:
684-
for op in bb[-1].opers:
685-
if op in LABELS:
686-
LABELS[op].used_by.add(bb)
521+
if bb[-1].is_ender and (op := bb[-1].branch_arg) in LABELS:
522+
LABELS[op].used_by.add(bb)
687523

688524
# For these blocks, add the referenced block in the goes_to
689-
for label in JUMP_LABELS:
525+
for label in jump_labels:
690526
for bb in LABELS[label].used_by:
691527
bb.add_goes_to(LABELS[label].basic_block)
692528

@@ -695,11 +531,10 @@ def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None
695531
if bb[-1].inst != "call":
696532
continue
697533

698-
for op in bb[-1].opers:
699-
if op in LABELS:
700-
LABELS[op].basic_block.called_by.add(bb)
701-
calling_blocks[bb] = LABELS[op].basic_block
702-
break
534+
op = bb[-1].branch_arg
535+
if op in LABELS:
536+
LABELS[op].basic_block.called_by.add(bb)
537+
calling_blocks[bb] = LABELS[op].basic_block
703538

704539
# For the annotated blocks, trace their goes_to, and their goes_to from
705540
# their goes_to and so on, until ret (unconditional or not) is found, and
@@ -751,9 +586,15 @@ def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:
751586
if not mem.is_ender:
752587
continue
753588

754-
for op in mem.opers:
755-
if op in LABELS:
756-
jump_labels.add(op)
589+
lbl = mem.branch_arg
590+
if lbl is None:
591+
continue
592+
593+
jump_labels.add(lbl)
594+
595+
if lbl not in LABELS:
596+
__DEBUG__(f"INFO: {lbl} is not defined. No optimization is done.", 2)
597+
LABELS[lbl] = LabelInfo(lbl, 0, DummyBasicBlock(ALL_REGS, ALL_REGS))
757598

758599
return jump_labels
759600

src/arch/z80/optimizer/common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import TYPE_CHECKING
55

66
if TYPE_CHECKING:
7-
from labelinfo import LabelInfo
7+
from .labelinfo import LabelInfo
88

99
# counter for generating unique random fake values
1010
RAND_COUNT = 0

0 commit comments

Comments
 (0)