June 25, 2020
Writing an SNI Proxy in 115 Lines of Go
The very first message sent in a TLS connection is the Client Hello record, in which the client greets the server and tells it, among other things, the server name it wants to connect to. This is called Server Name Indication, or SNI for short, and it's quite handy as it allows many different servers to be co-located on a single IP address.
The server name is sent in plaintext, which is unfortunately really bad for privacy and censorship resistance, but does enable something very useful: a proxy server can read the server name and use it to decide where to route the connection, without having to decrypt the connection. You can leverage this to make many different physical servers accessible from the Internet even if you have only one public IPv4 address: the proxy listens on your public IP address and forwards connections to the appropriate private IP address based on the SNI.
I just finished writing such a proxy server, which I plan to run on my home network's router so that I can easily access my internal servers from anywhere on the Internet, without a VPN or SSH port forwarding. I was pleased by how easy it was to write this proxy server using only Go's standard library. It's a great example of how well-suited Go is for programs involving networking and cryptography.
Let's start with a standard listen/accept loop (right out of the examples for Go's net
package):
func main() { l, err := net.Listen("tcp", ":443") if err != nil { log.Fatal(err) } for { conn, err := l.Accept() if err != nil { log.Print(err) continue } go handleConnection(conn) } }
Here's a sketch of the handleConnection
function, which reads the
Client Hello record from the client, dials the backend server indicated by the
Client Hello, and then proxies the client to and from the backend. (Note that we
dial the backend using the SNI value, which works well with split-horizon DNS where
the proxy sees the backend's private IP address and external clients see the proxy's public
IP address. If that doesn't work for you, can use more complicated routing logic.)
func handleConnection(clientConn net.Conn) { defer clientConn.Close() // ... read Client Hello from clientConn ... backendConn, err := net.Dial("tcp", net.JoinHostPort(clientHello.ServerName, "443")) if err != nil { log.Print(err) return } defer backendConn.Close() // ... proxy clientConn <==> backendConn ... }
Let's assume for now we have a convenient function to read a Client
Hello record from an io.Reader
and return a tls.ClientHelloInfo
:
func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error)
We can't simply call this function from handleConnection
,
because once the Client Hello is read, the bytes are gone. We need to
preserve the bytes and forward them along to the backend, which is expecting
a proper TLS connection that starts with a Client Hello record.
What we need to do instead is "peek" at the Client Hello record, and
thanks to some simple but powerful abstractions from Go's io
package, this can be
done with just six lines of code:
func peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) { peekedBytes := new(bytes.Buffer) hello, err := readClientHello(io.TeeReader(reader, peekedBytes)) if err != nil { return nil, nil, err } return hello, io.MultiReader(peekedBytes, reader), nil }
What this code does is create a TeeReader
, which
is a reader that wraps another reader and writes everything that is read
to a writer, which in our case is a byte buffer.
We pass the TeeReader
to readClientHello
, so every byte
read by readClientHello
gets saved to our buffer. Finally,
we create a MultiReader
which essentially
concatenates our buffer with the original reader. Reads from the
MultiReader
initially come out of the buffer, and when that's exhausted,
continue from the original reader. We return the MultiReader
to the caller
along with the ClientHelloInfo
. When the caller reads from the MultiReader
it will see a full TLS connection stream, starting with the Client Hello.
Now we just need to implement readClientHello
. We could open up the TLS
RFCs and learn how to parse a Client Hello record, but it turns out we can
let crypto/tls
do the work for us, thanks to a callback function in tls.Config
called GetConfigForClient
:
// GetConfigForClient, if not nil, is called after a ClientHello is // received from a client. GetConfigForClient func(*ClientHelloInfo) (*Config, error) // Go 1.8
Roughly, what we need to do is create a TLS server-side
connection with a GetConfigForClient
callback
that saves the ClientHelloInfo
passed to it. However, creating a TLS connection requires a full-blown
net.Conn
,
and readClientHello
is passed merely an io.Reader
. So let's
create a type, readOnlyConn
, which wraps an io.Reader
and satisfies the net.Conn
interface:
type readOnlyConn struct { reader io.Reader } func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) } func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } func (conn readOnlyConn) Close() error { return nil } func (conn readOnlyConn) LocalAddr() net.Addr { return nil } func (conn readOnlyConn) RemoteAddr() net.Addr { return nil } func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
readOnlyConn
forwards reads to the reader and simulates a broken pipe when written to
(as if the client closed the connection before the server could reply).
All other operations are a no-op.
Now we're ready to write readClientHello
:
func readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) { var hello *tls.ClientHelloInfo err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{ GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { hello = new(tls.ClientHelloInfo) *hello = *argHello return nil, nil }, }).Handshake() if hello == nil { return nil, err } return hello, nil }
Note that Handshake
always fails because the readOnlyConn
is not a real connection. As long as the Client Hello is successfully read, the failure
should only happen after GetConfigForClient
is called, so we only care
about the error if hello
was never set.
Let's put everything together to write the full handleConnection
function.
I've added deadlines (thanks, Filippo!)
and a check that the SNI value ends with .internal.example.com
to prevent this from being used as an open proxy. When I deploy this, I will
use the DNS suffix of my home network.
func handleConnection(clientConn net.Conn) { defer clientConn.Close() if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { log.Print(err) return } clientHello, clientReader, err := peekClientHello(clientConn) if err != nil { log.Print(err) return } if err := clientConn.SetReadDeadline(time.Time{}); err != nil { log.Print(err) return } if !strings.HasSuffix(clientHello.ServerName, ".internal.example.com") { log.Print("Blocking connection to unauthorized backend") return } backendConn, err := net.DialTimeout("tcp", net.JoinHostPort(clientHello.ServerName, "443"), 5*time.Second) if err != nil { log.Print(err) return } defer backendConn.Close() var wg sync.WaitGroup wg.Add(2) go func() { io.Copy(clientConn, backendConn) clientConn.(*net.TCPConn).CloseWrite() wg.Done() }() go func() { io.Copy(backendConn, clientReader) backendConn.(*net.TCPConn).CloseWrite() wg.Done() }() wg.Wait() }
Here's the complete Go source code - just 115 lines! (Not counting copyright legalese)
Post a Comment
Your comment will be public. To contact me privately, email me. Please keep your comment polite, on-topic, and comprehensible. Your comment may be held for moderation before being published.
Comments
Reader Ameya on 2020-06-26 at 18:56:
Super useful post! re: deadlines, wouldn't you also want to set a finite deadline on the connection after the TLS handshake? Otherwise, if either side stops sending data or a FIN gets lost somehow, you could leak connections.
Reply
Andrew Ayer on 2020-06-27 at 14:21:
Thanks, Ameya! You raise a good point about leaking connections but it's unfortunately not that simple to fix. I want to avoid deadlines because the proxy should be able to handle any type of protocol, and some protocols might legitimately have long periods of inactivity. Instead, it should be the backend's responsibility to close idle connections and the proxy should detect when that has happened.
Unfortunately, Go provides no way to detect when a connection has been fully closed without trying to write at least one byte to it, which isn't sufficient because the proxy doesn't always have a byte to write. On Unix, closure can be detected without writing by polling the socket with
events=0
and looking forPOLLHUP
inrevents
. I'm exploring possible solutions and plan to publish a blog post about it.Reply
Reader Ameya on 2020-06-27 at 20:44:
Thank you! As a newcomer to Go, this problem drove me crazy. In Java, if I were running a select loop, I can just get an event that the remote end disconnected. I'm writing a HTTP (1.1) proxy for delivering webhooks in Go as a side project, and the way I "solve" this problem is by having an idle read timeout on the connection. Of course, this only works because it's HTTP, and the intended behavior of my proxy is to close the connection after a request-response interaction anyway.
Very much looking forward to it!
Reply
Reader Jason Stangroome on 2020-07-06 at 05:52:
Hi Andrew, nice work! Have you considered how this might need to change for TLS v1.3 where the SNI value is encrypted?
Reply
Andrew Ayer on 2020-07-06 at 14:58:
Thanks, Jason!
SNI is not encrypted in TLS 1.3 and this code works with TLS 1.3.
There is ongoing work to add encrypted SNI to TLS <https://tools.ietf.org/html/draft-ietf-tls-esni-07>. The proposal explicitly supports SNI-based proxying. The proxy server would operate as the "provider", receive the encrypted Client Hello, decrypt it, and forward the connection along to the backend, without seeing the plaintext of the connection. This is the "split mode topology" shown on page 4.
Reply
Reader Nuno on 2021-01-25 at 11:38:
Hi, nice work!
If you're using this on Linux, I advise doing away with the io.MultiReader. Just return the bytes.Buffer, and io.Copy both. net.TCPConn implements io.ReaderFrom using the splice(2) syscall, which makes this much more efficient (everything happens in kernel space). io.Copy uses this implementation if it gets the unwrapped net.TCPConn.
Note, however, that this raises the number of open files. Every proxied connection uses 6 fds, 2 for the TCP connections and 4 pipes; your version is doing io.Copy on the return connection so it uses 4 fds per proxied connection.
So if you're using this for something more popular this might require some ulimit magic (the default is 1024, which is good for about 170 proxied connections).
Reply
Anonymous on 2022-10-17 at 10:34:
A working patch:
--- sniproxy.go.orig 2022-10-17 10:28:24.851394840 +0000 +++ sniproxy.go 2022-10-17 10:30:02.661635510 +0000 @@ -57,7 +57,7 @@ if err != nil { return nil, nil, err } - return hello, io.MultiReader(peekedBytes, reader), nil + return hello, peekedBytes, nil }
type readOnlyConn struct { @@ -99,7 +99,7 @@ return }
- clientHello, clientReader, err := peekClientHello(clientConn) + clientHello, clientHelloBytes, err := peekClientHello(clientConn) if err != nil { log.Print(err) return @@ -131,7 +131,8 @@ wg.Done() }() go func() { - io.Copy(backendConn, clientReader) + io.Copy(backendConn, clientHelloBytes) + io.Copy(backendConn, clientConn) backendConn.(*net.TCPConn).CloseWrite() wg.Done() }()
Reply
Anonymous on 2023-08-14 at 11:50:
Very useful!
Reply
Anonymous on 2024-05-08 at 03:33:
Hi, will this be possible to also proxy udp traffic so the new http/3 standard will be proxied for clients?
If not, will it be considered for a future feature?
Thanks
Reply
Andrew Ayer on 2024-05-12 at 21:34:
Proxying QUIC will require the backend servers to generate connection IDs that the proxy can decode, see https://quicwg.org/load-balancers/draft-ietf-quic-load-balancers.html
It will take more than 115 lines of code, but I intend to add support to snid at some point - https://github.com/AGWA/snid
Reply