websocket_dialer.go 2.5 KB
package httpexpect

import (
	"bufio"
	"net"
	"net/http"
	"net/http/httptest"
	"sync"

	"github.com/gorilla/websocket"
	"github.com/valyala/fasthttp"
)

// NewWebsocketDialer produces new websocket.Dialer which dials to bound
// http.Handler without creating a real net.Conn.
func NewWebsocketDialer(handler http.Handler) *websocket.Dialer {
	return &websocket.Dialer{
		NetDial: func(network, addr string) (net.Conn, error) {
			hc := newHandlerConn()
			hc.runHandler(handler)
			return hc, nil
		},
	}
}

// NewFastWebsocketDialer produces new websocket.Dialer which dials to bound
// fasthttp.RequestHandler without creating a real net.Conn.
func NewFastWebsocketDialer(handler fasthttp.RequestHandler) *websocket.Dialer {
	return &websocket.Dialer{
		NetDial: func(network, addr string) (net.Conn, error) {
			hc := newHandlerConn()
			hc.runFastHandler(handler)
			return hc, nil
		},
	}
}

type handlerConn struct {
	net.Conn          // returned from dialer
	backConn net.Conn // passed to the background goroutine

	wg sync.WaitGroup
}

func newHandlerConn() *handlerConn {
	dialConn, backConn := net.Pipe()

	return &handlerConn{
		Conn:     dialConn,
		backConn: backConn,
	}
}

func (hc *handlerConn) Close() error {
	err := hc.Conn.Close()
	hc.wg.Wait() // wait the background goroutine
	return err
}

func (hc *handlerConn) runHandler(handler http.Handler) {
	hc.wg.Add(1)

	go func() {
		defer hc.wg.Done()

		recorder := &hijackRecorder{conn: hc.backConn}

		for {
			req, err := http.ReadRequest(bufio.NewReader(hc.backConn))
			if err != nil {
				return
			}
			handler.ServeHTTP(recorder, req)
		}
	}()
}

func (hc *handlerConn) runFastHandler(handler fasthttp.RequestHandler) {
	hc.wg.Add(1)

	go func() {
		defer hc.wg.Done()

		_ = fasthttp.ServeConn(hc.backConn, handler)
	}()
}

// hijackRecorder it similar to httptest.ResponseRecorder,
// but with Hijack capabilities.
//
// Original idea is stolen from https://github.com/posener/wstest
type hijackRecorder struct {
	httptest.ResponseRecorder
	conn net.Conn
}

// Hijack the connection for caller.
//
// Implements http.Hijacker interface.
func (r *hijackRecorder) Hijack() (net.Conn, *bufio.ReadWriter, error) {
	rw := bufio.NewReadWriter(bufio.NewReader(r.conn), bufio.NewWriter(r.conn))
	return r.conn, rw, nil
}

// WriteHeader write HTTP header to the client and closes the connection
//
// Implements http.ResponseWriter interface.
func (r *hijackRecorder) WriteHeader(code int) {
	resp := http.Response{StatusCode: code, Header: r.Header()}
	_ = resp.Write(r.conn)
}