diff --git a/src/org/redkale/net/Context.java b/src/org/redkale/net/Context.java index c6f76661a..91897d054 100644 --- a/src/org/redkale/net/Context.java +++ b/src/org/redkale/net/Context.java @@ -144,6 +144,10 @@ public class Context { return bufferPool; } + public Consumer getBufferConsumer() { + return bufferPool; + } + public ByteBuffer pollBuffer() { return bufferPool.get(); } diff --git a/src/org/redkale/net/Cryptor.java b/src/org/redkale/net/Cryptor.java new file mode 100644 index 000000000..bdd64930e --- /dev/null +++ b/src/org/redkale/net/Cryptor.java @@ -0,0 +1,46 @@ +/* + * To change this license header, choose License Headers in Project Properties. + * To change this template file, choose Tools | Templates + * and open the template in the editor. + */ +package org.redkale.net; + +import java.nio.ByteBuffer; +import java.util.function.*; + +/** + * 加密解密接口 + * + *

+ * 详情见: https://redkale.org + * + * @author zhangjx + */ +public interface Cryptor { + + /** + * 加密 + * + * @param buffers 待加密数据 + * @param supplier ByteBuffer生成器 + * @param consumer ByteBuffer回收器 + * + * @return 加密后数据 + */ + default ByteBuffer[] encrypt(ByteBuffer[] buffers, final Supplier supplier, final Consumer consumer) { + return buffers; + } + + /** + * 解密 + * + * @param buffers 待解密数据 + * @param supplier ByteBuffer生成器 + * @param consumer ByteBuffer回收器 + * + * @return 解密后数据 + */ + default ByteBuffer[] decrypt(ByteBuffer[] buffers, final Supplier supplier, final Consumer consumer) { + return buffers; + } +} diff --git a/src/org/redkale/net/http/Rest.java b/src/org/redkale/net/http/Rest.java index 5c30a1f7b..77f606ec7 100644 --- a/src/org/redkale/net/http/Rest.java +++ b/src/org/redkale/net/http/Rest.java @@ -21,6 +21,7 @@ import static org.redkale.asm.Opcodes.*; import org.redkale.asm.Type; import org.redkale.convert.*; import org.redkale.convert.json.*; +import org.redkale.net.Cryptor; import org.redkale.service.*; import org.redkale.util.*; import org.redkale.source.Flipper; @@ -641,7 +642,14 @@ public final class Rest { cw.visitEnd(); Class newClazz = newLoader.loadClass(newDynName.replace('/', '.'), cw.toByteArray()); try { - return (T) newClazz.getDeclaredConstructor().newInstance(); + T servlet = (T) newClazz.getDeclaredConstructor().newInstance(); + if (rws.cryptor() != Cryptor.class) { + Cryptor cryptor = rws.cryptor().getDeclaredConstructor().newInstance(); + Field cryptorField = newClazz.getDeclaredField("cryptor"); + cryptorField.setAccessible(true); + cryptorField.set(servlet, cryptor); + } + return servlet; } catch (Exception e) { throw new RuntimeException(e); } diff --git a/src/org/redkale/net/http/RestWebSocket.java b/src/org/redkale/net/http/RestWebSocket.java index e4495d731..d7cc2c67d 100644 --- a/src/org/redkale/net/http/RestWebSocket.java +++ b/src/org/redkale/net/http/RestWebSocket.java @@ -8,6 +8,7 @@ package org.redkale.net.http; import java.lang.annotation.*; import static java.lang.annotation.ElementType.TYPE; import static java.lang.annotation.RetentionPolicy.RUNTIME; +import org.redkale.net.Cryptor; /** * 只能依附在WebSocket类上,name默认为Service的类名小写并去掉Service字样及后面的字符串 (如HelloWebSocket/HelloWebSocketImpl,的默认路径为 hello)。
@@ -66,6 +67,13 @@ public @interface RestWebSocket { */ int liveinterval() default WebSocketServlet.DEFAILT_LIVEINTERVAL; + /** + * 加密解密器 + * + * @return Cryptor + */ + Class cryptor() default Cryptor.class; + /** * 最大连接数, 小于1表示无限制 * diff --git a/src/org/redkale/net/http/WebSocketEngine.java b/src/org/redkale/net/http/WebSocketEngine.java index de1835f02..5c507143d 100644 --- a/src/org/redkale/net/http/WebSocketEngine.java +++ b/src/org/redkale/net/http/WebSocketEngine.java @@ -14,6 +14,7 @@ import java.util.function.*; import java.util.logging.*; import java.util.stream.*; import org.redkale.convert.Convert; +import org.redkale.net.Cryptor; import static org.redkale.net.http.WebSocket.RETCODE_GROUP_EMPTY; import static org.redkale.net.http.WebSocketServlet.*; import org.redkale.util.*; @@ -72,8 +73,11 @@ public class WebSocketEngine { @Comment("最大消息体长度, 小于1表示无限制") protected int wsmaxbody; + @Comment("加密解密器") + protected Cryptor cryptor; + protected WebSocketEngine(String engineid, boolean single, HttpContext context, int liveinterval, - int wsmaxconns, int wsmaxbody, WebSocketNode node, Convert sendConvert, Logger logger) { + int wsmaxconns, int wsmaxbody, Cryptor cryptor, WebSocketNode node, Convert sendConvert, Logger logger) { this.engineid = engineid; this.single = single; this.context = context; @@ -82,6 +86,7 @@ public class WebSocketEngine { this.liveinterval = liveinterval; this.wsmaxconns = wsmaxconns; this.wsmaxbody = wsmaxbody; + this.cryptor = cryptor; this.logger = logger; this.index = sequence.getAndIncrement(); } @@ -213,7 +218,7 @@ public class WebSocketEngine { final WebSocketPacket packet = (message instanceof WebSocketPacket) ? (WebSocketPacket) message : ((message == null || message instanceof CharSequence || message instanceof byte[]) ? new WebSocketPacket((Serializable) message, last) : new WebSocketPacket(this.sendConvert, false, message, last)); - packet.setSendBuffers(packet.encode(context.getBufferSupplier())); + packet.setSendBuffers(packet.encode(context.getBufferSupplier(), context.getBufferConsumer(), cryptor)); CompletableFuture future = null; if (single) { for (WebSocket websocket : websockets.values()) { @@ -270,7 +275,7 @@ public class WebSocketEngine { final WebSocketPacket packet = (message instanceof WebSocketPacket) ? (WebSocketPacket) message : ((message == null || message instanceof CharSequence || message instanceof byte[]) ? new WebSocketPacket((Serializable) message, last) : new WebSocketPacket(this.sendConvert, false, message, last)); - packet.setSendBuffers(packet.encode(context.getBufferSupplier())); + packet.setSendBuffers(packet.encode(context.getBufferSupplier(), context.getBufferConsumer(), cryptor)); CompletableFuture future = null; if (single) { for (Serializable userid : userids) { diff --git a/src/org/redkale/net/http/WebSocketPacket.java b/src/org/redkale/net/http/WebSocketPacket.java index b36cf23d6..274d11ff1 100644 --- a/src/org/redkale/net/http/WebSocketPacket.java +++ b/src/org/redkale/net/http/WebSocketPacket.java @@ -10,9 +10,10 @@ import java.io.*; import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.util.AbstractMap; -import java.util.function.Supplier; +import java.util.function.*; import java.util.logging.*; import org.redkale.convert.*; +import org.redkale.net.Cryptor; /** * @@ -210,10 +211,11 @@ public final class WebSocketPacket { * 消息编码 * * @param supplier Supplier + * @param cryptor Cryptor * * @return ByteBuffer[] */ - ByteBuffer[] encode(final Supplier supplier) { + ByteBuffer[] encode(final Supplier supplier, final Consumer consumer, final Cryptor cryptor) { final byte opcode = (byte) (this.type.getValue() | 0x80); if (this.sendConvert != null) { Supplier newsupplier = new Supplier() { @@ -232,6 +234,7 @@ public final class WebSocketPacket { } }; ByteBuffer[] buffers = this.sendMapconvable ? this.sendConvert.convertMapTo(newsupplier, (Object[]) sendJson) : this.sendConvert.convertTo(newsupplier, sendJson); + if (cryptor != null) buffers = cryptor.encrypt(buffers, supplier, consumer); int len = 0; for (ByteBuffer buf : buffers) { len += buf.remaining(); @@ -256,7 +259,27 @@ public final class WebSocketPacket { } ByteBuffer buffer = supplier.get(); //确保ByteBuffer的capacity不能小于128 - final byte[] content = content(); + byte[] content = content(); + if (cryptor != null) { + ByteBuffer[] ss = new ByteBuffer[]{ByteBuffer.wrap(content)}; + ByteBuffer[] bs = cryptor.encrypt(ss, supplier, consumer); + if (bs != ss) { + int r = 0; + for (ByteBuffer bb : bs) { + r += bb.remaining(); + } + content = new byte[r]; + int index = 0; + for (ByteBuffer bb : bs) { + int re = bb.remaining(); + bb.get(content, index, re); + index += re; + } + for (ByteBuffer bb : bs) { + consumer.accept(bb); + } + } + } final int len = content.length; if (len <= 0x7D) { //125 buffer.put(opcode); @@ -444,6 +467,10 @@ public final class WebSocketPacket { } void parseReceiveMessage(WebSocket webSocket, ByteBuffer... buffers) { + if (webSocket._engine.cryptor != null) { + HttpContext context = webSocket._engine.context; + buffers = webSocket._engine.cryptor.decrypt(buffers, context.getBufferSupplier(), context.getBufferConsumer()); + } if (this.type == FrameType.TEXT) { Convert textConvert = webSocket.getTextConvert(); if (textConvert == null) { diff --git a/src/org/redkale/net/http/WebSocketRunner.java b/src/org/redkale/net/http/WebSocketRunner.java index 1d2403218..f042b0c14 100644 --- a/src/org/redkale/net/http/WebSocketRunner.java +++ b/src/org/redkale/net/http/WebSocketRunner.java @@ -47,7 +47,7 @@ class WebSocketRunner implements Runnable { private final BiConsumer restMessageConsumer; //主要供RestWebSocket使用 protected long lastSendTime; - + protected long lastReadTime; WebSocketRunner(Context context, WebSocket webSocket, BiConsumer messageConsumer, AsyncConnection channel) { @@ -226,7 +226,7 @@ class WebSocketRunner implements Runnable { return futureResult; } } - ByteBuffer[] buffers = packet.sendBuffers != null ? packet.duplicateSendBuffers() : packet.encode(this.context.getBufferSupplier()); + ByteBuffer[] buffers = packet.sendBuffers != null ? packet.duplicateSendBuffers() : packet.encode(this.context.getBufferSupplier(), this.context.getBufferConsumer(), webSocket._engine.cryptor); if (debug) context.getLogger().log(Level.FINEST, "sending websocket message: " + packet); try { this.lastSendTime = System.currentTimeMillis(); @@ -276,7 +276,7 @@ class WebSocketRunner implements Runnable { } if (entry != null) { future = entry.future; - ByteBuffer[] buffers = entry.packet.sendBuffers != null ? entry.packet.duplicateSendBuffers() : entry.packet.encode(context.getBufferSupplier()); + ByteBuffer[] buffers = entry.packet.sendBuffers != null ? entry.packet.duplicateSendBuffers() : entry.packet.encode(context.getBufferSupplier(), context.getBufferConsumer(), webSocket._engine.cryptor); lastSendTime = System.currentTimeMillis(); if (debug) context.getLogger().log(Level.FINEST, "sending websocket message: " + entry.packet); channel.write(buffers, buffers, this); diff --git a/src/org/redkale/net/http/WebSocketServlet.java b/src/org/redkale/net/http/WebSocketServlet.java index ccd36e2e0..b7e21463e 100644 --- a/src/org/redkale/net/http/WebSocketServlet.java +++ b/src/org/redkale/net/http/WebSocketServlet.java @@ -17,6 +17,7 @@ import java.util.function.BiConsumer; import java.util.logging.*; import javax.annotation.*; import org.redkale.convert.Convert; +import org.redkale.net.Cryptor; import org.redkale.service.*; import org.redkale.util.*; @@ -52,6 +53,9 @@ public abstract class WebSocketServlet extends HttpServlet implements Resourcabl @Comment("最大消息体长度, 小于1表示无限制") public static final String WEBPARAM__WSMAXBODY = "wsmaxbody"; + @Comment("加密解密器") + public static final String WEBPARAM__CRYPTOR = "cryptor"; + @Comment("WebScoket服务器给客户端进行ping操作的默认间隔时间, 单位: 秒") public static final int DEFAILT_LIVEINTERVAL = 15; @@ -78,6 +82,9 @@ public abstract class WebSocketServlet extends HttpServlet implements Resourcabl //同RestWebSocket.anyuser protected boolean anyuser = false; + //同RestWebSocket.cryptor, 变量名不可改, 被Rest.createRestWebSocketServlet用到 + protected Cryptor cryptor; + @Resource(name = "jsonconvert") protected Convert jsonConvert; @@ -124,12 +131,26 @@ public abstract class WebSocketServlet extends HttpServlet implements Resourcabl if (logger.isLoggable(Level.WARNING)) logger.warning("Not found WebSocketNode, create a default value for " + getClass().getName()); } if (this.node.sendConvert == null) this.node.sendConvert = this.sendConvert; - + { + AnyValue props = conf; + if (conf != null && conf.getAnyValue("properties") != null) props = conf.getAnyValue("properties"); + if (props != null) { + String cryptorClass = props.getValue(WEBPARAM__CRYPTOR); + if (cryptorClass != null && !cryptorClass.isEmpty()) { + try { + this.cryptor = (Cryptor) Thread.currentThread().getContextClassLoader().loadClass(cryptorClass).getDeclaredConstructor().newInstance(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + } //存在WebSocketServlet,则此WebSocketNode必须是本地模式Service this.node.localEngine = new WebSocketEngine("WebSocketEngine-" + addr.getHostString() + ":" + addr.getPort() + "-[" + resourceName() + "]", - this.single, context, liveinterval, wsmaxconns, wsmaxbody, this.node, this.sendConvert, logger); + this.single, context, liveinterval, wsmaxconns, wsmaxbody, this.cryptor, this.node, this.sendConvert, logger); this.node.init(conf); this.node.localEngine.init(conf); + } @Override