From ffd235f640391780df0bcd782d88c88f1a49f2e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=9C=B0=E5=B9=B3=E7=BA=BF?= <22250530@qq.com> Date: Mon, 29 Jun 2015 11:11:25 +0800 Subject: [PATCH] --- .../redkale/net/http/WebSocketRunner.java | 52 ++++++++++++++++--- 1 file changed, 44 insertions(+), 8 deletions(-) diff --git a/src/com/wentch/redkale/net/http/WebSocketRunner.java b/src/com/wentch/redkale/net/http/WebSocketRunner.java index afb8582ef..7ea4095f0 100644 --- a/src/com/wentch/redkale/net/http/WebSocketRunner.java +++ b/src/com/wentch/redkale/net/http/WebSocketRunner.java @@ -11,6 +11,7 @@ import com.wentch.redkale.net.http.WebSocketPacket.PacketType; import java.nio.ByteBuffer; import java.nio.channels.*; import java.security.SecureRandom; +import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.*; @@ -48,7 +49,7 @@ public class WebSocketRunner implements Runnable { this.channel = channel; webSocket.runner = this; this.coder.logger = context.getLogger(); - this.coder.debugable = context.getLogger().isLoggable(Level.FINEST); + this.coder.debugable = context.getLogger().isLoggable(Level.FINEST); this.readBuffer = context.pollBuffer(); this.writeBuffer = context.pollBuffer(); } @@ -62,6 +63,10 @@ public class WebSocketRunner implements Runnable { channel.setReadTimeoutSecond(300); //读取超时5分钟 if (channel.isOpen()) { channel.read(readBuffer, null, new CompletionHandler() { + + //当接收的数据流长度大于ByteBuffer长度时, 则需要额外的ByteBuffer 辅助; + private final List readBuffers = new ArrayList<>(); + @Override public void completed(Integer count, Void attachment1) { if (count < 1) { @@ -72,7 +77,8 @@ public class WebSocketRunner implements Runnable { if (readBuffer == null) return; readBuffer.flip(); try { - WebSocketPacket packet = coder.decode(readBuffer); + ByteBuffer[] exBuffers = null; + WebSocketPacket packet = coder.decode(readBuffer, exBuffers); if (packet == null) { if (debug) context.getLogger().log(Level.FINEST, "WebSocketRunner abort on decode WebSocketPacket, force to close channel"); failed(null, attachment1); @@ -194,25 +200,47 @@ public class WebSocketRunner implements Runnable { private ByteBuffer buffer; + private ByteBuffer[] exbuffers; + private byte[] mask; private int index = 0; - public Masker(ByteBuffer buffer) { + public Masker(ByteBuffer buffer, ByteBuffer... exbuffers) { this.buffer = buffer; + this.exbuffers = exbuffers == null || exbuffers.length == 0 ? null : exbuffers; } public Masker() { generateMask(); } + public int remaining() { + int r = buffer.remaining(); + if (exbuffers != null) { + for (ByteBuffer b : exbuffers) { + r += b.remaining(); + } + } + return r; + } + public byte get() { return buffer.get(); } public byte[] get(final int size) { byte[] bytes = new byte[size]; - buffer.get(bytes); + if (buffer.remaining() >= size) { + buffer.get(bytes); + } else { //必须有 exbuffers + int offset = buffer.remaining(); + buffer.get(bytes, 0, buffer.remaining()); + for (ByteBuffer b : exbuffers) { + b.get(bytes, offset, b.remaining()); + offset += b.remaining(); + } + } return bytes; } @@ -285,9 +313,17 @@ public class WebSocketRunner implements Runnable { private Logger logger; - public WebSocketPacket decode(final ByteBuffer buffer) { + public WebSocketPacket decode(final ByteBuffer buffer, ByteBuffer... exbuffers) { final boolean debug = this.debugable; - if (debug) logger.log(Level.FINEST, "read web socket message's length = " + buffer.remaining()); + { + int remain = buffer.remaining(); + if (exbuffers != null) { + for (ByteBuffer b : exbuffers) { + remain += b == null ? 0 : b.remaining(); + } + } + if (debug) logger.log(Level.FINEST, "read web socket message's length = " + remain); + } if (buffer.remaining() < 2) return null; byte opcode = buffer.get(); final boolean last = (opcode & 0b1000000) != 0; @@ -311,7 +347,7 @@ public class WebSocketRunner implements Runnable { return null; } byte lengthCode = buffer.get(); - final Masker masker = new Masker(buffer); + final Masker masker = new Masker(buffer, exbuffers); final boolean masked = (lengthCode & 0x80) == 0x80; if (masked) lengthCode ^= 0x80; //mask int length; @@ -337,7 +373,7 @@ public class WebSocketRunner implements Runnable { } masker.readMask(); } - if (buffer.remaining() < length) { + if (masker.remaining() < length) { if (debug) logger.log(Level.FINE, " read illegal remaining length from websocket, expect " + length + " but " + buffer.remaining()); return null; }