package httpexpect import ( "encoding/json" "time" "github.com/gorilla/websocket" ) const noDuration = time.Duration(0) var infiniteTime = time.Time{} // Websocket provides methods to read from, write into and close WebSocket // connection. type Websocket struct { config Config chain chain conn *websocket.Conn readTimeout time.Duration writeTimeout time.Duration isClosed bool } // NewWebsocket returns a new Websocket given a Config with Reporter and // Printers, and websocket.Conn to be inspected and handled. func NewWebsocket(config Config, conn *websocket.Conn) *Websocket { return makeWebsocket(config, makeChain(config.Reporter), conn) } func makeWebsocket(config Config, chain chain, conn *websocket.Conn) *Websocket { return &Websocket{ config: config, chain: chain, conn: conn, } } // Raw returns underlying websocket.Conn object. // This is the value originally passed to NewConnection. func (c *Websocket) Raw() *websocket.Conn { return c.conn } // WithReadTimeout sets timeout duration for WebSocket connection reads. // // By default no timeout is used. func (c *Websocket) WithReadTimeout(timeout time.Duration) *Websocket { c.readTimeout = timeout return c } // WithoutReadTimeout removes timeout for WebSocket connection reads. func (c *Websocket) WithoutReadTimeout() *Websocket { c.readTimeout = noDuration return c } // WithWriteTimeout sets timeout duration for WebSocket connection writes. // // By default no timeout is used. func (c *Websocket) WithWriteTimeout(timeout time.Duration) *Websocket { c.writeTimeout = timeout return c } // WithoutWriteTimeout removes timeout for WebSocket connection writes. // // If not used then DefaultWebsocketTimeout will be used. func (c *Websocket) WithoutWriteTimeout() *Websocket { c.writeTimeout = noDuration return c } // Subprotocol returns a new String object that may be used to inspect // negotiated protocol for the connection. func (c *Websocket) Subprotocol() *String { s := &String{chain: c.chain} if c.conn != nil { s.value = c.conn.Subprotocol() } return s } // Expect reads next message from WebSocket connection and // returns a new WebsocketMessage object to inspect received message. // // Example: // msg := conn.Expect() // msg.JSON().Object().ValueEqual("message", "hi") func (c *Websocket) Expect() *WebsocketMessage { switch { case c.chain.failed(): return makeWebsocketMessage(c.chain) case c.conn == nil: c.chain.fail("\nunexpected read from failed WebSocket connection") return makeWebsocketMessage(c.chain) case c.isClosed: c.chain.fail("\nunexpected read from closed WebSocket connection") return makeWebsocketMessage(c.chain) case !c.setReadDeadline(): return makeWebsocketMessage(c.chain) } var err error m := makeWebsocketMessage(c.chain) m.typ, m.content, err = c.conn.ReadMessage() if err != nil { if cls, ok := err.(*websocket.CloseError); ok { m.typ = websocket.CloseMessage m.closeCode = cls.Code m.content = []byte(cls.Text) c.printRead(m.typ, m.content, m.closeCode) } else { c.chain.fail( "\nexpected read WebSocket connection, "+ "but got failure: %s", err.Error()) return makeWebsocketMessage(c.chain) } } else { c.printRead(m.typ, m.content, m.closeCode) } return m } func (c *Websocket) setReadDeadline() bool { deadline := infiniteTime if c.readTimeout != noDuration { deadline = time.Now().Add(c.readTimeout) } if err := c.conn.SetReadDeadline(deadline); err != nil { c.chain.fail( "\nunexpected failure when setting "+ "read WebSocket connection deadline: %s", err.Error()) return false } return true } func (c *Websocket) printRead(typ int, content []byte, closeCode int) { for _, printer := range c.config.Printers { if p, ok := printer.(WebsocketPrinter); ok { p.WebsocketRead(typ, content, closeCode) } } } // Disconnect closes the underlying WebSocket connection without sending or // waiting for a close message. // // It's okay to call this function multiple times. // // It's recommended to always call this function after connection usage is over // to ensure that no resource leaks will happen. // // Example: // conn := resp.Connection() // defer conn.Disconnect() func (c *Websocket) Disconnect() *Websocket { if c.conn == nil || c.isClosed { return c } c.isClosed = true if err := c.conn.Close(); err != nil { c.chain.fail("close error when disconnecting webcoket: " + err.Error()) } return c } // Close cleanly closes the underlying WebSocket connection // by sending an empty close message and then waiting (with timeout) // for the server to close the connection. // // WebSocket close code may be optionally specified. // If not, then "1000 - Normal Closure" will be used. // // WebSocket close codes are defined in RFC 6455, section 11.7. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants // // It's okay to call this function multiple times. // // Example: // conn := resp.Connection() // conn.Close(websocket.CloseUnsupportedData) func (c *Websocket) Close(code ...int) *Websocket { switch { case c.checkUnusable("Close"): return c case len(code) > 1: c.chain.fail("\nunexpected multiple code arguments passed to Close") return c } return c.CloseWithBytes(nil, code...) } // CloseWithBytes cleanly closes the underlying WebSocket connection // by sending given slice of bytes as a close message and then waiting // (with timeout) for the server to close the connection. // // WebSocket close code may be optionally specified. // If not, then "1000 - Normal Closure" will be used. // // WebSocket close codes are defined in RFC 6455, section 11.7. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants // // It's okay to call this function multiple times. // // Example: // conn := resp.Connection() // conn.CloseWithBytes([]byte("bye!"), websocket.CloseGoingAway) func (c *Websocket) CloseWithBytes(b []byte, code ...int) *Websocket { switch { case c.checkUnusable("CloseWithBytes"): return c case len(code) > 1: c.chain.fail( "\nunexpected multiple code arguments passed to CloseWithBytes") return c } c.WriteMessage(websocket.CloseMessage, b, code...) return c } // CloseWithJSON cleanly closes the underlying WebSocket connection // by sending given object (marshaled using json.Marshal()) as a close message // and then waiting (with timeout) for the server to close the connection. // // WebSocket close code may be optionally specified. // If not, then "1000 - Normal Closure" will be used. // // WebSocket close codes are defined in RFC 6455, section 11.7. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants // // It's okay to call this function multiple times. // // Example: // type MyJSON struct { // Foo int `json:"foo"` // } // // conn := resp.Connection() // conn.CloseWithJSON(MyJSON{Foo: 123}, websocket.CloseUnsupportedData) func (c *Websocket) CloseWithJSON( object interface{}, code ...int, ) *Websocket { switch { case c.checkUnusable("CloseWithJSON"): return c case len(code) > 1: c.chain.fail( "\nunexpected multiple code arguments passed to CloseWithJSON") return c } b, err := json.Marshal(object) if err != nil { c.chain.fail(err.Error()) return c } return c.CloseWithBytes(b, code...) } // CloseWithText cleanly closes the underlying WebSocket connection // by sending given text as a close message and then waiting (with timeout) // for the server to close the connection. // // WebSocket close code may be optionally specified. // If not, then "1000 - Normal Closure" will be used. // // WebSocket close codes are defined in RFC 6455, section 11.7. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants // // It's okay to call this function multiple times. // // Example: // conn := resp.Connection() // conn.CloseWithText("bye!") func (c *Websocket) CloseWithText(s string, code ...int) *Websocket { switch { case c.checkUnusable("CloseWithText"): return c case len(code) > 1: c.chain.fail( "\nunexpected multiple code arguments passed to CloseWithText") return c } return c.CloseWithBytes([]byte(s), code...) } // WriteMessage writes to the underlying WebSocket connection a message // of given type with given content. // Additionally, WebSocket close code may be specified for close messages. // // WebSocket message types are defined in RFC 6455, section 11.8. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants // // WebSocket close codes are defined in RFC 6455, section 11.7. // See also https://godoc.org/github.com/gorilla/websocket#pkg-constants // // Example: // conn := resp.Connection() // conn.WriteMessage(websocket.CloseMessage, []byte("Namárië...")) func (c *Websocket) WriteMessage( typ int, content []byte, closeCode ...int, ) *Websocket { if c.checkUnusable("WriteMessage") { return c } switch typ { case websocket.TextMessage, websocket.BinaryMessage: c.printWrite(typ, content, 0) case websocket.CloseMessage: if len(closeCode) > 1 { c.chain.fail("\nunexpected multiple closeCode arguments " + "passed to WriteMessage") return c } code := websocket.CloseNormalClosure if len(closeCode) > 0 { code = closeCode[0] } c.printWrite(typ, content, code) content = websocket.FormatCloseMessage(code, string(content)) default: c.chain.fail("\nunexpected WebSocket message type '%s' "+ "passed to WriteMessage", wsMessageTypeName(typ)) return c } if !c.setWriteDeadline() { return c } if err := c.conn.WriteMessage(typ, content); err != nil { c.chain.fail( "\nexpected write into WebSocket connection, "+ "but got failure: %s", err.Error()) } return c } // WriteBytesBinary is a shorthand for c.WriteMessage(websocket.BinaryMessage, b). func (c *Websocket) WriteBytesBinary(b []byte) *Websocket { if c.checkUnusable("WriteBytesBinary") { return c } return c.WriteMessage(websocket.BinaryMessage, b) } // WriteBytesText is a shorthand for c.WriteMessage(websocket.TextMessage, b). func (c *Websocket) WriteBytesText(b []byte) *Websocket { if c.checkUnusable("WriteBytesText") { return c } return c.WriteMessage(websocket.TextMessage, b) } // WriteText is a shorthand for // c.WriteMessage(websocket.TextMessage, []byte(s)). func (c *Websocket) WriteText(s string) *Websocket { if c.checkUnusable("WriteText") { return c } return c.WriteMessage(websocket.TextMessage, []byte(s)) } // WriteJSON writes to the underlying WebSocket connection given object, // marshaled using json.Marshal(). func (c *Websocket) WriteJSON(object interface{}) *Websocket { if c.checkUnusable("WriteJSON") { return c } b, err := json.Marshal(object) if err != nil { c.chain.fail(err.Error()) return c } return c.WriteMessage(websocket.TextMessage, b) } func (c *Websocket) checkUnusable(where string) bool { switch { case c.chain.failed(): return true case c.conn == nil: c.chain.fail("\nunexpected %s call for failed WebSocket connection", where) return true case c.isClosed: c.chain.fail("\nunexpected %s call for closed WebSocket connection", where) return true } return false } func (c *Websocket) setWriteDeadline() bool { deadline := infiniteTime if c.writeTimeout != noDuration { deadline = time.Now().Add(c.writeTimeout) } if err := c.conn.SetWriteDeadline(deadline); err != nil { c.chain.fail( "\nunexpected failure when setting "+ "write WebSocket connection deadline: %s", err.Error()) return false } return true } func (c *Websocket) printWrite(typ int, content []byte, closeCode int) { for _, printer := range c.config.Printers { if p, ok := printer.(WebsocketPrinter); ok { p.WebsocketWrite(typ, content, closeCode) } } }