...
|
...
|
@@ -2,11 +2,13 @@ package chat |
|
|
|
|
|
import (
|
|
|
"context"
|
|
|
"errors"
|
|
|
"github.com/gorilla/websocket"
|
|
|
"github.com/zeromicro/go-zero/core/fx"
|
|
|
"github.com/zeromicro/go-zero/rest/httpx"
|
|
|
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/internal/pkg/domain"
|
|
|
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/system/open"
|
|
|
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/ai"
|
|
|
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/contextdata"
|
|
|
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/transaction"
|
|
|
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/xerr"
|
...
|
...
|
@@ -54,6 +56,7 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon |
|
|
if session.Module == domain.ModuleSparkChat {
|
|
|
req.ModelId = domain.SparkChatDocModelId
|
|
|
}
|
|
|
var cs = &ConversationService{}
|
|
|
var answer string
|
|
|
var upgrader = websocket.Upgrader{
|
|
|
ReadBufferSize: 1024,
|
...
|
...
|
@@ -64,6 +67,7 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon |
|
|
}
|
|
|
var wsconn *websocket.Conn
|
|
|
wsconn, err = upgrader.Upgrade(w, r, nil)
|
|
|
wsconn.SetPingHandler(nil)
|
|
|
if err != nil {
|
|
|
httpx.ErrorCtx(r.Context(), w, err)
|
|
|
return
|
...
|
...
|
@@ -71,12 +75,33 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon |
|
|
defer func() {
|
|
|
wsconn.Close()
|
|
|
}()
|
|
|
|
|
|
// 心跳
|
|
|
//go func() {
|
|
|
// ticker := time.NewTicker(30 * time.Second)
|
|
|
// defer ticker.Stop()
|
|
|
// for {
|
|
|
// select {
|
|
|
// case <-ticker.C:
|
|
|
// if err := wsconn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
|
// logx.Error("Ping error:", err)
|
|
|
// return
|
|
|
// }
|
|
|
// }
|
|
|
// }
|
|
|
//}()
|
|
|
history := make([]ai.Message, 0)
|
|
|
for {
|
|
|
var text []byte
|
|
|
_, text, err = wsconn.ReadMessage()
|
|
|
if err != nil {
|
|
|
break
|
|
|
}
|
|
|
if string(text) == "ping" {
|
|
|
logx.Infof("->> 收到心跳 用户:%s 会话:%s %s", user.Name, session.Title, string(text))
|
|
|
wsconn.WriteMessage(websocket.PongMessage, text)
|
|
|
continue
|
|
|
}
|
|
|
var beginUnix = time.Now().UnixMilli()
|
|
|
var channel = make(chan string, 5)
|
|
|
dm = &domain.ChatSessionRecord{
|
...
|
...
|
@@ -93,13 +118,12 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon |
|
|
}
|
|
|
fx.Parallel(func() {
|
|
|
// 异步访问AI接口
|
|
|
answer, err = Conversation(l.ctx, l.svcCtx, session, req.ModelId, string(text), channel)
|
|
|
answer, err = cs.Conversation(l.ctx, l.svcCtx, session, req.ModelId, string(text), channel, history...)
|
|
|
}, func() {
|
|
|
for {
|
|
|
var v string
|
|
|
if v, ok = <-channel; ok {
|
|
|
if err = wsconn.WriteJSON(types.ChatSessionConversationResponse{Parts: []string{v}, Finished: false}); err != nil {
|
|
|
//httpx.ErrorCtx(r.Context(), w, err)
|
|
|
dm.Status = domain.FinishedFail
|
|
|
break
|
|
|
}
|
...
|
...
|
@@ -110,7 +134,14 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon |
|
|
}
|
|
|
return
|
|
|
})
|
|
|
|
|
|
history = append(history, ai.Message{
|
|
|
Role: "user",
|
|
|
Content: string(text),
|
|
|
})
|
|
|
//, ai.Message{
|
|
|
// Role: "assistant",
|
|
|
// Content: string(answer),
|
|
|
// }
|
|
|
// 记录
|
|
|
dm.AnswerText = answer
|
|
|
dm.Cost = time.Now().UnixMilli() - beginUnix
|
...
|
...
|
@@ -127,6 +158,11 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon |
|
|
}
|
|
|
|
|
|
if err != nil {
|
|
|
var closeErr *websocket.CloseError
|
|
|
if ok = errors.As(err, &closeErr); ok {
|
|
|
logx.Info("Close connect->", err.Error())
|
|
|
return nil, nil
|
|
|
}
|
|
|
return nil, xerr.NewErrMsgErr("AI模型异常,稍后再试", err)
|
|
|
}
|
|
|
return
|
...
|
...
|
|