add
wangpengfei
2023-08-25 9f98932726cb41697fabccbbbd876205e7255c95
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
package utils
 
import (
    "errors"
    "fmt"
    "go/ast"
    "go/parser"
    "go/token"
    "os"
    "strings"
)
 
//@author: [LeonardWang](https://github.com/WangLeonard)
//@function: AutoInjectionCode
//@description: 向文件中固定注释位置写入代码
//@param: filepath string, funcName string, codeData string
//@return: error
 
const (
    startComment = "Code generated by github.com/flipped-aurora/gin-vue-admin/server Begin; DO NOT EDIT."
    endComment   = "Code generated by github.com/flipped-aurora/gin-vue-admin/server End; DO NOT EDIT."
)
 
//@author: [LeonardWang](https://github.com/WangLeonard)
//@function: AutoInjectionCode
//@description: 向文件中固定注释位置写入代码
//@param: filepath string, funcName string, codeData string
//@return: error
 
func AutoInjectionCode(filepath string, funcName string, codeData string) error {
    srcData, err := os.ReadFile(filepath)
    if err != nil {
        return err
    }
    srcDataLen := len(srcData)
    fset := token.NewFileSet()
    fparser, err := parser.ParseFile(fset, filepath, srcData, parser.ParseComments)
    if err != nil {
        return err
    }
    codeData = strings.TrimSpace(codeData)
    codeStartPos := -1
    codeEndPos := srcDataLen
    var expectedFunction *ast.FuncDecl
 
    startCommentPos := -1
    endCommentPos := srcDataLen
 
    // 如果指定了函数名,先寻找对应函数
    if funcName != "" {
        for _, decl := range fparser.Decls {
            if funDecl, ok := decl.(*ast.FuncDecl); ok && funDecl.Name.Name == funcName {
                expectedFunction = funDecl
                codeStartPos = int(funDecl.Body.Lbrace)
                codeEndPos = int(funDecl.Body.Rbrace)
                break
            }
        }
    }
 
    // 遍历所有注释
    for _, comment := range fparser.Comments {
        if int(comment.Pos()) > codeStartPos && int(comment.End()) <= codeEndPos {
            if startComment != "" && strings.Contains(comment.Text(), startComment) {
                startCommentPos = int(comment.Pos()) // Note: Pos is the second '/'
            }
            if endComment != "" && strings.Contains(comment.Text(), endComment) {
                endCommentPos = int(comment.Pos()) // Note: Pos is the second '/'
            }
        }
    }
 
    if endCommentPos == srcDataLen {
        return fmt.Errorf("comment:%s not found", endComment)
    }
 
    // 在指定函数名,且函数中startComment和endComment都存在时,进行区间查重
    if (codeStartPos != -1 && codeEndPos <= srcDataLen) && (startCommentPos != -1 && endCommentPos != srcDataLen) && expectedFunction != nil {
        if exist := checkExist(&srcData, startCommentPos, endCommentPos, expectedFunction.Body, codeData); exist {
            fmt.Printf("文件 %s 待插入数据 %s 已存在\n", filepath, codeData)
            return nil // 这里不需要返回错误?
        }
    }
 
    // 两行注释中间没有换行时,会被认为是一条Comment
    if startCommentPos == endCommentPos {
        endCommentPos = startCommentPos + strings.Index(string(srcData[startCommentPos:]), endComment)
        for srcData[endCommentPos] != '/' {
            endCommentPos--
        }
    }
 
    // 记录"//"之前的空字符,保持写入后的格式一致
    tmpSpace := make([]byte, 0, 10)
    for tmp := endCommentPos - 2; tmp >= 0; tmp-- {
        if srcData[tmp] != '\n' {
            tmpSpace = append(tmpSpace, srcData[tmp])
        } else {
            break
        }
    }
 
    reverseSpace := make([]byte, 0, len(tmpSpace))
    for index := len(tmpSpace) - 1; index >= 0; index-- {
        reverseSpace = append(reverseSpace, tmpSpace[index])
    }
 
    // 插入数据
    indexPos := endCommentPos - 1
    insertData := []byte(append([]byte(codeData+"\n"), reverseSpace...))
 
    remainData := append([]byte{}, srcData[indexPos:]...)
    srcData = append(append(srcData[:indexPos], insertData...), remainData...)
 
    // 写回数据
    return os.WriteFile(filepath, srcData, 0o600)
}
 
func checkExist(srcData *[]byte, startPos int, endPos int, blockStmt *ast.BlockStmt, target string) bool {
    for _, list := range blockStmt.List {
        switch stmt := list.(type) {
        case *ast.ExprStmt:
            if callExpr, ok := stmt.X.(*ast.CallExpr); ok &&
                int(callExpr.Pos()) > startPos && int(callExpr.End()) < endPos {
                text := string((*srcData)[int(callExpr.Pos()-1):int(callExpr.End())])
                key := strings.TrimSpace(text)
                if key == target {
                    return true
                }
            }
        case *ast.BlockStmt:
            if checkExist(srcData, startPos, endPos, stmt, target) {
                return true
            }
        case *ast.AssignStmt:
            // 为 model 中的代码进行检查
            if len(stmt.Rhs) > 0 {
                if callExpr, ok := stmt.Rhs[0].(*ast.CallExpr); ok {
                    for _, arg := range callExpr.Args {
                        if int(arg.Pos()) > startPos && int(arg.End()) < endPos {
                            text := string((*srcData)[int(arg.Pos()-1):int(arg.End())])
                            key := strings.TrimSpace(text)
                            if key == target {
                                return true
                            }
                        }
                    }
                }
            }
        }
    }
    return false
}
 
func AutoClearCode(filepath string, codeData string) error {
    srcData, err := os.ReadFile(filepath)
    if err != nil {
        return err
    }
    srcData, err = cleanCode(codeData, string(srcData))
    if err != nil {
        return err
    }
    return os.WriteFile(filepath, srcData, 0o600)
}
 
func cleanCode(clearCode string, srcData string) ([]byte, error) {
    bf := make([]rune, 0, 1024)
    for i, v := range srcData {
        if v == '\n' {
            if strings.TrimSpace(string(bf)) == clearCode {
                return append([]byte(srcData[:i-len(bf)]), []byte(srcData[i+1:])...), nil
            }
            bf = (bf)[:0]
            continue
        }
        bf = append(bf, v)
    }
    return []byte(srcData), errors.New("未找到内容")
}