liuxiaolong
2019-05-09 0d1d88cdb668e75ea8609417ac18ae19947e9525
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
package framework.mybatis.pageAction;
 
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Properties;
 
import org.apache.ibatis.executor.ErrorContext;
import org.apache.ibatis.executor.ExecutorException;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ParameterMode;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.property.PropertyTokenizer;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.scripting.xmltags.ForEachSqlNode;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.TypeException;
import org.apache.ibatis.type.TypeHandler;
import org.apache.ibatis.type.TypeHandlerRegistry;
 
/**
 * 
 *
 * 分页拦截器,用于拦截需要进行分页查询的操作,然后对其进行分页处理。 利用拦截器实现Mybatis分页的原理:
 * 要利用JDBC对数据库进行操作就必须要有一个对应的Statement对象,
 * Mybatis在执行Sql语句前就会产生一个包含Sql语句的Statement对象,而且对应的Sql语句
 * 是在Statement之前产生的,所以我们就可以在它生成Statement之前对用来生成Statement的Sql语句下手。
 * 在Mybatis中Statement语句是通过RoutingStatementHandler对象的
 * prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法,
 * 然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句,之后再调用
 * StatementHandler对象的prepare方法,即调用invocation.proceed()。
 * 对于分页而言,在拦截器里面我们还需要做的一个操作就是统计满足当前条件的记录一共有多少,这是通过获取到了原始的Sql语句后,
 * 把它改为对应的统计语句再利用Mybatis封装好的参数和设 置参数的功能把Sql语句中的参数进行替换,之后再执行查询记录数的Sql语句进行总记录数的统计。
 *
 * 
 * 
 * @author liuyajun, 8384503@qq.com
 * @date 2016年1月28日
 * @time 下午5:48:18
 */
@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) })
public class PageInterceptor implements Interceptor {
 
    private String databaseType;// 数据库类型,不同的数据库有不同的分页方法
 
    /**
     * 拦截后要执行的方法
     */
    public Object intercept(Invocation invocation) throws Throwable {
        // 对于StatementHandler其实只有两个实现类,一个是RoutingStatementHandler,另一个是抽象类BaseStatementHandler,
        // BaseStatementHandler有三个子类,分别是SimpleStatementHandler,PreparedStatementHandler和CallableStatementHandler,
        // SimpleStatementHandler是用于处理Statement的,PreparedStatementHandler是处理PreparedStatement的,而CallableStatementHandler是
        // 处理CallableStatement的。Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,而在RoutingStatementHandler里面拥有一个
        // StatementHandler类型的delegate属性,RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler,即SimpleStatementHandler、
        // PreparedStatementHandler或CallableStatementHandler,在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。
        // 我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法,又因为Mybatis只有在建立RoutingStatementHandler的时候
        // 是通过Interceptor的plugin方法进行包裹的,所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。
        RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
        // 通过反射获取到当前RoutingStatementHandler对象的delegate属性
        StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
        // 获取到当前StatementHandler的
        // boundSql,这里不管是调用handler.getBoundSql()还是直接调用delegate.getBoundSql()结果是一样的,因为之前已经说过了
        // RoutingStatementHandler实现的所有StatementHandler接口方法里面都是调用的delegate对应的方法。
        BoundSql boundSql = delegate.getBoundSql();
        // 拿到当前绑定Sql的参数对象,就是我们在调用对应的Mapper映射语句时所传入的参数对象
        Object obj = boundSql.getParameterObject();
        // 这里我们简单的通过传入的是Page对象就认定它是需要进行分页操作的。
        if(obj instanceof Map<?,?>){
            Map<String,Object> paramMap = (Map<String,Object>)obj;
            
            if(paramMap.containsKey(PageUtil.PAGE_PARAM_KEY) 
                    && paramMap.get(PageUtil.PAGE_PARAM_KEY)!=null){
                // 通过反射获取delegate父类BaseStatementHandler的mappedStatement属性
                MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(
                        delegate, "mappedStatement");
                // 拦截到的prepare方法参数是一个Connection对象
                Connection connection = (Connection) invocation.getArgs()[0];
                // 获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句
                String sql = boundSql.getSql();
                
                // 给当前的page参数对象设置总记录数
                this.setTotalRecord(paramMap, mappedStatement, connection);
                // 获取分页Sql语句
                String pageSql = this.getPageSql(paramMap, sql);
                // 利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句
                ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
            }
        }
 
        return invocation.proceed();
    }
 
