From a6290c1b4fe3c077e15afa37184b653724898fad Mon Sep 17 00:00:00 2001 From: Michael Marineau Date: Thu, 4 May 2017 13:00:46 -0700 Subject: [PATCH] protocol: add ParseRequest and ParseResponse functions For parsing and verification of of HTTP request and response bodies, including optional checking the Content-Type field which the handler previously didn't do. --- omaha/handler.go | 19 +++------ omaha/parse.go | 71 +++++++++++++++++++++++++++++++++ omaha/parse_test.go | 55 ++++++++++++++++++++++++++ omaha/protocol.go | 33 ++++++++++++++++ omaha/protocol_test.go | 90 ++++++++++++++++++++++++++++++++++++++++-- 5 files changed, 252 insertions(+), 16 deletions(-) create mode 100644 omaha/parse.go create mode 100644 omaha/parse_test.go diff --git a/omaha/handler.go b/omaha/handler.go index 42898f9..2754ce9 100644 --- a/omaha/handler.go +++ b/omaha/handler.go @@ -33,25 +33,18 @@ func (o *OmahaHandler) ServeHTTP(w http.ResponseWriter, httpReq *http.Request) { // A request over 1M in size is certainly bogus. reader := http.MaxBytesReader(w, httpReq.Body, 1024*1024) - - decoder := xml.NewDecoder(reader) - var omahaReq Request - if err := decoder.Decode(&omahaReq); err != nil { - log.Printf("omaha: Failed decoding XML: %v", err) - http.Error(w, "Invalid XML", http.StatusBadRequest) - return - } - - if omahaReq.Protocol != "3.0" { - log.Printf("omaha: Unexpected protocol: %q", omahaReq.Protocol) - http.Error(w, "Omaha 3.0 Required", http.StatusBadRequest) + contentType := httpReq.Header.Get("Content-Type") + omahaReq, err := ParseRequest(contentType, reader) + if err != nil { + log.Printf("omaha: Failed parsing request: %v", err) + http.Error(w, "Bad Omaha Request", http.StatusBadRequest) return } httpStatus := 0 omahaResp := NewResponse() for _, appReq := range omahaReq.Apps { - appResp := o.serveApp(omahaResp, httpReq, &omahaReq, appReq) + appResp := o.serveApp(omahaResp, httpReq, omahaReq, appReq) if appResp.Status == AppOK { // HTTP is ok if any app is ok. httpStatus = http.StatusOK diff --git a/omaha/parse.go b/omaha/parse.go new file mode 100644 index 0000000..91ec45b --- /dev/null +++ b/omaha/parse.go @@ -0,0 +1,71 @@ +// Copyright 2017 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package omaha + +import ( + "encoding/xml" + "fmt" + "io" + "mime" + "strings" +) + +// checkContentType verifies the HTTP Content-Type header properly +// declares the document is XML and UTF-8. Blank is assumed OK. +func checkContentType(contentType string) error { + if contentType == "" { + return nil + } + + mType, mParams, err := mime.ParseMediaType(contentType) + if err != nil { + return err + } + + if mType != "text/xml" && mType != "application/xml" { + return fmt.Errorf("unsupported content type %q", mType) + } + + charset, _ := mParams["charset"] + if charset != "" && strings.ToLower(charset) != "utf-8" { + return fmt.Errorf("unsupported content charset %q", charset) + } + + return nil +} + +// parseReqOrResp parses Request and Response objects. +func parseReqOrResp(r io.Reader, v interface{}) error { + decoder := xml.NewDecoder(r) + if err := decoder.Decode(v); err != nil { + return err + } + + var protocol string + switch v := v.(type) { + case *Request: + protocol = v.Protocol + case *Response: + protocol = v.Protocol + default: + panic(fmt.Errorf("unexpected type %T", v)) + } + + if protocol != "3.0" { + return fmt.Errorf("unsupported omaha protocol: %q", protocol) + } + + return nil +} diff --git a/omaha/parse_test.go b/omaha/parse_test.go new file mode 100644 index 0000000..835cbad --- /dev/null +++ b/omaha/parse_test.go @@ -0,0 +1,55 @@ +// Copyright 2017 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package omaha + +import ( + "strings" + "testing" +) + +func TestCheckContentType(t *testing.T) { + for _, tt := range []struct { + ct string + ok bool + }{ + {"", true}, + {"text/xml", true}, + {"text/XML", true}, + {"application/xml", true}, + {"text/plain", false}, + {"xml", false}, + {"text/xml; charset=utf-8", true}, + {"text/xml; charset=UTF-8", true}, + {"text/xml; charset=ascii", false}, + } { + err := checkContentType(tt.ct) + if tt.ok && err != nil { + t.Errorf("%q failed: %v", tt.ct, err) + } + if !tt.ok && err == nil { + t.Errorf("%q was not rejected", tt.ct) + } + } +} + +func TestParseBadVersion(t *testing.T) { + r := strings.NewReader(``) + err := parseReqOrResp(r, &Request{}) + if err == nil { + t.Error("Bad protocol version was accepted") + } else if err.Error() != `unsupported omaha protocol: "2.0"` { + t.Errorf("Wrong error: %v", err) + } +} diff --git a/omaha/protocol.go b/omaha/protocol.go index 4cc7922..90daf07 100644 --- a/omaha/protocol.go +++ b/omaha/protocol.go @@ -24,6 +24,7 @@ package omaha import ( "encoding/xml" + "io" ) // Request sent by the Omaha client @@ -56,6 +57,22 @@ func NewRequest() *Request { } } +// ParseRequest verifies and returns the parsed Request document. +// The MIME Content-Type header may be provided to sanity check its +// value; if blank it is assumed to be XML in UTF-8. +func ParseRequest(contentType string, body io.Reader) (*Request, error) { + if err := checkContentType(contentType); err != nil { + return nil, err + } + + r := &Request{} + if err := parseReqOrResp(body, r); err != nil { + return nil, err + } + + return r, nil +} + func (r *Request) AddApp(id, version string) *AppRequest { a := &AppRequest{ID: id, Version: version} r.Apps = append(r.Apps, a) @@ -138,6 +155,22 @@ func NewResponse() *Response { } } +// ParseResponse verifies and returns the parsed Response document. +// The MIME Content-Type header may be provided to sanity check its +// value; if blank it is assumed to be XML in UTF-8. +func ParseResponse(contentType string, body io.Reader) (*Response, error) { + if err := checkContentType(contentType); err != nil { + return nil, err + } + + r := &Response{} + if err := parseReqOrResp(body, r); err != nil { + return nil, err + } + + return r, nil +} + type DayStart struct { ElapsedSeconds string `xml:"elapsed_seconds,attr"` } diff --git a/omaha/protocol_test.go b/omaha/protocol_test.go index 2dd3dbc..378e9d9 100644 --- a/omaha/protocol_test.go +++ b/omaha/protocol_test.go @@ -17,10 +17,13 @@ package omaha import ( "encoding/xml" "fmt" + "reflect" + "strings" "testing" ) -const SampleRequest = ` +const ( + sampleRequest = ` @@ -30,10 +33,34 @@ const SampleRequest = ` ` + sampleResponse = ` + + + + + + + + + + + + + + + + + + + +` +) func TestOmahaRequestUpdateCheck(t *testing.T) { - v := Request{} - xml.Unmarshal([]byte(SampleRequest), &v) + v, err := ParseRequest("", strings.NewReader(sampleRequest)) + if err != nil { + t.Fatalf("ParseRequest failed: %v", err) + } if v.OS.Version != "Indy" { t.Error("Unexpected version", v.OS.Version) @@ -80,6 +107,63 @@ func TestOmahaRequestUpdateCheck(t *testing.T) { } } +func TestOmahaResponseWithUpdate(t *testing.T) { + parsed, err := ParseResponse("", strings.NewReader(sampleResponse)) + if err != nil { + t.Fatalf("ParseResponse failed: %v", err) + } + + expected := &Response{ + XMLName: xml.Name{Local: "response"}, + Protocol: "3.0", + DayStart: DayStart{ElapsedSeconds: "49008"}, + Apps: []*AppResponse{&AppResponse{ + ID: "{87efface-864d-49a5-9bb3-4b050a7c227a}", + Status: AppOK, + Ping: &PingResponse{"ok"}, + UpdateCheck: &UpdateResponse{ + Status: UpdateOK, + URLs: []*URL{&URL{ + CodeBase: "http://kam:8080/static/", + }}, + Manifest: &Manifest{ + Version: "9999.0.0", + Packages: []*Package{&Package{ + SHA1: "+LXvjiaPkeYDLHoNKlf9qbJwvnk=", + Name: "update.gz", + Size: 67546213, + Required: true, + }}, + Actions: []*Action{&Action{ + Event: "postinstall", + DisplayVersion: "9999.0.0", + SHA256: "0VAlQW3RE99SGtSB5R4m08antAHO8XDoBMKDyxQT/Mg=", + IsDeltaPayload: true, + }}, + }, + }, + }}, + } + + if !reflect.DeepEqual(parsed, expected) { + t.Errorf("parsed != expected\n%s\n%s", parsed, expected) + } +} + +func TestOmahaResponsAsRequest(t *testing.T) { + _, err := ParseRequest("", strings.NewReader(sampleResponse)) + if err == nil { + t.Fatal("ParseRequest successfully parsed a response") + } +} + +func TestOmahaRequestAsResponse(t *testing.T) { + _, err := ParseResponse("", strings.NewReader(sampleRequest)) + if err == nil { + t.Fatal("ParseResponse successfully parsed a request") + } +} + func ExampleNewResponse() { response := NewResponse() app := response.AddApp("{52F1B9BC-D31A-4D86-9276-CBC256AADF9A}", "ok")