mirror of https://github.com/docker/docs.git
Merge pull request #10208 from mota/fix-env-writerto
Fix env.WriteTo count return
This commit is contained in:
commit
6d65fa1faa
|
@ -7,6 +7,8 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/docker/docker/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Env []string
|
type Env []string
|
||||||
|
@ -242,9 +244,10 @@ func (env *Env) Encode(dst io.Writer) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (env *Env) WriteTo(dst io.Writer) (n int64, err error) {
|
func (env *Env) WriteTo(dst io.Writer) (int64, error) {
|
||||||
// FIXME: return the number of bytes written to respect io.WriterTo
|
wc := utils.NewWriteCounter(dst)
|
||||||
return 0, env.Encode(dst)
|
err := env.Encode(wc)
|
||||||
|
return wc.Count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (env *Env) Import(src interface{}) (err error) {
|
func (env *Env) Import(src interface{}) (err error) {
|
||||||
|
|
|
@ -545,3 +545,24 @@ func ReadDockerIgnore(path string) ([]string, error) {
|
||||||
}
|
}
|
||||||
return excludes, nil
|
return excludes, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Wrap a concrete io.Writer and hold a count of the number
|
||||||
|
// of bytes written to the writer during a "session".
|
||||||
|
// This can be convenient when write return is masked
|
||||||
|
// (e.g., json.Encoder.Encode())
|
||||||
|
type WriteCounter struct {
|
||||||
|
Count int64
|
||||||
|
Writer io.Writer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewWriteCounter(w io.Writer) *WriteCounter {
|
||||||
|
return &WriteCounter{
|
||||||
|
Writer: w,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wc *WriteCounter) Write(p []byte) (count int, err error) {
|
||||||
|
count, err = wc.Writer.Write(p)
|
||||||
|
wc.Count += int64(count)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -97,3 +99,26 @@ func TestReadSymlinkedDirectoryToFile(t *testing.T) {
|
||||||
t.Errorf("failed to remove symlink: %s", err)
|
t.Errorf("failed to remove symlink: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestWriteCounter(t *testing.T) {
|
||||||
|
dummy1 := "This is a dummy string."
|
||||||
|
dummy2 := "This is another dummy string."
|
||||||
|
totalLength := int64(len(dummy1) + len(dummy2))
|
||||||
|
|
||||||
|
reader1 := strings.NewReader(dummy1)
|
||||||
|
reader2 := strings.NewReader(dummy2)
|
||||||
|
|
||||||
|
var buffer bytes.Buffer
|
||||||
|
wc := NewWriteCounter(&buffer)
|
||||||
|
|
||||||
|
reader1.WriteTo(wc)
|
||||||
|
reader2.WriteTo(wc)
|
||||||
|
|
||||||
|
if wc.Count != totalLength {
|
||||||
|
t.Errorf("Wrong count: %d vs. %d", wc.Count, totalLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
if buffer.String() != dummy1+dummy2 {
|
||||||
|
t.Error("Wrong message written")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue