@@ -857,6 +857,10 @@ type MockKeyRing struct {
857857 KeyIndex uint32
858858
859859 Keys map [keychain.KeyLocator ]* btcec.PrivateKey
860+
861+ // deriveNextKeyCallCount is used to track the number of calls to
862+ // DeriveNextKey.
863+ deriveNextKeyCallCount atomic.Uint64
860864}
861865
862866var _ KeyRing = (* MockKeyRing )(nil )
@@ -869,8 +873,11 @@ func NewMockKeyRing() *MockKeyRing {
869873 keyRing .On (
870874 "DeriveNextKey" , mock .Anything ,
871875 keychain .KeyFamily (asset .TaprootAssetsKeyFamily ),
872- ).Return (nil )
873- keyRing .On ("DeriveNextTaprootAssetKey" , mock .Anything ).Return (nil )
876+ ).Return (keychain.KeyDescriptor {}, nil )
877+
878+ keyRing .On (
879+ "DeriveNextTaprootAssetKey" , mock .Anything ,
880+ ).Return (keychain.KeyDescriptor {}, nil )
874881
875882 return keyRing
876883}
@@ -880,6 +887,7 @@ func NewMockKeyRing() *MockKeyRing {
880887func (m * MockKeyRing ) DeriveNextTaprootAssetKey (
881888 ctx context.Context ) (keychain.KeyDescriptor , error ) {
882889
890+ // No need to lock mutex here, DeriveNextKey does that for us.
883891 m .Called (ctx )
884892
885893 return m .DeriveNextKey (ctx , asset .TaprootAssetsKeyFamily )
@@ -888,20 +896,21 @@ func (m *MockKeyRing) DeriveNextTaprootAssetKey(
888896func (m * MockKeyRing ) DeriveNextKey (ctx context.Context ,
889897 keyFam keychain.KeyFamily ) (keychain.KeyDescriptor , error ) {
890898
899+ m .Lock ()
900+ defer func () {
901+ m .KeyIndex ++
902+ m .Unlock ()
903+ }()
904+
891905 m .Called (ctx , keyFam )
906+ m .deriveNextKeyCallCount .Add (1 )
892907
893908 select {
894909 case <- ctx .Done ():
895910 return keychain.KeyDescriptor {}, fmt .Errorf ("shutting down" )
896911 default :
897912 }
898913
899- m .Lock ()
900- defer func () {
901- m .KeyIndex ++
902- m .Unlock ()
903- }()
904-
905914 priv , err := btcec .NewPrivateKey ()
906915 if err != nil {
907916 return keychain.KeyDescriptor {}, err
@@ -925,10 +934,10 @@ func (m *MockKeyRing) DeriveNextKey(ctx context.Context,
925934func (m * MockKeyRing ) IsLocalKey (ctx context.Context ,
926935 d keychain.KeyDescriptor ) bool {
927936
928- m .Called (ctx , d )
937+ m .Lock ()
938+ defer m .Unlock ()
929939
930- m .RLock ()
931- defer m .RUnlock ()
940+ m .Called (ctx , d )
932941
933942 priv , ok := m .Keys [d .KeyLocator ]
934943 if ok && priv .PubKey ().IsEqual (d .PubKey ) {
@@ -945,8 +954,8 @@ func (m *MockKeyRing) IsLocalKey(ctx context.Context,
945954}
946955
947956func (m * MockKeyRing ) PubKeyAt (t * testing.T , idx uint32 ) * btcec.PublicKey {
948- m .RLock ()
949- defer m .RUnlock ()
957+ m .Lock ()
958+ defer m .Unlock ()
950959
951960 loc := keychain.KeyLocator {
952961 Index : idx ,
@@ -962,8 +971,8 @@ func (m *MockKeyRing) PubKeyAt(t *testing.T, idx uint32) *btcec.PublicKey {
962971}
963972
964973func (m * MockKeyRing ) ScriptKeyAt (t * testing.T , idx uint32 ) asset.ScriptKey {
965- m .RLock ()
966- defer m .RUnlock ()
974+ m .Lock ()
975+ defer m .Unlock ()
967976
968977 loc := keychain.KeyLocator {
969978 Index : idx ,
@@ -984,13 +993,13 @@ func (m *MockKeyRing) ScriptKeyAt(t *testing.T, idx uint32) asset.ScriptKey {
984993func (m * MockKeyRing ) DeriveSharedKey (_ context.Context , key * btcec.PublicKey ,
985994 locator * keychain.KeyLocator ) ([sha256 .Size ]byte , error ) {
986995
996+ m .Lock ()
997+ defer m .Unlock ()
998+
987999 if locator == nil {
9881000 return [32 ]byte {}, fmt .Errorf ("locator is nil" )
9891001 }
9901002
991- m .RLock ()
992- defer m .RUnlock ()
993-
9941003 priv , ok := m .Keys [* locator ]
9951004 if ! ok {
9961005 return [32 ]byte {}, fmt .Errorf ("script key not found at index " +
@@ -1003,6 +1012,19 @@ func (m *MockKeyRing) DeriveSharedKey(_ context.Context, key *btcec.PublicKey,
10031012 return ecdh .ECDH (key )
10041013}
10051014
1015+ // DeriveNextKeyCallCount returns the number of calls to DeriveNextKey. This is
1016+ // useful in tests to assert that the key ring was used as expected in
1017+ // concurrent scenarios.
1018+ func (m * MockKeyRing ) DeriveNextKeyCallCount () int {
1019+ return int (m .deriveNextKeyCallCount .Load ())
1020+ }
1021+
1022+ // ResetDeriveNextKeyCallCount resets the call counter for DeriveNextKey to
1023+ // zero. This is useful in tests to ensure a clean state for assertions.
1024+ func (m * MockKeyRing ) ResetDeriveNextKeyCallCount () {
1025+ m .deriveNextKeyCallCount .Store (0 )
1026+ }
1027+
10061028type MockGenSigner struct {
10071029 KeyRing * MockKeyRing
10081030 failSigning atomic.Bool
0 commit comments