Skip to content

Commit 51f9e3e

Browse files
committed
Add support for create database if not exist
- Support `InitDbMessage`. - Support `changeDatabase` in `MySqlConnection`. - Add integration tests for that.
1 parent dbd7d3a commit 51f9e3e

15 files changed

+148
-30
lines changed

src/main/java/io/asyncer/r2dbc/mysql/Capability.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ public final class Capability {
2626

2727
/**
2828
* Can use long password.
29+
* <p>
30+
* TODO: Reinterpret it as {@code CLIENT_MYSQL} to support MariaDB 10.2 and above.
2931
*/
3032
private static final int LONG_PASSWORD = 1;
3133

src/main/java/io/asyncer/r2dbc/mysql/MySqlConnection.java

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import io.asyncer.r2dbc.mysql.client.Client;
2222
import io.asyncer.r2dbc.mysql.codec.Codecs;
2323
import io.asyncer.r2dbc.mysql.constant.ServerStatuses;
24+
import io.asyncer.r2dbc.mysql.message.client.InitDbMessage;
2425
import io.asyncer.r2dbc.mysql.message.client.PingMessage;
2526
import io.asyncer.r2dbc.mysql.message.server.CompleteMessage;
2627
import io.asyncer.r2dbc.mysql.message.server.ErrorMessage;
@@ -91,6 +92,31 @@ public final class MySqlConnection implements Connection, ConnectionState {
9192
}
9293
};
9394

95+
private static final BiConsumer<ServerMessage, SynchronousSink<Boolean>> INIT_DB = (message, sink) -> {
96+
if (message instanceof ErrorMessage) {
97+
ErrorMessage msg = (ErrorMessage) message;
98+
logger.debug("Use database failed: [{}] [{}] {}", msg.getCode(), msg.getSqlState(),
99+
msg.getMessage());
100+
sink.next(false);
101+
sink.complete();
102+
} else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
103+
sink.next(true);
104+
sink.complete();
105+
} else {
106+
ReferenceCountUtil.safeRelease(message);
107+
}
108+
};
109+
110+
private static final BiConsumer<ServerMessage, SynchronousSink<Void>> INIT_DB_AFTER = (message, sink) -> {
111+
if (message instanceof ErrorMessage) {
112+
sink.error(((ErrorMessage) message).toException());
113+
} else if (message instanceof CompleteMessage && ((CompleteMessage) message).isDone()) {
114+
sink.complete();
115+
} else {
116+
ReferenceCountUtil.safeRelease(message);
117+
}
118+
};
119+
94120
private final Client client;
95121

