Plan 9 from Bell Labs’s /usr/web/sources/contrib/axel/8021x/v03/ttls.c

Copyright © 2021 Plan 9 Foundation.
Distributed under the MIT License.
Download the Plan 9 distribution.


#include <u.h>
#include <libc.h>
#include <thread.h>
#include <bio.h>
#include <ip.h>
#include <mp.h>
#include <libsec.h>
#include "dat.h"
#include "fns.h"


typedef struct TTLS {
	uchar tp;
	uchar flags;
	uchar tln[4];	//optional, present if L flag set
} TTLS;

enum {
	TtlsFlagL		= 1<<7,	// header contains tln field
	TtlsFlagM		= 1<<6,	// more fragment(s) will follow for current msg
	TtlsFlagS		= 1<<5,	// start of tls session
	TtlsVersion	= (1<<2)|(1<<1)|(1<<0),

	TtlsShortHlen	= 2,	// without tln field
	TtlsLongHlen	= TtlsShortHlen+4, // with tln field

	Idle = 0,
	Start,
	Waiting,
	Sending,
	RecvAck,
	Receiving,
	SendAck,
	Received,
};

char *snames[] = {
[Idle]			"Idle",
[Start]		"Start",
[Waiting]		"Waiting",
[Sending]		"Sending",
[RecvAck]		"RecvAck",
[Receiving]	"Receiving",
[SendAck]		"SendAck",
[Received]		"Received",
};

typedef struct TTLSstate {
	TLSconn tlsconn;	// our handle to the tls connection
	int tlspipe[2];	// double pipe over which we talk with our tls
				// the stuff we read from it has to be fragmented, encapsulated and sent
				// the fragments we receive have to be reassembled and then written to it
	Channel *tlsfdc;	// used to send file desc we get from tlsClient
	Channel *readc;	// contains index in rbuf containing last msg read from tlspipe
	Channel *eofc;		// confirm eof on tlspipe
	Channel *startclientc;	// start new clientclient session
	Channel *startreadc;		// start new readproc session
	int tlsfd;

	int ttslTxLen;	// length of frame we prepared for sending
	int ttlsDone;	// done processing the frame (and, if needed, preparing the response)?
	int ttlsState;	// ttls state we are in
	uint ttlsVersion;

	Buf rbuf[Nrbuf];	// msg read from the tls pipe, to be sent (possibly in fragments)
	int ridx;	// index of first free rbuf
	int sendT;	// total length of msg to be sent
	uint sendL;	// length remaining to be sent
	uchar*sendP;	// pointer in rbuf[...] pointing to stuff remaining to be sent
	int sendS;	// still have to send first frame (fragment) for current msg?

	Buf wbuf;		// receive buffer in which we reassemble fragments, and then write to tlspipe
	uint recvT;	// total length we want to receive (and reassemble)
	uint recvL;	// length received (and reassembled) so far
	uchar*recvP;	// first free position (reassembly insert point) in recv buffer

	Thumbprint *thumbTable;

	int inuse;
	int clientid;

	uchar*theSessionCert;
	int theSessionCertlen;
	uchar* theSessionID;
	int theSessionIDlen;

} TTLSstate;

static TTLSstate theTTLSstate;
static char errbuf[256];

