MyBatis分页插件实现 -《MyBatis技术内幕》笔记

一个简单的插件例子

1
2
3
4
5
6
7
8
9
@Intercepts({
  @Signature(type=Executor.class, method="query", args={
    MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class
  }),
  @Signature(type=Executor.class, method="close", args={boolean.class})
})
pulic class ExamplePlugin implements Interceptor {
  ...
}

实现

session/Configuration.java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
  executorType = executorType == null ? defaultExecutorType : executorType;
  executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
  Executor executor;
  if (ExecutorType.BATCH == executorType) {
    executor = new BatchExecutor(this, transaction);
  } else if (ExecutorType.REUSE == executorType) {
    executor = new ReuseExecutor(this, transaction);
  } else {
    executor = new SimpleExecutor(this, transaction);
  }
  if (cacheEnabled) {
    executor = new CachingExecutor(executor);
  }
  // interceptorChain记录了配置的拦截器
  executor = (Executor) interceptorChain.pluginAll(executor);
  return executor;
}

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
public interface Interceptor {
  Object intercept(Invocation invocation) throws Throwable;

  Object plugin(Object target);

  void setProperties(Properties properties);
}

public class InterceptorChain {
  private final List<Interceptor> interceptors = new ArrayList<>();

  public Object pluginAll(Object target) {
    for (Interceptor interceptor : interceptors) {
      target = interceptor.plugin(target);
    }
    return target;
  }

  public void addInterceptor(Interceptor interceptor) {
    interceptors.add(interceptor);
  }

  public List<Interceptor> getInterceptors() {
    return Collections.unmodifiableList(interceptors);
  }
}

public class Plugin implements InvocationHandler {
  private final Object target;
  private final Interceptor interceptor;
  //@Signature注解的信息
  private final Map<Class<?>, Set<Method>> signatureMap;

  private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap) {
    this.target = target;
    this.interceptor = interceptor;
    this.signatureMap = signatureMap;
  }

  public static Object wrap(Object target, Interceptor interceptor) {
    // 获取@Signature注解信息
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
    Class<?> type = target.getClass();
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
    if (interfaces.length > 0) {
      // 使用JDK动态代理创建代理对象
      return Proxy.newProxyInstance(
          type.getClassLoader(),
          interfaces,
          new Plugin(target, interceptor, signatureMap));
    }
    return target;
  }

  @Override
  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    try {
      Set<Method> methods = signatureMap.get(method.getDeclaringClass());
      // 方法需要被拦截
      if (methods != null && methods.contains(method)) {
        return interceptor.intercept(new Invocation(target, method, args));
      }
      // 方法不需要拦截,直接调用
      return method.invoke(target, args);
    } catch (Exception e) {
      throw ExceptionUtil.unwrapThrowable(e);
    }
  }
}

分页插件PageInterceptor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
@Intercepts({
  @Signature(type=Executor.class, method="query", args={
    MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class
  })
})
pulic class PageInterceptor implements Interceptor {
  public Object plugin(Object target){
    return Plugin.wrap(target, this);
  }

  public Object intercept(final Invocation invocation) throws Throwable {
    final Object[] queryArgs = invocation.getArgs();
    final MappedStatement ms = (MappedStatement) queryArgs[0];
    final Object parameter = queryArgs[1];
    final RowBounds rowBounds = (RowBounds) queryArgs[2];

    int offset = rowBound.getOffset();
    int limit = rowBound.getLimit();

    final BoundSql bondSql = mappedStatement.getBoundSql(parameter);
    final StringBuffer bufferSql = new StringBuffer(boundSql.getSql());
    String sql = getFormmatSql(bufferSql.toString().trim())

    sql = getPagingSql(sql, offset, limit)

    //需要重置RowBounds
    queryArgs[2] = new RowBounds(RowBounds.NOROWOFFSET, RowBound.NOROWLIMIT);
    //根据最新的SQL,创建新的MappedStatement
    queryArgs[0] = createMappedStatement(mappedStatement, boundSql, sql);
    return invocation.proceed();
  }

  public String getPagingSql(String sql, int offset, int limit){
    sql = sql.trim();
    boolean hasForUpdate = false;
    String forUpdatePart = "for update";
    if(sql.endsWith(forUpdatePart)){
      sql = sql.substring(0, sql.length()-forUpdatePart.length());
      hasForUpdate = true
    }

    StringBuffer result = new StringBuffer(sql.length())
    result.append(sql).append(" limit ");
    if(offset > 0){
      result.append(offset).append(",").append(limit);
    } else {
      result.append(limit)
    }

    if(hasForUpdate){
      result.append(" for update");
    }
    return reesult.toString();
  }
}