isThreadLocalConnMode

This commit is contained in:
redkale
2023-03-30 19:54:19 +08:00
parent 41e1ffa6e2
commit c9be4a89af
5 changed files with 63 additions and 194 deletions

View File

@@ -12,7 +12,6 @@ import java.util.Objects;
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
import org.redkale.annotation.ResourceType;
import org.redkale.net.client.*;
import org.redkale.util.*;
/**
@@ -100,14 +99,6 @@ public class AsyncIOGroup extends AsyncGroup {
return new AsyncIOThread(g, name, index, threads, workExecutor, safeBufferPool);
}
protected AsyncIOThread createClientReadIOThread(ThreadGroup g, String name, int index, int threads, ExecutorService workExecutor, ByteBufferPool safeBufferPool) throws IOException {
return new ClientReadIOThread(g, name, index, threads, workExecutor, safeBufferPool);
}
protected AsyncIOThread createClientWriteIOThread(ThreadGroup g, String name, int index, int threads, ExecutorService workExecutor, ByteBufferPool safeBufferPool) throws IOException {
return new ClientWriteIOThread(g, name, index, threads, workExecutor, safeBufferPool);
}
AsyncIOThread connectThread() {
if (connectThreadInited.compareAndSet(false, true)) {
this.connectThread.start();

View File

@@ -81,6 +81,9 @@ public abstract class Client<C extends ClientConnection<R, P>, R extends ClientR
protected int readTimeoutSeconds;
protected int writeTimeoutSeconds;
//------------------ LocalThreadMode模式 ------------------
final ThreadLocal<C> localConnection = new ThreadLocal();
//------------------ 可选项 ------------------
//PING心跳的请求数据为null且pingInterval<1表示不需要定时ping
@@ -263,6 +266,14 @@ public abstract class Client<C extends ClientConnection<R, P>, R extends ClientR
return conn.writeChannel(request, respTransfer);
}
//是否采用ThreadLocal连接池模式
//支持ThreadLocal连接池模式的最基本要求:
// 1) 只能调用connect()获取连接不能调用connect(SocketAddress addr)
// 2) request必须一次性输出不能出现写入request后request.isCompleted()=false的情况
protected boolean isThreadLocalConnMode() {
return false;
}
private C createConnection(int index, AsyncConnection channel) {
C conn = createClientConnection(index, channel);
if (!channel.isReadPending()) {
@@ -272,6 +283,36 @@ public abstract class Client<C extends ClientConnection<R, P>, R extends ClientR
}
protected CompletableFuture<C> connect() {
if (isThreadLocalConnMode()) {
C conn = localConnection.get();
if (conn == null || !conn.isOpen()) {
try {
conn = connect1();
} catch (Exception e) {
return CompletableFuture.failedFuture(e);
}
localConnection.set(conn);
}
return CompletableFuture.completedFuture(conn);
} else {
return connect0();
}
}
protected C connect1() {
CompletableFuture<C> future = group.createClient(tcp, this.address.randomAddress(), readTimeoutSeconds, writeTimeoutSeconds)
.thenApply(c -> (C) createConnection(-2, c).setMaxPipelines(maxPipelines));
R virtualReq = createVirtualRequestAfterConnect();
if (virtualReq != null) {
future = future.thenCompose(conn -> conn.writeVirtualRequest(virtualReq).thenApply(v -> conn));
}
if (authenticate != null) {
future = future.thenCompose(authenticate);
}
return future.thenApply(c -> (C) c.setAuthenticated(true)).join();
}
protected CompletableFuture<C> connect0() {
final int size = this.connArray.length;
WorkThread workThread = WorkThread.currWorkThread();
final int connIndex = (workThread != null && workThread.threads() == size) ? workThread.index() : (int) Math.abs(connIndexSeq.getAndIncrement()) % size;

View File

@@ -12,6 +12,7 @@ import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.*;
import java.util.function.*;
import org.redkale.annotation.Nullable;
import org.redkale.net.*;
import org.redkale.util.ByteArray;
@@ -29,10 +30,14 @@ import org.redkale.util.ByteArray;
*/
public abstract class ClientConnection<R extends ClientRequest, P> implements Consumer<AsyncConnection> {
protected final int index; //从0开始 connArray的下坐标
//=-1 表示连接放在connAddrEntrys存储
//=-2 表示连接放在ThreadLocal存储
//>=0 表示connArray的下坐标从0开始
protected final int index;
protected final Client client;
@Nullable
protected final LongAdder respWaitingCounter; //可能为null
protected final LongAdder doneRequestCounter = new LongAdder();
@@ -47,6 +52,7 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
ClientFuture currHalfWriteFuture; //pauseWriting=true此字段才会有值; pauseWriting=false此字段值为null
@Nullable
private final Client.AddressConnEntry connEntry;
protected final AsyncConnection channel;
@@ -70,8 +76,8 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
this.client = client;
this.codec = createCodec();
this.index = index;
this.connEntry = index >= 0 ? null : client.connAddrEntrys.get(channel.getRemoteAddress());
this.respWaitingCounter = index >= 0 ? client.connRespWaitings[index] : this.connEntry.connRespWaiting;
this.connEntry = index == -2 ? null : (index >= 0 ? null : client.connAddrEntrys.get(channel.getRemoteAddress()));
this.respWaitingCounter = index == -2 ? new LongAdder() : (index >= 0 ? client.connRespWaitings[index] : this.connEntry.connRespWaiting);
this.channel = channel.beforeCloseListener(this);
}
@@ -90,10 +96,20 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
respFuture.setTimeout(client.timeoutScheduler.schedule(respFuture, rts, TimeUnit.SECONDS));
}
respWaitingCounter.increment(); //放在writeChannelInWriteThread计数会延迟导致不准确
if (channel.inCurrWriteThread()) {
writeChannelInThread(request, respFuture);
if (client.isThreadLocalConnMode()) {
offerRespFuture(respFuture);
writeArray.clear();
request.writeTo(this, writeArray);
doneRequestCounter.increment();
if (writeArray.length() > 0) {
channel.write(writeArray, this, writeHandler);
}
} else {
channel.executeWrite(() -> writeChannelInThread(request, respFuture));
if (channel.inCurrWriteThread()) {
writeChannelInThread(request, respFuture);
} else {
channel.executeWrite(() -> writeChannelInThread(request, respFuture));
}
}
return respFuture;
}

