diff --git a/internal/conversation_msg/conversation.go b/internal/conversation_msg/conversation.go index 99d24d478..6e74f6fb9 100644 --- a/internal/conversation_msg/conversation.go +++ b/internal/conversation_msg/conversation.go @@ -64,7 +64,9 @@ func (c *Conversation) getAdvancedHistoryMessageList(ctx context.Context, req sd } startTime = m.SendTime } else { - c.messagePullMinSeqMap.Delete(conversationID) + // Clear both maps when the user enters the conversation + c.messagePullForwardEndSeqMap.Delete(conversationID) + c.messagePullReverseEndSeqMap.Delete(conversationID) } log.ZDebug(ctx, "Assembly conversation parameters", "cost time", time.Since(t), "conversationID", conversationID, "startTime:", startTime, "count:", req.Count, "startTime", startTime) @@ -75,41 +77,48 @@ func (c *Conversation) getAdvancedHistoryMessageList(ctx context.Context, req sd log.ZDebug(ctx, "pull message", "pull cost time", time.Since(t)) t = time.Now() - var thisMinSeq int64 - thisMinSeq, messageList = c.LocalChatLog2MsgStruct(ctx, list) + var thisEndSeq int64 + thisEndSeq, messageList = c.LocalChatLog2MsgStruct(list, isReverse) log.ZDebug(ctx, "message convert and unmarshal", "unmarshal cost time", time.Since(t)) t = time.Now() if !isReverse { sort.Sort(messageList) + if thisEndSeq != 0 { + c.messagePullForwardEndSeqMap.Store(conversationID, thisEndSeq) + } + } else { + if thisEndSeq != 0 { + c.messagePullReverseEndSeqMap.Store(conversationID, thisEndSeq) + } } log.ZDebug(ctx, "sort", "sort cost time", time.Since(t)) messageListCallback.MessageList = messageList - if thisMinSeq != 0 { - c.messagePullMinSeqMap.Store(conversationID, thisMinSeq) - } - return &messageListCallback, nil + return &messageListCallback, nil } func (c *Conversation) fetchMessagesWithGapCheck(ctx context.Context, conversationID string, count int, startTime int64, isReverse bool, messageListCallback *sdk.GetAdvancedHistoryMessageListCallback) ([]*model_struct.LocalChatLog, error) { - var list []*model_struct.LocalChatLog + var list, validMessages []*model_struct.LocalChatLog - // If all retrieved messages are either deleted or filtered out, continue fetching messages from an earlier point. - shouldFetchMoreMessages := func(messages []*model_struct.LocalChatLog) bool { + // Get the number of invalid messages in this batch to recursive fetching from earlier points. + shouldFetchMoreMessagesNum := func(messages []*model_struct.LocalChatLog) int { if len(messages) == 0 { - return false + return count } - allDeleted := true + // Represents the number of valid messages in the batch + validateMessageNum := 0 for _, msg := range messages { if msg.Status < constant.MsgStatusHasDeleted { - allDeleted = false - break + validateMessageNum++ + validMessages = append(validMessages, msg) + } else { + log.ZDebug(ctx, "this message has been deleted or exception message", "msg", msg) } } - return allDeleted + return count - validateMessageNum } getNewStartTime := func(messages []*model_struct.LocalChatLog) int64 { if len(messages) == 0 { @@ -128,39 +137,49 @@ func (c *Conversation) fetchMessagesWithGapCheck(ctx context.Context, conversati return nil, err } t = time.Now() - maxSeq := c.validateAndFillInternalGaps(ctx, conversationID, isReverse, + thisStartSeq := c.validateAndFillInternalGaps(ctx, conversationID, isReverse, count, startTime, &list, messageListCallback) log.ZDebug(ctx, "internal continuity check", "cost time", time.Since(t)) t = time.Now() - c.validateAndFillInterBlockGaps(ctx, maxSeq, conversationID, + c.validateAndFillInterBlockGaps(ctx, thisStartSeq, conversationID, isReverse, count, startTime, &list, messageListCallback) log.ZDebug(ctx, "between continuity check", "cost time", time.Since(t)) t = time.Now() c.validateAndFillEndBlockContinuity(ctx, conversationID, isReverse, count, startTime, &list, messageListCallback) log.ZDebug(ctx, "end continuity check", "cost time", time.Since(t)) - // If all retrieved messages are either deleted or filtered out, - //continue fetching recursively until either valid messages are found or all messages have been fetched. - if shouldFetchMoreMessages(list) && !messageListCallback.IsEnd { - return c.fetchMessagesWithGapCheck(ctx, conversationID, count, getNewStartTime(list), isReverse, messageListCallback) + // If the number of valid messages retrieved is less than the count, + // continue fetching recursively until the valid messages are sufficient or all messages have been fetched. + missingCount := shouldFetchMoreMessagesNum(list) + if missingCount > 0 && !messageListCallback.IsEnd { + log.ZDebug(ctx, "fetch more messages", "missingCount", missingCount, "conversationID", conversationID) + missingMessages, err := c.fetchMessagesWithGapCheck(ctx, conversationID, missingCount, getNewStartTime(list), isReverse, messageListCallback) + if err != nil { + return nil, err + } + log.ZDebug(ctx, "fetch more messages", "missingMessages", missingMessages) + return append(validMessages, missingMessages...), nil } - return list, nil + return validMessages, nil } -func (c *Conversation) LocalChatLog2MsgStruct(ctx context.Context, list []*model_struct.LocalChatLog) (int64, []*sdk_struct.MsgStruct) { +func (c *Conversation) LocalChatLog2MsgStruct(list []*model_struct.LocalChatLog, isReverse bool) (int64, []*sdk_struct.MsgStruct) { messageList := make([]*sdk_struct.MsgStruct, 0, len(list)) - var thisMinSeq int64 + var thisEndSeq int64 for _, v := range list { - if v.Seq != 0 && thisMinSeq == 0 { - thisMinSeq = v.Seq + if v.Seq != 0 && thisEndSeq == 0 { + thisEndSeq = v.Seq } - if v.Seq < thisMinSeq && v.Seq != 0 { - thisMinSeq = v.Seq - } - if v.Status >= constant.MsgStatusHasDeleted { - log.ZDebug(ctx, "this message has been deleted or exception message", "msg", v) - continue + if isReverse { + if v.Seq > thisEndSeq && thisEndSeq != 0 { + thisEndSeq = v.Seq + } + + } else { + if v.Seq < thisEndSeq && v.Seq != 0 { + thisEndSeq = v.Seq + } } temp := LocalChatLogToMsgStruct(v) @@ -169,7 +188,7 @@ func (c *Conversation) LocalChatLog2MsgStruct(ctx context.Context, list []*model } messageList = append(messageList, temp) } - return thisMinSeq, messageList + return thisEndSeq, messageList } func (c *Conversation) typingStatusUpdate(ctx context.Context, recvID, msgTip string) error { diff --git a/internal/conversation_msg/conversation_msg.go b/internal/conversation_msg/conversation_msg.go index dfa505f73..91a1d38a8 100644 --- a/internal/conversation_msg/conversation_msg.go +++ b/internal/conversation_msg/conversation_msg.go @@ -51,28 +51,29 @@ var SearchContentType = []int{constant.Text, constant.AtText, constant.File} type Conversation struct { *interaction.LongConnMgr - conversationSyncer *syncer.Syncer[*model_struct.LocalConversation, pbConversation.GetOwnerConversationResp, string] - db db_interface.DataBase - ConversationListener func() open_im_sdk_callback.OnConversationListener - msgListener func() open_im_sdk_callback.OnAdvancedMsgListener - msgKvListener func() open_im_sdk_callback.OnMessageKvInfoListener - batchMsgListener func() open_im_sdk_callback.OnBatchMsgListener - businessListener func() open_im_sdk_callback.OnCustomBusinessListener - recvCH chan common.Cmd2Value - loginUserID string - platformID int32 - DataDir string - relation *relation.Relation - group *group.Group - user *user.User - file *file.File - cache *cache.Cache[string, *model_struct.LocalConversation] - maxSeqRecorder MaxSeqRecorder - messagePullMinSeqMap *cache.Cache[string, int64] - IsExternalExtensions bool - msgOffset int - progress int - conversationSyncMutex sync.Mutex + conversationSyncer *syncer.Syncer[*model_struct.LocalConversation, pbConversation.GetOwnerConversationResp, string] + db db_interface.DataBase + ConversationListener func() open_im_sdk_callback.OnConversationListener + msgListener func() open_im_sdk_callback.OnAdvancedMsgListener + msgKvListener func() open_im_sdk_callback.OnMessageKvInfoListener + batchMsgListener func() open_im_sdk_callback.OnBatchMsgListener + businessListener func() open_im_sdk_callback.OnCustomBusinessListener + recvCH chan common.Cmd2Value + loginUserID string + platformID int32 + DataDir string + relation *relation.Relation + group *group.Group + user *user.User + file *file.File + cache *cache.Cache[string, *model_struct.LocalConversation] + maxSeqRecorder MaxSeqRecorder + messagePullForwardEndSeqMap *cache.Cache[string, int64] + messagePullReverseEndSeqMap *cache.Cache[string, int64] + IsExternalExtensions bool + msgOffset int + progress int + conversationSyncMutex sync.Mutex startTime time.Time @@ -100,20 +101,21 @@ func NewConversation(ctx context.Context, longConnMgr *interaction.LongConnMgr, file *file.File) *Conversation { info := ccontext.Info(ctx) n := &Conversation{db: db, - LongConnMgr: longConnMgr, - recvCH: ch, - loginUserID: info.UserID(), - platformID: info.PlatformID(), - DataDir: info.DataDir(), - relation: relation, - group: group, - user: user, - file: file, - IsExternalExtensions: info.IsExternalExtensions(), - maxSeqRecorder: NewMaxSeqRecorder(), - messagePullMinSeqMap: cache.NewCache[string, int64](), - msgOffset: 0, - progress: 0, + LongConnMgr: longConnMgr, + recvCH: ch, + loginUserID: info.UserID(), + platformID: info.PlatformID(), + DataDir: info.DataDir(), + relation: relation, + group: group, + user: user, + file: file, + IsExternalExtensions: info.IsExternalExtensions(), + maxSeqRecorder: NewMaxSeqRecorder(), + messagePullForwardEndSeqMap: cache.NewCache[string, int64](), + messagePullReverseEndSeqMap: cache.NewCache[string, int64](), + msgOffset: 0, + progress: 0, } n.typing = newTyping(n) n.initSyncer() @@ -835,8 +837,8 @@ func (c *Conversation) batchAddFaceURLAndName(ctx context.Context, conversations conversation.FaceURL = v.FaceURL conversation.ShowName = v.Nickname } else { - log.ZWarn(ctx, "user info not found", errors.New("user not found"),"userID", conversation.UserID) - + log.ZWarn(ctx, "user info not found", errors.New("user not found"), "userID", conversation.UserID) + conversation.FaceURL = "" conversation.ShowName = "UserNotFound" } @@ -929,37 +931,37 @@ func (c *Conversation) FetchSurroundingMessages(ctx context.Context, conversatio if len(res) == 0 { return []*sdk_struct.MsgStruct{}, nil } - _, msgList := c.LocalChatLog2MsgStruct(ctx, []*model_struct.LocalChatLog{res[0]}) - if len(msgList) == 0 { - return []*sdk_struct.MsgStruct{}, nil - } - msg := msgList[0] + //_, msgList := c.LocalChatLog2MsgStruct []*model_struct.LocalChatLog{res[0]}) + //if len(msgList) == 0 { + // return []*sdk_struct.MsgStruct{}, nil + //} + //msg := msgList[0] result := make([]*sdk_struct.MsgStruct, 0, before+after+1) - if before > 0 { - req := sdk.GetAdvancedHistoryMessageListParams{ - ConversationID: conversationID, - Count: int(before), - StartClientMsgID: msg.ClientMsgID, - } - val, err := c.getAdvancedHistoryMessageList(ctx, req, false) - if err != nil { - return nil, err - } - result = append(result, val.MessageList...) - } - result = append(result, msg) - if after > 0 { - req := sdk.GetAdvancedHistoryMessageListParams{ - ConversationID: conversationID, - Count: int(after), - StartClientMsgID: msg.ClientMsgID, - } - val, err := c.getAdvancedHistoryMessageList(ctx, req, true) - if err != nil { - return nil, err - } - result = append(result, val.MessageList...) - } - sort.Sort(sdk_struct.NewMsgList(result)) + //if before > 0 { + // req := sdk.GetAdvancedHistoryMessageListParams{ + // ConversationID: conversationID, + // Count: int(before), + // StartClientMsgID: msg.ClientMsgID, + // } + // val, err := c.getAdvancedHistoryMessageList(ctx, req, false) + // if err != nil { + // return nil, err + // } + // result = append(result, val.MessageList...) + //} + //result = append(result, msg) + //if after > 0 { + // req := sdk.GetAdvancedHistoryMessageListParams{ + // ConversationID: conversationID, + // Count: int(after), + // StartClientMsgID: msg.ClientMsgID, + // } + // val, err := c.getAdvancedHistoryMessageList(ctx, req, true) + // if err != nil { + // return nil, err + // } + // result = append(result, val.MessageList...) + //} + //sort.Sort(sdk_struct.NewMsgList(result)) return result, nil } diff --git a/internal/conversation_msg/message_check.go b/internal/conversation_msg/message_check.go index 5e914476f..4e8a2c80b 100644 --- a/internal/conversation_msg/message_check.go +++ b/internal/conversation_msg/message_check.go @@ -34,68 +34,41 @@ func (c *Conversation) validateAndFillInternalGaps(ctx context.Context, conversa } } + if isReverse { + return minSeq + } return maxSeq } // validateAndFillInterBlockGaps checks for continuity between blocks of messages. If a gap is identified, it retrieves the missing messages // to bridge the gap. The function returns a boolean indicating whether the blocks are continuous. -func (c *Conversation) validateAndFillInterBlockGaps(ctx context.Context, maxSeq int64, conversationID string, - isReverse bool, count int, startTime int64, list *[]*model_struct.LocalChatLog, messageListCallback *sdk.GetAdvancedHistoryMessageListCallback) bool { - lastMinSeq, _ := c.messagePullMinSeqMap.Load(conversationID) - if lastMinSeq != 0 { - log.ZDebug(ctx, "get lost LastMinSeq is :", "lastMinSeq", lastMinSeq, "thisMaxSeq", maxSeq) - if maxSeq+1 != lastMinSeq { - - lostSeqList := getLostSeqListWithLimitLength(maxSeq+1, lastMinSeq-1, []int64{}) - log.ZDebug(ctx, "get lost lostSeqList is :", "lostSeqList", lostSeqList, "length:", len(lostSeqList)) - if len(lostSeqList) > 0 { - log.ZDebug(ctx, "messageBlocksBetweenContinuityCheck", "lostSeqList", lostSeqList) - c.fetchAndMergeMissingMessages(ctx, conversationID, lostSeqList, isReverse, count, startTime, list, messageListCallback) - } - } else { - return true - } +func (c *Conversation) validateAndFillInterBlockGaps(ctx context.Context, thisStartSeq int64, conversationID string, + isReverse bool, count int, startTime int64, list *[]*model_struct.LocalChatLog, messageListCallback *sdk.GetAdvancedHistoryMessageListCallback) { + var lastEndSeq, startSeq, endSeq int64 + var isLostSeq bool + if isReverse { + lastEndSeq, _ = c.messagePullReverseEndSeqMap.Load(conversationID) + isLostSeq = lastEndSeq+1 != thisStartSeq + startSeq = lastEndSeq + 1 + endSeq = thisStartSeq - 1 } else { - return true + lastEndSeq, _ = c.messagePullForwardEndSeqMap.Load(conversationID) + isLostSeq = thisStartSeq+1 != lastEndSeq + startSeq = thisStartSeq + 1 + endSeq = lastEndSeq - 1 + } + if isLostSeq { + log.ZDebug(ctx, "get lost LastMinSeq is :", "lastEndSeq", lastEndSeq, "thisStartSeq", thisStartSeq, "startSeq", startSeq, "endSeq", endSeq) + lostSeqList := getLostSeqListWithLimitLength(startSeq, endSeq, []int64{}) + log.ZDebug(ctx, "get lost lostSeqList is :", "lostSeqList", lostSeqList, "length:", len(lostSeqList)) + if len(lostSeqList) > 0 { + log.ZDebug(ctx, "messageBlocksBetweenContinuityCheck", "lostSeqList", lostSeqList) + c.fetchAndMergeMissingMessages(ctx, conversationID, lostSeqList, isReverse, count, startTime, list, messageListCallback) + } } - - return false } -// func (c *Conversation) messageBlocksEndContinuityCheck(ctx context.Context, conversationID string, -// -// isReverse bool, count int, startTime int64, list *[]*model_struct.LocalChatLog, messageListCallback *sdk.GetAdvancedHistoryMessageListCallback) { -// maxSeq, minSeq, haveSeqList := c.getMaxAndMinHaveSeqList(*list) -// if minSeq != 0 { -// seqList := func(seq int64) (seqList []int64) { -// startSeq := seq - constant.PullMsgNumForReadDiffusion -// if startSeq <= 0 { -// startSeq = 1 -// } -// log.ZDebug(ctx, "pull start is ", "start seq", startSeq) -// for i := startSeq; i < seq; i++ { -// seqList = append(seqList, i) -// } -// return seqList -// }(minSeq) -// log.ZDebug(ctx, "pull seqList is ", "seqList", seqList, "len", len(seqList)) -// -// if len(seqList) > 0 { -// log.ZDebug(ctx, "messageBlocksEndContinuityCheck", "seqList", seqList) -// c.fetchAndMergeMissingMessages(ctx, conversationID, seqList, isReverse, count, startTime, list, messageListCallback) -// } -// -// } else { -// log.ZDebug(ctx, "messageBlocksEndContinuityCheck", "minSeq", minSeq, "conversationID", conversationID) -// // local don't have messages, but the server's maximum message count is not zero -// seqList := []int64{0, 0} -// c.fetchAndMergeMissingMessages(ctx, conversationID, seqList, isReverse, count, startTime, list, messageListCallback) -// -// } -// -// } - // validateAndFillEndBlockContinuity performs an end-of-block continuity check. If a batch of messages has passed // internal and inter-block continuity checks but contains fewer messages than `count`, this function verifies if the end // of the message history has been reached. If not, it attempts to retrieve any missing messages to ensure continuity. @@ -103,27 +76,53 @@ func (c *Conversation) validateAndFillEndBlockContinuity(ctx context.Context, co isReverse bool, count int, startTime int64, list *[]*model_struct.LocalChatLog, messageListCallback *sdk.GetAdvancedHistoryMessageListCallback) { // Perform an end-of-block check if the retrieved message count is less than requested if len(*list) < count { - _, minSeq, _ := c.getMaxAndMinHaveSeqList(*list) - log.ZDebug(ctx, "messageBlocksEndContinuityCheck", "minSeq", minSeq, "conversationID", conversationID) - if minSeq == 1 { // todo Replace `1` with the minimum sequence value as defined by the user or system - messageListCallback.IsEnd = true - } else { - lastMinSeq, _ := c.messagePullMinSeqMap.Load(conversationID) - log.ZDebug(ctx, "messageBlocksEndContinuityCheck", "lastMinSeq", lastMinSeq, "conversationID", conversationID) - // If `minSeq` is zero and `lastMinSeq` is at the minimum server sequence, this batch is fully local - if minSeq == 0 && lastMinSeq == 1 { // All messages in this batch are local messages, - // and the minimum seq of the last batch of valid messages has already reached the minimum pullable seq from the server. + if isReverse { + maxSeq, _, _ := c.getMaxAndMinHaveSeqList(*list) + log.ZDebug(ctx, "validateAndFillEndBlockContinuity", "maxSeq", maxSeq, "conversationID", conversationID) + if maxSeq == c.maxSeqRecorder.Get(conversationID) { // todo Replace `1` with the minimum sequence value as defined by the user or system messageListCallback.IsEnd = true } else { - // The batch includes sequences but has not reached the minimum value, - // This condition indicates local-only messages, with `minSeq > 1` as the only case, - // since `lastMinSeq > 1` is handled in inter-block continuity. - lostSeqList := getLostSeqListWithLimitLength(1, minSeq-1, []int64{}) - if len(lostSeqList) > 0 { - log.ZDebug(ctx, "messageBlocksEndContinuityCheck", "lostSeqList", lostSeqList) - c.fetchAndMergeMissingMessages(ctx, conversationID, lostSeqList, isReverse, count, startTime, list, messageListCallback) + lastEndSeq, _ := c.messagePullReverseEndSeqMap.Load(conversationID) + log.ZDebug(ctx, "validateAndFillEndBlockContinuity", "lastEndSeq", lastEndSeq, "conversationID", conversationID) + // If `maxSeq` is zero and `lastEndSeq` is at the maximum server sequence, this batch is fully local + if maxSeq == 0 && lastEndSeq == c.maxSeqRecorder.Get(conversationID) { // All messages in this batch are local messages, + // and the maximum seq of the last batch of valid messages has already reached the maximum pullable seq from the server. + messageListCallback.IsEnd = true + } else { + // The batch includes sequences but has not reached the maximum value, + // This condition indicates local-only messages, with `maxSeq < maxSeqRecorderMaxSeq` as the only case, + // since `lastEndSeq < maxSeqRecorderMaxSeq` is handled in inter-block continuity. + lostSeqList := getLostSeqListWithLimitLength(maxSeq+1, c.maxSeqRecorder.Get(conversationID), []int64{}) + if len(lostSeqList) > 0 { + log.ZDebug(ctx, "validateAndFillEndBlockContinuity", "lostSeqList", lostSeqList) + c.fetchAndMergeMissingMessages(ctx, conversationID, lostSeqList, isReverse, count, startTime, list, messageListCallback) + } + } + } + } else { + _, minSeq, _ := c.getMaxAndMinHaveSeqList(*list) + log.ZDebug(ctx, "validateAndFillEndBlockContinuity", "minSeq", minSeq, "conversationID", conversationID) + if minSeq == 1 { // todo Replace `1` with the minimum sequence value as defined by the user or system + messageListCallback.IsEnd = true + } else { + lastMinSeq, _ := c.messagePullForwardEndSeqMap.Load(conversationID) + log.ZDebug(ctx, "validateAndFillEndBlockContinuity", "lastMinSeq", lastMinSeq, "conversationID", conversationID) + // If `minSeq` is zero and `lastMinSeq` is at the minimum server sequence, this batch is fully local + if minSeq == 0 && lastMinSeq == 1 { // All messages in this batch are local messages, + // and the minimum seq of the last batch of valid messages has already reached the minimum pullable seq from the server. + messageListCallback.IsEnd = true + } else { + // The batch includes sequences but has not reached the minimum value, + // This condition indicates local-only messages, with `minSeq > 1` as the only case, + // since `lastMinSeq > 1` is handled in inter-block continuity. + lostSeqList := getLostSeqListWithLimitLength(1, minSeq-1, []int64{}) + if len(lostSeqList) > 0 { + log.ZDebug(ctx, "validateAndFillEndBlockContinuity", "lostSeqList", lostSeqList) + c.fetchAndMergeMissingMessages(ctx, conversationID, lostSeqList, isReverse, count, startTime, list, messageListCallback) + } + } } } } else { @@ -203,8 +202,10 @@ func (c *Conversation) fetchAndMergeMissingMessages(ctx context.Context, convers log.ZDebug(ctx, "syncMsgFromServerSplit pull msg success", "conversationID", conversationID, "count", count, "len", len(*list), "msgLen", len(v.Msgs)) localMessage := datautil.Batch(MsgDataToLocalChatLog, v.Msgs) - reverse(localMessage) - *list = mergeSortedArrays(*list, localMessage, count) + if !isReverse { + reverse(localMessage) + } + *list = mergeSortedArrays(*list, localMessage, count, !isReverse) } } @@ -227,7 +228,7 @@ func errHandle(seqList []int64, list *[]*model_struct.LocalChatLog, err error, m *list = result } -func mergeSortedArrays(arr1, arr2 []*model_struct.LocalChatLog, n int) []*model_struct.LocalChatLog { +func mergeSortedArrays(arr1, arr2 []*model_struct.LocalChatLog, n int, isDescending bool) []*model_struct.LocalChatLog { len1 := len(arr1) len2 := len(arr2) result := make([]*model_struct.LocalChatLog, 0, len1+len2) @@ -235,7 +236,7 @@ func mergeSortedArrays(arr1, arr2 []*model_struct.LocalChatLog, n int) []*model_ i, j := 0, 0 for i < len1 && j < len2 && len(result) < n { - if arr1[i].SendTime >= arr2[j].SendTime { + if (isDescending && arr1[i].SendTime >= arr2[j].SendTime) || (!isDescending && arr1[i].SendTime <= arr2[j].SendTime) { result = append(result, arr1[i]) i++ } else { diff --git a/internal/conversation_msg/message_check_test.go b/internal/conversation_msg/message_check_test.go index baadc9ef5..8ba123d88 100644 --- a/internal/conversation_msg/message_check_test.go +++ b/internal/conversation_msg/message_check_test.go @@ -17,9 +17,10 @@ func TestMergeSortedArrays(t *testing.T) { reverse(array) tests := []struct { - arr1, arr2 []*model_struct.LocalChatLog - n int - expected []*model_struct.LocalChatLog + arr1, arr2 []*model_struct.LocalChatLog + n int + isDescending bool + expected []*model_struct.LocalChatLog }{ { // Test merging two descending arrays @@ -33,7 +34,8 @@ func TestMergeSortedArrays(t *testing.T) { {SendTime: 6, Content: "Message 6"}, {SendTime: 2, Content: "Message 2"}, }, - n: 4, // Limit result to first 4 elements + n: 4, // Limit result to first 4 elements + isDescending: true, expected: []*model_struct.LocalChatLog{ {SendTime: 9, Content: "Message 9"}, {SendTime: 8, Content: "Message 8"}, @@ -49,7 +51,8 @@ func TestMergeSortedArrays(t *testing.T) { {SendTime: 3, Content: "Message 3"}, {SendTime: 1, Content: "Message 1"}, }, - n: 3, + n: 3, + isDescending: true, expected: []*model_struct.LocalChatLog{ {SendTime: 5, Content: "Message 5"}, {SendTime: 3, Content: "Message 3"}, @@ -58,10 +61,11 @@ func TestMergeSortedArrays(t *testing.T) { }, { // Test merging two empty arrays - arr1: []*model_struct.LocalChatLog{}, - arr2: []*model_struct.LocalChatLog{}, - n: 0, - expected: []*model_struct.LocalChatLog{}, + arr1: []*model_struct.LocalChatLog{}, + arr2: []*model_struct.LocalChatLog{}, + n: 0, + isDescending: true, + expected: []*model_struct.LocalChatLog{}, }, { // Test merging a descending array and an ascending array @@ -70,8 +74,9 @@ func TestMergeSortedArrays(t *testing.T) { {SendTime: 5, Content: "Message 5"}, {SendTime: 3, Content: "Message 3"}, }, - arr2: array, - n: 5, // Limit result to first 5 elements + arr2: array, + n: 5, // Limit result to first 5 elements + isDescending: true, // Expected result: merged in descending order expected: []*model_struct.LocalChatLog{ {SendTime: 7, Content: "Message 7"}, @@ -81,10 +86,34 @@ func TestMergeSortedArrays(t *testing.T) { {SendTime: 3, Content: "Message 3"}, }, }, + + { + // Test merging a descending array and an ascending array + arr1: []*model_struct.LocalChatLog{ + {SendTime: 1, Content: "Message 1"}, + {SendTime: 5, Content: "Message 5"}, + {SendTime: 7, Content: "Message 7"}, + }, + arr2: []*model_struct.LocalChatLog{ + {SendTime: 2, Content: "Message 2"}, + {SendTime: 6, Content: "Message 6"}, + {SendTime: 9, Content: "Message 9"}, + }, + n: 5, // Limit result to first 5 elements + isDescending: false, + // Expected result: merged in descending order + expected: []*model_struct.LocalChatLog{ + {SendTime: 1, Content: "Message 1"}, + {SendTime: 2, Content: "Message 2"}, + {SendTime: 5, Content: "Message 5"}, + {SendTime: 6, Content: "Message 6"}, + {SendTime: 7, Content: "Message 7"}, + }, + }, } for _, tt := range tests { - result := mergeSortedArrays(tt.arr1, tt.arr2, tt.n) + result := mergeSortedArrays(tt.arr1, tt.arr2, tt.n, tt.isDescending) if !reflect.DeepEqual(result, tt.expected) { t.Errorf( "mergeSortedArrays(%v, %v, %d) = %v; want %v",