Skip to content

Commit 6fbf2aa

Browse files
committed
tapfeatures: add aux channel negotiator
1 parent 1a58486 commit 6fbf2aa

File tree

2 files changed

+229
-0
lines changed

2 files changed

+229
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
package tapfeatures
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
7+
"github.com/lightningnetwork/lnd/lnutils"
8+
"github.com/lightningnetwork/lnd/lnwire"
9+
"github.com/lightningnetwork/lnd/routing/route"
10+
"github.com/lightningnetwork/lnd/tlv"
11+
)
12+
13+
const (
14+
// AuxFeatureBitsTLV is the TLV type used to encode auxiliary feature
15+
// bits in the init message. These feature bits allow aux channel
16+
// implementations to negotiate custom channel behavior.
17+
AuxFeatureBitsTLV tlv.Type = 65545
18+
)
19+
20+
// AuxFeatureBits is a type alias for a TLV blob that contains custom feature
21+
// bits for auxiliary channel negotiation.
22+
type AuxFeatureBits = tlv.Blob
23+
24+
// AuxChannelNegotiator is responsible for producing the extra tlv blob that is
25+
// encapsulated in the init and reestablish peer messages. This helps us
26+
// communicate custom feature bits with our peer.
27+
type AuxChannelNegotiator struct {
28+
// peerFeatures keeps track of the supported features of our peers. This
29+
// map will be used for lookups by other subsystems, when some features
30+
// need to be supported by both parties to take effect.
31+
peerFeatures lnutils.SyncMap[route.Vertex, *lnwire.RawFeatureVector]
32+
33+
// chanFeatures keeps track of the supported features of each channel.
34+
// This map will be used for lookups by other subsystems to check
35+
// whether certain custom channel features are supported.
36+
chanFeatures lnutils.SyncMap[lnwire.ChannelID, *lnwire.RawFeatureVector]
37+
}
38+
39+
// NewAuxChannelNegotiator returns a new instance of the aux channel negotiator.
40+
func NewAuxChannelNegotiator() *AuxChannelNegotiator {
41+
return &AuxChannelNegotiator{}
42+
}
43+
44+
// GetInitRecords is called when sending an init message to a peer. It returns
45+
// custom feature bits to include in the init message TLVs. The implementation
46+
// can decide which features to advertise based on the peer's identity.
47+
func (n *AuxChannelNegotiator) GetInitRecords(
48+
_ route.Vertex) (lnwire.CustomRecords, error) {
49+
50+
var buf bytes.Buffer
51+
52+
// Grab the "static" feature vector that denotes the supported features
53+
// of our node. If our peer can read this message they will keep track
54+
// of our features just like we do below in `ProcessInitFeatures`.
55+
features := getLocalFeatureVec()
56+
err := features.Encode(&buf)
57+
if err != nil {
58+
return nil, err
59+
}
60+
61+
tlvMap := make(tlv.TypeMap, 1)
62+
tlvMap[AuxFeatureBitsTLV] = buf.Bytes()
63+
64+
return lnwire.NewCustomRecords(tlvMap)
65+
}
66+
67+
// ProcessInitRecords handles received init feature TLVs from a peer. The
68+
// implementation can store state internally to affect future channel operations
69+
// with this peer.
70+
func (n *AuxChannelNegotiator) ProcessInitRecords(peer route.Vertex,
71+
customRecords lnwire.CustomRecords) error {
72+
73+
auxRecord, ok := customRecords[uint64(AuxFeatureBitsTLV)]
74+
if !ok {
75+
// If the entry was not present, delete the previous entry. Our
76+
// peer did not provide a custom feature bit vector this time.
77+
n.peerFeatures.Delete(peer)
78+
return nil
79+
}
80+
81+
buf := bytes.NewBuffer(auxRecord)
82+
peerVec := lnwire.NewRawFeatureVector()
83+
err := peerVec.Decode(buf)
84+
if err != nil {
85+
return err
86+
}
87+
88+
// Before we store this peer's supported features we need to first check
89+
// if our required features are supported by that peer. If a locally
90+
// required feature is not supported by the remote peer we have to
91+
// return an error and drop the connection. Whether we support all of
92+
// the remote required features is a responsibility of the remote peer.
93+
// If we fail to support a remotely required feature they are the ones
94+
// to drop the connection (by returning an error right here).
95+
err = checkRequiredBits(getLocalFeatureVec(), peerVec)
96+
if err != nil {
97+
return err
98+
}
99+
100+
// Store this peer's features.
101+
n.peerFeatures.Store(peer, peerVec)
102+
103+
return nil
104+
}
105+
106+
// ProcessChannelReady handles the reception of the ChannelReady message, which
107+
// signals that a newly established channel is now ready to use. This helps us
108+
// correlate a peer's features with a channel outpoint
109+
func (n *AuxChannelNegotiator) ProcessChannelReady(cid lnwire.ChannelID,
110+
peer route.Vertex) {
111+
112+
features, ok := n.peerFeatures.Load(peer)
113+
if ok {
114+
n.chanFeatures.Store(cid, features)
115+
}
116+
}
117+
118+
// ProcessReestablish handles the reception of the ChannelReestablish message,
119+
// which signals that a previously established channel is now ready to use. This
120+
// helps us correlate a peer's features with a channel outpoint.
121+
func (n *AuxChannelNegotiator) ProcessReestablish(
122+
cid lnwire.ChannelID, peer route.Vertex) {
123+
124+
features, ok := n.peerFeatures.Load(peer)
125+
if ok {
126+
n.chanFeatures.Store(cid, features)
127+
}
128+
}
129+
130+
// GetPeerFeatures returns the negotiated feature bit vector that was
131+
// established with the given peer.
132+
func (n *AuxChannelNegotiator) GetPeerFeatures(
133+
peer route.Vertex) lnwire.FeatureVector {
134+
135+
rawfeatures, ok := n.peerFeatures.Load(peer)
136+
if !ok {
137+
rawfeatures = lnwire.NewRawFeatureVector()
138+
}
139+
140+
return *lnwire.NewFeatureVector(rawfeatures, featureNames)
141+
}
142+
143+
// GetChannelFeatures returns the negotiated feature bits vector for the channel
144+
// identified by the given channelID.
145+
func (n *AuxChannelNegotiator) GetChannelFeatures(
146+
cid lnwire.ChannelID) lnwire.FeatureVector {
147+
148+
rawfeatures, ok := n.chanFeatures.Load(cid)
149+
if !ok {
150+
rawfeatures = lnwire.NewRawFeatureVector()
151+
}
152+
153+
return *lnwire.NewFeatureVector(rawfeatures, featureNames)
154+
}
155+
156+
// checkRequiredBits is a helper method that checks if all of the required bits
157+
// of the first vector are supported by the second vector.
158+
func checkRequiredBits(local, remote *lnwire.RawFeatureVector) error {
159+
localBits, remoteBits :=
160+
lnwire.NewFeatureVector(local, featureNames),
161+
lnwire.NewFeatureVector(remote, featureNames)
162+
163+
for _, f := range ourFeatures() {
164+
if localBits.RequiresFeature(f) && !remoteBits.HasFeature(f) {
165+
return fmt.Errorf("peer does not support required "+
166+
"feature: %v", localBits.Name(f))
167+
}
168+
}
169+
170+
return nil
171+
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package tapfeatures
2+
3+
import (
4+
"testing"
5+
6+
"github.com/lightningnetwork/lnd/lnwire"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
// TestFeatureBits tests that the behavior of the feature vector matches our
11+
// expectations when using the custom feature bits for taproot asset channels.
12+
func TestFeatureBits(t *testing.T) {
13+
featuresA := lnwire.NewFeatureVector(
14+
lnwire.NewRawFeatureVector(NoOpHTLCsOptional), featureNames,
15+
)
16+
17+
featuresB := lnwire.NewFeatureVector(
18+
lnwire.NewRawFeatureVector(STXOOptional), featureNames,
19+
)
20+
21+
require.True(t, featuresA.HasFeature(NoOpHTLCsOptional))
22+
require.True(t, featuresB.HasFeature(STXOOptional))
23+
24+
require.False(t, featuresA.HasFeature(STXOOptional))
25+
require.False(t, featuresB.HasFeature(NoOpHTLCsOptional))
26+
27+
require.False(t, featuresA.RequiresFeature(NoOpHTLCsOptional))
28+
require.False(t, featuresB.RequiresFeature(STXOOptional))
29+
30+
err := checkRequiredBits(
31+
featuresA.RawFeatureVector, featuresB.RawFeatureVector,
32+
)
33+
34+
require.NoError(t, err)
35+
36+
featuresA = lnwire.NewFeatureVector(
37+
lnwire.NewRawFeatureVector(NoOpHTLCsRequired), featureNames,
38+
)
39+
40+
featuresB = lnwire.NewFeatureVector(
41+
lnwire.NewRawFeatureVector(STXORequired), featureNames,
42+
)
43+
44+
require.True(t, featuresA.HasFeature(NoOpHTLCsOptional))
45+
require.True(t, featuresB.HasFeature(STXOOptional))
46+
47+
require.False(t, featuresA.HasFeature(STXOOptional))
48+
require.False(t, featuresB.HasFeature(NoOpHTLCsOptional))
49+
50+
require.True(t, featuresA.RequiresFeature(NoOpHTLCsOptional))
51+
require.True(t, featuresB.RequiresFeature(STXOOptional))
52+
53+
err = checkRequiredBits(
54+
featuresA.RawFeatureVector, featuresB.RawFeatureVector,
55+
)
56+
57+
require.Error(t, err)
58+
}

0 commit comments

Comments
 (0)