go动态创建/增加channel并处理数据

发布于:2024-05-07 ⋅ 阅读:(27) ⋅ 点赞:(0)

背景描述

有一个需求,大概可以描述为:有多个websocket连接,因此消息会并发地发送过来,这些消息中有一个标志可以表明是哪个连接发来的消息,但只有收到消息后才能建立channel或写入已有channel,在收消息前无法预先创建channel

解决过程(可直接阅读最终版)

初版:直接写入

因为对数据量错误预估(以为数据量不大),一开始我是用的mysql直接写入,每次收到ws消息立即处理,可测试中发现因数据量过多且都会操作同一行数据,出现了资源竞争,导致死锁。

第二版:增加锁

在发现出现数据竞争后,我第一反应是增加读写锁。读写锁的代码类似以下示例:

package main

import (
	"database/sql"
	"fmt"
	"sync"

	_ "github.com/go-sql-driver/mysql"
)

var (
	db *sql.DB
	mu sync.RWMutex
)

func init() {
	var err error
	db, err = sql.Open("mysql", "username:password@tcp(localhost:3306)/dbname")
	if err != nil {
		panic(err)
	}
}

func main() {
	defer db.Close()

	// 读取数据
	go readData()

	// 写入数据
	go writeData()

	// 保持主线程运行
	select {}
}

func readData() {
	for {
		mu.RLock()
		rows, err := db.Query("SELECT * FROM table_name")
		mu.RUnlock()
		if err != nil {
			fmt.Println("Error reading data:", err)
			continue
		}
		defer rows.Close()

		// 处理查询结果
		// ...

		// 睡眠一段时间,模拟读操作的持续性
		// 请注意,这是一个简单示例,实际应用中可能需要更复杂的逻辑
		// 或使用定时器进行控制
	}
}

func writeData() {
	for {
		mu.Lock()
		_, err := db.Exec("INSERT INTO table_name (column1, column2) VALUES (?, ?)", value1, value2)
		mu.Unlock()
		if err != nil {
			fmt.Println("Error writing data:", err)
			continue
		}

		// 睡眠一段时间,模拟写操作的持续性
		// 请注意,这是一个简单示例,实际应用中可能需要更复杂的逻辑
		// 或使用定时器进行控制
	}
}

但是代码里对数据库的操作非常频繁且混乱,加了读写锁后经常出现请求很慢的情况,考虑其他方案

第三版 使用事务

使用事务代码忽略,最终发现,因为事务过长,导致出现了重复写的问题,考虑其他方案

第四版 map

通过一个二维的map来存储数据,每当数据存满10条就处理,当然毫不意外的,出现了map的竞争。map也是可以用锁的,但是这里是二维的map,加上两层锁之后使得效率极低,而且依旧有概率出现map竞争导致报错

此外,还可以考虑使用redis设置锁,直接set就行了,但是因为环境不支持redis,此方案弃用

最终版 动态channel

出现以上问题的根本原因是消费太快,其实完全可以把每个ws连接的数据都写到各自的channel里,同时设置每个channel都累积10条再消费,当然还需要一个处理机制,如果超过10s也消费一次。

启动"生产者"、“消费者”

在当前环境中,生产者就是每次从ws中读到数据往动态channel中写入,消费者就是不断获取有哪些channel,以及从channel中读数据,在ws写入时的处理逻辑大概可以简化为如下demo:

package test

import (
	"context"
	"encoding/json"
	"github.com/gin-gonic/gin"
	"github.com/gorilla/websocket"
	log "github.com/sirupsen/logrus"
	"net/http"
	"sync"
)

// RequestTemplate 请求模板
type RequestTemplate struct {
	Op   string               `json:"op"`   // 操作
	Id   int                  `json:"id"`   // 唯一id标识
	Time string               `json:"time"` // 时间,用秒级时间戳,字符串包裹
	Data *RequestTemplateData `json:"data"` // 请求数据
	Code int                  `json:"code"` // 状态码
}

