This commit is contained in:
@@ -11,6 +11,7 @@ import com.wentch.redkale.net.http.WebSocketPacket.PacketType;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.channels.*;
|
||||
import java.security.SecureRandom;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.logging.*;
|
||||
@@ -48,7 +49,7 @@ public class WebSocketRunner implements Runnable {
|
||||
this.channel = channel;
|
||||
webSocket.runner = this;
|
||||
this.coder.logger = context.getLogger();
|
||||
this.coder.debugable = context.getLogger().isLoggable(Level.FINEST);
|
||||
this.coder.debugable = context.getLogger().isLoggable(Level.FINEST);
|
||||
this.readBuffer = context.pollBuffer();
|
||||
this.writeBuffer = context.pollBuffer();
|
||||
}
|
||||
@@ -62,6 +63,10 @@ public class WebSocketRunner implements Runnable {
|
||||
channel.setReadTimeoutSecond(300); //读取超时5分钟
|
||||
if (channel.isOpen()) {
|
||||
channel.read(readBuffer, null, new CompletionHandler<Integer, Void>() {
|
||||
|
||||
//当接收的数据流长度大于ByteBuffer长度时, 则需要额外的ByteBuffer 辅助;
|
||||
private final List<ByteBuffer> readBuffers = new ArrayList<>();
|
||||
|
||||
@Override
|
||||
public void completed(Integer count, Void attachment1) {
|
||||
if (count < 1) {
|
||||
@@ -72,7 +77,8 @@ public class WebSocketRunner implements Runnable {
|
||||
if (readBuffer == null) return;
|
||||
readBuffer.flip();
|
||||
try {
|
||||
WebSocketPacket packet = coder.decode(readBuffer);
|
||||
ByteBuffer[] exBuffers = null;
|
||||
WebSocketPacket packet = coder.decode(readBuffer, exBuffers);
|
||||
if (packet == null) {
|
||||
if (debug) context.getLogger().log(Level.FINEST, "WebSocketRunner abort on decode WebSocketPacket, force to close channel");
|
||||
failed(null, attachment1);
|
||||
@@ -194,25 +200,47 @@ public class WebSocketRunner implements Runnable {
|
||||
|
||||
private ByteBuffer buffer;
|
||||
|
||||
private ByteBuffer[] exbuffers;
|
||||
|
||||
private byte[] mask;
|
||||
|
||||
private int index = 0;
|
||||
|
||||
public Masker(ByteBuffer buffer) {
|
||||
public Masker(ByteBuffer buffer, ByteBuffer... exbuffers) {
|
||||
this.buffer = buffer;
|
||||
this.exbuffers = exbuffers == null || exbuffers.length == 0 ? null : exbuffers;
|
||||
}
|
||||
|
||||
public Masker() {
|
||||
generateMask();
|
||||
}
|
||||
|
||||
public int remaining() {
|
||||
int r = buffer.remaining();
|
||||
if (exbuffers != null) {
|
||||
for (ByteBuffer b : exbuffers) {
|
||||
r += b.remaining();
|
||||
}
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
public byte get() {
|
||||
return buffer.get();
|
||||
}
|
||||
|
||||
public byte[] get(final int size) {
|
||||
byte[] bytes = new byte[size];
|
||||
buffer.get(bytes);
|
||||
if (buffer.remaining() >= size) {
|
||||
buffer.get(bytes);
|
||||
} else { //必须有 exbuffers
|
||||
int offset = buffer.remaining();
|
||||
buffer.get(bytes, 0, buffer.remaining());
|
||||
for (ByteBuffer b : exbuffers) {
|
||||
b.get(bytes, offset, b.remaining());
|
||||
offset += b.remaining();
|
||||
}
|
||||
}
|
||||
return bytes;
|
||||
}
|
||||
|
||||
@@ -285,9 +313,17 @@ public class WebSocketRunner implements Runnable {
|
||||
|
||||
private Logger logger;
|
||||
|
||||
public WebSocketPacket decode(final ByteBuffer buffer) {
|
||||
public WebSocketPacket decode(final ByteBuffer buffer, ByteBuffer... exbuffers) {
|
||||
final boolean debug = this.debugable;
|
||||
if (debug) logger.log(Level.FINEST, "read web socket message's length = " + buffer.remaining());
|
||||
{
|
||||
int remain = buffer.remaining();
|
||||
if (exbuffers != null) {
|
||||
for (ByteBuffer b : exbuffers) {
|
||||
remain += b == null ? 0 : b.remaining();
|
||||
}
|
||||
}
|
||||
if (debug) logger.log(Level.FINEST, "read web socket message's length = " + remain);
|
||||
}
|
||||
if (buffer.remaining() < 2) return null;
|
||||
byte opcode = buffer.get();
|
||||
final boolean last = (opcode & 0b1000000) != 0;
|
||||
@@ -311,7 +347,7 @@ public class WebSocketRunner implements Runnable {
|
||||
return null;
|
||||
}
|
||||
byte lengthCode = buffer.get();
|
||||
final Masker masker = new Masker(buffer);
|
||||
final Masker masker = new Masker(buffer, exbuffers);
|
||||
final boolean masked = (lengthCode & 0x80) == 0x80;
|
||||
if (masked) lengthCode ^= 0x80; //mask
|
||||
int length;
|
||||
@@ -337,7 +373,7 @@ public class WebSocketRunner implements Runnable {
|
||||
}
|
||||
masker.readMask();
|
||||
}
|
||||
if (buffer.remaining() < length) {
|
||||
if (masker.remaining() < length) {
|
||||
if (debug) logger.log(Level.FINE, " read illegal remaining length from websocket, expect " + length + " but " + buffer.remaining());
|
||||
return null;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user