fix
wangpengfei
2023-08-26 731ea35ad6ce787231dd8a796f653a6d882415cb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
package ast
 
import (
    "bytes"
    "fmt"
    "go/ast"
    "go/parser"
    "go/printer"
    "go/token"
    "os"
    "path/filepath"
    "srm/global"
)
 
func RollBackAst(pk, model string) {
    RollGormBack(pk, model)
    RollRouterBack(pk, model)
}
 
func RollGormBack(pk, model string) {
 
    // 首先分析存在多少个ttt作为调用方的node块
    // 如果多个 仅仅删除对应块即可
    // 如果单个 那么还需要剔除import
    path := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "gorm.go")
    src, err := os.ReadFile(path)
    if err != nil {
        fmt.Println(err)
    }
    fileSet := token.NewFileSet()
    astFile, err := parser.ParseFile(fileSet, "", src, 0)
    if err != nil {
        fmt.Println(err)
    }
    var n *ast.CallExpr
    var k int = -1
    var pkNum = 0
    ast.Inspect(astFile, func(node ast.Node) bool {
        if node, ok := node.(*ast.CallExpr); ok {
            for i := range node.Args {
                pkOK := false
                modelOK := false
                ast.Inspect(node.Args[i], func(item ast.Node) bool {
                    if ii, ok := item.(*ast.Ident); ok {
                        if ii.Name == pk {
                            pkOK = true
                            pkNum++
                        }
                        if ii.Name == model {
                            modelOK = true
                        }
                    }
                    if pkOK && modelOK {
                        n = node
                        k = i
                    }
                    return true
                })
            }
        }
        return true
    })
    if k > 0 {
        n.Args = append(append([]ast.Expr{}, n.Args[:k]...), n.Args[k+1:]...)
    }
    if pkNum == 1 {
        var imI int = -1
        var gp *ast.GenDecl
        ast.Inspect(astFile, func(node ast.Node) bool {
            if gen, ok := node.(*ast.GenDecl); ok {
                for i := range gen.Specs {
                    if imspec, ok := gen.Specs[i].(*ast.ImportSpec); ok {
                        if imspec.Path.Value == "\"srm/model/"+pk+"\"" {
                            gp = gen
                            imI = i
                            return false
                        }
                    }
                }
            }
            return true
        })
 
        if imI > -1 {
            gp.Specs = append(append([]ast.Spec{}, gp.Specs[:imI]...), gp.Specs[imI+1:]...)
        }
    }
 
    var out []byte
    bf := bytes.NewBuffer(out)
    printer.Fprint(bf, fileSet, astFile)
    os.Remove(path)
    os.WriteFile(path, bf.Bytes(), 0666)
 
}
 
func RollRouterBack(pk, model string) {
 
    // 首先抓到所有的代码块结构 {}
    // 分析结构中是否存在一个变量叫做 pk+Router
    // 然后获取到代码块指针 对内部需要回滚的代码进行剔除
    path := filepath.Join(global.GVA_CONFIG.AutoCode.Root, global.GVA_CONFIG.AutoCode.Server, "initialize", "router.go")
    src, err := os.ReadFile(path)
    if err != nil {
        fmt.Println(err)
    }
    fileSet := token.NewFileSet()
    astFile, err := parser.ParseFile(fileSet, "", src, 0)
    if err != nil {
        fmt.Println(err)
    }
 
    var block *ast.BlockStmt
    ast.Inspect(astFile, func(node ast.Node) bool {
        if n, ok := node.(*ast.BlockStmt); ok {
            ast.Inspect(n, func(bNode ast.Node) bool {
                if in, ok := bNode.(*ast.Ident); ok {
                    if in.Name == pk+"Router" {
                        block = n
                        return false
                    }
                }
                return true
            })
            return true
        }
        return true
    })
    var k int
    for i := range block.List {
        if stmtNode, ok := block.List[i].(*ast.ExprStmt); ok {
            ast.Inspect(stmtNode, func(node ast.Node) bool {
                if n, ok := node.(*ast.Ident); ok {
                    if n.Name == "Init"+model+"Router" {
                        k = i
                        return false
                    }
                }
                return true
            })
        }
    }
 
    block.List = append(append([]ast.Stmt{}, block.List[:k]...), block.List[k+1:]...)
 
    if len(block.List) == 1 {
        // 说明这个块就没任何意义了
        block.List = nil
        // TODO 删除空的{}
    }
 
    var out []byte
    bf := bytes.NewBuffer(out)
    printer.Fprint(bf, fileSet, astFile)
    os.Remove(path)
    os.WriteFile(path, bf.Bytes(), 0666)
}