Skip to content

Commit c47c4ad

Browse files
bhosalejchrys
authored andcommitted
Added Support for providing configuration option for supplying password function (#157)
Motivation: Currently r2dbc-mysql does not support IAM based Authentication for authenticating with AWS Aurora RDS database. The way IAM based authentication works is requesting token from AWS RDS for that hostname and username (which is same as AWS IAM Role name). These tokens are valid for 15 minutes, cannot reuse same token after 15 minutes. By adding configuration option for supplying password function, whenever a new connection is made then password is retrieved using supplier function each time. Modification: Modified `MySqlConnectionFactoryProvider` - Added new configurable option. `Option<Publisher<String>> PASSWORD_SUPPLIER = Option.valueOf("passwordSupplier");` Modified `MySqlConnectionConfiguration` - Added the new configuration for Password Supplier function. `Publisher<String> passwordSupplier;` Modified `MySqlConnectionFactory` - Retrieves Password Supplier function from configuration, and then retrieves password each time connection factory is created. Result: Users can provide a supplier function using `PASSWORD_SUPPLIER` option. This function will be used for retrieving password/token each time. ``` public ConnectionFactory writeConnectionFactory(final RdsTokenGenerator rdsTokenGenerator) { return ConnectionFactories.get(ConnectionFactoryOptions.builder() .option(ConnectionFactoryOptions.DRIVER, "mysql") .option(ConnectionFactoryOptions.HOST, "Hostname of AWS Aurora DB instance") .option(ConnectionFactoryOptions.PORT, 3306) .option(ConnectionFactoryOptions.USER, "IAM ROLE Having access to RDS") .option(MySqlConnectionFactoryProvider.PASSWORD_SUPPLIER, rdsTokenGenerator. generateAuthenticationToken()) .build()); } ``` Example of `RdsTokenGenerator` ``` public class RdsTokenGenerator { public Mono<String> generateAuthenticationToken() { return Mono.fromCallable(() -> RdsUtilities.builder() .credentialsProvider(DefaultCredentialsProvider.create()) .region(Region.US_EAST_1) .build(); .generateAuthenticationToken((builder) -> builder .hostname(hostname) .port(port) .username(user) )) .flatMap(token -> LOGGER.info("Retrieved token from RdsUtilities") .then(Mono.just(token))) .subscribeOn(Schedulers.boundedElastic()); } } ```
1 parent 665d3cd commit c47c4ad

File tree

5 files changed

+126
-20
lines changed

5 files changed

+126
-20
lines changed

src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfiguration.java

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import io.asyncer.r2dbc.mysql.extension.Extension;
2222
import io.netty.handler.ssl.SslContextBuilder;
2323
import org.jetbrains.annotations.Nullable;
24+
import org.reactivestreams.Publisher;
2425

2526
import javax.net.ssl.HostnameVerifier;
2627
import java.net.Socket;
@@ -92,12 +93,15 @@ public final class MySqlConnectionConfiguration {
9293

9394
private final Extensions extensions;
9495

96+
@Nullable
97+
private final Publisher<String> passwordSupplier;
98+
9599
private MySqlConnectionConfiguration(boolean isHost, String domain, int port, MySqlSslConfiguration ssl,
96100
boolean tcpKeepAlive, boolean tcpNoDelay, @Nullable Duration connectTimeout,
97101
@Nullable Duration socketTimeout, ZeroDateOption zeroDateOption, @Nullable ZoneId serverZoneId,
98102
String user, @Nullable CharSequence password, @Nullable String database,
99103
@Nullable Predicate<String> preferPrepareStatement, int queryCacheSize, int prepareCacheSize,
100-
Extensions extensions) {
104+
Extensions extensions, @Nullable Publisher<String> passwordSupplier) {
101105
this.isHost = isHost;
102106
this.domain = domain;
103107
this.port = port;
@@ -115,6 +119,7 @@ private MySqlConnectionConfiguration(boolean isHost, String domain, int port, My
115119
this.queryCacheSize = queryCacheSize;
116120
this.prepareCacheSize = prepareCacheSize;
117121
this.extensions = extensions;
122+
this.passwordSupplier = passwordSupplier;
118123
}
119124

120125
/**
@@ -204,6 +209,11 @@ Extensions getExtensions() {
204209
return extensions;
205210
}
206211

212+
@Nullable
213+
Publisher<String> getPasswordSupplier() {
214+
return passwordSupplier;
215+
}
216+
207217
@Override
208218
public boolean equals(Object o) {
209219
if (this == o) {
@@ -229,14 +239,15 @@ public boolean equals(Object o) {
229239
Objects.equals(preferPrepareStatement, that.preferPrepareStatement) &&
230240
queryCacheSize == that.queryCacheSize &&
231241
prepareCacheSize == that.prepareCacheSize &&
232-
extensions.equals(that.extensions);
242+
extensions.equals(that.extensions) &&
243+
Objects.equals(passwordSupplier, that.passwordSupplier);
233244
}
234245

235246
@Override
236247
public int hashCode() {
237248
return Objects.hash(isHost, domain, port, ssl, tcpKeepAlive, tcpNoDelay,
238249
connectTimeout, socketTimeout, serverZoneId, zeroDateOption, user, password, database,
239-
preferPrepareStatement, queryCacheSize, prepareCacheSize, extensions);
250+
preferPrepareStatement, queryCacheSize, prepareCacheSize, extensions, passwordSupplier);
240251
}
241252

242253
@Override
@@ -248,15 +259,15 @@ public String toString() {
248259
", zeroDateOption=" + zeroDateOption + ", user='" + user + '\'' + ", password=" + password +
249260
", database='" + database + "', preferPrepareStatement=" + preferPrepareStatement +
250261
", queryCacheSize=" + queryCacheSize + ", prepareCacheSize=" + prepareCacheSize +
251-
", extensions=" + extensions + '}';
262+
", extensions=" + extensions + ", passwordSupplier="+ passwordSupplier + '}';
252263
}
253264

254265
return "MySqlConnectionConfiguration{, unixSocket='" + domain + "', connectTimeout=" +
255266
connectTimeout + ", socketTimeout=" + socketTimeout + ", serverZoneId=" + serverZoneId +
256267
", zeroDateOption=" + zeroDateOption + ", user='" + user + "', password=" + password +
257268
", database='" + database + "', preferPrepareStatement=" + preferPrepareStatement +
258269
", queryCacheSize=" + queryCacheSize + ", prepareCacheSize=" + prepareCacheSize +
259-
", extensions=" + extensions + '}';
270+
", extensions=" + extensions + ", passwordSupplier="+ passwordSupplier + '}';
260271
}
261272

262273
/**
@@ -327,6 +338,9 @@ public static final class Builder {
327338

328339
private final List<Extension> extensions = new ArrayList<>();
329340

341+
@Nullable
342+
private Publisher<String> passwordSupplier;
343+
330344
/**
331345
* Builds an immutable {@link MySqlConnectionConfiguration} with current options.
332346
*
@@ -351,7 +365,7 @@ public MySqlConnectionConfiguration build() {
351365
return new MySqlConnectionConfiguration(isHost, domain, port, ssl, tcpKeepAlive, tcpNoDelay,
352366
connectTimeout, socketTimeout, zeroDateOption, serverZoneId, user, password, database,
353367
preferPrepareStatement, queryCacheSize, prepareCacheSize,
354-
Extensions.from(extensions, autodetectExtensions));
368+
Extensions.from(extensions, autodetectExtensions), passwordSupplier);
355369
}
356370

357371
/**
@@ -779,6 +793,16 @@ public Builder extendWith(Extension extension) {
779793
return this;
780794
}
781795

796+
/**
797+
* Registers a password supplier function.
798+
* @param passwordSupplier function to retrieve password before making connection.
799+
* @return this {@link Builder}.
800+
*/
801+
public Builder passwordSupplier(Publisher<String> passwordSupplier) {
802+
this.passwordSupplier = passwordSupplier;
803+
return this;
804+
}
805+
782806
private SslMode requireSslMode() {
783807
SslMode sslMode = this.sslMode;
784808

src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@
2828
import io.netty.channel.unix.DomainSocketAddress;
2929
import io.r2dbc.spi.ConnectionFactory;
3030
import io.r2dbc.spi.ConnectionFactoryMetadata;
31+
import org.jetbrains.annotations.NotNull;
3132
import org.jetbrains.annotations.Nullable;
33+
import org.reactivestreams.Publisher;
3234
import reactor.core.publisher.Mono;
3335

3436
import java.net.InetSocketAddress;
3537
import java.net.SocketAddress;
38+
import java.util.Objects;
3639
import java.util.function.Predicate;
3740

3841
import static io.asyncer.r2dbc.mysql.internal.util.AssertUtils.requireNonNull;
@@ -91,22 +94,54 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
9194
Extensions extensions = configuration.getExtensions();
9295
Predicate<String> prepare = configuration.getPreferPrepareStatement();
9396
int prepareCacheSize = configuration.getPrepareCacheSize();
97+
Publisher<String> passwordSupplier = configuration.getPasswordSupplier();
98+
if (Objects.nonNull(passwordSupplier)) {
99+
return Mono.from(passwordSupplier)
100+
.flatMap(token -> getMySqlConnection(
101+
configuration, queryCache,
102+
ssl, address,
103+
database, user,
104+
sslMode, context,
105+
extensions, prepare,
106+
prepareCacheSize, token));
107+
}
108+
return getMySqlConnection(configuration, queryCache,
109+
ssl, address,
110+
database, user,
111+
sslMode, context,
112+
extensions, prepare,
113+
prepareCacheSize, password);
114+
}));
115+
}
94116

