diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 69383f9..1887a52 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1037,10 +1037,98 @@ pub unsafe extern "C-unwind" fn refresh_iam_token( ); }, }; + + async_panic_guard.panicked = false; + }); + + panic_guard.panicked = false; +} + +/// Update connection password +/// +/// # Arguments +/// * `client_ptr` - Pointer to the client +/// * `callback_index` - Callback index for async response +/// * `password` - New password (null for password removal) +/// * `immediate_auth` - Whether to authenticate immediately +/// +/// # Safety +/// * `client_ptr` must be a valid pointer to a Client +/// * `password` must be a valid C string or null +#[unsafe(no_mangle)] +pub unsafe extern "C-unwind" fn update_connection_password( + client_ptr: *const c_void, + callback_index: usize, + password_ptr: *const c_char, + immediate_auth: bool, +) { + // Build client and add panic guard. + let client = unsafe { + Arc::increment_strong_count(client_ptr); + Arc::from_raw(client_ptr as *mut Client) + }; + let core = client.core.clone(); + + let mut panic_guard = PanicGuard { + panicked: true, + failure_callback: core.failure_callback, + callback_index, + }; + + // Build password option. + let password = if password_ptr.is_null() { + None + } else { + match unsafe { CStr::from_ptr(password_ptr).to_str() } { + Ok(password_str) => { + if password_str.is_empty() { + None + } else { + Some(password_str.into()) + } + } + Err(_) => { + unsafe { + report_error( + core.failure_callback, + callback_index, + "Invalid password argument".into(), + RequestErrorType::Unspecified, + ); + } + panic_guard.panicked = false; + return; + } + } + }; + + // Run password update. + client.runtime.spawn(async move { + let mut async_panic_guard = PanicGuard { + panicked: true, + failure_callback: core.failure_callback, + callback_index, + }; + + let result = core.client.clone().update_connection_password(password, immediate_auth).await; + match result { + Ok(value) => { + let response = ResponseValue::from_value(value); + let ptr = Box::into_raw(Box::new(response)); + unsafe { (core.success_callback)(callback_index, ptr) }; + } + Err(err) => unsafe { + report_error( + core.failure_callback, + callback_index, + error_message(&err), + error_type(&err), + ); + }, + }; + async_panic_guard.panicked = false; - drop(async_panic_guard); }); panic_guard.panicked = false; - drop(panic_guard); } diff --git a/sources/Valkey.Glide/BaseClient.cs b/sources/Valkey.Glide/BaseClient.cs index e98f4c5..35a9588 100644 --- a/sources/Valkey.Glide/BaseClient.cs +++ b/sources/Valkey.Glide/BaseClient.cs @@ -58,6 +58,66 @@ public async Task RefreshIamTokenAsync() } } + /// + /// Update the current connection with a new password. + /// + /// The new password to update the connection with + /// If true, re-authenticate immediately after updating password + /// Thrown if is null or empty. + /// A task that completes when the password is updated + public async Task UpdateConnectionPasswordAsync(string password, bool immediateAuth = false) + { + if (password == null) + { + throw new ArgumentException("Password cannot be null", nameof(password)); + } + + if (password.Length == 0) + { + throw new ArgumentException("Password cannot be empty", nameof(password)); + } + + Message message = MessageContainer.GetMessageForCall(); + IntPtr passwordPtr = Marshal.StringToHGlobalAnsi(password); + try + { + UpdateConnectionPasswordFfi(ClientPointer, (ulong)message.Index, passwordPtr, immediateAuth); + IntPtr response = await message; + try + { + HandleResponse(response); + } + finally + { + FreeResponse(response); + } + } + finally + { + Marshal.FreeHGlobal(passwordPtr); + } + } + + /// + /// Clear the password from the current connection. + /// + /// If true, re-authenticate immediately after clearing password + /// A task that completes when the password is cleared + public async Task ClearConnectionPasswordAsync(bool immediateAuth = false) + { + Message message = MessageContainer.GetMessageForCall(); + + UpdateConnectionPasswordFfi(ClientPointer, (ulong)message.Index, IntPtr.Zero, immediateAuth); + IntPtr response = await message; + try + { + HandleResponse(response); + } + finally + { + FreeResponse(response); + } + } #endregion public methods #region protected methods diff --git a/sources/Valkey.Glide/Internals/FFI.methods.cs b/sources/Valkey.Glide/Internals/FFI.methods.cs index 2f3c719..49d5ace 100644 --- a/sources/Valkey.Glide/Internals/FFI.methods.cs +++ b/sources/Valkey.Glide/Internals/FFI.methods.cs @@ -70,6 +70,9 @@ public static partial void InvokeScriptFfi( [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] public static partial void RemoveClusterScanCursorFfi(IntPtr cursorId); + [LibraryImport("libglide_rs", EntryPoint = "update_connection_password")] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + public static partial void UpdateConnectionPasswordFfi(IntPtr client, ulong index, IntPtr password, [MarshalAs(UnmanagedType.U1)] bool immediateAuth); [LibraryImport("libglide_rs", EntryPoint = "refresh_iam_token")] [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] @@ -101,7 +104,7 @@ public static partial void InvokeScriptFfi( [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "free_drop_script_error")] public static extern void FreeDropScriptError(IntPtr errorBuffer); - + [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "invoke_script")] public static extern void InvokeScriptFfi( IntPtr client, @@ -121,7 +124,10 @@ public static extern void InvokeScriptFfi( [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "remove_cluster_scan_cursor")] public static extern void RemoveClusterScanCursorFfi(IntPtr cursorId); - + + [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "update_connection_password")] + public static extern void UpdateConnectionPasswordFfi(IntPtr client, ulong index, IntPtr password, [MarshalAs(UnmanagedType.U1)] bool immediateAuth); + [DllImport("libglide_rs", CallingConvention = CallingConvention.Cdecl, EntryPoint = "refresh_iam_token")] public static extern void RefreshIamTokenFfi(IntPtr client, ulong index); #endif diff --git a/tests/Valkey.Glide.IntegrationTests/UpdateConnectionPasswordTests.cs b/tests/Valkey.Glide.IntegrationTests/UpdateConnectionPasswordTests.cs new file mode 100644 index 0000000..356e5e6 --- /dev/null +++ b/tests/Valkey.Glide.IntegrationTests/UpdateConnectionPasswordTests.cs @@ -0,0 +1,219 @@ +// Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 + +using static Valkey.Glide.ConnectionConfiguration; +using static Valkey.Glide.Errors; + +namespace Valkey.Glide.IntegrationTests; + +public class UpdateConnectionPasswordTests(TestConfiguration config) +{ + public TestConfiguration Config { get; } = config; + + private static readonly string Password = "PASSWORD"; + private static readonly string InvalidPassword = "INVALID"; + private static readonly GlideString[] KillClientCommandArgs = ["CLIENT", "KILL", "TYPE", "NORMAL"]; + + [Fact] + public async Task UpdateConnectionPassword_Standalone_DelayAuth() + { + string serverName = $"test_{Guid.NewGuid():N}"; + + try + { + // Start server and build clients. + var addresses = ServerManager.StartStandaloneServer(serverName); + var config = new StandaloneClientConfigurationBuilder() + .WithAddress(addresses[0].host, addresses[0].port).Build(); + + using var client = await GlideClient.CreateClient(config); + using var adminClient = await GlideClient.CreateClient(config); + + VerifyConnection(client); + VerifyConnection(adminClient); + + // Update client connection password. + await client.UpdateConnectionPasswordAsync(Password, immediateAuth: false); + + VerifyConnection(client); // No reconnect + + // Update server password and kill all clients. + await adminClient.ConfigSetAsync("requirepass", Password); + await adminClient.CustomCommand(KillClientCommandArgs); + Task.Delay(1000).Wait(); + + VerifyConnection(client); // Reconnect + + // Clear client connection password. + await client.ClearConnectionPasswordAsync(immediateAuth: false); + + VerifyConnection(client); // No reconnect + + // Clear server password and kill all clients. + await adminClient.ConfigSetAsync("requirepass", ""); + await adminClient.CustomCommand(KillClientCommandArgs); + Task.Delay(1000).Wait(); + + VerifyConnection(client); // Reconnect + } + finally + { + ServerManager.StopServer(serverName); + } + } + + [Fact] + public async Task UpdateConnectionPassword_Standalone_ImmediateAuth() + { + string serverName = $"test_{Guid.NewGuid():N}"; + + try + { + // Start server and build client. + var addresses = ServerManager.StartStandaloneServer(serverName); + var config = new StandaloneClientConfigurationBuilder() + .WithAddress(addresses[0].host, addresses[0].port).Build(); + + using var client = await GlideClient.CreateClient(config); + + VerifyConnection(client); + + // Update server password. + await client.ConfigSetAsync("requirepass", Password); + Task.Delay(1000).Wait(); + + // Update client connection password. + await client.UpdateConnectionPasswordAsync(Password, immediateAuth: true); + + VerifyConnection(client); + + // Clear server password. + await client.ConfigSetAsync("requirepass", ""); + Task.Delay(1000).Wait(); + + // Clear client connection password. + await client.ClearConnectionPasswordAsync(immediateAuth: false); + + VerifyConnection(client); + } + finally + { + ServerManager.StopServer(serverName); + } + } + + [Fact] + public async Task UpdateConnectionPassword_Standalone_InvalidPassword() + { + using var client = TestConfiguration.DefaultStandaloneClient(); + await Assert.ThrowsAsync(() => client.UpdateConnectionPasswordAsync(null!, immediateAuth: true)); + await Assert.ThrowsAsync(() => client.UpdateConnectionPasswordAsync("", immediateAuth: true)); + await Assert.ThrowsAsync(() => client.UpdateConnectionPasswordAsync(InvalidPassword, immediateAuth: true)); + } + + [Fact] + public async Task UpdateConnectionPassword_Cluster_DelayAuth() + { + string serverName = $"test_{Guid.NewGuid():N}"; + + try + { + // Start cluster and build clients. + var addresses = ServerManager.StartClusterServer(serverName); + var config = new ClusterClientConfigurationBuilder() + .WithAddress(addresses[0].host, addresses[0].port).Build(); + + using var client = await GlideClusterClient.CreateClient(config); + using var adminClient = await GlideClusterClient.CreateClient(config); + + VerifyConnection(client); + VerifyConnection(adminClient); + + // Update client connection password. + await client.UpdateConnectionPasswordAsync(Password, immediateAuth: false); + + VerifyConnection(client); // No reconnect + + // Update server password and kill all clients. + await adminClient.ConfigSetAsync("requirepass", Password); + await adminClient.CustomCommand(KillClientCommandArgs); + Task.Delay(1000).Wait(); + + VerifyConnection(client); // Reconnect + + // Clear client connection password. + await client.ClearConnectionPasswordAsync(immediateAuth: false); + + VerifyConnection(client); // No reconnect + + // Clear server password and kill all clients. + await adminClient.ConfigSetAsync("requirepass", ""); + await adminClient.CustomCommand(KillClientCommandArgs); + Task.Delay(1000).Wait(); + + VerifyConnection(client); // Reconnect + } + finally + { + ServerManager.StopServer(serverName); + } + } + + [Fact] + public async Task UpdateConnectionPassword_Cluster_ImmediateAuth() + { + string serverName = $"test_{Guid.NewGuid():N}"; + + try + { + // Start cluster and build client. + var addresses = ServerManager.StartClusterServer(serverName); + var config = new ClusterClientConfigurationBuilder() + .WithAddress(addresses[0].host, addresses[0].port).Build(); + + using var client = await GlideClusterClient.CreateClient(config); + + VerifyConnection(client); + + // Update server password. + await client.ConfigSetAsync("requirepass", Password, Route.AllNodes); + Task.Delay(1000).Wait(); + + // Update client connection password. + await client.UpdateConnectionPasswordAsync(Password, immediateAuth: true); + + VerifyConnection(client); + + // Clear server password. + await client.ConfigSetAsync("requirepass", "", Route.AllNodes); + Task.Delay(1000).Wait(); + + // Clear client connection password. + await client.ClearConnectionPasswordAsync(immediateAuth: false); + + VerifyConnection(client); + } + finally + { + ServerManager.StopServer(serverName); + } + } + + [Fact] + public async Task UpdateConnectionPassword_Cluster_InvalidPassword() + { + using var client = TestConfiguration.DefaultClusterClient(); + await Assert.ThrowsAsync(() => client.UpdateConnectionPasswordAsync(null!, immediateAuth: true)); + await Assert.ThrowsAsync(() => client.UpdateConnectionPasswordAsync("", immediateAuth: true)); + await Assert.ThrowsAsync(() => client.UpdateConnectionPasswordAsync(InvalidPassword, immediateAuth: true)); + } + + private static async void VerifyConnection(GlideClient client) + { + Assert.True(await client.PingAsync() > TimeSpan.Zero); + } + + private static async void VerifyConnection(GlideClusterClient client) + { + Assert.True(await client.PingAsync() > TimeSpan.Zero); + } +}