gitlab-org--gitlab-foss/workhorse/internal/upstream/roundtripper/roundtripper_test.go

94 lines
2.3 KiB
Go

package roundtripper
import (
"crypto/tls"
"crypto/x509"
"fmt"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestMustParseAddress(t *testing.T) {
successExamples := []struct{ address, scheme, expected string }{
{"1.2.3.4:56", "http", "1.2.3.4:56"},
{"[::1]:23", "http", "::1:23"},
{"4.5.6.7", "http", "4.5.6.7:http"},
{"4.5.6.7", "https", "4.5.6.7:https"},
}
for i, example := range successExamples {
t.Run(strconv.Itoa(i), func(t *testing.T) {
require.Equal(t, example.expected, mustParseAddress(example.address, example.scheme))
})
}
}
func TestMustParseAddressPanic(t *testing.T) {
panicExamples := []struct{ address, scheme string }{
{"1.2.3.4", ""},
}
for i, panicExample := range panicExamples {
t.Run(strconv.Itoa(i), func(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal("expected panic")
}
}()
mustParseAddress(panicExample.address, panicExample.scheme)
})
}
}
func TestSupportsHTTPBackend(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
fmt.Fprint(w, "successful response")
}))
defer ts.Close()
testNewBackendRoundTripper(t, ts, nil, "successful response")
}
func TestSupportsHTTPSBackend(t *testing.T) {
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
fmt.Fprint(w, "successful response")
}))
defer ts.Close()
certpool := x509.NewCertPool()
certpool.AddCert(ts.Certificate())
tlsClientConfig := &tls.Config{
RootCAs: certpool,
}
testNewBackendRoundTripper(t, ts, tlsClientConfig, "successful response")
}
func testNewBackendRoundTripper(t *testing.T, ts *httptest.Server, tlsClientConfig *tls.Config, expectedResponseBody string) {
t.Helper()
backend, err := url.Parse(ts.URL)
require.NoError(t, err, "parse url")
rt := newBackendRoundTripper(backend, "", 0, true, tlsClientConfig)
req, err := http.NewRequest("GET", ts.URL+"/", nil)
require.NoError(t, err, "build request")
response, err := rt.RoundTrip(req)
require.NoError(t, err, "perform roundtrip")
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
require.NoError(t, err)
require.Equal(t, expectedResponseBody, string(body))
}