|
5 | 5 |
|
6 | 6 | from __future__ import division |
7 | 7 |
|
| 8 | +import functools |
8 | 9 | import math |
9 | 10 | from warnings import warn |
10 | 11 |
|
@@ -570,6 +571,19 @@ def _where_c(inarray, rows, cols, search_value, index_array): |
570 | 571 | return number_found |
571 | 572 |
|
572 | 573 |
|
| 574 | +@functools.lru_cache(maxsize=128, typed=False) |
| 575 | +def _compute_branch_metrics(decoding_type, r_codeword, i_codeword_array): |
| 576 | + if decoding_type == 'hard': |
| 577 | + return hamming_dist(r_codeword.astype(int), i_codeword_array.astype(int)) |
| 578 | + elif decoding_type == 'soft': |
| 579 | + neg_LL_0 = np.log(np.exp(r_codeword) + 1) # negative log-likelihood to have received a 0 |
| 580 | + neg_LL_1 = neg_LL_0 - r_codeword # negative log-likelihood to have received a 1 |
| 581 | + return np.where(i_codeword_array, neg_LL_1, neg_LL_0).sum() |
| 582 | + elif decoding_type == 'unquantized': |
| 583 | + i_codeword_array = 2 * i_codeword_array - 1 |
| 584 | + return euclid_dist(r_codeword, i_codeword_array) |
| 585 | + |
| 586 | + |
573 | 587 | def _acs_traceback(r_codeword, trellis, decoding_type, |
574 | 588 | path_metrics, paths, decoded_symbols, |
575 | 589 | decoded_bits, tb_count, t, count, |
@@ -605,15 +619,7 @@ def _acs_traceback(r_codeword, trellis, decoding_type, |
605 | 619 | i_codeword_array = dec2bitarray(i_codeword, n) |
606 | 620 |
|
607 | 621 | # Compute Branch Metrics |
608 | | - if decoding_type == 'hard': |
609 | | - branch_metric = hamming_dist(r_codeword.astype(int), i_codeword_array.astype(int)) |
610 | | - elif decoding_type == 'soft': |
611 | | - neg_LL_0 = np.log(np.exp(r_codeword) + 1) # negative log-likelihood to have received a 0 |
612 | | - neg_LL_1 = neg_LL_0 - r_codeword # negative log-likelihood to have received a 1 |
613 | | - branch_metric = np.where(i_codeword_array, neg_LL_1, neg_LL_0).sum() |
614 | | - elif decoding_type == 'unquantized': |
615 | | - i_codeword_array = 2*i_codeword_array - 1 |
616 | | - branch_metric = euclid_dist(r_codeword, i_codeword_array) |
| 622 | + branch_metric = _compute_branch_metrics(decoding_type, tuple(r_codeword), tuple(i_codeword_array)) |
617 | 623 |
|
618 | 624 | # ADD operation: Add the branch metric to the |
619 | 625 | # accumulated path metric and store it in the temporary array |
|
0 commit comments