Skip to content

Commit 912cebf

Browse files
author
Tien Nguyen
committed
Added test case
1 parent 5da06eb commit 912cebf

File tree

4 files changed

+78
-77
lines changed

4 files changed

+78
-77
lines changed

src/AdoNetCore.AseClient/AseConnection.cs

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,9 @@ public AseTransaction Transaction
627627
}
628628
}
629629

630+
/// <summary>
631+
/// Allow consumer to override the default certificate validation
632+
/// </summary>
630633
public RemoteCertificateValidationCallback UserCertificateValidationCallback { get; set; }
631634

632635
}
@@ -674,25 +677,13 @@ public AseTransaction Transaction
674677
/// </remarks>
675678
public delegate void TraceExitEventHandler(AseConnection connection, object source, string method, object returnValue);
676679

677-
//
678-
// Summary:
679-
// Verifies the remote Secure Sockets Layer (SSL) certificate used for authentication.
680-
//
681-
// Parameters:
682-
// sender:
683-
// An object that contains state information for this validation.
684-
//
685-
// certificate:
686-
// The certificate used to authenticate the remote party.
687-
//
688-
// chain:
689-
// The chain of certificate authorities associated with the remote certificate.
690-
//
691-
// sslPolicyErrors:
692-
// One or more errors associated with the remote certificate.
693-
//
694-
// Returns:
695-
// A System.Boolean value that determines whether the specified certificate is accepted
696-
// for authentication.
680+
/// <summary>
681+
/// Verifies the remote Secure Sockets Layer (SSL) certificate used for authentication.
682+
/// </summary>
683+
/// <param name="sender">An object that contains state information for this validation.</param>
684+
/// <param name="certificate">The certificate used to authenticate the remote party.</param>
685+
/// <param name="chain">The chain of certificate authorities associated with the remote certificate.</param>
686+
/// <param name="sslPolicyErrors">One or more errors associated with the remote certificate.</param>
687+
/// <returns>A System.Boolean value that determines whether the specified certificate is accepted for authentication.</returns>
697688
public delegate bool RemoteCertificateValidationCallback(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslPolicyErrors);
698689
}

