From e250a593a7c2317371a0539c0874e7d41e28d3a3 Mon Sep 17 00:00:00 2001 From: Redkale <22250530@qq.com> Date: Thu, 29 Mar 2018 11:14:13 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84Transport.pollConnection?= =?UTF-8?q?=E4=B8=AD=E8=BF=9E=E6=8E=A5=E6=B1=A0=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/org/redkale/net/Transport.java | 230 ++++++++++-------- src/org/redkale/net/TransportFactory.java | 5 +- .../test/{net => http}/UploadTestServlet.java | 2 +- test/org/redkale/test/net/TransportTest.java | 57 +++++ 4 files changed, 195 insertions(+), 99 deletions(-) rename test/org/redkale/test/{net => http}/UploadTestServlet.java (96%) create mode 100644 test/org/redkale/test/net/TransportTest.java diff --git a/src/org/redkale/net/Transport.java b/src/org/redkale/net/Transport.java index 24e5cd35f..dcd018176 100644 --- a/src/org/redkale/net/Transport.java +++ b/src/org/redkale/net/Transport.java @@ -11,6 +11,7 @@ import java.nio.ByteBuffer; import java.nio.channels.*; import java.util.*; import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Supplier; import java.util.logging.Level; import javax.net.ssl.SSLContext; @@ -30,7 +31,7 @@ public final class Transport { public static final String DEFAULT_PROTOCOL = "TCP"; - protected static final int MAX_POOL_LIMIT = Runtime.getRuntime().availableProcessors() * 16; + protected static final int MAX_POOL_LIMIT = Runtime.getRuntime().availableProcessors() * 8; protected static final boolean supportTcpNoDelay; @@ -59,7 +60,8 @@ public final class Transport { protected final InetSocketAddress clientAddress; - protected TransportAddress[] transportAddres = new TransportAddress[0]; + //不可能为null + protected TransportAddress[] transportAddrs = new TransportAddress[0]; protected final ObjectPool bufferPool; @@ -68,8 +70,6 @@ public final class Transport { //负载均衡策略 protected final TransportStrategy strategy; - protected final ConcurrentHashMap> connPool = new ConcurrentHashMap<>(); - protected Transport(String name, String subprotocol, TransportFactory factory, final ObjectPool transportBufferPool, final AsynchronousChannelGroup transportChannelGroup, final SSLContext sslContext, final InetSocketAddress clientAddress, final Collection addresses, final TransportStrategy strategy) { @@ -95,16 +95,26 @@ public final class Transport { } public final InetSocketAddress[] updateRemoteAddresses(final Collection addresses) { - TransportAddress[] oldAddresses = this.transportAddres; - List list = new ArrayList<>(); - if (addresses != null) { - for (InetSocketAddress addr : addresses) { - if (clientAddress != null && clientAddress.equals(addr)) continue; - list.add(new TransportAddress(addr)); + final TransportAddress[] oldAddresses = this.transportAddrs; + synchronized (this) { + List list = new ArrayList<>(); + if (addresses != null) { + for (InetSocketAddress addr : addresses) { + if (clientAddress != null && clientAddress.equals(addr)) continue; + boolean hasold = false; + for (TransportAddress oldAddr : oldAddresses) { + if (oldAddr.getAddress().equals(addr)) { + list.add(oldAddr); + hasold = true; + break; + } + } + if (hasold) continue; + list.add(new TransportAddress(addr)); + } } + this.transportAddrs = list.toArray(new TransportAddress[list.size()]); } - this.transportAddres = list.toArray(new TransportAddress[list.size()]); - InetSocketAddress[] rs = new InetSocketAddress[oldAddresses.length]; for (int i = 0; i < rs.length; i++) { rs[i] = oldAddresses[i].getAddress(); @@ -114,14 +124,15 @@ public final class Transport { public final boolean addRemoteAddresses(final InetSocketAddress addr) { if (addr == null) return false; + if (clientAddress != null && clientAddress.equals(addr)) return false; synchronized (this) { - if (this.transportAddres == null) { - this.transportAddres = new TransportAddress[]{new TransportAddress(addr)}; + if (this.transportAddrs.length == 0) { + this.transportAddrs = new TransportAddress[]{new TransportAddress(addr)}; } else { - for (TransportAddress i : this.transportAddres) { + for (TransportAddress i : this.transportAddrs) { if (addr.equals(i.address)) return false; } - this.transportAddres = Utility.append(transportAddres, new TransportAddress(addr)); + this.transportAddrs = Utility.append(transportAddrs, new TransportAddress(addr)); } return true; } @@ -129,9 +140,8 @@ public final class Transport { public final boolean removeRemoteAddresses(InetSocketAddress addr) { if (addr == null) return false; - if (this.transportAddres == null) return false; synchronized (this) { - this.transportAddres = Utility.remove(transportAddres, new TransportAddress(addr)); + this.transportAddrs = Utility.remove(transportAddrs, new TransportAddress(addr)); } return true; } @@ -145,7 +155,11 @@ public final class Transport { } public void close() { - connPool.forEach((k, v) -> v.forEach(c -> c.dispose())); + TransportAddress[] taddrs = this.transportAddrs; + if (taddrs == null) return; + for (TransportAddress taddr : taddrs) { + if (taddr != null) taddr.dispose(); + } } public InetSocketAddress getClientAddress() { @@ -153,24 +167,27 @@ public final class Transport { } public TransportAddress[] getTransportAddresses() { - return transportAddres; + return transportAddrs; + } + + public TransportAddress findTransportAddress(SocketAddress addr) { + for (TransportAddress taddr : this.transportAddrs) { + if (taddr.address.equals(addr)) return taddr; + } + return null; } public InetSocketAddress[] getRemoteAddresses() { - InetSocketAddress[] rs = new InetSocketAddress[transportAddres.length]; + InetSocketAddress[] rs = new InetSocketAddress[transportAddrs.length]; for (int i = 0; i < rs.length; i++) { - rs[i] = transportAddres[i].getAddress(); + rs[i] = transportAddrs[i].getAddress(); } return rs; } - public ConcurrentHashMap> getAsyncConnectionPool() { - return connPool; - } - @Override public String toString() { - return Transport.class.getSimpleName() + "{name = " + name + ", protocol = " + protocol + ", clientAddress = " + clientAddress + ", remoteAddres = " + Arrays.toString(transportAddres) + "}"; + return Transport.class.getSimpleName() + "{name = " + name + ", protocol = " + protocol + ", clientAddress = " + clientAddress + ", remoteAddres = " + Arrays.toString(transportAddrs) + "}"; } public ByteBuffer pollBuffer() { @@ -189,76 +206,93 @@ public final class Transport { for (ByteBuffer buffer : buffers) offerBuffer(buffer); } + public AsynchronousChannelGroup getTransportChannelGroup() { + return group; + } + public boolean isTCP() { return tcp; } - public CompletableFuture pollConnection(SocketAddress addr) { - if (this.strategy != null) return strategy.pollConnection(addr, this); - if (addr == null && this.transportAddres.length == 1) addr = this.transportAddres[0].address; - final boolean rand = addr == null; - if (rand && this.transportAddres.length < 1) throw new RuntimeException("Transport (" + this.name + ") have no remoteAddress list"); + public CompletableFuture pollConnection(SocketAddress addr0) { + if (this.strategy != null) return strategy.pollConnection(addr0, this); + if (addr0 == null && this.transportAddrs.length == 1) addr0 = this.transportAddrs[0].address; + final SocketAddress addr = addr0; + final boolean rand = addr == null; //是否随机取地址 + if (rand && this.transportAddrs.length < 1) throw new RuntimeException("Transport (" + this.name + ") have no remoteAddress list"); try { - if (tcp) { - AsynchronousSocketChannel channel = null; - if (rand) { //取地址 - TransportAddress transportAddr; - boolean tryed = false; - for (int i = 0; i < transportAddres.length; i++) { - transportAddr = transportAddres[i]; - addr = transportAddr.address; - if (!transportAddr.enable) continue; - final BlockingQueue queue = transportAddr.conns; - if (!queue.isEmpty()) { - AsyncConnection conn; - while ((conn = queue.poll()) != null) { - if (conn.isOpen()) return CompletableFuture.completedFuture(conn); - } - } - tryed = true; - if (channel == null) { - channel = AsynchronousSocketChannel.open(group); - if (supportTcpNoDelay) channel.setOption(StandardSocketOptions.TCP_NODELAY, true); - } - try { - channel.connect(addr).get(2, TimeUnit.SECONDS); - transportAddr.enable = true; - break; - } catch (Exception iex) { - transportAddr.enable = false; - channel = null; - } - } - if (channel == null && !tryed) { - for (int i = 0; i < transportAddres.length; i++) { - transportAddr = transportAddres[i]; - addr = transportAddr.address; - if (channel == null) { - channel = AsynchronousSocketChannel.open(group); - if (supportTcpNoDelay) channel.setOption(StandardSocketOptions.TCP_NODELAY, true); - } - try { - channel.connect(addr).get(2, TimeUnit.SECONDS); - transportAddr.enable = true; - break; - } catch (Exception iex) { - transportAddr.enable = false; - channel = null; - } - } - } - } else { - return AsyncConnection.createTCP(group, sslContext, addr, supportTcpNoDelay, 6, 6); - } - if (channel == null) return CompletableFuture.completedFuture(null); - return CompletableFuture.completedFuture(AsyncConnection.create(channel, addr, 6, 6)); - } else { // UDP - if (rand) addr = this.transportAddres[0].address; + if (!tcp) { // UDP + SocketAddress udpaddr = rand ? this.transportAddrs[0].address : addr; DatagramChannel channel = DatagramChannel.open(); channel.configureBlocking(true); - channel.connect(addr); - return CompletableFuture.completedFuture(AsyncConnection.create(channel, addr, true, 6, 6)); + channel.connect(udpaddr); + return CompletableFuture.completedFuture(AsyncConnection.create(channel, udpaddr, true, factory.readTimeoutSecond, factory.writeTimeoutSecond)); } + if (!rand) { //指定地址 + TransportAddress taddr = findTransportAddress(addr); + if (taddr == null) { + return AsyncConnection.createTCP(group, sslContext, addr, supportTcpNoDelay, factory.readTimeoutSecond, factory.writeTimeoutSecond); + } + final BlockingQueue queue = taddr.conns; + if (!queue.isEmpty()) { + AsyncConnection conn; + while ((conn = queue.poll()) != null) { + if (conn.isOpen()) return CompletableFuture.completedFuture(conn); + } + } + return AsyncConnection.createTCP(group, sslContext, addr, supportTcpNoDelay, factory.readTimeoutSecond, factory.writeTimeoutSecond); + } + + //---------------------随机取地址------------------------ + //从连接池里取 + for (final TransportAddress taddr : this.transportAddrs) { + if (!taddr.enable) continue; + final BlockingQueue queue = taddr.conns; + if (!queue.isEmpty()) { + AsyncConnection conn; + while ((conn = queue.poll()) != null) { + if (conn.isOpen()) return CompletableFuture.completedFuture(conn); + } + } + } + //从可用/不可用的地址列表中创建连接 + AtomicInteger count = new AtomicInteger(this.transportAddrs.length); + CompletableFuture future = new CompletableFuture(); + for (final TransportAddress taddr : this.transportAddrs) { + if (future.isDone()) return future; + final AsynchronousSocketChannel channel = AsynchronousSocketChannel.open(group); + if (supportTcpNoDelay) channel.setOption(StandardSocketOptions.TCP_NODELAY, true); + channel.connect(taddr.address, taddr, new CompletionHandler() { + @Override + public void completed(Void result, TransportAddress attachment) { + taddr.enable = true; + AsyncConnection asyncConn = AsyncConnection.create(channel, attachment.address, factory.readTimeoutSecond, factory.writeTimeoutSecond); + if (future.isDone()) { + if (!attachment.conns.offer(asyncConn)) { + try { + channel.close(); + } catch (Exception e) { + } + } + } else { + future.complete(asyncConn); + } + } + + @Override + public void failed(Throwable exc, TransportAddress attachment) { + taddr.enable = false; + if (count.decrementAndGet() < 1) { + future.completeExceptionally(exc); + } + try { + channel.close(); + } catch (Exception e) { + } + } + }); + } + return future; } catch (Exception ex) { throw new RuntimeException("transport address = " + addr, ex); } @@ -267,12 +301,8 @@ public final class Transport { public void offerConnection(final boolean forceClose, AsyncConnection conn) { if (!forceClose && conn.isTCP()) { if (conn.isOpen()) { - BlockingQueue queue = connPool.get(conn.getRemoteAddress()); - if (queue == null) { - queue = new ArrayBlockingQueue<>(MAX_POOL_LIMIT); - connPool.put(conn.getRemoteAddress(), queue); - } - if (!queue.offer(conn)) conn.dispose(); + TransportAddress taddr = findTransportAddress(conn.getRemoteAddress()); + if (taddr == null || !taddr.conns.offer(conn)) conn.dispose(); } } else { conn.dispose(); @@ -344,11 +374,18 @@ public final class Transport { return enable; } - @ConvertColumn(ignore = true) + @ConvertDisabled public BlockingQueue getConns() { return conns; } + public void dispose() { + AsyncConnection conn; + while ((conn = conns.poll()) != null) { + conn.dispose(); + } + } + @Override public int hashCode() { return this.address.hashCode(); @@ -363,6 +400,7 @@ public final class Transport { return this.address.equals(other.address); } + @Override public String toString() { return JsonConvert.root().convertTo(this); } diff --git a/src/org/redkale/net/TransportFactory.java b/src/org/redkale/net/TransportFactory.java index 3c84c6ffb..383bee65c 100644 --- a/src/org/redkale/net/TransportFactory.java +++ b/src/org/redkale/net/TransportFactory.java @@ -326,8 +326,9 @@ public class TransportFactory { nulllist.add(ref); continue; } - List> list = new ArrayList<>(transport.getAsyncConnectionPool().values()); - for (final BlockingQueue queue : list) { + Transport.TransportAddress[] taddrs = transport.getTransportAddresses(); + for (final Transport.TransportAddress taddr : taddrs) { + final BlockingQueue queue = taddr.conns; AsyncConnection conn; while ((conn = queue.poll()) != null) { if (conn.getLastWriteTime() > timex && false) { //最近几秒内已经进行过IO操作 diff --git a/test/org/redkale/test/net/UploadTestServlet.java b/test/org/redkale/test/http/UploadTestServlet.java similarity index 96% rename from test/org/redkale/test/net/UploadTestServlet.java rename to test/org/redkale/test/http/UploadTestServlet.java index dfc918877..01bb550fe 100644 --- a/test/org/redkale/test/net/UploadTestServlet.java +++ b/test/org/redkale/test/http/UploadTestServlet.java @@ -3,7 +3,7 @@ * To change this template file, choose Tools | Templates * and open the template in the editor. */ -package org.redkale.test.net; +package org.redkale.test.http; import org.redkale.net.http.HttpServlet; import org.redkale.net.http.MultiPart; diff --git a/test/org/redkale/test/net/TransportTest.java b/test/org/redkale/test/net/TransportTest.java new file mode 100644 index 000000000..e13ade230 --- /dev/null +++ b/test/org/redkale/test/net/TransportTest.java @@ -0,0 +1,57 @@ +/* + * To change this license header, choose License Headers in Project Properties. + * To change this template file, choose Tools | Templates + * and open the template in the editor. + */ +package org.redkale.test.net; + +import java.net.InetSocketAddress; +import java.util.*; +import org.redkale.net.*; +import org.redkale.net.http.HttpServer; +import org.redkale.net.sncp.Sncp; +import org.redkale.util.AnyValue.DefaultAnyValue; + +/** + * + * @author zhangjx + */ +public class TransportTest { + + private static final String format = "%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%tL"; + + public static void main(String[] args) throws Throwable { + + List addrs = new ArrayList<>(); + addrs.add(new InetSocketAddress("127.0.0.1", 22001)); + addrs.add(new InetSocketAddress("127.0.0.1", 22002)); + addrs.add(new InetSocketAddress("127.0.0.1", 22003)); + addrs.add(new InetSocketAddress("127.0.0.1", 22004)); + for (InetSocketAddress servaddr : addrs) { + //if (servaddr.getPort() % 100 == 4) continue; + HttpServer server = new HttpServer(); + DefaultAnyValue servconf = DefaultAnyValue.create("port", servaddr.getPort()); + server.init(servconf); + server.start(); + } + addrs.add(new InetSocketAddress("127.0.0.1", 22005)); + Thread.sleep(1000); + TransportFactory factory = TransportFactory.create(10); + DefaultAnyValue conf = DefaultAnyValue.create(TransportFactory.NAME_PINGINTERVAL, 5); + factory.init(conf, Sncp.PING_BUFFER, Sncp.PONG_BUFFER.remaining()); + Transport transport = factory.createTransportTCP("", null, addrs); + System.out.println(String.format(format, System.currentTimeMillis())); + try { + AsyncConnection firstconn = transport.pollConnection(null).join(); + System.out.println(firstconn); + if (firstconn != null) transport.offerConnection(false, firstconn); + AsyncConnection conn = transport.pollConnection(null).join(); + System.out.println(conn + "-------应该与前值相同"); + conn = transport.pollConnection(null).join(); + System.out.println(conn + "-------应该与前值不同"); + } finally { + System.out.println(String.format(format, System.currentTimeMillis())); + } + } + +}