zhangmeng
2019-08-26 a0123f163eddcea3e6b9f9d36f1f3fb3aa2c835a
提交 | 用户 | age
3bd1f2 1 // Copyright 2016 Tom Thorogood. All rights reserved.
Z 2 // Use of this source code is governed by a
3 // Modified BSD License license that can be found in
4 // the LICENSE file.
5
6 package shm
7
8 import (
9     "io"
10     "sync/atomic"
11     "unsafe"
12
a0123f 13     "golang.org/x/sys/unix"
3bd1f2 14 )
Z 15
16 const (
17     eofFlagIndex = 0
18     eofFlagMask  = 0x01
19 )
20
21 type Buffer struct {
22     block *sharedBlock
23     write bool
24
25     Data  []byte
26     Flags *[blockFlagsSize]byte
27 }
28
29 type ReadWriteCloser struct {
30     name string
31
32     data          []byte
33     readShared    *sharedMem
34     writeShared   *sharedMem
35     size          uint64
36     fullBlockSize uint64
37
38     // Must be accessed using atomic operations
39     Flags *[sharedFlagsSize]uint32
40
41     closed uint32
42 }
43
44 func (rw *ReadWriteCloser) Close() error {
45     if !atomic.CompareAndSwapUint32(&rw.closed, 0, 1) {
46         return nil
47     }
48
49     // finish all sends before close!
50
51     return unix.Munmap(rw.data)
52 }
53
54 // Name returns the name of the shared memory.
55 func (rw *ReadWriteCloser) Name() string {
56     return rw.name
57 }
58
59 // Read
60
61 func (rw *ReadWriteCloser) Read(p []byte) (n int, err error) {
62     buf, err := rw.GetReadBuffer()
63     if err != nil {
64         return 0, err
65     }
66
67     n = copy(p, buf.Data)
68     isEOF := buf.Flags[eofFlagIndex]&eofFlagMask != 0
69
70     if err = rw.SendReadBuffer(buf); err != nil {
71         return n, err
72     }
73
74     if isEOF {
75         return n, io.EOF
76     }
77
78     return n, nil
79 }
80
81 func (rw *ReadWriteCloser) WriteTo(w io.Writer) (n int64, err error) {
82     for {
83         buf, err := rw.GetReadBuffer()
84         if err != nil {
85             return n, err
86         }
87
88         nn, err := w.Write(buf.Data)
89         n += int64(nn)
90
91         isEOF := buf.Flags[eofFlagIndex]&eofFlagMask != 0
92
93         if putErr := rw.SendReadBuffer(buf); putErr != nil {
94             return n, putErr
95         }
96
97         if err != nil || isEOF {
98             return n, err
99         }
100     }
101 }
102
103 func (rw *ReadWriteCloser) GetReadBuffer() (Buffer, error) {
104     if atomic.LoadUint32(&rw.closed) != 0 {
105         return Buffer{}, io.ErrClosedPipe
106     }
107
108     var block *sharedBlock
109
110     blocks := uintptr(unsafe.Pointer(rw.readShared)) + sharedHeaderSize
111
112     for {
113         blockIndex := atomic.LoadUint32((*uint32)(&rw.readShared.ReadStart))
114         if blockIndex > uint32(rw.readShared.BlockCount) {
115             return Buffer{}, ErrInvalidSharedMemory
116         }
117
118         block = (*sharedBlock)(unsafe.Pointer(blocks + uintptr(uint64(blockIndex)*rw.fullBlockSize)))
119
120         if blockIndex == atomic.LoadUint32((*uint32)(&rw.readShared.WriteEnd)) {
a0123f 121             if err := ((*Semaphore)(&rw.readShared.SemSignal)).Wait(); err != nil {
3bd1f2 122                 return Buffer{}, err
Z 123             }
124
125             continue
126         }
127
128         if atomic.CompareAndSwapUint32((*uint32)(&rw.readShared.ReadStart), blockIndex, uint32(block.Next)) {
129             break
130         }
131     }
132
133     data := (*[1 << 30]byte)(unsafe.Pointer(uintptr(unsafe.Pointer(block)) + blockHeaderSize))
134     flags := (*[len(block.Flags)]byte)(unsafe.Pointer(&block.Flags[0]))
135     return Buffer{
136         block: block,
137
138         Data:  data[:block.Size:rw.readShared.BlockSize],
139         Flags: flags,
140     }, nil
141 }
142
143 func (rw *ReadWriteCloser) SendReadBuffer(buf Buffer) error {
144     if atomic.LoadUint32(&rw.closed) != 0 {
145         return io.ErrClosedPipe
146     }
147
148     if buf.write {
149         return ErrInvalidBuffer
150     }
151
152     block := buf.block
153
154     atomic.StoreUint32((*uint32)(&block.DoneRead), 1)
155
156     blocks := uintptr(unsafe.Pointer(rw.readShared)) + sharedHeaderSize
157
158     for {
159         blockIndex := atomic.LoadUint32((*uint32)(&rw.readShared.ReadEnd))
160         if blockIndex > uint32(rw.readShared.BlockCount) {
161             return ErrInvalidSharedMemory
162         }
163
164         block = (*sharedBlock)(unsafe.Pointer(blocks + uintptr(uint64(blockIndex)*rw.fullBlockSize)))
165
166         if !atomic.CompareAndSwapUint32((*uint32)(&block.DoneRead), 1, 0) {
167             return nil
168         }
169
170         atomic.CompareAndSwapUint32((*uint32)(&rw.readShared.ReadEnd), blockIndex, uint32(block.Next))
171
172         if uint32(block.Prev) == atomic.LoadUint32((*uint32)(&rw.readShared.WriteStart)) {
a0123f 173             if err := ((*Semaphore)(&rw.readShared.SemAvail)).Post(); err != nil {
3bd1f2 174                 return err
Z 175             }
176         }
177     }
178 }
179
180 // Write
181
182 func (rw *ReadWriteCloser) Write(p []byte) (n int, err error) {
183     buf, err := rw.GetWriteBuffer()
184     if err != nil {
185         return 0, err
186     }
187
188     n = copy(buf.Data[:cap(buf.Data)], p)
189     buf.Data = buf.Data[:n]
190
191     buf.Flags[eofFlagIndex] |= eofFlagMask
192
193     _, err = rw.SendWriteBuffer(buf)
194     return n, err
195 }
196
197 func (rw *ReadWriteCloser) ReadFrom(r io.Reader) (n int64, err error) {
198     for {
199         buf, err := rw.GetWriteBuffer()
200         if err != nil {
201             return n, err
202         }
203
204         nn, err := r.Read(buf.Data[:cap(buf.Data)])
205         buf.Data = buf.Data[:nn]
206         n += int64(nn)
207
208         if err == io.EOF {
209             buf.Flags[eofFlagIndex] |= eofFlagMask
210         } else {
211             buf.Flags[eofFlagIndex] &^= eofFlagMask
212         }
213
214         if _, putErr := rw.SendWriteBuffer(buf); putErr != nil {
215             return n, err
216         }
217
218         if err == io.EOF {
219             return n, nil
220         } else if err != nil {
221             return n, err
222         }
223     }
224 }
225
226 func (rw *ReadWriteCloser) GetWriteBuffer() (Buffer, error) {
227     if atomic.LoadUint32(&rw.closed) != 0 {
228         return Buffer{}, io.ErrClosedPipe
229     }
230
231     var block *sharedBlock
232
233     blocks := uintptr(unsafe.Pointer(rw.writeShared)) + sharedHeaderSize
234
235     for {
236         blockIndex := atomic.LoadUint32((*uint32)(&rw.writeShared.WriteStart))
237         if blockIndex > uint32(rw.writeShared.BlockCount) {
238             return Buffer{}, ErrInvalidSharedMemory
239         }
240
241         block = (*sharedBlock)(unsafe.Pointer(blocks + uintptr(uint64(blockIndex)*rw.fullBlockSize)))
242
243         if uint32(block.Next) == atomic.LoadUint32((*uint32)(&rw.writeShared.ReadEnd)) {
a0123f 244             if err := ((*Semaphore)(&rw.writeShared.SemAvail)).Wait(); err != nil {
3bd1f2 245                 return Buffer{}, err
Z 246             }
247
248             continue
249         }
250
251         if atomic.CompareAndSwapUint32((*uint32)(&rw.writeShared.WriteStart), blockIndex, uint32(block.Next)) {
252             break
253         }
254     }
255
256     data := (*[1 << 30]byte)(unsafe.Pointer(uintptr(unsafe.Pointer(block)) + blockHeaderSize))
257     flags := (*[len(block.Flags)]byte)(unsafe.Pointer(&block.Flags[0]))
258     return Buffer{
259         block: block,
260         write: true,
261
262         Data:  data[:0:rw.writeShared.BlockSize],
263         Flags: flags,
264     }, nil
265 }
266
267 func (rw *ReadWriteCloser) SendWriteBuffer(buf Buffer) (n int, err error) {
268     if atomic.LoadUint32(&rw.closed) != 0 {
269         return 0, io.ErrClosedPipe
270     }
271
272     if !buf.write {
273         return 0, ErrInvalidBuffer
274     }
275
276     block := buf.block
277
278     *(*uint64)(&block.Size) = uint64(len(buf.Data))
279
280     atomic.StoreUint32((*uint32)(&block.DoneWrite), 1)
281
282     blocks := uintptr(unsafe.Pointer(rw.writeShared)) + sharedHeaderSize
283
284     for {
285         blockIndex := atomic.LoadUint32((*uint32)(&rw.writeShared.WriteEnd))
286         if blockIndex > uint32(rw.writeShared.BlockCount) {
287             return len(buf.Data), ErrInvalidSharedMemory
288         }
289
290         block = (*sharedBlock)(unsafe.Pointer(blocks + uintptr(uint64(blockIndex)*rw.fullBlockSize)))
291
292         if !atomic.CompareAndSwapUint32((*uint32)(&block.DoneWrite), 1, 0) {
293             return len(buf.Data), nil
294         }
295
296         atomic.CompareAndSwapUint32((*uint32)(&rw.writeShared.WriteEnd), blockIndex, uint32(block.Next))
297
298         if blockIndex == atomic.LoadUint32((*uint32)(&rw.writeShared.ReadStart)) {
a0123f 299             if err := ((*Semaphore)(&rw.writeShared.SemSignal)).Post(); err != nil {
3bd1f2 300                 return len(buf.Data), err
Z 301             }
302         }
303     }
304 }