zhangmeng
2019-08-26 a7b93ee3558e6da945991d61f64b4c7b44dd132f
提交 | 用户 | age
3bd1f2 1 // Copyright 2016 Tom Thorogood. All rights reserved.
Z 2 // Use of this source code is governed by a
3 // Modified BSD License license that can be found in
4 // the LICENSE file.
5
6 package shm
7
8 import (
9     "io"
10     "sync/atomic"
11     "unsafe"
12
a0123f 13     "golang.org/x/sys/unix"
3bd1f2 14 )
Z 15
16 const (
17     eofFlagIndex = 0
18     eofFlagMask  = 0x01
19 )
20
21 type Buffer struct {
22     block *sharedBlock
23     write bool
24
25     Data  []byte
26     Flags *[blockFlagsSize]byte
27 }
28
29 type ReadWriteCloser struct {
30     name string
31
32     data          []byte
33     readShared    *sharedMem
34     writeShared   *sharedMem
35     size          uint64
36     fullBlockSize uint64
37
38     // Must be accessed using atomic operations
39     Flags *[sharedFlagsSize]uint32
40
41     closed uint32
42 }
43
44 func (rw *ReadWriteCloser) Close() error {
45     if !atomic.CompareAndSwapUint32(&rw.closed, 0, 1) {
46         return nil
47     }
48
49     // finish all sends before close!
50
51     return unix.Munmap(rw.data)
52 }
53
54 // Name returns the name of the shared memory.
55 func (rw *ReadWriteCloser) Name() string {
56     return rw.name
57 }
58
a7b93e 59 // DirectRead create byte in func
Z 60 func (rw *ReadWriteCloser) DirectRead() ([]byte, error) {
61     buf, err := rw.GetReadBuffer()
62     if err != nil {
63         return nil, err
64     }
3bd1f2 65
a7b93e 66     data := make([]byte, len(buf.Data))
Z 67
68     copy(data, buf.Data)
69     isEOF := buf.Flags[eofFlagIndex]&eofFlagMask != 0
70
71     if err = rw.SendReadBuffer(buf); err != nil {
72         return nil, err
73     }
74
75     if isEOF {
76         return nil, io.EOF
77     }
78
79     return data, nil
80 }
81
82 // Peek get length
83 func (rw *ReadWriteCloser) Peek() (n int, err error) {
84     buf, err := rw.GetReadBuffer()
85     if err != nil {
86         return 0, err
87     }
88
89     return len(buf.Data), nil
90 }
91
92 // Read
3bd1f2 93 func (rw *ReadWriteCloser) Read(p []byte) (n int, err error) {
Z 94     buf, err := rw.GetReadBuffer()
95     if err != nil {
96         return 0, err
97     }
98
99     n = copy(p, buf.Data)
100     isEOF := buf.Flags[eofFlagIndex]&eofFlagMask != 0
101
102     if err = rw.SendReadBuffer(buf); err != nil {
103         return n, err
104     }
105
106     if isEOF {
107         return n, io.EOF
108     }
109
110     return n, nil
111 }
112
113 func (rw *ReadWriteCloser) WriteTo(w io.Writer) (n int64, err error) {
114     for {
115         buf, err := rw.GetReadBuffer()
116         if err != nil {
117             return n, err
118         }
119
120         nn, err := w.Write(buf.Data)
121         n += int64(nn)
122
123         isEOF := buf.Flags[eofFlagIndex]&eofFlagMask != 0
124
125         if putErr := rw.SendReadBuffer(buf); putErr != nil {
126             return n, putErr
127         }
128
129         if err != nil || isEOF {
130             return n, err
131         }
132     }
133 }
134
135 func (rw *ReadWriteCloser) GetReadBuffer() (Buffer, error) {
136     if atomic.LoadUint32(&rw.closed) != 0 {
137         return Buffer{}, io.ErrClosedPipe
138     }
139
140     var block *sharedBlock
141
142     blocks := uintptr(unsafe.Pointer(rw.readShared)) + sharedHeaderSize
143
144     for {
145         blockIndex := atomic.LoadUint32((*uint32)(&rw.readShared.ReadStart))
146         if blockIndex > uint32(rw.readShared.BlockCount) {
147             return Buffer{}, ErrInvalidSharedMemory
148         }
149
150         block = (*sharedBlock)(unsafe.Pointer(blocks + uintptr(uint64(blockIndex)*rw.fullBlockSize)))
151
152         if blockIndex == atomic.LoadUint32((*uint32)(&rw.readShared.WriteEnd)) {
a0123f 153             if err := ((*Semaphore)(&rw.readShared.SemSignal)).Wait(); err != nil {
3bd1f2 154                 return Buffer{}, err
Z 155             }
156
157             continue
158         }
159
160         if atomic.CompareAndSwapUint32((*uint32)(&rw.readShared.ReadStart), blockIndex, uint32(block.Next)) {
161             break
162         }
163     }
164
165     data := (*[1 << 30]byte)(unsafe.Pointer(uintptr(unsafe.Pointer(block)) + blockHeaderSize))
166     flags := (*[len(block.Flags)]byte)(unsafe.Pointer(&block.Flags[0]))
167     return Buffer{
168         block: block,
169
170         Data:  data[:block.Size:rw.readShared.BlockSize],
171         Flags: flags,
172     }, nil
173 }
174
175 func (rw *ReadWriteCloser) SendReadBuffer(buf Buffer) error {
176     if atomic.LoadUint32(&rw.closed) != 0 {
177         return io.ErrClosedPipe
178     }
179
180     if buf.write {
181         return ErrInvalidBuffer
182     }
183
184     block := buf.block
185
186     atomic.StoreUint32((*uint32)(&block.DoneRead), 1)
187
188     blocks := uintptr(unsafe.Pointer(rw.readShared)) + sharedHeaderSize
189
190     for {
191         blockIndex := atomic.LoadUint32((*uint32)(&rw.readShared.ReadEnd))
192         if blockIndex > uint32(rw.readShared.BlockCount) {
193             return ErrInvalidSharedMemory
194         }
195
196         block = (*sharedBlock)(unsafe.Pointer(blocks + uintptr(uint64(blockIndex)*rw.fullBlockSize)))
197
198         if !atomic.CompareAndSwapUint32((*uint32)(&block.DoneRead), 1, 0) {
199             return nil
200         }
201
202         atomic.CompareAndSwapUint32((*uint32)(&rw.readShared.ReadEnd), blockIndex, uint32(block.Next))
203
204         if uint32(block.Prev) == atomic.LoadUint32((*uint32)(&rw.readShared.WriteStart)) {
a0123f 205             if err := ((*Semaphore)(&rw.readShared.SemAvail)).Post(); err != nil {
3bd1f2 206                 return err
Z 207             }
208         }
209     }
210 }
211
212 // Write
213
214 func (rw *ReadWriteCloser) Write(p []byte) (n int, err error) {
215     buf, err := rw.GetWriteBuffer()
216     if err != nil {
217         return 0, err
218     }
219
220     n = copy(buf.Data[:cap(buf.Data)], p)
221     buf.Data = buf.Data[:n]
222
223     buf.Flags[eofFlagIndex] |= eofFlagMask
224
225     _, err = rw.SendWriteBuffer(buf)
226     return n, err
227 }
228
229 func (rw *ReadWriteCloser) ReadFrom(r io.Reader) (n int64, err error) {
230     for {
231         buf, err := rw.GetWriteBuffer()
232         if err != nil {
233             return n, err
234         }
235
236         nn, err := r.Read(buf.Data[:cap(buf.Data)])
237         buf.Data = buf.Data[:nn]
238         n += int64(nn)
239
240         if err == io.EOF {
241             buf.Flags[eofFlagIndex] |= eofFlagMask
242         } else {
243             buf.Flags[eofFlagIndex] &^= eofFlagMask
244         }
245
246         if _, putErr := rw.SendWriteBuffer(buf); putErr != nil {
247             return n, err
248         }
249
250         if err == io.EOF {
251             return n, nil
252         } else if err != nil {
253             return n, err
254         }
255     }
256 }
257
258 func (rw *ReadWriteCloser) GetWriteBuffer() (Buffer, error) {
259     if atomic.LoadUint32(&rw.closed) != 0 {
260         return Buffer{}, io.ErrClosedPipe
261     }
262
263     var block *sharedBlock
264
265     blocks := uintptr(unsafe.Pointer(rw.writeShared)) + sharedHeaderSize
266
267     for {
268         blockIndex := atomic.LoadUint32((*uint32)(&rw.writeShared.WriteStart))
269         if blockIndex > uint32(rw.writeShared.BlockCount) {
270             return Buffer{}, ErrInvalidSharedMemory
271         }
272
273         block = (*sharedBlock)(unsafe.Pointer(blocks + uintptr(uint64(blockIndex)*rw.fullBlockSize)))
274
275         if uint32(block.Next) == atomic.LoadUint32((*uint32)(&rw.writeShared.ReadEnd)) {
a0123f 276             if err := ((*Semaphore)(&rw.writeShared.SemAvail)).Wait(); err != nil {
3bd1f2 277                 return Buffer{}, err
Z 278             }
279
280             continue
281         }
282
283         if atomic.CompareAndSwapUint32((*uint32)(&rw.writeShared.WriteStart), blockIndex, uint32(block.Next)) {
284             break
285         }
286     }
287
288     data := (*[1 << 30]byte)(unsafe.Pointer(uintptr(unsafe.Pointer(block)) + blockHeaderSize))
289     flags := (*[len(block.Flags)]byte)(unsafe.Pointer(&block.Flags[0]))
290     return Buffer{
291         block: block,
292         write: true,
293
294         Data:  data[:0:rw.writeShared.BlockSize],
295         Flags: flags,
296     }, nil
297 }
298
299 func (rw *ReadWriteCloser) SendWriteBuffer(buf Buffer) (n int, err error) {
300     if atomic.LoadUint32(&rw.closed) != 0 {
301         return 0, io.ErrClosedPipe
302     }
303
304     if !buf.write {
305         return 0, ErrInvalidBuffer
306     }
307
308     block := buf.block
309
310     *(*uint64)(&block.Size) = uint64(len(buf.Data))
311
312     atomic.StoreUint32((*uint32)(&block.DoneWrite), 1)
313
314     blocks := uintptr(unsafe.Pointer(rw.writeShared)) + sharedHeaderSize
315
316     for {
317         blockIndex := atomic.LoadUint32((*uint32)(&rw.writeShared.WriteEnd))
318         if blockIndex > uint32(rw.writeShared.BlockCount) {
319             return len(buf.Data), ErrInvalidSharedMemory
320         }
321
322         block = (*sharedBlock)(unsafe.Pointer(blocks + uintptr(uint64(blockIndex)*rw.fullBlockSize)))
323
324         if !atomic.CompareAndSwapUint32((*uint32)(&block.DoneWrite), 1, 0) {
325             return len(buf.Data), nil
326         }
327
328         atomic.CompareAndSwapUint32((*uint32)(&rw.writeShared.WriteEnd), blockIndex, uint32(block.Next))
329
330         if blockIndex == atomic.LoadUint32((*uint32)(&rw.writeShared.ReadStart)) {
a0123f 331             if err := ((*Semaphore)(&rw.writeShared.SemSignal)).Post(); err != nil {
3bd1f2 332                 return len(buf.Data), err
Z 333             }
334         }
335     }
336 }