作者 yangfu

fix 1.0

... ... @@ -4,7 +4,7 @@ Port: 8080
Verbose: false
Migrate: true
Timeout: 30000
Timeout: 60000
LogRequest: true # 记录详细请求日志
Log:
... ... @@ -25,5 +25,5 @@ Redis:
Type: node
Pass:
DB:
DataSource: host=114.55.200.59 user=postgres password=eagle1010 dbname=su_enterprise_platform_preonline port=31543 sslmode=disable TimeZone=Asia/Shanghai
DataSource: host=114.55.200.59 user=postgres password=eagle1010 dbname=su_enterprise_platform port=31543 sslmode=disable TimeZone=Asia/Shanghai
... ...
package chat
import (
"context"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/contextdata"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/xerr"
"net/http"
"github.com/zeromicro/go-zero/rest/httpx"
... ... @@ -16,8 +19,15 @@ func ChatSessionConversationWsHandler(svcCtx *svc.ServiceContext) http.HandlerFu
httpx.ErrorCtx(r.Context(), w, err)
return
}
l := chat.NewChatSessionConversationWsLogic(r.Context(), svcCtx)
var (
ctx context.Context
err error
)
if ctx, err = contextdata.ParseToken(r.Context(), svcCtx.Config.SystemAuth.AccessSecret, req.Token); err != nil {
httpx.WriteJson(w, http.StatusUnauthorized, xerr.Error(xerr.TokenExpireError, err.Error()))
return
}
l := chat.NewChatSessionConversationWsLogic(ctx, svcCtx)
resp, err := l.ChatSessionConversationWs(w, r, &req)
if err != nil {
httpx.ErrorCtx(r.Context(), w, err)
... ...
... ... @@ -53,11 +53,6 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
Handler: chat.ChatSessionConversationHandler(serverCtx),
},
{
Method: http.MethodGet,
Path: "/chat/session/conversation",
Handler: chat.ChatSessionConversationWsHandler(serverCtx),
},
{
Method: http.MethodPost,
Path: "/chat/session/add_files",
Handler: chat.ChatSessionAddFilesHandler(serverCtx),
... ... @@ -98,6 +93,20 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
[]rest.Middleware{serverCtx.LogRequest},
[]rest.Route{
{
Method: http.MethodGet,
Path: "/chat/session/conversation",
Handler: chat.ChatSessionConversationWsHandler(serverCtx),
},
}...,
),
rest.WithPrefix("/v1"),
)
server.AddRoutes(
rest.WithMiddlewares(
[]rest.Middleware{serverCtx.LogRequest},
[]rest.Route{
{
Method: http.MethodPost,
Path: "/chat/data/session/search",
Handler: chat.ChatDataSessionSearchHandler(serverCtx),
... ...
... ... @@ -56,6 +56,7 @@ func (l *ChatSessionConversationLogic) ChatSessionConversation(w http.ResponseWr
if session.Module == domain.ModuleSparkChat {
req.ModelId = domain.SparkChatDocModelId
}
cs := &ConversationService{}
dm = &domain.ChatSessionRecord{
CompanyId: token.CompanyId,
UserId: token.UserId,
... ... @@ -72,7 +73,7 @@ func (l *ChatSessionConversationLogic) ChatSessionConversation(w http.ResponseWr
var channel = make(chan string, 5)
// 异步访问AI接口
fx.Parallel(func() {
answer, err = Conversation(l.ctx, l.svcCtx, session, req.ModelId, req.Text, channel)
answer, err = cs.Conversation(l.ctx, l.svcCtx, session, req.ModelId, req.Text, channel)
}, func() {
for {
if _, ok = <-channel; !ok {
... ... @@ -104,8 +105,11 @@ func (l *ChatSessionConversationLogic) ChatSessionConversation(w http.ResponseWr
return
}
type ConversationService struct {
}
// Conversation 普通对话
func Conversation(ctx context.Context, svc *svc.ServiceContext, session *domain.ChatSession, modelId int64, text string, channel chan string) (answer string, err error) {
func (c *ConversationService) Conversation(ctx context.Context, svc *svc.ServiceContext, session *domain.ChatSession, modelId int64, text string, channel chan string, messages ...ai.Message) (answer string, err error) {
var (
m *domain.ChatModel
ok bool
... ... @@ -121,9 +125,9 @@ func Conversation(ctx context.Context, svc *svc.ServiceContext, session *domain.
switch m.Id {
// 星火3.5
case 1, 2, 3:
answer, err = ai.ChatGPT(m.Code, m.Config.AppKey, text, channel)
answer, err = ai.ChatGPT(m.Code, m.Config.AppKey, text, channel, messages...)
case 4:
answer, err = ai.ChatSpark(m.Config.AppId, m.Config.AppKey, m.Config.AppSecret, text, channel)
answer, err = ai.ChatSpark(m.Config.AppId, m.Config.AppKey, m.Config.AppSecret, text, channel, messages...)
}
if err != nil {
return "", err
... ...
... ... @@ -2,11 +2,13 @@ package chat
import (
"context"
"errors"
"github.com/gorilla/websocket"
"github.com/zeromicro/go-zero/core/fx"
"github.com/zeromicro/go-zero/rest/httpx"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/internal/pkg/domain"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/system/open"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/ai"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/contextdata"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/transaction"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/xerr"
... ... @@ -54,6 +56,7 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon
if session.Module == domain.ModuleSparkChat {
req.ModelId = domain.SparkChatDocModelId
}
var cs = &ConversationService{}
var answer string
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
... ... @@ -64,6 +67,7 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon
}
var wsconn *websocket.Conn
wsconn, err = upgrader.Upgrade(w, r, nil)
wsconn.SetPingHandler(nil)
if err != nil {
httpx.ErrorCtx(r.Context(), w, err)
return
... ... @@ -71,12 +75,33 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon
defer func() {
wsconn.Close()
}()
// 心跳
//go func() {
// ticker := time.NewTicker(30 * time.Second)
// defer ticker.Stop()
// for {
// select {
// case <-ticker.C:
// if err := wsconn.WriteMessage(websocket.PingMessage, nil); err != nil {
// logx.Error("Ping error:", err)
// return
// }
// }
// }
//}()
history := make([]ai.Message, 0)
for {
var text []byte
_, text, err = wsconn.ReadMessage()
if err != nil {
break
}
if string(text) == "ping" {
logx.Infof("->> 收到心跳 用户:%s 会话:%s %s", user.Name, session.Title, string(text))
wsconn.WriteMessage(websocket.PongMessage, text)
continue
}
var beginUnix = time.Now().UnixMilli()
var channel = make(chan string, 5)
dm = &domain.ChatSessionRecord{
... ... @@ -93,13 +118,12 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon
}
fx.Parallel(func() {
// 异步访问AI接口
answer, err = Conversation(l.ctx, l.svcCtx, session, req.ModelId, string(text), channel)
answer, err = cs.Conversation(l.ctx, l.svcCtx, session, req.ModelId, string(text), channel, history...)
}, func() {
for {
var v string
if v, ok = <-channel; ok {
if err = wsconn.WriteJSON(types.ChatSessionConversationResponse{Parts: []string{v}, Finished: false}); err != nil {
//httpx.ErrorCtx(r.Context(), w, err)
dm.Status = domain.FinishedFail
break
}
... ... @@ -110,7 +134,14 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon
}
return
})
history = append(history, ai.Message{
Role: "user",
Content: string(text),
})
//, ai.Message{
// Role: "assistant",
// Content: string(answer),
// }
// 记录
dm.AnswerText = answer
dm.Cost = time.Now().UnixMilli() - beginUnix
... ... @@ -127,6 +158,11 @@ func (l *ChatSessionConversationWsLogic) ChatSessionConversationWs(w http.Respon
}
if err != nil {
var closeErr *websocket.CloseError
if ok = errors.As(err, &closeErr); ok {
logx.Info("Close connect->", err.Error())
return nil, nil
}
return nil, xerr.NewErrMsgErr("AI模型异常,稍后再试", err)
}
return
... ...
... ... @@ -110,6 +110,7 @@ type ChatSessionConversationResponse struct {
}
type ChatSessionConversationRequestWs struct {
Token string `form:"token,optional"` // 授权的token
SessionId int64 `form:"sessionId"` // 会话ID
ModelId int64 `form:"modelId"` // 模型ID
ContentType string `form:"contentType"` // 内容类型 文本:text (图片:image 文档:document)
... ...
... ... @@ -29,9 +29,9 @@ service Core {
@doc "聊天会话-对话"
@handler chatSessionConversation
post /chat/session/conversation (ChatSessionConversationRequest) returns (ChatSessionConversationResponse)
@doc "聊天会话-对话"
@handler chatSessionConversationWs
get /chat/session/conversation (ChatSessionConversationRequestWs) returns (ChatSessionConversationResponse)
// @doc "聊天会话-对话"
// @handler chatSessionConversationWs
// get /chat/session/conversation (ChatSessionConversationRequestWs) returns (ChatSessionConversationResponse)
@doc "聊天会话-添加文件"
@handler chatSessionAddFiles
... ... @@ -56,6 +56,18 @@ service Core {
get /chat/models (ChatModelsRequest) returns (ChatModelsResponse)
}
// 后台接口
@server(
prefix: v1
group: chat
middleware: LogRequest
)
service Core {
@doc "聊天会话-对话"
@handler chatSessionConversationWs
get /chat/session/conversation (ChatSessionConversationRequestWs) returns (ChatSessionConversationResponse)
}
// 数据管理后台接口
@server(
... ... @@ -183,6 +195,7 @@ type(
Finished bool `json:"finished"`
}
ChatSessionConversationRequestWs{
Token string `form:"token,optional"` // 授权的token
SessionId int64 `form:"sessionId"` // 会话ID
ModelId int64 `form:"modelId"` // 模型ID
ContentType string `form:"contentType"` // 内容类型 文本:text (图片:image 文档:document)
... ...
... ... @@ -2,6 +2,7 @@ package repository
import (
"context"
"fmt"
"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/tiptok/gocomm/pkg/cache"
... ... @@ -187,7 +188,12 @@ func (repository *ChatSessionRecordRepository) FindByCompanyUser(ctx context.Con
total int64
)
queryFunc := func() (interface{}, error) {
tx = tx.Model(&ms).Order("id asc")
tx = tx.Model(&ms)
if v, ok := queryOptions["orderById"]; ok {
tx.Order(fmt.Sprintf("id %v", v))
} else {
tx.Order("id asc")
}
tx.Where("company_id = ?", companyId).Where("user_id = ?", userId)
if v, ok := queryOptions["sessionId"]; ok {
tx.Where("session_id =?", v)
... ...
... ... @@ -50,13 +50,14 @@ func (l *SystemEmployeeUpdateLogic) SystemEmployeeUpdate(req *types.EmployeeUpda
if employee.Code != req.Employee.Code {
// 员工编码唯一
if req.Employee.Code != "" {
if employee, err = l.svcCtx.EmployeeRepository.FindOneByCode(l.ctx, conn, token.CompanyId, req.Employee.Code); err == nil {
var foundEmployee *domain.SysEmployee
if foundEmployee, err = l.svcCtx.EmployeeRepository.FindOneByCode(l.ctx, conn, token.CompanyId, req.Employee.Code); foundEmployee != nil && err == nil {
return nil, xerr.NewErrMsgErr("工号重复", err)
}
}
}
if err = transaction.MustUseTrans(l.ctx, l.svcCtx.DB, func(ctx context.Context, conn transaction.Conn) error {
// 修改姓名
// 修改用户姓名
if user.Name != req.Employee.Name && req.Employee.Phone == user.Phone {
user.Name = req.Employee.Name
if user, err = l.svcCtx.UserRepository.UpdateWithVersion(l.ctx, conn, user); err != nil {
... ...
... ... @@ -7,3 +7,24 @@ update user_department
set employee_id = employee.id
from employee
where employee.user_id = user_department.user_id;
/*
-- suplus_enterprise
-- 用户表
select user_id id,name,phone,avatar,UNIX_TIMESTAMP() created_at,UNIX_TIMESTAMP() updated_at,0 deleted_at,0 version,0 is_del from user_info
where company_id = 1 and user_id <>0 and name ='林忠'
-- 职员表
select uid id,user_id,company_id,'' code,(case when status=1 then 1 else 2 end) account_status,'正式' employee_type,'{}' base_info,'{}' work_info,UNIX_TIMESTAMP() created_at,UNIX_TIMESTAMP() updated_at,0 deleted_at,0 version,0 is_del from user_info
where company_id = 1 and user_id <>0
-- suplus_business_admin
-- 部门表
select id,company_id,name,'' code,parent_id,'[]' department_heads,false is_root,1 sort,UNIX_TIMESTAMP() created_at,UNIX_TIMESTAMP() updated_at,0 deleted_at,0 version,0 is_del from departments where company_id = 1
-- 用户部门表
select company_id,department_id,user_id,0 employee_id,UNIX_TIMESTAMP() created_at from user_departments where company_id = 1 and deleted_at is null
-- 关联职员表更新职员id
*/
\ No newline at end of file
... ...
apiVersion: v1
kind: ConfigMap
metadata:
name: sumicro-chat-config-prd
data:
config.yml: |
Name: sumicro-chat-prd
Host: 0.0.0.0
Port: 8080
Verbose: false
Migrate: true
Timeout: 30000
LogRequest: true # 记录详细请求日志
Log:
Mode: file
Encoding: plain
Level: debug # info
MaxSize: 1 # 2MB
TimeFormat: 2006-01-02 15:04:05
Rotation: size
MaxContentLength: 10240
SystemAuth:
AccessSecret: su-platform
AccessExpire: 360000
Redis:
Host: 192.168.0.243:6379
Type: node
Pass:
DB:
DataSource: host=114.55.200.59 user=postgres password=eagle1010 dbname=su_enterprise_platform port=31544 sslmode=disable TimeZone=Asia/Shanghai
---
apiVersion: v1
kind: ConfigMap
metadata:
name: sumicro-system-config-prd
data:
config.yml: |
Name: sumicro-system-prd
Host: 0.0.0.0
Port: 8080
Verbose: false
Migrate: true
Timeout: 30000
LogRequest: true
Log:
Mode: file
Encoding: plain
Level: debug
MaxSize: 1
TimeFormat: 2006-01-02 15:04:05
Rotation: size
MaxContentLength: 10240
SystemAuth:
AccessSecret: su-platform
AccessExpire: 360000
Redis:
Host: 192.168.0.243:6379
Type: node
Pass:
DB:
DataSource: host=114.55.200.59 user=postgres password=eagle1010 dbname=su_enterprise_platform port=31544 sslmode=disable TimeZone=Asia/Shanghai
... ...
... ... @@ -9,7 +9,7 @@ import (
"io"
)
func ChatGPT(gptModelCode, key string, question string, channel chan string) (answer string, err error) {
func ChatGPT(gptModelCode, key string, question string, channel chan string, messages ...Message) (answer string, err error) {
config := openai.DefaultConfig(key)
config.BaseURL = "http://47.251.84.160:8080/v1" // "https://api.openai.com/v1" //
c := openai.NewClientWithConfig(config)
... ... @@ -18,14 +18,19 @@ func ChatGPT(gptModelCode, key string, question string, channel chan string) (an
req := openai.ChatCompletionRequest{
Model: gptModelCode,
MaxTokens: 2048,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: question,
},
},
Messages: []openai.ChatCompletionMessage{},
Stream: true,
}
for _, m := range messages {
req.Messages = append(req.Messages, openai.ChatCompletionMessage{
Role: m.Role,
Content: m.Content,
})
}
req.Messages = append(req.Messages, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: question,
})
stream, err := c.CreateChatCompletionStream(ctx, req)
if err != nil {
fmt.Printf("ChatCompletionStream error: %v\n", err)
... ...
... ... @@ -29,7 +29,7 @@ var (
hostChatDocumentUrl = "wss://chatdoc.xfyun.cn/openapi/chat"
)
func ChatSpark(appid string, apiKey string, apiSecret string, question string, channel chan string) (answer string, err error) {
func ChatSpark(appid string, apiKey string, apiSecret string, question string, channel chan string, messages ...Message) (answer string, err error) {
// fmt.Println(HmacWithShaTobase64("hmac-sha256", "hello\nhello", "hello"))
// st := time.Now()
d := websocket.Dialer{
... ... @@ -45,7 +45,7 @@ func ChatSpark(appid string, apiKey string, apiSecret string, question string, c
}
go func() {
data := genParams1(appid, question)
data := genParams(appid, question, messages...)
conn.WriteJSON(data)
}()
... ... @@ -119,11 +119,11 @@ type SparkChatMessage struct {
}
// 生成参数
func genParams1(appid, question string) map[string]interface{} { // 根据实际情况修改返回的数据结构和字段名
messages := []Message{
{Role: "user", Content: question},
}
func genParams(appid, question string, list ...Message) map[string]interface{} { // 根据实际情况修改返回的数据结构和字段名
var messages = list
messages = append(messages, Message{
Role: "user", Content: question,
})
data := map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名
"header": map[string]interface{}{ // 根据实际情况修改返回的数据结构和字段名
... ...
package contextdata
import (
"context"
"github.com/golang-jwt/jwt/v4"
"time"
)
... ... @@ -24,3 +25,16 @@ func (tk UserToken) GenerateToken(secret string, expire int64) (string, error) {
return token.SignedString([]byte(secret))
}
func ParseToken(ctx context.Context, secret string, tokenStr string) (context.Context, error) {
token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (interface{}, error) {
return []byte(secret), nil
})
if err != nil {
return ctx, err
}
for k, v := range token.Claims.(jwt.MapClaims) {
ctx = context.WithValue(ctx, k, v)
}
return ctx, nil
}
... ...
... ... @@ -3,7 +3,10 @@ package contextdata
import (
"context"
"encoding/json"
"fmt"
"github.com/zeromicro/go-zero/core/logx"
"strconv"
"strings"
)
var (
... ... @@ -14,12 +17,14 @@ var (
func GetInt64FromCtx(ctx context.Context, key string) int64 {
var uid int64
if jsonUid, ok := ctx.Value(key).(json.Number); ok {
if int64Uid, err := jsonUid.Int64(); err == nil {
uid = int64Uid
} else {
if value := ctx.Value(key); value != nil {
valueStr := strings.Trim(fmt.Sprintf("%v", value), `"`)
i64, err := strconv.ParseInt(valueStr, 10, 64)
if err != nil {
logx.WithContext(ctx).Errorf("GetUidFromCtx err : %+v", err)
}
uid = i64
return uid
}
return uid
}
... ...