Skip to content

Commit

Permalink
Merge pull request #1439 from marci4/fix/Issue1437
Browse files Browse the repository at this point in the history
Clone PerMessageDeflateExtension values correctly
  • Loading branch information
marci4 authored Nov 11, 2024
2 parents c4bf44e + dfca00b commit 3b29042
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
import org.java_websocket.extensions.CompressionExtension;
import org.java_websocket.extensions.ExtensionRequestData;
import org.java_websocket.extensions.IExtension;
import org.java_websocket.framing.BinaryFrame;
import org.java_websocket.framing.CloseFrame;
import org.java_websocket.framing.ContinuousFrame;
import org.java_websocket.framing.DataFrame;
import org.java_websocket.framing.Framedata;
import org.java_websocket.framing.FramedataImpl1;
import org.java_websocket.framing.TextFrame;

/**
* PerMessage Deflate Extension (<a href="https://tools.ietf.org/html/rfc7692#section-7">7&#46; The
Expand Down Expand Up @@ -53,23 +51,37 @@ public class PerMessageDeflateExtension extends CompressionExtension {
// For WebSocketClients, this variable holds the extension parameters that client himself has requested.
private Map<String, String> requestedParameters = new LinkedHashMap<>();

private Inflater inflater = new Inflater(true);
private Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
private final int compressionLevel;

public Inflater getInflater() {
return inflater;
}
private final Inflater inflater;
private final Deflater deflater;

public void setInflater(Inflater inflater) {
this.inflater = inflater;
/**
* Constructor for the PerMessage Deflate Extension (<a href="https://tools.ietf.org/html/rfc7692#section-7">7&#46; Thepermessage-deflate" Extension</a>)
*
* Uses {@link java.util.zip.Deflater#DEFAULT_COMPRESSION} as the compression level for the {@link java.util.zip.Deflater#Deflater(int)}
*/
public PerMessageDeflateExtension() {
this(Deflater.DEFAULT_COMPRESSION);
}

