websocket_message.go 7.8 KB
package httpexpect

import (
	"encoding/json"

	"github.com/gorilla/websocket"
)

// WebsocketMessage provides methods to inspect message read from WebSocket connection.
type WebsocketMessage struct {
	chain     chain
	typ       int
	content   []byte
	closeCode int
}

// NewWebsocketMessage returns a new WebsocketMessage object given a reporter used to
// report failures and the message parameters to be inspected.
//
// reporter should not be nil.
//
// Example:
//   m := NewWebsocketMessage(reporter, websocket.TextMessage, []byte("content"), 0)
//   m.TextMessage()
func NewWebsocketMessage(
	reporter Reporter, typ int, content []byte, closeCode ...int,
) *WebsocketMessage {
	m := &WebsocketMessage{
		chain:   makeChain(reporter),
		typ:     typ,
		content: content,
	}
	if len(closeCode) != 0 {
		m.closeCode = closeCode[0]
	}
	return m
}

func makeWebsocketMessage(chain chain) *WebsocketMessage {
	return &WebsocketMessage{
		chain: chain,
	}
}

// Raw returns underlying type, content and close code of WebSocket message.
// Theses values are originally read from WebSocket connection.
func (m *WebsocketMessage) Raw() (typ int, content []byte, closeCode int) {
	return m.typ, m.content, m.closeCode
}

// CloseMessage is a shorthand for m.Type(websocket.CloseMessage).
func (m *WebsocketMessage) CloseMessage() *WebsocketMessage {
	return m.Type(websocket.CloseMessage)
}

// NotCloseMessage is a shorthand for m.NotType(websocket.CloseMessage).
func (m *WebsocketMessage) NotCloseMessage() *WebsocketMessage {
	return m.NotType(websocket.CloseMessage)
}

// BinaryMessage is a shorthand for m.Type(websocket.BinaryMessage).
func (m *WebsocketMessage) BinaryMessage() *WebsocketMessage {
	return m.Type(websocket.BinaryMessage)
}

// NotBinaryMessage is a shorthand for m.NotType(websocket.BinaryMessage).
func (m *WebsocketMessage) NotBinaryMessage() *WebsocketMessage {
	return m.NotType(websocket.BinaryMessage)
}

// TextMessage is a shorthand for m.Type(websocket.TextMessage).
func (m *WebsocketMessage) TextMessage() *WebsocketMessage {
	return m.Type(websocket.TextMessage)
}

// NotTextMessage is a shorthand for m.NotType(websocket.TextMessage).
func (m *WebsocketMessage) NotTextMessage() *WebsocketMessage {
	return m.NotType(websocket.TextMessage)
}

// Type succeeds if WebSocket message type is one of the given.
//
// WebSocket message types are defined in RFC 6455, section 11.8.
// See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
//
// Example:
//  msg := conn.Expect()
//  msg.Type(websocket.TextMessage, websocket.BinaryMessage)
func (m *WebsocketMessage) Type(typ ...int) *WebsocketMessage {
	switch {
	case m.chain.failed():
		return m
	case len(typ) == 0:
		m.chain.fail("\nunexpected nil argument passed to Type")
		return m
	}
	yes := false
	for _, t := range typ {
		if t == m.typ {
			yes = true
			break
		}
	}
	if !yes {
		if len(typ) > 1 {
			m.chain.fail(
				"\nexpected message type equal to one of:\n %v\n\nbut got:\n %d",
				typ, m.typ)
		} else {
			m.chain.fail(
				"\nexpected message type:\n %d\n\nbut got:\n %d",
				typ[0], m.typ)
		}
	}
	return m
}

// NotType succeeds if WebSocket message type is none of the given.
//
// WebSocket message types are defined in RFC 6455, section 11.8.
// See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
//
// Example:
//  msg := conn.Expect()
//  msg.NotType(websocket.CloseMessage, websocket.BinaryMessage)
func (m *WebsocketMessage) NotType(typ ...int) *WebsocketMessage {
	switch {
	case m.chain.failed():
		return m
	case len(typ) == 0:
		m.chain.fail("\nunexpected nil argument passed to NotType")
		return m
	}
	for _, t := range typ {
		if t == m.typ {
			if len(typ) > 1 {
				m.chain.fail(
					"\nexpected message type not equal:\n %v\n\nbut got:\n %d",
					typ, m.typ)
			} else {
				m.chain.fail(
					"\nexpected message type not equal:\n %d\n\nbut it did",
					typ[0], m.typ)
			}
			return m
		}
	}
	return m
}

