Skip to Content [alt-c]

June 25, 2020

Writing an SNI Proxy in 115 Lines of Go

The very first message sent in a TLS connection is the Client Hello record, in which the client greets the server and tells it, among other things, the server name it wants to connect to. This is called Server Name Indication, or SNI for short, and it's quite handy as it allows many different servers to be co-located on a single IP address.

The server name is sent in plaintext, which is unfortunately really bad for privacy and censorship resistance, but does enable something very useful: a proxy server can read the server name and use it to decide where to route the connection, without having to decrypt the connection. You can leverage this to make many different physical servers accessible from the Internet even if you have only one public IPv4 address: the proxy listens on your public IP address and forwards connections to the appropriate private IP address based on the SNI.

I just finished writing such a proxy server, which I plan to run on my home network's router so that I can easily access my internal servers from anywhere on the Internet, without a VPN or SSH port forwarding. I was pleased by how easy it was to write this proxy server using only Go's standard library. It's a great example of how well-suited Go is for programs involving networking and cryptography.

Let's start with a standard listen/accept loop (right out of the examples for Go's net package):

func main() {
	l, err := net.Listen("tcp", ":443")
	if err != nil {
		log.Fatal(err)
	}
	for {
		conn, err := l.Accept()
		if err != nil {
			log.Print(err)
			continue
		}
		go handleConnection(conn)
	}
}

Here's a sketch of the handleConnection function, which reads the Client Hello record from the client, dials the backend server indicated by the Client Hello, and then proxies the client to and from the backend. (Note that we dial the backend using the SNI value, which works well with split-horizon DNS where the proxy sees the backend's private IP address and external clients see the proxy's public IP address. If that doesn't work for you, can use more complicated routing logic.)

func handleConnection(clientConn net.Conn) {
	defer clientConn.Close()

	// ... read Client Hello from clientConn ...

	backendConn, err := net.Dial("tcp", net.JoinHostPort(clientHello.ServerName, "443"))
	if err != nil {
		log.Print(err)
		return
	}
	defer backendConn.Close()

	// ... proxy clientConn <==> backendConn ...
}

Let's assume for now we have a convenient function to read a Client Hello record from an io.Reader and return a tls.ClientHelloInfo:

func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error)

We can't simply call this function from handleConnection, because once the Client Hello is read, the bytes are gone. We need to preserve the bytes and forward them along to the backend, which is expecting a proper TLS connection that starts with a Client Hello record.

What we need to do instead is "peek" at the Client Hello record, and thanks to some simple but powerful abstractions from Go's io package, this can be done with just six lines of code:

func peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
        peekedBytes := new(bytes.Buffer)
        hello, err := readClientHello(io.TeeReader(reader, peekedBytes))
        if err != nil {
                return nil, nil, err
        }
        return hello, io.MultiReader(peekedBytes, reader), nil
}

What this code does is create a TeeReader, which is a reader that wraps another reader and writes everything that is read to a writer, which in our case is a byte buffer. We pass the TeeReader to readClientHello, so every byte read by readClientHello gets saved to our buffer. Finally, we create a MultiReader which essentially concatenates our buffer with the original reader. Reads from the MultiReader initially come out of the buffer, and when that's exhausted, continue from the original reader. We return the MultiReader to the caller along with the ClientHelloInfo. When the caller reads from the MultiReader it will see a full TLS connection stream, starting with the Client Hello.

Now we just need to implement readClientHello. We could open up the TLS RFCs and learn how to parse a Client Hello record, but it turns out we can let crypto/tls do the work for us, thanks to a callback function in tls.Config called GetConfigForClient:

// GetConfigForClient, if not nil, is called after a ClientHello is
// received from a client.
GetConfigForClient func(*ClientHelloInfo) (*Config, error) // Go 1.8

