using System;
using System.Buffers;
using System.Collections.Generic;
using System.Data;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Security;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.ExceptionServices;
using System.Security.Authentication;
using System.Security.Cryptography.X509Certificates;
using System.Text;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Npgsql.BackendMessages;
using Npgsql.TypeMapping;
using Npgsql.Util;
using static Npgsql.Util.Statics;
using System.Transactions;
using Microsoft.Extensions.Logging;
using Npgsql.Properties;
namespace Npgsql.Internal;
///
/// Represents a connection to a PostgreSQL backend. Unlike NpgsqlConnection objects, which are
/// exposed to users, connectors are internal to Npgsql and are recycled by the connection pool.
///
public sealed partial class NpgsqlConnector : IDisposable
{
#region Fields and Properties
///
/// The physical connection socket to the backend.
///
Socket _socket = default!;
///
/// The physical connection stream to the backend, without anything on top.
///
NetworkStream _baseStream = default!;
///
/// The physical connection stream to the backend, layered with an SSL/TLS stream if in secure mode.
///
Stream _stream = default!;
///
/// The parsed connection string.
///
public NpgsqlConnectionStringBuilder Settings { get; }
Action? ClientCertificatesCallback { get; }
RemoteCertificateValidationCallback? UserCertificateValidationCallback { get; }
#pragma warning disable CS0618 // ProvidePasswordCallback is obsolete
ProvidePasswordCallback? ProvidePasswordCallback { get; }
#pragma warning restore CS0618
public Encoding TextEncoding { get; private set; } = default!;
///
/// Same as , except that it does not throw an exception if an invalid char is
/// encountered (exception fallback), but rather replaces it with a question mark character (replacement
/// fallback).
///
internal Encoding RelaxedTextEncoding { get; private set; } = default!;
///
/// Buffer used for reading data.
///
internal NpgsqlReadBuffer ReadBuffer { get; private set; } = default!;
///
/// If we read a data row that's bigger than , we allocate an oversize buffer.
/// The original (smaller) buffer is stored here, and restored when the connection is reset.
///
NpgsqlReadBuffer? _origReadBuffer;
///
/// Buffer used for writing data.
///
internal NpgsqlWriteBuffer WriteBuffer { get; private set; } = default!;
///
/// The secret key of the backend for this connector, used for query cancellation.
///
int _backendSecretKey;
///
/// The process ID of the backend for this connector.
///
internal int BackendProcessId { get; private set; }
string? _inferredUserName;
///
/// The user name that has been inferred when the connector was opened
///
internal string InferredUserName
{
get => _inferredUserName ?? throw new InvalidOperationException($"{nameof(InferredUserName)} cannot be accessed before the connector has been opened.");
private set => _inferredUserName = value;
}
bool SupportsPostgresCancellation => BackendProcessId != 0;
///
/// A unique ID identifying this connector, used for logging. Currently mapped to BackendProcessId
///
internal int Id => BackendProcessId;
///
/// Information about PostgreSQL and PostgreSQL-like databases (e.g. type definitions, capabilities...).
///
public NpgsqlDatabaseInfo DatabaseInfo { get; internal set; } = default!;
internal TypeMapper TypeMapper { get; set; } = default!;
///
/// The current transaction status for this connector.
///
internal TransactionStatus TransactionStatus { get; set; }
///
/// A transaction object for this connector. Since only one transaction can be in progress at any given time,
/// this instance is recycled. To check whether a transaction is currently in progress on this connector,
/// see .
///
internal NpgsqlTransaction? Transaction { get; set; }
internal NpgsqlTransaction? UnboundTransaction { get; set; }
///
/// The NpgsqlConnection that (currently) owns this connector. Null if the connector isn't
/// owned (i.e. idle in the pool)
///
internal NpgsqlConnection? Connection { get; set; }
///
/// The number of messages that were prepended to the current message chain, but not yet sent.
/// Note that this only tracks messages which produce a ReadyForQuery message
///
internal int PendingPrependedResponses { get; set; }
///
/// A ManualResetEventSlim used to make sure a cancellation request doesn't run
/// while we're reading responses for the prepended query
/// as we can't gracefully handle their cancellation.
///
readonly ManualResetEventSlim ReadingPrependedMessagesMRE = new(initialState: true);
internal NpgsqlDataReader? CurrentReader;
internal PreparedStatementManager PreparedStatementManager { get; }
internal SqlQueryParser SqlQueryParser { get; } = new();
///
/// If the connector is currently in COPY mode, holds a reference to the importer/exporter object.
/// Otherwise null.
///
internal ICancelable? CurrentCopyOperation;
///
/// Holds all run-time parameters received from the backend (via ParameterStatus messages)
///
internal Dictionary PostgresParameters { get; }
///
/// Holds all run-time parameters in raw, binary format for efficient handling without allocations.
///
readonly List<(byte[] Name, byte[] Value)> _rawParameters = new();
///
/// If this connector was broken, this contains the exception that caused the break.
///
volatile Exception? _breakReason;
///
///
/// Used by the pool to indicate that I/O is currently in progress on this connector, so that another write
/// isn't started concurrently. Note that since we have only one write loop, this is only ever usedto
/// protect against an over-capacity writes into a connector that's currently *asynchronously* writing.
///
///
/// It is guaranteed that the currently-executing
/// Specifically, reading may occur - and the connector may even be returned to the pool - before this is
/// released.
///
///
internal volatile int MultiplexAsyncWritingLock;
///
internal void FlagAsNotWritableForMultiplexing()
{
Debug.Assert(Settings.Multiplexing);
Debug.Assert(CommandsInFlightCount > 0 || IsBroken || IsClosed,
$"About to mark multiplexing connector as non-writable, but {nameof(CommandsInFlightCount)} is {CommandsInFlightCount}");
Interlocked.Exchange(ref MultiplexAsyncWritingLock, 1);
}
///
internal void FlagAsWritableForMultiplexing()
{
Debug.Assert(Settings.Multiplexing);
if (Interlocked.CompareExchange(ref MultiplexAsyncWritingLock, 0, 1) != 1)
throw new Exception("Multiplexing lock was not taken when releasing. Please report a bug.");
}
///
/// The timeout for reading messages that are part of the user's command
/// (i.e. which aren't internal prepended commands).
///
/// Precision is milliseconds
internal int UserTimeout { private get; set; }
///
/// A lock that's taken while a cancellation is being delivered; new queries are blocked until the
/// cancellation is delivered. This reduces the chance that a cancellation meant for a previous
/// command will accidentally cancel a later one, see #615.
///
object CancelLock { get; } = new();
///
/// A lock that's taken to make sure no other concurrent operation is running.
/// Break takes it to set the state of the connector.
/// Anyone else should immediately check the state and exit
/// if the connector is closed.
///
object SyncObj { get; } = new();
///
/// A lock that's used to wait for the Cleanup to complete while breaking the connection.
///
object CleanupLock { get; } = new();
readonly bool _isKeepAliveEnabled;
readonly Timer? _keepAliveTimer;
///
/// The command currently being executed by the connector, null otherwise.
/// Used only for concurrent use error reporting purposes.
///
NpgsqlCommand? _currentCommand;
bool _sendResetOnClose;
///
/// The connector source (e.g. pool) from where this connector came, and to which it will be returned.
/// Note that in multi-host scenarios, this references the host-specific rather than the
/// .
///
internal NpgsqlDataSource DataSource { get; }
internal string UserFacingConnectionString => DataSource.ConnectionString;
///
/// Contains the UTC timestamp when this connector was opened, used to implement
/// .
///
internal DateTime OpenTimestamp { get; private set; }
internal int ClearCounter { get; set; }
volatile bool _postgresCancellationPerformed;
internal bool PostgresCancellationPerformed
{
get => _postgresCancellationPerformed;
private set => _postgresCancellationPerformed = value;
}
volatile bool _userCancellationRequested;
CancellationTokenRegistration _cancellationTokenRegistration;
internal bool UserCancellationRequested => _userCancellationRequested;
internal CancellationToken UserCancellationToken { get; set; }
internal bool AttemptPostgresCancellation { get; private set; }
static readonly TimeSpan _cancelImmediatelyTimeout = TimeSpan.FromMilliseconds(-1);
internal NpgsqlLoggingConfiguration LoggingConfiguration { get; }
internal ILogger ConnectionLogger { get; }
internal ILogger CommandLogger { get; }
internal ILogger TransactionLogger { get; }
internal ILogger CopyLogger { get; }
internal readonly Stopwatch QueryLogStopWatch = new();
internal EndPoint? ConnectedEndPoint { get; private set; }
#endregion
#region Constants
///
/// The minimum timeout that can be set on internal commands such as COMMIT, ROLLBACK.
///
/// Precision is seconds
internal const int MinimumInternalCommandTimeout = 3;
#endregion
#region Reusable Message Objects
byte[]? _resetWithoutDeallocateMessage;
int _resetWithoutDeallocateResponseCount;
// Backend
readonly CommandCompleteMessage _commandCompleteMessage = new();
readonly ReadyForQueryMessage _readyForQueryMessage = new();
readonly ParameterDescriptionMessage _parameterDescriptionMessage = new();
readonly DataRowMessage _dataRowMessage = new();
readonly RowDescriptionMessage _rowDescriptionMessage = new();
// Since COPY is rarely used, allocate these lazily
CopyInResponseMessage? _copyInResponseMessage;
CopyOutResponseMessage? _copyOutResponseMessage;
CopyDataMessage? _copyDataMessage;
CopyBothResponseMessage? _copyBothResponseMessage;
#endregion
internal NpgsqlDataReader DataReader { get; set; }
internal NpgsqlDataReader? UnboundDataReader { get; set; }
#region Constructors
internal NpgsqlConnector(NpgsqlDataSource dataSource, NpgsqlConnection conn)
: this(dataSource)
{
if (conn.ProvideClientCertificatesCallback is not null)
ClientCertificatesCallback = certs => conn.ProvideClientCertificatesCallback(certs);
if (conn.UserCertificateValidationCallback is not null)
UserCertificateValidationCallback = conn.UserCertificateValidationCallback;
#pragma warning disable CS0618 // Obsolete
ProvidePasswordCallback = conn.ProvidePasswordCallback;
#pragma warning restore CS0618
}
NpgsqlConnector(NpgsqlConnector connector)
: this(connector.DataSource)
{
ClientCertificatesCallback = connector.ClientCertificatesCallback;
UserCertificateValidationCallback = connector.UserCertificateValidationCallback;
ProvidePasswordCallback = connector.ProvidePasswordCallback;
}
NpgsqlConnector(NpgsqlDataSource dataSource)
{
Debug.Assert(dataSource.OwnsConnectors);
DataSource = dataSource;
LoggingConfiguration = dataSource.LoggingConfiguration;
ConnectionLogger = LoggingConfiguration.ConnectionLogger;
CommandLogger = LoggingConfiguration.CommandLogger;
TransactionLogger = LoggingConfiguration.TransactionLogger;
CopyLogger = LoggingConfiguration.CopyLogger;
ClientCertificatesCallback = dataSource.ClientCertificatesCallback;
UserCertificateValidationCallback = dataSource.UserCertificateValidationCallback;
State = ConnectorState.Closed;
TransactionStatus = TransactionStatus.Idle;
Settings = dataSource.Settings;
PostgresParameters = new Dictionary();
_isKeepAliveEnabled = Settings.KeepAlive > 0;
if (_isKeepAliveEnabled)
_keepAliveTimer = new Timer(PerformKeepAlive, null, Timeout.Infinite, Timeout.Infinite);
DataReader = new NpgsqlDataReader(this);
// TODO: Not just for automatic preparation anymore...
PreparedStatementManager = new PreparedStatementManager(this);
if (Settings.Multiplexing)
{
// Note: It's OK for this channel to be unbounded: each command enqueued to it is accompanied by sending
// it to PostgreSQL. If we overload it, a TCP zero window will make us block on the networking side
// anyway.
// Note: the in-flight channel can probably be single-writer, but that doesn't actually do anything
// at this point. And we currently rely on being able to complete the channel at any point (from
// Break). We may want to revisit this if an optimized, SingleWriter implementation is introduced.
var commandsInFlightChannel = Channel.CreateUnbounded(
new UnboundedChannelOptions { SingleReader = true });
CommandsInFlightReader = commandsInFlightChannel.Reader;
CommandsInFlightWriter = commandsInFlightChannel.Writer;
// TODO: Properly implement this
if (_isKeepAliveEnabled)
throw new NotImplementedException("Keepalive not yet implemented for multiplexing");
}
}
#endregion
#region Configuration settings
internal string Host => Settings.Host!;
internal int Port => Settings.Port;
internal string Database => Settings.Database!;
string KerberosServiceName => Settings.KerberosServiceName;
int ConnectionTimeout => Settings.Timeout;
bool IntegratedSecurity => Settings.IntegratedSecurity;
///
/// The actual command timeout value that gets set on internal commands.
///
/// Precision is milliseconds
int InternalCommandTimeout
{
get
{
var internalTimeout = Settings.InternalCommandTimeout;
if (internalTimeout == -1)
return Math.Max(Settings.CommandTimeout, MinimumInternalCommandTimeout) * 1000;
// Todo: Decide what we really want here
// This assertion can easily fail if InternalCommandTimeout is set to 1 or 2 in the connection string
// We probably don't want to allow these values but in that case a Debug.Assert is the wrong way to enforce it.
Debug.Assert(internalTimeout == 0 || internalTimeout >= MinimumInternalCommandTimeout);
return internalTimeout * 1000;
}
}
#endregion Configuration settings
#region State management
int _state;
///
/// Gets the current state of the connector
///
internal ConnectorState State
{
get => (ConnectorState)_state;
set
{
var newState = (int)value;
if (newState == _state)
return;
Interlocked.Exchange(ref _state, newState);
}
}
///
/// Returns whether the connector is open, regardless of any task it is currently performing
///
bool IsConnected
=> State switch
{
ConnectorState.Ready => true,
ConnectorState.Executing => true,
ConnectorState.Fetching => true,
ConnectorState.Waiting => true,
ConnectorState.Copy => true,
ConnectorState.Replication => true,
ConnectorState.Closed => false,
ConnectorState.Connecting => false,
ConnectorState.Broken => false,
_ => throw new ArgumentOutOfRangeException("Unknown state: " + State)
};
internal bool IsReady => State == ConnectorState.Ready;
internal bool IsClosed => State == ConnectorState.Closed;
internal bool IsBroken => State == ConnectorState.Broken;
#endregion
#region Open
///
/// Opens the physical connection to the server.
///
/// Usually called by the RequestConnector
/// Method of the connection pool manager.
internal async Task Open(NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken)
{
Debug.Assert(State == ConnectorState.Closed);
State = ConnectorState.Connecting;
LogMessages.OpeningPhysicalConnection(ConnectionLogger, Host, Port, Database, UserFacingConnectionString);
var stopwatch = Stopwatch.StartNew();
try
{
await OpenCore(this, Settings.SslMode, timeout, async, cancellationToken);
await DataSource.Bootstrap(this, timeout, forceReload: false, async, cancellationToken);
Debug.Assert(DataSource.TypeMapper is not null);
Debug.Assert(DataSource.DatabaseInfo is not null);
TypeMapper = DataSource.TypeMapper;
DatabaseInfo = DataSource.DatabaseInfo;
if (Settings.Pooling && !Settings.Multiplexing && !Settings.NoResetOnClose && DatabaseInfo.SupportsDiscard)
{
_sendResetOnClose = true;
GenerateResetMessage();
}
OpenTimestamp = DateTime.UtcNow;
if (Settings.Multiplexing)
{
// Start an infinite async loop, which processes incoming multiplexing traffic.
// It is intentionally not awaited and will run as long as the connector is alive.
// The CommandsInFlightWriter channel is completed in Cleanup, which should cause this task
// to complete.
_ = Task.Run(MultiplexingReadLoop, CancellationToken.None)
.ContinueWith(t =>
{
// Note that we *must* observe the exception if the task is faulted.
ConnectionLogger.LogError(t.Exception!, "Exception bubbled out of multiplexing read loop", Id);
}, TaskContinuationOptions.OnlyOnFaulted);
}
if (_isKeepAliveEnabled)
{
// Start the keep alive mechanism to work by scheduling the timer.
// Otherwise, it doesn't work for cases when no query executed during
// the connection lifetime in case of a new connector.
lock (SyncObj)
{
var keepAlive = Settings.KeepAlive * 1000;
_keepAliveTimer!.Change(keepAlive, keepAlive);
}
}
if (DataSource.ConnectionInitializerAsync is not null)
{
Debug.Assert(DataSource.ConnectionInitializer is not null);
var tempConnection = new NpgsqlConnection(DataSource, this);
try
{
if (async)
await DataSource.ConnectionInitializerAsync(tempConnection);
else if (!async)
DataSource.ConnectionInitializer(tempConnection);
}
finally
{
// Note that we can't just close/dispose the NpgsqlConnection, since that puts the connector back in the pool.
// But we transition it to disposed immediately, in case the user decides to capture the NpgsqlConnection and use it
// later.
Connection?.MakeDisposed();
Connection = null;
}
}
LogMessages.OpenedPhysicalConnection(
ConnectionLogger, Host, Port, Database, UserFacingConnectionString, stopwatch.ElapsedMilliseconds, Id);
}
catch (Exception e)
{
Break(e);
throw;
}
static async Task OpenCore(
NpgsqlConnector conn,
SslMode sslMode,
NpgsqlTimeout timeout,
bool async,
CancellationToken cancellationToken,
bool isFirstAttempt = true)
{
await conn.RawOpen(sslMode, timeout, async, cancellationToken, isFirstAttempt);
var username = await conn.GetUsernameAsync(async, cancellationToken);
timeout.CheckAndApply(conn);
conn.WriteStartupMessage(username);
await conn.Flush(async, cancellationToken);
var cancellationRegistration = conn.StartCancellableOperation(cancellationToken, attemptPgCancellation: false);
try
{
await conn.Authenticate(username, timeout, async, cancellationToken);
}
catch (PostgresException e)
when (e.SqlState == PostgresErrorCodes.InvalidAuthorizationSpecification &&
(sslMode == SslMode.Prefer && conn.IsSecure || sslMode == SslMode.Allow && !conn.IsSecure))
{
cancellationRegistration.Dispose();
Debug.Assert(!conn.IsBroken);
conn.Cleanup();
// If Prefer was specified and we failed (with SSL), retry without SSL.
// If Allow was specified and we failed (without SSL), retry with SSL
await OpenCore(
conn,
sslMode == SslMode.Prefer ? SslMode.Disable : SslMode.Require,
timeout,
async,
cancellationToken,
isFirstAttempt: false);
return;
}
using var _ = cancellationRegistration;
// We treat BackendKeyData as optional because some PostgreSQL-like database
// don't send it (CockroachDB, CrateDB)
var msg = await conn.ReadMessage(async);
if (msg.Code == BackendMessageCode.BackendKeyData)
{
var keyDataMsg = (BackendKeyDataMessage)msg;
conn.BackendProcessId = keyDataMsg.BackendProcessId;
conn._backendSecretKey = keyDataMsg.BackendSecretKey;
msg = await conn.ReadMessage(async);
}
if (msg.Code != BackendMessageCode.ReadyForQuery)
throw new NpgsqlException($"Received backend message {msg.Code} while expecting ReadyForQuery. Please file a bug.");
conn.State = ConnectorState.Ready;
}
}
internal async ValueTask QueryDatabaseState(
NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken = default)
{
using var cmd = CreateCommand("select pg_is_in_recovery(); SHOW default_transaction_read_only");
cmd.CommandTimeout = (int)timeout.CheckAndGetTimeLeft().TotalSeconds;
var reader = async ? await cmd.ExecuteReaderAsync(cancellationToken) : cmd.ExecuteReader();
try
{
if (async)
{
await reader.ReadAsync(cancellationToken);
_isHotStandBy = reader.GetBoolean(0);
await reader.NextResultAsync(cancellationToken);
await reader.ReadAsync(cancellationToken);
}
else
{
reader.Read();
_isHotStandBy = reader.GetBoolean(0);
reader.NextResult();
reader.Read();
}
_isTransactionReadOnly = reader.GetString(0) != "off";
var databaseState = UpdateDatabaseState();
Debug.Assert(databaseState.HasValue);
return databaseState.Value;
}
finally
{
if (async)
await reader.DisposeAsync();
else
reader.Dispose();
}
}
void WriteStartupMessage(string username)
{
var startupParams = new Dictionary
{
["user"] = username,
["client_encoding"] = Settings.ClientEncoding ??
PostgresEnvironment.ClientEncoding ??
"UTF8"
};
if (Settings.Database is not null)
startupParams["database"] = Settings.Database;
if (Settings.ApplicationName?.Length > 0)
startupParams["application_name"] = Settings.ApplicationName;
if (Settings.SearchPath?.Length > 0)
startupParams["search_path"] = Settings.SearchPath;
var timezone = Settings.Timezone ?? PostgresEnvironment.TimeZone;
if (timezone != null)
startupParams["TimeZone"] = timezone;
var options = Settings.Options ?? PostgresEnvironment.Options;
if (options?.Length > 0)
startupParams["options"] = options;
switch (Settings.ReplicationMode)
{
case ReplicationMode.Logical:
startupParams["replication"] = "database";
break;
case ReplicationMode.Physical:
startupParams["replication"] = "true";
break;
}
WriteStartup(startupParams);
}
ValueTask GetUsernameAsync(bool async, CancellationToken cancellationToken)
{
var username = Settings.Username;
if (username?.Length > 0)
{
InferredUserName = username;
return new(username);
}
username = PostgresEnvironment.User;
if (username?.Length > 0)
{
InferredUserName = username;
return new(username);
}
return GetUsernameAsyncInternal();
async ValueTask GetUsernameAsyncInternal()
{
if (!RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
username = await KerberosUsernameProvider.GetUsernameAsync(Settings.IncludeRealm, ConnectionLogger, async,
cancellationToken);
if (username?.Length > 0)
{
InferredUserName = username;
return username;
}
}
username = Environment.UserName;
if (username?.Length > 0)
{
InferredUserName = username;
return username;
}
throw new NpgsqlException("No username could be found, please specify one explicitly");
}
}
async Task RawOpen(SslMode sslMode, NpgsqlTimeout timeout, bool async, CancellationToken cancellationToken, bool isFirstAttempt = true)
{
var cert = default(X509Certificate2?);
try
{
if (async)
await ConnectAsync(timeout, cancellationToken);
else
Connect(timeout);
_baseStream = new NetworkStream(_socket, true);
_stream = _baseStream;
if (Settings.Encoding == "UTF8")
{
TextEncoding = PGUtil.UTF8Encoding;
RelaxedTextEncoding = PGUtil.RelaxedUTF8Encoding;
}
else
{
TextEncoding = Encoding.GetEncoding(Settings.Encoding, EncoderFallback.ExceptionFallback, DecoderFallback.ExceptionFallback);
RelaxedTextEncoding = Encoding.GetEncoding(Settings.Encoding, EncoderFallback.ReplacementFallback, DecoderFallback.ReplacementFallback);
}
ReadBuffer = new NpgsqlReadBuffer(this, _stream, _socket, Settings.ReadBufferSize, TextEncoding, RelaxedTextEncoding);
WriteBuffer = new NpgsqlWriteBuffer(this, _stream, _socket, Settings.WriteBufferSize, TextEncoding);
timeout.CheckAndApply(this);
IsSecure = false;
if (sslMode is SslMode.Prefer or SslMode.Require or SslMode.VerifyCA or SslMode.VerifyFull)
{
WriteSslRequest();
await Flush(async, cancellationToken);
await ReadBuffer.Ensure(1, async);
var response = (char)ReadBuffer.ReadByte();
timeout.CheckAndApply(this);
switch (response)
{
default:
throw new NpgsqlException($"Received unknown response {response} for SSLRequest (expecting S or N)");
case 'N':
if (sslMode != SslMode.Prefer)
throw new NpgsqlException("SSL connection requested. No SSL enabled connection from this host is configured.");
break;
case 'S':
var clientCertificates = new X509Certificate2Collection();
var certPath = Settings.SslCertificate ?? PostgresEnvironment.SslCert ?? PostgresEnvironment.SslCertDefault;
if (certPath != null)
{
var password = Settings.SslPassword;
if (Path.GetExtension(certPath).ToUpperInvariant() != ".PFX")
{
#if NET5_0_OR_GREATER
// It's PEM time
var keyPath = Settings.SslKey ?? PostgresEnvironment.SslKey ?? PostgresEnvironment.SslKeyDefault;
cert = string.IsNullOrEmpty(password)
? X509Certificate2.CreateFromPemFile(certPath, keyPath)
: X509Certificate2.CreateFromEncryptedPemFile(certPath, password, keyPath);
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
// Windows crypto API has a bug with pem certs
// See #3650
using var previousCert = cert;
cert = new X509Certificate2(cert.Export(X509ContentType.Pkcs12));
}
#else
throw new NotSupportedException("PEM certificates are only supported with .NET 5 and higher");
#endif
}
cert ??= new X509Certificate2(certPath, password);
clientCertificates.Add(cert);
}
ClientCertificatesCallback?.Invoke(clientCertificates);
var checkCertificateRevocation = Settings.CheckCertificateRevocation;
RemoteCertificateValidationCallback? certificateValidationCallback;
if (UserCertificateValidationCallback is not null)
{
if (sslMode is SslMode.VerifyCA or SslMode.VerifyFull)
throw new ArgumentException(string.Format(NpgsqlStrings.CannotUseSslVerifyWithUserCallback, sslMode));
if (Settings.RootCertificate is not null)
throw new ArgumentException(string.Format(NpgsqlStrings.CannotUseSslRootCertificateWithUserCallback));
certificateValidationCallback = UserCertificateValidationCallback;
}
else if (sslMode is SslMode.Prefer or SslMode.Require)
{
if (isFirstAttempt && sslMode is SslMode.Require && !Settings.TrustServerCertificate)
throw new ArgumentException(NpgsqlStrings.CannotUseSslModeRequireWithoutTrustServerCertificate);
certificateValidationCallback = SslTrustServerValidation;
checkCertificateRevocation = false;
}
else if ((Settings.RootCertificate ?? PostgresEnvironment.SslCertRoot ?? PostgresEnvironment.SslCertRootDefault) is
{ } certRootPath)
{
certificateValidationCallback = SslRootValidation(certRootPath, sslMode == SslMode.VerifyFull);
}
else if (sslMode == SslMode.VerifyCA)
{
certificateValidationCallback = SslVerifyCAValidation;
}
else
{
Debug.Assert(sslMode == SslMode.VerifyFull);
certificateValidationCallback = SslVerifyFullValidation;
}
timeout.CheckAndApply(this);
try
{
var sslStream = new SslStream(_stream, leaveInnerStreamOpen: false, certificateValidationCallback);
var sslProtocols = SslProtocols.None;
// On .NET Framework SslProtocols.None can be disabled, see #3718
#if NETSTANDARD2_0
sslProtocols = SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12;
#endif
if (async)
await sslStream.AuthenticateAsClientAsync(Host, clientCertificates, sslProtocols, checkCertificateRevocation);
else
sslStream.AuthenticateAsClient(Host, clientCertificates, sslProtocols, checkCertificateRevocation);
_stream = sslStream;
}
catch (Exception e)
{
throw new NpgsqlException("Exception while performing SSL handshake", e);
}
ReadBuffer.Underlying = _stream;
WriteBuffer.Underlying = _stream;
IsSecure = true;
ConnectionLogger.LogTrace("SSL negotiation successful");
break;
}
if (ReadBuffer.ReadBytesLeft > 0)
throw new NpgsqlException("Additional unencrypted data received after SSL negotiation - this should never happen, and may be an indication of a man-in-the-middle attack.");
}
ConnectionLogger.LogTrace("Socket connected to {Host}:{Port}", Host, Port);
}
catch
{
cert?.Dispose();
_stream?.Dispose();
_stream = null!;
_baseStream?.Dispose();
_baseStream = null!;
_socket?.Dispose();
_socket = null!;
throw;
}
}
void Connect(NpgsqlTimeout timeout)
{
// Note that there aren't any timeout-able or cancellable DNS methods
var endpoints = NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath)
? new EndPoint[] { new UnixDomainSocketEndPoint(socketPath) }
: Dns.GetHostAddresses(Host).Select(a => new IPEndPoint(a, Port)).ToArray();
timeout.Check();
// Give each endpoint an equal share of the remaining time
var perEndpointTimeout = -1; // Default to infinity
if (timeout.IsSet)
perEndpointTimeout = (int)(timeout.CheckAndGetTimeLeft().Ticks / endpoints.Length / 10);
for (var i = 0; i < endpoints.Length; i++)
{
var endpoint = endpoints[i];
ConnectionLogger.LogTrace("Attempting to connect to {Endpoint}", endpoint);
var protocolType =
endpoint.AddressFamily == AddressFamily.InterNetwork ||
endpoint.AddressFamily == AddressFamily.InterNetworkV6
? ProtocolType.Tcp
: ProtocolType.IP;
var socket = new Socket(endpoint.AddressFamily, SocketType.Stream, protocolType)
{
Blocking = false
};
try
{
try
{
socket.Connect(endpoint);
}
catch (SocketException e)
{
if (e.SocketErrorCode != SocketError.WouldBlock)
throw;
}
var write = new List { socket };
var error = new List { socket };
Socket.Select(null, write, error, perEndpointTimeout);
var errorCode = (int) socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.Error)!;
if (errorCode != 0)
throw new SocketException(errorCode);
if (!write.Any())
throw new TimeoutException("Timeout during connection attempt");
socket.Blocking = true;
SetSocketOptions(socket);
_socket = socket;
ConnectedEndPoint = endpoint;
return;
}
catch (Exception e)
{
try { socket.Dispose(); }
catch
{
// ignored
}
ConnectionLogger.LogTrace(e, "Failed to connect to {Endpoint}", endpoint);
if (i == endpoints.Length - 1)
throw new NpgsqlException($"Failed to connect to {endpoint}", e);
}
}
}
async Task ConnectAsync(NpgsqlTimeout timeout, CancellationToken cancellationToken)
{
Task GetHostAddressesAsync(CancellationToken ct) =>
#if NET6_0_OR_GREATER
Dns.GetHostAddressesAsync(Host, ct);
#else
Dns.GetHostAddressesAsync(Host);
#endif
// Whether the framework and/or the OS platform support Dns.GetHostAddressesAsync cancellation API or they do not,
// we always fake-cancel the operation with the help of TaskTimeoutAndCancellation.ExecuteAsync. It stops waiting
// and raises the exception, while the actual task may be left running.
var endpoints = NpgsqlConnectionStringBuilder.IsUnixSocket(Host, Port, out var socketPath)
? new EndPoint[] { new UnixDomainSocketEndPoint(socketPath) }
: (await TaskTimeoutAndCancellation.ExecuteAsync(GetHostAddressesAsync, timeout, cancellationToken))
.Select(a => new IPEndPoint(a, Port)).ToArray();
// Give each IP an equal share of the remaining time
var perIpTimespan = default(TimeSpan);
var perIpTimeout = timeout;
if (timeout.IsSet)
{
perIpTimespan = new TimeSpan(timeout.CheckAndGetTimeLeft().Ticks / endpoints.Length);
perIpTimeout = new NpgsqlTimeout(perIpTimespan);
}
for (var i = 0; i < endpoints.Length; i++)
{
var endpoint = endpoints[i];
ConnectionLogger.LogTrace("Attempting to connect to {Endpoint}", endpoint);
var protocolType =
endpoint.AddressFamily == AddressFamily.InterNetwork ||
endpoint.AddressFamily == AddressFamily.InterNetworkV6
? ProtocolType.Tcp
: ProtocolType.IP;
var socket = new Socket(endpoint.AddressFamily, SocketType.Stream, protocolType);
try
{
await OpenSocketConnectionAsync(socket, endpoint, perIpTimeout, cancellationToken);
SetSocketOptions(socket);
_socket = socket;
ConnectedEndPoint = endpoint;
return;
}
catch (Exception e)
{
try
{
socket.Dispose();
}
catch
{
// ignored
}
cancellationToken.ThrowIfCancellationRequested();
if (e is OperationCanceledException)
e = new TimeoutException("Timeout during connection attempt");
ConnectionLogger.LogTrace(e, "Failed to connect to {Endpoint}", endpoint);
if (i == endpoints.Length - 1)
throw new NpgsqlException($"Failed to connect to {endpoint}", e);
}
}
static Task OpenSocketConnectionAsync(Socket socket, EndPoint endpoint, NpgsqlTimeout perIpTimeout, CancellationToken cancellationToken)
{
// Whether the framework and/or the OS platform support Socket.ConnectAsync cancellation API or they do not,
// we always fake-cancel the operation with the help of TaskTimeoutAndCancellation.ExecuteAsync. It stops waiting
// and raises the exception, while the actual task may be left running.
Task ConnectAsync(CancellationToken ct) =>
#if NET5_0_OR_GREATER
socket.ConnectAsync(endpoint, ct).AsTask();
#else
socket.ConnectAsync(endpoint);
#endif
return TaskTimeoutAndCancellation.ExecuteAsync(ConnectAsync, perIpTimeout, cancellationToken);
}
}
void SetSocketOptions(Socket socket)
{
if (socket.AddressFamily == AddressFamily.InterNetwork || socket.AddressFamily == AddressFamily.InterNetworkV6)
socket.NoDelay = true;
if (Settings.SocketReceiveBufferSize > 0)
socket.ReceiveBufferSize = Settings.SocketReceiveBufferSize;
if (Settings.SocketSendBufferSize > 0)
socket.SendBufferSize = Settings.SocketSendBufferSize;
if (Settings.TcpKeepAlive)
socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true);
if (Settings.TcpKeepAliveInterval > 0 && Settings.TcpKeepAliveTime == 0)
throw new ArgumentException("If TcpKeepAliveInterval is defined, TcpKeepAliveTime must be defined as well");
if (Settings.TcpKeepAliveTime > 0)
{
var timeSeconds = Settings.TcpKeepAliveTime;
var intervalSeconds = Settings.TcpKeepAliveInterval > 0
? Settings.TcpKeepAliveInterval
: Settings.TcpKeepAliveTime;
#if NETSTANDARD2_0 || NETSTANDARD2_1
var timeMilliseconds = timeSeconds * 1000;
var intervalMilliseconds = intervalSeconds * 1000;
// For the following see https://msdn.microsoft.com/en-us/library/dd877220.aspx
var uintSize = Marshal.SizeOf(typeof(uint));
var inOptionValues = new byte[uintSize * 3];
BitConverter.GetBytes((uint)1).CopyTo(inOptionValues, 0);
BitConverter.GetBytes((uint)timeMilliseconds).CopyTo(inOptionValues, uintSize);
BitConverter.GetBytes((uint)intervalMilliseconds).CopyTo(inOptionValues, uintSize * 2);
var result = 0;
try
{
result = socket.IOControl(IOControlCode.KeepAliveValues, inOptionValues, null);
}
catch (PlatformNotSupportedException)
{
throw new PlatformNotSupportedException("Setting TCP Keepalive Time and TCP Keepalive Interval is supported only on Windows, Mono and .NET Core 3.1+. " +
"TCP keepalives can still be used on other systems but are enabled via the TcpKeepAlive option or configured globally for the machine, see the relevant docs.");
}
if (result != 0)
throw new NpgsqlException($"Got non-zero value when trying to set TCP keepalive: {result}");
#else
socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, true);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, timeSeconds);
socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, intervalSeconds);
#endif
}
}
#endregion
#region I/O
readonly ChannelReader? CommandsInFlightReader;
internal readonly ChannelWriter? CommandsInFlightWriter;
internal volatile int CommandsInFlightCount;
internal ManualResetValueTaskSource