public Deflater getDeflater() {
return deflater;
/**
* Constructor for the PerMessage Deflate Extension (<a href="https://tools.ietf.org/html/rfc7692#section-7">7&#46; Thepermessage-deflate" Extension</a>)
*
* @param compressionLevel The compression level passed to the {@link java.util.zip.Deflater#Deflater(int)}
*/
public PerMessageDeflateExtension(int compressionLevel) {
this.compressionLevel = compressionLevel;
this.deflater = new Deflater(this.compressionLevel, true);
this.inflater = new Inflater(true);
}

public void setDeflater(Deflater deflater) {
this.deflater = deflater;
/**
* Get the compression level used for the compressor.
* @return the compression level
*/
public int getCompressionLevel() {
return this.compressionLevel;
}

/**
Expand Down Expand Up @@ -166,15 +178,15 @@ We can check the getRemaining() method to see whether the data we supplied has b
Note that this behavior doesn't occur if the message is "first compressed and then fragmented".
*/
if (inflater.getRemaining() > 0) {
inflater = new Inflater(true);
inflater.reset();
decompress(inputFrame.getPayloadData().array(), output);
}

if (inputFrame.isFin()) {
decompress(TAIL_BYTES, output);
// If context takeover is disabled, inflater can be reset.
if (clientNoContextTakeover) {
inflater = new Inflater(true);
inflater.reset();
}
}
} catch (DataFormatException e) {
Expand Down Expand Up @@ -244,8 +256,7 @@ public void encodeFrame(Framedata inputFrame) {
}

if (serverNoContextTakeover) {
deflater.end();
deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
deflater.reset();
}
}

Expand Down Expand Up @@ -330,7 +341,11 @@ public String getProvidedExtensionAsServer() {

@Override
public IExtension copyInstance() {
return new PerMessageDeflateExtension();
PerMessageDeflateExtension clone = new PerMessageDeflateExtension(this.getCompressionLevel());
clone.setThreshold(this.getThreshold());
clone.setClientNoContextTakeover(this.isClientNoContextTakeover());
clone.setServerNoContextTakeover(this.isServerNoContextTakeover());
return clone;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import static org.junit.Assert.fail;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.zip.Deflater;
import java.util.zip.Inflater;

import org.java_websocket.exceptions.InvalidDataException;
import org.java_websocket.extensions.permessage_deflate.PerMessageDeflateExtension;
import org.java_websocket.framing.BinaryFrame;
import org.java_websocket.framing.ContinuousFrame;
import org.java_websocket.framing.TextFrame;
import org.junit.Test;
Expand Down Expand Up @@ -51,6 +51,79 @@ public void testDecodeFrameIfRSVIsNotSet() throws InvalidDataException {
assertFalse(frame.isRSV1());
}

@Test
public void testDecodeFrameNoCompression() throws InvalidDataException {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(Deflater.NO_COMPRESSION);
deflateExtension.setThreshold(0);
String str = "This is a highly compressable text"
+ "This is a highly compressable text"
+ "This is a highly compressable text"
+ "This is a highly compressable text"
+ "This is a highly compressable text";
byte[] message = str.getBytes();
TextFrame frame = new TextFrame();
frame.setPayload(ByteBuffer.wrap(message));
deflateExtension.encodeFrame(frame);
byte[] payloadArray = frame.getPayloadData().array();
assertArrayEquals(message, Arrays.copyOfRange(payloadArray, 5,payloadArray.length-5));
assertTrue(frame.isRSV1());
deflateExtension.decodeFrame(frame);
assertArrayEquals(message, frame.getPayloadData().array());
}

@Test
public void testDecodeFrameBestSpeedCompression() throws InvalidDataException {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(Deflater.BEST_SPEED);
deflateExtension.setThreshold(0);
String str = "This is a highly compressable text"
+ "This is a highly compressable text"
+ "This is a highly compressable text"
+ "This is a highly compressable text"
+ "This is a highly compressable text";
byte[] message = str.getBytes();
TextFrame frame = new TextFrame();
frame.setPayload(ByteBuffer.wrap(message));

Deflater localDeflater = new Deflater(Deflater.BEST_SPEED,true);
localDeflater.setInput(ByteBuffer.wrap(message).array());
byte[] buffer = new byte[1024];
int bytesCompressed = localDeflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH);

deflateExtension.encodeFrame(frame);
byte[] payloadArray = frame.getPayloadData().array();
assertArrayEquals(Arrays.copyOfRange(buffer,0, bytesCompressed), Arrays.copyOfRange(payloadArray,0,payloadArray.length));
assertTrue(frame.isRSV1());
deflateExtension.decodeFrame(frame);
assertArrayEquals(message, frame.getPayloadData().array());
}

@Test
public void testDecodeFrameBestCompression() throws InvalidDataException {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension(Deflater.BEST_COMPRESSION);
deflateExtension.setThreshold(0);
String str = "This is a highly compressable text"
+ "This is a highly compressable text"
+ "This is a highly compressable text"
+ "This is a highly compressable text"
+ "This is a highly compressable text";
byte[] message = str.getBytes();
TextFrame frame = new TextFrame();
frame.setPayload(ByteBuffer.wrap(message));

Deflater localDeflater = new Deflater(Deflater.BEST_COMPRESSION,true);
localDeflater.setInput(ByteBuffer.wrap(message).array());
byte[] buffer = new byte[1024];
int bytesCompressed = localDeflater.deflate(buffer, 0, buffer.length, Deflater.SYNC_FLUSH);

deflateExtension.encodeFrame(frame);
byte[] payloadArray = frame.getPayloadData().array();
assertArrayEquals(Arrays.copyOfRange(buffer,0, bytesCompressed), Arrays.copyOfRange(payloadArray,0,payloadArray.length));
assertTrue(frame.isRSV1());
deflateExtension.decodeFrame(frame);
assertArrayEquals(message, frame.getPayloadData().array());
}


@Test
public void testEncodeFrame() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
Expand Down Expand Up @@ -191,35 +264,45 @@ public void testSetClientNoContextTakeover() {
@Test
public void testCopyInstance() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
IExtension newDeflateExtension = deflateExtension.copyInstance();
assertEquals(deflateExtension.toString(), newDeflateExtension.toString());
}
PerMessageDeflateExtension newDeflateExtension = (PerMessageDeflateExtension)deflateExtension.copyInstance();
assertEquals("PerMessageDeflateExtension", newDeflateExtension.toString());
// Also check the values
assertEquals(deflateExtension.getThreshold(), newDeflateExtension.getThreshold());
assertEquals(deflateExtension.isClientNoContextTakeover(), newDeflateExtension.isClientNoContextTakeover());
assertEquals(deflateExtension.isServerNoContextTakeover(), newDeflateExtension.isServerNoContextTakeover());
assertEquals(deflateExtension.getCompressionLevel(), newDeflateExtension.getCompressionLevel());

@Test
public void testGetInflater() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
assertEquals(deflateExtension.getInflater().getRemaining(), new Inflater(true).getRemaining());
}

@Test
public void testSetInflater() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
deflateExtension.setInflater(new Inflater(false));
assertEquals(deflateExtension.getInflater().getRemaining(), new Inflater(false).getRemaining());
}
deflateExtension = new PerMessageDeflateExtension(Deflater.BEST_COMPRESSION);
deflateExtension.setThreshold(512);
deflateExtension.setServerNoContextTakeover(false);
deflateExtension.setClientNoContextTakeover(true);
newDeflateExtension = (PerMessageDeflateExtension)deflateExtension.copyInstance();