Roughly, what we need to do is create a TLS server-side connection with a GetConfigForClient callback that saves the ClientHelloInfo passed to it. However, creating a TLS connection requires a full-blown net.Conn, and readClientHello is passed merely an io.Reader. So let's create a type, readOnlyConn, which wraps an io.Reader and satisfies the net.Conn interface:

type readOnlyConn struct {
        reader io.Reader
}
func (conn readOnlyConn) Read(p []byte) (int, error)         { return conn.reader.Read(p) }
func (conn readOnlyConn) Write(p []byte) (int, error)        { return 0, io.ErrClosedPipe }
func (conn readOnlyConn) Close() error                       { return nil }
func (conn readOnlyConn) LocalAddr() net.Addr                { return nil }
func (conn readOnlyConn) RemoteAddr() net.Addr               { return nil }
func (conn readOnlyConn) SetDeadline(t time.Time) error      { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error  { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }

readOnlyConn forwards reads to the reader and simulates a broken pipe when written to (as if the client closed the connection before the server could reply). All other operations are a no-op.

Now we're ready to write readClientHello:

func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
        var hello *tls.ClientHelloInfo

        err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
                GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
                        hello = new(tls.ClientHelloInfo)
                        *hello = *argHello
                        return nil, nil
                },
        }).Handshake()

        if hello == nil {
                return nil, err
        }

        return hello, nil
}

Note that Handshake always fails because the readOnlyConn is not a real connection. As long as the Client Hello is successfully read, the failure should only happen after GetConfigForClient is called, so we only care about the error if hello was never set.

Let's put everything together to write the full handleConnection function. I've added deadlines (thanks, Filippo!) and a check that the SNI value ends with .internal.example.com to prevent this from being used as an open proxy. When I deploy this, I will use the DNS suffix of my home network.

func handleConnection(clientConn net.Conn) {
	defer clientConn.Close()

	if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
		log.Print(err)
		return
	}

	clientHello, clientReader, err := peekClientHello(clientConn)
	if err != nil {
		log.Print(err)
		return
	}

	if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
		log.Print(err)
		return
	}

	if !strings.HasSuffix(clientHello.ServerName, ".internal.example.com") {
		log.Print("Blocking connection to unauthorized backend")
		return
	}

	backendConn, err := net.DialTimeout("tcp", net.JoinHostPort(clientHello.ServerName, "443"), 5*time.Second)
	if err != nil {
		log.Print(err)
		return
	}
	defer backendConn.Close()

        var wg sync.WaitGroup
        wg.Add(2)

        go func() {
                io.Copy(clientConn, backendConn)
                clientConn.(*net.TCPConn).CloseWrite()
                wg.Done()
        }()
        go func() {
                io.Copy(backendConn, clientReader)
                backendConn.(*net.TCPConn).CloseWrite()
                wg.Done()
        }()

        wg.Wait()
}

Here's the complete Go source code - just 115 lines! (Not counting copyright legalese)

Comments

Reader Ameya on 2020-06-26 at 18:56:

Super useful post! re: deadlines, wouldn't you also want to set a finite deadline on the connection after the TLS handshake? Otherwise, if either side stops sending data or a FIN gets lost somehow, you could leak connections.

Reply

Andrew Ayer on 2020-06-27 at 14:21:

Thanks, Ameya! You raise a good point about leaking connections but it's unfortunately not that simple to fix. I want to avoid deadlines because the proxy should be able to handle any type of protocol, and some protocols might legitimately have long periods of inactivity. Instead, it should be the backend's responsibility to close idle connections and the proxy should detect when that has happened.

Unfortunately, Go provides no way to detect when a connection has been fully closed without trying to write at least one byte to it, which isn't sufficient because the proxy doesn't always have a byte to write. On Unix, closure can be detected without writing by polling the socket with events=0 and looking for POLLHUP in revents. I'm exploring possible solutions and plan to publish a blog post about it.

Reply

Reader Ameya on 2020-06-27 at 20:44:

