diff --git a/omaha/client/client.go b/omaha/client/client.go index 72ee957..3be8c44 100644 --- a/omaha/client/client.go +++ b/omaha/client/client.go @@ -21,10 +21,13 @@ import ( "net/url" "github.com/satori/go.uuid" + + "github.com/coreos/go-omaha/omaha" ) // Client supports managing multiple apps using a single server. type Client struct { + apiClient *httpClient apiEndpoint string clientVersion string userID string @@ -49,6 +52,7 @@ func New(serverURL, userID string) (*Client, error) { } c := &Client{ + apiClient: newHTTPClient(), clientVersion: "go-omaha", userID: userID, sessionID: uuid.NewV4().String(), @@ -152,3 +156,119 @@ func (ac *AppClient) SetTrack(track string) error { ac.track = track return nil } + +func (ac *AppClient) UpdateCheck() (*omaha.UpdateResponse, error) { + req := ac.newReq() + app := req.Apps[0] + app.AddPing() + app.AddUpdateCheck() + + appResp, err := ac.doReq(ac.apiEndpoint, req) + if err != nil { + return nil, err + } + + if appResp.Ping == nil { + return nil, fmt.Errorf("omaha: ping status missing from response") + } + + if appResp.Ping.Status != "ok" { + return nil, fmt.Errorf("omaha: ping status %s", appResp.Ping.Status) + } + + if appResp.UpdateCheck == nil { + return nil, fmt.Errorf("omaha: update check missing from response") + } + + if appResp.UpdateCheck.Status != omaha.UpdateOK { + return nil, appResp.UpdateCheck.Status + } + + return appResp.UpdateCheck, nil +} + +func (ac *AppClient) Ping() error { + req := ac.newReq() + app := req.Apps[0] + app.AddPing() + + appResp, err := ac.doReq(ac.apiEndpoint, req) + if err != nil { + return err + } + + if appResp.Ping == nil { + return fmt.Errorf("omaha: ping status missing from response") + } + + if appResp.Ping.Status != "ok" { + return fmt.Errorf("omaha: ping status %s", appResp.Ping.Status) + } + + return nil +} + +func (ac *AppClient) Event(event *omaha.EventRequest) error { + req := ac.newReq() + app := req.Apps[0] + app.Events = append(app.Events, event) + + appResp, err := ac.doReq(ac.apiEndpoint, req) + if err != nil { + return err + } + + if len(appResp.Events) == 0 { + return fmt.Errorf("omaha: event status missing from response") + } + + if appResp.Events[0].Status != "ok" { + return fmt.Errorf("omaha: event status %s", appResp.Events[0].Status) + } + + return nil +} + +func (ac *AppClient) newReq() *omaha.Request { + req := omaha.NewRequest() + req.Version = ac.clientVersion + req.UserID = ac.userID + req.SessionID = ac.sessionID + if ac.isMachine { + req.IsMachine = 1 + } + + app := req.AddApp(ac.appID, ac.version) + app.Track = ac.track + + // MachineID and BootID are non-standard fields used by CoreOS' + // update_engine and Core Update. Copy their values from the + // standard UserID and SessionID. Eventually the non-standard + // fields should be deprecated. + app.MachineID = req.UserID + app.BootID = req.SessionID + + return req +} + +func (ac *AppClient) doReq(url string, req *omaha.Request) (*omaha.AppResponse, error) { + if len(req.Apps) != 1 { + panic(fmt.Errorf("unexpected number of apps: %d", len(req.Apps))) + } + appID := req.Apps[0].ID + resp, err := ac.apiClient.Omaha(url, req) + if err != nil { + return nil, err + } + + appResp := resp.GetApp(appID) + if appResp == nil { + return nil, fmt.Errorf("omaha: app %s missing from response", appID) + } + + if appResp.Status != omaha.AppOK { + return nil, appResp.Status + } + + return appResp, nil +} diff --git a/omaha/client/client_test.go b/omaha/client/client_test.go new file mode 100644 index 0000000..05ee0d0 --- /dev/null +++ b/omaha/client/client_test.go @@ -0,0 +1,181 @@ +// Copyright 2017 CoreOS, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package client + +import ( + "reflect" + "testing" + + "github.com/coreos/go-omaha/omaha" +) + +// implements omaha.Updater +type recorder struct { + t *testing.T + update *omaha.Update + checks []*omaha.UpdateRequest + events []*omaha.EventRequest + pings []*omaha.PingRequest +} + +func newRecordingServer(t *testing.T, u *omaha.Update) (*recorder, *omaha.Server) { + r := &recorder{t: t, update: u} + s, err := omaha.NewServer("127.0.0.1:0", r) + if err != nil { + t.Fatal(err) + } + go s.Serve() + return r, s +} + +func (r *recorder) CheckApp(req *omaha.Request, app *omaha.AppRequest) error { + // CheckApp is meant for checking if app.ID is valid but we don't + // care and accept any ID. Instead this is just a convenient place + // to check that all requests are well formed. + if len(req.SessionID) != 36 { + r.t.Errorf("SessionID %q is not a UUID", req.SessionID) + } + if app.BootID != req.SessionID { + r.t.Errorf("BootID %q != SessionID %q", app.BootID, req.SessionID) + } + if req.UserID == "" { + r.t.Error("UserID is blank") + } + if app.MachineID != req.UserID { + r.t.Errorf("MachineID %q != UserID %q", app.MachineID, req.UserID) + } + if app.Version == "" { + r.t.Error("App Version is blank") + } + return nil +} + +func (r *recorder) CheckUpdate(req *omaha.Request, app *omaha.AppRequest) (*omaha.Update, error) { + r.checks = append(r.checks, app.UpdateCheck) + if r.update == nil { + return nil, omaha.NoUpdate + } else { + return r.update, nil + } +} + +func (r *recorder) Event(req *omaha.Request, app *omaha.AppRequest, event *omaha.EventRequest) { + r.events = append(r.events, event) +} + +func (r *recorder) Ping(req *omaha.Request, app *omaha.AppRequest) { + r.pings = append(r.pings, app.Ping) +} + +func TestClientNoUpdate(t *testing.T) { + r, s := newRecordingServer(t, nil) + defer s.Destroy() + + url := "http://" + s.Addr().String() + ac, err := NewAppClient(url, "client-id", "app-id", "0.0.0") + if err != nil { + t.Fatal(err) + } + + if _, err := ac.UpdateCheck(); err != omaha.NoUpdate { + t.Fatalf("UpdateCheck id not return NoUpdate: %v", err) + } + + if len(r.pings) != 1 { + t.Fatalf("expected 1 ping, not %d", len(r.pings)) + } + + if len(r.checks) != 1 { + t.Fatalf("expected 1 update check, not %d", len(r.checks)) + } +} + +func TestClientWithUpdate(t *testing.T) { + r, s := newRecordingServer(t, &omaha.Update{ + Manifest: omaha.Manifest{ + Version: "1.1.1", + }, + }) + defer s.Destroy() + + url := "http://" + s.Addr().String() + ac, err := NewAppClient(url, "client-id", "app-id", "0.0.0") + if err != nil { + t.Fatal(err) + } + + update, err := ac.UpdateCheck() + if err != nil { + t.Fatal(err) + } + + if update.Manifest.Version != "1.1.1" { + t.Fatalf("expected version 1.1.1, not %s", update.Manifest.Version) + } + + if len(r.pings) != 1 { + t.Fatalf("expected 1 ping, not %d", len(r.pings)) + } + + if len(r.checks) != 1 { + t.Fatalf("expected 1 update check, not %d", len(r.checks)) + } +} + +func TestClientPing(t *testing.T) { + r, s := newRecordingServer(t, nil) + defer s.Destroy() + + url := "http://" + s.Addr().String() + ac, err := NewAppClient(url, "client-id", "app-id", "0.0.0") + if err != nil { + t.Fatal(err) + } + + if err := ac.Ping(); err != nil { + t.Fatal(err) + } + + if len(r.pings) != 1 { + t.Fatalf("expected 1 ping, not %d", len(r.pings)) + } +} + +func TestClientEvent(t *testing.T) { + r, s := newRecordingServer(t, nil) + defer s.Destroy() + + url := "http://" + s.Addr().String() + ac, err := NewAppClient(url, "client-id", "app-id", "0.0.0") + if err != nil { + t.Fatal(err) + } + + event := &omaha.EventRequest{ + Type: omaha.EventTypeDownloadComplete, + Result: omaha.EventResultSuccess, + } + if err := ac.Event(event); err != nil { + t.Fatal(err) + } + + if len(r.events) != 1 { + t.Fatalf("expected 1 event, not %d", len(r.events)) + } + + if !reflect.DeepEqual(event, r.events[0]) { + t.Fatalf("sent != received:\n%#v\n%#v", event, r.events[0]) + } +}