@Test
public void testGetDeflater() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
assertEquals(deflateExtension.getDeflater().finished(),
new Deflater(Deflater.DEFAULT_COMPRESSION, true).finished());
assertEquals(deflateExtension.getThreshold(), newDeflateExtension.getThreshold());
assertEquals(deflateExtension.isClientNoContextTakeover(), newDeflateExtension.isClientNoContextTakeover());
assertEquals(deflateExtension.isServerNoContextTakeover(), newDeflateExtension.isServerNoContextTakeover());
assertEquals(deflateExtension.getCompressionLevel(), newDeflateExtension.getCompressionLevel());


deflateExtension = new PerMessageDeflateExtension(Deflater.NO_COMPRESSION);
deflateExtension.setThreshold(64);
deflateExtension.setServerNoContextTakeover(true);
deflateExtension.setClientNoContextTakeover(false);
newDeflateExtension = (PerMessageDeflateExtension)deflateExtension.copyInstance();

assertEquals(deflateExtension.getThreshold(), newDeflateExtension.getThreshold());
assertEquals(deflateExtension.isClientNoContextTakeover(), newDeflateExtension.isClientNoContextTakeover());
assertEquals(deflateExtension.isServerNoContextTakeover(), newDeflateExtension.isServerNoContextTakeover());
assertEquals(deflateExtension.getCompressionLevel(), newDeflateExtension.getCompressionLevel());
}

@Test
public void testSetDeflater() {
public void testDefaults() {
PerMessageDeflateExtension deflateExtension = new PerMessageDeflateExtension();
deflateExtension.setDeflater(new Deflater(Deflater.DEFAULT_COMPRESSION, false));
assertEquals(deflateExtension.getDeflater().finished(),
new Deflater(Deflater.DEFAULT_COMPRESSION, false).finished());
assertFalse(deflateExtension.isClientNoContextTakeover());
assertTrue(deflateExtension.isServerNoContextTakeover());
assertEquals(1024, deflateExtension.getThreshold());
assertEquals(Deflater.DEFAULT_COMPRESSION, deflateExtension.getCompressionLevel());
}
}

0 comments on commit 3b29042

Please sign in to comment.