Unfortunately, Go provides no way to detect when a connection has been fully closed without trying to write at least one byte to it

Thank you! As a newcomer to Go, this problem drove me crazy. In Java, if I were running a select loop, I can just get an event that the remote end disconnected. I'm writing a HTTP (1.1) proxy for delivering webhooks in Go as a side project, and the way I "solve" this problem is by having an idle read timeout on the connection. Of course, this only works because it's HTTP, and the intended behavior of my proxy is to close the connection after a request-response interaction anyway.

I'm exploring possible solutions and plan to publish a blog post about it

Very much looking forward to it!

Reply

Reader Jason Stangroome on 2020-07-06 at 05:52:

Hi Andrew, nice work! Have you considered how this might need to change for TLS v1.3 where the SNI value is encrypted?

Reply

Andrew Ayer on 2020-07-06 at 14:58:

Thanks, Jason!

SNI is not encrypted in TLS 1.3 and this code works with TLS 1.3.

There is ongoing work to add encrypted SNI to TLS <https://tools.ietf.org/html/draft-ietf-tls-esni-07>. The proposal explicitly supports SNI-based proxying. The proxy server would operate as the "provider", receive the encrypted Client Hello, decrypt it, and forward the connection along to the backend, without seeing the plaintext of the connection. This is the "split mode topology" shown on page 4.

Reply

Reader Nuno on 2021-01-25 at 11:38:

Hi, nice work!

If you're using this on Linux, I advise doing away with the io.MultiReader. Just return the bytes.Buffer, and io.Copy both. net.TCPConn implements io.ReaderFrom using the splice(2) syscall, which makes this much more efficient (everything happens in kernel space). io.Copy uses this implementation if it gets the unwrapped net.TCPConn.

Note, however, that this raises the number of open files. Every proxied connection uses 6 fds, 2 for the TCP connections and 4 pipes; your version is doing io.Copy on the return connection so it uses 4 fds per proxied connection.

So if you're using this for something more popular this might require some ulimit magic (the default is 1024, which is good for about 170 proxied connections).

Reply

Anonymous on 2022-10-17 at 10:34:

If you're using this on Linux, I advise doing away with the io.MultiReader. Just return the bytes.Buffer, and io.Copy both. net.TCPConn implements io.ReaderFrom using the splice(2) syscall, which makes this much more efficient (everything happens in kernel space).

A working patch:

--- sniproxy.go.orig 2022-10-17 10:28:24.851394840 +0000 +++ sniproxy.go 2022-10-17 10:30:02.661635510 +0000 @@ -57,7 +57,7 @@ if err != nil { return nil, nil, err } - return hello, io.MultiReader(peekedBytes, reader), nil + return hello, peekedBytes, nil } type readOnlyConn struct { @@ -99,7 +99,7 @@ return } - clientHello, clientReader, err := peekClientHello(clientConn) + clientHello, clientHelloBytes, err := peekClientHello(clientConn) if err != nil { log.Print(err) return @@ -131,7 +131,8 @@ wg.Done() }() go func() { - io.Copy(backendConn, clientReader) + io.Copy(backendConn, clientHelloBytes) + io.Copy(backendConn, clientConn) backendConn.(*net.TCPConn).CloseWrite() wg.Done() }()

Reply

Anonymous on 2023-08-14 at 11:50:

Very useful!

Reply

Post a Comment

Your comment will be public. To contact me privately, email me. Please keep your comment polite, on-topic, and comprehensible. Your comment may be held for moderation before being published.

(Optional; will be published)

(Optional; will not be published)

(Optional; will be published)

  • Blank lines separate paragraphs.
  • Lines starting with > are indented as block quotes.
  • Lines starting with two spaces are reproduced verbatim (good for code).
  • Text surrounded by *asterisks* is italicized.
  • Text surrounded by `back ticks` is monospaced.
  • URLs are turned into links.
  • Use the Preview button to check your formatting.