TLS fixes

Signed-off-by: Evan Hazlett <ejhazlett@gmail.com>
This commit is contained in:
Evan Hazlett 2019-10-08 02:06:48 -04:00
parent 512655fdb0
commit 6edfb4d93f
No known key found for this signature in database
GPG key ID: A519480096146526
8 changed files with 101 additions and 24 deletions

View file

@ -102,10 +102,10 @@ func (c *Client) Close() error {
func DialOptionsFromConfig(cfg *heimdall.Config) ([]grpc.DialOption, error) { func DialOptionsFromConfig(cfg *heimdall.Config) ([]grpc.DialOption, error) {
opts := []grpc.DialOption{} opts := []grpc.DialOption{}
if cfg.TLSClientCertificate != "" { if cfg.TLSClientCertificate != "" {
logrus.WithField("cert", cfg.TLSClientCertificate) logrus.WithField("cert", cfg.TLSClientCertificate).Debug("configuring TLS cert")
var creds credentials.TransportCredentials var creds credentials.TransportCredentials
if cfg.TLSClientKey != "" { if cfg.TLSClientKey != "" {
logrus.WithField("key", cfg.TLSClientKey) logrus.WithField("key", cfg.TLSClientKey).Debug("configuring TLS key")
cert, err := tls.LoadX509KeyPair(cfg.TLSClientCertificate, cfg.TLSClientKey) cert, err := tls.LoadX509KeyPair(cfg.TLSClientCertificate, cfg.TLSClientKey)
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -113,6 +113,30 @@ func main() {
Value: "darknet", Value: "darknet",
EnvVar: "HEIMDALL_INTERFACE_NAME", EnvVar: "HEIMDALL_INTERFACE_NAME",
}, },
cli.StringFlag{
Name: "cert, c",
Usage: "heimdall server certificate",
Value: "",
},
cli.StringFlag{
Name: "key, k",
Usage: "heimdall server key",
Value: "",
},
cli.StringFlag{
Name: "client-cert",
Usage: "heimdall client certificate",
Value: "",
},
cli.StringFlag{
Name: "client-key",
Usage: "heimdall client key",
Value: "",
},
cli.BoolFlag{
Name: "skip-verify",
Usage: "skip TLS verification",
},
} }
app.Before = func(c *cli.Context) error { app.Before = func(c *cli.Context) error {
if c.Bool("debug") { if c.Bool("debug") {

View file

@ -52,6 +52,11 @@ func runServer(cx *cli.Context) error {
EndpointPort: cx.Int("endpoint-port"), EndpointPort: cx.Int("endpoint-port"),
InterfaceName: cx.String("interface-name"), InterfaceName: cx.String("interface-name"),
RedisURL: cx.String("redis-url"), RedisURL: cx.String("redis-url"),
TLSServerCertificate: cx.String("cert"),
TLSServerKey: cx.String("key"),
TLSClientCertificate: cx.String("client-cert"),
TLSClientKey: cx.String("client-key"),
TLSInsecureSkipVerify: cx.Bool("skip-verify"),
} }
errCh := make(chan error, 1) errCh := make(chan error, 1)

View file

@ -39,6 +39,9 @@ func run(cx *cli.Context) error {
Address: cx.String("addr"), Address: cx.String("addr"),
UpdateInterval: cx.Duration("update-interval"), UpdateInterval: cx.Duration("update-interval"),
InterfaceName: cx.String("interface-name"), InterfaceName: cx.String("interface-name"),
TLSClientCertificate: cx.String("cert"),
TLSClientKey: cx.String("key"),
TLSInsecureSkipVerify: cx.Bool("skip-verify"),
} }
p, err := peer.NewPeer(cfg) p, err := peer.NewPeer(cfg)
if err != nil { if err != nil {

View file

@ -69,4 +69,10 @@ type PeerConfig struct {
UpdateInterval time.Duration UpdateInterval time.Duration
// InterfaceName is the interface used for peer communication // InterfaceName is the interface used for peer communication
InterfaceName string InterfaceName string
// TLSClientCertificate is the client certificate used for communication
TLSClientCertificate string
// TLSClientKey is the client key used for communication
TLSClientKey string
// TLSInsecureSkipVerify disables certificate verification
TLSInsecureSkipVerify bool
} }

View file

@ -22,12 +22,16 @@
package peer package peer
import ( import (
"context"
"fmt"
"path/filepath" "path/filepath"
"time" "time"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/stellarproject/heimdall" "github.com/stellarproject/heimdall"
"github.com/stellarproject/heimdall/client" "github.com/stellarproject/heimdall/client"
"github.com/stellarproject/heimdall/version"
"google.golang.org/grpc"
) )
const ( const (
@ -50,9 +54,12 @@ func NewPeer(cfg *heimdall.PeerConfig) (*Peer, error) {
func (p *Peer) Run() error { func (p *Peer) Run() error {
// initial sync // initial sync
logrus.Infof("connecting to peer %s", p.cfg.Address) logrus.Infof("connecting to peer %s", p.cfg.Address)
if err := p.sync(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), p.cfg.UpdateInterval)
if err := p.sync(ctx); err != nil {
cancel()
return err return err
} }
cancel()
doneCh := make(chan bool) doneCh := make(chan bool)
errCh := make(chan error) errCh := make(chan error)
@ -60,10 +67,13 @@ func (p *Peer) Run() error {
t := time.NewTicker(p.cfg.UpdateInterval) t := time.NewTicker(p.cfg.UpdateInterval)
go func() { go func() {
for range t.C { for range t.C {
if err := p.sync(); err != nil { ctx, cancel := context.WithTimeout(context.Background(), p.cfg.UpdateInterval)
if err := p.sync(ctx); err != nil {
errCh <- err errCh <- err
cancel()
return return
} }
cancel()
} }
}() }()
select { select {
@ -89,5 +99,20 @@ func (p *Peer) getTunnelName() string {
} }
func (p *Peer) getClient(addr string) (*client.Client, error) { func (p *Peer) getClient(addr string) (*client.Client, error) {
return client.NewClient(p.cfg.ID, addr) cfg := &heimdall.Config{
TLSClientCertificate: p.cfg.TLSClientCertificate,
TLSClientKey: p.cfg.TLSClientKey,
TLSInsecureSkipVerify: p.cfg.TLSInsecureSkipVerify,
}
opts, err := client.DialOptionsFromConfig(cfg)
if err != nil {
return nil, err
}
opts = append(opts,
grpc.WithBlock(),
grpc.WithUserAgent(fmt.Sprintf("%s/%s", version.Name, version.Version)),
)
return client.NewClient(p.cfg.ID, addr, opts...)
} }

View file

@ -31,15 +31,13 @@ import (
"github.com/stellarproject/heimdall/wg" "github.com/stellarproject/heimdall/wg"
) )
func (p *Peer) sync() error { func (p *Peer) sync(ctx context.Context) error {
c, err := p.getClient(p.cfg.Address) c, err := p.getClient(p.cfg.Address)
if err != nil { if err != nil {
return err return err
} }
defer c.Close() defer c.Close()
ctx := context.Background()
resp, err := c.Connect() resp, err := c.Connect()
if err != nil { if err != nil {
return err return err

View file

@ -44,6 +44,7 @@ import (
"github.com/stellarproject/heimdall" "github.com/stellarproject/heimdall"
v1 "github.com/stellarproject/heimdall/api/v1" v1 "github.com/stellarproject/heimdall/api/v1"
"github.com/stellarproject/heimdall/client" "github.com/stellarproject/heimdall/client"
"github.com/stellarproject/heimdall/version"
"github.com/stellarproject/heimdall/wg" "github.com/stellarproject/heimdall/wg"
"google.golang.org/grpc" "google.golang.org/grpc"
) )
@ -442,7 +443,22 @@ func (s *Server) getNodeNetworkKey(id string) string {
} }
func (s *Server) getClient(addr string) (*client.Client, error) { func (s *Server) getClient(addr string) (*client.Client, error) {
return client.NewClient(s.cfg.ID, addr) cfg := &heimdall.Config{
TLSClientCertificate: s.cfg.TLSClientCertificate,
TLSClientKey: s.cfg.TLSClientKey,
TLSInsecureSkipVerify: s.cfg.TLSInsecureSkipVerify,
}
opts, err := client.DialOptionsFromConfig(cfg)
if err != nil {
return nil, err
}
opts = append(opts,
grpc.WithBlock(),
grpc.WithUserAgent(fmt.Sprintf("%s/%s", version.Name, version.Version)),
)
return client.NewClient(s.cfg.ID, addr, opts...)
} }
func (s *Server) getClusterKey(ctx context.Context) (string, error) { func (s *Server) getClusterKey(ctx context.Context) (string, error) {