// RequestTemplateData 请求中data包含的部分,实际这里是很复杂的结构,之前超时/死锁也是因为这里处理逻辑比较复杂,但是这篇博客的演示重点不是这个,因此简略为id和请求ip
type RequestTemplateData struct {
	ConnIp string `json:"conn_ip"` // 请求ip
	Id     int    `json:"id"`      // 唯一id标识
}

// ConnInfo 具体的连接信息
type ConnInfo struct {
	Conn      *websocket.Conn    `json:"conn"`   // websocket连接
	Ctx       context.Context    `json:"ctx"`    // 连接上下文
	CtxCancel context.CancelFunc `json:"cancel"` // 连接上下文cancel function
	Ip        string             `json:"ip"`     // 连接的手机端ip
	Id        int                `json:"id"`     // 唯一id标识
}

var AllConns = make(map[string]*ConnInfo) //创建字典集合存储连接信息

// Start 启动
func Start() {
	//处理ws的连接
	http.HandleFunc("/ws", HandleMsg)

	// //监听7001端口号,作为websocket连接的服务
	log.Info("Server started on :7001")
	log.Fatal(http.ListenAndServe(":7001", nil))
}

// ChannelStorage channel数据
type ChannelStorage struct {
	sync.RWMutex
	channels map[string]chan *RequestTemplateData
}

var ConnRequestData map[int]*RequestTemplateData

var upgrader = websocket.Upgrader{
	CheckOrigin: func(r *http.Request) bool {
		return true
	},
}

// HandleMsg 处理ws连接,每来一个新客户端请求就建立一个新连接
func HandleMsg(w http.ResponseWriter, r *http.Request) {
	conn, err := upgrader.Upgrade(w, r, nil) // 协议升级,这里也可以直连
	if err != nil {
		log.Error(err)
		return
	}
	//获取连接ip,这里是为了区分每个连接
	connIp := conn.RemoteAddr().String()
	// 这里是为了后续关闭channel
	rootCtx := context.Background()
	ctx, cancel := context.WithCancel(rootCtx)
	//加入连接
	AllConns[connIp] = &ConnInfo{
		Conn:      conn,   // 客户端ws链接对象
		Ctx:       ctx,    // 连接上下文
		CtxCancel: cancel, // 取消连接上下文
	}

	defer func() {
		// 如果断开连接,删除数据
		if AllConns[connIp] != nil {
			AllConns[connIp].CtxCancel()
			delete(ConnRequestData, AllConns[connIp].Id)
			go SetDoneData(AllConns[connIp].Id, conn) // 这里对结束做处理
		}
		delete(AllConns, conn.RemoteAddr().String())
		err = conn.Close()
		if err != nil {
			return
		}
		log.Error("HandleMsg异常,开始defer处理:", err)
		if err := recover(); err != nil {
			log.Error("websocket连接异常,已断开:", err)
		}
	}()

	log.WithFields(log.Fields{
		"connIp": connIp,
	}).Info("沙箱已连接")
	reqCh := &ChannelStorage{}
	go reqCh.ResultConsumer(ctx) // 这里是消费者
	//循环读取ws客户端的消息
	for {
		// 读取消息
		_, msg, err := conn.ReadMessage()
		if err != nil {
			log.WithFields(log.Fields{
				"connIp": connIp,
			}).WithError(err).Error("读取websocket的消息失败")
			if AllConns[connIp] != nil {
				delete(ConnRequestData, AllConns[connIp].Id)
				go SetDoneData(AllConns[connIp].Id, conn) // 连接断开设置状态为结束
			}
			// 断开ws连接
			conn.Close()
			delete(AllConns, conn.RemoteAddr().String())
			return
		}

		//msg []byte转string
		msgStr := string(msg)
		log.Info("收到消息为:", msgStr)

		//反序列化消息为结构体
		requestData := RequestTemplate{}
		if err := json.Unmarshal(msg, &requestData); err != nil {
			conn.WriteJSON(gin.H{"sandbox_id": "位置", "cmd": "未知", "error": "cmd通信的请求参数有误,无法json decode"})
			log.Error("json_decode cmd命令的请求参数时出错:", err)
			continue
		}
		dataInfo := requestData.Data
		// 这里实际上有很多操作,简写为两种
		if requestData.Op != "" {
			switch requestData.Op {
			// 收到报告
			case "report":
				go reqCh.Produce(dataInfo) // "生产者",发送一条消息
			// 已完成
			case "done":
				go CheckDone(dataInfo, conn) // 做完成的处理
			default:
				log.Error("未识别的命令:", msgStr)
			}
		}
	}
}

