Skip to content

Commit b104c99

Browse files
authored
Merge pull request #74 from eSoares/speed-improvements
Speed improvements
2 parents 210b973 + bcc36dd commit b104c99

File tree

3 files changed

+29
-20
lines changed

3 files changed

+29
-20
lines changed

commpy/channelcoding/convcode.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from __future__ import division
77

8+
import functools
89
import math
910
from warnings import warn
1011

@@ -560,16 +561,30 @@ def conv_encode(message_bits, trellis, termination = 'term', puncture_matrix=Non
560561
def _where_c(inarray, rows, cols, search_value, index_array):
561562

562563
number_found = 0
563-
for i in range(rows):
564-
for j in range(cols):
565-
if inarray[i, j] == search_value:
566-
index_array[number_found, 0] = i
567-
index_array[number_found, 1] = j
568-
number_found += 1
564+
res = np.where(inarray == search_value)
565+
i_s, j_s = res
566+
for i, j in zip(i_s, j_s):
567+
if inarray[i, j] == search_value:
568+
index_array[number_found, 0] = i
569+
index_array[number_found, 1] = j
570+
number_found += 1
569571

570572
return number_found
571573

572574

575+
@functools.lru_cache(maxsize=128, typed=False)
576+
def _compute_branch_metrics(decoding_type, r_codeword, i_codeword_array):
577+
if decoding_type == 'hard':
578+
return hamming_dist(r_codeword.astype(int), i_codeword_array.astype(int))
579+
elif decoding_type == 'soft':
580+
neg_LL_0 = np.log(np.exp(r_codeword) + 1) # negative log-likelihood to have received a 0
581+
neg_LL_1 = neg_LL_0 - r_codeword # negative log-likelihood to have received a 1
582+
return np.where(i_codeword_array, neg_LL_1, neg_LL_0).sum()
583+
elif decoding_type == 'unquantized':
584+
i_codeword_array = 2 * i_codeword_array - 1
585+
return euclid_dist(r_codeword, i_codeword_array)
586+
587+
573588
def _acs_traceback(r_codeword, trellis, decoding_type,
574589
path_metrics, paths, decoded_symbols,
575590
decoded_bits, tb_count, t, count,
@@ -605,15 +620,7 @@ def _acs_traceback(r_codeword, trellis, decoding_type,
605620
i_codeword_array = dec2bitarray(i_codeword, n)
606621

607622
# Compute Branch Metrics
608-
if decoding_type == 'hard':
609-
branch_metric = hamming_dist(r_codeword.astype(int), i_codeword_array.astype(int))
610-
elif decoding_type == 'soft':
611-
neg_LL_0 = np.log(np.exp(r_codeword) + 1) # negative log-likelihood to have received a 0
612-
neg_LL_1 = neg_LL_0 - r_codeword # negative log-likelihood to have received a 1
613-
branch_metric = np.where(i_codeword_array, neg_LL_1, neg_LL_0).sum()
614-
elif decoding_type == 'unquantized':
615-
i_codeword_array = 2*i_codeword_array - 1
616-
branch_metric = euclid_dist(r_codeword, i_codeword_array)
623+
branch_metric = _compute_branch_metrics(decoding_type, tuple(r_codeword), tuple(i_codeword_array))
617624

618625
# ADD operation: Add the branch metric to the
619626
# accumulated path metric and store it in the temporary array

commpy/links.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def link_performance(self, SNRs, send_max, err_min, send_chunk=None, code_rate=1
203203
# Deals with MIMO channel
204204
if isinstance(self.channel, MIMOFlatChannel):
205205
nb_symb_vector = len(channel_output)
206-
received_msg = np.empty(int(math.ceil(len(msg) / self.rate)))
206+
received_msg = np.empty(int(math.ceil(len(msg) / self.rate)), dtype=np.int8)
207207
for i in range(nb_symb_vector):
208208
received_msg[receive_size * i:receive_size * (i + 1)] = \
209209
self.receive(channel_output[i], self.channel.channel_gains[i],
@@ -216,9 +216,9 @@ def link_performance(self, SNRs, send_max, err_min, send_chunk=None, code_rate=1
216216
decoded_bits = self.decoder(channel_output, self.channel.channel_gains,
217217
self.constellation, self.channel.noise_std ** 2,
218218
received_msg, self.channel.nb_tx * self.num_bits_symbol)
219-
bit_err += (msg != decoded_bits[:len(msg)]).sum()
219+
bit_err += np.bitwise_xor(msg, decoded_bits[:len(msg)]).sum()
220220
else:
221-
bit_err += (msg != self.decoder(received_msg)[:len(msg)]).sum()
221+
bit_err += np.bitwise_xor(msg, self.decoder(received_msg)[:len(msg)]).sum()
222222
bit_send += send_chunk
223223
BERs[id_SNR] = bit_err / bit_send
224224
if bit_err < err_min:

commpy/utilities.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
upsample -- Upsample by an integral factor (zero insertion).
1818
signal_power -- Compute the power of a discrete time signal.
1919
"""
20+
import functools
2021

2122
import numpy as np
2223

@@ -47,13 +48,14 @@ def dec2bitarray(in_number, bit_width):
4748
"""
4849

4950
if isinstance(in_number, (np.integer, int)):
50-
return decimal2bitarray(in_number, bit_width)
51+
return decimal2bitarray(in_number, bit_width).copy()
5152
result = np.zeros(bit_width * len(in_number), np.int8)
5253
for pox, number in enumerate(in_number):
53-
result[pox * bit_width:(pox + 1) * bit_width] = decimal2bitarray(number, bit_width)
54+
result[pox * bit_width:(pox + 1) * bit_width] = decimal2bitarray(number, bit_width).copy()
5455
return result
5556

5657

58+
@functools.lru_cache(maxsize=128, typed=False)
5759
def decimal2bitarray(number, bit_width):
5860
"""
5961
Converts a positive integer to NumPy array of the specified size containing bits (0 and 1). This version is slightly

0 commit comments

Comments
 (0)