static void
cleanup(TTLSstate* s)
{
	int idx, readdone;
	Alt a[] = {
	/*	 c			v		op   */
		{s->readc,	&idx,	CHANRCV},
		{s->eofc,		nil,		CHANRCV},
		{nil,			nil,		CHANEND},
	};

	syslog(0, logname, "cleanup pre tlsfd=%d tlspipe[0]=%d  tlspipe[1]=%d", s->tlsfd, s->tlspipe[0], s->tlspipe[1]);
	if (!s->inuse)
		return;

	readdone = 0;

	if (s->tlsfd < 0 && s->tlspipe[0] < 0 && s->tlspipe[1] < 0)
		readdone = 1;
	if (s->tlsfd >= 0) {
		syslog(0, logname, "\tcleanup: closing tlsfd: %d", s->tlsfd);
		close(s->tlsfd); // should make devtls  close s->tlspipe[1], causing eof on s->tlspipe[0] in readproc
		s->tlsfd = -1;
	}
	if (s->tlspipe[0] >= 0) {
		syslog(0, logname, "\tcleanup: closing tlspipe[0]: %d", s->tlspipe[0]);
		close(s->tlspipe[0]);
		s->tlspipe[0] = -1;
	}

	if (s->clientid != 0)
		threadkill(s->clientid);
	s->clientid = 0;

	syslog(0, logname, "\tcleanup middle readdone=%d tlsfd=%d tlspipe[0]=%d  tlspipe[1]=%d", readdone, s->tlsfd, s->tlspipe[0], s->tlspipe[1]);

	while(!readdone) {
		syslog(0, logname, "\tcleanup receiving...");
		switch(alt(a)){
		case 0:
			syslog(0, logname, "\t\toops... cleanup recv from readc: %d", idx);
			// is this the close assert . if so, should we write this to ether? 
			break;
		case 1:
			syslog(0, logname, "\t\tcleanup: confirmed eof from readproc");
			readdone =  1;
			break;
		}
	}
	if (s->tlspipe[1] >= 0) {
		syslog(0, logname, "\tcleanup: closing tlspipe[1]: %d", s->tlspipe[1]);
		close(s->tlspipe[1]);
		s->tlspipe[1] = -1;
	}
	syslog(0, logname, "cleanup post tlsfd=%d tlspipe[0]=%d  tlspipe[1]=%d", s->tlsfd, s->tlspipe[0], s->tlspipe[1]);
}

static void
readproc(void *arg)
{
	TTLSstate *s;
	Buf *r;

	s = arg;

	syslog(0, logname, "readproc starts: %d", threadid());
	while(recvul(s->startreadc)) {
		syslog(0, logname, "readproc monitoring pipe: %d", s->tlspipe[0]);
		for(;;) {
			if (s->tlspipe[0] < 0) {
				syslog(0, logname, "readproc pipe not active: %d", s->tlspipe[0]);
				break;
			}
			r = &s->rbuf[s->ridx];
			r->n = read(s->tlspipe[0], r->b, Nbuf);
			syslog(0, logname, "readproc read from %d:%d", s->tlspipe[0], r->n);
			if (r->n <= 0) {
				syslog(0, logname, "readproc eof on pipe: %d", s->tlspipe[0]);
				break;
			}
			if (s->tlspipe[0] < 0) {
				syslog(0, logname, "readproc pipe no longer active: %d", s->tlspipe[0]);
				break;
			}
//			syslog(0, logname, "readproc sending...");
			sendul(s->readc, s->ridx);
			s->ridx = (s->ridx+1)%Nrbuf;
		}
		syslog(0, logname, "readproc sending eofc: %d", s->tlspipe[0]);
		sendul(s->eofc, 0);
		syslog(0, logname, "readproc restarts: %d", s->tlspipe[0]);
	}
	syslog(0, logname, "readproc exits: %d", threadid());
	threadexits(nil);
}

static void
clientproc(void *arg)
{
	TTLSstate *s;
	int fd;
	uchar hash[SHA1dlen];

	s = arg;
	syslog(0, logname, "clientproc starts: %d", threadid());
	s->clientid = threadid();

	syslog(0, logname, "clientproc (re)starting: tlspipe[1]=%d", s->tlspipe[1]);
	if (s->tlspipe[1] <= 0) {
		snprint(errbuf, sizeof(errbuf), "clientproc: no fd for tlsClient:%d", s->tlspipe[1]);
		syslog(0, logname, "%s", errbuf);
		fprint(2, "%s\n", errbuf);
		threadexitsall(errbuf);
	}
	
	syslog(0, logname, "calling tlsClient");
	fd  = tlsClient(s->tlspipe[1], &s->tlsconn);
	syslog(0, logname, "tlsClient result: fd=%d", fd);
	if (debug) print("clientproc: fd %d\n", fd);
	if (fd < 0) {
		syslog(0, logname, "tlsClient failed: %r");
		fprint(2, "tlsClient failed: %r\n");

	} else {
		syslog(0, logname, "tlsClient ok fd=%d", fd);
		if (s->tlsconn.cert==nil || s->tlsconn.certlen<=0) {
			syslog(0, logname, "server did not provide TLS certificate");
			fprint(2, "server did not provide TLS certificate\n");
		} else {
			X509dump(s->tlsconn.cert, s->tlsconn.certlen);
			if (s->thumbTable != nil) {
				sha1(s->tlsconn.cert, s->tlsconn.certlen, hash, nil);
				if(!okThumbprint(hash, s->thumbTable)) {
					syslog(0, logname, "server certificate %.*H not recognized", SHA1dlen, hash);
					fprint(2, "server certificate %.*H not recognized\n", SHA1dlen, hash);
				}
			}   else {
				syslog(0, logname, "no thumbprint to check server certificate");
			}
		}
	}

	// clean up before we (implicitly) yield
	if (s->tlsconn.sessionID != nil)
		free(s->tlsconn.sessionID);
	s->tlsconn.sessionID = nil;
	s->tlsconn.sessionIDlen = 0;

	if (s->tlsconn.cert)
		free(s->tlsconn.cert);
	s->tlsconn.cert = nil;
	s->tlsconn.certlen = 0;

	sendul(s->tlsfdc, fd);

	syslog(0, logname, "clientproc  ... finished: fd=%d", fd);
	syslog(0, logname, "clientproc exits: %d", threadid());
	s->clientid = 0;
	threadexits(nil);
}

