Skip to content

Commit

Permalink
before fixing PaxeNetworkTest
Browse files Browse the repository at this point in the history
  • Loading branch information
simbo1905 committed Jan 19, 2025
1 parent aacd7be commit 4d6d7e9
Showing 1 changed file with 27 additions and 19 deletions.
46 changes: 27 additions & 19 deletions trex-paxe/src/main/java/com/github/trex_paxos/paxe/PaxeNetwork.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.github.trex_paxos.paxe;

import com.github.trex_paxos.network.*;
import static com.github.trex_paxos.paxe.PaxeLogger.LOGGER;

import java.io.IOException;
import java.net.InetSocketAddress;
Expand All @@ -12,9 +11,12 @@
import java.nio.channels.Selector;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.function.Consumer;
import java.util.function.Supplier;

import static com.github.trex_paxos.paxe.PaxeLogger.LOGGER;

public final class PaxeNetwork implements NetworkLayer, AutoCloseable {
private static final int MAX_PACKET_SIZE = 65507; // UDP max size
private static final int HEADER_SIZE = 8; // from(2) + to(2) + channel(2) + length(2)
Expand All @@ -30,8 +32,11 @@ public final class PaxeNetwork implements NetworkLayer, AutoCloseable {
volatile boolean running;
private Thread receiver;

record DirectBuffer(ByteBuffer sendBuffer, ByteBuffer receiveBuffer) {}
record PendingMessage(Channel channel, byte[] serializedData) {}
record DirectBuffer(ByteBuffer sendBuffer, ByteBuffer receiveBuffer) {
}

record PendingMessage(Channel channel, byte[] serializedData) {
}

private final Map<NodeId, Queue<PendingMessage>> pendingMessages = new ConcurrentHashMap<>();
private static final int MAX_BUFFERED_BYTES = 65000;
Expand Down Expand Up @@ -82,29 +87,29 @@ public <T> void send(Channel channel, NodeId to, T msg) {
LOGGER.finest(() -> String.format("Encrypting message for %d, key %s", to.id(),
finalKey != null ? "present" : "missing"));
if (key == null) {
// Get handshake message
// Buffer message
byte[] serialized = serializeMessage(msg);
Queue<PendingMessage> queue = pendingMessages.computeIfAbsent(to, k -> new ConcurrentLinkedQueue<>());
int queueBytes = queue.stream().mapToInt(m -> m.serializedData().length).sum();

LOGGER.finest(() -> String.format("Buffering %d bytes for %s (total %d)", serialized.length, to, queueBytes));

if (queueBytes + serialized.length > MAX_BUFFERED_BYTES) {
throw new IllegalStateException("Message buffer full for " + to);
}
queue.add(new PendingMessage(channel, serialized));

// Initiate handshake
var handshake = keyManager.initiateHandshake(to);
if (handshake.isPresent()) {
// Send handshake on KEY_EXCHANGE channel
send(Channel.KEY_EXCHANGE, to, handshake.get());
}
// Wait briefly for key establishment
try {
// FIXME should not do this as node may be down
Thread.sleep(10);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
// Retry getting key
key = keyManager.sessionKeys.get(to);
if (key == null) {
throw new IllegalStateException("Failed to establish session key with " + to);
}
return;
}
payload = PaxeCrypto.encrypt(serializeMessage(msg), key);
}

buffer.putShort((short)payload.length);
buffer.putShort((short) payload.length);
buffer.put(payload);
buffer.flip();

Expand Down Expand Up @@ -171,6 +176,7 @@ private void readFromChannel() throws IOException {
SocketAddress sender = channel.receive(buffer);
if (sender == null) continue;


buffer.flip();
if (buffer.remaining() < HEADER_SIZE) {
LOGGER.finest(() -> String.format("Received undersized packet from %s: %d bytes", sender, buffer.remaining()));
Expand All @@ -193,6 +199,8 @@ private void readFromChannel() throws IOException {

Channel msgChannel = new Channel(channelId);

LOGGER.finest(() -> String.format("Processing message from %d on channel %s", fromId, msgChannel));

// Extract payload
byte[] payload = new byte[length];
buffer.get(payload);
Expand Down Expand Up @@ -238,7 +246,7 @@ void handleKeyExchange(short fromId, byte[] payload) {
}

private byte[] serializeMessage(Object msg) {
return ((byte[])msg);
return ((byte[]) msg);
}

private SocketAddress resolveAddress(NodeId to) {
Expand Down

0 comments on commit 4d6d7e9

Please sign in to comment.