using System; using System.Collections.Generic; using System.Diagnostics; using System.Net.Security; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Npgsql.BackendMessages; using Npgsql.Util; using static Npgsql.Util.Statics; namespace Npgsql.Internal; partial class NpgsqlConnector { async Task Authenticate(string username, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken) { var requiredAuthModes = Settings.RequireAuthModes; if (requiredAuthModes == default) requiredAuthModes = NpgsqlConnectionStringBuilder.ParseAuthMode(PostgresEnvironment.RequireAuth); var authenticated = false; while (true) { timeout.CheckAndApply(this); var msg = ExpectAny(await ReadMessage(async).ConfigureAwait(false), this); switch (msg.AuthRequestType) { case AuthenticationRequestType.Ok: // If we didn't complete authentication, check whether it's allowed if (!authenticated) { // User requested GSS authentication, but server said that no auth is required // If and only if our connection is gss encrypted, we consider us already authenticated if (requiredAuthModes.HasFlag(RequireAuthMode.GSS) && IsGssEncrypted) return; ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.None); } return; case AuthenticationRequestType.CleartextPassword: ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.Password); await AuthenticateCleartext(username, async, cancellationToken).ConfigureAwait(false); break; case AuthenticationRequestType.MD5Password: ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.MD5); await AuthenticateMD5(username, ((AuthenticationMD5PasswordMessage)msg).Salt, async, cancellationToken).ConfigureAwait(false); break; case AuthenticationRequestType.SASL: ThrowIfNotAllowed(requiredAuthModes, RequireAuthMode.ScramSHA256); await AuthenticateSASL(((AuthenticationSASLMessage)msg).Mechanisms, username, async, cancellationToken).ConfigureAwait(false); break; case AuthenticationRequestType.GSS: case AuthenticationRequestType.SSPI: ThrowIfNotAllowed(requiredAuthModes, msg.AuthRequestType == AuthenticationRequestType.GSS ? RequireAuthMode.GSS : RequireAuthMode.SSPI); await DataSource.IntegratedSecurityHandler.NegotiateAuthentication(async, this, cancellationToken).ConfigureAwait(false); return; case AuthenticationRequestType.GSSContinue: throw new NpgsqlException("Can't start auth cycle with AuthenticationGSSContinue"); default: throw new NotSupportedException($"Authentication method not supported (Received: {msg.AuthRequestType})"); } authenticated = true; } static void ThrowIfNotAllowed(RequireAuthMode requiredAuthModes, RequireAuthMode requestedAuthMode) { if (!requiredAuthModes.HasFlag(requestedAuthMode)) throw new NpgsqlException($"\"{requestedAuthMode}\" authentication method is not allowed. Allowed methods: {requiredAuthModes}"); } } async Task AuthenticateCleartext(string username, bool async, CancellationToken cancellationToken = default) { var passwd = await GetPassword(username, async, cancellationToken).ConfigureAwait(false); if (string.IsNullOrEmpty(passwd)) throw new NpgsqlException("No password has been provided but the backend requires one (in cleartext)"); var encoded = new byte[Encoding.UTF8.GetByteCount(passwd) + 1]; Encoding.UTF8.GetBytes(passwd, 0, passwd.Length, encoded, 0); await WritePassword(encoded, async, cancellationToken).ConfigureAwait(false); await Flush(async, cancellationToken).ConfigureAwait(false); } async Task AuthenticateSASL(List mechanisms, string username, bool async, CancellationToken cancellationToken) { // At the time of writing PostgreSQL only supports SCRAM-SHA-256 and SCRAM-SHA-256-PLUS var serverSupportsSha256 = mechanisms.Contains("SCRAM-SHA-256"); var allowSha256 = serverSupportsSha256 && Settings.ChannelBinding != ChannelBinding.Require; var serverSupportsSha256Plus = mechanisms.Contains("SCRAM-SHA-256-PLUS"); var allowSha256Plus = serverSupportsSha256Plus && Settings.ChannelBinding != ChannelBinding.Disable; if (!allowSha256 && !allowSha256Plus) { if (serverSupportsSha256 && Settings.ChannelBinding == ChannelBinding.Require) throw new NpgsqlException($"Couldn't connect because {nameof(ChannelBinding)} is set to {nameof(ChannelBinding.Require)} " + "but the server doesn't support SCRAM-SHA-256-PLUS"); if (serverSupportsSha256Plus && Settings.ChannelBinding == ChannelBinding.Disable) throw new NpgsqlException($"Couldn't connect because {nameof(ChannelBinding)} is set to {nameof(ChannelBinding.Disable)} " + "but the server doesn't support SCRAM-SHA-256"); throw new NpgsqlException("No supported SASL mechanism found (only SCRAM-SHA-256 and SCRAM-SHA-256-PLUS are supported for now). " + "Mechanisms received from server: " + string.Join(", ", mechanisms)); } var mechanism = string.Empty; var cbindFlag = string.Empty; var cbind = string.Empty; var successfulBind = false; if (allowSha256Plus) DataSource.TransportSecurityHandler.AuthenticateSASLSha256Plus(this, ref mechanism, ref cbindFlag, ref cbind, ref successfulBind); if (!successfulBind && allowSha256) { mechanism = "SCRAM-SHA-256"; // We can get here if PostgreSQL supports only SCRAM-SHA-256 or there was an error while binding to SCRAM-SHA-256-PLUS // Or the user specifically requested to not use bindings // So, we set 'n' (client does not support binding) if there was an error while binding // or 'y' (client supports but server doesn't) in other case cbindFlag = serverSupportsSha256Plus ? "n" : "y"; cbind = serverSupportsSha256Plus ? "biws" : "eSws"; successfulBind = true; IsScram = true; } if (!successfulBind) { // We can get here if PostgreSQL supports only SCRAM-SHA-256-PLUS but there was an error while binding to it throw new NpgsqlException("Unable to bind to SCRAM-SHA-256-PLUS, check logs for more information"); } var passwd = await GetPassword(username, async, cancellationToken).ConfigureAwait(false); if (string.IsNullOrEmpty(passwd)) throw new NpgsqlException($"No password has been provided but the backend requires one (in SASL/{mechanism})"); // Assumption: the write buffer is big enough to contain all our outgoing messages var clientNonce = GetNonce(); await WriteSASLInitialResponse(mechanism, NpgsqlWriteBuffer.UTF8Encoding.GetBytes($"{cbindFlag},,n=*,r={clientNonce}"), async, cancellationToken).ConfigureAwait(false); await Flush(async, cancellationToken).ConfigureAwait(false); var saslContinueMsg = Expect(await ReadMessage(async).ConfigureAwait(false), this); if (saslContinueMsg.AuthRequestType != AuthenticationRequestType.SASLContinue) throw new NpgsqlException("[SASL] AuthenticationSASLContinue message expected"); var firstServerMsg = AuthenticationSCRAMServerFirstMessage.Load(saslContinueMsg.Payload, ConnectionLogger); if (!firstServerMsg.Nonce.StartsWith(clientNonce, StringComparison.Ordinal)) throw new NpgsqlException("[SCRAM] Malformed SCRAMServerFirst message: server nonce doesn't start with client nonce"); var saltBytes = Convert.FromBase64String(firstServerMsg.Salt); var saltedPassword = Hi(passwd.Normalize(NormalizationForm.FormKC), saltBytes, firstServerMsg.Iteration); var clientKey = HMAC(saltedPassword, "Client Key"); var storedKey = SHA256.HashData(clientKey); var clientFirstMessageBare = $"n=*,r={clientNonce}"; var serverFirstMessage = $"r={firstServerMsg.Nonce},s={firstServerMsg.Salt},i={firstServerMsg.Iteration}"; var clientFinalMessageWithoutProof = $"c={cbind},r={firstServerMsg.Nonce}"; var authMessage = $"{clientFirstMessageBare},{serverFirstMessage},{clientFinalMessageWithoutProof}"; var clientSignature = HMAC(storedKey, authMessage); var clientProofBytes = Xor(clientKey, clientSignature); var clientProof = Convert.ToBase64String(clientProofBytes); var serverKey = HMAC(saltedPassword, "Server Key"); var serverSignature = HMAC(serverKey, authMessage); var messageStr = $"{clientFinalMessageWithoutProof},p={clientProof}"; await WriteSASLResponse(Encoding.UTF8.GetBytes(messageStr), async, cancellationToken).ConfigureAwait(false); await Flush(async, cancellationToken).ConfigureAwait(false); var saslFinalServerMsg = Expect(await ReadMessage(async).ConfigureAwait(false), this); if (saslFinalServerMsg.AuthRequestType != AuthenticationRequestType.SASLFinal) throw new NpgsqlException("[SASL] AuthenticationSASLFinal message expected"); var scramFinalServerMsg = AuthenticationSCRAMServerFinalMessage.Load(saslFinalServerMsg.Payload, ConnectionLogger); if (scramFinalServerMsg.ServerSignature != Convert.ToBase64String(serverSignature)) throw new NpgsqlException("[SCRAM] Unable to verify server signature"); static string GetNonce() { using var rncProvider = RandomNumberGenerator.Create(); var nonceBytes = new byte[18]; rncProvider.GetBytes(nonceBytes); return Convert.ToBase64String(nonceBytes); } } internal void AuthenticateSASLSha256Plus(ref string mechanism, ref string cbindFlag, ref string cbind, ref bool successfulBind) { // The check below is copied from libpq (with commentary) // https://github.com/postgres/postgres/blob/98640f960eb9ed80cf90de3ef5d2e829b785b3eb/src/interfaces/libpq/fe-auth.c#L507-L517 // The server offered SCRAM-SHA-256-PLUS, but the connection // is not SSL-encrypted. That's not sane. Perhaps SSL was // stripped by a proxy? There's no point in continuing, // because the server will reject the connection anyway if we // try authenticate without channel binding even though both // the client and server supported it. The SCRAM exchange // checks for that, to prevent downgrade attacks. if (!IsSslEncrypted) throw new NpgsqlException("Server offered SCRAM-SHA-256-PLUS authentication over a non-SSL connection"); var sslStream = (SslStream)_stream; if (sslStream.RemoteCertificate is null) { ConnectionLogger.LogWarning("Remote certificate null, falling back to SCRAM-SHA-256"); return; } // While SslStream.RemoteCertificate is X509Certificate2, it actually returns X509Certificate2 // But to be on the safe side we'll just create a new instance of it using var remoteCertificate = new X509Certificate2(sslStream.RemoteCertificate); // Checking for hashing algorithms var algorithmName = remoteCertificate.SignatureAlgorithm.FriendlyName; HashAlgorithm? hashAlgorithm = algorithmName switch { not null when algorithmName.StartsWith("sha1", StringComparison.OrdinalIgnoreCase) => SHA256.Create(), not null when algorithmName.StartsWith("md5", StringComparison.OrdinalIgnoreCase) => SHA256.Create(), not null when algorithmName.StartsWith("sha256", StringComparison.OrdinalIgnoreCase) => SHA256.Create(), not null when algorithmName.StartsWith("sha384", StringComparison.OrdinalIgnoreCase) => SHA384.Create(), not null when algorithmName.StartsWith("sha512", StringComparison.OrdinalIgnoreCase) => SHA512.Create(), not null when algorithmName.StartsWith("sha3-256", StringComparison.OrdinalIgnoreCase) => SHA3_256.Create(), not null when algorithmName.StartsWith("sha3-384", StringComparison.OrdinalIgnoreCase) => SHA3_384.Create(), not null when algorithmName.StartsWith("sha3-512", StringComparison.OrdinalIgnoreCase) => SHA3_512.Create(), _ => null }; if (hashAlgorithm is null) { ConnectionLogger.LogWarning( algorithmName is null ? "Signature algorithm was null, falling back to SCRAM-SHA-256" : $"Support for signature algorithm {algorithmName} is not yet implemented, falling back to SCRAM-SHA-256"); return; } using var _ = hashAlgorithm; // RFC 5929 mechanism = "SCRAM-SHA-256-PLUS"; // PostgreSQL only supports tls-server-end-point binding cbindFlag = "p=tls-server-end-point"; // SCRAM-SHA-256-PLUS depends on using ssl stream, so it's fine var cbindFlagBytes = Encoding.UTF8.GetBytes($"{cbindFlag},,"); var certificateHash = hashAlgorithm.ComputeHash(remoteCertificate.GetRawCertData()); var cbindBytes = new byte[cbindFlagBytes.Length + certificateHash.Length]; cbindFlagBytes.CopyTo(cbindBytes, 0); certificateHash.CopyTo(cbindBytes, cbindFlagBytes.Length); cbind = Convert.ToBase64String(cbindBytes); successfulBind = true; IsScramPlus = true; } static byte[] Hi(string str, byte[] salt, int count) => Rfc2898DeriveBytes.Pbkdf2(str, salt, count, HashAlgorithmName.SHA256, 256 / 8); static byte[] Xor(byte[] buffer1, byte[] buffer2) { for (var i = 0; i < buffer1.Length; i++) buffer1[i] ^= buffer2[i]; return buffer1; } static byte[] HMAC(byte[] key, string data) => HMACSHA256.HashData(key, Encoding.UTF8.GetBytes(data)); async Task AuthenticateMD5(string username, byte[] salt, bool async, CancellationToken cancellationToken = default) { var passwd = await GetPassword(username, async, cancellationToken).ConfigureAwait(false); if (string.IsNullOrEmpty(passwd)) throw new NpgsqlException("No password has been provided but the backend requires one (in MD5)"); byte[] result; { // First phase var passwordBytes = NpgsqlWriteBuffer.UTF8Encoding.GetBytes(passwd); var usernameBytes = NpgsqlWriteBuffer.UTF8Encoding.GetBytes(username); var cryptBuf = new byte[passwordBytes.Length + usernameBytes.Length]; passwordBytes.CopyTo(cryptBuf, 0); usernameBytes.CopyTo(cryptBuf, passwordBytes.Length); var sb = new StringBuilder(); var hashResult = MD5.HashData(cryptBuf); foreach (var b in hashResult) sb.Append(b.ToString("x2")); var prehash = sb.ToString(); var prehashbytes = NpgsqlWriteBuffer.UTF8Encoding.GetBytes(prehash); cryptBuf = new byte[prehashbytes.Length + 4]; Array.Copy(salt, 0, cryptBuf, prehashbytes.Length, 4); // 2. prehashbytes.CopyTo(cryptBuf, 0); sb = new StringBuilder("md5"); hashResult = MD5.HashData(cryptBuf); foreach (var b in hashResult) sb.Append(b.ToString("x2")); var resultString = sb.ToString(); result = new byte[Encoding.UTF8.GetByteCount(resultString) + 1]; Encoding.UTF8.GetBytes(resultString, 0, resultString.Length, result, 0); result[^1] = 0; } await WritePassword(result, async, cancellationToken).ConfigureAwait(false); await Flush(async, cancellationToken).ConfigureAwait(false); } internal async ValueTask AuthenticateGSS(bool async, CancellationToken cancellationToken) { var targetName = $"{KerberosServiceName}/{Host}"; var clientOptions = new NegotiateAuthenticationClientOptions { TargetName = targetName }; NegotiateOptionsCallback?.Invoke(clientOptions); using var authContext = new NegotiateAuthentication(clientOptions); var data = authContext.GetOutgoingBlob(ReadOnlySpan.Empty, out var statusCode)!; if (statusCode != NegotiateAuthenticationStatusCode.ContinueNeeded) { // Unable to retrieve credentials or some other issue throw new NpgsqlException($"Unable to authenticate with GSS: received {statusCode} instead of the expected ContinueNeeded"); } await WritePassword(data, 0, data.Length, async, cancellationToken).ConfigureAwait(false); await Flush(async, cancellationToken).ConfigureAwait(false); while (true) { var response = ExpectAny(await ReadMessage(async).ConfigureAwait(false), this); if (response.AuthRequestType == AuthenticationRequestType.Ok) break; if (response is not AuthenticationGSSContinueMessage gssMsg) throw new NpgsqlException($"Received unexpected authentication request message {response.AuthRequestType}"); data = authContext.GetOutgoingBlob(gssMsg.AuthenticationData.AsSpan(), out statusCode); if (statusCode is not NegotiateAuthenticationStatusCode.Completed and not NegotiateAuthenticationStatusCode.ContinueNeeded) throw new NpgsqlException($"Error while authenticating GSS/SSPI: {statusCode}"); // We might get NegotiateAuthenticationStatusCode.Completed but the data will not be null // This can happen if it's the first cycle, in which case we have to send that data to complete handshake (#4888) if (data is null) continue; await WritePassword(data, 0, data.Length, async, cancellationToken).ConfigureAwait(false); await Flush(async, cancellationToken).ConfigureAwait(false); } } async ValueTask GetPassword(string username, bool async, CancellationToken cancellationToken = default) { var password = await DataSource.GetPassword(async, cancellationToken).ConfigureAwait(false); if (password is not null) return password; if (ProvidePasswordCallback is { } passwordCallback) { try { ConnectionLogger.LogTrace($"Taking password from {nameof(ProvidePasswordCallback)} delegate"); password = passwordCallback(Host, Port, Settings.Database!, username); } catch (Exception e) { throw new NpgsqlException($"Obtaining password using {nameof(NpgsqlConnection)}.{nameof(ProvidePasswordCallback)} delegate failed", e); } } password ??= PostgresEnvironment.Password; if (password != null) return password; var passFile = Settings.Passfile ?? PostgresEnvironment.PassFile ?? PostgresEnvironment.PassFileDefault; if (passFile != null) { var matchingEntry = new PgPassFile(passFile!) .GetFirstMatchingEntry(Host, Port, Settings.Database!, username); if (matchingEntry != null) { ConnectionLogger.LogTrace("Taking password from pgpass file"); password = matchingEntry.Password; } } return password; } }