static void
setupTls(TTLSstate *s)
{
	syslog(0, logname, "setupTls pre tlspipe[0]=%d  tlspipe[1]=%d", s->tlspipe[0], s->tlspipe[1]);
	if (s->tlspipe[0] >= 0 || s->tlspipe[1] >= 0) {
		snprint(errbuf, sizeof(errbuf), "setupTls: pipe already open? %d %d", s->tlspipe[0], s->tlspipe[1]);
		fprint(2, "%s\n", errbuf);
		syslog(0, logname, "%s", errbuf);
		threadexitsall(errbuf);
	}
	if (pipe(s->tlspipe) < 0) {
		fprint(2, "pipe failed: %r\n");
		syslog(0, logname, "pipe failed: %r");
		threadexitsall("pipe failed");
	}

	// call tlsClient and wait for result
	syslog(0, logname, "setupTls startclientc...");
	s->clientid = proccreate(clientproc, s, STACK);

	// signal reader to restart
	syslog(0, logname, "setupTls startreadc...");
	sendul(s->startreadc, 1);

	s->inuse = 1;

	syslog(0, logname, "setupTls post tlspipe[0]=%d  tlspipe[1]=%d", s->tlspipe[0], s->tlspipe[1]);
}

static int
buildFrameStart(TTLSstate *s, uchar*b, int mtu)
{
	TTLS *t;

	if (mtu <= TtlsLongHlen)
		print("buildFrameStart error: mtu much too small: mtu=%d, longhdr=%d\n", mtu, TtlsLongHlen);
	if (s->sendL <= mtu-TtlsLongHlen)
		print("buildFrameStart error: small enough, no framing needed: sz=%d, space=%d\n", s->sendL, mtu-TtlsLongHlen);
	t = (TTLS*)b;
	memset(t, 0, TtlsLongHlen);
	t->tp = EapTpTtls;
	t->flags = TtlsFlagM | TtlsFlagL;
	hnputl(t->tln, s->sendL);
	memcpy(b+TtlsLongHlen, s->sendP, mtu-TtlsLongHlen);
	s->ttslTxLen = mtu;
	s->sendP += mtu-TtlsLongHlen;
	s->sendL -= mtu-TtlsLongHlen;
	return mtu;
}

static int
buildFrameMiddle(TTLSstate *s, uchar*b, int mtu)
{
	TTLS *t;

	if (mtu <= TtlsShortHlen)
		print("buildFrameMiddle error: mtu much too small: mtu=%d, longhdr=%d\n", mtu, TtlsShortHlen);
	if (s->sendL <= mtu-TtlsShortHlen)
		print("buildFrameMiddle error: small enough, no framing needed: sz=%d, space=%d\n", s->sendL, mtu-TtlsShortHlen);
	t = (TTLS*)b;
	memset(t, 0, TtlsShortHlen);
	t->tp = EapTpTtls;
	t->flags = TtlsFlagM;
	memcpy(b+TtlsShortHlen, s->sendP, mtu-TtlsShortHlen);
	s->ttslTxLen = mtu;
	s->sendP += mtu-TtlsShortHlen;
	s->sendL -= mtu-TtlsShortHlen;
	return mtu;
}

