This commit is contained in:
Sebastiaan van Stijn 2025-09-04 13:56:35 +02:00 committed by GitHub
commit 9e563efd3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 131 additions and 1 deletions

View File

@ -6,6 +6,9 @@ package templates
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"sort"
"strings"
"text/template"
)
@ -26,7 +29,7 @@ var basicFunctions = template.FuncMap{
return strings.TrimSpace(buf.String())
},
"split": strings.Split,
"join": strings.Join,
"join": joinElements,
"title": strings.Title, //nolint:nolintlint,staticcheck // strings.Title is deprecated, but we only use it for ASCII, so replacing with golang.org/x/text is out of scope
"lower": strings.ToLower,
"upper": strings.ToUpper,
@ -103,3 +106,40 @@ func truncateWithLength(source string, length int) string {
}
return source[:length]
}
// joinElements joins a slice of items with the given separator. It uses
// [strings.Join] if it's a slice of strings, otherwise uses [fmt.Sprint]
// to join each item to the output.
func joinElements(elems any, sep string) (string, error) {
if elems == nil {
return "", nil
}
if ss, ok := elems.([]string); ok {
return strings.Join(ss, sep), nil
}
switch rv := reflect.ValueOf(elems); rv.Kind() { //nolint:exhaustive // ignore: too many options to make exhaustive
case reflect.Array, reflect.Slice:
var b strings.Builder
for i := range rv.Len() {
if i > 0 {
b.WriteString(sep)
}
_, _ = fmt.Fprint(&b, rv.Index(i).Interface())
}
return b.String(), nil
case reflect.Map:
var out []string
for _, k := range rv.MapKeys() {
out = append(out, fmt.Sprint(rv.MapIndex(k).Interface()))
}
// Not ideal, but trying to keep a consistent order
sort.Strings(out)
return strings.Join(out, sep), nil
default:
return "", fmt.Errorf("expected slice, got %T", elems)
}
}

View File

@ -3,6 +3,7 @@ package templates
import (
"bytes"
"testing"
"text/template"
"gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp"
@ -139,3 +140,92 @@ func TestHeaderFunctions(t *testing.T) {
})
}
}
type stringerString string
func (s stringerString) String() string {
return "stringer" + string(s)
}
type stringerAndError string
func (s stringerAndError) String() string {
return "stringer" + string(s)
}
func (s stringerAndError) Error() string {
return "error" + string(s)
}
func TestJoinElements(t *testing.T) {
tests := []struct {
doc string
data any
expOut string
expErr string
}{
{
doc: "nil",
data: nil,
expOut: `output: ""`,
},
{
doc: "non-slice",
data: "hello",
expOut: `output: "`,
expErr: `error calling join: expected slice, got string`,
},
{
doc: "structs",
data: []struct{ A, B string }{{"1", "2"}, {"3", "4"}},
expOut: `output: "{1 2}, {3 4}"`,
},
{
doc: "map with strings",
data: map[string]string{"A": "1", "B": "2", "C": "3"},
expOut: `output: "1, 2, 3"`,
},
{
doc: "map with stringers",
data: map[string]stringerString{"A": "1", "B": "2", "C": "3"},
expOut: `output: "stringer1, stringer2, stringer3"`,
},
{
doc: "map with errors",
data: []stringerAndError{"1", "2", "3"},
expOut: `output: "error1, error2, error3"`,
},
{
doc: "stringers",
data: []stringerString{"1", "2", "3"},
expOut: `output: "stringer1, stringer2, stringer3"`,
},
{
doc: "stringer with errors",
data: []stringerAndError{"1", "2", "3"},
expOut: `output: "error1, error2, error3"`,
},
{
doc: "slice of bools",
data: []bool{true, false, true},
expOut: `output: "true, false, true"`,
},
}
const formatStr = `output: "{{- join . ", " -}}"`
tmpl, err := New("my-template").Funcs(template.FuncMap{"join": joinElements}).Parse(formatStr)
assert.NilError(t, err)
for _, tc := range tests {
t.Run(tc.doc, func(t *testing.T) {
var b bytes.Buffer
err := tmpl.Execute(&b, tc.data)
if tc.expErr != "" {
assert.ErrorContains(t, err, tc.expErr)
} else {
assert.NilError(t, err)
}
assert.Equal(t, b.String(), tc.expOut)
})
}
}