예제 #1
0
 /**
  * 寻找拆分字段在 columnList中的索引
  *
  * @param insertStmt
  * @param partitionColumn
  * @return
  */
 private int getShardingColIndex(MySqlInsertStatement insertStmt, String partitionColumn) {
   int shardingColIndex = -1;
   for (int i = 0; i < insertStmt.getColumns().size(); i++) {
     if (partitionColumn.equalsIgnoreCase(
         StringUtil.removeBackquote(insertStmt.getColumns().get(i).toString()))) { // 找到分片字段
       shardingColIndex = i;
       return shardingColIndex;
     }
   }
   return shardingColIndex;
 }
예제 #2
0
  @Override
  public boolean visit(MySqlInsertStatement x) {
    setMode(x, Mode.Insert);

    setAliasMap();

    if (x.getTableName() instanceof SQLIdentifierExpr) {
      String ident = ((SQLIdentifierExpr) x.getTableName()).getName();
      setCurrentTable(x, ident);

      TableStat stat = getTableStat(ident);
      stat.incrementInsertCount();

      Map<String, String> aliasMap = getAliasMap();
      if (aliasMap != null) {
        if (x.getAlias() != null) {
          aliasMap.put(x.getAlias(), ident);
        }
        aliasMap.put(ident, ident);
      }
    }

    accept(x.getColumns());
    accept(x.getValuesList());
    accept(x.getQuery());
    accept(x.getDuplicateKeyUpdate());

    return false;
  }
예제 #3
0
  /** 考虑因素:isChildTable、批量、是否分片 目前只支持 */
  @Override
  public void statementParse(SchemaConfig schema, RouteResultset rrs, SQLStatement stmt)
      throws SQLNonTransientException {
    MySqlInsertStatement insert = (MySqlInsertStatement) stmt;
    String tableName =
        StringUtil.removeBackquote(insert.getTableName().getSimpleName()).toUpperCase();

    ctx.addTable(tableName);

    //        现阶段不支持,后期可以支持
    //        if(RouterUtil.isNoSharding(schema,tableName)) {//整个schema都不分库或者该表不拆分
    //            RouterUtil.routeForTableMeta(rrs, schema, tableName, rrs.getStatement());
    //            rrs.setFinishedRoute(true);
    //            return;
    //        }

    TableConfig tc = schema.getTables().get(tableName);
    if (Objects.isNull(tc)) {
      String msg = "can't find table : " + tableName + " define in schema : " + schema.getName();
      logger.warn(msg);
      throw new IllegalArgumentException(msg);
    } else {
      //

      // 后面添加两维度分表
      // 单维度的分表
      String partitionColumn = tc.getPartitionColumn();

      if (partitionColumn != null) { // 分片表
        // 拆分表必须给出column list,否则无法寻找分片字段的值
        if (insert.getColumns() == null || insert.getColumns().size() == 0) {
          logger.error("sql : {} , insert must provide ColumnList", stmt);
          throw new IllegalArgumentException("partition table, insert must provide ColumnList");
        }

        if (isMultiInsert(insert)) {
          logger.error("multi insert is forbidden");
          return;
        } else {
          parserSingleInsert(schema, rrs, partitionColumn, tableName, insert);
        }
      }
    }
  }
예제 #4
0
  /**
   * 目前只支持这个操作 单条insert(非批量)
   *
   * @param schema
   * @param rrs
   * @param partitionColumn
   * @param tableName
   * @param insertStmt
   * @throws SQLNonTransientException
   */
  private void parserSingleInsert(
      SchemaConfig schema,
      RouteResultset rrs,
      String partitionColumn,
      String tableName,
      MySqlInsertStatement insertStmt)
      throws SQLNonTransientException {
    boolean isFound = false;
    // 将分片的键 作为 路由计算单元
    for (int i = 0; i < insertStmt.getColumns().size(); i++) {
      if (partitionColumn.equalsIgnoreCase(
          StringUtil.removeBackquote(insertStmt.getColumns().get(i).toString()))) {
        // 找到分片字段
        isFound = true;
        String column = StringUtil.removeBackquote(insertStmt.getColumns().get(i).toString());

        String value =
            StringUtil.removeBackquote(insertStmt.getValues().getValues().get(i).toString());

        RouteCalculateUnit routeCalculateUnit = new RouteCalculateUnit();
        routeCalculateUnit.addShardingExpr(tableName, column, value);

        ctx.addRouteCalculateUnit(routeCalculateUnit);

        // mycat是单分片键,找到了就返回
        break;
      }
    }
    if (!isFound) { // 分片表的
      String msg =
          "bad insert sql (sharding column:" + partitionColumn + " not provided," + insertStmt;
      logger.warn(msg);
      throw new SQLNonTransientException(msg);
    }

    // 这种语句不应该支持
    //  insert into .... on duplicateKey
    //  such as :   INSERT INTO TABLEName (a,b,c) VALUES (1,2,3) ON DUPLICATE KEY UPDATE
    // b=VALUES(b);
    //              INSERT INTO TABLEName (a,b,c) VALUES (1,2,3) ON DUPLICATE KEY UPDATE c=c+1;
    //              du plicate 处理!!
    if (insertStmt.getDuplicateKeyUpdate() != null) {
      List<SQLExpr> updateList = insertStmt.getDuplicateKeyUpdate();
      for (SQLExpr expr : updateList) {
        SQLBinaryOpExpr opExpr = (SQLBinaryOpExpr) expr;
        String column = StringUtil.removeBackquote(opExpr.getLeft().toString().toUpperCase());
        if (column.equals(partitionColumn)) {
          String msg = "partion key can't be updated: " + tableName + " -> " + partitionColumn;
          logger.warn(msg);
          throw new SQLNonTransientException(msg);
        }
      }
    }
  }
