1
0
Fork 0
mirror of https://github.com/vbatts/go-mtree.git synced 2025-10-03 20:21:01 +00:00

govis: modernise errors

This code was written before %w was added to Go, and there were a fair
few mistakes in the copy-pasted error code.

Signed-off-by: Aleksa Sarai <cyphar@cyphar.com>
This commit is contained in:
Aleksa Sarai 2025-09-21 01:38:26 +10:00
parent 3ce83fca15
commit fbada2e081
No known key found for this signature in database
GPG key ID: 2897FAD2B7E9446F
4 changed files with 89 additions and 58 deletions

View file

@ -19,6 +19,8 @@
package govis package govis
import ( import (
"errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
) )
@ -44,6 +46,25 @@ const (
VisWhite VisFlag = (VisSpace | VisTab | VisNewline) VisWhite VisFlag = (VisSpace | VisTab | VisNewline)
) )
// errUnknownVisFlagsError is a special value that lets you use [errors.Is]
// with [unknownVisFlagsError]. Don't actually return this value, use
// [unknownVisFlagsError] instead!
var errUnknownVisFlagsError = errors.New("unknown or unsupported vis flags")
// unknownVisFlagsError represents an error caused by unknown [VisFlag]s being
// passed to [Vis] or [Unvis].
type unknownVisFlagsError struct {
flags VisFlag
}
func (err unknownVisFlagsError) Is(target error) bool {
return target == errUnknownVisFlagsError
}
func (err unknownVisFlagsError) Error() string {
return fmt.Sprintf("%s contains unknown or unsupported flags %s", err.flags, err.flags&^visMask)
}
// String pretty-prints VisFlag. // String pretty-prints VisFlag.
func (vflags VisFlag) String() string { func (vflags VisFlag) String() string {
flagNames := []struct { flagNames := []struct {

View file

@ -215,3 +215,21 @@ func TestVisFlagsString(t *testing.T) {
}) })
} }
} }
func TestUnknownVisFlagsError(t *testing.T) {
t.Run("Vis", func(t *testing.T) {
enc, err := Vis("dummy text", visMask+1)
require.Error(t, err, "Vis with invalid flags should fail")
assert.ErrorIs(t, err, errUnknownVisFlagsError, "Vis with invalid flags")
assert.ErrorContains(t, err, "contains unknown or unsupported flags", "Vis with invalid flags")
assert.Equal(t, "", enc, "error Vis should return empty string")
})
t.Run("Unvis", func(t *testing.T) {
dec, err := Unvis("dummy text", visMask+1)
require.Error(t, err, "Unvis with invalid flags should fail")
assert.ErrorIs(t, err, errUnknownVisFlagsError, "Unvis with invalid flags")
assert.ErrorContains(t, err, "contains unknown or unsupported flags", "Vis with invalid flags")
assert.Equal(t, "", dec, "error Unvis should return empty string")
})
}

View file

