Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: rock connection #764

Merged
merged 6 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions r/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ dependencies {
implementation 'com.google.auto.value:auto-value-annotations:1.10.4'
implementation 'org.apache.commons:commons-text:1.11.0'
implementation 'org.json:json:20240303'
implementation 'org.apache.httpcomponents.client5:httpclient5:5.3.1'

//test
testImplementation('org.springframework.boot:spring-boot-starter-test') {
Expand Down
172 changes: 101 additions & 71 deletions r/src/main/java/org/molgenis/r/rock/RockConnection.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,117 +4,130 @@
import com.google.common.collect.Lists;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
import java.util.Collections;
import java.util.function.Consumer;
import org.molgenis.r.RServerConnection;
import org.molgenis.r.RServerException;
import org.molgenis.r.RServerResult;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.web.client.ClientHttpRequestFactories;
import org.springframework.boot.web.client.ClientHttpRequestFactorySettings;
import org.springframework.core.io.InputStreamResource;
import org.springframework.http.*;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.client.ResponseExtractor;
import org.springframework.web.client.RestClient;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponentsBuilder;

public class RockConnection implements RServerConnection {

private static final Logger logger = LoggerFactory.getLogger(RockConnection.class);
public static final String FILE_DOWNLOAD_FAILED = "File download failed: ";
public static final MediaType MEDIATYPE_APPLICATION_RSCRIPT =
MediaType.valueOf("application/x-rscript");
private static final String UPLOAD_FAILED = "File upload failed: ";
private static final String UPLOAD_ENDPOINT = "/_upload";
private static final String EVAL_ENDPOINT = "/_eval";
public static final String DOWNLOAD_ENDPOINT = "/_download";
private static final String PATH = "path";
private static final String OVERWRITE = "overwrite";
private static final String FILE = "file";

private String rockSessionId;

private final RockApplication application;
private final RestClient restClient;

public RockConnection(RockApplication application) throws RServerException {
this.application = application;
this.restClient = getRestClient();
openSession();
}

@Override
public RServerResult eval(String expr, boolean serialized) throws RServerException {
RestTemplate restTemplate = new RestTemplate();
HttpHeaders headers = createHeaders();
headers.setContentType(MediaType.valueOf("application/x-rscript"));

String serverUrl = getRSessionResourceUrl("/_eval");
public RServerResult eval(String expr, boolean serialized) {
String serverUrl = getRSessionResourceUrl(EVAL_ENDPOINT);
if (serialized) {
// accept application/octet-stream
headers.setAccept(
Lists.newArrayList(MediaType.APPLICATION_OCTET_STREAM, MediaType.APPLICATION_JSON));
ResponseEntity<byte[]> response =
restTemplate.exchange(
serverUrl, HttpMethod.POST, new HttpEntity<>(expr, headers), byte[].class);
return new RockResult(response.getBody());
RestClient.RequestBodySpec request =
createRequestForPost(serverUrl, expr, MEDIATYPE_APPLICATION_RSCRIPT);
byte[] responseBody =
request
.headers(
httpHeaders -> {
httpHeaders.setAccept(
Lists.newArrayList(
MediaType.APPLICATION_OCTET_STREAM, MediaType.APPLICATION_JSON));
})
.retrieve()
.body(byte[].class);
return new RockResult(responseBody);
} else {
headers.setAccept(Lists.newArrayList(MediaType.APPLICATION_JSON));
ResponseEntity<String> response =
restTemplate.exchange(
serverUrl, HttpMethod.POST, new HttpEntity<>(expr, headers), String.class);
String jsonSource = response.getBody();
return new RockResult(jsonSource);
RestClient.RequestBodySpec request =
createRequestForPost(serverUrl, expr, MEDIATYPE_APPLICATION_RSCRIPT);
RestClient.ResponseSpec resp =
request
.headers(
httpHeaders -> {
httpHeaders.setAccept(Lists.newArrayList(MediaType.APPLICATION_JSON));
})
.retrieve();
String responseBody = resp.body(String.class);
return new RockResult(responseBody);
}
}

@Override
public void writeFile(String fileName, InputStream in) throws RServerException {
try {
HttpHeaders headers = createHeaders();
headers.setContentType(MediaType.MULTIPART_FORM_DATA);
MultiValueMap<String, Object> body = new LinkedMultiValueMap<>();
body.add("file", new MultiPartInputStreamResource(in, fileName));
HttpEntity<MultiValueMap<String, Object>> requestEntity = new HttpEntity<>(body, headers);
body.add(FILE, new MultiPartInputStreamResource(in, fileName));

String serverUrl = getRSessionResourceUrl("/_upload");
String serverUrl = getRSessionResourceUrl(UPLOAD_ENDPOINT);
UriComponentsBuilder builder =
UriComponentsBuilder.fromHttpUrl(serverUrl)
.queryParam("path", fileName)
.queryParam("overwrite", true);
.queryParam(PATH, fileName)
.queryParam(OVERWRITE, true);

ResponseEntity<Void> response =
postToRestClient(builder.toUriString(), body, MediaType.MULTIPART_FORM_DATA)
.toBodilessEntity();

RestTemplate restTemplate = new RestTemplate();
ResponseEntity<String> response =
restTemplate.postForEntity(builder.toUriString(), requestEntity, String.class);
if (!response.getStatusCode().is2xxSuccessful()) {
logger.error("File upload to {} failed: {}", serverUrl, response.getStatusCode());
throw new RockServerException("File upload failed: " + response.getStatusCode());
throw new RockServerException(UPLOAD_FAILED + response.getStatusCode());
}
} catch (RestClientException e) {
throw new RockServerException("File upload failed", e);
throw new RockServerException(UPLOAD_FAILED, e);
}
}

@Override
public void readFile(String fileName, Consumer<InputStream> inputStreamConsumer)
throws RServerException {
try {
HttpHeaders headers = createHeaders();
headers.setAccept(Collections.singletonList(MediaType.ALL));
String serverUrl = getRSessionResourceUrl(DOWNLOAD_ENDPOINT);

String serverUrl = getRSessionResourceUrl("/_download");
UriComponentsBuilder builder =
UriComponentsBuilder.fromHttpUrl(serverUrl).queryParam("path", fileName);

RestTemplate restTemplate = new RestTemplate();
restTemplate.execute(
builder.build().toUri(),
HttpMethod.GET,
request -> request.getHeaders().putAll(headers),
(ResponseExtractor<Void>)
response -> {
UriComponentsBuilder.fromHttpUrl(serverUrl).queryParam(PATH, fileName);
restClient
.get()
.uri(builder.build().toUri())
.exchange(
(request, response) -> {
if (!response.getStatusCode().is2xxSuccessful()) {
logger.error(
"File download from {} failed: {}", serverUrl, response.getStatusCode());
throw new RuntimeException("File download failed: " + response.getStatusText());
throw new RuntimeException(FILE_DOWNLOAD_FAILED + response.getStatusText());
} else {
inputStreamConsumer.accept(response.getBody());
}
return null;
});
} catch (RestClientException e) {
throw new RockServerException("File download failed", e);
throw new RockServerException(FILE_DOWNLOAD_FAILED, e);
}
}

Expand All @@ -123,12 +136,7 @@ public boolean close() {
if (Strings.isNullOrEmpty(rockSessionId)) return true;

try {
RestTemplate restTemplate = new RestTemplate();
restTemplate.exchange(
getRSessionResourceUrl(""),
HttpMethod.DELETE,
new HttpEntity<>(createHeaders()),
Void.class);
restClient.delete().uri(getRSessionResourceUrl("")).retrieve().toBodilessEntity();
this.rockSessionId = null;
return true;
} catch (RestClientException e) {
Expand All @@ -141,14 +149,13 @@ public boolean close() {

private void openSession() throws RServerException {
try {
RestTemplate restTemplate = new RestTemplate();
ResponseEntity<RockSessionInfo> response =
restTemplate.exchange(
getRSessionsResourceUrl(),
HttpMethod.POST,
new HttpEntity<>(createHeaders()),
RockSessionInfo.class);
RockSessionInfo info = response.getBody();
RockSessionInfo info =
postToRestClient(
getRSessionsResourceUrl(),
new LinkedMultiValueMap<>(),
MediaType.APPLICATION_JSON)
.body(RockSessionInfo.class);
assert info != null;
this.rockSessionId = info.getId();
} catch (RestClientException e) {
throw new RockServerException("Failure when opening a Rock R session", e);
Expand All @@ -163,15 +170,38 @@ private String getRSessionResourceUrl(String path) {
return String.format("%s/r/session/%s%s", application.getUrl(), rockSessionId, path);
}

private HttpHeaders createHeaders() {
return new HttpHeaders() {
{
String auth = application.getUser() + ":" + application.getPassword();
byte[] encodedAuth = Base64.getEncoder().encode(auth.getBytes(StandardCharsets.UTF_8));
String authHeader = "Basic " + new String(encodedAuth);
add("Authorization", authHeader);
}
};
private String getAuthHeader() {
String auth = application.getUser() + ":" + application.getPassword();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

password?

Copy link
Member

@mswertz mswertz Sep 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is such that a request from rock can use the password to authenticate so we know if the proces (that armadillo started via rock) has access to the data (from that user)?
Kinda scary, is there not a toke we could use instead? Or a way to at least encrypt this? Or is the internal communication also SSL? Because untrusted R code might now abuse it? Or to go even further, we could simply use a session number to reconnect because the password doesn't add much security?

byte[] encodedAuth = Base64.getEncoder().encode(auth.getBytes(StandardCharsets.UTF_8));
return "Basic " + new String(encodedAuth);
}

private RestClient getRestClient() {
String serverUrl = getRSessionResourceUrl(UPLOAD_ENDPOINT);
String authHeader = getAuthHeader();
ClientHttpRequestFactorySettings settings =
ClientHttpRequestFactorySettings.DEFAULTS
.withConnectTimeout(Duration.ofSeconds(300L))
.withReadTimeout(Duration.ofSeconds(900L));
return RestClient.builder()
.baseUrl(serverUrl)
.requestFactory(ClientHttpRequestFactories.get(settings))
.defaultHeaders(
httpHeaders -> {
httpHeaders.setBasicAuth(authHeader);
httpHeaders.set(HttpHeaders.AUTHORIZATION, authHeader);
})
.build();
}

private RestClient.RequestBodySpec createRequestForPost(
String uriString, Object body, MediaType contentType) {
return restClient.post().uri(uriString).contentType(contentType).body(body);
}

private RestClient.ResponseSpec postToRestClient(
String uriString, Object body, MediaType contentType) {
return createRequestForPost(uriString, body, contentType).retrieve();
}

private static class MultiPartInputStreamResource extends InputStreamResource {
Expand Down