diff --git a/omaha/handler.go b/omaha/handler.go
index 42898f9..2754ce9 100644
--- a/omaha/handler.go
+++ b/omaha/handler.go
@@ -33,25 +33,18 @@ func (o *OmahaHandler) ServeHTTP(w http.ResponseWriter, httpReq *http.Request) {
// A request over 1M in size is certainly bogus.
reader := http.MaxBytesReader(w, httpReq.Body, 1024*1024)
-
- decoder := xml.NewDecoder(reader)
- var omahaReq Request
- if err := decoder.Decode(&omahaReq); err != nil {
- log.Printf("omaha: Failed decoding XML: %v", err)
- http.Error(w, "Invalid XML", http.StatusBadRequest)
- return
- }
-
- if omahaReq.Protocol != "3.0" {
- log.Printf("omaha: Unexpected protocol: %q", omahaReq.Protocol)
- http.Error(w, "Omaha 3.0 Required", http.StatusBadRequest)
+ contentType := httpReq.Header.Get("Content-Type")
+ omahaReq, err := ParseRequest(contentType, reader)
+ if err != nil {
+ log.Printf("omaha: Failed parsing request: %v", err)
+ http.Error(w, "Bad Omaha Request", http.StatusBadRequest)
return
}
httpStatus := 0
omahaResp := NewResponse()
for _, appReq := range omahaReq.Apps {
- appResp := o.serveApp(omahaResp, httpReq, &omahaReq, appReq)
+ appResp := o.serveApp(omahaResp, httpReq, omahaReq, appReq)
if appResp.Status == AppOK {
// HTTP is ok if any app is ok.
httpStatus = http.StatusOK
diff --git a/omaha/parse.go b/omaha/parse.go
new file mode 100644
index 0000000..91ec45b
--- /dev/null
+++ b/omaha/parse.go
@@ -0,0 +1,71 @@
+// Copyright 2017 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package omaha
+
+import (
+ "encoding/xml"
+ "fmt"
+ "io"
+ "mime"
+ "strings"
+)
+
+// checkContentType verifies the HTTP Content-Type header properly
+// declares the document is XML and UTF-8. Blank is assumed OK.
+func checkContentType(contentType string) error {
+ if contentType == "" {
+ return nil
+ }
+
+ mType, mParams, err := mime.ParseMediaType(contentType)
+ if err != nil {
+ return err
+ }
+
+ if mType != "text/xml" && mType != "application/xml" {
+ return fmt.Errorf("unsupported content type %q", mType)
+ }
+
+ charset, _ := mParams["charset"]
+ if charset != "" && strings.ToLower(charset) != "utf-8" {
+ return fmt.Errorf("unsupported content charset %q", charset)
+ }
+
+ return nil
+}
+
+// parseReqOrResp parses Request and Response objects.
+func parseReqOrResp(r io.Reader, v interface{}) error {
+ decoder := xml.NewDecoder(r)
+ if err := decoder.Decode(v); err != nil {
+ return err
+ }
+
+ var protocol string
+ switch v := v.(type) {
+ case *Request:
+ protocol = v.Protocol
+ case *Response:
+ protocol = v.Protocol
+ default:
+ panic(fmt.Errorf("unexpected type %T", v))
+ }
+
+ if protocol != "3.0" {
+ return fmt.Errorf("unsupported omaha protocol: %q", protocol)
+ }
+
+ return nil
+}
diff --git a/omaha/parse_test.go b/omaha/parse_test.go
new file mode 100644
index 0000000..835cbad
--- /dev/null
+++ b/omaha/parse_test.go
@@ -0,0 +1,55 @@
+// Copyright 2017 CoreOS, Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package omaha
+
+import (
+ "strings"
+ "testing"
+)
+
+func TestCheckContentType(t *testing.T) {
+ for _, tt := range []struct {
+ ct string
+ ok bool
+ }{
+ {"", true},
+ {"text/xml", true},
+ {"text/XML", true},
+ {"application/xml", true},
+ {"text/plain", false},
+ {"xml", false},
+ {"text/xml; charset=utf-8", true},
+ {"text/xml; charset=UTF-8", true},
+ {"text/xml; charset=ascii", false},
+ } {
+ err := checkContentType(tt.ct)
+ if tt.ok && err != nil {
+ t.Errorf("%q failed: %v", tt.ct, err)
+ }
+ if !tt.ok && err == nil {
+ t.Errorf("%q was not rejected", tt.ct)
+ }
+ }
+}
+
+func TestParseBadVersion(t *testing.T) {
+ r := strings.NewReader(``)
+ err := parseReqOrResp(r, &Request{})
+ if err == nil {
+ t.Error("Bad protocol version was accepted")
+ } else if err.Error() != `unsupported omaha protocol: "2.0"` {
+ t.Errorf("Wrong error: %v", err)
+ }
+}
diff --git a/omaha/protocol.go b/omaha/protocol.go
index 4cc7922..90daf07 100644
--- a/omaha/protocol.go
+++ b/omaha/protocol.go
@@ -24,6 +24,7 @@ package omaha
import (
"encoding/xml"
+ "io"
)
// Request sent by the Omaha client
@@ -56,6 +57,22 @@ func NewRequest() *Request {
}
}
+// ParseRequest verifies and returns the parsed Request document.
+// The MIME Content-Type header may be provided to sanity check its
+// value; if blank it is assumed to be XML in UTF-8.
+func ParseRequest(contentType string, body io.Reader) (*Request, error) {
+ if err := checkContentType(contentType); err != nil {
+ return nil, err
+ }
+
+ r := &Request{}
+ if err := parseReqOrResp(body, r); err != nil {
+ return nil, err
+ }
+
+ return r, nil
+}
+
func (r *Request) AddApp(id, version string) *AppRequest {
a := &AppRequest{ID: id, Version: version}
r.Apps = append(r.Apps, a)
@@ -138,6 +155,22 @@ func NewResponse() *Response {
}
}
+// ParseResponse verifies and returns the parsed Response document.
+// The MIME Content-Type header may be provided to sanity check its
+// value; if blank it is assumed to be XML in UTF-8.
+func ParseResponse(contentType string, body io.Reader) (*Response, error) {
+ if err := checkContentType(contentType); err != nil {
+ return nil, err
+ }
+
+ r := &Response{}
+ if err := parseReqOrResp(body, r); err != nil {
+ return nil, err
+ }
+
+ return r, nil
+}
+
type DayStart struct {
ElapsedSeconds string `xml:"elapsed_seconds,attr"`
}
diff --git a/omaha/protocol_test.go b/omaha/protocol_test.go
index 2dd3dbc..378e9d9 100644
--- a/omaha/protocol_test.go
+++ b/omaha/protocol_test.go
@@ -17,10 +17,13 @@ package omaha
import (
"encoding/xml"
"fmt"
+ "reflect"
+ "strings"
"testing"
)
-const SampleRequest = `
+const (
+ sampleRequest = `
@@ -30,10 +33,34 @@ const SampleRequest = `
`
+ sampleResponse = `
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+`
+)
func TestOmahaRequestUpdateCheck(t *testing.T) {
- v := Request{}
- xml.Unmarshal([]byte(SampleRequest), &v)
+ v, err := ParseRequest("", strings.NewReader(sampleRequest))
+ if err != nil {
+ t.Fatalf("ParseRequest failed: %v", err)
+ }
if v.OS.Version != "Indy" {
t.Error("Unexpected version", v.OS.Version)
@@ -80,6 +107,63 @@ func TestOmahaRequestUpdateCheck(t *testing.T) {
}
}
+func TestOmahaResponseWithUpdate(t *testing.T) {
+ parsed, err := ParseResponse("", strings.NewReader(sampleResponse))
+ if err != nil {
+ t.Fatalf("ParseResponse failed: %v", err)
+ }
+
+ expected := &Response{
+ XMLName: xml.Name{Local: "response"},
+ Protocol: "3.0",
+ DayStart: DayStart{ElapsedSeconds: "49008"},
+ Apps: []*AppResponse{&AppResponse{
+ ID: "{87efface-864d-49a5-9bb3-4b050a7c227a}",
+ Status: AppOK,
+ Ping: &PingResponse{"ok"},
+ UpdateCheck: &UpdateResponse{
+ Status: UpdateOK,
+ URLs: []*URL{&URL{
+ CodeBase: "http://kam:8080/static/",
+ }},
+ Manifest: &Manifest{
+ Version: "9999.0.0",
+ Packages: []*Package{&Package{
+ SHA1: "+LXvjiaPkeYDLHoNKlf9qbJwvnk=",
+ Name: "update.gz",
+ Size: 67546213,
+ Required: true,
+ }},
+ Actions: []*Action{&Action{
+ Event: "postinstall",
+ DisplayVersion: "9999.0.0",
+ SHA256: "0VAlQW3RE99SGtSB5R4m08antAHO8XDoBMKDyxQT/Mg=",
+ IsDeltaPayload: true,
+ }},
+ },
+ },
+ }},
+ }
+
+ if !reflect.DeepEqual(parsed, expected) {
+ t.Errorf("parsed != expected\n%s\n%s", parsed, expected)
+ }
+}
+
+func TestOmahaResponsAsRequest(t *testing.T) {
+ _, err := ParseRequest("", strings.NewReader(sampleResponse))
+ if err == nil {
+ t.Fatal("ParseRequest successfully parsed a response")
+ }
+}
+
+func TestOmahaRequestAsResponse(t *testing.T) {
+ _, err := ParseResponse("", strings.NewReader(sampleRequest))
+ if err == nil {
+ t.Fatal("ParseResponse successfully parsed a request")
+ }
+}
+
func ExampleNewResponse() {
response := NewResponse()
app := response.AddApp("{52F1B9BC-D31A-4D86-9276-CBC256AADF9A}", "ok")