From b49c334f9f90d03bc1431adc15adba46fd53f4b1 Mon Sep 17 00:00:00 2001 From: Redkale <22250530@qq.com> Date: Wed, 28 Feb 2018 08:56:47 +0800 Subject: [PATCH] =?UTF-8?q?AsyncConnection=E5=A2=9E=E5=8A=A0SSLContext?= =?UTF-8?q?=E5=B1=9E=E6=80=A7=EF=BC=8C=E4=BE=BF=E4=BA=8E=E4=BB=A5=E5=90=8E?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0SSL=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/org/redkale/net/AsyncConnection.java | 41 +++++++++++++++++++---- src/org/redkale/net/ProtocolServer.java | 2 +- src/org/redkale/net/Transport.java | 12 ++++--- src/org/redkale/net/TransportFactory.java | 40 ++++++++++++++++------ 4 files changed, 73 insertions(+), 22 deletions(-) diff --git a/src/org/redkale/net/AsyncConnection.java b/src/org/redkale/net/AsyncConnection.java index b2845073a..6a75cba0d 100644 --- a/src/org/redkale/net/AsyncConnection.java +++ b/src/org/redkale/net/AsyncConnection.java @@ -12,6 +12,7 @@ import java.nio.channels.*; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicLong; +import javax.net.ssl.SSLContext; /** * @@ -22,6 +23,8 @@ import java.util.concurrent.atomic.AtomicLong; */ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCloseable { + protected SSLContext sslContext; + protected Map attributes; //用于存储绑定在Connection上的对象集合 protected Object subobject; //用于存储绑定在Connection上的对象, 同attributes, 只绑定单个对象时尽量使用subobject而非attributes @@ -134,13 +137,30 @@ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCl */ public static CompletableFuture createTCP(final AsynchronousChannelGroup group, final SocketAddress address, final int readTimeoutSecond, final int writeTimeoutSecond) { - return createTCP(group, address, false, readTimeoutSecond, writeTimeoutSecond); + return createTCP(group, null, address, false, readTimeoutSecond, writeTimeoutSecond); } /** * 创建TCP协议客户端连接 * * @param address 连接点子 + * @param sslContext SSLContext + * @param group 连接AsynchronousChannelGroup + * @param readTimeoutSecond 读取超时秒数 + * @param writeTimeoutSecond 写入超时秒数 + * + * @return 连接CompletableFuture + */ + public static CompletableFuture createTCP(final AsynchronousChannelGroup group, final SSLContext sslContext, + final SocketAddress address, final int readTimeoutSecond, final int writeTimeoutSecond) { + return createTCP(group, sslContext, address, false, readTimeoutSecond, writeTimeoutSecond); + } + + /** + * 创建TCP协议客户端连接 + * + * @param address 连接点子 + * @param sslContext SSLContext * @param group 连接AsynchronousChannelGroup * @param noDelay TcpNoDelay * @param readTimeoutSecond 读取超时秒数 @@ -148,8 +168,8 @@ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCl * * @return 连接CompletableFuture */ - public static CompletableFuture createTCP(final AsynchronousChannelGroup group, final SocketAddress address, - final boolean noDelay, final int readTimeoutSecond, final int writeTimeoutSecond) { + public static CompletableFuture createTCP(final AsynchronousChannelGroup group, final SSLContext sslContext, + final SocketAddress address, final boolean noDelay, final int readTimeoutSecond, final int writeTimeoutSecond) { final CompletableFuture future = new CompletableFuture(); try { final AsynchronousSocketChannel channel = AsynchronousSocketChannel.open(group); @@ -162,7 +182,7 @@ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCl } catch (IOException e) { } } - future.complete(create(channel, address, readTimeoutSecond, writeTimeoutSecond)); + future.complete(create(channel, sslContext, address, readTimeoutSecond, writeTimeoutSecond)); } @Override @@ -482,8 +502,10 @@ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCl private final SocketAddress remoteAddress; - public AIOTCPAsyncConnection(final AsynchronousSocketChannel ch, final SocketAddress addr0, final int readTimeoutSecond0, final int writeTimeoutSecond0) { + public AIOTCPAsyncConnection(final AsynchronousSocketChannel ch, SSLContext sslContext, + final SocketAddress addr0, final int readTimeoutSecond0, final int writeTimeoutSecond0) { this.channel = ch; + this.sslContext = sslContext; this.readTimeoutSecond = readTimeoutSecond0; this.writeTimeoutSecond = writeTimeoutSecond0; SocketAddress addr = addr0; @@ -603,7 +625,14 @@ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCl } public static AsyncConnection create(final AsynchronousSocketChannel ch, final SocketAddress addr0, final int readTimeoutSecond, final int writeTimeoutSecond) { - return new AIOTCPAsyncConnection(ch, addr0, readTimeoutSecond, writeTimeoutSecond); + return new AIOTCPAsyncConnection(ch, null, addr0, readTimeoutSecond, writeTimeoutSecond); } + public static AsyncConnection create(final AsynchronousSocketChannel ch, SSLContext sslContext, final SocketAddress addr0, final int readTimeoutSecond, final int writeTimeoutSecond) { + return new AIOTCPAsyncConnection(ch, sslContext, addr0, readTimeoutSecond, writeTimeoutSecond); + } + + public static AsyncConnection create(final AsynchronousSocketChannel ch, final SocketAddress addr0, final Context context) { + return new AIOTCPAsyncConnection(ch, context.sslContext, addr0, context.readTimeoutSecond, context.writeTimeoutSecond); + } } diff --git a/src/org/redkale/net/ProtocolServer.java b/src/org/redkale/net/ProtocolServer.java index b166f0a66..57cdbdca9 100644 --- a/src/org/redkale/net/ProtocolServer.java +++ b/src/org/redkale/net/ProtocolServer.java @@ -214,7 +214,7 @@ public abstract class ProtocolServer { } createCounter.incrementAndGet(); livingCounter.incrementAndGet(); - AsyncConnection conn = AsyncConnection.create(channel, null, context.readTimeoutSecond, context.writeTimeoutSecond); + AsyncConnection conn = AsyncConnection.create(channel, null, context); conn.livingCounter = livingCounter; conn.closedCounter = closedCounter; context.runAsync(new PrepareRunner(context, conn, null)); diff --git a/src/org/redkale/net/Transport.java b/src/org/redkale/net/Transport.java index 69278c6c4..24e5cd35f 100644 --- a/src/org/redkale/net/Transport.java +++ b/src/org/redkale/net/Transport.java @@ -13,6 +13,7 @@ import java.util.*; import java.util.concurrent.*; import java.util.function.Supplier; import java.util.logging.Level; +import javax.net.ssl.SSLContext; import org.redkale.convert.*; import org.redkale.convert.json.JsonConvert; import org.redkale.util.*; @@ -62,20 +63,22 @@ public final class Transport { protected final ObjectPool bufferPool; + protected final SSLContext sslContext; + //负载均衡策略 protected final TransportStrategy strategy; protected final ConcurrentHashMap> connPool = new ConcurrentHashMap<>(); protected Transport(String name, String subprotocol, TransportFactory factory, final ObjectPool transportBufferPool, - final AsynchronousChannelGroup transportChannelGroup, final InetSocketAddress clientAddress, + final AsynchronousChannelGroup transportChannelGroup, final SSLContext sslContext, final InetSocketAddress clientAddress, final Collection addresses, final TransportStrategy strategy) { - this(name, DEFAULT_PROTOCOL, subprotocol, factory, transportBufferPool, transportChannelGroup, clientAddress, addresses, strategy); + this(name, DEFAULT_PROTOCOL, subprotocol, factory, transportBufferPool, transportChannelGroup, sslContext, clientAddress, addresses, strategy); } protected Transport(String name, String protocol, String subprotocol, final TransportFactory factory, final ObjectPool transportBufferPool, - final AsynchronousChannelGroup transportChannelGroup, final InetSocketAddress clientAddress, + final AsynchronousChannelGroup transportChannelGroup, final SSLContext sslContext, final InetSocketAddress clientAddress, final Collection addresses, final TransportStrategy strategy) { this.name = name; this.subprotocol = subprotocol == null ? "" : subprotocol.trim(); @@ -84,6 +87,7 @@ public final class Transport { factory.transportReferences.add(new WeakReference<>(this)); this.tcp = "TCP".equalsIgnoreCase(protocol); this.group = transportChannelGroup; + this.sslContext = sslContext; this.bufferPool = transportBufferPool; this.clientAddress = clientAddress; this.strategy = strategy; @@ -244,7 +248,7 @@ public final class Transport { } } } else { - return AsyncConnection.createTCP(group, addr, supportTcpNoDelay, 6, 6); + return AsyncConnection.createTCP(group, sslContext, addr, supportTcpNoDelay, 6, 6); } if (channel == null) return CompletableFuture.completedFuture(null); return CompletableFuture.completedFuture(AsyncConnection.create(channel, addr, 6, 6)); diff --git a/src/org/redkale/net/TransportFactory.java b/src/org/redkale/net/TransportFactory.java index d69991e02..3c84c6ffb 100644 --- a/src/org/redkale/net/TransportFactory.java +++ b/src/org/redkale/net/TransportFactory.java @@ -16,6 +16,7 @@ import java.util.concurrent.atomic.*; import java.util.function.Supplier; import java.util.logging.*; import java.util.stream.Collectors; +import javax.net.ssl.SSLContext; import org.redkale.service.Service; import org.redkale.util.*; @@ -70,6 +71,8 @@ public class TransportFactory { //ping的定时器 private ScheduledThreadPoolExecutor pingScheduler; + protected SSLContext sslContext; + //ping的内容 private ByteBuffer pingBuffer; @@ -80,18 +83,19 @@ public class TransportFactory { protected final TransportStrategy strategy; protected TransportFactory(ExecutorService executor, ObjectPool bufferPool, AsynchronousChannelGroup channelGroup, - int readTimeoutSecond, int writeTimeoutSecond, final TransportStrategy strategy) { + SSLContext sslContext, int readTimeoutSecond, int writeTimeoutSecond, final TransportStrategy strategy) { this.executor = executor; this.bufferPool = bufferPool; this.channelGroup = channelGroup; + this.sslContext = sslContext; this.readTimeoutSecond = readTimeoutSecond; this.writeTimeoutSecond = writeTimeoutSecond; this.strategy = strategy; } protected TransportFactory(ExecutorService executor, ObjectPool bufferPool, AsynchronousChannelGroup channelGroup, - int readTimeoutSecond, int writeTimeoutSecond) { - this(executor, bufferPool, channelGroup, readTimeoutSecond, writeTimeoutSecond, null); + SSLContext sslContext, int readTimeoutSecond, int writeTimeoutSecond) { + this(executor, bufferPool, channelGroup, sslContext, readTimeoutSecond, writeTimeoutSecond, null); } public void init(AnyValue conf, ByteBuffer pingBuffer, int pongLength) { @@ -146,30 +150,44 @@ public class TransportFactory { } public static TransportFactory create(ExecutorService executor, ObjectPool bufferPool, AsynchronousChannelGroup channelGroup) { - return new TransportFactory(executor, bufferPool, channelGroup, DEFAULT_READTIMEOUTSECOND, DEFAULT_WRITETIMEOUTSECOND, null); + return new TransportFactory(executor, bufferPool, channelGroup, null, DEFAULT_READTIMEOUTSECOND, DEFAULT_WRITETIMEOUTSECOND, null); } public static TransportFactory create(ExecutorService executor, ObjectPool bufferPool, AsynchronousChannelGroup channelGroup, int readTimeoutSecond, int writeTimeoutSecond) { - return new TransportFactory(executor, bufferPool, channelGroup, readTimeoutSecond, writeTimeoutSecond, null); + return new TransportFactory(executor, bufferPool, channelGroup, null, readTimeoutSecond, writeTimeoutSecond, null); } public static TransportFactory create(ExecutorService executor, ObjectPool bufferPool, AsynchronousChannelGroup channelGroup, int readTimeoutSecond, int writeTimeoutSecond, final TransportStrategy strategy) { - return new TransportFactory(executor, bufferPool, channelGroup, readTimeoutSecond, writeTimeoutSecond, strategy); + return new TransportFactory(executor, bufferPool, channelGroup, null, readTimeoutSecond, writeTimeoutSecond, strategy); + } + + public static TransportFactory create(ExecutorService executor, ObjectPool bufferPool, AsynchronousChannelGroup channelGroup, SSLContext sslContext) { + return new TransportFactory(executor, bufferPool, channelGroup, sslContext, DEFAULT_READTIMEOUTSECOND, DEFAULT_WRITETIMEOUTSECOND, null); + } + + public static TransportFactory create(ExecutorService executor, ObjectPool bufferPool, AsynchronousChannelGroup channelGroup, + SSLContext sslContext, int readTimeoutSecond, int writeTimeoutSecond) { + return new TransportFactory(executor, bufferPool, channelGroup, sslContext, readTimeoutSecond, writeTimeoutSecond, null); + } + + public static TransportFactory create(ExecutorService executor, ObjectPool bufferPool, AsynchronousChannelGroup channelGroup, + SSLContext sslContext, int readTimeoutSecond, int writeTimeoutSecond, final TransportStrategy strategy) { + return new TransportFactory(executor, bufferPool, channelGroup, sslContext, readTimeoutSecond, writeTimeoutSecond, strategy); } public Transport createTransportTCP(String name, final InetSocketAddress clientAddress, final Collection addresses) { - return new Transport(name, "TCP", "", this, this.bufferPool, this.channelGroup, clientAddress, addresses, strategy); + return new Transport(name, "TCP", "", this, this.bufferPool, this.channelGroup, this.sslContext, clientAddress, addresses, strategy); } public Transport createTransport(String name, String protocol, final InetSocketAddress clientAddress, final Collection addresses) { - return new Transport(name, protocol, "", this, this.bufferPool, this.channelGroup, clientAddress, addresses, strategy); + return new Transport(name, protocol, "", this, this.bufferPool, this.channelGroup, this.sslContext, clientAddress, addresses, strategy); } public Transport createTransport(String name, String protocol, String subprotocol, final InetSocketAddress clientAddress, final Collection addresses) { - return new Transport(name, protocol, subprotocol, this, this.bufferPool, this.channelGroup, clientAddress, addresses, strategy); + return new Transport(name, protocol, subprotocol, this, this.bufferPool, this.channelGroup, this.sslContext, clientAddress, addresses, strategy); } public String findGroupName(InetSocketAddress addr) { @@ -250,14 +268,14 @@ public class TransportFactory { } if (info == null) info = new TransportGroupInfo("TCP"); if (sncpAddress != null) addresses.remove(sncpAddress); - return new Transport(groups.stream().sorted().collect(Collectors.joining(";")), info.protocol, info.subprotocol, this, this.bufferPool, this.channelGroup, sncpAddress, addresses, this.strategy); + return new Transport(groups.stream().sorted().collect(Collectors.joining(";")), info.protocol, info.subprotocol, this, this.bufferPool, this.channelGroup, this.sslContext, sncpAddress, addresses, this.strategy); } private Transport loadTransport(final String groupName, InetSocketAddress sncpAddress) { if (groupName == null) return null; TransportGroupInfo info = groupInfos.get(groupName); if (info == null) return null; - return new Transport(groupName, info.protocol, info.subprotocol, this, this.bufferPool, this.channelGroup, sncpAddress, info.addresses, this.strategy); + return new Transport(groupName, info.protocol, info.subprotocol, this, this.bufferPool, this.channelGroup, this.sslContext, sncpAddress, info.addresses, this.strategy); } public ExecutorService getExecutor() {