Skip to content

Commit

Permalink
added option to register packets to multiple groups with different IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
Pesekjak committed May 27, 2024
1 parent b096773 commit 23b1d1c
Show file tree
Hide file tree
Showing 15 changed files with 217 additions and 46 deletions.
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Group and version
group = org.machinemc.paklet
version = 1.1.1
version = 1.2

# Dependency versions
jetbrainsAnnotations = 24.1.0
Expand Down
5 changes: 4 additions & 1 deletion paklet-api/src/main/java/org/machinemc/paklet/Packet.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,13 @@
* <p>
* Compare to packet IDs, packets do not store information about their groups when
* they are serialized using {@link PacketFactory}.
* <p>
* One packet can have multiple different groups, for registering the packet to each group
* under different ID, see {@link Packet#DYNAMIC_PACKET} and {@link PacketRegistrationContext}.
*
* @return group of the packet
*/
String group() default DEFAULT;
String[] group() default DEFAULT;

/**
* Specifies class that is used as catalogue (identifier) for the packet.
Expand Down
39 changes: 24 additions & 15 deletions paklet-api/src/main/java/org/machinemc/paklet/PacketFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import org.jetbrains.annotations.Unmodifiable;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

Expand Down Expand Up @@ -46,7 +48,7 @@ public interface PacketFactory {
* @param reader reader for the packet
* @param writer writer for the packet
* @param packetID packet ID used to register this packet
* @param group group name of this packet
* @param group group the packet should be registered to
* @param <PacketType> packet
*
* @throws IllegalArgumentException if packet with the same ID and group is already registered
Expand Down Expand Up @@ -82,15 +84,19 @@ public interface PacketFactory {
* Removes packet with given type.
*
* @param packetClass class of the packet
* @return whether the packet has been removed successfully
* @return array of groups from where the packet has been removed
* @param <PacketType> packet
*/
default <PacketType> boolean removePacket(Class<PacketType> packetClass) {
int id = getPacketID(packetClass);
if (id == -1) return false;
String group = getPacketGroup(packetClass).orElse(null);
if (group == null) return false;
return removePacket(id, group);
default <PacketType> String[] removePacket(Class<PacketType> packetClass) {
String[] groups = getPacketGroup(packetClass).orElse(new String[0]);
List<String> removed = new ArrayList<>();
for (String group : groups) {
int id = getPacketID(packetClass, group);
if (id == -1) continue;
if (!removePacket(id, group)) continue;
removed.add(group);
}
return removed.toArray(String[]::new);
}

/**
Expand All @@ -115,11 +121,12 @@ default <PacketType> boolean removePacket(Class<PacketType> packetClass) {
* Returns ID for given registered packet class.
*
* @param packetClass packet class
* @param group group of the packet
* @return packet ID of given packet class, or {@code -1} if the class is
* not registered
* @param <PacketType> packet
*/
<PacketType> int getPacketID(Class<PacketType> packetClass);
<PacketType> int getPacketID(Class<PacketType> packetClass, String group);

/**
* Returns group for given registered packet class.
Expand All @@ -128,17 +135,18 @@ default <PacketType> boolean removePacket(Class<PacketType> packetClass) {
* @return packet class of given packet class
* @param <PacketType> packet
*/
<PacketType> Optional<String> getPacketGroup(Class<PacketType> packetClass);
<PacketType> Optional<String[]> getPacketGroup(Class<PacketType> packetClass);

/**
* Checks whether the given packet class is registered.
* Checks whether the given packet class is registered in given group.
*
* @param packetClass packet class
* @return whether the packet class is registered
* @param group group
* @return whether the packet class is registered in the group
* @param <PacketType> packet
*/
default <PacketType> boolean isRegistered(Class<PacketType> packetClass) {
return getPacketID(packetClass) != -1;
default <PacketType> boolean isRegistered(Class<PacketType> packetClass, String group) {
return getPacketID(packetClass, group) != -1;
}

/**
Expand Down Expand Up @@ -186,9 +194,10 @@ default boolean isRegistered(int packetID, String group) {
* Writes packet to the provided data visitor.
*
* @param packet packet to write
* @param group group
* @param visitor visitor
* @param <PacketType> packet
*/
<PacketType> void write(PacketType packet, DataVisitor visitor);
<PacketType> void write(PacketType packet, String group, DataVisitor visitor);

}
6 changes: 3 additions & 3 deletions paklet-api/src/main/java/org/machinemc/paklet/PacketID.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
import java.lang.annotation.Target;

/**
* Used to annotate static int fields of packet classes with
* dynamic packet IDs.
* Used to annotate static int fields (or no argument methods returning int)
* of packet classes to compute dynamic packet IDs.
*
* @see Packet#DYNAMIC_PACKET
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
@Target({ElementType.FIELD, ElementType.METHOD})
public @interface PacketID {
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package org.machinemc.paklet;

import java.util.Objects;

/**
* Allows to access the additional information during dynamic packet registration
* for packets with {@link Packet#DYNAMIC_PACKET} ID.
*/
public class PacketRegistrationContext {

protected static final ThreadLocal<PacketRegistrationContext> threadLocal = ThreadLocal.withInitial(PacketRegistrationContext::new);

private final String group;

/**
* Returns the current packet registration context.
* <p>
* This can be called only within methods annotated with {@link PacketID}, resolving
* packet IDs for packets with {@link Packet#DYNAMIC_PACKET} ID.
*
* @return current packet registration context
*/
public static PacketRegistrationContext get() {
PacketRegistrationContext context = threadLocal.get();
if (context.group == null) throw new RuntimeException("Called outside of dynamic packet registration context");
return context;
}

private PacketRegistrationContext() {
group = null;
}

/**
* Creates new packet registration context with given group.
*
* @param group group
*/
protected PacketRegistrationContext(String group) {
this.group = Objects.requireNonNull(group, "Packet group can not be null");
}

/**
* Returns packet group used to register the packet.
*
* @return current packet group
*/
public String getPacketGroup() {
return group;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@

import java.io.InputStream;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Exchanger;
import java.util.function.Function;

/**
Expand All @@ -23,7 +26,7 @@ public class PacketFactoryImpl implements PacketFactory {
private final PacketEncoder encoder;
private final SerializerProvider serializerProvider;

private final Map<Class<?>, PacketGroup> packet2Group = new ConcurrentHashMap<>();
private final Map<Class<?>, List<PacketGroup>> packet2Group = new ConcurrentHashMap<>();
private final Map<String, PacketGroup> groups = new ConcurrentHashMap<>();

public PacketFactoryImpl(PacketEncoder encoder, SerializerProvider serializerProvider) {
Expand Down Expand Up @@ -54,17 +57,23 @@ public <PacketType> void addPacket(Class<PacketType> packetClass) {
public <PacketType> void addPacket(Class<PacketType> packetClass, PacketReader<PacketType> reader, PacketWriter<PacketType> writer) {
Packet annotation = packetClass.getAnnotation(Packet.class);
if (annotation == null) throw new IllegalArgumentException("Class " + packetClass.getName() + " is not a valid packet class");
addPacket(packetClass, reader, writer, computePacketID(packetClass), annotation.group());
for (String group : annotation.group())
addPacket(packetClass, reader, writer, computePacketID(packetClass, group), group);
}

@Override
public <PacketType> void addPacket(Class<PacketType> packetClass, PacketReader<PacketType> reader, PacketWriter<PacketType> writer, int packetID, String group) {
if (packetClass == null || reader == null || writer == null) throw new NullPointerException();
if (packetID == Packet.INVALID_PACKET) return; // invalid packets should be ignored
if (packetID < 0) throw new IllegalArgumentException("Invalid packet ID for packet " + packetClass.getName());
PacketGroup packetGroup = groups.computeIfAbsent(group, PacketGroup::new);

PacketGroup packetGroup = this.groups.computeIfAbsent(group, PacketGroup::new);
packetGroup.addPacket(packetID, packetClass, reader, writer); // throws illegal exception if packet ID already exists
packet2Group.put(packetClass, packetGroup);

List<PacketGroup> groupsList = new CopyOnWriteArrayList<>(packet2Group.computeIfAbsent(packetClass, __ -> new ArrayList<>()));
groupsList.add(packetGroup);

packet2Group.put(packetClass, Collections.unmodifiableList(groupsList));
}

@Override
Expand Down Expand Up @@ -96,15 +105,15 @@ public <PacketType> Optional<Class<PacketType>> getPacketClass(int packetID, Str
}

@Override
public <PacketType> int getPacketID(Class<PacketType> packetClass) {
PacketGroup packetGroup = packet2Group.get(packetClass);
public <PacketType> int getPacketID(Class<PacketType> packetClass, String group) {
PacketGroup packetGroup = groups.get(group);
if (packetGroup == null) return -1;
return packetGroup.getID(packetClass);
}

@Override
public <PacketType> Optional<String> getPacketGroup(Class<PacketType> packetClass) {
return Optional.ofNullable(packet2Group.get(packetClass)).map(PacketGroup::getName);
public <PacketType> Optional<String[]> getPacketGroup(Class<PacketType> packetClass) {
return Optional.ofNullable(packet2Group.get(packetClass)).map(l -> l.stream().map(PacketGroup::getName).toArray(String[]::new));
}

@Override
Expand Down Expand Up @@ -138,10 +147,10 @@ public <PacketType> PacketType create(int packetID, String group, DataVisitor vi

@Override
@SuppressWarnings("unchecked")
public <PacketType> void write(PacketType packet, DataVisitor visitor) {
public <PacketType> void write(PacketType packet, String group, DataVisitor visitor) {
Class<?> packetClass = packet.getClass();
PacketGroup packetGroup = packet2Group.get(packetClass);
if (packetGroup == null) throw new NullPointerException("Packet " + packetClass.getName() + " is not assigned to any group");
PacketGroup packetGroup = groups.get(group);
if (packetGroup == null) throw new NullPointerException("Group " + group + " is not registered");

int packetID = packetGroup.getID(packetClass);
if (packetID < 0) throw new IllegalArgumentException("Invalid packet ID: " + packetID);
Expand All @@ -156,26 +165,64 @@ public <PacketType> void write(PacketType packet, DataVisitor visitor) {
encoder.encode(visitor, serializerProvider, packetGroup.getName(), new PacketEncoder.Encoded(packetID, packetData));
}

private int computePacketID(Class<?> packetClass) {
private int computePacketID(Class<?> packetClass, String group) {
Packet annotation = packetClass.getAnnotation(Packet.class);
if (annotation == null) throw new IllegalArgumentException("Class " + packetClass.getName() + " is not a valid packet class");

if (annotation.id() == Packet.INVALID_PACKET) return Packet.INVALID_PACKET;

if (annotation.id() == Packet.DYNAMIC_PACKET) {

Field[] packetIDFields = Arrays.stream(packetClass.getDeclaredFields())
.filter(f -> Modifier.isStatic(f.getModifiers()))
.filter(f -> f.getType().equals(int.class))
.filter(f -> f.isAnnotationPresent(PacketID.class))
.toArray(Field[]::new);
if (packetIDFields.length == 0) throw new IllegalStateException("Class " + packetClass.getName() + " is missing packet ID field");
if (packetIDFields.length > 1) throw new IllegalStateException("Class " + packetClass.getName() + " has more than one packet ID field");
if (packetIDFields.length == 1) {
try {
packetIDFields[0].setAccessible(true);
return checkPacketID((int) packetIDFields[0].get(null));
} catch (Exception exception) {
throw new RuntimeException(exception);
}
}

Method[] packetIDMethods = Arrays.stream(packetClass.getDeclaredMethods())
.filter(m -> Modifier.isStatic(m.getModifiers()))
.filter(m -> m.getReturnType().equals(int.class))
.filter(m -> m.getParameterTypes().length == 0)
.filter(m -> m.isAnnotationPresent(PacketID.class))
.toArray(Method[]::new);
if (packetIDMethods.length == 0) throw new IllegalStateException("Class " + packetClass.getName() + " is missing packet ID field or method");
if (packetIDMethods.length > 1) throw new IllegalStateException("Class " + packetClass.getName() + " has more than one packet ID method");
try {
packetIDFields[0].setAccessible(true);
return (int) packetIDFields[0].get(null);
packetIDMethods[0].setAccessible(true);
Exchanger<Integer> idResolver = new Exchanger<>();
Thread.ofVirtual().start(() -> {
try {
PacketRegistrationContext.threadLocal.set(new PacketRegistrationContext(group));
idResolver.exchange((int) packetIDMethods[0].invoke(null));
} catch (Throwable throwable) {
try {
idResolver.exchange(Packet.DYNAMIC_PACKET);
} catch (InterruptedException exception) {
throw new RuntimeException(exception);
}
}
});
int resolved = idResolver.exchange(null);
return checkPacketID(resolved);
} catch (Exception exception) {
throw new RuntimeException(exception);
}
}
return annotation.id();
return checkPacketID(annotation.id());
}

private int checkPacketID(int id) {
if (id > 0 || id == Packet.INVALID_PACKET) return id;
throw new RuntimeException("Invalid packet ID: " + id);
}

static class PacketGroup {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public void doNotPrefixTest() {
packet.signature[1] = 2;
packet.signature[2] = 3;

factory.write(packet, visitor);
factory.write(packet, Packet.DEFAULT, visitor);

assert visitor.writerIndex() == 257;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public void customSerializerTest() {
CustomSerializerPacket packet = new CustomSerializerPacket();
packet.content = "Hello World";

factory.write(packet, visitor);
factory.write(packet, Packet.DEFAULT, visitor);
CustomSerializerPacket packetClone = factory.create(Packet.DEFAULT, visitor);

assert packetClone.content.equals(packet.content);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package org.machinemc.paklet.test;

import io.netty.buffer.Unpooled;
import org.junit.jupiter.api.Test;
import org.machinemc.paklet.DataVisitor;
import org.machinemc.paklet.PacketFactory;
import org.machinemc.paklet.netty.NettyDataVisitor;
import org.machinemc.paklet.serialization.VarIntSerializer;
import org.machinemc.paklet.test.packet.DynamicPacket;

public class DynamicPacketTest {

@Test
public void testDynamicPacketIDs() {
PacketFactory factory = TestUtil.createFactory();

assert factory.getPacketID(DynamicPacket.class, "one") == 21;
assert factory.getPacketID(DynamicPacket.class, "two") == 22;
assert factory.getPacketID(DynamicPacket.class, "three") == 23;

DataVisitor visitor = new NettyDataVisitor(Unpooled.buffer());

DynamicPacket packet = new DynamicPacket();
packet.value = 15;

factory.write(packet, "two", visitor);

assert visitor.read(null, new VarIntSerializer()) == 22;

visitor.readerIndex(0);

DynamicPacket packetClone = factory.create("two", visitor);

assert packetClone.value == packet.value;
}

}
Loading

0 comments on commit 23b1d1c

Please sign in to comment.