重构Client模块的IO读写策略

This commit is contained in:
Redkale
2023-01-13 09:33:23 +08:00
parent 43ff13867f
commit 41028384af
13 changed files with 194 additions and 192 deletions

View File

@@ -74,7 +74,7 @@ public class AsyncIOThread extends WorkThread {
/** /**
* 不可重置, 防止IO操作不在IO线程中执行 * 不可重置, 防止IO操作不在IO线程中执行
* *
* @param command * @param command 操作
*/ */
@Override @Override
public void execute(Runnable command) { public void execute(Runnable command) {
@@ -85,7 +85,7 @@ public class AsyncIOThread extends WorkThread {
/** /**
* 不可重置, 防止IO操作不在IO线程中执行 * 不可重置, 防止IO操作不在IO线程中执行
* *
* @param commands * @param commands 操作
*/ */
@Override @Override
public void execute(Runnable... commands) { public void execute(Runnable... commands) {
@@ -98,7 +98,7 @@ public class AsyncIOThread extends WorkThread {
/** /**
* 不可重置, 防止IO操作不在IO线程中执行 * 不可重置, 防止IO操作不在IO线程中执行
* *
* @param commands * @param commands 操作
*/ */
@Override @Override
public void execute(Collection<Runnable> commands) { public void execute(Collection<Runnable> commands) {

View File

@@ -27,18 +27,21 @@ import org.redkale.util.*;
*/ */
public abstract class ClientCodec<R extends ClientRequest, P> implements CompletionHandler<Integer, ByteBuffer> { public abstract class ClientCodec<R extends ClientRequest, P> implements CompletionHandler<Integer, ByteBuffer> {
private final List<ClientResponse<P>> repsResults = new ArrayList<>(); protected final ClientConnection connection;
private final ClientConnection connection; private final List<ClientResponse<P>> respResults = new ArrayList<>();
private final ByteArray readArray = new ByteArray(); private final ByteArray readArray = new ByteArray();
private final ObjectPool<ClientResponse> respPool = ObjectPool.createUnsafePool(256, t -> new ClientResponse(), ClientResponse::prepare, ClientResponse::recycle);
public ClientCodec(ClientConnection connection) { public ClientCodec(ClientConnection connection) {
Objects.requireNonNull(connection);
this.connection = connection; this.connection = connection;
} }
//返回true: array会clear, 返回false: buffer会clear //返回true: array会clear, 返回false: buffer会clear
public abstract boolean decodeMessages(ClientConnection connection, ByteBuffer buffer, ByteArray array); public abstract boolean decodeMessages(ByteBuffer buffer, ByteArray array);
@Override @Override
public final void completed(Integer count, ByteBuffer attachment) { public final void completed(Integer count, ByteBuffer attachment) {
@@ -61,27 +64,17 @@ public abstract class ClientCodec<R extends ClientRequest, P> implements Complet
AsyncConnection channel = connection.channel; AsyncConnection channel = connection.channel;
Deque<ClientFuture> responseQueue = connection.responseQueue; Deque<ClientFuture> responseQueue = connection.responseQueue;
Map<Serializable, ClientFuture> responseMap = connection.responseMap; Map<Serializable, ClientFuture> responseMap = connection.responseMap;
if (decodeMessages(connection, buffer, readArray)) { //成功了 if (decodeMessages(buffer, readArray)) { //成功了
readArray.clear(); readArray.clear();
List<ClientResponse<P>> results = pollMessages(); for (ClientResponse<P> cr : respResults) {
if (results != null) { Serializable reqid = cr.getRequestid();
for (ClientResponse<P> rs : results) { ClientFuture respFuture = reqid == null ? responseQueue.poll() : responseMap.remove(reqid);
Serializable reqid = rs.getRequestid(); if (respFuture != null) {
ClientFuture respFuture = reqid == null ? responseQueue.poll() : responseMap.remove(reqid); completeResponse(respFuture, cr.message, cr.exc);
if (respFuture != null) {
int mergeCount = respFuture.getMergeCount();
completeResponse(rs, respFuture);
if (mergeCount > 0) {
for (int i = 0; i < mergeCount; i++) {
respFuture = reqid == null ? responseQueue.poll() : responseMap.remove(reqid);
if (respFuture != null) {
completeResponse(rs, respFuture);
}
}
}
}
} }
respPool.accept(cr);
} }
respResults.clear();
if (buffer.hasRemaining()) { if (buffer.hasRemaining()) {
decodeResponse(buffer); decodeResponse(buffer);
@@ -97,40 +90,40 @@ public abstract class ClientCodec<R extends ClientRequest, P> implements Complet
} }
} }
private void completeResponse(ClientResponse<P> rs, ClientFuture respFuture) { private void completeResponse(ClientFuture respFuture, P message, Throwable exc) {
if (respFuture != null) { if (respFuture != null) {
ClientRequest request = respFuture.request; ClientRequest request = respFuture.request;
if (!request.isCompleted()) {
if (rs.exc == null) {
connection.sendHalfWrite(rs.exc);
//request没有发送完respFuture需要再次接收
return;
} else { //异常了需要清掉半包
connection.sendHalfWrite(rs.exc);
}
}
connection.respWaitingCounter.decrement();
if (connection.isAuthenticated()) {
connection.client.incrRespDoneCounter();
}
try { try {
if (!request.isCompleted()) {
if (exc == null) {
connection.sendHalfWrite(exc);
//request没有发送完respFuture需要再次接收
return;
} else { //异常了需要清掉半包
connection.sendHalfWrite(exc);
}
}
connection.respWaitingCounter.decrement();
if (connection.isAuthenticated()) {
connection.client.incrRespDoneCounter();
}
respFuture.cancelTimeout(); respFuture.cancelTimeout();
//if (client.finest) client.logger.log(Level.FINEST, Utility.nowMillis() + ": " + Thread.currentThread().getName() + ": " + ClientConnection.this + ", 回调处理, req=" + request + ", message=" + rs.message); //if (client.finest) client.logger.log(Level.FINEST, Utility.nowMillis() + ": " + Thread.currentThread().getName() + ": " + ClientConnection.this + ", 回调处理, req=" + request + ", message=" + rs.message);
connection.preComplete(rs.message, (R) request, rs.exc); connection.preComplete(message, (R) request, exc);
WorkThread workThread = request.workThread; WorkThread workThread = request.workThread;
request.workThread = null; request.workThread = null;
if (workThread == null || workThread.getWorkExecutor() == null) { if (workThread == null || workThread.getWorkExecutor() == null) {
workThread = connection.channel.getReadIOThread(); workThread = connection.channel.getReadIOThread();
} }
if (rs.exc != null) { if (exc != null) {
workThread.runWork(() -> { workThread.runWork(() -> {
Traces.currTraceid(request.traceid); Traces.currTraceid(request.traceid);
respFuture.completeExceptionally(rs.exc); respFuture.completeExceptionally(exc);
}); });
} else { } else {
workThread.runWork(() -> { workThread.runWork(() -> {
Traces.currTraceid(request.traceid); Traces.currTraceid(request.traceid);
respFuture.complete(rs.message); respFuture.complete(message);
}); });
} }
} catch (Throwable t) { } catch (Throwable t) {
@@ -148,22 +141,18 @@ public abstract class ClientCodec<R extends ClientRequest, P> implements Complet
return connection.responseQueue.iterator(); return connection.responseQueue.iterator();
} }
public List<ClientResponse<P>> pollMessages() { protected List<ClientResponse<P>> pollMessages() {
List<ClientResponse<P>> rs = new ArrayList<>(repsResults); List<ClientResponse<P>> rs = new ArrayList<>(respResults);
this.repsResults.clear(); this.respResults.clear();
return rs; return rs;
} }
public ClientConnection getConnection() { public void addMessage(R request, P result) {
return connection; this.respResults.add(respPool.get().set(request, result));
} }
public void addMessage(P result) { public void addMessage(R request, Throwable exc) {
this.repsResults.add(new ClientResponse<>(result)); this.respResults.add(respPool.get().set(request, exc));
}
public void addMessage(Throwable exc) {
this.repsResults.add(new ClientResponse<>(exc));
} }
@Override @Override

View File

@@ -37,7 +37,7 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
protected final AtomicBoolean pauseResuming = new AtomicBoolean(); protected final AtomicBoolean pauseResuming = new AtomicBoolean();
protected final List<ClientWriteIOThread.ClientEntity> pauseRequests = new CopyOnWriteArrayList<ClientWriteIOThread.ClientEntity>(); protected final List<ClientFuture> pauseRequests = new CopyOnWriteArrayList<>();
protected final AsyncConnection channel; protected final AsyncConnection channel;
@@ -71,7 +71,6 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
ClientFuture respFuture = createClientFuture(request); ClientFuture respFuture = createClientFuture(request);
int rts = this.channel.getReadTimeoutSeconds(); int rts = this.channel.getReadTimeoutSeconds();
if (rts > 0 && !request.isCloseType()) { if (rts > 0 && !request.isCloseType()) {
respFuture.setConn(this);
respFuture.setTimeout(client.timeoutScheduler.schedule(respFuture, rts, TimeUnit.SECONDS)); respFuture.setTimeout(client.timeoutScheduler.schedule(respFuture, rts, TimeUnit.SECONDS));
} }
respWaitingCounter.increment(); //放在writeChannelUnsafe计数会延迟导致不准确 respWaitingCounter.increment(); //放在writeChannelUnsafe计数会延迟导致不准确
@@ -79,11 +78,11 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
return respFuture; return respFuture;
} }
CompletableFuture writeVirtualRequest(ClientRequest request) { CompletableFuture writeVirtualRequest(R request) {
if (!request.isVirtualType()) { if (!request.isVirtualType()) {
return CompletableFuture.failedFuture(new RuntimeException("ClientVirtualRequest must be virtualType = true")); return CompletableFuture.failedFuture(new RuntimeException("ClientVirtualRequest must be virtualType = true"));
} }
ClientFuture respFuture = new ClientFuture(request); ClientFuture respFuture = createClientFuture(request);
responseQueue.offer(respFuture); responseQueue.offer(respFuture);
readChannel(); readChannel();
return respFuture; return respFuture;
@@ -93,7 +92,7 @@ public abstract class ClientConnection<R extends ClientRequest, P> implements Co
} }
protected ClientFuture createClientFuture(R request) { protected ClientFuture createClientFuture(R request) {
return new ClientFuture(request); return new ClientFuture(this, request);
} }
protected ClientConnection readChannel() { protected ClientConnection readChannel() {

View File

@@ -5,7 +5,7 @@
*/ */
package org.redkale.net.client; package org.redkale.net.client;
import java.util.Queue; import java.util.*;
import java.util.concurrent.*; import java.util.concurrent.*;
import org.redkale.net.*; import org.redkale.net.*;
@@ -16,49 +16,16 @@ import org.redkale.net.*;
*/ */
public class ClientFuture<T> extends CompletableFuture<T> implements Runnable { public class ClientFuture<T> extends CompletableFuture<T> implements Runnable {
public static final ClientFuture EMPTY = new ClientFuture(null) {
@Override
public boolean complete(Object value) {
return true;
}
@Override
public boolean completeExceptionally(Throwable ex) {
return true;
}
@Override
void setConn(ClientConnection conn) {
}
@Override
void setTimeout(ScheduledFuture timeout) {
}
@Override
void incrMergeCount() {
}
@Override
public void run() {
}
};
protected final ClientRequest request; protected final ClientRequest request;
protected final ClientConnection conn;
private ScheduledFuture timeout; private ScheduledFuture timeout;
private int mergeCount; //合并的个数,不算自身 public ClientFuture(ClientConnection conn, ClientRequest request) {
private ClientConnection conn;
public ClientFuture(ClientRequest request) {
super(); super();
this.request = request;
}
void setConn(ClientConnection conn) {
this.conn = conn; this.conn = conn;
this.request = request;
} }
void setTimeout(ScheduledFuture timeout) { void setTimeout(ScheduledFuture timeout) {
@@ -71,20 +38,10 @@ public class ClientFuture<T> extends CompletableFuture<T> implements Runnable {
} }
} }
void incrMergeCount() {
mergeCount++;
}
public int getMergeCount() {
return mergeCount;
}
@Override //JDK9+ @Override //JDK9+
public <U> ClientFuture<U> newIncompleteFuture() { public <U> ClientFuture<U> newIncompleteFuture() {
ClientFuture future = new ClientFuture<>(request); ClientFuture future = new ClientFuture<>(conn, request);
future.timeout = timeout; future.timeout = timeout;
future.mergeCount = mergeCount;
future.conn = conn;
return future; return future;
} }
@@ -125,4 +82,9 @@ public class ClientFuture<T> extends CompletableFuture<T> implements Runnable {
} }
workThread.runWork(() -> completeExceptionally(ex)); workThread.runWork(() -> completeExceptionally(ex));
} }
@Override
public String toString() {
return getClass().getSimpleName() + "_" + Objects.hash(this) + "{conn = " + conn + ", request = " + request + "}";
}
} }

View File

@@ -27,6 +27,9 @@ public abstract class ClientRequest implements BiConsumer<ClientConnection, Byte
protected String traceid; protected String traceid;
@Override
public abstract void accept(ClientConnection conn, ByteArray array);
public Serializable getRequestid() { public Serializable getRequestid() {
return null; return null;
} }
@@ -54,16 +57,6 @@ public abstract class ClientRequest implements BiConsumer<ClientConnection, Byte
return (T) this; return (T) this;
} }
//是否能合并, requestid=null的情况下值才有效
protected boolean canMerge(ClientConnection conn) {
return false;
}
//合并成功了返回true
protected boolean merge(ClientConnection conn, ClientRequest other) {
return false;
}
//数据是否全部写入如果只写部分返回false, 配合ClientConnection.pauseWriting使用 //数据是否全部写入如果只写部分返回false, 配合ClientConnection.pauseWriting使用
protected boolean isCompleted() { protected boolean isCompleted() {
return true; return true;

View File

@@ -14,22 +14,62 @@ import java.io.Serializable;
*/ */
public class ClientResponse<P> { public class ClientResponse<P> {
protected ClientRequest request;
protected P message; protected P message;
protected Throwable exc; protected Throwable exc;
public Serializable getRequestid() { public ClientResponse() {
return null;
} }
public ClientResponse(P result) { public ClientResponse(ClientRequest request, P message) {
this.message = result; this.request = request;
this.message = message;
} }
public ClientResponse(Throwable exc) { public ClientResponse(ClientRequest request, Throwable exc) {
this.request = request;
this.exc = exc; this.exc = exc;
} }
public Serializable getRequestid() {
return request == null ? null : request.getRequestid();
}
public ClientResponse<P> set(ClientRequest request, P message) {
this.request = request;
this.message = message;
return this;
}
public ClientResponse<P> set(ClientRequest request, Throwable exc) {
this.request = request;
this.exc = exc;
return this;
}
protected void prepare() {
this.request = null;
this.message = null;
this.exc = null;
}
protected boolean recycle() {
this.request = null;
this.message = null;
this.exc = null;
return true;
}
public ClientRequest getRequest() {
return request;
}
public void setRequest(ClientRequest request) {
this.request = request;
}
public P getMessage() { public P getMessage() {
return message; return message;
} }
@@ -53,4 +93,5 @@ public class ClientResponse<P> {
} }
return "{\"message\":" + message + "}"; return "{\"message\":" + message + "}";
} }
} }

View File

@@ -18,7 +18,7 @@ import org.redkale.util.*;
*/ */
public class ClientWriteIOThread extends AsyncIOThread { public class ClientWriteIOThread extends AsyncIOThread {
private final BlockingDeque<ClientEntity> requestQueue = new LinkedBlockingDeque<>(); private final BlockingDeque<ClientFuture> requestQueue = new LinkedBlockingDeque<>();
public ClientWriteIOThread(String name, int index, int threads, ExecutorService workExecutor, Selector selector, public ClientWriteIOThread(String name, int index, int threads, ExecutorService workExecutor, Selector selector,
ObjectPool<ByteBuffer> unsafeBufferPool, ObjectPool<ByteBuffer> safeBufferPool) { ObjectPool<ByteBuffer> unsafeBufferPool, ObjectPool<ByteBuffer> safeBufferPool) {
@@ -26,7 +26,7 @@ public class ClientWriteIOThread extends AsyncIOThread {
} }
public void offerRequest(ClientConnection conn, ClientRequest request, ClientFuture respFuture) { public void offerRequest(ClientConnection conn, ClientRequest request, ClientFuture respFuture) {
requestQueue.offer(new ClientEntity(conn, request, respFuture)); requestQueue.offer(respFuture);
} }
public void sendHalfWrite(ClientConnection conn, Throwable halfRequestExc) { public void sendHalfWrite(ClientConnection conn, Throwable halfRequestExc) {
@@ -37,7 +37,7 @@ public class ClientWriteIOThread extends AsyncIOThread {
conn.pauseRequests.removeIf(e -> { conn.pauseRequests.removeIf(e -> {
if (e != null) { if (e != null) {
if (!skipFirst.compareAndSet(true, false)) { if (!skipFirst.compareAndSet(true, false)) {
requestQueue.offer((ClientEntity) e); requestQueue.offer((ClientFuture) e);
} }
} }
return true; return true;
@@ -57,62 +57,62 @@ public class ClientWriteIOThread extends AsyncIOThread {
final ByteBuffer buffer = getBufferSupplier().get(); final ByteBuffer buffer = getBufferSupplier().get();
final int capacity = buffer.capacity(); final int capacity = buffer.capacity();
final ByteArray writeArray = new ByteArray(1024 * 32); final ByteArray writeArray = new ByteArray(1024 * 32);
final Map<ClientConnection, List<ClientEntity>> map = new HashMap<>(); final Map<ClientConnection, List<ClientFuture>> map = new HashMap<>();
final ObjectPool<List> listPool = ObjectPool.createUnsafePool(Utility.cpus() * 2, () -> new ArrayList(), null, t -> { final ObjectPool<List> listPool = ObjectPool.createUnsafePool(Utility.cpus() * 2, () -> new ArrayList(), null, t -> {
t.clear(); t.clear();
return true; return true;
}); });
while (!isClosed()) { while (!isClosed()) {
ClientEntity entity; ClientFuture entry;
try { try {
while ((entity = requestQueue.take()) != null) { while ((entry = requestQueue.take()) != null) {
map.clear(); map.clear();
{ {
Serializable reqid = entity.request.getRequestid(); Serializable reqid = entry.request.getRequestid();
if (reqid == null) { if (reqid == null) {
entity.conn.responseQueue.offer(entity.respFuture); entry.conn.responseQueue.offer(entry);
} else { } else {
entity.conn.responseMap.put(reqid, entity.respFuture); entry.conn.responseMap.put(reqid, entry);
} }
} }
if (entity.conn.pauseWriting.get()) { if (entry.conn.pauseWriting.get()) {
if (entity.conn.pauseResuming.get()) { if (entry.conn.pauseResuming.get()) {
try { try {
synchronized (entity.conn.pauseRequests) { synchronized (entry.conn.pauseRequests) {
entity.conn.pauseRequests.wait(3_000); entry.conn.pauseRequests.wait(3_000);
} }
} catch (InterruptedException ie) { } catch (InterruptedException ie) {
} }
} }
entity.conn.pauseRequests.add(entity); entry.conn.pauseRequests.add(entry);
} else { } else {
map.computeIfAbsent(entity.conn, c -> listPool.get()).add(entity); map.computeIfAbsent(entry.conn, c -> listPool.get()).add(entry);
} }
while ((entity = requestQueue.poll()) != null) { while ((entry = requestQueue.poll()) != null) {
Serializable reqid = entity.request.getRequestid(); Serializable reqid = entry.request.getRequestid();
if (reqid == null) { if (reqid == null) {
entity.conn.responseQueue.offer(entity.respFuture); entry.conn.responseQueue.offer(entry);
} else { } else {
entity.conn.responseMap.put(reqid, entity.respFuture); entry.conn.responseMap.put(reqid, entry);
} }
if (entity.conn.pauseWriting.get()) { if (entry.conn.pauseWriting.get()) {
if (entity.conn.pauseResuming.get()) { if (entry.conn.pauseResuming.get()) {
try { try {
synchronized (entity.conn.pauseRequests) { synchronized (entry.conn.pauseRequests) {
entity.conn.pauseRequests.wait(3_000); entry.conn.pauseRequests.wait(3_000);
} }
} catch (InterruptedException ie) { } catch (InterruptedException ie) {
} }
} }
entity.conn.pauseRequests.add(entity); entry.conn.pauseRequests.add(entry);
} else { } else {
map.computeIfAbsent(entity.conn, c -> listPool.get()).add(entity); map.computeIfAbsent(entry.conn, c -> listPool.get()).add(entry);
} }
} }
map.forEach((conn, list) -> { map.forEach((conn, list) -> {
writeArray.clear(); writeArray.clear();
int i = -1; int i = -1;
for (ClientEntity en : list) { for (ClientFuture en : list) {
++i; ++i;
ClientRequest request = en.request; ClientRequest request = en.request;
request.accept(conn, writeArray); request.accept(conn, writeArray);
@@ -153,23 +153,4 @@ public class ClientWriteIOThread extends AsyncIOThread {
} }
}; };
protected static class ClientEntity {
ClientConnection conn;
ClientRequest request;
ClientFuture respFuture;
public ClientEntity(ClientConnection conn, ClientRequest request, ClientFuture respFuture) {
this.conn = conn;
this.request = request;
this.respFuture = respFuture;
}
@Override
public String toString() {
return getClass().getSimpleName() + "_" + Objects.hash(this) + "{conn = " + conn + ", request = " + request + "}";
}
}
} }

View File

@@ -1405,7 +1405,7 @@ public class HttpResponse extends Response<HttpContext, HttpRequest> {
/** /**
* 判断是否存在Header值 * 判断是否存在Header值
* *
* @param name * @param name header-name
* *
* @return 是否存在 * @return 是否存在
*/ */

View File

@@ -26,7 +26,7 @@ public class SncpDispatcherServlet extends DispatcherServlet<Uint128, SncpContex
synchronized (sncplock) { synchronized (sncplock) {
for (SncpServlet s : getServlets()) { for (SncpServlet s : getServlets()) {
if (s.service == servlet.service) { if (s.service == servlet.service) {
throw new RuntimeException(s.service + " repeat addSncpServlet"); throw new SncpException(s.service + " repeat addSncpServlet");
} }
} }
setServletConf(servlet, conf); setServletConf(servlet, conf);

View File

@@ -53,7 +53,7 @@ public class SncpResponse extends Response<SncpContext, SncpRequest> {
this.addrBytes = context.getServerAddress().getAddress().getAddress(); this.addrBytes = context.getServerAddress().getAddress().getAddress();
this.addrPort = context.getServerAddress().getPort(); this.addrPort = context.getServerAddress().getPort();
if (this.addrBytes.length != 4) { if (this.addrBytes.length != 4) {
throw new RuntimeException("SNCP serverAddress only support IPv4"); throw new SncpException("SNCP serverAddress only support IPv4");
} }
} }

View File

@@ -1867,7 +1867,7 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi
final EntityInfo<T> info = loadEntityInfo(clazz); final EntityInfo<T> info = loadEntityInfo(clazz);
String illegalColumn = checkIllegalColumn(info, selects); String illegalColumn = checkIllegalColumn(info, selects);
if (illegalColumn != null) { if (illegalColumn != null) {
return CompletableFuture.failedFuture(new RuntimeException(info.getType() + " cannot found column " + illegalColumn)); return CompletableFuture.failedFuture(new SourceException(info.getType() + " cannot found column " + illegalColumn));
} }
if (isOnlyCache(info)) { if (isOnlyCache(info)) {
return CompletableFuture.completedFuture(updateCache(info, -1, false, entity, null, selects)); return CompletableFuture.completedFuture(updateCache(info, -1, false, entity, null, selects));
@@ -1929,7 +1929,7 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi
final EntityInfo<T> info = loadEntityInfo(clazz); final EntityInfo<T> info = loadEntityInfo(clazz);
String illegalColumn = checkIllegalColumn(info, selects); String illegalColumn = checkIllegalColumn(info, selects);
if (illegalColumn != null) { if (illegalColumn != null) {
return CompletableFuture.failedFuture(new RuntimeException(info.getType() + " cannot found column " + illegalColumn)); return CompletableFuture.failedFuture(new SourceException(info.getType() + " cannot found column " + illegalColumn));
} }
if (isOnlyCache(info)) { if (isOnlyCache(info)) {
return CompletableFuture.completedFuture(updateCache(info, -1, true, entity, node, selects)); return CompletableFuture.completedFuture(updateCache(info, -1, true, entity, node, selects));
@@ -2571,15 +2571,15 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi
return rs; return rs;
} }
} }
String table = info.getTable(pk); String[] tables = info.getTableOneArray(pk);
String sql = findSql(info, selects, pk); String sql = findSql(info, selects, pk);
if (info.isLoggable(logger, Level.FINEST, sql)) { if (info.isLoggable(logger, Level.FINEST, sql)) {
logger.finest(info.getType().getSimpleName() + " find sql=" + sql); logger.finest(info.getType().getSimpleName() + " find sql=" + sql);
} }
if (isAsync()) { if (isAsync()) {
return findDBAsync(info, new String[]{table}, sql, true, selects, pk, null).join(); return findDBAsync(info, tables, sql, true, selects, pk, null).join();
} else { } else {
return findDB(info, new String[]{table}, sql, true, selects, pk, null); return findDB(info, tables, sql, true, selects, pk, null);
} }
} }
@@ -2593,15 +2593,15 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi
return CompletableFuture.completedFuture(rs); return CompletableFuture.completedFuture(rs);
} }
} }
String table = info.getTable(pk); String[] tables = info.getTableOneArray(pk);
String sql = findSql(info, selects, pk); String sql = findSql(info, selects, pk);
if (info.isLoggable(logger, Level.FINEST, sql)) { if (info.isLoggable(logger, Level.FINEST, sql)) {
logger.finest(info.getType().getSimpleName() + " find sql=" + sql); logger.finest(info.getType().getSimpleName() + " find sql=" + sql);
} }
if (isAsync()) { if (isAsync()) {
return findDBAsync(info, new String[]{table}, sql, true, selects, pk, null); return findDBAsync(info, tables, sql, true, selects, pk, null);
} else { } else {
return supplyAsync(() -> findDB(info, new String[]{table}, sql, true, selects, pk, null)); return supplyAsync(() -> findDB(info, tables, sql, true, selects, pk, null));
} }
} }
@@ -2692,15 +2692,15 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi
return val; return val;
} }
} }
String table = info.getTable(pk); String[] tables = info.getTableOneArray(pk);
String sql = findColumnSql(info, column, defValue, pk); String sql = findColumnSql(info, column, defValue, pk);
if (info.isLoggable(logger, Level.FINEST, sql)) { if (info.isLoggable(logger, Level.FINEST, sql)) {
logger.finest(info.getType().getSimpleName() + " findColumn sql=" + sql); logger.finest(info.getType().getSimpleName() + " findColumn sql=" + sql);
} }
if (isAsync()) { if (isAsync()) {
return findColumnDBAsync(info, new String[]{table}, sql, true, column, defValue, pk, null).join(); return findColumnDBAsync(info, tables, sql, true, column, defValue, pk, null).join();
} else { } else {
return findColumnDB(info, new String[]{table}, sql, true, column, defValue, pk, null); return findColumnDB(info, tables, sql, true, column, defValue, pk, null);
} }
} }
@@ -2714,15 +2714,15 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi
return CompletableFuture.completedFuture(val); return CompletableFuture.completedFuture(val);
} }
} }
String table = info.getTable(pk); String[] tables = info.getTableOneArray(pk);
String sql = findColumnSql(info, column, defValue, pk); String sql = findColumnSql(info, column, defValue, pk);
if (info.isLoggable(logger, Level.FINEST, sql)) { if (info.isLoggable(logger, Level.FINEST, sql)) {
logger.finest(info.getType().getSimpleName() + " findColumn sql=" + sql); logger.finest(info.getType().getSimpleName() + " findColumn sql=" + sql);
} }
if (isAsync()) { if (isAsync()) {
return findColumnDBAsync(info, new String[]{table}, sql, true, column, defValue, pk, null); return findColumnDBAsync(info, tables, sql, true, column, defValue, pk, null);
} else { } else {
return supplyAsync(() -> findColumnDB(info, new String[]{table}, sql, true, column, defValue, pk, null)); return supplyAsync(() -> findColumnDB(info, tables, sql, true, column, defValue, pk, null));
} }
} }
@@ -2819,15 +2819,15 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi
return rs; return rs;
} }
} }
String table = info.getTable(pk); String[] tables = info.getTableOneArray(pk);
String sql = existsSql(info, pk); String sql = existsSql(info, pk);
if (info.isLoggable(logger, Level.FINEST, sql)) { if (info.isLoggable(logger, Level.FINEST, sql)) {
logger.finest(info.getType().getSimpleName() + " exists sql=" + sql); logger.finest(info.getType().getSimpleName() + " exists sql=" + sql);
} }
if (isAsync()) { if (isAsync()) {
return existsDBAsync(info, new String[]{table}, sql, true, pk, null).join(); return existsDBAsync(info, tables, sql, true, pk, null).join();
} else { } else {
return existsDB(info, new String[]{table}, sql, true, pk, null); return existsDB(info, tables, sql, true, pk, null);
} }
} }
@@ -2841,15 +2841,15 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi
return CompletableFuture.completedFuture(rs); return CompletableFuture.completedFuture(rs);
} }
} }
String table = info.getTable(pk); String[] tables = info.getTableOneArray(pk);
String sql = existsSql(info, pk); String sql = existsSql(info, pk);
if (info.isLoggable(logger, Level.FINEST, sql)) { if (info.isLoggable(logger, Level.FINEST, sql)) {
logger.finest(info.getType().getSimpleName() + " exists sql=" + sql); logger.finest(info.getType().getSimpleName() + " exists sql=" + sql);
} }
if (isAsync()) { if (isAsync()) {
return existsDBAsync(info, new String[]{table}, sql, true, pk, null); return existsDBAsync(info, tables, sql, true, pk, null);
} else { } else {
return supplyAsync(() -> existsDB(info, new String[]{table}, sql, true, pk, null)); return supplyAsync(() -> existsDB(info, tables, sql, true, pk, null));
} }
} }

