zhangzengfei
2023-08-14 8f750b461a4f442825e516016bf78d05ed66afcb
kingdee/query.go
@@ -1,8 +1,116 @@
package kingdee
import "fmt"
import (
   "encoding/json"
   "errors"
   "strings"
func QueryMsgHandle(data []byte) error {
   fmt.Println("recv msg ", string(data))
   "kingdee-dbapi/config"
   "kingdee-dbapi/logger"
   "kingdee-dbapi/nsqclient"
)
// 通用sql查询接口
type SqlQueryMsg struct {
   Key     string // 请求
   Command string
   Success bool
   Message string
   Result  []byte
}
func SqlQueryHandle(msg []byte) error {
   var query SqlQueryMsg
   if err := json.Unmarshal(msg, &query); err != nil {
      logger.Warn("解析请求失败, %s", err.Error())
      return err
   }
   var sql = query.Command
   logger.Debug("接收到查询请求,%s", sql)
   if !sqlCheck(sql) {
      query.Message = "危险的sql语句, 拒绝执行"
      logger.Warn(query.Message)
   } else {
      result, err := execSqlCommand(sql)
      if err != nil {
         query.Message = err.Error()
         logger.Warn("sql执行失败:%s", query.Message)
      } else {
         query.Result = result
         query.Success = true
         logger.Warn("sql执行完成.")
      }
   }
   replyData, _ := json.Marshal(query)
   ok := nsqclient.Produce(config.Options.SqlReplyTopic, replyData)
   logger.Warn("应答查询请求结果:%t, key:%s", ok, query.Key)
   return nil
}
func execSqlCommand(sql string) ([]byte, error) {
   var result []interface{}
   if db == nil {
      return nil, errors.New("数据库未连接")
   }
   rows, err := db.Raw(sql).Rows()
   if err != nil {
      return nil, err
   }
   var cols []string
   for rows.Next() {
      //先获取所有的column
      if cols == nil {
         cols, _ = rows.Columns()
      }
      //建立俩个interface数组,columnPointers中存在columns的地址
      columns := make([]interface{}, len(cols))
      columnPointers := make([]interface{}, len(cols))
      for i := 0; i < len(columns); i++ {
         //赋值地址
         columnPointers[i] = &columns[i]
      }
      //扫描结果
      err = rows.Scan(columnPointers...)
      if err != nil {
         return nil, err
      }
      m := make(map[string]interface{})
      for i, colName := range cols {
         val := columnPointers[i].(*interface{})
         m[colName] = *val
      }
      result = append(result, m)
   }
   rb, _ := json.Marshal(result)
   return rb, nil
}
// 简单过滤下sql语句,拒绝增删改操作
func sqlCheck(sql string) bool {
   var dangerousWords = []string{"INSERT", "UPDATE", "DELETE", "ALTER", "DROP", "DECLARE", "EXECUTE", "EXEC", "INTO", "TRANCATE"}
   var upperStr = strings.ToUpper(sql)
   for _, word := range dangerousWords {
      if strings.Contains(upperStr, word) {
         return false
      }
   }
   return true
}