DataSource增加字段加解密功能,主类:CryptColumn/CryptHandler

This commit is contained in:
Redkale
2019-01-11 16:41:14 +08:00
parent 623c0a127e
commit 91d4477ed9
6 changed files with 165 additions and 31 deletions

View File

@@ -0,0 +1,29 @@
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package org.redkale.source;
import static java.lang.annotation.ElementType.FIELD;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
import java.lang.annotation.*;
/**
* 加密字段标记 <br>
* 注意: 加密字段不能用于 LIKE 等过滤查询
*
* <p>
* 详情见: https://redkale.org
*
* @author zhangjx
* @since 2.0.0
*/
@Inherited
@Documented
@Target({FIELD})
@Retention(RUNTIME)
public @interface CryptColumn {
Class<? extends CryptHandler> handler();
}

View File

@@ -0,0 +1,38 @@
/*
* To change this license header, choose License Headers in Project Properties.
* To change this template file, choose Tools | Templates
* and open the template in the editor.
*/
package org.redkale.source;
/**
* 字段加密解密接口
*
* <p>
* 详情见: https://redkale.org
*
* @author zhangjx
* @since 2.0.0
* @param <S> 加密的字段类型
* @param <D> 加密后的数据类型
*/
public interface CryptHandler<S, D> {
/**
* 加密
*
* @param value 加密前的字段值
*
* @return 加密后的字段值
*/
public D encrypt(S value);
/**
* 解密
*
* @param value 加密的字段值
*
* @return 解密后的字段值
*/
public S decrypt(D value);
}

View File

