diff --git a/src/main/java/org/redkale/net/client/Client.java b/src/main/java/org/redkale/net/client/Client.java index 9c2c9781e..3588ab505 100644 --- a/src/main/java/org/redkale/net/client/Client.java +++ b/src/main/java/org/redkale/net/client/Client.java @@ -6,7 +6,7 @@ package org.redkale.net.client; import java.net.SocketAddress; -import java.util.Queue; +import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.*; import java.util.function.*; @@ -254,6 +254,50 @@ public abstract class Client, R extends ClientR return conn.writeChannel(request, respTransfer); } + public final CompletableFuture> sendAsync(R[] requests) { + for (R request : requests) { + if (request.workThread == null) { + request.workThread = WorkThread.currentWorkThread(); + } + } + return connect().thenCompose(conn -> writeChannel(conn, requests)); + } + + public final CompletableFuture> sendAsync(R[] requests, Function respTransfer) { + for (R request : requests) { + if (request.workThread == null) { + request.workThread = WorkThread.currentWorkThread(); + } + } + return connect().thenCompose(conn -> writeChannel(conn, requests, respTransfer)); + } + + public final CompletableFuture> sendAsync(SocketAddress addr, R[] requests) { + for (R request : requests) { + if (request.workThread == null) { + request.workThread = WorkThread.currentWorkThread(); + } + } + return connect(addr).thenCompose(conn -> writeChannel(conn, requests)); + } + + public final CompletableFuture> sendAsync(SocketAddress addr, R[] requests, Function respTransfer) { + for (R request : requests) { + if (request.workThread == null) { + request.workThread = WorkThread.currentWorkThread(); + } + } + return connect(addr).thenCompose(conn -> writeChannel(conn, requests, respTransfer)); + } + + protected CompletableFuture> writeChannel(ClientConnection conn, R[] requests) { + return conn.writeChannel(requests); + } + + protected CompletableFuture> writeChannel(ClientConnection conn, R[] requests, Function respTransfer) { + return conn.writeChannel(requests, respTransfer); + } + private C createConnection(int index, AsyncConnection channel) { C conn = createClientConnection(index, channel); if (!channel.isReadPending()) { diff --git a/src/main/java/org/redkale/net/client/ClientConnection.java b/src/main/java/org/redkale/net/client/ClientConnection.java index d679b8c37..746783ce8 100644 --- a/src/main/java/org/redkale/net/client/ClientConnection.java +++ b/src/main/java/org/redkale/net/client/ClientConnection.java @@ -9,14 +9,14 @@ import java.io.Serializable; import java.net.SocketAddress; import java.nio.ByteBuffer; import java.nio.channels.*; -import java.util.Iterator; +import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.*; import java.util.concurrent.locks.ReentrantLock; import java.util.function.*; import org.redkale.annotation.*; import org.redkale.net.*; -import org.redkale.util.ByteArray; +import org.redkale.util.*; /** * 注意: 要确保AsyncConnection的读写过程都必须在channel.ioThread中运行 @@ -59,7 +59,7 @@ public abstract class ClientConnection implements Co @Override public void completed(Integer result, ClientConnection attachment) { - + } @Override @@ -112,6 +112,10 @@ public abstract class ClientConnection implements Co return writeChannel(request, null); } + protected final CompletableFuture> writeChannel(R[] requests) { + return writeChannel(requests, null); + } + //respTransfer只会在ClientCodec的读线程里调用 protected final CompletableFuture writeChannel(R request, Function respTransfer) { request.respTransfer = respTransfer; @@ -135,6 +139,36 @@ public abstract class ClientConnection implements Co return respFuture; } + //respTransfer只会在ClientCodec的读线程里调用 + protected final CompletableFuture> writeChannel(R[] requests, Function respTransfer) { + ClientFuture[] respFutures = new ClientFuture[requests.length]; + int rts = this.channel.getReadTimeoutSeconds(); + for (int i = 0; i < respFutures.length; i++) { + R request = requests[i]; + request.respTransfer = respTransfer; + ClientFuture respFuture = createClientFuture(requests[i]); + respFutures[i] = respFuture; + if (rts > 0 && !request.isCloseType()) { + respFuture.setTimeout(client.timeoutScheduler.schedule(respFuture, rts, TimeUnit.SECONDS)); + } + } + respWaitingCounter.add(respFutures.length);//放在writeChannelInWriteThread计数会延迟,导致不准确 + + writeLock.lock(); + try { + for (ClientFuture respFuture : respFutures) { + offerRespFuture(respFuture); + if (pauseWriting.get()) { + pauseRequests.add(respFuture); + } + } + sendRequestInLocking(respFutures); + } finally { + writeLock.unlock(); + } + return Utility.allOfFutures(respFutures); + } + private void sendRequestInLocking(R request, ClientFuture respFuture) { if (true) { //新方式 ByteArray array = arrayThreadLocal.get(); @@ -170,6 +204,26 @@ public abstract class ClientConnection implements Co } } + private void sendRequestInLocking(ClientFuture[] respFutures) { + ByteArray array = arrayThreadLocal.get(); + array.clear(); + for (ClientFuture respFuture : respFutures) { + if (pauseWriting.get()) { + pauseRequests.add(respFuture); + } else { + ClientRequest request = respFuture.request; + request.writeTo(this, array); + if (request.isCompleted()) { + doneRequestCounter.increment(); + } else { //还剩半包没发送完 + pauseWriting.set(true); + currHalfWriteFuture = respFuture; + } + } + } + channel.fastWrite(array.getBytes(), writeHandler); + } + //发送半包和积压的请求数据包 void sendHalfWriteInReadThread(R request, Throwable halfRequestExc) { writeLock.lock(); @@ -198,7 +252,12 @@ public abstract class ClientConnection implements Co return CompletableFuture.failedFuture(new RuntimeException("ClientVirtualRequest must be virtualType = true")); } ClientFuture respFuture = createClientFuture(request); - offerRespFuture(respFuture); + writeLock.lock(); + try { + offerRespFuture(respFuture); + } finally { + writeLock.unlock(); + } return respFuture; } @@ -241,7 +300,7 @@ public abstract class ClientConnection implements Co } } - //只会在WriteIOThread中调用 + //只会在WriteIOThread中调用, 必须在writeLock内执行 void offerFirstRespFuture(ClientFuture respFuture) { Serializable requestid = respFuture.request.getRequestid(); if (requestid == null) { @@ -251,7 +310,7 @@ public abstract class ClientConnection implements Co } } - //只会在WriteIOThread中调用 + //必须在writeLock内执行 void offerRespFuture(ClientFuture respFuture) { Serializable requestid = respFuture.request.getRequestid(); if (requestid == null) {