Skip to content

Commit f147e42

Browse files
committed
tapgarden: switch to full locks and acquire them first in methods
Replace read-only locks with full mutex locks in MockKeyRing methods for simplicity. Ensure the full lock is acquired as the first operation in each method to maintain consistent and safe access.
1 parent 0ddb086 commit f147e42

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

tapgarden/mock.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -880,6 +880,7 @@ func NewMockKeyRing() *MockKeyRing {
880880
func (m *MockKeyRing) DeriveNextTaprootAssetKey(
881881
ctx context.Context) (keychain.KeyDescriptor, error) {
882882

883+
// No need to lock mutex here, DeriveNextKey does that for us.
883884
m.Called(ctx)
884885

885886
return m.DeriveNextKey(ctx, asset.TaprootAssetsKeyFamily)
@@ -888,6 +889,12 @@ func (m *MockKeyRing) DeriveNextTaprootAssetKey(
888889
func (m *MockKeyRing) DeriveNextKey(ctx context.Context,
889890
keyFam keychain.KeyFamily) (keychain.KeyDescriptor, error) {
890891

892+
m.Lock()
893+
defer func() {
894+
m.KeyIndex++
895+
m.Unlock()
896+
}()
897+
891898
m.Called(ctx, keyFam)
892899

893900
select {
@@ -896,12 +903,6 @@ func (m *MockKeyRing) DeriveNextKey(ctx context.Context,
896903
default:
897904
}
898905

899-
m.Lock()
900-
defer func() {
901-
m.KeyIndex++
902-
m.Unlock()
903-
}()
904-
905906
priv, err := btcec.NewPrivateKey()
906907
if err != nil {
907908
return keychain.KeyDescriptor{}, err
@@ -925,10 +926,10 @@ func (m *MockKeyRing) DeriveNextKey(ctx context.Context,
925926
func (m *MockKeyRing) IsLocalKey(ctx context.Context,
926927
d keychain.KeyDescriptor) bool {
927928

928-
m.Called(ctx, d)
929+
m.Lock()
930+
defer m.Unlock()
929931

930-
m.RLock()
931-
defer m.RUnlock()
932+
m.Called(ctx, d)
932933

933934
priv, ok := m.Keys[d.KeyLocator]
934935
if ok && priv.PubKey().IsEqual(d.PubKey) {
@@ -945,8 +946,8 @@ func (m *MockKeyRing) IsLocalKey(ctx context.Context,
945946
}
946947

947948
func (m *MockKeyRing) PubKeyAt(t *testing.T, idx uint32) *btcec.PublicKey {
948-
m.RLock()
949-
defer m.RUnlock()
949+
m.Lock()
950+
defer m.Unlock()
950951

951952
loc := keychain.KeyLocator{
952953
Index: idx,
@@ -962,8 +963,8 @@ func (m *MockKeyRing) PubKeyAt(t *testing.T, idx uint32) *btcec.PublicKey {
962963
}
963964

964965
func (m *MockKeyRing) ScriptKeyAt(t *testing.T, idx uint32) asset.ScriptKey {
965-
m.RLock()
966-
defer m.RUnlock()
966+
m.Lock()
967+
defer m.Unlock()
967968

968969
loc := keychain.KeyLocator{
969970
Index: idx,
@@ -984,13 +985,13 @@ func (m *MockKeyRing) ScriptKeyAt(t *testing.T, idx uint32) asset.ScriptKey {
984985
func (m *MockKeyRing) DeriveSharedKey(_ context.Context, key *btcec.PublicKey,
985986
locator *keychain.KeyLocator) ([sha256.Size]byte, error) {
986987

988+
m.Lock()
989+
defer m.Unlock()
990+
987991
if locator == nil {
988992
return [32]byte{}, fmt.Errorf("locator is nil")
989993
}
990994

991-
m.RLock()
992-
defer m.RUnlock()
993-
994995
priv, ok := m.Keys[*locator]
995996
if !ok {
996997
return [32]byte{}, fmt.Errorf("script key not found at index "+

0 commit comments

Comments
 (0)