diff --git a/mflag/flag.go b/mflag/flag.go index 0eebfc1..b40f911 100644 --- a/mflag/flag.go +++ b/mflag/flag.go @@ -317,8 +317,13 @@ func (p flagSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] } // sortFlags returns the flags as a slice in lexicographical sorted order. func sortFlags(flags map[string]*Flag) []*Flag { var list flagSlice - for _, f := range flags { + + // The sorted list is based on the first name, when flag map might use the other names. + nameMap := make(map[string]string) + + for n, f := range flags { fName := strings.TrimPrefix(f.Names[0], "#") + nameMap[fName] = n if len(f.Names) == 1 { list = append(list, fName) continue @@ -338,7 +343,7 @@ func sortFlags(flags map[string]*Flag) []*Flag { sort.Sort(list) result := make([]*Flag, len(list)) for i, name := range list { - result[i] = flags[name] + result[i] = flags[nameMap[name]] } return result } @@ -473,7 +478,7 @@ var Usage = func() { } // FlagCount returns the number of flags that have been defined. -func (f *FlagSet) FlagCount() int { return len(f.formal) } +func (f *FlagSet) FlagCount() int { return len(sortFlags(f.formal)) } // FlagCountUndeprecated returns the number of undeprecated flags that have been defined. func (f *FlagSet) FlagCountUndeprecated() int { diff --git a/mflag/flag_test.go b/mflag/flag_test.go index 9321926..340a1cb 100644 --- a/mflag/flag_test.go +++ b/mflag/flag_test.go @@ -440,7 +440,7 @@ func TestFlagCounts(t *testing.T) { fs.BoolVar(&flag, []string{"flag3"}, false, "regular flag") fs.BoolVar(&flag, []string{"g", "#flag4", "-flag4"}, false, "regular flag") - if fs.FlagCount() != 10 { + if fs.FlagCount() != 6 { t.Fatal("FlagCount wrong. ", fs.FlagCount()) } if fs.FlagCountUndeprecated() != 4 { @@ -457,3 +457,50 @@ func TestFlagCounts(t *testing.T) { t.Fatal("NFlag wrong. ", fs.NFlag()) } } + +// Show up bug in sortFlags +func TestSortFlags(t *testing.T) { + fs := NewFlagSet("help TestSortFlags", ContinueOnError) + + var err error + + var b bool + fs.BoolVar(&b, []string{"b", "-banana"}, false, "usage") + + err = fs.Parse([]string{"--banana=true"}) + if err != nil { + t.Fatal("expected no error; got ", err) + } + + count := 0 + + fs.VisitAll(func(flag *Flag) { + count++ + if flag == nil { + t.Fatal("VisitAll should not return a nil flag") + } + }) + flagcount := fs.FlagCount() + if flagcount != count { + t.Fatalf("FlagCount (%d) != number (%d) of elements visited", flagcount, count) + } + // Make sure its idempotent + if flagcount != fs.FlagCount() { + t.Fatalf("FlagCount (%d) != fs.FlagCount() (%d) of elements visited", flagcount, fs.FlagCount()) + } + + count = 0 + fs.Visit(func(flag *Flag) { + count++ + if flag == nil { + t.Fatal("Visit should not return a nil flag") + } + }) + nflag := fs.NFlag() + if nflag != count { + t.Fatalf("NFlag (%d) != number (%d) of elements visited", nflag, count) + } + if nflag != fs.NFlag() { + t.Fatalf("NFlag (%d) != fs.NFlag() (%d) of elements visited", nflag, fs.NFlag()) + } +}