Update bencode for weekly.2011-12-06

This commit is contained in:
Yves Junqueira 2011-12-17 17:35:20 +01:00
parent e65ed38f87
commit 335317be3b
4 changed files with 171 additions and 173 deletions

View file

@ -2,35 +2,35 @@ package bencode
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"os"
"reflect" "reflect"
"testing" "testing"
) )
type any interface{} type any interface{}
func checkMarshal(expected string, data any) (err os.Error) { func checkMarshal(expected string, data any) (err error) {
var b bytes.Buffer var b bytes.Buffer
if err = Marshal(&b, data); err != nil { if err = Marshal(&b, data); err != nil {
return return
} }
s := b.String() s := b.String()
if expected != s { if expected != s {
err = os.NewError(fmt.Sprintf("Expected %s got %s", expected, s)) err = errors.New(fmt.Sprintf("Expected %s got %s", expected, s))
return return
} }
return return
} }
func check(expected string, data any) (err os.Error) { func check(expected string, data any) (err error) {
if err = checkMarshal(expected, data); err != nil { if err = checkMarshal(expected, data); err != nil {
return return
} }
b2 := bytes.NewBufferString(expected) b2 := bytes.NewBufferString(expected)
val, err := Decode(b2) val, err := Decode(b2)
if err != nil { if err != nil {
err = os.NewError(fmt.Sprint("Failed decoding ", expected, " ", err)) err = errors.New(fmt.Sprint("Failed decoding ", expected, " ", err))
return return
} }
if err = checkFuzzyEqual(data, val); err != nil { if err = checkFuzzyEqual(data, val); err != nil {
@ -39,41 +39,41 @@ func check(expected string, data any) (err os.Error) {
return return
} }
func checkFuzzyEqual(a any, b any) (err os.Error) { func checkFuzzyEqual(a any, b any) (err error) {
if !fuzzyEqual(a, b) { if !fuzzyEqual(a, b) {
err = os.NewError(fmt.Sprint(a, " != ", b, err = errors.New(fmt.Sprint(a, " != ", b,
":", reflect.NewValue(a), "!=", reflect.NewValue(b))) ":", reflect.ValueOf(a), "!=", reflect.ValueOf(b)))
} }
return return
} }
func fuzzyEqual(a, b any) bool { func fuzzyEqual(a, b any) bool {
return fuzzyEqualValue(reflect.NewValue(a), reflect.NewValue(b)) return fuzzyEqualValue(reflect.ValueOf(a), reflect.ValueOf(b))
} }
func checkFuzzyEqualValue(a, b reflect.Value) (err os.Error) { func checkFuzzyEqualValue(a, b reflect.Value) (err error) {
if !fuzzyEqualValue(a, b) { if !fuzzyEqualValue(a, b) {
err = os.NewError(fmt.Sprint(a, " != ", b, err = errors.New(fmt.Sprint(a, " != ", b,
":", a.Interface(), "!=", b.Interface())) ":", a.Interface(), "!=", b.Interface()))
} }
return return
} }
func fuzzyEqualInt64(a int64, b reflect.Value) bool { func fuzzyEqualInt64(a int64, b reflect.Value) bool {
switch vb := b.(type) { switch vb := b; vb.Kind() {
case *reflect.IntValue: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return a == (vb.Get()) return a == (vb.Int())
default: default:
return false return false
} }
return false return false
} }
func fuzzyEqualArrayOrSlice(va reflect.ArrayOrSliceValue, b reflect.Value) bool { func fuzzyEqualArrayOrSlice(va reflect.Value, b reflect.Value) bool {
switch vb := b.(type) { switch vb := b; vb.Kind() {
case *reflect.ArrayValue: case reflect.Array:
return fuzzyEqualArrayOrSlice2(va, vb) return fuzzyEqualArrayOrSlice2(va, vb)
case *reflect.SliceValue: case reflect.Slice:
return fuzzyEqualArrayOrSlice2(va, vb) return fuzzyEqualArrayOrSlice2(va, vb)
default: default:
return false return false
@ -82,21 +82,21 @@ func fuzzyEqualArrayOrSlice(va reflect.ArrayOrSliceValue, b reflect.Value) bool
} }
func deInterface(a reflect.Value) reflect.Value { func deInterface(a reflect.Value) reflect.Value {
switch va := a.(type) { switch va := a; va.Kind() {
case *reflect.InterfaceValue: case reflect.Interface:
return va.Elem() return va.Elem()
} }
return a return a
} }
func fuzzyEqualArrayOrSlice2(a reflect.ArrayOrSliceValue, b reflect.ArrayOrSliceValue) bool { func fuzzyEqualArrayOrSlice2(a reflect.Value, b reflect.Value) bool {
if a.Len() != b.Len() { if a.Len() != b.Len() {
return false return false
} }
for i := 0; i < a.Len(); i++ { for i := 0; i < a.Len(); i++ {
ea := deInterface(a.Elem(i)) ea := deInterface(a.Index(i))
eb := deInterface(b.Elem(i)) eb := deInterface(b.Index(i))
if !fuzzyEqualValue(ea, eb) { if !fuzzyEqualValue(ea, eb) {
return false return false
} }
@ -104,31 +104,31 @@ func fuzzyEqualArrayOrSlice2(a reflect.ArrayOrSliceValue, b reflect.ArrayOrSlice
return true return true
} }
func fuzzyEqualMap(a *reflect.MapValue, b *reflect.MapValue) bool { func fuzzyEqualMap(a reflect.Value, b reflect.Value) bool {
key := a.Type().(*reflect.MapType).Key() key := a.Type().Key()
if _, ok := key.(*reflect.StringType); !ok { if key.Kind() != reflect.String {
return false return false
} }
key = b.Type().(*reflect.MapType).Key() key = b.Type().Key()
if _, ok := key.(*reflect.StringType); !ok { if key.Kind() != reflect.String {
return false return false
} }
aKeys, bKeys := a.Keys(), b.Keys() aKeys, bKeys := a.MapKeys(), b.MapKeys()
if len(aKeys) != len(bKeys) { if len(aKeys) != len(bKeys) {
return false return false
} }
for _, k := range aKeys { for _, k := range aKeys {
if !fuzzyEqualValue(a.Elem(k), b.Elem(k)) { if !fuzzyEqualValue(a.MapIndex(k), b.MapIndex(k)) {
return false return false
} }
} }
return true return true
} }
func fuzzyEqualStruct(a *reflect.StructValue, b *reflect.StructValue) bool { func fuzzyEqualStruct(a reflect.Value, b reflect.Value) bool {
numA, numB := a.NumField(), b.NumField() numA, numB := a.NumField(), b.NumField()
if numA != numB { if numA != numB {
return false return false
@ -143,37 +143,37 @@ func fuzzyEqualStruct(a *reflect.StructValue, b *reflect.StructValue) bool {
} }
func fuzzyEqualValue(a, b reflect.Value) bool { func fuzzyEqualValue(a, b reflect.Value) bool {
switch va := a.(type) { switch va := a; va.Kind() {
case *reflect.StringValue: case reflect.String:
switch vb := b.(type) { switch vb := b; vb.Kind() {
case *reflect.StringValue: case reflect.String:
return va.Get() == vb.Get() return va.String() == vb.String()
default: default:
return false return false
} }
case *reflect.IntValue: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return fuzzyEqualInt64(va.Get(), b) return fuzzyEqualInt64(va.Int(), b)
case *reflect.ArrayValue: case reflect.Array:
return fuzzyEqualArrayOrSlice(va, b) return fuzzyEqualArrayOrSlice(va, b)
case *reflect.SliceValue: case reflect.Slice:
return fuzzyEqualArrayOrSlice(va, b) return fuzzyEqualArrayOrSlice(va, b)
case *reflect.MapValue: case reflect.Map:
switch vb := b.(type) { switch vb := b; vb.Kind() {
case *reflect.MapValue: case reflect.Map:
return fuzzyEqualMap(va, vb) return fuzzyEqualMap(va, vb)
default: default:
return false return false
} }
case *reflect.StructValue: case reflect.Struct:
switch vb := b.(type) { switch vb := b; vb.Kind() {
case *reflect.StructValue: case reflect.Struct:
return fuzzyEqualStruct(va, vb) return fuzzyEqualStruct(va, vb)
default: default:
return false return false
} }
case *reflect.InterfaceValue: case reflect.Interface:
switch vb := b.(type) { switch vb := b; vb.Kind() {
case *reflect.InterfaceValue: case reflect.Interface:
return fuzzyEqualValue(va.Elem(), vb.Elem()) return fuzzyEqualValue(va.Elem(), vb.Elem())
default: default:
return false return false
@ -184,12 +184,12 @@ func fuzzyEqualValue(a, b reflect.Value) bool {
return false return false
} }
func checkUnmarshal(expected string, data any) (err os.Error) { func checkUnmarshal(expected string, data any) (err error) {
if err = checkMarshal(expected, data); err != nil { if err = checkMarshal(expected, data); err != nil {
return return
} }
dataValue := reflect.NewValue(data) dataValue := reflect.ValueOf(data)
newOne := reflect.MakeZero(dataValue.Type()) newOne := reflect.Zero(dataValue.Type())
buf := bytes.NewBufferString(expected) buf := bytes.NewBufferString(expected)
if err = UnmarshalValue(buf, newOne); err != nil { if err = UnmarshalValue(buf, newOne); err != nil {
return return
@ -223,7 +223,7 @@ func TestDecode(t *testing.T) {
} }
for _, sv := range tests { for _, sv := range tests {
if err := check(sv.s, sv.v); err != nil { if err := check(sv.s, sv.v); err != nil {
t.Error(err.String()) t.Error(err.Error())
} }
} }
} }
@ -262,7 +262,7 @@ func TestUnmarshal(t *testing.T) {
} }
for _, sv := range tests { for _, sv := range tests {
if err := checkUnmarshal(sv.s, sv.v); err != nil { if err := checkUnmarshal(sv.s, sv.v); err != nil {
t.Error(err.String()) t.Error(err.Error())
} }
} }
} }

View file

@ -8,9 +8,7 @@
package bencode package bencode
import ( import (
"container/vector"
"io" "io"
"os"
) )
// Decode a bencode stream // Decode a bencode stream
@ -23,7 +21,7 @@ import (
// //
// If Decode encounters a syntax error, it returns with err set to an // If Decode encounters a syntax error, it returns with err set to an
// instance of ParseError. See ParseError documentation for details. // instance of ParseError. See ParseError documentation for details.
func Decode(r io.Reader) (data interface{}, err os.Error) { func Decode(r io.Reader) (data interface{}, err error) {
jb := newDecoder(nil, nil) jb := newDecoder(nil, nil)
err = Parse(r, jb) err = Parse(r, jb)
if err == nil { if err == nil {
@ -35,7 +33,7 @@ func Decode(r io.Reader) (data interface{}, err os.Error) {
type decoder struct { type decoder struct {
// A value being constructed. // A value being constructed.
value interface{} value interface{}
// Container entity to flush into. Can be either vector.Vector or // Container entity to flush into. Can be either []interface{} or
// map[string]interface{}. // map[string]interface{}.
container interface{} container interface{}
// The index into the container interface. Either int or string. // The index into the container interface. Either int or string.
@ -58,19 +56,24 @@ func (j *decoder) Bool(b bool) { j.value = b }
func (j *decoder) Null() { j.value = nil } func (j *decoder) Null() { j.value = nil }
func (j *decoder) Array() { j.value = new(vector.Vector) } func (j *decoder) Array() { j.value = make([]interface{}, 0, 8) }
func (j *decoder) Map() { j.value = make(map[string]interface{}) } func (j *decoder) Map() { j.value = make(map[string]interface{}) }
func (j *decoder) Elem(i int) Builder { func (j *decoder) Elem(i int) Builder {
v, ok := j.value.(*vector.Vector) v, ok := j.value.([]interface{})
if !ok { if !ok {
v = new(vector.Vector) v = make([]interface{}, 0, 8)
j.value = v j.value = v
} }
if v.Len() <= i { lens := len(v)
v.Resize(i+1, (i+1)*2) if cap(v) <= lens {
news := make([]interface{}, 0, lens*2)
copy(news, j.value.([]interface{}))
v = news
} }
v = v[0 : lens+1]
j.value = v
return newDecoder(v, i) return newDecoder(v, i)
} }
@ -85,9 +88,9 @@ func (j *decoder) Key(s string) Builder {
func (j *decoder) Flush() { func (j *decoder) Flush() {
switch c := j.container.(type) { switch c := j.container.(type) {
case *vector.Vector: case []interface{}:
index := j.index.(int) index := j.index.(int)
c.Set(index, j.Copy()) c[index] = j.Copy()
case map[string]interface{}: case map[string]interface{}:
index := j.index.(string) index := j.index.(string)
c[index] = j.Copy() c[index] = j.Copy()
@ -96,9 +99,6 @@ func (j *decoder) Flush() {
// Get the value built by this builder. // Get the value built by this builder.
func (j *decoder) Copy() interface{} { func (j *decoder) Copy() interface{} {
switch v := j.value.(type) { // XXX correct for slice?
case *vector.Vector:
return v.Copy()
}
return j.value return j.value
} }

View file

@ -9,16 +9,16 @@ package bencode
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io" "io"
"os"
"strconv" "strconv"
) )
type Reader interface { type Reader interface {
io.Reader io.Reader
ReadByte() (c byte, err os.Error) ReadByte() (c byte, err error)
UnreadByte() os.Error UnreadByte() error
} }
// Parser // Parser
@ -38,7 +38,6 @@ type Reader interface {
// nested data structure, using the "map keys" // nested data structure, using the "map keys"
// as struct field names. // as struct field names.
// A Builder is an interface implemented by clients and passed // A Builder is an interface implemented by clients and passed
// to the bencode parser. It gives clients full control over the // to the bencode parser. It gives clients full control over the
// eventual representation returned by the parser. // eventual representation returned by the parser.
@ -58,7 +57,7 @@ type Builder interface {
Flush() Flush()
} }
func collectInt(r Reader, delim byte) (buf []byte, err os.Error) { func collectInt(r Reader, delim byte) (buf []byte, err error) {
for { for {
var c byte var c byte
c, err = r.ReadByte() c, err = r.ReadByte()
@ -69,7 +68,7 @@ func collectInt(r Reader, delim byte) (buf []byte, err os.Error) {
return return
} }
if !(c == '-' || (c >= '0' && c <= '9')) { if !(c == '-' || (c >= '0' && c <= '9')) {
err = os.NewError("expected digit") err = errors.New("expected digit")
return return
} }
buf = append(buf, c) buf = append(buf, c)
@ -77,22 +76,22 @@ func collectInt(r Reader, delim byte) (buf []byte, err os.Error) {
return return
} }
func decodeInt64(r Reader, delim byte) (data int64, err os.Error) { func decodeInt64(r Reader, delim byte) (data int64, err error) {
buf, err := collectInt(r, delim) buf, err := collectInt(r, delim)
if err != nil { if err != nil {
return return
} }
data, err = strconv.Atoi64(string(buf)) data, err = strconv.ParseInt(string(buf), 10, 64)
return return
} }
func decodeString(r Reader) (data string, err os.Error) { func decodeString(r Reader) (data string, err error) {
length, err := decodeInt64(r, ':') length, err := decodeInt64(r, ':')
if err != nil { if err != nil {
return return
} }
if length < 0 { if length < 0 {
err = os.NewError("Bad string length") err = errors.New("Bad string length")
return return
} }
var buf = make([]byte, length) var buf = make([]byte, length)
@ -104,8 +103,7 @@ func decodeString(r Reader) (data string, err os.Error) {
return return
} }
func parse(r Reader, build Builder) (err os.Error) { func parse(r Reader, build Builder) (err error) {
Switch:
c, err := r.ReadByte() c, err := r.ReadByte()
if err != nil { if err != nil {
goto exit goto exit
@ -163,12 +161,12 @@ Switch:
var i2 uint64 var i2 uint64
str = string(buf) str = string(buf)
// If the number is exactly an integer, use that. // If the number is exactly an integer, use that.
if i, err = strconv.Atoi64(str); err == nil { if i, err = strconv.ParseInt(str, 10, 64); err == nil {
build.Int64(i) build.Int64(i)
} else if i2, err = strconv.Atoui64(str); err == nil { } else if i2, err = strconv.ParseUint(str, 10, 64); err == nil {
build.Uint64(i2) build.Uint64(i2)
} else { } else {
err = os.NewError("Bad integer") err = errors.New("Bad integer")
} }
case c == 'l': case c == 'l':
@ -194,7 +192,7 @@ Switch:
n++ n++
} }
default: default:
err = os.NewError(fmt.Sprintf("Unexpected character: '%v'", c)) err = errors.New(fmt.Sprintf("Unexpected character: '%v'", c))
} }
exit: exit:
build.Flush() build.Flush()
@ -203,7 +201,7 @@ exit:
// Parse parses the bencode stream and makes calls to // Parse parses the bencode stream and makes calls to
// the builder to construct a parsed representation. // the builder to construct a parsed representation.
func Parse(r io.Reader, builder Builder) (err os.Error) { func Parse(r io.Reader, builder Builder) (err error) {
rr, ok := r.(Reader) rr, ok := r.(Reader)
if !ok { if !ok {
rr = bufio.NewReader(r) rr = bufio.NewReader(r)

174
struct.go
View file

@ -10,9 +10,10 @@
package bencode package bencode
import ( import (
"errors"
"fmt" "fmt"
"io" "io"
"os"
"reflect" "reflect"
"sort" "sort"
"strings" "strings"
@ -22,35 +23,35 @@ type structBuilder struct {
val reflect.Value val reflect.Value
// if map_ != nil, write val to map_[key] on each change // if map_ != nil, write val to map_[key] on each change
map_ *reflect.MapValue map_ reflect.Value
key reflect.Value key reflect.Value
} }
var nobuilder *structBuilder var nobuilder *structBuilder
func isfloat(v reflect.Value) bool { func isfloat(v reflect.Value) bool {
switch v.(type) { switch v.Kind() {
case *reflect.FloatValue: case reflect.Float32, reflect.Float64:
return true return true
} }
return false return false
} }
func setfloat(v reflect.Value, f float64) { func setfloat(v reflect.Value, f float64) {
switch v := v.(type) { switch v.Kind() {
case *reflect.FloatValue: case reflect.Float32, reflect.Float64:
v.Set(f) v.SetFloat(f)
} }
} }
func setint(val reflect.Value, i int64) { func setint(val reflect.Value, i int64) {
switch v := val.(type) { switch v := val; v.Kind() {
case *reflect.IntValue: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
v.Set(int64(i)) v.SetInt(int64(i))
case *reflect.UintValue: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
v.Set(uint64(i)) v.SetUint(uint64(i))
case *reflect.InterfaceValue: case reflect.Interface:
v.Set(reflect.NewValue(i)) v.Set(reflect.ValueOf(i))
} }
} }
@ -60,8 +61,8 @@ func (b *structBuilder) Flush() {
if b == nil { if b == nil {
return return
} }
if b.map_ != nil { if b.map_.IsValid() {
b.map_.SetElem(b.key, b.val) b.map_.SetMapIndex(b.key, b.val)
} }
} }
@ -106,11 +107,11 @@ func (b *structBuilder) String(s string) {
return return
} }
switch v := b.val.(type) { switch v := b.val; v.Kind() {
case *reflect.StringValue: case reflect.String:
v.Set(s) v.SetString(s)
case *reflect.InterfaceValue: case reflect.Interface:
v.Set(reflect.NewValue(s)) v.Set(reflect.ValueOf(s))
} }
} }
@ -118,9 +119,9 @@ func (b *structBuilder) Array() {
if b == nil { if b == nil {
return return
} }
if v, ok := b.val.(*reflect.SliceValue); ok { if v := b.val; v.Kind() == reflect.Slice {
if v.IsNil() { if v.IsNil() {
v.Set(reflect.MakeSlice(v.Type().(*reflect.SliceType), 0, 8)) v.Set(reflect.MakeSlice(v.Type(), 0, 8))
} }
} }
} }
@ -129,12 +130,12 @@ func (b *structBuilder) Elem(i int) Builder {
if b == nil || i < 0 { if b == nil || i < 0 {
return nobuilder return nobuilder
} }
switch v := b.val.(type) { switch v := b.val; v.Kind() {
case *reflect.ArrayValue: case reflect.Array:
if i < v.Len() { if i < v.Len() {
return &structBuilder{val: v.Elem(i)} return &structBuilder{val: v.Index(i)}
} }
case *reflect.SliceValue: case reflect.Slice:
if i >= v.Cap() { if i >= v.Cap() {
n := v.Cap() n := v.Cap()
if n < 8 { if n < 8 {
@ -143,7 +144,7 @@ func (b *structBuilder) Elem(i int) Builder {
for n <= i { for n <= i {
n *= 2 n *= 2
} }
nv := reflect.MakeSlice(v.Type().(*reflect.SliceType), v.Len(), n) nv := reflect.MakeSlice(v.Type(), v.Len(), n)
reflect.Copy(nv, v) reflect.Copy(nv, v)
v.Set(nv) v.Set(nv)
} }
@ -151,7 +152,7 @@ func (b *structBuilder) Elem(i int) Builder {
v.SetLen(i + 1) v.SetLen(i + 1)
} }
if i < v.Len() { if i < v.Len() {
return &structBuilder{val: v.Elem(i)} return &structBuilder{val: v.Index(i)}
} }
} }
return nobuilder return nobuilder
@ -161,16 +162,16 @@ func (b *structBuilder) Map() {
if b == nil { if b == nil {
return return
} }
if v, ok := b.val.(*reflect.PtrValue); ok && v.IsNil() { if v := b.val; v.Kind() == reflect.Ptr && v.IsNil() {
if v.IsNil() { if v.IsNil() {
v.PointTo(reflect.MakeZero(v.Type().(*reflect.PtrType).Elem())) v.Set(reflect.Zero(v.Type().Elem()).Addr())
b.Flush() b.Flush()
} }
b.map_ = nil b.map_ = reflect.Value{}
b.val = v.Elem() b.val = v.Elem()
} }
if v, ok := b.val.(*reflect.MapValue); ok && v.IsNil() { if v := b.val; v.Kind() == reflect.Map && v.IsNil() {
v.Set(reflect.MakeMap(v.Type().(*reflect.MapType))) v.Set(reflect.MakeMap(v.Type()))
} }
} }
@ -178,28 +179,28 @@ func (b *structBuilder) Key(k string) Builder {
if b == nil { if b == nil {
return nobuilder return nobuilder
} }
switch v := reflect.Indirect(b.val).(type) { switch v := reflect.Indirect(b.val); v.Kind() {
case *reflect.StructValue: case reflect.Struct:
t := v.Type().(*reflect.StructType) t := v.Type()
// Case-insensitive field lookup. // Case-insensitive field lookup.
k = strings.ToLower(k) k = strings.ToLower(k)
for i := 0; i < t.NumField(); i++ { for i := 0; i < t.NumField(); i++ {
field := t.Field(i) field := t.Field(i)
if strings.ToLower(field.Tag) == k || if strings.ToLower(string(field.Tag)) == k ||
strings.ToLower(field.Name) == k { strings.ToLower(field.Name) == k {
return &structBuilder{val: v.Field(i)} return &structBuilder{val: v.Field(i)}
} }
} }
case *reflect.MapValue: case reflect.Map:
t := v.Type().(*reflect.MapType) t := v.Type()
if t.Key() != reflect.Typeof(k) { if t.Key() != reflect.TypeOf(k) {
break break
} }
key := reflect.NewValue(k) key := reflect.ValueOf(k)
elem := v.Elem(key) elem := v.MapIndex(key)
if elem == nil { if !elem.IsValid() {
v.SetElem(key, reflect.MakeZero(t.Elem())) v.SetMapIndex(key, reflect.Zero(t.Elem()))
elem = v.Elem(key) elem = v.MapIndex(key)
} }
return &structBuilder{val: elem, map_: v, key: key} return &structBuilder{val: elem, map_: v, key: key}
} }
@ -258,26 +259,26 @@ func (b *structBuilder) Key(k string) Builder {
// slice of the correct type. // slice of the correct type.
// //
func Unmarshal(r io.Reader, val interface{}) (err os.Error) { func Unmarshal(r io.Reader, val interface{}) (err error) {
// If e represents a value, the answer won't get back to the // If e represents a value, the answer won't get back to the
// caller. Make sure it's a pointer. // caller. Make sure it's a pointer.
if _, ok := reflect.Typeof(val).(*reflect.PtrType); !ok { if reflect.TypeOf(val).Kind() != reflect.Ptr {
err = os.ErrorString("Attempt to unmarshal into a non-pointer") err = errors.New("Attempt to unmarshal into a non-pointer")
return return
} }
err = UnmarshalValue(r, reflect.NewValue(val)) err = UnmarshalValue(r, reflect.ValueOf(val))
return return
} }
// This API is public primarily to make testing easier, but it is available if you // This API is public primarily to make testing easier, but it is available if you
// have a use for it. // have a use for it.
func UnmarshalValue(r io.Reader, v reflect.Value) (err os.Error) { func UnmarshalValue(r io.Reader, v reflect.Value) (err error) {
var b *structBuilder var b *structBuilder
// If val is a pointer to a slice, we append to the slice. // If val is a pointer to a slice, we append to the slice.
if ptr, ok := v.(*reflect.PtrValue); ok { if ptr := v; ptr.Kind() == reflect.Ptr {
if slice, ok := ptr.Elem().(*reflect.SliceValue); ok { if slice := ptr.Elem(); slice.Kind() == reflect.Slice {
b = &structBuilder{val: slice} b = &structBuilder{val: slice}
} }
} }
@ -294,17 +295,17 @@ type MarshalError struct {
T reflect.Type T reflect.Type
} }
func (e *MarshalError) String() string { func (e *MarshalError) Error() string {
return "bencode cannot encode value of type " + e.T.String() return "bencode cannot encode value of type " + e.T.String()
} }
func writeArrayOrSlice(w io.Writer, val reflect.ArrayOrSliceValue) (err os.Error) { func writeArrayOrSlice(w io.Writer, val reflect.Value) (err error) {
_, err = fmt.Fprint(w, "l") _, err = fmt.Fprint(w, "l")
if err != nil { if err != nil {
return return
} }
for i := 0; i < val.Len(); i++ { for i := 0; i < val.Len(); i++ {
if err := writeValue(w, val.Elem(i)); err != nil { if err := writeValue(w, val.Index(i)); err != nil {
return err return err
} }
} }
@ -331,7 +332,7 @@ func (a StringValueArray) Less(i, j int) bool { return a[i].key < a[j].key }
func (a StringValueArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] } func (a StringValueArray) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func writeSVList(w io.Writer, svList StringValueArray) (err os.Error) { func writeSVList(w io.Writer, svList StringValueArray) (err error) {
sort.Sort(svList) sort.Sort(svList)
for _, sv := range svList { for _, sv := range svList {
@ -351,10 +352,9 @@ func writeSVList(w io.Writer, svList StringValueArray) (err os.Error) {
return return
} }
func writeMap(w io.Writer, val reflect.Value) (err error) {
func writeMap(w io.Writer, val *reflect.MapValue) (err os.Error) { key := val.Type().Key()
key := val.Type().(*reflect.MapType).Key() if key.Kind() != reflect.String {
if _, ok := key.(*reflect.StringType); !ok {
return &MarshalError{val.Type()} return &MarshalError{val.Type()}
} }
_, err = fmt.Fprint(w, "d") _, err = fmt.Fprint(w, "d")
@ -362,14 +362,14 @@ func writeMap(w io.Writer, val *reflect.MapValue) (err os.Error) {
return return
} }
keys := val.Keys() keys := val.MapKeys()
// Sort keys // Sort keys
svList := make(StringValueArray, len(keys)) svList := make(StringValueArray, len(keys))
for i, key := range keys { for i, key := range keys {
svList[i].key = key.(*reflect.StringValue).Get() svList[i].key = key.String()
svList[i].value = val.Elem(key) svList[i].value = val.MapIndex(key)
} }
err = writeSVList(w, svList) err = writeSVList(w, svList)
@ -384,13 +384,13 @@ func writeMap(w io.Writer, val *reflect.MapValue) (err os.Error) {
return return
} }
func writeStruct(w io.Writer, val *reflect.StructValue) (err os.Error) { func writeStruct(w io.Writer, val reflect.Value) (err error) {
_, err = fmt.Fprint(w, "d") _, err = fmt.Fprint(w, "d")
if err != nil { if err != nil {
return return
} }
typ := val.Type().(*reflect.StructType) typ := val.Type()
numFields := val.NumField() numFields := val.NumField()
svList := make(StringValueArray, numFields) svList := make(StringValueArray, numFields)
@ -399,7 +399,7 @@ func writeStruct(w io.Writer, val *reflect.StructValue) (err os.Error) {
field := typ.Field(i) field := typ.Field(i)
key := field.Name key := field.Name
if len(field.Tag) > 0 { if len(field.Tag) > 0 {
key = field.Tag key = string(field.Tag)
} }
svList[i].key = key svList[i].key = key
svList[i].value = val.Field(i) svList[i].value = val.Field(i)
@ -417,29 +417,29 @@ func writeStruct(w io.Writer, val *reflect.StructValue) (err os.Error) {
return return
} }
func writeValue(w io.Writer, val reflect.Value) (err os.Error) { func writeValue(w io.Writer, val reflect.Value) (err error) {
if val == nil { if !val.IsValid() {
err = os.NewError("Can't write null value") err = errors.New("Can't write null value")
return return
} }
switch v := val.(type) { switch v := val; v.Kind() {
case *reflect.StringValue: case reflect.String:
s := v.Get() s := v.String()
_, err = fmt.Fprintf(w, "%d:%s", len(s), s) _, err = fmt.Fprintf(w, "%d:%s", len(s), s)
case *reflect.IntValue: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
_, err = fmt.Fprintf(w, "i%de", v.Get()) _, err = fmt.Fprintf(w, "i%de", v.Int())
case *reflect.UintValue: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
_, err = fmt.Fprintf(w, "i%de", v.Get()) _, err = fmt.Fprintf(w, "i%de", v.Uint())
case *reflect.ArrayValue: case reflect.Array:
err = writeArrayOrSlice(w, v) err = writeArrayOrSlice(w, v)
case *reflect.SliceValue: case reflect.Slice:
err = writeArrayOrSlice(w, v) err = writeArrayOrSlice(w, v)
case *reflect.MapValue: case reflect.Map:
err = writeMap(w, v) err = writeMap(w, v)
case *reflect.StructValue: case reflect.Struct:
err = writeStruct(w, v) err = writeStruct(w, v)
case *reflect.InterfaceValue: case reflect.Interface:
err = writeValue(w, v.Elem()) err = writeValue(w, v.Elem())
default: default:
err = &MarshalError{val.Type()} err = &MarshalError{val.Type()}
@ -448,11 +448,11 @@ func writeValue(w io.Writer, val reflect.Value) (err os.Error) {
} }
func isValueNil(val reflect.Value) bool { func isValueNil(val reflect.Value) bool {
if val == nil { if !val.IsValid() {
return true return true
} }
switch v := val.(type) { switch v := val; v.Kind() {
case *reflect.InterfaceValue: case reflect.Interface:
return isValueNil(v.Elem()) return isValueNil(v.Elem())
default: default:
return false return false
@ -460,6 +460,6 @@ func isValueNil(val reflect.Value) bool {
return false return false
} }
func Marshal(w io.Writer, val interface{}) os.Error { func Marshal(w io.Writer, val interface{}) error {
return writeValue(w, reflect.NewValue(val)) return writeValue(w, reflect.ValueOf(val))
} }