    /**
     * 拦截器对应的封装原始对象的方法
     */
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }
 
    /**
     * 设置注册拦截器时设定的属性
     */
    public void setProperties(Properties properties) {
        this.databaseType = properties.getProperty("databaseType");
    }
 
    /**
     * 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle 其它的数据库都 没有进行分页
     *
     * @param page
     *            分页对象
     * @param sql
     *            原sql语句
     * @return
     */
    private String getPageSql(Map<String,Object> paramMap, String sql) {
        StringBuffer sqlBuffer = new StringBuffer(sql);
        PageInfo page = PageUtil.getPageInfo(paramMap);
        
        if ("mysql".equalsIgnoreCase(databaseType)) {
            return getMysqlPageSql(page, sqlBuffer);
        } 
        
//        else if ("oracle".equalsIgnoreCase(databaseType)) {
//            return getOraclePageSql(page, sqlBuffer);
//        }
//        return sqlBuffer.toString();
        
        throw new RuntimeException("Unsupport database type: "+databaseType);
    }
 
    /**
     * 获取Mysql数据库的分页查询语句
     * 
     * @param page
     *            分页对象
     * @param sqlBuffer
     *            包含原sql语句的StringBuffer对象
     * @return Mysql数据库分页语句
     */
    private String getMysqlPageSql(PageInfo page, StringBuffer sqlBuffer) {
        // 计算第一条记录的位置,Mysql中记录的位置是从0开始的
        sqlBuffer.append(" limit ").append(page.getStartRow()==0?0:(page.getStartRow()-1))
            .append(",").append(page.getStartRow()==0?0:(page.getEndRow()-page.getStartRow()+1));
        return sqlBuffer.toString();
    }
 
    /**
     * 获取Oracle数据库的分页查询语句
     * 
     * @param page
     *            分页对象
     * @param sqlBuffer
     *            包含原sql语句的StringBuffer对象
     * @return Oracle数据库的分页查询语句
     */
