diff --git a/client/client.go b/client/client.go index 81e7bc2..3125d45 100644 --- a/client/client.go +++ b/client/client.go @@ -102,10 +102,10 @@ func (c *Client) Close() error { func DialOptionsFromConfig(cfg *heimdall.Config) ([]grpc.DialOption, error) { opts := []grpc.DialOption{} if cfg.TLSClientCertificate != "" { - logrus.WithField("cert", cfg.TLSClientCertificate) + logrus.WithField("cert", cfg.TLSClientCertificate).Debug("configuring TLS cert") var creds credentials.TransportCredentials if cfg.TLSClientKey != "" { - logrus.WithField("key", cfg.TLSClientKey) + logrus.WithField("key", cfg.TLSClientKey).Debug("configuring TLS key") cert, err := tls.LoadX509KeyPair(cfg.TLSClientCertificate, cfg.TLSClientKey) if err != nil { return nil, err diff --git a/cmd/heimdall/main.go b/cmd/heimdall/main.go index fc12248..cd40ca5 100644 --- a/cmd/heimdall/main.go +++ b/cmd/heimdall/main.go @@ -113,6 +113,30 @@ func main() { Value: "darknet", EnvVar: "HEIMDALL_INTERFACE_NAME", }, + cli.StringFlag{ + Name: "cert, c", + Usage: "heimdall server certificate", + Value: "", + }, + cli.StringFlag{ + Name: "key, k", + Usage: "heimdall server key", + Value: "", + }, + cli.StringFlag{ + Name: "client-cert", + Usage: "heimdall client certificate", + Value: "", + }, + cli.StringFlag{ + Name: "client-key", + Usage: "heimdall client key", + Value: "", + }, + cli.BoolFlag{ + Name: "skip-verify", + Usage: "skip TLS verification", + }, } app.Before = func(c *cli.Context) error { if c.Bool("debug") { diff --git a/cmd/heimdall/run.go b/cmd/heimdall/run.go index b801f0b..b94da14 100644 --- a/cmd/heimdall/run.go +++ b/cmd/heimdall/run.go @@ -41,17 +41,22 @@ import ( func runServer(cx *cli.Context) error { cfg := &heimdall.Config{ - ID: cx.String("id"), - GRPCAddress: cx.String("addr"), - AdvertiseGRPCAddress: cx.String("advertise-grpc-address"), - GRPCPeerAddress: cx.String("peer"), - ClusterKey: cx.String("cluster-key"), - NodeNetwork: cx.String("node-network"), - PeerNetwork: cx.String("peer-network"), - EndpointIP: cx.String("endpoint-ip"), - EndpointPort: cx.Int("endpoint-port"), - InterfaceName: cx.String("interface-name"), - RedisURL: cx.String("redis-url"), + ID: cx.String("id"), + GRPCAddress: cx.String("addr"), + AdvertiseGRPCAddress: cx.String("advertise-grpc-address"), + GRPCPeerAddress: cx.String("peer"), + ClusterKey: cx.String("cluster-key"), + NodeNetwork: cx.String("node-network"), + PeerNetwork: cx.String("peer-network"), + EndpointIP: cx.String("endpoint-ip"), + EndpointPort: cx.Int("endpoint-port"), + InterfaceName: cx.String("interface-name"), + RedisURL: cx.String("redis-url"), + TLSServerCertificate: cx.String("cert"), + TLSServerKey: cx.String("key"), + TLSClientCertificate: cx.String("client-cert"), + TLSClientKey: cx.String("client-key"), + TLSInsecureSkipVerify: cx.Bool("skip-verify"), } errCh := make(chan error, 1) diff --git a/cmd/hpeer/run.go b/cmd/hpeer/run.go index e248f1c..abd131c 100644 --- a/cmd/hpeer/run.go +++ b/cmd/hpeer/run.go @@ -35,10 +35,13 @@ import ( func run(cx *cli.Context) error { cfg := &heimdall.PeerConfig{ - ID: cx.String("id"), - Address: cx.String("addr"), - UpdateInterval: cx.Duration("update-interval"), - InterfaceName: cx.String("interface-name"), + ID: cx.String("id"), + Address: cx.String("addr"), + UpdateInterval: cx.Duration("update-interval"), + InterfaceName: cx.String("interface-name"), + TLSClientCertificate: cx.String("cert"), + TLSClientKey: cx.String("key"), + TLSInsecureSkipVerify: cx.Bool("skip-verify"), } p, err := peer.NewPeer(cfg) if err != nil { diff --git a/config.go b/config.go index d2f8075..4c41912 100644 --- a/config.go +++ b/config.go @@ -69,4 +69,10 @@ type PeerConfig struct { UpdateInterval time.Duration // InterfaceName is the interface used for peer communication InterfaceName string + // TLSClientCertificate is the client certificate used for communication + TLSClientCertificate string + // TLSClientKey is the client key used for communication + TLSClientKey string + // TLSInsecureSkipVerify disables certificate verification + TLSInsecureSkipVerify bool } diff --git a/peer/peer.go b/peer/peer.go index 21426ce..9258f83 100644 --- a/peer/peer.go +++ b/peer/peer.go @@ -22,12 +22,16 @@ package peer import ( + "context" + "fmt" "path/filepath" "time" "github.com/sirupsen/logrus" "github.com/stellarproject/heimdall" "github.com/stellarproject/heimdall/client" + "github.com/stellarproject/heimdall/version" + "google.golang.org/grpc" ) const ( @@ -50,9 +54,12 @@ func NewPeer(cfg *heimdall.PeerConfig) (*Peer, error) { func (p *Peer) Run() error { // initial sync logrus.Infof("connecting to peer %s", p.cfg.Address) - if err := p.sync(); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), p.cfg.UpdateInterval) + if err := p.sync(ctx); err != nil { + cancel() return err } + cancel() doneCh := make(chan bool) errCh := make(chan error) @@ -60,10 +67,13 @@ func (p *Peer) Run() error { t := time.NewTicker(p.cfg.UpdateInterval) go func() { for range t.C { - if err := p.sync(); err != nil { + ctx, cancel := context.WithTimeout(context.Background(), p.cfg.UpdateInterval) + if err := p.sync(ctx); err != nil { errCh <- err + cancel() return } + cancel() } }() select { @@ -89,5 +99,20 @@ func (p *Peer) getTunnelName() string { } func (p *Peer) getClient(addr string) (*client.Client, error) { - return client.NewClient(p.cfg.ID, addr) + cfg := &heimdall.Config{ + TLSClientCertificate: p.cfg.TLSClientCertificate, + TLSClientKey: p.cfg.TLSClientKey, + TLSInsecureSkipVerify: p.cfg.TLSInsecureSkipVerify, + } + + opts, err := client.DialOptionsFromConfig(cfg) + if err != nil { + return nil, err + } + opts = append(opts, + grpc.WithBlock(), + grpc.WithUserAgent(fmt.Sprintf("%s/%s", version.Name, version.Version)), + ) + + return client.NewClient(p.cfg.ID, addr, opts...) } diff --git a/peer/sync.go b/peer/sync.go index 7edf8ad..9504b52 100644 --- a/peer/sync.go +++ b/peer/sync.go @@ -31,15 +31,13 @@ import ( "github.com/stellarproject/heimdall/wg" ) -func (p *Peer) sync() error { +func (p *Peer) sync(ctx context.Context) error { c, err := p.getClient(p.cfg.Address) if err != nil { return err } defer c.Close() - ctx := context.Background() - resp, err := c.Connect() if err != nil { return err diff --git a/server/server.go b/server/server.go index d42d2db..65c8e2d 100644 --- a/server/server.go +++ b/server/server.go @@ -44,6 +44,7 @@ import ( "github.com/stellarproject/heimdall" v1 "github.com/stellarproject/heimdall/api/v1" "github.com/stellarproject/heimdall/client" + "github.com/stellarproject/heimdall/version" "github.com/stellarproject/heimdall/wg" "google.golang.org/grpc" ) @@ -442,7 +443,22 @@ func (s *Server) getNodeNetworkKey(id string) string { } func (s *Server) getClient(addr string) (*client.Client, error) { - return client.NewClient(s.cfg.ID, addr) + cfg := &heimdall.Config{ + TLSClientCertificate: s.cfg.TLSClientCertificate, + TLSClientKey: s.cfg.TLSClientKey, + TLSInsecureSkipVerify: s.cfg.TLSInsecureSkipVerify, + } + + opts, err := client.DialOptionsFromConfig(cfg) + if err != nil { + return nil, err + } + opts = append(opts, + grpc.WithBlock(), + grpc.WithUserAgent(fmt.Sprintf("%s/%s", version.Name, version.Version)), + ) + + return client.NewClient(s.cfg.ID, addr, opts...) } func (s *Server) getClusterKey(ctx context.Context) (string, error) {