2014-09-17 11:24:07 -04:00
|
|
|
package stdcopy
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2015-09-20 05:46:53 -04:00
|
|
|
"errors"
|
2015-09-24 06:39:54 -04:00
|
|
|
"io"
|
2014-09-17 11:24:07 -04:00
|
|
|
"io/ioutil"
|
2015-04-16 08:17:23 -04:00
|
|
|
"strings"
|
2014-09-17 11:24:07 -04:00
|
|
|
"testing"
|
|
|
|
)
|
|
|
|
|
2015-04-16 08:17:23 -04:00
|
|
|
func TestNewStdWriter(t *testing.T) {
|
|
|
|
writer := NewStdWriter(ioutil.Discard, Stdout)
|
|
|
|
if writer == nil {
|
|
|
|
t.Fatalf("NewStdWriter with an invalid StdType should not return nil.")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-05-21 19:24:07 -04:00
|
|
|
func TestWriteWithUninitializedStdWriter(t *testing.T) {
|
2016-02-25 18:22:19 -05:00
|
|
|
writer := stdWriter{
|
|
|
|
Writer: nil,
|
|
|
|
prefix: byte(Stdout),
|
2015-04-16 08:17:23 -04:00
|
|
|
}
|
|
|
|
n, err := writer.Write([]byte("Something here"))
|
|
|
|
if n != 0 || err == nil {
|
|
|
|
t.Fatalf("Should fail when given an uncomplete or uninitialized StdWriter")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestWriteWithNilBytes(t *testing.T) {
|
|
|
|
writer := NewStdWriter(ioutil.Discard, Stdout)
|
|
|
|
n, err := writer.Write(nil)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Shouldn't have fail when given no data")
|
|
|
|
}
|
|
|
|
if n > 0 {
|
|
|
|
t.Fatalf("Write should have written 0 byte, but has written %d", n)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestWrite(t *testing.T) {
|
|
|
|
writer := NewStdWriter(ioutil.Discard, Stdout)
|
|
|
|
data := []byte("Test StdWrite.Write")
|
|
|
|
n, err := writer.Write(data)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Error while writing with StdWrite")
|
|
|
|
}
|
|
|
|
if n != len(data) {
|
2015-08-07 18:24:18 -04:00
|
|
|
t.Fatalf("Write should have written %d byte but wrote %d.", len(data), n)
|
2015-04-16 08:17:23 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-09-20 05:46:53 -04:00
|
|
|
type errWriter struct {
|
|
|
|
n int
|
|
|
|
err error
|
|
|
|
}
|
|
|
|
|
|
|
|
func (f *errWriter) Write(buf []byte) (int, error) {
|
|
|
|
return f.n, f.err
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestWriteWithWriterError(t *testing.T) {
|
|
|
|
expectedError := errors.New("expected")
|
|
|
|
expectedReturnedBytes := 10
|
|
|
|
writer := NewStdWriter(&errWriter{
|
|
|
|
n: stdWriterPrefixLen + expectedReturnedBytes,
|
|
|
|
err: expectedError}, Stdout)
|
|
|
|
data := []byte("This won't get written, sigh")
|
|
|
|
n, err := writer.Write(data)
|
|
|
|
if err != expectedError {
|
|
|
|
t.Fatalf("Didn't get expected error.")
|
|
|
|
}
|
|
|
|
if n != expectedReturnedBytes {
|
2016-02-22 14:22:20 -05:00
|
|
|
t.Fatalf("Didn't get expected written bytes %d, got %d.",
|
2015-09-20 05:46:53 -04:00
|
|
|
expectedReturnedBytes, n)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestWriteDoesNotReturnNegativeWrittenBytes(t *testing.T) {
|
|
|
|
writer := NewStdWriter(&errWriter{n: -1}, Stdout)
|
|
|
|
data := []byte("This won't get written, sigh")
|
|
|
|
actual, _ := writer.Write(data)
|
|
|
|
if actual != 0 {
|
|
|
|
t.Fatalf("Expected returned written bytes equal to 0, got %d", actual)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-09-24 06:39:54 -04:00
|
|
|
func getSrcBuffer(stdOutBytes, stdErrBytes []byte) (buffer *bytes.Buffer, err error) {
|
|
|
|
buffer = new(bytes.Buffer)
|
2015-09-22 18:30:06 -04:00
|
|
|
dstOut := NewStdWriter(buffer, Stdout)
|
2015-09-24 06:39:54 -04:00
|
|
|
_, err = dstOut.Write(stdOutBytes)
|
2015-09-22 18:30:06 -04:00
|
|
|
if err != nil {
|
2015-09-24 06:39:54 -04:00
|
|
|
return
|
2015-09-22 18:30:06 -04:00
|
|
|
}
|
|
|
|
dstErr := NewStdWriter(buffer, Stderr)
|
|
|
|
_, err = dstErr.Write(stdErrBytes)
|
2015-09-24 06:39:54 -04:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestStdCopyWriteAndRead(t *testing.T) {
|
|
|
|
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
2015-09-22 18:30:06 -04:00
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
written, err := StdCopy(ioutil.Discard, ioutil.Discard, buffer)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
expectedTotalWritten := len(stdOutBytes) + len(stdErrBytes)
|
|
|
|
if written != int64(expectedTotalWritten) {
|
|
|
|
t.Fatalf("Expected to have total of %d bytes written, got %d", expectedTotalWritten, written)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-09-24 06:39:54 -04:00
|
|
|
type customReader struct {
|
|
|
|
n int
|
|
|
|
err error
|
|
|
|
totalCalls int
|
|
|
|
correctCalls int
|
|
|
|
src *bytes.Buffer
|
|
|
|
}
|
|
|
|
|
|
|
|
func (f *customReader) Read(buf []byte) (int, error) {
|
|
|
|
f.totalCalls++
|
|
|
|
if f.totalCalls <= f.correctCalls {
|
|
|
|
return f.src.Read(buf)
|
|
|
|
}
|
|
|
|
return f.n, f.err
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestStdCopyReturnsErrorReadingHeader(t *testing.T) {
|
|
|
|
expectedError := errors.New("error")
|
|
|
|
reader := &customReader{
|
|
|
|
err: expectedError}
|
|
|
|
written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
|
|
|
|
if written != 0 {
|
|
|
|
t.Fatalf("Expected 0 bytes read, got %d", written)
|
|
|
|
}
|
|
|
|
if err != expectedError {
|
|
|
|
t.Fatalf("Didn't get expected error")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestStdCopyReturnsErrorReadingFrame(t *testing.T) {
|
|
|
|
expectedError := errors.New("error")
|
|
|
|
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
reader := &customReader{
|
|
|
|
correctCalls: 1,
|
|
|
|
n: stdWriterPrefixLen + 1,
|
|
|
|
err: expectedError,
|
|
|
|
src: buffer}
|
|
|
|
written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
|
|
|
|
if written != 0 {
|
|
|
|
t.Fatalf("Expected 0 bytes read, got %d", written)
|
|
|
|
}
|
|
|
|
if err != expectedError {
|
|
|
|
t.Fatalf("Didn't get expected error")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestStdCopyDetectsCorruptedFrame(t *testing.T) {
|
|
|
|
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
reader := &customReader{
|
|
|
|
correctCalls: 1,
|
|
|
|
n: stdWriterPrefixLen + 1,
|
|
|
|
err: io.EOF,
|
|
|
|
src: buffer}
|
|
|
|
written, err := StdCopy(ioutil.Discard, ioutil.Discard, reader)
|
|
|
|
if written != startingBufLen {
|
2016-02-25 18:22:19 -05:00
|
|
|
t.Fatalf("Expected %d bytes read, got %d", startingBufLen, written)
|
2015-09-24 06:39:54 -04:00
|
|
|
}
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal("Didn't get nil error")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-04-16 08:17:23 -04:00
|
|
|
func TestStdCopyWithInvalidInputHeader(t *testing.T) {
|
|
|
|
dstOut := NewStdWriter(ioutil.Discard, Stdout)
|
|
|
|
dstErr := NewStdWriter(ioutil.Discard, Stderr)
|
|
|
|
src := strings.NewReader("Invalid input")
|
|
|
|
_, err := StdCopy(dstOut, dstErr, src)
|
|
|
|
if err == nil {
|
|
|
|
t.Fatal("StdCopy with invalid input header should fail.")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestStdCopyWithCorruptedPrefix(t *testing.T) {
|
|
|
|
data := []byte{0x01, 0x02, 0x03}
|
|
|
|
src := bytes.NewReader(data)
|
|
|
|
written, err := StdCopy(nil, nil, src)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("StdCopy should not return an error with corrupted prefix.")
|
|
|
|
}
|
|
|
|
if written != 0 {
|
|
|
|
t.Fatalf("StdCopy should have written 0, but has written %d", written)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-09-24 06:39:54 -04:00
|
|
|
func TestStdCopyReturnsWriteErrors(t *testing.T) {
|
|
|
|
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
expectedError := errors.New("expected")
|
|
|
|
|
|
|
|
dstOut := &errWriter{err: expectedError}
|
|
|
|
|
|
|
|
written, err := StdCopy(dstOut, ioutil.Discard, buffer)
|
|
|
|
if written != 0 {
|
|
|
|
t.Fatalf("StdCopy should have written 0, but has written %d", written)
|
|
|
|
}
|
|
|
|
if err != expectedError {
|
|
|
|
t.Fatalf("Didn't get expected error, got %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestStdCopyDetectsNotFullyWrittenFrames(t *testing.T) {
|
|
|
|
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
dstOut := &errWriter{n: startingBufLen - 10}
|
|
|
|
|
|
|
|
written, err := StdCopy(dstOut, ioutil.Discard, buffer)
|
|
|
|
if written != 0 {
|
|
|
|
t.Fatalf("StdCopy should have return 0 written bytes, but returned %d", written)
|
|
|
|
}
|
|
|
|
if err != io.ErrShortWrite {
|
|
|
|
t.Fatalf("Didn't get expected io.ErrShortWrite error")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2017-02-09 17:11:29 -05:00
|
|
|
// TestStdCopyReturnsErrorFromSystem tests that StdCopy correctly returns an
|
|
|
|
// error, when that error is muxed into the Systemerr stream.
|
|
|
|
func TestStdCopyReturnsErrorFromSystem(t *testing.T) {
|
|
|
|
// write in the basic messages, just so there's some fluff in there
|
|
|
|
stdOutBytes := []byte(strings.Repeat("o", startingBufLen))
|
|
|
|
stdErrBytes := []byte(strings.Repeat("e", startingBufLen))
|
|
|
|
buffer, err := getSrcBuffer(stdOutBytes, stdErrBytes)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
// add in an error message on the Systemerr stream
|
|
|
|
systemErrBytes := []byte(strings.Repeat("S", startingBufLen))
|
|
|
|
systemWriter := NewStdWriter(buffer, Systemerr)
|
|
|
|
_, err = systemWriter.Write(systemErrBytes)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
|
|
|
|
|
|
|
// now copy and demux. we should expect an error containing the string we
|
|
|
|
// wrote out
|
|
|
|
_, err = StdCopy(ioutil.Discard, ioutil.Discard, buffer)
|
|
|
|
if err == nil {
|
|
|
|
t.Fatal("expected error, got none")
|
|
|
|
}
|
|
|
|
if !strings.Contains(err.Error(), string(systemErrBytes)) {
|
|
|
|
t.Fatal("expected error to contain message")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2014-09-17 11:24:07 -04:00
|
|
|
func BenchmarkWrite(b *testing.B) {
|
|
|
|
w := NewStdWriter(ioutil.Discard, Stdout)
|
|
|
|
data := []byte("Test line for testing stdwriter performance\n")
|
|
|
|
data = bytes.Repeat(data, 100)
|
|
|
|
b.SetBytes(int64(len(data)))
|
|
|
|
b.ResetTimer()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
|
|
if _, err := w.Write(data); err != nil {
|
|
|
|
b.Fatal(err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|