OpenAI-Gateway设计与实现
项目背景
微服务API网关通常可以反向代理系统内部的服务,保证内部服务对外不可见和安全。网关作为统一的流量入口,还可以实现负载均衡等、流量限制、消峰平谷、节点路由等功能,与Prometheus结合时还可以提供所有请求的监控统计功能。
在maas(模型即服务)系统的构建过程中,底层使用Ollamm和VLLM等框架推理模型并提供服务。使用K8s实现推理Pod和Server的管理,方便挂载模型文件、与推理日志的采集。在启用多个模型服务之后,需要对外提供统一的模型接口,统计每个模型调用的token使用情况。网关需要实现以下功能:
- 支持OpenAI的标准接口调用,如chat
- 通过请求中模型名称转发到后端的推理服务
- 相同模型推理服务的后端节点的自动选择
- 校验apikey的正确性
- 记录首token、单对话token数量等的统计
- 每个模型的调用统计
项目实现
相关技术
web框架:fiber
fiber基于fasthttp实现,fasthttp引入协程池,对比net/http具有更高的性能。但截至当前暂不支持http2,需要配合其他包实现。
中间件:PostgresqlDb、Kafka、Influxdb、Redis、Prometheus、Grafana
部署方式:docker镜像、K8s、kubevela CI
请求流程

执行步骤:
- 用户使用标准OpenAI接口调用模型,向网关发送请求
- 网关获取请求之后,校验apikey并根据请求参数找到真实的模型地址
- 网关创建一个客户端,使用客户端请求模型
- 客户端得到响应之后,根据需求分析响应的内容(计算token)
- 客户端将响应赋值给网关的响应
- 网关返回响应给用户
项目架构
客户端架构

上图是项目gateway中client的实现,具体实现可以参考在项目中的内容。
重点分析:大模型的调用一般返回的都是流式响应,客户端获取的响应和网关返回给用户的响应都是流式数据,在fasthttp的实现中,响应的读取使用io read实现,读取客户端的响应就会将响应读取完,不能赋值给网关的响应。具体的请求转发逻辑的实现,在那个过程实现。可以封装fasthttp client在其之上实现,也可以直接在网关层实现。
问题解决
流式返回的内容在网关读取的同时,也需要返回给用户。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
| func (c *client) Do(ctx *fiber.Ctx) error { n := c.applier.nodes[0] apikey := ctx.Get("Authorization")[7:] modelName := ctx.Locals("modelName").(string) reqBody := tools.Analysis(ctx)
clientReq := fasthttp.AcquireRequest() clientRes := fasthttp.AcquireResponse() resp := ctx.Response() req := ctx.Request() resp.StreamBody = true clientRes.StreamBody = true clientReq = ctx.Request() originalURL := utils.CopyString(string(req.RequestURI())) defer clientReq.SetRequestURI(originalURL) copiedURL := utils.CopyString(n.address) clientReq.SetRequestURI(copiedURL) if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 { clientReq.URI().SetSchemeBytes(scheme) } err := n.client.Do(clientReq, clientRes) if err != nil { GatewayLog.Error().Str("ERROR_CODE", e.GetMsg(e.ERROR)).Msg(err.Error()) return err } bodyStream := clientRes.BodyStream() if bodyStream == nil { return errors.New("response body stream is nil") } pr, pw := io.Pipe() startTime := time.Now() EndMessage := "" CompletionTokens := 0
go func() { defer func() { err := pw.Close() if err != nil { GatewayLog.Error().Str("ERROR_CODE", e.GetMsg(e.ERROR)).Msg(err.Error()) } err = clientRes.CloseBodyStream() if err != nil { GatewayLog.Error().Str("ERROR_CODE", e.GetMsg(e.ERROR)).Msg(err.Error()) } }() buf := make([]byte, 1024) for { n, readErr := bodyStream.Read(buf) if readErr != nil && readErr != io.EOF { GatewayLog.Error().Str("ERROR_CODE", e.GetMsg(e.ERROR)).Msg(readErr.Error()) return } _, writeErr := pw.Write(buf[:n]) if writeErr != nil { GatewayLog.Error().Str("ERROR_CODE", e.GetMsg(e.ERROR)).Msg(writeErr.Error()) return } CompletionTokens += 1 if readErr == io.EOF { EndMessage = string(buf[:n]) break } } endTime := time.Now()
go Handler(apikey, modelName, endTime.Sub(startTime), EndMessage, CompletionTokens, reqBody, clientRes.StatusCode(), )
}() ctx.Set("Content-Type", "text/event-stream") resp.SetBodyStream(pr, -1) return nil }
|
在代码中使用pr, pw := io.Pipe()
管道从clientRes.BodyStream
中读取流式数据并做相关的处理,同时将读到的内容写到 pw.Write(buf[:n])
,网关从管道中读到网关的响应resp.SetBodyStream(pr, -1)
。