diff --git a/src/org/redkale/net/http/WebSocket.java b/src/org/redkale/net/http/WebSocket.java index 6815b46ef..b36c72ea6 100644 --- a/src/org/redkale/net/http/WebSocket.java +++ b/src/org/redkale/net/http/WebSocket.java @@ -89,18 +89,21 @@ public abstract class WebSocket { } //---------------------------------------------------------------- - /** - * 给自身发送消息体, 包含二进制/文本 - * - * @param packet WebSocketPacket - * - * @return 0表示成功, 非0表示错误码 - */ - public final CompletableFuture send(WebSocketPacket packet) { - CompletableFuture rs = null; - if (this._runner != null) rs = this._runner.sendMessage(packet); - if (_engine.finest) _engine.logger.finest("wsgroupid:" + getGroupid() + " send websocket result is " + rs + " on " + this + " by message(" + packet + ")"); - return rs == null ? CompletableFuture.completedFuture(RETCODE_WSOCKET_CLOSED) : rs; + public final CompletableFuture sendPing() { + //if (_engine.finest) _engine.logger.finest(this + " on "+_engine.getEngineid()+" ping..."); + return send(WebSocketPacket.DEFAULT_PING_PACKET); + } + + public final CompletableFuture sendPing(byte[] data) { + return send(new WebSocketPacket(FrameType.PING, data)); + } + + public final CompletableFuture sendPong(byte[] data) { + return send(new WebSocketPacket(FrameType.PONG, data)); + } + + public final long getCreatetime() { + return createtime; } /** @@ -123,24 +126,7 @@ public abstract class WebSocket { * @return 0表示成功, 非0表示错误码 */ public final CompletableFuture send(String text, boolean last) { - return send(new WebSocketPacket(text, last)); - } - - public final CompletableFuture sendPing() { - //if (_engine.finest) _engine.logger.finest(this + " on "+_engine.getEngineid()+" ping..."); - return send(WebSocketPacket.DEFAULT_PING_PACKET); - } - - public final CompletableFuture sendPing(byte[] data) { - return send(new WebSocketPacket(FrameType.PING, data)); - } - - public final CompletableFuture sendPong(byte[] data) { - return send(new WebSocketPacket(FrameType.PONG, data)); - } - - public final long getCreatetime() { - return createtime; + return sendPacket(new WebSocketPacket(text, last)); } /** @@ -163,7 +149,7 @@ public abstract class WebSocket { * @return 0表示成功, 非0表示错误码 */ public final CompletableFuture send(byte[] data, boolean last) { - return send(new WebSocketPacket(data, last)); + return sendPacket(new WebSocketPacket(data, last)); } /** @@ -187,9 +173,9 @@ public abstract class WebSocket { */ public final CompletableFuture send(Object message, boolean last) { if (message == null || message instanceof CharSequence || message instanceof byte[]) { - return send(new WebSocketPacket((Serializable) message, last)); + return sendPacket(new WebSocketPacket((Serializable) message, last)); } else { - return send(new WebSocketPacket(_jsonConvert, message, last)); + return sendPacket(new WebSocketPacket(_jsonConvert, message, last)); } } @@ -202,7 +188,7 @@ public abstract class WebSocket { * @return 0表示成功, 非0表示错误码 */ public final CompletableFuture send(JsonConvert convert, Object message) { - return send(new WebSocketPacket(convert == null ? _jsonConvert : convert, message, true)); + return sendPacket(new WebSocketPacket(convert == null ? _jsonConvert : convert, message, true)); } /** @@ -215,7 +201,21 @@ public abstract class WebSocket { * @return 0表示成功, 非0表示错误码 */ public final CompletableFuture send(JsonConvert convert, Object message, boolean last) { - return send(new WebSocketPacket(convert == null ? _jsonConvert : convert, message, last)); + return sendPacket(new WebSocketPacket(convert == null ? _jsonConvert : convert, message, last)); + } + + /** + * 给自身发送消息体, 包含二进制/文本 + * + * @param packet WebSocketPacket + * + * @return 0表示成功, 非0表示错误码 + */ + CompletableFuture sendPacket(WebSocketPacket packet) { + CompletableFuture rs = null; + if (this._runner != null) rs = this._runner.sendMessage(packet); + if (_engine.finest) _engine.logger.finest("wsgroupid:" + getGroupid() + " send websocket result is " + rs + " on " + this + " by message(" + packet + ")"); + return rs == null ? CompletableFuture.completedFuture(RETCODE_WSOCKET_CLOSED) : rs; } //---------------------------------------------------------------- @@ -522,19 +522,23 @@ public abstract class WebSocket { public void onConnected() { } - public void onMessage(String text) { - } - public void onPing(byte[] bytes) { } public void onPong(byte[] bytes) { } + public java.lang.reflect.Type getTextMessageType() { + return String.class; + } + + public void onMessage(Object message) { + } + public void onMessage(byte[] bytes) { } - public void onFragment(String text, boolean last) { + public void onFragment(Object message, boolean last) { } public void onFragment(byte[] bytes, boolean last) { diff --git a/src/org/redkale/net/http/WebSocketPacket.java b/src/org/redkale/net/http/WebSocketPacket.java index c602973aa..87f6df4b1 100644 --- a/src/org/redkale/net/http/WebSocketPacket.java +++ b/src/org/redkale/net/http/WebSocketPacket.java @@ -10,6 +10,7 @@ import java.io.*; import java.nio.ByteBuffer; import java.util.function.Supplier; import java.util.logging.*; +import org.redkale.convert.ConvertMask; import org.redkale.convert.json.JsonConvert; /** @@ -57,9 +58,13 @@ public final class WebSocketPacket { protected boolean last = true; - protected Object json; + protected Object sendJson; - JsonConvert convert; + JsonConvert sendConvert; + + ConvertMask receiveMasker; + + ByteBuffer[] receiveBuffers; public WebSocketPacket() { } @@ -88,8 +93,8 @@ public final class WebSocketPacket { public WebSocketPacket(JsonConvert convert, Object json, boolean fin) { this.type = FrameType.TEXT; - this.convert = convert; - this.json = json; + this.sendConvert = convert; + this.sendJson = json; this.last = fin; } @@ -147,7 +152,7 @@ public final class WebSocketPacket { */ ByteBuffer[] encode(final Supplier supplier) { final byte opcode = (byte) (this.type.getValue() | 0x80); - if (this.convert != null) { + if (this.sendConvert != null) { Supplier newsupplier = new Supplier() { private ByteBuffer buf = supplier.get(); @@ -163,7 +168,7 @@ public final class WebSocketPacket { return supplier.get(); } }; - ByteBuffer[] buffers = this.convert.convertTo(newsupplier, json); + ByteBuffer[] buffers = this.sendConvert.convertTo(newsupplier, sendJson); int len = 0; for (ByteBuffer buf : buffers) { len += buf.remaining(); @@ -229,23 +234,20 @@ public final class WebSocketPacket { return buffers; } - /** - * - *
-     * public static void main(String[] args) throws Throwable {
-     *      byte[] mask = new byte[]{(byte) 0x8f, (byte) 0xf8, (byte) 0x6d, (byte) 0x94};
-     *      ByteBuffer buffer = ByteBuffer.wrap(new byte[]{(byte) 0x67, (byte) 0x47, (byte) 0xf4, (byte) 0x70, (byte) 0x37, (byte) 0x52, (byte) 0x8b, (byte) 0x0c, (byte) 0x20, (byte) 0x1e, (byte) 0xdb, (byte) 0x1c, (byte) 0x69, (byte) 0x79, (byte) 0xc2});
-     *      ConvertMask masker = new ConvertMask() {
-     *          private int index = 0;
-     *          public byte unmask(byte value) {
-     *              return (byte) (value ^ mask[index++ % 4]);
-     *          }
-     *      };
-     *      String rs = JsonConvert.root().convertFrom(String.class, masker, buffer);
-     *      System.out.println(rs);
-     * }
-     * 
- */ +// public static void main(String[] args) throws Throwable { +// byte[] mask = new byte[]{(byte) 0x8f, (byte) 0xf8, (byte) 0x6d, (byte) 0x94}; +// ByteBuffer buffer = ByteBuffer.wrap(new byte[]{(byte) 0x67, (byte) 0x47, (byte) 0xf4, (byte) 0x70, (byte) 0x37, (byte) 0x52, (byte) 0x8b, (byte) 0x0c, (byte) 0x20, (byte) 0x1e, (byte) 0xdb, (byte) 0x1c, (byte) 0x69, (byte) 0x79, (byte) 0xc2}); +// ConvertMask masker = new ConvertMask() { +// private int index = 0; +// +// public byte unmask(byte value) { +// return (byte) (value ^ mask[index++ % 4]); +// } +// }; +// String rs = JsonConvert.root().convertFrom(String.class, masker, buffer); +// System.out.println(rs); +// } + /** * 消息解码
* @@ -290,7 +292,7 @@ public final class WebSocketPacket { this.type = FrameType.valueOf(opcode & 0xF); if (type == FrameType.CLOSE) { if (debug) logger.log(Level.FINEST, " receive close command from websocket client"); - return null; + return this; } final boolean checkrsv = false;//暂时不校验 if (checkrsv && (opcode & 0b0111_0000) != 0) { @@ -323,34 +325,20 @@ public final class WebSocketPacket { length = buffer.getInt(); } } - byte[] mask = null; if (masked) { - mask = new byte[4]; - buffer.get(mask); - } - final byte[] data = new byte[length]; - if (buffer.remaining() >= length) { - buffer.get(data); - } else { //必须有 exbuffers - int offset = buffer.remaining(); - buffer.get(data, 0, offset); - for (ByteBuffer b : exbuffers) { - int r = b.remaining(); - b.get(data, offset, r); - offset += r; - if (offset >= length) break; - } - } - if (mask != null) { - for (int i = 0; i < data.length; i++) { - data[i] ^= mask[i % 4]; - } - } - if (type == FrameType.TEXT) { - this.payload = new String(Utility.decodeUTF8(data)); - } else { - this.bytes = data; + final byte[] masks = new byte[4]; + buffer.get(masks); + this.receiveMasker = new ConvertMask() { + + private int index = 0; + + @Override + public byte unmask(byte value) { + return (byte) (value ^ masks[index++ % 4]); + } + }; } + this.receiveBuffers = Utility.append(new ByteBuffer[]{buffer}, exbuffers); return this; } diff --git a/src/org/redkale/net/http/WebSocketRunner.java b/src/org/redkale/net/http/WebSocketRunner.java index 7069d2721..74a3c8b1e 100644 --- a/src/org/redkale/net/http/WebSocketRunner.java +++ b/src/org/redkale/net/http/WebSocketRunner.java @@ -15,6 +15,7 @@ import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.*; +import org.redkale.convert.json.JsonConvert; /** * WebSocket的消息接收发送器, 一个WebSocket对应一个WebSocketRunner @@ -48,6 +49,8 @@ public class WebSocketRunner implements Runnable { protected long lastSendTime; + protected final JsonConvert convert; + public WebSocketRunner(Context context, WebSocket webSocket, AsyncConnection channel, final boolean wsbinary) { this.context = context; this.engine = webSocket._engine; @@ -55,6 +58,7 @@ public class WebSocketRunner implements Runnable { this.channel = channel; this.wsbinary = wsbinary; this.readBuffer = context.pollBuffer(); + this.convert = context.getJsonConvert(); } @Override @@ -101,9 +105,55 @@ public class WebSocketRunner implements Runnable { b.flip(); } } - WebSocketPacket packet = null; try { - packet = new WebSocketPacket().decode(context.getLogger(), readBuffer, exBuffers); + + WebSocketPacket packet = new WebSocketPacket().decode(context.getLogger(), readBuffer, exBuffers); + + if (packet == null) { + failed(null, attachment1); + if (debug) context.getLogger().log(Level.FINEST, "WebSocketRunner abort on decode WebSocketPacket, force to close channel, live " + (System.currentTimeMillis() - webSocket.getCreatetime()) / 1000 + " seconds"); + return; + } + webSocket._group.setRecentWebSocket(webSocket); + try { + if (packet.type == FrameType.TEXT) { + Object message = convert.convertFrom(webSocket.getTextMessageType(), packet.receiveMasker, packet.receiveBuffers); + if (readBuffer != null) { + readBuffer.clear(); + channel.read(readBuffer, null, this); + } + webSocket.onMessage(message); + } else if (packet.type == FrameType.BINARY) { + Object message = convert.convertFrom(byte[].class, packet.receiveMasker, packet.receiveBuffers); + if (readBuffer != null) { + readBuffer.clear(); + channel.read(readBuffer, null, this); + } + webSocket.onMessage(message); + } else if (packet.type == FrameType.PONG) { + byte[] message = convert.convertFrom(byte[].class, packet.receiveMasker, packet.receiveBuffers); + if (readBuffer != null) { + readBuffer.clear(); + channel.read(readBuffer, null, this); + } + webSocket.onPong(message); + } else if (packet.type == FrameType.PING) { + byte[] message = convert.convertFrom(byte[].class, packet.receiveMasker, packet.receiveBuffers); + if (readBuffer != null) { + readBuffer.clear(); + channel.read(readBuffer, null, this); + } + webSocket.onPing(message); + } else { + context.getLogger().log(Level.WARNING, "WebSocketRunner onMessage by unknown FrameType : " + packet); + if (readBuffer != null) { + readBuffer.clear(); + channel.read(readBuffer, null, this); + } + } + } catch (Exception e) { + context.getLogger().log(Level.INFO, "WebSocket onMessage error (" + packet + ")", e); + } } finally { if (exBuffers != null) { for (ByteBuffer b : exBuffers) { @@ -111,29 +161,6 @@ public class WebSocketRunner implements Runnable { } } } - if (packet == null) { - failed(null, attachment1); - if (debug) context.getLogger().log(Level.FINEST, "WebSocketRunner abort on decode WebSocketPacket, force to close channel, live " + (System.currentTimeMillis() - webSocket.getCreatetime()) / 1000 + " seconds"); - return; - } - if (readBuffer != null) { - readBuffer.clear(); - channel.read(readBuffer, null, this); - } - webSocket._group.setRecentWebSocket(webSocket); - try { - if (packet.type == FrameType.TEXT) { - webSocket.onMessage(packet.getPayload()); - } else if (packet.type == FrameType.BINARY) { - webSocket.onMessage(packet.getBytes()); - } else if (packet.type == FrameType.PONG) { - webSocket.onPong(packet.getBytes()); - } else if (packet.type == FrameType.PING) { - webSocket.onPing(packet.getBytes()); - } - } catch (Exception e) { - context.getLogger().log(Level.INFO, "WebSocket onMessage error (" + packet + ")", e); - } } catch (Throwable t) { closeRunner(); if (debug) context.getLogger().log(Level.FINEST, "WebSocketRunner abort on read WebSocketPacket, force to close channel, live " + (System.currentTimeMillis() - webSocket.getCreatetime()) / 1000 + " seconds", t); diff --git a/src/org/redkale/util/Utility.java b/src/org/redkale/util/Utility.java index 3a7baf4ed..7f0338853 100644 --- a/src/org/redkale/util/Utility.java +++ b/src/org/redkale/util/Utility.java @@ -216,6 +216,7 @@ public final class Utility { */ public static T[] append(final T[] array, final T... objs) { if (array == null || array.length == 0) return objs; + if (objs == null || objs.length == 0) return array; final T[] news = (T[]) Array.newInstance(array.getClass().getComponentType(), array.length + objs.length); System.arraycopy(array, 0, news, 0, array.length); System.arraycopy(objs, 0, news, array.length, objs.length); diff --git a/test/org/redkale/test/websocket/ChatWebSocketServlet.java b/test/org/redkale/test/websocket/ChatWebSocketServlet.java index 1b02005b3..e44d86876 100644 --- a/test/org/redkale/test/websocket/ChatWebSocketServlet.java +++ b/test/org/redkale/test/websocket/ChatWebSocketServlet.java @@ -9,6 +9,7 @@ import org.redkale.net.http.WebServlet; import org.redkale.net.http.WebSocketServlet; import org.redkale.net.http.WebSocket; import java.io.*; +import java.lang.reflect.Type; import java.util.concurrent.atomic.*; import org.redkale.convert.json.JsonConvert; import org.redkale.util.Utility; @@ -55,11 +56,11 @@ public class ChatWebSocketServlet extends WebSocketServlet { return new WebSocket() { @Override - public void onMessage(String text) { + public void onMessage(Object text) { icounter.incrementAndGet(); counter.incrementAndGet(); - ChatMessage message = jsonConvert.convertFrom(ChatMessage.class, text); - if (debug) System.out.println("收到消息: " + text + ", " + message); + ChatMessage message = (ChatMessage) text;//jsonConvert.convertFrom(ChatMessage.class, text.toString()); + if (debug) System.out.println("收到消息: " + message); super.getWebSocketGroup().getWebSockets().forEach(x -> x.send(message)); } @@ -67,6 +68,11 @@ public class ChatWebSocketServlet extends WebSocketServlet { protected Serializable createGroupid() { return ""; } + + @Override + public Type getTextMessageType(){ + return ChatMessage.class; + } }; } diff --git a/test/org/redkale/test/websocket/VideoWebSocketServlet.java b/test/org/redkale/test/websocket/VideoWebSocketServlet.java index 47b55d6e7..32d042107 100644 --- a/test/org/redkale/test/websocket/VideoWebSocketServlet.java +++ b/test/org/redkale/test/websocket/VideoWebSocketServlet.java @@ -87,7 +87,7 @@ public class VideoWebSocketServlet extends WebSocketServlet { } @Override - public void onMessage(String text) { + public void onMessage(Object text) { //System.out.println("接收到消息: " + text); super.getWebSocketGroup().getWebSockets().filter(x -> x != this).forEach(x -> { x.send(text);