1
0
Fork 0
mirror of https://github.com/vbatts/go-mtree.git synced 2025-10-06 13:11:01 +00:00

vendor: glide update

Signed-off-by: Vincent Batts <vbatts@hashbangbash.com>
This commit is contained in:
Vincent Batts 2019-01-21 10:10:51 -05:00
parent 53e54ea2f7
commit 8d3cf7ea39
Signed by: vbatts
GPG key ID: 10937E57733F1362
322 changed files with 47691 additions and 5542 deletions

View file

@ -2,12 +2,13 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
/*
Package agent implements a client to an ssh-agent daemon.
References:
[PROTOCOL.agent]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent
*/
// Package agent implements the ssh-agent protocol, and provides both
// a client and a server. The client can talk to a standard ssh-agent
// that uses UNIX sockets, and one could implement an alternative
// ssh-agent process using the sample server.
//
// References:
// [PROTOCOL.agent]: https://tools.ietf.org/html/draft-miller-ssh-agent-00
package agent // import "golang.org/x/crypto/ssh/agent"
import (
@ -24,9 +25,22 @@ import (
"math/big"
"sync"
"crypto"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
)
// SignatureFlags represent additional flags that can be passed to the signature
// requests an defined in [PROTOCOL.agent] section 4.5.1.
type SignatureFlags uint32
// SignatureFlag values as defined in [PROTOCOL.agent] section 5.3.
const (
SignatureFlagReserved SignatureFlags = 1 << iota
SignatureFlagRsaSha256
SignatureFlagRsaSha512
)
// Agent represents the capabilities of an ssh-agent.
type Agent interface {
// List returns the identities known to the agent.
@ -36,9 +50,8 @@ type Agent interface {
// in [PROTOCOL.agent] section 2.6.2.
Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error)
// Insert adds a private key to the agent. If a certificate
// is given, that certificate is added as public key.
Add(s interface{}, cert *ssh.Certificate, comment string) error
// Add adds a private key to the agent.
Add(key AddedKey) error
// Remove removes all identities with the given public key.
Remove(key ssh.PublicKey) error
@ -56,15 +69,68 @@ type Agent interface {
Signers() ([]ssh.Signer, error)
}
type ExtendedAgent interface {
Agent
// SignWithFlags signs like Sign, but allows for additional flags to be sent/received
SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error)
// Extension processes a custom extension request. Standard-compliant agents are not
// required to support any extensions, but this method allows agents to implement
// vendor-specific methods or add experimental features. See [PROTOCOL.agent] section 4.7.
// If agent extensions are unsupported entirely this method MUST return an
// ErrExtensionUnsupported error. Similarly, if just the specific extensionType in
// the request is unsupported by the agent then ErrExtensionUnsupported MUST be
// returned.
//
// In the case of success, since [PROTOCOL.agent] section 4.7 specifies that the contents
// of the response are unspecified (including the type of the message), the complete
// response will be returned as a []byte slice, including the "type" byte of the message.
Extension(extensionType string, contents []byte) ([]byte, error)
}
// ConstraintExtension describes an optional constraint defined by users.
type ConstraintExtension struct {
// ExtensionName consist of a UTF-8 string suffixed by the
// implementation domain following the naming scheme defined
// in Section 4.2 of [RFC4251], e.g. "foo@example.com".
ExtensionName string
// ExtensionDetails contains the actual content of the extended
// constraint.
ExtensionDetails []byte
}
// AddedKey describes an SSH key to be added to an Agent.
type AddedKey struct {
// PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or
// *ecdsa.PrivateKey, which will be inserted into the agent.
PrivateKey interface{}
// Certificate, if not nil, is communicated to the agent and will be
// stored with the key.
Certificate *ssh.Certificate
// Comment is an optional, free-form string.
Comment string
// LifetimeSecs, if not zero, is the number of seconds that the
// agent will store the key for.
LifetimeSecs uint32
// ConfirmBeforeUse, if true, requests that the agent confirm with the
// user before each use of this key.
ConfirmBeforeUse bool
// ConstraintExtensions are the experimental or private-use constraints
// defined by users.
ConstraintExtensions []ConstraintExtension
}
// See [PROTOCOL.agent], section 3.
const (
agentRequestV1Identities = 1
agentRequestV1Identities = 1
agentRemoveAllV1Identities = 9
// 3.2 Requests from client to agent for protocol 2 key operations
agentAddIdentity = 17
agentRemoveIdentity = 18
agentRemoveAllIdentities = 19
agentAddIdConstrained = 25
agentAddIDConstrained = 25
// 3.3 Key-type independent requests from client to agent
agentAddSmartcardKey = 20
@ -74,8 +140,9 @@ const (
agentAddSmartcardKeyConstrained = 26
// 3.7 Key constraint identifiers
agentConstrainLifetime = 1
agentConstrainConfirm = 2
agentConstrainLifetime = 1
agentConstrainConfirm = 2
agentConstrainExtension = 3
)
// maxAgentResponseBytes is the maximum agent reply size that is accepted. This
@ -131,6 +198,36 @@ type publicKey struct {
Rest []byte `ssh:"rest"`
}
// 3.7 Key constraint identifiers
type constrainLifetimeAgentMsg struct {
LifetimeSecs uint32 `sshtype:"1"`
}
type constrainExtensionAgentMsg struct {
ExtensionName string `sshtype:"3"`
ExtensionDetails []byte
// Rest is a field used for parsing, not part of message
Rest []byte `ssh:"rest"`
}
// See [PROTOCOL.agent], section 4.7
const agentExtension = 27
const agentExtensionFailure = 28
// ErrExtensionUnsupported indicates that an extension defined in
// [PROTOCOL.agent] section 4.7 is unsupported by the agent. Specifically this
// error indicates that the agent returned a standard SSH_AGENT_FAILURE message
// as the result of a SSH_AGENTC_EXTENSION request. Note that the protocol
// specification (and therefore this error) does not distinguish between a
// specific extension being unsupported and extensions being unsupported entirely.
var ErrExtensionUnsupported = errors.New("agent: extension unsupported")
type extensionAgentMsg struct {
ExtensionType string `sshtype:"27"`
Contents []byte
}
// Key represents a protocol 2 public key as defined in
// [PROTOCOL.agent], section 2.5.2.
type Key struct {
@ -165,10 +262,13 @@ func (k *Key) Marshal() []byte {
return k.Blob
}
// Verify satisfies the ssh.PublicKey interface, but is not
// implemented for agent keys.
// Verify satisfies the ssh.PublicKey interface.
func (k *Key) Verify(data []byte, sig *ssh.Signature) error {
return errors.New("agent: agent key does not know how to verify")
pubKey, err := ssh.ParsePublicKey(k.Blob)
if err != nil {
return fmt.Errorf("agent: bad public key: %v", err)
}
return pubKey.Verify(data, sig)
}
type wireKey struct {
@ -209,7 +309,7 @@ type client struct {
// NewClient returns an Agent that talks to an ssh-agent process over
// the given connection.
func NewClient(rw io.ReadWriter) Agent {
func NewClient(rw io.ReadWriter) ExtendedAgent {
return &client{conn: rw}
}
@ -217,6 +317,21 @@ func NewClient(rw io.ReadWriter) Agent {
// unmarshaled into reply and replyType is set to the first byte of
// the reply, which contains the type of the message.
func (c *client) call(req []byte) (reply interface{}, err error) {
buf, err := c.callRaw(req)
if err != nil {
return nil, err
}
reply, err = unmarshal(buf)
if err != nil {
return nil, clientErr(err)
}
return reply, nil
}
// callRaw sends an RPC to the agent. On success, the raw
// bytes of the response are returned; no unmarshalling is
// performed on the response.
func (c *client) callRaw(req []byte) (reply []byte, err error) {
c.mu.Lock()
defer c.mu.Unlock()
@ -233,18 +348,14 @@ func (c *client) call(req []byte) (reply interface{}, err error) {
}
respSize := binary.BigEndian.Uint32(respSizeBuf[:])
if respSize > maxAgentResponseBytes {
return nil, clientErr(err)
return nil, clientErr(errors.New("response too large"))
}
buf := make([]byte, respSize)
if _, err = io.ReadFull(c.conn, buf); err != nil {
return nil, clientErr(err)
}
reply, err = unmarshal(buf)
if err != nil {
return nil, clientErr(err)
}
return reply, err
return buf, nil
}
func (c *client) simpleCall(req []byte) error {
@ -318,9 +429,14 @@ func (c *client) List() ([]*Key, error) {
// Sign has the agent sign the data using a protocol 2 key as defined
// in [PROTOCOL.agent] section 2.6.2.
func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
return c.SignWithFlags(key, data, 0)
}
func (c *client) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) {
req := ssh.Marshal(signRequestAgentMsg{
KeyBlob: key.Marshal(),
Data: data,
Flags: uint32(flags),
})
msg, err := c.call(req)
@ -358,6 +474,8 @@ func unmarshal(packet []byte) (interface{}, error) {
msg = new(identitiesAnswerAgentMsg)
case agentSignResponse:
msg = new(signResponseAgentMsg)
case agentV1IdentitiesAnswer:
msg = new(agentV1IdentityMsg)
default:
return nil, fmt.Errorf("agent: unknown type tag %d", packet[0])
}
@ -368,36 +486,47 @@ func unmarshal(packet []byte) (interface{}, error) {
}
type rsaKeyMsg struct {
Type string `sshtype:"17"`
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int // IQMP = Inverse Q Mod P
P *big.Int
Q *big.Int
Comments string
Type string `sshtype:"17|25"`
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int // IQMP = Inverse Q Mod P
P *big.Int
Q *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type dsaKeyMsg struct {
Type string `sshtype:"17"`
P *big.Int
Q *big.Int
G *big.Int
Y *big.Int
X *big.Int
Comments string
Type string `sshtype:"17|25"`
P *big.Int
Q *big.Int
G *big.Int
Y *big.Int
X *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type ecdsaKeyMsg struct {
Type string `sshtype:"17"`
Curve string
KeyBytes []byte
D *big.Int
Comments string
Type string `sshtype:"17|25"`
Curve string
KeyBytes []byte
D *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type ed25519KeyMsg struct {
Type string `sshtype:"17|25"`
Pub []byte
Priv []byte
Comments string
Constraints []byte `ssh:"rest"`
}
// Insert adds a private key to the agent.
func (c *client) insertKey(s interface{}, comment string) error {
func (c *client) insertKey(s interface{}, comment string, constraints []byte) error {
var req []byte
switch k := s.(type) {
case *rsa.PrivateKey:
@ -406,37 +535,54 @@ func (c *client) insertKey(s interface{}, comment string) error {
}
k.Precompute()
req = ssh.Marshal(rsaKeyMsg{
Type: ssh.KeyAlgoRSA,
N: k.N,
E: big.NewInt(int64(k.E)),
D: k.D,
Iqmp: k.Precomputed.Qinv,
P: k.Primes[0],
Q: k.Primes[1],
Comments: comment,
Type: ssh.KeyAlgoRSA,
N: k.N,
E: big.NewInt(int64(k.E)),
D: k.D,
Iqmp: k.Precomputed.Qinv,
P: k.Primes[0],
Q: k.Primes[1],
Comments: comment,
Constraints: constraints,
})
case *dsa.PrivateKey:
req = ssh.Marshal(dsaKeyMsg{
Type: ssh.KeyAlgoDSA,
P: k.P,
Q: k.Q,
G: k.G,
Y: k.Y,
X: k.X,
Comments: comment,
Type: ssh.KeyAlgoDSA,
P: k.P,
Q: k.Q,
G: k.G,
Y: k.Y,
X: k.X,
Comments: comment,
Constraints: constraints,
})
case *ecdsa.PrivateKey:
nistID := fmt.Sprintf("nistp%d", k.Params().BitSize)
req = ssh.Marshal(ecdsaKeyMsg{
Type: "ecdsa-sha2-" + nistID,
Curve: nistID,
KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y),
D: k.D,
Comments: comment,
Type: "ecdsa-sha2-" + nistID,
Curve: nistID,
KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y),
D: k.D,
Comments: comment,
Constraints: constraints,
})
case *ed25519.PrivateKey:
req = ssh.Marshal(ed25519KeyMsg{
Type: ssh.KeyAlgoED25519,
Pub: []byte(*k)[32:],
Priv: []byte(*k),
Comments: comment,
Constraints: constraints,
})
default:
return fmt.Errorf("agent: unsupported key type %T", s)
}
// if constraints are present then the message type needs to be changed.
if len(constraints) != 0 {
req[0] = agentAddIDConstrained
}
resp, err := c.call(req)
if err != nil {
return err
@ -448,40 +594,62 @@ func (c *client) insertKey(s interface{}, comment string) error {
}
type rsaCertMsg struct {
Type string `sshtype:"17"`
CertBytes []byte
D *big.Int
Iqmp *big.Int // IQMP = Inverse Q Mod P
P *big.Int
Q *big.Int
Comments string
Type string `sshtype:"17|25"`
CertBytes []byte
D *big.Int
Iqmp *big.Int // IQMP = Inverse Q Mod P
P *big.Int
Q *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type dsaCertMsg struct {
Type string `sshtype:"17"`
CertBytes []byte
X *big.Int
Comments string
Type string `sshtype:"17|25"`
CertBytes []byte
X *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
type ecdsaCertMsg struct {
Type string `sshtype:"17"`
CertBytes []byte
D *big.Int
Comments string
Type string `sshtype:"17|25"`
CertBytes []byte
D *big.Int
Comments string
Constraints []byte `ssh:"rest"`
}
// Insert adds a private key to the agent. If a certificate is given,
type ed25519CertMsg struct {
Type string `sshtype:"17|25"`
CertBytes []byte
Pub []byte
Priv []byte
Comments string
Constraints []byte `ssh:"rest"`
}
// Add adds a private key to the agent. If a certificate is given,
// that certificate is added instead as public key.
func (c *client) Add(s interface{}, cert *ssh.Certificate, comment string) error {
if cert == nil {
return c.insertKey(s, comment)
} else {
return c.insertCert(s, cert, comment)
func (c *client) Add(key AddedKey) error {
var constraints []byte
if secs := key.LifetimeSecs; secs != 0 {
constraints = append(constraints, ssh.Marshal(constrainLifetimeAgentMsg{secs})...)
}
if key.ConfirmBeforeUse {
constraints = append(constraints, agentConstrainConfirm)
}
cert := key.Certificate
if cert == nil {
return c.insertKey(key.PrivateKey, key.Comment, constraints)
}
return c.insertCert(key.PrivateKey, cert, key.Comment, constraints)
}
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string) error {
func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error {
var req []byte
switch k := s.(type) {
case *rsa.PrivateKey:
@ -490,32 +658,49 @@ func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string
}
k.Precompute()
req = ssh.Marshal(rsaCertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
D: k.D,
Iqmp: k.Precomputed.Qinv,
P: k.Primes[0],
Q: k.Primes[1],
Comments: comment,
Type: cert.Type(),
CertBytes: cert.Marshal(),
D: k.D,
Iqmp: k.Precomputed.Qinv,
P: k.Primes[0],
Q: k.Primes[1],
Comments: comment,
Constraints: constraints,
})
case *dsa.PrivateKey:
req = ssh.Marshal(dsaCertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
X: k.X,
Comments: comment,
Type: cert.Type(),
CertBytes: cert.Marshal(),
X: k.X,
Comments: comment,
Constraints: constraints,
})
case *ecdsa.PrivateKey:
req = ssh.Marshal(ecdsaCertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
D: k.D,
Comments: comment,
Type: cert.Type(),
CertBytes: cert.Marshal(),
D: k.D,
Comments: comment,
Constraints: constraints,
})
case *ed25519.PrivateKey:
req = ssh.Marshal(ed25519CertMsg{
Type: cert.Type(),
CertBytes: cert.Marshal(),
Pub: []byte(*k)[32:],
Priv: []byte(*k),
Comments: comment,
Constraints: constraints,
})
default:
return fmt.Errorf("agent: unsupported key type %T", s)
}
// if constraints are present then the message type needs to be changed.
if len(constraints) != 0 {
req[0] = agentAddIDConstrained
}
signer, err := ssh.NewSignerFromKey(s)
if err != nil {
return err
@ -561,3 +746,44 @@ func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature,
// The agent has its own entropy source, so the rand argument is ignored.
return s.agent.Sign(s.pub, data)
}
func (s *agentKeyringSigner) SignWithOpts(rand io.Reader, data []byte, opts crypto.SignerOpts) (*ssh.Signature, error) {
var flags SignatureFlags
if opts != nil {
switch opts.HashFunc() {
case crypto.SHA256:
flags = SignatureFlagRsaSha256
case crypto.SHA512:
flags = SignatureFlagRsaSha512
}
}
return s.agent.SignWithFlags(s.pub, data, flags)
}
// Calls an extension method. It is up to the agent implementation as to whether or not
// any particular extension is supported and may always return an error. Because the
// type of the response is up to the implementation, this returns the bytes of the
// response and does not attempt any type of unmarshalling.
func (c *client) Extension(extensionType string, contents []byte) ([]byte, error) {
req := ssh.Marshal(extensionAgentMsg{
ExtensionType: extensionType,
Contents: contents,
})
buf, err := c.callRaw(req)
if err != nil {
return nil, err
}
if len(buf) == 0 {
return nil, errors.New("agent: failure; empty response")
}
// [PROTOCOL.agent] section 4.7 indicates that an SSH_AGENT_FAILURE message
// represents an agent that does not support the extension
if buf[0] == agentFailure {
return nil, ErrExtensionUnsupported
}
if buf[0] == agentExtensionFailure {
return nil, errors.New("agent: generic extension failure")
}
return buf, nil
}

View file

@ -14,12 +14,13 @@ import (
"path/filepath"
"strconv"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
// startAgent executes ssh-agent, and returns a Agent interface to it.
func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) {
// startOpenSSHAgent executes ssh-agent, and returns an Agent interface to it.
func startOpenSSHAgent(t *testing.T) (client ExtendedAgent, socket string, cleanup func()) {
if testing.Short() {
// ssh-agent is not always available, and the key
// types supported vary by platform.
@ -78,14 +79,39 @@ func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) {
}
}
func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate) {
agent, _, cleanup := startAgent(t)
defer cleanup()
func startAgent(t *testing.T, agent Agent) (client ExtendedAgent, cleanup func()) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
go ServeAgent(agent, c2)
testAgentInterface(t, agent, key, cert)
return NewClient(c1), func() {
c1.Close()
c2.Close()
}
}
func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate) {
// startKeyringAgent uses Keyring to simulate a ssh-agent Server and returns a client.
func startKeyringAgent(t *testing.T) (client ExtendedAgent, cleanup func()) {
return startAgent(t, NewKeyring())
}
func testOpenSSHAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
agent, _, cleanup := startOpenSSHAgent(t)
defer cleanup()
testAgentInterface(t, agent, key, cert, lifetimeSecs)
}
func testKeyringAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
agent, cleanup := startKeyringAgent(t)
defer cleanup()
testAgentInterface(t, agent, key, cert, lifetimeSecs)
}
func testAgentInterface(t *testing.T, agent ExtendedAgent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) {
signer, err := ssh.NewSignerFromKey(key)
if err != nil {
t.Fatalf("NewSignerFromKey(%T): %v", key, err)
@ -100,10 +126,15 @@ func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Ce
// Attempt to insert the key, with certificate if specified.
var pubKey ssh.PublicKey
if cert != nil {
err = agent.Add(key, cert, "comment")
err = agent.Add(AddedKey{
PrivateKey: key,
Certificate: cert,
Comment: "comment",
LifetimeSecs: lifetimeSecs,
})
pubKey = cert
} else {
err = agent.Add(key, nil, "comment")
err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs})
pubKey = signer.PublicKey()
}
if err != nil {
@ -131,11 +162,44 @@ func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Ce
if err := pubKey.Verify(data, sig); err != nil {
t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
}
// For tests on RSA keys, try signing with SHA-256 and SHA-512 flags
if pubKey.Type() == "ssh-rsa" {
sshFlagTest := func(flag SignatureFlags, expectedSigFormat string) {
sig, err = agent.SignWithFlags(pubKey, data, flag)
if err != nil {
t.Fatalf("SignWithFlags(%s): %v", pubKey.Type(), err)
}
if sig.Format != expectedSigFormat {
t.Fatalf("Signature format didn't match expected value: %s != %s", sig.Format, expectedSigFormat)
}
if err := pubKey.Verify(data, sig); err != nil {
t.Fatalf("Verify(%s): %v", pubKey.Type(), err)
}
}
sshFlagTest(0, ssh.SigAlgoRSA)
sshFlagTest(SignatureFlagRsaSha256, ssh.SigAlgoRSASHA2256)
sshFlagTest(SignatureFlagRsaSha512, ssh.SigAlgoRSASHA2512)
}
// If the key has a lifetime, is it removed when it should be?
if lifetimeSecs > 0 {
time.Sleep(time.Second*time.Duration(lifetimeSecs) + 100*time.Millisecond)
keys, err := agent.List()
if err != nil {
t.Fatalf("List: %v", err)
}
if len(keys) > 0 {
t.Fatalf("key not expired")
}
}
}
func TestAgent(t *testing.T) {
for _, keyType := range []string{"rsa", "dsa", "ecdsa"} {
testAgent(t, testPrivateKeys[keyType], nil)
for _, keyType := range []string{"rsa", "dsa", "ecdsa", "ed25519"} {
testOpenSSHAgent(t, testPrivateKeys[keyType], nil, 0)
testKeyringAgent(t, testPrivateKeys[keyType], nil, 0)
}
}
@ -147,7 +211,8 @@ func TestCert(t *testing.T) {
}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
testAgent(t, testPrivateKeys["rsa"], cert)
testOpenSSHAgent(t, testPrivateKeys["rsa"], cert, 0)
testKeyringAgent(t, testPrivateKeys["rsa"], cert, 0)
}
// netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
@ -156,7 +221,10 @@ func TestCert(t *testing.T) {
func netPipe() (net.Conn, net.Conn, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, nil, err
listener, err = net.Listen("tcp", "[::1]:0")
if err != nil {
return nil, nil, err
}
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
@ -173,7 +241,7 @@ func netPipe() (net.Conn, net.Conn, error) {
return c1, c2, nil
}
func TestAuth(t *testing.T) {
func TestServerResponseTooLarge(t *testing.T) {
a, b, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
@ -182,10 +250,39 @@ func TestAuth(t *testing.T) {
defer a.Close()
defer b.Close()
agent, _, cleanup := startAgent(t)
var response identitiesAnswerAgentMsg
response.NumKeys = 1
response.Keys = make([]byte, maxAgentResponseBytes+1)
agent := NewClient(a)
go func() {
n, _ := b.Write(ssh.Marshal(response))
if n < 4 {
t.Fatalf("At least 4 bytes (the response size) should have been successfully written: %d < 4", n)
}
}()
_, err = agent.List()
if err == nil {
t.Fatal("Did not get error result")
}
if err.Error() != "agent: client error: response too large" {
t.Fatal("Did not get expected error result")
}
}
func TestAuth(t *testing.T) {
agent, _, cleanup := startOpenSSHAgent(t)
defer cleanup()
if err := agent.Add(testPrivateKeys["rsa"], nil, "comment"); err != nil {
a, b, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer a.Close()
defer b.Close()
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil {
t.Errorf("Add: %v", err)
}
@ -207,7 +304,9 @@ func TestAuth(t *testing.T) {
conn.Close()
}()
conf := ssh.ClientConfig{}
conf := ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers))
conn, _, _, err := ssh.NewClientConn(b, "", &conf)
if err != nil {
@ -216,17 +315,23 @@ func TestAuth(t *testing.T) {
conn.Close()
}
func TestLockClient(t *testing.T) {
agent, _, cleanup := startAgent(t)
func TestLockOpenSSHAgent(t *testing.T) {
agent, _, cleanup := startOpenSSHAgent(t)
defer cleanup()
testLockAgent(agent, t)
}
func TestLockKeyringAgent(t *testing.T) {
agent, cleanup := startKeyringAgent(t)
defer cleanup()
testLockAgent(agent, t)
}
func testLockAgent(agent Agent, t *testing.T) {
if err := agent.Add(testPrivateKeys["rsa"], nil, "comment 1"); err != nil {
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil {
t.Errorf("Add: %v", err)
}
if err := agent.Add(testPrivateKeys["dsa"], nil, "comment dsa"); err != nil {
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil {
t.Errorf("Add: %v", err)
}
if keys, err := agent.List(); err != nil {
@ -276,3 +381,86 @@ func testLockAgent(agent Agent, t *testing.T) {
t.Errorf("Want 1 keys, got %v", keys)
}
}
func testOpenSSHAgentLifetime(t *testing.T) {
agent, _, cleanup := startOpenSSHAgent(t)
defer cleanup()
testAgentLifetime(t, agent)
}
func testKeyringAgentLifetime(t *testing.T) {
agent, cleanup := startKeyringAgent(t)
defer cleanup()
testAgentLifetime(t, agent)
}
func testAgentLifetime(t *testing.T, agent Agent) {
for _, keyType := range []string{"rsa", "dsa", "ecdsa"} {
// Add private keys to the agent.
err := agent.Add(AddedKey{
PrivateKey: testPrivateKeys[keyType],
Comment: "comment",
LifetimeSecs: 1,
})
if err != nil {
t.Fatalf("add: %v", err)
}
// Add certs to the agent.
cert := &ssh.Certificate{
Key: testPublicKeys[keyType],
ValidBefore: ssh.CertTimeInfinity,
CertType: ssh.UserCert,
}
cert.SignCert(rand.Reader, testSigners[keyType])
err = agent.Add(AddedKey{
PrivateKey: testPrivateKeys[keyType],
Certificate: cert,
Comment: "comment",
LifetimeSecs: 1,
})
if err != nil {
t.Fatalf("add: %v", err)
}
}
time.Sleep(1100 * time.Millisecond)
if keys, err := agent.List(); err != nil {
t.Errorf("List: %v", err)
} else if len(keys) != 0 {
t.Errorf("Want 0 keys, got %v", len(keys))
}
}
type keyringExtended struct {
*keyring
}
func (r *keyringExtended) Extension(extensionType string, contents []byte) ([]byte, error) {
if extensionType != "my-extension@example.com" {
return []byte{agentExtensionFailure}, nil
}
return append([]byte{agentSuccess}, contents...), nil
}
func TestAgentExtensions(t *testing.T) {
agent, _, cleanup := startOpenSSHAgent(t)
defer cleanup()
result, err := agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
if err == nil {
t.Fatal("should have gotten agent extension failure")
}
agent, cleanup = startAgent(t, &keyringExtended{})
defer cleanup()
result, err = agent.Extension("my-extension@example.com", []byte{0x00, 0x01, 0x02})
if err != nil {
t.Fatalf("agent extension failure: %v", err)
}
if len(result) != 4 || !bytes.Equal(result, []byte{agentSuccess, 0x00, 0x01, 0x02}) {
t.Fatalf("agent extension result invalid: %v", result)
}
result, err = agent.Extension("bad-extension@example.com", []byte{0x00, 0x01, 0x02})
if err == nil {
t.Fatal("should have gotten agent extension failure")
}
}

41
vendor/golang.org/x/crypto/ssh/agent/example_test.go generated vendored Normal file
View file

@ -0,0 +1,41 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package agent_test
import (
"log"
"net"
"os"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)
func ExampleClientAgent() {
// ssh-agent has a UNIX socket under $SSH_AUTH_SOCK
socket := os.Getenv("SSH_AUTH_SOCK")
conn, err := net.Dial("unix", socket)
if err != nil {
log.Fatalf("net.Dial: %v", err)
}
agentClient := agent.NewClient(conn)
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
// Use a callback rather than PublicKeys
// so we only consult the agent once the remote server
// wants it.
ssh.PublicKeysCallback(agentClient.Signers),
},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
sshc, err := ssh.Dial("tcp", "localhost:22", config)
if err != nil {
log.Fatalf("Dial: %v", err)
}
// .. use sshc
sshc.Close()
}

View file

@ -11,6 +11,7 @@ import (
"errors"
"fmt"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
@ -18,6 +19,7 @@ import (
type privKey struct {
signer ssh.Signer
comment string
expire *time.Time
}
type keyring struct {
@ -48,21 +50,15 @@ func (r *keyring) RemoveAll() error {
return nil
}
// Remove removes all identities with the given public key.
func (r *keyring) Remove(key ssh.PublicKey) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
want := key.Marshal()
// removeLocked does the actual key removal. The caller must already be holding the
// keyring mutex.
func (r *keyring) removeLocked(want []byte) error {
found := false
for i := 0; i < len(r.keys); {
if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) {
found = true
r.keys[i] = r.keys[len(r.keys)-1]
r.keys = r.keys[len(r.keys)-1:]
r.keys = r.keys[:len(r.keys)-1]
continue
} else {
i++
@ -75,7 +71,18 @@ func (r *keyring) Remove(key ssh.PublicKey) error {
return nil
}
// Lock locks the agent. Sign and Remove will fail, and List will empty an empty list.
// Remove removes all identities with the given public key.
func (r *keyring) Remove(key ssh.PublicKey) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
return r.removeLocked(key.Marshal())
}
// Lock locks the agent. Sign and Remove will fail, and List will return an empty list.
func (r *keyring) Lock(passphrase []byte) error {
r.mu.Lock()
defer r.mu.Unlock()
@ -95,7 +102,7 @@ func (r *keyring) Unlock(passphrase []byte) error {
if !r.locked {
return errors.New("agent: not locked")
}
if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) {
if 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) {
return fmt.Errorf("agent: incorrect passphrase")
}
@ -104,6 +111,17 @@ func (r *keyring) Unlock(passphrase []byte) error {
return nil
}
// expireKeysLocked removes expired keys from the keyring. If a key was added
// with a lifetimesecs contraint and seconds >= lifetimesecs seconds have
// ellapsed, it is removed. The caller *must* be holding the keyring mutex.
func (r *keyring) expireKeysLocked() {
for _, k := range r.keys {
if k.expire != nil && time.Now().After(*k.expire) {
r.removeLocked(k.signer.PublicKey().Marshal())
}
}
}
// List returns the identities known to the agent.
func (r *keyring) List() ([]*Key, error) {
r.mu.Lock()
@ -113,6 +131,7 @@ func (r *keyring) List() ([]*Key, error) {
return nil, nil
}
r.expireKeysLocked()
var ids []*Key
for _, k := range r.keys {
pub := k.signer.PublicKey()
@ -125,43 +144,76 @@ func (r *keyring) List() ([]*Key, error) {
}
// Insert adds a private key to the keyring. If a certificate
// is given, that certificate is added as public key.
func (r *keyring) Add(priv interface{}, cert *ssh.Certificate, comment string) error {
// is given, that certificate is added as public key. Note that
// any constraints given are ignored.
func (r *keyring) Add(key AddedKey) error {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return errLocked
}
signer, err := ssh.NewSignerFromKey(priv)
signer, err := ssh.NewSignerFromKey(key.PrivateKey)
if err != nil {
return err
}
if cert != nil {
if cert := key.Certificate; cert != nil {
signer, err = ssh.NewCertSigner(cert, signer)
if err != nil {
return err
}
}
r.keys = append(r.keys, privKey{signer, comment})
p := privKey{
signer: signer,
comment: key.Comment,
}
if key.LifetimeSecs > 0 {
t := time.Now().Add(time.Duration(key.LifetimeSecs) * time.Second)
p.expire = &t
}
r.keys = append(r.keys, p)
return nil
}
// Sign returns a signature for the data.
func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) {
return r.SignWithFlags(key, data, 0)
}
func (r *keyring) SignWithFlags(key ssh.PublicKey, data []byte, flags SignatureFlags) (*ssh.Signature, error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.locked {
return nil, errLocked
}
r.expireKeysLocked()
wanted := key.Marshal()
for _, k := range r.keys {
if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) {
return k.signer.Sign(rand.Reader, data)
if flags == 0 {
return k.signer.Sign(rand.Reader, data)
} else {
if algorithmSigner, ok := k.signer.(ssh.AlgorithmSigner); !ok {
return nil, fmt.Errorf("agent: signature does not support non-default signature algorithm: %T", k.signer)
} else {
var algorithm string
switch flags {
case SignatureFlagRsaSha256:
algorithm = ssh.SigAlgoRSASHA2256
case SignatureFlagRsaSha512:
algorithm = ssh.SigAlgoRSASHA2512
default:
return nil, fmt.Errorf("agent: unsupported signature flags: %d", flags)
}
return algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm)
}
}
}
}
return nil, errors.New("not found")
@ -175,9 +227,15 @@ func (r *keyring) Signers() ([]ssh.Signer, error) {
return nil, errLocked
}
r.expireKeysLocked()
s := make([]ssh.Signer, 0, len(r.keys))
for _, k := range r.keys {
s = append(s, k.signer)
}
return s, nil
}
// The keyring does not support any extensions
func (r *keyring) Extension(extensionType string, contents []byte) ([]byte, error) {
return nil, ErrExtensionUnsupported
}

76
vendor/golang.org/x/crypto/ssh/agent/keyring_test.go generated vendored Normal file
View file

@ -0,0 +1,76 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package agent
import "testing"
func addTestKey(t *testing.T, a Agent, keyName string) {
err := a.Add(AddedKey{
PrivateKey: testPrivateKeys[keyName],
Comment: keyName,
})
if err != nil {
t.Fatalf("failed to add key %q: %v", keyName, err)
}
}
func removeTestKey(t *testing.T, a Agent, keyName string) {
err := a.Remove(testPublicKeys[keyName])
if err != nil {
t.Fatalf("failed to remove key %q: %v", keyName, err)
}
}
func validateListedKeys(t *testing.T, a Agent, expectedKeys []string) {
listedKeys, err := a.List()
if err != nil {
t.Fatalf("failed to list keys: %v", err)
return
}
actualKeys := make(map[string]bool)
for _, key := range listedKeys {
actualKeys[key.Comment] = true
}
matchedKeys := make(map[string]bool)
for _, expectedKey := range expectedKeys {
if !actualKeys[expectedKey] {
t.Fatalf("expected key %q, but was not found", expectedKey)
} else {
matchedKeys[expectedKey] = true
}
}
for actualKey := range actualKeys {
if !matchedKeys[actualKey] {
t.Fatalf("key %q was found, but was not expected", actualKey)
}
}
}
func TestKeyringAddingAndRemoving(t *testing.T) {
keyNames := []string{"dsa", "ecdsa", "rsa", "user"}
// add all test private keys
k := NewKeyring()
for _, keyName := range keyNames {
addTestKey(t, k, keyName)
}
validateListedKeys(t, k, keyNames)
// remove a key in the middle
keyToRemove := keyNames[1]
keyNames = append(keyNames[:1], keyNames[2:]...)
removeTestKey(t, k, keyToRemove)
validateListedKeys(t, k, keyNames)
// remove all keys
err := k.RemoveAll()
if err != nil {
t.Fatalf("failed to remove all keys: %v", err)
}
validateListedKeys(t, k, []string{})
}

View file

@ -5,13 +5,18 @@
package agent
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"math/big"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh"
)
@ -49,6 +54,9 @@ func marshalKey(k *Key) []byte {
return ssh.Marshal(&record)
}
// See [PROTOCOL.agent], section 2.5.1.
const agentV1IdentitiesAnswer = 2
type agentV1IdentityMsg struct {
Numkeys uint32 `sshtype:"2"`
}
@ -69,6 +77,10 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
switch data[0] {
case agentRequestV1Identities:
return &agentV1IdentityMsg{0}, nil
case agentRemoveAllV1Identities:
return nil, nil
case agentRemoveIdentity:
var req agentRemoveIdentityMsg
if err := ssh.Unmarshal(data, &req); err != nil {
@ -94,7 +106,7 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
return nil, s.agent.Lock(req.Passphrase)
case agentUnlock:
var req agentLockMsg
var req agentUnlockMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
@ -116,11 +128,19 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
Blob: req.KeyBlob,
}
sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags.
var sig *ssh.Signature
var err error
if extendedAgent, ok := s.agent.(ExtendedAgent); ok {
sig, err = extendedAgent.SignWithFlags(k, req.Data, SignatureFlags(req.Flags))
} else {
sig, err = s.agent.Sign(k, req.Data)
}
if err != nil {
return nil, err
}
return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
case agentRequestIdentities:
keys, err := s.agent.List()
if err != nil {
@ -134,42 +154,380 @@ func (s *server) processRequest(data []byte) (interface{}, error) {
rep.Keys = append(rep.Keys, marshalKey(k)...)
}
return rep, nil
case agentAddIdentity:
case agentAddIDConstrained, agentAddIdentity:
return nil, s.insertIdentity(data)
case agentExtension:
// Return a stub object where the whole contents of the response gets marshaled.
var responseStub struct {
Rest []byte `ssh:"rest"`
}
if extendedAgent, ok := s.agent.(ExtendedAgent); !ok {
// If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7
// requires that we return a standard SSH_AGENT_FAILURE message.
responseStub.Rest = []byte{agentFailure}
} else {
var req extensionAgentMsg
if err := ssh.Unmarshal(data, &req); err != nil {
return nil, err
}
res, err := extendedAgent.Extension(req.ExtensionType, req.Contents)
if err != nil {
// If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE
// message as required by [PROTOCOL.agent] section 4.7.
if err == ErrExtensionUnsupported {
responseStub.Rest = []byte{agentFailure}
} else {
// As the result of any other error processing an extension request,
// [PROTOCOL.agent] section 4.7 requires that we return a
// SSH_AGENT_EXTENSION_FAILURE code.
responseStub.Rest = []byte{agentExtensionFailure}
}
} else {
if len(res) == 0 {
return nil, nil
}
responseStub.Rest = res
}
}
return responseStub, nil
}
return nil, fmt.Errorf("unknown opcode %d", data[0])
}
func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
for len(constraints) != 0 {
switch constraints[0] {
case agentConstrainLifetime:
lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
constraints = constraints[5:]
case agentConstrainConfirm:
confirmBeforeUse = true
constraints = constraints[1:]
case agentConstrainExtension:
var msg constrainExtensionAgentMsg
if err = ssh.Unmarshal(constraints, &msg); err != nil {
return 0, false, nil, err
}
extensions = append(extensions, ConstraintExtension{
ExtensionName: msg.ExtensionName,
ExtensionDetails: msg.ExtensionDetails,
})
constraints = msg.Rest
default:
return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
}
}
return
}
func setConstraints(key *AddedKey, constraintBytes []byte) error {
lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
if err != nil {
return err
}
key.LifetimeSecs = lifetimeSecs
key.ConfirmBeforeUse = confirmBeforeUse
key.ConstraintExtensions = constraintExtensions
return nil
}
func parseRSAKey(req []byte) (*AddedKey, error) {
var k rsaKeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
if k.E.BitLen() > 30 {
return nil, errors.New("agent: RSA public exponent too large")
}
priv := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
E: int(k.E.Int64()),
N: k.N,
},
D: k.D,
Primes: []*big.Int{k.P, k.Q},
}
priv.Precompute()
addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseEd25519Key(req []byte) (*AddedKey, error) {
var k ed25519KeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
priv := ed25519.PrivateKey(k.Priv)
addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseDSAKey(req []byte) (*AddedKey, error) {
var k dsaKeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
priv := &dsa.PrivateKey{
PublicKey: dsa.PublicKey{
Parameters: dsa.Parameters{
P: k.P,
Q: k.Q,
G: k.G,
},
Y: k.Y,
},
X: k.X,
}
addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
priv = &ecdsa.PrivateKey{
D: privScalar,
}
switch curveName {
case "nistp256":
priv.Curve = elliptic.P256()
case "nistp384":
priv.Curve = elliptic.P384()
case "nistp521":
priv.Curve = elliptic.P521()
default:
return nil, fmt.Errorf("agent: unknown curve %q", curveName)
}
priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes)
if priv.X == nil || priv.Y == nil {
return nil, errors.New("agent: point not on curve")
}
return priv, nil
}
func parseEd25519Cert(req []byte) (*AddedKey, error) {
var k ed25519CertMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
if err != nil {
return nil, err
}
priv := ed25519.PrivateKey(k.Priv)
cert, ok := pubKey.(*ssh.Certificate)
if !ok {
return nil, errors.New("agent: bad ED25519 certificate")
}
addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseECDSAKey(req []byte) (*AddedKey, error) {
var k ecdsaKeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D)
if err != nil {
return nil, err
}
addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseRSACert(req []byte) (*AddedKey, error) {
var k rsaCertMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
if err != nil {
return nil, err
}
cert, ok := pubKey.(*ssh.Certificate)
if !ok {
return nil, errors.New("agent: bad RSA certificate")
}
// An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go
var rsaPub struct {
Name string
E *big.Int
N *big.Int
}
if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil {
return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
}
if rsaPub.E.BitLen() > 30 {
return nil, errors.New("agent: RSA public exponent too large")
}
priv := rsa.PrivateKey{
PublicKey: rsa.PublicKey{
E: int(rsaPub.E.Int64()),
N: rsaPub.N,
},
D: k.D,
Primes: []*big.Int{k.Q, k.P},
}
priv.Precompute()
addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseDSACert(req []byte) (*AddedKey, error) {
var k dsaCertMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
if err != nil {
return nil, err
}
cert, ok := pubKey.(*ssh.Certificate)
if !ok {
return nil, errors.New("agent: bad DSA certificate")
}
// A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go
var w struct {
Name string
P, Q, G, Y *big.Int
}
if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil {
return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
}
priv := &dsa.PrivateKey{
PublicKey: dsa.PublicKey{
Parameters: dsa.Parameters{
P: w.P,
Q: w.Q,
G: w.G,
},
Y: w.Y,
},
X: k.X,
}
addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func parseECDSACert(req []byte) (*AddedKey, error) {
var k ecdsaCertMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return nil, err
}
pubKey, err := ssh.ParsePublicKey(k.CertBytes)
if err != nil {
return nil, err
}
cert, ok := pubKey.(*ssh.Certificate)
if !ok {
return nil, errors.New("agent: bad ECDSA certificate")
}
// An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go
var ecdsaPub struct {
Name string
ID string
Key []byte
}
if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil {
return nil, err
}
priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D)
if err != nil {
return nil, err
}
addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
if err := setConstraints(addedKey, k.Constraints); err != nil {
return nil, err
}
return addedKey, nil
}
func (s *server) insertIdentity(req []byte) error {
var record struct {
Type string `sshtype:"17"`
Type string `sshtype:"17|25"`
Rest []byte `ssh:"rest"`
}
if err := ssh.Unmarshal(req, &record); err != nil {
return err
}
var addedKey *AddedKey
var err error
switch record.Type {
case ssh.KeyAlgoRSA:
var k rsaKeyMsg
if err := ssh.Unmarshal(req, &k); err != nil {
return err
}
priv := rsa.PrivateKey{
PublicKey: rsa.PublicKey{
E: int(k.E.Int64()),
N: k.N,
},
D: k.D,
Primes: []*big.Int{k.P, k.Q},
}
priv.Precompute()
return s.agent.Add(&priv, nil, k.Comments)
addedKey, err = parseRSAKey(req)
case ssh.KeyAlgoDSA:
addedKey, err = parseDSAKey(req)
case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
addedKey, err = parseECDSAKey(req)
case ssh.KeyAlgoED25519:
addedKey, err = parseEd25519Key(req)
case ssh.CertAlgoRSAv01:
addedKey, err = parseRSACert(req)
case ssh.CertAlgoDSAv01:
addedKey, err = parseDSACert(req)
case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01:
addedKey, err = parseECDSACert(req)
case ssh.CertAlgoED25519v01:
addedKey, err = parseEd25519Cert(req)
default:
return fmt.Errorf("agent: not implemented: %q", record.Type)
}
return fmt.Errorf("not implemented: %s", record.Type)
if err != nil {
return err
}
return s.agent.Add(*addedKey)
}
// ServeAgent serves the agent protocol on the given connection. It

View file

@ -5,6 +5,12 @@
package agent
import (
"crypto"
"crypto/rand"
"fmt"
pseudorand "math/rand"
"reflect"
"strings"
"testing"
"golang.org/x/crypto/ssh"
@ -21,7 +27,7 @@ func TestServer(t *testing.T) {
go ServeAgent(NewKeyring(), c2)
testAgentInterface(t, client, testPrivateKeys["rsa"], nil)
testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0)
}
func TestLockServer(t *testing.T) {
@ -37,7 +43,7 @@ func TestSetupForwardAgent(t *testing.T) {
defer a.Close()
defer b.Close()
_, socket, cleanup := startAgent(t)
_, socket, cleanup := startOpenSSHAgent(t)
defer cleanup()
serverConf := ssh.ServerConfig{
@ -53,7 +59,9 @@ func TestSetupForwardAgent(t *testing.T) {
incoming <- conn
}()
conf := ssh.ClientConfig{}
conf := ssh.ClientConfig{
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf)
if err != nil {
t.Fatalf("NewClientConn: %v", err)
@ -72,6 +80,180 @@ func TestSetupForwardAgent(t *testing.T) {
go ssh.DiscardRequests(reqs)
agentClient := NewClient(ch)
testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil)
testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0)
conn.Close()
}
func TestV1ProtocolMessages(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
c := NewClient(c1)
go ServeAgent(NewKeyring(), c2)
testV1ProtocolMessages(t, c.(*client))
}
func testV1ProtocolMessages(t *testing.T, c *client) {
reply, err := c.call([]byte{agentRequestV1Identities})
if err != nil {
t.Fatalf("v1 request all failed: %v", err)
}
if msg, ok := reply.(*agentV1IdentityMsg); !ok || msg.Numkeys != 0 {
t.Fatalf("invalid request all response: %#v", reply)
}
reply, err = c.call([]byte{agentRemoveAllV1Identities})
if err != nil {
t.Fatalf("v1 remove all failed: %v", err)
}
if _, ok := reply.(*successAgentMsg); !ok {
t.Fatalf("invalid remove all response: %#v", reply)
}
}
func verifyKey(sshAgent Agent) error {
keys, err := sshAgent.List()
if err != nil {
return fmt.Errorf("listing keys: %v", err)
}
if len(keys) != 1 {
return fmt.Errorf("bad number of keys found. expected 1, got %d", len(keys))
}
buf := make([]byte, 128)
if _, err := rand.Read(buf); err != nil {
return fmt.Errorf("rand: %v", err)
}
sig, err := sshAgent.Sign(keys[0], buf)
if err != nil {
return fmt.Errorf("sign: %v", err)
}
if err := keys[0].Verify(buf, sig); err != nil {
return fmt.Errorf("verify: %v", err)
}
return nil
}
func addKeyToAgent(key crypto.PrivateKey) error {
sshAgent := NewKeyring()
if err := sshAgent.Add(AddedKey{PrivateKey: key}); err != nil {
return fmt.Errorf("add: %v", err)
}
return verifyKey(sshAgent)
}
func TestKeyTypes(t *testing.T) {
for k, v := range testPrivateKeys {
if err := addKeyToAgent(v); err != nil {
t.Errorf("error adding key type %s, %v", k, err)
}
if err := addCertToAgentSock(v, nil); err != nil {
t.Errorf("error adding key type %s, %v", k, err)
}
}
}
func addCertToAgentSock(key crypto.PrivateKey, cert *ssh.Certificate) error {
a, b, err := netPipe()
if err != nil {
return err
}
agentServer := NewKeyring()
go ServeAgent(agentServer, a)
agentClient := NewClient(b)
if err := agentClient.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
return fmt.Errorf("add: %v", err)
}
return verifyKey(agentClient)
}
func addCertToAgent(key crypto.PrivateKey, cert *ssh.Certificate) error {
sshAgent := NewKeyring()
if err := sshAgent.Add(AddedKey{PrivateKey: key, Certificate: cert}); err != nil {
return fmt.Errorf("add: %v", err)
}
return verifyKey(sshAgent)
}
func TestCertTypes(t *testing.T) {
for keyType, key := range testPublicKeys {
cert := &ssh.Certificate{
ValidPrincipals: []string{"gopher1"},
ValidAfter: 0,
ValidBefore: ssh.CertTimeInfinity,
Key: key,
Serial: 1,
CertType: ssh.UserCert,
SignatureKey: testPublicKeys["rsa"],
Permissions: ssh.Permissions{
CriticalOptions: map[string]string{},
Extensions: map[string]string{},
},
}
if err := cert.SignCert(rand.Reader, testSigners["rsa"]); err != nil {
t.Fatalf("signcert: %v", err)
}
if err := addCertToAgent(testPrivateKeys[keyType], cert); err != nil {
t.Fatalf("%v", err)
}
if err := addCertToAgentSock(testPrivateKeys[keyType], cert); err != nil {
t.Fatalf("%v", err)
}
}
}
func TestParseConstraints(t *testing.T) {
// Test LifetimeSecs
var msg = constrainLifetimeAgentMsg{pseudorand.Uint32()}
lifetimeSecs, _, _, err := parseConstraints(ssh.Marshal(msg))
if err != nil {
t.Fatalf("parseConstraints: %v", err)
}
if lifetimeSecs != msg.LifetimeSecs {
t.Errorf("got lifetime %v, want %v", lifetimeSecs, msg.LifetimeSecs)
}
// Test ConfirmBeforeUse
_, confirmBeforeUse, _, err := parseConstraints([]byte{agentConstrainConfirm})
if err != nil {
t.Fatalf("%v", err)
}
if !confirmBeforeUse {
t.Error("got comfirmBeforeUse == false")
}
// Test ConstraintExtensions
var data []byte
var expect []ConstraintExtension
for i := 0; i < 10; i++ {
var ext = ConstraintExtension{
ExtensionName: fmt.Sprintf("name%d", i),
ExtensionDetails: []byte(fmt.Sprintf("details: %d", i)),
}
expect = append(expect, ext)
data = append(data, agentConstrainExtension)
data = append(data, ssh.Marshal(ext)...)
}
_, _, extensions, err := parseConstraints(data)
if err != nil {
t.Fatalf("%v", err)
}
if !reflect.DeepEqual(expect, extensions) {
t.Errorf("got extension %v, want %v", extensions, expect)
}
// Test Unknown Constraint
_, _, _, err = parseConstraints([]byte{128})
if err == nil || !strings.Contains(err.Error(), "unknown constraint") {
t.Errorf("unexpected error: %v", err)
}
}

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places:
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
// instances.

View file

@ -40,7 +40,8 @@ func sshPipe() (Conn, *server, error) {
}
clientConf := ClientConfig{
User: "user",
User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
}
serverConf := ServerConfig{
NoClientAuth: true,

View file

@ -51,13 +51,12 @@ func (b *buffer) write(buf []byte) {
}
// eof closes the buffer. Reads from the buffer once all
// the data has been consumed will receive os.EOF.
func (b *buffer) eof() error {
// the data has been consumed will receive io.EOF.
func (b *buffer) eof() {
b.Cond.L.Lock()
b.closed = true
b.Cond.Signal()
b.Cond.L.Unlock()
return nil
}
// Read reads data from the internal buffer in buf. Reads will block

View file

@ -22,6 +22,7 @@ const (
CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com"
CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com"
CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com"
CertAlgoED25519v01 = "ssh-ed25519-cert-v01@openssh.com"
)
// Certificate types distinguish between host and user
@ -43,7 +44,9 @@ type Signature struct {
const CertTimeInfinity = 1<<64 - 1
// An Certificate represents an OpenSSH certificate as defined in
// [PROTOCOL.certkeys]?rev=1.8.
// [PROTOCOL.certkeys]?rev=1.8. The Certificate type implements the
// PublicKey interface, so it can be unmarshaled using
// ParsePublicKey.
type Certificate struct {
Nonce []byte
Key PublicKey
@ -85,46 +88,73 @@ func marshalStringList(namelist []string) []byte {
return to
}
type optionsTuple struct {
Key string
Value []byte
}
type optionsTupleValue struct {
Value string
}
// serialize a map of critical options or extensions
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
// we need two length prefixes for a non-empty string value
func marshalTuples(tups map[string]string) []byte {
keys := make([]string, 0, len(tups))
for k := range tups {
keys = append(keys, k)
for key := range tups {
keys = append(keys, key)
}
sort.Strings(keys)
var r []byte
for _, k := range keys {
s := struct{ K, V string }{k, tups[k]}
r = append(r, Marshal(&s)...)
var ret []byte
for _, key := range keys {
s := optionsTuple{Key: key}
if value := tups[key]; len(value) > 0 {
s.Value = Marshal(&optionsTupleValue{value})
}
ret = append(ret, Marshal(&s)...)
}
return r
return ret
}
// issue #10569 - per [PROTOCOL.certkeys] and SSH implementation,
// we need two length prefixes for a non-empty option value
func parseTuples(in []byte) (map[string]string, error) {
tups := map[string]string{}
var lastKey string
var haveLastKey bool
for len(in) > 0 {
nameBytes, rest, ok := parseString(in)
if !ok {
return nil, errShortRead
}
data, rest, ok := parseString(rest)
if !ok {
return nil, errShortRead
}
name := string(nameBytes)
var key, val, extra []byte
var ok bool
if key, in, ok = parseString(in); !ok {
return nil, errShortRead
}
keyStr := string(key)
// according to [PROTOCOL.certkeys], the names must be in
// lexical order.
if haveLastKey && name <= lastKey {
if haveLastKey && keyStr <= lastKey {
return nil, fmt.Errorf("ssh: certificate options are not in lexical order")
}
lastKey, haveLastKey = name, true
tups[name] = string(data)
in = rest
lastKey, haveLastKey = keyStr, true
// the next field is a data field, which if non-empty has a string embedded
if val, in, ok = parseString(in); !ok {
return nil, errShortRead
}
if len(val) > 0 {
val, extra, ok = parseString(val)
if !ok {
return nil, errShortRead
}
if len(extra) > 0 {
return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value")
}
tups[keyStr] = string(val)
} else {
tups[keyStr] = ""
}
}
return tups, nil
}
@ -192,6 +222,11 @@ type openSSHCertSigner struct {
signer Signer
}
type algorithmOpenSSHCertSigner struct {
*openSSHCertSigner
algorithmSigner AlgorithmSigner
}
// NewCertSigner returns a Signer that signs with the given Certificate, whose
// private key is held by signer. It returns an error if the public key in cert
// doesn't match the key used by signer.
@ -200,7 +235,12 @@ func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) {
return nil, errors.New("ssh: signer and cert have different public key")
}
return &openSSHCertSigner{cert, signer}, nil
if algorithmSigner, ok := signer.(AlgorithmSigner); ok {
return &algorithmOpenSSHCertSigner{
&openSSHCertSigner{cert, signer}, algorithmSigner}, nil
} else {
return &openSSHCertSigner{cert, signer}, nil
}
}
func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
@ -211,6 +251,10 @@ func (s *openSSHCertSigner) PublicKey() PublicKey {
return s.pub
}
func (s *algorithmOpenSSHCertSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
return s.algorithmSigner.SignWithAlgorithm(rand, data, algorithm)
}
const sourceAddressCriticalOption = "source-address"
// CertChecker does the work of verifying a certificate. Its methods
@ -223,10 +267,18 @@ type CertChecker struct {
// for user certificates.
SupportedCriticalOptions []string
// IsAuthority should return true if the key is recognized as
// an authority. This allows for certificates to be signed by other
// certificates.
IsAuthority func(auth PublicKey) bool
// IsUserAuthority should return true if the key is recognized as an
// authority for the given user certificate. This allows for
// certificates to be signed by other certificates. This must be set
// if this CertChecker will be checking user certificates.
IsUserAuthority func(auth PublicKey) bool
// IsHostAuthority should report whether the key is recognized as
// an authority for this host. This allows for certificates to be
// signed by other keys, and for those other keys to only be valid
// signers for particular hostnames. This must be set if this
// CertChecker will be checking host certificates.
IsHostAuthority func(auth PublicKey, address string) bool
// Clock is used for verifying time stamps. If nil, time.Now
// is used.
@ -240,7 +292,7 @@ type CertChecker struct {
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a
// public key that is not a certificate. It must implement host key
// validation or else, if nil, all such keys are rejected.
HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error
HostKeyFallback HostKeyCallback
// IsRevoked is called for each certificate so that revocation checking
// can be implemented. It should return true if the given certificate
@ -262,8 +314,17 @@ func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey)
if cert.CertType != HostCert {
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType)
}
if !c.IsHostAuthority(cert.SignatureKey, addr) {
return fmt.Errorf("ssh: no authorities for hostname: %v", addr)
}
return c.CheckCert(addr, cert)
hostname, _, err := net.SplitHostPort(addr)
if err != nil {
return err
}
// Pass hostname only as principal for host certificates (consistent with OpenSSH)
return c.CheckCert(hostname, cert)
}
// Authenticate checks a user certificate. Authenticate can be used as
@ -280,6 +341,9 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis
if cert.CertType != UserCert {
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType)
}
if !c.IsUserAuthority(cert.SignatureKey) {
return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority")
}
if err := c.CheckCert(conn.User(), cert); err != nil {
return nil, err
@ -292,10 +356,10 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis
// the signature of the certificate.
func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
if c.IsRevoked != nil && c.IsRevoked(cert) {
return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial)
return fmt.Errorf("ssh: certificate serial %d revoked", cert.Serial)
}
for opt, _ := range cert.CriticalOptions {
for opt := range cert.CriticalOptions {
// sourceAddressCriticalOption will be enforced by
// serverAuthenticate
if opt == sourceAddressCriticalOption {
@ -328,10 +392,6 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
}
}
if !c.IsAuthority(cert.SignatureKey) {
return fmt.Errorf("ssh: certificate signed by unrecognized authority")
}
clock := c.Clock
if clock == nil {
clock = time.Now
@ -341,7 +401,7 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error {
if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) {
return fmt.Errorf("ssh: cert is not yet valid")
}
if before := int64(cert.ValidBefore); cert.ValidBefore != CertTimeInfinity && (unixNow >= before || before < 0) {
if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) {
return fmt.Errorf("ssh: cert has expired")
}
if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil {
@ -374,6 +434,7 @@ var certAlgoNames = map[string]string{
KeyAlgoECDSA256: CertAlgoECDSA256v01,
KeyAlgoECDSA384: CertAlgoECDSA384v01,
KeyAlgoECDSA521: CertAlgoECDSA521v01,
KeyAlgoED25519: CertAlgoED25519v01,
}
// certToPrivAlgo returns the underlying algorithm for a certificate algorithm.
@ -432,7 +493,7 @@ func (c *Certificate) Marshal() []byte {
func (c *Certificate) Type() string {
algo, ok := certAlgoNames[c.Key.Type()]
if !ok {
panic("unknown cert key type")
panic("unknown cert key type " + c.Key.Type())
}
return algo
}

View file

@ -6,14 +6,20 @@ package ssh
import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"net"
"reflect"
"testing"
"time"
"golang.org/x/crypto/ssh/testdata"
)
// Cert generated by ssh-keygen 6.0p1 Debian-4.
// % ssh-keygen -s ca-key -I test user-key
var exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=`
const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=`
func TestParseCert(t *testing.T) {
authKeyBytes := []byte(exampleSSHCert)
@ -27,7 +33,7 @@ func TestParseCert(t *testing.T) {
}
if _, ok := key.(*Certificate); !ok {
t.Fatalf("got %#v, want *Certificate", key)
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
marshaled := MarshalAuthorizedKey(key)
@ -39,6 +45,60 @@ func TestParseCert(t *testing.T) {
}
}
// Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3
// % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub
// user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN
// Critical Options:
// force-command /bin/sleep
// source-address 192.168.1.0/24
// Extensions:
// permit-X11-forwarding
// permit-agent-forwarding
// permit-port-forwarding
// permit-pty
// permit-user-rc
const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ`
func TestParseCertWithOptions(t *testing.T) {
opts := map[string]string{
"source-address": "192.168.1.0/24",
"force-command": "/bin/sleep",
}
exts := map[string]string{
"permit-X11-forwarding": "",
"permit-agent-forwarding": "",
"permit-port-forwarding": "",
"permit-pty": "",
"permit-user-rc": "",
}
authKeyBytes := []byte(exampleSSHCertWithOptions)
key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes)
if err != nil {
t.Fatalf("ParseAuthorizedKey: %v", err)
}
if len(rest) > 0 {
t.Errorf("rest: got %q, want empty", rest)
}
cert, ok := key.(*Certificate)
if !ok {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
if !reflect.DeepEqual(cert.CriticalOptions, opts) {
t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts)
}
if !reflect.DeepEqual(cert.Extensions, exts) {
t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts)
}
marshaled := MarshalAuthorizedKey(key)
// Before comparison, remove the trailing newline that
// MarshalAuthorizedKey adds.
marshaled = marshaled[:len(marshaled)-1]
if !bytes.Equal(authKeyBytes, marshaled) {
t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes)
}
}
func TestValidateCert(t *testing.T) {
key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert))
if err != nil {
@ -49,7 +109,7 @@ func TestValidateCert(t *testing.T) {
t.Fatalf("got %v (%T), want *Certificate", key, key)
}
checker := CertChecker{}
checker.IsAuthority = func(k PublicKey) bool {
checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal())
}
@ -87,7 +147,7 @@ func TestValidateCertTime(t *testing.T) {
checker := CertChecker{
Clock: func() time.Time { return time.Unix(ts, 0) },
}
checker.IsAuthority = func(k PublicKey) bool {
checker.IsUserAuthority = func(k PublicKey) bool {
return bytes.Equal(k.Marshal(),
testPublicKeys["ecdsa"].Marshal())
}
@ -105,7 +165,7 @@ func TestValidateCertTime(t *testing.T) {
func TestHostKeyCert(t *testing.T) {
cert := &Certificate{
ValidPrincipals: []string{"hostname", "hostname.domain"},
ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"},
Key: testPublicKeys["rsa"],
ValidBefore: CertTimeInfinity,
CertType: HostCert,
@ -113,8 +173,8 @@ func TestHostKeyCert(t *testing.T) {
cert.SignCert(rand.Reader, testSigners["ecdsa"])
checker := &CertChecker{
IsAuthority: func(p PublicKey) bool {
return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
IsHostAuthority: func(p PublicKey, addr string) bool {
return addr == "hostname:22" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal())
},
}
@ -123,7 +183,14 @@ func TestHostKeyCert(t *testing.T) {
t.Errorf("NewCertSigner: %v", err)
}
for _, name := range []string{"hostname", "otherhost"} {
for _, test := range []struct {
addr string
succeed bool
}{
{addr: "hostname:22", succeed: true},
{addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22'
{addr: "lasthost:22", succeed: false},
} {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
@ -131,26 +198,138 @@ func TestHostKeyCert(t *testing.T) {
defer c1.Close()
defer c2.Close()
errc := make(chan error)
go func() {
conf := ServerConfig{
NoClientAuth: true,
}
conf.AddHostKey(certSigner)
_, _, _, err := NewServerConn(c1, &conf)
if err != nil {
t.Fatalf("NewServerConn: %v", err)
}
errc <- err
}()
config := &ClientConfig{
User: "user",
HostKeyCallback: checker.CheckHostKey,
}
_, _, _, err = NewClientConn(c2, name, config)
_, _, _, err = NewClientConn(c2, test.addr, config)
succeed := name == "hostname"
if (err == nil) != succeed {
t.Fatalf("NewClientConn(%q): %v", name, err)
if (err == nil) != test.succeed {
t.Fatalf("NewClientConn(%q): %v", test.addr, err)
}
err = <-errc
if (err == nil) != test.succeed {
t.Fatalf("NewServerConn(%q): %v", test.addr, err)
}
}
}
func TestCertTypes(t *testing.T) {
var testVars = []struct {
name string
keys func() Signer
}{
{
name: CertAlgoECDSA256v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ecdsap256"])
return s
},
},
{
name: CertAlgoECDSA384v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ecdsap384"])
return s
},
},
{
name: CertAlgoECDSA521v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ecdsap521"])
return s
},
},
{
name: CertAlgoED25519v01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["ed25519"])
return s
},
},
{
name: CertAlgoRSAv01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["rsa"])
return s
},
},
{
name: CertAlgoDSAv01,
keys: func() Signer {
s, _ := ParsePrivateKey(testdata.PEMBytes["dsa"])
return s
},
},
}
k, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("error generating host key: %v", err)
}
signer, err := NewSignerFromKey(k)
if err != nil {
t.Fatalf("error generating signer for ssh listener: %v", err)
}
conf := &ServerConfig{
PublicKeyCallback: func(c ConnMetadata, k PublicKey) (*Permissions, error) {
return new(Permissions), nil
},
}
conf.AddHostKey(signer)
for _, m := range testVars {
t.Run(m.name, func(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, conf)
priv := m.keys()
if err != nil {
t.Fatalf("error generating ssh pubkey: %v", err)
}
cert := &Certificate{
CertType: UserCert,
Key: priv.PublicKey(),
}
cert.SignCert(rand.Reader, priv)
certSigner, err := NewCertSigner(cert, priv)
if err != nil {
t.Fatalf("error generating cert signer: %v", err)
}
config := &ClientConfig{
User: "user",
HostKeyCallback: func(h string, r net.Addr, k PublicKey) error { return nil },
Auth: []AuthMethod{PublicKeys(certSigner)},
}
_, _, _, err = NewClientConn(c2, "", config)
if err != nil {
t.Fatalf("error connecting: %v", err)
}
})
}
}

View file

@ -67,6 +67,8 @@ type Channel interface {
// boolean, otherwise the return value will be false. Channel
// requests are out-of-band messages so they may be sent even
// if the data stream is closed or blocked by flow control.
// If the channel is closed before a reply is returned, io.EOF
// is returned.
SendRequest(name string, wantReply bool, payload []byte) (bool, error)
// Stderr returns an io.ReadWriter that writes to this channel
@ -203,32 +205,32 @@ type channel struct {
// writePacket sends a packet. If the packet is a channel close, it updates
// sentClose. This method takes the lock c.writeMu.
func (c *channel) writePacket(packet []byte) error {
c.writeMu.Lock()
if c.sentClose {
c.writeMu.Unlock()
func (ch *channel) writePacket(packet []byte) error {
ch.writeMu.Lock()
if ch.sentClose {
ch.writeMu.Unlock()
return io.EOF
}
c.sentClose = (packet[0] == msgChannelClose)
err := c.mux.conn.writePacket(packet)
c.writeMu.Unlock()
ch.sentClose = (packet[0] == msgChannelClose)
err := ch.mux.conn.writePacket(packet)
ch.writeMu.Unlock()
return err
}
func (c *channel) sendMessage(msg interface{}) error {
func (ch *channel) sendMessage(msg interface{}) error {
if debugMux {
log.Printf("send %d: %#v", c.mux.chanList.offset, msg)
log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
}
p := Marshal(msg)
binary.BigEndian.PutUint32(p[1:], c.remoteId)
return c.writePacket(p)
binary.BigEndian.PutUint32(p[1:], ch.remoteId)
return ch.writePacket(p)
}
// WriteExtended writes data to a specific extended stream. These streams are
// used, for example, for stderr.
func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if c.sentEOF {
func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
if ch.sentEOF {
return 0, io.EOF
}
// 1 byte message type, 4 bytes remoteId, 4 bytes data length
@ -239,16 +241,16 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
opCode = msgChannelExtendedData
}
c.writeMu.Lock()
packet := c.packetPool[extendedCode]
ch.writeMu.Lock()
packet := ch.packetPool[extendedCode]
// We don't remove the buffer from packetPool, so
// WriteExtended calls from different goroutines will be
// flagged as errors by the race detector.
c.writeMu.Unlock()
ch.writeMu.Unlock()
for len(data) > 0 {
space := min(c.maxRemotePayload, len(data))
if space, err = c.remoteWin.reserve(space); err != nil {
space := min(ch.maxRemotePayload, len(data))
if space, err = ch.remoteWin.reserve(space); err != nil {
return n, err
}
if want := headerLength + space; uint32(cap(packet)) < want {
@ -260,13 +262,13 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
todo := data[:space]
packet[0] = opCode
binary.BigEndian.PutUint32(packet[1:], c.remoteId)
binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
if extendedCode > 0 {
binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
}
binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
copy(packet[headerLength:], todo)
if err = c.writePacket(packet); err != nil {
if err = ch.writePacket(packet); err != nil {
return n, err
}
@ -274,14 +276,14 @@ func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err er
data = data[len(todo):]
}
c.writeMu.Lock()
c.packetPool[extendedCode] = packet
c.writeMu.Unlock()
ch.writeMu.Lock()
ch.packetPool[extendedCode] = packet
ch.writeMu.Unlock()
return n, err
}
func (c *channel) handleData(packet []byte) error {
func (ch *channel) handleData(packet []byte) error {
headerLen := 9
isExtendedData := packet[0] == msgChannelExtendedData
if isExtendedData {
@ -301,7 +303,7 @@ func (c *channel) handleData(packet []byte) error {
if length == 0 {
return nil
}
if length > c.maxIncomingPayload {
if length > ch.maxIncomingPayload {
// TODO(hanwen): should send Disconnect?
return errors.New("ssh: incoming packet exceeds maximum payload size")
}
@ -311,21 +313,21 @@ func (c *channel) handleData(packet []byte) error {
return errors.New("ssh: wrong packet length")
}
c.windowMu.Lock()
if c.myWindow < length {
c.windowMu.Unlock()
ch.windowMu.Lock()
if ch.myWindow < length {
ch.windowMu.Unlock()
// TODO(hanwen): should send Disconnect with reason?
return errors.New("ssh: remote side wrote too much")
}
c.myWindow -= length
c.windowMu.Unlock()
ch.myWindow -= length
ch.windowMu.Unlock()
if extended == 1 {
c.extPending.write(data)
ch.extPending.write(data)
} else if extended > 0 {
// discard other extended data.
} else {
c.pending.write(data)
ch.pending.write(data)
}
return nil
}
@ -371,7 +373,7 @@ func (c *channel) close() {
close(c.msg)
close(c.incomingRequests)
c.writeMu.Lock()
// This is not necesary for a normal channel teardown, but if
// This is not necessary for a normal channel teardown, but if
// there was another error, it is.
c.sentClose = true
c.writeMu.Unlock()
@ -382,31 +384,31 @@ func (c *channel) close() {
// responseMessageReceived is called when a success or failure message is
// received on a channel to check that such a message is reasonable for the
// given channel.
func (c *channel) responseMessageReceived() error {
if c.direction == channelInbound {
func (ch *channel) responseMessageReceived() error {
if ch.direction == channelInbound {
return errors.New("ssh: channel response message received on inbound channel")
}
if c.decided {
if ch.decided {
return errors.New("ssh: duplicate response received for channel")
}
c.decided = true
ch.decided = true
return nil
}
func (c *channel) handlePacket(packet []byte) error {
func (ch *channel) handlePacket(packet []byte) error {
switch packet[0] {
case msgChannelData, msgChannelExtendedData:
return c.handleData(packet)
return ch.handleData(packet)
case msgChannelClose:
c.sendMessage(channelCloseMsg{PeersId: c.remoteId})
c.mux.chanList.remove(c.localId)
c.close()
ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
ch.mux.chanList.remove(ch.localId)
ch.close()
return nil
case msgChannelEOF:
// RFC 4254 is mute on how EOF affects dataExt messages but
// it is logical to signal EOF at the same time.
c.extPending.eof()
c.pending.eof()
ch.extPending.eof()
ch.pending.eof()
return nil
}
@ -417,24 +419,24 @@ func (c *channel) handlePacket(packet []byte) error {
switch msg := decoded.(type) {
case *channelOpenFailureMsg:
if err := c.responseMessageReceived(); err != nil {
if err := ch.responseMessageReceived(); err != nil {
return err
}
c.mux.chanList.remove(msg.PeersId)
c.msg <- msg
ch.mux.chanList.remove(msg.PeersID)
ch.msg <- msg
case *channelOpenConfirmMsg:
if err := c.responseMessageReceived(); err != nil {
if err := ch.responseMessageReceived(); err != nil {
return err
}
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
}
c.remoteId = msg.MyId
c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.MyWindow)
c.msg <- msg
ch.remoteId = msg.MyID
ch.maxRemotePayload = msg.MaxPacketSize
ch.remoteWin.add(msg.MyWindow)
ch.msg <- msg
case *windowAdjustMsg:
if !c.remoteWin.add(msg.AdditionalBytes) {
if !ch.remoteWin.add(msg.AdditionalBytes) {
return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
}
case *channelRequestMsg:
@ -442,12 +444,12 @@ func (c *channel) handlePacket(packet []byte) error {
Type: msg.Request,
WantReply: msg.WantReply,
Payload: msg.RequestSpecificData,
ch: c,
ch: ch,
}
c.incomingRequests <- &req
ch.incomingRequests <- &req
default:
c.msg <- msg
ch.msg <- msg
}
return nil
}
@ -459,8 +461,8 @@ func (m *mux) newChannel(chanType string, direction channelDirection, extraData
pending: newBuffer(),
extPending: newBuffer(),
direction: direction,
incomingRequests: make(chan *Request, 16),
msg: make(chan interface{}, 16),
incomingRequests: make(chan *Request, chanSize),
msg: make(chan interface{}, chanSize),
chanType: chanType,
extraData: extraData,
mux: m,
@ -486,23 +488,23 @@ func (e *extChannel) Read(data []byte) (n int, err error) {
return e.ch.ReadExtended(data, e.code)
}
func (c *channel) Accept() (Channel, <-chan *Request, error) {
if c.decided {
func (ch *channel) Accept() (Channel, <-chan *Request, error) {
if ch.decided {
return nil, nil, errDecidedAlready
}
c.maxIncomingPayload = channelMaxPacket
ch.maxIncomingPayload = channelMaxPacket
confirm := channelOpenConfirmMsg{
PeersId: c.remoteId,
MyId: c.localId,
MyWindow: c.myWindow,
MaxPacketSize: c.maxIncomingPayload,
PeersID: ch.remoteId,
MyID: ch.localId,
MyWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload,
}
c.decided = true
if err := c.sendMessage(confirm); err != nil {
ch.decided = true
if err := ch.sendMessage(confirm); err != nil {
return nil, nil, err
}
return c, c.incomingRequests, nil
return ch, ch.incomingRequests, nil
}
func (ch *channel) Reject(reason RejectionReason, message string) error {
@ -510,7 +512,7 @@ func (ch *channel) Reject(reason RejectionReason, message string) error {
return errDecidedAlready
}
reject := channelOpenFailureMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
Reason: reason,
Message: message,
Language: "en",
@ -539,7 +541,7 @@ func (ch *channel) CloseWrite() error {
}
ch.sentEOF = true
return ch.sendMessage(channelEOFMsg{
PeersId: ch.remoteId})
PeersID: ch.remoteId})
}
func (ch *channel) Close() error {
@ -548,7 +550,7 @@ func (ch *channel) Close() error {
}
return ch.sendMessage(channelCloseMsg{
PeersId: ch.remoteId})
PeersID: ch.remoteId})
}
// Extended returns an io.ReadWriter that sends and receives data on the given,
@ -575,7 +577,7 @@ func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (boo
}
msg := channelRequestMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
Request: name,
WantReply: wantReply,
RequestSpecificData: payload,
@ -612,11 +614,11 @@ func (ch *channel) ackRequest(ok bool) error {
var msg interface{}
if !ok {
msg = channelRequestFailureMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
}
} else {
msg = channelRequestSuccessMsg{
PeersId: ch.remoteId,
PeersID: ch.remoteId,
}
}
return ch.sendMessage(msg)

View file

@ -7,6 +7,7 @@ package ssh
import (
"crypto/aes"
"crypto/cipher"
"crypto/des"
"crypto/rc4"
"crypto/subtle"
"encoding/binary"
@ -14,6 +15,11 @@ import (
"fmt"
"hash"
"io"
"io/ioutil"
"math/bits"
"golang.org/x/crypto/internal/chacha20"
"golang.org/x/crypto/poly1305"
)
const (
@ -51,68 +57,78 @@ func newRC4(key, iv []byte) (cipher.Stream, error) {
return rc4.NewCipher(key)
}
type streamCipherMode struct {
keySize int
ivSize int
skip int
createFunc func(key, iv []byte) (cipher.Stream, error)
type cipherMode struct {
keySize int
ivSize int
create func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error)
}
func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) {
if len(key) < c.keySize {
panic("ssh: key length too small for cipher")
}
if len(iv) < c.ivSize {
panic("ssh: iv too small for cipher")
}
stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize])
if err != nil {
return nil, err
}
var streamDump []byte
if c.skip > 0 {
streamDump = make([]byte, 512)
}
for remainingToDump := c.skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
func streamCipherMode(skip int, createFunc func(key, iv []byte) (cipher.Stream, error)) func(key, iv []byte, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
return func(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
stream, err := createFunc(key, iv)
if err != nil {
return nil, err
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
return stream, nil
var streamDump []byte
if skip > 0 {
streamDump = make([]byte, 512)
}
for remainingToDump := skip; remainingToDump > 0; {
dumpThisTime := remainingToDump
if dumpThisTime > len(streamDump) {
dumpThisTime = len(streamDump)
}
stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime])
remainingToDump -= dumpThisTime
}
mac := macModes[algs.MAC].new(macKey)
return &streamPacketCipher{
mac: mac,
etm: macModes[algs.MAC].etm,
macResult: make([]byte, mac.Size()),
cipher: stream,
}, nil
}
}
// cipherModes documents properties of supported ciphers. Ciphers not included
// are not supported and will not be negotiated, even if explicitly requested in
// ClientConfig.Crypto.Ciphers.
var cipherModes = map[string]*streamCipherMode{
var cipherModes = map[string]*cipherMode{
// Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms
// are defined in the order specified in the RFC.
"aes128-ctr": {16, aes.BlockSize, 0, newAESCTR},
"aes192-ctr": {24, aes.BlockSize, 0, newAESCTR},
"aes256-ctr": {32, aes.BlockSize, 0, newAESCTR},
"aes128-ctr": {16, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes192-ctr": {24, aes.BlockSize, streamCipherMode(0, newAESCTR)},
"aes256-ctr": {32, aes.BlockSize, streamCipherMode(0, newAESCTR)},
// Ciphers from RFC4345, which introduces security-improved arcfour ciphers.
// They are defined in the order specified in the RFC.
"arcfour128": {16, 0, 1536, newRC4},
"arcfour256": {32, 0, 1536, newRC4},
"arcfour128": {16, 0, streamCipherMode(1536, newRC4)},
"arcfour256": {32, 0, streamCipherMode(1536, newRC4)},
// Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol.
// Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and
// RC4) has problems with weak keys, and should be used with caution."
// RFC4345 introduces improved versions of Arcfour.
"arcfour": {16, 0, 0, newRC4},
"arcfour": {16, 0, streamCipherMode(0, newRC4)},
// AES-GCM is not a stream cipher, so it is constructed with a
// special case. If we add any more non-stream ciphers, we
// should invest a cleaner way to do this.
gcmCipherID: {16, 12, 0, nil},
// AEAD ciphers
gcmCipherID: {16, 12, newGCMCipher},
chacha20Poly1305ID: {64, 0, newChaCha20Cipher},
// CBC mode is insecure and so is not included in the default config.
// (See http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf). If absolutely
// needed, it's possible to specify a custom Config to enable it.
// You should expect that an active attacker can recover plaintext if
// you do.
aes128cbcID: {16, aes.BlockSize, newAESCBCCipher},
// 3des-cbc is insecure and is not included in the default
// config.
tripledescbcID: {24, des.BlockSize, newTripleDESCBCCipher},
}
// prefixLen is the length of the packet prefix that contains the packet length
@ -123,6 +139,7 @@ const prefixLen = 5
type streamPacketCipher struct {
mac hash.Hash
cipher cipher.Stream
etm bool
// The following members are to avoid per-packet allocations.
prefix [prefixLen]byte
@ -138,7 +155,14 @@ func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, err
return nil, err
}
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
var encryptedPaddingLength [1]byte
if s.mac != nil && s.etm {
copy(encryptedPaddingLength[:], s.prefix[4:5])
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
} else {
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
}
length := binary.BigEndian.Uint32(s.prefix[0:4])
paddingLength := uint32(s.prefix[4])
@ -147,7 +171,12 @@ func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, err
s.mac.Reset()
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
s.mac.Write(s.seqNumBytes[:])
s.mac.Write(s.prefix[:])
if s.etm {
s.mac.Write(s.prefix[:4])
s.mac.Write(encryptedPaddingLength[:])
} else {
s.mac.Write(s.prefix[:])
}
macSize = uint32(s.mac.Size())
}
@ -172,10 +201,17 @@ func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, err
}
mac := s.packetData[length-1:]
data := s.packetData[:length-1]
if s.mac != nil && s.etm {
s.mac.Write(data)
}
s.cipher.XORKeyStream(data, data)
if s.mac != nil {
s.mac.Write(data)
if !s.etm {
s.mac.Write(data)
}
s.macResult = s.mac.Sum(s.macResult[:0])
if subtle.ConstantTimeCompare(s.macResult, mac) != 1 {
return nil, errors.New("ssh: MAC failure")
@ -191,7 +227,13 @@ func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Rea
return errors.New("ssh: packet too large")
}
paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple
aadlen := 0
if s.mac != nil && s.etm {
// packet length is not encrypted for EtM modes
aadlen = 4
}
paddingLength := packetSizeMultiple - (prefixLen+len(packet)-aadlen)%packetSizeMultiple
if paddingLength < 4 {
paddingLength += packetSizeMultiple
}
@ -208,15 +250,37 @@ func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Rea
s.mac.Reset()
binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum)
s.mac.Write(s.seqNumBytes[:])
if s.etm {
// For EtM algorithms, the packet length must stay unencrypted,
// but the following data (padding length) must be encrypted
s.cipher.XORKeyStream(s.prefix[4:5], s.prefix[4:5])
}
s.mac.Write(s.prefix[:])
if !s.etm {
// For non-EtM algorithms, the algorithm is applied on unencrypted data
s.mac.Write(packet)
s.mac.Write(padding)
}
}
if !(s.mac != nil && s.etm) {
// For EtM algorithms, the padding length has already been encrypted
// and the packet length must remain unencrypted
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
}
s.cipher.XORKeyStream(packet, packet)
s.cipher.XORKeyStream(padding, padding)
if s.mac != nil && s.etm {
// For EtM algorithms, packet and padding must be encrypted
s.mac.Write(packet)
s.mac.Write(padding)
}
s.cipher.XORKeyStream(s.prefix[:], s.prefix[:])
s.cipher.XORKeyStream(packet, packet)
s.cipher.XORKeyStream(padding, padding)
if _, err := w.Write(s.prefix[:]); err != nil {
return err
}
@ -244,7 +308,7 @@ type gcmCipher struct {
buf []byte
}
func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) {
func newGCMCipher(key, iv, unusedMacKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
@ -312,7 +376,7 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
}
length := binary.BigEndian.Uint32(c.prefix[:])
if length > maxPacket {
return nil, errors.New("ssh: max packet length exceeded.")
return nil, errors.New("ssh: max packet length exceeded")
}
if cap(c.buf) < int(length+gcmTagSize) {
@ -332,7 +396,9 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
c.incIV()
padding := plain[0]
if padding < 4 || padding >= 20 {
if padding < 4 {
// padding is a byte, so it automatically satisfies
// the maximum size, which is 255.
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
}
@ -342,3 +408,363 @@ func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
plain = plain[1 : length-uint32(padding)]
return plain, nil
}
// cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1
type cbcCipher struct {
mac hash.Hash
macSize uint32
decrypter cipher.BlockMode
encrypter cipher.BlockMode
// The following members are to avoid per-packet allocations.
seqNumBytes [4]byte
packetData []byte
macResult []byte
// Amount of data we should still read to hide which
// verification error triggered.
oracleCamouflage uint32
}
func newCBCCipher(c cipher.Block, key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
cbc := &cbcCipher{
mac: macModes[algs.MAC].new(macKey),
decrypter: cipher.NewCBCDecrypter(c, iv),
encrypter: cipher.NewCBCEncrypter(c, iv),
packetData: make([]byte, 1024),
}
if cbc.mac != nil {
cbc.macSize = uint32(cbc.mac.Size())
}
return cbc, nil
}
func newAESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
c, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
if err != nil {
return nil, err
}
return cbc, nil
}
func newTripleDESCBCCipher(key, iv, macKey []byte, algs directionAlgorithms) (packetCipher, error) {
c, err := des.NewTripleDESCipher(key)
if err != nil {
return nil, err
}
cbc, err := newCBCCipher(c, key, iv, macKey, algs)
if err != nil {
return nil, err
}
return cbc, nil
}
func maxUInt32(a, b int) uint32 {
if a > b {
return uint32(a)
}
return uint32(b)
}
const (
cbcMinPacketSizeMultiple = 8
cbcMinPacketSize = 16
cbcMinPaddingSize = 4
)
// cbcError represents a verification error that may leak information.
type cbcError string
func (e cbcError) Error() string { return string(e) }
func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
p, err := c.readPacketLeaky(seqNum, r)
if err != nil {
if _, ok := err.(cbcError); ok {
// Verification error: read a fixed amount of
// data, to make distinguishing between
// failing MAC and failing length check more
// difficult.
io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage))
}
}
return p, err
}
func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) {
blockSize := c.decrypter.BlockSize()
// Read the header, which will include some of the subsequent data in the
// case of block ciphers - this is copied back to the payload later.
// How many bytes of payload/padding will be read with this first read.
firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize)
firstBlock := c.packetData[:firstBlockLength]
if _, err := io.ReadFull(r, firstBlock); err != nil {
return nil, err
}
c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength
c.decrypter.CryptBlocks(firstBlock, firstBlock)
length := binary.BigEndian.Uint32(firstBlock[:4])
if length > maxPacket {
return nil, cbcError("ssh: packet too large")
}
if length+4 < maxUInt32(cbcMinPacketSize, blockSize) {
// The minimum size of a packet is 16 (or the cipher block size, whichever
// is larger) bytes.
return nil, cbcError("ssh: packet too small")
}
// The length of the packet (including the length field but not the MAC) must
// be a multiple of the block size or 8, whichever is larger.
if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 {
return nil, cbcError("ssh: invalid packet length multiple")
}
paddingLength := uint32(firstBlock[4])
if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 {
return nil, cbcError("ssh: invalid packet length")
}
// Positions within the c.packetData buffer:
macStart := 4 + length
paddingStart := macStart - paddingLength
// Entire packet size, starting before length, ending at end of mac.
entirePacketSize := macStart + c.macSize
// Ensure c.packetData is large enough for the entire packet data.
if uint32(cap(c.packetData)) < entirePacketSize {
// Still need to upsize and copy, but this should be rare at runtime, only
// on upsizing the packetData buffer.
c.packetData = make([]byte, entirePacketSize)
copy(c.packetData, firstBlock)
} else {
c.packetData = c.packetData[:entirePacketSize]
}
n, err := io.ReadFull(r, c.packetData[firstBlockLength:])
if err != nil {
return nil, err
}
c.oracleCamouflage -= uint32(n)
remainingCrypted := c.packetData[firstBlockLength:macStart]
c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted)
mac := c.packetData[macStart:]
if c.mac != nil {
c.mac.Reset()
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum)
c.mac.Write(c.seqNumBytes[:])
c.mac.Write(c.packetData[:macStart])
c.macResult = c.mac.Sum(c.macResult[:0])
if subtle.ConstantTimeCompare(c.macResult, mac) != 1 {
return nil, cbcError("ssh: MAC failure")
}
}
return c.packetData[prefixLen:paddingStart], nil
}
func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error {
effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize())
// Length of encrypted portion of the packet (header, payload, padding).
// Enforce minimum padding and packet size.
encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize)
// Enforce block size.
encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize
length := encLength - 4
paddingLength := int(length) - (1 + len(packet))
// Overall buffer contains: header, payload, padding, mac.
// Space for the MAC is reserved in the capacity but not the slice length.
bufferSize := encLength + c.macSize
if uint32(cap(c.packetData)) < bufferSize {
c.packetData = make([]byte, encLength, bufferSize)
} else {
c.packetData = c.packetData[:encLength]
}
p := c.packetData
// Packet header.
binary.BigEndian.PutUint32(p, length)
p = p[4:]
p[0] = byte(paddingLength)
// Payload.
p = p[1:]
copy(p, packet)
// Padding.
p = p[len(packet):]
if _, err := io.ReadFull(rand, p); err != nil {
return err
}
if c.mac != nil {
c.mac.Reset()
binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum)
c.mac.Write(c.seqNumBytes[:])
c.mac.Write(c.packetData)
// The MAC is now appended into the capacity reserved for it earlier.
c.packetData = c.mac.Sum(c.packetData)
}
c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength])
if _, err := w.Write(c.packetData); err != nil {
return err
}
return nil
}
const chacha20Poly1305ID = "chacha20-poly1305@openssh.com"
// chacha20Poly1305Cipher implements the chacha20-poly1305@openssh.com
// AEAD, which is described here:
//
// https://tools.ietf.org/html/draft-josefsson-ssh-chacha20-poly1305-openssh-00
//
// the methods here also implement padding, which RFC4253 Section 6
// also requires of stream ciphers.
type chacha20Poly1305Cipher struct {
lengthKey [8]uint32
contentKey [8]uint32
buf []byte
}
func newChaCha20Cipher(key, unusedIV, unusedMACKey []byte, unusedAlgs directionAlgorithms) (packetCipher, error) {
if len(key) != 64 {
panic(len(key))
}
c := &chacha20Poly1305Cipher{
buf: make([]byte, 256),
}
for i := range c.contentKey {
c.contentKey[i] = binary.LittleEndian.Uint32(key[i*4 : (i+1)*4])
}
for i := range c.lengthKey {
c.lengthKey[i] = binary.LittleEndian.Uint32(key[(i+8)*4 : (i+9)*4])
}
return c, nil
}
func (c *chacha20Poly1305Cipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) {
nonce := [3]uint32{0, 0, bits.ReverseBytes32(seqNum)}
s := chacha20.New(c.contentKey, nonce)
var polyKey [32]byte
s.XORKeyStream(polyKey[:], polyKey[:])
s.Advance() // skip next 32 bytes
encryptedLength := c.buf[:4]
if _, err := io.ReadFull(r, encryptedLength); err != nil {
return nil, err
}
var lenBytes [4]byte
chacha20.New(c.lengthKey, nonce).XORKeyStream(lenBytes[:], encryptedLength)
length := binary.BigEndian.Uint32(lenBytes[:])
if length > maxPacket {
return nil, errors.New("ssh: invalid packet length, packet too large")
}
contentEnd := 4 + length
packetEnd := contentEnd + poly1305.TagSize
if uint32(cap(c.buf)) < packetEnd {
c.buf = make([]byte, packetEnd)
copy(c.buf[:], encryptedLength)
} else {
c.buf = c.buf[:packetEnd]
}
if _, err := io.ReadFull(r, c.buf[4:packetEnd]); err != nil {
return nil, err
}
var mac [poly1305.TagSize]byte
copy(mac[:], c.buf[contentEnd:packetEnd])
if !poly1305.Verify(&mac, c.buf[:contentEnd], &polyKey) {
return nil, errors.New("ssh: MAC failure")
}
plain := c.buf[4:contentEnd]
s.XORKeyStream(plain, plain)
padding := plain[0]
if padding < 4 {
// padding is a byte, so it automatically satisfies
// the maximum size, which is 255.
return nil, fmt.Errorf("ssh: illegal padding %d", padding)
}
if int(padding)+1 >= len(plain) {
return nil, fmt.Errorf("ssh: padding %d too large", padding)
}
plain = plain[1 : len(plain)-int(padding)]
return plain, nil
}
func (c *chacha20Poly1305Cipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, payload []byte) error {
nonce := [3]uint32{0, 0, bits.ReverseBytes32(seqNum)}
s := chacha20.New(c.contentKey, nonce)
var polyKey [32]byte
s.XORKeyStream(polyKey[:], polyKey[:])
s.Advance() // skip next 32 bytes
// There is no blocksize, so fall back to multiple of 8 byte
// padding, as described in RFC 4253, Sec 6.
const packetSizeMultiple = 8
padding := packetSizeMultiple - (1+len(payload))%packetSizeMultiple
if padding < 4 {
padding += packetSizeMultiple
}
// size (4 bytes), padding (1), payload, padding, tag.
totalLength := 4 + 1 + len(payload) + padding + poly1305.TagSize
if cap(c.buf) < totalLength {
c.buf = make([]byte, totalLength)
} else {
c.buf = c.buf[:totalLength]
}
binary.BigEndian.PutUint32(c.buf, uint32(1+len(payload)+padding))
chacha20.New(c.lengthKey, nonce).XORKeyStream(c.buf, c.buf[:4])
c.buf[4] = byte(padding)
copy(c.buf[5:], payload)
packetEnd := 5 + len(payload) + padding
if _, err := io.ReadFull(rand, c.buf[5+len(payload):packetEnd]); err != nil {
return err
}
s.XORKeyStream(c.buf[4:], c.buf[4:packetEnd])
var mac [poly1305.TagSize]byte
poly1305.Sum(&mac, c.buf[:packetEnd], &polyKey)
copy(c.buf[packetEnd:], mac[:])
if _, err := w.Write(c.buf); err != nil {
return err
}
return nil
}

View file

@ -14,46 +14,118 @@ import (
func TestDefaultCiphersExist(t *testing.T) {
for _, cipherAlgo := range supportedCiphers {
if _, ok := cipherModes[cipherAlgo]; !ok {
t.Errorf("default cipher %q is unknown", cipherAlgo)
t.Errorf("supported cipher %q is unknown", cipherAlgo)
}
}
for _, cipherAlgo := range preferredCiphers {
if _, ok := cipherModes[cipherAlgo]; !ok {
t.Errorf("preferred cipher %q is unknown", cipherAlgo)
}
}
}
func TestPacketCiphers(t *testing.T) {
defaultMac := "hmac-sha2-256"
defaultCipher := "aes128-ctr"
for cipher := range cipherModes {
kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
Cipher: cipher,
MAC: "hmac-sha1",
Compression: "none",
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
continue
}
server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Errorf("newPacketCipher(client, %q): %v", cipher, err)
continue
}
want := "bla bla"
input := []byte(want)
buf := &bytes.Buffer{}
if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
t.Errorf("writePacket(%q): %v", cipher, err)
continue
}
packet, err := server.readPacket(0, buf)
if err != nil {
t.Errorf("readPacket(%q): %v", cipher, err)
continue
}
if string(packet) != want {
t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want)
}
t.Run("cipher="+cipher,
func(t *testing.T) { testPacketCipher(t, cipher, defaultMac) })
}
for mac := range macModes {
t.Run("mac="+mac,
func(t *testing.T) { testPacketCipher(t, defaultCipher, mac) })
}
}
func testPacketCipher(t *testing.T, cipher, mac string) {
kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
Cipher: cipher,
MAC: mac,
Compression: "none",
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
}
server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Fatalf("newPacketCipher(client, %q, %q): %v", cipher, mac, err)
}
want := "bla bla"
input := []byte(want)
buf := &bytes.Buffer{}
if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
t.Fatalf("writePacket(%q, %q): %v", cipher, mac, err)
}
packet, err := server.readPacket(0, buf)
if err != nil {
t.Fatalf("readPacket(%q, %q): %v", cipher, mac, err)
}
if string(packet) != want {
t.Errorf("roundtrip(%q, %q): got %q, want %q", cipher, mac, packet, want)
}
}
func TestCBCOracleCounterMeasure(t *testing.T) {
kr := &kexResult{Hash: crypto.SHA1}
algs := directionAlgorithms{
Cipher: aes128cbcID,
MAC: "hmac-sha1",
Compression: "none",
}
client, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Fatalf("newPacketCipher(client): %v", err)
}
want := "bla bla"
input := []byte(want)
buf := &bytes.Buffer{}
if err := client.writePacket(0, buf, rand.Reader, input); err != nil {
t.Errorf("writePacket: %v", err)
}
packetSize := buf.Len()
buf.Write(make([]byte, 2*maxPacket))
// We corrupt each byte, but this usually will only test the
// 'packet too large' or 'MAC failure' cases.
lastRead := -1
for i := 0; i < packetSize; i++ {
server, err := newPacketCipher(clientKeys, algs, kr)
if err != nil {
t.Fatalf("newPacketCipher(client): %v", err)
}
fresh := &bytes.Buffer{}
fresh.Write(buf.Bytes())
fresh.Bytes()[i] ^= 0x01
before := fresh.Len()
_, err = server.readPacket(0, fresh)
if err == nil {
t.Errorf("corrupt byte %d: readPacket succeeded ", i)
continue
}
if _, ok := err.(cbcError); !ok {
t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err)
continue
}
after := fresh.Len()
bytesRead := before - after
if bytesRead < maxPacket {
t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket)
continue
}
if i > 0 && bytesRead != lastRead {
t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead)
}
lastRead = bytesRead
}
}

View file

@ -5,17 +5,22 @@
package ssh
import (
"bytes"
"errors"
"fmt"
"net"
"os"
"sync"
"time"
)
// Client implements a traditional SSH client that supports shells,
// subprocesses, port forwarding and tunneled dialing.
// subprocesses, TCP port/streamlocal forwarding and tunneled dialing.
type Client struct {
Conn
handleForwardsOnce sync.Once // guards calling (*Client).handleForwards
forwards forwardList // forwarded tcpip connections from the remote side
mu sync.Mutex
channelHandlers map[string]chan NewChannel
@ -39,7 +44,7 @@ func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel {
return nil
}
ch = make(chan NewChannel, 16)
ch = make(chan NewChannel, chanSize)
c.channelHandlers[channelType] = ch
return ch
}
@ -57,7 +62,6 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
conn.Wait()
conn.forwards.closeAll()
}()
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip"))
return conn
}
@ -67,6 +71,11 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client {
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
if fullConf.HostKeyCallback == nil {
c.Close()
return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback")
}
conn := &connection{
sshConn: sshConn{conn: c},
}
@ -96,19 +105,11 @@ func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) e
c.transport = newClientTransport(
newTransport(c.sshConn.conn, config.Rand, true /* is client */),
c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr())
if err := c.transport.requestKeyChange(); err != nil {
if err := c.transport.waitSession(); err != nil {
return err
}
if packet, err := c.transport.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
return unexpectedMessageError(msgNewKeys, packet[0])
}
// We just did the key change, so the session ID is established.
c.sessionID = c.transport.getSessionID()
return c.clientAuthenticate(config)
}
@ -169,7 +170,7 @@ func (c *Client) handleChannelOpens(in <-chan NewChannel) {
// to incoming channels and requests, use net.Dial with NewClientConn
// instead.
func Dial(network, addr string, config *ClientConfig) (*Client, error) {
conn, err := net.Dial(network, addr)
conn, err := net.DialTimeout(network, addr, config.Timeout)
if err != nil {
return nil, err
}
@ -180,6 +181,17 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) {
return NewClient(c, chans, reqs), nil
}
// HostKeyCallback is the function type used for verifying server
// keys. A HostKeyCallback must return nil if the host key is OK, or
// an error to reject it. It receives the hostname as passed to Dial
// or NewClientConn. The remote address is the RemoteAddr of the
// net.Conn underlying the SSH connection.
type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// BannerCallback is the function type used for treat the banner sent by
// the server. A BannerCallback receives the message sent by the remote server.
type BannerCallback func(message string) error
// A ClientConfig structure is used to configure a Client. It must not be
// modified after having been passed to an SSH function.
type ClientConfig struct {
@ -195,12 +207,72 @@ type ClientConfig struct {
// be used during authentication.
Auth []AuthMethod
// HostKeyCallback, if not nil, is called during the cryptographic
// handshake to validate the server's host key. A nil HostKeyCallback
// implies that all host keys are accepted.
HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
// HostKeyCallback is called during the cryptographic
// handshake to validate the server's host key. The client
// configuration must supply this callback for the connection
// to succeed. The functions InsecureIgnoreHostKey or
// FixedHostKey can be used for simplistic host key checks.
HostKeyCallback HostKeyCallback
// BannerCallback is called during the SSH dance to display a custom
// server's message. The client configuration can supply this callback to
// handle it as wished. The function BannerDisplayStderr can be used for
// simplistic display on Stderr.
BannerCallback BannerCallback
// ClientVersion contains the version identification string that will
// be used for the connection. If empty, a reasonable default is used.
ClientVersion string
// HostKeyAlgorithms lists the key types that the client will
// accept from the server as host key, in order of
// preference. If empty, a reasonable default is used. Any
// string returned from PublicKey.Type method may be used, or
// any of the CertAlgoXxxx and KeyAlgoXxxx constants.
HostKeyAlgorithms []string
// Timeout is the maximum amount of time for the TCP connection to establish.
//
// A Timeout of zero means no timeout.
Timeout time.Duration
}
// InsecureIgnoreHostKey returns a function that can be used for
// ClientConfig.HostKeyCallback to accept any host key. It should
// not be used for production code.
func InsecureIgnoreHostKey() HostKeyCallback {
return func(hostname string, remote net.Addr, key PublicKey) error {
return nil
}
}
type fixedHostKey struct {
key PublicKey
}
func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error {
if f.key == nil {
return fmt.Errorf("ssh: required host key was nil")
}
if !bytes.Equal(key.Marshal(), f.key.Marshal()) {
return fmt.Errorf("ssh: host key mismatch")
}
return nil
}
// FixedHostKey returns a function for use in
// ClientConfig.HostKeyCallback to accept only a specific host key.
func FixedHostKey(key PublicKey) HostKeyCallback {
hk := &fixedHostKey{key}
return hk.check
}
// BannerDisplayStderr returns a function that can be used for
// ClientConfig.BannerCallback to display banners on os.Stderr.
func BannerDisplayStderr() BannerCallback {
return func(banner string) error {
_, err := os.Stderr.WriteString(banner)
return err
}
}

View file

@ -11,6 +11,14 @@ import (
"io"
)
type authResult int
const (
authFailure authResult = iota
authPartialSuccess
authSuccess
)
// clientAuthenticate authenticates with the remote server. See RFC 4252.
func (c *connection) clientAuthenticate(config *ClientConfig) error {
// initiate user auth session
@ -30,16 +38,19 @@ func (c *connection) clientAuthenticate(config *ClientConfig) error {
// then any untried methods suggested by the server.
tried := make(map[string]bool)
var lastMethods []string
sessionID := c.transport.getSessionID()
for auth := AuthMethod(new(noneAuth)); auth != nil; {
ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand)
ok, methods, err := auth.auth(sessionID, config.User, c.transport, config.Rand)
if err != nil {
return err
}
if ok {
if ok == authSuccess {
// success
return nil
} else if ok == authFailure {
tried[auth.method()] = true
}
tried[auth.method()] = true
if methods == nil {
methods = lastMethods
}
@ -80,7 +91,7 @@ type AuthMethod interface {
// If authentication is not successful, a []string of alternative
// method names is returned. If the slice is nil, it will be ignored
// and the previous set of possible methods will be reused.
auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error)
auth(session []byte, user string, p packetConn, rand io.Reader) (authResult, []string, error)
// method returns the RFC 4252 method name.
method() string
@ -89,13 +100,13 @@ type AuthMethod interface {
// "none" authentication, RFC 4252 section 5.2.
type noneAuth int
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
if err := c.writePacket(Marshal(&userAuthRequestMsg{
User: user,
Service: serviceSSH,
Method: "none",
})); err != nil {
return false, nil, err
return authFailure, nil, err
}
return handleAuthResponse(c)
@ -109,7 +120,7 @@ func (n *noneAuth) method() string {
// a function call, e.g. by prompting the user.
type passwordCallback func() (password string, err error)
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
type passwordAuthMsg struct {
User string `sshtype:"50"`
Service string
@ -123,7 +134,7 @@ func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand
// The program may only find out that the user doesn't have a password
// when prompting.
if err != nil {
return false, nil, err
return authFailure, nil, err
}
if err := c.writePacket(Marshal(&passwordAuthMsg{
@ -133,7 +144,7 @@ func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand
Reply: false,
Password: pw,
})); err != nil {
return false, nil, err
return authFailure, nil, err
}
return handleAuthResponse(c)
@ -176,32 +187,27 @@ func (cb publicKeyCallback) method() string {
return "publickey"
}
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
// Authentication is performed in two stages. The first stage sends an
// enquiry to test if each key is acceptable to the remote. The second
// stage attempts to authenticate with the valid keys obtained in the
// first stage.
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
// Authentication is performed by sending an enquiry to test if a key is
// acceptable to the remote. If the key is acceptable, the client will
// attempt to authenticate with the valid key. If not the client will repeat
// the process with the remaining keys.
signers, err := cb()
if err != nil {
return false, nil, err
return authFailure, nil, err
}
var validKeys []Signer
for _, signer := range signers {
if ok, err := validateKey(signer.PublicKey(), user, c); ok {
validKeys = append(validKeys, signer)
} else {
if err != nil {
return false, nil, err
}
}
}
// methods that may continue if this auth is not successful.
var methods []string
for _, signer := range validKeys {
pub := signer.PublicKey()
for _, signer := range signers {
ok, err := validateKey(signer.PublicKey(), user, c)
if err != nil {
return authFailure, nil, err
}
if !ok {
continue
}
pub := signer.PublicKey()
pubKey := pub.Marshal()
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{
User: user,
@ -209,7 +215,7 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
Method: cb.method(),
}, []byte(pub.Type()), pubKey))
if err != nil {
return false, nil, err
return authFailure, nil, err
}
// manually wrap the serialized signature in a string
@ -227,18 +233,34 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
}
p := Marshal(&msg)
if err := c.writePacket(p); err != nil {
return false, nil, err
return authFailure, nil, err
}
var success bool
var success authResult
success, methods, err = handleAuthResponse(c)
if err != nil {
return false, nil, err
return authFailure, nil, err
}
if success {
// If authentication succeeds or the list of available methods does not
// contain the "publickey" method, do not attempt to authenticate with any
// other keys. According to RFC 4252 Section 7, the latter can occur when
// additional authentication methods are required.
if success == authSuccess || !containsMethod(methods, cb.method()) {
return success, methods, err
}
}
return false, methods, nil
return authFailure, methods, nil
}
func containsMethod(methods []string, method string) bool {
for _, m := range methods {
if m == method {
return true
}
}
return false
}
// validateKey validates the key provided is acceptable to the server.
@ -270,7 +292,9 @@ func confirmKeyAck(key PublicKey, c packetConn) (bool, error) {
}
switch packet[0] {
case msgUserAuthBanner:
// TODO(gpaul): add callback to present the banner to the user
if err := handleBannerResponse(c, packet); err != nil {
return false, err
}
case msgUserAuthPubKeyOk:
var msg userAuthPubKeyOkMsg
if err := Unmarshal(packet, &msg); err != nil {
@ -303,32 +327,53 @@ func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMet
// handleAuthResponse returns whether the preceding authentication request succeeded
// along with a list of remaining authentication methods to try next and
// an error if an unexpected response was received.
func handleAuthResponse(c packetConn) (bool, []string, error) {
func handleAuthResponse(c packetConn) (authResult, []string, error) {
for {
packet, err := c.readPacket()
if err != nil {
return false, nil, err
return authFailure, nil, err
}
switch packet[0] {
case msgUserAuthBanner:
// TODO: add callback to present the banner to the user
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
}
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
return authFailure, nil, err
}
return false, msg.Methods, nil
if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil
}
return authFailure, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
case msgDisconnect:
return false, nil, io.EOF
return authSuccess, nil, nil
default:
return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
return authFailure, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0])
}
}
}
func handleBannerResponse(c packetConn, packet []byte) error {
var msg userAuthBannerMsg
if err := Unmarshal(packet, &msg); err != nil {
return err
}
transport, ok := c.(*handshakeTransport)
if !ok {
return nil
}
if transport.bannerCallback != nil {
return transport.bannerCallback(msg.Message)
}
return nil
}
// KeyboardInteractiveChallenge should print questions, optionally
// disabling echoing (e.g. for passwords), and return all the answers.
// Challenge may be called multiple times in a single session. After
@ -338,7 +383,7 @@ func handleAuthResponse(c packetConn) (bool, []string, error) {
// both CLI and GUI environments.
type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error)
// KeyboardInteractive returns a AuthMethod using a prompt/response
// KeyboardInteractive returns an AuthMethod using a prompt/response
// sequence controlled by the server.
func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod {
return challenge
@ -348,7 +393,7 @@ func (cb KeyboardInteractiveChallenge) method() string {
return "keyboard-interactive"
}
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) {
func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (authResult, []string, error) {
type initiateMsg struct {
User string `sshtype:"50"`
Service string
@ -362,37 +407,42 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
Service: serviceSSH,
Method: "keyboard-interactive",
})); err != nil {
return false, nil, err
return authFailure, nil, err
}
for {
packet, err := c.readPacket()
if err != nil {
return false, nil, err
return authFailure, nil, err
}
// like handleAuthResponse, but with less options.
switch packet[0] {
case msgUserAuthBanner:
// TODO: Print banners during userauth.
if err := handleBannerResponse(c, packet); err != nil {
return authFailure, nil, err
}
continue
case msgUserAuthInfoRequest:
// OK
case msgUserAuthFailure:
var msg userAuthFailureMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
return authFailure, nil, err
}
return false, msg.Methods, nil
if msg.PartialSuccess {
return authPartialSuccess, msg.Methods, nil
}
return authFailure, msg.Methods, nil
case msgUserAuthSuccess:
return true, nil, nil
return authSuccess, nil, nil
default:
return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
return authFailure, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0])
}
var msg userAuthInfoRequestMsg
if err := Unmarshal(packet, &msg); err != nil {
return false, nil, err
return authFailure, nil, err
}
// Manually unpack the prompt/echo pairs.
@ -402,7 +452,7 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
for i := 0; i < int(msg.NumPrompts); i++ {
prompt, r, ok := parseString(rest)
if !ok || len(r) == 0 {
return false, nil, errors.New("ssh: prompt format error")
return authFailure, nil, errors.New("ssh: prompt format error")
}
prompts = append(prompts, string(prompt))
echos = append(echos, r[0] != 0)
@ -410,16 +460,16 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
}
if len(rest) != 0 {
return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
return authFailure, nil, errors.New("ssh: extra data following keyboard-interactive pairs")
}
answers, err := cb(msg.User, msg.Instruction, prompts, echos)
if err != nil {
return false, nil, err
return authFailure, nil, err
}
if len(answers) != len(prompts) {
return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
return authFailure, nil, errors.New("ssh: not enough answers from keyboard-interactive callback")
}
responseLength := 1 + 4
for _, a := range answers {
@ -435,7 +485,41 @@ func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packe
}
if err := c.writePacket(serialized); err != nil {
return false, nil, err
return authFailure, nil, err
}
}
}
type retryableAuthMethod struct {
authMethod AuthMethod
maxTries int
}
func (r *retryableAuthMethod) auth(session []byte, user string, c packetConn, rand io.Reader) (ok authResult, methods []string, err error) {
for i := 0; r.maxTries <= 0 || i < r.maxTries; i++ {
ok, methods, err = r.authMethod.auth(session, user, c, rand)
if ok != authFailure || err != nil { // either success, partial success or error terminate
return ok, methods, err
}
}
return ok, methods, err
}
func (r *retryableAuthMethod) method() string {
return r.authMethod.method()
}
// RetryableAuthMethod is a decorator for other auth methods enabling them to
// be retried up to maxTries before considering that AuthMethod itself failed.
// If maxTries is <= 0, will retry indefinitely
//
// This is useful for interactive clients using challenge/response type
// authentication (e.g. Keyboard-Interactive, Password, etc) where the user
// could mistype their response resulting in the server issuing a
// SSH_MSG_USERAUTH_FAILURE (rfc4252 #8 [password] and rfc4256 #3.4
// [keyboard-interactive]); Without this decorator, the non-retryable
// AuthMethod would be removed from future consideration, and never tried again
// (and so the user would never be able to retry their entry).
func RetryableAuthMethod(auth AuthMethod, maxTries int) AuthMethod {
return &retryableAuthMethod{authMethod: auth, maxTries: maxTries}
}

View file

@ -9,6 +9,8 @@ import (
"crypto/rand"
"errors"
"fmt"
"io"
"os"
"strings"
"testing"
)
@ -27,8 +29,14 @@ func (cr keyboardInteractive) Challenge(user string, instruction string, questio
var clientPassword = "tiger"
// tryAuth runs a handshake with a given config against an SSH server
// with config serverConfig
// with config serverConfig. Returns both client and server side errors.
func tryAuth(t *testing.T, config *ClientConfig) error {
err, _ := tryAuthBothSides(t, config)
return err
}
// tryAuthBothSides runs the handshake and returns the resulting errors from both sides of the connection.
func tryAuthBothSides(t *testing.T, config *ClientConfig) (clientError error, serverAuthErrors []error) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
@ -37,7 +45,7 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
defer c2.Close()
certChecker := CertChecker{
IsAuthority: func(k PublicKey) bool {
IsUserAuthority: func(k PublicKey) bool {
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal())
},
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
@ -75,15 +83,16 @@ func tryAuth(t *testing.T, config *ClientConfig) error {
}
return nil, errors.New("keyboard-interactive failed")
},
AuthLogCallback: func(conn ConnMetadata, method string, err error) {
t.Logf("user %q, method %q: %v", conn.User(), method, err)
},
}
serverConfig.AddHostKey(testSigners["rsa"])
serverConfig.AuthLogCallback = func(conn ConnMetadata, method string, err error) {
serverAuthErrors = append(serverAuthErrors, err)
}
go newServer(c1, serverConfig)
_, _, _, err = NewClientConn(c2, "", config)
return err
return err, serverAuthErrors
}
func TestClientAuthPublicKey(t *testing.T) {
@ -92,6 +101,7 @@ func TestClientAuthPublicKey(t *testing.T) {
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
@ -104,6 +114,7 @@ func TestAuthMethodPassword(t *testing.T) {
Auth: []AuthMethod{
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
@ -123,6 +134,7 @@ func TestAuthMethodFallback(t *testing.T) {
return "WRONG", nil
}),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
@ -141,6 +153,7 @@ func TestAuthMethodWrongPassword(t *testing.T) {
Password("wrong"),
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
@ -158,6 +171,7 @@ func TestAuthMethodKeyboardInteractive(t *testing.T) {
Auth: []AuthMethod{
KeyboardInteractive(answers.Challenge),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
@ -203,12 +217,52 @@ func TestAuthMethodRSAandDSA(t *testing.T) {
Auth: []AuthMethod{
PublicKeys(testSigners["dsa"], testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with rsa key: %v", err)
}
}
type invalidAlgSigner struct {
Signer
}
func (s *invalidAlgSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
sig, err := s.Signer.Sign(rand, data)
if sig != nil {
sig.Format = "invalid"
}
return sig, err
}
func TestMethodInvalidAlgorithm(t *testing.T) {
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(&invalidAlgSigner{testSigners["rsa"]}),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
err, serverErrors := tryAuthBothSides(t, config)
if err == nil {
t.Fatalf("login succeeded")
}
found := false
want := "algorithm \"invalid\""
var errStrings []string
for _, err := range serverErrors {
found = found || (err != nil && strings.Contains(err.Error(), want))
errStrings = append(errStrings, err.Error())
}
if !found {
t.Errorf("server got error %q, want substring %q", errStrings, want)
}
}
func TestClientHMAC(t *testing.T) {
for _, mac := range supportedMACs {
config := &ClientConfig{
@ -219,6 +273,7 @@ func TestClientHMAC(t *testing.T) {
Config: Config{
MACs: []string{mac},
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err)
@ -243,6 +298,9 @@ func TestClientUnsupportedCipher(t *testing.T) {
}
func TestClientUnsupportedKex(t *testing.T) {
if os.Getenv("GO_BUILDER_NAME") != "" {
t.Skip("skipping known-flaky test on the Go build dashboard; see golang.org/issue/15198")
}
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
@ -251,9 +309,10 @@ func TestClientUnsupportedKex(t *testing.T) {
Config: Config{
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "no common algorithms") {
t.Errorf("got %v, expected 'no common algorithms'", err)
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") {
t.Errorf("got %v, expected 'common algorithm'", err)
}
}
@ -270,22 +329,23 @@ func TestClientLoginCert(t *testing.T) {
}
clientConfig := &ClientConfig{
User: "user",
User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
}
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner))
t.Log("should succeed")
// should succeed
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login failed: %v", err)
}
t.Log("corrupted signature")
// corrupted signature
cert.Signature.Blob[0]++
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with corrupted sig")
}
t.Log("revoked")
// revoked
cert.Serial = 666
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
@ -293,13 +353,13 @@ func TestClientLoginCert(t *testing.T) {
}
cert.Serial = 1
t.Log("sign with wrong key")
// sign with wrong key
cert.SignCert(rand.Reader, testSigners["dsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with non-authoritive key")
t.Errorf("cert login passed with non-authoritative key")
}
t.Log("host cert")
// host cert
cert.CertType = HostCert
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
@ -307,14 +367,14 @@ func TestClientLoginCert(t *testing.T) {
}
cert.CertType = UserCert
t.Log("principal specified")
// principal specified
cert.ValidPrincipals = []string{"user"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login failed: %v", err)
}
t.Log("wrong principal specified")
// wrong principal specified
cert.ValidPrincipals = []string{"fred"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
@ -322,22 +382,22 @@ func TestClientLoginCert(t *testing.T) {
}
cert.ValidPrincipals = nil
t.Log("added critical option")
// added critical option
cert.CriticalOptions = map[string]string{"root-access": "yes"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login passed with unrecognized critical option")
}
t.Log("allowed source address")
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"}
// allowed source address
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24,::42/120"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err != nil {
t.Errorf("cert login with source-address failed: %v", err)
}
t.Log("disallowed source address")
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"}
// disallowed source address
cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42,::42"}
cert.SignCert(rand.Reader, testSigners["ecdsa"])
if err := tryAuth(t, clientConfig); err == nil {
t.Errorf("cert login with source-address succeeded")
@ -349,9 +409,8 @@ func testPermissionsPassing(withPermissions bool, t *testing.T) {
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
if conn.User() == "nopermissions" {
return nil, nil
} else {
return &Permissions{}, nil
}
return &Permissions{}, nil
},
}
serverConfig.AddHostKey(testSigners["rsa"])
@ -360,6 +419,7 @@ func testPermissionsPassing(withPermissions bool, t *testing.T) {
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if withPermissions {
clientConfig.User = "permissions"
@ -391,3 +451,228 @@ func TestPermissionsPassing(t *testing.T) {
func TestNoPermissionsPassing(t *testing.T) {
testPermissionsPassing(false, t)
}
func TestRetryableAuth(t *testing.T) {
n := 0
passwords := []string{"WRONG1", "WRONG2"}
config := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
RetryableAuthMethod(PasswordCallback(func() (string, error) {
p := passwords[n]
n++
return p, nil
}), 2),
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
if n != 2 {
t.Fatalf("Did not try all passwords")
}
}
func ExampleRetryableAuthMethod(t *testing.T) {
user := "testuser"
NumberOfPrompts := 3
// Normally this would be a callback that prompts the user to answer the
// provided questions
Cb := func(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
return []string{"answer1", "answer2"}, nil
}
config := &ClientConfig{
HostKeyCallback: InsecureIgnoreHostKey(),
User: user,
Auth: []AuthMethod{
RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts),
},
}
if err := tryAuth(t, config); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
}
// Test if username is received on server side when NoClientAuth is used
func TestClientAuthNone(t *testing.T) {
user := "testuser"
serverConfig := &ServerConfig{
NoClientAuth: true,
}
serverConfig.AddHostKey(testSigners["rsa"])
clientConfig := &ClientConfig{
User: user,
HostKeyCallback: InsecureIgnoreHostKey(),
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewClientConn(c2, "", clientConfig)
serverConn, err := newServer(c1, serverConfig)
if err != nil {
t.Fatalf("newServer: %v", err)
}
if serverConn.User() != user {
t.Fatalf("server: got %q, want %q", serverConn.User(), user)
}
}
// Test if authentication attempts are limited on server when MaxAuthTries is set
func TestClientAuthMaxAuthTries(t *testing.T) {
user := "testuser"
serverConfig := &ServerConfig{
MaxAuthTries: 2,
PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) {
if conn.User() == "testuser" && string(pass) == "right" {
return nil, nil
}
return nil, errors.New("password auth failed")
},
}
serverConfig.AddHostKey(testSigners["rsa"])
expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
})
for tries := 2; tries < 4; tries++ {
n := tries
clientConfig := &ClientConfig{
User: user,
Auth: []AuthMethod{
RetryableAuthMethod(PasswordCallback(func() (string, error) {
n--
if n == 0 {
return "right", nil
}
return "wrong", nil
}), tries),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go newServer(c1, serverConfig)
_, _, _, err = NewClientConn(c2, "", clientConfig)
if tries > 2 {
if err == nil {
t.Fatalf("client: got no error, want %s", expectedErr)
} else if err.Error() != expectedErr.Error() {
t.Fatalf("client: got %s, want %s", err, expectedErr)
}
} else {
if err != nil {
t.Fatalf("client: got %s, want no error", err)
}
}
}
}
// Test if authentication attempts are correctly limited on server
// when more public keys are provided then MaxAuthTries
func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) {
signers := []Signer{}
for i := 0; i < 6; i++ {
signers = append(signers, testSigners["dsa"])
}
validConfig := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, validConfig); err != nil {
t.Fatalf("unable to dial remote side: %s", err)
}
expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
})
invalidConfig := &ClientConfig{
User: "testuser",
Auth: []AuthMethod{
PublicKeys(append(signers, testSigners["rsa"])...),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
if err := tryAuth(t, invalidConfig); err == nil {
t.Fatalf("client: got no error, want %s", expectedErr)
} else if err.Error() != expectedErr.Error() {
t.Fatalf("client: got %s, want %s", err, expectedErr)
}
}
// Test whether authentication errors are being properly logged if all
// authentication methods have been exhausted
func TestClientAuthErrorList(t *testing.T) {
publicKeyErr := errors.New("This is an error from PublicKeyCallback")
clientConfig := &ClientConfig{
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
serverConfig := &ServerConfig{
PublicKeyCallback: func(_ ConnMetadata, _ PublicKey) (*Permissions, error) {
return nil, publicKeyErr
},
}
serverConfig.AddHostKey(testSigners["rsa"])
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewClientConn(c2, "", clientConfig)
_, err = newServer(c1, serverConfig)
if err == nil {
t.Fatal("newServer: got nil, expected errors")
}
authErrs, ok := err.(*ServerAuthError)
if !ok {
t.Fatalf("errors: got %T, want *ssh.ServerAuthError", err)
}
for i, e := range authErrs.Errors {
switch i {
case 0:
if e != ErrNoAuth {
t.Fatalf("errors: got error %v, want ErrNoAuth", e)
}
case 1:
if e != publicKeyErr {
t.Fatalf("errors: got %v, want %v", e, publicKeyErr)
}
default:
t.Fatalf("errors: got %v, expected 2 errors", authErrs.Errors)
}
}
}

View file

@ -5,35 +5,162 @@
package ssh
import (
"net"
"strings"
"testing"
)
func testClientVersion(t *testing.T, config *ClientConfig, expected string) {
clientConn, serverConn := net.Pipe()
defer clientConn.Close()
receivedVersion := make(chan string, 1)
go func() {
version, err := readVersion(serverConn)
if err != nil {
receivedVersion <- ""
} else {
receivedVersion <- string(version)
}
serverConn.Close()
}()
NewClientConn(clientConn, "", config)
actual := <-receivedVersion
if actual != expected {
t.Fatalf("got %s; want %s", actual, expected)
func TestClientVersion(t *testing.T) {
for _, tt := range []struct {
name string
version string
multiLine string
wantErr bool
}{
{
name: "default version",
version: packageVersion,
},
{
name: "custom version",
version: "SSH-2.0-CustomClientVersionString",
},
{
name: "good multi line version",
version: packageVersion,
multiLine: strings.Repeat("ignored\r\n", 20),
},
{
name: "bad multi line version",
version: packageVersion,
multiLine: "bad multi line version",
wantErr: true,
},
{
name: "long multi line version",
version: packageVersion,
multiLine: strings.Repeat("long multi line version\r\n", 50)[:256],
wantErr: true,
},
} {
t.Run(tt.name, func(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go func() {
if tt.multiLine != "" {
c1.Write([]byte(tt.multiLine))
}
NewClientConn(c1, "", &ClientConfig{
ClientVersion: tt.version,
HostKeyCallback: InsecureIgnoreHostKey(),
})
c1.Close()
}()
conf := &ServerConfig{NoClientAuth: true}
conf.AddHostKey(testSigners["rsa"])
conn, _, _, err := NewServerConn(c2, conf)
if err == nil == tt.wantErr {
t.Fatalf("got err %v; wantErr %t", err, tt.wantErr)
}
if tt.wantErr {
// Don't verify the version on an expected error.
return
}
if got := string(conn.ClientVersion()); got != tt.version {
t.Fatalf("got %q; want %q", got, tt.version)
}
})
}
}
func TestCustomClientVersion(t *testing.T) {
version := "Test-Client-Version-0.0"
testClientVersion(t, &ClientConfig{ClientVersion: version}, version)
func TestHostKeyCheck(t *testing.T) {
for _, tt := range []struct {
name string
wantError string
key PublicKey
}{
{"no callback", "must specify HostKeyCallback", nil},
{"correct key", "", testSigners["rsa"].PublicKey()},
{"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()},
} {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serverConf := &ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["rsa"])
go NewServerConn(c1, serverConf)
clientConf := ClientConfig{
User: "user",
}
if tt.key != nil {
clientConf.HostKeyCallback = FixedHostKey(tt.key)
}
_, _, _, err = NewClientConn(c2, "", &clientConf)
if err != nil {
if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) {
t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError)
}
} else if tt.wantError != "" {
t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError)
}
}
}
func TestDefaultClientVersion(t *testing.T) {
testClientVersion(t, &ClientConfig{}, packageVersion)
func TestBannerCallback(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serverConf := &ServerConfig{
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
return &Permissions{}, nil
},
BannerCallback: func(conn ConnMetadata) string {
return "Hello World"
},
}
serverConf.AddHostKey(testSigners["rsa"])
go NewServerConn(c1, serverConf)
var receivedBanner string
var bannerCount int
clientConf := ClientConfig{
Auth: []AuthMethod{
Password("123"),
},
User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
BannerCallback: func(message string) error {
bannerCount++
receivedBanner = message
return nil
},
}
_, _, _, err = NewClientConn(c2, "", &clientConf)
if err != nil {
t.Fatal(err)
}
if bannerCount != 1 {
t.Errorf("got %d banners; want 1", bannerCount)
}
expected := "Hello World"
if receivedBanner != expected {
t.Fatalf("got %s; want %s", receivedBanner, expected)
}
}

View file

@ -9,6 +9,7 @@ import (
"crypto/rand"
"fmt"
"io"
"math"
"sync"
_ "crypto/sha1"
@ -23,37 +24,50 @@ const (
serviceSSH = "ssh-connection"
)
// supportedCiphers specifies the supported ciphers in preference order.
// supportedCiphers lists ciphers we support but might not recommend.
var supportedCiphers = []string{
"aes128-ctr", "aes192-ctr", "aes256-ctr",
"aes128-gcm@openssh.com",
"arcfour256", "arcfour128",
chacha20Poly1305ID,
"arcfour256", "arcfour128", "arcfour",
aes128cbcID,
tripledescbcID,
}
// preferredCiphers specifies the default preference for ciphers.
var preferredCiphers = []string{
"aes128-gcm@openssh.com",
chacha20Poly1305ID,
"aes128-ctr", "aes192-ctr", "aes256-ctr",
}
// supportedKexAlgos specifies the supported key-exchange algorithms in
// preference order.
var supportedKexAlgos = []string{
kexAlgoCurve25519SHA256,
// P384 and P521 are not constant-time yet, but since we don't
// reuse ephemeral keys, using them for ECDH should be OK.
kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521,
kexAlgoDH14SHA1, kexAlgoDH1SHA1,
}
// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods
// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods
// of authenticating servers) in preference order.
var supportedHostKeyAlgos = []string{
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01,
CertAlgoECDSA384v01, CertAlgoECDSA521v01,
CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01,
KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
KeyAlgoRSA, KeyAlgoDSA,
KeyAlgoED25519,
}
// supportedMACs specifies a default set of MAC algorithms in preference order.
// This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed
// because they have reached the end of their useful life.
var supportedMACs = []string{
"hmac-sha1", "hmac-sha1-96",
"hmac-sha2-256-etm@openssh.com", "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96",
}
var supportedCompressions = []string{compressionNone}
@ -84,27 +98,15 @@ func parseError(tag uint8) error {
return fmt.Errorf("ssh: parse error in message type %d", tag)
}
func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
for _, clientAlgo := range clientAlgos {
for _, serverAlgo := range serverAlgos {
if clientAlgo == serverAlgo {
return clientAlgo, true
func findCommon(what string, client []string, server []string) (common string, err error) {
for _, c := range client {
for _, s := range server {
if c == s {
return c, nil
}
}
}
return
}
func findCommonCipher(clientCiphers []string, serverCiphers []string) (commonCipher string, ok bool) {
for _, clientCipher := range clientCiphers {
for _, serverCipher := range serverCiphers {
// reject the cipher if we have no cipherModes definition
if clientCipher == serverCipher && cipherModes[clientCipher] != nil {
return clientCipher, true
}
}
}
return
return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server)
}
type directionAlgorithms struct {
@ -113,6 +115,21 @@ type directionAlgorithms struct {
Compression string
}
// rekeyBytes returns a rekeying intervals in bytes.
func (a *directionAlgorithms) rekeyBytes() int64 {
// According to RFC4344 block ciphers should rekey after
// 2^(BLOCKSIZE/4) blocks. For all AES flavors BLOCKSIZE is
// 128.
switch a.Cipher {
case "aes128-ctr", "aes192-ctr", "aes256-ctr", gcmCipherID, aes128cbcID:
return 16 * (1 << 32)
}
// For others, stick with RFC4253 recommendation to rekey after 1 Gb of data.
return 1 << 30
}
type algorithms struct {
kex string
hostKey string
@ -120,50 +137,50 @@ type algorithms struct {
r directionAlgorithms
}
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms) {
var ok bool
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) {
result := &algorithms{}
result.kex, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if !ok {
result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if err != nil {
return
}
result.hostKey, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
if !ok {
result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
if err != nil {
return
}
result.w.Cipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
if !ok {
result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
if err != nil {
return
}
result.r.Cipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
if !ok {
result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
if err != nil {
return
}
result.w.MAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
if !ok {
result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
if err != nil {
return
}
result.r.MAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
if !ok {
result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
if err != nil {
return
}
result.w.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if !ok {
result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if err != nil {
return
}
result.r.Compression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if !ok {
result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if err != nil {
return
}
return result
return result, nil
}
// If rekeythreshold is too small, we can't make any progress sending
@ -180,7 +197,7 @@ type Config struct {
// The maximum number of bytes sent or received after which a
// new key is negotiated. It must be at least 256. If
// unspecified, 1 gigabyte is used.
// unspecified, a size suitable for the chosen cipher is used.
RekeyThreshold uint64
// The allowed key exchanges algorithms. If unspecified then a
@ -204,8 +221,16 @@ func (c *Config) SetDefaults() {
c.Rand = rand.Reader
}
if c.Ciphers == nil {
c.Ciphers = supportedCiphers
c.Ciphers = preferredCiphers
}
var ciphers []string
for _, c := range c.Ciphers {
if cipherModes[c] != nil {
// reject the cipher if we have no cipherModes definition
ciphers = append(ciphers, c)
}
}
c.Ciphers = ciphers
if c.KeyExchanges == nil {
c.KeyExchanges = supportedKexAlgos
@ -216,17 +241,18 @@ func (c *Config) SetDefaults() {
}
if c.RekeyThreshold == 0 {
// RFC 4253, section 9 suggests rekeying after 1G.
c.RekeyThreshold = 1 << 30
}
if c.RekeyThreshold < minRekeyThreshold {
// cipher specific default
} else if c.RekeyThreshold < minRekeyThreshold {
c.RekeyThreshold = minRekeyThreshold
} else if c.RekeyThreshold >= math.MaxInt64 {
// Avoid weirdness if somebody uses -1 as a threshold.
c.RekeyThreshold = math.MaxInt64
}
}
// buildDataSignedForAuth returns the data that is signed in order to prove
// possession of a private key. See RFC 4252, section 7.
func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
func buildDataSignedForAuth(sessionID []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
data := struct {
Session []byte
Type byte
@ -237,7 +263,7 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK
Algo []byte
PubKey []byte
}{
sessionId,
sessionID,
msgUserAuthRequest,
req.User,
req.Service,

View file

@ -23,17 +23,16 @@ func (e *OpenChannelError) Error() string {
// ConnMetadata holds metadata for the connection.
type ConnMetadata interface {
// User returns the user ID for this connection.
// It is empty if no authentication is used.
User() string
// SessionID returns the sesson hash, also denoted by H.
// SessionID returns the session hash, also denoted by H.
SessionID() []byte
// ClientVersion returns the client's version string as hashed
// into the session ID.
ClientVersion() []byte
// ServerVersion returns the client's version string as hashed
// ServerVersion returns the server's version string as hashed
// into the session ID.
ServerVersion() []byte

View file

@ -12,7 +12,10 @@ the multiplexed nature of SSH is exposed to users that wish to support
others.
References:
[PROTOCOL.certkeys]: http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
This package does not fall under the stability promise of the Go language itself,
so its API may be changed when pressing needs arise.
*/
package ssh // import "golang.org/x/crypto/ssh"

View file

@ -5,21 +5,45 @@
package ssh_test
import (
"bufio"
"bytes"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/terminal"
)
func ExampleNewServerConn() {
// Public key authentication is done by comparing
// the public key of a received connection
// with the entries in the authorized_keys file.
authorizedKeysBytes, err := ioutil.ReadFile("authorized_keys")
if err != nil {
log.Fatalf("Failed to load authorized_keys, err: %v", err)
}
authorizedKeysMap := map[string]bool{}
for len(authorizedKeysBytes) > 0 {
pubKey, _, _, rest, err := ssh.ParseAuthorizedKey(authorizedKeysBytes)
if err != nil {
log.Fatal(err)
}
authorizedKeysMap[string(pubKey.Marshal())] = true
authorizedKeysBytes = rest
}
// An SSH server is represented by a ServerConfig, which holds
// certificate details and handles authentication of ServerConns.
config := &ssh.ServerConfig{
// Remove to disable password auth.
PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) {
// Should use constant-time compare (or better, salt+hash) in
// a production setting.
@ -28,16 +52,29 @@ func ExampleNewServerConn() {
}
return nil, fmt.Errorf("password rejected for %q", c.User())
},
// Remove to disable public key auth.
PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
if authorizedKeysMap[string(pubKey.Marshal())] {
return &ssh.Permissions{
// Record the public key used for authentication.
Extensions: map[string]string{
"pubkey-fp": ssh.FingerprintSHA256(pubKey),
},
}, nil
}
return nil, fmt.Errorf("unknown public key for %q", c.User())
},
}
privateBytes, err := ioutil.ReadFile("id_rsa")
if err != nil {
panic("Failed to load private key")
log.Fatal("Failed to load private key: ", err)
}
private, err := ssh.ParsePrivateKey(privateBytes)
if err != nil {
panic("Failed to parse private key")
log.Fatal("Failed to parse private key: ", err)
}
config.AddHostKey(private)
@ -46,19 +83,21 @@ func ExampleNewServerConn() {
// accepted.
listener, err := net.Listen("tcp", "0.0.0.0:2022")
if err != nil {
panic("failed to listen for connection")
log.Fatal("failed to listen for connection: ", err)
}
nConn, err := listener.Accept()
if err != nil {
panic("failed to accept incoming connection")
log.Fatal("failed to accept incoming connection: ", err)
}
// Before use, a handshake must be performed on the incoming
// net.Conn.
_, chans, reqs, err := ssh.NewServerConn(nConn, config)
conn, chans, reqs, err := ssh.NewServerConn(nConn, config)
if err != nil {
panic("failed to handshake")
log.Fatal("failed to handshake: ", err)
}
log.Printf("logged in with key %s", conn.Permissions.Extensions["pubkey-fp"])
// The incoming Request channel must be serviced.
go ssh.DiscardRequests(reqs)
@ -74,7 +113,7 @@ func ExampleNewServerConn() {
}
channel, requests, err := newChannel.Accept()
if err != nil {
panic("could not accept channel.")
log.Fatalf("Could not accept channel: %v", err)
}
// Sessions have out-of-band requests such as "shell",
@ -82,18 +121,7 @@ func ExampleNewServerConn() {
// "shell" request.
go func(in <-chan *ssh.Request) {
for req := range in {
ok := false
switch req.Type {
case "shell":
ok = true
if len(req.Payload) > 0 {
// We don't accept any
// commands, only the
// default shell.
ok = false
}
}
req.Reply(ok, nil)
req.Reply(req.Type == "shell", nil)
}
}(requests)
@ -112,28 +140,70 @@ func ExampleNewServerConn() {
}
}
func ExampleHostKeyCheck() {
// Every client must provide a host key check. Here is a
// simple-minded parse of OpenSSH's known_hosts file
host := "hostname"
file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
if err != nil {
log.Fatal(err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
var hostKey ssh.PublicKey
for scanner.Scan() {
fields := strings.Split(scanner.Text(), " ")
if len(fields) != 3 {
continue
}
if strings.Contains(fields[0], host) {
var err error
hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes())
if err != nil {
log.Fatalf("error parsing %q: %v", fields[2], err)
}
break
}
}
if hostKey == nil {
log.Fatalf("no hostkey for %s", host)
}
config := ssh.ClientConfig{
User: os.Getenv("USER"),
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
_, err = ssh.Dial("tcp", host+":22", &config)
log.Println(err)
}
func ExampleDial() {
// An SSH client is represented with a ClientConn. Currently only
// the "password" authentication method is supported.
var hostKey ssh.PublicKey
// An SSH client is represented with a ClientConn.
//
// To authenticate with the remote server you must pass at least one
// implementation of AuthMethod via the Auth field in ClientConfig.
// implementation of AuthMethod via the Auth field in ClientConfig,
// and provide a HostKeyCallback.
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("yourpassword"),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
client, err := ssh.Dial("tcp", "yourserver.com:22", config)
if err != nil {
panic("Failed to dial: " + err.Error())
log.Fatal("Failed to dial: ", err)
}
// Each ClientConn can support multiple interactive sessions,
// represented by a Session.
session, err := client.NewSession()
if err != nil {
panic("Failed to create session: " + err.Error())
log.Fatal("Failed to create session: ", err)
}
defer session.Close()
@ -142,29 +212,66 @@ func ExampleDial() {
var b bytes.Buffer
session.Stdout = &b
if err := session.Run("/usr/bin/whoami"); err != nil {
panic("Failed to run: " + err.Error())
log.Fatal("Failed to run: " + err.Error())
}
fmt.Println(b.String())
}
func ExamplePublicKeys() {
var hostKey ssh.PublicKey
// A public key may be used to authenticate against the remote
// server by using an unencrypted PEM-encoded private key file.
//
// If you have an encrypted private key, the crypto/x509 package
// can be used to decrypt it.
key, err := ioutil.ReadFile("/home/user/.ssh/id_rsa")
if err != nil {
log.Fatalf("unable to read private key: %v", err)
}
// Create the Signer for this private key.
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
log.Fatalf("unable to parse private key: %v", err)
}
config := &ssh.ClientConfig{
User: "user",
Auth: []ssh.AuthMethod{
// Use the PublicKeys method for remote authentication.
ssh.PublicKeys(signer),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Connect to the remote server and perform the SSH handshake.
client, err := ssh.Dial("tcp", "host.com:22", config)
if err != nil {
log.Fatalf("unable to connect: %v", err)
}
defer client.Close()
}
func ExampleClient_Listen() {
var hostKey ssh.PublicKey
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("password"),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Dial your ssh server.
conn, err := ssh.Dial("tcp", "localhost:22", config)
if err != nil {
log.Fatalf("unable to connect: %s", err)
log.Fatal("unable to connect: ", err)
}
defer conn.Close()
// Request the remote side to open port 8080 on all interfaces.
l, err := conn.Listen("tcp", "0.0.0.0:8080")
if err != nil {
log.Fatalf("unable to register tcp forward: %v", err)
log.Fatal("unable to register tcp forward: ", err)
}
defer l.Close()
@ -175,23 +282,25 @@ func ExampleClient_Listen() {
}
func ExampleSession_RequestPty() {
var hostKey ssh.PublicKey
// Create client config
config := &ssh.ClientConfig{
User: "username",
Auth: []ssh.AuthMethod{
ssh.Password("password"),
},
HostKeyCallback: ssh.FixedHostKey(hostKey),
}
// Connect to ssh server
conn, err := ssh.Dial("tcp", "localhost:22", config)
if err != nil {
log.Fatalf("unable to connect: %s", err)
log.Fatal("unable to connect: ", err)
}
defer conn.Close()
// Create a session
session, err := conn.NewSession()
if err != nil {
log.Fatalf("unable to create session: %s", err)
log.Fatal("unable to create session: ", err)
}
defer session.Close()
// Set up terminal modes
@ -201,11 +310,11 @@ func ExampleSession_RequestPty() {
ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
}
// Request pseudo terminal
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
log.Fatalf("request for pseudo terminal failed: %s", err)
if err := session.RequestPty("xterm", 40, 80, modes); err != nil {
log.Fatal("request for pseudo terminal failed: ", err)
}
// Start remote shell
if err := session.Shell(); err != nil {
log.Fatalf("failed to start shell: %s", err)
log.Fatal("failed to start shell: ", err)
}
}

View file

@ -19,6 +19,11 @@ import (
// messages are wrong when using ECDH.
const debugHandshake = false
// chanSize sets the amount of buffering SSH connections. This is
// primarily for testing: setting chanSize=0 uncovers deadlocks more
// quickly.
const chanSize = 16
// keyingTransport is a packet based transport that supports key
// changes. It need not be thread-safe. It should pass through
// msgNewKeys in both directions.
@ -29,25 +34,6 @@ type keyingTransport interface {
// direction will be effected if a msgNewKeys message is sent
// or received.
prepareKeyChange(*algorithms, *kexResult) error
// getSessionID returns the session ID. prepareKeyChange must
// have been called once.
getSessionID() []byte
}
// rekeyingTransport is the interface of handshakeTransport that we
// (internally) expose to ClientConn and ServerConn.
type rekeyingTransport interface {
packetConn
// requestKeyChange asks the remote side to change keys. All
// writes are blocked until the key change succeeds, which is
// signaled by reading a msgNewKeys.
requestKeyChange() error
// getSessionID returns the session ID. This is only valid
// after the first key change has completed.
getSessionID() []byte
}
// handshakeTransport implements rekeying on top of a keyingTransport
@ -59,26 +45,60 @@ type handshakeTransport struct {
serverVersion []byte
clientVersion []byte
hostKeys []Signer // If hostKeys are given, we are the server.
// hostKeys is non-empty if we are the server. In that case,
// it contains all host keys that can be used to sign the
// connection.
hostKeys []Signer
// hostKeyAlgorithms is non-empty if we are the client. In that case,
// we accept these key types from the server as host key.
hostKeyAlgorithms []string
// On read error, incoming is closed, and readError is set.
incoming chan []byte
readError error
mu sync.Mutex
writeError error
sentInitPacket []byte
sentInitMsg *kexInitMsg
pendingPackets [][]byte // Used when a key exchange is in progress.
// If the read loop wants to schedule a kex, it pings this
// channel, and the write loop will send out a kex
// message.
requestKex chan struct{}
// If the other side requests or confirms a kex, its kexInit
// packet is sent here for the write loop to find it.
startKex chan *pendingKex
// data for host key checking
hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error
hostKeyCallback HostKeyCallback
dialAddress string
remoteAddr net.Addr
readSinceKex uint64
// bannerCallback is non-empty if we are the client and it has been set in
// ClientConfig. In that case it is called during the user authentication
// dance to handle a custom server's message.
bannerCallback BannerCallback
// Protects the writing side of the connection
mu sync.Mutex
cond *sync.Cond
sentInitPacket []byte
sentInitMsg *kexInitMsg
writtenSinceKex uint64
writeError error
// Algorithms agreed in the last key exchange.
algorithms *algorithms
readPacketsLeft uint32
readBytesLeft int64
writePacketsLeft uint32
writeBytesLeft int64
// The session ID or nil if first kex did not complete yet.
sessionID []byte
}
type pendingKex struct {
otherInit []byte
done chan error
}
func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport {
@ -86,10 +106,17 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion,
conn: conn,
serverVersion: serverVersion,
clientVersion: clientVersion,
incoming: make(chan []byte, 16),
config: config,
incoming: make(chan []byte, chanSize),
requestKex: make(chan struct{}, 1),
startKex: make(chan *pendingKex, 1),
config: config,
}
t.cond = sync.NewCond(&t.mu)
t.resetReadThresholds()
t.resetWriteThresholds()
// We always start with a mandatory key exchange.
t.requestKex <- struct{}{}
return t
}
@ -98,7 +125,14 @@ func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
t.dialAddress = dialAddr
t.remoteAddr = addr
t.hostKeyCallback = config.HostKeyCallback
t.bannerCallback = config.BannerCallback
if config.HostKeyAlgorithms != nil {
t.hostKeyAlgorithms = config.HostKeyAlgorithms
} else {
t.hostKeyAlgorithms = supportedHostKeyAlgos
}
go t.readLoop()
go t.kexLoop()
return t
}
@ -106,11 +140,26 @@ func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byt
t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion)
t.hostKeys = config.hostKeys
go t.readLoop()
go t.kexLoop()
return t
}
func (t *handshakeTransport) getSessionID() []byte {
return t.conn.getSessionID()
return t.sessionID
}
// waitSession waits for the session to be established. This should be
// the first thing to call after instantiating handshakeTransport.
func (t *handshakeTransport) waitSession() error {
p, err := t.readPacket()
if err != nil {
return err
}
if p[0] != msgNewKeys {
return fmt.Errorf("ssh: first packet should be msgNewKeys")
}
return nil
}
func (t *handshakeTransport) id() string {
@ -120,6 +169,20 @@ func (t *handshakeTransport) id() string {
return "client"
}
func (t *handshakeTransport) printPacket(p []byte, write bool) {
action := "got"
if write {
action = "sent"
}
if p[0] == msgChannelData || p[0] == msgChannelExtendedData {
log.Printf("%s %s data (packet %d bytes)", t.id(), action, len(p))
} else {
msg, err := decode(p)
log.Printf("%s %s %T %v (%v)", t.id(), action, msg, msg, err)
}
}
func (t *handshakeTransport) readPacket() ([]byte, error) {
p, ok := <-t.incoming
if !ok {
@ -129,8 +192,10 @@ func (t *handshakeTransport) readPacket() ([]byte, error) {
}
func (t *handshakeTransport) readLoop() {
first := true
for {
p, err := t.readOnePacket()
p, err := t.readOnePacket(first)
first = false
if err != nil {
t.readError = err
close(t.incoming)
@ -141,82 +206,244 @@ func (t *handshakeTransport) readLoop() {
}
t.incoming <- p
}
// Stop writers too.
t.recordWriteError(t.readError)
// Unblock the writer should it wait for this.
close(t.startKex)
// Don't close t.requestKex; it's also written to from writePacket.
}
func (t *handshakeTransport) readOnePacket() ([]byte, error) {
if t.readSinceKex > t.config.RekeyThreshold {
if err := t.requestKeyChange(); err != nil {
return nil, err
func (t *handshakeTransport) pushPacket(p []byte) error {
if debugHandshake {
t.printPacket(p, true)
}
return t.conn.writePacket(p)
}
func (t *handshakeTransport) getWriteError() error {
t.mu.Lock()
defer t.mu.Unlock()
return t.writeError
}
func (t *handshakeTransport) recordWriteError(err error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.writeError == nil && err != nil {
t.writeError = err
}
}
func (t *handshakeTransport) requestKeyExchange() {
select {
case t.requestKex <- struct{}{}:
default:
// something already requested a kex, so do nothing.
}
}
func (t *handshakeTransport) resetWriteThresholds() {
t.writePacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.writeBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.writeBytesLeft = t.algorithms.w.rekeyBytes()
} else {
t.writeBytesLeft = 1 << 30
}
}
func (t *handshakeTransport) kexLoop() {
write:
for t.getWriteError() == nil {
var request *pendingKex
var sent bool
for request == nil || !sent {
var ok bool
select {
case request, ok = <-t.startKex:
if !ok {
break write
}
case <-t.requestKex:
break
}
if !sent {
if err := t.sendKexInit(); err != nil {
t.recordWriteError(err)
break
}
sent = true
}
}
if err := t.getWriteError(); err != nil {
if request != nil {
request.done <- err
}
break
}
// We're not servicing t.requestKex, but that is OK:
// we never block on sending to t.requestKex.
// We're not servicing t.startKex, but the remote end
// has just sent us a kexInitMsg, so it can't send
// another key change request, until we close the done
// channel on the pendingKex request.
err := t.enterKeyExchange(request.otherInit)
t.mu.Lock()
t.writeError = err
t.sentInitPacket = nil
t.sentInitMsg = nil
t.resetWriteThresholds()
// we have completed the key exchange. Since the
// reader is still blocked, it is safe to clear out
// the requestKex channel. This avoids the situation
// where: 1) we consumed our own request for the
// initial kex, and 2) the kex from the remote side
// caused another send on the requestKex channel,
clear:
for {
select {
case <-t.requestKex:
//
default:
break clear
}
}
request.done <- t.writeError
// kex finished. Push packets that we received while
// the kex was in progress. Don't look at t.startKex
// and don't increment writtenSinceKex: if we trigger
// another kex while we are still busy with the last
// one, things will become very confusing.
for _, p := range t.pendingPackets {
t.writeError = t.pushPacket(p)
if t.writeError != nil {
break
}
}
t.pendingPackets = t.pendingPackets[:0]
t.mu.Unlock()
}
// drain startKex channel. We don't service t.requestKex
// because nobody does blocking sends there.
go func() {
for init := range t.startKex {
init.done <- t.writeError
}
}()
// Unblock reader.
t.conn.Close()
}
// The protocol uses uint32 for packet counters, so we can't let them
// reach 1<<32. We will actually read and write more packets than
// this, though: the other side may send more packets, and after we
// hit this limit on writing we will send a few more packets for the
// key exchange itself.
const packetRekeyThreshold = (1 << 31)
func (t *handshakeTransport) resetReadThresholds() {
t.readPacketsLeft = packetRekeyThreshold
if t.config.RekeyThreshold > 0 {
t.readBytesLeft = int64(t.config.RekeyThreshold)
} else if t.algorithms != nil {
t.readBytesLeft = t.algorithms.r.rekeyBytes()
} else {
t.readBytesLeft = 1 << 30
}
}
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
p, err := t.conn.readPacket()
if err != nil {
return nil, err
}
t.readSinceKex += uint64(len(p))
if debugHandshake {
msg, err := decode(p)
log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err)
if t.readPacketsLeft > 0 {
t.readPacketsLeft--
} else {
t.requestKeyExchange()
}
if t.readBytesLeft > 0 {
t.readBytesLeft -= int64(len(p))
} else {
t.requestKeyExchange()
}
if debugHandshake {
t.printPacket(p, false)
}
if first && p[0] != msgKexInit {
return nil, fmt.Errorf("ssh: first packet should be msgKexInit")
}
if p[0] != msgKexInit {
return p, nil
}
err = t.enterKeyExchange(p)
t.mu.Lock()
if err != nil {
// drop connection
t.conn.Close()
t.writeError = err
firstKex := t.sessionID == nil
kex := pendingKex{
done: make(chan error, 1),
otherInit: p,
}
t.startKex <- &kex
err = <-kex.done
if debugHandshake {
log.Printf("%s exited key exchange, err %v", t.id(), err)
log.Printf("%s exited key exchange (first %v), err %v", t.id(), firstKex, err)
}
// Unblock writers.
t.sentInitMsg = nil
t.sentInitPacket = nil
t.cond.Broadcast()
t.writtenSinceKex = 0
t.mu.Unlock()
if err != nil {
return nil, err
}
t.readSinceKex = 0
return []byte{msgNewKeys}, nil
t.resetReadThresholds()
// By default, a key exchange is hidden from higher layers by
// translating it into msgIgnore.
successPacket := []byte{msgIgnore}
if firstKex {
// sendKexInit() for the first kex waits for
// msgNewKeys so the authentication process is
// guaranteed to happen over an encrypted transport.
successPacket = []byte{msgNewKeys}
}
return successPacket, nil
}
// sendKexInit sends a key change message, and returns the message
// that was sent. After initiating the key change, all writes will be
// blocked until the change is done, and a failed key change will
// close the underlying transport. This function is safe for
// concurrent use by multiple goroutines.
func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) {
// sendKexInit sends a key change message.
func (t *handshakeTransport) sendKexInit() error {
t.mu.Lock()
defer t.mu.Unlock()
return t.sendKexInitLocked()
}
func (t *handshakeTransport) requestKeyChange() error {
_, _, err := t.sendKexInit()
return err
}
// sendKexInitLocked sends a key change message. t.mu must be locked
// while this happens.
func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
// kexInits may be sent either in response to the other side,
// or because our side wants to initiate a key change, so we
// may have already sent a kexInit. In that case, don't send a
// second kexInit.
if t.sentInitMsg != nil {
return t.sentInitMsg, t.sentInitPacket, nil
// kexInits may be sent either in response to the other side,
// or because our side wants to initiate a key change, so we
// may have already sent a kexInit. In that case, don't send a
// second kexInit.
return nil
}
msg := &kexInitMsg{
KexAlgos: t.config.KeyExchanges,
CiphersClientServer: t.config.Ciphers,
@ -234,7 +461,7 @@ func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
msg.ServerHostKeyAlgos, k.PublicKey().Type())
}
} else {
msg.ServerHostKeyAlgos = supportedHostKeyAlgos
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
}
packet := Marshal(msg)
@ -242,54 +469,65 @@ func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) {
packetCopy := make([]byte, len(packet))
copy(packetCopy, packet)
if err := t.conn.writePacket(packetCopy); err != nil {
return nil, nil, err
if err := t.pushPacket(packetCopy); err != nil {
return err
}
t.sentInitMsg = msg
t.sentInitPacket = packet
return msg, packet, nil
return nil
}
func (t *handshakeTransport) writePacket(p []byte) error {
switch p[0] {
case msgKexInit:
return errors.New("ssh: only handshakeTransport can send kexInit")
case msgNewKeys:
return errors.New("ssh: only handshakeTransport can send newKeys")
}
t.mu.Lock()
if t.writtenSinceKex > t.config.RekeyThreshold {
t.sendKexInitLocked()
}
for t.sentInitMsg != nil {
t.cond.Wait()
}
defer t.mu.Unlock()
if t.writeError != nil {
return t.writeError
}
t.writtenSinceKex += uint64(len(p))
var err error
switch p[0] {
case msgKexInit:
err = errors.New("ssh: only handshakeTransport can send kexInit")
case msgNewKeys:
err = errors.New("ssh: only handshakeTransport can send newKeys")
default:
err = t.conn.writePacket(p)
if t.sentInitMsg != nil {
// Copy the packet so the writer can reuse the buffer.
cp := make([]byte, len(p))
copy(cp, p)
t.pendingPackets = append(t.pendingPackets, cp)
return nil
}
t.mu.Unlock()
return err
if t.writeBytesLeft > 0 {
t.writeBytesLeft -= int64(len(p))
} else {
t.requestKeyExchange()
}
if t.writePacketsLeft > 0 {
t.writePacketsLeft--
} else {
t.requestKeyExchange()
}
if err := t.pushPacket(p); err != nil {
t.writeError = err
}
return nil
}
func (t *handshakeTransport) Close() error {
return t.conn.Close()
}
// enterKeyExchange runs the key exchange.
func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
if debugHandshake {
log.Printf("%s entered key exchange", t.id())
}
myInit, myInitPacket, err := t.sendKexInit()
if err != nil {
return err
}
otherInit := &kexInitMsg{}
if err := Unmarshal(otherInitPacket, otherInit); err != nil {
@ -300,26 +538,35 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
clientVersion: t.clientVersion,
serverVersion: t.serverVersion,
clientKexInit: otherInitPacket,
serverKexInit: myInitPacket,
serverKexInit: t.sentInitPacket,
}
clientInit := otherInit
serverInit := myInit
serverInit := t.sentInitMsg
if len(t.hostKeys) == 0 {
clientInit = myInit
serverInit = otherInit
clientInit, serverInit = serverInit, clientInit
magics.clientKexInit = myInitPacket
magics.clientKexInit = t.sentInitPacket
magics.serverKexInit = otherInitPacket
}
algs := findAgreedAlgorithms(clientInit, serverInit)
if algs == nil {
return errors.New("ssh: no common algorithms")
var err error
t.algorithms, err = findAgreedAlgorithms(clientInit, serverInit)
if err != nil {
return err
}
// We don't send FirstKexFollows, but we handle receiving it.
if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] {
//
// RFC 4253 section 7 defines the kex and the agreement method for
// first_kex_packet_follows. It states that the guessed packet
// should be ignored if the "kex algorithm and/or the host
// key algorithm is guessed wrong (server and client have
// different preferred algorithm), or if any of the other
// algorithms cannot be agreed upon". The other algorithms have
// already been checked above so the kex algorithm and host key
// algorithm are checked here.
if otherInit.FirstKexFollows && (clientInit.KexAlgos[0] != serverInit.KexAlgos[0] || clientInit.ServerHostKeyAlgos[0] != serverInit.ServerHostKeyAlgos[0]) {
// other side sent a kex message for the wrong algorithm,
// which we have to ignore.
if _, err := t.conn.readPacket(); err != nil {
@ -327,23 +574,30 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
}
}
kex, ok := kexAlgoMap[algs.kex]
kex, ok := kexAlgoMap[t.algorithms.kex]
if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", t.algorithms.kex)
}
var result *kexResult
if len(t.hostKeys) > 0 {
result, err = t.server(kex, algs, &magics)
result, err = t.server(kex, t.algorithms, &magics)
} else {
result, err = t.client(kex, algs, &magics)
result, err = t.client(kex, t.algorithms, &magics)
}
if err != nil {
return err
}
t.conn.prepareKeyChange(algs, result)
if t.sessionID == nil {
t.sessionID = result.H
}
result.SessionID = t.sessionID
if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil {
return err
}
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
@ -352,6 +606,7 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
} else if packet[0] != msgNewKeys {
return unexpectedMessageError(msgNewKeys, packet[0])
}
return nil
}
@ -382,11 +637,9 @@ func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *
return nil, err
}
if t.hostKeyCallback != nil {
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
if err != nil {
return nil, err
}
err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey)
if err != nil {
return nil, err
}
return result, nil

View file

@ -7,8 +7,14 @@ package ssh
import (
"bytes"
"crypto/rand"
"errors"
"fmt"
"io"
"net"
"reflect"
"runtime"
"strings"
"sync"
"testing"
)
@ -36,7 +42,10 @@ func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error
func netPipe() (net.Conn, net.Conn, error) {
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, nil, err
listener, err = net.Listen("tcp", "[::1]:0")
if err != nil {
return nil, nil, err
}
}
defer listener.Close()
c1, err := net.Dial("tcp", listener.Addr().String())
@ -53,14 +62,46 @@ func netPipe() (net.Conn, net.Conn, error) {
return c1, c2, nil
}
func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
// noiseTransport inserts ignore messages to check that the read loop
// and the key exchange filters out these messages.
type noiseTransport struct {
keyingTransport
}
func (t *noiseTransport) writePacket(p []byte) error {
ignore := []byte{msgIgnore}
if err := t.keyingTransport.writePacket(ignore); err != nil {
return err
}
debug := []byte{msgDebug, 1, 2, 3}
if err := t.keyingTransport.writePacket(debug); err != nil {
return err
}
return t.keyingTransport.writePacket(p)
}
func addNoiseTransport(t keyingTransport) keyingTransport {
return &noiseTransport{t}
}
// handshakePair creates two handshakeTransports connected with each
// other. If the noise argument is true, both transports will try to
// confuse the other side by sending ignore and debug messages.
func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
a, b, err := netPipe()
if err != nil {
return nil, nil, err
}
trC := newTransport(a, rand.Reader, true)
trS := newTransport(b, rand.Reader, false)
var trC, trS keyingTransport
trC = newTransport(a, rand.Reader, true)
trS = newTransport(b, rand.Reader, false)
if noise {
trC = addNoiseTransport(trC)
trS = addNoiseTransport(trS)
}
clientConf.SetDefaults()
v := []byte("version")
@ -68,15 +109,32 @@ func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTran
serverConf := &ServerConfig{}
serverConf.AddHostKey(testSigners["ecdsa"])
serverConf.AddHostKey(testSigners["rsa"])
serverConf.SetDefaults()
server = newServerTransport(trS, v, v, serverConf)
if err := server.waitSession(); err != nil {
return nil, nil, fmt.Errorf("server.waitSession: %v", err)
}
if err := client.waitSession(); err != nil {
return nil, nil, fmt.Errorf("client.waitSession: %v", err)
}
return client, server, nil
}
func TestHandshakeBasic(t *testing.T) {
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
if runtime.GOOS == "plan9" {
t.Skip("see golang.org/issue/7237")
}
checker := &syncChecker{
waitCall: make(chan int, 10),
called: make(chan int, 10),
}
checker.waitCall <- 1
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
@ -84,212 +142,195 @@ func TestHandshakeBasic(t *testing.T) {
defer trC.Close()
defer trS.Close()
// Let first kex complete normally.
<-checker.called
clientDone := make(chan int, 0)
gotHalf := make(chan int, 0)
const N = 20
go func() {
defer close(clientDone)
// Client writes a bunch of stuff, and does a key
// change in the middle. This should not confuse the
// handshake in progress
for i := 0; i < 10; i++ {
// handshake in progress. We do this twice, so we test
// that the packet buffer is reset correctly.
for i := 0; i < N; i++ {
p := []byte{msgRequestSuccess, byte(i)}
if err := trC.writePacket(p); err != nil {
t.Fatalf("sendPacket: %v", err)
}
if i == 5 {
if (i % 10) == 5 {
<-gotHalf
// halfway through, we request a key change.
_, _, err := trC.sendKexInit()
if err != nil {
t.Fatalf("sendKexInit: %v", err)
}
trC.requestKeyExchange()
// Wait until we can be sure the key
// change has really started before we
// write more.
<-checker.called
}
if (i % 10) == 7 {
// write some packets until the kex
// completes, to test buffering of
// packets.
checker.waitCall <- 1
}
}
trC.Close()
}()
// Server checks that client messages come in cleanly
i := 0
for {
p, err := trS.readPacket()
err = nil
for ; i < N; i++ {
var p []byte
p, err = trS.readPacket()
if err != nil {
break
}
if p[0] == msgNewKeys {
continue
if (i % 10) == 5 {
gotHalf <- 1
}
want := []byte{msgRequestSuccess, byte(i)}
if bytes.Compare(p, want) != 0 {
t.Errorf("message %d: got %q, want %q", i, p, want)
t.Errorf("message %d: got %v, want %v", i, p, want)
}
i++
}
if i != 10 {
<-clientDone
if err != nil && err != io.EOF {
t.Fatalf("server error: %v", err)
}
if i != N {
t.Errorf("received %d messages, want 10.", i)
}
// If all went well, we registered exactly 1 key change.
if len(checker.calls) != 1 {
t.Fatalf("got %d host key checks, want 1", len(checker.calls))
}
pub := testSigners["ecdsa"].PublicKey()
want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
if want != checker.calls[0] {
t.Errorf("got %q want %q for host key check", checker.calls[0], want)
close(checker.called)
if _, ok := <-checker.called; ok {
// If all went well, we registered exactly 2 key changes: one
// that establishes the session, and one that we requested
// additionally.
t.Fatalf("got another host key checks after 2 handshakes")
}
}
func TestHandshakeError(t *testing.T) {
func TestForceFirstKex(t *testing.T) {
// like handshakePair, but must access the keyingTransport.
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
a, b, err := netPipe()
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
// send a packet
packet := []byte{msgRequestSuccess, 42}
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
t.Fatalf("netPipe: %v", err)
}
// Now request a key change.
_, _, err = trC.sendKexInit()
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
var trC, trS keyingTransport
// the key change will fail, and afterwards we can't write.
if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
t.Errorf("writePacket after botched rekey succeeded.")
}
trC = newTransport(a, rand.Reader, true)
readback, err := trS.readPacket()
if err != nil {
t.Fatalf("server closed too soon: %v", err)
}
if bytes.Compare(readback, packet) != 0 {
t.Errorf("got %q want %q", readback, packet)
}
readback, err = trS.readPacket()
if err == nil {
t.Errorf("got a message %q after failed key change", readback)
}
}
// This is the disallowed packet:
trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
func TestHandshakeTwice(t *testing.T) {
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
// Rest of the setup.
trS = newTransport(b, rand.Reader, false)
clientConf.SetDefaults()
defer trC.Close()
defer trS.Close()
v := []byte("version")
client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
// send a packet
packet := make([]byte, 5)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
serverConf := &ServerConfig{}
serverConf.AddHostKey(testSigners["ecdsa"])
serverConf.AddHostKey(testSigners["rsa"])
serverConf.SetDefaults()
server := newServerTransport(trS, v, v, serverConf)
// Now request a key change.
_, _, err = trC.sendKexInit()
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
defer client.Close()
defer server.Close()
// Send another packet. Use a fresh one, since writePacket destroys.
packet = make([]byte, 5)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
// We setup the initial key exchange, but the remote side
// tries to send serviceRequestMsg in cleartext, which is
// disallowed.
// 2nd key change.
_, _, err = trC.sendKexInit()
if err != nil {
t.Errorf("sendKexInit: %v", err)
}
packet = make([]byte, 5)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
t.Errorf("writePacket: %v", err)
}
packet = make([]byte, 5)
packet[0] = msgRequestSuccess
for i := 0; i < 5; i++ {
msg, err := trS.readPacket()
if err != nil {
t.Fatalf("server closed too soon: %v", err)
}
if msg[0] == msgNewKeys {
continue
}
if bytes.Compare(msg, packet) != 0 {
t.Errorf("packet %d: got %q want %q", i, msg, packet)
}
}
if len(checker.calls) != 2 {
t.Errorf("got %d key changes, want 2", len(checker.calls))
if err := server.waitSession(); err == nil {
t.Errorf("server first kex init should reject unexpected packet")
}
}
func TestHandshakeAutoRekeyWrite(t *testing.T) {
checker := &testChecker{}
checker := &syncChecker{
called: make(chan int, 10),
waitCall: nil,
}
clientConf := &ClientConfig{HostKeyCallback: checker.Check}
clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr")
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
for i := 0; i < 5; i++ {
packet := make([]byte, 251)
packet[0] = msgRequestSuccess
if err := trC.writePacket(packet); err != nil {
input := make([]byte, 251)
input[0] = msgRequestSuccess
done := make(chan int, 1)
const numPacket = 5
go func() {
defer close(done)
j := 0
for ; j < numPacket; j++ {
if p, err := trS.readPacket(); err != nil {
break
} else if !bytes.Equal(input, p) {
t.Errorf("got packet type %d, want %d", p[0], input[0])
}
}
if j != numPacket {
t.Errorf("got %d, want 5 messages", j)
}
}()
<-checker.called
for i := 0; i < numPacket; i++ {
p := make([]byte, len(input))
copy(p, input)
if err := trC.writePacket(p); err != nil {
t.Errorf("writePacket: %v", err)
}
}
j := 0
for ; j < 5; j++ {
_, err := trS.readPacket()
if err != nil {
break
if i == 2 {
// Make sure the kex is in progress.
<-checker.called
}
}
if j != 5 {
t.Errorf("got %d, want 5 messages", j)
}
if len(checker.calls) != 2 {
t.Errorf("got %d key changes, wanted 2", len(checker.calls))
}
<-done
}
type syncChecker struct {
called chan int
waitCall chan int
called chan int
}
func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
t.called <- 1
func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
c.called <- 1
if c.waitCall != nil {
<-c.waitCall
}
return nil
}
func TestHandshakeAutoRekeyRead(t *testing.T) {
sync := &syncChecker{make(chan int, 2)}
sync := &syncChecker{
called: make(chan int, 2),
waitCall: nil,
}
clientConf := &ClientConfig{
HostKeyCallback: sync.Check,
}
clientConf.RekeyThreshold = 500
trC, trS, err := handshakePair(clientConf, "addr")
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
@ -301,11 +342,218 @@ func TestHandshakeAutoRekeyRead(t *testing.T) {
if err := trS.writePacket(packet); err != nil {
t.Fatalf("writePacket: %v", err)
}
// While we read out the packet, a key change will be
// initiated.
if _, err := trC.readPacket(); err != nil {
t.Fatalf("readPacket(client): %v", err)
}
done := make(chan int, 1)
go func() {
defer close(done)
if _, err := trC.readPacket(); err != nil {
t.Fatalf("readPacket(client): %v", err)
}
}()
<-done
<-sync.called
}
// errorKeyingTransport generates errors after a given number of
// read/write operations.
type errorKeyingTransport struct {
packetConn
readLeft, writeLeft int
}
func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
return nil
}
func (n *errorKeyingTransport) getSessionID() []byte {
return nil
}
func (n *errorKeyingTransport) writePacket(packet []byte) error {
if n.writeLeft == 0 {
n.Close()
return errors.New("barf")
}
n.writeLeft--
return n.packetConn.writePacket(packet)
}
func (n *errorKeyingTransport) readPacket() ([]byte, error) {
if n.readLeft == 0 {
n.Close()
return nil, errors.New("barf")
}
n.readLeft--
return n.packetConn.readPacket()
}
func TestHandshakeErrorHandlingRead(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, i, -1, false)
}
}
func TestHandshakeErrorHandlingWrite(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, -1, i, false)
}
}
func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, i, -1, true)
}
}
func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
for i := 0; i < 20; i++ {
testHandshakeErrorHandlingN(t, -1, i, true)
}
}
// testHandshakeErrorHandlingN runs handshakes, injecting errors. If
// handshakeTransport deadlocks, the go runtime will detect it and
// panic.
func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
a, b := memPipe()
defer a.Close()
defer b.Close()
key := testSigners["ecdsa"]
serverConf := Config{RekeyThreshold: minRekeyThreshold}
serverConf.SetDefaults()
serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
serverConn.hostKeys = []Signer{key}
go serverConn.readLoop()
go serverConn.kexLoop()
clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
clientConf.SetDefaults()
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
clientConn.hostKeyCallback = InsecureIgnoreHostKey()
go clientConn.readLoop()
go clientConn.kexLoop()
var wg sync.WaitGroup
for _, hs := range []packetConn{serverConn, clientConn} {
if !coupled {
wg.Add(2)
go func(c packetConn) {
for i := 0; ; i++ {
str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
err := c.writePacket(Marshal(&serviceRequestMsg{str}))
if err != nil {
break
}
}
wg.Done()
c.Close()
}(hs)
go func(c packetConn) {
for {
_, err := c.readPacket()
if err != nil {
break
}
}
wg.Done()
}(hs)
} else {
wg.Add(1)
go func(c packetConn) {
for {
_, err := c.readPacket()
if err != nil {
break
}
if err := c.writePacket(msg); err != nil {
break
}
}
wg.Done()
}(hs)
}
}
wg.Wait()
}
func TestDisconnect(t *testing.T) {
if runtime.GOOS == "plan9" {
t.Skip("see golang.org/issue/7237")
}
checker := &testChecker{}
trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
errMsg := &disconnectMsg{
Reason: 42,
Message: "such is life",
}
trC.writePacket(Marshal(errMsg))
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
packet, err := trS.readPacket()
if err != nil {
t.Fatalf("readPacket 1: %v", err)
}
if packet[0] != msgRequestSuccess {
t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
}
_, err = trS.readPacket()
if err == nil {
t.Errorf("readPacket 2 succeeded")
} else if !reflect.DeepEqual(err, errMsg) {
t.Errorf("got error %#v, want %#v", err, errMsg)
}
_, err = trS.readPacket()
if err == nil {
t.Errorf("readPacket 3 succeeded")
}
}
func TestHandshakeRekeyDefault(t *testing.T) {
clientConf := &ClientConfig{
Config: Config{
Ciphers: []string{"aes128-ctr"},
},
HostKeyCallback: InsecureIgnoreHostKey(),
}
trC, trS, err := handshakePair(clientConf, "addr", false)
if err != nil {
t.Fatalf("handshakePair: %v", err)
}
defer trC.Close()
defer trS.Close()
trC.writePacket([]byte{msgRequestSuccess, 0, 0})
trC.Close()
rgb := (1024 + trC.readBytesLeft) >> 30
wgb := (1024 + trC.writeBytesLeft) >> 30
if rgb != 64 {
t.Errorf("got rekey after %dG read, want 64G", rgb)
}
if wgb != 64 {
t.Errorf("got rekey after %dG write, want 64G", wgb)
}
}

202
vendor/golang.org/x/crypto/ssh/kex.go generated vendored
View file

@ -9,17 +9,21 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/subtle"
"errors"
"io"
"math/big"
"golang.org/x/crypto/curve25519"
)
const (
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
kexAlgoECDH256 = "ecdh-sha2-nistp256"
kexAlgoECDH384 = "ecdh-sha2-nistp384"
kexAlgoECDH521 = "ecdh-sha2-nistp521"
kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1"
kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
kexAlgoECDH256 = "ecdh-sha2-nistp256"
kexAlgoECDH384 = "ecdh-sha2-nistp384"
kexAlgoECDH521 = "ecdh-sha2-nistp521"
kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org"
)
// kexResult captures the outcome of a key exchange.
@ -42,7 +46,7 @@ type kexResult struct {
Hash crypto.Hash
// The session ID, which is the first H computed. This is used
// to signal data inside transport.
// to derive key material inside the transport.
SessionID []byte
}
@ -73,11 +77,11 @@ type kexAlgorithm interface {
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
type dhGroup struct {
g, p *big.Int
g, p, pMinus1 *big.Int
}
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
if theirPublic.Cmp(bigOne) <= 0 || theirPublic.Cmp(group.pMinus1) >= 0 {
return nil, errors.New("ssh: DH parameter out of bounds")
}
return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil
@ -86,10 +90,17 @@ func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int,
func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) {
hashFunc := crypto.SHA1
x, err := rand.Int(randSource, group.p)
if err != nil {
return nil, err
var x *big.Int
for {
var err error
if x, err = rand.Int(randSource, group.pMinus1); err != nil {
return nil, err
}
if x.Sign() > 0 {
break
}
}
X := new(big.Int).Exp(group.g, x, group.p)
kexDHInit := kexDHInitMsg{
X: X,
@ -108,7 +119,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
return nil, err
}
kInt, err := group.diffieHellman(kexDHReply.Y, x)
ki, err := group.diffieHellman(kexDHReply.Y, x)
if err != nil {
return nil, err
}
@ -118,8 +129,8 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
writeString(h, kexDHReply.HostKey)
writeInt(h, X)
writeInt(h, kexDHReply.Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
return &kexResult{
@ -142,13 +153,18 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
return
}
y, err := rand.Int(randSource, group.p)
if err != nil {
return
var y *big.Int
for {
if y, err = rand.Int(randSource, group.pMinus1); err != nil {
return
}
if y.Sign() > 0 {
break
}
}
Y := new(big.Int).Exp(group.g, y, group.p)
kInt, err := group.diffieHellman(kexDHInit.X, y)
ki, err := group.diffieHellman(kexDHInit.X, y)
if err != nil {
return nil, err
}
@ -161,8 +177,8 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
writeInt(h, kexDHInit.X)
writeInt(h, Y)
K := make([]byte, intLength(kInt))
marshalInt(K, kInt)
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
H := h.Sum(nil)
@ -367,8 +383,9 @@ func init() {
// 4253 and Oakley Group 2 in RFC 2409.
p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
g: new(big.Int).SetInt64(2),
p: p,
pMinus1: new(big.Int).Sub(p, bigOne),
}
// This is the group called diffie-hellman-group14-sha1 in RFC
@ -376,11 +393,148 @@ func init() {
p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{
g: new(big.Int).SetInt64(2),
p: p,
g: new(big.Int).SetInt64(2),
p: p,
pMinus1: new(big.Int).Sub(p, bigOne),
}
kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()}
kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()}
kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()}
kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{}
}
// curve25519sha256 implements the curve25519-sha256@libssh.org key
// agreement protocol, as described in
// https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt
type curve25519sha256 struct{}
type curve25519KeyPair struct {
priv [32]byte
pub [32]byte
}
func (kp *curve25519KeyPair) generate(rand io.Reader) error {
if _, err := io.ReadFull(rand, kp.priv[:]); err != nil {
return err
}
curve25519.ScalarBaseMult(&kp.pub, &kp.priv)
return nil
}
// curve25519Zeros is just an array of 32 zero bytes so that we have something
// convenient to compare against in order to reject curve25519 points with the
// wrong order.
var curve25519Zeros [32]byte
func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) {
var kp curve25519KeyPair
if err := kp.generate(rand); err != nil {
return nil, err
}
if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil {
return nil, err
}
packet, err := c.readPacket()
if err != nil {
return nil, err
}
var reply kexECDHReplyMsg
if err = Unmarshal(packet, &reply); err != nil {
return nil, err
}
if len(reply.EphemeralPubKey) != 32 {
return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
}
var servPub, secret [32]byte
copy(servPub[:], reply.EphemeralPubKey)
curve25519.ScalarMult(&secret, &kp.priv, &servPub)
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
}
h := crypto.SHA256.New()
magics.write(h)
writeString(h, reply.HostKey)
writeString(h, kp.pub[:])
writeString(h, reply.EphemeralPubKey)
ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
return &kexResult{
H: h.Sum(nil),
K: K,
HostKey: reply.HostKey,
Signature: reply.Signature,
Hash: crypto.SHA256,
}, nil
}
func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) {
packet, err := c.readPacket()
if err != nil {
return
}
var kexInit kexECDHInitMsg
if err = Unmarshal(packet, &kexInit); err != nil {
return
}
if len(kexInit.ClientPubKey) != 32 {
return nil, errors.New("ssh: peer's curve25519 public value has wrong length")
}
var kp curve25519KeyPair
if err := kp.generate(rand); err != nil {
return nil, err
}
var clientPub, secret [32]byte
copy(clientPub[:], kexInit.ClientPubKey)
curve25519.ScalarMult(&secret, &kp.priv, &clientPub)
if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 {
return nil, errors.New("ssh: peer's curve25519 public value has wrong order")
}
hostKeyBytes := priv.PublicKey().Marshal()
h := crypto.SHA256.New()
magics.write(h)
writeString(h, hostKeyBytes)
writeString(h, kexInit.ClientPubKey)
writeString(h, kp.pub[:])
ki := new(big.Int).SetBytes(secret[:])
K := make([]byte, intLength(ki))
marshalInt(K, ki)
h.Write(K)
H := h.Sum(nil)
sig, err := signAndMarshal(priv, rand, H)
if err != nil {
return nil, err
}
reply := kexECDHReplyMsg{
EphemeralPubKey: kp.pub[:],
HostKey: hostKeyBytes,
Signature: sig,
}
if err := c.writePacket(Marshal(&reply)); err != nil {
return nil, err
}
return &kexResult{
H: H,
K: K,
HostKey: hostKeyBytes,
Signature: sig,
Hash: crypto.SHA256,
}, nil
}

View file

@ -26,10 +26,12 @@ func TestKexes(t *testing.T) {
var magics handshakeMagics
go func() {
r, e := kex.Client(a, rand.Reader, &magics)
a.Close()
c <- kexResultErr{r, e}
}()
go func() {
r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"])
b.Close()
s <- kexResultErr{r, e}
}()

View file

@ -10,15 +10,21 @@ import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/md5"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/asn1"
"encoding/base64"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
"strings"
"golang.org/x/crypto/ed25519"
)
// These constants represent the algorithm names for key types supported by this
@ -29,6 +35,17 @@ const (
KeyAlgoECDSA256 = "ecdsa-sha2-nistp256"
KeyAlgoECDSA384 = "ecdsa-sha2-nistp384"
KeyAlgoECDSA521 = "ecdsa-sha2-nistp521"
KeyAlgoED25519 = "ssh-ed25519"
)
// These constants represent non-default signature algorithms that are supported
// as algorithm parameters to AlgorithmSigner.SignWithAlgorithm methods. See
// [PROTOCOL.agent] section 4.5.1 and
// https://tools.ietf.org/html/draft-ietf-curdle-rsa-sha2-10
const (
SigAlgoRSA = "ssh-rsa"
SigAlgoRSASHA2256 = "rsa-sha2-256"
SigAlgoRSASHA2512 = "rsa-sha2-512"
)
// parsePubKey parses a public key of the given algorithm.
@ -41,14 +58,16 @@ func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err err
return parseDSA(in)
case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521:
return parseECDSA(in)
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
case KeyAlgoED25519:
return parseED25519(in)
case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01:
cert, err := parseCert(in, certToPrivAlgo(algo))
if err != nil {
return nil, nil, err
}
return cert, nil, nil
}
return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err)
return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", algo)
}
// parseAuthorizedKey parses a public key in OpenSSH authorized_keys format
@ -77,6 +96,79 @@ func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) {
return out, comment, nil
}
// ParseKnownHosts parses an entry in the format of the known_hosts file.
//
// The known_hosts format is documented in the sshd(8) manual page. This
// function will parse a single entry from in. On successful return, marker
// will contain the optional marker value (i.e. "cert-authority" or "revoked")
// or else be empty, hosts will contain the hosts that this entry matches,
// pubKey will contain the public key and comment will contain any trailing
// comment at the end of the line. See the sshd(8) manual page for the various
// forms that a host string can take.
//
// The unparsed remainder of the input will be returned in rest. This function
// can be called repeatedly to parse multiple entries.
//
// If no entries were found in the input then err will be io.EOF. Otherwise a
// non-nil err value indicates a parse error.
func ParseKnownHosts(in []byte) (marker string, hosts []string, pubKey PublicKey, comment string, rest []byte, err error) {
for len(in) > 0 {
end := bytes.IndexByte(in, '\n')
if end != -1 {
rest = in[end+1:]
in = in[:end]
} else {
rest = nil
}
end = bytes.IndexByte(in, '\r')
if end != -1 {
in = in[:end]
}
in = bytes.TrimSpace(in)
if len(in) == 0 || in[0] == '#' {
in = rest
continue
}
i := bytes.IndexAny(in, " \t")
if i == -1 {
in = rest
continue
}
// Strip out the beginning of the known_host key.
// This is either an optional marker or a (set of) hostname(s).
keyFields := bytes.Fields(in)
if len(keyFields) < 3 || len(keyFields) > 5 {
return "", nil, nil, "", nil, errors.New("ssh: invalid entry in known_hosts data")
}
// keyFields[0] is either "@cert-authority", "@revoked" or a comma separated
// list of hosts
marker := ""
if keyFields[0][0] == '@' {
marker = string(keyFields[0][1:])
keyFields = keyFields[1:]
}
hosts := string(keyFields[0])
// keyFields[1] contains the key type (e.g. “ssh-rsa”).
// However, that information is duplicated inside the
// base64-encoded key and so is ignored here.
key := bytes.Join(keyFields[2:], []byte(" "))
if pubKey, comment, err = parseAuthorizedKey(key); err != nil {
return "", nil, nil, "", nil, err
}
return marker, strings.Split(hosts, ","), pubKey, comment, rest, nil
}
return "", nil, nil, "", nil, io.EOF
}
// ParseAuthorizedKeys parses a public key from an authorized_keys
// file used in OpenSSH according to the sshd(8) manual page.
func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) {
@ -194,7 +286,8 @@ type PublicKey interface {
Type() string
// Marshal returns the serialized key data in SSH wire format,
// with the name prefix.
// with the name prefix. To unmarshal the returned data, use
// the ParsePublicKey function.
Marshal() []byte
// Verify that sig is a signature on the given data using this
@ -202,6 +295,12 @@ type PublicKey interface {
Verify(data []byte, sig *Signature) error
}
// CryptoPublicKey, if implemented by a PublicKey,
// returns the underlying crypto.PublicKey form of the key.
type CryptoPublicKey interface {
CryptoPublicKey() crypto.PublicKey
}
// A Signer can create signatures that verify against a public key.
type Signer interface {
// PublicKey returns an associated PublicKey instance.
@ -212,6 +311,19 @@ type Signer interface {
Sign(rand io.Reader, data []byte) (*Signature, error)
}
// A AlgorithmSigner is a Signer that also supports specifying a specific
// algorithm to use for signing.
type AlgorithmSigner interface {
Signer
// SignWithAlgorithm is like Signer.Sign, but allows specification of a
// non-default signing algorithm. See the SigAlgo* constants in this
// package for signature algorithms supported by this package. Callers may
// pass an empty string for the algorithm in which case the AlgorithmSigner
// will use its default algorithm.
SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error)
}
type rsaPublicKey rsa.PublicKey
func (r *rsaPublicKey) Type() string {
@ -245,6 +357,8 @@ func parseRSA(in []byte) (out PublicKey, rest []byte, err error) {
func (r *rsaPublicKey) Marshal() []byte {
e := new(big.Int).SetInt64(int64(r.E))
// RSA publickey struct layout should match the struct used by
// parseRSACert in the x/crypto/ssh/agent package.
wirekey := struct {
Name string
E *big.Int
@ -258,43 +372,44 @@ func (r *rsaPublicKey) Marshal() []byte {
}
func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != r.Type() {
var hash crypto.Hash
switch sig.Format {
case SigAlgoRSA:
hash = crypto.SHA1
case SigAlgoRSASHA2256:
hash = crypto.SHA256
case SigAlgoRSASHA2512:
hash = crypto.SHA512
default:
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type())
}
h := crypto.SHA1.New()
h := hash.New()
h.Write(data)
digest := h.Sum(nil)
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob)
return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), hash, digest, sig.Blob)
}
type rsaPrivateKey struct {
*rsa.PrivateKey
}
func (r *rsaPrivateKey) PublicKey() PublicKey {
return (*rsaPublicKey)(&r.PrivateKey.PublicKey)
}
func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
h := crypto.SHA1.New()
h.Write(data)
digest := h.Sum(nil)
blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest)
if err != nil {
return nil, err
}
return &Signature{
Format: r.PublicKey().Type(),
Blob: blob,
}, nil
func (r *rsaPublicKey) CryptoPublicKey() crypto.PublicKey {
return (*rsa.PublicKey)(r)
}
type dsaPublicKey dsa.PublicKey
func (r *dsaPublicKey) Type() string {
func (k *dsaPublicKey) Type() string {
return "ssh-dss"
}
func checkDSAParams(param *dsa.Parameters) error {
// SSH specifies FIPS 186-2, which only provided a single size
// (1024 bits) DSA key. FIPS 186-3 allows for larger key
// sizes, which would confuse SSH.
if l := param.P.BitLen(); l != 1024 {
return fmt.Errorf("ssh: unsupported DSA key size %d", l)
}
return nil
}
// parseDSA parses an DSA key according to RFC 4253, section 6.6.
func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
var w struct {
@ -305,18 +420,25 @@ func parseDSA(in []byte) (out PublicKey, rest []byte, err error) {
return nil, nil, err
}
param := dsa.Parameters{
P: w.P,
Q: w.Q,
G: w.G,
}
if err := checkDSAParams(&param); err != nil {
return nil, nil, err
}
key := &dsaPublicKey{
Parameters: dsa.Parameters{
P: w.P,
Q: w.Q,
G: w.G,
},
Y: w.Y,
Parameters: param,
Y: w.Y,
}
return key, w.Rest, nil
}
func (k *dsaPublicKey) Marshal() []byte {
// DSA publickey struct layout should match the struct used by
// parseDSACert in the x/crypto/ssh/agent package.
w := struct {
Name string
P, Q, G, Y *big.Int
@ -355,6 +477,10 @@ func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error {
return errors.New("ssh: signature did not verify")
}
func (k *dsaPublicKey) CryptoPublicKey() crypto.PublicKey {
return (*dsa.PublicKey)(k)
}
type dsaPrivateKey struct {
*dsa.PrivateKey
}
@ -364,6 +490,14 @@ func (k *dsaPrivateKey) PublicKey() PublicKey {
}
func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
return k.SignWithAlgorithm(rand, data, "")
}
func (k *dsaPrivateKey) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
if algorithm != "" && algorithm != k.PublicKey().Type() {
return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
}
h := crypto.SHA1.New()
h.Write(data)
digest := h.Sum(nil)
@ -387,12 +521,12 @@ func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
type ecdsaPublicKey ecdsa.PublicKey
func (key *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + key.nistID()
func (k *ecdsaPublicKey) Type() string {
return "ecdsa-sha2-" + k.nistID()
}
func (key *ecdsaPublicKey) nistID() string {
switch key.Params().BitSize {
func (k *ecdsaPublicKey) nistID() string {
switch k.Params().BitSize {
case 256:
return "nistp256"
case 384:
@ -403,6 +537,55 @@ func (key *ecdsaPublicKey) nistID() string {
panic("ssh: unsupported ecdsa key size")
}
type ed25519PublicKey ed25519.PublicKey
func (k ed25519PublicKey) Type() string {
return KeyAlgoED25519
}
func parseED25519(in []byte) (out PublicKey, rest []byte, err error) {
var w struct {
KeyBytes []byte
Rest []byte `ssh:"rest"`
}
if err := Unmarshal(in, &w); err != nil {
return nil, nil, err
}
key := ed25519.PublicKey(w.KeyBytes)
return (ed25519PublicKey)(key), w.Rest, nil
}
func (k ed25519PublicKey) Marshal() []byte {
w := struct {
Name string
KeyBytes []byte
}{
KeyAlgoED25519,
[]byte(k),
}
return Marshal(&w)
}
func (k ed25519PublicKey) Verify(b []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
edKey := (ed25519.PublicKey)(k)
if ok := ed25519.Verify(edKey, b, sig.Blob); !ok {
return errors.New("ssh: signature did not verify")
}
return nil
}
func (k ed25519PublicKey) CryptoPublicKey() crypto.PublicKey {
return ed25519.PublicKey(k)
}
func supportedEllipticCurve(curve elliptic.Curve) bool {
return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521()
}
@ -422,14 +605,19 @@ func ecHash(curve elliptic.Curve) crypto.Hash {
// parseECDSA parses an ECDSA key according to RFC 5656, section 3.1.
func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
identifier, in, ok := parseString(in)
if !ok {
return nil, nil, errShortRead
var w struct {
Curve string
KeyBytes []byte
Rest []byte `ssh:"rest"`
}
if err := Unmarshal(in, &w); err != nil {
return nil, nil, err
}
key := new(ecdsa.PublicKey)
switch string(identifier) {
switch w.Curve {
case "nistp256":
key.Curve = elliptic.P256()
case "nistp384":
@ -440,40 +628,37 @@ func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) {
return nil, nil, errors.New("ssh: unsupported curve")
}
var keyBytes []byte
if keyBytes, in, ok = parseString(in); !ok {
return nil, nil, errShortRead
}
key.X, key.Y = elliptic.Unmarshal(key.Curve, keyBytes)
key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes)
if key.X == nil || key.Y == nil {
return nil, nil, errors.New("ssh: invalid curve point")
}
return (*ecdsaPublicKey)(key), in, nil
return (*ecdsaPublicKey)(key), w.Rest, nil
}
func (key *ecdsaPublicKey) Marshal() []byte {
func (k *ecdsaPublicKey) Marshal() []byte {
// See RFC 5656, section 3.1.
keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y)
keyBytes := elliptic.Marshal(k.Curve, k.X, k.Y)
// ECDSA publickey struct layout should match the struct used by
// parseECDSACert in the x/crypto/ssh/agent package.
w := struct {
Name string
ID string
Key []byte
}{
key.Type(),
key.nistID(),
k.Type(),
k.nistID(),
keyBytes,
}
return Marshal(&w)
}
func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != key.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type())
func (k *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
if sig.Format != k.Type() {
return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type())
}
h := ecHash(key.Curve).New()
h := ecHash(k.Curve).New()
h.Write(data)
digest := h.Sum(nil)
@ -490,78 +675,165 @@ func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error {
return err
}
if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) {
if ecdsa.Verify((*ecdsa.PublicKey)(k), digest, ecSig.R, ecSig.S) {
return nil
}
return errors.New("ssh: signature did not verify")
}
type ecdsaPrivateKey struct {
*ecdsa.PrivateKey
func (k *ecdsaPublicKey) CryptoPublicKey() crypto.PublicKey {
return (*ecdsa.PublicKey)(k)
}
func (k *ecdsaPrivateKey) PublicKey() PublicKey {
return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey)
// NewSignerFromKey takes an *rsa.PrivateKey, *dsa.PrivateKey,
// *ecdsa.PrivateKey or any other crypto.Signer and returns a
// corresponding Signer instance. ECDSA keys must use P-256, P-384 or
// P-521. DSA keys must use parameter size L1024N160.
func NewSignerFromKey(key interface{}) (Signer, error) {
switch key := key.(type) {
case crypto.Signer:
return NewSignerFromSigner(key)
case *dsa.PrivateKey:
return newDSAPrivateKey(key)
default:
return nil, fmt.Errorf("ssh: unsupported key type %T", key)
}
}
func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) {
h := ecHash(k.PrivateKey.PublicKey.Curve).New()
h.Write(data)
digest := h.Sum(nil)
r, s, err := ecdsa.Sign(rand, k.PrivateKey, digest)
func newDSAPrivateKey(key *dsa.PrivateKey) (Signer, error) {
if err := checkDSAParams(&key.PublicKey.Parameters); err != nil {
return nil, err
}
return &dsaPrivateKey{key}, nil
}
type wrappedSigner struct {
signer crypto.Signer
pubKey PublicKey
}
// NewSignerFromSigner takes any crypto.Signer implementation and
// returns a corresponding Signer interface. This can be used, for
// example, with keys kept in hardware modules.
func NewSignerFromSigner(signer crypto.Signer) (Signer, error) {
pubKey, err := NewPublicKey(signer.Public())
if err != nil {
return nil, err
}
sig := make([]byte, intLength(r)+intLength(s))
rest := marshalInt(sig, r)
marshalInt(rest, s)
return &wrappedSigner{signer, pubKey}, nil
}
func (s *wrappedSigner) PublicKey() PublicKey {
return s.pubKey
}
func (s *wrappedSigner) Sign(rand io.Reader, data []byte) (*Signature, error) {
return s.SignWithAlgorithm(rand, data, "")
}
func (s *wrappedSigner) SignWithAlgorithm(rand io.Reader, data []byte, algorithm string) (*Signature, error) {
var hashFunc crypto.Hash
if _, ok := s.pubKey.(*rsaPublicKey); ok {
// RSA keys support a few hash functions determined by the requested signature algorithm
switch algorithm {
case "", SigAlgoRSA:
algorithm = SigAlgoRSA
hashFunc = crypto.SHA1
case SigAlgoRSASHA2256:
hashFunc = crypto.SHA256
case SigAlgoRSASHA2512:
hashFunc = crypto.SHA512
default:
return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
}
} else {
// The only supported algorithm for all other key types is the same as the type of the key
if algorithm == "" {
algorithm = s.pubKey.Type()
} else if algorithm != s.pubKey.Type() {
return nil, fmt.Errorf("ssh: unsupported signature algorithm %s", algorithm)
}
switch key := s.pubKey.(type) {
case *dsaPublicKey:
hashFunc = crypto.SHA1
case *ecdsaPublicKey:
hashFunc = ecHash(key.Curve)
case ed25519PublicKey:
default:
return nil, fmt.Errorf("ssh: unsupported key type %T", key)
}
}
var digest []byte
if hashFunc != 0 {
h := hashFunc.New()
h.Write(data)
digest = h.Sum(nil)
} else {
digest = data
}
signature, err := s.signer.Sign(rand, digest, hashFunc)
if err != nil {
return nil, err
}
// crypto.Signer.Sign is expected to return an ASN.1-encoded signature
// for ECDSA and DSA, but that's not the encoding expected by SSH, so
// re-encode.
switch s.pubKey.(type) {
case *ecdsaPublicKey, *dsaPublicKey:
type asn1Signature struct {
R, S *big.Int
}
asn1Sig := new(asn1Signature)
_, err := asn1.Unmarshal(signature, asn1Sig)
if err != nil {
return nil, err
}
switch s.pubKey.(type) {
case *ecdsaPublicKey:
signature = Marshal(asn1Sig)
case *dsaPublicKey:
signature = make([]byte, 40)
r := asn1Sig.R.Bytes()
s := asn1Sig.S.Bytes()
copy(signature[20-len(r):20], r)
copy(signature[40-len(s):40], s)
}
}
return &Signature{
Format: k.PublicKey().Type(),
Blob: sig,
Format: algorithm,
Blob: signature,
}, nil
}
// NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey
// returns a corresponding Signer instance. EC keys should use P256,
// P384 or P521.
func NewSignerFromKey(k interface{}) (Signer, error) {
var sshKey Signer
switch t := k.(type) {
case *rsa.PrivateKey:
sshKey = &rsaPrivateKey{t}
case *dsa.PrivateKey:
sshKey = &dsaPrivateKey{t}
case *ecdsa.PrivateKey:
if !supportedEllipticCurve(t.Curve) {
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.")
}
sshKey = &ecdsaPrivateKey{t}
default:
return nil, fmt.Errorf("ssh: unsupported key type %T", k)
}
return sshKey, nil
}
// NewPublicKey takes a pointer to rsa, dsa or ecdsa PublicKey
// and returns a corresponding ssh PublicKey instance. EC keys should use P256, P384 or P521.
func NewPublicKey(k interface{}) (PublicKey, error) {
var sshKey PublicKey
switch t := k.(type) {
// NewPublicKey takes an *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey,
// or ed25519.PublicKey returns a corresponding PublicKey instance.
// ECDSA keys must use P-256, P-384 or P-521.
func NewPublicKey(key interface{}) (PublicKey, error) {
switch key := key.(type) {
case *rsa.PublicKey:
sshKey = (*rsaPublicKey)(t)
return (*rsaPublicKey)(key), nil
case *ecdsa.PublicKey:
if !supportedEllipticCurve(t.Curve) {
return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.")
if !supportedEllipticCurve(key.Curve) {
return nil, errors.New("ssh: only P-256, P-384 and P-521 EC keys are supported")
}
sshKey = (*ecdsaPublicKey)(t)
return (*ecdsaPublicKey)(key), nil
case *dsa.PublicKey:
sshKey = (*dsaPublicKey)(t)
return (*dsaPublicKey)(key), nil
case ed25519.PublicKey:
return (ed25519PublicKey)(key), nil
default:
return nil, fmt.Errorf("ssh: unsupported key type %T", k)
return nil, fmt.Errorf("ssh: unsupported key type %T", key)
}
return sshKey, nil
}
// ParsePrivateKey returns a Signer from a PEM encoded private key. It supports
@ -575,21 +847,87 @@ func ParsePrivateKey(pemBytes []byte) (Signer, error) {
return NewSignerFromKey(key)
}
// ParsePrivateKeyWithPassphrase returns a Signer from a PEM encoded private
// key and passphrase. It supports the same keys as
// ParseRawPrivateKeyWithPassphrase.
func ParsePrivateKeyWithPassphrase(pemBytes, passPhrase []byte) (Signer, error) {
key, err := ParseRawPrivateKeyWithPassphrase(pemBytes, passPhrase)
if err != nil {
return nil, err
}
return NewSignerFromKey(key)
}
// encryptedBlock tells whether a private key is
// encrypted by examining its Proc-Type header
// for a mention of ENCRYPTED
// according to RFC 1421 Section 4.6.1.1.
func encryptedBlock(block *pem.Block) bool {
return strings.Contains(block.Headers["Proc-Type"], "ENCRYPTED")
}
// ParseRawPrivateKey returns a private key from a PEM encoded private key. It
// supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys.
// supports RSA (PKCS#1), PKCS#8, DSA (OpenSSL), and ECDSA private keys.
func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("ssh: no key found")
}
if encryptedBlock(block) {
return nil, errors.New("ssh: cannot decode encrypted private keys")
}
switch block.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(block.Bytes)
// RFC5208 - https://tools.ietf.org/html/rfc5208
case "PRIVATE KEY":
return x509.ParsePKCS8PrivateKey(block.Bytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(block.Bytes)
case "DSA PRIVATE KEY":
return ParseDSAPrivateKey(block.Bytes)
case "OPENSSH PRIVATE KEY":
return parseOpenSSHPrivateKey(block.Bytes)
default:
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
}
}
// ParseRawPrivateKeyWithPassphrase returns a private key decrypted with
// passphrase from a PEM encoded private key. If wrong passphrase, return
// x509.IncorrectPasswordError.
func ParseRawPrivateKeyWithPassphrase(pemBytes, passPhrase []byte) (interface{}, error) {
block, _ := pem.Decode(pemBytes)
if block == nil {
return nil, errors.New("ssh: no key found")
}
buf := block.Bytes
if encryptedBlock(block) {
if x509.IsEncryptedPEMBlock(block) {
var err error
buf, err = x509.DecryptPEMBlock(block, passPhrase)
if err != nil {
if err == x509.IncorrectPasswordError {
return nil, err
}
return nil, fmt.Errorf("ssh: cannot decode encrypted private keys: %v", err)
}
}
}
switch block.Type {
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(buf)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(buf)
case "DSA PRIVATE KEY":
return ParseDSAPrivateKey(buf)
case "OPENSSH PRIVATE KEY":
return parseOpenSSHPrivateKey(buf)
default:
return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type)
}
@ -603,8 +941,8 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
P *big.Int
Q *big.Int
G *big.Int
Priv *big.Int
Pub *big.Int
Priv *big.Int
}
rest, err := asn1.Unmarshal(der, &k)
if err != nil {
@ -621,8 +959,142 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) {
Q: k.Q,
G: k.G,
},
Y: k.Priv,
Y: k.Pub,
},
X: k.Pub,
X: k.Priv,
}, nil
}
// Implemented based on the documentation at
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key
func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) {
const magic = "openssh-key-v1\x00"
if len(key) < len(magic) || string(key[:len(magic)]) != magic {
return nil, errors.New("ssh: invalid openssh private key format")
}
remaining := key[len(magic):]
var w struct {
CipherName string
KdfName string
KdfOpts string
NumKeys uint32
PubKey []byte
PrivKeyBlock []byte
}
if err := Unmarshal(remaining, &w); err != nil {
return nil, err
}
if w.KdfName != "none" || w.CipherName != "none" {
return nil, errors.New("ssh: cannot decode encrypted private keys")
}
pk1 := struct {
Check1 uint32
Check2 uint32
Keytype string
Rest []byte `ssh:"rest"`
}{}
if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil {
return nil, err
}
if pk1.Check1 != pk1.Check2 {
return nil, errors.New("ssh: checkint mismatch")
}
// we only handle ed25519 and rsa keys currently
switch pk1.Keytype {
case KeyAlgoRSA:
// https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773
key := struct {
N *big.Int
E *big.Int
D *big.Int
Iqmp *big.Int
P *big.Int
Q *big.Int
Comment string
Pad []byte `ssh:"rest"`
}{}
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
for i, b := range key.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
pk := &rsa.PrivateKey{
PublicKey: rsa.PublicKey{
N: key.N,
E: int(key.E.Int64()),
},
D: key.D,
Primes: []*big.Int{key.P, key.Q},
}
if err := pk.Validate(); err != nil {
return nil, err
}
pk.Precompute()
return pk, nil
case KeyAlgoED25519:
key := struct {
Pub []byte
Priv []byte
Comment string
Pad []byte `ssh:"rest"`
}{}
if err := Unmarshal(pk1.Rest, &key); err != nil {
return nil, err
}
if len(key.Priv) != ed25519.PrivateKeySize {
return nil, errors.New("ssh: private key unexpected length")
}
for i, b := range key.Pad {
if int(b) != i+1 {
return nil, errors.New("ssh: padding not as expected")
}
}
pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize))
copy(pk, key.Priv)
return &pk, nil
default:
return nil, errors.New("ssh: unhandled key type")
}
}
// FingerprintLegacyMD5 returns the user presentation of the key's
// fingerprint as described by RFC 4716 section 4.
func FingerprintLegacyMD5(pubKey PublicKey) string {
md5sum := md5.Sum(pubKey.Marshal())
hexarray := make([]string, len(md5sum))
for i, c := range md5sum {
hexarray[i] = hex.EncodeToString([]byte{c})
}
return strings.Join(hexarray, ":")
}
// FingerprintSHA256 returns the user presentation of the key's
// fingerprint as unpadded base64 encoded sha256 hash.
// This format was introduced from OpenSSH 6.8.
// https://www.openssh.com/txt/release-6.8
// https://tools.ietf.org/html/rfc4648#section-3.2 (unpadded base64 encoding)
func FingerprintSHA256(pubKey PublicKey) string {
sha256sum := sha256.Sum256(pubKey.Marshal())
hash := base64.RawStdEncoding.EncodeToString(sha256sum[:])
return "SHA256:" + hash
}

View file

@ -11,12 +11,16 @@ import (
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"fmt"
"io"
"reflect"
"strings"
"testing"
"golang.org/x/crypto/ed25519"
"golang.org/x/crypto/ssh/testdata"
)
@ -28,6 +32,8 @@ func rawKey(pub PublicKey) interface{} {
return (*dsa.PublicKey)(k)
case *ecdsaPublicKey:
return (*ecdsa.PublicKey)(k)
case ed25519PublicKey:
return (ed25519.PublicKey)(k)
case *Certificate:
return k
}
@ -57,12 +63,12 @@ func TestUnsupportedCurves(t *testing.T) {
t.Fatalf("GenerateKey: %v", err)
}
if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P256") {
t.Fatalf("NewPrivateKey should not succeed with P224, got: %v", err)
if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P-256") {
t.Fatalf("NewPrivateKey should not succeed with P-224, got: %v", err)
}
if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P256") {
t.Fatalf("NewPublicKey should not succeed with P224, got: %v", err)
if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P-256") {
t.Fatalf("NewPublicKey should not succeed with P-224, got: %v", err)
}
}
@ -103,6 +109,49 @@ func TestKeySignVerify(t *testing.T) {
}
}
func TestKeySignWithAlgorithmVerify(t *testing.T) {
for _, priv := range testSigners {
if algorithmSigner, ok := priv.(AlgorithmSigner); !ok {
t.Errorf("Signers constructed by ssh package should always implement the AlgorithmSigner interface: %T", priv)
} else {
pub := priv.PublicKey()
data := []byte("sign me")
signWithAlgTestCase := func(algorithm string, expectedAlg string) {
sig, err := algorithmSigner.SignWithAlgorithm(rand.Reader, data, algorithm)
if err != nil {
t.Fatalf("Sign(%T): %v", priv, err)
}
if sig.Format != expectedAlg {
t.Errorf("signature format did not match requested signature algorithm: %s != %s", sig.Format, expectedAlg)
}
if err := pub.Verify(data, sig); err != nil {
t.Errorf("publicKey.Verify(%T): %v", priv, err)
}
sig.Blob[5]++
if err := pub.Verify(data, sig); err == nil {
t.Errorf("publicKey.Verify on broken sig did not fail")
}
}
// Using the empty string as the algorithm name should result in the same signature format as the algorithm-free Sign method.
defaultSig, err := priv.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("Sign(%T): %v", priv, err)
}
signWithAlgTestCase("", defaultSig.Format)
// RSA keys are the only ones which currently support more than one signing algorithm
if pub.Type() == KeyAlgoRSA {
for _, algorithm := range []string{SigAlgoRSA, SigAlgoRSASHA2256, SigAlgoRSASHA2512} {
signWithAlgTestCase(algorithm, algorithm)
}
}
}
}
}
func TestParseRSAPrivateKey(t *testing.T) {
key := testPrivateKeys["rsa"]
@ -129,6 +178,47 @@ func TestParseECPrivateKey(t *testing.T) {
}
}
// See Issue https://github.com/golang/go/issues/6650.
func TestParseEncryptedPrivateKeysFails(t *testing.T) {
const wantSubstring = "encrypted"
for i, tt := range testdata.PEMEncryptedKeys {
_, err := ParsePrivateKey(tt.PEMBytes)
if err == nil {
t.Errorf("#%d key %s: ParsePrivateKey successfully parsed, expected an error", i, tt.Name)
continue
}
if !strings.Contains(err.Error(), wantSubstring) {
t.Errorf("#%d key %s: got error %q, want substring %q", i, tt.Name, err, wantSubstring)
}
}
}
// Parse encrypted private keys with passphrase
func TestParseEncryptedPrivateKeysWithPassphrase(t *testing.T) {
data := []byte("sign me")
for _, tt := range testdata.PEMEncryptedKeys {
s, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte(tt.EncryptionKey))
if err != nil {
t.Fatalf("ParsePrivateKeyWithPassphrase returned error: %s", err)
continue
}
sig, err := s.Sign(rand.Reader, data)
if err != nil {
t.Fatalf("dsa.Sign: %v", err)
}
if err := s.PublicKey().Verify(data, sig); err != nil {
t.Errorf("Verify failed: %v", err)
}
}
tt := testdata.PEMEncryptedKeys[0]
_, err := ParsePrivateKeyWithPassphrase(tt.PEMBytes, []byte("incorrect"))
if err != x509.IncorrectPasswordError {
t.Fatalf("got %v want IncorrectPasswordError", err)
}
}
func TestParseDSA(t *testing.T) {
// We actually exercise the ParsePrivateKey codepath here, as opposed to
// using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go
@ -189,7 +279,7 @@ func TestMarshalParsePublicKey(t *testing.T) {
}
}
type authResult struct {
type testAuthResult struct {
pubKey PublicKey
options []string
comments string
@ -197,11 +287,11 @@ type authResult struct {
ok bool
}
func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) {
func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []testAuthResult) {
rest := authKeys
var values []authResult
var values []testAuthResult
for len(rest) > 0 {
var r authResult
var r testAuthResult
var err error
r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest)
r.ok = (err == nil)
@ -219,7 +309,7 @@ func TestAuthorizedKeyBasic(t *testing.T) {
pub, pubSerialized := getTestKey()
line := "ssh-rsa " + pubSerialized + " user@host"
testAuthorizedKeys(t, []byte(line),
[]authResult{
[]testAuthResult{
{pub, nil, "user@host", "", true},
})
}
@ -241,7 +331,7 @@ func TestAuth(t *testing.T) {
authOptions := strings.Join(authWithOptions, eol)
rest2 := strings.Join(authWithOptions[3:], eol)
rest3 := strings.Join(authWithOptions[6:], eol)
testAuthorizedKeys(t, []byte(authOptions), []authResult{
testAuthorizedKeys(t, []byte(authOptions), []testAuthResult{
{pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true},
{pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true},
{nil, nil, "", "", false},
@ -252,7 +342,7 @@ func TestAuth(t *testing.T) {
func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{
testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []testAuthResult{
{pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true},
})
}
@ -260,7 +350,7 @@ func TestAuthWithQuotedSpaceInEnv(t *testing.T) {
func TestAuthWithQuotedCommaInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{
testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []testAuthResult{
{pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true},
})
}
@ -269,11 +359,11 @@ func TestAuthWithQuotedQuoteInEnv(t *testing.T) {
pub, pubSerialized := getTestKey()
authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`)
authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`)
testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{
testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []testAuthResult{
{pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true},
})
testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{
testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []testAuthResult{
{pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true},
})
}
@ -282,7 +372,7 @@ func TestAuthWithInvalidSpace(t *testing.T) {
_, pubSerialized := getTestKey()
authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
#more to follow but still no valid keys`)
testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{
testAuthorizedKeys(t, []byte(authWithInvalidSpace), []testAuthResult{
{nil, nil, "", "", false},
})
}
@ -292,7 +382,7 @@ func TestAuthWithMissingQuote(t *testing.T) {
authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host
env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`)
testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{
testAuthorizedKeys(t, []byte(authWithMissingQuote), []testAuthResult{
{pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true},
})
}
@ -304,3 +394,181 @@ func TestInvalidEntry(t *testing.T) {
t.Errorf("got valid entry for %q", authInvalid)
}
}
var knownHostsParseTests = []struct {
input string
err string
marker string
comment string
hosts []string
rest string
}{
{
"",
"EOF",
"", "", nil, "",
},
{
"# Just a comment",
"EOF",
"", "", nil, "",
},
{
" \t ",
"EOF",
"", "", nil, "",
},
{
"localhost ssh-rsa {RSAPUB}",
"",
"", "", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}",
"",
"", "", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment",
"",
"", "comment comment", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment\n",
"",
"", "comment comment", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment\r\n",
"",
"", "comment comment", []string{"localhost"}, "",
},
{
"localhost\tssh-rsa {RSAPUB}\tcomment comment\r\nnext line",
"",
"", "comment comment", []string{"localhost"}, "next line",
},
{
"localhost,[host2:123]\tssh-rsa {RSAPUB}\tcomment comment",
"",
"", "comment comment", []string{"localhost", "[host2:123]"}, "",
},
{
"@marker \tlocalhost,[host2:123]\tssh-rsa {RSAPUB}",
"",
"marker", "", []string{"localhost", "[host2:123]"}, "",
},
{
"@marker \tlocalhost,[host2:123]\tssh-rsa aabbccdd",
"short read",
"", "", nil, "",
},
}
func TestKnownHostsParsing(t *testing.T) {
rsaPub, rsaPubSerialized := getTestKey()
for i, test := range knownHostsParseTests {
var expectedKey PublicKey
const rsaKeyToken = "{RSAPUB}"
input := test.input
if strings.Contains(input, rsaKeyToken) {
expectedKey = rsaPub
input = strings.Replace(test.input, rsaKeyToken, rsaPubSerialized, -1)
}
marker, hosts, pubKey, comment, rest, err := ParseKnownHosts([]byte(input))
if err != nil {
if len(test.err) == 0 {
t.Errorf("#%d: unexpectedly failed with %q", i, err)
} else if !strings.Contains(err.Error(), test.err) {
t.Errorf("#%d: expected error containing %q, but got %q", i, test.err, err)
}
continue
} else if len(test.err) != 0 {
t.Errorf("#%d: succeeded but expected error including %q", i, test.err)
continue
}
if !reflect.DeepEqual(expectedKey, pubKey) {
t.Errorf("#%d: expected key %#v, but got %#v", i, expectedKey, pubKey)
}
if marker != test.marker {
t.Errorf("#%d: expected marker %q, but got %q", i, test.marker, marker)
}
if comment != test.comment {
t.Errorf("#%d: expected comment %q, but got %q", i, test.comment, comment)
}
if !reflect.DeepEqual(test.hosts, hosts) {
t.Errorf("#%d: expected hosts %#v, but got %#v", i, test.hosts, hosts)
}
if rest := string(rest); rest != test.rest {
t.Errorf("#%d: expected remaining input to be %q, but got %q", i, test.rest, rest)
}
}
}
func TestFingerprintLegacyMD5(t *testing.T) {
pub, _ := getTestKey()
fingerprint := FingerprintLegacyMD5(pub)
want := "fb:61:6d:1a:e3:f0:95:45:3c:a0:79:be:4a:93:63:66" // ssh-keygen -lf -E md5 rsa
if fingerprint != want {
t.Errorf("got fingerprint %q want %q", fingerprint, want)
}
}
func TestFingerprintSHA256(t *testing.T) {
pub, _ := getTestKey()
fingerprint := FingerprintSHA256(pub)
want := "SHA256:Anr3LjZK8YVpjrxu79myrW9Hrb/wpcMNpVvTq/RcBm8" // ssh-keygen -lf rsa
if fingerprint != want {
t.Errorf("got fingerprint %q want %q", fingerprint, want)
}
}
func TestInvalidKeys(t *testing.T) {
keyTypes := []string{
"RSA PRIVATE KEY",
"PRIVATE KEY",
"EC PRIVATE KEY",
"DSA PRIVATE KEY",
"OPENSSH PRIVATE KEY",
}
for _, keyType := range keyTypes {
for _, dataLen := range []int{0, 1, 2, 5, 10, 20} {
data := make([]byte, dataLen)
if _, err := io.ReadFull(rand.Reader, data); err != nil {
t.Fatal(err)
}
var buf bytes.Buffer
pem.Encode(&buf, &pem.Block{
Type: keyType,
Bytes: data,
})
// This test is just to ensure that the function
// doesn't panic so the return value is ignored.
ParseRawPrivateKey(buf.Bytes())
}
}
}

540
vendor/golang.org/x/crypto/ssh/knownhosts/knownhosts.go generated vendored Normal file
View file

@ -0,0 +1,540 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package knownhosts implements a parser for the OpenSSH known_hosts
// host key database, and provides utility functions for writing
// OpenSSH compliant known_hosts files.
package knownhosts
import (
"bufio"
"bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"net"
"os"
"strings"
"golang.org/x/crypto/ssh"
)
// See the sshd manpage
// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for
// background.
type addr struct{ host, port string }
func (a *addr) String() string {
h := a.host
if strings.Contains(h, ":") {
h = "[" + h + "]"
}
return h + ":" + a.port
}
type matcher interface {
match(addr) bool
}
type hostPattern struct {
negate bool
addr addr
}
func (p *hostPattern) String() string {
n := ""
if p.negate {
n = "!"
}
return n + p.addr.String()
}
type hostPatterns []hostPattern
func (ps hostPatterns) match(a addr) bool {
matched := false
for _, p := range ps {
if !p.match(a) {
continue
}
if p.negate {
return false
}
matched = true
}
return matched
}
// See
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c
// The matching of * has no regard for separators, unlike filesystem globs
func wildcardMatch(pat []byte, str []byte) bool {
for {
if len(pat) == 0 {
return len(str) == 0
}
if len(str) == 0 {
return false
}
if pat[0] == '*' {
if len(pat) == 1 {
return true
}
for j := range str {
if wildcardMatch(pat[1:], str[j:]) {
return true
}
}
return false
}
if pat[0] == '?' || pat[0] == str[0] {
pat = pat[1:]
str = str[1:]
} else {
return false
}
}
}
func (p *hostPattern) match(a addr) bool {
return wildcardMatch([]byte(p.addr.host), []byte(a.host)) && p.addr.port == a.port
}
type keyDBLine struct {
cert bool
matcher matcher
knownKey KnownKey
}
func serialize(k ssh.PublicKey) string {
return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal())
}
func (l *keyDBLine) match(a addr) bool {
return l.matcher.match(a)
}
type hostKeyDB struct {
// Serialized version of revoked keys
revoked map[string]*KnownKey
lines []keyDBLine
}
func newHostKeyDB() *hostKeyDB {
db := &hostKeyDB{
revoked: make(map[string]*KnownKey),
}
return db
}
func keyEq(a, b ssh.PublicKey) bool {
return bytes.Equal(a.Marshal(), b.Marshal())
}
// IsAuthorityForHost can be used as a callback in ssh.CertChecker
func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool {
h, p, err := net.SplitHostPort(address)
if err != nil {
return false
}
a := addr{host: h, port: p}
for _, l := range db.lines {
if l.cert && keyEq(l.knownKey.Key, remote) && l.match(a) {
return true
}
}
return false
}
// IsRevoked can be used as a callback in ssh.CertChecker
func (db *hostKeyDB) IsRevoked(key *ssh.Certificate) bool {
_, ok := db.revoked[string(key.Marshal())]
return ok
}
const markerCert = "@cert-authority"
const markerRevoked = "@revoked"
func nextWord(line []byte) (string, []byte) {
i := bytes.IndexAny(line, "\t ")
if i == -1 {
return string(line), nil
}
return string(line[:i]), bytes.TrimSpace(line[i:])
}
func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) {
if w, next := nextWord(line); w == markerCert || w == markerRevoked {
marker = w
line = next
}
host, line = nextWord(line)
if len(line) == 0 {
return "", "", nil, errors.New("knownhosts: missing host pattern")
}
// ignore the keytype as it's in the key blob anyway.
_, line = nextWord(line)
if len(line) == 0 {
return "", "", nil, errors.New("knownhosts: missing key type pattern")
}
keyBlob, _ := nextWord(line)
keyBytes, err := base64.StdEncoding.DecodeString(keyBlob)
if err != nil {
return "", "", nil, err
}
key, err = ssh.ParsePublicKey(keyBytes)
if err != nil {
return "", "", nil, err
}
return marker, host, key, nil
}
func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error {
marker, pattern, key, err := parseLine(line)
if err != nil {
return err
}
if marker == markerRevoked {
db.revoked[string(key.Marshal())] = &KnownKey{
Key: key,
Filename: filename,
Line: linenum,
}
return nil
}
entry := keyDBLine{
cert: marker == markerCert,
knownKey: KnownKey{
Filename: filename,
Line: linenum,
Key: key,
},
}
if pattern[0] == '|' {
entry.matcher, err = newHashedHost(pattern)
} else {
entry.matcher, err = newHostnameMatcher(pattern)
}
if err != nil {
return err
}
db.lines = append(db.lines, entry)
return nil
}
func newHostnameMatcher(pattern string) (matcher, error) {
var hps hostPatterns
for _, p := range strings.Split(pattern, ",") {
if len(p) == 0 {
continue
}
var a addr
var negate bool
if p[0] == '!' {
negate = true
p = p[1:]
}
if len(p) == 0 {
return nil, errors.New("knownhosts: negation without following hostname")
}
var err error
if p[0] == '[' {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
return nil, err
}
} else {
a.host, a.port, err = net.SplitHostPort(p)
if err != nil {
a.host = p
a.port = "22"
}
}
hps = append(hps, hostPattern{
negate: negate,
addr: a,
})
}
return hps, nil
}
// KnownKey represents a key declared in a known_hosts file.
type KnownKey struct {
Key ssh.PublicKey
Filename string
Line int
}
func (k *KnownKey) String() string {
return fmt.Sprintf("%s:%d: %s", k.Filename, k.Line, serialize(k.Key))
}
// KeyError is returned if we did not find the key in the host key
// database, or there was a mismatch. Typically, in batch
// applications, this should be interpreted as failure. Interactive
// applications can offer an interactive prompt to the user.
type KeyError struct {
// Want holds the accepted host keys. For each key algorithm,
// there can be one hostkey. If Want is empty, the host is
// unknown. If Want is non-empty, there was a mismatch, which
// can signify a MITM attack.
Want []KnownKey
}
func (u *KeyError) Error() string {
if len(u.Want) == 0 {
return "knownhosts: key is unknown"
}
return "knownhosts: key mismatch"
}
// RevokedError is returned if we found a key that was revoked.
type RevokedError struct {
Revoked KnownKey
}
func (r *RevokedError) Error() string {
return "knownhosts: key is revoked"
}
// check checks a key against the host database. This should not be
// used for verifying certificates.
func (db *hostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error {
if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil {
return &RevokedError{Revoked: *revoked}
}
host, port, err := net.SplitHostPort(remote.String())
if err != nil {
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err)
}
hostToCheck := addr{host, port}
if address != "" {
// Give preference to the hostname if available.
host, port, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err)
}
hostToCheck = addr{host, port}
}
return db.checkAddr(hostToCheck, remoteKey)
}
// checkAddrs checks if we can find the given public key for any of
// the given addresses. If we only find an entry for the IP address,
// or only the hostname, then this still succeeds.
func (db *hostKeyDB) checkAddr(a addr, remoteKey ssh.PublicKey) error {
// TODO(hanwen): are these the right semantics? What if there
// is just a key for the IP address, but not for the
// hostname?
// Algorithm => key.
knownKeys := map[string]KnownKey{}
for _, l := range db.lines {
if l.match(a) {
typ := l.knownKey.Key.Type()
if _, ok := knownKeys[typ]; !ok {
knownKeys[typ] = l.knownKey
}
}
}
keyErr := &KeyError{}
for _, v := range knownKeys {
keyErr.Want = append(keyErr.Want, v)
}
// Unknown remote host.
if len(knownKeys) == 0 {
return keyErr
}
// If the remote host starts using a different, unknown key type, we
// also interpret that as a mismatch.
if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known.Key, remoteKey) {
return keyErr
}
return nil
}
// The Read function parses file contents.
func (db *hostKeyDB) Read(r io.Reader, filename string) error {
scanner := bufio.NewScanner(r)
lineNum := 0
for scanner.Scan() {
lineNum++
line := scanner.Bytes()
line = bytes.TrimSpace(line)
if len(line) == 0 || line[0] == '#' {
continue
}
if err := db.parseLine(line, filename, lineNum); err != nil {
return fmt.Errorf("knownhosts: %s:%d: %v", filename, lineNum, err)
}
}
return scanner.Err()
}
// New creates a host key callback from the given OpenSSH host key
// files. The returned callback is for use in
// ssh.ClientConfig.HostKeyCallback. By preference, the key check
// operates on the hostname if available, i.e. if a server changes its
// IP address, the host key check will still succeed, even though a
// record of the new IP address is not available.
func New(files ...string) (ssh.HostKeyCallback, error) {
db := newHostKeyDB()
for _, fn := range files {
f, err := os.Open(fn)
if err != nil {
return nil, err
}
defer f.Close()
if err := db.Read(f, fn); err != nil {
return nil, err
}
}
var certChecker ssh.CertChecker
certChecker.IsHostAuthority = db.IsHostAuthority
certChecker.IsRevoked = db.IsRevoked
certChecker.HostKeyFallback = db.check
return certChecker.CheckHostKey, nil
}
// Normalize normalizes an address into the form used in known_hosts
func Normalize(address string) string {
host, port, err := net.SplitHostPort(address)
if err != nil {
host = address
port = "22"
}
entry := host
if port != "22" {
entry = "[" + entry + "]:" + port
} else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") {
entry = "[" + entry + "]"
}
return entry
}
// Line returns a line to add append to the known_hosts files.
func Line(addresses []string, key ssh.PublicKey) string {
var trimmed []string
for _, a := range addresses {
trimmed = append(trimmed, Normalize(a))
}
return strings.Join(trimmed, ",") + " " + serialize(key)
}
// HashHostname hashes the given hostname. The hostname is not
// normalized before hashing.
func HashHostname(hostname string) string {
// TODO(hanwen): check if we can safely normalize this always.
salt := make([]byte, sha1.Size)
_, err := rand.Read(salt)
if err != nil {
panic(fmt.Sprintf("crypto/rand failure %v", err))
}
hash := hashHost(hostname, salt)
return encodeHash(sha1HashType, salt, hash)
}
func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) {
if len(encoded) == 0 || encoded[0] != '|' {
err = errors.New("knownhosts: hashed host must start with '|'")
return
}
components := strings.Split(encoded, "|")
if len(components) != 4 {
err = fmt.Errorf("knownhosts: got %d components, want 3", len(components))
return
}
hashType = components[1]
if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil {
return
}
if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil {
return
}
return
}
func encodeHash(typ string, salt []byte, hash []byte) string {
return strings.Join([]string{"",
typ,
base64.StdEncoding.EncodeToString(salt),
base64.StdEncoding.EncodeToString(hash),
}, "|")
}
// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
func hashHost(hostname string, salt []byte) []byte {
mac := hmac.New(sha1.New, salt)
mac.Write([]byte(hostname))
return mac.Sum(nil)
}
type hashedHost struct {
salt []byte
hash []byte
}
const sha1HashType = "1"
func newHashedHost(encoded string) (*hashedHost, error) {
typ, salt, hash, err := decodeHash(encoded)
if err != nil {
return nil, err
}
// The type field seems for future algorithm agility, but it's
// actually hardcoded in openssh currently, see
// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120
if typ != sha1HashType {
return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ)
}
return &hashedHost{salt: salt, hash: hash}, nil
}
func (h *hashedHost) match(a addr) bool {
return bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash)
}

View file

@ -0,0 +1,356 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package knownhosts
import (
"bytes"
"fmt"
"net"
"reflect"
"testing"
"golang.org/x/crypto/ssh"
)
const edKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGBAarftlLeoyf+v+nVchEZII/vna2PCV8FaX4vsF5BX"
const alternateEdKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIIXffBYeYL+WVzVru8npl5JHt2cjlr4ornFTWzoij9sx"
const ecKeyStr = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNLCu01+wpXe3xB5olXCN4SqU2rQu0qjSRKJO4Bg+JRCPU+ENcgdA5srTU8xYDz/GEa4dzK5ldPw4J/gZgSXCMs="
var ecKey, alternateEdKey, edKey ssh.PublicKey
var testAddr = &net.TCPAddr{
IP: net.IP{198, 41, 30, 196},
Port: 22,
}
var testAddr6 = &net.TCPAddr{
IP: net.IP{198, 41, 30, 196,
1, 2, 3, 4,
1, 2, 3, 4,
1, 2, 3, 4,
},
Port: 22,
}
func init() {
var err error
ecKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(ecKeyStr))
if err != nil {
panic(err)
}
edKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(edKeyStr))
if err != nil {
panic(err)
}
alternateEdKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(alternateEdKeyStr))
if err != nil {
panic(err)
}
}
func testDB(t *testing.T, s string) *hostKeyDB {
db := newHostKeyDB()
if err := db.Read(bytes.NewBufferString(s), "testdb"); err != nil {
t.Fatalf("Read: %v", err)
}
return db
}
func TestRevoked(t *testing.T) {
db := testDB(t, "\n\n@revoked * "+edKeyStr+"\n")
want := &RevokedError{
Revoked: KnownKey{
Key: edKey,
Filename: "testdb",
Line: 3,
},
}
if err := db.check("", &net.TCPAddr{
Port: 42,
}, edKey); err == nil {
t.Fatal("no error for revoked key")
} else if !reflect.DeepEqual(want, err) {
t.Fatalf("got %#v, want %#v", want, err)
}
}
func TestHostAuthority(t *testing.T) {
for _, m := range []struct {
authorityFor string
address string
good bool
}{
{authorityFor: "localhost", address: "localhost:22", good: true},
{authorityFor: "localhost", address: "localhost", good: false},
{authorityFor: "localhost", address: "localhost:1234", good: false},
{authorityFor: "[localhost]:1234", address: "localhost:1234", good: true},
{authorityFor: "[localhost]:1234", address: "localhost:22", good: false},
{authorityFor: "[localhost]:1234", address: "localhost", good: false},
} {
db := testDB(t, `@cert-authority `+m.authorityFor+` `+edKeyStr)
if ok := db.IsHostAuthority(db.lines[0].knownKey.Key, m.address); ok != m.good {
t.Errorf("IsHostAuthority: authority %s, address %s, wanted good = %v, got good = %v",
m.authorityFor, m.address, m.good, ok)
}
}
}
func TestBracket(t *testing.T) {
db := testDB(t, `[git.eclipse.org]:29418,[198.41.30.196]:29418 `+edKeyStr)
if err := db.check("git.eclipse.org:29418", &net.TCPAddr{
IP: net.IP{198, 41, 30, 196},
Port: 29418,
}, edKey); err != nil {
t.Errorf("got error %v, want none", err)
}
if err := db.check("git.eclipse.org:29419", &net.TCPAddr{
Port: 42,
}, edKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) > 0 {
t.Fatalf("got Want %v, want []", ke.Want)
}
}
func TestNewKeyType(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, ecKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) == 0 {
t.Fatalf("got empty KeyError.Want")
}
}
func TestSameKeyType(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, alternateEdKey); err == nil {
t.Fatalf("no error for unknown address")
} else if ke, ok := err.(*KeyError); !ok {
t.Fatalf("got type %T, want *KeyError", err)
} else if len(ke.Want) == 0 {
t.Fatalf("got empty KeyError.Want")
} else if got, want := ke.Want[0].Key.Marshal(), edKey.Marshal(); !bytes.Equal(got, want) {
t.Fatalf("got key %q, want %q", got, want)
}
}
func TestIPAddress(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr, edKey); err != nil {
t.Errorf("got error %q, want none", err)
}
}
func TestIPv6Address(t *testing.T) {
str := fmt.Sprintf("%s %s", testAddr6, edKeyStr)
db := testDB(t, str)
if err := db.check("", testAddr6, edKey); err != nil {
t.Errorf("got error %q, want none", err)
}
}
func TestBasic(t *testing.T) {
str := fmt.Sprintf("#comment\n\nserver.org,%s %s\notherhost %s", testAddr, edKeyStr, ecKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", testAddr, edKey); err != nil {
t.Errorf("got error %v, want none", err)
}
want := KnownKey{
Key: edKey,
Filename: "testdb",
Line: 3,
}
if err := db.check("server.org:22", testAddr, ecKey); err == nil {
t.Errorf("succeeded, want KeyError")
} else if ke, ok := err.(*KeyError); !ok {
t.Errorf("got %T, want *KeyError", err)
} else if len(ke.Want) != 1 {
t.Errorf("got %v, want 1 entry", ke)
} else if !reflect.DeepEqual(ke.Want[0], want) {
t.Errorf("got %v, want %v", ke.Want[0], want)
}
}
func TestHostNamePrecedence(t *testing.T) {
var evilAddr = &net.TCPAddr{
IP: net.IP{66, 66, 66, 66},
Port: 22,
}
str := fmt.Sprintf("server.org,%s %s\nevil.org,%s %s", testAddr, edKeyStr, evilAddr, ecKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", evilAddr, ecKey); err == nil {
t.Errorf("check succeeded")
} else if _, ok := err.(*KeyError); !ok {
t.Errorf("got %T, want *KeyError", err)
}
}
func TestDBOrderingPrecedenceKeyType(t *testing.T) {
str := fmt.Sprintf("server.org,%s %s\nserver.org,%s %s", testAddr, edKeyStr, testAddr, alternateEdKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", testAddr, alternateEdKey); err == nil {
t.Errorf("check succeeded")
} else if _, ok := err.(*KeyError); !ok {
t.Errorf("got %T, want *KeyError", err)
}
}
func TestNegate(t *testing.T) {
str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr)
db := testDB(t, str)
if err := db.check("server.org:22", testAddr, ecKey); err == nil {
t.Errorf("succeeded")
} else if ke, ok := err.(*KeyError); !ok {
t.Errorf("got error type %T, want *KeyError", err)
} else if len(ke.Want) != 0 {
t.Errorf("got expected keys %d (first of type %s), want []", len(ke.Want), ke.Want[0].Key.Type())
}
}
func TestWildcard(t *testing.T) {
str := fmt.Sprintf("server*.domain %s", edKeyStr)
db := testDB(t, str)
want := &KeyError{
Want: []KnownKey{{
Filename: "testdb",
Line: 1,
Key: edKey,
}},
}
got := db.check("server.domain:22", &net.TCPAddr{}, ecKey)
if !reflect.DeepEqual(got, want) {
t.Errorf("got %s, want %s", got, want)
}
}
func TestLine(t *testing.T) {
for in, want := range map[string]string{
"server.org": "server.org " + edKeyStr,
"server.org:22": "server.org " + edKeyStr,
"server.org:23": "[server.org]:23 " + edKeyStr,
"[c629:1ec4:102:304:102:304:102:304]:22": "[c629:1ec4:102:304:102:304:102:304] " + edKeyStr,
"[c629:1ec4:102:304:102:304:102:304]:23": "[c629:1ec4:102:304:102:304:102:304]:23 " + edKeyStr,
} {
if got := Line([]string{in}, edKey); got != want {
t.Errorf("Line(%q) = %q, want %q", in, got, want)
}
}
}
func TestWildcardMatch(t *testing.T) {
for _, c := range []struct {
pat, str string
want bool
}{
{"a?b", "abb", true},
{"ab", "abc", false},
{"abc", "ab", false},
{"a*b", "axxxb", true},
{"a*b", "axbxb", true},
{"a*b", "axbxbc", false},
{"a*?", "axbxc", true},
{"a*b*", "axxbxxxxxx", true},
{"a*b*c", "axxbxxxxxxc", true},
{"a*b*?", "axxbxxxxxxc", true},
{"a*b*z", "axxbxxbxxxz", true},
{"a*b*z", "axxbxxzxxxz", true},
{"a*b*z", "axxbxxzxxx", false},
} {
got := wildcardMatch([]byte(c.pat), []byte(c.str))
if got != c.want {
t.Errorf("wildcardMatch(%q, %q) = %v, want %v", c.pat, c.str, got, c.want)
}
}
}
// TODO(hanwen): test coverage for certificates.
const testHostname = "hostname"
// generated with keygen -H -f
const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y="
func TestHostHash(t *testing.T) {
testHostHash(t, testHostname, encodedTestHostnameHash)
}
func TestHashList(t *testing.T) {
encoded := HashHostname(testHostname)
testHostHash(t, testHostname, encoded)
}
func testHostHash(t *testing.T, hostname, encoded string) {
typ, salt, hash, err := decodeHash(encoded)
if err != nil {
t.Fatalf("decodeHash: %v", err)
}
if got := encodeHash(typ, salt, hash); got != encoded {
t.Errorf("got encoding %s want %s", got, encoded)
}
if typ != sha1HashType {
t.Fatalf("got hash type %q, want %q", typ, sha1HashType)
}
got := hashHost(hostname, salt)
if !bytes.Equal(got, hash) {
t.Errorf("got hash %x want %x", got, hash)
}
}
func TestNormalize(t *testing.T) {
for in, want := range map[string]string{
"127.0.0.1:22": "127.0.0.1",
"[127.0.0.1]:22": "127.0.0.1",
"[127.0.0.1]:23": "[127.0.0.1]:23",
"127.0.0.1:23": "[127.0.0.1]:23",
"[a.b.c]:22": "a.b.c",
"[abcd:abcd:abcd:abcd]": "[abcd:abcd:abcd:abcd]",
"[abcd:abcd:abcd:abcd]:22": "[abcd:abcd:abcd:abcd]",
"[abcd:abcd:abcd:abcd]:23": "[abcd:abcd:abcd:abcd]:23",
} {
got := Normalize(in)
if got != want {
t.Errorf("Normalize(%q) = %q, want %q", in, got, want)
}
}
}
func TestHashedHostkeyCheck(t *testing.T) {
str := fmt.Sprintf("%s %s", HashHostname(testHostname), edKeyStr)
db := testDB(t, str)
if err := db.check(testHostname+":22", testAddr, edKey); err != nil {
t.Errorf("check(%s): %v", testHostname, err)
}
want := &KeyError{
Want: []KnownKey{{
Filename: "testdb",
Line: 1,
Key: edKey,
}},
}
if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) {
t.Errorf("got error %v, want %v", got, want)
}
}

View file

@ -9,11 +9,13 @@ package ssh
import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"hash"
)
type macMode struct {
keySize int
etm bool
new func(key []byte) hash.Hash
}
@ -44,10 +46,16 @@ func (t truncatingMAC) Size() int {
func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() }
var macModes = map[string]*macMode{
"hmac-sha1": {20, func(key []byte) hash.Hash {
"hmac-sha2-256-etm@openssh.com": {32, true, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}},
"hmac-sha2-256": {32, false, func(key []byte) hash.Hash {
return hmac.New(sha256.New, key)
}},
"hmac-sha1": {20, false, func(key []byte) hash.Hash {
return hmac.New(sha1.New, key)
}},
"hmac-sha1-96": {20, func(key []byte) hash.Hash {
"hmac-sha1-96": {20, false, func(key []byte) hash.Hash {
return truncatingMAC{12, hmac.New(sha1.New, key)}
}},
}

View file

@ -76,7 +76,7 @@ func memPipe() (a, b packetConn) {
return &t1, &t2
}
func TestmemPipe(t *testing.T) {
func TestMemPipe(t *testing.T) {
a, b := memPipe()
if err := a.writePacket([]byte{42}); err != nil {
t.Fatalf("writePacket: %v", err)

View file

@ -13,6 +13,7 @@ import (
"math/big"
"reflect"
"strconv"
"strings"
)
// These are SSH message type numbers. They are scattered around several
@ -22,10 +23,6 @@ const (
msgUnimplemented = 3
msgDebug = 4
msgNewKeys = 21
// Standard authentication messages
msgUserAuthSuccess = 52
msgUserAuthBanner = 53
)
// SSH messages:
@ -47,7 +44,7 @@ type disconnectMsg struct {
}
func (d *disconnectMsg) Error() string {
return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message)
return fmt.Sprintf("ssh: disconnect, reason %d: %s", d.Reason, d.Message)
}
// See RFC 4253, section 7.1.
@ -124,6 +121,10 @@ type userAuthRequestMsg struct {
Payload []byte `ssh:"rest"`
}
// Used for debug printouts of packets.
type userAuthSuccessMsg struct {
}
// See RFC 4252, section 5.1
const msgUserAuthFailure = 51
@ -132,6 +133,18 @@ type userAuthFailureMsg struct {
PartialSuccess bool
}
// See RFC 4252, section 5.1
const msgUserAuthSuccess = 52
// See RFC 4252, section 5.4
const msgUserAuthBanner = 53
type userAuthBannerMsg struct {
Message string `sshtype:"53"`
// unused, but required to allow message parsing
Language string
}
// See RFC 4256, section 3.2
const msgUserAuthInfoRequest = 60
const msgUserAuthInfoResponse = 61
@ -149,7 +162,7 @@ const msgChannelOpen = 90
type channelOpenMsg struct {
ChanType string `sshtype:"90"`
PeersId uint32
PeersID uint32
PeersWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
@ -158,12 +171,19 @@ type channelOpenMsg struct {
const msgChannelExtendedData = 95
const msgChannelData = 94
// Used for debug print outs of packets.
type channelDataMsg struct {
PeersID uint32 `sshtype:"94"`
Length uint32
Rest []byte `ssh:"rest"`
}
// See RFC 4254, section 5.1.
const msgChannelOpenConfirm = 91
type channelOpenConfirmMsg struct {
PeersId uint32 `sshtype:"91"`
MyId uint32
PeersID uint32 `sshtype:"91"`
MyID uint32
MyWindow uint32
MaxPacketSize uint32
TypeSpecificData []byte `ssh:"rest"`
@ -173,7 +193,7 @@ type channelOpenConfirmMsg struct {
const msgChannelOpenFailure = 92
type channelOpenFailureMsg struct {
PeersId uint32 `sshtype:"92"`
PeersID uint32 `sshtype:"92"`
Reason RejectionReason
Message string
Language string
@ -182,7 +202,7 @@ type channelOpenFailureMsg struct {
const msgChannelRequest = 98
type channelRequestMsg struct {
PeersId uint32 `sshtype:"98"`
PeersID uint32 `sshtype:"98"`
Request string
WantReply bool
RequestSpecificData []byte `ssh:"rest"`
@ -192,28 +212,28 @@ type channelRequestMsg struct {
const msgChannelSuccess = 99
type channelRequestSuccessMsg struct {
PeersId uint32 `sshtype:"99"`
PeersID uint32 `sshtype:"99"`
}
// See RFC 4254, section 5.4.
const msgChannelFailure = 100
type channelRequestFailureMsg struct {
PeersId uint32 `sshtype:"100"`
PeersID uint32 `sshtype:"100"`
}
// See RFC 4254, section 5.3
const msgChannelClose = 97
type channelCloseMsg struct {
PeersId uint32 `sshtype:"97"`
PeersID uint32 `sshtype:"97"`
}
// See RFC 4254, section 5.3
const msgChannelEOF = 96
type channelEOFMsg struct {
PeersId uint32 `sshtype:"96"`
PeersID uint32 `sshtype:"96"`
}
// See RFC 4254, section 4
@ -243,7 +263,7 @@ type globalRequestFailureMsg struct {
const msgChannelWindowAdjust = 93
type windowAdjustMsg struct {
PeersId uint32 `sshtype:"93"`
PeersID uint32 `sshtype:"93"`
AdditionalBytes uint32
}
@ -255,17 +275,19 @@ type userAuthPubKeyOkMsg struct {
PubKey []byte
}
// typeTag returns the type byte for the given type. The type should
// be struct.
func typeTag(structType reflect.Type) byte {
var tag byte
var tagStr string
tagStr = structType.Field(0).Tag.Get("sshtype")
i, err := strconv.Atoi(tagStr)
if err == nil {
tag = byte(i)
// typeTags returns the possible type bytes for the given reflect.Type, which
// should be a struct. The possible values are separated by a '|' character.
func typeTags(structType reflect.Type) (tags []byte) {
tagStr := structType.Field(0).Tag.Get("sshtype")
for _, tag := range strings.Split(tagStr, "|") {
i, err := strconv.Atoi(tag)
if err == nil {
tags = append(tags, byte(i))
}
}
return tag
return tags
}
func fieldError(t reflect.Type, field int, problem string) error {
@ -279,19 +301,34 @@ var errShortRead = errors.New("ssh: short read")
// Unmarshal parses data in SSH wire format into a structure. The out
// argument should be a pointer to struct. If the first member of the
// struct has the "sshtype" tag set to a number in decimal, the packet
// must start that number. In case of error, Unmarshal returns a
// ParseError or UnexpectedMessageError.
// struct has the "sshtype" tag set to a '|'-separated set of numbers
// in decimal, the packet must start with one of those numbers. In
// case of error, Unmarshal returns a ParseError or
// UnexpectedMessageError.
func Unmarshal(data []byte, out interface{}) error {
v := reflect.ValueOf(out).Elem()
structType := v.Type()
expectedType := typeTag(structType)
expectedTypes := typeTags(structType)
var expectedType byte
if len(expectedTypes) > 0 {
expectedType = expectedTypes[0]
}
if len(data) == 0 {
return parseError(expectedType)
}
if expectedType > 0 {
if data[0] != expectedType {
return unexpectedMessageError(expectedType, data[0])
if len(expectedTypes) > 0 {
goodType := false
for _, e := range expectedTypes {
if e > 0 && data[0] == e {
goodType = true
break
}
}
if !goodType {
return fmt.Errorf("ssh: unexpected message type %d (expected one of %v)", data[0], expectedTypes)
}
data = data[1:]
}
@ -375,7 +412,7 @@ func Unmarshal(data []byte, out interface{}) error {
return fieldError(structType, i, "pointer to unsupported type")
}
default:
return fieldError(structType, i, "unsupported type")
return fieldError(structType, i, fmt.Sprintf("unsupported type: %v", t))
}
}
@ -398,9 +435,9 @@ func Marshal(msg interface{}) []byte {
func marshalStruct(out []byte, msg interface{}) []byte {
v := reflect.Indirect(reflect.ValueOf(msg))
msgType := typeTag(v.Type())
if msgType > 0 {
out = append(out, msgType)
msgTypes := typeTags(v.Type())
if len(msgTypes) > 0 {
out = append(out, msgTypes[0])
}
for i, n := 0, v.NumField(); i < n; i++ {
@ -484,11 +521,12 @@ func parseString(in []byte) (out, rest []byte, ok bool) {
return
}
length := binary.BigEndian.Uint32(in)
if uint32(len(in)) < 4+length {
in = in[4:]
if uint32(len(in)) < length {
return
}
out = in[4 : 4+length]
rest = in[4+length:]
out = in[:length]
rest = in[length:]
ok = true
return
}
@ -686,6 +724,8 @@ func decode(packet []byte) (interface{}, error) {
msg = new(kexDHReplyMsg)
case msgUserAuthRequest:
msg = new(userAuthRequestMsg)
case msgUserAuthSuccess:
return new(userAuthSuccessMsg), nil
case msgUserAuthFailure:
msg = new(userAuthFailureMsg)
case msgUserAuthPubKeyOk:
@ -698,6 +738,8 @@ func decode(packet []byte) (interface{}, error) {
msg = new(globalRequestFailureMsg)
case msgChannelOpen:
msg = new(channelOpenMsg)
case msgChannelData:
msg = new(channelDataMsg)
case msgChannelOpenConfirm:
msg = new(channelOpenConfirmMsg)
case msgChannelOpenFailure:

View file

@ -162,6 +162,50 @@ func TestBareMarshal(t *testing.T) {
}
}
func TestUnmarshalShortKexInitPacket(t *testing.T) {
// This used to panic.
// Issue 11348
packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff}
kim := &kexInitMsg{}
if err := Unmarshal(packet, kim); err == nil {
t.Error("truncated packet unmarshaled without error")
}
}
func TestMarshalMultiTag(t *testing.T) {
var res struct {
A uint32 `sshtype:"1|2"`
}
good1 := struct {
A uint32 `sshtype:"1"`
}{
1,
}
good2 := struct {
A uint32 `sshtype:"2"`
}{
1,
}
if e := Unmarshal(Marshal(good1), &res); e != nil {
t.Errorf("error unmarshaling multipart tag: %v", e)
}
if e := Unmarshal(Marshal(good2), &res); e != nil {
t.Errorf("error unmarshaling multipart tag: %v", e)
}
bad1 := struct {
A uint32 `sshtype:"3"`
}{
1,
}
if e := Unmarshal(Marshal(bad1), &res); e == nil {
t.Errorf("bad struct unmarshaled without error")
}
}
func randomBytes(out []byte, rand *rand.Rand) {
for i := 0; i < len(out); i++ {
out[i] = byte(rand.Int31())

View file

@ -116,9 +116,9 @@ func (m *mux) Wait() error {
func newMux(p packetConn) *mux {
m := &mux{
conn: p,
incomingChannels: make(chan NewChannel, 16),
incomingChannels: make(chan NewChannel, chanSize),
globalResponses: make(chan interface{}, 1),
incomingRequests: make(chan *Request, 16),
incomingRequests: make(chan *Request, chanSize),
errCond: newCond(),
}
if debugMux {
@ -131,6 +131,9 @@ func newMux(p packetConn) *mux {
func (m *mux) sendMessage(msg interface{}) error {
p := Marshal(msg)
if debugMux {
log.Printf("send global(%d): %#v", m.chanList.offset, msg)
}
return m.conn.writePacket(p)
}
@ -175,18 +178,6 @@ func (m *mux) ackRequest(ok bool, data []byte) error {
return m.sendMessage(globalRequestFailureMsg{Data: data})
}
// TODO(hanwen): Disconnect is a transport layer message. We should
// probably send and receive Disconnect somewhere in the transport
// code.
// Disconnect sends a disconnect message.
func (m *mux) Disconnect(reason uint32, message string) error {
return m.sendMessage(disconnectMsg{
Reason: reason,
Message: message,
})
}
func (m *mux) Close() error {
return m.conn.Close()
}
@ -236,11 +227,6 @@ func (m *mux) onePacket() error {
}
switch packet[0] {
case msgNewKeys:
// Ignore notification of key change.
return nil
case msgDisconnect:
return m.handleDisconnect(packet)
case msgChannelOpen:
return m.handleChannelOpen(packet)
case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
@ -260,18 +246,6 @@ func (m *mux) onePacket() error {
return ch.handlePacket(packet)
}
func (m *mux) handleDisconnect(packet []byte) error {
var d disconnectMsg
if err := Unmarshal(packet, &d); err != nil {
return err
}
if debugMux {
log.Printf("caught disconnect: %v", d)
}
return &d
}
func (m *mux) handleGlobalPacket(packet []byte) error {
msg, err := decode(packet)
if err != nil {
@ -304,7 +278,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
failMsg := channelOpenFailureMsg{
PeersId: msg.PeersId,
PeersID: msg.PeersID,
Reason: ConnectionFailed,
Message: "invalid request",
Language: "en_US.UTF-8",
@ -313,7 +287,7 @@ func (m *mux) handleChannelOpen(packet []byte) error {
}
c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
c.remoteId = msg.PeersId
c.remoteId = msg.PeersID
c.maxRemotePayload = msg.MaxPacketSize
c.remoteWin.add(msg.PeersWindow)
m.incomingChannels <- c
@ -339,7 +313,7 @@ func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
PeersWindow: ch.myWindow,
MaxPacketSize: ch.maxIncomingPayload,
TypeSpecificData: extra,
PeersId: ch.localId,
PeersID: ch.localId,
}
if err := m.sendMessage(open); err != nil {
return nil, err

View file

@ -20,7 +20,7 @@ func muxPair() (*mux, *mux) {
return s, c
}
// Returns both ends of a channel, and the mux for the the 2nd
// Returns both ends of a channel, and the mux for the 2nd
// channel.
func channelPair(t *testing.T) (*channel, *channel, *mux) {
c, s := muxPair()
@ -108,10 +108,6 @@ func TestMuxReadWrite(t *testing.T) {
if err != nil {
t.Fatalf("Write: %v", err)
}
err = s.Close()
if err != nil {
t.Fatalf("Close: %v", err)
}
}()
var buf [1024]byte
@ -331,7 +327,6 @@ func TestMuxGlobalRequest(t *testing.T) {
ok, data, err)
}
clientMux.Disconnect(0, "")
if !seen {
t.Errorf("never saw 'peek' request")
}
@ -378,28 +373,6 @@ func TestMuxChannelRequestUnblock(t *testing.T) {
}
}
func TestMuxDisconnect(t *testing.T) {
a, b := muxPair()
defer a.Close()
defer b.Close()
go func() {
for r := range b.incomingRequests {
r.Reply(true, nil)
}
}()
a.Disconnect(42, "whatever")
ok, _, err := a.SendRequest("hello", true, nil)
if ok || err == nil {
t.Errorf("got reply after disconnecting")
}
err = b.Wait()
if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
}
}
func TestMuxCloseChannel(t *testing.T) {
r, w, mux := channelPair(t)
defer mux.Close()
@ -522,4 +495,7 @@ func TestDebug(t *testing.T) {
if debugHandshake {
t.Error("handshake debug switched on")
}
if debugTransport {
t.Error("transport debug switched on")
}
}

View file

@ -10,26 +10,38 @@ import (
"fmt"
"io"
"net"
"strings"
)
// The Permissions type holds fine-grained permissions that are
// specific to a user or a specific authentication method for a
// user. Permissions, except for "source-address", must be enforced in
// the server application layer, after successful authentication. The
// Permissions are passed on in ServerConn so a server implementation
// can honor them.
// specific to a user or a specific authentication method for a user.
// The Permissions value for a successful authentication attempt is
// available in ServerConn, so it can be used to pass information from
// the user-authentication phase to the application layer.
type Permissions struct {
// Critical options restrict default permissions. Common
// restrictions are "source-address" and "force-command". If
// the server cannot enforce the restriction, or does not
// recognize it, the user should not authenticate.
// CriticalOptions indicate restrictions to the default
// permissions, and are typically used in conjunction with
// user certificates. The standard for SSH certificates
// defines "force-command" (only allow the given command to
// execute) and "source-address" (only allow connections from
// the given address). The SSH package currently only enforces
// the "source-address" critical option. It is up to server
// implementations to enforce other critical options, such as
// "force-command", by checking them after the SSH handshake
// is successful. In general, SSH servers should reject
// connections that specify critical options that are unknown
// or not supported.
CriticalOptions map[string]string
// Extensions are extra functionality that the server may
// offer on authenticated connections. Common extensions are
// "permit-agent-forwarding", "permit-X11-forwarding". Lack of
// support for an extension does not preclude authenticating a
// user.
// offer on authenticated connections. Lack of support for an
// extension does not preclude authenticating a user. Common
// extensions are "permit-agent-forwarding",
// "permit-X11-forwarding". The Go SSH library currently does
// not act on any extension, and it is up to server
// implementations to honor them. Extensions can be used to
// pass data from the authentication callbacks to the server
// application layer.
Extensions map[string]string
}
@ -44,13 +56,24 @@ type ServerConfig struct {
// authenticating.
NoClientAuth bool
// MaxAuthTries specifies the maximum number of authentication attempts
// permitted per connection. If set to a negative number, the number of
// attempts are unlimited. If set to zero, the number of attempts are limited
// to 6.
MaxAuthTries int
// PasswordCallback, if non-nil, is called when a user
// attempts to authenticate using a password.
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
// PublicKeyCallback, if non-nil, is called when a client attempts public
// key authentication. It must return true if the given public key is
// valid for the given user. For example, see CertChecker.Authenticate.
// PublicKeyCallback, if non-nil, is called when a client
// offers a public key for authentication. It must return a nil error
// if the given public key can be used to authenticate the
// given user. For example, see CertChecker.Authenticate. A
// call to this function does not guarantee that the key
// offered is in fact used to authenticate. To record any data
// depending on the public key, store it inside a
// Permissions.Extensions entry.
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// KeyboardInteractiveCallback, if non-nil, is called when
@ -66,10 +89,16 @@ type ServerConfig struct {
// attempts.
AuthLogCallback func(conn ConnMetadata, method string, err error)
// ServerVersion is the version identification string to
// announce in the public handshake.
// ServerVersion is the version identification string to announce in
// the public handshake.
// If empty, a reasonable default is used.
// Note that RFC 4253 section 4.2 requires that this string start with
// "SSH-2.0-".
ServerVersion string
// BannerCallback, if present, is called and the return string is sent to
// the client after key exchange completed but before authentication.
BannerCallback func(conn ConnMetadata) string
}
// AddHostKey adds a private key as a host key. If an existing host
@ -137,9 +166,16 @@ type ServerConn struct {
// unsuccessful, it closes the connection and returns an error. The
// Request and NewChannel channels must be serviced, or the connection
// will hang.
//
// The returned error may be of type *ServerAuthError for
// authentication errors.
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) {
fullConf := *config
fullConf.SetDefaults()
if fullConf.MaxAuthTries == 0 {
fullConf.MaxAuthTries = 6
}
s := &connection{
sshConn: sshConn{conn: c},
}
@ -168,6 +204,10 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
return nil, errors.New("ssh: server has no host keys")
}
if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && config.KeyboardInteractiveCallback == nil {
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
if config.ServerVersion != "" {
s.serverVersion = []byte(config.ServerVersion)
} else {
@ -182,16 +222,10 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */)
s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config)
if err := s.transport.requestKeyChange(); err != nil {
if err := s.transport.waitSession(); err != nil {
return nil, err
}
if packet, err := s.transport.readPacket(); err != nil {
return nil, err
} else if packet[0] != msgNewKeys {
return nil, unexpectedMessageError(msgNewKeys, packet[0])
}
// We just did the key change, so the session ID is established.
s.sessionID = s.transport.getSessionID()
@ -224,14 +258,14 @@ func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error)
func isAcceptableAlgo(algo string) bool {
switch algo {
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01:
case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, KeyAlgoED25519,
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01, CertAlgoED25519v01:
return true
}
return false
}
func checkSourceAddress(addr net.Addr, sourceAddr string) error {
func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
if addr == nil {
return errors.New("ssh: no address known for client, but source-address match required")
}
@ -241,33 +275,80 @@ func checkSourceAddress(addr net.Addr, sourceAddr string) error {
return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr)
}
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
if bytes.Equal(allowedIP, tcpAddr.IP) {
return nil
}
} else {
_, ipNet, err := net.ParseCIDR(sourceAddr)
if err != nil {
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
}
for _, sourceAddr := range strings.Split(sourceAddrs, ",") {
if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil {
if allowedIP.Equal(tcpAddr.IP) {
return nil
}
} else {
_, ipNet, err := net.ParseCIDR(sourceAddr)
if err != nil {
return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err)
}
if ipNet.Contains(tcpAddr.IP) {
return nil
if ipNet.Contains(tcpAddr.IP) {
return nil
}
}
}
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
}
// ServerAuthError represents server authentication errors and is
// sometimes returned by NewServerConn. It appends any authentication
// errors that may occur, and is returned if all of the authentication
// methods provided by the user failed to authenticate.
type ServerAuthError struct {
// Errors contains authentication errors returned by the authentication
// callback methods. The first entry is typically ErrNoAuth.
Errors []error
}
func (l ServerAuthError) Error() string {
var errs []string
for _, err := range l.Errors {
errs = append(errs, err.Error())
}
return "[" + strings.Join(errs, ", ") + "]"
}
// ErrNoAuth is the error value returned if no
// authentication method has been passed yet. This happens as a normal
// part of the authentication loop, since the client first tries
// 'none' authentication to discover available methods.
// It is returned in ServerAuthError.Errors from NewServerConn.
var ErrNoAuth = errors.New("ssh: no auth passed yet")
func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
var err error
sessionID := s.transport.getSessionID()
var cache pubKeyCache
var perms *Permissions
authFailures := 0
var authErrs []error
var displayedBanner bool
userAuthLoop:
for {
if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 {
discMsg := &disconnectMsg{
Reason: 2,
Message: "too many authentication failures",
}
if err := s.transport.writePacket(Marshal(discMsg)); err != nil {
return nil, err
}
return nil, discMsg
}
var userAuthReq userAuthRequestMsg
if packet, err := s.transport.readPacket(); err != nil {
if err == io.EOF {
return nil, &ServerAuthError{Errors: authErrs}
}
return nil, err
} else if err = Unmarshal(packet, &userAuthReq); err != nil {
return nil, err
@ -278,15 +359,33 @@ userAuthLoop:
}
s.user = userAuthReq.User
if !displayedBanner && config.BannerCallback != nil {
displayedBanner = true
msg := config.BannerCallback(s)
if msg != "" {
bannerMsg := &userAuthBannerMsg{
Message: msg,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
}
perms = nil
authErr := errors.New("no auth passed yet")
authErr := ErrNoAuth
switch userAuthReq.Method {
case "none":
if config.NoClientAuth {
s.user = ""
authErr = nil
}
// allow initial attempt of 'none' without penalty
if authFailures == 0 {
authFailures--
}
case "password":
if config.PasswordCallback == nil {
authErr = errors.New("ssh: password auth not configured")
@ -305,7 +404,7 @@ userAuthLoop:
perms, authErr = config.PasswordCallback(s, password)
case "keyboard-interactive":
if config.KeyboardInteractiveCallback == nil {
authErr = errors.New("ssh: keyboard-interactive auth not configubred")
authErr = errors.New("ssh: keyboard-interactive auth not configured")
break
}
@ -358,6 +457,7 @@ userAuthLoop:
if isQuery {
// The client can query if the given public key
// would be okay.
if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest)
}
@ -384,9 +484,10 @@ userAuthLoop:
// sig.Format. This is usually the same, but
// for certs, the names differ.
if !isAcceptableAlgo(sig.Format) {
authErr = fmt.Errorf("ssh: algorithm %q not accepted", sig.Format)
break
}
signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData)
signedData := buildDataSignedForAuth(sessionID, userAuthReq, algoBytes, pubKeyData)
if err := pubKey.Verify(signedData, sig); err != nil {
return nil, err
@ -399,6 +500,8 @@ userAuthLoop:
authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method)
}
authErrs = append(authErrs, authErr)
if config.AuthLogCallback != nil {
config.AuthLogCallback(s, userAuthReq.Method, authErr)
}
@ -407,6 +510,8 @@ userAuthLoop:
break userAuthLoop
}
authFailures++
var failureMsg userAuthFailureMsg
if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password")
@ -422,12 +527,12 @@ userAuthLoop:
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil {
if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {
return nil, err
}
}
if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
if err := s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil {
return nil, err
}
return perms, nil

View file

@ -9,6 +9,7 @@ package ssh
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
@ -230,6 +231,26 @@ func (s *Session) RequestSubsystem(subsystem string) error {
return err
}
// RFC 4254 Section 6.7.
type ptyWindowChangeMsg struct {
Columns uint32
Rows uint32
Width uint32
Height uint32
}
// WindowChange informs the remote host about a terminal window dimension change to h rows and w columns.
func (s *Session) WindowChange(h, w int) error {
req := ptyWindowChangeMsg{
Columns: uint32(w),
Rows: uint32(h),
Width: uint32(w * 8),
Height: uint32(h * 8),
}
_, err := s.ch.SendRequest("window-change", false, Marshal(&req))
return err
}
// RFC 4254 Section 6.9.
type signalMsg struct {
Signal string
@ -281,9 +302,10 @@ func (s *Session) Start(cmd string) error {
// copying stdin, stdout, and stderr, and exits with a zero exit
// status.
//
// If the command fails to run or doesn't complete successfully, the
// error is of type *ExitError. Other error types may be
// returned for I/O problems.
// If the remote server does not send an exit status, an error of type
// *ExitMissingError is returned. If the command completes
// unsuccessfully or is interrupted by a signal, the error is of type
// *ExitError. Other error types may be returned for I/O problems.
func (s *Session) Run(cmd string) error {
err := s.Start(cmd)
if err != nil {
@ -339,7 +361,7 @@ func (s *Session) Shell() error {
ok, err := s.ch.SendRequest("shell", true, nil)
if err == nil && !ok {
return fmt.Errorf("ssh: cound not start shell")
return errors.New("ssh: could not start shell")
}
if err != nil {
return err
@ -370,9 +392,10 @@ func (s *Session) start() error {
// copying stdin, stdout, and stderr, and exits with a zero exit
// status.
//
// If the command fails to run or doesn't complete successfully, the
// error is of type *ExitError. Other error types may be
// returned for I/O problems.
// If the remote server does not send an exit status, an error of type
// *ExitMissingError is returned. If the command completes
// unsuccessfully or is interrupted by a signal, the error is of type
// *ExitError. Other error types may be returned for I/O problems.
func (s *Session) Wait() error {
if !s.started {
return errors.New("ssh: session not started")
@ -383,7 +406,7 @@ func (s *Session) Wait() error {
s.stdinPipeWriter.Close()
}
var copyError error
for _ = range s.copyFuncs {
for range s.copyFuncs {
if err := <-s.errors; err != nil && copyError == nil {
copyError = err
}
@ -400,8 +423,7 @@ func (s *Session) wait(reqs <-chan *Request) error {
for msg := range reqs {
switch msg.Type {
case "exit-status":
d := msg.Payload
wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3])
wm.status = int(binary.BigEndian.Uint32(msg.Payload))
case "exit-signal":
var sigval struct {
Signal string
@ -431,16 +453,29 @@ func (s *Session) wait(reqs <-chan *Request) error {
if wm.status == -1 {
// exit-status was never sent from server
if wm.signal == "" {
return errors.New("wait: remote command exited without exit status or exit signal")
// signal was not sent either. RFC 4254
// section 6.10 recommends against this
// behavior, but it is allowed, so we let
// clients handle it.
return &ExitMissingError{}
}
wm.status = 128
if _, ok := signals[Signal(wm.signal)]; ok {
wm.status += signals[Signal(wm.signal)]
}
}
return &ExitError{wm}
}
// ExitMissingError is returned if a session is torn down cleanly, but
// the server sends no confirmation of the exit status.
type ExitMissingError struct{}
func (e *ExitMissingError) Error() string {
return "wait: remote command exited without exit status or exit signal"
}
func (s *Session) stdin() {
if s.stdinpipe {
return
@ -601,5 +636,12 @@ func (w Waitmsg) Lang() string {
}
func (w Waitmsg) String() string {
return fmt.Sprintf("Process exited with: %v. Reason was: %v (%v)", w.status, w.msg, w.signal)
str := fmt.Sprintf("Process exited with status %v", w.status)
if w.signal != "" {
str += fmt.Sprintf(" from signal %v", w.signal)
}
if w.msg != "" {
str += fmt.Sprintf(". Reason was: %v", w.msg)
}
return str
}

View file

@ -9,9 +9,11 @@ package ssh
import (
"bytes"
crypto_rand "crypto/rand"
"errors"
"io"
"io/ioutil"
"math/rand"
"net"
"testing"
"golang.org/x/crypto/ssh/terminal"
@ -57,7 +59,8 @@ func dial(handler serverType, t *testing.T) *Client {
}()
config := &ClientConfig{
User: "testuser",
User: "testuser",
HostKeyCallback: InsecureIgnoreHostKey(),
}
conn, chans, reqs, err := NewClientConn(c2, "", config)
@ -295,7 +298,6 @@ func TestUnknownExitSignal(t *testing.T) {
}
}
// Test WaitMsg is not returned if the channel closes abruptly.
func TestExitWithoutStatusOrSignal(t *testing.T) {
conn := dial(exitWithoutSignalOrStatus, t)
defer conn.Close()
@ -311,11 +313,8 @@ func TestExitWithoutStatusOrSignal(t *testing.T) {
if err == nil {
t.Fatalf("expected command to fail but it didn't")
}
_, ok := err.(*ExitError)
if ok {
// you can't actually test for errors.errorString
// because it's not exported.
t.Fatalf("expected *errorString but got %T", err)
if _, ok := err.(*ExitMissingError); !ok {
t.Fatalf("got %T want *ExitMissingError", err)
}
}
@ -643,7 +642,8 @@ func TestSessionID(t *testing.T) {
}
serverConf.AddHostKey(testSigners["ecdsa"])
clientConf := &ClientConfig{
User: "user",
HostKeyCallback: InsecureIgnoreHostKey(),
User: "user",
}
go func() {
@ -678,3 +678,97 @@ func TestSessionID(t *testing.T) {
t.Errorf("client and server SessionID were empty.")
}
}
type noReadConn struct {
readSeen bool
net.Conn
}
func (c *noReadConn) Close() error {
return nil
}
func (c *noReadConn) Read(b []byte) (int, error) {
c.readSeen = true
return 0, errors.New("noReadConn error")
}
func TestInvalidServerConfiguration(t *testing.T) {
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
serveConn := noReadConn{Conn: c1}
serverConf := &ServerConfig{}
NewServerConn(&serveConn, serverConf)
if serveConn.readSeen {
t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key")
}
serverConf.AddHostKey(testSigners["ecdsa"])
NewServerConn(&serveConn, serverConf)
if serveConn.readSeen {
t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method")
}
}
func TestHostKeyAlgorithms(t *testing.T) {
serverConf := &ServerConfig{
NoClientAuth: true,
}
serverConf.AddHostKey(testSigners["rsa"])
serverConf.AddHostKey(testSigners["ecdsa"])
connect := func(clientConf *ClientConfig, want string) {
var alg string
clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error {
alg = key.Type()
return nil
}
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, serverConf)
_, _, _, err = NewClientConn(c2, "", clientConf)
if err != nil {
t.Fatalf("NewClientConn: %v", err)
}
if alg != want {
t.Errorf("selected key algorithm %s, want %s", alg, want)
}
}
// By default, we get the preferred algorithm, which is ECDSA 256.
clientConf := &ClientConfig{
HostKeyCallback: InsecureIgnoreHostKey(),
}
connect(clientConf, KeyAlgoECDSA256)
// Client asks for RSA explicitly.
clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA}
connect(clientConf, KeyAlgoRSA)
c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go NewServerConn(c1, serverConf)
clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"}
_, _, _, err = NewClientConn(c2, "", clientConf)
if err == nil {
t.Fatal("succeeded connecting with unknown hostkey algorithm")
}
}

116
vendor/golang.org/x/crypto/ssh/streamlocal.go generated vendored Normal file
View file

@ -0,0 +1,116 @@
package ssh
import (
"errors"
"io"
"net"
)
// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message
// with "direct-streamlocal@openssh.com" string.
//
// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235
type streamLocalChannelOpenDirectMsg struct {
socketPath string
reserved0 string
reserved1 uint32
}
// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message
// with "forwarded-streamlocal@openssh.com" string.
type forwardedStreamLocalPayload struct {
SocketPath string
Reserved0 string
}
// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message
// with "streamlocal-forward@openssh.com"/"cancel-streamlocal-forward@openssh.com" string.
type streamLocalChannelForwardMsg struct {
socketPath string
}
// ListenUnix is similar to ListenTCP but uses a Unix domain socket.
func (c *Client) ListenUnix(socketPath string) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
m := streamLocalChannelForwardMsg{
socketPath,
}
// send message
ok, _, err := c.SendRequest("streamlocal-forward@openssh.com", true, Marshal(&m))
if err != nil {
return nil, err
}
if !ok {
return nil, errors.New("ssh: streamlocal-forward@openssh.com request denied by peer")
}
ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"})
return &unixListener{socketPath, c, ch}, nil
}
func (c *Client) dialStreamLocal(socketPath string) (Channel, error) {
msg := streamLocalChannelOpenDirectMsg{
socketPath: socketPath,
}
ch, in, err := c.OpenChannel("direct-streamlocal@openssh.com", Marshal(&msg))
if err != nil {
return nil, err
}
go DiscardRequests(in)
return ch, err
}
type unixListener struct {
socketPath string
conn *Client
in <-chan forward
}
// Accept waits for and returns the next connection to the listener.
func (l *unixListener) Accept() (net.Conn, error) {
s, ok := <-l.in
if !ok {
return nil, io.EOF
}
ch, incoming, err := s.newCh.Accept()
if err != nil {
return nil, err
}
go DiscardRequests(incoming)
return &chanConn{
Channel: ch,
laddr: &net.UnixAddr{
Name: l.socketPath,
Net: "unix",
},
raddr: &net.UnixAddr{
Name: "@",
Net: "unix",
},
}, nil
}
// Close closes the listener.
func (l *unixListener) Close() error {
// this also closes the listener.
l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"})
m := streamLocalChannelForwardMsg{
l.socketPath,
}
ok, _, err := l.conn.SendRequest("cancel-streamlocal-forward@openssh.com", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-streamlocal-forward@openssh.com failed")
}
return err
}
// Addr returns the listener's network address.
func (l *unixListener) Addr() net.Addr {
return &net.UnixAddr{
Name: l.socketPath,
Net: "unix",
}
}

View file

@ -20,12 +20,20 @@ import (
// addr. Incoming connections will be available by calling Accept on
// the returned net.Listener. The listener must be serviced, or the
// SSH connection may hang.
// N must be "tcp", "tcp4", "tcp6", or "unix".
func (c *Client) Listen(n, addr string) (net.Listener, error) {
laddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
switch n {
case "tcp", "tcp4", "tcp6":
laddr, err := net.ResolveTCPAddr(n, addr)
if err != nil {
return nil, err
}
return c.ListenTCP(laddr)
case "unix":
return c.ListenUnix(addr)
default:
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
}
return c.ListenTCP(laddr)
}
// Automatic port allocation is broken with OpenSSH before 6.0. See
@ -82,10 +90,19 @@ type channelForwardMsg struct {
rport uint32
}
// handleForwards starts goroutines handling forwarded connections.
// It's called on first use by (*Client).ListenTCP to not launch
// goroutines until needed.
func (c *Client) handleForwards() {
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-tcpip"))
go c.forwards.handleChannels(c.HandleChannelOpen("forwarded-streamlocal@openssh.com"))
}
// ListenTCP requests the remote peer open a listening socket
// on laddr. Incoming connections will be available by calling
// Accept on the returned net.Listener.
func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
c.handleForwardsOnce.Do(c.handleForwards)
if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) {
return c.autoPortListenWorkaround(laddr)
}
@ -116,7 +133,7 @@ func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) {
}
// Register this forward, using the port number we obtained.
ch := c.forwards.add(*laddr)
ch := c.forwards.add(laddr)
return &tcpListener{laddr, c, ch}, nil
}
@ -131,7 +148,7 @@ type forwardList struct {
// forwardEntry represents an established mapping of a laddr on a
// remote ssh server to a channel connected to a tcpListener.
type forwardEntry struct {
laddr net.TCPAddr
laddr net.Addr
c chan forward
}
@ -139,16 +156,16 @@ type forwardEntry struct {
// arguments to add/remove/lookup should be address as specified in
// the original forward-request.
type forward struct {
newCh NewChannel // the ssh client channel underlying this forward
raddr *net.TCPAddr // the raddr of the incoming connection
newCh NewChannel // the ssh client channel underlying this forward
raddr net.Addr // the raddr of the incoming connection
}
func (l *forwardList) add(addr net.TCPAddr) chan forward {
func (l *forwardList) add(addr net.Addr) chan forward {
l.Lock()
defer l.Unlock()
f := forwardEntry{
addr,
make(chan forward, 1),
laddr: addr,
c: make(chan forward, 1),
}
l.entries = append(l.entries, f)
return f.c
@ -176,44 +193,69 @@ func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) {
func (l *forwardList) handleChannels(in <-chan NewChannel) {
for ch := range in {
var payload forwardedTCPPayload
if err := Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
continue
}
var (
laddr net.Addr
raddr net.Addr
err error
)
switch channelType := ch.ChannelType(); channelType {
case "forwarded-tcpip":
var payload forwardedTCPPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error())
continue
}
// RFC 4254 section 7.2 specifies that incoming
// addresses should list the address, in string
// format. It is implied that this should be an IP
// address, as it would be impossible to connect to it
// otherwise.
laddr, err := parseTCPAddr(payload.Addr, payload.Port)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
// RFC 4254 section 7.2 specifies that incoming
// addresses should list the address, in string
// format. It is implied that this should be an IP
// address, as it would be impossible to connect to it
// otherwise.
laddr, err = parseTCPAddr(payload.Addr, payload.Port)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort)
if err != nil {
ch.Reject(ConnectionFailed, err.Error())
continue
}
if ok := l.forward(*laddr, *raddr, ch); !ok {
case "forwarded-streamlocal@openssh.com":
var payload forwardedStreamLocalPayload
if err = Unmarshal(ch.ExtraData(), &payload); err != nil {
ch.Reject(ConnectionFailed, "could not parse forwarded-streamlocal@openssh.com payload: "+err.Error())
continue
}
laddr = &net.UnixAddr{
Name: payload.SocketPath,
Net: "unix",
}
raddr = &net.UnixAddr{
Name: "@",
Net: "unix",
}
default:
panic(fmt.Errorf("ssh: unknown channel type %s", channelType))
}
if ok := l.forward(laddr, raddr, ch); !ok {
// Section 7.2, implementations MUST reject spurious incoming
// connections.
ch.Reject(Prohibited, "no forward for address")
continue
}
}
}
// remove removes the forward entry, and the channel feeding its
// listener.
func (l *forwardList) remove(addr net.TCPAddr) {
func (l *forwardList) remove(addr net.Addr) {
l.Lock()
defer l.Unlock()
for i, f := range l.entries {
if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port {
if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() {
l.entries = append(l.entries[:i], l.entries[i+1:]...)
close(f.c)
return
@ -231,12 +273,12 @@ func (l *forwardList) closeAll() {
l.entries = nil
}
func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool {
func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool {
l.Lock()
defer l.Unlock()
for _, f := range l.entries {
if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port {
f.c <- forward{ch, &raddr}
if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() {
f.c <- forward{newCh: ch, raddr: raddr}
return true
}
}
@ -262,7 +304,7 @@ func (l *tcpListener) Accept() (net.Conn, error) {
}
go DiscardRequests(incoming)
return &tcpChanConn{
return &chanConn{
Channel: ch,
laddr: l.laddr,
raddr: s.raddr,
@ -277,7 +319,7 @@ func (l *tcpListener) Close() error {
}
// this also closes the listener.
l.conn.forwards.remove(*l.laddr)
l.conn.forwards.remove(l.laddr)
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m))
if err == nil && !ok {
err = errors.New("ssh: cancel-tcpip-forward failed")
@ -293,29 +335,52 @@ func (l *tcpListener) Addr() net.Addr {
// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) {
// Parse the address into host and numeric port.
host, portString, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
var ch Channel
switch n {
case "tcp", "tcp4", "tcp6":
// Parse the address into host and numeric port.
host, portString, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port))
if err != nil {
return nil, err
}
// Use a zero address for local and remote address.
zeroAddr := &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
return &chanConn{
Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
case "unix":
var err error
ch, err = c.dialStreamLocal(addr)
if err != nil {
return nil, err
}
return &chanConn{
Channel: ch,
laddr: &net.UnixAddr{
Name: "@",
Net: "unix",
},
raddr: &net.UnixAddr{
Name: addr,
Net: "unix",
},
}, nil
default:
return nil, fmt.Errorf("ssh: unsupported protocol: %s", n)
}
port, err := strconv.ParseUint(portString, 10, 16)
if err != nil {
return nil, err
}
// Use a zero address for local and remote address.
zeroAddr := &net.TCPAddr{
IP: net.IPv4zero,
Port: 0,
}
ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port))
if err != nil {
return nil, err
}
return &tcpChanConn{
Channel: ch,
laddr: zeroAddr,
raddr: zeroAddr,
}, nil
}
// DialTCP connects to the remote address raddr on the network net,
@ -332,7 +397,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error)
if err != nil {
return nil, err
}
return &tcpChanConn{
return &chanConn{
Channel: ch,
laddr: laddr,
raddr: raddr,
@ -355,6 +420,9 @@ func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel
lport: uint32(lport),
}
ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg))
if err != nil {
return nil, err
}
go DiscardRequests(in)
return ch, err
}
@ -363,26 +431,26 @@ type tcpChan struct {
Channel // the backing channel
}
// tcpChanConn fulfills the net.Conn interface without
// chanConn fulfills the net.Conn interface without
// the tcpChan having to hold laddr or raddr directly.
type tcpChanConn struct {
type chanConn struct {
Channel
laddr, raddr net.Addr
}
// LocalAddr returns the local network address.
func (t *tcpChanConn) LocalAddr() net.Addr {
func (t *chanConn) LocalAddr() net.Addr {
return t.laddr
}
// RemoteAddr returns the remote network address.
func (t *tcpChanConn) RemoteAddr() net.Addr {
func (t *chanConn) RemoteAddr() net.Addr {
return t.raddr
}
// SetDeadline sets the read and write deadlines associated
// with the connection.
func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
func (t *chanConn) SetDeadline(deadline time.Time) error {
if err := t.SetReadDeadline(deadline); err != nil {
return err
}
@ -393,12 +461,14 @@ func (t *tcpChanConn) SetDeadline(deadline time.Time) error {
// A zero value for t means Read will not time out.
// After the deadline, the error from Read will implement net.Error
// with Timeout() == true.
func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error {
func (t *chanConn) SetReadDeadline(deadline time.Time) error {
// for compatibility with previous version,
// the error message contains "tcpChan"
return errors.New("ssh: tcpChan: deadline not supported")
}
// SetWriteDeadline exists to satisfy the net.Conn interface
// but is not implemented by this type. It always returns an error.
func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error {
func (t *chanConn) SetWriteDeadline(deadline time.Time) error {
return errors.New("ssh: tcpChan: deadline not supported")
}

View file

@ -132,8 +132,11 @@ const (
keyPasteEnd
)
var pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'}
var pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'}
var (
crlf = []byte{'\r', '\n'}
pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'}
pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'}
)
// bytesToKey tries to parse a key sequence from b. If successful, it returns
// the key and the remainder of the input. Otherwise it returns utf8.RuneError.
@ -333,7 +336,7 @@ func (t *Terminal) advanceCursor(places int) {
// So, if we are stopping at the end of a line, we
// need to write a newline so that our cursor can be
// advanced to the next line.
t.outBuf = append(t.outBuf, '\n')
t.outBuf = append(t.outBuf, '\r', '\n')
}
}
@ -593,6 +596,35 @@ func (t *Terminal) writeLine(line []rune) {
}
}
// writeWithCRLF writes buf to w but replaces all occurrences of \n with \r\n.
func writeWithCRLF(w io.Writer, buf []byte) (n int, err error) {
for len(buf) > 0 {
i := bytes.IndexByte(buf, '\n')
todo := len(buf)
if i >= 0 {
todo = i
}
var nn int
nn, err = w.Write(buf[:todo])
n += nn
if err != nil {
return n, err
}
buf = buf[todo:]
if i >= 0 {
if _, err = w.Write(crlf); err != nil {
return n, err
}
n++
buf = buf[1:]
}
}
return n, nil
}
func (t *Terminal) Write(buf []byte) (n int, err error) {
t.lock.Lock()
defer t.lock.Unlock()
@ -600,7 +632,7 @@ func (t *Terminal) Write(buf []byte) (n int, err error) {
if t.cursorX == 0 && t.cursorY == 0 {
// This is the easy case: there's nothing on the screen that we
// have to move out of the way.
return t.c.Write(buf)
return writeWithCRLF(t.c, buf)
}
// We have a prompt and possibly user input on the screen. We
@ -620,7 +652,7 @@ func (t *Terminal) Write(buf []byte) (n int, err error) {
}
t.outBuf = t.outBuf[:0]
if n, err = t.c.Write(buf); err != nil {
if n, err = writeWithCRLF(t.c, buf); err != nil {
return
}
@ -740,8 +772,6 @@ func (t *Terminal) readLine() (line string, err error) {
t.remainder = t.inBuf[:n+len(t.remainder)]
}
panic("unreachable") // for Go 1.0.
}
// SetPrompt sets the prompt to be used when reading subsequent lines.
@ -890,3 +920,32 @@ func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) {
}
return s.entries[index], true
}
// readPasswordLine reads from reader until it finds \n or io.EOF.
// The slice returned does not include the \n.
// readPasswordLine also ignores any \r it finds.
func readPasswordLine(reader io.Reader) ([]byte, error) {
var buf [1]byte
var ret []byte
for {
n, err := reader.Read(buf[:])
if n > 0 {
switch buf[0] {
case '\n':
return ret, nil
case '\r':
// remove \r from passwords on Windows
default:
ret = append(ret, buf[0])
}
continue
}
if err != nil {
if err == io.EOF && len(ret) > 0 {
return ret, nil
}
return ret, err
}
}
}

View file

@ -2,10 +2,15 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build aix darwin dragonfly freebsd linux,!appengine netbsd openbsd windows plan9 solaris
package terminal
import (
"bytes"
"io"
"os"
"runtime"
"testing"
)
@ -267,3 +272,87 @@ func TestTerminalSetSize(t *testing.T) {
}
}
}
func TestReadPasswordLineEnd(t *testing.T) {
var tests = []struct {
input string
want string
}{
{"\n", ""},
{"\r\n", ""},
{"test\r\n", "test"},
{"testtesttesttes\n", "testtesttesttes"},
{"testtesttesttes\r\n", "testtesttesttes"},
{"testtesttesttesttest\n", "testtesttesttesttest"},
{"testtesttesttesttest\r\n", "testtesttesttesttest"},
}
for _, test := range tests {
buf := new(bytes.Buffer)
if _, err := buf.WriteString(test.input); err != nil {
t.Fatal(err)
}
have, err := readPasswordLine(buf)
if err != nil {
t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
continue
}
if string(have) != test.want {
t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
continue
}
if _, err = buf.WriteString(test.input); err != nil {
t.Fatal(err)
}
have, err = readPasswordLine(buf)
if err != nil {
t.Errorf("readPasswordLine(%q) failed: %v", test.input, err)
continue
}
if string(have) != test.want {
t.Errorf("readPasswordLine(%q) returns %q, but %q is expected", test.input, string(have), test.want)
continue
}
}
}
func TestMakeRawState(t *testing.T) {
fd := int(os.Stdout.Fd())
if !IsTerminal(fd) {
t.Skip("stdout is not a terminal; skipping test")
}
st, err := GetState(fd)
if err != nil {
t.Fatalf("failed to get terminal state from GetState: %s", err)
}
if runtime.GOOS == "darwin" && (runtime.GOARCH == "arm" || runtime.GOARCH == "arm64") {
t.Skip("MakeRaw not allowed on iOS; skipping test")
}
defer Restore(fd, st)
raw, err := MakeRaw(fd)
if err != nil {
t.Fatalf("failed to get terminal state from MakeRaw: %s", err)
}
if *st != *raw {
t.Errorf("states do not match; was %v, expected %v", raw, st)
}
}
func TestOutputNewlines(t *testing.T) {
// \n should be changed to \r\n in terminal output.
buf := new(bytes.Buffer)
term := NewTerminal(buf, ">")
term.Write([]byte("1\n2\n"))
output := string(buf.Bytes())
const expected = "1\r\n2\r\n"
if output != expected {
t.Errorf("incorrect output: was %q, expected %q", output, expected)
}
}

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux,!appengine netbsd openbsd
// +build aix darwin dragonfly freebsd linux,!appengine netbsd openbsd
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
@ -17,36 +17,41 @@
package terminal // import "golang.org/x/crypto/ssh/terminal"
import (
"io"
"syscall"
"unsafe"
"golang.org/x/sys/unix"
)
// State contains the state of a terminal.
type State struct {
termios syscall.Termios
termios unix.Termios
}
// IsTerminal returns true if the given file descriptor is a terminal.
// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
var termios syscall.Termios
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0)
return err == 0
_, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
return err == nil
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
var oldState State
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 {
termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if err != nil {
return nil, err
}
newState := oldState.termios
newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF
newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 {
oldState := State{termios: *termios}
// This attempts to replicate the behaviour documented for cfmakeraw in
// the termios(3) manpage.
termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
termios.Oflag &^= unix.OPOST
termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
termios.Cflag &^= unix.CSIZE | unix.PARENB
termios.Cflag |= unix.CS8
termios.Cc[unix.VMIN] = 1
termios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, ioctlWriteTermios, termios); err != nil {
return nil, err
}
@ -56,73 +61,54 @@ func MakeRaw(fd int) (*State, error) {
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
var oldState State
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 {
termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if err != nil {
return nil, err
}
return &oldState, nil
return &State{termios: *termios}, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0)
return err
return unix.IoctlSetTermios(fd, ioctlWriteTermios, &state.termios)
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
var dimensions [4]uint16
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 {
ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
if err != nil {
return -1, -1, err
}
return int(dimensions[1]), int(dimensions[0]), nil
return int(ws.Col), int(ws.Row), nil
}
// passwordReader is an io.Reader that reads from a specific file descriptor.
type passwordReader int
func (r passwordReader) Read(buf []byte) (int, error) {
return unix.Read(int(r), buf)
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
var oldState syscall.Termios
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 {
termios, err := unix.IoctlGetTermios(fd, ioctlReadTermios)
if err != nil {
return nil, err
}
newState := oldState
newState.Lflag &^= syscall.ECHO
newState.Lflag |= syscall.ICANON | syscall.ISIG
newState.Iflag |= syscall.ICRNL
if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 {
newState := *termios
newState.Lflag &^= unix.ECHO
newState.Lflag |= unix.ICANON | unix.ISIG
newState.Iflag |= unix.ICRNL
if err := unix.IoctlSetTermios(fd, ioctlWriteTermios, &newState); err != nil {
return nil, err
}
defer func() {
syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0)
}()
defer unix.IoctlSetTermios(fd, ioctlWriteTermios, termios)
var buf [16]byte
var ret []byte
for {
n, err := syscall.Read(fd, buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
return readPasswordLine(passwordReader(fd))
}

12
vendor/golang.org/x/crypto/ssh/terminal/util_aix.go generated vendored Normal file
View file

@ -0,0 +1,12 @@
// Copyright 2018 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build aix
package terminal
import "golang.org/x/sys/unix"
const ioctlReadTermios = unix.TCGETS
const ioctlWriteTermios = unix.TCSETS

View file

@ -6,7 +6,7 @@
package terminal
import "syscall"
import "golang.org/x/sys/unix"
const ioctlReadTermios = syscall.TIOCGETA
const ioctlWriteTermios = syscall.TIOCSETA
const ioctlReadTermios = unix.TIOCGETA
const ioctlWriteTermios = unix.TIOCSETA

View file

@ -4,8 +4,7 @@
package terminal
// These constants are declared here, rather than importing
// them from the syscall package as some syscall packages, even
// on linux, for example gccgo, do not declare them.
const ioctlReadTermios = 0x5401 // syscall.TCGETS
const ioctlWriteTermios = 0x5402 // syscall.TCSETS
import "golang.org/x/sys/unix"
const ioctlReadTermios = unix.TCGETS
const ioctlWriteTermios = unix.TCSETS

58
vendor/golang.org/x/crypto/ssh/terminal/util_plan9.go generated vendored Normal file
View file

@ -0,0 +1,58 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package terminal provides support functions for dealing with terminals, as
// commonly found on UNIX systems.
//
// Putting a terminal into raw mode is the most common requirement:
//
// oldState, err := terminal.MakeRaw(0)
// if err != nil {
// panic(err)
// }
// defer terminal.Restore(0, oldState)
package terminal
import (
"fmt"
"runtime"
)
type State struct{}
// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
return false
}
// MakeRaw put the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
func MakeRaw(fd int) (*State, error) {
return nil, fmt.Errorf("terminal: MakeRaw not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
return nil, fmt.Errorf("terminal: GetState not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
return fmt.Errorf("terminal: Restore not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
return 0, 0, fmt.Errorf("terminal: GetSize not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
return nil, fmt.Errorf("terminal: ReadPassword not implemented on %s/%s", runtime.GOOS, runtime.GOARCH)
}

124
vendor/golang.org/x/crypto/ssh/terminal/util_solaris.go generated vendored Normal file
View file

@ -0,0 +1,124 @@
// Copyright 2015 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build solaris
package terminal // import "golang.org/x/crypto/ssh/terminal"
import (
"golang.org/x/sys/unix"
"io"
"syscall"
)
// State contains the state of a terminal.
type State struct {
termios unix.Termios
}
// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
_, err := unix.IoctlGetTermio(fd, unix.TCGETA)
return err == nil
}
// ReadPassword reads a line of input from a terminal without local echo. This
// is commonly used for inputting passwords and other sensitive data. The slice
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
// see also: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libast/common/uwin/getpass.c
val, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
oldState := *val
newState := oldState
newState.Lflag &^= syscall.ECHO
newState.Lflag |= syscall.ICANON | syscall.ISIG
newState.Iflag |= syscall.ICRNL
err = unix.IoctlSetTermios(fd, unix.TCSETS, &newState)
if err != nil {
return nil, err
}
defer unix.IoctlSetTermios(fd, unix.TCSETS, &oldState)
var buf [16]byte
var ret []byte
for {
n, err := syscall.Read(fd, buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
}
return ret, nil
}
// MakeRaw puts the terminal connected to the given file descriptor into raw
// mode and returns the previous state of the terminal so that it can be
// restored.
// see http://cr.illumos.org/~webrev/andy_js/1060/
func MakeRaw(fd int) (*State, error) {
termios, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
oldState := State{termios: *termios}
termios.Iflag &^= unix.IGNBRK | unix.BRKINT | unix.PARMRK | unix.ISTRIP | unix.INLCR | unix.IGNCR | unix.ICRNL | unix.IXON
termios.Oflag &^= unix.OPOST
termios.Lflag &^= unix.ECHO | unix.ECHONL | unix.ICANON | unix.ISIG | unix.IEXTEN
termios.Cflag &^= unix.CSIZE | unix.PARENB
termios.Cflag |= unix.CS8
termios.Cc[unix.VMIN] = 1
termios.Cc[unix.VTIME] = 0
if err := unix.IoctlSetTermios(fd, unix.TCSETS, termios); err != nil {
return nil, err
}
return &oldState, nil
}
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, oldState *State) error {
return unix.IoctlSetTermios(fd, unix.TCSETS, &oldState.termios)
}
// GetState returns the current state of a terminal which may be useful to
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
termios, err := unix.IoctlGetTermios(fd, unix.TCGETS)
if err != nil {
return nil, err
}
return &State{termios: *termios}, nil
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ)
if err != nil {
return 0, 0, err
}
return int(ws.Col), int(ws.Row), nil
}

View file

@ -17,65 +17,20 @@
package terminal
import (
"io"
"syscall"
"unsafe"
)
"os"
const (
enableLineInput = 2
enableEchoInput = 4
enableProcessedInput = 1
enableWindowInput = 8
enableMouseInput = 16
enableInsertMode = 32
enableQuickEditMode = 64
enableExtendedFlags = 128
enableAutoPosition = 256
enableProcessedOutput = 1
enableWrapAtEolOutput = 2
)
var kernel32 = syscall.NewLazyDLL("kernel32.dll")
var (
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
)
type (
short int16
word uint16
coord struct {
x short
y short
}
smallRect struct {
left short
top short
right short
bottom short
}
consoleScreenBufferInfo struct {
size coord
cursorPosition coord
attributes word
window smallRect
maximumWindowSize coord
}
"golang.org/x/sys/windows"
)
type State struct {
mode uint32
}
// IsTerminal returns true if the given file descriptor is a terminal.
// IsTerminal returns whether the given file descriptor is a terminal.
func IsTerminal(fd int) bool {
var st uint32
r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
return r != 0 && e == 0
err := windows.GetConsoleMode(windows.Handle(fd), &st)
return err == nil
}
// MakeRaw put the terminal connected to the given file descriptor into raw
@ -83,14 +38,12 @@ func IsTerminal(fd int) bool {
// restored.
func MakeRaw(fd int) (*State, error) {
var st uint32
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
if e != 0 {
return nil, error(e)
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
st &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput)
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0)
if e != 0 {
return nil, error(e)
raw := st &^ (windows.ENABLE_ECHO_INPUT | windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
if err := windows.SetConsoleMode(windows.Handle(fd), raw); err != nil {
return nil, err
}
return &State{st}, nil
}
@ -99,9 +52,8 @@ func MakeRaw(fd int) (*State, error) {
// restore the terminal after a signal.
func GetState(fd int) (*State, error) {
var st uint32
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
if e != 0 {
return nil, error(e)
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
return &State{st}, nil
}
@ -109,18 +61,16 @@ func GetState(fd int) (*State, error) {
// Restore restores the terminal connected to the given file descriptor to a
// previous state.
func Restore(fd int, state *State) error {
_, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0)
return err
return windows.SetConsoleMode(windows.Handle(fd), state.mode)
}
// GetSize returns the dimensions of the given terminal.
func GetSize(fd int) (width, height int, err error) {
var info consoleScreenBufferInfo
_, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0)
if e != 0 {
return 0, 0, error(e)
var info windows.ConsoleScreenBufferInfo
if err := windows.GetConsoleScreenBufferInfo(windows.Handle(fd), &info); err != nil {
return 0, 0, err
}
return int(info.size.x), int(info.size.y), nil
return int(info.Size.X), int(info.Size.Y), nil
}
// ReadPassword reads a line of input from a terminal without local echo. This
@ -128,47 +78,26 @@ func GetSize(fd int) (width, height int, err error) {
// returned does not include the \n.
func ReadPassword(fd int) ([]byte, error) {
var st uint32
_, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0)
if e != 0 {
return nil, error(e)
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err
}
old := st
st &^= (enableEchoInput)
st |= (enableProcessedInput | enableLineInput | enableProcessedOutput)
_, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0)
if e != 0 {
return nil, error(e)
st &^= (windows.ENABLE_ECHO_INPUT)
st |= (windows.ENABLE_PROCESSED_INPUT | windows.ENABLE_LINE_INPUT | windows.ENABLE_PROCESSED_OUTPUT)
if err := windows.SetConsoleMode(windows.Handle(fd), st); err != nil {
return nil, err
}
defer func() {
syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0)
}()
defer windows.SetConsoleMode(windows.Handle(fd), old)
var buf [16]byte
var ret []byte
for {
n, err := syscall.Read(syscall.Handle(fd), buf[:])
if err != nil {
return nil, err
}
if n == 0 {
if len(ret) == 0 {
return nil, io.EOF
}
break
}
if buf[n-1] == '\n' {
n--
}
if n > 0 && buf[n-1] == '\r' {
n--
}
ret = append(ret, buf[:n]...)
if n < len(buf) {
break
}
var h windows.Handle
p, _ := windows.GetCurrentProcess()
if err := windows.DuplicateHandle(p, windows.Handle(fd), p, &h, 0, false, windows.DUPLICATE_SAME_ACCESS); err != nil {
return nil, err
}
return ret, nil
f := os.NewFile(uintptr(h), "stdin")
defer f.Close()
return readPasswordLine(f)
}

View file

@ -21,7 +21,16 @@ func TestAgentForward(t *testing.T) {
defer conn.Close()
keyring := agent.NewKeyring()
keyring.Add(testPrivateKeys["dsa"], nil, "")
if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["dsa"]}); err != nil {
t.Fatalf("Error adding key: %s", err)
}
if err := keyring.Add(agent.AddedKey{
PrivateKey: testPrivateKeys["dsa"],
ConfirmBeforeUse: true,
LifetimeSecs: 3600,
}); err != nil {
t.Fatalf("Error adding key with constraints: %s", err)
}
pub := testPublicKeys["dsa"]
sess, err := conn.NewSession()

32
vendor/golang.org/x/crypto/ssh/test/banner_test.go generated vendored Normal file
View file

@ -0,0 +1,32 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build darwin dragonfly freebsd linux netbsd openbsd
package test
import (
"testing"
)
func TestBannerCallbackAgainstOpenSSH(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
clientConf := clientConfig()
var receivedBanner string
clientConf.BannerCallback = func(message string) error {
receivedBanner = message
return nil
}
conn := server.Dial(clientConf)
defer conn.Close()
expected := "Server Banner"
if receivedBanner != expected {
t.Fatalf("got %v; want %v", receivedBanner, expected)
}
}

View file

@ -7,12 +7,14 @@
package test
import (
"bytes"
"crypto/rand"
"testing"
"golang.org/x/crypto/ssh"
)
// Test both logging in with a cert, and also that the certificate presented by an OpenSSH host can be validated correctly
func TestCertLogin(t *testing.T) {
s := newServer(t)
defer s.Shutdown()
@ -37,11 +39,39 @@ func TestCertLogin(t *testing.T) {
conf := &ssh.ClientConfig{
User: username(),
HostKeyCallback: (&ssh.CertChecker{
IsHostAuthority: func(pk ssh.PublicKey, addr string) bool {
return bytes.Equal(pk.Marshal(), testPublicKeys["ca"].Marshal())
},
}).CheckHostKey,
}
conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner))
client, err := s.TryDial(conf)
if err != nil {
t.Fatalf("TryDial: %v", err)
for _, test := range []struct {
addr string
succeed bool
}{
{addr: "host.example.com:22", succeed: true},
{addr: "host.example.com:10000", succeed: true}, // non-standard port must be OK
{addr: "host.example.com", succeed: false}, // port must be specified
{addr: "host.ex4mple.com:22", succeed: false}, // wrong host
} {
client, err := s.TryDialWithAddr(conf, test.addr)
// Always close client if opened successfully
if err == nil {
client.Close()
}
// Now evaluate whether the test failed or passed
if test.succeed {
if err != nil {
t.Fatalf("TryDialWithAddr: %v", err)
}
} else {
if err == nil {
t.Fatalf("TryDialWithAddr, unexpected success")
}
}
}
client.Close()
}

128
vendor/golang.org/x/crypto/ssh/test/dial_unix_test.go generated vendored Normal file
View file

@ -0,0 +1,128 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows
package test
// direct-tcpip and direct-streamlocal functional tests
import (
"fmt"
"io"
"io/ioutil"
"net"
"strings"
"testing"
)
type dialTester interface {
TestServerConn(t *testing.T, c net.Conn)
TestClientConn(t *testing.T, c net.Conn)
}
func testDial(t *testing.T, n, listenAddr string, x dialTester) {
server := newServer(t)
defer server.Shutdown()
sshConn := server.Dial(clientConfig())
defer sshConn.Close()
l, err := net.Listen(n, listenAddr)
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer l.Close()
testData := fmt.Sprintf("hello from %s, %s", n, listenAddr)
go func() {
for {
c, err := l.Accept()
if err != nil {
break
}
x.TestServerConn(t, c)
io.WriteString(c, testData)
c.Close()
}
}()
conn, err := sshConn.Dial(n, l.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
x.TestClientConn(t, conn)
defer conn.Close()
b, err := ioutil.ReadAll(conn)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
t.Logf("got %q", string(b))
if string(b) != testData {
t.Fatalf("expected %q, got %q", testData, string(b))
}
}
type tcpDialTester struct {
listenAddr string
}
func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) {
host := strings.Split(x.listenAddr, ":")[0]
prefix := host + ":"
if !strings.HasPrefix(c.LocalAddr().String(), prefix) {
t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String())
}
if !strings.HasPrefix(c.RemoteAddr().String(), prefix) {
t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String())
}
}
func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) {
// we use zero addresses. see *Client.Dial.
if c.LocalAddr().String() != "0.0.0.0:0" {
t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String())
}
if c.RemoteAddr().String() != "0.0.0.0:0" {
t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String())
}
}
func TestDialTCP(t *testing.T) {
x := &tcpDialTester{
listenAddr: "127.0.0.1:0",
}
testDial(t, "tcp", x.listenAddr, x)
}
type unixDialTester struct {
listenAddr string
}
func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) {
if c.LocalAddr().String() != x.listenAddr {
t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String())
}
if c.RemoteAddr().String() != "@" {
t.Fatalf("expected \"@\", got %q", c.RemoteAddr().String())
}
}
func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) {
if c.RemoteAddr().String() != x.listenAddr {
t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String())
}
if c.LocalAddr().String() != "@" {
t.Fatalf("expected \"@\", got %q", c.LocalAddr().String())
}
}
func TestDialUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
x := &unixDialTester{
listenAddr: addr,
}
testDial(t, "unix", x.listenAddr, x)
}

View file

@ -2,6 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This package contains integration tests for the
// code.google.com/p/go.crypto/ssh package.
// Package test contains integration tests for the
// golang.org/x/crypto/ssh package.
package test // import "golang.org/x/crypto/ssh/test"

View file

@ -16,13 +16,17 @@ import (
"time"
)
func TestPortForward(t *testing.T) {
type closeWriter interface {
CloseWrite() error
}
func testPortForward(t *testing.T, n, listenAddr string) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
sshListener, err := conn.Listen("tcp", "localhost:0")
sshListener, err := conn.Listen(n, listenAddr)
if err != nil {
t.Fatal(err)
}
@ -41,14 +45,14 @@ func TestPortForward(t *testing.T) {
}()
forwardedAddr := sshListener.Addr().String()
tcpConn, err := net.Dial("tcp", forwardedAddr)
netConn, err := net.Dial(n, forwardedAddr)
if err != nil {
t.Fatalf("TCP dial failed: %v", err)
t.Fatalf("net dial failed: %v", err)
}
readChan := make(chan []byte)
go func() {
data, _ := ioutil.ReadAll(tcpConn)
data, _ := ioutil.ReadAll(netConn)
readChan <- data
}()
@ -62,14 +66,14 @@ func TestPortForward(t *testing.T) {
for len(sent) < 1000*1000 {
// Send random sized chunks
m := rand.Intn(len(data))
n, err := tcpConn.Write(data[:m])
n, err := netConn.Write(data[:m])
if err != nil {
break
}
sent = append(sent, data[:n]...)
}
if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
t.Errorf("tcpConn.CloseWrite: %v", err)
if err := netConn.(closeWriter).CloseWrite(); err != nil {
t.Errorf("netConn.CloseWrite: %v", err)
}
read := <-readChan
@ -86,19 +90,29 @@ func TestPortForward(t *testing.T) {
}
// Check that the forward disappeared.
tcpConn, err = net.Dial("tcp", forwardedAddr)
netConn, err = net.Dial(n, forwardedAddr)
if err == nil {
tcpConn.Close()
netConn.Close()
t.Errorf("still listening to %s after closing", forwardedAddr)
}
}
func TestAcceptClose(t *testing.T) {
func TestPortForwardTCP(t *testing.T) {
testPortForward(t, "tcp", "localhost:0")
}
func TestPortForwardUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
testPortForward(t, "unix", addr)
}
func testAcceptClose(t *testing.T, n, listenAddr string) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
sshListener, err := conn.Listen("tcp", "localhost:0")
sshListener, err := conn.Listen(n, listenAddr)
if err != nil {
t.Fatal(err)
}
@ -124,13 +138,23 @@ func TestAcceptClose(t *testing.T) {
}
}
func TestAcceptCloseTCP(t *testing.T) {
testAcceptClose(t, "tcp", "localhost:0")
}
func TestAcceptCloseUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
testAcceptClose(t, "unix", addr)
}
// Check that listeners exit if the underlying client transport dies.
func TestPortForwardConnectionClose(t *testing.T) {
func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
sshListener, err := conn.Listen("tcp", "localhost:0")
sshListener, err := conn.Listen(n, listenAddr)
if err != nil {
t.Fatal(err)
}
@ -158,3 +182,13 @@ func TestPortForwardConnectionClose(t *testing.T) {
t.Logf("quit as expected (error %v)", err)
}
}
func TestPortForwardConnectionCloseTCP(t *testing.T) {
testPortForwardConnectionClose(t, "tcp", "localhost:0")
}
func TestPortForwardConnectionCloseUnix(t *testing.T) {
addr, cleanup := newTempSocket(t)
defer cleanup()
testPortForwardConnectionClose(t, "unix", addr)
}

144
vendor/golang.org/x/crypto/ssh/test/multi_auth_test.go generated vendored Normal file
View file

@ -0,0 +1,144 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Tests for ssh client multi-auth
//
// These tests run a simple go ssh client against OpenSSH server
// over unix domain sockets. The tests use multiple combinations
// of password, keyboard-interactive and publickey authentication
// methods.
//
// A wrapper library for making sshd PAM authentication use test
// passwords is required in ./sshd_test_pw.so. If the library does
// not exist these tests will be skipped. See compile instructions
// (for linux) in file ./sshd_test_pw.c.
// +build linux
package test
import (
"fmt"
"strings"
"testing"
"golang.org/x/crypto/ssh"
)
// test cases
type multiAuthTestCase struct {
authMethods []string
expectedPasswordCbs int
expectedKbdIntCbs int
}
// test context
type multiAuthTestCtx struct {
password string
numPasswordCbs int
numKbdIntCbs int
}
// create test context
func newMultiAuthTestCtx(t *testing.T) *multiAuthTestCtx {
password, err := randomPassword()
if err != nil {
t.Fatalf("Failed to generate random test password: %s", err.Error())
}
return &multiAuthTestCtx{
password: password,
}
}
// password callback
func (ctx *multiAuthTestCtx) passwordCb() (secret string, err error) {
ctx.numPasswordCbs++
return ctx.password, nil
}
// keyboard-interactive callback
func (ctx *multiAuthTestCtx) kbdIntCb(user, instruction string, questions []string, echos []bool) (answers []string, err error) {
if len(questions) == 0 {
return nil, nil
}
ctx.numKbdIntCbs++
if len(questions) == 1 {
return []string{ctx.password}, nil
}
return nil, fmt.Errorf("unsupported keyboard-interactive flow")
}
// TestMultiAuth runs several subtests for different combinations of password, keyboard-interactive and publickey authentication methods
func TestMultiAuth(t *testing.T) {
testCases := []multiAuthTestCase{
// Test password,publickey authentication, assert that password callback is called 1 time
multiAuthTestCase{
authMethods: []string{"password", "publickey"},
expectedPasswordCbs: 1,
},
// Test keyboard-interactive,publickey authentication, assert that keyboard-interactive callback is called 1 time
multiAuthTestCase{
authMethods: []string{"keyboard-interactive", "publickey"},
expectedKbdIntCbs: 1,
},
// Test publickey,password authentication, assert that password callback is called 1 time
multiAuthTestCase{
authMethods: []string{"publickey", "password"},
expectedPasswordCbs: 1,
},
// Test publickey,keyboard-interactive authentication, assert that keyboard-interactive callback is called 1 time
multiAuthTestCase{
authMethods: []string{"publickey", "keyboard-interactive"},
expectedKbdIntCbs: 1,
},
// Test password,password authentication, assert that password callback is called 2 times
multiAuthTestCase{
authMethods: []string{"password", "password"},
expectedPasswordCbs: 2,
},
}
for _, testCase := range testCases {
t.Run(strings.Join(testCase.authMethods, ","), func(t *testing.T) {
ctx := newMultiAuthTestCtx(t)
server := newServerForConfig(t, "MultiAuth", map[string]string{"AuthMethods": strings.Join(testCase.authMethods, ",")})
defer server.Shutdown()
clientConfig := clientConfig()
server.setTestPassword(clientConfig.User, ctx.password)
publicKeyAuthMethod := clientConfig.Auth[0]
clientConfig.Auth = nil
for _, authMethod := range testCase.authMethods {
switch authMethod {
case "publickey":
clientConfig.Auth = append(clientConfig.Auth, publicKeyAuthMethod)
case "password":
clientConfig.Auth = append(clientConfig.Auth,
ssh.RetryableAuthMethod(ssh.PasswordCallback(ctx.passwordCb), 5))
case "keyboard-interactive":
clientConfig.Auth = append(clientConfig.Auth,
ssh.RetryableAuthMethod(ssh.KeyboardInteractive(ctx.kbdIntCb), 5))
default:
t.Fatalf("Unknown authentication method %s", authMethod)
}
}
conn := server.Dial(clientConfig)
defer conn.Close()
if ctx.numPasswordCbs != testCase.expectedPasswordCbs {
t.Fatalf("passwordCallback was called %d times, expected %d times", ctx.numPasswordCbs, testCase.expectedPasswordCbs)
}
if ctx.numKbdIntCbs != testCase.expectedKbdIntCbs {
t.Fatalf("keyboardInteractiveCallback was called %d times, expected %d times", ctx.numKbdIntCbs, testCase.expectedKbdIntCbs)
}
})
}
}

View file

@ -11,6 +11,7 @@ package test
import (
"bytes"
"errors"
"fmt"
"io"
"strings"
"testing"
@ -276,24 +277,104 @@ func TestValidTerminalMode(t *testing.T) {
}
}
func TestWindowChange(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
conn := server.Dial(clientConfig())
defer conn.Close()
session, err := conn.NewSession()
if err != nil {
t.Fatalf("session failed: %v", err)
}
defer session.Close()
stdout, err := session.StdoutPipe()
if err != nil {
t.Fatalf("unable to acquire stdout pipe: %s", err)
}
stdin, err := session.StdinPipe()
if err != nil {
t.Fatalf("unable to acquire stdin pipe: %s", err)
}
tm := ssh.TerminalModes{ssh.ECHO: 0}
if err = session.RequestPty("xterm", 80, 40, tm); err != nil {
t.Fatalf("req-pty failed: %s", err)
}
if err := session.WindowChange(100, 100); err != nil {
t.Fatalf("window-change failed: %s", err)
}
err = session.Shell()
if err != nil {
t.Fatalf("session failed: %s", err)
}
stdin.Write([]byte("stty size && exit\n"))
var buf bytes.Buffer
if _, err := io.Copy(&buf, stdout); err != nil {
t.Fatalf("reading failed: %s", err)
}
if sttyOutput := buf.String(); !strings.Contains(sttyOutput, "100 100") {
t.Fatalf("terminal WindowChange failure: expected \"100 100\" stty output, got %s", sttyOutput)
}
}
func testOneCipher(t *testing.T, cipher string, cipherOrder []string) {
server := newServer(t)
defer server.Shutdown()
conf := clientConfig()
conf.Ciphers = []string{cipher}
// Don't fail if sshd doesn't have the cipher.
conf.Ciphers = append(conf.Ciphers, cipherOrder...)
conn, err := server.TryDial(conf)
if err != nil {
t.Fatalf("TryDial: %v", err)
}
defer conn.Close()
numBytes := 4096
// Exercise sending data to the server
if _, _, err := conn.Conn.SendRequest("drop-me", false, make([]byte, numBytes)); err != nil {
t.Fatalf("SendRequest: %v", err)
}
// Exercise receiving data from the server
session, err := conn.NewSession()
if err != nil {
t.Fatalf("NewSession: %v", err)
}
out, err := session.Output(fmt.Sprintf("dd if=/dev/zero of=/dev/stdout bs=%d count=1", numBytes))
if err != nil {
t.Fatalf("Output: %v", err)
}
if len(out) != numBytes {
t.Fatalf("got %d bytes, want %d bytes", len(out), numBytes)
}
}
var deprecatedCiphers = []string{
"aes128-cbc", "3des-cbc",
"arcfour128", "arcfour256",
}
func TestCiphers(t *testing.T) {
var config ssh.Config
config.SetDefaults()
cipherOrder := config.Ciphers
cipherOrder := append(config.Ciphers, deprecatedCiphers...)
for _, ciph := range cipherOrder {
server := newServer(t)
defer server.Shutdown()
conf := clientConfig()
conf.Ciphers = []string{ciph}
// Don't fail if sshd doesnt have the cipher.
conf.Ciphers = append(conf.Ciphers, cipherOrder...)
conn, err := server.TryDial(conf)
if err == nil {
conn.Close()
} else {
t.Fatalf("failed for cipher %q", ciph)
}
t.Run(ciph, func(t *testing.T) {
testOneCipher(t, ciph, cipherOrder)
})
}
}
@ -307,7 +388,7 @@ func TestMACs(t *testing.T) {
defer server.Shutdown()
conf := clientConfig()
conf.MACs = []string{mac}
// Don't fail if sshd doesnt have the MAC.
// Don't fail if sshd doesn't have the MAC.
conf.MACs = append(conf.MACs, macOrder...)
if conn, err := server.TryDial(conf); err == nil {
conn.Close()
@ -316,3 +397,47 @@ func TestMACs(t *testing.T) {
}
}
}
func TestKeyExchanges(t *testing.T) {
var config ssh.Config
config.SetDefaults()
kexOrder := config.KeyExchanges
for _, kex := range kexOrder {
server := newServer(t)
defer server.Shutdown()
conf := clientConfig()
// Don't fail if sshd doesn't have the kex.
conf.KeyExchanges = append([]string{kex}, kexOrder...)
conn, err := server.TryDial(conf)
if err == nil {
conn.Close()
} else {
t.Errorf("failed for kex %q", kex)
}
}
}
func TestClientAuthAlgorithms(t *testing.T) {
for _, key := range []string{
"rsa",
"dsa",
"ecdsa",
"ed25519",
} {
server := newServer(t)
conf := clientConfig()
conf.SetDefaults()
conf.Auth = []ssh.AuthMethod{
ssh.PublicKeys(testSigners[key]),
}
conn, err := server.TryDial(conf)
if err == nil {
conn.Close()
} else {
t.Errorf("failed for key %q", key)
}
server.Shutdown()
}
}

173
vendor/golang.org/x/crypto/ssh/test/sshd_test_pw.c generated vendored Normal file
View file

@ -0,0 +1,173 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// sshd_test_pw.c
// Wrapper to inject test password data for sshd PAM authentication
//
// This wrapper implements custom versions of getpwnam, getpwnam_r,
// getspnam and getspnam_r. These functions first call their real
// libc versions, then check if the requested user matches test user
// specified in env variable TEST_USER and if so replace the password
// with crypted() value of TEST_PASSWD env variable.
//
// Compile:
// gcc -Wall -shared -o sshd_test_pw.so -fPIC sshd_test_pw.c
//
// Compile with debug:
// gcc -DVERBOSE -Wall -shared -o sshd_test_pw.so -fPIC sshd_test_pw.c
//
// Run sshd:
// LD_PRELOAD="sshd_test_pw.so" TEST_USER="..." TEST_PASSWD="..." sshd ...
// +build ignore
#define _GNU_SOURCE
#include <string.h>
#include <pwd.h>
#include <shadow.h>
#include <dlfcn.h>
#include <stdlib.h>
#include <unistd.h>
#include <stdio.h>
#ifdef VERBOSE
#define DEBUG(X...) fprintf(stderr, X)
#else
#define DEBUG(X...) while (0) { }
#endif
/* crypt() password */
static char *
pwhash(char *passwd) {
return strdup(crypt(passwd, "$6$"));
}
/* Pointers to real functions in libc */
static struct passwd * (*real_getpwnam)(const char *) = NULL;
static int (*real_getpwnam_r)(const char *, struct passwd *, char *, size_t, struct passwd **) = NULL;
static struct spwd * (*real_getspnam)(const char *) = NULL;
static int (*real_getspnam_r)(const char *, struct spwd *, char *, size_t, struct spwd **) = NULL;
/* Cached test user and test password */
static char *test_user = NULL;
static char *test_passwd_hash = NULL;
static void
init(void) {
/* Fetch real libc function pointers */
real_getpwnam = dlsym(RTLD_NEXT, "getpwnam");
real_getpwnam_r = dlsym(RTLD_NEXT, "getpwnam_r");
real_getspnam = dlsym(RTLD_NEXT, "getspnam");
real_getspnam_r = dlsym(RTLD_NEXT, "getspnam_r");
/* abort if env variables are not defined */
if (getenv("TEST_USER") == NULL || getenv("TEST_PASSWD") == NULL) {
fprintf(stderr, "env variables TEST_USER and TEST_PASSWD are missing\n");
abort();
}
/* Fetch test user and test password from env */
test_user = strdup(getenv("TEST_USER"));
test_passwd_hash = pwhash(getenv("TEST_PASSWD"));
DEBUG("sshd_test_pw init():\n");
DEBUG("\treal_getpwnam: %p\n", real_getpwnam);
DEBUG("\treal_getpwnam_r: %p\n", real_getpwnam_r);
DEBUG("\treal_getspnam: %p\n", real_getspnam);
DEBUG("\treal_getspnam_r: %p\n", real_getspnam_r);
DEBUG("\tTEST_USER: '%s'\n", test_user);
DEBUG("\tTEST_PASSWD: '%s'\n", getenv("TEST_PASSWD"));
DEBUG("\tTEST_PASSWD_HASH: '%s'\n", test_passwd_hash);
}
static int
is_test_user(const char *name) {
if (test_user != NULL && strcmp(test_user, name) == 0)
return 1;
return 0;
}
/* getpwnam */
struct passwd *
getpwnam(const char *name) {
struct passwd *pw;
DEBUG("sshd_test_pw getpwnam(%s)\n", name);
if (real_getpwnam == NULL)
init();
if ((pw = real_getpwnam(name)) == NULL)
return NULL;
if (is_test_user(name))
pw->pw_passwd = strdup(test_passwd_hash);
return pw;
}
/* getpwnam_r */
int
getpwnam_r(const char *name,
struct passwd *pwd,
char *buf,
size_t buflen,
struct passwd **result) {
int r;
DEBUG("sshd_test_pw getpwnam_r(%s)\n", name);
if (real_getpwnam_r == NULL)
init();
if ((r = real_getpwnam_r(name, pwd, buf, buflen, result)) != 0 || *result == NULL)
return r;
if (is_test_user(name))
pwd->pw_passwd = strdup(test_passwd_hash);
return 0;
}
/* getspnam */
struct spwd *
getspnam(const char *name) {
struct spwd *sp;
DEBUG("sshd_test_pw getspnam(%s)\n", name);
if (real_getspnam == NULL)
init();
if ((sp = real_getspnam(name)) == NULL)
return NULL;
if (is_test_user(name))
sp->sp_pwdp = strdup(test_passwd_hash);
return sp;
}
/* getspnam_r */
int
getspnam_r(const char *name,
struct spwd *spbuf,
char *buf,
size_t buflen,
struct spwd **spbufp) {
int r;
DEBUG("sshd_test_pw getspnam_r(%s)\n", name);
if (real_getspnam_r == NULL)
init();
if ((r = real_getspnam_r(name, spbuf, buf, buflen, spbufp)) != 0)
return r;
if (is_test_user(name))
spbuf->sp_pwdp = strdup(test_passwd_hash);
return r;
}

View file

@ -1,46 +0,0 @@
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !windows
package test
// direct-tcpip functional tests
import (
"io"
"net"
"testing"
)
func TestDial(t *testing.T) {
server := newServer(t)
defer server.Shutdown()
sshConn := server.Dial(clientConfig())
defer sshConn.Close()
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer l.Close()
go func() {
for {
c, err := l.Accept()
if err != nil {
break
}
io.WriteString(c, c.RemoteAddr().String())
c.Close()
}
}()
conn, err := sshConn.Dial("tcp", l.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer conn.Close()
}

View file

@ -10,6 +10,8 @@ package test
import (
"bytes"
"crypto/rand"
"encoding/base64"
"fmt"
"io/ioutil"
"log"
@ -25,11 +27,14 @@ import (
"golang.org/x/crypto/ssh/testdata"
)
const sshd_config = `
const (
defaultSshdConfig = `
Protocol 2
Banner {{.Dir}}/banner
HostKey {{.Dir}}/id_rsa
HostKey {{.Dir}}/id_dsa
HostKey {{.Dir}}/id_ecdsa
HostCertificate {{.Dir}}/id_rsa-cert.pub
Pidfile {{.Dir}}/sshd.pid
#UsePrivilegeSeparation no
KeyRegenerationInterval 3600
@ -41,14 +46,24 @@ PermitRootLogin no
StrictModes no
RSAAuthentication yes
PubkeyAuthentication yes
AuthorizedKeysFile {{.Dir}}/id_user.pub
AuthorizedKeysFile {{.Dir}}/authorized_keys
TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
IgnoreRhosts yes
RhostsRSAAuthentication no
HostbasedAuthentication no
PubkeyAcceptedKeyTypes=*
`
multiAuthSshdConfigTail = `
UsePAM yes
PasswordAuthentication yes
ChallengeResponseAuthentication yes
AuthenticationMethods {{.AuthMethods}}
`
)
var configTmpl = template.Must(template.New("").Parse(sshd_config))
var configTmpl = map[string]*template.Template{
"default": template.Must(template.New("").Parse(defaultSshdConfig)),
"MultiAuth": template.Must(template.New("").Parse(defaultSshdConfig + multiAuthSshdConfigTail))}
type server struct {
t *testing.T
@ -57,6 +72,10 @@ type server struct {
cmd *exec.Cmd
output bytes.Buffer // holds stderr from sshd process
testUser string // test username for sshd
testPasswd string // test password for sshd
sshdTestPwSo string // dynamic library to inject a custom password into sshd
// Client half of the network connection.
clientConn net.Conn
}
@ -118,6 +137,11 @@ func clientConfig() *ssh.ClientConfig {
ssh.PublicKeys(testSigners["user"]),
},
HostKeyCallback: hostKeyDB().Check,
HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker
ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521,
ssh.KeyAlgoRSA, ssh.KeyAlgoDSA,
ssh.KeyAlgoED25519,
},
}
return config
}
@ -153,6 +177,12 @@ func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
}
func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
return s.TryDialWithAddr(config, "")
}
// addr is the user specified host:port. While we don't actually dial it,
// we need to know this for host key matching
func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) {
sshd, err := exec.LookPath("sshd")
if err != nil {
s.t.Skipf("skipping test: %v", err)
@ -172,13 +202,27 @@ func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
s.cmd.Stdin = f
s.cmd.Stdout = f
s.cmd.Stderr = &s.output
if s.sshdTestPwSo != "" {
if s.testUser == "" {
s.t.Fatal("user missing from sshd_test_pw.so config")
}
if s.testPasswd == "" {
s.t.Fatal("password missing from sshd_test_pw.so config")
}
s.cmd.Env = append(os.Environ(),
fmt.Sprintf("LD_PRELOAD=%s", s.sshdTestPwSo),
fmt.Sprintf("TEST_USER=%s", s.testUser),
fmt.Sprintf("TEST_PASSWD=%s", s.testPasswd))
}
if err := s.cmd.Start(); err != nil {
s.t.Fail()
s.Shutdown()
s.t.Fatalf("s.cmd.Start: %v", err)
}
s.clientConn = c1
conn, chans, reqs, err := ssh.NewClientConn(c1, "", config)
conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config)
if err != nil {
return nil, err
}
@ -222,11 +266,49 @@ func writeFile(path string, contents []byte) {
}
}
// generate random password
func randomPassword() (string, error) {
b := make([]byte, 12)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// setTestPassword is used for setting user and password data for sshd_test_pw.so
// This function also checks that ./sshd_test_pw.so exists and if not calls s.t.Skip()
func (s *server) setTestPassword(user, passwd string) error {
wd, _ := os.Getwd()
wrapper := filepath.Join(wd, "sshd_test_pw.so")
if _, err := os.Stat(wrapper); err != nil {
s.t.Skip(fmt.Errorf("sshd_test_pw.so is not available"))
return err
}
s.sshdTestPwSo = wrapper
s.testUser = user
s.testPasswd = passwd
return nil
}
// newServer returns a new mock ssh server.
func newServer(t *testing.T) *server {
return newServerForConfig(t, "default", map[string]string{})
}
// newServerForConfig returns a new mock ssh server.
func newServerForConfig(t *testing.T, config string, configVars map[string]string) *server {
if testing.Short() {
t.Skip("skipping test due to -short")
}
u, err := user.Current()
if err != nil {
t.Fatalf("user.Current: %v", err)
}
if u.Name == "root" {
t.Skip("skipping test because current user is root")
}
dir, err := ioutil.TempDir("", "sshtest")
if err != nil {
t.Fatal(err)
@ -235,20 +317,35 @@ func newServer(t *testing.T) *server {
if err != nil {
t.Fatal(err)
}
err = configTmpl.Execute(f, map[string]string{
"Dir": dir,
})
if _, ok := configTmpl[config]; ok == false {
t.Fatal(fmt.Errorf("Invalid server config '%s'", config))
}
configVars["Dir"] = dir
err = configTmpl[config].Execute(f, configVars)
if err != nil {
t.Fatal(err)
}
f.Close()
writeFile(filepath.Join(dir, "banner"), []byte("Server Banner"))
for k, v := range testdata.PEMBytes {
filename := "id_" + k
writeFile(filepath.Join(dir, filename), v)
writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
}
for k, v := range testdata.SSHCertificates {
filename := "id_" + k + "-cert.pub"
writeFile(filepath.Join(dir, filename), v)
}
var authkeys bytes.Buffer
for k := range testdata.PEMBytes {
authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k]))
}
writeFile(filepath.Join(dir, "authorized_keys"), authkeys.Bytes())
return &server{
t: t,
configfile: f.Name(),
@ -259,3 +356,13 @@ func newServer(t *testing.T) *server {
},
}
}
func newTempSocket(t *testing.T) (string, func()) {
dir, err := ioutil.TempDir("", "socket")
if err != nil {
t.Fatal(err)
}
deferFunc := func() { os.RemoveAll(dir) }
addr := filepath.Join(dir, "sock")
return addr, deferFunc
}

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places:
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
// instances.

View file

@ -3,6 +3,6 @@
// license that can be found in the LICENSE file.
// This package contains test data shared between the various subpackages of
// the code.google.com/p/go.crypto/ssh package. Under no circumstance should
// the golang.org/x/crypto/ssh package. Under no circumstance should
// this data be used for production code.
package testdata // import "golang.org/x/crypto/ssh/testdata"

View file

@ -23,21 +23,205 @@ MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49
AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+
6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA==
-----END EC PRIVATE KEY-----
`),
"ecdsap256": []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIAPCE25zK0PQSnsgVcEbM1mbKTASH4pqb5QJajplDwDZoAoGCCqGSM49
AwEHoUQDQgAEWy8TxGcIHRh5XGpO4dFVfDjeNY+VkgubQrf/eyFJZHxAn1SKraXU
qJUjTKj1z622OxYtJ5P7s9CfAEVsTzLCzg==
-----END EC PRIVATE KEY-----
`),
"ecdsap384": []byte(`-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDBWfSnMuNKq8J9rQLzzEkx3KAoEohSXqhE/4CdjEYtoU2i22HW80DDS
qQhYNHRAduygBwYFK4EEACKhZANiAAQWaDMAd0HUd8ZiXCX7mYDDnC54gwH/nG43
VhCUEYmF7HMZm/B9Yn3GjFk3qYEDEvuF/52+NvUKBKKaLbh32AWxMv0ibcoba4cz
hL9+hWYhUD9XIUlzMWiZ2y6eBE9PdRI=
-----END EC PRIVATE KEY-----
`),
"ecdsap521": []byte(`-----BEGIN EC PRIVATE KEY-----
MIHcAgEBBEIBrkYpQcy8KTVHNiAkjlFZwee90224Bu6wz94R4OBo+Ts0eoAQG7SF
iaygEDMUbx6kTgXTBcKZ0jrWPKakayNZ/kigBwYFK4EEACOhgYkDgYYABADFuvLV
UoaCDGHcw5uNfdRIsvaLKuWSpLsl48eWGZAwdNG432GDVKduO+pceuE+8XzcyJb+
uMv+D2b11Q/LQUcHJwE6fqbm8m3EtDKPsoKs0u/XUJb0JsH4J8lkZzbUTjvGYamn
FFlRjzoB3Oxu8UQgb+MWPedtH9XYBbg9biz4jJLkXQ==
-----END EC PRIVATE KEY-----
`),
"rsa": []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld
r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ
tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC
nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW
2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB
y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr
rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg==
MIICXAIBAAKBgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2
a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8
Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQIDAQAB
AoGAJMCk5vqfSRzyXOTXLGIYCuR4Kj6pdsbNSeuuRGfYBeR1F2c/XdFAg7D/8s5R
38p/Ih52/Ty5S8BfJtwtvgVY9ecf/JlU/rl/QzhG8/8KC0NG7KsyXklbQ7gJT8UT
Ojmw5QpMk+rKv17ipDVkQQmPaj+gJXYNAHqImke5mm/K/h0CQQDciPmviQ+DOhOq
2ZBqUfH8oXHgFmp7/6pXw80DpMIxgV3CwkxxIVx6a8lVH9bT/AFySJ6vXq4zTuV9
6QmZcZzDAkEA2j/UXJPIs1fQ8z/6sONOkU/BjtoePFIWJlRxdN35cZjXnBraX5UR
fFHkePv4YwqmXNqrBOvSu+w2WdSDci+IKwJAcsPRc/jWmsrJW1q3Ha0hSf/WG/Bu
X7MPuXaKpP/DkzGoUmb8ks7yqj6XWnYkPNLjCc8izU5vRwIiyWBRf4mxMwJBAILa
NDvRS0rjwt6lJGv7zPZoqDc65VfrK2aNyHx2PgFyzwrEOtuF57bu7pnvEIxpLTeM
z26i6XVMeYXAWZMTloMCQBbpGgEERQpeUknLBqUHhg/wXF6+lFA+vEGnkY+Dwab2
KCXFGd+SQ5GdUcEMe9isUH6DYj/6/yCDoFrXXmpQb+M=
-----END RSA PRIVATE KEY-----
`),
"pkcs8": []byte(`-----BEGIN PRIVATE KEY-----
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCitzS2KiRQTccf
VApb0mbPpo1lt29JjeLBYAehXHWfQ+w8sXpd8e04n/020spx1R94yg+v0NjXyh2R
NFXNBYdhNei33VJxUeKNlExaecvW2yxfuZqka+ZxT1aI8zrAsjh3Rwc6wayAJS4R
wZuzlDv4jZitWqwD+mb/22Zwq/WSs4YX5dUHDklfdWSVnoBfue8K/00n8f5yMTdJ
vFF0qAJwf9spPEHla0lYcozJk64CO5lRkqfLor4UnsXXOiA7aRIoaUSKa+rlhiqt
1EMGYiBjblPt4SwMelGGU2UfywPb4d85gpQ/s8SBARbpPxNVs2IbHDMwj70P3uZc
74M3c4VJAgMBAAECggEAFIzY3mziGzZHgMBncoNXMsCRORh6uKpvygZr0EhSHqRA
cMXlc3n7gNxL6aGjqc7F48Z5RrY0vMQtCcq3T2Z0W6WoV5hfMiqqV0E0h3S8ds1F
hG13h26NMyBXCILXl8Cqev4Afr45IBISCHIQTRTaoiCX+MTr1rDIU2YNQQumvzkz
fMw2XiFTFTgxAtJUAgKoTqLtm7/T+az7TKw+Hesgbx7yaJoMh9DWGBh4Y61DnIDA
fcxJboAfxxnFiXvdBVmzo72pCsRXrWOsjW6WxQmCKuXHvyB1FZTmMaEFNCGSJDa6
U+OCzA3m65loAZAE7ffFHhYgssz/h9TBaOjKO0BX1QKBgQDZiCBvu+bFh9pEodcS
VxaI+ATlsYcmGdLtnZw5pxuEdr60iNWhpEcV6lGkbdiv5aL43QaGFDLagqeHI77b
+ITFbPPdCiYNaqlk6wyiXv4pdN7V683EDmGWSQlPeC9IhUilt2c+fChK2EB/XlkO
q8c3Vk1MsC6JOxDXNgJxylNpswKBgQC/fYBTb9iD+uM2n3SzJlct/ZlPaONKnNDR
pbTOdxBFHsu2VkfY858tfnEPkmSRX0yKmjHni6e8/qIzfzLwWBY4NmxhNZE5v+qJ
qZF26ULFdrZB4oWXAOliy/1S473OpQnp2MZp2asd0LPcg/BNaMuQrz44hxHb76R7
qWD0ebIfEwKBgQCRCIiP1pjbVGN7ZOgPS080DSC+wClahtcyI+ZYLglTvRQTLDQ7
LFtUykCav748MIADKuJBnM/3DiuCF5wV71EejDDfS/fo9BdyuKBY1brhixFTUX+E
Ww5Hc/SoLnpgALVZ/7jvWTpIBHykLxRziqYtR/YLzl+IkX/97P2ePoZ0rwKBgHNC
/7M5Z4JJyepfIMeVFHTCaT27TNTkf20x6Rs937U7TDN8y9JzEiU4LqXI4HAAhPoI
xnExRs4kF04YCnlRDE7Zs3Lv43J3ap1iTATfcymYwyv1RaQXEGQ/lUQHgYCZJtZz
fTrJoo5XyWu6nzJ5Gc8FLNaptr5ECSXGVm3Rsr2xAoGBAJWqEEQS/ejhO05QcPqh
y4cUdLr0269ILVsvic4Ot6zgfPIntXAK6IsHGKcg57kYm6W9k1CmmlA4ENGryJnR
vxyyqA9eyTFc1CQNuc2frKFA9It49JzjXahKc0aDHEHmTR787Tmk1LbuT0/gm9kA
L4INU6g+WqF0fatJxd+IJPrp
-----END PRIVATE KEY-----
`),
"ed25519": []byte(`-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAMwAAAAtzc2gtZW
QyNTUxOQAAACA+3f7hS7g5UWwXOGVTrMfhmxyrjqz7Sxxbx7I1j8DvvwAAAJhAFfkOQBX5
DgAAAAtzc2gtZWQyNTUxOQAAACA+3f7hS7g5UWwXOGVTrMfhmxyrjqz7Sxxbx7I1j8Dvvw
AAAEAaYmXltfW6nhRo3iWGglRB48lYq0z0Q3I3KyrdutEr6j7d/uFLuDlRbBc4ZVOsx+Gb
HKuOrPtLHFvHsjWPwO+/AAAAE2dhcnRvbm1AZ2FydG9ubS14cHMBAg==
-----END OPENSSH PRIVATE KEY-----
`),
"rsa-openssh-format": []byte(`-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAlwAAAAdzc2gtcn
NhAAAAAwEAAQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5l
oEuW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lz
a+yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAIQWL0H31i9B98AAAAH
c3NoLXJzYQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5loE
uW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lza+
yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAADAQABAAAAgCThyTGsT4
IARDxVMhWl6eiB2ZrgFgWSeJm/NOqtppWgOebsIqPMMg4UVuVFsl422/lE3RkPhVkjGXgE
pWvZAdCnmLmApK8wK12vF334lZhZT7t3Z9EzJps88PWEHo7kguf285HcnUM7FlFeissJdk
kXly34y7/3X/a6Tclm+iABAAAAQE0xR/KxZ39slwfMv64Rz7WKk1PPskaryI29aHE3mKHk
pY2QA+P3QlrKxT/VWUMjHUbNNdYfJm48xu0SGNMRdKMAAABBAORh2NP/06JUV3J9W/2Hju
X1ViJuqqcQnJPVzpgSL826EC2xwOECTqoY8uvFpUdD7CtpksIxNVqRIhuNOlz0lqEAAABB
ANkaHTTaPojClO0dKJ/Zjs7pWOCGliebBYprQ/Y4r9QLBkC/XaWMS26gFIrjgC7D2Rv+rZ
wSD0v0RcmkITP1ZR0AAAAYcHF1ZXJuYUBMdWNreUh5ZHJvLmxvY2FsAQID
-----END OPENSSH PRIVATE KEY-----`),
"user": []byte(`-----BEGIN EC PRIVATE KEY-----
MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49
AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD
PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w==
-----END EC PRIVATE KEY-----
`),
"ca": []byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAvg9dQ9IRG59lYJb+GESfKWTch4yBpr7Ydw1jkK6vvtrx9jLo
5hkA8X6+ElRPRqTAZSlN5cBm6YCAcQIOsmXDUn6Oj1lVPQAoOjTBTvsjM3NjGhvv
52kHTY0nsMsBeY9q5DTtlzmlYkVUq2a6Htgf2mNi01dIw5fJ7uTTo8EbNf7O0i3u
c9a8P19HaZl5NKiWN4EIZkfB2WdXYRJCVBsGgQj3dE/GrEmH9QINq1A+GkNvK96u
vZm8H1jjmuqzHplWa7lFeXcx8FTVTbVb/iJrZ2Lc/JvIPitKZWhqbR59yrGjpwEp
Id7bo4WhO5L3OB0fSIJYvfu+o4WYnt4f3UzecwIDAQABAoIBABRD9yHgKErVuC2Q
bA+SYZY8VvdtF/X7q4EmQFORDNRA7EPgMc03JU6awRGbQ8i4kHs46EFzPoXvWcKz
AXYsO6N0Myc900Tp22A5d9NAHATEbPC/wdje7hRq1KyZONMJY9BphFv3nZbY5apR
Dc90JBFZP5RhXjTc3n9GjvqLAKfFEKVmPRCvqxCOZunw6XR+SgIQLJo36nsIsbhW
QUXIVaCI6cXMN8bRPm8EITdBNZu06Fpu4ZHm6VaxlXN9smERCDkgBSNXNWHKxmmA
c3Glo2DByUr2/JFBOrLEe9fkYgr24KNCQkHVcSaFxEcZvTggr7StjKISVHlCNEaB
7Q+kPoECgYEA3zE9FmvFGoQCU4g4Nl3dpQHs6kaAW8vJlrmq3xsireIuaJoa2HMe
wYdIvgCnK9DIjyxd5OWnE4jXtAEYPsyGD32B5rSLQrRO96lgb3f4bESCLUb3Bsn/
sdgeE3p1xZMA0B59htqCrvVgN9k8WxyevBxYl3/gSBm/p8OVH1RTW/ECgYEA2f9Z
95OLj0KQHQtxQXf+I3VjhCw3LkLW39QZOXVI0QrCJfqqP7uxsJXH9NYX0l0GFTcR
kRrlyoaSU1EGQosZh+n1MvplGBTkTSV47/bPsTzFpgK2NfEZuFm9RoWgltS+nYeH
Y2k4mnAN3PhReCMwuprmJz8GRLsO3Cs2s2YylKMCgYEA2UX+uO/q7jgqZ5UJW+ue
1H5+W0aMuFA3i7JtZEnvRaUVFqFGlwXin/WJ2+WY1++k/rPrJ+Rk9IBXtBUIvEGw
FC5TIfsKQsJyyWgqx/jbbtJ2g4s8+W/1qfTAuqeRNOg5d2DnRDs90wJuS4//0JaY
9HkHyVwkQyxFxhSA/AHEMJECgYA2MvyFR1O9bIk0D3I7GsA+xKLXa77Ua53MzIjw
9i4CezBGDQpjCiFli/fI8am+jY5DnAtsDknvjoG24UAzLy5L0mk6IXMdB6SzYYut
7ak5oahqW+Y9hxIj+XvLmtGQbphtxhJtLu35x75KoBpxSh6FZpmuTEccs31AVCYn
eFM/DQKBgQDOPUwbLKqVi6ddFGgrV9MrWw+SWsDa43bPuyvYppMM3oqesvyaX1Dt
qDvN7owaNxNM4OnfKcZr91z8YPVCFo4RbBif3DXRzjNNBlxEjHBtuMOikwvsmucN
vIrbeEpjTiUMTEAr6PoTiVHjsfS8WAM6MDlF5M+2PNswDsBpa2yLgA==
-----END RSA PRIVATE KEY-----
`),
}
var SSHCertificates = map[string][]byte{
// The following are corresponding certificates for the private keys above, signed by the CA key
// Generated by the following commands:
//
// 1. Assumes "rsa" key above in file named "rsa", write out the public key to "rsa.pub":
// ssh-keygen -y -f rsa > rsa.pub
//
// 2. Assumes "ca" key above in file named "ca", sign a cert for "rsa.pub":
// ssh-keygen -s ca -h -n host.example.com -V +500w -I host.example.com-key rsa.pub
"rsa": []byte(`ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLjYqmmuTSEmjVhSfLQphBSTJMLwIZhRgmpn8FHKLiEIAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABZHN8UAAAAAGsjIYUAAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABDwAAAAdzc2gtcnNhAAABALeDea+60H6xJGhktAyosHaSY7AYzLocaqd8hJQjEIDifBwzoTlnBmcK9CxGhKuaoJFThdCLdaevCeOSuquh8HTkf+2ebZZc/G5T+2thPvPqmcuEcmMosWo+SIjYhbP3S6KD49aLC1X0kz8IBQeauFvURhkZ5ZjhA1L4aQYt9NjL73nqOl8PplRui+Ov5w8b4ldul4zOvYAFrzfcP6wnnXk3c1Zzwwf5wynD5jakO8GpYKBuhM7Z4crzkKSQjU3hla7xqgfomC5Gz4XbR2TNjcQiRrJQ0UlKtX3X3ObRCEhuvG0Kzjklhv+Ddw6txrhKjMjiSi/Yyius/AE8TmC1p4U= host.example.com
`),
}
var PEMEncryptedKeys = []struct {
Name string
EncryptionKey string
PEMBytes []byte
}{
0: {
Name: "rsa-encrypted",
EncryptionKey: "r54-G0pher_t3st$",
PEMBytes: []byte(`-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,3E1714DE130BC5E81327F36564B05462
MqW88sud4fnWk/Jk3fkjh7ydu51ZkHLN5qlQgA4SkAXORPPMj2XvqZOv1v2LOgUV
dUevUn8PZK7a9zbZg4QShUSzwE5k6wdB7XKPyBgI39mJ79GBd2U4W3h6KT6jIdWA
goQpluxkrzr2/X602IaxLEre97FT9mpKC6zxKCLvyFWVIP9n3OSFS47cTTXyFr+l
7PdRhe60nn6jSBgUNk/Q1lAvEQ9fufdPwDYY93F1wyJ6lOr0F1+mzRrMbH67NyKs
rG8J1Fa7cIIre7ueKIAXTIne7OAWqpU9UDgQatDtZTbvA7ciqGsSFgiwwW13N+Rr
hN8MkODKs9cjtONxSKi05s206A3NDU6STtZ3KuPDjFE1gMJODotOuqSM+cxKfyFq
wxpk/CHYCDdMAVBSwxb/vraOHamylL4uCHpJdBHypzf2HABt+lS8Su23uAmL87DR
yvyCS/lmpuNTndef6qHPRkoW2EV3xqD3ovosGf7kgwGJUk2ZpCLVteqmYehKlZDK
r/Jy+J26ooI2jIg9bjvD1PZq+Mv+2dQ1RlDrPG3PB+rEixw6vBaL9x3jatCd4ej7
XG7lb3qO9xFpLsx89tkEcvpGR+broSpUJ6Mu5LBCVmrvqHjvnDhrZVz1brMiQtU9
iMZbgXqDLXHd6ERWygk7OTU03u+l1gs+KGMfmS0h0ZYw6KGVLgMnsoxqd6cFSKNB
8Ohk9ZTZGCiovlXBUepyu8wKat1k8YlHSfIHoRUJRhhcd7DrmojC+bcbMIZBU22T
Pl2ftVRGtcQY23lYd0NNKfebF7ncjuLWQGy+vZW+7cgfI6wPIbfYfP6g7QAutk6W
KQx0AoX5woZ6cNxtpIrymaVjSMRRBkKQrJKmRp3pC/lul5E5P2cueMs1fj4OHTbJ
lAUv88ywr+R+mRgYQlFW/XQ653f6DT4t6+njfO9oBcPrQDASZel3LjXLpjjYG/N5
+BWnVexuJX9ika8HJiFl55oqaKb+WknfNhk5cPY+x7SDV9ywQeMiDZpr0ffeYAEP
LlwwiWRDYpO+uwXHSFF3+JjWwjhs8m8g99iFb7U93yKgBB12dCEPPa2ZeH9wUHMJ
sreYhNuq6f4iWWSXpzN45inQqtTi8jrJhuNLTT543ErW7DtntBO2rWMhff3aiXbn
Uy3qzZM1nPbuCGuBmP9L2dJ3Z5ifDWB4JmOyWY4swTZGt9AVmUxMIKdZpRONx8vz
I9u9nbVPGZBcou50Pa0qTLbkWsSL94MNXrARBxzhHC9Zs6XNEtwN7mOuii7uMkVc
adrxgknBH1J1N+NX/eTKzUwJuPvDtA+Z5ILWNN9wpZT/7ed8zEnKHPNUexyeT5g3
uw9z9jH7ffGxFYlx87oiVPHGOrCXYZYW5uoZE31SCBkbtNuffNRJRKIFeipmpJ3P
7bpAG+kGHMelQH6b+5K1Qgsv4tpuSyKeTKpPFH9Av5nN4P1ZBm9N80tzbNWqjSJm
S7rYdHnuNEVnUGnRmEUMmVuYZnNBEVN/fP2m2SEwXcP3Uh7TiYlcWw10ygaGmOr7
MvMLGkYgQ4Utwnd98mtqa0jr0hK2TcOSFir3AqVvXN3XJj4cVULkrXe4Im1laWgp
-----END RSA PRIVATE KEY-----
`),
},
1: {
Name: "dsa-encrypted",
EncryptionKey: "qG0pher-dsa_t3st$",
PEMBytes: []byte(`-----BEGIN DSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,7CE7A6E4A647DC01AF860210B15ADE3E
hvnBpI99Hceq/55pYRdOzBLntIEis02JFNXuLEydWL+RJBFDn7tA+vXec0ERJd6J
G8JXlSOAhmC2H4uK3q2xR8/Y3yL95n6OIcjvCBiLsV+o3jj1MYJmErxP6zRtq4w3
JjIjGHWmaYFSxPKQ6e8fs74HEqaeMV9ONUoTtB+aISmgaBL15Fcoayg245dkBvVl
h5Kqspe7yvOBmzA3zjRuxmSCqKJmasXM7mqs3vIrMxZE3XPo1/fWKcPuExgpVQoT
HkJZEoIEIIPnPMwT2uYbFJSGgPJVMDT84xz7yvjCdhLmqrsXgs5Qw7Pw0i0c0BUJ
b7fDJ2UhdiwSckWGmIhTLlJZzr8K+JpjCDlP+REYBI5meB7kosBnlvCEHdw2EJkH
0QDc/2F4xlVrHOLbPRFyu1Oi2Gvbeoo9EsM/DThpd1hKAlb0sF5Y0y0d+owv0PnE
R/4X3HWfIdOHsDUvJ8xVWZ4BZk9Zk9qol045DcFCehpr/3hslCrKSZHakLt9GI58
vVQJ4L0aYp5nloLfzhViZtKJXRLkySMKdzYkIlNmW1oVGl7tce5UCNI8Nok4j6yn
IiHM7GBn+0nJoKTXsOGMIBe3ulKlKVxLjEuk9yivh/8=
-----END DSA PRIVATE KEY-----
`),
},
}

View file

@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places:
// IMPLEMENTATION NOTE: To avoid a package loop, this file is in three places:
// ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three
// instances.

View file

@ -6,12 +6,20 @@ package ssh
import (
"bufio"
"bytes"
"errors"
"io"
"log"
)
// debugTransport if set, will print packet types as they go over the
// wire. No message decoding is done, to minimize the impact on timing.
const debugTransport = false
const (
gcmCipherID = "aes128-gcm@openssh.com"
gcmCipherID = "aes128-gcm@openssh.com"
aes128cbcID = "aes128-cbc"
tripledescbcID = "3des-cbc"
)
// packetConn represents a transport that implements packet based
@ -20,7 +28,9 @@ type packetConn interface {
// Encrypt and send a packet of data to the remote peer.
writePacket(packet []byte) error
// Read a packet from the connection
// Read a packet from the connection. The read is blocking,
// i.e. if error is nil, then the returned byte slice is
// always non-empty.
readPacket() ([]byte, error)
// Close closes the write-side of the connection.
@ -36,21 +46,8 @@ type transport struct {
bufReader *bufio.Reader
bufWriter *bufio.Writer
rand io.Reader
isClient bool
io.Closer
// Initial H used for the session ID. Once assigned this does
// not change, even during subsequent key exchanges.
sessionID []byte
}
// getSessionID returns the ID of the SSH connection. The return value
// should not be modified.
func (t *transport) getSessionID() []byte {
if t.sessionID == nil {
panic("session ID not set yet")
}
return t.sessionID
}
// packetCipher represents a combination of SSH encryption/MAC
@ -80,30 +77,53 @@ type connectionState struct {
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
if t.sessionID == nil {
t.sessionID = kexResult.H
}
kexResult.SessionID = t.sessionID
if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil {
ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult)
if err != nil {
return err
} else {
t.reader.pendingKeyChange <- ciph
}
t.reader.pendingKeyChange <- ciph
if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil {
ciph, err = newPacketCipher(t.writer.dir, algs.w, kexResult)
if err != nil {
return err
} else {
t.writer.pendingKeyChange <- ciph
}
t.writer.pendingKeyChange <- ciph
return nil
}
func (t *transport) printPacket(p []byte, write bool) {
if len(p) == 0 {
return
}
who := "server"
if t.isClient {
who = "client"
}
what := "read"
if write {
what = "write"
}
log.Println(what, who, p[0])
}
// Read and decrypt next packet.
func (t *transport) readPacket() ([]byte, error) {
return t.reader.readPacket(t.bufReader)
func (t *transport) readPacket() (p []byte, err error) {
for {
p, err = t.reader.readPacket(t.bufReader)
if err != nil {
break
}
if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) {
break
}
}
if debugTransport {
t.printPacket(p, false)
}
return p, err
}
func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
@ -113,12 +133,27 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
err = errors.New("ssh: zero length packet")
}
if len(packet) > 0 && packet[0] == msgNewKeys {
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
default:
return nil, errors.New("ssh: got bogus newkeys message.")
if len(packet) > 0 {
switch packet[0] {
case msgNewKeys:
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
default:
return nil, errors.New("ssh: got bogus newkeys message")
}
case msgDisconnect:
// Transform a disconnect message into an
// error. Since this is lowest level at which
// we interpret message types, doing it here
// ensures that we don't have to handle it
// elsewhere.
var msg disconnectMsg
if err := Unmarshal(packet, &msg); err != nil {
return nil, err
}
return nil, &msg
}
}
@ -131,6 +166,9 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
}
func (t *transport) writePacket(packet []byte) error {
if debugTransport {
t.printPacket(packet, true)
}
return t.writer.writePacket(t.bufWriter, t.rand, packet)
}
@ -171,6 +209,8 @@ func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transp
},
Closer: rwc,
}
t.isClient = isClient
if isClient {
t.reader.dir = serverKeys
t.writer.dir = clientKeys
@ -193,43 +233,22 @@ var (
clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
// generateKeys generates key material for IV, MAC and encryption.
func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) {
cipherMode := cipherModes[algs.Cipher]
macMode := macModes[algs.MAC]
iv = make([]byte, cipherMode.ivSize)
key = make([]byte, cipherMode.keySize)
macKey = make([]byte, macMode.keySize)
generateKeyMaterial(iv, d.ivTag, kex)
generateKeyMaterial(key, d.keyTag, kex)
generateKeyMaterial(macKey, d.macKeyTag, kex)
return
}
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) {
iv, key, macKey := generateKeys(d, algs, kex)
cipherMode := cipherModes[algs.Cipher]
macMode := macModes[algs.MAC]
if algs.Cipher == gcmCipherID {
return newGCMCipher(iv, key, macKey)
}
iv := make([]byte, cipherMode.ivSize)
key := make([]byte, cipherMode.keySize)
macKey := make([]byte, macMode.keySize)
c := &streamPacketCipher{
mac: macModes[algs.MAC].new(macKey),
}
c.macResult = make([]byte, c.mac.Size())
generateKeyMaterial(iv, d.ivTag, kex)
generateKeyMaterial(key, d.keyTag, kex)
generateKeyMaterial(macKey, d.macKeyTag, kex)
var err error
c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv)
if err != nil {
return nil, err
}
return c, nil
return cipherModes[algs.Cipher].create(key, iv, macKey, algs)
}
// generateKeyMaterial fills out with key material generated from tag, K, H
@ -294,7 +313,7 @@ func readVersion(r io.Reader) ([]byte, error) {
var ok bool
var buf [1]byte
for len(versionString) < maxVersionStringBytes {
for length := 0; length < maxVersionStringBytes; length++ {
_, err := io.ReadFull(r, buf[:])
if err != nil {
return nil, err
@ -302,6 +321,13 @@ func readVersion(r io.Reader) ([]byte, error) {
// The RFC says that the version should be terminated with \r\n
// but several SSH servers actually only send a \n.
if buf[0] == '\n' {
if !bytes.HasPrefix(versionString, []byte("SSH-")) {
// RFC 4253 says we need to ignore all version string lines
// except the one containing the SSH version (provided that
// all the lines do not exceed 255 bytes in total).
versionString = versionString[:0]
continue
}
ok = true
break
}

View file

@ -13,11 +13,13 @@ import (
)
func TestReadVersion(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253]
multiLineVersion := strings.Repeat("ignored\r\n", 20) + "SSH-2.0-bla\r\n"
cases := map[string]string{
"SSH-2.0-bla\r\n": "SSH-2.0-bla",
"SSH-2.0-bla\n": "SSH-2.0-bla",
longversion + "\r\n": longversion,
multiLineVersion: "SSH-2.0-bla",
longVersion + "\r\n": longVersion,
}
for in, want := range cases {
@ -33,9 +35,11 @@ func TestReadVersion(t *testing.T) {
}
func TestReadVersionError(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
longVersion := strings.Repeat("SSH-2.0-bla", 50)[:253]
multiLineVersion := strings.Repeat("ignored\r\n", 50) + "SSH-2.0-bla\r\n"
cases := []string{
longversion + "too-long\r\n",
longVersion + "too-long\r\n",
multiLineVersion,
}
for _, in := range cases {
if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
@ -60,7 +64,7 @@ func TestExchangeVersionsBasic(t *testing.T) {
func TestExchangeVersions(t *testing.T) {
cases := []string{
"not\x000allowed",
"not allowed\n",
"not allowed\x01\r\n",
}
for _, c := range cases {
buf := bytes.NewBufferString("SSH-2.0-bla\r\n")