Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support connecting to different databases #121

Merged
merged 3 commits into from
Apr 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
package com.google.cloud.spanner.pgadapter;

import com.google.api.core.InternalApi;
import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.connection.Connection;
import com.google.cloud.spanner.connection.ConnectionOptions;
import com.google.cloud.spanner.pgadapter.metadata.ConnectionMetadata;
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata;
import com.google.cloud.spanner.pgadapter.statements.IntermediatePortalStatement;
import com.google.cloud.spanner.pgadapter.statements.IntermediatePreparedStatement;
import com.google.cloud.spanner.pgadapter.statements.IntermediateStatement;
Expand Down Expand Up @@ -77,20 +81,13 @@ public class ConnectionHandler extends Thread {
private static final AtomicInteger incrementingConnectionId = new AtomicInteger(0);
private ConnectionMetadata connectionMetadata;
private WireMessage message;
private final Connection spannerConnection;
private Connection spannerConnection;

ConnectionHandler(ProxyServer server, Socket socket) {
super("ConnectionHandler-" + CONNECTION_HANDLER_ID_GENERATOR.incrementAndGet());
this.server = server;
this.socket = socket;
this.secret = new SecureRandom().nextInt();
String uri = server.getOptions().getConnectionURL();
if (uri.startsWith("jdbc:")) {
uri = uri.substring("jdbc:".length());
}
uri = appendPropertiesToUrl(uri, server.getProperties());
ConnectionOptions connectionOptions = ConnectionOptions.newBuilder().setUri(uri).build();
this.spannerConnection = connectionOptions.getConnection();
setDaemon(true);
logger.log(
Level.INFO,
Expand All @@ -100,6 +97,37 @@ public class ConnectionHandler extends Thread {
getName(), socket.getInetAddress().getHostAddress()));
}

@InternalApi
public void connectToSpanner(String database) {
OptionsMetadata options = getServer().getOptions();
String uri =
options.hasDefaultConnectionUrl()
? options.getDefaultConnectionUrl()
: options.buildConnectionURL(database);
if (uri.startsWith("jdbc:")) {
uri = uri.substring("jdbc:".length());
}
uri = appendPropertiesToUrl(uri, getServer().getProperties());
ConnectionOptions connectionOptions = ConnectionOptions.newBuilder().setUri(uri).build();
Connection spannerConnection = connectionOptions.getConnection();
try {
// Note: Calling getDialect() will cause a SpannerException if the connection itself is
// invalid, for example as a result of the credentials being wrong.
if (spannerConnection.getDialect() != Dialect.POSTGRESQL) {
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.INVALID_ARGUMENT,
String.format(
"The database uses dialect %s. Currently PGAdapter only supports connections to PostgreSQL dialect databases. "
+ "These can be created using https://cloud.google.com/spanner/docs/quickstart-console#postgresql",
spannerConnection.getDialect()));
}
} catch (SpannerException e) {
spannerConnection.close();
throw e;
}
this.spannerConnection = spannerConnection;
}