예제 #5
0
  public static boolean processERChildTable(
      final SchemaConfig schema, final String origSQL, final ServerConnection sc)
      throws SQLNonTransientException {
    String tableName = StringUtil.getTableName(origSQL).toUpperCase();
    final TableConfig tc = schema.getTables().get(tableName);

    if (null != tc && tc.isChildTable()) {
      final RouteResultset rrs = new RouteResultset(origSQL, ServerParse.INSERT);
      String joinKey = tc.getJoinKey();
      MySqlInsertStatement insertStmt =
          (MySqlInsertStatement) (new MySqlStatementParser(origSQL)).parseInsert();
      int joinKeyIndex = getJoinKeyIndex(insertStmt.getColumns(), joinKey);

      if (joinKeyIndex == -1) {
        String inf = "joinKey not provided :" + tc.getJoinKey() + "," + insertStmt;
        LOGGER.warn(inf);
        throw new SQLNonTransientException(inf);
      }
      if (isMultiInsert(insertStmt)) {
        String msg = "ChildTable multi insert not provided";
        LOGGER.warn(msg);
        throw new SQLNonTransientException(msg);
      }

      String joinKeyVal = insertStmt.getValues().getValues().get(joinKeyIndex).toString();

      String sql = insertStmt.toString();

      // try to route by ER parent partion key
      RouteResultset theRrs = RouterUtil.routeByERParentKey(sql, rrs, tc, joinKeyVal);

      if (theRrs != null) {
        rrs.setFinishedRoute(true);
        sc.getSession2().execute(rrs, ServerParse.INSERT);
        return true;
      }

      // route by sql query root parent's datanode
      final String findRootTBSql = tc.getLocateRTableKeySql().toLowerCase() + joinKeyVal;
      if (LOGGER.isDebugEnabled()) {
        LOGGER.debug("find root parent's node sql " + findRootTBSql);
      }

      ListenableFuture<String> listenableFuture =
          MycatServer.getInstance()
              .getListeningExecutorService()
              .submit(
                  new Callable<String>() {
                    @Override
                    public String call() throws Exception {
                      FetchStoreNodeOfChildTableHandler fetchHandler =
                          new FetchStoreNodeOfChildTableHandler();
                      return fetchHandler.execute(
                          schema.getName(), findRootTBSql, tc.getRootParent().getDataNodes());
                    }
                  });

      Futures.addCallback(
          listenableFuture,
          new FutureCallback<String>() {
            @Override
            public void onSuccess(String result) {
              if (Strings.isNullOrEmpty(result)) {
                StringBuilder s = new StringBuilder();
                LOGGER.warn(
                    s.append(sc.getSession2()).append(origSQL).toString()
                        + " err:"
                        + "can't find (root) parent sharding node for sql:"
                        + origSQL);
                sc.writeErrMessage(
                    ErrorCode.ER_PARSE_ERROR,
                    "can't find (root) parent sharding node for sql:" + origSQL);
                return;
              }

              if (LOGGER.isDebugEnabled()) {
                LOGGER.debug(
                    "found partion node for child table to insert " + result + " sql :" + origSQL);
              }

              RouteResultset executeRrs = RouterUtil.routeToSingleNode(rrs, result, origSQL);
              sc.getSession2().execute(executeRrs, ServerParse.INSERT);
            }

            @Override
            public void onFailure(Throwable t) {
              StringBuilder s = new StringBuilder();
              LOGGER.warn(
                  s.append(sc.getSession2()).append(origSQL).toString() + " err:" + t.getMessage());
              sc.writeErrMessage(ErrorCode.ER_PARSE_ERROR, t.getMessage() + " " + s.toString());
            }
          },
          MycatServer.getInstance().getListeningExecutorService());
      return true;
    }
    return false;
  }
예제 #6
0
 /**
  * 是否为批量插入:insert into ...values (),()...或 insert into ...select.....
  *
  * @param insertStmt
  * @return
  */
 private static boolean isMultiInsert(MySqlInsertStatement insertStmt) {
   return (insertStmt.getValuesList() != null && insertStmt.getValuesList().size() > 1)
       || insertStmt.getQuery() != null;
 }
예제 #7
0
  @Override
  public boolean visit(MySqlInsertStatement x) {
    print("INSERT ");

    if (x.isLowPriority()) {
      print("LOW_PRIORITY ");
    }

    if (x.isDelayed()) {
      print("DELAYED ");
    }

    if (x.isHighPriority()) {
      print("HIGH_PRIORITY ");
    }

    if (x.isIgnore()) {
      print("IGNORE ");
    }

    print("INTO ");

    x.getTableName().accept(this);

    if (x.getColumns().size() > 0) {
      incrementIndent();
      println();
      print("(");
      for (int i = 0, size = x.getColumns().size(); i < size; ++i) {
        if (i != 0) {
          if (i % 5 == 0) {
            println();
          }
          print(", ");
        }

        x.getColumns().get(i).accept(this);
      }
      print(")");
      decrementIndent();
    }

    if (x.getValuesList().size() != 0) {
      println();
      print("VALUES");
      println();
      for (int i = 0, size = x.getValuesList().size(); i < size; ++i) {
        if (i != 0) {
          print(", ");
        }
        x.getValuesList().get(i).accept(this);
      }
    }

    if (x.getQuery() != null) {
      print(" ");
      x.getQuery().accept(this);
    }

    if (x.getDuplicateKeyUpdate().size() != 0) {
      print(" ON DUPLICATE KEY UPDATE ");
      printAndAccept(x.getDuplicateKeyUpdate(), ", ");
    }

    return false;
  }