| | |
| | | package kingdee |
| | | |
| | | import "fmt" |
| | | import ( |
| | | "encoding/json" |
| | | "errors" |
| | | "strings" |
| | | |
| | | func QueryMsgHandle(data []byte) error { |
| | | fmt.Println("recv msg ", string(data)) |
| | | "kingdee-dbapi/config" |
| | | "kingdee-dbapi/logger" |
| | | "kingdee-dbapi/nsqclient" |
| | | ) |
| | | |
| | | // 通用sql查询接口 |
| | | |
| | | type SqlQueryMsg struct { |
| | | Key string // 请求 |
| | | Command string |
| | | Success bool |
| | | Message string |
| | | Result []byte |
| | | } |
| | | |
| | | func SqlQueryHandle(msg []byte) error { |
| | | var query SqlQueryMsg |
| | | |
| | | if err := json.Unmarshal(msg, &query); err != nil { |
| | | logger.Warn("解析请求失败, %s", err.Error()) |
| | | return err |
| | | } |
| | | |
| | | var sql = query.Command |
| | | logger.Debug("接收到查询请求,%s", sql) |
| | | |
| | | if !sqlCheck(sql) { |
| | | query.Message = "危险的sql语句, 拒绝执行" |
| | | logger.Warn(query.Message) |
| | | } else { |
| | | result, err := execSqlCommand(sql) |
| | | if err != nil { |
| | | query.Message = err.Error() |
| | | logger.Warn("sql执行失败:%s", query.Message) |
| | | } else { |
| | | query.Result = result |
| | | query.Success = true |
| | | logger.Warn("sql执行完成.") |
| | | } |
| | | } |
| | | |
| | | replyData, _ := json.Marshal(query) |
| | | ok := nsqclient.Produce(config.Options.SqlReplyTopic, replyData) |
| | | logger.Warn("应答查询请求结果:%t, key:%s", ok, query.Key) |
| | | |
| | | return nil |
| | | } |
| | | |
| | | func execSqlCommand(sql string) ([]byte, error) { |
| | | var result []interface{} |
| | | |
| | | if db == nil { |
| | | return nil, errors.New("数据库未连接") |
| | | } |
| | | |
| | | rows, err := db.Raw(sql).Rows() |
| | | if err != nil { |
| | | return nil, err |
| | | } |
| | | |
| | | var cols []string |
| | | for rows.Next() { |
| | | //先获取所有的column |
| | | if cols == nil { |
| | | cols, _ = rows.Columns() |
| | | } |
| | | |
| | | //建立俩个interface数组,columnPointers中存在columns的地址 |
| | | columns := make([]interface{}, len(cols)) |
| | | columnPointers := make([]interface{}, len(cols)) |
| | | for i := 0; i < len(columns); i++ { |
| | | //赋值地址 |
| | | columnPointers[i] = &columns[i] |
| | | } |
| | | |
| | | //扫描结果 |
| | | err = rows.Scan(columnPointers...) |
| | | if err != nil { |
| | | return nil, err |
| | | } |
| | | |
| | | m := make(map[string]interface{}) |
| | | for i, colName := range cols { |
| | | val := columnPointers[i].(*interface{}) |
| | | m[colName] = *val |
| | | } |
| | | |
| | | result = append(result, m) |
| | | } |
| | | |
| | | rb, _ := json.Marshal(result) |
| | | |
| | | return rb, nil |
| | | } |
| | | |
| | | // 简单过滤下sql语句,拒绝增删改操作 |
| | | func sqlCheck(sql string) bool { |
| | | var dangerousWords = []string{"INSERT", "UPDATE", "DELETE", "ALTER", "DROP", "DECLARE", "EXECUTE", "EXEC", "INTO", "TRANCATE"} |
| | | |
| | | var upperStr = strings.ToUpper(sql) |
| | | |
| | | for _, word := range dangerousWords { |
| | | if strings.Contains(upperStr, word) { |
| | | return false |
| | | } |
| | | } |
| | | |
| | | return true |
| | | } |