Skip to content

Commit c8003d4

Browse files
committed
Use bufio reader
1 parent 5f4bc7a commit c8003d4

3 files changed

Lines changed: 48 additions & 68 deletions

File tree

httpmuxer/https.go

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
3131
continue
3232
}
3333

34-
tlsHello, buf, teeConn, peekErr := utils.PeekTLSHello(cl)
34+
tlsHello, teeConn, peekErr := utils.PeekTLSHello(cl)
3535
if peekErr != nil && tlsHello == nil {
3636
return teeConn, nil
3737
}
@@ -59,20 +59,20 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
5959
connectionLocation, err := balancer.NextServer()
6060
if err != nil {
6161
log.Println("Unable to load connection location:", err)
62-
cl.Close()
62+
teeConn.Close()
6363
continue
6464
}
6565

6666
host, err := base64.StdEncoding.DecodeString(connectionLocation.Host)
6767
if err != nil {
6868
log.Println("Unable to decode connection location:", err)
69-
cl.Close()
69+
teeConn.Close()
7070
continue
7171
}
7272

7373
hostAddr := string(host)
7474

75-
logLine := fmt.Sprintf("Accepted connection from %s -> %s", cl.RemoteAddr().String(), cl.LocalAddr().String())
75+
logLine := fmt.Sprintf("Accepted connection from %s -> %s", teeConn.RemoteAddr().String(), teeConn.LocalAddr().String())
7676
log.Println(logLine)
7777

7878
if viper.GetBool("log-to-client") {
@@ -94,18 +94,11 @@ func (pL *proxyListener) Accept() (net.Conn, error) {
9494
conn, err := net.Dial("unix", hostAddr)
9595
if err != nil {
9696
log.Println("Error connecting to tcp balancer:", err)
97-
cl.Close()
98-
continue
99-
}
100-
101-
_, err = conn.Write(buf.Bytes())
102-
if err != nil {
103-
log.Println("Unable to write to conn:", err)
104-
cl.Close()
97+
teeConn.Close()
10598
continue
10699
}
107100

108-
go utils.CopyBoth(conn, cl)
101+
go utils.CopyBoth(conn, teeConn)
109102
}
110103
}
111104

utils/conn.go

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package utils
22