src/AdoNetCore.AseClient/Internal/InternalConnectionFactory.cs

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ namespace AdoNetCore.AseClient.Internal
1919
internal class InternalConnectionFactory : IInternalConnectionFactory
2020
{
2121
private readonly IConnectionParameters _parameters;
22-
private readonly RemoteCertificateValidationCallback _userCertificateValidationCallback;
22+
private readonly System.Net.Security.RemoteCertificateValidationCallback _userCertificateValidationCallback;
23+
2324

2425
#if ENABLE_ARRAY_POOL
2526
private readonly System.Buffers.ArrayPool<byte> _arrayPool;
@@ -33,10 +34,7 @@ public InternalConnectionFactory(IConnectionParameters parameters, RemoteCertifi
3334
#endif
3435
{
3536
_parameters = parameters;
36-
if (userCertificateValidationCallback == null)
37-
userCertificateValidationCallback = GetDefaultUserCertificateValidationCallback();
38-
39-
_userCertificateValidationCallback = userCertificateValidationCallback;
37+
_userCertificateValidationCallback = userCertificateValidationCallback == null ? UserCertificateValidationCallback : new System.Net.Security.RemoteCertificateValidationCallback(userCertificateValidationCallback);
4038

4139
#if ENABLE_ARRAY_POOL
4240
_arrayPool = arrayPool;
@@ -116,7 +114,7 @@ private InternalConnection CreateConnection(Socket socket, CancellationToken tok
116114

117115
if (_parameters.Encryption)
118116
{
119-
sslStream = new SslStream(networkStream, false, _userCertificateValidationCallback.Invoke);
117+
sslStream = new SslStream(networkStream, false, _userCertificateValidationCallback);
120118

121119
var authenticate = sslStream.AuthenticateAsClientAsync(_parameters.Server);
122120

@@ -187,32 +185,30 @@ private InternalConnection CreateConnectionInternal(Stream networkStream)
187185
}
188186

189187

190-
private RemoteCertificateValidationCallback GetDefaultUserCertificateValidationCallback()
188+
private bool UserCertificateValidationCallback(object sender, X509Certificate serverCertificate, X509Chain chain, SslPolicyErrors sslPolicyErrors)
191189
{
192-
//object sender, X509Certificate serverCertificate, X509Chain chain, SslPolicyErrors sslPolicyErrors
193-
return (sender, serverCertificate, chain, sslPolicyErrors) => {
194-
var certificateChainPolicyErrors = (sslPolicyErrors & SslPolicyErrors.RemoteCertificateChainErrors) == SslPolicyErrors.RemoteCertificateChainErrors;
195-
var otherPolicyErrors = (sslPolicyErrors & ~SslPolicyErrors.RemoteCertificateChainErrors) != SslPolicyErrors.None;
190+
var certificateChainPolicyErrors = (sslPolicyErrors & SslPolicyErrors.RemoteCertificateChainErrors) == SslPolicyErrors.RemoteCertificateChainErrors;
191+
var otherPolicyErrors = (sslPolicyErrors & ~SslPolicyErrors.RemoteCertificateChainErrors) != SslPolicyErrors.None;
196192

197-
// We're not concerned with chain errors as we verify the chain below.
198-
if (otherPolicyErrors)
199-
{
200-
Logger.Instance?.WriteLine($"{nameof(InternalConnectionFactory)}.{nameof(GetDefaultUserCertificateValidationCallback)} secure connection failed due to policy errors: {sslPolicyErrors}");
201-
return false;
202-
}
193+
// We're not concerned with chain errors as we verify the chain below.
194+
if (otherPolicyErrors)
195+
{
196+
Logger.Instance?.WriteLine($"{nameof(InternalConnectionFactory)}.{nameof(UserCertificateValidationCallback)} secure connection failed due to policy errors: {sslPolicyErrors}");
197+
return false;
198+
}
203199

204-
var mergedStatusFlags = X509ChainStatusFlags.NoError;
205-
foreach (var status in chain.ChainStatus)
206-
{
207-
mergedStatusFlags |= status.Status;
208-
}
200+
var mergedStatusFlags = X509ChainStatusFlags.NoError;
201+
foreach (var status in chain.ChainStatus)
202+
{
203+
mergedStatusFlags |= status.Status;
204+
}
209205

210-
var trustedCerts = LoadTrustedFile(_parameters.TrustedFile);
211-
if (trustedCerts == null)
212-
{
213-
Logger.Instance?.WriteLine($"{nameof(InternalConnectionFactory)}.{nameof(GetDefaultUserCertificateValidationCallback)} secure connection failed due to missing TrustedFile parameter.");
214-
return false;
215-
}
206+
var trustedCerts = LoadTrustedFile(_parameters.TrustedFile);
207+
if (trustedCerts == null)
208+
{
209+
Logger.Instance?.WriteLine($"{nameof(InternalConnectionFactory)}.{nameof(UserCertificateValidationCallback)} secure connection failed due to missing TrustedFile parameter.");
210+
return false;
211+
}
216212

217213
#if !(NETCOREAPP1_0 || NETCOREAPP1_1) // these frameworks do not have the following X509Certificate2 constructor...
218214
// sometimes the chain policy is only a partial chain because it doesn't include a self signed root?
@@ -233,41 +229,40 @@ private RemoteCertificateValidationCallback GetDefaultUserCertificateValidationC
233229
}
234230
#endif
235231

236-
var untrustedRootChainStatusFlags = (mergedStatusFlags & X509ChainStatusFlags.UntrustedRoot) == X509ChainStatusFlags.UntrustedRoot;
237-
var otherChainStatusFlags = (mergedStatusFlags & ~X509ChainStatusFlags.UntrustedRoot) != X509ChainStatusFlags.NoError;
232+
var untrustedRootChainStatusFlags = (mergedStatusFlags & X509ChainStatusFlags.UntrustedRoot) == X509ChainStatusFlags.UntrustedRoot;
233+
var otherChainStatusFlags = (mergedStatusFlags & ~X509ChainStatusFlags.UntrustedRoot) != X509ChainStatusFlags.NoError;
238234

239-
if (otherChainStatusFlags)
240-
{
241-
Logger.Instance?.WriteLine($"{nameof(InternalConnectionFactory)}.{nameof(GetDefaultUserCertificateValidationCallback)} secure connection failed due to chain status: {mergedStatusFlags}");
242-
return false;
243-
}
235+
if (otherChainStatusFlags)
236+
{
237+
Logger.Instance?.WriteLine($"{nameof(InternalConnectionFactory)}.{nameof(UserCertificateValidationCallback)} secure connection failed due to chain status: {mergedStatusFlags}");
238+
return false;
239+
}
244240

245-
if (!(certificateChainPolicyErrors || untrustedRootChainStatusFlags))
246-
{
247-
//No chain Errors, we will trust the server certificate.
248-
return true;
249-
}
241+
if (!(certificateChainPolicyErrors || untrustedRootChainStatusFlags))
242+
{
243+
//No chain Errors, we will trust the server certificate.
244+
return true;
245+
}
250246

251-
// If any certificates in the chain are trusted, then we will trust the server certificate.
252-
// To do this fairly quickly we can check if thumbprints exist in the set of trusted roots.
253-
var set = new HashSet<string>(trustedCerts.Select(c => c.Thumbprint));
247+
// If any certificates in the chain are trusted, then we will trust the server certificate.
248+
// To do this fairly quickly we can check if thumbprints exist in the set of trusted roots.
249+
var set = new HashSet<string>(trustedCerts.Select(c => c.Thumbprint));
254250

255-
// the chain is in an array from leaf at 0 to root at [count - 1]
256-
// looping from end to start should find cases generated according to sybase documentation on the first attempt
257-
// but it is possible that someone puts an intermediate or even the leaf cert in their trusted file
258-
for (int i = chain.ChainElements.Count - 1; i >= 0; i--)
251+
// the chain is in an array from leaf at 0 to root at [count - 1]
252+
// looping from end to start should find cases generated according to sybase documentation on the first attempt
253+
// but it is possible that someone puts an intermediate or even the leaf cert in their trusted file
254+
for (int i = chain.ChainElements.Count - 1; i >= 0; i--)
255+
{
256+
var potentialTrusted = chain.ChainElements[i].Certificate.Thumbprint;
257+
if (set.Contains(potentialTrusted))
259258
{
260-
var potentialTrusted = chain.ChainElements[i].Certificate.Thumbprint;
261-
if (set.Contains(potentialTrusted))
262-
{
263-
return true;
264-
}
259+
return true;
265260
}
261+
}
266262

267-
Logger.Instance?.WriteLine($"{nameof(InternalConnectionFactory)}.{nameof(GetDefaultUserCertificateValidationCallback)} secure connection failed due to missing root or intermediate certificate in the certificate store, or the TrustedFile.");
263+
Logger.Instance?.WriteLine($"{nameof(InternalConnectionFactory)}.{nameof(UserCertificateValidationCallback)} secure connection failed due to missing root or intermediate certificate in the certificate store, or the TrustedFile.");
268264

269-
return false;
270-
};
265+
return false;
271266
}
272267

273268
private static X509Certificate2[] LoadTrustedFile(string trustedFile)

test/AdoNetCore.AseClient.Tests/Integration/ConnectionTests.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,21 @@ public void OpenConnection_ToNonListeningServer_ThrowsAseException()
4343
Assert.AreEqual(30294, ex.Errors[0].MessageNumber);
4444
}
4545
}
46+
47+
[Test]
48+
//Note: for this to work, we would need a sybase SSL setup with a separate SSL port
49+
//for this test, no certificate on the consuming side
50+
public void UserCertificateValidationCallback_ShouldWork()
51+
{
52+
var isDelegateCalled = false;
53+
54+
using (var connection = new AseConnection(ConnectionStrings.Tls))
55+
{
56+
connection.UserCertificateValidationCallback = (o, cert, chain, pol) => { isDelegateCalled = true; return true; };
57+
58+
connection.Open();
59+
Assert.True(isDelegateCalled);
60+
}
61+
}
4662
}
4763
}

test/AdoNetCore.AseClient.Tests/Unit/AseConnectionTests.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,13 @@ public void ChangeDatabase_WhenOpen_Succeeds()
173173

174174
var connection = new AseConnection("Data Source=myASEserver;Port=5000;Database=foo;Uid=myUsername;Pwd=myPassword;", mockConnectionPoolManager.Object);
175175

176-
//connection.UserCertificateValidationCallback = (sender, certificate, chain, errors) => true;
177176
// Act
178177
connection.Open();
179178
connection.ChangeDatabase("bar");
180-
181179
// Assert
182180
// No error...
183181
}
182+
184183
[Test]
185184
public void ChangeDatabase_WhenNotOpen_Succeeds()
186185
{

0 commit comments

Comments
 (0)