gRPC Mock Server

1. 项目结构

.
├── Makefile
├── db.json
├── generic_mock_server.go
├── go.mod
├── go.sum
├── test
│   ├── callback
│   │   ├── callback.pb.go
│   │   └── callback_grpc.pb.go
│   ├── callback_server_50053.go
│   ├── generic_greeter_client_50052.go
│   ├── greeter
│   │   ├── greeter.pb.go
│   │   └── greeter_grpc.pb.go
│   ├── greeter_client_50052.go
│   └── proto
│       └── v1
│           ├── callback.proto
│           └── greeter.proto
└── util
    └── helper.go

Makefile:

.PHONY: test clean

test:
	protoc --go_out=. --go-grpc_out=. test/proto/v1/greeter.proto
	protoc --go_out=. --go-grpc_out=. test/proto/v1/callback.proto

clean:
	rm -rf ./test/greeter/ ./test/callback/

db.json:

{
  "/greeter.Greeter/SayHello": {
    "key-1": {
      "delay_milli_seconds": 1000,
      "relative_paths": [
        "./test/proto/v1"
      ],
      "file": "greeter.proto",
      "resp": "{\"message\":\"Hello gRPC Mock Server\"}",
      "map_from_req": {
        "trace_id": "trace_id"
      },
      "after_return": {
        "callbacks": [
          {
            "delay_milli_seconds": 1000,
            "type": "echo",
            "param": "submitted"
          },
          {
            "delay_milli_seconds": 1000,
            "type": "echo",
            "param": "succeeded"
          },
          {
            "delay_milli_seconds": 2000,
            "type": "grpc",
            "param": "{\"addr\":\"localhost:50053\",\"method_path\":\"/callback.Callback/Invoke\",\"relative_paths\":[\"./test/proto/v1\"],\"file\":\"callback.proto\",\"init_req\":\"{\\\"state\\\":\\\"ok\\\"}\",\"map_from_resp\":{\"trace_id\":\"trace_id\"}}"
          }
        ]
      }
    }
  }
}

generic_mock_server:

package main

import (
	"context"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"log"
	"net"
	"os"
	"runtime"
	"time"

	"golang.org/x/sync/errgroup"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/metadata"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/types/dynamicpb"

	"grpc-mock-server/util"
)

type Callback struct {
	DelayMilliSeconds int64  `json:"delay_milli_seconds"`
	Type              string `json:"type"`
	Param             string `json:"param"`
}

type GRPCCallbackParam struct {
	Addr          string            `json:"addr"`
	MethodPath    string            `json:"method_path"`
	RelativePaths []string          `json:"relative_paths"`
	File          string            `json:"file"`
	InitReq       string            `json:"init_req"`
	MapFromReq    map[string]string `json:"map_from_req"`
	MapFromResp   map[string]string `json:"map_from_resp"`
}