33
import (
4+
"bufio"
45
"bytes"
56
"crypto/tls"
67
"io"
@@ -88,44 +89,22 @@ func (s *SSHConnection) CleanUp(state *State) {
8889

8990
// TeeConn represents a simple net.Conn interface for SNI Processing.
9091
type TeeConn struct {
91-
Conn net.Conn
92-
Reader io.Reader
93-
Buffer *bytes.Buffer
94-
FirstRead bool
95-
Flushed bool
92+
Conn net.Conn
93+
Buffer *bufio.ReadWriter
9694
}
9795

9896
// Read implements a reader ontop of the TeeReader.
9997
func (conn *TeeConn) Read(p []byte) (int, error) {
100-
if !conn.FirstRead {
101-
conn.FirstRead = true
102-
return conn.Reader.Read(p)
103-
}
104-
105-
if conn.FirstRead && !conn.Flushed {
106-
conn.Flushed = true
107-
copy(p[0:conn.Buffer.Len()], conn.Buffer.Bytes())
108-
return conn.Buffer.Len(), nil
109-
}
110-
111-
return conn.Conn.Read(p)
98+
return conn.Buffer.Read(p)
11299
}
113100

114101
// Write is a shim function to fit net.Conn.
115102
func (conn *TeeConn) Write(p []byte) (int, error) {
116-
if !conn.Flushed {
117-
return 0, io.ErrClosedPipe
118-
}
119-
120-
return conn.Conn.Write(p)
103+
return conn.Buffer.Write(p)
121104
}
122105

123106
// Close is a shim function to fit net.Conn.
124107
func (conn *TeeConn) Close() error {
125-
if !conn.Flushed {
126-
return nil
127-
}
128-
129108
return conn.Conn.Close()
130109
}
131110

@@ -145,22 +124,19 @@ func (conn *TeeConn) SetReadDeadline(t time.Time) error { return conn.Conn.SetRe
145124
func (conn *TeeConn) SetWriteDeadline(t time.Time) error { return conn.Conn.SetWriteDeadline(t) }
146125

147126
// GetBuffer returns the tee'd buffer.
148-
func (conn *TeeConn) GetBuffer() *bytes.Buffer { return conn.Buffer }
127+
func (conn *TeeConn) GetBuffer() *bufio.ReadWriter { return conn.Buffer }
149128

150129
func NewTeeConn(conn net.Conn) *TeeConn {
151130
teeConn := &TeeConn{
152-
Conn: conn,
153-
Buffer: bytes.NewBuffer([]byte{}),
154-
Flushed: false,
131+
Conn: conn,
132+
Buffer: bufio.NewReadWriter(bufio.NewReaderSize(conn, 8192), bufio.NewWriterSize(conn, 8192)),
155133
}
156134

157-
teeConn.Reader = io.TeeReader(conn, teeConn.Buffer)
158-
159135
return teeConn
160136
}
161137

162138
// PeekTLSHello peeks the TLS Connection Hello to proxy based on SNI.
163-
func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *bytes.Buffer, *TeeConn, error) {
139+
func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *TeeConn, error) {
164140
var tlsHello *tls.ClientHelloInfo
165141

166142
tlsConfig := &tls.Config{
@@ -172,11 +148,33 @@ func PeekTLSHello(conn net.Conn) (*tls.ClientHelloInfo, *bytes.Buffer, *TeeConn,
172148

173149
teeConn := NewTeeConn(conn)
174150

175-
err := tls.Server(teeConn, tlsConfig).Handshake()
151+
header, err := teeConn.GetBuffer().Peek(5)
152+
if err != nil {
153+
return tlsHello, teeConn, err
154+
}
155+
156+
if header[0] != 0x16 {
157+
return tlsHello, teeConn, err
158+
}
159+
160+
helloBytes, err := teeConn.GetBuffer().Peek(len(header) + (int(header[3])<<8 | int(header[4])))
161+
if err != nil {
162+
return tlsHello, teeConn, err
163+
}
164+
165+
err = tls.Server(bufConn{reader: bytes.NewReader(helloBytes)}, tlsConfig).Handshake()
176166

177-
return tlsHello, teeConn.GetBuffer(), teeConn, err
167+
return tlsHello, teeConn, err
178168
}
179169

170+
type bufConn struct {
171+
reader io.Reader
172+
net.Conn
173+
}
174+
175+
func (b bufConn) Read(p []byte) (int, error) { return b.reader.Read(p) }
176+
func (bufConn) Write(p []byte) (int, error) { return 0, io.EOF }
177+
180178
// IdleTimeoutConn handles the connection with a context deadline.
181179
// code adapted from https://qiita.com/kwi/items/b38d6273624ad3f6ae79
182180
type IdleTimeoutConn struct {

utils/state.go

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package utils
22

33
import (
4-
"bytes"
54
"encoding/base64"
65
"fmt"
76
"io"
@@ -94,19 +93,18 @@ func (tH *TCPHolder) Handle(state *State) {
9493
continue
9594
}
9695

97-
var firstWrite *bytes.Buffer
96+
realConn := cl
9897

9998
balancerName := ""
10099
if tH.SNIProxy {
101-
tlsHello, buf, _, err := PeekTLSHello(cl)
100+
tlsHello, realConn, err := PeekTLSHello(cl)
102101
if err != nil && tlsHello == nil {
103102
log.Printf("Unable to read TLS hello: %s", err)
104-
cl.Close()
103+
realConn.Close()
105104
continue
106105
}
107106

108107
balancerName = tlsHello.ServerName
109-
firstWrite = buf
110108
}
111109

112110
pB, ok := tH.Balancers.Load(balancerName)
@@ -121,7 +119,7 @@ func (tH *TCPHolder) Handle(state *State) {
121119

122120
if pB == nil {
123121
log.Printf("Unable to load connection location: %s not found on TCP listener %s", balancerName, tH.TCPHost)
124-
cl.Close()
122+
realConn.Close()
125123
continue
126124
}
127125
}
@@ -131,20 +129,20 @@ func (tH *TCPHolder) Handle(state *State) {
131129
connectionLocation, err := balancer.NextServer()
132130
if err != nil {
133131
log.Println("Unable to load connection location:", err)
134-
cl.Close()
132+
realConn.Close()
135133
continue
136134
}
137135

138136
host, err := base64.StdEncoding.DecodeString(connectionLocation.Host)
139137
if err != nil {
140138
log.Println("Unable to decode connection location:", err)
141-
cl.Close()
139+
realConn.Close()
142140
continue
143141
}
144142

145143
hostAddr := string(host)
146144

147-
logLine := fmt.Sprintf("Accepted connection from %s -> %s", cl.RemoteAddr().String(), cl.LocalAddr().String())
145+
logLine := fmt.Sprintf("Accepted connection from %s -> %s", realConn.RemoteAddr().String(), realConn.LocalAddr().String())
148146
log.Println(logLine)
149147

150148
if viper.GetBool("log-to-client") {
@@ -166,20 +164,11 @@ func (tH *TCPHolder) Handle(state *State) {
166164
conn, err := net.Dial("unix", hostAddr)
167165
if err != nil {
168166
log.Println("Error connecting to tcp balancer:", err)
169-
cl.Close()
167+
realConn.Close()
170168
continue
171169
}
172170

173-
if firstWrite != nil {
174-
_, err := conn.Write(firstWrite.Bytes())
175-
if err != nil {
176-
log.Println("Unable to write to conn:", err)
177-
cl.Close()
178-
continue
179-
}
180-
}
181-
182-
go CopyBoth(conn, cl)
171+
go CopyBoth(conn, realConn)
183172
}
184173
}
185174

0 commit comments

Comments
 (0)