package vrrp
|
|
import (
|
"fmt"
|
"net"
|
"syscall"
|
"time"
|
|
"basic.com/valib/logger.git"
|
"github.com/mdlayher/arp"
|
"github.com/mdlayher/ndp"
|
)
|
|
type IPConnection interface {
|
WriteMessage(*VRRPPacket) error
|
ReadMessage() (*VRRPPacket, error)
|
}
|
|
type AddrAnnouncer interface {
|
AnnounceAll(vr *VirtualRouter) error
|
}
|
|
type IPv4AddrAnnouncer struct {
|
ARPClient *arp.Client
|
}
|
|
type IPv6AddrAnnouncer struct {
|
con *ndp.Conn
|
}
|
|
func NewIPIPv6AddrAnnouncer(nif *net.Interface) *IPv6AddrAnnouncer {
|
var con, ip, errOfMakeNDPCon = ndp.Dial(nif, ndp.LinkLocal)
|
if errOfMakeNDPCon != nil {
|
logger.Fatal("NewIPv6AddrAnnouncer: ", errOfMakeNDPCon)
|
}
|
logger.Info("NDP client initialized, working on %v, source IP ", nif.Name, ip)
|
return &IPv6AddrAnnouncer{con: con}
|
}
|
|
func (nd *IPv6AddrAnnouncer) AnnounceAll(vr *VirtualRouter) error {
|
for key := range vr.protectedIPaddrs {
|
var multicastgroup, errOfParseMulticastGroup = ndp.SolicitedNodeMulticast(net.IP(key[:]))
|
if errOfParseMulticastGroup != nil {
|
logger.Error("IPv6AddrAnnouncer.AnnounceAll: ", errOfParseMulticastGroup)
|
return errOfParseMulticastGroup
|
} else {
|
//send unsolicited NeighborAdvertisement to refresh link layer address cache
|
var msg = &ndp.NeighborAdvertisement{
|
Override: true,
|
TargetAddress: net.IP(key[:]),
|
Options: []ndp.Option{
|
&ndp.LinkLayerAddress{
|
Direction: ndp.Source,
|
Addr: vr.netInterface.HardwareAddr,
|
},
|
},
|
}
|
if errOfWrite := nd.con.WriteTo(msg, nil, multicastgroup); errOfWrite != nil {
|
logger.Error("IPv6AddrAnnouncer.AnnounceAll: ", errOfWrite)
|
return errOfWrite
|
} else {
|
logger.Info("send unsolicited neighbor advertisement for ", net.IP(key[:]))
|
}
|
}
|
|
}
|
|
return nil
|
}
|
|
//makeGratuitousPacket make gratuitous ARP packet with out payload
|
func (ar *IPv4AddrAnnouncer) makeGratuitousPacket() *arp.Packet {
|
var packet arp.Packet
|
packet.HardwareType = 1 //ethernet10m
|
packet.ProtocolType = 0x0800 //IPv4
|
packet.HardwareAddrLength = 6
|
packet.IPLength = 4
|
packet.Operation = 2 //response
|
return &packet
|
}
|
|
//AnnounceAll send gratuitous ARP response for all protected IPv4 addresses
|
func (ar *IPv4AddrAnnouncer) AnnounceAll(vr *VirtualRouter) error {
|
if errofSetDealLine := ar.ARPClient.SetWriteDeadline(time.Now().Add(500 * time.Microsecond)); errofSetDealLine != nil {
|
return fmt.Errorf("IPv4AddrAnnouncer.AnnounceAll: %v", errofSetDealLine)
|
}
|
var packet = ar.makeGratuitousPacket()
|
for k := range vr.protectedIPaddrs {
|
packet.SenderHardwareAddr = vr.netInterface.HardwareAddr
|
packet.SenderIP = net.IP(k[:]).To4()
|
packet.TargetHardwareAddr = BaordcastHADDR
|
packet.TargetIP = net.IP(k[:]).To4()
|
logger.Info("send gratuitous arp for ", net.IP(k[:]))
|
if errofsendarp := ar.ARPClient.WriteTo(packet, BaordcastHADDR); errofsendarp != nil {
|
return fmt.Errorf("IPv4AddrAnnouncer.AnnounceAll: %v", errofsendarp)
|
}
|
}
|
return nil
|
}
|
|
func NewIPv4AddrAnnouncer(nif *net.Interface) *IPv4AddrAnnouncer {
|
if aar, errofDialARP := arp.Dial(nif); errofDialARP != nil {
|
panic(errofDialARP)
|
} else {
|
logger.Debug("IPv4 addresses announcer created")
|
return &IPv4AddrAnnouncer{ARPClient: aar}
|
}
|
|
}
|
|
type IPv4Con struct {
|
buffer []byte
|
remote net.IP
|
local net.IP
|
SendCon *net.IPConn
|
ReceiveCon *net.IPConn
|
}
|
|
type IPv6Con struct {
|
buffer []byte
|
oob []byte
|
remote net.IP
|
local net.IP
|
Con *net.IPConn
|
}
|
|
func ipConnection(local, remote net.IP) (*net.IPConn, error) {
|
|
var conn *net.IPConn
|
var errOfListenIP error
|
//redundant
|
//todo simplify here
|
if local.IsLinkLocalUnicast() {
|
var itf, errOfFind = findInterfacebyIP(local)
|
if errOfFind != nil {
|
return nil, fmt.Errorf("ipConnection: can't find zone info of %v", local)
|
}
|
conn, errOfListenIP = net.ListenIP("ip:112", &net.IPAddr{IP: local, Zone: itf.Name})
|
} else {
|
conn, errOfListenIP = net.ListenIP("ip:112", &net.IPAddr{IP: local})
|
}
|
if errOfListenIP != nil {
|
return nil, errOfListenIP
|
}
|
var fd, errOfGetFD = conn.File()
|
if errOfGetFD != nil {
|
return nil, errOfGetFD
|
}
|
defer fd.Close()
|
if remote.To4() != nil {
|
//IPv4 mode
|
//set hop limit
|
if errOfSetHopLimit := syscall.SetsockoptInt(int(fd.Fd()), syscall.IPPROTO_IP, syscall.IP_MULTICAST_TTL, VRRPMultiTTL); errOfSetHopLimit != nil {
|
return nil, fmt.Errorf("ipConnection: %v", errOfSetHopLimit)
|
}
|
//set tos
|
if errOfSetTOS := syscall.SetsockoptInt(int(fd.Fd()), syscall.IPPROTO_IP, syscall.IP_TOS, 7); errOfSetTOS != nil {
|
return nil, fmt.Errorf("ipConnection: %v", errOfSetTOS)
|
}
|
//disable multicast loop
|
if errOfSetLoop := syscall.SetsockoptInt(int(fd.Fd()), syscall.IPPROTO_IP, syscall.IP_MULTICAST_LOOP, 0); errOfSetLoop != nil {
|
return nil, fmt.Errorf("ipConnection: %v", errOfSetLoop)
|
}
|
} else {
|
//IPv6 mode
|
//set hop limit
|
if errOfSetHOPLimit := syscall.SetsockoptInt(int(fd.Fd()), syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_HOPS, 255); errOfSetHOPLimit != nil {
|
return nil, fmt.Errorf("ipConnection: %v", errOfSetHOPLimit)
|
}
|
//disable multicast loop
|
if errOfSetLoop := syscall.SetsockoptInt(int(fd.Fd()), syscall.IPPROTO_IPV6, syscall.IPV6_MULTICAST_LOOP, 0); errOfSetLoop != nil {
|
return nil, fmt.Errorf("ipConnection: %v", errOfSetLoop)
|
}
|
//to receive the hop limit and dst address in oob
|
if err := syscall.SetsockoptInt(int(fd.Fd()), syscall.IPPROTO_IPV6, syscall.IPV6_2292HOPLIMIT, 1); err != nil {
|
return nil, fmt.Errorf("ipConnection: %v", err)
|
}
|
if err := syscall.SetsockoptInt(int(fd.Fd()), syscall.IPPROTO_IPV6, syscall.IPV6_2292PKTINFO, 1); err != nil {
|
return nil, fmt.Errorf("ipConnection: %v", err)
|
}
|
|
}
|
logger.Info("IP virtual connection established %v ==> %v", local, remote)
|
return conn, nil
|
}
|
|
func makeMulticastIPv4Conn(multi, local net.IP) (*net.IPConn, error) {
|
var conn, errOfListenIP = net.ListenIP("ip4:112", &net.IPAddr{IP: multi})
|
if errOfListenIP != nil {
|
return nil, fmt.Errorf("makeMulticastIPv4Conn: %v", errOfListenIP)
|
}
|
var fd, errOfGetFD = conn.File()
|
if errOfGetFD != nil {
|
return nil, fmt.Errorf("makeMulticastIPv4Conn: %v", errOfGetFD)
|
}
|
defer fd.Close()
|
multi = multi.To4()
|
local = local.To4()
|
var mreq = &syscall.IPMreq{
|
Multiaddr: [4]byte{multi[0], multi[1], multi[2], multi[3]},
|
Interface: [4]byte{local[0], local[1], local[2], local[3]},
|
}
|
if errSetMreq := syscall.SetsockoptIPMreq(int(fd.Fd()), syscall.IPPROTO_IP, syscall.IP_ADD_MEMBERSHIP, mreq); errSetMreq != nil {
|
return nil, fmt.Errorf("makeMulticastIPv4Conn: %v", errSetMreq)
|
}
|
return conn, nil
|
}
|
|
func joinIPv6MulticastGroup(con *net.IPConn, local, remote net.IP) error {
|
var fd, errOfGetFD = con.File()
|
if errOfGetFD != nil {
|
return fmt.Errorf("joinIPv6MulticastGroup: %v", errOfGetFD)
|
}
|
defer fd.Close()
|
var mreq = &syscall.IPv6Mreq{}
|
copy(mreq.Multiaddr[:], remote.To16())
|
var IF, errOfGetIF = findInterfacebyIP(local)
|
if errOfGetIF != nil {
|
return fmt.Errorf("joinIPv6MulticastGroup: %v", errOfGetIF)
|
}
|
mreq.Interface = uint32(IF.Index)
|
if errOfSetMreq := syscall.SetsockoptIPv6Mreq(int(fd.Fd()), syscall.IPPROTO_IPV6, syscall.IPV6_JOIN_GROUP, mreq); errOfSetMreq != nil {
|
return fmt.Errorf("joinIPv6MulticastGroup: %v", errOfSetMreq)
|
}
|
logger.Info("Join IPv6 multicast group %v on %v", remote, IF.Name)
|
return nil
|
}
|
|
func NewIPv4Conn(local, remote net.IP) IPConnection {
|
var SendConn, errOfMakeIPConn = ipConnection(local, remote)
|
if errOfMakeIPConn != nil {
|
panic(errOfMakeIPConn)
|
}
|
var receiveConn, errOfMakeRecv = makeMulticastIPv4Conn(VRRPMultiAddrIPv4, local)
|
if errOfMakeRecv != nil {
|
panic(errOfMakeRecv)
|
}
|
return &IPv4Con{
|
buffer: make([]byte, 2048),
|
local: local,
|
remote: remote,
|
SendCon: SendConn,
|
ReceiveCon: receiveConn,
|
}
|
|
}
|
|
func (conn *IPv4Con) WriteMessage(packet *VRRPPacket) error {
|
if _, err := conn.SendCon.WriteTo(packet.ToBytes(), &net.IPAddr{IP: conn.remote}); err != nil {
|
return fmt.Errorf("IPv4Con.WriteMessage: %v", err)
|
}
|
return nil
|
}
|
|
func (conn *IPv4Con) ReadMessage() (*VRRPPacket, error) {
|
var n, errOfRead = conn.ReceiveCon.Read(conn.buffer)
|
if errOfRead != nil {
|
return nil, fmt.Errorf("IPv4Con.ReadMessage: %v", errOfRead)
|
}
|
if n < 20 {
|
return nil, fmt.Errorf("IPv4Con.ReadMessage: IP datagram lenght %v too small", n)
|
}
|
var hdrlen = (int(conn.buffer[0]) & 0x0f) << 2
|
if hdrlen > n {
|
return nil, fmt.Errorf("IPv4Con.ReadMessage: the header length %v is lagger than total length %V", hdrlen, n)
|
}
|
if conn.buffer[8] != 255 {
|
return nil, fmt.Errorf("IPv4Con.ReadMessage: the TTL of IP datagram carring VRRP advertisment must equal to 255")
|
}
|
if advertisement, errOfUnmarshal := FromBytes(IPv4, conn.buffer[hdrlen:n]); errOfUnmarshal != nil {
|
return nil, fmt.Errorf("IPv4Con.ReadMessage: %v", errOfUnmarshal)
|
} else {
|
if VRRPVersion(advertisement.GetVersion()) != VRRPv3 {
|
return nil, fmt.Errorf("IPv4Con.ReadMessage: received an advertisement with %s", VRRPVersion(advertisement.GetVersion()))
|
}
|
var pshdr PseudoHeader
|
pshdr.Saddr = net.IPv4(conn.buffer[12], conn.buffer[13], conn.buffer[14], conn.buffer[15]).To16()
|
pshdr.Daddr = net.IPv4(conn.buffer[16], conn.buffer[17], conn.buffer[18], conn.buffer[19]).To16()
|
pshdr.Protocol = VRRPIPProtocolNumber
|
pshdr.Len = uint16(n - hdrlen)
|
if !advertisement.ValidateCheckSum(&pshdr) {
|
return nil, fmt.Errorf("IPv4Con.ReadMessage: validate the check sum of advertisement failed")
|
} else {
|
advertisement.Pshdr = &pshdr
|
return advertisement, nil
|
}
|
}
|
}
|
|
func NewIPv6Con(local, remote net.IP) *IPv6Con {
|
var con, errOfNewIPv6Con = ipConnection(local, remote)
|
if errOfNewIPv6Con != nil {
|
panic(fmt.Errorf("NewIPv6Con: %v", errOfNewIPv6Con))
|
}
|
if errOfJoinMG := joinIPv6MulticastGroup(con, local, remote); errOfJoinMG != nil {
|
panic(fmt.Errorf("NewIPv6Con: %v", errOfJoinMG))
|
}
|
return &IPv6Con{
|
buffer: make([]byte, 4096),
|
oob: make([]byte, 4096),
|
local: local,
|
remote: remote,
|
Con: con,
|
}
|
}
|
|
func (con *IPv6Con) WriteMessage(packet *VRRPPacket) error {
|
if _, errOfWrite := con.Con.WriteToIP(packet.ToBytes(), &net.IPAddr{IP: con.remote}); errOfWrite != nil {
|
return fmt.Errorf("IPv6Con.WriteMessage: %v", errOfWrite)
|
}
|
return nil
|
}
|
|
func (con *IPv6Con) ReadMessage() (*VRRPPacket, error) {
|
var buffern, oobn, _, raddr, errOfRead = con.Con.ReadMsgIP(con.buffer, con.oob)
|
if errOfRead != nil {
|
return nil, fmt.Errorf("IPv6Con.ReadMessage: %v", errOfRead)
|
}
|
var oobdata, errOfParseOOB = syscall.ParseSocketControlMessage(con.oob[:oobn])
|
if errOfParseOOB != nil {
|
return nil, fmt.Errorf("IPv6Con.ReadMessage: %v", errOfParseOOB)
|
}
|
var (
|
dst net.IP
|
TTL byte
|
GetTTL = false
|
)
|
for index := range oobdata {
|
if oobdata[index].Header.Level != syscall.IPPROTO_IPV6 {
|
continue
|
}
|
switch oobdata[index].Header.Type {
|
case syscall.IPV6_2292HOPLIMIT:
|
if len(oobdata[index].Data) == 0 {
|
return nil, fmt.Errorf("IPv6Con.ReadMessage: invalid HOPLIMIT")
|
}
|
TTL = oobdata[index].Data[0]
|
GetTTL = true
|
case syscall.IPV6_2292PKTINFO:
|
if len(oobdata[index].Data) < 16 {
|
return nil, fmt.Errorf("IPv6Con.ReadMessage: invalid destination IP addrress length")
|
}
|
dst = net.IP(oobdata[index].Data[:16])
|
}
|
}
|
if GetTTL == false {
|
return nil, fmt.Errorf("IPv6Con.ReadMessage: HOPLIMIT not found")
|
}
|
if dst == nil {
|
return nil, fmt.Errorf("IPv6Con.ReadMessage: destination address not found")
|
}
|
var pshdr = PseudoHeader{
|
Daddr: dst,
|
Saddr: raddr.IP,
|
Protocol: VRRPIPProtocolNumber,
|
Len: uint16(buffern),
|
}
|
var advertisement, errOfUnmarshal = FromBytes(IPv6, con.buffer)
|
if errOfUnmarshal != nil {
|
return nil, fmt.Errorf("IPv6Con.ReadMessage: %v", errOfUnmarshal)
|
}
|
if TTL != 255 {
|
return nil, fmt.Errorf("IPv6Con.ReadMessage: invalid HOPLIMIT")
|
}
|
if VRRPVersion(advertisement.GetVersion()) != VRRPv3 {
|
return nil, fmt.Errorf("IPv6Con.ReadMessage: invalid VRRP version %v", advertisement.GetVersion())
|
}
|
if !advertisement.ValidateCheckSum(&pshdr) {
|
return nil, fmt.Errorf("IPv6Con.ReadMessage: invalid check sum")
|
}
|
advertisement.Pshdr = &pshdr
|
return advertisement, nil
|
}
|
|
func findIPbyInterface(itf *net.Interface, IPvX byte) (net.IP, error) {
|
var addrs, errOfListAddrs = itf.Addrs()
|
if errOfListAddrs != nil {
|
return nil, fmt.Errorf("findIPbyInterface: %v", errOfListAddrs)
|
}
|
for index := range addrs {
|
var ipaddr, _, errOfParseIP = net.ParseCIDR(addrs[index].String())
|
if errOfParseIP != nil {
|
return nil, fmt.Errorf("findIPbyInterface: %v", errOfParseIP)
|
}
|
if IPvX == IPv4 {
|
if ipaddr.To4() != nil {
|
if ipaddr.IsGlobalUnicast() {
|
return ipaddr, nil
|
}
|
}
|
} else {
|
if ipaddr.To4() == nil {
|
if ipaddr.IsLinkLocalUnicast() {
|
return ipaddr, nil
|
}
|
}
|
}
|
}
|
return nil, fmt.Errorf("findIPbyInterface: can not find valid IP addrs on %v", itf.Name)
|
}
|
|
func findInterfacebyIP(ip net.IP) (*net.Interface, error) {
|
if itfs, errOfListInterface := net.Interfaces(); errOfListInterface != nil {
|
return nil, fmt.Errorf("findInterfacebyIP: %v", errOfListInterface)
|
} else {
|
for index := range itfs {
|
if addrs, errOfListAddrs := itfs[index].Addrs(); errOfListAddrs != nil {
|
return nil, fmt.Errorf("findInterfacebyIP: %v", errOfListAddrs)
|
} else {
|
for index1 := range addrs {
|
var ipaddr, _, errOfParseIP = net.ParseCIDR(addrs[index1].String())
|
if errOfParseIP != nil {
|
return nil, fmt.Errorf("findInterfacebyIP: %v", errOfParseIP)
|
}
|
if ipaddr.Equal(ip) {
|
return &itfs[index], nil
|
}
|
}
|
}
|
}
|
}
|
|
return nil, fmt.Errorf("findInterfacebyIP: can't find the corresponding interface of %v", ip)
|
}
|