// Code succeeds if WebSocket close code is one of the given.
//
// Code fails if WebSocket message type is not "8 - Connection Close Frame".
//
// WebSocket close codes are defined in RFC 6455, section 11.7.
// See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
//
// Example:
//  msg := conn.Expect().Closed()
//  msg.Code(websocket.CloseNormalClosure, websocket.CloseGoingAway)
func (m *WebsocketMessage) Code(code ...int) *WebsocketMessage {
	switch {
	case m.chain.failed():
		return m
	case len(code) == 0:
		m.chain.fail("\nunexpected nil argument passed to Code")
		return m
	case m.checkClosed("Code"):
		return m
	}
	yes := false
	for _, c := range code {
		if c == m.closeCode {
			yes = true
			break
		}
	}
	if !yes {
		if len(code) > 1 {
			m.chain.fail(
				"\nexpected close code equal to one of:\n %v\n\nbut got:\n %d",
				code, m.closeCode)
		} else {
			m.chain.fail(
				"\nexpected close code:\n %d\n\nbut got:\n %d",
				code[0], m.closeCode)
		}
	}
	return m
}

// NotCode succeeds if WebSocket close code is none of the given.
//
// NotCode fails if WebSocket message type is not "8 - Connection Close Frame".
//
// WebSocket close codes are defined in RFC 6455, section 11.7.
// See also https://godoc.org/github.com/gorilla/websocket#pkg-constants
//
// Example:
//  msg := conn.Expect().Closed()
//  msg.NotCode(websocket.CloseAbnormalClosure, websocket.CloseNoStatusReceived)
func (m *WebsocketMessage) NotCode(code ...int) *WebsocketMessage {
	switch {
	case m.chain.failed():
		return m
	case len(code) == 0:
		m.chain.fail("\nunexpected nil argument passed to CodeNotEqual")
		return m
	case m.checkClosed("NotCode"):
		return m
	}
	for _, c := range code {
		if c == m.closeCode {
			if len(code) > 1 {
				m.chain.fail(
					"\nexpected close code not equal:\n %v\n\nbut got:\n %d",
					code, m.closeCode)
			} else {
				m.chain.fail(
					"\nexpected close code not equal:\n %d\n\nbut it did",
					code[0], m.closeCode)
			}
			return m
		}
	}
	return m
}

func (m *WebsocketMessage) checkClosed(where string) bool {
	if m.typ != websocket.CloseMessage {
		m.chain.fail(
			"\nunexpected %s usage for not '%' WebSocket message type\n\n"+
				"got type:\n %s",
			where,
			wsMessageTypeName(websocket.CloseMessage),
			wsMessageTypeName(m.typ))
		return true
	}
	return false
}

// Body returns a new String object that may be used to inspect
// WebSocket message content.
//
// Example:
//  msg := conn.Expect()
//  msg.Body().NotEmpty()
//  msg.Body().Length().Equal(100)
func (m *WebsocketMessage) Body() *String {
	return &String{m.chain, string(m.content)}
}

// NoContent succeeds if WebSocket message has no content (is empty).
func (m *WebsocketMessage) NoContent() *WebsocketMessage {
	switch {
	case m.chain.failed():
		return m
	case len(m.content) == 0:
		return m
	}
	switch m.typ {
	case websocket.BinaryMessage:
		m.chain.fail(
			"\nexpected message body being empty, but got:\n %d bytes",
			len(m.content))
	default:
		m.chain.fail(
			"\nexpected message body being empty, but got:\n %s",
			string(m.content))
	}
	return m
}

// JSON returns a new Value object that may be used to inspect JSON contents
// of WebSocket message.
//
// JSON succeeds if JSON may be decoded from message content.
//
// Example:
//  msg := conn.Expect()
//  msg.JSON().Array().Elements("foo", "bar")
func (m *WebsocketMessage) JSON() *Value {
	return &Value{m.chain, m.getJSON()}
}

func (m *WebsocketMessage) getJSON() interface{} {
	if m.chain.failed() {
		return nil
	}

	var value interface{}
	if err := json.Unmarshal(m.content, &value); err != nil {
		m.chain.fail(err.Error())
		return nil
	}

	return value
}

func wsMessageTypeName(typ int) string {
	switch typ {
	case websocket.TextMessage:
		return "text"
	case websocket.BinaryMessage:
		return "binary"
	case websocket.CloseMessage:
		return "close"
	case websocket.PingMessage:
		return "ping"
	case websocket.PongMessage:
		return "pong"
	}
	return "unknown"
}