diff --git a/src/main/java/org/redkale/source/DataJdbcSource.java b/src/main/java/org/redkale/source/DataJdbcSource.java index 61ca2cc20..00ff53354 100644 --- a/src/main/java/org/redkale/source/DataJdbcSource.java +++ b/src/main/java/org/redkale/source/DataJdbcSource.java @@ -134,29 +134,55 @@ public class DataJdbcSource extends DataSqlSource { try { int c = 0; conn = writePool.pollConnection(); - final String sql = info.getInsertQuestionPrepareSQL(entitys[0]); Attribute[] attrs = info.insertAttributes; conn.setReadOnly(false); - conn.setAutoCommit(true); - PreparedStatement prestmt = createInsertPreparedStatement(conn, sql, info, entitys); + conn.setAutoCommit(false); + + String presql = null; + PreparedStatement prestmt = null; + + List prestmts = null; + Map> prepareInfos = null; + + if (info.getTableStrategy() == null) { + presql = info.getInsertQuestionPrepareSQL(entitys[0]); + prestmt = createInsertPreparedStatement(conn, presql, info, entitys); + } else { + prepareInfos = getInsertQuestionPrepareInfo(info, entitys); + prestmts = createInsertPreparedStatements(conn, info, prepareInfos, entitys); + } try { - int[] cs = prestmt.executeBatch(); - int c1 = 0; - for (int cc : cs) { - c1 += cc; + if (info.getTableStrategy() == null) { + int c1 = 0; + int[] cs = prestmt.executeBatch(); + for (int cc : cs) { + c1 += cc; + } + c = c1; + } else { + int c1 = 0; + for (PreparedStatement stmt : prestmts) { + int[] cs = stmt.executeBatch(); + for (int cc : cs) { + c1 += cc; + } + } + c = c1; } - c = c1; + conn.commit(); } catch (SQLException se) { + conn.rollback(); if (!isTableNotExist(info, se.getSQLState())) throw se; if (info.getTableStrategy() == null) { //单表 - String[] tablesqls = createTableSqls(info); - if (tablesqls == null) throw se; + String[] tableSqls = createTableSqls(info); + if (tableSqls == null) throw se; + //创建单表结构 Statement st = conn.createStatement(); - if (tablesqls.length == 1) { - st.execute(tablesqls[0]); + if (tableSqls.length == 1) { + st.execute(tableSqls[0]); } else { - for (String tablesql : tablesqls) { - st.addBatch(tablesql); + for (String tableSql : tableSqls) { + st.addBatch(tableSql); } st.executeBatch(); } @@ -164,74 +190,91 @@ public class DataJdbcSource extends DataSqlSource { } else { //分库分表 synchronized (info.disTableLock()) { final String catalog = conn.getCatalog(); - final String newTable = info.getTable(entitys[0]); - final String tableKey = newTable.indexOf('.') > 0 ? newTable : (catalog + '.' + newTable); - if (!info.containsDisTable(tableKey)) { - try { - //执行一遍复制表操作 - Statement st = conn.createStatement(); - st.execute(getTableCopySQL(info, newTable)); - st.close(); - info.addDisTable(tableKey); - } catch (SQLException sqle) { //多进程并发时可能会出现重复建表 - if (isTableNotExist(info, sqle.getSQLState())) { - if (newTable.indexOf('.') < 0) { //分表的原始表不存在 - String[] tablesqls = createTableSqls(info); - if (tablesqls != null) { //创建原始表 - Statement st = conn.createStatement(); - if (tablesqls.length == 1) { - st.execute(tablesqls[0]); - } else { - for (String tablesql : tablesqls) { - st.addBatch(tablesql); + final Set newCatalogs = new LinkedHashSet<>(); + final List tableCopys = new ArrayList<>(); + prepareInfos.forEach((t, p) -> { + int pos = t.indexOf('.'); + if (pos > 0) { + newCatalogs.add(t.substring(0, pos)); + } + tableCopys.add(getTableCopySQL(info, t)); + }); + try { + //执行一遍创建分表操作 + Statement st = conn.createStatement(); + for (String copySql : tableCopys) { + st.addBatch(copySql); + } + st.executeBatch(); + st.close(); + } catch (SQLException sqle) { //多进程并发时可能会出现重复建表 + if (isTableNotExist(info, sqle.getSQLState())) { + if (newCatalogs.isEmpty()) { //分表的原始表不存在 + String[] tableSqls = createTableSqls(info); + if (tableSqls != null) { + //创建原始表 + Statement st = conn.createStatement(); + if (tableSqls.length == 1) { + st.execute(tableSqls[0]); + } else { + for (String tableSql : tableSqls) { + st.addBatch(tableSql); + } + st.executeBatch(); + } + st.close(); + //再执行一遍创建分表操作 + st = conn.createStatement(); + for (String copySql : tableCopys) { + st.addBatch(copySql); + } + st.executeBatch(); + st.close(); + } + } else { //需要先建库 + Statement st; + try { + st = conn.createStatement(); + for (String newCatalog : newCatalogs) { + st.addBatch(("postgresql".equals(dbtype()) ? "CREATE SCHEMA IF NOT EXISTS " : "CREATE DATABASE IF NOT EXISTS ") + newCatalog); + } + st.executeBatch(); + st.close(); + } catch (SQLException sqle1) { + logger.log(Level.SEVERE, "create database " + tableCopys + " error", sqle1); + } + try { + //再执行一遍创建分表操作 + st = conn.createStatement(); + for (String copySql : tableCopys) { + st.addBatch(copySql); + } + st.executeBatch(); + st.close(); + } catch (SQLException sqle2) { + if (isTableNotExist(info, sqle2.getSQLState())) { + String[] tablesqls = createTableSqls(info); + if (tablesqls != null) { //创建原始表 + st = conn.createStatement(); + if (tablesqls.length == 1) { + st.execute(tablesqls[0]); + } else { + for (String tableSql : tablesqls) { + st.addBatch(tableSql); + } + st.executeBatch(); + } + st.close(); + //再执行一遍创建分表操作 + st = conn.createStatement(); + for (String copySql : tableCopys) { + st.addBatch(copySql); } st.executeBatch(); + st.close(); } - st.close(); - //再执行一遍创建分表操作 - st = conn.createStatement(); - st.execute(getTableCopySQL(info, newTable)); - st.close(); - info.addDisTable(tableKey); - } - } else { //需要先建库 - Statement st; - try { - st = conn.createStatement(); - st.execute(("postgresql".equals(dbtype()) ? "CREATE SCHEMA IF NOT EXISTS " : "CREATE DATABASE IF NOT EXISTS ") + newTable.substring(0, newTable.indexOf('.'))); - st.close(); - } catch (SQLException sqle1) { - logger.log(Level.SEVERE, "create database(" + newTable.substring(0, newTable.indexOf('.')) + ") error", sqle1); - } - try { - //再执行一遍创建分表操作 - st = conn.createStatement(); - st.execute(getTableCopySQL(info, newTable)); - st.close(); - info.addDisTable(tableKey); - } catch (SQLException sqle2) { - if (isTableNotExist(info, sqle2.getSQLState())) { - String[] tablesqls = createTableSqls(info); - if (tablesqls != null) { //创建原始表 - st = conn.createStatement(); - if (tablesqls.length == 1) { - st.execute(tablesqls[0]); - } else { - for (String tablesql : tablesqls) { - st.addBatch(tablesql); - } - st.executeBatch(); - } - st.close(); - //再执行一遍创建分表操作 - st = conn.createStatement(); - st.execute(getTableCopySQL(info, newTable)); - st.close(); - info.addDisTable(tableKey); - } - } else { - logger.log(Level.SEVERE, "create table2(" + getTableCopySQL(info, newTable) + ") error", sqle2); - } + } else { + logger.log(Level.SEVERE, "create table2 " + tableCopys + " error", sqle2); } } } @@ -239,50 +282,120 @@ public class DataJdbcSource extends DataSqlSource { } } } - prestmt.close(); - prestmt = createInsertPreparedStatement(conn, sql, info, entitys); - int[] cs = prestmt.executeBatch(); - int c1 = 0; - for (int cc : cs) { - c1 += cc; - } - c = c1; - } - prestmt.close(); - //------------------------------------------------------------ - if (info.isLoggable(logger, Level.FINEST)) { //打印调试信息 - char[] sqlchars = sql.toCharArray(); - for (final T value : entitys) { - //----------------------------- - StringBuilder sb = new StringBuilder(128); - int i = 0; - for (char ch : sqlchars) { - if (ch == '?') { - Object obj = info.getSQLValue(attrs[i++], value); - if (obj != null && obj.getClass().isArray()) { - sb.append("'[length=").append(java.lang.reflect.Array.getLength(obj)).append("]'"); - } else { - sb.append(info.formatSQLValue(obj, sqlFormatter)); - } - } else { - sb.append(ch); + if (info.getTableStrategy() == null) { + prestmt.close(); + prestmt = createInsertPreparedStatement(conn, presql, info, entitys); + int c1 = 0; + int[] cs = prestmt.executeBatch(); + for (int cc : cs) { + c1 += cc; + } + c = c1; + prestmt.close(); + } else { + for (PreparedStatement stmt : prestmts) { + stmt.close(); + } + prestmts = createInsertPreparedStatements(conn, info, prepareInfos, entitys); + int c1 = 0; + for (PreparedStatement stmt : prestmts) { + int[] cs = stmt.executeBatch(); + for (int cc : cs) { + c1 += cc; } } - String debugsql = sb.toString(); - if (info.isLoggable(logger, Level.FINEST, debugsql)) logger.finest(info.getType().getSimpleName() + " insert sql=" + debugsql.replaceAll("(\r|\n)", "\\n")); + c = c1; + for (PreparedStatement stmt : prestmts) { + stmt.close(); + } } - } //打印结束 - slowLog(s, sql); + conn.commit(); + } + //------------------------------------------------------------ + if (info.isLoggable(logger, Level.FINEST)) { //打印调试信息 + if (info.getTableStrategy() == null) { + char[] sqlchars = presql.toCharArray(); + for (final T value : entitys) { + //----------------------------- + StringBuilder sb = new StringBuilder(128); + int i = 0; + for (char ch : sqlchars) { + if (ch == '?') { + Object obj = info.getSQLValue(attrs[i++], value); + if (obj != null && obj.getClass().isArray()) { + sb.append("'[length=").append(java.lang.reflect.Array.getLength(obj)).append("]'"); + } else { + sb.append(info.formatSQLValue(obj, sqlFormatter)); + } + } else { + sb.append(ch); + } + } + String debugsql = sb.toString(); + if (info.isLoggable(logger, Level.FINEST, debugsql)) logger.finest(info.getType().getSimpleName() + " insert sql=" + debugsql.replaceAll("(\r|\n)", "\\n")); + } + } else { + prepareInfos.forEach((t, p) -> { + char[] sqlchars = p.prepareSql.toCharArray(); + for (final T value : p.entitys) { + //----------------------------- + StringBuilder sb = new StringBuilder(128); + int i = 0; + for (char ch : sqlchars) { + if (ch == '?') { + Object obj = info.getSQLValue(attrs[i++], value); + if (obj != null && obj.getClass().isArray()) { + sb.append("'[length=").append(java.lang.reflect.Array.getLength(obj)).append("]'"); + } else { + sb.append(info.formatSQLValue(obj, sqlFormatter)); + } + } else { + sb.append(ch); + } + } + String debugsql = sb.toString(); + if (info.isLoggable(logger, Level.FINEST, debugsql)) logger.finest(info.getType().getSimpleName() + " insert sql=" + debugsql.replaceAll("(\r|\n)", "\\n")); + } + }); + } + } //打印结束 + if (info.getTableStrategy() == null) { + slowLog(s, presql); + } else { + List presqls = new ArrayList<>(); + prepareInfos.forEach((t, p) -> { + presqls.add(p.prepareSql); + }); + slowLog(s, presqls.toArray(new String[presqls.size()])); + } return CompletableFuture.completedFuture(c); } catch (SQLException e) { + try { + if (conn != null) conn.rollback(); + } catch (SQLException se) { + } return CompletableFuture.failedFuture(e); } finally { if (conn != null) writePool.offerConnection(conn); } } - protected PreparedStatement createInsertPreparedStatement(final Connection conn, final String sql, - final EntityInfo info, T... entitys) throws SQLException { + protected List createInsertPreparedStatements(final Connection conn, EntityInfo info, Map> prepareInfos, T... entitys) throws SQLException { + Attribute[] attrs = info.insertAttributes; + final List prestmts = new ArrayList<>(); + for (Map.Entry> en : prepareInfos.entrySet()) { + PrepareInfo prepareInfo = en.getValue(); + PreparedStatement prestmt = conn.prepareStatement(prepareInfo.prepareSql); + for (final T value : prepareInfo.entitys) { + batchStatementParameters(conn, prestmt, info, attrs, value); + prestmt.addBatch(); + } + prestmts.add(prestmt); + } + return prestmts; + } + + protected PreparedStatement createInsertPreparedStatement(Connection conn, String sql, EntityInfo info, T... entitys) throws SQLException { Attribute[] attrs = info.insertAttributes; final PreparedStatement prestmt = conn.prepareStatement(sql); diff --git a/src/main/java/org/redkale/source/DataSqlSource.java b/src/main/java/org/redkale/source/DataSqlSource.java index af7d044c3..02b586fa4 100644 --- a/src/main/java/org/redkale/source/DataSqlSource.java +++ b/src/main/java/org/redkale/source/DataSqlSource.java @@ -86,7 +86,7 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi //用于判断表不存在的使用, 多个SQLState用;隔开 protected String tableNotExistSqlstates; - //用于复制表结构使用 + //用于复制表结构使用, sql语句必须包含IF NOT EXISTS判断,确保重复执行不会报错 protected String tablecopySQL; protected AnyValue config; @@ -513,6 +513,46 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi return val; } + @Local + protected Map> getInsertQuestionPrepareInfo(EntityInfo info, T... entitys) { + Map> map = new LinkedHashMap<>(); + for (T entity : entitys) { + String table = info.getTable(entity); + map.computeIfAbsent(table, t -> new PrepareInfo(info.getInsertQuestionPrepareSQL(entity))).addEntity(entity); + } + return map; + } + + @Local + protected Map> getInsertDollarPrepareInfo(EntityInfo info, T... entitys) { + Map> map = new LinkedHashMap<>(); + for (T entity : entitys) { + String table = info.getTable(entity); + map.computeIfAbsent(table, t -> new PrepareInfo(info.getInsertDollarPrepareSQL(entity))).addEntity(entity); + } + return map; + } + + @Local + protected Map> getUpdateQuestionPrepareInfo(EntityInfo info, T... entitys) { + Map> map = new LinkedHashMap<>(); + for (T entity : entitys) { + String table = info.getTable(entity); + map.computeIfAbsent(table, t -> new PrepareInfo(info.getUpdateQuestionPrepareSQL(entity))).addEntity(entity); + } + return map; + } + + @Local + protected Map> getUpdateDollarPrepareInfo(EntityInfo info, T... entitys) { + Map> map = new LinkedHashMap<>(); + for (T entity : entitys) { + String table = info.getTable(entity); + map.computeIfAbsent(table, t -> new PrepareInfo(info.getUpdateDollarPrepareSQL(entity))).addEntity(entity); + } + return map; + } + @Local protected Serializable getEntityAttrValue(EntityInfo info, Attribute attr, T entity) { Serializable val = info.getSQLValue(attr, entity); @@ -2574,4 +2614,22 @@ public abstract class DataSqlSource extends AbstractDataSource implements Functi } } + + protected static class PrepareInfo { + + public String prepareSql; + + public List entitys; + + public PrepareInfo(String prepareSql) { + this.prepareSql = prepareSql; + } + + public void addEntity(T entity) { + if (entitys == null) { + entitys = new ArrayList<>(); + } + entitys.add(entity); + } + } }