diff options
| author | Drew DeVault <sir@cmpwn.com> | 2020-10-25 14:50:07 -0400 |
|---|---|---|
| committer | Drew DeVault <sir@cmpwn.com> | 2020-10-25 14:50:07 -0400 |
| commit | 1fe107875b05cc07cf62c714c0136026eef7b93a (patch) | |
| tree | 463fa81cf3f67012ed542331cf072c10c72eaafd /src/server.c | |
| parent | a22bec51494a50c044416d469cc33e043480e7fd (diff) | |
| download | gmnisrv-1fe107875b05cc07cf62c714c0136026eef7b93a.tar.gz gmnisrv-1fe107875b05cc07cf62c714c0136026eef7b93a.tar.xz gmnisrv-1fe107875b05cc07cf62c714c0136026eef7b93a.zip | |
Overhaul network I/O to be async for real
Had to totally cut off OpenSSL from the network fd because obviously
OpenSSL is just going to wreck our shit
Diffstat (limited to 'src/server.c')
| -rw-r--r-- | src/server.c | 308 |
1 files changed, 217 insertions, 91 deletions
diff --git a/src/server.c b/src/server.c index 6412f1b..65b8204 100644 --- a/src/server.c +++ b/src/server.c @@ -186,9 +186,6 @@ disconnect_client(struct gmnisrv_server *server, struct gmnisrv_client *client) client->path ? client->path : "(none)", ms, client->bbytes, (int)client->status, client->meta); } - if (client->bio) { - BIO_free_all(client->bio); - } if (client->ssl) { SSL_free(client->ssl); } @@ -211,7 +208,7 @@ disconnect_client(struct gmnisrv_server *server, struct gmnisrv_client *client) static int client_init_ssl(struct gmnisrv_server *server, struct gmnisrv_client *client) { - client->ssl = tls_get_ssl(server->conf, client->sockfd); + client->ssl = tls_get_ssl(server->conf); if (!client->ssl) { client_error(&client->addr, "unable to initialize SSL, disconnecting"); @@ -219,151 +216,280 @@ client_init_ssl(struct gmnisrv_server *server, struct gmnisrv_client *client) return 1; } - int r = SSL_accept(client->ssl); - if (r != 1) { - r = SSL_get_error(client->ssl, r); - if (r == SSL_ERROR_WANT_READ || r == SSL_ERROR_WANT_WRITE) { - return 1; - } - client_error(&client->addr, "SSL accept error %s, disconnecting", - ERR_error_string(r, NULL)); - disconnect_client(server, client); - return 1; - } + client->rbio = BIO_new(BIO_s_mem()); + client->wbio = BIO_new(BIO_s_mem()); - client->sbio = BIO_new(BIO_f_ssl()); - BIO_set_ssl(client->sbio, client->ssl, 0); - client->bio = BIO_new(BIO_f_buffer()); - BIO_push(client->bio, client->sbio); + SSL_set_accept_state(client->ssl); + SSL_set_bio(client->ssl, client->rbio, client->wbio); return 0; } -enum client_state { - CLIENT_CONNECTED, - CLIENT_DISCONNECTED, +enum connection_state { + CONNECTED, + DISCONNECTED, }; -static enum client_state +static enum connection_state client_readable(struct gmnisrv_server *server, struct gmnisrv_client *client) { if (!client->ssl && client_init_ssl(server, client) != 0) { - return CLIENT_DISCONNECTED; + return DISCONNECTED; + } + + char buf[BUFSIZ]; + ssize_t n = read(client->sockfd, buf, sizeof(buf)); + if (n <= 0) { + disconnect_client(server, client); + return DISCONNECTED; } + + size_t w = 0; + while (w < (size_t)n) { + int r = BIO_write(client->rbio, &buf[w], n - w); + if (r <= 0) { + client_error(&client->addr, + "Error writing to client RBIO: %s", + ERR_error_string(r, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } + w += r; + } + + if (!SSL_is_init_finished(client->ssl)) { + int r = SSL_accept(client->ssl); + switch ((r = SSL_get_error(client->ssl, r))) { + case SSL_ERROR_NONE: + break; + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + goto queue_ssl_write; + case SSL_ERROR_SSL: + client_error(&client->addr, + "SSL accept error: %s", + ERR_error_string(ERR_get_error(), NULL)); + disconnect_client(server, client); + return DISCONNECTED; + default: + client_error(&client->addr, + "SSL accept error: %s", + ERR_error_string(r, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } + + if (!SSL_is_init_finished(client->ssl)) { + return CONNECTED; + } + } + if (!client->host) { + client_log(&client->addr, "missing client host"); const char *error = "This server requires clients to support the TLS SNI (server name identification) extension"; client_submit_response(client, GEMINI_STATUS_BAD_REQUEST, error, NULL); - return CLIENT_CONNECTED; + return CONNECTED; } - int r = BIO_gets(client->bio, client->buf, sizeof(client->buf)); - if (r <= 0) { - r = SSL_get_error(client->ssl, r); - if (r == SSL_ERROR_WANT_READ) { - return CLIENT_CONNECTED; + int r, e; + do { + if (client->bufln >= sizeof(client->buf)) { + client_log(&client->addr, "overlong"); + const char *error = "Protocol error: malformed request"; + client_submit_response(client, + GEMINI_STATUS_BAD_REQUEST, error, NULL); + return CONNECTED; } - client_error(&client->addr, "SSL read error %s, disconnecting", - ERR_error_string(r, NULL)); - disconnect_client(server, client); - return CLIENT_DISCONNECTED; - } - client->buf[r] = '\0'; + + r = SSL_read(client->ssl, + &client->buf[client->bufln], + sizeof(client->buf) - client->bufln); + + switch ((e = SSL_get_error(client->ssl, r))) { + case SSL_ERROR_NONE: + client->bufln += r; + break; + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + break; + case SSL_ERROR_SSL: + client_error(&client->addr, + "SSL read error: %s", + ERR_error_string(ERR_get_error(), NULL)); + disconnect_client(server, client); + return DISCONNECTED; + default: + client_error(&client->addr, + "SSL read error: %s", + ERR_error_string(e, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } + } while (r > 0); + + client->buf[client->bufln] = '\0'; char *newline = strstr(client->buf, "\r\n"); if (!newline) { const char *error = "Protocol error: malformed request"; - client_submit_response(client, - GEMINI_STATUS_BAD_REQUEST, error, NULL); - return CLIENT_CONNECTED; + switch (e) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + goto queue_ssl_write; + default: + client_submit_response(client, + GEMINI_STATUS_BAD_REQUEST, error, NULL); + return CONNECTED; + } } *newline = 0; if (!request_validate(client, &client->path)) { - return CLIENT_CONNECTED; + return CONNECTED; } serve_request(client); - return CLIENT_CONNECTED; + return CONNECTED; + +queue_ssl_write: + client->bufln = 0; + client->state = CLIENT_STATE_SSL; + client->next = CLIENT_STATE_REQUEST; + do { + assert(client->bufln < sizeof(client->buf)); + r = BIO_read(client->wbio, + &client->buf[client->bufln], + sizeof(client->buf) - client->bufln); + if (r <= 0) { + if (BIO_should_retry(client->wbio)) { + continue; + } + client_error(&client->addr, + "BIO read error: %s", + ERR_error_string(r, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } else { + client->bufln += r; + } + } while (r > 0); + client->pollfd->events = POLLOUT; + return CONNECTED; } -static enum client_state +static enum connection_state client_writable(struct gmnisrv_server *server, struct gmnisrv_client *client) { int r; ssize_t n; + char buf[BUFSIZ]; switch (client->state) { - case RESPOND_HEADER: + case CLIENT_STATE_REQUEST: + assert(0); // Invariant + case CLIENT_STATE_SSL: + assert(client->bufln > 0); + n = write(client->sockfd, client->buf, client->bufln); + if (n <= 0) { + client_log(&client->addr, "write error: %s", + strerror(errno)); + disconnect_client(server, client); + return DISCONNECTED; + } + client->bufln -= n; + if (client->bufln == 0) { + client->state = client->next; + if (client->state == CLIENT_STATE_REQUEST) { + client->pollfd->events = POLLIN; + } + } + return CONNECTED; + case CLIENT_STATE_HEADER: if (client->bufix == 0) { assert(strlen(client->meta) <= 1024); - n = snprintf(client->buf, sizeof(client->buf), + int n = snprintf(client->buf, sizeof(client->buf), "%02d %s\r\n", (int)client->status, client->meta); assert(n > 0); client->bufln = n; } - r = BIO_write(client->sbio, &client->buf[client->bufix], - client->bufln - client->bufix); - if (r <= 0) { - r = SSL_get_error(client->ssl, r); - if (r == SSL_ERROR_WANT_WRITE) { - return CLIENT_CONNECTED; - } - client->status = GEMINI_STATUS_NONE; - client_error(&client->addr, - "header write error %s, disconnecting", - ERR_error_string(r, NULL)); - disconnect_client(server, client); - return CLIENT_DISCONNECTED; - } - client->bufix += r; - if (client->bufix >= client->bufln) { - if (!client->body) { - disconnect_client(server, client); - return CLIENT_DISCONNECTED; - } else { - client->state = RESPOND_BODY; - client->bufix = client->bufln = 0; - return CLIENT_CONNECTED; - } - } break; - case RESPOND_BODY: + case CLIENT_STATE_BODY: if (client->bufix >= client->bufln) { - n = fread(client->buf, 1, sizeof(client->buf), - client->body); + int n = fread(client->buf, 1, + sizeof(client->buf), client->body); if (n == -1) { client_error(&client->addr, "Error reading response body: %s", strerror(errno)); disconnect_client(server, client); - return CLIENT_DISCONNECTED; + return DISCONNECTED; } if (n == 0) { // EOF disconnect_client(server, client); - return CLIENT_DISCONNECTED; + return DISCONNECTED; } client->bbytes += n; client->bufln = n; client->bufix = 0; } - r = BIO_write(client->sbio, &client->buf[client->bufix], + break; + } + + r = SSL_write(client->ssl, &client->buf[client->bufix], client->bufln - client->bufix); - if (r <= 0) { - r = SSL_get_error(client->ssl, r); - if (r == SSL_ERROR_WANT_WRITE) { - return CLIENT_CONNECTED; - } - client->status = GEMINI_STATUS_NONE; - client_error(&client->addr, "body write error %s, disconnecting", - ERR_error_string(r, NULL)); + if (r <= 0) { + r = SSL_get_error(client->ssl, r); + assert(r == SSL_ERROR_WANT_WRITE); // Hmm? + client->status = GEMINI_STATUS_NONE; + client_error(&client->addr, + "header write error %s, disconnecting", + ERR_error_string(r, NULL)); + disconnect_client(server, client); + return DISCONNECTED; + } + client->bufix += r; + + while (r > 0) { + r = BIO_read(client->wbio, buf, sizeof(buf)); + if (r < 0 && !BIO_should_retry(client->wbio)) { + client_error(&client->addr, + "BIO read error: %s", + ERR_error_string(r, NULL)); disconnect_client(server, client); - return CLIENT_DISCONNECTED; + return DISCONNECTED; + } + + for (int w = 0; w < r; ) { + int q = write(client->sockfd, &buf[w], r - w); + if (q < 0) { + assert(0); // TODO: handle write errors + } + w += q; } - client->bufix += r; + } + + switch (client->state) { + case CLIENT_STATE_REQUEST: + case CLIENT_STATE_SSL: + assert(0); // Invariant + case CLIENT_STATE_HEADER: + if (client->bufix >= client->bufln) { + if (!client->body) { + disconnect_client(server, client); + return DISCONNECTED; + } else { + client->state = CLIENT_STATE_BODY; + client->bufix = client->bufln = 0; + return CONNECTED; + } + } + break; + case CLIENT_STATE_BODY: break; } - return false; + + return CONNECTED; } static long @@ -450,18 +576,18 @@ server_run(struct gmnisrv_server *server) for (size_t i = 0; i < server->nclients; ++i) { int pi = i + server->nlisten; - enum client_state s = CLIENT_CONNECTED; + enum connection_state s = CONNECTED; if ((server->fds[pi].revents & (POLLHUP | POLLERR))) { disconnect_client(server, &server->clients[i]); - s = CLIENT_DISCONNECTED; + s = DISCONNECTED; } - if (s == CLIENT_CONNECTED && (server->fds[pi].revents & POLLIN)) { + if (s == CONNECTED && (server->fds[pi].revents & POLLIN)) { s = client_readable(server, &server->clients[i]); } - if (s == CLIENT_CONNECTED && (server->fds[pi].revents & POLLOUT)) { + if (s == CONNECTED && (server->fds[pi].revents & POLLOUT)) { s = client_writable(server, &server->clients[i]); } - if (s == CLIENT_DISCONNECTED) { + if (s == DISCONNECTED) { --i; } } |
