gitlab-org--gitlab-foss/workhorse/channel_test.go

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

246 lines
6.8 KiB
Go
Raw Normal View History

package main
import (
"bytes"
"encoding/pem"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"path"
"strings"
"testing"
"time"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"
"gitlab.com/gitlab-org/labkit/log"
"gitlab.com/gitlab-org/gitlab/workhorse/internal/api"
)
var (
envTerminalPath = fmt.Sprintf("%s/-/environments/1/terminal.ws", testProject)
jobTerminalPath = fmt.Sprintf("%s/-/jobs/1/terminal.ws", testProject)
servicesProxyWSPath = fmt.Sprintf("%s/-/jobs/1/proxy.ws", testProject)
)
type connWithReq struct {
conn *websocket.Conn
req *http.Request
}
func TestChannelHappyPath(t *testing.T) {
tests := []struct {
name string
channelPath string
}{
{"environments", envTerminalPath},
{"jobs", jobTerminalPath},
{"services", servicesProxyWSPath},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
serverConns, clientURL, close := wireupChannel(t, test.channelPath, nil, "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
require.NoError(t, err)
server := (<-serverConns).conn
defer server.Close()
message := "test message"
// channel.k8s.io: server writes to channel 1, STDOUT
require.NoError(t, say(server, "\x01"+message))
requireReadMessage(t, client, websocket.BinaryMessage, message)
require.NoError(t, say(client, message))
// channel.k8s.io: client writes get put on channel 0, STDIN
requireReadMessage(t, server, websocket.BinaryMessage, "\x00"+message)
// Closing the client should send an EOT signal to the server's STDIN
client.Close()
requireReadMessage(t, server, websocket.BinaryMessage, "\x00\x04")
})
}
}
func TestChannelBadTLS(t *testing.T) {
_, clientURL, close := wireupChannel(t, envTerminalPath, badCA, "channel.k8s.io")
defer close()
_, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
require.Equal(t, websocket.ErrBadHandshake, err, "unexpected error %v", err)
}
func TestChannelSessionTimeout(t *testing.T) {
serverConns, clientURL, close := wireupChannel(t, envTerminalPath, timeout, "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
require.NoError(t, err)
sc := <-serverConns
defer sc.conn.Close()
client.SetReadDeadline(time.Now().Add(time.Duration(2) * time.Second))
_, _, err = client.ReadMessage()
require.True(t, websocket.IsCloseError(err, websocket.CloseAbnormalClosure), "Client connection was not closed, got %v", err)
}
func TestChannelProxyForwardsHeadersFromUpstream(t *testing.T) {
hdr := make(http.Header)
hdr.Set("Random-Header", "Value")
serverConns, clientURL, close := wireupChannel(t, envTerminalPath, setHeader(hdr), "channel.k8s.io")
defer close()
client, _, err := dialWebsocket(clientURL, nil, "terminal.gitlab.com")
require.NoError(t, err)
defer client.Close()
sc := <-serverConns
defer sc.conn.Close()
require.Equal(t, "Value", sc.req.Header.Get("Random-Header"), "Header specified by upstream not sent to remote")
}
func TestChannelProxyForwardsXForwardedForFromClient(t *testing.T) {
serverConns, clientURL, close := wireupChannel(t, envTerminalPath, nil, "channel.k8s.io")
defer close()
hdr := make(http.Header)
hdr.Set("X-Forwarded-For", "127.0.0.2")
client, _, err := dialWebsocket(clientURL, hdr, "terminal.gitlab.com")
require.NoError(t, err)
defer client.Close()
clientIP, _, err := net.SplitHostPort(client.LocalAddr().String())
require.NoError(t, err)
sc := <-serverConns
defer sc.conn.Close()
require.Equal(t, "127.0.0.2, "+clientIP, sc.req.Header.Get("X-Forwarded-For"), "X-Forwarded-For from client not sent to remote")
}
func wireupChannel(t *testing.T, channelPath string, modifier func(*api.Response), subprotocols ...string) (chan connWithReq, string, func()) {
serverConns, remote := startWebsocketServer(subprotocols...)
authResponse := channelOkBody(remote, nil, subprotocols...)
if modifier != nil {
modifier(authResponse)
}
upstream := testAuthServer(t, nil, nil, 200, authResponse)
workhorse := startWorkhorseServer(upstream.URL)
return serverConns, websocketURL(workhorse.URL, channelPath), func() {
workhorse.Close()
upstream.Close()
remote.Close()
}
}
func startWebsocketServer(subprotocols ...string) (chan connWithReq, *httptest.Server) {
upgrader := &websocket.Upgrader{Subprotocols: subprotocols}
connCh := make(chan connWithReq, 1)
server := httptest.NewTLSServer(webSocketHandler(upgrader, connCh))
return connCh, server
}
func webSocketHandler(upgrader *websocket.Upgrader, connCh chan connWithReq) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
logEntry := log.WithFields(log.Fields{
"method": r.Method,
"url": r.URL,
"headers": r.Header,
})
logEntry.Info("WEBSOCKET")
conn, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logEntry.WithError(err).Error("WEBSOCKET Upgrade failed")
return
}
connCh <- connWithReq{conn, r}
// The connection has been hijacked so it's OK to end here
})
}
func channelOkBody(remote *httptest.Server, header http.Header, subprotocols ...string) *api.Response {
out := &api.Response{
Channel: &api.ChannelSettings{
Url: websocketURL(remote.URL),
Header: header,
Subprotocols: subprotocols,
MaxSessionTime: 0,
},
}
if len(remote.TLS.Certificates) > 0 {
data := bytes.NewBuffer(nil)
pem.Encode(data, &pem.Block{Type: "CERTIFICATE", Bytes: remote.TLS.Certificates[0].Certificate[0]})
out.Channel.CAPem = data.String()
}
return out
}
func badCA(authResponse *api.Response) {
authResponse.Channel.CAPem = "Bad CA"
}
func timeout(authResponse *api.Response) {
authResponse.Channel.MaxSessionTime = 1
}
func setHeader(hdr http.Header) func(*api.Response) {
return func(authResponse *api.Response) {
authResponse.Channel.Header = hdr
}
}
func dialWebsocket(url string, header http.Header, subprotocols ...string) (*websocket.Conn, *http.Response, error) {
dialer := &websocket.Dialer{
Subprotocols: subprotocols,
}
return dialer.Dial(url, header)
}
func websocketURL(httpURL string, suffix ...string) string {
url, err := url.Parse(httpURL)
if err != nil {
panic(err)
}
switch url.Scheme {
case "http":
url.Scheme = "ws"
case "https":
url.Scheme = "wss"
default:
panic("Unknown scheme: " + url.Scheme)
}
url.Path = path.Join(url.Path, strings.Join(suffix, "/"))
return url.String()
}
func say(conn *websocket.Conn, message string) error {
return conn.WriteMessage(websocket.TextMessage, []byte(message))
}
func requireReadMessage(t *testing.T, conn *websocket.Conn, expectedMessageType int, expectedData string) {
messageType, data, err := conn.ReadMessage()
require.NoError(t, err)
require.Equal(t, expectedMessageType, messageType, "message type")
require.Equal(t, expectedData, string(data), "message data")
}