Skip to content

Commit 7244771

Browse files
SocketWrapper: use shared_ptr for sock_fd to prevent unwanted closure
1 parent 316fa77 commit 7244771

File tree

2 files changed

+29
-22
lines changed

2 files changed

+29
-22
lines changed

libraries/SocketWrapper/SocketWrapper.h

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,40 @@
77
#endif
88

99
#include <zephyr/net/socket.h>
10+
#include <memory>
11+
#include <cstring>
1012

1113
class ZephyrSocketWrapper {
1214
protected:
13-
int* sock_fd = nullptr;
15+
std::shared_ptr<int> sock_fd;
1416
bool is_ssl = false;
1517
int ssl_sock_temp_char = -1;
1618

17-
public:
18-
ZephyrSocketWrapper() {
19-
sock_fd = new int(-1);
19+
// custom deleter for shared_ptr to close automatically the socket
20+
static auto socket_deleter() {
21+
return [](int *fd) {
22+
if (fd && *fd != -1) {
23+
::close(*fd);
24+
delete fd;
25+
}
26+
};
2027
}
2128

22-
ZephyrSocketWrapper(int fd) {
23-
sock_fd = new int(fd);
24-
}
29+
public:
30+
ZephyrSocketWrapper() = default;
2531

26-
~ZephyrSocketWrapper() {
27-
if (sock_fd && *sock_fd != -1) {
28-
::close(*sock_fd);
29-
}
30-
delete sock_fd;
31-
sock_fd = nullptr;
32+
ZephyrSocketWrapper(int fd) : sock_fd(std::shared_ptr<int>(fd<0 ? nullptr : new int(fd), socket_deleter())) {
3233
}
3334

35+
~ZephyrSocketWrapper() = default; // socket close managed by shared_ptr
36+
3437
bool connect(const char *host, uint16_t port) {
3538

3639
// Resolve address
3740
struct addrinfo hints = {0};
3841
struct addrinfo *res = nullptr;
3942
bool rv = true;
43+
int raw_sock_fd;
4044

4145
hints.ai_family = AF_INET;
4246
hints.ai_socktype = SOCK_STREAM;
@@ -59,7 +63,8 @@ class ZephyrSocketWrapper {
5963
goto exit;
6064
}
6165

62-
*sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
66+
raw_sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
67+
sock_fd = std::shared_ptr<int>(raw_sock_fd < 0 ? nullptr : new int(raw_sock_fd), socket_deleter());
6368
if (!sock_fd || *sock_fd < 0) {
6469
rv = false;
6570

@@ -85,13 +90,15 @@ class ZephyrSocketWrapper {
8590
bool connect(IPAddress host, uint16_t port) {
8691

8792
const char *_host = host.toString().c_str();
93+
int raw_sock_fd;
8894

8995
struct sockaddr_in addr;
9096
addr.sin_family = AF_INET;
9197
addr.sin_port = htons(port);
9298
inet_pton(AF_INET, _host, &addr.sin_addr);
9399

94-
*sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
100+
raw_sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
101+
sock_fd = std::shared_ptr<int>(raw_sock_fd < 0 ? nullptr : new int(raw_sock_fd), socket_deleter());
95102
if (!sock_fd || *sock_fd < 0) {
96103
return false;
97104
}
@@ -118,6 +125,7 @@ class ZephyrSocketWrapper {
118125
int resolve_attempts = 100;
119126
int ret;
120127
bool rv = false;
128+
int raw_sock_fd;
121129

122130
sec_tag_t sec_tag_opt[] = {
123131
CA_CERTIFICATE_TAG,
@@ -150,7 +158,8 @@ class ZephyrSocketWrapper {
150158
}
151159
}
152160

153-
*sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TLS_1_2);
161+
raw_sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TLS_1_2);
162+
sock_fd = std::shared_ptr<int>(raw_sock_fd < 0 ? nullptr : new int(raw_sock_fd), socket_deleter());
154163
if (!sock_fd || *sock_fd < 0) {
155164
goto exit;
156165
}
@@ -238,11 +247,13 @@ class ZephyrSocketWrapper {
238247

239248
bool bind(uint16_t port) {
240249
struct sockaddr_in addr;
250+
int raw_sock_fd;
241251
addr.sin_family = AF_INET;
242252
addr.sin_port = htons(port);
243253
addr.sin_addr.s_addr = INADDR_ANY;
244254

245-
*sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
255+
raw_sock_fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
256+
sock_fd = std::shared_ptr<int>(raw_sock_fd < 0 ? nullptr : new int(raw_sock_fd), socket_deleter());
246257
if (!sock_fd || *sock_fd < 0) {
247258
return false;
248259
}

libraries/SocketWrapper/ZephyrClient.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@ class ZephyrClient : public arduino::Client, ZephyrSocketWrapper {
1111

1212
protected:
1313
void setSocket(int sock) {
14-
if (sock_fd) {
15-
*sock_fd = sock;
16-
} else {
17-
sock_fd = new int(sock);
18-
}
14+
sock_fd = std::shared_ptr<int>(sock < 0 ? nullptr : new int(sock), socket_deleter());
1915
_connected = true;
2016
}
2117

0 commit comments

Comments
 (0)