AsyncConnection增加SSLContext属性,便于以后增加SSL功能

This commit is contained in:
Redkale
2018-02-28 08:56:47 +08:00
parent a374e1278b
commit b49c334f9f
4 changed files with 73 additions and 22 deletions

View File

@@ -12,6 +12,7 @@ import java.nio.channels.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicLong;
import javax.net.ssl.SSLContext;
/**
*
@@ -22,6 +23,8 @@ import java.util.concurrent.atomic.AtomicLong;
*/
public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCloseable {
protected SSLContext sslContext;
protected Map<String, Object> attributes; //用于存储绑定在Connection上的对象集合
protected Object subobject; //用于存储绑定在Connection上的对象 同attributes 只绑定单个对象时尽量使用subobject而非attributes
@@ -134,13 +137,30 @@ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCl
*/
public static CompletableFuture<AsyncConnection> createTCP(final AsynchronousChannelGroup group, final SocketAddress address,
final int readTimeoutSecond, final int writeTimeoutSecond) {
return createTCP(group, address, false, readTimeoutSecond, writeTimeoutSecond);
return createTCP(group, null, address, false, readTimeoutSecond, writeTimeoutSecond);
}
/**
* 创建TCP协议客户端连接
*
* @param address 连接点子
* @param sslContext SSLContext
* @param group 连接AsynchronousChannelGroup
* @param readTimeoutSecond 读取超时秒数
* @param writeTimeoutSecond 写入超时秒数
*
* @return 连接CompletableFuture
*/
public static CompletableFuture<AsyncConnection> createTCP(final AsynchronousChannelGroup group, final SSLContext sslContext,
final SocketAddress address, final int readTimeoutSecond, final int writeTimeoutSecond) {
return createTCP(group, sslContext, address, false, readTimeoutSecond, writeTimeoutSecond);
}
/**
* 创建TCP协议客户端连接
*
* @param address 连接点子
* @param sslContext SSLContext
* @param group 连接AsynchronousChannelGroup
* @param noDelay TcpNoDelay
* @param readTimeoutSecond 读取超时秒数
@@ -148,8 +168,8 @@ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCl
*
* @return 连接CompletableFuture
*/
public static CompletableFuture<AsyncConnection> createTCP(final AsynchronousChannelGroup group, final SocketAddress address,
final boolean noDelay, final int readTimeoutSecond, final int writeTimeoutSecond) {
public static CompletableFuture<AsyncConnection> createTCP(final AsynchronousChannelGroup group, final SSLContext sslContext,
final SocketAddress address, final boolean noDelay, final int readTimeoutSecond, final int writeTimeoutSecond) {
final CompletableFuture future = new CompletableFuture();
try {
final AsynchronousSocketChannel channel = AsynchronousSocketChannel.open(group);
@@ -162,7 +182,7 @@ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCl
} catch (IOException e) {
}
}
future.complete(create(channel, address, readTimeoutSecond, writeTimeoutSecond));
future.complete(create(channel, sslContext, address, readTimeoutSecond, writeTimeoutSecond));
}
@Override
@@ -482,8 +502,10 @@ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCl
private final SocketAddress remoteAddress;
public AIOTCPAsyncConnection(final AsynchronousSocketChannel ch, final SocketAddress addr0, final int readTimeoutSecond0, final int writeTimeoutSecond0) {
public AIOTCPAsyncConnection(final AsynchronousSocketChannel ch, SSLContext sslContext,
final SocketAddress addr0, final int readTimeoutSecond0, final int writeTimeoutSecond0) {
this.channel = ch;
this.sslContext = sslContext;
this.readTimeoutSecond = readTimeoutSecond0;
this.writeTimeoutSecond = writeTimeoutSecond0;
SocketAddress addr = addr0;
@@ -603,7 +625,14 @@ public abstract class AsyncConnection implements AsynchronousByteChannel, AutoCl
}
public static AsyncConnection create(final AsynchronousSocketChannel ch, final SocketAddress addr0, final int readTimeoutSecond, final int writeTimeoutSecond) {
return new AIOTCPAsyncConnection(ch, addr0, readTimeoutSecond, writeTimeoutSecond);
return new AIOTCPAsyncConnection(ch, null, addr0, readTimeoutSecond, writeTimeoutSecond);
}
public static AsyncConnection create(final AsynchronousSocketChannel ch, SSLContext sslContext, final SocketAddress addr0, final int readTimeoutSecond, final int writeTimeoutSecond) {
return new AIOTCPAsyncConnection(ch, sslContext, addr0, readTimeoutSecond, writeTimeoutSecond);
}
public static AsyncConnection create(final AsynchronousSocketChannel ch, final SocketAddress addr0, final Context context) {
return new AIOTCPAsyncConnection(ch, context.sslContext, addr0, context.readTimeoutSecond, context.writeTimeoutSecond);
}
}

View File

@@ -214,7 +214,7 @@ public abstract class ProtocolServer {
}
createCounter.incrementAndGet();
livingCounter.incrementAndGet();
AsyncConnection conn = AsyncConnection.create(channel, null, context.readTimeoutSecond, context.writeTimeoutSecond);
AsyncConnection conn = AsyncConnection.create(channel, null, context);
conn.livingCounter = livingCounter;
conn.closedCounter = closedCounter;
context.runAsync(new PrepareRunner(context, conn, null));

View File

@@ -13,6 +13,7 @@ import java.util.*;
import java.util.concurrent.*;
import java.util.function.Supplier;
import java.util.logging.Level;
import javax.net.ssl.SSLContext;
import org.redkale.convert.*;
import org.redkale.convert.json.JsonConvert;
import org.redkale.util.*;
@@ -62,20 +63,22 @@ public final class Transport {
protected final ObjectPool<ByteBuffer> bufferPool;
protected final SSLContext sslContext;
//负载均衡策略
protected final TransportStrategy strategy;
protected final ConcurrentHashMap<SocketAddress, BlockingQueue<AsyncConnection>> connPool = new ConcurrentHashMap<>();
protected Transport(String name, String subprotocol, TransportFactory factory, final ObjectPool<ByteBuffer> transportBufferPool,
final AsynchronousChannelGroup transportChannelGroup, final InetSocketAddress clientAddress,
final AsynchronousChannelGroup transportChannelGroup, final SSLContext sslContext, final InetSocketAddress clientAddress,
final Collection<InetSocketAddress> addresses, final TransportStrategy strategy) {
this(name, DEFAULT_PROTOCOL, subprotocol, factory, transportBufferPool, transportChannelGroup, clientAddress, addresses, strategy);
this(name, DEFAULT_PROTOCOL, subprotocol, factory, transportBufferPool, transportChannelGroup, sslContext, clientAddress, addresses, strategy);
}
protected Transport(String name, String protocol, String subprotocol,
final TransportFactory factory, final ObjectPool<ByteBuffer> transportBufferPool,
final AsynchronousChannelGroup transportChannelGroup, final InetSocketAddress clientAddress,
final AsynchronousChannelGroup transportChannelGroup, final SSLContext sslContext, final InetSocketAddress clientAddress,
final Collection<InetSocketAddress> addresses, final TransportStrategy strategy) {
this.name = name;
this.subprotocol = subprotocol == null ? "" : subprotocol.trim();
@@ -84,6 +87,7 @@ public final class Transport {
factory.transportReferences.add(new WeakReference<>(this));
this.tcp = "TCP".equalsIgnoreCase(protocol);
this.group = transportChannelGroup;
this.sslContext = sslContext;
this.bufferPool = transportBufferPool;
this.clientAddress = clientAddress;
this.strategy = strategy;
@@ -244,7 +248,7 @@ public final class Transport {
}
}
} else {
return AsyncConnection.createTCP(group, addr, supportTcpNoDelay, 6, 6);
return AsyncConnection.createTCP(group, sslContext, addr, supportTcpNoDelay, 6, 6);
}
if (channel == null) return CompletableFuture.completedFuture(null);
return CompletableFuture.completedFuture(AsyncConnection.create(channel, addr, 6, 6));

View File

@@ -16,6 +16,7 @@ import java.util.concurrent.atomic.*;
import java.util.function.Supplier;
import java.util.logging.*;
import java.util.stream.Collectors;
import javax.net.ssl.SSLContext;
import org.redkale.service.Service;
import org.redkale.util.*;
@@ -70,6 +71,8 @@ public class TransportFactory {
//ping的定时器
private ScheduledThreadPoolExecutor pingScheduler;
protected SSLContext sslContext;
//ping的内容
private ByteBuffer pingBuffer;
@@ -80,18 +83,19 @@ public class TransportFactory {
protected final TransportStrategy strategy;
protected TransportFactory(ExecutorService executor, ObjectPool<ByteBuffer> bufferPool, AsynchronousChannelGroup channelGroup,
int readTimeoutSecond, int writeTimeoutSecond, final TransportStrategy strategy) {
SSLContext sslContext, int readTimeoutSecond, int writeTimeoutSecond, final TransportStrategy strategy) {
this.executor = executor;
this.bufferPool = bufferPool;
this.channelGroup = channelGroup;
this.sslContext = sslContext;
this.readTimeoutSecond = readTimeoutSecond;
this.writeTimeoutSecond = writeTimeoutSecond;
this.strategy = strategy;
}
protected TransportFactory(ExecutorService executor, ObjectPool<ByteBuffer> bufferPool, AsynchronousChannelGroup channelGroup,
int readTimeoutSecond, int writeTimeoutSecond) {
this(executor, bufferPool, channelGroup, readTimeoutSecond, writeTimeoutSecond, null);
SSLContext sslContext, int readTimeoutSecond, int writeTimeoutSecond) {
this(executor, bufferPool, channelGroup, sslContext, readTimeoutSecond, writeTimeoutSecond, null);
}
public void init(AnyValue conf, ByteBuffer pingBuffer, int pongLength) {
@@ -146,30 +150,44 @@ public class TransportFactory {
}
public static TransportFactory create(ExecutorService executor, ObjectPool<ByteBuffer> bufferPool, AsynchronousChannelGroup channelGroup) {
return new TransportFactory(executor, bufferPool, channelGroup, DEFAULT_READTIMEOUTSECOND, DEFAULT_WRITETIMEOUTSECOND, null);
return new TransportFactory(executor, bufferPool, channelGroup, null, DEFAULT_READTIMEOUTSECOND, DEFAULT_WRITETIMEOUTSECOND, null);
}
public static TransportFactory create(ExecutorService executor, ObjectPool<ByteBuffer> bufferPool, AsynchronousChannelGroup channelGroup,
int readTimeoutSecond, int writeTimeoutSecond) {
return new TransportFactory(executor, bufferPool, channelGroup, readTimeoutSecond, writeTimeoutSecond, null);
return new TransportFactory(executor, bufferPool, channelGroup, null, readTimeoutSecond, writeTimeoutSecond, null);
}
public static TransportFactory create(ExecutorService executor, ObjectPool<ByteBuffer> bufferPool, AsynchronousChannelGroup channelGroup,
int readTimeoutSecond, int writeTimeoutSecond, final TransportStrategy strategy) {
return new TransportFactory(executor, bufferPool, channelGroup, readTimeoutSecond, writeTimeoutSecond, strategy);
return new TransportFactory(executor, bufferPool, channelGroup, null, readTimeoutSecond, writeTimeoutSecond, strategy);
}
public static TransportFactory create(ExecutorService executor, ObjectPool<ByteBuffer> bufferPool, AsynchronousChannelGroup channelGroup, SSLContext sslContext) {
return new TransportFactory(executor, bufferPool, channelGroup, sslContext, DEFAULT_READTIMEOUTSECOND, DEFAULT_WRITETIMEOUTSECOND, null);
}
public static TransportFactory create(ExecutorService executor, ObjectPool<ByteBuffer> bufferPool, AsynchronousChannelGroup channelGroup,
SSLContext sslContext, int readTimeoutSecond, int writeTimeoutSecond) {
return new TransportFactory(executor, bufferPool, channelGroup, sslContext, readTimeoutSecond, writeTimeoutSecond, null);
}
public static TransportFactory create(ExecutorService executor, ObjectPool<ByteBuffer> bufferPool, AsynchronousChannelGroup channelGroup,
SSLContext sslContext, int readTimeoutSecond, int writeTimeoutSecond, final TransportStrategy strategy) {
return new TransportFactory(executor, bufferPool, channelGroup, sslContext, readTimeoutSecond, writeTimeoutSecond, strategy);
}
public Transport createTransportTCP(String name, final InetSocketAddress clientAddress, final Collection<InetSocketAddress> addresses) {
return new Transport(name, "TCP", "", this, this.bufferPool, this.channelGroup, clientAddress, addresses, strategy);
return new Transport(name, "TCP", "", this, this.bufferPool, this.channelGroup, this.sslContext, clientAddress, addresses, strategy);
}
public Transport createTransport(String name, String protocol, final InetSocketAddress clientAddress, final Collection<InetSocketAddress> addresses) {
return new Transport(name, protocol, "", this, this.bufferPool, this.channelGroup, clientAddress, addresses, strategy);
return new Transport(name, protocol, "", this, this.bufferPool, this.channelGroup, this.sslContext, clientAddress, addresses, strategy);
}
public Transport createTransport(String name, String protocol, String subprotocol,
final InetSocketAddress clientAddress, final Collection<InetSocketAddress> addresses) {
return new Transport(name, protocol, subprotocol, this, this.bufferPool, this.channelGroup, clientAddress, addresses, strategy);
return new Transport(name, protocol, subprotocol, this, this.bufferPool, this.channelGroup, this.sslContext, clientAddress, addresses, strategy);
}
public String findGroupName(InetSocketAddress addr) {
@@ -250,14 +268,14 @@ public class TransportFactory {
}
if (info == null) info = new TransportGroupInfo("TCP");
if (sncpAddress != null) addresses.remove(sncpAddress);
return new Transport(groups.stream().sorted().collect(Collectors.joining(";")), info.protocol, info.subprotocol, this, this.bufferPool, this.channelGroup, sncpAddress, addresses, this.strategy);
return new Transport(groups.stream().sorted().collect(Collectors.joining(";")), info.protocol, info.subprotocol, this, this.bufferPool, this.channelGroup, this.sslContext, sncpAddress, addresses, this.strategy);
}
private Transport loadTransport(final String groupName, InetSocketAddress sncpAddress) {
if (groupName == null) return null;
TransportGroupInfo info = groupInfos.get(groupName);
if (info == null) return null;
return new Transport(groupName, info.protocol, info.subprotocol, this, this.bufferPool, this.channelGroup, sncpAddress, info.addresses, this.strategy);
return new Transport(groupName, info.protocol, info.subprotocol, this, this.bufferPool, this.channelGroup, this.sslContext, sncpAddress, info.addresses, this.strategy);
}
public ExecutorService getExecutor() {