Client增加批量请求方法

This commit is contained in:
redkale
2023-07-08 09:01:06 +08:00
parent 2163ce3c4c
commit 69797fd29c
2 changed files with 110 additions and 7 deletions

View File

@@ -6,7 +6,7 @@
package org.redkale.net.client;
import java.net.SocketAddress;
import java.util.Queue;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
import java.util.function.*;
@@ -254,6 +254,50 @@ public abstract class Client<C extends ClientConnection<R, P>, R extends ClientR
return conn.writeChannel(request, respTransfer);
}
public final CompletableFuture<List<P>> sendAsync(R[] requests) {
for (R request : requests) {
if (request.workThread == null) {
request.workThread = WorkThread.currentWorkThread();
}
}
return connect().thenCompose(conn -> writeChannel(conn, requests));
}
public final <T> CompletableFuture<List<T>> sendAsync(R[] requests, Function<P, T> respTransfer) {
for (R request : requests) {
if (request.workThread == null) {
request.workThread = WorkThread.currentWorkThread();
}
}
return connect().thenCompose(conn -> writeChannel(conn, requests, respTransfer));
}
public final CompletableFuture<List<P>> sendAsync(SocketAddress addr, R[] requests) {
for (R request : requests) {
if (request.workThread == null) {
request.workThread = WorkThread.currentWorkThread();
}
}
return connect(addr).thenCompose(conn -> writeChannel(conn, requests));
}
public final <T> CompletableFuture<List<T>> sendAsync(SocketAddress addr, R[] requests, Function<P, T> respTransfer) {
for (R request : requests) {
if (request.workThread == null) {
request.workThread = WorkThread.currentWorkThread();
}
}
return connect(addr).thenCompose(conn -> writeChannel(conn, requests, respTransfer));
}
protected CompletableFuture<List<P>> writeChannel(ClientConnection conn, R[] requests) {
return conn.writeChannel(requests);
}
protected <T> CompletableFuture<List<T>> writeChannel(ClientConnection conn, R[] requests, Function<P, T> respTransfer) {
return conn.writeChannel(requests, respTransfer);
}
private C createConnection(int index, AsyncConnection channel) {
C conn = createClientConnection(index, channel);
if (!channel.isReadPending()) {

View File

@@ -9,14 +9,14 @@ import java.io.Serializable;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.util.Iterator;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.*;
import org.redkale.annotation.*;
import org.redkale.net.*;
import org.redkale.util.ByteArray;
import org.redkale.util.*;
/**
* 注意: 要确保AsyncConnection的读写过程都必须在channel.ioThread中运行
@@ -59,7 +59,7 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
@Override
public void completed(Integer result, ClientConnection attachment) {
}
@Override
@@ -112,6 +112,10 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
return writeChannel(request, null);
}
protected final CompletableFuture<List<P>> writeChannel(R[] requests) {
return writeChannel(requests, null);
}
//respTransfer只会在ClientCodec的读线程里调用
protected final <T> CompletableFuture<T> writeChannel(R request, Function<P, T> respTransfer) {
request.respTransfer = respTransfer;
@@ -135,6 +139,36 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
return respFuture;
}
//respTransfer只会在ClientCodec的读线程里调用
protected final <T> CompletableFuture<List<T>> writeChannel(R[] requests, Function<P, T> respTransfer) {
ClientFuture[] respFutures = new ClientFuture[requests.length];
int rts = this.channel.getReadTimeoutSeconds();
for (int i = 0; i < respFutures.length; i++) {
R request = requests[i];
request.respTransfer = respTransfer;
ClientFuture respFuture = createClientFuture(requests[i]);
respFutures[i] = respFuture;
if (rts > 0 && !request.isCloseType()) {
respFuture.setTimeout(client.timeoutScheduler.schedule(respFuture, rts, TimeUnit.SECONDS));
}
}
respWaitingCounter.add(respFutures.length);//放在writeChannelInWriteThread计数会延迟导致不准确
writeLock.lock();
try {
for (ClientFuture respFuture : respFutures) {
offerRespFuture(respFuture);
if (pauseWriting.get()) {
pauseRequests.add(respFuture);
}
}
sendRequestInLocking(respFutures);
} finally {
writeLock.unlock();
}
return Utility.allOfFutures(respFutures);
}
private void sendRequestInLocking(R request, ClientFuture respFuture) {
if (true) { //新方式
ByteArray array = arrayThreadLocal.get();
@@ -170,6 +204,26 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
}
}
private void sendRequestInLocking(ClientFuture[] respFutures) {
ByteArray array = arrayThreadLocal.get();
array.clear();
for (ClientFuture respFuture : respFutures) {
if (pauseWriting.get()) {
pauseRequests.add(respFuture);
} else {
ClientRequest request = respFuture.request;
request.writeTo(this, array);
if (request.isCompleted()) {
doneRequestCounter.increment();
} else { //还剩半包没发送完
pauseWriting.set(true);
currHalfWriteFuture = respFuture;
}
}
}
channel.fastWrite(array.getBytes(), writeHandler);
}
//发送半包和积压的请求数据包
void sendHalfWriteInReadThread(R request, Throwable halfRequestExc) {
writeLock.lock();
@@ -198,7 +252,12 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
return CompletableFuture.failedFuture(new RuntimeException("ClientVirtualRequest must be virtualType = true"));
}
ClientFuture<R, P> respFuture = createClientFuture(request);
offerRespFuture(respFuture);
writeLock.lock();
try {
offerRespFuture(respFuture);
} finally {
writeLock.unlock();
}
return respFuture;
}
@@ -241,7 +300,7 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
}
}
//只会在WriteIOThread中调用
//只会在WriteIOThread中调用, 必须在writeLock内执行
void offerFirstRespFuture(ClientFuture<R, P> respFuture) {
Serializable requestid = respFuture.request.getRequestid();
if (requestid == null) {
@@ -251,7 +310,7 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
}
}
//只会在WriteIOThread中调用
//必须在writeLock内执行
void offerRespFuture(ClientFuture<R, P> respFuture) {
Serializable requestid = respFuture.request.getRequestid();
if (requestid == null) {