@ -19,11 +19,19 @@
package govis package govis
import ( import (
"errors"
"fmt" "fmt"
"strconv" "strconv"
"unicode" "unicode"
) )
var (
errEndOfString = errors.New("unexpectedly reached end of string")
errUnknownEscapeChar = errors.New("unknown escape character")
errOutsideLatin1 = errors.New("outside latin-1 encoding")
errParseDigit = errors.New("could not parse digit")
)
// unvisParser stores the current state of the token parser. // unvisParser stores the current state of the token parser.
type unvisParser struct { type unvisParser struct {
tokens []rune tokens []rune
@ -39,7 +47,7 @@ func (p *unvisParser) Next() {
// Peek gets the current token. // Peek gets the current token.
func (p *unvisParser) Peek() (rune, error) { func (p *unvisParser) Peek() (rune, error) {
if p.idx >= len(p.tokens) { if p.idx >= len(p.tokens) {
return unicode.ReplacementChar, fmt.Errorf("tried to read past end of token list") return unicode.ReplacementChar, errEndOfString
} }
return p.tokens[p.idx], nil return p.tokens[p.idx], nil
} }
@ -75,23 +83,16 @@ func newParser(input string, flag VisFlag) *unvisParser {
func unvisPlainRune(p *unvisParser) ([]byte, error) { func unvisPlainRune(p *unvisParser) ([]byte, error) {
ch, err := p.Peek() ch, err := p.Peek()
if err != nil { if err != nil {
return nil, fmt.Errorf("plain rune: %c", ch) return nil, fmt.Errorf("plain rune: %w", err)
} }
p.Next() p.Next()
return []byte(string(ch)), nil
// XXX: Maybe we should not be converting to runes and then back to strings
// here. Are we sure that the byte-for-byte representation is the
// same? If the bytes change, then using these strings for paths will
// break...
str := string(ch)
return []byte(str), nil
} }
func unvisEscapeCStyle(p *unvisParser) ([]byte, error) { func unvisEscapeCStyle(p *unvisParser) ([]byte, error) {
ch, err := p.Peek() ch, err := p.Peek()
if err != nil { if err != nil {
return nil, fmt.Errorf("escape hex: %s", err) return nil, fmt.Errorf("escape cstyle: %w", err)
} }
output := "" output := ""
@ -120,7 +121,7 @@ func unvisEscapeCStyle(p *unvisParser) ([]byte, error) {
// Hidden marker. // Hidden marker.
default: default:
// XXX: We should probably allow falling through and return "\" here... // XXX: We should probably allow falling through and return "\" here...
return nil, fmt.Errorf("escape cstyle: unknown escape character: %q", ch) return nil, fmt.Errorf("escape cstyle: %w %q", errUnknownEscapeChar, ch)
} }
p.Next() p.Next()
@ -136,7 +137,7 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) {
if !force && i != 0xFF { if !force && i != 0xFF {
break break
} }
return nil, fmt.Errorf("escape base %d: %s", base, err) return nil, fmt.Errorf("escape base %d: %w", base, err)
} }
digit, err := strconv.ParseInt(string(ch), base, 8) digit, err := strconv.ParseInt(string(ch), base, 8)
@ -144,7 +145,7 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) {
if !force && i != 0xFF { if !force && i != 0xFF {
break break
} }
return nil, fmt.Errorf("escape base %d: could not parse digit: %s", base, err) return nil, fmt.Errorf("escape base %d: %w %q: %w", base, errParseDigit, ch, err)
} }
code = (code * base) + int(digit) code = (code * base) + int(digit)
@ -152,7 +153,7 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) {
} }
if code > unicode.MaxLatin1 { if code > unicode.MaxLatin1 {
return nil, fmt.Errorf("escape base %d: code %q outside latin-1 encoding", base, code) return nil, fmt.Errorf("escape base %d: code %+.2x %w", base, code, errOutsideLatin1)
} }
char := byte(code & 0xFF) char := byte(code & 0xFF)
@ -162,10 +163,10 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) {
func unvisEscapeCtrl(p *unvisParser, mask byte) ([]byte, error) { func unvisEscapeCtrl(p *unvisParser, mask byte) ([]byte, error) {
ch, err := p.Peek() ch, err := p.Peek()
if err != nil { if err != nil {
return nil, fmt.Errorf("escape ctrl: %s", err) return nil, fmt.Errorf("escape ctrl: %w", err)
} }
if ch > unicode.MaxLatin1 { if ch > unicode.MaxLatin1 {
return nil, fmt.Errorf("escape ctrl: code %q outside latin-1 encoding", ch) return nil, fmt.Errorf("escape ctrl: code %q %w", ch, errOutsideLatin1)
} }
char := byte(ch) & 0x1f char := byte(ch) & 0x1f
@ -180,7 +181,7 @@ func unvisEscapeCtrl(p *unvisParser, mask byte) ([]byte, error) {
func unvisEscapeMeta(p *unvisParser) ([]byte, error) { func unvisEscapeMeta(p *unvisParser) ([]byte, error) {
ch, err := p.Peek() ch, err := p.Peek()
if err != nil { if err != nil {
return nil, fmt.Errorf("escape meta: %s", err) return nil, fmt.Errorf("escape meta: %w", err)
} }
mask := byte(0x80) mask := byte(0x80)
@ -196,10 +197,10 @@ func unvisEscapeMeta(p *unvisParser) ([]byte, error) {
ch, err := p.Peek() ch, err := p.Peek()
if err != nil { if err != nil {
return nil, fmt.Errorf("escape meta1: %s", err) return nil, fmt.Errorf("escape meta1: %w", err)
} }
if ch > unicode.MaxLatin1 { if ch > unicode.MaxLatin1 {
return nil, fmt.Errorf("escape meta1: code %q outside latin-1 encoding", ch) return nil, fmt.Errorf("escape meta1: code %q %w", ch, errOutsideLatin1)
} }
// Add mask to character. // Add mask to character.
@ -207,13 +208,13 @@ func unvisEscapeMeta(p *unvisParser) ([]byte, error) {
return []byte{mask | byte(ch)}, nil return []byte{mask | byte(ch)}, nil
} }
return nil, fmt.Errorf("escape meta: unknown escape char: %s", err) return nil, fmt.Errorf("escape meta: %w %q", errUnknownEscapeChar, ch)
} }
func unvisEscapeSequence(p *unvisParser) ([]byte, error) { func unvisEscapeSequence(p *unvisParser) ([]byte, error) {
ch, err := p.Peek() ch, err := p.Peek()
if err != nil { if err != nil {
return nil, fmt.Errorf("escape sequence: %s", err) return nil, fmt.Errorf("escape sequence: %w", err)
} }
switch ch { switch ch {
@ -244,7 +245,7 @@ func unvisEscapeSequence(p *unvisParser) ([]byte, error) {
func unvisRune(p *unvisParser) ([]byte, error) { func unvisRune(p *unvisParser) ([]byte, error) {
ch, err := p.Peek() ch, err := p.Peek()
if err != nil { if err != nil {
return nil, fmt.Errorf("rune: %s", err) return nil, err
} }
switch ch { switch ch {
@ -258,11 +259,8 @@ func unvisRune(p *unvisParser) ([]byte, error) {
p.Next() p.Next()
return unvisEscapeDigits(p, 16, true) return unvisEscapeDigits(p, 16, true)
} }
fallthrough
default:
return unvisPlainRune(p)
} }
return unvisPlainRune(p)
} }
func unvis(p *unvisParser) (string, error) { func unvis(p *unvisParser) (string, error) {
@ -270,7 +268,7 @@ func unvis(p *unvisParser) (string, error) {
for !p.End() { for !p.End() {
ch, err := unvisRune(p) ch, err := unvisRune(p)
if err != nil { if err != nil {
return "", fmt.Errorf("input: %s", err) return "", err
} }
output = append(output, ch...) output = append(output, ch...)
} }
@ -281,15 +279,14 @@ func unvis(p *unvisParser) (string, error) {
// VisHTTPStyle flag is checked) and output the un-encoded version of the // VisHTTPStyle flag is checked) and output the un-encoded version of the
// encoded string. An error is returned if any escape sequences in the input // encoded string. An error is returned if any escape sequences in the input
// string were invalid. // string were invalid.
func Unvis(input string, flag VisFlag) (string, error) { func Unvis(input string, flags VisFlag) (string, error) {
// TODO: Check all of the VisFlag bits. if unknown := flags &^ visMask; unknown != 0 {
p := newParser(input, flag) return "", unknownVisFlagsError{flags: flags}
}
p := newParser(input, flags)
output, err := unvis(p) output, err := unvis(p)
if err != nil { if err != nil {
return "", fmt.Errorf("unvis: %s", err) return "", fmt.Errorf("unvis '%s': %w", input, err)
}
if !p.End() {
return "", fmt.Errorf("unvis: trailing characters at end of input")
} }
return output, nil return output, nil
} }

View file

@ -59,14 +59,14 @@ func isgraph(ch rune) bool {
// the plus side this is actually a benefit on the encoding side (it will // the plus side this is actually a benefit on the encoding side (it will
// always work with the simple unvis(3) implementation). It also means that we // always work with the simple unvis(3) implementation). It also means that we
// don't have to worry about different multi-byte encodings. // don't have to worry about different multi-byte encodings.
func vis(b byte, flag VisFlag) (string, error) { func vis(b byte, flag VisFlag) string {
// Treat the single-byte character as a rune. // Treat the single-byte character as a rune.
ch := rune(b) ch := rune(b)
// XXX: This is quite a horrible thing to support. // XXX: This is quite a horrible thing to support.
if flag&VisHTTPStyle == VisHTTPStyle { if flag&VisHTTPStyle == VisHTTPStyle {
if !ishttp(ch) { if !ishttp(ch) {
return "%" + fmt.Sprintf("%.2X", ch), nil return "%" + fmt.Sprintf("%.2X", ch)
} }
} }
@ -90,31 +90,31 @@ func vis(b byte, flag VisFlag) (string, error) {
if ch == '\\' && flag&VisNoSlash == 0 { if ch == '\\' && flag&VisNoSlash == 0 {
encoded += "\\" encoded += "\\"
} }
return encoded, nil return encoded
} }
// Try to use C-style escapes first. // Try to use C-style escapes first.
if flag&VisCStyle == VisCStyle { if flag&VisCStyle == VisCStyle {
switch ch { switch ch {
case ' ': case ' ':
return "\\s", nil return "\\s"
case '\n': case '\n':
return "\\n", nil return "\\n"
case '\r': case '\r':
return "\\r", nil return "\\r"
case '\b': case '\b':
return "\\b", nil return "\\b"
case '\a': case '\a':
return "\\a", nil return "\\a"
case '\v': case '\v':
return "\\v", nil return "\\v"
case '\t': case '\t':
return "\\t", nil return "\\t"
case '\f': case '\f':
return "\\f", nil return "\\f"
case '\x00': case '\x00':
// Output octal just to be safe. // Output octal just to be safe.
return "\\000", nil return "\\000"
} }
} }
@ -123,7 +123,7 @@ func vis(b byte, flag VisFlag) (string, error) {
// encoded as octal. // encoded as octal.
if flag&VisOctal == VisOctal || isgraph(ch) || ch&0x7f == ' ' { if flag&VisOctal == VisOctal || isgraph(ch) || ch&0x7f == ' ' {
// Always output three-character octal just to be safe. // Always output three-character octal just to be safe.
return fmt.Sprintf("\\%.3o", ch), nil return fmt.Sprintf("\\%.3o", ch)
} }
// Now we have to output meta or ctrl escapes. As far as I can tell, this // Now we have to output meta or ctrl escapes. As far as I can tell, this
@ -154,25 +154,20 @@ func vis(b byte, flag VisFlag) (string, error) {
encoded += fmt.Sprintf("-%c", b) encoded += fmt.Sprintf("-%c", b)
} }
return encoded, nil return encoded
} }
// Vis encodes the provided string to a BSD-compatible encoding using BSD's // Vis encodes the provided string to a BSD-compatible encoding using BSD's
// vis() flags. However, it will correctly handle multi-byte encoding (which is // vis() flags. However, it will correctly handle multi-byte encoding (which is
// not done properly by BSD's vis implementation). // not done properly by BSD's vis implementation).
func Vis(src string, flag VisFlag) (string, error) { func Vis(src string, flags VisFlag) (string, error) {
if flag&visMask != flag { if unknown := flags &^ visMask; unknown != 0 {
return "", fmt.Errorf("vis: flag %q contains unknown or unsupported flags", flag) return "", unknownVisFlagsError{flags: flags}
} }
output := "" output := ""
for _, ch := range []byte(src) { for _, ch := range []byte(src) {
encodedCh, err := vis(ch, flag) output += vis(ch, flags)
if err != nil {
return "", err
}
output += encodedCh
} }
return output, nil return output, nil
} }