diff --git a/cmd/client/main.go b/cmd/client/main.go index f129984b..6d123ebe 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -3,7 +3,9 @@ package main import ( "fmt" "log" + "net" "os" + "time" pb "github.com/kubernetes/kubernetes/pkg/kubelet/api/v1alpha1/runtime" "github.com/urfave/cli" @@ -12,9 +14,22 @@ import ( ) const ( - address = "localhost:49999" + unixDomainSocket = "/var/run/ocid.sock" + // TODO: Make configurable + timeout = 10 * time.Second ) +func getClientConnection() (*grpc.ClientConn, error) { + conn, err := grpc.Dial(unixDomainSocket, grpc.WithInsecure(), grpc.WithTimeout(timeout), + grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { + return net.DialTimeout("unix", addr, timeout) + })) + if err != nil { + return nil, fmt.Errorf("Failed to connect: %v", err) + } + return conn, nil +} + // Version sends a VersionRequest to the server, and parses the returned VersionResponse. func Version(client pb.RuntimeServiceClient, version string) error { r, err := client.Version(context.Background(), &pb.VersionRequest{Version: &version}) @@ -54,7 +69,7 @@ var pullImageCommand = cli.Command{ Usage: "pull an image", Action: func(context *cli.Context) error { // Set up a connection to the server. - conn, err := grpc.Dial(address, grpc.WithInsecure()) + conn, err := getClientConnection() if err != nil { return fmt.Errorf("Failed to connect: %v", err) } @@ -74,7 +89,7 @@ var runtimeVersionCommand = cli.Command{ Usage: "get runtime version information", Action: func(context *cli.Context) error { // Set up a connection to the server. - conn, err := grpc.Dial(address, grpc.WithInsecure()) + conn, err := getClientConnection() if err != nil { return fmt.Errorf("Failed to connect: %v", err) } diff --git a/cmd/server/main.go b/cmd/server/main.go index 961742df..bfe35abc 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -3,6 +3,7 @@ package main import ( "log" "net" + "os" "github.com/kubernetes/kubernetes/pkg/kubelet/api/v1alpha1/runtime" "github.com/mrunalp/ocid/server" @@ -10,11 +11,17 @@ import ( ) const ( - port = ":49999" + unixDomainSocket = "/var/run/ocid.sock" ) func main() { - lis, err := net.Listen("tcp", port) + // Remove the socket if it already exists + if _, err := os.Stat(unixDomainSocket); err == nil { + if err := os.Remove(unixDomainSocket); err != nil { + log.Fatal(err) + } + } + lis, err := net.Listen("unix", unixDomainSocket) if err != nil { log.Fatalf("failed to listen: %v", err) }