View File

@@ -1,28 +0,0 @@
/*
*
*/
package org.redkale.net.client;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import org.redkale.net.AsyncIOThread;
import org.redkale.util.ByteBufferPool;
/**
* 客户端IO读线程
*
* <p>
* 详情见: https://redkale.org
*
* @author zhangjx
*
* @since 2.8.0
*/
public class ClientReadIOThread extends AsyncIOThread {
public ClientReadIOThread(ThreadGroup g, String name, int index, int threads,
ExecutorService workExecutor, ByteBufferPool safeBufferPool) throws IOException {
super(g, name, index, threads, workExecutor, safeBufferPool);
}
}

View File

@@ -1,151 +0,0 @@
/*
*
*/
package org.redkale.net.client;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicBoolean;
import org.redkale.net.AsyncIOThread;
import org.redkale.util.*;
/**
* 客户端IO写线程
*
* <p>
* 详情见: https://redkale.org
*
* @author zhangjx
*
* @since 2.8.0
*/
public class ClientWriteIOThread extends AsyncIOThread {
private final AtomicBoolean writingFlag = new AtomicBoolean();
private final BlockingQueue<ClientFuture> requestQueue = new LinkedBlockingQueue<>();
public ClientWriteIOThread(ThreadGroup g, String name, int index, int threads,
ExecutorService workExecutor, ByteBufferPool safeBufferPool) throws IOException {
super(g, name, index, threads, workExecutor, safeBufferPool);
}
public void offerRequest(ClientConnection conn, ClientRequest request, ClientFuture respFuture) {
requestQueue.offer(respFuture);
}
public void sendHalfWrite(ClientConnection conn, ClientRequest request, Throwable halfRequestExc) {
ClientFuture respFuture = conn.createClientFuture(request);
respFuture.resumeHalfRequestFlag = true;
if (halfRequestExc != null) { //halfRequestExc不为null时需要把当前halfRequest移除
conn.pauseRequests.poll();
}
requestQueue.offer(respFuture);
}
@Override
public void run() {
final ByteBuffer buffer = getBufferSupplier().get();
final int capacity = buffer.capacity();
final ByteArray writeArray = new ByteArray();
final Map<ClientConnection, List<ClientFuture>> map = new HashMap<>();
final ObjectPool<List> listPool = ObjectPool.createUnsafePool(Utility.cpus() * 2, () -> new ArrayList(), null, t -> {
t.clear();
return true;
});
while (!isClosed()) {
ClientFuture entry;
try {
while ((entry = requestQueue.take()) != null) {
map.clear();
if (entry.resumeHalfRequestFlag != null) { //将暂停的pauseRequests写入list
List<ClientFuture> cl = map.computeIfAbsent(entry.conn, c -> listPool.get());
for (ClientFuture f : (Collection<ClientFuture>) entry.conn.pauseRequests) {
if (!f.isDone()) {
entry.conn.offerRespFuture(f);
cl.add(f);
}
}
entry.conn.pauseRequests.clear();
entry.conn.pauseWriting.set(false);
} else if (!entry.isDone()) {
entry.conn.offerRespFuture(entry);
if (entry.conn.pauseWriting.get()) {
entry.conn.pauseRequests.add(entry);
} else {
map.computeIfAbsent(entry.conn, c -> listPool.get()).add(entry);
}
}
while ((entry = requestQueue.poll()) != null) {
if (entry.resumeHalfRequestFlag != null) { //将暂停的pauseRequests写入list
List<ClientFuture> cl = map.computeIfAbsent(entry.conn, c -> listPool.get());
for (ClientFuture f : (Collection<ClientFuture>) entry.conn.pauseRequests) {
if (!f.isDone()) {
entry.conn.offerRespFuture(f);
cl.add(f);
}
}
entry.conn.pauseRequests.clear();
entry.conn.pauseWriting.set(false);
} else if (!entry.isDone()) {
entry.conn.offerRespFuture(entry);
if (entry.conn.pauseWriting.get()) {
entry.conn.pauseRequests.add(entry);
} else {
map.computeIfAbsent(entry.conn, c -> listPool.get()).add(entry);
}
}
}
map.forEach((conn, list) -> {
writeArray.clear();
int i = -1;
for (ClientFuture en : list) {
++i;
ClientRequest request = en.request;
request.writeTo(conn, writeArray);
conn.doneRequestCounter.increment();
if (!request.isCompleted()) {
conn.pauseWriting.set(true);
conn.pauseRequests.addAll(list.subList(i, list.size()));
break;
}
if (writeArray.length() > capacity) { //合并的数据包不能太大
conn.channel.write(writeArray, conn, writeHandler);
writeArray.clear();
}
}
listPool.accept(list);
//channel.write
if (writeArray.length() > 0) {
if (writeArray.length() <= capacity) {
buffer.clear();
buffer.put(writeArray.content(), 0, writeArray.length());
buffer.flip();
conn.channel.write(buffer, conn, writeHandler);
} else {
conn.channel.write(writeArray, conn, writeHandler);
}
}
});
}
} catch (InterruptedException e) {
}
}
}
protected final CompletionHandler<Integer, ClientConnection> writeHandler = new CompletionHandler<Integer, ClientConnection>() {
@Override
public void completed(Integer result, ClientConnection attachment) {
}
@Override
public void failed(Throwable exc, ClientConnection attachment) {
attachment.dispose(exc);
}
};
}