作者 yangfu

feat-1.0 chat doc

package chat
import (
"net/http"
"github.com/zeromicro/go-zero/rest/httpx"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/logic/chat"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/svc"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/types"
)
func ChatSessionAddFilesHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.ChatSessionAddFilesRequest
if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err)
return
}
l := chat.NewChatSessionAddFilesLogic(r.Context(), svcCtx)
resp, err := l.ChatSessionAddFiles(&req)
if err != nil {
httpx.ErrorCtx(r.Context(), w, err)
} else {
httpx.OkJsonCtx(r.Context(), w, resp)
}
}
}
... ...
package chat
import (
"net/http"
"github.com/zeromicro/go-zero/rest/httpx"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/logic/chat"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/svc"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/types"
)
func ChatSessionRemoveFilesHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req types.ChatSessionAddFilesRequest
if err := httpx.Parse(r, &req); err != nil {
httpx.ErrorCtx(r.Context(), w, err)
return
}
l := chat.NewChatSessionRemoveFilesLogic(r.Context(), svcCtx)
resp, err := l.ChatSessionRemoveFiles(&req)
if err != nil {
httpx.ErrorCtx(r.Context(), w, err)
} else {
httpx.OkJsonCtx(r.Context(), w, resp)
}
}
}
... ...
... ... @@ -54,6 +54,16 @@ func RegisterHandlers(server *rest.Server, serverCtx *svc.ServiceContext) {
},
{
Method: http.MethodPost,
Path: "/chat/session/add_files",
Handler: chat.ChatSessionAddFilesHandler(serverCtx),
},
{
Method: http.MethodPost,
Path: "/chat/session/remove_files",
Handler: chat.ChatSessionRemoveFilesHandler(serverCtx),
},
{
Method: http.MethodPost,
Path: "/chat/session/my_spark_sessions",
Handler: chat.ChatMySparkSessionsHandler(serverCtx),
},
... ...
package chat
import (
"context"
"github.com/samber/lo"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/logic/core"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/internal/pkg/domain"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/contextdata"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/transaction"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/xerr"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/svc"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/types"
"github.com/zeromicro/go-zero/core/logx"
)
type ChatSessionAddFilesLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewChatSessionAddFilesLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatSessionAddFilesLogic {
return &ChatSessionAddFilesLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *ChatSessionAddFilesLogic) ChatSessionAddFiles(req *types.ChatSessionAddFilesRequest) (resp *types.ChatSessionAddFilesResponse, err error) {
var (
conn = l.svcCtx.DefaultDBConn()
dm *domain.ChatSession
token = contextdata.GetUserTokenFromCtx(l.ctx)
)
if dm, err = l.svcCtx.ChatSessionRepository.FindOne(l.ctx, conn, req.Id); err != nil {
return nil, xerr.NewErrMsgErr("不存在", err)
}
if dm.UserId != token.UserId {
return nil, xerr.NewErrMsgErr("无权限", err)
}
// 赋值
if dm.Module != domain.ModuleSparkChat {
return nil, xerr.NewErrMsgErr("类型有误,星火文档类型才可以添加文档", err)
}
// 更新
if err = transaction.UseTrans(l.ctx, l.svcCtx.DB, func(ctx context.Context, conn transaction.Conn) error {
// 知识库移除
if dm.Type == domain.TypeSparkDatasetChat {
if err = core.DatasetAddFiles(l.ctx, l.svcCtx, conn, dm.Metadata.DatasetId, req.DocumentIds); err != nil {
return err
}
return nil
}
// 普通多文档移除
dm.Metadata.DocumentIds = lo.Union(dm.Metadata.DocumentIds, req.DocumentIds)
dm, err = l.svcCtx.ChatSessionRepository.UpdateWithVersion(l.ctx, conn, dm)
return err
}, true); err != nil {
return nil, xerr.NewErrMsg("更新失败")
}
resp = &types.ChatSessionAddFilesResponse{}
return
}
... ...
package chat
import (
"context"
"github.com/samber/lo"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/logic/core"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/internal/pkg/domain"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/contextdata"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/transaction"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/xerr"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/svc"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/types"
"github.com/zeromicro/go-zero/core/logx"
)
type ChatSessionRemoveFilesLogic struct {
logx.Logger
ctx context.Context
svcCtx *svc.ServiceContext
}
func NewChatSessionRemoveFilesLogic(ctx context.Context, svcCtx *svc.ServiceContext) *ChatSessionRemoveFilesLogic {
return &ChatSessionRemoveFilesLogic{
Logger: logx.WithContext(ctx),
ctx: ctx,
svcCtx: svcCtx,
}
}
func (l *ChatSessionRemoveFilesLogic) ChatSessionRemoveFiles(req *types.ChatSessionAddFilesRequest) (resp *types.ChatSessionAddFilesResponse, err error) {
var (
conn = l.svcCtx.DefaultDBConn()
dm *domain.ChatSession
token = contextdata.GetUserTokenFromCtx(l.ctx)
)
if dm, err = l.svcCtx.ChatSessionRepository.FindOne(l.ctx, conn, req.Id); err != nil {
return nil, xerr.NewErrMsgErr("不存在", err)
}
if dm.UserId != token.UserId {
return nil, xerr.NewErrMsgErr("无权限", err)
}
// 赋值
if dm.Module != domain.ModuleSparkChat {
return nil, xerr.NewErrMsgErr("类型有误,星火文档类型才可以添加文档", err)
}
// 更新
if err = transaction.UseTrans(l.ctx, l.svcCtx.DB, func(ctx context.Context, conn transaction.Conn) error {
// 知识库移除
if dm.Type == domain.TypeSparkDatasetChat {
if err = core.DatasetRemoveFiles(l.ctx, l.svcCtx, conn, dm.Metadata.DatasetId, req.DocumentIds); err != nil {
return err
}
return nil
}
dm.Metadata.DocumentIds = lo.Without(dm.Metadata.DocumentIds, req.DocumentIds...)
dm, err = l.svcCtx.ChatSessionRepository.UpdateWithVersion(l.ctx, conn, dm)
return err
}, true); err != nil {
return nil, xerr.NewErrMsg("更新失败")
}
resp = &types.ChatSessionAddFilesResponse{}
return
}
... ...
... ... @@ -3,8 +3,10 @@ package chat
import (
"context"
"fmt"
"github.com/samber/lo"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/internal/pkg/domain"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/contextdata"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/tool"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/svc"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/types"
... ... @@ -47,8 +49,45 @@ func (l *ChatSessionSearchLogic) ChatSessionSearch(req *types.ChatSessionSearchR
total, dms, err = l.svcCtx.ChatSessionRepository.Find(l.ctx, conn, queryOptions)
list := make([]types.ChatSessionItem, 0)
for i := range dms {
list = append(list, NewTypesChatSession(dms[i]))
// 1.分组dataset document mapping
var datasetIds []int64
for _, session := range dms {
if session.Type == domain.TypeSparkDatasetChat && session.Metadata.DatasetId > 0 {
datasetIds = append(datasetIds, session.Metadata.DatasetId)
}
datasetIds = lo.Uniq(datasetIds)
}
var groupDocumentMappings map[int64][]*domain.ChatDatasetDocumentMapping
if len(datasetIds) > 0 {
_, documentMapping, _ := l.svcCtx.ChatDatasetDocumentMappingRepository.FindByDataset(l.ctx, conn, datasetIds...)
groupDocumentMappings = lo.GroupBy(documentMapping, func(item *domain.ChatDatasetDocumentMapping) int64 {
return item.DatasetId
})
}
// 2. 加载documents
lazyDocument := tool.NewLazyLoadService(l.svcCtx.ChatDocumentRepository.FindOne)
for i, session := range dms {
var documents []types.ChatDocumentItem
if session.Type == domain.TypeSparkDocumentsChat && len(session.Metadata.DocumentIds) > 0 {
for _, id := range session.Metadata.DocumentIds {
if document, _ := lazyDocument.Load(l.ctx, conn, id); document != nil {
documents = append(documents, types.NewTypesChatDocument(document))
}
}
} else if session.Type == domain.TypeSparkDatasetChat && session.Metadata.DatasetId > 0 {
if documentMapping, ok := groupDocumentMappings[session.Metadata.DatasetId]; ok {
lo.ForEach(documentMapping, func(item *domain.ChatDatasetDocumentMapping, index int) {
if document, _ := lazyDocument.Load(l.ctx, conn, item.DocumentId); document != nil {
documents = append(documents, types.NewTypesChatDocument(document))
}
})
}
}
typesSession := NewTypesChatSession(dms[i])
typesSession.Documents = documents
list = append(list, typesSession)
}
resp = &types.ChatSessionSearchResponse{
List: list,
... ...
package core
import (
"context"
"github.com/samber/lo"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/svc"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/internal/pkg/domain"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/transaction"
)
func DatasetAddFiles(ctx context.Context, svcCtx *svc.ServiceContext, conn transaction.Conn, datasetId int64, documents []int64) error {
var (
documentMapping []*domain.ChatDatasetDocumentMapping
err error
)
_, documentMapping, _ = svcCtx.ChatDatasetDocumentMappingRepository.FindByDataset(ctx, conn, datasetId)
for _, id := range documents {
if lo.ContainsBy(documentMapping, func(item *domain.ChatDatasetDocumentMapping) bool {
return item.DocumentId == id
}) {
continue
}
if _, err = svcCtx.ChatDatasetDocumentMappingRepository.Insert(ctx, conn, &domain.ChatDatasetDocumentMapping{
DatasetId: datasetId,
DocumentId: id,
}); err != nil {
return err
}
}
return nil
}
... ...
package core
import (
"context"
"github.com/samber/lo"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/svc"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/internal/pkg/domain"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/transaction"
)
func DatasetRemoveFiles(ctx context.Context, svcCtx *svc.ServiceContext, conn transaction.Conn, datasetId int64, documents []int64) error {
var (
documentMapping []*domain.ChatDatasetDocumentMapping
err error
)
_, documentMapping, _ = svcCtx.ChatDatasetDocumentMappingRepository.FindByDataset(ctx, conn, datasetId)
for _, id := range documents {
var found *domain.ChatDatasetDocumentMapping
if !lo.ContainsBy(documentMapping, func(item *domain.ChatDatasetDocumentMapping) bool {
if item.DocumentId == id {
found = item
}
return item.DocumentId == id
}) {
continue
}
if _, err = svcCtx.ChatDatasetDocumentMappingRepository.Delete(ctx, conn, found); err != nil {
return err
}
}
return nil
}
... ...
... ... @@ -2,7 +2,7 @@ package dataset
import (
"context"
"github.com/samber/lo"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/logic/core"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/internal/pkg/domain"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/transaction"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/xerr"
... ... @@ -31,27 +31,15 @@ func (l *ChatDatasetAddFilesLogic) ChatDatasetAddFiles(req *types.ChatDatasetAdd
var (
conn = l.svcCtx.DefaultDBConn()
dm *domain.ChatDataset
documentMapping []*domain.ChatDatasetDocumentMapping
)
// 货号唯一
if dm, err = l.svcCtx.ChatDatasetRepository.FindOne(l.ctx, conn, req.Id); err != nil {
return nil, xerr.NewErrMsgErr("不存在", err)
return nil, xerr.NewErrMsgErr("知识库不存在", err)
}
_, documentMapping, _ = l.svcCtx.ChatDatasetDocumentMappingRepository.FindByDataset(l.ctx, conn, dm.Id)
if err = transaction.UseTrans(l.ctx, l.svcCtx.DB, func(ctx context.Context, conn transaction.Conn) error {
for _, id := range req.DocumentIds {
if lo.ContainsBy(documentMapping, func(item *domain.ChatDatasetDocumentMapping) bool {
return item.DocumentId == id
}) {
continue
}
if _, err = l.svcCtx.ChatDatasetDocumentMappingRepository.Insert(l.ctx, conn, &domain.ChatDatasetDocumentMapping{
DatasetId: dm.Id,
DocumentId: id,
}); err != nil {
if err = core.DatasetAddFiles(l.ctx, l.svcCtx, conn, dm.Id, req.DocumentIds); err != nil {
return err
}
}
return err
}, true); err != nil {
return nil, xerr.NewErrMsg("添加文档失败")
... ...
... ... @@ -2,7 +2,7 @@ package dataset
import (
"context"
"github.com/samber/lo"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/api/internal/logic/core"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/cmd/ep/chat/internal/pkg/domain"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/transaction"
"gitlab.fjmaimaimai.com/allied-creation/su-micro/pkg/xerr"
... ... @@ -31,28 +31,15 @@ func (l *ChatDatasetRemvoeFilesLogic) ChatDatasetRemvoeFiles(req *types.ChatData
var (
conn = l.svcCtx.DefaultDBConn()
dm *domain.ChatDataset
documentMapping []*domain.ChatDatasetDocumentMapping
)
// 货号唯一
if dm, err = l.svcCtx.ChatDatasetRepository.FindOne(l.ctx, conn, req.Id); err != nil {
return nil, xerr.NewErrMsgErr("不存在", err)
}
_, documentMapping, _ = l.svcCtx.ChatDatasetDocumentMappingRepository.FindByDataset(l.ctx, conn, dm.Id)
if err = transaction.UseTrans(l.ctx, l.svcCtx.DB, func(ctx context.Context, conn transaction.Conn) error {
for _, id := range req.DocumentIds {
var found *domain.ChatDatasetDocumentMapping
if !lo.ContainsBy(documentMapping, func(item *domain.ChatDatasetDocumentMapping) bool {
if item.DocumentId == id {
found = item
}
return item.DocumentId == id
}) {
continue
}
if _, err = l.svcCtx.ChatDatasetDocumentMappingRepository.Delete(l.ctx, conn, found); err != nil {
if err = core.DatasetRemoveFiles(l.ctx, l.svcCtx, conn, dm.Id, req.DocumentIds); err != nil {
return err
}
}
return err
}, true); err != nil {
return nil, xerr.NewErrMsg("添加文档失败")
... ...
... ... @@ -56,6 +56,7 @@ type ChatSessionItem struct {
Type string `json:"type,optional,omitempty,default=chat"` // 类型 chat:普通问答 spark_dataset_chat:星火知识库问答 spark_documents_chat:星火多文档问答
DatasetId int64 `json:"datasetId,optional,omitempty"` // 知识库
DocumentIds []int64 `json:"documentIds,optional,omitempty"` // 多文档
Documents []ChatDocumentItem `json:"documents,optional,omitempty"` // 多文档
}
type ChatModelsRequest struct {
... ... @@ -130,6 +131,14 @@ type User struct {
Avatar string `json:"avatar"` // 头像
}
type ChatSessionAddFilesRequest struct {
Id int64 `json:"id"` // 文档ID
DocumentIds []int64 `json:"documentIds"` // 文档ID列表
}
type ChatSessionAddFilesResponse struct {
}
type ChatDocumentGetRequest struct {
Id int64 `path:"id"`
}
... ...
... ... @@ -30,6 +30,13 @@ service Core {
@handler chatSessionConversationWs
get /chat/session/conversation (ChatSessionConversationRequestWs) returns (ChatSessionConversationResponse)
@doc "聊天会话-添加文件"
@handler chatSessionAddFiles
post /chat/session/add_files (ChatSessionAddFilesRequest) returns (ChatSessionAddFilesResponse)
@doc "聊天会话-移除文件"
@handler chatSessionRemoveFiles
post /chat/session/remove_files (ChatSessionAddFilesRequest) returns (ChatSessionAddFilesResponse)
@doc "星火聊天会话-我的会话"
@handler chatMySparkSessions
post /chat/session/my_spark_sessions (ChatSessionSearchRequest) returns (ChatSessionSearchResponse)
... ... @@ -95,6 +102,7 @@ type (
Type string `json:"type,optional,omitempty,default=chat"` // 类型 chat:普通问答 spark_dataset_chat:星火知识库问答 spark_documents_chat:星火多文档问答
DatasetId int64 `json:"datasetId,optional,omitempty"` // 知识库
DocumentIds []int64 `json:"documentIds,optional,omitempty"` // 多文档
Documents []ChatDocumentItem `json:"documents,optional,omitempty"` // 多文档
}
)
... ... @@ -171,3 +179,14 @@ type(
Avatar string `json:"avatar"` // 头像
}
)
// 会话添加新文档
type(
ChatSessionAddFilesRequest{
Id int64 `json:"id"` // 文档ID
DocumentIds []int64 `json:"documentIds"` // 文档ID列表
}
ChatSessionAddFilesResponse{
}
)
\ No newline at end of file
... ...