From a25171a94baf72c6e4d87c4fe3b4f7ac1e29a5d6 Mon Sep 17 00:00:00 2001 From: mikeos Date: Fri, 1 Feb 2013 00:46:24 +0400 Subject: add support of unix sockets --- main.c | 153 +++++++++++++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 115 insertions(+), 38 deletions(-) diff --git a/main.c b/main.c index 1ebc4c7..db32afa 100644 --- a/main.c +++ b/main.c @@ -34,6 +34,11 @@ #define bputconst(b, s) bputs(b, s, sizeof(s) - 1) +typedef enum { + UNIX, + TCP +} socket_type_t; + LIST_HEAD(channel_list, channel) channels; SLIST_HEAD(server_list, server) servers; @@ -44,7 +49,8 @@ struct server { int sock; int local_port; char *remote_addr; - struct sockaddr_in remote_sa; + socket_type_t type; + struct sockaddr_storage remote_sa; int nchannel; SLIST_ENTRY(server) entries; @@ -103,12 +109,14 @@ static void on_quit(void); static void makeaddr(struct sockaddr_in *sa, const char *host, const char *port); static int portbyname(const char *name); -static char *addrstr(struct sockaddr_in *sa); +static char *addrstr(struct sockaddr_storage *ss); static struct channel *chan_new(struct server *s, char *addr, int fd1, int fd2); static void chan_free(struct channel *chan); -static void serv_add_new(char *local_port, char *host, char *port); +static void server_add_tcp(char *local_port, char *host, char *port); +static void server_add_unix(char *local_port, char *path); +static void server_add_new(char *local_port, socket_type_t type, struct sockaddr_storage *ss); static void endless_loop(int unix_sock); @@ -125,24 +133,47 @@ static void read_config_file(char *name) sys_err("fopen %s", name); while (getline(&line, &len, fp) >= 0) { - char *local, *host, *port; + char *local, *type; + + if ((local = strtok(line, delim)) == NULL) + goto invalid; + + if ((type = strtok(NULL, delim)) == NULL) + goto invalid; + + if (! strcmp(type, "unix")) { + char *path = strtok(NULL, delim); + + if (path == NULL) + goto invalid; - local = host = port = NULL; + server_add_unix(local, path); + } else { + char *host, *port = NULL; - if ((local = strtok(line, delim)) != NULL) { - if ((host = strtok(NULL, delim)) != NULL) + if (! strcmp(type, "tcp")) + host = strtok(NULL, delim); + else + host = type; + + if (host != NULL) port = strtok(NULL, delim); - } - if (port == NULL) - err_quit("%s: invalid line %d", name, nline); + if (port == NULL) + goto invalid; + + server_add_tcp(local, host, port); + } - serv_add_new(local, host, port); nline++; } fclose(fp); free(line); + return; + +invalid: + err_quit("%s: invalid line %d", name, nline); } static void usage(void) @@ -277,7 +308,31 @@ static void chan_free(struct channel *chan) s->nchannel--; } -static void serv_add_new(char *local_port, char *host, char *port) +static void server_add_tcp(char *local_port, char *host, char *port) +{ + struct sockaddr_storage storage; + + memset(&storage, 0, sizeof(storage)); + makeaddr((struct sockaddr_in *) &storage, host, port); + + server_add_new(local_port, TCP, &storage); +} + +static void server_add_unix(char *local_port, char *path) +{ + union { + struct sockaddr_storage storage; + struct sockaddr_un un; + } sa; + + memset(&sa, 0, sizeof(sa)); + sa.un.sun_family = AF_UNIX; + strcpy(sa.un.sun_path, path); + + server_add_new(local_port, UNIX, &sa.storage); +} + +static void server_add_new(char *local_port, socket_type_t type, struct sockaddr_storage *ss) { struct sockaddr_in sa; struct server *s; @@ -285,10 +340,10 @@ static void serv_add_new(char *local_port, char *host, char *port) s = (struct server *) xmalloc(sizeof(*s)); s->nchannel = 0; - makeaddr(&s->remote_sa, host, port); + s->type = type; + memcpy(&s->remote_sa, ss, sizeof(struct sockaddr_storage)); s->remote_addr = addrstr(&s->remote_sa); - memset(&sa, 0, sizeof(sa)); makeaddr(&sa, "0.0.0.0", local_port); s->local_port = ntohs(sa.sin_port); s->sock = tcp_server((struct sockaddr *) &sa, sizeof(sa)); @@ -353,12 +408,24 @@ static void makeaddr(struct sockaddr_in *sa, const char *host, const char *port) sa->sin_port = portbyname(port); } -static char *addrstr(struct sockaddr_in *sa) +static char *addrstr(struct sockaddr_storage *ss) { char buf[512]; - snprintf(buf, sizeof(buf), "%s:%d", inet_ntoa(sa->sin_addr), - ntohs(sa->sin_port)); + switch (ss->ss_family) { + case AF_UNIX: { + struct sockaddr_un *sa = (struct sockaddr_un *) ss; + strcpy(buf, sa->sun_path); + } break; + case AF_INET: { + struct sockaddr_in *sa = (struct sockaddr_in *) ss; + snprintf(buf, sizeof(buf), "%s:%d", + inet_ntoa(sa->sin_addr), ntohs(sa->sin_port)); + } break; + default: + strcpy(buf, "unknown"); + } + return xstrdup(buf); } @@ -394,6 +461,7 @@ static inline int unix_socket(void) sock = socket(PF_UNIX, SOCK_STREAM, 0); if (sock < 0) sys_err("socket PF_UNIX"); + set_nonblock(sock); return sock; } @@ -776,21 +844,28 @@ static void test_chans(fd_set *readfds, fd_set *writefds, fd_set *exceptfds) static void accept_client(struct server *server, int serv_sock) { struct channel *chan; - struct sockaddr_in sa; - socklen_t salen; - int fd1, fd2; + struct sockaddr_storage ss; + socklen_t salen, addrlen = 0; + int fd1, fd2 = -1; int ret; - salen = sizeof(sa); - fd1 = accept(serv_sock, (struct sockaddr *) &sa, &salen); + salen = sizeof(ss); + fd1 = accept(serv_sock, (struct sockaddr *) &ss, &salen); if (fd1 < 0) { if (errno != EAGAIN) msg_err("accept"); return; } - fd2 = tcp_socket(); - ret = connect(fd2, (struct sockaddr *) &server->remote_sa, sizeof(struct sockaddr_in)); + if (server->type == TCP) { + fd2 = tcp_socket(); + addrlen = sizeof(struct sockaddr_in); + } else if (server->type == UNIX) { + fd2 = unix_socket(); + addrlen = sizeof(struct sockaddr_un); + } + + ret = connect(fd2, (struct sockaddr *) &server->remote_sa, addrlen); if (ret < 0 && errno != EINPROGRESS) { close(fd1); close(fd2); @@ -799,7 +874,7 @@ static void accept_client(struct server *server, int serv_sock) return; } - chan = chan_new(server, addrstr(&sa), fd1, fd2); + chan = chan_new(server, addrstr(&ss), fd1, fd2); if (ret == 0) chan->connected++; @@ -844,24 +919,26 @@ static void accept_unix(int unix_sock) static void endless_loop(int unix_sock) { - fd_set rd, wr, er; - int nfds; - int ret; + fd_set init_rd; + int init_nfds = -1; + + FD_ZERO(&init_rd); + if (unix_sock >= 0) { + FD_SET(unix_sock, &init_rd); + init_nfds = unix_sock; + } + + init_nfds = max(add_servs(&init_rd), init_nfds); for (;;) { - FD_ZERO(&rd); + fd_set rd, wr, er; + int nfds, ret; + + memcpy(&rd, &init_rd, sizeof(fd_set)); FD_ZERO(&wr); FD_ZERO(&er); - if (unix_sock < 0) - nfds = -1; - else { - FD_SET(unix_sock, &rd); - nfds = unix_sock; - } - - nfds = max(add_servs(&rd), nfds); - nfds = max(add_chans(&rd, &wr, &er), nfds); + nfds = max(add_chans(&rd, &wr, &er), init_nfds); if (nfds < 0) break; -- cgit v1.2.3-70-g09d2