Skip to content

Commit 21c91d9

Browse files
committed
refactor: add decode_msg method to BaseFunASRClient
1 parent 037d998 commit 21c91d9

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

src/funasr_client/async_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from .base_client import BaseFunASRClient
1717
from .types import FunASRMessageLike
18-
from .utils import sync_to_async, decode_msg, is_final_msg
18+
from .utils import sync_to_async, is_final_msg
1919

2020

2121
module_logger = logging.getLogger(__name__)
@@ -148,7 +148,7 @@ async def _recv(self, timeout: Optional[float] = None):
148148
raise TimeoutError() from e.__cause__
149149
response = json.loads(message)
150150
if self.decode:
151-
response = decode_msg(response, self.start_time)
151+
response = self.decode_msg(response)
152152
response = cast(MessageType, response)
153153
if is_final_msg(response):
154154
self._received_final = True

src/funasr_client/base_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from websockets.sync.client import connect as ws_connect
99
from websockets.sync.connection import Connection
1010

11-
from .types import FunASRMessageLike, InitMessageMode
12-
from .utils import typed_params
11+
from .types import FunASRMessage, FunASRMessageLike, InitMessageMode
12+
from .utils import typed_params, decode_msg
1313

1414

1515
MessageType = TypeVar("MessageType", bound=FunASRMessageLike)
@@ -167,3 +167,6 @@ def _get_connect_params(self):
167167
ssl=ssl_context,
168168
subprotocols=["binary"], # type: ignore
169169
)
170+
171+
def decode_msg(self, msg: FunASRMessage):
172+
return decode_msg(msg, start_time=self.start_time)

src/funasr_client/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from .base_client import BaseFunASRClient
1717
from .types import FunASRMessageLike
18-
from .utils import decode_msg, is_final_msg
18+
from .utils import is_final_msg
1919

2020

2121
module_logger = logging.getLogger(__name__)
@@ -130,7 +130,7 @@ def _recv(self, timeout: Optional[float] = None):
130130
message = self._ws.recv(timeout=timeout)
131131
response = json.loads(message)
132132
if self.decode:
133-
response = decode_msg(response, self.start_time)
133+
response = self.decode_msg(response)
134134
response = cast(MessageType, response)
135135
if is_final_msg(response):
136136
self._received_final = True

0 commit comments

Comments
 (0)