diff --git a/omaha/isclosed.go b/omaha/isclosed.go new file mode 100644 index 0000000..5b00698 --- /dev/null +++ b/omaha/isclosed.go @@ -0,0 +1,32 @@ +// Copyright 2016 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package omaha + +import ( + "net" +) + +// isClosed detects if an error is due to a closed network connection, +// working around bug https://github.com/golang/go/issues/4373 +func isClosed(err error) bool { + if err == nil { + return false + } + if operr, ok := err.(*net.OpError); ok { + err = operr.Err + } + // cry softly + return err.Error() == "use of closed network connection" +} diff --git a/omaha/server.go b/omaha/server.go new file mode 100644 index 0000000..dcf025a --- /dev/null +++ b/omaha/server.go @@ -0,0 +1,73 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package omaha + +import ( + "net" + "net/http" +) + +func NewServer(addr string, updater Updater) (*Server, error) { + l, err := net.Listen("tcp", addr) + if err != nil { + return nil, err + } + + mux := http.NewServeMux() + + srv := &http.Server{ + Addr: addr, + Handler: mux, + } + + s := &Server{ + Updater: updater, + Mux: mux, + l: l, + srv: srv, + } + + h := &OmahaHandler{s} + mux.Handle("/v1/update", h) + mux.Handle("/v1/update/", h) + + return s, nil +} + +type Server struct { + Updater + + Mux *http.ServeMux + + l net.Listener + srv *http.Server +} + +func (s *Server) Serve() error { + err := s.srv.Serve(s.l) + if isClosed(err) { + // gracefully quit + err = nil + } + return nil +} + +func (s *Server) Destroy() error { + return s.l.Close() +} + +func (s *Server) Addr() net.Addr { + return s.l.Addr() +} diff --git a/omaha/server_test.go b/omaha/server_test.go new file mode 100644 index 0000000..d5806bc --- /dev/null +++ b/omaha/server_test.go @@ -0,0 +1,114 @@ +// Copyright 2015 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package omaha + +import ( + "bytes" + "encoding/xml" + "fmt" + "net/http" + "sync" + "testing" + "time" +) + +type mockServer struct { + UpdaterStub + + reqChan chan *Request +} + +func (m *mockServer) CheckApp(req *Request, app *AppRequest) error { + m.reqChan <- req + return nil +} + +func TestServerRequestResponse(t *testing.T) { + var wg sync.WaitGroup + defer wg.Wait() + + // make an omaha server + svc := &mockServer{ + reqChan: make(chan *Request), + } + + s, err := NewServer("127.0.0.1:0", svc) + if err != nil { + t.Fatalf("failed to create omaha server: %v", err) + } + defer func() { + err := s.Destroy() + if err != nil { + t.Error(err) + } + close(svc.reqChan) + }() + + wg.Add(1) + go func() { + defer wg.Done() + if err := s.Serve(); err != nil { + t.Errorf("Serve failed: %v", err) + } + }() + + buf := new(bytes.Buffer) + enc := xml.NewEncoder(buf) + enc.Indent("", "\t") + err = enc.Encode(nilRequest) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + + // check that server gets the same thing we sent + wg.Add(1) + go func() { + defer wg.Done() + sreq, ok := <-svc.reqChan + if !ok { + t.Errorf("failed to get notification from server") + return + } + + if err := compareXML(nilRequest, sreq); err != nil { + t.Error(err) + } + }() + + // send omaha request + endpoint := fmt.Sprintf("http://%s/v1/update/", s.Addr()) + httpClient := &http.Client{ + Timeout: 2 * time.Second, + } + res, err := httpClient.Post(endpoint, "text/xml", buf) + if err != nil { + t.Fatalf("failed to post: %v", err) + } + + defer res.Body.Close() + + if res.StatusCode != 200 { + t.Fatalf("failed to post: %v", res.Status) + } + + dec := xml.NewDecoder(res.Body) + sresp := &Response{} + if err := dec.Decode(sresp); err != nil { + t.Fatalf("failed to parse body: %v", err) + } + if err := compareXML(nilResponse, sresp); err != nil { + t.Error(err) + } +}