forked from docker-archive/go-p9p
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathchannel_test.go
256 lines (219 loc) · 5.91 KB
/
channel_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
package p9p
import (
"bytes"
"context"
"encoding/binary"
"net"
"testing"
"time"
)
// TestTwriteOverflow ensures that a Twrite message will have the data field
// truncated if the msize would be exceeded.
func TestTwriteOverflow(t *testing.T) {
const (
msize = 512
// size[4] Twrite tag[2] fid[4] offset[8] count[4] data[count] | count = 0
overhead = 4 + 1 + 2 + 4 + 8 + 4
)
var (
ctx = context.Background()
conn = &mockConn{}
ch = NewChannel(conn, msize)
)
for _, testcase := range []struct {
name string
overflow int // amount to overflow the message by.
}{
{
name: "BoundedOverflow",
overflow: msize / 2,
},
{
name: "LargeOverflow",
overflow: msize * 3,
},
{
name: "HeaderOverflow",
overflow: overhead,
},
{
name: "HeaderOffsetOverflow",
overflow: overhead - 1,
},
{
name: "OverflowByOne",
overflow: 1,
},
} {
t.Run(testcase.name, func(t *testing.T) {
var (
fcall = overflowMessage(ch.(*channel).codec, msize, testcase.overflow)
data = fcall.Message.(MessageTwrite).Data
size uint32
)
t.Logf("overflow: %v, len(data): %v, expected overflow: %v", testcase.overflow, len(data), overhead+len(data)-msize)
conn.buf.Reset()
if err := ch.WriteFcall(ctx, fcall); err != nil {
t.Fatal(err)
}
if err := binary.Read(bytes.NewReader(conn.buf.Bytes()), binary.LittleEndian, &size); err != nil {
t.Fatal(err)
}
if size != msize {
t.Fatalf("should have truncated size header: %d != %d", size, msize)
}
if conn.buf.Len() != msize {
t.Fatalf("should have truncated message: conn.buf.Len(%v) != msize(%v)", conn.buf.Len(), msize)
}
})
}
}
// TestWriteOverflowError ensures that we return an error in cases when there
// will certainly be an overflow and it cannot be resolved.
func TestWriteOverflowError(t *testing.T) {
const (
msize = 4
overflowMSize = msize + 1
)
var (
ctx = context.Background()
conn = &mockConn{}
ch = NewChannel(conn, msize)
data = bytes.Repeat([]byte{'A'}, 4)
fcall = newFcall(1, MessageTwrite{
Data: data,
})
messageSize = 4 + ch.(*channel).codec.Size(fcall)
)
err := ch.WriteFcall(ctx, fcall)
if err == nil {
t.Fatal("error expected when overflowing message")
}
if Overflow(err) != messageSize-msize {
t.Fatalf("overflow should reflect messageSize and msize, %d != %d", Overflow(err), messageSize-msize)
}
}
// TestReadOverflow ensures that messages coming over a network connection do
// not overflow the msize. Invalid messages will cause `ReadFcall` to return an
// Overflow error.
func TestReadFcallOverflow(t *testing.T) {
const (
msize = 256
)
var (
ctx = context.Background()
conn = &mockConn{}
ch = NewChannel(conn, msize)
codec = ch.(*channel).codec
)
for _, testcase := range []struct {
name string
overflow int
}{
{
name: "OverflowByOne",
overflow: 1,
},
{
name: "HeaderOverflow",
overflow: overheadMessage(codec, MessageTwrite{}),
},
{
name: "HeaderOffsetOverflow",
overflow: overheadMessage(codec, MessageTwrite{}) - 1,
},
} {
t.Run(testcase.name, func(t *testing.T) {
fcall := overflowMessage(codec, msize, testcase.overflow)
// prepare the raw message
p, err := ch.(*channel).codec.Marshal(fcall)
if err != nil {
t.Fatal(err)
}
// "send" the message into the buffer
// this message is crafted to overflow the read buffer.
if err := sendmsg(&conn.buf, p); err != nil {
t.Fatal(err)
}
var incoming Fcall
err = ch.ReadFcall(ctx, &incoming)
if err == nil {
t.Fatal("expected error on fcall")
}
// sanity check to ensure our test code has the right overflow
if testcase.overflow != ch.(*channel).msgmsize(fcall)-msize {
t.Fatalf("overflow calculation incorrect: %v != %v", testcase.overflow, ch.(*channel).msgmsize(fcall)-msize)
}
if Overflow(err) != testcase.overflow {
t.Fatalf("unexpected overflow on error: %v !=%v", Overflow(err), testcase.overflow)
}
})
}
}
// TestTreadRewrite ensures that messages that whose response would overflow
// the msize will have be adjusted before sending.
func TestTreadRewrite(t *testing.T) {
const (
msize = 256
overflowMSize = msize + 1
)
var (
ctx = context.Background()
conn = &mockConn{}
ch = NewChannel(conn, msize)
buf = make([]byte, overflowMSize)
// data = bytes.Repeat([]byte{'A'}, overflowMSize)
fcall = newFcall(1, MessageTread{
Count: overflowMSize,
})
responseMSize = ch.(*channel).msgmsize(newFcall(1, MessageRread{
Data: buf,
}))
)
if err := ch.WriteFcall(ctx, fcall); err != nil {
t.Fatal(err)
}
// just read the message off the buffer
n, err := readmsg(&conn.buf, buf)
if err != nil {
t.Fatal(err)
}
*fcall = Fcall{}
if err := ch.(*channel).codec.Unmarshal(buf[:n], fcall); err != nil {
t.Fatal(err)
}
tread, ok := fcall.Message.(MessageTread)
if !ok {
t.Fatalf("unexpected message: %v", fcall)
}
if tread.Count != overflowMSize-(uint32(responseMSize)-msize) {
t.Fatalf("count not rewritten: %v != %v", tread.Count, overflowMSize-(uint32(responseMSize)-msize))
}
}
type mockConn struct {
net.Conn
buf bytes.Buffer
}
func (m mockConn) SetWriteDeadline(t time.Time) error { return nil }
func (m mockConn) SetReadDeadline(t time.Time) error { return nil }
func (m *mockConn) Write(p []byte) (int, error) {
return m.buf.Write(p)
}
func (m *mockConn) Read(p []byte) (int, error) {
return m.buf.Read(p)
}
func overheadMessage(codec Codec, msg Message) int {
return 4 + codec.Size(newFcall(1, msg))
}
// overflowMessage returns message that overflows the msize by overflow bytes,
// returning the message size and the fcall.
func overflowMessage(codec Codec, msize, overflow int) *Fcall {
var (
overhead = overheadMessage(codec, MessageTwrite{})
data = bytes.Repeat([]byte{'A'}, (msize-overhead)+overflow)
fcall = newFcall(1, MessageTwrite{
Data: data,
})
)
return fcall
}