static int
buildMsg(TTLSstate *s, uchar*b, int mtu)
{
	TTLS *t;
	int res;

	if (mtu <= TtlsShortHlen)
		print("buildMsg error: mtu much too small: mtu=%d, longhdr=%d\n", mtu, TtlsShortHlen);
	if (s->sendL > mtu-TtlsShortHlen)
		print("buildMsg error: too big, framing needed: sz=%d, space=%d\n", s->sendL, mtu-TtlsShortHlen);
	t = (TTLS*)b;
	memset(t, 0, TtlsShortHlen);
	t->tp = EapTpTtls;
	memcpy(b+TtlsShortHlen, s->sendP, s->sendL);
	s->ttslTxLen = TtlsShortHlen + s->sendL;
	res = s->sendL;
	s->sendP = 0;
	s->sendL = 0;
	return res;
}

static void
buildAck(TTLSstate *s, uchar*b, int mtu)
{
	TTLS *t;

	USED(mtu);
	t = (TTLS*)b;
	memset(t, 0, TtlsShortHlen);
	t->tp = EapTpTtls;
	s->ttslTxLen = TtlsShortHlen;
}

static void
trans(TTLSstate *s, int new)
{
	if (debug) print("ttls trans: %s -> %s\n", (s->ttlsState>=0) ? snames[s->ttlsState] : "-", snames[new]);
	switch(new){
	case RecvAck:
		s->ttlsDone = 1;
		break;
	case Receiving:
		s->ttlsDone = 1;
		break;
	case Idle:
		s->ttlsDone = 1;
		break;
	}
	s->ttlsState = new;
}