@@ -137,7 +137,7 @@ public class DataJdbcSource extends DataSqlSource<Connection> {
int i = 0;
for (char ch : sqlchars) {
if (ch == '?') {
Object obj = attrs[i++].get(value);
Object obj = info.getSQLValue(attrs[i++], value);
if (obj != null && obj.getClass().isArray()) {
sb.append("'[length=").append(java.lang.reflect.Array.getLength(obj)).append("]'");
} else {
@@ -174,10 +174,10 @@ public class DataJdbcSource extends DataSqlSource<Connection> {
return prestmt;
}
protected <T> int batchStatementParameters(Connection conn, PreparedStatement prestmt, EntityInfo<T> info, Attribute<T, Serializable>[] attrs, T value) throws SQLException {
protected <T> int batchStatementParameters(Connection conn, PreparedStatement prestmt, EntityInfo<T> info, Attribute<T, Serializable>[] attrs, T entity) throws SQLException {
int i = 0;
for (Attribute<T, Serializable> attr : attrs) {
Serializable val = attr.get(value);
Serializable val = info.getSQLValue(attr, entity);
if (val instanceof byte[]) {
Blob blob = conn.createBlob();
blob.setBytes(1, (byte[]) val);
@@ -186,7 +186,7 @@ public class DataJdbcSource extends DataSqlSource<Connection> {
prestmt.setObject(++i, ((AtomicInteger) val).get());
} else if (val instanceof AtomicLong) {
prestmt.setObject(++i, ((AtomicLong) val).get());
} else if (val != null && !(val instanceof Number) && !(val instanceof CharSequence) && !(value instanceof java.util.Date)
} else if (val != null && !(val instanceof Number) && !(val instanceof CharSequence) && !(entity instanceof java.util.Date)
&& !val.getClass().getName().startsWith("java.sql.") && !val.getClass().getName().startsWith("java.time.")) {
prestmt.setObject(++i, info.jsonConvert.convertTo(attr.genericType(), val));
} else {
@@ -281,7 +281,7 @@ public class DataJdbcSource extends DataSqlSource<Connection> {
StringBuilder sb = new StringBuilder(128);
for (char ch : sqlchars) {
if (ch == '?') {
Object obj = i == attrs.length ? primary.get(value) : attrs[i++].get(value);
Object obj = i == attrs.length ? info.getSQLValue(primary, value) : info.getSQLValue(attrs[i++], value);
if (obj != null && obj.getClass().isArray()) {
sb.append("'[length=").append(java.lang.reflect.Array.getLength(obj)).append("]'");
} else {
@@ -292,7 +292,7 @@ public class DataJdbcSource extends DataSqlSource<Connection> {
}
}
String debugsql = sb.toString();
if (info.isLoggable(logger, Level.FINEST, debugsql)) logger.finest(info.getType().getSimpleName() + " update sql=" + debugsql.replaceAll("(\r|\n)", "\\n"));
if (info.isLoggable(logger, Level.FINEST, debugsql)) logger.finest(info.getType().getSimpleName() + " updates sql=" + debugsql.replaceAll("(\r|\n)", "\\n"));
} //打印结束
}
int[] pc = prestmt.executeBatch();

View File

@@ -483,13 +483,13 @@ public abstract class DataSqlSource<DBChannel> extends AbstractService implement
protected <T> CompletableFuture<Integer> deleteCompose(final EntityInfo<T> info, final Serializable... ids) {
if (ids.length == 1) {
String sql = "DELETE FROM " + info.getTable(ids[0]) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(ids[0]);
String sql = "DELETE FROM " + info.getTable(ids[0]) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(info.getSQLValue(info.getPrimarySQLColumn(), ids[0]));
return deleteDB(info, null, sql);
}
String sql = "DELETE FROM " + info.getTable(ids[0]) + " WHERE " + info.getPrimarySQLColumn() + " IN (";
for (int i = 0; i < ids.length; i++) {
if (i > 0) sql += ',';
sql += FilterNode.formatToString(ids[i]);
sql += FilterNode.formatToString(info.getSQLValue(info.getPrimarySQLColumn(), ids[i]));
}
sql += ")";
if (info.isLoggable(logger, Level.FINEST, sql)) logger.finest(info.getType().getSimpleName() + " delete sql=" + sql);
@@ -772,11 +772,11 @@ public abstract class DataSqlSource<DBChannel> extends AbstractService implement
protected <T> CompletableFuture<Integer> updateColumnCompose(final EntityInfo<T> info, Serializable id, String column, final Serializable value) {
if (value instanceof byte[]) {
String sql = "UPDATE " + info.getTable(id) + " SET " + info.getSQLColumn(null, column) + " = " + prepareParamSign(1) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(id);
String sql = "UPDATE " + info.getTable(id) + " SET " + info.getSQLColumn(null, column) + " = " + prepareParamSign(1) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(info.getSQLValue(info.getPrimarySQLColumn(), id));
return updateDB(info, null, sql, true, value);
} else {
String sql = "UPDATE " + info.getTable(id) + " SET " + info.getSQLColumn(null, column) + " = "
+ info.formatToString(value) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(id);
+ info.formatToString(info.getSQLValue(column, value)) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(info.getSQLValue(info.getPrimarySQLColumn(), id));
return updateDB(info, null, sql, false);
}
}
@@ -917,11 +917,11 @@ public abstract class DataSqlSource<DBChannel> extends AbstractService implement
blobs.add((byte[]) col.getValue());
setsql.append(c).append(" = ").append(prepareParamSign(++index));
} else {
setsql.append(c).append(" = ").append(info.formatSQLValue(c, col));
setsql.append(c).append(" = ").append(info.formatSQLValue(c, attr, col));
}
}
if (setsql.length() < 1) return CompletableFuture.completedFuture(0);
String sql = "UPDATE " + info.getTable(id) + " SET " + setsql + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(id);
String sql = "UPDATE " + info.getTable(id) + " SET " + setsql + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(info.getSQLValue(info.getPrimarySQLColumn(), id));
if (blobs == null) return updateDB(info, null, sql, false);
return updateDB(info, null, sql, true, blobs.toArray());
}
@@ -999,7 +999,7 @@ public abstract class DataSqlSource<DBChannel> extends AbstractService implement
blobs.add((byte[]) col.getValue());
setsql.append(c).append(" = ").append(prepareParamSign(++index));
} else {
setsql.append(c).append(" = ").append(info.formatSQLValue(c, col));
setsql.append(c).append(" = ").append(info.formatSQLValue(c, attr, col));
}
}
if (setsql.length() < 1) return CompletableFuture.completedFuture(0);
@@ -1128,7 +1128,7 @@ public abstract class DataSqlSource<DBChannel> extends AbstractService implement
if (!selects.test(attr.field())) continue;
if (setsql.length() > 0) setsql.append(", ");
setsql.append(info.getSQLColumn(alias, attr.field()));
Serializable val = attr.get(bean);
Serializable val = info.getFieldValue(attr, bean);
if (val instanceof byte[]) {
if (blobs == null) blobs = new ArrayList<>();
blobs.add((byte[]) val);
@@ -1154,8 +1154,8 @@ public abstract class DataSqlSource<DBChannel> extends AbstractService implement
if (blobs == null) return updateDB(info, null, sql, false);
return updateDB(info, null, sql, true, blobs.toArray());
} else {
final Serializable id = info.getPrimary().get(bean);
String sql = "UPDATE " + info.getTable(id) + " a SET " + setsql + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(id);
final Serializable id = info.getSQLValue(info.getPrimary(), bean);
String sql = "UPDATE " + info.getTable(id) + " a SET " + setsql + " WHERE " + info.getPrimarySQLColumn() + " = " + info.getSQLValue(info.getPrimarySQLColumn(), id);
if (blobs == null) return updateDB(info, null, sql, false);
return updateDB(info, null, sql, true, blobs.toArray());
}
@@ -1533,7 +1533,8 @@ public abstract class DataSqlSource<DBChannel> extends AbstractService implement
}
protected <T> CompletableFuture<T> findCompose(final EntityInfo<T> info, final SelectColumn selects, Serializable pk) {
final String sql = "SELECT " + info.getQueryColumns(null, selects) + " FROM " + info.getTable(pk) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(pk);
String column = info.getPrimarySQLColumn();
final String sql = "SELECT " + info.getQueryColumns(null, selects) + " FROM " + info.getTable(pk) + " WHERE " + column + " = " + FilterNode.formatToString(info.getSQLValue(column, pk));
if (info.isLoggable(logger, Level.FINEST, sql)) logger.finest(info.getType().getSimpleName() + " find sql=" + sql);
return findDB(info, sql, true, selects);
}
@@ -1670,7 +1671,7 @@ public abstract class DataSqlSource<DBChannel> extends AbstractService implement
}
protected <T> CompletableFuture<Serializable> findColumnCompose(final EntityInfo<T> info, String column, final Serializable defValue, final Serializable pk) {
final String sql = "SELECT " + info.getSQLColumn(null, column) + " FROM " + info.getTable(pk) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(pk);
final String sql = "SELECT " + info.getSQLColumn(null, column) + " FROM " + info.getTable(pk) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(info.getSQLValue(info.getPrimarySQLColumn(), pk));
if (info.isLoggable(logger, Level.FINEST, sql)) logger.finest(info.getType().getSimpleName() + " find sql=" + sql);
return findColumnDB(info, sql, true, column, defValue);
}
@@ -1732,7 +1733,7 @@ public abstract class DataSqlSource<DBChannel> extends AbstractService implement
}
protected <T> CompletableFuture<Boolean> existsCompose(final EntityInfo<T> info, Serializable pk) {
final String sql = "SELECT COUNT(*) FROM " + info.getTable(pk) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(pk);
final String sql = "SELECT COUNT(*) FROM " + info.getTable(pk) + " WHERE " + info.getPrimarySQLColumn() + " = " + FilterNode.formatToString(info.getSQLValue(info.getPrimarySQLColumn(), pk));
if (info.isLoggable(logger, Level.FINEST, sql)) logger.finest(info.getType().getSimpleName() + " exists sql=" + sql);
return existsDB(info, sql, true);
}

View File

@@ -75,6 +75,10 @@ public final class EntityInfo<T> {
//只有field.name 与 Column.name不同才存放在aliasmap里.
private final Map<String, String> aliasmap;
//key是field的name value是CryptHandler
//字段都不存在CryptHandler时值因为为null减少判断
private final Map<String, CryptHandler> cryptmap;
//所有可更新字段,即排除了主键字段和标记为&#064;Column(updatable=false)的字段
private final Map<String, Attribute<T, Serializable>> updateAttributeMap = new HashMap<>();
@@ -274,6 +278,7 @@ public final class EntityInfo<T> {
}
this.constructorParameters = (cp == null || cp.value().length < 1) ? null : cp.value();
Attribute idAttr0 = null;
Map<String, CryptHandler> cryptmap0 = null;
Map<String, String> aliasmap0 = null;
Class cltmp = type;
Set<String> fields = new HashSet<>();
@@ -284,7 +289,7 @@ public final class EntityInfo<T> {
List<Attribute<T, Serializable>> updateattrs = new ArrayList<>();
boolean auto = false;
boolean uuid = false;
Map<Class, Creator<CryptHandler>> cryptCreatorMap = new HashMap<>();
do {
for (Field field : cltmp.getDeclaredFields()) {
if (Modifier.isStatic(field.getModifiers())) continue;
@@ -298,9 +303,16 @@ public final class EntityInfo<T> {
if (aliasmap0 == null) aliasmap0 = new HashMap<>();
aliasmap0.put(fieldname, sqlfield);
}
final CryptColumn cpt = field.getAnnotation(CryptColumn.class);
CryptHandler cryptHandler = null;
if (cpt != null) {
if (cryptmap0 == null) cryptmap0 = new HashMap<>();
cryptHandler = cryptCreatorMap.computeIfAbsent(cpt.handler(), c -> (Creator<CryptHandler>) Creator.create(cpt.handler())).create();
cryptmap0.put(fieldname, cryptHandler);
}
Attribute attr;
try {
attr = Attribute.create(cltmp, field);
attr = Attribute.create(cltmp, field, cryptHandler);
} catch (RuntimeException e) {
continue;
}
@@ -357,6 +369,7 @@ public final class EntityInfo<T> {
this.primary = idAttr0;
this.aliasmap = aliasmap0;
this.cryptmap = cryptmap0;
this.attributes = attributeMap.values().toArray(new Attribute[attributeMap.size()]);
this.queryAttributes = queryattrs.toArray(new Attribute[queryattrs.size()]);
this.insertAttributes = insertattrs.toArray(new Attribute[insertattrs.size()]);
@@ -857,6 +870,51 @@ public final class EntityInfo<T> {
: (tabalis == null ? aliasmap.getOrDefault(fieldname, fieldname) : (tabalis + '.' + aliasmap.getOrDefault(fieldname, fieldname)));
}
/**
* 字段值转换成数据库的值
*
* @param fieldname 字段名
* @param fieldvalue 字段值
*
* @return Object
*/
public Object getSQLValue(String fieldname, Object fieldvalue) {
if (this.cryptmap == null) return fieldvalue;
CryptHandler handler = this.cryptmap.get(fieldname);
if (handler == null) return fieldvalue;
return handler.encrypt(fieldvalue);
}
/**
* 字段值转换成数据库的值
*
* @param attr Attribute
* @param entity 记录对象
*
* @return Object
*/
public Serializable getSQLValue(Attribute<T, Serializable> attr, T entity) {
Serializable val = attr.get(entity);
CryptHandler cryptHandler = attr.attach();
if (cryptHandler != null) val = (Serializable) cryptHandler.encrypt(val);
return val;
}
/**
* 数据库的值转换成数字段值
*
* @param attr Attribute
* @param entity 记录对象
*
* @return Object
*/
public Serializable getFieldValue(Attribute<T, Serializable> attr, T entity) {
Serializable val = attr.get(entity);
CryptHandler cryptHandler = attr.attach();
if (cryptHandler != null) val = (Serializable) cryptHandler.decrypt(val);
return val;
}
/**
* 获取主键字段的表字段名
*
@@ -880,26 +938,30 @@ public final class EntityInfo<T> {
/**
* 拼接UPDATE给字段赋值的SQL片段
*
* @param col 表字段名
* @param cv ColumnValue
* @param col 表字段名
* @param attr Attribute
* @param cv ColumnValue
*
* @return CharSequence
*/
protected CharSequence formatSQLValue(String col, final ColumnValue cv) {
protected CharSequence formatSQLValue(String col, Attribute<T, Serializable> attr, final ColumnValue cv) {
if (cv == null) return null;
Object val = cv.getValue();
CryptHandler handler = attr.attach();
if (handler != null) val = handler.encrypt(val);
switch (cv.getExpress()) {
case INC:
return new StringBuilder().append(col).append(" + ").append(cv.getValue());
return new StringBuilder().append(col).append(" + ").append(val);
case MUL:
return new StringBuilder().append(col).append(" * ").append(cv.getValue());
return new StringBuilder().append(col).append(" * ").append(val);
case AND:
return new StringBuilder().append(col).append(" & ").append(cv.getValue());
return new StringBuilder().append(col).append(" & ").append(val);
case ORR:
return new StringBuilder().append(col).append(" | ").append(cv.getValue());
return new StringBuilder().append(col).append(" | ").append(val);
case MOV:
return formatToString(cv.getValue());
return formatToString(val);
}
return formatToString(cv.getValue());
return formatToString(val);
}
/**
@@ -1003,9 +1065,13 @@ public final class EntityInfo<T> {
o = null;
} else { //不支持超过2G的数据
o = blob.getBytes(1, (int) blob.length());
CryptHandler cryptHandler = attr.attach();
if (cryptHandler != null) o = (Serializable) cryptHandler.decrypt(o);
}
} else {
o = (Serializable) set.getObject(this.getSQLColumn(null, attr.field()));
CryptHandler cryptHandler = attr.attach();
if (cryptHandler != null) o = (Serializable) cryptHandler.decrypt(o);
if (t.isPrimitive()) {
if (o != null) {
if (t == int.class) {

View File

@@ -389,7 +389,7 @@ public class FilterNode { //FilterNode 不能实现Serializable接口 否则
.append(' ').append(fv.getExpress().value()).append(' ').append(fv.getDestvalue());
}
final boolean fk = (val0 instanceof FilterKey);
CharSequence val = fk ? info.getSQLColumn(talis, ((FilterKey) val0).getColumn()) : formatToString(express, val0);
CharSequence val = fk ? info.getSQLColumn(talis, ((FilterKey) val0).getColumn()) : formatToString(express, info.getSQLValue(column, val0));
if (val == null) return null;
StringBuilder sb = new StringBuilder(32);
if (express == CONTAIN) return info.containSQL.replace("${column}", info.getSQLColumn(talis, column)).replace("${keystr}", val);