95-
return Client.connect(ssl, address, configuration.isTcpKeepAlive(), configuration.isTcpNoDelay(),
117+
@NotNull
118+
private static Mono<MySqlConnection> getMySqlConnection(
119+
final MySqlConnectionConfiguration configuration,
120+
final LazyQueryCache queryCache,
121+
final MySqlSslConfiguration ssl,
122+
final SocketAddress address,
123+
final String database,
124+
final String user,
125+
final SslMode sslMode,
126+
final ConnectionContext context,
127+
final Extensions extensions,
128+
@Nullable final Predicate<String> prepare,
129+
final int prepareCacheSize,
130+
@Nullable final CharSequence password) {
131+
return Client.connect(ssl, address, configuration.isTcpKeepAlive(), configuration.isTcpNoDelay(),
96132
context, configuration.getConnectTimeout(), configuration.getSocketTimeout())
97-
.flatMap(client -> QueryFlow.login(client, sslMode, database, user, password, context))
98-
.flatMap(client -> {
99-
ByteBufAllocator allocator = client.getByteBufAllocator();
100-
CodecsBuilder builder = Codecs.builder(allocator);
101-
PrepareCache prepareCache = Caches.createPrepareCache(prepareCacheSize);
102-
103-
extensions.forEach(CodecRegistrar.class, registrar ->
104-
registrar.register(allocator, builder));
105-
106-
return MySqlConnection.init(client, builder.build(), context, queryCache.get(),
107-
prepareCache, prepare);
108-
});
109-
}));
133+
.flatMap(client -> QueryFlow.login(client, sslMode, database, user, password, context))
134+
.flatMap(client -> {
135+
ByteBufAllocator allocator = client.getByteBufAllocator();
136+
CodecsBuilder builder = Codecs.builder(allocator);
137+
PrepareCache prepareCache = Caches.createPrepareCache(prepareCacheSize);
138+
139+
extensions.forEach(CodecRegistrar.class, registrar ->
140+
registrar.register(allocator, builder));
141+
142+
return MySqlConnection.init(client, builder.build(), context, queryCache.get(),
143+
prepareCache, prepare);
144+
});
110145
}
111146

