diff --git a/pkg/govis/govis.go b/pkg/govis/govis.go index 003b182..4200af2 100644 --- a/pkg/govis/govis.go +++ b/pkg/govis/govis.go @@ -19,6 +19,8 @@ package govis import ( + "errors" + "fmt" "strconv" "strings" ) @@ -44,6 +46,25 @@ const ( VisWhite VisFlag = (VisSpace | VisTab | VisNewline) ) +// errUnknownVisFlagsError is a special value that lets you use [errors.Is] +// with [unknownVisFlagsError]. Don't actually return this value, use +// [unknownVisFlagsError] instead! +var errUnknownVisFlagsError = errors.New("unknown or unsupported vis flags") + +// unknownVisFlagsError represents an error caused by unknown [VisFlag]s being +// passed to [Vis] or [Unvis]. +type unknownVisFlagsError struct { + flags VisFlag +} + +func (err unknownVisFlagsError) Is(target error) bool { + return target == errUnknownVisFlagsError +} + +func (err unknownVisFlagsError) Error() string { + return fmt.Sprintf("%s contains unknown or unsupported flags %s", err.flags, err.flags&^visMask) +} + // String pretty-prints VisFlag. func (vflags VisFlag) String() string { flagNames := []struct { diff --git a/pkg/govis/govis_test.go b/pkg/govis/govis_test.go index 5027030..e3c9245 100644 --- a/pkg/govis/govis_test.go +++ b/pkg/govis/govis_test.go @@ -215,3 +215,21 @@ func TestVisFlagsString(t *testing.T) { }) } } + +func TestUnknownVisFlagsError(t *testing.T) { + t.Run("Vis", func(t *testing.T) { + enc, err := Vis("dummy text", visMask+1) + require.Error(t, err, "Vis with invalid flags should fail") + assert.ErrorIs(t, err, errUnknownVisFlagsError, "Vis with invalid flags") + assert.ErrorContains(t, err, "contains unknown or unsupported flags", "Vis with invalid flags") + assert.Equal(t, "", enc, "error Vis should return empty string") + }) + + t.Run("Unvis", func(t *testing.T) { + dec, err := Unvis("dummy text", visMask+1) + require.Error(t, err, "Unvis with invalid flags should fail") + assert.ErrorIs(t, err, errUnknownVisFlagsError, "Unvis with invalid flags") + assert.ErrorContains(t, err, "contains unknown or unsupported flags", "Vis with invalid flags") + assert.Equal(t, "", dec, "error Unvis should return empty string") + }) +} diff --git a/pkg/govis/unvis.go b/pkg/govis/unvis.go index ddf7725..5dbf935 100644 --- a/pkg/govis/unvis.go +++ b/pkg/govis/unvis.go @@ -19,11 +19,19 @@ package govis import ( + "errors" "fmt" "strconv" "unicode" ) +var ( + errEndOfString = errors.New("unexpectedly reached end of string") + errUnknownEscapeChar = errors.New("unknown escape character") + errOutsideLatin1 = errors.New("outside latin-1 encoding") + errParseDigit = errors.New("could not parse digit") +) + // unvisParser stores the current state of the token parser. type unvisParser struct { tokens []rune @@ -39,7 +47,7 @@ func (p *unvisParser) Next() { // Peek gets the current token. func (p *unvisParser) Peek() (rune, error) { if p.idx >= len(p.tokens) { - return unicode.ReplacementChar, fmt.Errorf("tried to read past end of token list") + return unicode.ReplacementChar, errEndOfString } return p.tokens[p.idx], nil } @@ -75,23 +83,16 @@ func newParser(input string, flag VisFlag) *unvisParser { func unvisPlainRune(p *unvisParser) ([]byte, error) { ch, err := p.Peek() if err != nil { - return nil, fmt.Errorf("plain rune: %c", ch) + return nil, fmt.Errorf("plain rune: %w", err) } p.Next() - - // XXX: Maybe we should not be converting to runes and then back to strings - // here. Are we sure that the byte-for-byte representation is the - // same? If the bytes change, then using these strings for paths will - // break... - - str := string(ch) - return []byte(str), nil + return []byte(string(ch)), nil } func unvisEscapeCStyle(p *unvisParser) ([]byte, error) { ch, err := p.Peek() if err != nil { - return nil, fmt.Errorf("escape hex: %s", err) + return nil, fmt.Errorf("escape cstyle: %w", err) } output := "" @@ -120,7 +121,7 @@ func unvisEscapeCStyle(p *unvisParser) ([]byte, error) { // Hidden marker. default: // XXX: We should probably allow falling through and return "\" here... - return nil, fmt.Errorf("escape cstyle: unknown escape character: %q", ch) + return nil, fmt.Errorf("escape cstyle: %w %q", errUnknownEscapeChar, ch) } p.Next() @@ -136,7 +137,7 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) { if !force && i != 0xFF { break } - return nil, fmt.Errorf("escape base %d: %s", base, err) + return nil, fmt.Errorf("escape base %d: %w", base, err) } digit, err := strconv.ParseInt(string(ch), base, 8) @@ -144,7 +145,7 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) { if !force && i != 0xFF { break } - return nil, fmt.Errorf("escape base %d: could not parse digit: %s", base, err) + return nil, fmt.Errorf("escape base %d: %w %q: %w", base, errParseDigit, ch, err) } code = (code * base) + int(digit) @@ -152,7 +153,7 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) { } if code > unicode.MaxLatin1 { - return nil, fmt.Errorf("escape base %d: code %q outside latin-1 encoding", base, code) + return nil, fmt.Errorf("escape base %d: code %+.2x %w", base, code, errOutsideLatin1) } char := byte(code & 0xFF) @@ -162,10 +163,10 @@ func unvisEscapeDigits(p *unvisParser, base int, force bool) ([]byte, error) { func unvisEscapeCtrl(p *unvisParser, mask byte) ([]byte, error) { ch, err := p.Peek() if err != nil { - return nil, fmt.Errorf("escape ctrl: %s", err) + return nil, fmt.Errorf("escape ctrl: %w", err) } if ch > unicode.MaxLatin1 { - return nil, fmt.Errorf("escape ctrl: code %q outside latin-1 encoding", ch) + return nil, fmt.Errorf("escape ctrl: code %q %w", ch, errOutsideLatin1) } char := byte(ch) & 0x1f @@ -180,7 +181,7 @@ func unvisEscapeCtrl(p *unvisParser, mask byte) ([]byte, error) { func unvisEscapeMeta(p *unvisParser) ([]byte, error) { ch, err := p.Peek() if err != nil { - return nil, fmt.Errorf("escape meta: %s", err) + return nil, fmt.Errorf("escape meta: %w", err) } mask := byte(0x80) @@ -196,10 +197,10 @@ func unvisEscapeMeta(p *unvisParser) ([]byte, error) { ch, err := p.Peek() if err != nil { - return nil, fmt.Errorf("escape meta1: %s", err) + return nil, fmt.Errorf("escape meta1: %w", err) } if ch > unicode.MaxLatin1 { - return nil, fmt.Errorf("escape meta1: code %q outside latin-1 encoding", ch) + return nil, fmt.Errorf("escape meta1: code %q %w", ch, errOutsideLatin1) } // Add mask to character. @@ -207,13 +208,13 @@ func unvisEscapeMeta(p *unvisParser) ([]byte, error) { return []byte{mask | byte(ch)}, nil } - return nil, fmt.Errorf("escape meta: unknown escape char: %s", err) + return nil, fmt.Errorf("escape meta: %w %q", errUnknownEscapeChar, ch) } func unvisEscapeSequence(p *unvisParser) ([]byte, error) { ch, err := p.Peek() if err != nil { - return nil, fmt.Errorf("escape sequence: %s", err) + return nil, fmt.Errorf("escape sequence: %w", err) } switch ch { @@ -244,7 +245,7 @@ func unvisEscapeSequence(p *unvisParser) ([]byte, error) { func unvisRune(p *unvisParser) ([]byte, error) { ch, err := p.Peek() if err != nil { - return nil, fmt.Errorf("rune: %s", err) + return nil, err } switch ch { @@ -258,11 +259,8 @@ func unvisRune(p *unvisParser) ([]byte, error) { p.Next() return unvisEscapeDigits(p, 16, true) } - fallthrough - - default: - return unvisPlainRune(p) } + return unvisPlainRune(p) } func unvis(p *unvisParser) (string, error) { @@ -270,7 +268,7 @@ func unvis(p *unvisParser) (string, error) { for !p.End() { ch, err := unvisRune(p) if err != nil { - return "", fmt.Errorf("input: %s", err) + return "", err } output = append(output, ch...) } @@ -281,15 +279,14 @@ func unvis(p *unvisParser) (string, error) { // VisHTTPStyle flag is checked) and output the un-encoded version of the // encoded string. An error is returned if any escape sequences in the input // string were invalid. -func Unvis(input string, flag VisFlag) (string, error) { - // TODO: Check all of the VisFlag bits. - p := newParser(input, flag) +func Unvis(input string, flags VisFlag) (string, error) { + if unknown := flags &^ visMask; unknown != 0 { + return "", unknownVisFlagsError{flags: flags} + } + p := newParser(input, flags) output, err := unvis(p) if err != nil { - return "", fmt.Errorf("unvis: %s", err) - } - if !p.End() { - return "", fmt.Errorf("unvis: trailing characters at end of input") + return "", fmt.Errorf("unvis '%s': %w", input, err) } return output, nil } diff --git a/pkg/govis/vis.go b/pkg/govis/vis.go index 1af15eb..620b49a 100644 --- a/pkg/govis/vis.go +++ b/pkg/govis/vis.go @@ -59,14 +59,14 @@ func isgraph(ch rune) bool { // the plus side this is actually a benefit on the encoding side (it will // always work with the simple unvis(3) implementation). It also means that we // don't have to worry about different multi-byte encodings. -func vis(b byte, flag VisFlag) (string, error) { +func vis(b byte, flag VisFlag) string { // Treat the single-byte character as a rune. ch := rune(b) // XXX: This is quite a horrible thing to support. if flag&VisHTTPStyle == VisHTTPStyle { if !ishttp(ch) { - return "%" + fmt.Sprintf("%.2X", ch), nil + return "%" + fmt.Sprintf("%.2X", ch) } } @@ -90,31 +90,31 @@ func vis(b byte, flag VisFlag) (string, error) { if ch == '\\' && flag&VisNoSlash == 0 { encoded += "\\" } - return encoded, nil + return encoded } // Try to use C-style escapes first. if flag&VisCStyle == VisCStyle { switch ch { case ' ': - return "\\s", nil + return "\\s" case '\n': - return "\\n", nil + return "\\n" case '\r': - return "\\r", nil + return "\\r" case '\b': - return "\\b", nil + return "\\b" case '\a': - return "\\a", nil + return "\\a" case '\v': - return "\\v", nil + return "\\v" case '\t': - return "\\t", nil + return "\\t" case '\f': - return "\\f", nil + return "\\f" case '\x00': // Output octal just to be safe. - return "\\000", nil + return "\\000" } } @@ -123,7 +123,7 @@ func vis(b byte, flag VisFlag) (string, error) { // encoded as octal. if flag&VisOctal == VisOctal || isgraph(ch) || ch&0x7f == ' ' { // Always output three-character octal just to be safe. - return fmt.Sprintf("\\%.3o", ch), nil + return fmt.Sprintf("\\%.3o", ch) } // Now we have to output meta or ctrl escapes. As far as I can tell, this @@ -154,25 +154,20 @@ func vis(b byte, flag VisFlag) (string, error) { encoded += fmt.Sprintf("-%c", b) } - return encoded, nil + return encoded } // Vis encodes the provided string to a BSD-compatible encoding using BSD's // vis() flags. However, it will correctly handle multi-byte encoding (which is // not done properly by BSD's vis implementation). -func Vis(src string, flag VisFlag) (string, error) { - if flag&visMask != flag { - return "", fmt.Errorf("vis: flag %q contains unknown or unsupported flags", flag) +func Vis(src string, flags VisFlag) (string, error) { + if unknown := flags &^ visMask; unknown != 0 { + return "", unknownVisFlagsError{flags: flags} } output := "" for _, ch := range []byte(src) { - encodedCh, err := vis(ch, flag) - if err != nil { - return "", err - } - output += encodedCh + output += vis(ch, flags) } - return output, nil }