fs.Visit() returns nil flag
Signed-off-by: Sven Dowideit <SvenDowideit@docker.com>
This commit is contained in:
parent
47f57a7b7c
commit
6a1cc969fc
2 changed files with 56 additions and 4 deletions
|
@ -317,8 +317,13 @@ func (p flagSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||||
// sortFlags returns the flags as a slice in lexicographical sorted order.
|
// sortFlags returns the flags as a slice in lexicographical sorted order.
|
||||||
func sortFlags(flags map[string]*Flag) []*Flag {
|
func sortFlags(flags map[string]*Flag) []*Flag {
|
||||||
var list flagSlice
|
var list flagSlice
|
||||||
for _, f := range flags {
|
|
||||||
|
// The sorted list is based on the first name, when flag map might use the other names.
|
||||||
|
nameMap := make(map[string]string)
|
||||||
|
|
||||||
|
for n, f := range flags {
|
||||||
fName := strings.TrimPrefix(f.Names[0], "#")
|
fName := strings.TrimPrefix(f.Names[0], "#")
|
||||||
|
nameMap[fName] = n
|
||||||
if len(f.Names) == 1 {
|
if len(f.Names) == 1 {
|
||||||
list = append(list, fName)
|
list = append(list, fName)
|
||||||
continue
|
continue
|
||||||
|
@ -338,7 +343,7 @@ func sortFlags(flags map[string]*Flag) []*Flag {
|
||||||
sort.Sort(list)
|
sort.Sort(list)
|
||||||
result := make([]*Flag, len(list))
|
result := make([]*Flag, len(list))
|
||||||
for i, name := range list {
|
for i, name := range list {
|
||||||
result[i] = flags[name]
|
result[i] = flags[nameMap[name]]
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
@ -473,7 +478,7 @@ var Usage = func() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// FlagCount returns the number of flags that have been defined.
|
// FlagCount returns the number of flags that have been defined.
|
||||||
func (f *FlagSet) FlagCount() int { return len(f.formal) }
|
func (f *FlagSet) FlagCount() int { return len(sortFlags(f.formal)) }
|
||||||
|
|
||||||
// FlagCountUndeprecated returns the number of undeprecated flags that have been defined.
|
// FlagCountUndeprecated returns the number of undeprecated flags that have been defined.
|
||||||
func (f *FlagSet) FlagCountUndeprecated() int {
|
func (f *FlagSet) FlagCountUndeprecated() int {
|
||||||
|
|
|
@ -440,7 +440,7 @@ func TestFlagCounts(t *testing.T) {
|
||||||
fs.BoolVar(&flag, []string{"flag3"}, false, "regular flag")
|
fs.BoolVar(&flag, []string{"flag3"}, false, "regular flag")
|
||||||
fs.BoolVar(&flag, []string{"g", "#flag4", "-flag4"}, false, "regular flag")
|
fs.BoolVar(&flag, []string{"g", "#flag4", "-flag4"}, false, "regular flag")
|
||||||
|
|
||||||
if fs.FlagCount() != 10 {
|
if fs.FlagCount() != 6 {
|
||||||
t.Fatal("FlagCount wrong. ", fs.FlagCount())
|
t.Fatal("FlagCount wrong. ", fs.FlagCount())
|
||||||
}
|
}
|
||||||
if fs.FlagCountUndeprecated() != 4 {
|
if fs.FlagCountUndeprecated() != 4 {
|
||||||
|
@ -457,3 +457,50 @@ func TestFlagCounts(t *testing.T) {
|
||||||
t.Fatal("NFlag wrong. ", fs.NFlag())
|
t.Fatal("NFlag wrong. ", fs.NFlag())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Show up bug in sortFlags
|
||||||
|
func TestSortFlags(t *testing.T) {
|
||||||
|
fs := NewFlagSet("help TestSortFlags", ContinueOnError)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
var b bool
|
||||||
|
fs.BoolVar(&b, []string{"b", "-banana"}, false, "usage")
|
||||||
|
|
||||||
|
err = fs.Parse([]string{"--banana=true"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("expected no error; got ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
|
||||||
|
fs.VisitAll(func(flag *Flag) {
|
||||||
|
count++
|
||||||
|
if flag == nil {
|
||||||
|
t.Fatal("VisitAll should not return a nil flag")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
flagcount := fs.FlagCount()
|
||||||
|
if flagcount != count {
|
||||||
|
t.Fatalf("FlagCount (%d) != number (%d) of elements visited", flagcount, count)
|
||||||
|
}
|
||||||
|
// Make sure its idempotent
|
||||||
|
if flagcount != fs.FlagCount() {
|
||||||
|
t.Fatalf("FlagCount (%d) != fs.FlagCount() (%d) of elements visited", flagcount, fs.FlagCount())
|
||||||
|
}
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
fs.Visit(func(flag *Flag) {
|
||||||
|
count++
|
||||||
|
if flag == nil {
|
||||||
|
t.Fatal("Visit should not return a nil flag")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
nflag := fs.NFlag()
|
||||||
|
if nflag != count {
|
||||||
|
t.Fatalf("NFlag (%d) != number (%d) of elements visited", nflag, count)
|
||||||
|
}
|
||||||
|
if nflag != fs.NFlag() {
|
||||||
|
t.Fatalf("NFlag (%d) != fs.NFlag() (%d) of elements visited", nflag, fs.NFlag())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue