gRPC Mock Server
1. 项目结构
.
├── Makefile
├── db.json
├── generic_mock_server.go
├── go.mod
├── go.sum
├── test
│ ├── callback
│ │ ├── callback.pb.go
│ │ └── callback_grpc.pb.go
│ ├── callback_server_50053.go
│ ├── generic_greeter_client_50052.go
│ ├── greeter
│ │ ├── greeter.pb.go
│ │ └── greeter_grpc.pb.go
│ ├── greeter_client_50052.go
│ └── proto
│ └── v1
│ ├── callback.proto
│ └── greeter.proto
└── util
└── helper.go
Makefile:
.PHONY: test clean
test:
protoc --go_out=. --go-grpc_out=. test/proto/v1/greeter.proto
protoc --go_out=. --go-grpc_out=. test/proto/v1/callback.proto
clean:
rm -rf ./test/greeter/ ./test/callback/
db.json:
{
"/greeter.Greeter/SayHello": {
"key-1": {
"delay_milli_seconds": 1000,
"relative_paths": [
"./test/proto/v1"
],
"file": "greeter.proto",
"resp": "{\"message\":\"Hello gRPC Mock Server\"}",
"map_from_req": {
"trace_id": "trace_id"
},
"after_return": {
"callbacks": [
{
"delay_milli_seconds": 1000,
"type": "echo",
"param": "submitted"
},
{
"delay_milli_seconds": 1000,
"type": "echo",
"param": "succeeded"
},
{
"delay_milli_seconds": 2000,
"type": "grpc",
"param": "{\"addr\":\"localhost:50053\",\"method_path\":\"/callback.Callback/Invoke\",\"relative_paths\":[\"./test/proto/v1\"],\"file\":\"callback.proto\",\"init_req\":\"{\\\"state\\\":\\\"ok\\\"}\",\"map_from_resp\":{\"trace_id\":\"trace_id\"}}"
}
]
}
}
}
}
generic_mock_server:
package main
import (
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"log"
"net"
"os"
"runtime"
"time"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/dynamicpb"
"grpc-mock-server/util"
)
type Callback struct {
DelayMilliSeconds int64 `json:"delay_milli_seconds"`
Type string `json:"type"`
Param string `json:"param"`
}
type GRPCCallbackParam struct {
Addr string `json:"addr"`
MethodPath string `json:"method_path"`
RelativePaths []string `json:"relative_paths"`
File string `json:"file"`
InitReq string `json:"init_req"`
MapFromReq map[string]string `json:"map_from_req"`
MapFromResp map[string]string `json:"map_from_resp"`
}
func (c *Callback) Execute(req, resp *dynamicpb.Message) error {
time.Sleep(time.Duration(c.DelayMilliSeconds) * time.Millisecond)
switch c.Type {
case "echo":
log.Printf("echo: %s", c.Param)
case "grpc":
gcp := &GRPCCallbackParam{}
if err := json.Unmarshal([]byte(c.Param), gcp); err != nil {
return err
}
conn, err := grpc.NewClient(gcp.Addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
log.Printf("Failed to create gRPC client: %v", err)
return err
}
defer func() { _ = conn.Close() }()
log.Printf("make a gRPC call to %s", gcp.Addr)
resp, err := util.MakeCall(context.Background(), &util.MakeCallParam{
Conn: conn,
MethodPath: gcp.MethodPath,
RelativePaths: gcp.RelativePaths,
File: gcp.File,
InitReqJson: []byte(gcp.InitReq),
Req: req,
Resp: resp,
MapFromReq: gcp.MapFromReq,
MapFromResp: gcp.MapFromResp,
})
if err != nil {
log.Printf("Failed to make call: %v", err)
return err
}
log.Printf("got response: %s\n", protojson.Format(resp))
// other callback types ...
default:
return fmt.Errorf("unknow callback type %s", c.Type)
}
return nil
}
type AfterReturn struct {
Callbacks []Callback `json:"callbacks"`
}
func (c *AfterReturn) Execute(req, resp *dynamicpb.Message) error {
for _, callback := range c.Callbacks {
if err := callback.Execute(req, resp); err != nil {
return err
}
}
return nil
}
type Item struct {
DelayMilliSeconds int64 `json:"delay_milli_seconds"`
RelativePaths []string `json:"relative_paths"`
File string `json:"file"`
Resp string `json:"resp"`
MapFromReq map[string]string `json:"map_from_req"`
AfterReturn *AfterReturn `json:"after_return"`
}
var g errgroup.Group
func init() {
g.SetLimit(runtime.NumCPU())
}
func GetItemFromFile(file, methodPath, mockKey string) (*Item, error) {
db := make(map[string]map[string]*Item)
bs, err := os.ReadFile(file)
if err != nil {
return nil, err
}
if err := json.Unmarshal(bs, &db); err != nil {
return nil, err
}
if _, ok := db[methodPath]; !ok {
return nil, fmt.Errorf("method not found: %s", methodPath)
}
if _, ok := db[methodPath][mockKey]; !ok {
return nil, fmt.Errorf("mock key not found: %s", mockKey)
}
return db[methodPath][mockKey], nil
}
type StreamHandler struct {
dbFile string
}
func (s *StreamHandler) HandleStream(_ any, ss grpc.ServerStream) error {
for {
methodPath, ok := grpc.MethodFromServerStream(ss)
if !ok {
return errors.New("could not find method")
}
log.Printf("method: %v\n", methodPath)
md, ok := metadata.FromIncomingContext(ss.Context())
if !ok {
return errors.New("could not get metadata")
}
mockKeyList, found := md["x-mock-key"]
if !found || len(mockKeyList) == 0 {
return errors.New("could not find x-mock-key")
}
mockKey := mockKeyList[0]
log.Printf("mock key: %v\n", mockKey)
item, err := GetItemFromFile(s.dbFile, methodPath, mockKey)
if err != nil {
return err
}
methodDesc, err := util.GetMethodDescriptor(methodPath, item.RelativePaths, item.File)
if err != nil {
return err
}
req := dynamicpb.NewMessage(methodDesc.GetInputType().UnwrapMessage())
if err := ss.RecvMsg(req); err != nil {
return err
}
time.Sleep(time.Duration(item.DelayMilliSeconds) * time.Millisecond)
if err := ss.SendHeader(metadata.MD{
"x-is-mocked": []string{"1"},
}); err != nil {
log.Printf("could not send header: %v\n", err)
return err
}
resp := dynamicpb.NewMessage(methodDesc.GetOutputType().UnwrapMessage())
if len(item.Resp) > 0 {
if err := protojson.Unmarshal([]byte(item.Resp), resp); err != nil {
log.Printf("could not unmarshal response: %v\n", err)
return err
}
}
for k, v := range item.MapFromReq {
val, err := util.GetNestedValue(req, v)
if err != nil {
return err
}
if err := util.SetNestedValue(resp, k, val); err != nil {
return err
}
}
if err := ss.SendMsg(resp); err != nil {
log.Printf("could not send message: %v\n", err)
return err
}
g.Go(func() error {
if item.AfterReturn != nil {
return item.AfterReturn.Execute(req, resp)
}
return nil
})
return nil
}
}
func NewStreamHandler(dbFile string) *StreamHandler {
return &StreamHandler{
dbFile: dbFile,
}
}
func main() {
var port = 0
var dbFile = ""
flag.IntVar(&port, "port", 50052, "server port")
flag.StringVar(&dbFile, "db", "db.json", "database file")
flag.Parse()
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
server := grpc.NewServer(
grpc.UnknownServiceHandler(NewStreamHandler(dbFile).HandleStream),
)
log.Println("Starting generic gRPC server on :50052")
if err := server.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}
go.mod:
module grpc-mock-server
go 1.23.2
require (
github.com/jhump/protoreflect v1.17.0
golang.org/x/sync v0.11.0
google.golang.org/grpc v1.70.0
google.golang.org/protobuf v1.36.5
)
require (
github.com/bufbuild/protocompile v0.14.1 // indirect
github.com/golang/protobuf v1.5.4 // indirect
golang.org/x/net v0.32.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a // indirect
)
test/callback_server_50053.go:
package main
import (
"context"
"log"
"net"
"google.golang.org/grpc"
"google.golang.org/protobuf/encoding/protojson"
"grpc-mock-server/test/callback"
)
type CallbackService struct {
callback.UnimplementedCallbackServer
}
func (c *CallbackService) Invoke(ctx context.Context, req *callback.CallbackRequest) (*callback.CallbackResponse, error) {
log.Printf("Received Callback Request %s", protojson.Format(req))
return &callback.CallbackResponse{Succeeded: true}, nil
}
func main() {
lis, err := net.Listen("tcp", ":50053")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
server := grpc.NewServer()
callback.RegisterCallbackServer(server, &CallbackService{})
log.Println("Starting gRPC server on :50053")
if err := server.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err)
}
}
test/generic_greeter_client_50052.go:
package main
import (
"context"
"log"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/encoding/protojson"
"grpc-mock-server/util"
)
func main() {
conn, err := grpc.NewClient(
"localhost:50052",
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
log.Fatal(err)
}
defer func() { _ = conn.Close() }()
methodPath := "/greeter.Greeter/SayHello"
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// create metadata
md := metadata.New(map[string]string{
"x-mock-key": "key-1",
"x-request-id": "1234567890",
})
// create context with the metadata
ctx = metadata.NewOutgoingContext(ctx, md)
resp, err := util.MakeCall(ctx,
&util.MakeCallParam{
Conn: conn,
MethodPath: methodPath,
RelativePaths: []string{"./proto/v1"},
File: "greeter.proto",
InitReqJson: []byte(`{"name": {"first_name": "Tim", "last_name": "Zhou"}}`),
},
)
if err != nil {
log.Fatal(err)
}
log.Println(protojson.Format(resp))
}
test/greeter_client_50052.go:
package main
import (
"context"
"fmt"
"log"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
pb "grpc-mock-server/test/greeter"
)
func main() {
conn, err := grpc.NewClient(
"localhost:50052",
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
log.Fatalf("did not connect: %v", err)
}
defer conn.Close()
client := pb.NewGreeterClient(conn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// create metadata
md := metadata.New(map[string]string{
"x-mock-key": "key-1",
"x-request-id": "1234567890",
})
// create context with the metadata
ctx = metadata.NewOutgoingContext(ctx, md)
// prepare to receive header and trailer metadata
var header metadata.MD
// make the gRPC call with metadata
resp, err := client.SayHello(ctx, &pb.HelloRequest{Name: &pb.Name{
FirstName: "Tim",
LastName: "Zhou",
}, TraceId: "trace-id-XXX"}, grpc.Header(&header))
if err != nil {
log.Fatalf("could not greet: %v", err)
}
// print the response
fmt.Printf("Greeting: %s[trace_id=%s]\n", resp.Message, resp.TraceId)
// print received header metadata
fmt.Println("Header metadata:")
for k, v := range header {
fmt.Printf("\t%s: %s\n", k, v)
}
}
test/proto/v1/callback.proto:
syntax = "proto3";
package callback;
option go_package = "./test/callback;callback";
service Callback {
rpc Invoke(CallbackRequest) returns (CallbackResponse) {}
}
message CallbackRequest {
string state = 1;
string trace_id = 2;
}
message CallbackResponse {
bool succeeded = 1;
}
test/proto/v1/greeter.proto:
syntax = "proto3";
package greeter;
option go_package = "./test/greeter;greeter";
service Greeter {
rpc SayHello (HelloRequest) returns (HelloResponse) {}
}
message Name {
string first_name = 1;
string last_name = 2;
}
message HelloRequest {
Name name = 1;
string trace_id = 2;
}
message HelloResponse {
string message = 1;
string trace_id = 2;
}
util/helper.go:
package util
import (
"context"
"errors"
"fmt"
"strings"
"github.com/jhump/protoreflect/desc"
"github.com/jhump/protoreflect/desc/protoparse"
"google.golang.org/grpc"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/dynamicpb"
)
// GetMethodDescriptor 获取方法描述符
func GetMethodDescriptor(methodPath string, relativePaths []string, file string) (*desc.MethodDescriptor, error) {
parts := strings.Split(methodPath, "/")
if len(parts) != 3 {
return nil, errors.New("invalid method path")
}
serviceName := parts[1]
methodName := parts[2]
parser := protoparse.Parser{ImportPaths: relativePaths}
fds, err := parser.ParseFiles(file)
if err != nil {
return nil, err
}
fd := fds[0]
serviceDesc := fd.FindService(serviceName)
if serviceDesc == nil {
return nil, errors.New("service not found")
}
methodDesc := serviceDesc.FindMethodByName(methodName)
if methodDesc == nil {
return nil, errors.New("method not found")
}
return methodDesc, nil
}
type MakeCallParam struct {
Conn *grpc.ClientConn
MethodPath string
RelativePaths []string
File string
InitReqJson []byte
MapFromReq map[string]string
Req *dynamicpb.Message
MapFromResp map[string]string
Resp *dynamicpb.Message
}
// MakeCall 发起 gRPC 调用
func MakeCall(ctx context.Context, param *MakeCallParam) (*dynamicpb.Message, error) {
methodPath := param.MethodPath
methodDesc, err := GetMethodDescriptor(methodPath, param.RelativePaths, param.File)
if err != nil {
return nil, err
}
inputDesc := methodDesc.GetInputType()
req := dynamicpb.NewMessage(inputDesc.UnwrapMessage())
if err := protojson.Unmarshal(param.InitReqJson, req); err != nil {
return nil, err
}
if param.Req != nil {
for k, v := range param.MapFromReq {
val, err := GetNestedValue(param.Req, v)
if err != nil {
return nil, err
}
if err := SetNestedValue(req, k, val); err != nil {
return nil, err
}
}
}
if param.Resp != nil {
for k, v := range param.MapFromResp {
val, err := GetNestedValue(param.Resp, v)
if err != nil {
return nil, err
}
if err := SetNestedValue(req, k, val); err != nil {
return nil, err
}
}
}
outputDesc := methodDesc.GetOutputType()
resp := dynamicpb.NewMessage(outputDesc.UnwrapMessage())
err = param.Conn.Invoke(ctx, methodPath, req, resp)
if err != nil {
return nil, err
}
return resp, nil
}
// GetNestedValue 从 Message 中获取嵌套字段,层级之间用“.”分隔
func GetNestedValue(msg *dynamicpb.Message, fieldPath string) (protoreflect.Value, error) {
fields := strings.Split(fieldPath, ".")
current := msg
// 逐级获取字段
for i, fieldName := range fields {
// 获取当前曾经对应的 File Descriptor
fd := current.Descriptor().Fields().ByName(protoreflect.Name(fieldName))
if fd == nil {
return protoreflect.Value{}, fmt.Errorf("field %q not found", fieldName)
}
// 如果已经是最后一级,那么返回
if i == len(fields)-1 {
return current.Get(fd), nil
}
// 中间层级必须是 Message
if fd.Kind() != protoreflect.MessageKind {
return protoreflect.Value{}, fmt.Errorf("field %q is not a message", fieldName)
}
// 获取当前层级对应的子 Message
subMsg := current.Get(fd).Message()
if subMsg == nil {
// 如果为 nil,那么返回
return protoreflect.Value{}, fmt.Errorf("submessage %q is nil", fieldName)
}
var ok bool
if current, ok = subMsg.(*dynamicpb.Message); !ok {
return protoreflect.Value{}, fmt.Errorf("submessage %q is not dynamicpb.Message", fieldName)
}
}
return protoreflect.Value{}, nil
}
// SetNestedValue 设置 Message 的嵌套字段,字段之间用“.”分隔
func SetNestedValue(msg *dynamicpb.Message, fieldPath string, value protoreflect.Value) error {
fields := strings.Split(fieldPath, ".")
// 获取或初始化中间层级
current := msg
for i := 0; i < len(fields)-1; i++ {
fieldName := fields[i]
fd := current.Descriptor().Fields().ByName(protoreflect.Name(fieldName))
if fd == nil {
return fmt.Errorf("field %q not found", fieldName)
}
// 中间层级必须是 Message
if fd.Kind() != protoreflect.MessageKind {
return fmt.Errorf("field %q is not a message", fieldName)
}
// 如果字段不存在,那么初始化它
if !current.Has(fd) {
newSubMsg := dynamicpb.NewMessage(fd.Message())
current.Set(fd, protoreflect.ValueOf(newSubMsg))
}
// 继续处理下一级
subVal := current.Get(fd)
subMsg, ok := subVal.Message().(*dynamicpb.Message)
if !ok {
return fmt.Errorf("submessage %q is not dynamicpb.Message", fieldName)
}
current = subMsg
}
lastField := fields[len(fields)-1]
fd := current.Descriptor().Fields().ByName(protoreflect.Name(lastField))
if fd == nil {
return fmt.Errorf("field %q not found", lastField)
}
// 值必须与字段兼容
if !isTypeCompatible(fd, value) {
return fmt.Errorf("type mismatch: field %s expects %s, got %T",
lastField, fd.Kind(), value.Interface())
}
// 对 List 和 Map 进行特殊处理,其它类型直接设置
switch {
case fd.IsList():
return handleListField(current, fd, value)
case fd.IsMap():
return handleMapField(current, fd, value)
default:
current.Set(fd, value)
}
return nil
}
// isTypeCompatible 判读 value 的类型是否与 FileDescriptor 兼容
func isTypeCompatible(fd protoreflect.FieldDescriptor, value protoreflect.Value) bool {
if !value.IsValid() {
return true
}
fieldKind := fd.Kind()
valueKind := ValueKind(value)
if fieldKind == protoreflect.EnumKind {
_, ok := value.Interface().(protoreflect.EnumNumber)
return ok
}
if fieldKind == protoreflect.MessageKind {
msgDesc := fd.Message()
valMsg, ok := value.Interface().(protoreflect.Message)
if !ok {
return false
}
return msgDesc == valMsg.Descriptor()
}
return fieldKind == valueKind
}
// ValueKind 返回 Value 的 protoreflect 种类
func ValueKind(value protoreflect.Value) protoreflect.Kind {
switch value.Interface().(type) {
case bool:
return protoreflect.BoolKind
case int32:
return protoreflect.Int32Kind
case int64:
return protoreflect.Int64Kind
case uint32:
return protoreflect.Uint32Kind
case uint64:
return protoreflect.Uint64Kind
case float32:
return protoreflect.FloatKind
case float64:
return protoreflect.DoubleKind
case string:
return protoreflect.StringKind
case []byte:
return protoreflect.BytesKind
case protoreflect.Message:
return protoreflect.MessageKind
case protoreflect.EnumNumber:
return protoreflect.EnumKind
default:
return protoreflect.Kind(0)
}
}
// handleListField 处理 List 字段
func handleListField(msg *dynamicpb.Message, fd protoreflect.FieldDescriptor, value protoreflect.Value) error {
dstList := msg.Mutable(fd).List()
srcList := value.List()
// 截断目标列表
dstList.Truncate(0)
// 将源列表中的每个元素追加到目标列表
for i := 0; i < srcList.Len(); i++ {
val := srcList.Get(i)
// 追加时判断类型是否兼容
if !isTypeCompatible(fd, val) {
return fmt.Errorf("list element type mismatch at index %d", i)
}
dstList.Append(val)
}
return nil
}
// handleMapField 处理 Map 字段
func handleMapField(msg *dynamicpb.Message, fd protoreflect.FieldDescriptor, value protoreflect.Value) error {
dstMap := msg.Mutable(fd).Map()
srcMap := value.Map()
// 清空目标 Map
dstMap.Range(func(k protoreflect.MapKey, _ protoreflect.Value) bool {
dstMap.Clear(k)
return true
})
// 将源 Map 中的条目逐个添加到目标 Map
var err error
srcMap.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
if !isTypeCompatible(fd.MapKey(), protoreflect.ValueOf(k.Interface())) {
err = fmt.Errorf("map key type mismatch: %v", k.Interface())
return false
}
if !isTypeCompatible(fd.MapValue(), v) {
err = fmt.Errorf("map value type mismatch: %v", v.Interface())
return false
}
dstMap.Set(k, v)
return true
})
return err
}
2. 测试
2.1. Setup
go mod tidy
go mod download
make test
2.2. 测试
go run generic_mock_server.go
go run test/callback_server_50053.go
go run test/greeter_client_50052.go