zhangqian
2023-12-08 32e00f9438ed29fc26351f65cf7d98eefd1d838e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
package snowflake
 
import (
    "apsClient/pkg/logx"
    "errors"
    "fmt"
    "hash/fnv"
    "net"
    "strconv"
    "sync"
    "time"
)
 
var idGenerater *IdWorker
 
func init() {
    // 使用 LookupIP 获取主机的 IP 地址列表
    // 获取本机所有网络接口
    interfaces, err := net.Interfaces()
    if err != nil {
        logx.Errorf("snowflake InitWithIP error:%v", err)
        return
    }
 
    var ip string
    // 遍历所有网络接口
    for _, iface := range interfaces {
        // 排除一些特殊的接口,例如 loopback 接口
        if iface.Flags&net.FlagUp != 0 && iface.Flags&net.FlagLoopback == 0 {
            // 获取接口的所有地址
            addrs, err := iface.Addrs()
            if err != nil {
                logx.Errorf("snowflake InitWithIP error:%v", err)
                continue
            }
            // 遍历接口的所有地址
            for _, addr := range addrs {
                // 检查地址类型是否是 IP 地址
                if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
                    // 判断 IP 地址的版本是 IPv4 还是 IPv6
                    if ipNet.IP.To4() != nil {
                        fmt.Printf("IPv4 Address: %s\n", ipNet.IP.String())
                        if ipNet.IP.String() != "127.0.0.1" {
                            ip = ipNet.IP.String()
                            goto getIpOK
                        }
                    } else {
                        fmt.Printf("IPv6 Address: %s\n", ipNet.IP.String())
                    }
                }
            }
        }
    }
getIpOK:
    if ip == "" {
        logx.Errorf("snowflake InitWithIP can not find Ip")
        return
    }
    ipNumber, err := ipToNumber(ip)
    if err != nil {
        logx.Errorf("snowflake can not generate, init error, ip to number error :%v", err)
        panic(fmt.Sprintf("snowflake can not generate, init error, ip to number error :%v", err))
    }
    idGenerater, err = NewIdWorker(ipNumber)
    if err != nil {
        logx.Errorf("snowflake can not generate, init error :%v", err)
        panic(fmt.Sprintf("snowflake can not generate, init error :%v", err))
    }
}
 
func ipToNumber(ip string) (int64, error) {
    // 将 IP 地址字符串解析为 net.IP 类型
    parsedIP := net.ParseIP(ip)
    if parsedIP == nil {
        return 0, fmt.Errorf("invalid IP address:%v", ip)
    }
 
    // 使用 FNV-1a 散列算法计算哈希值
    hash := fnv.New32a()
    hash.Write(parsedIP)
    return int64(hash.Sum32() % 1024), nil // 取余数确保结果在 1 到 1023 之间
}
 
const (
    CEpoch         = 1474802888000
    CWorkerIdBits  = 10 // Num of WorkerId Bits
    CSenquenceBits = 12 // Num of Sequence Bits
 
    CWorkerIdShift  = 12
    CTimeStampShift = 22
 
    CSequenceMask = 0xfff // equal as getSequenceMask()
    CMaxWorker    = 0x3ff // equal as getMaxWorkerId()
)
 
type IdWorker struct {
    workerId      int64
    lastTimeStamp int64
    sequence      int64
    maxWorkerId   int64
    lock          *sync.Mutex
}
 
func NewIdWorker(workerId int64) (iw *IdWorker, err error) {
    iw = new(IdWorker)
 
    iw.maxWorkerId = getMaxWorkerId()
 
    if workerId > iw.maxWorkerId || workerId < 0 {
        return nil, errors.New("worker not fit")
    }
    iw.workerId = workerId
    iw.lastTimeStamp = -1
    iw.sequence = 0
    iw.lock = new(sync.Mutex)
    return iw, nil
}
 
func getMaxWorkerId() int64 {
    return -1 ^ -1<<CWorkerIdBits
}
 
func getSequenceMask() int64 {
    return -1 ^ -1<<CSenquenceBits
}
 
// return in ms
func (iw *IdWorker) timeGen() int64 {
    return time.Now().UnixNano() / 1000 / 1000
}
 
func (iw *IdWorker) timeReGen(last int64) int64 {
    ts := time.Now().UnixNano() / 1000 / 1000
    for {
        if ts <= last {
            ts = iw.timeGen()
        } else {
            break
        }
    }
    return ts
}
 
func (iw *IdWorker) NextId() (ts int64, err error) {
    iw.lock.Lock()
    defer iw.lock.Unlock()
    ts = iw.timeGen()
    if ts == iw.lastTimeStamp {
        iw.sequence = (iw.sequence + 1) & CSequenceMask
        if iw.sequence == 0 {
            ts = iw.timeReGen(ts)
        }
    } else {
        iw.sequence = 0
    }
 
    if ts < iw.lastTimeStamp {
        err = errors.New("Clock moved backwards, Refuse gen id")
        return 0, err
    }
    iw.lastTimeStamp = ts
    ts = (ts-CEpoch)<<CTimeStampShift | iw.workerId<<CWorkerIdShift | iw.sequence
    return ts, nil
}
 
func ParseId(id int64) (t time.Time, ts int64, workerId int64, seq int64) {
    seq = id & CSequenceMask
    workerId = (id >> CWorkerIdShift) & CMaxWorker
    ts = (id >> CTimeStampShift) + CEpoch
    t = time.Unix(ts/1000, (ts%1000)*1000000)
    return
}
 
func GenerateID() int64 {
start:
    id, err := idGenerater.NextId()
    if err != nil {
        goto start
    }
    return id
}
 
func GenerateIdStr() string {
start:
    id, err := idGenerater.NextId()
    if err != nil {
        goto start
    }
    return strconv.FormatInt(id, 10)
}