View File

@@ -45,6 +45,9 @@ public final class EntityInfo<T> {
//类对应的数据表名, 如果是VirtualEntity 类, 则该字段为null //类对应的数据表名, 如果是VirtualEntity 类, 则该字段为null
final String table; final String table;
//table的单一元素数组
final String[] tableOneArray;
//JsonConvert //JsonConvert
final JsonConvert jsonConvert; final JsonConvert jsonConvert;
@@ -66,6 +69,9 @@ public final class EntityInfo<T> {
//主键 //主键
final Attribute<T, Serializable> primary; final Attribute<T, Serializable> primary;
//table的单一元素数组
final Attribute<T, Serializable>[] primaryOneArray;
//DDL字段集合 //DDL字段集合
final EntityColumn[] ddlColumns; final EntityColumn[] ddlColumns;
@@ -294,6 +300,7 @@ public final class EntityInfo<T> {
|| type.getAnnotation(org.redkale.source.VirtualEntity.class) != null || type.getAnnotation(org.redkale.source.VirtualEntity.class) != null
|| (source == null || "memory".equalsIgnoreCase(source.getType()))) { || (source == null || "memory".equalsIgnoreCase(source.getType()))) {
this.table = null; this.table = null;
this.tableOneArray = null;
BiFunction<DataSource, EntityInfo, CompletableFuture<List>> loader = null; BiFunction<DataSource, EntityInfo, CompletableFuture<List>> loader = null;
try { try {
org.redkale.persistence.VirtualEntity ve = type.getAnnotation(org.redkale.persistence.VirtualEntity.class); org.redkale.persistence.VirtualEntity ve = type.getAnnotation(org.redkale.persistence.VirtualEntity.class);
@@ -316,6 +323,7 @@ public final class EntityInfo<T> {
throw new SourceException(type + " have illegal table.name on @Table"); throw new SourceException(type + " have illegal table.name on @Table");
} }
this.table = (tableCcatalog0 == null) ? type.getSimpleName().toLowerCase() : (tableCcatalog0.isEmpty()) ? (tableName0.isEmpty() ? type.getSimpleName().toLowerCase() : tableName0) : (tableCcatalog0 + '.' + (tableName0.isEmpty() ? type.getSimpleName().toLowerCase() : tableName0)); this.table = (tableCcatalog0 == null) ? type.getSimpleName().toLowerCase() : (tableCcatalog0.isEmpty()) ? (tableName0.isEmpty() ? type.getSimpleName().toLowerCase() : tableName0) : (tableCcatalog0 + '.' + (tableName0.isEmpty() ? type.getSimpleName().toLowerCase() : tableName0));
this.tableOneArray = new String[]{this.table};
} }
DistributeTable dt = type.getAnnotation(DistributeTable.class); DistributeTable dt = type.getAnnotation(DistributeTable.class);
DistributeTableStrategy dts = null; DistributeTableStrategy dts = null;
@@ -456,6 +464,7 @@ public final class EntityInfo<T> {
this.jsonConvert = convert == null ? DEFAULT_JSON_CONVERT : convert; this.jsonConvert = convert == null ? DEFAULT_JSON_CONVERT : convert;
this.primary = idAttr0; this.primary = idAttr0;
this.primaryOneArray = new Attribute[]{this.primary};
this.aliasmap = aliasmap0; this.aliasmap = aliasmap0;
List<EntityColumn> ddls = new ArrayList<>(); List<EntityColumn> ddls = new ArrayList<>();
Collections.reverse(ddlList); //父类的字段排在前面 Collections.reverse(ddlList); //父类的字段排在前面
@@ -1046,6 +1055,24 @@ public final class EntityInfo<T> {
return t; return t;
} }
/**
* 根据主键值获取Entity的表名单一元素数组
*
* @param primary Entity主键值
*
* @return String[]
*/
public String[] getTableOneArray(Serializable primary) {
if (tableStrategy == null) {
return tableOneArray;
}
String t = tableStrategy.getTable(table, primary);
if (t == null || t.isEmpty()) {
throw new SourceException(table + " tableStrategy.getTable is empty, primary=" + primary);
}
return new String[]{t};
}
/** /**
* 根据过滤条件获取Entity的表名 * 根据过滤条件获取Entity的表名
* *
@@ -1091,6 +1118,15 @@ public final class EntityInfo<T> {
return this.primary; return this.primary;
} }
/**
* 获取主键字段的Attribute单一元素数组
*
* @return Attribute[]
*/
public Attribute<T, Serializable>[] getPrimaryOneArray() {
return this.primaryOneArray;
}
/** /**
* 遍历数据库表对应的所有字段, 不包含&#64;Transient字段 * 遍历数据库表对应的所有字段, 不包含&#64;Transient字段
* *

View File

@@ -459,6 +459,7 @@ public final class ResourceFactory {
/** /**
* 将多个以指定资源名的String对象注入到资源池中 * 将多个以指定资源名的String对象注入到资源池中
* *
* @param <A> 泛型
* @param properties 资源键值对 * @param properties 资源键值对
* @param environmentName 额外的资源名 * @param environmentName 额外的资源名
* @param environmentType 额外的类名 * @param environmentType 额外的类名