static void
ttls(TTLSstate *s, uchar*rcvp, uint rcvl, uchar*txp, uint mtu, int*ttlsSuccess, int*ttlsFail)
{
	int fd;
	int i;
	Alt a[] = {
	/*	 c		v		op   */
		{s->tlsfdc,	&fd,	CHANRCV},
		{s->readc,	&i,	CHANRCV},
		{nil,	nil,	CHANEND},
	};
	TTLS *t;
	uchar *p;
	uint l;
	int n;
	int olen, flen;

//	print("ttls %s\n", snames[s->ttlsState]);
	if (debug) print("ttls %s; recvL=%d; recvT=%d\n", snames[s->ttlsState], s->recvL, s->recvT);
	switch(s->ttlsState){
	case Idle:
		trans(s, Idle);
		break;
	case Start:
		setupTls(s); // new session
		trans(s, Waiting);
		break;
	case Waiting:
		while(s->ttlsState == Waiting) {
			switch(alt(a)){
			case 0: // the tlsClient call returned
				// if success, start phase 2
				if (debug) print("ttls tlsfdc: fd=%d\n", fd);
				syslog(0, logname, "ttls tlsfdc: fd=%d", fd);
				s->tlsfd = fd;
				if (fd < 0) {
					*ttlsFail = 1;
					trans(s, Idle);
				} else {
					doTTLSphase2(fd);
				}
				break;
			case 1: // something read from tlspipe: encapsulate and send
				s->sendP = s->rbuf[i].b;
				s->sendL = s->rbuf[i].n;
				s->sendT = s->sendL;
				if (debug) print("ttls readc: i=%d sendP=%p sendL=%d\n", i, s->sendP, s->sendL);
				s->sendS = 1;
				trans(s, Sending);
				break;
			}
		}
		break;
	case Sending:
		if (s->sendS && s->sendL > mtu-TtlsShortHlen) {
			olen = s->sendL;
			flen = buildFrameStart(s, txp, mtu);
			if (debug) print("ttls sendS and framed %d of %d, total %d, remains %d\n", flen, olen, s->sendT, s->sendL);
			s->sendS = 0;
			trans(s, RecvAck);
		} else if (s->sendL > mtu-TtlsShortHlen) {
			olen = s->sendL;
			flen = buildFrameMiddle(s, txp, mtu);
			if (debug) print("ttls framed %d of %d, total %d, remains %d\n", flen, olen, s->sendT, s->sendL);
			trans(s, RecvAck);
		} else {
			olen = s->sendL;
			flen = buildMsg(s, txp, mtu);
			if (debug) print("ttls framed %d of %d, total %d, remains %d\n", flen, olen, s->sendT, s->sendL);
			s->recvP = s->wbuf.b;
			s->recvL = 0;
			s->recvT = 0;
			trans(s, Receiving);
		}
		break;
	case RecvAck:
		t = (TTLS*)rcvp;
		if (t->flags&TtlsFlagS)
			print("tls: unexpected TtlsFlagS in %s\n", snames[s->ttlsState]);
		if (t->flags&TtlsFlagM)
			print("tls: unexpected TtlsFlagM in %s\n", snames[s->ttlsState]);
		if (t->flags&TtlsFlagL)
			print("tls: unexpected TtlsFlagL in %s\n", snames[s->ttlsState]);
		trans(s, Sending);
		break;
	case Receiving:
		t = (TTLS*)rcvp;
		if (t->flags&TtlsFlagS)
			print("tls: unexpected TtlsFlagS in %s\n", snames[s->ttlsState]);
		if (t->flags&TtlsFlagL && s->recvT > 0)
			print("tls: TtlsFlagL when recvT=%d\n", s->recvT);
		if (t->flags&TtlsFlagL) {
			s->recvT = nhgetl(t->tln);
			if (debug) print("ttls:  TtlsFlagL len=%d\n", s->recvT);
			p = rcvp+TtlsLongHlen;
			l = rcvl-TtlsLongHlen;
			if (s->recvP != s->wbuf.b)
				print("ttls %s: recvP != wbuf.b  recvP=%p wbuf.b=%p \n", snames[s->ttlsState], s->recvP, s->wbuf.b);
			if (s->recvL != 0)
				print("ttls %s: recvL != 0  recvL=%d\n", snames[s->ttlsState], s->recvL);
		} else {
			p = rcvp+TtlsShortHlen;
			l = rcvl-TtlsShortHlen;
			if (s->recvP != s->wbuf.b + s->recvL)
				print("ttls %s: recvP != wbuf.b + s->recvL  recvP=%p wbuf.b=%p recvL=%d\n", snames[s->ttlsState], s->recvP, s->wbuf.b, s->recvL);
		}
		memcpy(s->recvP, p, l);
		s->recvP += l;
		s->recvL += l;
		if (debug) print("ttls %s: received %d; recvL=%d; recvT=%d\n", snames[s->ttlsState], l, s->recvL, s->recvT);
		if (t->flags&TtlsFlagM)
			trans(s, SendAck);
		else {
			if (s->recvT > 0 && s->recvT != s->recvL)
				print("ttls : recvT=%d != recvL=%d\n", s->recvT, s->recvL);
			if (s->recvL > 0)
				trans(s, Received);
			else
				trans(s, Waiting);
		}
		break;
	case SendAck:
		buildAck(s, txp, mtu);
		trans(s, Receiving);
		break;
	case Received:
		if (debug) print("ttls %s: writing tlspipe[0]: %s\n", snames[s->ttlsState], hexprefix(s->wbuf.b, s->recvL, 5));
		n = write(s->tlspipe[0], s->wbuf.b, s->recvL);
		if (n<0)
			print("ttls %s: error writing tlspipe[0]: %r\n", snames[s->ttlsState]);
		syslog(0, logname, "writeproc written %d", n);
		if (n != s->recvL)
			print("ttls %s: writing tlspipe[0]: n != recvL  n=%d recvL=%d\n", snames[s->ttlsState], n, s->recvL);
		if (debug) print("ttls %s: written to tlspipe[0] : %d\n", snames[s->ttlsState], s->recvL);
		trans(s, Waiting);
		break;
	}

	if (debug) print("ttls %s; recvL=%d; recvT=%d\n", snames[s->ttlsState], s->recvL, s->recvT);
//	print("ttls .... %s\n", snames[s->ttlsState]);
}

