#include #include #include #include "session.h" struct { int version; char *s; } vtVersions[] = { VtVersion02, "02", 0, 0, }; static char EBigString[] = "string too long"; static char EBigPacket[] = "packet too long"; static char ENullString[] = "missing string"; static char EBadVersion[] = "bad format in version string"; static Packet *vtRPC(VtSession *z, int op, Packet *p); VtSession * vtAlloc(void) { VtSession *z; z = vtMemAllocZ(sizeof(VtSession)); z->lk = vtLockAlloc(); // z->inHash = vtSha1Alloc(); z->inLock = vtLockAlloc(); z->part = packetAlloc(); // z->outHash = vtSha1Alloc(); z->outLock = vtLockAlloc(); z->fd = -1; z->uid = vtStrDup("anonymous"); z->sid = vtStrDup("anonymous"); return z; } void vtReset(VtSession *z) { vtLock(z->lk); z->cstate = VtStateAlloc; if(z->fd >= 0){ vtFdClose(z->fd); z->fd = -1; } vtUnlock(z->lk); } int vtConnected(VtSession *z) { return z->cstate == VtStateConnected; } void vtDisconnect(VtSession *z, int error) { Packet *p; uchar *b; vtDebug(z, "vtDisconnect\n"); vtLock(z->lk); if(z->cstate == VtStateConnected && !error && z->vtbl == nil) { /* clean shutdown */ p = packetAlloc(); b = packetHeader(p, 2); b[0] = VtQGoodbye; b[1] = 0; vtSendPacket(z, p); } if(z->fd >= 0) vtFdClose(z->fd); z->fd = -1; z->cstate = VtStateClosed; vtUnlock(z->lk); } void vtClose(VtSession *z) { vtDisconnect(z, 0); } void vtFree(VtSession *z) { if(z == nil) return; vtLockFree(z->lk); vtSha1Free(z->inHash); vtLockFree(z->inLock); packetFree(z->part); vtSha1Free(z->outHash); vtLockFree(z->outLock); vtMemFree(z->uid); vtMemFree(z->sid); vtMemFree(z->vtbl); memset(z, 0, sizeof(VtSession)); z->fd = -1; vtMemFree(z); } char * vtGetUid(VtSession *s) { return s->uid; } char * vtGetSid(VtSession *z) { return z->sid; } int vtSetDebug(VtSession *z, int debug) { int old; vtLock(z->lk); old = z->debug; z->debug = debug; vtUnlock(z->lk); return old; } int vtSetFd(VtSession *z, int fd) { vtLock(z->lk); if(z->cstate != VtStateAlloc) { vtSetError("bad state"); vtUnlock(z->lk); return 0; } if(z->fd >= 0) vtFdClose(z->fd); z->fd = fd; vtUnlock(z->lk); return 1; } int vtGetFd(VtSession *z) { return z->fd; } int vtSetCryptoStrength(VtSession *z, int c) { if(z->cstate != VtStateAlloc) { vtSetError("bad state"); return 0; } if(c != VtCryptoStrengthNone) { vtSetError("not supported yet"); return 0; } return 1; } int vtGetCryptoStrength(VtSession *s) { return s->cryptoStrength; } int vtSetCompression(VtSession *z, int fd) { vtLock(z->lk); if(z->cstate != VtStateAlloc) { vtSetError("bad state"); vtUnlock(z->lk); return 0; } z->fd = fd; vtUnlock(z->lk); return 1; } int vtGetCompression(VtSession *s) { return s->compression; } int vtGetCrypto(VtSession *s) { return s->crypto; } int vtGetCodec(VtSession *s) { return s->codec; } char * vtGetVersion(VtSession *z) { int v, i; v = z->version; if(v == 0) return "unknown"; for(i=0; vtVersions[i].version; i++) if(vtVersions[i].version == v) return vtVersions[i].s; assert(0); return 0; } /* hold z->inLock */ static int vtVersionRead(VtSession *z, char *prefix, int *ret) { char c; char buf[VtMaxStringSize]; char *q, *p, *pp; int i; q = prefix; p = buf; for(;;) { if(p >= buf + sizeof(buf)) { vtSetError(EBadVersion); return 0; } if(!vtFdReadFully(z->fd, (uchar*)&c, 1)) return 0; if(z->inHash) vtSha1Update(z->inHash, (uchar*)&c, 1); if(c == '\n') { *p = 0; break; } if(c < ' ' || *q && c != *q) { vtSetError(EBadVersion); return 0; } *p++ = c; if(*q) q++; } vtDebug(z, "version string in: %s\n", buf); p = buf + strlen(prefix); for(;;) { for(pp=p; *pp && *pp != ':' && *pp != '-'; pp++) ; for(i=0; vtVersions[i].version; i++) { if(strlen(vtVersions[i].s) != pp-p) continue; if(memcmp(vtVersions[i].s, p, pp-p) == 0) { *ret = vtVersions[i].version; return 1; } } p = pp; if(*p != ':') return 0; p++; } } Packet* vtRecvPacket(VtSession *z) { uchar buf[10], *b; int n; Packet *p; int size, len; if(z->cstate != VtStateConnected) { vtSetError("session not connected"); return 0; } vtLock(z->inLock); p = z->part; /* get enough for head size */ size = packetSize(p); while(size < 2) { b = packetTrailer(p, MaxFragSize); assert(b != nil); n = vtFdRead(z->fd, b, MaxFragSize); if(n <= 0) goto Err; size += n; packetTrim(p, 0, size); } if(!packetConsume(p, buf, 2)) goto Err; len = (buf[0] << 8) | buf[1]; size -= 2; while(size < len) { n = len - size; if(n > MaxFragSize) n = MaxFragSize; b = packetTrailer(p, n); if(!vtFdReadFully(z->fd, b, n)) goto Err; size += n; } p = packetSplit(p, len); vtUnlock(z->inLock); return p; Err: vtUnlock(z->inLock); return nil; } int vtSendPacket(VtSession *z, Packet *p) { IOchunk ioc; int n; uchar buf[2]; /* add framing */ n = packetSize(p); if(n >= (1<<16)) { vtSetError(EBigPacket); packetFree(p); return 0; } buf[0] = n>>8; buf[1] = n; packetPrefix(p, buf, 2); for(;;) { n = packetFragments(p, &ioc, 1, 0); if(n == 0) break; if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) { packetFree(p); return 0; } packetConsume(p, nil, n); } packetFree(p); return 1; } int vtGetString(Packet *p, char **ret) { uchar buf[2]; int n; char *s; if(!packetConsume(p, buf, 2)) return 0; n = (buf[0]<<8) + buf[1]; if(n > VtMaxStringSize) { vtSetError(EBigString); return 0; } s = vtMemAlloc(n+1); setmalloctag(s, getcallerpc(&p)); if(!packetConsume(p, (uchar*)s, n)) { vtMemFree(s); return 0; } s[n] = 0; *ret = s; return 1; } int vtAddString(Packet *p, char *s) { uchar buf[2]; int n; if(s == nil) { vtSetError(ENullString); return 0; } n = strlen(s); if(n > VtMaxStringSize) { vtSetError(EBigString); return 0; } buf[0] = n>>8; buf[1] = n; packetAppend(p, buf, 2); packetAppend(p, (uchar*)s, n); return 1; } int vtConnect(VtSession *z, char *password) { char buf[VtMaxStringSize], *p, *ep, *prefix; int i; USED(password); vtLock(z->lk); if(z->cstate != VtStateAlloc) { vtSetError("bad session state"); vtUnlock(z->lk); return 0; } if(z->fd < 0){ vtSetError("%s", z->fderror); vtUnlock(z->lk); return 0; } /* be a little anal */ vtLock(z->inLock); vtLock(z->outLock); prefix = "venti-"; p = buf; ep = buf + sizeof(buf); p = seprint(p, ep, "%s", prefix); p += strlen(p); for(i=0; vtVersions[i].version; i++) { if(i != 0) *p++ = ':'; p = seprint(p, ep, "%s", vtVersions[i].s); } p = seprint(p, ep, "-libventi\n"); assert(p-buf < sizeof(buf)); if(z->outHash) vtSha1Update(z->outHash, (uchar*)buf, p-buf); if(!vtFdWrite(z->fd, (uchar*)buf, p-buf)) goto Err; vtDebug(z, "version string out: %s", buf); if(!vtVersionRead(z, prefix, &z->version)) goto Err; vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z)); vtUnlock(z->inLock); vtUnlock(z->outLock); z->cstate = VtStateConnected; vtUnlock(z->lk); if(z->vtbl) return 1; if(!vtHello(z)) goto Err; return 1; Err: if(z->fd >= 0) vtFdClose(z->fd); z->fd = -1; vtUnlock(z->inLock); vtUnlock(z->outLock); z->cstate = VtStateClosed; vtUnlock(z->lk); return 0; }