-
-
Notifications
You must be signed in to change notification settings - Fork 78
/
Copy pathgrpc_handler.go
190 lines (171 loc) · 5.65 KB
/
grpc_handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
package grpc_proxy
import (
"context"
"fmt"
"io"
"os"
"strings"
"github.com/bradleyjkemp/grpc-tools/internal/codec"
"github.com/bradleyjkemp/grpc-tools/internal/marker"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
var proxyStreamDesc = &grpc.StreamDesc{
ServerStreams: true,
ClientStreams: true,
}
// Originally based on github.com/mwitkow/grpc-proxy/proxy/handler.go
func (s *server) proxyHandler(srv interface{}, ss grpc.ServerStream) error {
md, ok := metadata.FromIncomingContext(ss.Context())
if !ok {
return status.Error(codes.Unknown, "could not extract metadata from request")
}
options := append(s.dialOptions,
grpc.WithDefaultCallOptions(grpc.ForceCodec(codec.NoopCodec{})),
grpc.WithBlock(),
)
if marker.IsTLSRPC(md) {
options = append(options, grpc.WithTransportCredentials(credentials.NewTLS(nil)))
} else {
options = append(options, grpc.WithInsecure())
}
destinationAddr, err := s.calculateDestination(md)
if err != nil {
return err
}
destination, err := s.connPool.GetClientConn(ss.Context(), destinationAddr, options...)
if err != nil {
return err
}
// little bit of gRPC internals never hurt anyone
fullMethodName, ok := grpc.MethodFromServerStream(ss)
if !ok {
return status.Errorf(codes.Internal, "no method exists in context")
}
clientCtx, clientCancel := getClientCtx(ss.Context())
clientStream, err := destination.NewStream(clientCtx, proxyStreamDesc, fullMethodName)
if err != nil {
return err
}
// Explicitly *do not close* s2cErrChan and c2sErrChan, otherwise the select below will not terminate.
// Channels do not have to be closed, it is just a control flow mechanism, see
// https://groups.google.com/forum/#!msg/golang-nuts/pZwdYRGxCIk/qpbHxRRPJdUJ
s2cErrChan := forwardServerToClient(ss, clientStream)
c2sErrChan := forwardClientToServer(clientStream, ss)
// We don't know which side is going to stop sending first, so we need a select between the two.
for i := 0; i < 2; i++ {
select {
case s2cErr := <-s2cErrChan:
if s2cErr == io.EOF {
// this is the happy case where the sender has encountered io.EOF, and won't be sending anymore./
// the clientStream>serverStream may continue pumping though.
clientStream.CloseSend()
break
} else {
// however, we may have gotten a receive error (stream disconnected, a read error etc) in which case we need
// to cancel the clientStream to the backend, let all of its goroutines be freed up by the CancelFunc and
// exit with an error to the stack
clientCancel()
fmt.Fprintln(os.Stderr, "failed proxying s2c", s2cErr)
return grpc.Errorf(codes.Internal, "failed proxying s2c: %v", s2cErr)
}
case c2sErr := <-c2sErrChan:
// This happens when the clientStream has nothing else to offer (io.EOF), returned a gRPC error. In those two
// cases we may have received Trailers as part of the call. In case of other errors (stream closed) the trailers
// will be nil.
ss.SetTrailer(clientStream.Trailer())
// c2sErr will contain RPC error from client code. If not io.EOF return the RPC error as server stream error.
if c2sErr != io.EOF {
return c2sErr
}
return nil
}
}
return status.Errorf(codes.Internal, "gRPC proxying should never reach this stage.")
}
func (s *server) calculateDestination(md metadata.MD) (string, error) {
authority := md.Get(":authority")
var destinationAddr string
switch {
case s.destination != "":
// used hardcoded destination if set (used by clients not supporting HTTP proxies)
destinationAddr = s.destination
case len(authority) > 0:
// use authority from request
destinationAddr = authority[0]
default:
// no destination can be determined so just error
return "", status.Error(codes.Unimplemented, "no proxy destination configured")
}
// if this a gRPC-Web connection then it doesn't have a port so we add the default
if !strings.Contains(destinationAddr, ":") {
if marker.IsTLSRPC(md) {
destinationAddr = destinationAddr + ":443"
} else {
destinationAddr = destinationAddr + ":80"
}
}
if err := marker.AddLoopCheck(md, s.listener.Addr().String()); err != nil {
return "", err
}
return destinationAddr, nil
}
func getClientCtx(serverCtx context.Context) (context.Context, context.CancelFunc) {
clientCtx, clientCancel := context.WithCancel(serverCtx)
md, ok := metadata.FromIncomingContext(serverCtx)
if ok {
clientCtx = metadata.NewOutgoingContext(clientCtx, md)
}
return clientCtx, clientCancel
}
func forwardClientToServer(src grpc.ClientStream, dst grpc.ServerStream) chan error {
ret := make(chan error, 1)
go func() {
var f []byte
for i := 0; ; i++ {
if err := src.RecvMsg(&f); err != nil {
ret <- err // this can be io.EOF which is happy case
break
}
if i == 0 {
// This is a bit of a hack, but client to server headers are only readable after first client msg is
// received but must be written to server stream before the first msg is flushed.
// This is the only place to do it nicely.
md, err := src.Header()
if err != nil {
ret <- err
break
}
if err := dst.SendHeader(md); err != nil {
ret <- err
break
}
}
if err := dst.SendMsg(f); err != nil {
ret <- err
break
}
}
}()
return ret
}
func forwardServerToClient(src grpc.ServerStream, dst grpc.ClientStream) chan error {
ret := make(chan error, 1)
go func() {
var f []byte
for i := 0; ; i++ {
if err := src.RecvMsg(&f); err != nil {
ret <- err // this can be io.EOF which is happy case
break
}
if err := dst.SendMsg(f); err != nil {
ret <- err
break
}
}
}()
return ret
}