void
initTTLS(char *file, char *filex)
{
	TTLSstate *s;

	syslog(0, logname, "initTTLS");

	s = &theTTLSstate;
	memset(s, 0, sizeof(TTLSstate));

	s->ttlsState = Idle;

	s->tlsfdc = chancreate(sizeof(int), 0);
	s->readc = chancreate(sizeof(int), 0);
	s->eofc = chancreate(sizeof(int), 0);

	s->tlsfd = -1;
	s->tlspipe[0] = -1;
	s->tlspipe[1] = -1;

	s->tlsconn.sessionType = "ttls";
	s->tlsconn.sessionConst = "ttls keying material";
	s->tlsconn.sessionKey = theSessionKey;
	s->tlsconn.sessionKeylen = sizeof(theSessionKey);
	if (debugTLS)
		s->tlsconn.trace = print;

	fmtinstall('H', encodefmt);
	if (file) {
		s->thumbTable = initThumbprints(file, filex);
		if (s->thumbTable == nil) {
			snprint(errbuf, sizeof(errbuf), "initThumbprints: %r");
			syslog(0, logname, "%s", errbuf);
			fprint(2, "%s\n", errbuf);
			threadexitsall(errbuf);
		}
	}

	// proc to call tlsClient and wait for result
	s->startclientc = chancreate(sizeof(int), 0);

	// proc to call tlsClient and wait for result
	s->startreadc = chancreate(sizeof(int), 0);
	proccreate(readproc, s, STACK);

}

void
abortTTLS(void)
{
	TTLSstate *s;

	syslog(0, logname, "abortTTLS");
	s = &theTTLSstate;

	if (s->tlspipe[0] >= 0) {
		close(s->tlspipe[0]);
		s->tlspipe[0] = -1;
	}
}

static void
run(TTLSstate *s, uchar*rcvp, uint rcvl, uchar*txp, uint mtu, int*success, int*failed)
{
	s->ttlsDone = 0;
	while (!s->ttlsDone)
		ttls(s, rcvp, rcvl, txp, mtu, success, failed);
}

int
processTTLS(uchar*rcvp, uint rcvl, int expectStart, uchar*txp, uint mtu, int*success, int*failed)
{
	TTLS *hr;
	uchar flags, version;
	TTLSstate *s;

//	if (debug) print("processTTLS br=%p txp=%p mtu=%d bl=%d\n", br, txp, mtu, bl);

	s = &theTTLSstate;

	hr = (TTLS*)rcvp;

	if (hr->tp != EapTpTtls)
		return 0; // flag error??

	// first thing should be EAP-TTLS start packet
	flags = rcvp[1]; // check length
	version = flags & TtlsVersion;
	if (debug) print("processTTLS flags=%s%s%s ver=%d mtu=%d bl=%d\n",
		(flags&TtlsFlagS ? "S":""),
		(flags&TtlsFlagM ? "M":""),
		(flags&TtlsFlagL ? "L":""),
		version, mtu, rcvl);
	if (expectStart && !flags&TtlsFlagS) {
		fprint(2, "expected EAP-TTLS start packet\n");
		syslog(0, logname, "expected EAP-TTLS start packet");
		threadexitsall("expected EAP-TTLS start packet");
	}
	if (flags & TtlsFlagS) {
		cleanup(s); // previous session

		// ack??
		// look for piggy-backed stuff?

		s->ttlsVersion = version;
		s->ttlsState = Start;
		s->ttlsDone = 0;
		s->sendP = 0;
		s->sendL = 0;
		s->sendS = 0;
		s->sendT = 0;
		s->recvP = 0;
		s->recvL = 0;
		s->recvT = 0;
		// we don't have a client certificate
		s->tlsconn.cert = nil;
		s->tlsconn.certlen = 0;
		// avoid trying session resumption - tlsClient does not support it
		s->tlsconn.sessionID = nil;
		s->tlsconn.sessionIDlen = 0;
		
//		if (debug) print("processTTLS TtlsFlagS version=%d \n", version);
	}
	run(s, rcvp, rcvl, txp, mtu, success, failed);
	return s->ttslTxLen;
}

Bell Labs OSI certified Powered by Plan 9

(Return to Plan 9 Home Page)

Copyright © 2021 Plan 9 Foundation. All Rights Reserved.
Comments to webmaster@9p.io.