diff --git a/mflag/flag.go b/mflag/flag.go index 2ad299a..e7fabe0 100644 --- a/mflag/flag.go +++ b/mflag/flag.go @@ -1223,11 +1223,27 @@ func (v mergeVal) IsBoolFlag() bool { return false } +// Name returns the name of a mergeVal. +// If the original value had a name, return the original name, +// otherwise, return the key asinged to this mergeVal. +func (v mergeVal) Name() string { + type namedValue interface { + Name() string + } + if nVal, ok := v.Value.(namedValue); ok { + return nVal.Name() + } + return v.key +} + // Merge is an helper function that merges n FlagSets into a single dest FlagSet // In case of name collision between the flagsets it will apply // the destination FlagSet's errorHandling behavior. func Merge(dest *FlagSet, flagsets ...*FlagSet) error { for _, fset := range flagsets { + if fset.formal == nil { + continue + } for k, f := range fset.formal { if _, ok := dest.formal[k]; ok { var err error @@ -1249,6 +1265,9 @@ func Merge(dest *FlagSet, flagsets ...*FlagSet) error { } newF := *f newF.Value = mergeVal{f.Value, k, fset} + if dest.formal == nil { + dest.formal = make(map[string]*Flag) + } dest.formal[k] = &newF } } diff --git a/mflag/flag_test.go b/mflag/flag_test.go index c28deda..1383555 100644 --- a/mflag/flag_test.go +++ b/mflag/flag_test.go @@ -514,3 +514,14 @@ func TestSortFlags(t *testing.T) { t.Fatalf("NFlag (%d) != fs.NFlag() (%d) of elements visited", nflag, fs.NFlag()) } } + +func TestMergeFlags(t *testing.T) { + base := NewFlagSet("base", ContinueOnError) + base.String([]string{"f"}, "", "") + + fs := NewFlagSet("test", ContinueOnError) + Merge(fs, base) + if len(fs.formal) != 1 { + t.Fatalf("FlagCount (%d) != number (1) of elements merged", len(fs.formal)) + } +}