SslReadCompletionHandler

This commit is contained in:
redkale
2024-11-11 08:50:18 +08:00
parent 6ade0fdb77
commit 800da01c72
6 changed files with 149 additions and 150 deletions

View File

@@ -355,11 +355,15 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
} else { } else {
try { try {
int remain = src.remaining(); int remain = src.remaining();
sslWriteImpl(false, src, t -> { sslWriteImpl(false, src, new CompletionHandler<Integer, Void>() {
if (t == null) { @Override
public void completed(Integer result, Void attach) {
handler.completed(remain - src.remaining(), attachment); handler.completed(remain - src.remaining(), attachment);
} else { }
handler.failed(t, attachment);
@Override
public void failed(Throwable exc, Void attach) {
handler.failed(exc, attachment);
} }
}); });
} catch (SSLException e) { } catch (SSLException e) {
@@ -375,11 +379,15 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
} else { } else {
try { try {
int remain = ByteBufferReader.remaining(srcs, offset, length); int remain = ByteBufferReader.remaining(srcs, offset, length);
sslWriteImpl(false, srcs, offset, length, t -> { sslWriteImpl(false, srcs, offset, length, new CompletionHandler<Integer, Void>() {
if (t == null) { @Override
public void completed(Integer result, Void attach) {
handler.completed(remain - ByteBufferReader.remaining(srcs, offset, length), attachment); handler.completed(remain - ByteBufferReader.remaining(srcs, offset, length), attachment);
} else { }
handler.failed(t, attachment);
@Override
public void failed(Throwable exc, Void attach) {
handler.failed(exc, attachment);
} }
}); });
} catch (SSLException e) { } catch (SSLException e) {
@@ -839,22 +847,22 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
} }
} }
protected void startHandshake(final Consumer<Throwable> callback) { protected void startHandshake(CompletionHandler<Integer, Void> handler) {
if (sslEngine == null) { if (sslEngine == null) {
callback.accept(null); handler.completed(0, null);
return; return;
} }
SSLEngine engine = this.sslEngine; SSLEngine engine = this.sslEngine;
try { try {
engine.beginHandshake(); engine.beginHandshake();
doHandshake(callback); doHandshake(handler);
} catch (Throwable t) { } catch (Throwable t) {
callback.accept(t); handler.failed(t, null);
} }
} }
// 解密ssl网络数据 返回null表示CLOSED // 解密ssl网络数据 返回null表示CLOSED
protected ByteBuffer sslUnwrap(final boolean handshake, ByteBuffer netBuffer) throws SSLException { protected ByteBuffer sslUnwrap(boolean handshake, ByteBuffer netBuffer) throws SSLException {
ByteBuffer appBuffer = pollReadBuffer(); ByteBuffer appBuffer = pollReadBuffer();
SSLEngine engine = this.sslEngine; SSLEngine engine = this.sslEngine;
HandshakeStatus hss; HandshakeStatus hss;
@@ -885,64 +893,21 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
return appBuffer; return appBuffer;
} }
protected void sslReadImpl(final boolean handshake, final CompletionHandler<Integer, ByteBuffer> handler) { protected void sslReadImpl(boolean handshake, CompletionHandler<Integer, ByteBuffer> handler) {
readImpl(createSslCompletionHandler(handshake, handler)); readImpl(createSslCompletionHandler(handshake, handler));
} }
protected void sslReadRegisterImpl(final boolean handshake, final CompletionHandler<Integer, ByteBuffer> handler) { protected void sslReadRegisterImpl(boolean handshake, CompletionHandler<Integer, ByteBuffer> handler) {
readRegisterImpl(createSslCompletionHandler(handshake, handler)); readRegisterImpl(createSslCompletionHandler(handshake, handler));
} }
private CompletionHandler<Integer, ByteBuffer> createSslCompletionHandler( private CompletionHandler<Integer, ByteBuffer> createSslCompletionHandler(
final boolean handshake, final CompletionHandler<Integer, ByteBuffer> handler) { boolean handshake, CompletionHandler<Integer, ByteBuffer> handler) {
return new CompletionHandler<Integer, ByteBuffer>() { return new SslReadCompletionHandler(handshake, handler);
@Override
public void completed(Integer count, ByteBuffer attachment) {
// System.out.println(AsyncConnection.this + " 进来了读到的字节数: " + count);
if (count < 0) {
handler.completed(count, attachment);
return;
}
ByteBuffer netBuffer = attachment;
netBuffer.flip();
try {
ByteBuffer appBuffer = sslUnwrap(handshake, netBuffer);
if (appBuffer == null) {
return; // CLOSEDnetBuffer已被回收
}
if (AsyncConnection.this.readSSLHalfBuffer != netBuffer) {
offerReadBuffer(netBuffer);
}
if (AsyncConnection.this.readBuffer != null) {
ByteBuffer rsBuffer = AsyncConnection.this.readBuffer;
AsyncConnection.this.readBuffer = null;
appBuffer.flip();
if (rsBuffer.remaining() >= appBuffer.remaining()) {
rsBuffer.put(appBuffer);
offerReadBuffer(appBuffer);
appBuffer = rsBuffer;
} else {
while (rsBuffer.hasRemaining()) rsBuffer.put(appBuffer.get());
AsyncConnection.this.readBuffer = appBuffer.compact();
appBuffer = rsBuffer;
}
}
handler.completed(count, appBuffer);
} catch (SSLException e) {
failed(e, attachment);
}
}
@Override
public void failed(Throwable t, ByteBuffer attachment) {
handler.failed(t, attachment);
}
};
} }
// 加密ssl内容数据 // 加密ssl内容数据
protected ByteBuffer[] sslWrap(final boolean handshake, ByteBuffer appBuffer) throws SSLException { protected ByteBuffer[] sslWrap(boolean handshake, ByteBuffer appBuffer) throws SSLException {
final SSLEngine engine = this.sslEngine; final SSLEngine engine = this.sslEngine;
final int netSize = engine.getSession().getPacketBufferSize(); final int netSize = engine.getSession().getPacketBufferSize();
ByteBuffer netBuffer = pollWriteBuffer(); ByteBuffer netBuffer = pollWriteBuffer();
@@ -980,7 +945,7 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
} }
// 加密ssl内容数据 // 加密ssl内容数据
protected ByteBuffer[] sslWrap(final boolean handshake, ByteBuffer[] appBuffers, int offset, int length) protected ByteBuffer[] sslWrap(boolean handshake, ByteBuffer[] appBuffers, int offset, int length)
throws SSLException { throws SSLException {
final SSLEngine engine = this.sslEngine; final SSLEngine engine = this.sslEngine;
final int netSize = engine.getSession().getPacketBufferSize(); final int netSize = engine.getSession().getPacketBufferSize();
@@ -1016,38 +981,14 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
return netBuffers; return netBuffers;
} }
protected boolean sslWriteImpl(final boolean handshake, ByteBuffer appBuffer, final Consumer<Throwable> callback) protected boolean sslWriteImpl(boolean handshake, ByteBuffer appBuffer, CompletionHandler<Integer, Void> handler)
throws SSLException { throws SSLException {
ByteBuffer[] netBuffers = sslWrap(handshake, appBuffer); ByteBuffer[] netBuffers = sslWrap(handshake, appBuffer);
if (netBuffers.length > 0) { if (netBuffers.length > 0) {
if (netBuffers.length == 1) { if (netBuffers.length == 1) {
writeImpl(netBuffers[0], null, new CompletionHandler<Integer, Void>() { writeImpl(netBuffers[0], writeBufferConsumer, null, handler);
@Override
public void completed(Integer count, Void attachment) {
offerWriteBuffer(netBuffers[0]);
callback.accept(null);
}
@Override
public void failed(Throwable t, Void attachment) {
offerWriteBuffer(netBuffers[0]);
callback.accept(t);
}
});
} else { } else {
writeImpl(netBuffers, 0, netBuffers.length, null, new CompletionHandler<Integer, Void>() { writeImpl(netBuffers, 0, netBuffers.length, writeBufferConsumer, null, handler);
@Override
public void completed(Integer count, Void attachment) {
offerWriteBuffers(netBuffers);
callback.accept(null);
}
@Override
public void failed(Throwable t, Void attachment) {
offerWriteBuffers(netBuffers);
callback.accept(t);
}
});
} }
return true; return true;
} else { } else {
@@ -1057,42 +998,18 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
} }
protected boolean sslWriteImpl( protected boolean sslWriteImpl(
final boolean handshake, boolean handshake,
ByteBuffer[] appBuffers, ByteBuffer[] appBuffers,
int offset, int offset,
int length, int length,
final Consumer<Throwable> callback) CompletionHandler<Integer, Void> handler)
throws SSLException { throws SSLException {
ByteBuffer[] netBuffers = sslWrap(handshake, appBuffers, offset, length); ByteBuffer[] netBuffers = sslWrap(handshake, appBuffers, offset, length);
if (netBuffers.length > 0) { if (netBuffers.length > 0) {
if (netBuffers.length == 1) { if (netBuffers.length == 1) {
writeImpl(netBuffers[0], null, new CompletionHandler<Integer, Void>() { writeImpl(netBuffers[0], writeBufferConsumer, null, handler);
@Override
public void completed(Integer count, Void attachment) {
offerWriteBuffer(netBuffers[0]);
callback.accept(null);
}
@Override
public void failed(Throwable t, Void attachment) {
offerWriteBuffer(netBuffers[0]);
callback.accept(t);
}
});
} else { } else {
writeImpl(netBuffers, 0, netBuffers.length, null, new CompletionHandler<Integer, Void>() { writeImpl(netBuffers, 0, netBuffers.length, writeBufferConsumer, null, handler);
@Override
public void completed(Integer count, Void attachment) {
offerWriteBuffers(netBuffers);
callback.accept(null);
}
@Override
public void failed(Throwable t, Void attachment) {
offerWriteBuffers(netBuffers);
callback.accept(t);
}
});
} }
return true; return true;
} else { } else {
@@ -1101,7 +1018,7 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
} }
} }
private void doHandshake(final Consumer<Throwable> callback) { private void doHandshake(CompletionHandler<Integer, Void> handler) {
HandshakeStatus handshakeStatus; HandshakeStatus handshakeStatus;
final SSLEngine engine = this.sslEngine; final SSLEngine engine = this.sslEngine;
while ((handshakeStatus = engine.getHandshakeStatus()) != null) { while ((handshakeStatus = engine.getHandshakeStatus()) != null) {
@@ -1110,7 +1027,7 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
case FINISHED: case FINISHED:
case NOT_HANDSHAKING: case NOT_HANDSHAKING:
// System.out.println(AsyncConnection.this + " doHandshakde完毕开始进入读写操作-----"); // System.out.println(AsyncConnection.this + " doHandshakde完毕开始进入读写操作-----");
callback.accept(null); handler.completed(0, null);
return; return;
case NEED_TASK: { case NEED_TASK: {
Runnable task; Runnable task;
@@ -1121,18 +1038,22 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
} }
case NEED_WRAP: { case NEED_WRAP: {
try { // try { //
boolean rs = sslWriteImpl(true, EMPTY_BUFFER, t -> { boolean rs = sslWriteImpl(true, EMPTY_BUFFER, new CompletionHandler<Integer, Void>() {
if (t == null) { @Override
doHandshake(callback); public void completed(Integer result, Void attachment) {
} else { doHandshake(handler);
callback.accept(t); }
@Override
public void failed(Throwable exc, Void attachment) {
handler.failed(exc, attachment);
} }
}); });
if (rs) { if (rs) {
return; return;
} }
} catch (SSLException e) { } catch (SSLException e) {
callback.accept(e); handler.failed(e, null);
return; return;
} }
break; break;
@@ -1142,16 +1063,19 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
@Override @Override
public void completed(Integer count, ByteBuffer attachment) { public void completed(Integer count, ByteBuffer attachment) {
if (count < 1) { if (count < 1) {
callback.accept(new IOException("read data error")); handler.failed(new IOException("read data error"), null);
} else { } else {
offerReadBuffer(attachment); offerReadBuffer(attachment);
doHandshake(callback); doHandshake(handler);
} }
} }
@Override @Override
public void failed(Throwable t, ByteBuffer attachment) { public void failed(Throwable t, ByteBuffer attachment) {
callback.accept(t); if (attachment != null) {
offerReadBuffer(attachment);
}
handler.failed(t, null);
} }
}); });
return; return;
@@ -1176,4 +1100,59 @@ public abstract class AsyncConnection implements Channel, AutoCloseable {
for (int i = 0; i < cha; i++) s += ' '; for (int i = 0; i < cha; i++) s += ' ';
return s; return s;
} }
protected class SslReadCompletionHandler implements CompletionHandler<Integer, ByteBuffer> {
private boolean handshake;
private CompletionHandler<Integer, ByteBuffer> handler;
public SslReadCompletionHandler(boolean handshake, CompletionHandler<Integer, ByteBuffer> handler) {
this.handshake = handshake;
this.handler = handler;
}
@Override
public void completed(Integer count, ByteBuffer attachment) {
// System.out.println(AsyncConnection.this + " 进来了读到的字节数: " + count);
if (count < 0) {
handler.completed(count, attachment);
return;
}
ByteBuffer netBuffer = attachment;
netBuffer.flip();
try {
ByteBuffer appBuffer = sslUnwrap(handshake, netBuffer);
if (appBuffer == null) {
failed(new SSLException("appBuffer is null"), attachment);
return; // CLOSEDnetBuffer已被回收
}
if (readSSLHalfBuffer != netBuffer) { // unwap完整
offerReadBuffer(netBuffer);
}
if (readBuffer != null) {
ByteBuffer rsBuffer = readBuffer;
readBuffer = null;
appBuffer.flip();
if (rsBuffer.remaining() >= appBuffer.remaining()) {
rsBuffer.put(appBuffer);
offerReadBuffer(appBuffer);
appBuffer = rsBuffer;
} else {
while (rsBuffer.hasRemaining()) rsBuffer.put(appBuffer.get());
readBuffer = appBuffer.compact();
appBuffer = rsBuffer;
}
}
handler.completed(count, appBuffer);
} catch (SSLException e) {
failed(e, attachment);
}
}
@Override
public void failed(Throwable t, ByteBuffer attachment) {
handler.failed(t, attachment);
}
}
} }

View File

@@ -267,10 +267,14 @@ public class AsyncIOGroup extends AsyncGroup {
if (conn.sslEngine == null) { if (conn.sslEngine == null) {
future.complete(conn); future.complete(conn);
} else { } else {
conn.startHandshake(t -> { conn.startHandshake(new CompletionHandler<Integer, Void>() {
if (t == null) { @Override
public void completed(Integer result, Void attachment) {
future.complete(conn); future.complete(conn);
} else { }
@Override
public void failed(Throwable t, Void attachment) {
future.completeExceptionally(t); future.completeExceptionally(t);
} }
}); });
@@ -344,10 +348,14 @@ public class AsyncIOGroup extends AsyncGroup {
if (conn.sslEngine == null) { if (conn.sslEngine == null) {
future.complete(conn); future.complete(conn);
} else { } else {
conn.startHandshake(t -> { conn.startHandshake(new CompletionHandler<Integer, Void>() {
if (t == null) { @Override
public void completed(Integer result, Void attachment) {
future.complete(conn); future.complete(conn);
} else { }
@Override
public void failed(Throwable t, Void attachment) {
future.completeExceptionally(t); future.completeExceptionally(t);
} }
}); });

View File

@@ -85,8 +85,8 @@ abstract class AsyncNioConnection extends AsyncConnection {
} }
@Override @Override
protected void startHandshake(final Consumer<Throwable> callback) { protected void startHandshake(CompletionHandler<Integer, Void> handler) {
ioReadThread.register(t -> super.startHandshake(callback)); ioReadThread.register(t -> super.startHandshake(handler));
} }
@Override @Override
@@ -102,7 +102,7 @@ abstract class AsyncNioConnection extends AsyncConnection {
return; return;
} }
if (handler != readCompletionHandler) { // 如果是Codec无需重复赋值 if (handler != readCompletionHandler) { // 如果是Codec无需重复赋值
if (this.readPending) { if (this.readPending && handler.getClass() != SslReadCompletionHandler.class) {
handler.failed(new ReadPendingException(), null); handler.failed(new ReadPendingException(), null);
return; return;
} }

View File

@@ -197,14 +197,20 @@ class AsyncNioTcpProtocolServer extends ProtocolServer {
if (conn.sslEngine == null) { if (conn.sslEngine == null) {
codec.start(null); codec.start(null);
} else { } else {
conn.startHandshake(t -> { conn.startHandshake(new CompletionHandler<Integer, Void>() {
if (t == null) { @Override
public void completed(Integer result, Void attachment) {
codec.start(null); codec.start(null);
} else if (t instanceof RuntimeException) { }
@Override
public void failed(Throwable t, Void attachment) {
if (t instanceof RuntimeException) {
throw (RuntimeException) t; throw (RuntimeException) t;
} else { } else {
throw new RedkaleException(t); throw new RedkaleException(t);
} }
}
}); });
} }
} }

View File

@@ -209,14 +209,20 @@ class AsyncNioUdpProtocolServer extends ProtocolServer {
if (conn.sslEngine == null) { if (conn.sslEngine == null) {
codec.start(buffer); codec.start(buffer);
} else { } else {
conn.startHandshake(t -> { conn.startHandshake(new CompletionHandler<Integer, Void>() {
if (t == null) { @Override
public void completed(Integer result, Void attachment) {
codec.start(buffer); codec.start(buffer);
} else if (t instanceof RuntimeException) { }
@Override
public void failed(Throwable t, Void attachment) {
if (t instanceof RuntimeException) {
throw (RuntimeException) t; throw (RuntimeException) t;
} else { } else {
throw new RedkaleException(t); throw new RedkaleException(t);
} }
}
}); });
} }
} }