1
0
Fork 0
mirror of https://github.com/vbatts/go-mtree.git synced 2025-10-03 20:21:01 +00:00

govis: modernise errors

This code was written before %w was added to Go, and there were a fair
few mistakes in the copy-pasted error code.

Signed-off-by: Aleksa Sarai <cyphar@cyphar.com>
This commit is contained in:
Aleksa Sarai 2025-09-21 01:38:26 +10:00
parent 3ce83fca15
commit fbada2e081
No known key found for this signature in database
GPG key ID: 2897FAD2B7E9446F
4 changed files with 89 additions and 58 deletions

View file

@ -19,6 +19,8 @@
package govis
import (
"errors"
"fmt"
"strconv"
"strings"
)
@ -44,6 +46,25 @@ const (
VisWhite VisFlag = (VisSpace | VisTab | VisNewline)
)
// errUnknownVisFlagsError is a special value that lets you use [errors.Is]
// with [unknownVisFlagsError]. Don't actually return this value, use
// [unknownVisFlagsError] instead!
var errUnknownVisFlagsError = errors.New("unknown or unsupported vis flags")
// unknownVisFlagsError represents an error caused by unknown [VisFlag]s being
// passed to [Vis] or [Unvis].
type unknownVisFlagsError struct {
flags VisFlag
}
func (err unknownVisFlagsError) Is(target error) bool {
return target == errUnknownVisFlagsError
}
func (err unknownVisFlagsError) Error() string {
return fmt.Sprintf("%s contains unknown or unsupported flags %s", err.flags, err.flags&^visMask)
}
// String pretty-prints VisFlag.
func (vflags VisFlag) String() string {
flagNames := []struct {

View file

@ -215,3 +215,21 @@ func TestVisFlagsString(t *testing.T) {
})
}
}
func TestUnknownVisFlagsError(t *testing.T) {
t.Run("Vis", func(t *testing.T) {
enc, err := Vis("dummy text", visMask+1)
require.Error(t, err, "Vis with invalid flags should fail")
assert.ErrorIs(t, err, errUnknownVisFlagsError, "Vis with invalid flags")
assert.ErrorContains(t, err, "contains unknown or unsupported flags", "Vis with invalid flags")
assert.Equal(t, "", enc, "error Vis should return empty string")
})
t.Run("Unvis", func(t *testing.T) {
dec, err := Unvis("dummy text", visMask+1)
require.Error(t, err, "Unvis with invalid flags should fail")
assert.ErrorIs(t, err, errUnknownVisFlagsError, "Unvis with invalid flags")
assert.ErrorContains(t, err, "contains unknown or unsupported flags", "Vis with invalid flags")
assert.Equal(t, "", dec, "error Unvis should return empty string")
})
}

View file