96122
private final Codecs codecs;
@@ -403,13 +429,17 @@ boolean isSessionAutoCommit() {
403429
* @param client must be logged-in.
404430
* @param codecs the {@link Codecs}.
405431
* @param context must be initialized.
432+
* @param database the database that should be lazy init.
406433
* @param queryCache the cache of {@link Query}.
407434
* @param prepareCache the cache of server-preparing result.
408435
* @param prepare judging for prefer use prepare statement to execute simple query.
409436
* @return a {@link Mono} will emit an initialized {@link MySqlConnection}.
410437
*/
411-
static Mono<MySqlConnection> init(Client client, Codecs codecs, ConnectionContext context,
412-
QueryCache queryCache, PrepareCache prepareCache, @Nullable Predicate<String> prepare) {
438+
static Mono<MySqlConnection> init(
439+
Client client, Codecs codecs, ConnectionContext context, String database,
440+
QueryCache queryCache, PrepareCache prepareCache,
441+
@Nullable Predicate<String> prepare
442+
) {
413443
ServerVersion version = context.getServerVersion();
414444
StringBuilder query = new StringBuilder(128);
415445

@@ -431,7 +461,7 @@ static Mono<MySqlConnection> init(Client client, Codecs codecs, ConnectionContex
431461
handler = MySqlConnection::init;
432462
}
433463

434-
return new TextSimpleStatement(client, codecs, context, query.toString())
464+
Mono<MySqlConnection> connection = new TextSimpleStatement(client, codecs, context, query.toString())
435465
.execute()
436466
.flatMap(handler)
437467
.last()
@@ -445,6 +475,25 @@ static Mono<MySqlConnection> init(Client client, Codecs codecs, ConnectionContex
445475
return new MySqlConnection(client, context, codecs, data.level, data.lockWaitTimeout,
446476
queryCache, prepareCache, data.product, prepare);
447477
});
478+
479+
if (database.isEmpty()) {
480+
return connection;
481+
}
482+
483+
requireValidName(database, "database must not be empty and not contain backticks");
484+
485+
return connection.flatMap(conn -> client.exchange(new InitDbMessage(database), INIT_DB)
486+
.last()
487+
.flatMap(success -> {
488+
if (success) {
489+
return Mono.just(conn);
490+
}
491+
492+
String sql = String.format("CREATE DATABASE IF NOT EXISTS `%s`", database);
493+
494+
return QueryFlow.executeVoid(client, sql)
495+
.then(client.exchange(new InitDbMessage(database), INIT_DB_AFTER).then(Mono.just(conn)));
496+
}));
448497
}
449498

450499
private static Publisher<InitData> init(MySqlResult r) {

src/main/java/io/asyncer/r2dbc/mysql/MySqlConnectionFactory.java

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import io.netty.channel.unix.DomainSocketAddress;
2929
import io.r2dbc.spi.ConnectionFactory;
3030
import io.r2dbc.spi.ConnectionFactoryMetadata;
31-
import org.jetbrains.annotations.NotNull;
3231
import org.jetbrains.annotations.Nullable;
3332
import org.reactivestreams.Publisher;
3433
import reactor.core.publisher.Mono;
@@ -86,6 +85,7 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
8685
}
8786

8887
String database = configuration.getDatabase();
88+
boolean createDbIfNotExist = configuration.isCreateDatabaseIfNotExist();
8989
String user = configuration.getUser();
9090
CharSequence password = configuration.getPassword();
9191
SslMode sslMode = ssl.getSslMode();
@@ -95,32 +95,36 @@ public static MySqlConnectionFactory from(MySqlConnectionConfiguration configura
9595
Predicate<String> prepare = configuration.getPreferPrepareStatement();
9696
int prepareCacheSize = configuration.getPrepareCacheSize();
9797
Publisher<String> passwordPublisher = configuration.getPasswordPublisher();
98+
9899
if (Objects.nonNull(passwordPublisher)) {
99-
return Mono.from(passwordPublisher)
100-
.flatMap(token -> getMySqlConnection(
101-
configuration, queryCache,
102-
ssl, address,
103-
database, user,
104-
sslMode, context,
105-
extensions, prepare,
106-
prepareCacheSize, token));
100+
return Mono.from(passwordPublisher).flatMap(token -> getMySqlConnection(
101+
configuration, queryCache,
102+
ssl, address,
103+
database, createDbIfNotExist,
104+
user, sslMode, context,
105+
extensions, prepare,
106+
prepareCacheSize, token
107+
));
107108
}
108-
return getMySqlConnection(configuration, queryCache,
109+
110+
return getMySqlConnection(
111+
configuration, queryCache,
109112
ssl, address,
110-
database, user,
111-
sslMode, context,
113+
database, createDbIfNotExist,
114+
user, sslMode, context,
112115
extensions, prepare,
113-
prepareCacheSize, password);
116+
prepareCacheSize, password
117+
);
114118
}));
115119
}
116120

