|
1 | 1 | import logging |
2 | 2 |
|
3 | | -TYPE_SERVER=0 |
4 | | -TYPE_CLIENT=1 |
5 | | - |
6 | 3 | class Connection: |
7 | | - def __init__(self, sock, target, name): |
| 4 | + def __init__(self, sock, address, name, events, context): |
8 | 5 | self.sock = sock |
9 | | - self.target = target |
| 6 | + self.address = address |
10 | 7 | self.name = name |
11 | | - |
12 | | - def send(self, message): |
13 | | - logging.debug("sending message to {}:\n{}".format(self.name, message)) |
14 | | - total = len(message) |
15 | | - total_sent = 0 |
16 | | - remaining = message |
17 | | - while total_sent < total: |
18 | | - sent = self.sock.send(remaining) |
19 | | - total_sent += sent |
20 | | - remaining = remaining[sent:] |
21 | | - |
22 | | - def __receive_raw(self, length): |
23 | | - total_received = 0 |
24 | | - chunks = [] |
25 | | - while total_received < length: |
26 | | - chunk = self.sock.recv(min([length - total_received]), 4096) |
27 | | - if chunk==b'': |
28 | | - raise RuntimeError("socket connection broken") |
29 | | - chunks.append(chunk) |
30 | | - total_received += len(chunk) |
31 | | - return b''.join(chunks) |
32 | | - |
33 | | - |
34 | | - def receive_packet(self): |
35 | | - pack_type = self.__receive_raw(1) |
36 | | - if pack_type == b'N': |
37 | | - # Null message? This message has no length. Just a single byte. Weird. |
38 | | - return pack_type, pack_type |
39 | | - if pack_type == b'\x00': |
40 | | - # Initialization packet. No type. This, and the next 3 bytes are the length |
41 | | - pack_length = self.__receive_raw(3) |
42 | | - pack_length = b''.join([pack_type, pack_length]) |
43 | | - pack_header = pack_length |
44 | | - else: |
45 | | - pack_length = self.__receive_raw(4) |
46 | | - pack_header = b''.join([pack_type, pack_length]) |
47 | | - pack_length = int.from_bytes(pack_length, 'big') |
48 | | - pack_body = self.__receive_raw(pack_length - 4) |
49 | | - pack = b''.join([pack_header, pack_body]) |
50 | | - return pack, pack_type |
51 | | - |
52 | | - |
53 | | - def receive(self): |
54 | | - packets = [] |
55 | | - logging.debug("receive message from {}:".format(self.name)) |
| 8 | + self.events = events |
| 9 | + self.context = context |
| 10 | + self.is_reading = False |
| 11 | + self.is_writing = False |
| 12 | + self.interceptor = None |
| 13 | + self.redirect_conn = None |
| 14 | + self.out_bytes = b'' |
| 15 | + self.in_bytes = b'' |
| 16 | + |
| 17 | + def parse_length(self, length_bytes): |
| 18 | + return int.from_bytes(length_bytes, 'big') |
| 19 | + |
| 20 | + def encode_length(self, length): |
| 21 | + return length.to_bytes(4, byteorder='big') |
| 22 | + |
| 23 | + def received(self, in_bytes): |
| 24 | + self.in_bytes += in_bytes |
| 25 | + # Read packet from byte array while there are enough bytes to make up a packet. |
| 26 | + # Otherwise wait for more bytes to be received (break and exit) |
56 | 27 | while True: |
57 | | - logging.debug("receive packet from {}:".format(self.name)) |
58 | | - packet, pack_type = self.receive_packet() |
59 | | - packets.append(packet) |
60 | | - logging.debug("received packet from {}:\n{}".format(self.name, packet)) |
61 | | - if not pack_type in (b'B', b'D', b'P', b'E'): |
| 28 | + ptype = self.in_bytes[0:1] |
| 29 | + if ptype == b'\x00': |
| 30 | + if len(self.in_bytes) < 4: |
| 31 | + break |
| 32 | + header_length = 4 |
| 33 | + body_length = self.parse_length(self.in_bytes[0:4]) - 4 |
| 34 | + elif ptype == b'N': |
| 35 | + header_length = 1 |
| 36 | + body_length = 0 |
| 37 | + else: |
| 38 | + if len(self.in_bytes) < 5: |
| 39 | + break |
| 40 | + header_length = 5 |
| 41 | + body_length = self.parse_length(self.in_bytes[1:5]) - 4 |
| 42 | + |
| 43 | + length = header_length + body_length |
| 44 | + if len(self.in_bytes) < length: |
62 | 45 | break |
63 | | - message = b''.join(packets) |
64 | | - logging.debug("received message from {}:\n{}".format(self.name, message)) |
65 | | - return message |
| 46 | + header = self.in_bytes[0:header_length] |
| 47 | + body = self.in_bytes[header_length:length] |
| 48 | + self.process_inbound_packet(header, body) |
| 49 | + self.in_bytes = self.in_bytes[length:] |
| 50 | + |
| 51 | + def process_inbound_packet(self, header, body): |
| 52 | + if header != b'N': |
| 53 | + packet_type = header[0:-4] |
| 54 | + logging.info("intercepting packet of type '%s' from %s", packet_type, self.name) |
| 55 | + body = self.interceptor.intercept(packet_type, body) |
| 56 | + header = packet_type + self.encode_length(len(body) + 4) |
| 57 | + message = header + body |
| 58 | + logging.debug("Received message. Relaying. Speaker: %s, message:\n%s", self.name, message) |
| 59 | + self.redirect_conn.out_bytes += message |
| 60 | + |
| 61 | + def sent(self, num_bytes): |
| 62 | + self.out_bytes = self.out_bytes[num_bytes:] |
0 commit comments