gitlab-org--gitlab-foss/workhorse/internal/channel/wrappers_test.go

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

156 lines
4.2 KiB
Go
Raw Permalink Normal View History

package channel
import (
"bytes"
"errors"
"net"
"testing"
"time"
"github.com/gorilla/websocket"
)
type testcase struct {
input *fakeConn
expected *fakeConn
}
type fakeConn struct {
// WebSocket message type
mt int
data []byte
err error
}
func (f *fakeConn) ReadMessage() (int, []byte, error) {
return f.mt, f.data, f.err
}
func (f *fakeConn) WriteMessage(mt int, data []byte) error {
f.mt = mt
f.data = data
return f.err
}
func (f *fakeConn) WriteControl(mt int, data []byte, _ time.Time) error {
f.mt = mt
f.data = data
return f.err
}
func (f *fakeConn) UnderlyingConn() net.Conn {
return nil
}
func fake(mt int, data []byte, err error) *fakeConn {
return &fakeConn{mt: mt, data: []byte(data), err: err}
}
var (
msg = []byte("foo bar")
msgBase64 = []byte("Zm9vIGJhcg==")
kubeMsg = append([]byte{0}, msg...)
kubeMsgBase64 = append([]byte{'0'}, msgBase64...)
errFake = errors.New("fake error")
text = websocket.TextMessage
binary = websocket.BinaryMessage
other = 999
fakeOther = fake(other, []byte("foo"), nil)
)
func requireEqualConn(t *testing.T, expected, actual *fakeConn, msg string, args ...interface{}) {
if expected.mt != actual.mt {
t.Logf("messageType expected to be %v but was %v", expected.mt, actual.mt)
t.Fatalf(msg, args...)
}
if !bytes.Equal(expected.data, actual.data) {
t.Logf("data expected to be %q but was %q: ", expected.data, actual.data)
t.Fatalf(msg, args...)
}
if expected.err != actual.err {
t.Logf("error expected to be %v but was %v", expected.err, actual.err)
t.Fatalf(msg, args...)
}
}
func TestReadMessage(t *testing.T) {
testCases := map[string][]testcase{
"channel.k8s.io": {
{fake(binary, kubeMsg, errFake), fake(binary, kubeMsg, errFake)},
{fake(binary, kubeMsg, nil), fake(binary, msg, nil)},
{fake(text, kubeMsg, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"base64.channel.k8s.io": {
{fake(text, kubeMsgBase64, errFake), fake(text, kubeMsgBase64, errFake)},
{fake(text, kubeMsgBase64, nil), fake(binary, msg, nil)},
{fake(binary, kubeMsgBase64, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"terminal.gitlab.com": {
{fake(binary, msg, errFake), fake(binary, msg, errFake)},
{fake(binary, msg, nil), fake(binary, msg, nil)},
{fake(text, msg, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"base64.terminal.gitlab.com": {
{fake(text, msgBase64, errFake), fake(text, msgBase64, errFake)},
{fake(text, msgBase64, nil), fake(binary, msg, nil)},
{fake(binary, msgBase64, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
}
for subprotocol, cases := range testCases {
for i, tc := range cases {
conn := Wrap(tc.input, subprotocol)
mt, data, err := conn.ReadMessage()
actual := fake(mt, data, err)
requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i)
}
}
}
func TestWriteMessage(t *testing.T) {
testCases := map[string][]testcase{
"channel.k8s.io": {
{fake(binary, msg, errFake), fake(binary, kubeMsg, errFake)},
{fake(binary, msg, nil), fake(binary, kubeMsg, nil)},
{fake(text, msg, nil), fake(binary, kubeMsg, nil)},
{fakeOther, fakeOther},
},
"base64.channel.k8s.io": {
{fake(binary, msg, errFake), fake(text, kubeMsgBase64, errFake)},
{fake(binary, msg, nil), fake(text, kubeMsgBase64, nil)},
{fake(text, msg, nil), fake(text, kubeMsgBase64, nil)},
{fakeOther, fakeOther},
},
"terminal.gitlab.com": {
{fake(binary, msg, errFake), fake(binary, msg, errFake)},
{fake(binary, msg, nil), fake(binary, msg, nil)},
{fake(text, msg, nil), fake(binary, msg, nil)},
{fakeOther, fakeOther},
},
"base64.terminal.gitlab.com": {
{fake(binary, msg, errFake), fake(text, msgBase64, errFake)},
{fake(binary, msg, nil), fake(text, msgBase64, nil)},
{fake(text, msg, nil), fake(text, msgBase64, nil)},
{fakeOther, fakeOther},
},
}
for subprotocol, cases := range testCases {
for i, tc := range cases {
actual := fake(0, nil, tc.input.err)
conn := Wrap(actual, subprotocol)
actual.err = conn.WriteMessage(tc.input.mt, tc.input.data)
requireEqualConn(t, tc.expected, actual, "%s test case %v", subprotocol, i)
}
}
}