117-
@NotNull
118121
private static Mono<MySqlConnection> getMySqlConnection(
119122
final MySqlConnectionConfiguration configuration,
120123
final LazyQueryCache queryCache,
121124
final MySqlSslConfiguration ssl,
122125
final SocketAddress address,
123126
final String database,
127+
final boolean createDbIfNotExist,
124128
final String user,
125129
final SslMode sslMode,
126130
final ConnectionContext context,
@@ -130,16 +134,21 @@ private static Mono<MySqlConnection> getMySqlConnection(
130134
@Nullable final CharSequence password) {
131135
return Client.connect(ssl, address, configuration.isTcpKeepAlive(), configuration.isTcpNoDelay(),
132136
context, configuration.getConnectTimeout(), configuration.getSocketTimeout())
133-
.flatMap(client -> QueryFlow.login(client, sslMode, database, user, password, context))
137+
.flatMap(client -> {
138+
// Lazy init database after handshake/login
139+
String db = createDbIfNotExist ? "" : database;
140+
return QueryFlow.login(client, sslMode, db, user, password, context);
141+
})
134142
.flatMap(client -> {
135143
ByteBufAllocator allocator = client.getByteBufAllocator();
136144
CodecsBuilder builder = Codecs.builder(allocator);
137145
PrepareCache prepareCache = Caches.createPrepareCache(prepareCacheSize);
146+
String db = createDbIfNotExist ? database : "";
138147

139148
extensions.forEach(CodecRegistrar.class, registrar ->
140149
registrar.register(allocator, builder));
141150

142-
return MySqlConnection.init(client, builder.build(), context, queryCache.get(),
151+
return MySqlConnection.init(client, builder.build(), context, db, queryCache.get(),
143152
prepareCache, prepare);
144153
});
145154
}

src/main/java/io/asyncer/r2dbc/mysql/QueryFlow.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ private Capability clientCapability(Capability serverCapability) {
888888

889889
builder.disableDatabasePinned();
890890
builder.disableCompression();
891+
// TODO: support LOAD DATA LOCAL INFILE
891892
builder.disableLoadDataInfile();
892893
builder.disableIgnoreAmbiguitySpace();
893894
builder.disableInteractiveTimeout();
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package io.asyncer.r2dbc.mysql.message.client;
2+
3+
import io.asyncer.r2dbc.mysql.ConnectionContext;
4+
import io.netty.buffer.ByteBuf;
5+
6+
public final class InitDbMessage extends ScalarClientMessage {
7+
8+
private static final byte FLAG = 0x02;
9+
10+
private final String database;
11+
12+
public InitDbMessage(String database) { this.database = database; }
13+
14+
@Override
15+
protected void writeTo(ByteBuf buf, ConnectionContext context) {
16+
// RestOfPacketString, no need terminal or length
17+
buf.writeByte(FLAG).writeCharSequence(database, context.getClientCollation().getCharset());
18+
}
19+
}

src/test/java/io/asyncer/r2dbc/mysql/ConnectionIntegrationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
class ConnectionIntegrationTest extends IntegrationTestSupport {
3636

3737
ConnectionIntegrationTest() {
38-
super(configuration(false, null, null));
38+
super(configuration("r2dbc", false, false, null, null));
3939
}
4040

4141
@Test
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package io.asyncer.r2dbc.mysql;
2+
3+
import org.junit.jupiter.api.Test;
4+
5+
import java.util.concurrent.ThreadLocalRandom;
6+
import java.util.stream.Collectors;
7+
8+
import static org.assertj.core.api.Assertions.assertThat;
9+
10+
/**
11+
* Integration tests for {@code createDatabaseIfNotExist}.
12+
*/
13+
class InitDbIntegrationTest extends IntegrationTestSupport {
14+
15+
private static final String DATABASE = "test-" + ThreadLocalRandom.current().nextInt(10000);
16+
17+
InitDbIntegrationTest() {
18+
super(configuration(
19+
DATABASE, true, false,
20+
null, null
21+
));
22+
}
23+
24+
@Test
25+
void shouldCreateDatabase() {
26+
complete(conn -> conn.createStatement("SHOW DATABASES")
27+
.execute()
28+
.flatMap(it -> it.map((row, rowMetadata) -> row.get(0, String.class)))
29+
.collect(Collectors.toSet())
30+
.doOnNext(it -> assertThat(it).contains(DATABASE))
31+
.thenMany(conn.createStatement("DROP DATABASE `" + DATABASE + "`")
32+
.execute()
33+
.flatMap(MySqlResult::getRowsUpdated)));
34+
}
35+
}

src/test/java/io/asyncer/r2dbc/mysql/IntegrationTestSupport.java

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ static Mono<Long> extractRowsUpdated(Result result) {
7171
return Mono.from(result.getRowsUpdated());
7272
}
7373

74-
static MySqlConnectionConfiguration configuration(boolean autodetectExtensions,
75-
@Nullable ZoneId serverZoneId, @Nullable Predicate<String> preferPrepared) {
74+
static MySqlConnectionConfiguration configuration(
75+
String database, boolean createDatabaseIfNotExist, boolean autodetectExtensions,
76+
@Nullable ZoneId serverZoneId, @Nullable Predicate<String> preferPrepared
77+
) {
7678
String password = System.getProperty("test.mysql.password");
7779

7880
assertThat(password).withFailMessage("Property test.mysql.password must exists and not be empty")
@@ -84,7 +86,8 @@ static MySqlConnectionConfiguration configuration(boolean autodetectExtensions,
8486
.connectTimeout(Duration.ofSeconds(3))
8587
.user("root")
8688
.password(password)
87-
.database("r2dbc")
89+
.database(database)
90+
.createDatabaseIfNotExist(createDatabaseIfNotExist)
8891
.autodetectExtensions(autodetectExtensions);
8992

9093
if (serverZoneId != null) {

src/test/java/io/asyncer/r2dbc/mysql/JacksonPrepareIntegrationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
class JacksonPrepareIntegrationTest extends JacksonIntegrationTestSupport {
2323

2424
JacksonPrepareIntegrationTest() {
25-
super(configuration(true, null, sql -> false));
25+
super(configuration("r2dbc", false, true, null, sql -> false));
2626
}
2727
}
2828

src/test/java/io/asyncer/r2dbc/mysql/JacksonTextIntegrationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@
2222
class JacksonTextIntegrationTest extends JacksonIntegrationTestSupport {
2323

2424
JacksonTextIntegrationTest() {
25-
super(configuration(true, null, null));
25+
super(configuration("r2dbc", false, true, null, null));
2626
}
2727
}

0 commit comments

Comments
 (0)