diff --git a/src/org/redkale/net/PrepareServlet.java b/src/org/redkale/net/PrepareServlet.java index 4eea2af3a..0e33d414d 100644 --- a/src/org/redkale/net/PrepareServlet.java +++ b/src/org/redkale/net/PrepareServlet.java @@ -50,10 +50,10 @@ public abstract class PrepareServlet newmappings = new HashMap<>(mappings); - newmappings.put(key, value); + newmappings.put(key, servlet); this.mappings = newmappings; } } diff --git a/src/org/redkale/net/http/HttpPrepareServlet.java b/src/org/redkale/net/http/HttpPrepareServlet.java index b8f544d55..04a3733dd 100644 --- a/src/org/redkale/net/http/HttpPrepareServlet.java +++ b/src/org/redkale/net/http/HttpPrepareServlet.java @@ -29,7 +29,11 @@ public class HttpPrepareServlet extends PrepareServlet, HttpServlet>[] regArray = new SimpleEntry[0]; + protected SimpleEntry, HttpServlet>[] regArray = null; //regArray 包含 regWsArray + + protected Map wsmappings = new HashMap<>(); //super.mappings 包含 wsmappings + + protected SimpleEntry, WebSocketServlet>[] regWsArray = null; protected HttpServlet resourceHttpServlet = new HttpResourceServlet(); @@ -83,17 +87,34 @@ public class HttpPrepareServlet extends PrepareServlet servlet = mappingServlet(uri); - if (servlet == null && this.regArray != null) { - for (SimpleEntry, HttpServlet> en : regArray) { - if (en.getKey().test(uri)) { - servlet = en.getValue(); - break; + Servlet servlet = null; + if (request.isWebSocket()) { + servlet = wsmappings.get(uri); + if (servlet == null && this.regWsArray != null) { + for (SimpleEntry, WebSocketServlet> en : regWsArray) { + if (en.getKey().test(uri)) { + servlet = en.getValue(); + break; + } } } + if (servlet == null) { + response.finish(500, null); + return; + } + } else { + servlet = mappingServlet(uri); + if (servlet == null && this.regArray != null) { + for (SimpleEntry, HttpServlet> en : regArray) { + if (en.getKey().test(uri)) { + servlet = en.getValue(); + break; + } + } + } + //找不到匹配的HttpServlet则使用静态资源HttpResourceServlet + if (servlet == null) servlet = this.resourceHttpServlet; } - //找不到匹配的HttpServlet则使用静态资源HttpResourceServlet - if (servlet == null) servlet = this.resourceHttpServlet; servlet.execute(request, response); } catch (Exception e) { request.getContext().getLogger().log(Level.WARNING, "Servlet occur, forece to close channel. request = " + request, e); @@ -138,8 +159,22 @@ public class HttpPrepareServlet extends PrepareServlet(Pattern.compile(mapping).asPredicate(), servlet); } + if (servlet instanceof WebSocketServlet) { + if (regWsArray == null) { + regWsArray = new SimpleEntry[1]; + regWsArray[0] = new SimpleEntry<>(Pattern.compile(mapping).asPredicate(), (WebSocketServlet) servlet); + } else { + regWsArray = Arrays.copyOf(regWsArray, regWsArray.length + 1); + regWsArray[regWsArray.length - 1] = new SimpleEntry<>(Pattern.compile(mapping).asPredicate(), (WebSocketServlet) servlet); + } + } } else if (mapping != null && !mapping.isEmpty()) { putMapping(mapping, servlet); + if (servlet instanceof WebSocketServlet) { + Map newmappings = new HashMap<>(wsmappings); + newmappings.put(mapping, (WebSocketServlet) servlet); + this.wsmappings = newmappings; + } } if (this.allMapStrings.containsKey(mapping)) { Class old = this.allMapStrings.get(mapping); @@ -185,6 +220,9 @@ public class HttpPrepareServlet extends PrepareServlet { this.remoteAddrHeader = remoteAddrHeader; } + protected boolean isWebSocket() { + return connection != null && connection.contains("Upgrade") && "GET".equalsIgnoreCase(method) && "websocket".equalsIgnoreCase(getHeader("Upgrade")); + } + protected void setKeepAlive(boolean keepAlive) { this.keepAlive = keepAlive; } diff --git a/src/org/redkale/net/http/WebSocketServlet.java b/src/org/redkale/net/http/WebSocketServlet.java index dd6228455..19dfc0eef 100644 --- a/src/org/redkale/net/http/WebSocketServlet.java +++ b/src/org/redkale/net/http/WebSocketServlet.java @@ -98,9 +98,7 @@ public abstract class WebSocketServlet extends HttpServlet implements Resourcabl @Override public final void execute(final HttpRequest request, final HttpResponse response) throws IOException { final boolean debug = logger.isLoggable(Level.FINEST); - if (!"GET".equalsIgnoreCase(request.getMethod()) - || !request.getConnection().contains("Upgrade") - || !"websocket".equalsIgnoreCase(request.getHeader("Upgrade"))) { + if (!request.isWebSocket()) { if (debug) logger.finest("WebSocket connect abort, (Not GET Method) or (Connection != Upgrade) or (Upgrade != websocket). request=" + request); response.finish(true); return;