diff --git a/stringutils/strslice.go b/stringutils/strslice.go new file mode 100644 index 0000000..4055754 --- /dev/null +++ b/stringutils/strslice.go @@ -0,0 +1,71 @@ +package stringutils + +import ( + "encoding/json" + "strings" +) + +// StrSlice representes a string or an array of strings. +// We need to override the json decoder to accept both options. +type StrSlice struct { + parts []string +} + +// MarshalJSON Marshals (or serializes) the StrSlice into the json format. +// This method is needed to implement json.Marshaller. +func (e *StrSlice) MarshalJSON() ([]byte, error) { + if e == nil { + return []byte{}, nil + } + return json.Marshal(e.Slice()) +} + +// UnmarshalJSON decodes the byte slice whether it's a string or an array of strings. +// This method is needed to implement json.Unmarshaler. +func (e *StrSlice) UnmarshalJSON(b []byte) error { + if len(b) == 0 { + return nil + } + + p := make([]string, 0, 1) + if err := json.Unmarshal(b, &p); err != nil { + var s string + if err := json.Unmarshal(b, &s); err != nil { + return err + } + p = append(p, s) + } + + e.parts = p + return nil +} + +// Len returns the number of parts of the StrSlice. +func (e *StrSlice) Len() int { + if e == nil { + return 0 + } + return len(e.parts) +} + +// Slice gets the parts of the StrSlice as a Slice of string. +func (e *StrSlice) Slice() []string { + if e == nil { + return nil + } + return e.parts +} + +// ToString gets space separated string of all the parts. +func (e *StrSlice) ToString() string { + s := e.Slice() + if s == nil { + return "" + } + return strings.Join(s, " ") +} + +// NewStrSlice creates an StrSlice based on the specified parts (as strings). +func NewStrSlice(parts ...string) *StrSlice { + return &StrSlice{parts} +} diff --git a/stringutils/strslice_test.go b/stringutils/strslice_test.go new file mode 100644 index 0000000..d0e6b1b --- /dev/null +++ b/stringutils/strslice_test.go @@ -0,0 +1,105 @@ +package stringutils + +import ( + "encoding/json" + "testing" +) + +func TestStrSliceMarshalJSON(t *testing.T) { + strss := map[*StrSlice]string{ + nil: "", + &StrSlice{}: "null", + &StrSlice{[]string{"/bin/sh", "-c", "echo"}}: `["/bin/sh","-c","echo"]`, + } + + for strs, expected := range strss { + data, err := strs.MarshalJSON() + if err != nil { + t.Fatal(err) + } + if string(data) != expected { + t.Fatalf("Expected %v, got %v", expected, string(data)) + } + } +} + +func TestStrSliceUnmarshalJSON(t *testing.T) { + parts := map[string][]string{ + "": {"default", "values"}, + "[]": {}, + `["/bin/sh","-c","echo"]`: {"/bin/sh", "-c", "echo"}, + } + for json, expectedParts := range parts { + strs := &StrSlice{ + []string{"default", "values"}, + } + if err := strs.UnmarshalJSON([]byte(json)); err != nil { + t.Fatal(err) + } + + actualParts := strs.Slice() + if len(actualParts) != len(expectedParts) { + t.Fatalf("Expected %v parts, got %v (%v)", len(expectedParts), len(actualParts), expectedParts) + } + for index, part := range actualParts { + if part != expectedParts[index] { + t.Fatalf("Expected %v, got %v", expectedParts, actualParts) + break + } + } + } +} + +func TestStrSliceUnmarshalString(t *testing.T) { + var e *StrSlice + echo, err := json.Marshal("echo") + if err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(echo, &e); err != nil { + t.Fatal(err) + } + + slice := e.Slice() + if len(slice) != 1 { + t.Fatalf("expected 1 element after unmarshal: %q", slice) + } + + if slice[0] != "echo" { + t.Fatalf("expected `echo`, got: %q", slice[0]) + } +} + +func TestStrSliceUnmarshalSlice(t *testing.T) { + var e *StrSlice + echo, err := json.Marshal([]string{"echo"}) + if err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(echo, &e); err != nil { + t.Fatal(err) + } + + slice := e.Slice() + if len(slice) != 1 { + t.Fatalf("expected 1 element after unmarshal: %q", slice) + } + + if slice[0] != "echo" { + t.Fatalf("expected `echo`, got: %q", slice[0]) + } +} + +func TestStrSliceToString(t *testing.T) { + slices := map[*StrSlice]string{ + NewStrSlice(""): "", + NewStrSlice("one"): "one", + NewStrSlice("one", "two"): "one two", + } + for s, expected := range slices { + toString := s.ToString() + if toString != expected { + t.Fatalf("Expected %v, got %v", expected, toString) + } + } +}