有一个for循环在持续监听ws消息,消费者只启动一次,这里重点就是生产和消费如何实现

“生产者”

“生产者”要做的事就是:
1 每当收到ws消息后,解析,拿到唯一id(这个唯一是指这个连接下的所有上报消息的id都是相同的)
2 判断这个“唯一id”是否已经创建了channel,若创建了则不需要创建,直接写入channel,若未创建则新建channel
以下是生产者的demo:

// GetChannel 获取通道
func (cs *ChannelStorage) GetChannel(key string) chan *RequestTemplateData {
	cs.RLock()
	defer cs.RUnlock()
	return cs.channels[key]
}

// CreateChannel 创建通道并存储到 map 中
func (cs *ChannelStorage) CreateChannel(key string) chan *RequestTemplateData {
	cs.Lock()
	defer cs.Unlock()
	if cs.channels == nil {
		cs.channels = make(map[string]chan *RequestTemplateData, 800)
	}
	ch := make(chan *RequestTemplateData, 10)
	cs.channels[key] = ch
	return ch
}

// Produce 往上报channel中写数据
func (cs *ChannelStorage) Produce(requestData *RequestTemplateData) {
	defer func() {
		if err := recover(); err != nil {
			log.Info("_____________recover CaseResultAdd error________: ", err)
		}
	}()
	// 创建存储通道的结构体实例
	chanelKey := strconv.Itoa(requestData.Id)
	channel := cs.GetChannel(chanelKey)
	if channel == nil {
		channel = cs.CreateChannel(chanelKey)
	}
	// 直接往channel里面塞
	if channel != nil {
		channel <- requestData
	}
}
消费者

消费者由于只启动一次,但后续可能会有新的channel,因此需要增加一个获取所有连接的方法:
消费者demo:

func (cs *ChannelStorage) ResultConsumer(ctx context.Context) {
	defer func() {
		if err := recover(); err != nil {
			log.Info("_____________recover CaseResultConsumer error________: ", err)
		}
	}()

	for {
		select {
		case <-ctx.Done():
			log.Info("websocket断开连接,消费者协程退出...")
			return
		default:
			cs.processAllChannels(ctx)  // 传入 context.Context
			time.Sleep(2 * time.Second) // 控制处理频率
		}
	}
}

// processAllChannels 获取所有channel
func (cs *ChannelStorage) processAllChannels(ctx context.Context) {
	cs.RLock()
	defer cs.RUnlock()

	var wg sync.WaitGroup // 用于等待所有通道处理完毕
	for chName, channel := range cs.channels {
		wg.Add(1)
		go func(chName string, channel chan *RequestTemplateData) {
			defer wg.Done()
			cs.processChannel(chName, channel, ctx)
		}(chName, channel)
	}
	wg.Wait() // 等待所有通道处理完毕
}
func (cs *ChannelStorage) processChannel(chName string, channel chan *RequestTemplateData, ctx context.Context) {
	const batchSize = 10 // 每次处理的数据量

	var messages []*RequestTemplateData
	targetMsgOverTime := 10 * time.Second // 超时时间
	for {
		select {
		case caseMsg := <-channel:
			messages = append(messages, caseMsg) // 将接收到的消息放入 messages 切片中
			if len(messages) == batchSize {
				tmpMessages := messages
				messages = nil
				processMessages(tmpMessages)
			}
		case <-time.After(targetMsgOverTime):
			log.Info("Timeout reached. Processing...")
			if len(messages) > 0 {
				tmpMessages := messages
				messages = nil
				log.Info("Processing remaining messages for channel:", chName)
				processMessages(tmpMessages)
			}
		case <-ctx.Done(): // 如果收到上下文取消信号,退出函数
			log.Info("______________________error__________cancel______")
			return
		}
	}
}

func processMessages(messages []*RequestTemplateData) {
	// 在这里处理消息就是批量的了
}