// BSON library for Go // // Copyright (c) 2010-2012 - Gustavo Niemeyer // // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are met: // // 1. Redistributions of source code must retain the above copyright notice, this // list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation // and/or other materials provided with the distribution. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR // ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. // gobson - BSON library for Go. package bson import ( "fmt" "math" "net/url" "reflect" "sync" "time" ) type decoder struct { in []byte i int docType reflect.Type } var typeM = reflect.TypeOf(M{}) func newDecoder(in []byte) *decoder { return &decoder{in, 0, typeM} } // -------------------------------------------------------------------------- // Some helper functions. func corrupted() { panic("Document is corrupted") } func settableValueOf(i interface{}) reflect.Value { v := reflect.ValueOf(i) sv := reflect.New(v.Type()).Elem() sv.Set(v) return sv } // -------------------------------------------------------------------------- // Unmarshaling of documents. const ( setterUnknown = iota setterNone setterType setterAddr ) var setterStyle map[reflect.Type]int var setterIface reflect.Type var setterMutex sync.RWMutex func init() { var iface Setter setterIface = reflect.TypeOf(&iface).Elem() setterStyle = make(map[reflect.Type]int) } func getSetter(outt reflect.Type, out reflect.Value) Setter { setterMutex.RLock() style := setterStyle[outt] setterMutex.RUnlock() if style == setterNone { return nil } if style == setterUnknown { setterMutex.Lock() defer setterMutex.Unlock() if outt.Implements(setterIface) { setterStyle[outt] = setterType } else if reflect.PtrTo(outt).Implements(setterIface) { setterStyle[outt] = setterAddr } else { setterStyle[outt] = setterNone return nil } style = setterStyle[outt] } if style == setterAddr { if !out.CanAddr() { return nil } out = out.Addr() } else if outt.Kind() == reflect.Ptr && out.IsNil() { out.Set(reflect.New(outt.Elem())) } return out.Interface().(Setter) } func clearMap(m reflect.Value) { var none reflect.Value for _, k := range m.MapKeys() { m.SetMapIndex(k, none) } } func (d *decoder) readDocTo(out reflect.Value) { var elemType reflect.Type outt := out.Type() outk := outt.Kind() for { if outk == reflect.Ptr && out.IsNil() { out.Set(reflect.New(outt.Elem())) } if setter := getSetter(outt, out); setter != nil { var raw Raw d.readDocTo(reflect.ValueOf(&raw)) err := setter.SetBSON(raw) if _, ok := err.(*TypeError); err != nil && !ok { panic(err) } return } if outk == reflect.Ptr { out = out.Elem() outt = out.Type() outk = out.Kind() continue } break } var fieldsMap map[string]fieldInfo var inlineMap reflect.Value start := d.i origout := out if outk == reflect.Interface { if d.docType.Kind() == reflect.Map { mv := reflect.MakeMap(d.docType) out.Set(mv) out = mv } else { dv := reflect.New(d.docType).Elem() out.Set(dv) out = dv } outt = out.Type() outk = outt.Kind() } docType := d.docType keyType := typeString convertKey := false switch outk { case reflect.Map: keyType = outt.Key() if keyType.Kind() != reflect.String { panic("BSON map must have string keys. Got: " + outt.String()) } if keyType != typeString { convertKey = true } elemType = outt.Elem() if elemType == typeIface { d.docType = outt } if out.IsNil() { out.Set(reflect.MakeMap(out.Type())) } else if out.Len() > 0 { clearMap(out) } case reflect.Struct: if outt != typeRaw { sinfo, err := getStructInfo(out.Type()) if err != nil { panic(err) } fieldsMap = sinfo.FieldsMap out.Set(sinfo.Zero) if sinfo.InlineMap != -1 { inlineMap = out.Field(sinfo.InlineMap) if !inlineMap.IsNil() && inlineMap.Len() > 0 { clearMap(inlineMap) } elemType = inlineMap.Type().Elem() if elemType == typeIface { d.docType = inlineMap.Type() } } } case reflect.Slice: switch outt.Elem() { case typeDocElem: origout.Set(d.readDocElems(outt)) return case typeRawDocElem: origout.Set(d.readRawDocElems(outt)) return } fallthrough default: panic("Unsupported document type for unmarshalling: " + out.Type().String()) } end := int(d.readInt32()) end += d.i - 4 if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { corrupted() } for d.in[d.i] != '\x00' { kind := d.readByte() name := d.readCStr() if d.i >= end { corrupted() } switch outk { case reflect.Map: e := reflect.New(elemType).Elem() if d.readElemTo(e, kind) { k := reflect.ValueOf(name) if convertKey { k = k.Convert(keyType) } out.SetMapIndex(k, e) } case reflect.Struct: if outt == typeRaw { d.dropElem(kind) } else { if info, ok := fieldsMap[name]; ok { if info.Inline == nil { d.readElemTo(out.Field(info.Num), kind) } else { d.readElemTo(out.FieldByIndex(info.Inline), kind) } } else if inlineMap.IsValid() { if inlineMap.IsNil() { inlineMap.Set(reflect.MakeMap(inlineMap.Type())) } e := reflect.New(elemType).Elem() if d.readElemTo(e, kind) { inlineMap.SetMapIndex(reflect.ValueOf(name), e) } } else { d.dropElem(kind) } } case reflect.Slice: } if d.i >= end { corrupted() } } d.i++ // '\x00' if d.i != end { corrupted() } d.docType = docType if outt == typeRaw { out.Set(reflect.ValueOf(Raw{0x03, d.in[start:d.i]})) } } func (d *decoder) readArrayDocTo(out reflect.Value) { end := int(d.readInt32()) end += d.i - 4 if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { corrupted() } i := 0 l := out.Len() for d.in[d.i] != '\x00' { if i >= l { panic("Length mismatch on array field") } kind := d.readByte() for d.i < end && d.in[d.i] != '\x00' { d.i++ } if d.i >= end { corrupted() } d.i++ d.readElemTo(out.Index(i), kind) if d.i >= end { corrupted() } i++ } if i != l { panic("Length mismatch on array field") } d.i++ // '\x00' if d.i != end { corrupted() } } func (d *decoder) readSliceDoc(t reflect.Type) interface{} { tmp := make([]reflect.Value, 0, 8) elemType := t.Elem() end := int(d.readInt32()) end += d.i - 4 if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { corrupted() } for d.in[d.i] != '\x00' { kind := d.readByte() for d.i < end && d.in[d.i] != '\x00' { d.i++ } if d.i >= end { corrupted() } d.i++ e := reflect.New(elemType).Elem() if d.readElemTo(e, kind) { tmp = append(tmp, e) } if d.i >= end { corrupted() } } d.i++ // '\x00' if d.i != end { corrupted() } n := len(tmp) slice := reflect.MakeSlice(t, n, n) for i := 0; i != n; i++ { slice.Index(i).Set(tmp[i]) } return slice.Interface() } var typeSlice = reflect.TypeOf([]interface{}{}) var typeIface = typeSlice.Elem() func (d *decoder) readDocElems(typ reflect.Type) reflect.Value { docType := d.docType d.docType = typ slice := make([]DocElem, 0, 8) d.readDocWith(func(kind byte, name string) { e := DocElem{Name: name} v := reflect.ValueOf(&e.Value) if d.readElemTo(v.Elem(), kind) { slice = append(slice, e) } }) slicev := reflect.New(typ).Elem() slicev.Set(reflect.ValueOf(slice)) d.docType = docType return slicev } func (d *decoder) readRawDocElems(typ reflect.Type) reflect.Value { docType := d.docType d.docType = typ slice := make([]RawDocElem, 0, 8) d.readDocWith(func(kind byte, name string) { e := RawDocElem{Name: name} v := reflect.ValueOf(&e.Value) if d.readElemTo(v.Elem(), kind) { slice = append(slice, e) } }) slicev := reflect.New(typ).Elem() slicev.Set(reflect.ValueOf(slice)) d.docType = docType return slicev } func (d *decoder) readDocWith(f func(kind byte, name string)) { end := int(d.readInt32()) end += d.i - 4 if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { corrupted() } for d.in[d.i] != '\x00' { kind := d.readByte() name := d.readCStr() if d.i >= end { corrupted() } f(kind, name) if d.i >= end { corrupted() } } d.i++ // '\x00' if d.i != end { corrupted() } } // -------------------------------------------------------------------------- // Unmarshaling of individual elements within a document. var blackHole = settableValueOf(struct{}{}) func (d *decoder) dropElem(kind byte) { d.readElemTo(blackHole, kind) } // Attempt to decode an element from the document and put it into out. // If the types are not compatible, the returned ok value will be // false and out will be unchanged. func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { start := d.i if kind == '\x03' { // Special case for documents. Delegate to readDocTo(). switch out.Kind() { case reflect.Interface, reflect.Ptr, reflect.Struct, reflect.Map: d.readDocTo(out) default: switch out.Interface().(type) { case D: out.Set(d.readDocElems(out.Type())) case RawD: out.Set(d.readRawDocElems(out.Type())) default: d.readDocTo(blackHole) } } return true } var in interface{} switch kind { case 0x01: // Float64 in = d.readFloat64() case 0x02: // UTF-8 string in = d.readStr() case 0x03: // Document panic("Can't happen. Handled above.") case 0x04: // Array outt := out.Type() for outt.Kind() == reflect.Ptr { outt = outt.Elem() } switch outt.Kind() { case reflect.Array: d.readArrayDocTo(out) return true case reflect.Slice: in = d.readSliceDoc(outt) default: in = d.readSliceDoc(typeSlice) } case 0x05: // Binary b := d.readBinary() if b.Kind == 0x00 || b.Kind == 0x02 { in = b.Data } else { in = b } case 0x06: // Undefined (obsolete, but still seen in the wild) in = Undefined case 0x07: // ObjectId in = ObjectId(d.readBytes(12)) case 0x08: // Bool in = d.readBool() case 0x09: // Timestamp // MongoDB handles timestamps as milliseconds. i := d.readInt64() if i == -62135596800000 { in = time.Time{} // In UTC for convenience. } else { in = time.Unix(i/1e3, i%1e3*1e6) } case 0x0A: // Nil in = nil case 0x0B: // RegEx in = d.readRegEx() case 0x0D: // JavaScript without scope in = JavaScript{Code: d.readStr()} case 0x0E: // Symbol in = Symbol(d.readStr()) case 0x0F: // JavaScript with scope d.i += 4 // Skip length js := JavaScript{d.readStr(), make(M)} d.readDocTo(reflect.ValueOf(js.Scope)) in = js case 0x10: // Int32 in = int(d.readInt32()) case 0x11: // Mongo-specific timestamp in = MongoTimestamp(d.readInt64()) case 0x12: // Int64 in = d.readInt64() case 0x7F: // Max key in = MaxKey case 0xFF: // Min key in = MinKey default: panic(fmt.Sprintf("Unknown element kind (0x%02X)", kind)) } outt := out.Type() if outt == typeRaw { out.Set(reflect.ValueOf(Raw{kind, d.in[start:d.i]})) return true } if setter := getSetter(outt, out); setter != nil { err := setter.SetBSON(Raw{kind, d.in[start:d.i]}) if err == SetZero { out.Set(reflect.Zero(outt)) return true } if err == nil { return true } if _, ok := err.(*TypeError); !ok { panic(err) } return false } if in == nil { out.Set(reflect.Zero(outt)) return true } outk := outt.Kind() // Dereference and initialize pointer if necessary. first := true for outk == reflect.Ptr { if !out.IsNil() { out = out.Elem() } else { elem := reflect.New(outt.Elem()) if first { // Only set if value is compatible. first = false defer func(out, elem reflect.Value) { if good { out.Set(elem) } }(out, elem) } else { out.Set(elem) } out = elem } outt = out.Type() outk = outt.Kind() } inv := reflect.ValueOf(in) if outt == inv.Type() { out.Set(inv) return true } switch outk { case reflect.Interface: out.Set(inv) return true case reflect.String: switch inv.Kind() { case reflect.String: out.SetString(inv.String()) return true case reflect.Slice: if b, ok := in.([]byte); ok { out.SetString(string(b)) return true } } case reflect.Slice, reflect.Array: // Remember, array (0x04) slices are built with the correct // element type. If we are here, must be a cross BSON kind // conversion (e.g. 0x05 unmarshalling on string). if outt.Elem().Kind() != reflect.Uint8 { break } switch inv.Kind() { case reflect.String: slice := []byte(inv.String()) out.Set(reflect.ValueOf(slice)) return true case reflect.Slice: switch outt.Kind() { case reflect.Array: reflect.Copy(out, inv) case reflect.Slice: out.SetBytes(inv.Bytes()) } return true } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: switch inv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: out.SetInt(inv.Int()) return true case reflect.Float32, reflect.Float64: out.SetInt(int64(inv.Float())) return true case reflect.Bool: if inv.Bool() { out.SetInt(1) } else { out.SetInt(0) } return true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: panic("Can't happen. No uint types in BSON?") } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: switch inv.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: out.SetUint(uint64(inv.Int())) return true case reflect.Float32, reflect.Float64: out.SetUint(uint64(inv.Float())) return true case reflect.Bool: if inv.Bool() { out.SetUint(1) } else { out.SetUint(0) } return true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: panic("Can't happen. No uint types in BSON.") } case reflect.Float32, reflect.Float64: switch inv.Kind() { case reflect.Float32, reflect.Float64: out.SetFloat(inv.Float()) return true case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: out.SetFloat(float64(inv.Int())) return true case reflect.Bool: if inv.Bool() { out.SetFloat(1) } else { out.SetFloat(0) } return true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: panic("Can't happen. No uint types in BSON?") } case reflect.Bool: switch inv.Kind() { case reflect.Bool: out.SetBool(inv.Bool()) return true case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: out.SetBool(inv.Int() != 0) return true case reflect.Float32, reflect.Float64: out.SetBool(inv.Float() != 0) return true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: panic("Can't happen. No uint types in BSON?") } case reflect.Struct: if outt == typeURL && inv.Kind() == reflect.String { u, err := url.Parse(inv.String()) if err != nil { panic(err) } out.Set(reflect.ValueOf(u).Elem()) return true } } return false } // -------------------------------------------------------------------------- // Parsers of basic types. func (d *decoder) readRegEx() RegEx { re := RegEx{} re.Pattern = d.readCStr() re.Options = d.readCStr() return re } func (d *decoder) readBinary() Binary { l := d.readInt32() b := Binary{} b.Kind = d.readByte() b.Data = d.readBytes(l) if b.Kind == 0x02 && len(b.Data) >= 4 { // Weird obsolete format with redundant length. b.Data = b.Data[4:] } return b } func (d *decoder) readStr() string { l := d.readInt32() b := d.readBytes(l - 1) if d.readByte() != '\x00' { corrupted() } return string(b) } func (d *decoder) readCStr() string { start := d.i end := start l := len(d.in) for ; end != l; end++ { if d.in[end] == '\x00' { break } } d.i = end + 1 if d.i > l { corrupted() } return string(d.in[start:end]) } func (d *decoder) readBool() bool { if d.readByte() == 1 { return true } return false } func (d *decoder) readFloat64() float64 { return math.Float64frombits(uint64(d.readInt64())) } func (d *decoder) readInt32() int32 { b := d.readBytes(4) return int32((uint32(b[0]) << 0) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)) } func (d *decoder) readInt64() int64 { b := d.readBytes(8) return int64((uint64(b[0]) << 0) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) | (uint64(b[4]) << 32) | (uint64(b[5]) << 40) | (uint64(b[6]) << 48) | (uint64(b[7]) << 56)) } func (d *decoder) readByte() byte { i := d.i d.i++ if d.i > len(d.in) { corrupted() } return d.in[i] } func (d *decoder) readBytes(length int32) []byte { start := d.i d.i += int(length) if d.i > len(d.in) { corrupted() } return d.in[start : start+int(length)] }