gitlab-org--gitlab-foss/workhorse/internal/httprs/httprs_test.go

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

263 lines
6.3 KiB
Go
Raw Normal View History

package httprs
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
. "github.com/smartystreets/goconvey/convey"
)
type fakeResponseWriter struct {
code int
h http.Header
tmp *os.File
}
func (f *fakeResponseWriter) Header() http.Header {
return f.h
}
func (f *fakeResponseWriter) Write(b []byte) (int, error) {
return f.tmp.Write(b)
}
func (f *fakeResponseWriter) Close(b []byte) error {
return f.tmp.Close()
}
func (f *fakeResponseWriter) WriteHeader(code int) {
f.code = code
}
func (f *fakeResponseWriter) Response() *http.Response {
f.tmp.Seek(0, io.SeekStart)
return &http.Response{Body: f.tmp, StatusCode: f.code, Header: f.h}
}
type fakeRoundTripper struct {
src *os.File
downgradeZeroToNoRange bool
}
func (f *fakeRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
fw := &fakeResponseWriter{h: http.Header{}}
var err error
fw.tmp, err = os.CreateTemp(os.TempDir(), "httprs")
if err != nil {
return nil, err
}
if err := os.Remove(fw.tmp.Name()); err != nil {
return nil, err
}
if f.downgradeZeroToNoRange {
// There are implementations that downgrades bytes=0- to a normal un-ranged GET
if r.Header.Get("Range") == "bytes=0-" {
r.Header.Del("Range")
}
}
http.ServeContent(fw, r, "temp.txt", time.Now(), f.src)
return fw.Response(), nil
}
const SZ = 4096
const (
downgradeZeroToNoRange = 1 << iota
sendAcceptRanges
)
type RSFactory func() *HttpReadSeeker
func newRSFactory(flags int) RSFactory {
return func() *HttpReadSeeker {
tmp, err := os.CreateTemp(os.TempDir(), "httprs")
if err != nil {
return nil
}
if err := os.Remove(tmp.Name()); err != nil {
return nil
}
for i := 0; i < SZ; i++ {
tmp.WriteString(fmt.Sprintf("%04d", i))
}
req, err := http.NewRequest("GET", "http://www.example.com", nil)
if err != nil {
return nil
}
res := &http.Response{
Request: req,
ContentLength: SZ * 4,
}
if flags&sendAcceptRanges > 0 {
res.Header = http.Header{"Accept-Ranges": []string{"bytes"}}
}
downgradeZeroToNoRange := (flags & downgradeZeroToNoRange) > 0
return NewHttpReadSeeker(res, &http.Client{Transport: &fakeRoundTripper{src: tmp, downgradeZeroToNoRange: downgradeZeroToNoRange}})
}
}
func TestHttpWebServer(t *testing.T) {
Convey("Scenario: testing WebServer", t, func() {
dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, "file"), make([]byte, 10000), 0755)
So(err, ShouldBeNil)
server := httptest.NewServer(http.FileServer(http.Dir(dir)))
Convey("When requesting /file", func() {
res, err := http.Get(server.URL + "/file")
So(err, ShouldBeNil)
stream := NewHttpReadSeeker(res)
So(stream, ShouldNotBeNil)
Convey("Can read 100 bytes from start of file", func() {
n, err := stream.Read(make([]byte, 100))
So(err, ShouldBeNil)
So(n, ShouldEqual, 100)
Convey("When seeking 4KiB forward", func() {
pos, err := stream.Seek(4096, io.SeekCurrent)
So(err, ShouldBeNil)
So(pos, ShouldEqual, 4096+100)
Convey("Can read 100 bytes", func() {
n, err := stream.Read(make([]byte, 100))
So(err, ShouldBeNil)
So(n, ShouldEqual, 100)
})
})
})
})
})
}
func TestHttpReaderSeeker(t *testing.T) {
tests := []struct {
name string
newRS func() *HttpReadSeeker
}{
{name: "with no flags", newRS: newRSFactory(0)},
{name: "with only Accept-Ranges", newRS: newRSFactory(sendAcceptRanges)},
{name: "downgrade 0-range to no range", newRS: newRSFactory(downgradeZeroToNoRange)},
{name: "downgrade 0-range with Accept-Ranges", newRS: newRSFactory(downgradeZeroToNoRange | sendAcceptRanges)},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
testHttpReaderSeeker(t, test.newRS)
})
}
}
func testHttpReaderSeeker(t *testing.T, newRS RSFactory) {
Convey("Scenario: testing HttpReaderSeeker", t, func() {
Convey("Read should start at the beginning", func() {
r := newRS()
So(r, ShouldNotBeNil)
defer r.Close()
buf := make([]byte, 4)
n, err := io.ReadFull(r, buf)
So(n, ShouldEqual, 4)
So(err, ShouldBeNil)
So(string(buf), ShouldEqual, "0000")
})
Convey("Seek w SEEK_SET should seek to right offset", func() {
r := newRS()
So(r, ShouldNotBeNil)
defer r.Close()
s, err := r.Seek(4*64, io.SeekStart)
So(s, ShouldEqual, 4*64)
So(err, ShouldBeNil)
buf := make([]byte, 4)
n, err := io.ReadFull(r, buf)
So(n, ShouldEqual, 4)
So(err, ShouldBeNil)
So(string(buf), ShouldEqual, "0064")
})
Convey("Read + Seek w SEEK_CUR should seek to right offset", func() {
r := newRS()
So(r, ShouldNotBeNil)
defer r.Close()
buf := make([]byte, 4)
io.ReadFull(r, buf)
s, err := r.Seek(4*64, os.SEEK_CUR)
So(s, ShouldEqual, 4*64+4)
So(err, ShouldBeNil)
n, err := io.ReadFull(r, buf)
So(n, ShouldEqual, 4)
So(err, ShouldBeNil)
So(string(buf), ShouldEqual, "0065")
})
Convey("Seek w SEEK_END should seek to right offset", func() {
r := newRS()
So(r, ShouldNotBeNil)
defer r.Close()
buf := make([]byte, 4)
io.ReadFull(r, buf)
s, err := r.Seek(4, os.SEEK_END)
So(s, ShouldEqual, SZ*4-4)
So(err, ShouldBeNil)
n, err := io.ReadFull(r, buf)
So(n, ShouldEqual, 4)
So(err, ShouldBeNil)
So(string(buf), ShouldEqual, fmt.Sprintf("%04d", SZ-1))
})
Convey("Short seek should consume existing request", func() {
r := newRS()
So(r, ShouldNotBeNil)
defer r.Close()
buf := make([]byte, 4)
So(r.Requests, ShouldEqual, 0)
io.ReadFull(r, buf)
So(r.Requests, ShouldEqual, 1)
s, err := r.Seek(shortSeekBytes, os.SEEK_CUR)
So(r.Requests, ShouldEqual, 1)
So(s, ShouldEqual, shortSeekBytes+4)
So(err, ShouldBeNil)
n, err := io.ReadFull(r, buf)
So(n, ShouldEqual, 4)
So(err, ShouldBeNil)
So(string(buf), ShouldEqual, "0257")
So(r.Requests, ShouldEqual, 1)
})
Convey("Long seek should do a new request", func() {
r := newRS()
So(r, ShouldNotBeNil)
defer r.Close()
buf := make([]byte, 4)
So(r.Requests, ShouldEqual, 0)
io.ReadFull(r, buf)
So(r.Requests, ShouldEqual, 1)
s, err := r.Seek(shortSeekBytes+1, os.SEEK_CUR)
So(r.Requests, ShouldEqual, 1)
So(s, ShouldEqual, shortSeekBytes+4+1)
So(err, ShouldBeNil)
n, err := io.ReadFull(r, buf)
So(n, ShouldEqual, 4)
So(err, ShouldBeNil)
So(string(buf), ShouldEqual, "2570")
So(r.Requests, ShouldEqual, 2)
})
})
}