private String appendPropertiesToUrl(String url, Properties info) {
if (info == null || info.isEmpty()) {
return url;
Expand Down
22 changes: 2 additions & 20 deletions src/main/java/com/google/cloud/spanner/pgadapter/ProxyServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
package com.google.cloud.spanner.pgadapter;

import com.google.api.core.AbstractApiService;
import com.google.cloud.spanner.Dialect;
import com.google.cloud.spanner.ErrorCode;
import com.google.cloud.spanner.SpannerException;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.pgadapter.ConnectionHandler.QueryMode;
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata;
import com.google.cloud.spanner.pgadapter.metadata.OptionsMetadata.TextFormat;
Expand Down Expand Up @@ -221,23 +218,8 @@ private void handleConnectionError(SpannerException exception, Socket socket) {
*/
void createConnectionHandler(Socket socket) {
ConnectionHandler handler = new ConnectionHandler(this, socket);
try {
// Note: Calling getDialect() will cause a SpannerException if the connection itself is
// invalid, for example as a result of the credentials being wrong.
if (handler.getSpannerConnection().getDialect() != Dialect.POSTGRESQL) {
throw SpannerExceptionFactory.newSpannerException(
ErrorCode.INVALID_ARGUMENT,
String.format(
"The database uses dialect %s. Currently PGAdapter only supports connections to PostgreSQL dialect databases. "
+ "These can be created using https://cloud.google.com/spanner/docs/quickstart-console#postgresql",
handler.getSpannerConnection().getDialect()));
}
register(handler);
handler.start();
} catch (Exception e) {
handler.getSpannerConnection().close();
throw e;
}
register(handler);
handler.start();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.auth.oauth2.GoogleCredentials;
import com.google.cloud.spanner.SpannerExceptionFactory;
import com.google.cloud.spanner.pgadapter.Server;
import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import java.io.IOException;
import java.util.HashMap;
Expand Down Expand Up @@ -60,8 +61,9 @@ public class OptionsMetadata {
private static final String OPTION_SERVER_VERSION = "v";
private static final String OPTION_DEBUG_MODE = "debug";

private final CommandLine commandLine;
private final CommandMetadataParser commandMetadataParser;
private final String connectionURL;
private final String defaultConnectionUrl;
private final int proxyPort;
private final TextFormat textFormat;
private final boolean binaryFormat;
Expand All @@ -75,9 +77,14 @@ public class OptionsMetadata {
private final boolean debugMode;

public OptionsMetadata(String[] args) {
CommandLine commandLine = buildOptions(args);
this.commandLine = buildOptions(args);
this.commandMetadataParser = new CommandMetadataParser();
this.connectionURL = buildConnectionURL(commandLine);
if (this.commandLine.hasOption(OPTION_DATABASE_NAME)) {
this.defaultConnectionUrl =
buildConnectionURL(this.commandLine.getOptionValue(OPTION_DATABASE_NAME));
} else {
this.defaultConnectionUrl = null;
}
this.proxyPort = buildProxyPort(commandLine);
this.textFormat = TextFormat.POSTGRESQL;
this.binaryFormat = commandLine.hasOption(OPTION_BINARY_FORMAT);
Expand All @@ -94,16 +101,17 @@ public OptionsMetadata(String[] args) {
}

public OptionsMetadata(
String connectionURL,
String defaultConnectionUrl,
int proxyPort,
TextFormat textFormat,
boolean forceBinary,
boolean authenticate,
boolean requiresMatcher,
boolean replaceJdbcMetadataQueries,
JSONObject commandMetadata) {
this.commandLine = null;
this.commandMetadataParser = new CommandMetadataParser();
this.connectionURL = connectionURL;
this.defaultConnectionUrl = defaultConnectionUrl;
this.proxyPort = proxyPort;
this.textFormat = textFormat;
this.binaryFormat = forceBinary;
Expand Down Expand Up @@ -151,10 +159,9 @@ private int buildProxyPort(CommandLine commandLine) {
* Get credential file path from either command line or application default. If neither throw
* error.
*
* @param commandLine The parsed options for CLI
* @return The absolute path of the credentials file.
*/
private String buildCredentialsFile(CommandLine commandLine) {
private String buildCredentialsFile() {
if (!commandLine.hasOption(OPTION_CREDENTIALS_FILE)) {
try {
// This will throw an IOException if no default credentials are available.
Expand All @@ -169,10 +176,10 @@ private String buildCredentialsFile(CommandLine commandLine) {
/**
* Takes user inputs and builds a JDBC connection string from them.
*
* @param commandLine The parsed options for CLI
* @return The parsed JDBC connection string.
*/
private String buildConnectionURL(CommandLine commandLine) {
public String buildConnectionURL(String database) {
Preconditions.checkNotNull(database);
String host = commandLine.getOptionValue(OPTION_SPANNER_ENDPOINT, "");
String jdbcEndpoint;
if (host.isEmpty()) {
Expand All @@ -198,10 +205,10 @@ private String buildConnectionURL(CommandLine commandLine) {
+ ";userAgent=%s",
commandLine.getOptionValue(OPTION_PROJECT_ID),
commandLine.getOptionValue(OPTION_INSTANCE_ID),
commandLine.getOptionValue(OPTION_DATABASE_NAME),
database,
DEFAULT_USER_AGENT);

String credentials = buildCredentialsFile(commandLine);
String credentials = buildCredentialsFile();
if (!Strings.isNullOrEmpty(credentials)) {
url = String.format("%s;credentials=%s", url, credentials);
}
Expand Down Expand Up @@ -270,11 +277,15 @@ private CommandLine buildOptions(String[] args) {
"instance",
true,
"The id of the Spanner instance within the GCP project.");
options.addRequiredOption(
options.addOption(
OPTION_DATABASE_NAME,
"database",
true,
"The name of the Spanner database within the GCP project.");
"The default Spanner database within the GCP project to use. "
+ "If specified, PGAdapter will always connect to this database. "
+ "Any database name in the connection request from the client will be ignored. "
+ "Omit this option to be able to connect to different databases using a single "
+ "PGAdapter instance.");
options.addOption(
OPTION_CREDENTIALS_FILE,
"credentials-file",
Expand Down Expand Up @@ -386,8 +397,32 @@ public JSONObject getCommandMetadataJSON() {
return this.commandMetadataJSON;
}

/**
* @deprecated use {@link #getDefaultConnectionUrl()}
* @return the default connection URL that is used by the server.
*/
@Deprecated
public String getConnectionURL() {
return this.connectionURL;
return this.defaultConnectionUrl;
}

/**
* @return true if the server uses a default connection URL and ignores the database in a
* connection request
*/
public boolean hasDefaultConnectionUrl() {
return this.defaultConnectionUrl != null;
}

/**
* Returns the default connection URL that is used by the server. If a default connection URL has
* been set, the database parameter in a connection request will be ignored, and the database in
* this connection URL will be used instead.
*
* @return the default connection URL that is used by the server.
*/
public String getDefaultConnectionUrl() {
return defaultConnectionUrl;
}

public int getProxyPort() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class StartupMessage extends BootstrapMessage {
public static final int IDENTIFIER = 196608; // First Hextet: 3 (version), Second Hextet: 0

private final boolean authenticate;
private Map<String, String> parameters;
private final Map<String, String> parameters;

public StartupMessage(ConnectionHandler connection, int length) throws Exception {
super(connection, length);
Expand All @@ -44,6 +44,7 @@ public StartupMessage(ConnectionHandler connection, int length) throws Exception
@Override
protected void sendPayload() throws Exception {
if (!authenticate) {
this.connection.connectToSpanner(this.parameters.get("database"));
sendStartupMessage(
this.outputStream,
this.connection.getConnectionId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,11 @@ protected <T extends WireMessage> List<T> getWireMessagesOfType(Class<T> type) {

@BeforeClass
public static void startMockSpannerAndPgAdapterServers() throws Exception {
doStartMockSpannerAndPgAdapterServers(Collections.emptyList());
doStartMockSpannerAndPgAdapterServers("d", Collections.emptyList());
}

protected static void doStartMockSpannerAndPgAdapterServers(
Iterable<String> extraPGAdapterOptions) throws Exception {
String defaultDatabase, Iterable<String> extraPGAdapterOptions) throws Exception {
mockSpanner = new MockSpannerServiceImpl();
mockSpanner.setAbortProbability(0.0D); // We don't want any unpredictable aborted transactions.
mockSpanner.putStatementResult(StatementResult.query(SELECT1, SELECT1_RESULTSET));
Expand Down Expand Up @@ -303,24 +303,21 @@ public <ReqT, RespT> Listener<ReqT> interceptCall(
.start();

ImmutableList.Builder<String> argsListBuilder =
ImmutableList.<String>builder()
.add(
"-p",
"p",
"-i",
"i",
"-d",
"d",
"-jdbc",
"-debug",
"-c",
"", // empty credentials file, as we are using a plain text connection.
"-s",
"0", // port 0 to let the OS pick an available port
"-e",
String.format("localhost:%d", spannerServer.getPort()),
"-r",
"usePlainText=true;");
ImmutableList.<String>builder().add("-p", "p", "-i", "i");
if (defaultDatabase != null) {
argsListBuilder.add("-d", defaultDatabase);
}
argsListBuilder.add(
"-jdbc",
"-debug",
"-c",
"", // empty credentials file, as we are using a plain text connection.
"-s",
"0", // port 0 to let the OS pick an available port
"-e",
String.format("localhost:%d", spannerServer.getPort()),
"-r",
"usePlainText=true;");
argsListBuilder.addAll(extraPGAdapterOptions);
String[] args = argsListBuilder.build().toArray(new String[0]);
pgServer = new ProxyServer(new OptionsMetadata(args));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
import com.google.cloud.spanner.Statement;
import com.google.common.collect.ImmutableList;
import com.google.spanner.v1.ExecuteBatchDmlRequest;
import com.google.spanner.v1.ExecuteSqlRequest;
import com.google.spanner.v1.SessionName;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.List;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -41,7 +44,8 @@ public static void loadPgJdbcDriver() throws Exception {

@BeforeClass
public static void startMockSpannerAndPgAdapterServers() throws Exception {
doStartMockSpannerAndPgAdapterServers(ImmutableList.of("-q"));
// Start PGAdapter in psql mode without a default database.
doStartMockSpannerAndPgAdapterServers(null, ImmutableList.of("-q"));

mockSpanner.putStatementResults(
StatementResult.update(Statement.of(INSERT1), 1L),
Expand All @@ -53,16 +57,31 @@ public static void startMockSpannerAndPgAdapterServers() throws Exception {
* mode for queries and DML statements. This makes the JDBC driver behave in (much) the same way
* as psql.
*/
private String createUrl() {
private String createUrl(String database) {
return String.format(
"jdbc:postgresql://localhost:%d/?preferQueryMode=simple", pgServer.getLocalPort());
"jdbc:postgresql://localhost:%d/%s?preferQueryMode=simple",
pgServer.getLocalPort(), database);
}

@Test
public void testTwoInserts() throws SQLException {
String sql = "insert into foo values (1); insert into foo values (2);";
public void testConnectToDifferentDatabases() throws SQLException {
final ImmutableList<String> databases = ImmutableList.of("db1", "db2");
for (String database : databases) {
try (Connection connection = DriverManager.getConnection(createUrl(database))) {
connection.createStatement().execute(INSERT1);
}
}

try (Connection connection = DriverManager.getConnection(createUrl())) {
List<ExecuteSqlRequest> requests = mockSpanner.getRequestsOfType(ExecuteSqlRequest.class);
assertEquals(databases.size(), requests.size());
for (int i = 0; i < requests.size(); i++) {
assertEquals(databases.get(i), SessionName.parse(requests.get(i).getSession()).getDatabase());
}
}

@Test
public void testTwoInserts() throws SQLException {
try (Connection connection = DriverManager.getConnection(createUrl("my-db"))) {
connection.createStatement().execute(String.format("%s; %s", INSERT1, INSERT2));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ public class ITJdbcMetadataTest implements IntegrationTest {
public static void setup() {
testEnv.setUp();
database = testEnv.createDatabase(getDdlStatements());
testEnv.startPGAdapterServer(database.getId(), getAdditionalPGAdapterOptions());
testEnv.startPGAdapterServerWithDefaultDatabase(
database.getId(), getAdditionalPGAdapterOptions());
}

@AfterClass
Expand Down
Loading