func (c *Callback) Execute(req, resp *dynamicpb.Message) error {
	time.Sleep(time.Duration(c.DelayMilliSeconds) * time.Millisecond)
	switch c.Type {
	case "echo":
		log.Printf("echo: %s", c.Param)
	case "grpc":
		gcp := &GRPCCallbackParam{}
		if err := json.Unmarshal([]byte(c.Param), gcp); err != nil {
			return err
		}
		conn, err := grpc.NewClient(gcp.Addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
		if err != nil {
			log.Printf("Failed to create gRPC client: %v", err)
			return err
		}
		defer func() { _ = conn.Close() }()
		log.Printf("make a gRPC call to %s", gcp.Addr)
		resp, err := util.MakeCall(context.Background(), &util.MakeCallParam{
			Conn:          conn,
			MethodPath:    gcp.MethodPath,
			RelativePaths: gcp.RelativePaths,
			File:          gcp.File,
			InitReqJson:   []byte(gcp.InitReq),
			Req:           req,
			Resp:          resp,
			MapFromReq:    gcp.MapFromReq,
			MapFromResp:   gcp.MapFromResp,
		})
		if err != nil {
			log.Printf("Failed to make call: %v", err)
			return err
		}
		log.Printf("got response: %s\n", protojson.Format(resp))
	// other callback types ...
	default:
		return fmt.Errorf("unknow callback type %s", c.Type)
	}
	return nil
}

type AfterReturn struct {
	Callbacks []Callback `json:"callbacks"`
}

func (c *AfterReturn) Execute(req, resp *dynamicpb.Message) error {
	for _, callback := range c.Callbacks {
		if err := callback.Execute(req, resp); err != nil {
			return err
		}
	}
	return nil
}

type Item struct {
	DelayMilliSeconds int64             `json:"delay_milli_seconds"`
	RelativePaths     []string          `json:"relative_paths"`
	File              string            `json:"file"`
	Resp              string            `json:"resp"`
	MapFromReq        map[string]string `json:"map_from_req"`
	AfterReturn       *AfterReturn      `json:"after_return"`
}

var g errgroup.Group

func init() {
	g.SetLimit(runtime.NumCPU())
}

func GetItemFromFile(file, methodPath, mockKey string) (*Item, error) {
	db := make(map[string]map[string]*Item)
	bs, err := os.ReadFile(file)
	if err != nil {
		return nil, err
	}
	if err := json.Unmarshal(bs, &db); err != nil {
		return nil, err
	}
	if _, ok := db[methodPath]; !ok {
		return nil, fmt.Errorf("method not found: %s", methodPath)
	}
	if _, ok := db[methodPath][mockKey]; !ok {
		return nil, fmt.Errorf("mock key not found: %s", mockKey)
	}
	return db[methodPath][mockKey], nil
}

type StreamHandler struct {
	dbFile string
}

func (s *StreamHandler) HandleStream(_ any, ss grpc.ServerStream) error {
	for {
		methodPath, ok := grpc.MethodFromServerStream(ss)
		if !ok {
			return errors.New("could not find method")
		}
		log.Printf("method: %v\n", methodPath)
		md, ok := metadata.FromIncomingContext(ss.Context())
		if !ok {
			return errors.New("could not get metadata")
		}
		mockKeyList, found := md["x-mock-key"]
		if !found || len(mockKeyList) == 0 {
			return errors.New("could not find x-mock-key")
		}
		mockKey := mockKeyList[0]
		log.Printf("mock key: %v\n", mockKey)

		item, err := GetItemFromFile(s.dbFile, methodPath, mockKey)
		if err != nil {
			return err
		}

		methodDesc, err := util.GetMethodDescriptor(methodPath, item.RelativePaths, item.File)
		if err != nil {
			return err
		}

		req := dynamicpb.NewMessage(methodDesc.GetInputType().UnwrapMessage())
		if err := ss.RecvMsg(req); err != nil {
			return err
		}

		time.Sleep(time.Duration(item.DelayMilliSeconds) * time.Millisecond)

		if err := ss.SendHeader(metadata.MD{
			"x-is-mocked": []string{"1"},
		}); err != nil {
			log.Printf("could not send header: %v\n", err)
			return err
		}
		resp := dynamicpb.NewMessage(methodDesc.GetOutputType().UnwrapMessage())
		if len(item.Resp) > 0 {
			if err := protojson.Unmarshal([]byte(item.Resp), resp); err != nil {
				log.Printf("could not unmarshal response: %v\n", err)
				return err
			}
		}
		for k, v := range item.MapFromReq {
			val, err := util.GetNestedValue(req, v)
			if err != nil {
				return err
			}
			if err := util.SetNestedValue(resp, k, val); err != nil {
				return err
			}
		}
		if err := ss.SendMsg(resp); err != nil {
			log.Printf("could not send message: %v\n", err)
			return err
		}
		g.Go(func() error {
			if item.AfterReturn != nil {
				return item.AfterReturn.Execute(req, resp)
			}
			return nil
		})
		return nil
	}
}

func NewStreamHandler(dbFile string) *StreamHandler {
	return &StreamHandler{
		dbFile: dbFile,
	}
}

func main() {
	var port = 0
	var dbFile = ""
	flag.IntVar(&port, "port", 50052, "server port")
	flag.StringVar(&dbFile, "db", "db.json", "database file")
	flag.Parse()

	lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
	if err != nil {
		log.Fatalf("failed to listen: %v", err)
	}

	server := grpc.NewServer(
		grpc.UnknownServiceHandler(NewStreamHandler(dbFile).HandleStream),
	)

	log.Println("Starting generic gRPC server on :50052")
	if err := server.Serve(lis); err != nil {
		log.Fatalf("failed to serve: %v", err)
	}
}

go.mod:

module grpc-mock-server

go 1.23.2

require (
	github.com/jhump/protoreflect v1.17.0
	golang.org/x/sync v0.11.0
	google.golang.org/grpc v1.70.0
	google.golang.org/protobuf v1.36.5
)

require (
	github.com/bufbuild/protocompile v0.14.1 // indirect
	github.com/golang/protobuf v1.5.4 // indirect
	golang.org/x/net v0.32.0 // indirect
	golang.org/x/sys v0.28.0 // indirect
	golang.org/x/text v0.21.0 // indirect
	google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a // indirect
)

test/callback_server_50053.go:

package main

import (
	"context"
	"log"
	"net"

	"google.golang.org/grpc"
	"google.golang.org/protobuf/encoding/protojson"

	"grpc-mock-server/test/callback"
)

type CallbackService struct {
	callback.UnimplementedCallbackServer
}

func (c *CallbackService) Invoke(ctx context.Context, req *callback.CallbackRequest) (*callback.CallbackResponse, error) {
	log.Printf("Received Callback Request %s", protojson.Format(req))
	return &callback.CallbackResponse{Succeeded: true}, nil
}

func main() {
	lis, err := net.Listen("tcp", ":50053")
	if err != nil {
		log.Fatalf("failed to listen: %v", err)
	}

	server := grpc.NewServer()
	callback.RegisterCallbackServer(server, &CallbackService{})

	log.Println("Starting gRPC server on :50053")
	if err := server.Serve(lis); err != nil {
		log.Fatalf("failed to serve: %v", err)
	}
}

test/generic_greeter_client_50052.go:

package main

import (
	"context"
	"log"
	"time"

	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/metadata"
	"google.golang.org/protobuf/encoding/protojson"

	"grpc-mock-server/util"
)

func main() {
	conn, err := grpc.NewClient(
		"localhost:50052",
		grpc.WithTransportCredentials(insecure.NewCredentials()),
	)
	if err != nil {
		log.Fatal(err)
	}
	defer func() { _ = conn.Close() }()

	methodPath := "/greeter.Greeter/SayHello"

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	// create metadata
	md := metadata.New(map[string]string{
		"x-mock-key":   "key-1",
		"x-request-id": "1234567890",
	})
	// create context with the metadata
	ctx = metadata.NewOutgoingContext(ctx, md)
	resp, err := util.MakeCall(ctx,
		&util.MakeCallParam{
			Conn:          conn,
			MethodPath:    methodPath,
			RelativePaths: []string{"./proto/v1"},
			File:          "greeter.proto",
			InitReqJson:   []byte(`{"name": {"first_name": "Tim", "last_name": "Zhou"}}`),
		},
	)
	if err != nil {
		log.Fatal(err)
	}
	log.Println(protojson.Format(resp))
}

test/greeter_client_50052.go:

package main

import (
	"context"
	"fmt"
	"log"
	"time"

	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/metadata"

	pb "grpc-mock-server/test/greeter"
)

func main() {
	conn, err := grpc.NewClient(
		"localhost:50052",
		grpc.WithTransportCredentials(insecure.NewCredentials()),
	)
	if err != nil {
		log.Fatalf("did not connect: %v", err)
	}
	defer conn.Close()

	client := pb.NewGreeterClient(conn)

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()
	// create metadata
	md := metadata.New(map[string]string{
		"x-mock-key":   "key-1",
		"x-request-id": "1234567890",
	})
	// create context with the metadata
	ctx = metadata.NewOutgoingContext(ctx, md)

	// prepare to receive header and trailer metadata
	var header metadata.MD

	// make the gRPC call with metadata
	resp, err := client.SayHello(ctx, &pb.HelloRequest{Name: &pb.Name{
		FirstName: "Tim",
		LastName:  "Zhou",
	}, TraceId: "trace-id-XXX"}, grpc.Header(&header))
	if err != nil {
		log.Fatalf("could not greet: %v", err)
	}

	// print the response
	fmt.Printf("Greeting: %s[trace_id=%s]\n", resp.Message, resp.TraceId)

	// print received header metadata
	fmt.Println("Header metadata:")
	for k, v := range header {
		fmt.Printf("\t%s: %s\n", k, v)
	}
}

test/proto/v1/callback.proto:

syntax = "proto3";

package callback;

option go_package = "./test/callback;callback";

service Callback {
  rpc Invoke(CallbackRequest) returns (CallbackResponse) {}
}

message CallbackRequest {
  string state = 1;
  string trace_id = 2;
}

message CallbackResponse {
  bool succeeded = 1;
}

test/proto/v1/greeter.proto:

syntax = "proto3";

package greeter;

option go_package = "./test/greeter;greeter";

service Greeter {
  rpc SayHello (HelloRequest) returns (HelloResponse) {}
}

message Name {
  string first_name = 1;
  string last_name = 2;
}

message HelloRequest {
  Name name = 1;
  string trace_id = 2;
}

message HelloResponse {
  string message = 1;
  string trace_id = 2;
}

util/helper.go:

 package util

import (
	"context"
	"errors"
	"fmt"
	"strings"

	"github.com/jhump/protoreflect/desc"
	"github.com/jhump/protoreflect/desc/protoparse"
	"google.golang.org/grpc"
	"google.golang.org/protobuf/encoding/protojson"
	"google.golang.org/protobuf/reflect/protoreflect"
	"google.golang.org/protobuf/types/dynamicpb"
)

// GetMethodDescriptor 获取方法描述符
func GetMethodDescriptor(methodPath string, relativePaths []string, file string) (*desc.MethodDescriptor, error) {
	parts := strings.Split(methodPath, "/")
	if len(parts) != 3 {
		return nil, errors.New("invalid method path")
	}
	serviceName := parts[1]
	methodName := parts[2]
	parser := protoparse.Parser{ImportPaths: relativePaths}
	fds, err := parser.ParseFiles(file)
	if err != nil {
		return nil, err
	}
	fd := fds[0]
	serviceDesc := fd.FindService(serviceName)
	if serviceDesc == nil {
		return nil, errors.New("service not found")
	}
	methodDesc := serviceDesc.FindMethodByName(methodName)
	if methodDesc == nil {
		return nil, errors.New("method not found")
	}
	return methodDesc, nil
}

type MakeCallParam struct {
	Conn          *grpc.ClientConn
	MethodPath    string
	RelativePaths []string
	File          string
	InitReqJson   []byte
	MapFromReq    map[string]string
	Req           *dynamicpb.Message
	MapFromResp   map[string]string
	Resp          *dynamicpb.Message
}

// MakeCall 发起 gRPC 调用
func MakeCall(ctx context.Context, param *MakeCallParam) (*dynamicpb.Message, error) {
	methodPath := param.MethodPath
	methodDesc, err := GetMethodDescriptor(methodPath, param.RelativePaths, param.File)
	if err != nil {
		return nil, err
	}
	inputDesc := methodDesc.GetInputType()
	req := dynamicpb.NewMessage(inputDesc.UnwrapMessage())
	if err := protojson.Unmarshal(param.InitReqJson, req); err != nil {
		return nil, err
	}
	if param.Req != nil {
		for k, v := range param.MapFromReq {
			val, err := GetNestedValue(param.Req, v)
			if err != nil {
				return nil, err
			}
			if err := SetNestedValue(req, k, val); err != nil {
				return nil, err
			}
		}
	}
	if param.Resp != nil {
		for k, v := range param.MapFromResp {
			val, err := GetNestedValue(param.Resp, v)
			if err != nil {
				return nil, err
			}
			if err := SetNestedValue(req, k, val); err != nil {
				return nil, err
			}
		}
	}
	outputDesc := methodDesc.GetOutputType()
	resp := dynamicpb.NewMessage(outputDesc.UnwrapMessage())

	err = param.Conn.Invoke(ctx, methodPath, req, resp)
	if err != nil {
		return nil, err
	}
	return resp, nil
}

// GetNestedValue 从 Message 中获取嵌套字段,层级之间用“.”分隔
func GetNestedValue(msg *dynamicpb.Message, fieldPath string) (protoreflect.Value, error) {
	fields := strings.Split(fieldPath, ".")
	current := msg
	// 逐级获取字段
	for i, fieldName := range fields {
		// 获取当前曾经对应的 File Descriptor
		fd := current.Descriptor().Fields().ByName(protoreflect.Name(fieldName))
		if fd == nil {
			return protoreflect.Value{}, fmt.Errorf("field %q not found", fieldName)
		}

		// 如果已经是最后一级,那么返回
		if i == len(fields)-1 {
			return current.Get(fd), nil
		}

		// 中间层级必须是 Message
		if fd.Kind() != protoreflect.MessageKind {
			return protoreflect.Value{}, fmt.Errorf("field %q is not a message", fieldName)
		}

		// 获取当前层级对应的子 Message
		subMsg := current.Get(fd).Message()
		if subMsg == nil {
			// 如果为 nil,那么返回
			return protoreflect.Value{}, fmt.Errorf("submessage %q is nil", fieldName)
		}

		var ok bool
		if current, ok = subMsg.(*dynamicpb.Message); !ok {
			return protoreflect.Value{}, fmt.Errorf("submessage %q is not dynamicpb.Message", fieldName)
		}
	}
	return protoreflect.Value{}, nil
}

// SetNestedValue 设置 Message 的嵌套字段,字段之间用“.”分隔
func SetNestedValue(msg *dynamicpb.Message, fieldPath string, value protoreflect.Value) error {
	fields := strings.Split(fieldPath, ".")
	// 获取或初始化中间层级
	current := msg
	for i := 0; i < len(fields)-1; i++ {
		fieldName := fields[i]
		fd := current.Descriptor().Fields().ByName(protoreflect.Name(fieldName))
		if fd == nil {
			return fmt.Errorf("field %q not found", fieldName)
		}
		// 中间层级必须是 Message
		if fd.Kind() != protoreflect.MessageKind {
			return fmt.Errorf("field %q is not a message", fieldName)
		}

		// 如果字段不存在,那么初始化它
		if !current.Has(fd) {
			newSubMsg := dynamicpb.NewMessage(fd.Message())
			current.Set(fd, protoreflect.ValueOf(newSubMsg))
		}

		// 继续处理下一级
		subVal := current.Get(fd)
		subMsg, ok := subVal.Message().(*dynamicpb.Message)
		if !ok {
			return fmt.Errorf("submessage %q is not dynamicpb.Message", fieldName)
		}
		current = subMsg
	}

	lastField := fields[len(fields)-1]
	fd := current.Descriptor().Fields().ByName(protoreflect.Name(lastField))
	if fd == nil {
		return fmt.Errorf("field %q not found", lastField)
	}
	// 值必须与字段兼容
	if !isTypeCompatible(fd, value) {
		return fmt.Errorf("type mismatch: field %s expects %s, got %T",
			lastField, fd.Kind(), value.Interface())
	}

	// 对 List 和 Map 进行特殊处理,其它类型直接设置
	switch {
	case fd.IsList():
		return handleListField(current, fd, value)
	case fd.IsMap():
		return handleMapField(current, fd, value)
	default:
		current.Set(fd, value)
	}
	return nil
}

// isTypeCompatible 判读 value 的类型是否与 FileDescriptor 兼容
func isTypeCompatible(fd protoreflect.FieldDescriptor, value protoreflect.Value) bool {
	if !value.IsValid() {
		return true
	}

	fieldKind := fd.Kind()
	valueKind := ValueKind(value)

	if fieldKind == protoreflect.EnumKind {
		_, ok := value.Interface().(protoreflect.EnumNumber)
		return ok
	}

	if fieldKind == protoreflect.MessageKind {
		msgDesc := fd.Message()
		valMsg, ok := value.Interface().(protoreflect.Message)
		if !ok {
			return false
		}
		return msgDesc == valMsg.Descriptor()
	}

	return fieldKind == valueKind
}

// ValueKind 返回 Value 的 protoreflect 种类
func ValueKind(value protoreflect.Value) protoreflect.Kind {
	switch value.Interface().(type) {
	case bool:
		return protoreflect.BoolKind
	case int32:
		return protoreflect.Int32Kind
	case int64:
		return protoreflect.Int64Kind
	case uint32:
		return protoreflect.Uint32Kind
	case uint64:
		return protoreflect.Uint64Kind
	case float32:
		return protoreflect.FloatKind
	case float64:
		return protoreflect.DoubleKind
	case string:
		return protoreflect.StringKind
	case []byte:
		return protoreflect.BytesKind
	case protoreflect.Message:
		return protoreflect.MessageKind
	case protoreflect.EnumNumber:
		return protoreflect.EnumKind
	default:
		return protoreflect.Kind(0)
	}
}

// handleListField 处理 List 字段
func handleListField(msg *dynamicpb.Message, fd protoreflect.FieldDescriptor, value protoreflect.Value) error {
	dstList := msg.Mutable(fd).List()
	srcList := value.List()
	// 截断目标列表
	dstList.Truncate(0)

	// 将源列表中的每个元素追加到目标列表
	for i := 0; i < srcList.Len(); i++ {
		val := srcList.Get(i)
		// 追加时判断类型是否兼容
		if !isTypeCompatible(fd, val) {
			return fmt.Errorf("list element type mismatch at index %d", i)
		}
		dstList.Append(val)
	}
	return nil
}

// handleMapField 处理 Map 字段
func handleMapField(msg *dynamicpb.Message, fd protoreflect.FieldDescriptor, value protoreflect.Value) error {
	dstMap := msg.Mutable(fd).Map()
	srcMap := value.Map()

	// 清空目标 Map
	dstMap.Range(func(k protoreflect.MapKey, _ protoreflect.Value) bool {
		dstMap.Clear(k)
		return true
	})

	// 将源 Map 中的条目逐个添加到目标 Map
	var err error
	srcMap.Range(func(k protoreflect.MapKey, v protoreflect.Value) bool {
		if !isTypeCompatible(fd.MapKey(), protoreflect.ValueOf(k.Interface())) {
			err = fmt.Errorf("map key type mismatch: %v", k.Interface())
			return false
		}
		if !isTypeCompatible(fd.MapValue(), v) {
			err = fmt.Errorf("map value type mismatch: %v", v.Interface())
			return false
		}
		dstMap.Set(k, v)
		return true
	})
	return err
}

2. 测试

2.1. Setup

go mod tidy
go mod download
make test

2.2. 测试

 go run generic_mock_server.go
 go run test/callback_server_50053.go
 
 go run test/greeter_client_50052.go