//    private String getOraclePageSql(PageInfo page, StringBuffer sqlBuffer) {
//        // 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
//        sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ")
//                .append(offset + page.getTotalPage());
//        sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset);
//        // 上面的Sql语句拼接之后大概是这个样子:
//        // select * from (select u.*, rownum r from (select * from t_user) u
//        // where rownum < 31) where r >= 16
//        return sqlBuffer.toString();
//    }
 
    /**
     * 给当前的参数对象page设置总记录数
     *
     * @param page
     *            Mapper映射语句对应的参数对象
     * @param mappedStatement
     *            Mapper映射语句
     * @param connection
     *            当前的数据库连接
     */
    private void setTotalRecord(Map<String,Object> paramMap, MappedStatement mappedStatement, Connection connection) {
        // 获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。
        // delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。
        
        BoundSql boundSql = mappedStatement.getBoundSql(paramMap);
        // 获取到我们自己写在Mapper映射语句中对应的Sql语句
        String sql = boundSql.getSql();
        // 通过查询Sql语句获取到对应的计算总记录数的sql语句
        String countSql = this.getCountSql(sql);
        // 通过BoundSql获取对应的参数映射
        //List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        
        // 利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。
//        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, paramMap);
 
        // 通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象
        //ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, paramMap, countBoundSql);
        
        // 通过connection建立一个countSql对应的PreparedStatement对象。
        PreparedStatement pstmt = null;
        ResultSet rs = null;
        try {
            pstmt = connection.prepareStatement(countSql);
            // 通过parameterHandler给PreparedStatement对象设置参数
            //parameterHandler.setParameters(pstmt);
            this.setParameters2(pstmt, mappedStatement, boundSql, paramMap);
            // 之后就是执行获取总记录数的Sql语句和获取结果了。
            rs = pstmt.executeQuery();
            if (rs.next()) {
                int totalRecord = rs.getInt(1);
                // 给当前的参数page对象设置总记录数
                PageUtil.getPageInfo(paramMap).setTotalCount(totalRecord);
            }
        } catch (Throwable e) {
            e.printStackTrace();
        } finally {
            try {
                if (rs != null){
                    rs.close();
                }
            } catch (Throwable e) {
                e.printStackTrace();
            }
            try {
                if (pstmt != null){
                    pstmt.close();
                }
            } catch (Throwable e) {
                e.printStackTrace();
            }
        }
    }
 
    private void setParameters2(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql,
            Object parameterObject) {
        ErrorContext.instance().activity("setting parameters").object(mappedStatement.getParameterMap().getId());
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        if (parameterMappings != null) {
            Configuration configuration = mappedStatement.getConfiguration();
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            for (int i = 0; i < parameterMappings.size(); i++) {
                ParameterMapping parameterMapping = parameterMappings.get(i);
                if (parameterMapping.getMode() != ParameterMode.OUT) {
                    Object value;
                    String propertyName = parameterMapping.getProperty();
                    if (boundSql.hasAdditionalParameter(propertyName)) {
                        value = boundSql.getAdditionalParameter(propertyName);
                    } else if (parameterObject == null) {
                        value = null;
                    } else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                        value = parameterObject;
                    } else {
                        MetaObject metaObject = configuration.newMetaObject(parameterObject);
                        value = metaObject.getValue(propertyName);
                    }
                    TypeHandler typeHandler = parameterMapping.getTypeHandler();
                    JdbcType jdbcType = parameterMapping.getJdbcType();
                    if (value == null && jdbcType == null) {
                        jdbcType = configuration.getJdbcTypeForNull();
                    }
                    try {
                        typeHandler.setParameter(ps, i + 1, value, jdbcType);
                    } catch (Throwable e) {
                        throw new TypeException(
                                "Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
                    }
                }
            }
        }
 
    }
    
    /** 
       * 对SQL参数(?)设值,参考org.apache.ibatis.executor.parameter.DefaultParameterHandler 
       * @param ps 
       * @param mappedStatement 
       * @param boundSql 
       * @param parameterObject 
       * @throws SQLException 
       */
    private void setParameters(PreparedStatement ps, MappedStatement mappedStatement, BoundSql boundSql,
            Object parameterObject) throws SQLException {
        ErrorContext.instance().activity("setting parameters").object(mappedStatement.getParameterMap().getId());
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        if (parameterMappings != null) {
            Configuration configuration = mappedStatement.getConfiguration();
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            MetaObject metaObject = parameterObject == null ? null : configuration.newMetaObject(parameterObject);
            for (int i = 0; i < parameterMappings.size(); i++) {
                ParameterMapping parameterMapping = parameterMappings.get(i);
                if (parameterMapping.getMode() != ParameterMode.OUT) {
                    Object value;
                    String propertyName = parameterMapping.getProperty();
                    PropertyTokenizer prop = new PropertyTokenizer(propertyName);
                    if (parameterObject == null) {
                        value = null;
                    } else if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                        value = parameterObject;
                    } else if (boundSql.hasAdditionalParameter(propertyName)) {
                        value = boundSql.getAdditionalParameter(propertyName);
                    } else if (propertyName.startsWith(ForEachSqlNode.ITEM_PREFIX)
                            && boundSql.hasAdditionalParameter(prop.getName())) {
                        value = boundSql.getAdditionalParameter(prop.getName());
                        if (value != null) {
                            value = configuration.newMetaObject(value)
                                    .getValue(propertyName.substring(prop.getName().length()));
                        }
                    } else {
                        value = metaObject == null ? null : metaObject.getValue(propertyName);
                    }
                    TypeHandler typeHandler = parameterMapping.getTypeHandler();
                    if (typeHandler == null) {
                        throw new ExecutorException("There was no TypeHandler found for parameter " + propertyName
                                + " of statement " + mappedStatement.getId());
                    }
                    JdbcType jdbcType = parameterMapping.getJdbcType();
                    if (value == null && jdbcType == null) {
                        jdbcType = configuration.getJdbcTypeForNull();
                    }
                    try {
                        typeHandler.setParameter(ps, i + 1, value, jdbcType);
                    } catch (Throwable e) {
                        throw new TypeException(
                                "Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
                    }
 
                    // typeHandler.setParameter(ps, i + 1, value,
                    // parameterMapping.getJdbcType());
                }
            }
        }
    }
 
    /**
     * 根据原Sql语句获取对应的查询总记录数的Sql语句
     * 
     * @param sql
     * @return
     */
    private String getCountSql(String sql) {
 
        return "select count(1) from (" + sql + ") ori_table";
    }
 
    /**
     * 利用反射进行操作的一个工具类
     *
     */
    private static class ReflectUtil {
        /**
         * 利用反射获取指定对象的指定属性
         * 
         * @param obj
         *            目标对象
         * @param fieldName
         *            目标属性
         * @return 目标属性的值
         */
        public static Object getFieldValue(Object obj, String fieldName) {
            Object result = null;
            Field field = ReflectUtil.getField(obj, fieldName);
            if (field != null) {
                field.setAccessible(true);
                try {
                    result = field.get(obj);
                } catch (Throwable e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
            return result;
        }
 
        /**
         * 利用反射获取指定对象里面的指定属性
         * 
         * @param obj
         *            目标对象
         * @param fieldName
         *            目标属性
         * @return 目标字段
         */
        private static Field getField(Object obj, String fieldName) {
            Field field = null;
            for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) {
                try {
                    field = clazz.getDeclaredField(fieldName);
                    break;
                } catch (Throwable e) {
                    // 这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
                }
            }
            return field;
        }
 
        /**
         * 利用反射设置指定对象的指定属性为指定的值
         * 
         * @param obj
         *            目标对象
         * @param fieldName
         *            目标属性
         * @param fieldValue
         *            目标值
         */
        public static void setFieldValue(Object obj, String fieldName, String fieldValue) {
            Field field = ReflectUtil.getField(obj, fieldName);
            if (field != null) {
                try {
                    field.setAccessible(true);
                    field.set(obj, fieldValue);
                } catch (Throwable e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
    }
 
}