#include #include #include #include #include #define NS(x) ((vlong)x) #define US(x) (NS(x) * 1000LL) #define MS(x) (US(x) * 1000LL) #define S(x) (MS(x) * 1000LL) #define LOGNAME "aan" enum { Synctime = S(8), Nbuf = 10, K = 1024, Bufsize = 8 * K, Stacksize = 8 * K, Timer = 0, /* Alt channels. */ Unsent = 1, Maxto = 24 * 3600, /* A full day to reconnect. */ Hdrsz = 12, }; typedef struct Endpoints Endpoints; struct Endpoints { char *lsys; char *lserv; char *rsys; char *rserv; }; typedef struct { ulong nb; /* Number of data bytes in this message */ ulong msg; /* Message number */ ulong acked; /* Number of messages acked */ } Hdr; typedef struct { Hdr hdr; uchar buf[Bufsize]; } Buf; static char *Logname = LOGNAME; static int client; static int debug; static char *devdir; static char *dialstring; static int done; static int inmsg; static int maxto = Maxto; static int netfd; static Channel *empty; static Channel *unacked; static Channel *unsent; static Alt a[] = { /* c v op */ { nil, nil, CHANRCV }, /* timer */ { nil, nil, CHANRCV }, /* unsent */ { nil, nil, CHANEND }, }; static void dmessage(int, char *, ...); static void freeendpoints(Endpoints *); static void fromclient(void*); static void fromnet(void*); static Endpoints *getendpoints(char *); static void packhdr(Hdr *, uchar *); static void reconnect(void); static void showmsg(int, char *, Buf *); static void synchronize(void); static void timerproc(void *); static void unpackhdr(Hdr *, uchar *); static int writen(int, uchar *, int); static void usage(void) { fprint(2, "Usage: %s [-cd] [-m maxto] dialstring|netdir\n", argv0); threadexitsall("usage"); } static int catch(void *, char *s) { if (strstr(s, "alarm") != nil) { syslog(0, Logname, "Timed out waiting for client on %s, exiting...", devdir); threadexitsall(nil); } return 0; } void threadmain(int argc, char **argv) { int i, fd, failed, delta; vlong synctime, now; char *p; uchar buf[Hdrsz]; Buf *b, *eb; Channel *timer; Hdr hdr; ARGBEGIN { case 'c': client++; break; case 'd': debug++; break; case 'm': maxto = strtol(EARGF(usage()), (char **)nil, 0); break; default: usage(); } ARGEND; if (argc != 1) usage(); if (!client) { devdir = argv[0]; if ((p = strstr(devdir, "/local")) != nil) *p = '\0'; }else dialstring = argv[0]; if (debug > 0) { fd = open("#c/cons", OWRITE|OCEXEC); dup(fd, 2); } fmtinstall('F', fcallfmt); atnotify(catch, 1); unsent = chancreate(sizeof(Buf *), Nbuf); unacked= chancreate(sizeof(Buf *), Nbuf); empty = chancreate(sizeof(Buf *), Nbuf); timer = chancreate(sizeof(uchar *), 1); for (i = 0; i != Nbuf; i++) { eb = malloc(sizeof(Buf)); sendp(empty, eb); } netfd = -1; if (proccreate(fromnet, nil, Stacksize) < 0) sysfatal("Cannot start fromnet; %r"); reconnect(); /* Set up the initial connection. */ synchronize(); if (proccreate(fromclient, nil, Stacksize) < 0) sysfatal("cannot start fromclient; %r"); if (proccreate(timerproc, timer, Stacksize) < 0) sysfatal("Cannot start timerproc; %r"); a[Timer].c = timer; a[Unsent].c = unsent; a[Unsent].v = &b; synctime = nsec() + Synctime; failed = 0; while (!done) { if (failed) { /* Wait for the netreader to die. */ while (netfd >= 0) { dmessage(1, "main; waiting for netreader to die\n"); sleep(1000); } /* the reader died; reestablish the world. */ reconnect(); synchronize(); failed = 0; } now = nsec(); delta = (synctime - nsec()) / MS(1); if (delta <= 0) { hdr.nb = 0; hdr.acked = inmsg; hdr.msg = -1; packhdr(&hdr, buf); if (writen(netfd, buf, sizeof(buf)) < 0) { dmessage(2, "main; writen failed; %r\n"); failed = 1; continue; } synctime = nsec() + Synctime; assert(synctime > now); } switch (alt(a)) { case Timer: break; case Unsent: sendp(unacked, b); b->hdr.acked = inmsg; packhdr(&b->hdr, buf); if (writen(netfd, buf, sizeof(buf)) < 0 || writen(netfd, b->buf, b->hdr.nb) < 0) { dmessage(2, "main; writen failed; %r\n"); failed = 1; } if (b->hdr.nb == 0) done = 1; break; } } syslog(0, Logname, "exiting..."); threadexitsall(nil); } static void fromclient(void*) { Buf *b; static int outmsg; do { b = recvp(empty); if ((int)(b->hdr.nb = read(0, b->buf, Bufsize)) <= 0) { if ((int)b->hdr.nb < 0) dmessage(2, "fromclient; Cannot read 9P message; %r\n"); else dmessage(2, "fromclient; Client terminated\n"); b->hdr.nb = 0; } b->hdr.msg = outmsg++; showmsg(1, "fromclient", b); sendp(unsent, b); } while (b->hdr.nb != 0); } static void fromnet(void*) { int len, acked, i; uchar buf[Hdrsz]; Buf *b, *rb; static int lastacked; b = (Buf *)malloc(sizeof(Buf)); assert(b); while (!done) { while (netfd < 0) { dmessage(1, "fromnet; waiting for connection... (inmsg %d)\n", inmsg); sleep(1000); } /* Read the header. */ if ((len = readn(netfd, buf, sizeof(buf))) <= 0) { if (len < 0) dmessage(1, "fromnet; (hdr) network failure; %r\n"); else dmessage(1, "fromnet; (hdr) network closed\n"); close(netfd); netfd = -1; continue; } unpackhdr(&b->hdr, buf); dmessage(2, "fromnet: Got message, size %d, nb %d, msg %d\n", len, b->hdr.nb, b->hdr.msg); if (b->hdr.nb == 0) { if ((long)b->hdr.msg >= 0) { dmessage(1, "fromnet; network closed\n"); break; } continue; } if ((len = readn(netfd, b->buf, b->hdr.nb)) <= 0 || len != b->hdr.nb) { if (len == 0) dmessage(1, "fromnet; network closed\n"); else dmessage(1, "fromnet; network failure; %r\n"); close(netfd); netfd = -1; continue; } if (b->hdr.msg < inmsg) { dmessage(1, "fromnet; skipping message %d, currently at %d\n", b->hdr.msg, inmsg); continue; } /* Process the acked list. */ acked = b->hdr.acked - lastacked; for (i = 0; i != acked; i++) { rb = recvp(unacked); if (rb->hdr.msg != lastacked + i) { dmessage(1, "rb %p, msg %d, lastacked %d, i %d\n", rb, rb? rb->hdr.msg: -2, lastacked, i); assert(0); } rb->hdr.msg = -1; sendp(empty, rb); } lastacked = b->hdr.acked; inmsg++; showmsg(1, "fromnet", b); if (writen(1, b->buf, len) < 0) sysfatal("fromnet; cannot write to client; %r"); } done = 1; } static void reconnect(void) { char err[32], ldir[40]; int lcfd, fd; Endpoints *ep; if (dialstring) { syslog(0, Logname, "dialing %s", dialstring); while ((fd = dial(dialstring, nil, nil, nil)) < 0) { err[0] = '\0'; errstr(err, sizeof err); if (strstr(err, "connection refused")) { dmessage(1, "reconnect; server died...\n"); threadexitsall("server died..."); } dmessage(1, "reconnect: dialed %s; %s\n", dialstring, err); sleep(1000); } syslog(0, Logname, "reconnected to %s", dialstring); } else { syslog(0, Logname, "waiting for connection on %s", devdir); alarm(maxto * 1000); if ((lcfd = listen(devdir, ldir)) < 0) sysfatal("reconnect; cannot listen; %r"); if ((fd = accept(lcfd, ldir)) < 0) sysfatal("reconnect; cannot accept; %r"); alarm(0); close(lcfd); ep = getendpoints(ldir); dmessage(1, "rsys '%s'\n", ep->rsys); syslog(0, Logname, "connected from %s", ep->rsys); freeendpoints(ep); } netfd = fd; /* Wakes up the netreader. */ } static void synchronize(void) { Channel *tmp; Buf *b; uchar buf[Hdrsz]; /* * Ignore network errors here. If we fail during * synchronization, the next alarm will pick up * the error. */ tmp = chancreate(sizeof(Buf *), Nbuf); while ((b = nbrecvp(unacked)) != nil) { packhdr(&b->hdr, buf); writen(netfd, buf, sizeof(buf)); writen(netfd, b->buf, b->hdr.nb); sendp(tmp, b); } chanfree(unacked); unacked = tmp; } static void showmsg(int level, char *s, Buf *b) { if (b == nil) { dmessage(level, "%s; b == nil\n", s); return; } dmessage(level, "%s; (len %d) %X %X %X %X %X %X %X %X %X (%p)\n", s, b->hdr.nb, b->buf[0], b->buf[1], b->buf[2], b->buf[3], b->buf[4], b->buf[5], b->buf[6], b->buf[7], b->buf[8], b); } static int writen(int fd, uchar *buf, int nb) { int n, len = nb; while (nb > 0) { if (fd < 0) return -1; if ((n = write(fd, buf, nb)) < 0) { dmessage(1, "writen; Write failed; %r\n"); return -1; } dmessage(2, "writen: wrote %d bytes\n", n); buf += n; nb -= n; } return len; } static void timerproc(void *x) { Channel *timer = x; while (!done) { sleep((Synctime / MS(1)) >> 1); sendp(timer, "timer"); } } static void dmessage(int level, char *fmt, ...) { va_list arg; if (level > debug) return; va_start(arg, fmt); vfprint(2, fmt, arg); va_end(arg); } static void getendpoint(char *dir, char *file, char **sysp, char **servp) { int fd, n; char buf[128]; char *sys, *serv; sys = serv = 0; snprint(buf, sizeof buf, "%s/%s", dir, file); fd = open(buf, OREAD); if(fd >= 0){ n = read(fd, buf, sizeof(buf)-1); if(n>0){ buf[n-1] = 0; serv = strchr(buf, '!'); if(serv){ *serv++ = 0; serv = strdup(serv); } sys = strdup(buf); } close(fd); } if(serv == 0) serv = strdup("unknown"); if(sys == 0) sys = strdup("unknown"); *servp = serv; *sysp = sys; } static Endpoints * getendpoints(char *dir) { Endpoints *ep; ep = malloc(sizeof(*ep)); getendpoint(dir, "local", &ep->lsys, &ep->lserv); getendpoint(dir, "remote", &ep->rsys, &ep->rserv); return ep; } static void freeendpoints(Endpoints *ep) { free(ep->lsys); free(ep->rsys); free(ep->lserv); free(ep->rserv); free(ep); } /* p must be a uchar* */ #define U32GET(p) (p[0] | p[1]<<8 | p[2]<<16 | p[3]<<24) #define U32PUT(p,v) (p)[0] = (v); (p)[1] = (v)>>8; \ (p)[2] = (v)>>16; (p)[3] = (v)>>24 static void packhdr(Hdr *hdr, uchar *buf) { uchar *p; p = buf; U32PUT(p, hdr->nb); p += 4; U32PUT(p, hdr->msg); p += 4; U32PUT(p, hdr->acked); } static void unpackhdr(Hdr *hdr, uchar *buf) { uchar *p; p = buf; hdr->nb = U32GET(p); p += 4; hdr->msg = U32GET(p); p += 4; hdr->acked = U32GET(p); }