diff options
-rw-r--r-- | gnutls_network.c | 296 |
1 files changed, 223 insertions, 73 deletions
diff --git a/gnutls_network.c b/gnutls_network.c index d10bda6..2c80d53 100644 --- a/gnutls_network.c +++ b/gnutls_network.c @@ -28,6 +28,7 @@ #include <arpa/inet.h> #include <errno.h> +#include <fcntl.h> #include <gnutls/gnutls.h> #include <poll.h> #include <pthread.h> @@ -39,6 +40,7 @@ #include "config.h" #include "general_network.h" #include "gnutls_network.h" +#include "main.h" struct gnutls_handle { gnutls_session_t session; @@ -94,12 +96,14 @@ int gnutls_send(void *handle, struct string msg) { if (poll_res < 0) goto gnutls_send_error_unlock; - if ((pollfd.revents & (POLLIN | POLLOUT)) == (POLLIN | POLLOUT) || (pollfd.revents & (POLLIN | POLLOUT)) == 0) + if ((pollfd.revents & (POLLIN | POLLOUT)) == (POLLIN | POLLOUT)) continue; - else if (pollfd.revents & POLLIN) - pollfd.events = POLLOUT; + else if (pollfd.revents & (~(POLLIN | POLLOUT))) + goto gnutls_send_error_unlock; else - pollfd.events = POLLIN; + pollfd.events = (pollfd.revents & (POLLIN | POLLOUT)) ^ (POLLIN | POLLOUT); + } else { + goto gnutls_send_error_unlock; } } else { break; @@ -120,86 +124,168 @@ int gnutls_send(void *handle, struct string msg) { return 0; gnutls_send_error_unlock: - gnutls_handle->valid = 0; - pthread_mutex_unlock(&(gnutls_handle->mutex)); + if (gnutls_handle->valid) { + pthread_mutex_unlock(&(gnutls_handle->mutex)); + gnutls_shutdown(gnutls_handle); + } else { + pthread_mutex_unlock(&(gnutls_handle->mutex)); + } return 1; } -size_t gnutls_recv(void *session, char *data, size_t len, char *err) { - ssize_t res; +size_t gnutls_recv(void *handle, char *data, size_t len, char *err) { + struct gnutls_handle *gnutls_handle = handle; + + struct pollfd pollfd = { + .fd = gnutls_handle->fd, + }; + ssize_t gnutls_res; do { - res = gnutls_record_recv(*((gnutls_session_t*)session), data, len); - } while (res < 0 && (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN)); + int poll_res; + pthread_mutex_lock(&(gnutls_handle->mutex)); + if (!gnutls_handle->valid) { + pthread_mutex_unlock(&(gnutls_handle->mutex)); + *err = 3; + return 0; + } + do { + gnutls_res = gnutls_record_recv(gnutls_handle->session, data, len); + } while (gnutls_res == GNUTLS_E_INTERRUPTED); + pthread_mutex_unlock(&(gnutls_handle->mutex)); + if (gnutls_res < 0) { + if (gnutls_res == GNUTLS_E_AGAIN) { + pollfd.events = POLLIN | POLLOUT; + do { + poll_res = poll(&pollfd, 1, 0); + } while (poll_res < 0 && errno == EINTR); + if (poll_res < 0) { + *err = 3; + return 0; + } - if (res < 0) { - if (res == GNUTLS_E_TIMEDOUT) { - *err = 1; + if ((pollfd.revents & (POLLIN | POLLOUT)) == (POLLIN | POLLOUT)) + continue; + else + pollfd.events = (pollfd.revents & (POLLIN | POLLOUT)) ^ (POLLIN | POLLOUT); + } else { + *err = 3; + return 0; + } + } else if (gnutls_res == 0) { + *err = 2; + return 0; } else { + break; + } + + do { + poll_res = poll(&pollfd, 1, PING_INTERVAL*1000); + } while (poll_res < 0 && errno == EINTR); + if (poll_res < 0) { + *err = 3; + return 0; + } if (poll_res == 0) { // Timed out + *err = 1; + return 0; + } if ((pollfd.revents & (POLLIN | POLLOUT)) == 0) { *err = 3; + return 0; } - return 0; - } else if (res == 0) { - *err = 2; - return 0; - } + } while (1); + *err = 0; - return (size_t)res; + return (size_t)gnutls_res; } int gnutls_connect(void **handle, struct string address, struct string port, struct string *addr_out) { + struct gnutls_handle *gnutls_handle; + gnutls_handle = malloc(sizeof(*gnutls_handle)); + if (!gnutls_handle) + return -1; + + *handle = gnutls_handle; + + int res = pthread_mutex_init(&(gnutls_handle->mutex), &pthread_mutexattr); + if (res != 0) + goto gnutls_connect_free_gnutls_handle; + struct sockaddr sockaddr; if (resolve(address, port, &sockaddr) != 0) - return -1; + goto gnutls_connect_destroy_mutex; int fd = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); if (fd == -1) - return -1; + goto gnutls_connect_destroy_mutex; - { - struct timeval timeout = { - .tv_sec = PING_INTERVAL, - .tv_usec = 0, - }; + gnutls_handle->fd = fd; + gnutls_handle->valid = 1; - setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeout, sizeof(timeout)); - } - - int res; do { res = connect(fd, &sockaddr, sizeof(sockaddr)); } while (res < 0 && errno == EINTR); if (res < 0) goto gnutls_connect_close; - gnutls_session_t *session; - session = malloc(sizeof(*session)); - if (session == 0) + int flags = fcntl(fd, F_GETFL); + if (flags == -1) + goto gnutls_connect_close; + if (fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1) goto gnutls_connect_close; - *handle = session; - if (gnutls_init(session, GNUTLS_CLIENT | GNUTLS_NONBLOCK) != GNUTLS_E_SUCCESS) - goto gnutls_connect_free_session; + if (gnutls_init(&(gnutls_handle->session), GNUTLS_CLIENT | GNUTLS_NONBLOCK) != GNUTLS_E_SUCCESS) + goto gnutls_connect_close; - if (gnutls_server_name_set(*session, GNUTLS_NAME_DNS, address.data, address.len) != GNUTLS_E_SUCCESS) + if (gnutls_server_name_set(gnutls_handle->session, GNUTLS_NAME_DNS, address.data, address.len) != GNUTLS_E_SUCCESS) goto gnutls_connect_deinit_session; - if (gnutls_credentials_set(*session, GNUTLS_CRD_CERTIFICATE, gnutls_cert_creds) != GNUTLS_E_SUCCESS) + if (gnutls_credentials_set(gnutls_handle->session, GNUTLS_CRD_CERTIFICATE, gnutls_cert_creds) != GNUTLS_E_SUCCESS) goto gnutls_connect_deinit_session; - if (gnutls_set_default_priority(*session) != GNUTLS_E_SUCCESS) + if (gnutls_set_default_priority(gnutls_handle->session) != GNUTLS_E_SUCCESS) goto gnutls_connect_deinit_session; - gnutls_transport_set_int(*session, fd); - - gnutls_handshake_set_timeout(*session, PING_INTERVAL * 1000); - gnutls_record_set_timeout(*session, PING_INTERVAL * 1000); + gnutls_transport_set_int(gnutls_handle->session, fd); + struct pollfd pollfd = { + .fd = fd, + }; + ssize_t gnutls_res; do { - res = gnutls_handshake(*session); - } while (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN); - if (res < 0) - goto gnutls_connect_deinit_session; + int poll_res; + do { + gnutls_res = gnutls_handshake(gnutls_handle->session); + } while (res == GNUTLS_E_INTERRUPTED); + if (gnutls_res < 0) { + if (gnutls_res == GNUTLS_E_AGAIN) { + pollfd.events = POLLIN | POLLOUT; + do { + poll_res = poll(&pollfd, 1, 0); + } while (poll_res < 0 && errno == EINTR); + if (poll_res < 0) + goto gnutls_connect_deinit_session; + + if ((pollfd.revents & (POLLIN | POLLOUT)) == (POLLIN | POLLOUT)) + continue; + else if (pollfd.revents & (~(POLLIN | POLLOUT))) + goto gnutls_connect_deinit_session; + else + pollfd.events = pollfd.revents ^ (POLLIN | POLLOUT); + } + } else { + break; + } + + do { + poll_res = poll(&pollfd, 1, PING_INTERVAL*1000); + } while (poll_res < 0 && errno == EINTR); + if (poll_res < 0) + goto gnutls_connect_deinit_session; + if (poll_res == 0) // Timed out + goto gnutls_connect_deinit_session; + if ((pollfd.revents & (POLLIN | POLLOUT)) == 0) + goto gnutls_connect_deinit_session; + } while (1); addr_out->data = malloc(sizeof(sockaddr)); if (!addr_out->data) @@ -211,15 +297,21 @@ int gnutls_connect(void **handle, struct string address, struct string port, str return fd; gnutls_connect_deinit_session: - gnutls_deinit(*session); - gnutls_connect_free_session: - free(session); + gnutls_deinit(gnutls_handle->session); gnutls_connect_close: close(fd); + gnutls_connect_destroy_mutex: + pthread_mutex_destroy(&(gnutls_handle->mutex)); + gnutls_connect_free_gnutls_handle: + free(gnutls_handle); + return -1; } int gnutls_accept(int listen_fd, void **handle, struct string *addr) { + if (!GNUTLS_CERT_PATH || !GNUTLS_KEY_PATH) + return -1; + struct sockaddr address; socklen_t address_len = sizeof(address); @@ -231,55 +323,113 @@ int gnutls_accept(int listen_fd, void **handle, struct string *addr) { if (con_fd == -1) return -1; + int flags = fcntl(con_fd, F_GETFL); + if (flags == -1) + goto gnutls_accept_close; + if (fcntl(con_fd, F_SETFL, flags | O_NONBLOCK) == -1) + goto gnutls_accept_close; + + struct gnutls_handle *gnutls_handle; + gnutls_handle = malloc(sizeof(*gnutls_handle)); + if (!gnutls_handle) + goto gnutls_accept_close; + + *handle = gnutls_handle; + gnutls_handle->valid = 1; + gnutls_handle->fd = con_fd; + + int res = pthread_mutex_init(&(gnutls_handle->mutex), &(pthread_mutexattr)); + if (res != 0) + goto gnutls_accept_free_gnutls_handle; + addr->data = malloc(address_len); if (addr->data == 0 && address_len != 0) - goto gnutls_accept_close; + goto gnutls_accept_destroy_mutex; memcpy(addr->data, &address, address_len); addr->len = address_len; - gnutls_session_t *session; - session = malloc(sizeof(*session)); - if (!session) + if (gnutls_init(&(gnutls_handle->session), GNUTLS_SERVER | GNUTLS_NONBLOCK) != GNUTLS_E_SUCCESS) goto gnutls_accept_free_addr_data; - *handle = session; - if (gnutls_init(session, GNUTLS_SERVER | GNUTLS_NONBLOCK) != GNUTLS_E_SUCCESS) - goto gnutls_accept_free_session; - - if (gnutls_credentials_set(*session, GNUTLS_CRD_CERTIFICATE, gnutls_cert_creds) != GNUTLS_E_SUCCESS) + if (gnutls_credentials_set(gnutls_handle->session, GNUTLS_CRD_CERTIFICATE, gnutls_cert_creds) != GNUTLS_E_SUCCESS) goto gnutls_accept_deinit_session; - if (gnutls_set_default_priority(*session) != GNUTLS_E_SUCCESS) + if (gnutls_set_default_priority(gnutls_handle->session) != GNUTLS_E_SUCCESS) goto gnutls_accept_deinit_session; - gnutls_transport_set_int(*session, con_fd); + gnutls_transport_set_int(gnutls_handle->session, con_fd); - gnutls_handshake_set_timeout(*session, PING_INTERVAL * 1000); - gnutls_record_set_timeout(*session, PING_INTERVAL * 1000); + gnutls_handshake_set_timeout(gnutls_handle->session, PING_INTERVAL * 1000); + gnutls_record_set_timeout(gnutls_handle->session, PING_INTERVAL * 1000); - int res; + struct pollfd pollfd = { + .fd = con_fd, + }; + ssize_t gnutls_res; do { - res = gnutls_handshake(*session); - } while (res == GNUTLS_E_INTERRUPTED || res == GNUTLS_E_AGAIN); - if (res != GNUTLS_E_SUCCESS) - goto gnutls_accept_deinit_session; + int poll_res; + do { + gnutls_res = gnutls_handshake(gnutls_handle->session); + } while (res == GNUTLS_E_INTERRUPTED); + if (gnutls_res < 0) { + if (gnutls_res == GNUTLS_E_AGAIN) { + pollfd.events = POLLIN | POLLOUT; + do { + poll_res = poll(&pollfd, 1, 0); + } while (poll_res < 0 && errno == EINTR); + if (poll_res < 0) + goto gnutls_accept_deinit_session; + + if ((pollfd.revents & (POLLIN | POLLOUT)) == (POLLIN | POLLOUT)) + continue; + else if (pollfd.revents & (~(POLLIN | POLLOUT))) + goto gnutls_accept_deinit_session; + else + pollfd.events = pollfd.revents ^ (POLLIN | POLLOUT); + } + } else { + break; + } + + do { + poll_res = poll(&pollfd, 1, PING_INTERVAL*1000); + } while (poll_res < 0 && errno == EINTR); + if (poll_res < 0) + goto gnutls_accept_deinit_session; + if (poll_res == 0) // Timed out + goto gnutls_accept_deinit_session; + if ((pollfd.revents & (POLLIN | POLLOUT)) == 0) + goto gnutls_accept_deinit_session; + } while (1); return con_fd; gnutls_accept_deinit_session: - gnutls_deinit(*session); - gnutls_accept_free_session: - free(session); + gnutls_deinit(gnutls_handle->session); gnutls_accept_free_addr_data: free(addr->data); + gnutls_accept_destroy_mutex: + pthread_mutex_destroy(&(gnutls_handle->mutex)); + gnutls_accept_free_gnutls_handle: + free(gnutls_handle); gnutls_accept_close: close(con_fd); return -1; } +void gnutls_shutdown(void *handle) { + struct gnutls_handle *gnutls_handle = handle; + pthread_mutex_lock(&(gnutls_handle->mutex)); + shutdown(gnutls_handle->fd, SHUT_RDWR); + gnutls_handle->valid = 0; + pthread_mutex_unlock(&(gnutls_handle->mutex)); +} + void gnutls_close(int fd, void *handle) { - gnutls_deinit(*((gnutls_session_t*)handle)); - free(handle); + struct gnutls_handle *gnutls_handle = handle; + pthread_mutex_destroy(&(gnutls_handle->mutex)); + gnutls_deinit(gnutls_handle->session); + free(gnutls_handle); close(fd); } |