@ -19,11 +19,19 @@
package govis
import (
"errors"
"fmt"
"strconv"
"unicode"
)
var (
errEndOfString = errors.New("unexpectedly reached end of string")
errUnknownEscapeChar = errors.New("unknown escape character")
errOutsideLatin1 = errors.New("outside latin-1 encoding")
errParseDigit = errors.New("could not parse digit")
)
// unvisParser stores the current state of the token parser.
type unvisParser struct {
tokens []rune
@ -39,7 +47,7 @@ func (p *unvisParser) Next() {
// Peek gets the current token.
func (p *unvisParser) Peek() (rune, error) {
if p.idx >= len(p.tokens) {
return unicode.ReplacementChar, fmt.Errorf("tried to read past end of token list")
return unicode.ReplacementChar, errEndOfString
}
return p.tokens[p.idx], nil
}
@ -75,23 +83,16 @@ func newParser(input string, flag VisFlag) *unvisParser {
func unvisPlainRune(p *unvisParser) ([]byte, error) {
ch, err := p.Peek()
if err != nil {
return nil, fmt.Errorf("plain rune: %c", ch)
return nil, fmt.Errorf("plain rune: %w", err)
}
p.Next()
// XXX: Maybe we should not be converting to runes and then back to strings
// here. Are we sure that the byte-for-byte representation is the
// same? If the bytes change, then using these strings for paths will
// break...
str := string(ch)
return []byte(str), nil
return []byte(string(ch)), nil
}
func unvisEscapeCStyle(p *unvisParser) ([]byte, error) {
ch, err := p.Peek()
if err != nil {
return nil, fmt.Errorf("escape hex: %s", err)
return nil, fmt.Errorf("escape cstyle: %w", err)
}
output := ""
@ -120,7 +121,7 @@ func unvisEscapeCStyle(p *unvisParser) ([]byte, error) {
// Hidden marker.
default:
// XXX: We should probably allow falling through and return "\" here...
return nil, fmt.Errorf("escape cstyle: unknown escape character: %q", ch)
return nil, fmt.Errorf("escape cstyle: %w %q", errUnknownEscapeChar, ch)
}
p.Next()
@ -136,7 +137,7 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) {
if !force && i != 0xFF {
break
}
return nil, fmt.Errorf("escape base %d: %s", base, err)
return nil, fmt.Errorf("escape base %d: %w", base, err)
}
digit, err := strconv.ParseInt(string(ch), base, 8)
@ -144,7 +145,7 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) {
if !force && i != 0xFF {
break
}
return nil, fmt.Errorf("escape base %d: could not parse digit: %s", base, err)
return nil, fmt.Errorf("escape base %d: %w %q: %w", base, errParseDigit, ch, err)
}
code = (code * base) + int(digit)
@ -152,7 +153,7 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) {
}
if code > unicode.MaxLatin1 {
return nil, fmt.Errorf("escape base %d: code %q outside latin-1 encoding", base, code)
return nil, fmt.Errorf("escape base %d: code %+.2x %w", base, code, errOutsideLatin1)
}
char := byte(code & 0xFF)
@ -162,10 +163,10 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) {
func unvisEscapeCtrl(p *unvisParser, mask byte) ([]byte, error) {
ch, err := p.Peek()
if err != nil {
return nil, fmt.Errorf("escape ctrl: %s", err)
return nil, fmt.Errorf("escape ctrl: %w", err)
}
if ch > unicode.MaxLatin1 {
return nil, fmt.Errorf("escape ctrl: code %q outside latin-1 encoding", ch)
return nil, fmt.Errorf("escape ctrl: code %q %w", ch, errOutsideLatin1)
}
char := byte(ch) & 0x1f
@ -180,7 +181,7 @@ func unvisEscapeCtrl(p *unvisParser, mask byte) ([]byte, error) {
func unvisEscapeMeta(p *unvisParser) ([]byte, error) {
ch, err := p.Peek()
if err != nil {
return nil, fmt.Errorf("escape meta: %s", err)
return nil, fmt.Errorf("escape meta: %w", err)
}
mask := byte(0x80)
@ -196,10 +197,10 @@ func unvisEscapeMeta(p *unvisParser) ([]byte, error) {
ch, err := p.Peek()
if err != nil {
return nil, fmt.Errorf("escape meta1: %s", err)
return nil, fmt.Errorf("escape meta1: %w", err)
}
if ch > unicode.MaxLatin1 {
return nil, fmt.Errorf("escape meta1: code %q outside latin-1 encoding", ch)
return nil, fmt.Errorf("escape meta1: code %q %w", ch, errOutsideLatin1)
}
// Add mask to character.
@ -207,13 +208,13 @@ func unvisEscapeMeta(p *unvisParser) ([]byte, error) {
return []byte{mask | byte(ch)}, nil
}
return nil, fmt.Errorf("escape meta: unknown escape char: %s", err)
return nil, fmt.Errorf("escape meta: %w %q", errUnknownEscapeChar, ch)
}
func unvisEscapeSequence(p *unvisParser) ([]byte, error) {
ch, err := p.Peek()
if err != nil {
return nil, fmt.Errorf("escape sequence: %s", err)
return nil, fmt.Errorf("escape sequence: %w", err)
}
switch ch {
@ -244,7 +245,7 @@ func unvisEscapeSequence(p *unvisParser) ([]byte, error) {
func unvisRune(p *unvisParser) ([]byte, error) {
ch, err := p.Peek()
if err != nil {
return nil, fmt.Errorf("rune: %s", err)
return nil, err
}
switch ch {
@ -258,11 +259,8 @@ func unvisRune(p *unvisParser) ([]byte, error) {
p.Next()
return unvisEscapeDigits(p, 16, true)
}
fallthrough
default:
return unvisPlainRune(p)
}
return unvisPlainRune(p)
}
func unvis(p *unvisParser) (string, error) {
@ -270,7 +268,7 @@ func unvis(p *unvisParser) (string, error) {
for !p.End() {
ch, err := unvisRune(p)
if err != nil {
return "", fmt.Errorf("input: %s", err)
return "", err
}
output = append(output, ch...)
}
@ -281,15 +279,14 @@ func unvis(p *unvisParser) (string, error) {
// VisHTTPStyle flag is checked) and output the un-encoded version of the
// encoded string. An error is returned if any escape sequences in the input
// string were invalid.
func Unvis(input string, flag VisFlag) (string, error) {
// TODO: Check all of the VisFlag bits.
p := newParser(input, flag)
func Unvis(input string, flags VisFlag) (string, error) {
if unknown := flags &^ visMask; unknown != 0 {
return "", unknownVisFlagsError{flags: flags}
}
p := newParser(input, flags)
output, err := unvis(p)
if err != nil {
return "", fmt.Errorf("unvis: %s", err)
}
if !p.End() {
return "", fmt.Errorf("unvis: trailing characters at end of input")
return "", fmt.Errorf("unvis '%s': %w", input, err)
}
return output, nil
}

View file

@ -59,14 +59,14 @@ func isgraph(ch rune) bool {
// the plus side this is actually a benefit on the encoding side (it will
// always work with the simple unvis(3) implementation). It also means that we
// don't have to worry about different multi-byte encodings.
func vis(b byte, flag VisFlag) (string, error) {
func vis(b byte, flag VisFlag) string {
// Treat the single-byte character as a rune.
ch := rune(b)
// XXX: This is quite a horrible thing to support.
if flag&VisHTTPStyle == VisHTTPStyle {
if !ishttp(ch) {
return "%" + fmt.Sprintf("%.2X", ch), nil
return "%" + fmt.Sprintf("%.2X", ch)
}
}
@ -90,31 +90,31 @@ func vis(b byte, flag VisFlag) (string, error) {
if ch == '\\' && flag&VisNoSlash == 0 {
encoded += "\\"
}
return encoded, nil
return encoded
}
// Try to use C-style escapes first.
if flag&VisCStyle == VisCStyle {
switch ch {
case ' ':
return "\\s", nil
return "\\s"
case '\n':
return "\\n", nil
return "\\n"
case '\r':
return "\\r", nil
return "\\r"
case '\b':
return "\\b", nil
return "\\b"
case '\a':
return "\\a", nil
return "\\a"
case '\v':
return "\\v", nil
return "\\v"
case '\t':
return "\\t", nil
return "\\t"
case '\f':
return "\\f", nil
return "\\f"
case '\x00':
// Output octal just to be safe.
return "\\000", nil
return "\\000"
}
}
@ -123,7 +123,7 @@ func vis(b byte, flag VisFlag) (string, error) {
// encoded as octal.
if flag&VisOctal == VisOctal || isgraph(ch) || ch&0x7f == ' ' {
// Always output three-character octal just to be safe.
return fmt.Sprintf("\\%.3o", ch), nil
return fmt.Sprintf("\\%.3o", ch)
}
// Now we have to output meta or ctrl escapes. As far as I can tell, this
@ -154,25 +154,20 @@ func vis(b byte, flag VisFlag) (string, error) {
encoded += fmt.Sprintf("-%c", b)
}
return encoded, nil
return encoded
}
// Vis encodes the provided string to a BSD-compatible encoding using BSD's
// vis() flags. However, it will correctly handle multi-byte encoding (which is
// not done properly by BSD's vis implementation).
func Vis(src string, flag VisFlag) (string, error) {
if flag&visMask != flag {
return "", fmt.Errorf("vis: flag %q contains unknown or unsupported flags", flag)
func Vis(src string, flags VisFlag) (string, error) {
if unknown := flags &^ visMask; unknown != 0 {
return "", unknownVisFlagsError{flags: flags}
}
output := ""
for _, ch := range []byte(src) {
encodedCh, err := vis(ch, flag)
if err != nil {
return "", err
}
output += encodedCh
output += vis(ch, flags)
}
return output, nil
}