重构Client模块的IO读写策略

This commit is contained in:
Redkale
2023-01-12 20:26:39 +08:00
parent 9b66bd186f
commit 43ff13867f
8 changed files with 152 additions and 356 deletions

View File

@@ -86,14 +86,10 @@ public class AsyncIOGroup extends AsyncGroup {
ObjectPool<ByteBuffer> unsafeReadBufferPool = ObjectPool.createUnsafePool(safeBufferPool, safeBufferPool.getCreatCounter(),
safeBufferPool.getCycleCounter(), 512, safeBufferPool.getCreator(), safeBufferPool.getPrepare(), safeBufferPool.getRecycler());
if (client) {
this.ioReadThreads[i] = new ClientIOThread(String.format(threadNameFormat, indexfix), i, threads, workExecutor, Selector.open(), unsafeReadBufferPool, safeBufferPool);
this.ioWriteThreads[i] = this.ioReadThreads[i];
if (System.currentTimeMillis() < 1) { //暂时不使用
this.ioReadThreads[i].setName(String.format(threadNameFormat, "Read-" + indexfix));
ObjectPool<ByteBuffer> unsafeWriteBufferPool = ObjectPool.createUnsafePool(safeBufferPool, safeBufferPool.getCreatCounter(),
safeBufferPool.getCycleCounter(), 512, safeBufferPool.getCreator(), safeBufferPool.getPrepare(), safeBufferPool.getRecycler());
this.ioWriteThreads[i] = new ClientWriteIOThread(String.format(threadNameFormat, "Write-" + indexfix), i, threads, workExecutor, Selector.open(), unsafeWriteBufferPool, safeBufferPool);
}
this.ioReadThreads[i] = new ClientReadIOThread(String.format(threadNameFormat, "Read-" + indexfix), i, threads, workExecutor, Selector.open(), unsafeReadBufferPool, safeBufferPool);
ObjectPool<ByteBuffer> unsafeWriteBufferPool = ObjectPool.createUnsafePool(safeBufferPool, safeBufferPool.getCreatCounter(),
safeBufferPool.getCycleCounter(), 512, safeBufferPool.getCreator(), safeBufferPool.getPrepare(), safeBufferPool.getRecycler());
this.ioWriteThreads[i] = new ClientWriteIOThread(String.format(threadNameFormat, "Write-" + indexfix), i, threads, workExecutor, Selector.open(), unsafeWriteBufferPool, safeBufferPool);
} else {
this.ioReadThreads[i] = new AsyncIOThread(String.format(threadNameFormat, indexfix), i, threads, workExecutor, Selector.open(), unsafeReadBufferPool, safeBufferPool);
this.ioWriteThreads[i] = this.ioReadThreads[i];
@@ -102,7 +98,7 @@ public class AsyncIOGroup extends AsyncGroup {
if (client) {
ObjectPool<ByteBuffer> unsafeBufferPool = ObjectPool.createUnsafePool(safeBufferPool, safeBufferPool.getCreatCounter(),
safeBufferPool.getCycleCounter(), 512, safeBufferPool.getCreator(), safeBufferPool.getPrepare(), safeBufferPool.getRecycler());
this.connectThread = client ? new ClientIOThread(String.format(threadNameFormat, "Connect"), 0, 0, workExecutor, Selector.open(), unsafeBufferPool, safeBufferPool)
this.connectThread = client ? new ClientReadIOThread(String.format(threadNameFormat, "Connect"), 0, 0, workExecutor, Selector.open(), unsafeBufferPool, safeBufferPool)
: new AsyncIOThread(String.format(threadNameFormat, "Connect"), 0, 0, workExecutor, Selector.open(), unsafeBufferPool, safeBufferPool);
}
} catch (IOException e) {

View File

@@ -418,20 +418,10 @@ abstract class AsyncNioConnection extends AsyncConnection {
this.connectPending = false;//必须放最后
if (handler != null) {
if (!client || inCurrWriteThread()) { //client模式下必须保证read、write在ioThread内运行
if (t == null) {
handler.completed(null, attach);
} else {
handler.failed(t, attach);
}
if (t == null) {
handler.completed(null, attach);
} else {
ioWriteThread.execute(() -> {
if (t == null) {
handler.completed(null, attach);
} else {
handler.failed(t, attach);
}
});
handler.failed(t, attach);
}
}
}

View File

@@ -165,6 +165,11 @@ public abstract class Client<R extends ClientRequest, P> implements Resourcable
protected abstract ClientConnection createClientConnection(final int index, AsyncConnection channel);
//创建连接后先从服务器拉取数据构建的虚拟请求返回null表示连上服务器后不读取数据
protected R createVirtualRequestAfterConnect() {
return null;
}
protected int pingIntervalSeconds() {
return 30;
}
@@ -247,6 +252,12 @@ public abstract class Client<R extends ClientRequest, P> implements Resourcable
if (this.connOpenStates[index].compareAndSet(false, true)) {
CompletableFuture<ClientConnection> future = address.createClient(tcp, group, readTimeoutSeconds, writeTimeoutSeconds)
.thenApply(c -> createClientConnection(index, c).setMaxPipelines(maxPipelines));
R virtualReq = createVirtualRequestAfterConnect();
if (virtualReq != null) {
future = future.thenCompose(conn -> conn.writeVirtualRequest(virtualReq).thenApply(v -> conn));
} else {
future = future.thenApply(conn -> conn.readChannel());
}
return (authenticate == null ? future : authenticate.apply(future)).thenApply(c -> {
c.setAuthenticated(true);
this.connArray[index] = c;

View File

@@ -102,20 +102,11 @@ public abstract class ClientCodec<R extends ClientRequest, P> implements Complet
ClientRequest request = respFuture.request;
if (!request.isCompleted()) {
if (rs.exc == null) {
connection.sendHalfWrite(rs.exc);
//request没有发送完respFuture需要再次接收
Serializable reqid = request.getRequestid();
if (reqid == null) {
connection.responseQueue.offerFirst(respFuture);
} else {
connection.responseMap.put(reqid, respFuture);
}
connection.pauseWriting.set(false);
connection.wakeupWrite();
return;
} else { //异常了需要清掉半包
connection.lastHalfEntry = null;
connection.pauseWriting.set(false);
connection.wakeupWrite();
connection.sendHalfWrite(rs.exc);
}
}
connection.respWaitingCounter.decrement();

View File

@@ -6,16 +6,12 @@
package org.redkale.net.client;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.util.AbstractMap.SimpleEntry;
import java.nio.channels.ClosedChannelException;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
import java.util.function.Consumer;
import java.util.logging.Level;
import org.redkale.net.*;
import org.redkale.util.*;
/**
* 注意: 要确保AsyncConnection的读写过程都必须在channel.ioThread中运行
@@ -35,23 +31,19 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
protected final Client<R, P> client;
protected final ClientCodec<R, P> codec;
protected final LongAdder respWaitingCounter;
protected final AsyncConnection channel;
protected final ByteArray writeArray = new ByteArray();
protected final ByteArray readArray = new ByteArray();
protected final AtomicBoolean pauseWriting = new AtomicBoolean();
protected final AtomicBoolean readPending = new AtomicBoolean();
protected final AtomicBoolean pauseResuming = new AtomicBoolean();
protected final AtomicBoolean writePending = new AtomicBoolean();
protected final List<ClientWriteIOThread.ClientEntity> pauseRequests = new CopyOnWriteArrayList<ClientWriteIOThread.ClientEntity>();
protected final Queue<SimpleEntry<R, ClientFuture<R>>> requestQueue = new ArrayDeque<>();
protected final AsyncConnection channel;
private final ClientCodec<R, P> codec;
private final ClientWriteIOThread writeThread;
//responseQueue、responseMap二选一
final Deque<ClientFuture> responseQueue = new LinkedBlockingDeque<>();
@@ -59,169 +51,10 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
//responseQueue、responseMap二选一, key: requestid
final Map<Serializable, ClientFuture> responseMap = new ConcurrentHashMap<>();
SimpleEntry<R, ClientFuture<R>> lastHalfEntry;
private int maxPipelines; //最大并行处理数
private boolean closed;
private boolean authenticated;
protected final CompletionHandler<Integer, ByteBuffer> readHandler = new CompletionHandler<Integer, ByteBuffer>() {
@Override
public void completed(Integer count, ByteBuffer attachment) {
if (count < 1) {
channel.setReadBuffer(attachment);
dispose(new NonReadableChannelException());
return;
}
try {
attachment.flip();
decodeResponse(attachment);
} catch (Throwable e) {
channel.setReadBuffer(attachment);
dispose(e);
}
}
private void decodeResponse(ByteBuffer buffer) {
if (codec.decodeMessages(ClientConnection.this, buffer, readArray)) { //成功了
readArray.clear();
List<ClientResponse<P>> results = codec.pollMessages();
if (results != null) {
for (ClientResponse<P> rs : results) {
Serializable reqid = rs.getRequestid();
ClientFuture respFuture = reqid == null ? responseQueue.poll() : responseMap.remove(reqid);
if (respFuture != null) {
int mergeCount = respFuture.getMergeCount();
completeResponse(rs, respFuture);
if (mergeCount > 0) {
for (int i = 0; i < mergeCount; i++) {
respFuture = reqid == null ? responseQueue.poll() : responseMap.remove(reqid);
if (respFuture != null) {
completeResponse(rs, respFuture);
}
}
}
}
}
}
if (buffer.hasRemaining()) {
decodeResponse(buffer);
} else if (isWaitingResponseEmpty()) { //队列都已处理完了
buffer.clear();
channel.setReadBuffer(buffer);
if (readPending.compareAndSet(true, false)) {
//无消息处理
} else {
channel.read(this);
}
} else { //还有消息需要读取
if ((!requestQueue.isEmpty() || lastHalfEntry != null) && writePending.compareAndSet(false, true)) {
//先写后读取
if (sendWrite(true) <= 0) {
writePending.compareAndSet(true, false);
}
}
buffer.clear();
channel.setReadBuffer(buffer);
channel.read(this);
}
} else { //数据不全, 继续读
buffer.clear();
channel.setReadBuffer(buffer);
channel.read(this);
}
}
private void completeResponse(ClientResponse<P> rs, ClientFuture respFuture) {
if (respFuture != null) {
if (!respFuture.request.isCompleted()) {
if (rs.exc == null) {
Serializable reqid = respFuture.request.getRequestid();
if (reqid == null) {
responseQueue.offerFirst(respFuture);
} else {
responseMap.put(reqid, respFuture);
}
pauseWriting.set(false);
return;
} else { //异常了需要清掉半包
lastHalfEntry = null;
pauseWriting.set(false);
}
}
respWaitingCounter.decrement();
if (isAuthenticated() && client.respDoneCounter != null) {
client.respDoneCounter.increment();
}
try {
respFuture.cancelTimeout();
ClientRequest request = respFuture.request;
//if (client.finest) client.logger.log(Level.FINEST, Utility.nowMillis() + ": " + Thread.currentThread().getName() + ": " + ClientConnection.this + ", 回调处理, req=" + request + ", message=" + rs.message);
preComplete(rs.message, (R) request, rs.exc);
WorkThread workThread = null;
if (request != null) {
workThread = request.workThread;
request.workThread = null;
}
if (workThread == null || workThread.getWorkExecutor() == null) {
workThread = channel.getReadIOThread();
}
if (rs.exc != null) {
workThread.runWork(() -> {
if (request != null) {
Traces.currTraceid(request.traceid);
}
respFuture.completeExceptionally(rs.exc);
});
} else {
workThread.runWork(() -> {
if (request != null) {
Traces.currTraceid(request.traceid);
}
respFuture.complete(rs.message);
});
}
} catch (Throwable t) {
client.logger.log(Level.INFO, "Complete result error, request: " + respFuture.request, t);
}
}
}
@Override
public void failed(Throwable t, ByteBuffer attachment) {
dispose(t);
}
};
protected final CompletionHandler<Integer, Void> writeHandler = new CompletionHandler<Integer, Void>() {
@Override
public void completed(Integer result, Void attachment) {
// if (writeLastRequest != null && writeLastRequest.isCloseType()) {
// if (closeFuture != null) {
// channel.getWriteIOThread().runWork(() -> {
// closeFuture.complete(null);
// });
// }
// closeFuture = null;
// return;
// }
if (sendWrite(false) <= 0) {
writePending.compareAndSet(true, false);
readChannel();
}
}
@Override
public void failed(Throwable exc, Void attachment) {
dispose(exc);
}
};
@SuppressWarnings({"LeakingThisInConstructor", "OverridableMethodCallInConstructor"})
public ClientConnection(Client client, int index, AsyncConnection channel) {
this.client = client;
@@ -229,6 +62,7 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
this.index = index;
this.respWaitingCounter = client.connRespWaitings[index];
this.channel = channel.beforeCloseListener(this);
this.writeThread = (ClientWriteIOThread) channel.getWriteIOThread();
}
protected abstract ClientCodec createCodec();
@@ -241,108 +75,18 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
respFuture.setTimeout(client.timeoutScheduler.schedule(respFuture, rts, TimeUnit.SECONDS));
}
respWaitingCounter.increment(); //放在writeChannelUnsafe计数会延迟导致不准确
if (channel.inCurrWriteThread()) {
writeChannelUnsafe(request, respFuture);
} else {
channel.executeWrite(() -> writeChannelUnsafe(request, respFuture));
}
writeThread.offerRequest(this, request, respFuture);
return respFuture;
}
private void writeChannelUnsafe(R request, ClientFuture<R> respFuture) {
if (closed) {
WorkThread workThread = request.workThread;
if (workThread == null || workThread.getWorkExecutor() == null) {
workThread = channel.getReadIOThread();
}
Throwable e = new ClosedChannelException();
workThread.runWork(() -> {
Traces.currTraceid(request.traceid);
respFuture.completeExceptionally(e);
});
return;
CompletableFuture writeVirtualRequest(ClientRequest request) {
if (!request.isVirtualType()) {
return CompletableFuture.failedFuture(new RuntimeException("ClientVirtualRequest must be virtualType = true"));
}
Serializable reqid = request.getRequestid();
//保证顺序一致
if (reqid == null) {
responseQueue.offer(respFuture);
} else {
responseMap.put(reqid, respFuture);
}
requestQueue.offer(new SimpleEntry<>(request, respFuture));
if (isAuthenticated()) {
client.incrReqWritedCounter();
}
if (writePending.compareAndSet(false, true)) {
sendWrite(true);
}
}
//返回写入数据request的数量返回0表示没有可写的request
int sendWrite(boolean must) {
ClientConnection conn = this;
ByteArray rw = conn.writeArray;
rw.clear();
int pipelines = maxPipelines > 1 ? (maxPipelines - responseQueue.size() - responseMap.size()) : 1;
if (must && pipelines < 1) {
pipelines = 1;
}
int c = 0;
AtomicBoolean pw = conn.pauseWriting;
for (int i = 0; i < pipelines; i++) {
if (pw.get()) {
break;
}
SimpleEntry<R, ClientFuture<R>> entry;
if (lastHalfEntry == null) {
entry = requestQueue.poll();
} else {
entry = lastHalfEntry;
lastHalfEntry = null;
}
if (entry == null) {
break;
}
R req = entry.getKey();
if (req.getRequestid() == null && req.canMerge(conn)) {
SimpleEntry<R, ClientFuture<R>> r;
while ((r = requestQueue.poll()) != null) {
i++;
if (!req.merge(conn, r.getKey())) {
break;
}
ClientFuture<R> f = entry.getValue();
if (f != null) {
f.incrMergeCount();
}
}
req.accept(conn, rw);
if (r != null) {
req = r.getKey();
req.accept(conn, rw);
}
} else {
req.accept(conn, rw);
}
c++;
if (req.isCloseType()) {
closed = true;
this.pauseWriting.set(true);
break;
} else if (!req.isCompleted()) {
lastHalfEntry = entry;
this.pauseWriting.set(true);
break;
}
}
if (c > 0) { //当Client连接Server后先从Server读取数据时,会先发送一个EMPTY的request这样writeArray.count就会为0
channel.write(rw, writeHandler);
return c;
}
if (pw.get()) {
writePending.compareAndSet(true, false);
}
return 0;
ClientFuture respFuture = new ClientFuture(request);
responseQueue.offer(respFuture);
readChannel();
return respFuture;
}
protected void preComplete(P resp, R req, Throwable exc) {
@@ -352,18 +96,16 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
return new ClientFuture(request);
}
protected void readChannel() {
if (readPending.compareAndSet(false, true)) {
readArray.clear();
channel.read(readHandler);
}
protected ClientConnection readChannel() {
channel.readInIOThread(codec);
return this;
}
@Override //AsyncConnection.beforeCloseListener
public void accept(AsyncConnection t) {
respWaitingCounter.reset();
client.connOpenStates[index].set(false);
client.connArray[index] = null; //必须connflags之后
client.connArray[index] = null; //必须connOpenStates之后
}
public void dispose(Throwable exc) {
@@ -386,11 +128,8 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
}
}
public void wakeupWrite() {
AsyncIOThread thread = channel.getWriteIOThread();
if (thread instanceof ClientWriteIOThread) {
((ClientWriteIOThread) thread).wakeupWrite();
}
void sendHalfWrite(Throwable halfRequestExc) {
writeThread.sendHalfWrite(this, halfRequestExc);
}
public boolean isAuthenticated() {
@@ -424,14 +163,6 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
return this;
}
protected boolean isWaitingResponseEmpty() {
return responseQueue.isEmpty() && responseMap.isEmpty();
}
protected void resumeWrite() {
this.pauseWriting.set(false);
}
public int runningCount() {
return respWaitingCounter.intValue();
}

View File

@@ -19,9 +19,9 @@ import org.redkale.util.ObjectPool;
*
* @since 2.8.0
*/
public class ClientIOThread extends AsyncIOThread {
public class ClientReadIOThread extends AsyncIOThread {
public ClientIOThread(String name, int index, int threads, ExecutorService workExecutor, Selector selector,
public ClientReadIOThread(String name, int index, int threads, ExecutorService workExecutor, Selector selector,
ObjectPool<ByteBuffer> unsafeBufferPool, ObjectPool<ByteBuffer> safeBufferPool) {
super(name, index, threads, workExecutor, selector, unsafeBufferPool, safeBufferPool);
}

View File

@@ -36,6 +36,11 @@ public abstract class ClientRequest implements BiConsumer<ClientConnection, Byte
return false;
}
//连接上先从服务器拉取数据构建的虚拟请求一定要返回true
public boolean isVirtualType() {
return false;
}
public long getCreateTime() {
return createTime;
}

View File

@@ -6,17 +6,19 @@ package org.redkale.net.client;
import java.io.Serializable;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import org.redkale.net.AsyncIOThread;
import org.redkale.util.*;
/**
*
* @author zhangjx
*/
public class ClientWriteIOThread extends ClientIOThread {
public class ClientWriteIOThread extends AsyncIOThread {
private final BlockingQueue<ClientEntity> requestQueue = new LinkedBlockingQueue<>();
private final BlockingDeque<ClientEntity> requestQueue = new LinkedBlockingDeque<>();
public ClientWriteIOThread(String name, int index, int threads, ExecutorService workExecutor, Selector selector,
ObjectPool<ByteBuffer> unsafeBufferPool, ObjectPool<ByteBuffer> safeBufferPool) {
@@ -27,9 +29,26 @@ public class ClientWriteIOThread extends ClientIOThread {
requestQueue.offer(new ClientEntity(conn, request, respFuture));
}
public void wakeupWrite() {
synchronized (writeHandler) {
writeHandler.notify();
public void sendHalfWrite(ClientConnection conn, Throwable halfRequestExc) {
if (conn.pauseWriting.get()) {
conn.pauseResuming.set(true);
try {
AtomicBoolean skipFirst = new AtomicBoolean(halfRequestExc != null);
conn.pauseRequests.removeIf(e -> {
if (e != null) {
if (!skipFirst.compareAndSet(true, false)) {
requestQueue.offer((ClientEntity) e);
}
}
return true;
});
} finally {
conn.pauseResuming.set(false);
conn.pauseWriting.set(false);
synchronized (conn.pauseRequests) {
conn.pauseRequests.notify();
}
}
}
}
@@ -37,36 +56,85 @@ public class ClientWriteIOThread extends ClientIOThread {
public void run() {
final ByteBuffer buffer = getBufferSupplier().get();
final int capacity = buffer.capacity();
final ByteArray writeArray = new ByteArray(1024 * 32);
final Map<ClientConnection, List<ClientEntity>> map = new HashMap<>();
final ObjectPool<List> listPool = ObjectPool.createUnsafePool(Utility.cpus() * 2, () -> new ArrayList(), null, t -> {
t.clear();
return true;
});
while (!isClosed()) {
ClientEntity entity;
try {
while ((entity = requestQueue.take()) != null) {
ClientConnection conn = entity.conn;
ClientRequest request = entity.request;
ClientFuture respFuture = entity.respFuture;
AtomicBoolean pw = conn.pauseWriting;
Serializable reqid = request.getRequestid();
if (reqid == null) {
conn.responseQueue.offer(respFuture);
} else {
conn.responseMap.put(reqid, respFuture);
}
ByteArray rw = conn.writeArray;
rw.clear();
request.accept(conn, rw);
if (rw.length() <= capacity) {
buffer.clear();
buffer.put(rw.content(), 0, rw.length());
buffer.flip();
conn.channel.write(buffer, conn, writeHandler);
} else {
conn.channel.write(rw, conn, writeHandler);
}
if (pw.get()) {
synchronized (writeHandler) {
writeHandler.wait(30_000);
map.clear();
{
Serializable reqid = entity.request.getRequestid();
if (reqid == null) {
entity.conn.responseQueue.offer(entity.respFuture);
} else {
entity.conn.responseMap.put(reqid, entity.respFuture);
}
}
if (entity.conn.pauseWriting.get()) {
if (entity.conn.pauseResuming.get()) {
try {
synchronized (entity.conn.pauseRequests) {
entity.conn.pauseRequests.wait(3_000);
}
} catch (InterruptedException ie) {
}
}
entity.conn.pauseRequests.add(entity);
} else {
map.computeIfAbsent(entity.conn, c -> listPool.get()).add(entity);
}
while ((entity = requestQueue.poll()) != null) {
Serializable reqid = entity.request.getRequestid();
if (reqid == null) {
entity.conn.responseQueue.offer(entity.respFuture);
} else {
entity.conn.responseMap.put(reqid, entity.respFuture);
}
if (entity.conn.pauseWriting.get()) {
if (entity.conn.pauseResuming.get()) {
try {
synchronized (entity.conn.pauseRequests) {
entity.conn.pauseRequests.wait(3_000);
}
} catch (InterruptedException ie) {
}
}
entity.conn.pauseRequests.add(entity);
} else {
map.computeIfAbsent(entity.conn, c -> listPool.get()).add(entity);
}
}
map.forEach((conn, list) -> {
writeArray.clear();
int i = -1;
for (ClientEntity en : list) {
++i;
ClientRequest request = en.request;
request.accept(conn, writeArray);
if (!request.isCompleted()) {
conn.pauseWriting.set(true);
conn.pauseRequests.addAll(list.subList(i, list.size()));
break;
}
}
listPool.accept(list);
//channel.write
if (writeArray.length() > 0) {
if (writeArray.length() <= capacity) {
buffer.clear();
buffer.put(writeArray.content(), 0, writeArray.length());
buffer.flip();
conn.channel.write(buffer, conn, writeHandler);
} else {
conn.channel.write(writeArray, conn, writeHandler);
}
}
});
}
} catch (InterruptedException e) {
}
@@ -99,5 +167,9 @@ public class ClientWriteIOThread extends ClientIOThread {
this.respFuture = respFuture;
}
@Override
public String toString() {
return getClass().getSimpleName() + "_" + Objects.hash(this) + "{conn = " + conn + ", request = " + request + "}";
}
}
}