1010from src .arch .z80 .optimizer import helpers
1111from src .arch .z80 .optimizer .common import JUMP_LABELS , LABELS
1212from src .arch .z80 .optimizer .cpustate import CPUState
13- from src .arch .z80 .optimizer .errors import OptimizerError
1413from src .arch .z80 .optimizer .helpers import ALL_REGS
1514from src .arch .z80 .optimizer .labelinfo import LabelInfo
1615from src .arch .z80 .optimizer .memcell import MemCell
@@ -237,108 +236,15 @@ def update_next_block(self):
237236 n_block = LABELS [last .opers [0 ]].basic_block
238237 self .add_goes_to (n_block )
239238
240- def update_used_by_list (self ):
241- """Every label has a set containing
242- which blocks jumps (jp, jr, call) if any.
243- A block can "use" (call/jump) only another block
244- and only one"""
245-
246- # Searches all labels and remove this block out
247- # of their used_by set, since this might have changed
248- for label in LABELS .values ():
249- label .used_by .remove (self ) # Delete this bblock
250-
251- def update_goes_and_comes (self ):
252- """Once the block is a Basic one, check the last instruction and updates
253- goes_to and comes_from set of the receivers.
254- Note: jp, jr and ret are already done in update_next_block()
255- """
256- if not len (self ):
257- return
258-
259- last = self .mem [- 1 ]
260- inst = last .inst
261- oper = last .opers
262- cond = last .condition_flag
263-
264- for blk in list (self .goes_to ):
265- self .delete_goes_to (blk )
266-
267- if self .next :
268- self .add_goes_to (self .next )
269-
270- if not last .is_ender :
271- return
272-
273- if cond is None :
274- self .delete_goes_to (self .next )
275-
276- if last .inst in {"ret" , "reti" , "retn" } and cond is None :
277- return # subroutine returns are updated from CALLer blocks
278-
279- if oper and oper [0 ]:
280- if oper [0 ] not in LABELS :
281- __DEBUG__ ("INFO: %s is not defined. No optimization is done." % oper [0 ], 1 )
282- LABELS [oper [0 ]] = LabelInfo (oper [0 ], 0 , DummyBasicBlock (ALL_REGS , ALL_REGS ))
283-
284- LABELS [oper [0 ]].used_by .add (self )
285- self .add_goes_to (LABELS [oper [0 ]].basic_block )
286-
287- if inst in {"djnz" , "jp" , "jr" }:
288- return
289-
290- assert inst in ("call" , "rst" )
291-
292- if self .next is None :
293- raise OptimizerError ("Unexpected NULL next block" )
294-
295- final_blk = self .next # The block all the final returns should go to
296- stack = [LABELS [oper [0 ]].basic_block ]
297- bbset : set [BasicBlock ] = set ()
298-
299- while stack :
300- bb = stack .pop (0 )
301- while True :
302- if bb is None :
303- bb = DummyBasicBlock (ALL_REGS , ALL_REGS )
304-
305- if bb in bbset :
306- break
307-
308- bbset .add (bb )
309-
310- if isinstance (bb , DummyBasicBlock ):
311- bb .add_goes_to (final_blk )
312- break
313-
314- if bb :
315- bb1 = bb [- 1 ]
316- if bb1 .inst in {"ret" , "reti" , "retn" }:
317- bb .add_goes_to (final_blk )
318- if bb1 .condition_flag is None : # 'ret'
319- break
320- elif bb1 .inst in ("jp" , "jr" ) and bb1 .condition_flag is not None : # jp/jr nc/nz/.. LABEL
321- if bb1 .opers [0 ] in LABELS : # some labels does not exist (e.g. immediate numeric addresses)
322- stack .append (LABELS [bb1 .opers [0 ]].basic_block )
323- else :
324- raise OptimizerError ("Unknown block label '{}'" .format (bb1 .opers [0 ]))
325-
326- bb = bb .next # next contiguous block
327-
328- def is_used (self , regs , i , top = None ):
239+ def is_used (self , regs : list [str ], i : int , top : int | None = None ) -> bool :
329240 """Checks whether any of the given regs are required from the given point
330241 to the end or not.
331242 """
332- if i < 0 :
333- i = 0
334-
335243 if self .lock :
336244 return True
337245
338- if top is None :
339- top = len (self )
340- else :
341- top -= 1
246+ i = max (i , 0 )
247+ top = len (self ) if top is None else top + 1
342248
343249 if regs and regs [0 ][0 ] == "(" and regs [0 ][- 1 ] == ")" : # A memory address
344250 r16 = helpers .single_registers (regs [0 ][1 :- 1 ]) if helpers .is_16bit_oper_register (regs [0 ][1 :- 1 ]) else []
@@ -384,33 +290,17 @@ def is_used(self, regs, i, top=None):
384290
385291 return result
386292
387- def safe_to_write (self , regs , i = 0 , end_ = 0 ):
388- """Given a list of registers (8 or 16 bits) returns a list of them
389- that are safe to modify from the given index until the position given
390- which, if omitted, defaults to the end of the block.
391- :param regs: register or iterable of registers (8 or 16 bit one)
392- :param i: initial position of the block to examine
393- :param end_: final position to examine
394- :returns: registers safe to write
395- """
396- if helpers .is_register (regs ):
397- regs = set (helpers .single_registers (regs ))
398- else :
399- regs = set (helpers .single_registers (x ) for x in regs )
400- return not regs .intersection (self .requires (i , end_ ))
401-
402- def requires (self , i = 0 , end_ = None ):
293+ def requires (self , i : int = 0 , end_ : int | None = None ) -> set [str ]:
403294 """Returns a list of registers and variables this block requires.
404295 By default checks from the beginning (i = 0).
405296 :param i: initial position of the block to examine
406297 :param end_: final position to examine
407298 :returns: registers safe to write
408299 """
409- if i < 0 :
410- i = 0
300+ i = max (i , 0 )
411301 end_ = len (self ) if end_ is None or end_ > len (self ) else end_
412302 regs = {"a" , "b" , "c" , "d" , "e" , "f" , "h" , "l" , "i" , "ixh" , "ixl" , "iyh" , "iyl" , "sp" }
413- result = set ()
303+ result : set [ str ] = set ()
414304
415305 for ii in range (i , end_ ):
416306 for r in self .mem [ii ].requires :
@@ -429,13 +319,13 @@ def requires(self, i=0, end_=None):
429319
430320 return result
431321
432- def destroys (self , i = 0 ) :
322+ def destroys (self , i : int = 0 ) -> list [ str ] :
433323 """Returns a list of registers this block destroys
434324 By default checks from the beginning (i = 0).
435325 """
436326 regs = {"a" , "b" , "c" , "d" , "e" , "f" , "h" , "l" , "i" , "ixh" , "ixl" , "iyh" , "iyl" , "sp" }
437327 top = len (self )
438- result = []
328+ result : list [ str ] = []
439329
440330 for ii in range (i , top ):
441331 for r in self .mem [ii ].destroys :
@@ -448,10 +338,6 @@ def destroys(self, i=0):
448338
449339 return result
450340
451- def swap (self , a : int , b : int ) -> None :
452- """Swaps mem positions a and b"""
453- self .mem [a ], self .mem [b ] = self .mem [b ], self .mem [a ]
454-
455341 def goes_requires (self , regs ):
456342 """Returns whether any of the goes_to block requires any of
457343 the given registers.
@@ -462,16 +348,6 @@ def goes_requires(self, regs):
462348
463349 return False
464350
465- def get_label_idx (self , label ):
466- """Returns the index of a label.
467- Returns None if not found.
468- """
469- for i in range (len (self )):
470- if self .mem [i ].is_label and self .mem [i ].inst == label :
471- return i
472-
473- return None
474-
475351 def get_first_non_label_instruction (self ):
476352 """Returns the memcell of the given block, which is
477353 not a LABEL.
@@ -563,7 +439,7 @@ def optimize(self, patterns_list):
563439 if not p .cond .eval (match ):
564440 continue
565441
566- # all patterns applied successfully. Apply this pattern
442+ # all patterns matched successfully. Apply this rule
567443 new_code = list (code )
568444 matched = new_code [i : i + len (p .patt )]
569445 new_code [i : i + len (p .patt )] = p .template .filter (match )
@@ -590,60 +466,22 @@ class DummyBasicBlock(BasicBlock):
590466 about what registers uses an destroys
591467 """
592468
593- def __init__ (self , destroys , requires ):
469+ def __init__ (self , destroys : Iterable [ str ] , requires : Iterable [ str ] ):
594470 BasicBlock .__init__ (self , [])
595- self .__destroys = [x for x in destroys ]
596- self .__requires = [x for x in requires ]
471+ self .__destroys = tuple (destroys )
472+ self .__requires = set (requires )
473+ self .code = ["ret" ]
597474
598- def destroys (self , i : int = 0 ):
599- return [ x for x in self .__destroys ]
475+ def destroys (self , i : int = 0 ) -> list [ str ] :
476+ return list ( self .__destroys )
600477
601- def requires (self , i : int = 0 , end_ = None ):
602- return [ x for x in self .__requires ]
478+ def requires (self , i : int = 0 , end_ = None ) -> set [ str ] :
479+ return set ( self .__requires )
603480
604- def is_used (self , regs , i , top = None ):
481+ def is_used (self , regs : Iterable [ str ] , i : int , top : int | None = None ) -> bool :
605482 return len ([x for x in regs if x in self .__requires ]) > 0
606483
607484
608- def block_partition (block , i ):
609- """Returns two blocks, as a result of partitioning the given one at
610- i-th instruction.
611- """
612- i += 1
613- new_block = BasicBlock ([])
614- new_block .mem = block .mem [i :]
615- block .mem = block .mem [:i ]
616-
617- for label , lbl_info in LABELS .items ():
618- if lbl_info .basic_block != block or lbl_info .position < len (block ):
619- continue
620-
621- lbl_info .basic_block = new_block
622- lbl_info .position -= len (block )
623-
624- for b_ in list (block .goes_to ):
625- block .delete_goes_to (b_ )
626- new_block .add_goes_to (b_ )
627-
628- new_block .label_goes = block .label_goes
629- block .label_goes = []
630-
631- new_block .next = block .next
632- new_block .prev = block
633- block .next = new_block
634- new_block .add_comes_from (block )
635-
636- if new_block .next is not None :
637- new_block .next .prev = new_block
638- if block in new_block .next .comes_from :
639- new_block .next .delete_comes_from (block )
640- new_block .next .add_comes_from (new_block )
641-
642- block .update_next_block ()
643-
644- return block , new_block
645-
646-
647485def split_block (block : BasicBlock , start_of_new_block : int ) -> tuple [BasicBlock , BasicBlock ]:
648486 assert 0 <= start_of_new_block < len (block ), f"Invalid split pos: { start_of_new_block } "
649487 new_block = BasicBlock ([])
@@ -680,13 +518,11 @@ def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None
680518
681519 # Compute which blocks use jump labels
682520 for bb in basic_blocks :
683- if bb [- 1 ].is_ender :
684- for op in bb [- 1 ].opers :
685- if op in LABELS :
686- LABELS [op ].used_by .add (bb )
521+ if bb [- 1 ].is_ender and (op := bb [- 1 ].branch_arg ) in LABELS :
522+ LABELS [op ].used_by .add (bb )
687523
688524 # For these blocks, add the referenced block in the goes_to
689- for label in JUMP_LABELS :
525+ for label in jump_labels :
690526 for bb in LABELS [label ].used_by :
691527 bb .add_goes_to (LABELS [label ].basic_block )
692528
@@ -695,11 +531,10 @@ def compute_calls(basic_blocks: list[BasicBlock], jump_labels: set[str]) -> None
695531 if bb [- 1 ].inst != "call" :
696532 continue
697533
698- for op in bb [- 1 ].opers :
699- if op in LABELS :
700- LABELS [op ].basic_block .called_by .add (bb )
701- calling_blocks [bb ] = LABELS [op ].basic_block
702- break
534+ op = bb [- 1 ].branch_arg
535+ if op in LABELS :
536+ LABELS [op ].basic_block .called_by .add (bb )
537+ calling_blocks [bb ] = LABELS [op ].basic_block
703538
704539 # For the annotated blocks, trace their goes_to, and their goes_to from
705540 # their goes_to and so on, until ret (unconditional or not) is found, and
@@ -751,9 +586,15 @@ def get_jump_labels(main_basic_block: BasicBlock) -> set[str]:
751586 if not mem .is_ender :
752587 continue
753588
754- for op in mem .opers :
755- if op in LABELS :
756- jump_labels .add (op )
589+ lbl = mem .branch_arg
590+ if lbl is None :
591+ continue
592+
593+ jump_labels .add (lbl )
594+
595+ if lbl not in LABELS :
596+ __DEBUG__ (f"INFO: { lbl } is not defined. No optimization is done." , 2 )
597+ LABELS [lbl ] = LabelInfo (lbl , 0 , DummyBasicBlock (ALL_REGS , ALL_REGS ))
757598
758599 return jump_labels
759600
0 commit comments