diff --git a/omaha/package.go b/omaha/package.go new file mode 100644 index 0000000..497d239 --- /dev/null +++ b/omaha/package.go @@ -0,0 +1,112 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package omaha + +import ( + "crypto/sha1" + "crypto/sha256" + "encoding/base64" + "errors" + "io" + "os" + "path/filepath" +) + +var ( + PackageHashMismatchError = errors.New("package hash is invalid") + PackageSizeMismatchError = errors.New("package size is invalid") +) + +// Package represents a single downloadable file. The Sha256 attribute +// is not a standard part of the Omaha protocol which only uses Sha1. +type Package struct { + Name string `xml:"name,attr"` + Sha1 string `xml:"hash,attr"` + Sha256 string `xml:"sha256,attr,omitempty"` + Size uint64 `xml:"size,attr"` + Required bool `xml:"required,attr"` +} + +func (p *Package) FromPath(name string) error { + f, err := os.Open(name) + if err != nil { + return err + } + defer f.Close() + + err = p.FromReader(f) + if err != nil { + return err + } + + p.Name = filepath.Base(name) + return nil +} + +func (p *Package) FromReader(r io.Reader) error { + sha1b64, sha256b64, n, err := multihash(r) + if err != nil { + return err + } + + p.Sha1 = sha1b64 + p.Sha256 = sha256b64 + p.Size = uint64(n) + return nil +} + +func (p *Package) Verify(dir string) error { + f, err := os.Open(filepath.Join(dir, p.Name)) + if err != nil { + return err + } + defer f.Close() + + return p.VerifyReader(f) +} + +func (p *Package) VerifyReader(r io.Reader) error { + sha1b64, sha256b64, n, err := multihash(r) + if err != nil { + return err + } + + if p.Size != uint64(n) { + return PackageSizeMismatchError + } + + if p.Sha1 != sha1b64 { + return PackageHashMismatchError + } + + // Allow Sha256 to be empty since it is a protocol extension. + if p.Sha256 != "" && p.Sha256 != sha256b64 { + return PackageHashMismatchError + } + + return nil +} + +func multihash(r io.Reader) (sha1b64, sha256b64 string, n int64, err error) { + h1 := sha1.New() + h256 := sha256.New() + if n, err = io.Copy(io.MultiWriter(h1, h256), r); err != nil { + return + } + + sha1b64 = base64.StdEncoding.EncodeToString(h1.Sum(nil)) + sha256b64 = base64.StdEncoding.EncodeToString(h256.Sum(nil)) + return +} diff --git a/omaha/package_test.go b/omaha/package_test.go new file mode 100644 index 0000000..c617218 --- /dev/null +++ b/omaha/package_test.go @@ -0,0 +1,144 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package omaha + +import ( + "strings" + "testing" + + "github.com/kylelemons/godebug/pretty" +) + +func TestPackageFromPath(t *testing.T) { + expect := Package{ + Name: "null", + Sha1: "2jmj7l5rSw0yVb/vlWAYkK/YBwk=", + Sha256: "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=", + Size: 0, + Required: false, + } + + p := Package{} + if err := p.FromPath("/dev/null"); err != nil { + t.Fatal(err) + } + + if diff := pretty.Compare(expect, p); diff != "" { + t.Errorf("Hashing /dev/null failed: %v", diff) + } +} + +func TestProtocolFromReader(t *testing.T) { + data := strings.NewReader("testing\n") + expect := Package{ + Name: "", + Sha1: "mAFznarkTsUpPU4fU9P00tQm2Rw=", + Sha256: "EqYfThc/s6EcBdZHH3Ryj3YjG0pfzZZnzvOvh6OuTcI=", + Size: 8, + Required: false, + } + + p := Package{} + if err := p.FromReader(data); err != nil { + t.Fatal(err) + } + + if diff := pretty.Compare(expect, p); diff != "" { + t.Errorf("Hashing failed: %v", diff) + } +} + +func TestPackageVerify(t *testing.T) { + p := Package{ + Name: "null", + Sha1: "2jmj7l5rSw0yVb/vlWAYkK/YBwk=", + Sha256: "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=", + Size: 0, + Required: false, + } + + if err := p.Verify("/dev"); err != nil { + t.Fatal(err) + } +} + +func TestPackageVerifyNoSha256(t *testing.T) { + p := Package{ + Name: "null", + Sha1: "2jmj7l5rSw0yVb/vlWAYkK/YBwk=", + Sha256: "", + Size: 0, + Required: false, + } + + if err := p.Verify("/dev"); err != nil { + t.Fatal(err) + } +} + +func TestPackageVerifyBadSize(t *testing.T) { + p := Package{ + Name: "null", + Sha1: "2jmj7l5rSw0yVb/vlWAYkK/YBwk=", + Sha256: "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=", + Size: 1, + Required: false, + } + + err := p.Verify("/dev") + if err == nil { + t.Error("verify passed") + } + if err != PackageSizeMismatchError { + t.Error(err) + } + +} + +func TestPackageVerifyBadSha1(t *testing.T) { + p := Package{ + Name: "null", + Sha1: "xxxxxxxxxxxxxxxxxxxxxxxxxxx=", + Sha256: "47DEQpj8HBSa+/TImW+5JCeuQeRkm5NMpJWZG3hSuFU=", + Size: 0, + Required: false, + } + + err := p.Verify("/dev") + if err == nil { + t.Error("verify passed") + } + if err != PackageHashMismatchError { + t.Error(err) + } +} + +func TestPackageVerifyBadSha256(t *testing.T) { + p := Package{ + Name: "null", + Sha1: "2jmj7l5rSw0yVb/vlWAYkK/YBwk=", + Sha256: "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx=", + Size: 0, + Required: false, + } + + err := p.Verify("/dev") + if err == nil { + t.Error("verify passed") + } + if err != PackageHashMismatchError { + t.Error(err) + } +} diff --git a/omaha/protocol.go b/omaha/protocol.go index 3abe808..b7b152e 100644 --- a/omaha/protocol.go +++ b/omaha/protocol.go @@ -224,19 +224,21 @@ type Manifest struct { Version string `xml:"version,attr"` } -type Package struct { - Hash string `xml:"hash,attr"` - Name string `xml:"name,attr"` - Size uint64 `xml:"size,attr"` - Required bool `xml:"required,attr"` -} - func (m *Manifest) AddPackage() *Package { p := &Package{} m.Packages = append(m.Packages, p) return p } +func (m *Manifest) AddPackageFromPath(path string) (*Package, error) { + p := &Package{} + if err := p.FromPath(path); err != nil { + return nil, err + } + m.Packages = append(m.Packages, p) + return p, nil +} + func (m *Manifest) AddAction(event string) *Action { a := &Action{Event: event} m.Actions = append(m.Actions, a) diff --git a/omaha/protocol_test.go b/omaha/protocol_test.go index 6fb106f..54b6b0f 100644 --- a/omaha/protocol_test.go +++ b/omaha/protocol_test.go @@ -88,7 +88,7 @@ func ExampleNewResponse() { u.AddURL("http://localhost/updates") m := u.AddManifest("9999.0.0") k := m.AddPackage() - k.Hash = "+LXvjiaPkeYDLHoNKlf9qbJwvnk=" + k.Sha1 = "+LXvjiaPkeYDLHoNKlf9qbJwvnk=" k.Name = "update.gz" k.Size = 67546213 k.Required = true @@ -118,7 +118,7 @@ func ExampleNewResponse() { // // // - // + // // // //