This commit is contained in:
地平线
2015-06-29 11:11:25 +08:00
parent 13d4df3c36
commit ffd235f640

View File

@@ -11,6 +11,7 @@ import com.wentch.redkale.net.http.WebSocketPacket.PacketType;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.*; import java.nio.channels.*;
import java.security.SecureRandom; import java.security.SecureRandom;
import java.util.*;
import java.util.concurrent.*; import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.*; import java.util.logging.*;
@@ -62,6 +63,10 @@ public class WebSocketRunner implements Runnable {
channel.setReadTimeoutSecond(300); //读取超时5分钟 channel.setReadTimeoutSecond(300); //读取超时5分钟
if (channel.isOpen()) { if (channel.isOpen()) {
channel.read(readBuffer, null, new CompletionHandler<Integer, Void>() { channel.read(readBuffer, null, new CompletionHandler<Integer, Void>() {
//当接收的数据流长度大于ByteBuffer长度时 则需要额外的ByteBuffer 辅助;
private final List<ByteBuffer> readBuffers = new ArrayList<>();
@Override @Override
public void completed(Integer count, Void attachment1) { public void completed(Integer count, Void attachment1) {
if (count < 1) { if (count < 1) {
@@ -72,7 +77,8 @@ public class WebSocketRunner implements Runnable {
if (readBuffer == null) return; if (readBuffer == null) return;
readBuffer.flip(); readBuffer.flip();
try { try {
WebSocketPacket packet = coder.decode(readBuffer); ByteBuffer[] exBuffers = null;
WebSocketPacket packet = coder.decode(readBuffer, exBuffers);
if (packet == null) { if (packet == null) {
if (debug) context.getLogger().log(Level.FINEST, "WebSocketRunner abort on decode WebSocketPacket, force to close channel"); if (debug) context.getLogger().log(Level.FINEST, "WebSocketRunner abort on decode WebSocketPacket, force to close channel");
failed(null, attachment1); failed(null, attachment1);
@@ -194,25 +200,47 @@ public class WebSocketRunner implements Runnable {
private ByteBuffer buffer; private ByteBuffer buffer;
private ByteBuffer[] exbuffers;
private byte[] mask; private byte[] mask;
private int index = 0; private int index = 0;
public Masker(ByteBuffer buffer) { public Masker(ByteBuffer buffer, ByteBuffer... exbuffers) {
this.buffer = buffer; this.buffer = buffer;
this.exbuffers = exbuffers == null || exbuffers.length == 0 ? null : exbuffers;
} }
public Masker() { public Masker() {
generateMask(); generateMask();
} }
public int remaining() {
int r = buffer.remaining();
if (exbuffers != null) {
for (ByteBuffer b : exbuffers) {
r += b.remaining();
}
}
return r;
}
public byte get() { public byte get() {
return buffer.get(); return buffer.get();
} }
public byte[] get(final int size) { public byte[] get(final int size) {
byte[] bytes = new byte[size]; byte[] bytes = new byte[size];
if (buffer.remaining() >= size) {
buffer.get(bytes); buffer.get(bytes);
} else { //必须有 exbuffers
int offset = buffer.remaining();
buffer.get(bytes, 0, buffer.remaining());
for (ByteBuffer b : exbuffers) {
b.get(bytes, offset, b.remaining());
offset += b.remaining();
}
}
return bytes; return bytes;
} }
@@ -285,9 +313,17 @@ public class WebSocketRunner implements Runnable {
private Logger logger; private Logger logger;
public WebSocketPacket decode(final ByteBuffer buffer) { public WebSocketPacket decode(final ByteBuffer buffer, ByteBuffer... exbuffers) {
final boolean debug = this.debugable; final boolean debug = this.debugable;
if (debug) logger.log(Level.FINEST, "read web socket message's length = " + buffer.remaining()); {
int remain = buffer.remaining();
if (exbuffers != null) {
for (ByteBuffer b : exbuffers) {
remain += b == null ? 0 : b.remaining();
}
}
if (debug) logger.log(Level.FINEST, "read web socket message's length = " + remain);
}
if (buffer.remaining() < 2) return null; if (buffer.remaining() < 2) return null;
byte opcode = buffer.get(); byte opcode = buffer.get();
final boolean last = (opcode & 0b1000000) != 0; final boolean last = (opcode & 0b1000000) != 0;
@@ -311,7 +347,7 @@ public class WebSocketRunner implements Runnable {
return null; return null;
} }
byte lengthCode = buffer.get(); byte lengthCode = buffer.get();
final Masker masker = new Masker(buffer); final Masker masker = new Masker(buffer, exbuffers);
final boolean masked = (lengthCode & 0x80) == 0x80; final boolean masked = (lengthCode & 0x80) == 0x80;
if (masked) lengthCode ^= 0x80; //mask if (masked) lengthCode ^= 0x80; //mask
int length; int length;
@@ -337,7 +373,7 @@ public class WebSocketRunner implements Runnable {
} }
masker.readMask(); masker.readMask();
} }
if (buffer.remaining() < length) { if (masker.remaining() < length) {
if (debug) logger.log(Level.FINE, " read illegal remaining length from websocket, expect " + length + " but " + buffer.remaining()); if (debug) logger.log(Level.FINE, " read illegal remaining length from websocket, expect " + length + " but " + buffer.remaining());
return null; return null;
} }