This commit is contained in:
地平线
2015-06-29 11:11:25 +08:00
parent 13d4df3c36
commit ffd235f640

View File

@@ -11,6 +11,7 @@ import com.wentch.redkale.net.http.WebSocketPacket.PacketType;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.security.SecureRandom;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.*;
@@ -48,7 +49,7 @@ public class WebSocketRunner implements Runnable {
this.channel = channel;
webSocket.runner = this;
this.coder.logger = context.getLogger();
this.coder.debugable = context.getLogger().isLoggable(Level.FINEST);
this.coder.debugable = context.getLogger().isLoggable(Level.FINEST);
this.readBuffer = context.pollBuffer();
this.writeBuffer = context.pollBuffer();
}
@@ -62,6 +63,10 @@ public class WebSocketRunner implements Runnable {
channel.setReadTimeoutSecond(300); //读取超时5分钟
if (channel.isOpen()) {
channel.read(readBuffer, null, new CompletionHandler<Integer, Void>() {
//当接收的数据流长度大于ByteBuffer长度时 则需要额外的ByteBuffer 辅助;
private final List<ByteBuffer> readBuffers = new ArrayList<>();
@Override
public void completed(Integer count, Void attachment1) {
if (count < 1) {
@@ -72,7 +77,8 @@ public class WebSocketRunner implements Runnable {
if (readBuffer == null) return;
readBuffer.flip();
try {
WebSocketPacket packet = coder.decode(readBuffer);
ByteBuffer[] exBuffers = null;
WebSocketPacket packet = coder.decode(readBuffer, exBuffers);
if (packet == null) {
if (debug) context.getLogger().log(Level.FINEST, "WebSocketRunner abort on decode WebSocketPacket, force to close channel");
failed(null, attachment1);
@@ -194,25 +200,47 @@ public class WebSocketRunner implements Runnable {
private ByteBuffer buffer;
private ByteBuffer[] exbuffers;
private byte[] mask;
private int index = 0;
public Masker(ByteBuffer buffer) {
public Masker(ByteBuffer buffer, ByteBuffer... exbuffers) {
this.buffer = buffer;
this.exbuffers = exbuffers == null || exbuffers.length == 0 ? null : exbuffers;
}
public Masker() {
generateMask();
}
public int remaining() {
int r = buffer.remaining();
if (exbuffers != null) {
for (ByteBuffer b : exbuffers) {
r += b.remaining();
}
}
return r;
}
public byte get() {
return buffer.get();
}
public byte[] get(final int size) {
byte[] bytes = new byte[size];
buffer.get(bytes);
if (buffer.remaining() >= size) {
buffer.get(bytes);
} else { //必须有 exbuffers
int offset = buffer.remaining();
buffer.get(bytes, 0, buffer.remaining());
for (ByteBuffer b : exbuffers) {
b.get(bytes, offset, b.remaining());
offset += b.remaining();
}
}
return bytes;
}
@@ -285,9 +313,17 @@ public class WebSocketRunner implements Runnable {
private Logger logger;
public WebSocketPacket decode(final ByteBuffer buffer) {
public WebSocketPacket decode(final ByteBuffer buffer, ByteBuffer... exbuffers) {
final boolean debug = this.debugable;
if (debug) logger.log(Level.FINEST, "read web socket message's length = " + buffer.remaining());
{
int remain = buffer.remaining();
if (exbuffers != null) {
for (ByteBuffer b : exbuffers) {
remain += b == null ? 0 : b.remaining();
}
}
if (debug) logger.log(Level.FINEST, "read web socket message's length = " + remain);
}
if (buffer.remaining() < 2) return null;
byte opcode = buffer.get();
final boolean last = (opcode & 0b1000000) != 0;
@@ -311,7 +347,7 @@ public class WebSocketRunner implements Runnable {
return null;
}
byte lengthCode = buffer.get();
final Masker masker = new Masker(buffer);
final Masker masker = new Masker(buffer, exbuffers);
final boolean masked = (lengthCode & 0x80) == 0x80;
if (masked) lengthCode ^= 0x80; //mask
int length;
@@ -337,7 +373,7 @@ public class WebSocketRunner implements Runnable {
}
masker.readMask();
}
if (buffer.remaining() < length) {
if (masker.remaining() < length) {
if (debug) logger.log(Level.FINE, " read illegal remaining length from websocket, expect " + length + " but " + buffer.remaining());
return null;
}