|
5 | 5 |
|
6 | 6 | from __future__ import division |
7 | 7 |
|
| 8 | +import functools |
8 | 9 | import math |
9 | 10 | from warnings import warn |
10 | 11 |
|
@@ -560,16 +561,30 @@ def conv_encode(message_bits, trellis, termination = 'term', puncture_matrix=Non |
560 | 561 | def _where_c(inarray, rows, cols, search_value, index_array): |
561 | 562 |
|
562 | 563 | number_found = 0 |
563 | | - for i in range(rows): |
564 | | - for j in range(cols): |
565 | | - if inarray[i, j] == search_value: |
566 | | - index_array[number_found, 0] = i |
567 | | - index_array[number_found, 1] = j |
568 | | - number_found += 1 |
| 564 | + res = np.where(inarray == search_value) |
| 565 | + i_s, j_s = res |
| 566 | + for i, j in zip(i_s, j_s): |
| 567 | + if inarray[i, j] == search_value: |
| 568 | + index_array[number_found, 0] = i |
| 569 | + index_array[number_found, 1] = j |
| 570 | + number_found += 1 |
569 | 571 |
|
570 | 572 | return number_found |
571 | 573 |
|
572 | 574 |
|
| 575 | +@functools.lru_cache(maxsize=128, typed=False) |
| 576 | +def _compute_branch_metrics(decoding_type, r_codeword, i_codeword_array): |
| 577 | + if decoding_type == 'hard': |
| 578 | + return hamming_dist(r_codeword.astype(int), i_codeword_array.astype(int)) |
| 579 | + elif decoding_type == 'soft': |
| 580 | + neg_LL_0 = np.log(np.exp(r_codeword) + 1) # negative log-likelihood to have received a 0 |
| 581 | + neg_LL_1 = neg_LL_0 - r_codeword # negative log-likelihood to have received a 1 |
| 582 | + return np.where(i_codeword_array, neg_LL_1, neg_LL_0).sum() |
| 583 | + elif decoding_type == 'unquantized': |
| 584 | + i_codeword_array = 2 * i_codeword_array - 1 |
| 585 | + return euclid_dist(r_codeword, i_codeword_array) |
| 586 | + |
| 587 | + |
573 | 588 | def _acs_traceback(r_codeword, trellis, decoding_type, |
574 | 589 | path_metrics, paths, decoded_symbols, |
575 | 590 | decoded_bits, tb_count, t, count, |
@@ -605,15 +620,7 @@ def _acs_traceback(r_codeword, trellis, decoding_type, |
605 | 620 | i_codeword_array = dec2bitarray(i_codeword, n) |
606 | 621 |
|
607 | 622 | # Compute Branch Metrics |
608 | | - if decoding_type == 'hard': |
609 | | - branch_metric = hamming_dist(r_codeword.astype(int), i_codeword_array.astype(int)) |
610 | | - elif decoding_type == 'soft': |
611 | | - neg_LL_0 = np.log(np.exp(r_codeword) + 1) # negative log-likelihood to have received a 0 |
612 | | - neg_LL_1 = neg_LL_0 - r_codeword # negative log-likelihood to have received a 1 |
613 | | - branch_metric = np.where(i_codeword_array, neg_LL_1, neg_LL_0).sum() |
614 | | - elif decoding_type == 'unquantized': |
615 | | - i_codeword_array = 2*i_codeword_array - 1 |
616 | | - branch_metric = euclid_dist(r_codeword, i_codeword_array) |
| 623 | + branch_metric = _compute_branch_metrics(decoding_type, tuple(r_codeword), tuple(i_codeword_array)) |
617 | 624 |
|
618 | 625 | # ADD operation: Add the branch metric to the |
619 | 626 | # accumulated path metric and store it in the temporary array |
|
0 commit comments