代码结构
.
├── client.go
├── coder.go
├── coder_test.go
├── rpc_test.go
├── server.go
├── session.go
└── session_test.go
代码
client.go
package rpc
import (
"net"
"reflect"
)
// rpc 客户端实现
// 抽象客户端方法
type Client struct {
conn net.Conn
}
// client构造方法
func NewClient(conn net.Conn) *Client {
return &Client{conn: conn}
}
// 客户端调用服务端rpc实现
// client.RpcCall("login", &req)
func (c *Client) RpcCall(name string, fpr interface{}) {
// 反射获取函数原型
fn := reflect.ValueOf(fpr).Elem()
// 客户端逻辑的实现
f := func(args []reflect.Value) (results []reflect.Value) {
// 从匿名函数中构建请求参数
inArgs := make([]interface{}, 0, len(args))
for _, v := range args {
inArgs = append(inArgs, v.Interface())
}
// 组装rpc data请求数据
reqData := RpcData{Name: name, Args: inArgs}
// 进行数据编码
reqByteData, err := encode(reqData)
if err != nil {
return
}
// 创建session 对象
session := NewSession(c.conn)
// 客户端发送数据
err = session.Write(reqByteData)
if err != nil {
return
}
// 读取客户端数据
rspByteData, err := session.Read()
if err != nil {
return
}
// 数据进行解码
rspData, err := decode(rspByteData)
if err != nil {
return
}
// 处理服务端返回的数据结果
outArgs := make([]reflect.Value, 0, len(rspData.Args))
for i, v := range rspData.Args {
// 数据特殊情况处理
if v == nil {
// reflect.Zero() 返回某类型的零值的value
// .Out()返回函数输出的参数类型
// 得到具体第几个位置的参数的零值
outArgs = append(outArgs, reflect.Zero(fn.Type().Out(i)))
continue
}
outArgs = append(outArgs, reflect.ValueOf(v))
}
return outArgs
}
// 函数原型到调用的关键,需要2个参数
// 参数1:函数原型,是Type类型
// 参数2:返回类型是Value类型
// 简单理解:参数1是函数原型,参数2是客户端逻辑
v := reflect.MakeFunc(fn.Type(), f)
fn.Set(v)
}
coder.go
package rpc
import (
"bytes"
"encoding/gob"
"fmt"
)
// 对传输的数据进行编解码
// 使用Golang自带的一个数据结构序列化编码/解码工具 gob
// 定义rpc数据交互式数据传输格式
type RpcData struct {
Name string // 调用方法名
Args []interface{} // 调用和返回的参数列表
}
// 编码
func encode(data RpcData) ([]byte, error) {
// gob进行编码
var buf bytes.Buffer
// 得到字节编码器
encoder := gob.NewEncoder(&buf)
// 进行编码
if err := encoder.Encode(data); err != nil {
fmt.Printf("gob encode failed, err: %v\n", err)
return nil, err
}
return buf.Bytes(), nil
}
// 解码
func decode(data []byte) (RpcData, error) {
// 得到字节解码器
buf := bytes.NewBuffer(data)
decoder := gob.NewDecoder(buf)
// 解码数据
var rd RpcData
if err := decoder.Decode(&rd); err != nil {
fmt.Printf("gob decode failed, err: %v\n", err)
return rd, err
}
return rd, nil
}
server.go
package rpc
import (
"net"
"reflect"
)
// rpc 服务端实现
// 抽象服务端
type Server struct {
add string // 连接地址
funcs map[string]reflect.Value // 存储方法名和方法的对应关系,服务注册
}
// server 构造方法
func NewServer(addr string) *Server {
return &Server{add: addr, funcs: make(map[string]reflect.Value)}
}
// 注册接口
func (s *Server) Register(name string, fc interface{}) {
if _, ok := s.funcs[name]; ok {
return
}
s.funcs[name] = reflect.ValueOf(fc)
}
func (s *Server) Run() (err error) {
listener, err := net.Listen("tcp", s.add)
if err != nil {
return
}
for {
// 监听连接
conn, err := listener.Accept()
if err != nil {
conn.Close()
continue
}
// 创建会话
session := NewSession(conn)
// 读取会话请求数据
reqData, err := session.Read()
if err != nil {
conn.Close()
continue
}
// 数据解码
rpcReqData, err := decode(reqData)
// 获取客户端要调用的方法
fc, ok := s.funcs[rpcReqData.Name];
if !ok {
conn.Close()
continue
}
// 获取请求的参数列表
args := make([]reflect.Value, 0, len(rpcReqData.Args))
for _, v := range rpcReqData.Args {
args = append(args, reflect.ValueOf(v))
}
// 调用
callReslut := fc.Call(args)
// 处理调用返回的数据结果
rargs := make([]interface{}, 0, len(callReslut))
for _, rv := range callReslut {
rargs = append(rargs, rv.Interface())
}
// 构建返回的rpc数据
rpcRspData := RpcData{Name: rpcReqData.Name, Args: rargs}
// 返回数据进行编码
rspData, err := encode(rpcRspData)
if err != nil {
conn.Close()
continue
}
err = session.Write(rspData)
if err != nil {
conn.Close()
continue
}
}
return
}
session.go
package rpc
import (
"encoding/binary"
"fmt"
"io"
"net"
)
// 处理连接会话
// 会话对象结构体
type Session struct {
conn net.Conn
}
// 传输数据存储方式
// 字节数组, 添加4个字节的头,用来存储数据的长度
// 会话构造函数
func NewSession(conn net.Conn) *Session {
return &Session{conn: conn}
}
// 从连接中读取数据
func (s *Session) Read() (data []byte, err error) {
// 读取数据header数据
header := make([]byte, 4)
_, err = s.conn.Read(header)
if err != nil {
fmt.Printf("read conn header data failed, err: %v\n", err)
return
}
// 读取body数据
hlen := binary.BigEndian.Uint32(header)
data = make([]byte, hlen)
_, err = io.ReadFull(s.conn, data)
if err != nil {
fmt.Printf("read conn body data failed, err: %v\n", err)
return
}
return
}
// 向连接中写入数据
func (s *Session) Write(data []byte) (err error) {
// 创建数据字节切片
buf := make([]byte, 4+len(data))
// 向header写入数据长度
binary.BigEndian.PutUint32(buf[:4], uint32(len(data)))
// 写入body内容
copy(buf[4:], data)
// 写入连接数据
_, err = s.conn.Write(buf)
if err != nil {
fmt.Printf("write conn data failed, err: %v\n", err)
return
}
return
}
coder_test.go
package rpc
import (
"testing"
)
func TestCoder(t *testing.T) {
rd := RpcData{
Name: "login",
Args: []interface{}{"zhangsan", "zs123"},
}
eData, err := encode(rd)
if err != nil {
t.Error(err)
return
}
t.Logf("gob 编码后数据长度: %d\n", len(eData))
dData, err := decode(eData)
if err != nil {
t.Error(err)
return
}
t.Logf("%#v\n", dData)
}
session_test.go
package rpc
import (
"net"
"sync"
"testing"
)
func TestSession(t *testing.T) {
addr := ":8080"
test_data := "my is test data"
var wg sync.WaitGroup
wg.Add(2)
// 写数据
go func() {
defer wg.Done()
listener, err := net.Listen("tcp", addr)
if err != nil {
t.Fatal(err)
return
}
conn, _ := listener.Accept()
s := NewSession(conn)
data, err := s.Read()
if err != nil {
t.Error(err)
return
}
t.Log(string(data))
}()
// 读数据
go func() {
defer wg.Done()
conn, err := net.Dial("tcp", addr)
if err != nil {
t.Fatal(err)
return
}
s := NewSession(conn)
err = s.Write([]byte(test_data))
if err != nil {
return
}
t.Log("写入数据成功")
return
}()
wg.Wait()
}
rpc_test.go
package rpc
import (
"encoding/gob"
"fmt"
"net"
"testing"
)
// rpc 客户端和服务端测试
// 定义一个服务端结构体
// 定义一个方法
// 通过调用rpc方法查询用户的信息
type User struct {
Name string
Age int
}
// 定义查询用户的方法
// 通过用户id查询用户数据
func queryUser(id int) (User, error) {
// 造一些查询user的假数据
users := make(map[int]User)
users[0] = User{"user01", 22}
users[1] = User{"user02", 23}
users[2] = User{"user03", 24}
if u, ok := users[id]; ok {
return u, nil
}
return User{}, fmt.Errorf("%d id not found", id)
}
func TestRpc(t *testing.T) {
// 给gob注册类型
gob.Register(User{})
addr := ":8080"
// 创建服务端
server := NewServer(addr)
// 注册服务
server.Register("queryUser", queryUser)
// 启动服务端
go server.Run()
// 创建客户端连接
conn, err := net.Dial("tcp", addr)
if err != nil {
return
}
// 创客户端
client := NewClient(conn)
// 定义函数调用原型
var query func(int) (User, error)
// 客户端调用rpc
client.RpcCall("queryUser", &query)
// 得到返回结果
user, err := query(1)
if err != nil {
t.Error(err)
return
}
fmt.Printf("%#v\n", user)
}
原文地址:https://www.cnblogs.com/zhichaoma/p/12638184.html
时间: 2024-10-08 12:44:42