diff --git a/sql/mysql_db/auth.go b/sql/mysql_db/auth.go index dead3fb36e..b88e6a87c9 100644 --- a/sql/mysql_db/auth.go +++ b/sql/mysql_db/auth.go @@ -17,8 +17,12 @@ package mysql_db import ( "bytes" "crypto/sha1" + "crypto/tls" + "crypto/x509/pkix" "encoding/hex" + "fmt" "net" + "strings" "github.com/dolthub/vitess/go/mysql" "github.com/sirupsen/logrus" @@ -107,7 +111,7 @@ var _ mysql.CachingStorage = (*noopCachingStorage)(nil) // // This implementation also handles authentication when a client doesn't send an auth response and // the associated user account does not have a password set. -func (n noopCachingStorage) UserEntryWithCacheHash(_ *mysql.Conn, _ []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, mysql.CacheState, error) { +func (n noopCachingStorage) UserEntryWithCacheHash(conn *mysql.Conn, _ []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, mysql.CacheState, error) { db := n.db // If there is no mysql database of user info, then don't approve or reject, since we can't look at @@ -131,7 +135,12 @@ func (n noopCachingStorage) UserEntryWithCacheHash(_ *mysql.Conn, _ []byte, user userEntry := db.GetUser(rd, user, host, false) if userEntry == nil || userEntry.Locked { - return nil, mysql.AuthRejected, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) + return nil, mysql.AuthRejected, newAccessDeniedError(user) + } + + // validate any extra connection security requirements, such as SSL or a client cert + if err = validateConnectionSecurity(userEntry, conn); err != nil { + return nil, mysql.AuthRejected, err } if userEntry.AuthString == "" { @@ -166,7 +175,7 @@ var _ mysql.PlainTextStorage = (*sha2PlainTextStorage)(nil) // UserEntryWithPassword implements the mysql.PlainTextStorage interface. // The auth framework in Vitess also passes in user certificates, but we don't support that feature yet. -func (s sha2PlainTextStorage) UserEntryWithPassword(_ *mysql.Conn, user string, password string, remoteAddr net.Addr) (mysql.Getter, error) { +func (s sha2PlainTextStorage) UserEntryWithPassword(conn *mysql.Conn, user string, password string, remoteAddr net.Addr) (mysql.Getter, error) { db := s.db host, err := extractHostAddress(remoteAddr) @@ -183,7 +192,12 @@ func (s sha2PlainTextStorage) UserEntryWithPassword(_ *mysql.Conn, user string, userEntry := db.GetUser(rd, user, host, false) if userEntry == nil || userEntry.Locked { - return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) + return nil, newAccessDeniedError(userEntry.User) + } + + // validate any extra connection security requirements, such as SSL or a client cert + if err = validateConnectionSecurity(userEntry, conn); err != nil { + return nil, err } if len(userEntry.AuthString) > 0 { @@ -202,12 +216,12 @@ func (s sha2PlainTextStorage) UserEntryWithPassword(_ *mysql.Conn, user string, } if userEntry.AuthString != string(authString) { - return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) + return nil, newAccessDeniedError(user) } } else if len(password) > 0 { // password is nil or empty, therefore no password is set // a password was given and the account has no password set, therefore access is denied - return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) + return nil, newAccessDeniedError(user) } return sql.MysqlConnectionUser{User: userEntry.User, Host: userEntry.Host}, nil @@ -269,8 +283,7 @@ func (f extendedAuthPlainTextStorage) UserEntryWithPassword(conn *mysql.Conn, us "Access denied for user '%v': %v", user, err) } if !authed { - return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, - "Access denied for user '%v'", user) + return nil, newAccessDeniedError(user) } return connUser, nil } @@ -329,7 +342,7 @@ var _ mysql.HashStorage = (*nativePasswordHashStorage)(nil) // UserEntryWithHash implements the mysql.HashStorage interface. This implementation is called by the MySQL // native password auth method to validate a password hash with the user's stored password hash. -func (nphs *nativePasswordHashStorage) UserEntryWithHash(_ *mysql.Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, error) { +func (nphs *nativePasswordHashStorage) UserEntryWithHash(conn *mysql.Conn, salt []byte, user string, authResponse []byte, remoteAddr net.Addr) (mysql.Getter, error) { db := nphs.db host, err := extractHostAddress(remoteAddr) @@ -346,21 +359,94 @@ func (nphs *nativePasswordHashStorage) UserEntryWithHash(_ *mysql.Conn, salt []b userEntry := db.GetUser(rd, user, host, false) if userEntry == nil || userEntry.Locked { - return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) + return nil, newAccessDeniedError(user) + } + + // validate any extra connection security requirements, such as SSL or a client cert + if err = validateConnectionSecurity(userEntry, conn); err != nil { + return nil, err } + if len(userEntry.AuthString) > 0 { if !validateMysqlNativePassword(authResponse, salt, userEntry.AuthString) { - return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) + return nil, newAccessDeniedError(user) } } else if len(authResponse) > 0 { // password is nil or empty, therefore no password is set // a password was given and the account has no password set, therefore access is denied - return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user) + return nil, newAccessDeniedError(user) } return sql.MysqlConnectionUser{User: userEntry.User, Host: userEntry.Host}, nil } +// validateConnectionSecurity examines the security properties of |conn| (e.g. TLS, +// selected cipher, X509 client certs) and validates specific connection properties +// based on what |userEntry| has configured. An error is returned if any validation +// issues were detected, otherwise nil is returned. +func validateConnectionSecurity(userEntry *User, conn *mysql.Conn) error { + switch userEntry.SslType { + case "": + // No connection security validation needed + return nil + case "ANY": + // ANY indicates that we need any form of secure socket + if !conn.TLSEnabled() { + return newAccessDeniedError(userEntry.User) + } + case "X509": + // X509 means that a valid X509 client certificate is required + // NOTE: cert validation (e.g. expiration date, CA chain) is handled + // in the Go networking stack, so long as tls.VerifyClientCertIfGiven + // is specified in the TLS configuration for the server. + clientCerts := conn.GetTLSClientCerts() + if len(clientCerts) == 0 { + return newAccessDeniedError(userEntry.User) + } + case "SPECIFIED": + // Specified means that we have additional requirements on either the SSL cipher + // or the X509 cert, so we need to perform additional validation checks. + if !conn.TLSEnabled() { + return newAccessDeniedError(userEntry.User) + } + if userEntry.SslCipher != "" { + tlsConn, ok := conn.Conn.(*tls.Conn) + if !ok { + return newAccessDeniedError(userEntry.User) + } + state := tlsConn.ConnectionState() + cipherSuiteName := tls.CipherSuiteName(state.CipherSuite) + if cipherSuiteName != userEntry.SslCipher { + return newAccessDeniedError(userEntry.User) + } + } + if userEntry.X509Issuer != "" { + if len(conn.GetTLSClientCerts()) == 0 { + return newAccessDeniedError(userEntry.User) + } + clientCert := conn.GetTLSClientCerts()[0] + normalizedIssuer := formatDistinguishedNameForMySQL(clientCert.Issuer) + if normalizedIssuer != userEntry.X509Issuer { + return newAccessDeniedError(userEntry.User) + } + } + if userEntry.X509Subject != "" { + if len(conn.GetTLSClientCerts()) == 0 { + return newAccessDeniedError(userEntry.User) + } + clientCert := conn.GetTLSClientCerts()[0] + normalizedSubject := formatDistinguishedNameForMySQL(clientCert.Subject) + if normalizedSubject != userEntry.X509Subject { + return newAccessDeniedError(userEntry.User) + } + } + default: + return fmt.Errorf("unsupported ssl_type: %v", userEntry.SslType) + } + + return nil +} + // userValidator implements the mysql.UserValidator interface. It looks up a user and host from the // associated mysql database (|db|) and validates that a user entry exists and that it is configured // for the specified authentication plugin (|authMethod|). @@ -408,7 +494,13 @@ func (uv *userValidator) HandleUser(user string, remoteAddr net.Addr) bool { } userEntry := db.GetUser(rd, user, host, false) - return userEntry != nil && userEntry.Plugin == string(uv.authMethod) + // If we don't find a matching user, or we find one, but it's for a different auth method, + // then return false to indicate this auth method can't handle that user. + if userEntry == nil || userEntry.Plugin != string(uv.authMethod) { + return false + } + + return true } // extractHostAddress extracts the host address from |addr|, checking to see if it is a unix socket, and if @@ -429,6 +521,30 @@ func extractHostAddress(addr net.Addr) (host string, err error) { return host, nil } +// newAccessDeniedError returns an "access denied" error, including the |userName| trying to authenticate, +// matching MySQL's error message. Note that MySQL tends to return a generic "access denied" error message +// for authentication failures, without leaking more details about why so that attackers can't exploit that +// information to determine how a user is configured for authentication. +func newAccessDeniedError(userName string) error { + return mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", userName) +} + +// formatDistinguishedNameForMySQL returns a distinguished name, created from |name|, that matches +// MySQL's formatting style (e.g. "/C=US/ST=Washington/L=Seattle/O=Test CA/CN=MySQL Test CA"). +// By default, Golang's stack uses a different format when converting a pkix.Name to a string. This +// function reverses the order of the elements and uses a "/" prefix for each element, instead of a +// "," in between elements. +func formatDistinguishedNameForMySQL(name pkix.Name) string { + parts := strings.Split(name.String(), ",") + + b := strings.Builder{} + for i := len(parts) - 1; i >= 0; i-- { + b.WriteString("/") + b.WriteString(parts[i]) + } + return b.String() +} + // validateMysqlNativePassword was taken from vitess and validates the password hash for the mysql_native_password // auth protocol. Note that this implementation has diverged slightly from the original code in Vitess. func validateMysqlNativePassword(authResponse, salt []byte, mysqlNativePassword string) bool {