112147
private static final class LazyQueryCache {

src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProvider.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.r2dbc.spi.ConnectionFactoryOptions;
2424
import io.r2dbc.spi.ConnectionFactoryProvider;
2525
import io.r2dbc.spi.Option;
26+
import org.reactivestreams.Publisher;
2627

2728
import javax.net.ssl.HostnameVerifier;
2829
import java.time.Duration;
@@ -203,6 +204,14 @@ public final class MySqlConnectionFactoryProvider implements ConnectionFactoryPr
203204
*/
204205
public static final Option<Boolean> AUTODETECT_EXTENSIONS = Option.valueOf("autodetectExtensions");
205206

207+
/**
208+
* Password Supplier function can be used to retrieve password before creating a connection.
209+
* This can be used with Amazon RDS Aurora IAM authentication, wherein it requires token to be generated.
210+
* The token is valid for 15 minutes, and this token will be used as password.
211+
*
212+
*/
213+
public static final Option<Publisher<String>> PASSWORD_SUPPLIER = Option.valueOf("passwordSupplier");
214+
206215
@Override
207216
public ConnectionFactory create(ConnectionFactoryOptions options) {
208217
requireNonNull(options, "connectionFactoryOptions must not be null");
@@ -261,6 +270,8 @@ static MySqlConnectionConfiguration setup(ConnectionFactoryOptions options) {
261270
.to(builder::socketTimeout);
262271
mapper.optional(DATABASE).asString()
263272
.to(builder::database);
273+
mapper.optional(PASSWORD_SUPPLIER).as(Publisher.class)
274+
.to(builder::passwordSupplier);
264275

265276
return builder.build();
266277
}

src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionConfigurationTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,16 @@
2525
import org.assertj.core.api.ThrowableTypeAssert;
2626
import org.jetbrains.annotations.Nullable;
2727
import org.junit.jupiter.api.Test;
28+
import reactor.core.publisher.Mono;
29+
import reactor.test.StepVerifier;
2830

2931
import java.time.Duration;
3032
import java.time.ZoneId;
3133
import java.util.ArrayList;
3234
import java.util.List;
3335
import java.util.Objects;
3436
import java.util.function.Function;
37+
import java.util.function.Supplier;
3538

3639
import static org.assertj.core.api.Assertions.assertThat;
3740
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -189,6 +192,21 @@ void nonAutodetectExtensions() {
189192
assertThat(list).isEmpty();
190193
}
191194

195+
@Test
196+
void validPasswordSupplier() {
197+
final Mono<String> passwordSupplier = Mono.just("123456");
198+
Mono.from(MySqlConnectionConfiguration.builder()
199+
.host(HOST)
200+
.user(USER)
201+
.passwordSupplier(passwordSupplier)
202+
.autodetectExtensions(false)
203+
.build()
204+
.getPasswordSupplier())
205+
.as(StepVerifier::create)
206+
.expectNext("123456")
207+
.verifyComplete();
208+
}
209+
192210
private static MySqlConnectionConfiguration unixSocketSslMode(SslMode sslMode) {
193211
return MySqlConnectionConfiguration.builder()
194212
.unixSocket(UNIX_SOCKET)

src/test/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactoryProviderTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import io.r2dbc.spi.Option;
2525
import org.assertj.core.api.Assert;
2626
import org.junit.jupiter.api.Test;
27+
import org.reactivestreams.Publisher;
28+
import reactor.core.publisher.Mono;
2729

2830
import javax.net.ssl.HostnameVerifier;
2931
import javax.net.ssl.SSLSession;
@@ -34,7 +36,9 @@
3436
import java.util.Collections;
3537
import java.util.function.Function;
3638
import java.util.function.Predicate;
39+
import java.util.function.Supplier;
3740

41+
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.PASSWORD_SUPPLIER;
3842
import static io.asyncer.r2dbc.mysql.MySqlConnectionFactoryProvider.USE_SERVER_PREPARE_STATEMENT;
3943
import static io.r2dbc.spi.ConnectionFactoryOptions.CONNECT_TIMEOUT;
4044
import static io.r2dbc.spi.ConnectionFactoryOptions.DATABASE;
@@ -390,6 +394,20 @@ void invalidServerPreparing() {
390394
.option(USE_SERVER_PREPARE_STATEMENT, NotPredicate.class.getPackage() + "NonePredicate")
391395
.build()));
392396
}
397+
398+
@Test
399+
void validPasswordSupplier() {
400+
final Publisher<String> passwordSupplier = Mono.just("123456");
401+
ConnectionFactoryOptions options = ConnectionFactoryOptions.builder()
402+
.option(DRIVER, "mysql")
403+
.option(HOST, "127.0.0.1")
404+
.option(USER, "root")
405+
.option(PASSWORD_SUPPLIER, passwordSupplier)
406+
.build();
407+
408+
assertThat(ConnectionFactories.get(options)).isExactlyInstanceOf(MySqlConnectionFactory.class);
409+
}
410+
393411
}
394412

395413
final class MockException extends RuntimeException {

0 commit comments

Comments
 (0)