602 lines
15 KiB
Go
602 lines
15 KiB
Go
package engine
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"m8sh.su/d/jules/chat"
|
|
"m8sh.su/d/jules/database"
|
|
"m8sh.su/d/jules/llm"
|
|
"m8sh.su/d/jules/search"
|
|
)
|
|
|
|
type mockChat struct {
|
|
sendErr error
|
|
receiveCh chan chat.Message
|
|
receiveErr error
|
|
}
|
|
|
|
func (m *mockChat) Send(_ context.Context, id, text string) error { return m.sendErr }
|
|
func (m *mockChat) Receive(ctx context.Context) <-chan chat.Message {
|
|
if m.receiveErr != nil {
|
|
ch := make(chan chat.Message)
|
|
close(ch)
|
|
return ch
|
|
}
|
|
return m.receiveCh
|
|
}
|
|
|
|
type mockDB struct {
|
|
database.Database
|
|
notifications *mockNotifications
|
|
closeErr error
|
|
}
|
|
|
|
func (m *mockDB) Notifications() database.Notifications { return m.notifications }
|
|
func (m *mockDB) Close() error { return m.closeErr }
|
|
|
|
type mockNotifications struct {
|
|
database.Notifications
|
|
popFunc func(ctx context.Context, limit int) ([]database.Notification, error)
|
|
}
|
|
|
|
func (m *mockNotifications) Pop(ctx context.Context, limit int) ([]database.Notification, error) {
|
|
if m.popFunc != nil {
|
|
return m.popFunc(ctx, limit)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
type mockLLM struct{ llm.LLM }
|
|
|
|
type mockSearcher struct{ search.Searcher }
|
|
|
|
func TestNew(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
params *Parameters
|
|
wantErr bool
|
|
errMsg string
|
|
}{
|
|
{
|
|
name: "nil params",
|
|
params: nil,
|
|
wantErr: true,
|
|
errMsg: "params can't be nil",
|
|
},
|
|
{
|
|
name: "database is nil",
|
|
params: &Parameters{
|
|
LLM: &mockLLM{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{}},
|
|
},
|
|
wantErr: true,
|
|
errMsg: "database can't be nil",
|
|
},
|
|
{
|
|
name: "llm is nil",
|
|
params: &Parameters{
|
|
Database: &mockDB{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{}},
|
|
},
|
|
wantErr: true,
|
|
errMsg: "llm can't be nil",
|
|
},
|
|
{
|
|
name: "searcher is nil",
|
|
params: &Parameters{
|
|
Database: &mockDB{},
|
|
LLM: &mockLLM{},
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{}},
|
|
},
|
|
wantErr: true,
|
|
errMsg: "seach engine can't be nil",
|
|
},
|
|
{
|
|
name: "chats empty",
|
|
params: &Parameters{
|
|
Database: &mockDB{},
|
|
LLM: &mockLLM{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{},
|
|
},
|
|
wantErr: true,
|
|
errMsg: "chats can't be empty",
|
|
},
|
|
{
|
|
name: "chat is nil",
|
|
params: &Parameters{
|
|
Database: &mockDB{},
|
|
LLM: &mockLLM{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{"telegram": nil},
|
|
},
|
|
wantErr: true,
|
|
errMsg: "telegram initialized as nil",
|
|
},
|
|
{
|
|
name: "platform chat name empty",
|
|
params: &Parameters{
|
|
Database: &mockDB{},
|
|
LLM: &mockLLM{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{"": &mockChat{}},
|
|
},
|
|
wantErr: true,
|
|
errMsg: "platform name can't be empty",
|
|
},
|
|
{
|
|
name: "valid params with defaults",
|
|
params: &Parameters{
|
|
Database: &mockDB{},
|
|
LLM: &mockLLM{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{}},
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "NumWorkers zero gets default",
|
|
params: &Parameters{
|
|
NumWorkers: 0,
|
|
Database: &mockDB{},
|
|
LLM: &mockLLM{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{}},
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "ProcessingTimeout zero gets default",
|
|
params: &Parameters{
|
|
ProcessingTimeout: 0,
|
|
Database: &mockDB{},
|
|
LLM: &mockLLM{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{}},
|
|
},
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "LLMRetryAttempts zero gets default",
|
|
params: &Parameters{
|
|
LLMRetryAttempts: 0,
|
|
Database: &mockDB{},
|
|
LLM: &mockLLM{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{}},
|
|
},
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
e, err := New(tt.params)
|
|
if tt.wantErr {
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), tt.errMsg)
|
|
return
|
|
}
|
|
assert.NoError(t, err)
|
|
assert.NotNil(t, e)
|
|
|
|
if tt.params.NumWorkers == 0 {
|
|
assert.Equal(t, 100, e.NumWorkers)
|
|
}
|
|
if tt.params.ProcessingTimeout == 0 {
|
|
assert.Equal(t, 90*time.Second, e.ProcessingTimeout)
|
|
}
|
|
if tt.params.LLMRetryAttempts == 0 {
|
|
assert.Equal(t, 3, e.LLMRetryAttempts)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRun(t *testing.T) {
|
|
t.Run("full cycle with graceful shutdown", func(t *testing.T) {
|
|
ch := make(chan chat.Message)
|
|
close(ch)
|
|
|
|
db := &mockDB{
|
|
notifications: &mockNotifications{
|
|
popFunc: func(ctx context.Context, limit int) ([]database.Notification, error) {
|
|
return nil, nil
|
|
},
|
|
},
|
|
}
|
|
|
|
e := &Engine{
|
|
Parameters: &Parameters{
|
|
NumWorkers: 2,
|
|
Database: db,
|
|
LLM: &mockLLM{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{receiveCh: ch}},
|
|
DatabasePollingDuration: 10 * time.Millisecond,
|
|
NotificationBatchSize: 10,
|
|
ProcessingTimeout: time.Second,
|
|
LLMRetryAttempts: 1,
|
|
},
|
|
}
|
|
e.processMessage = func(ctx context.Context, msg chat.Message) {}
|
|
e.processNotification = func(ctx context.Context, n database.Notification) {}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
done := make(chan error)
|
|
go func() {
|
|
done <- e.Run(ctx)
|
|
}()
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
cancel()
|
|
|
|
err := <-done
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("database close error", func(t *testing.T) {
|
|
ch := make(chan chat.Message)
|
|
close(ch)
|
|
|
|
db := &mockDB{
|
|
notifications: &mockNotifications{},
|
|
closeErr: errors.New("close failed"),
|
|
}
|
|
|
|
e := &Engine{
|
|
Parameters: &Parameters{
|
|
NumWorkers: 1,
|
|
Database: db,
|
|
LLM: &mockLLM{},
|
|
Searcher: &mockSearcher{},
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{receiveCh: ch}},
|
|
DatabasePollingDuration: 10 * time.Millisecond,
|
|
NotificationBatchSize: 10,
|
|
ProcessingTimeout: time.Second,
|
|
LLMRetryAttempts: 1,
|
|
},
|
|
}
|
|
e.processMessage = func(ctx context.Context, msg chat.Message) {}
|
|
e.processNotification = func(ctx context.Context, n database.Notification) {}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
done := make(chan error)
|
|
go func() {
|
|
done <- e.Run(ctx)
|
|
}()
|
|
|
|
time.Sleep(30 * time.Millisecond)
|
|
cancel()
|
|
|
|
err := <-done
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "failed to close the database")
|
|
})
|
|
}
|
|
|
|
func TestConsumeChatMessages(t *testing.T) {
|
|
t.Run("multiplexes messages", func(t *testing.T) {
|
|
ch1 := make(chan chat.Message, 1)
|
|
ch2 := make(chan chat.Message, 1)
|
|
ch1 <- chat.Message{ID: "1", Text: "first"}
|
|
ch2 <- chat.Message{ID: "2", Text: "second"}
|
|
close(ch1)
|
|
close(ch2)
|
|
|
|
e := &Engine{
|
|
Parameters: &Parameters{
|
|
Chats: map[string]chat.Chat{
|
|
"telegram": &mockChat{receiveCh: ch1},
|
|
"whatsapp": &mockChat{receiveCh: ch2},
|
|
},
|
|
},
|
|
}
|
|
|
|
ctx := context.Background()
|
|
out := e.consumeChatMessages(ctx)
|
|
|
|
var msgs []chat.Message
|
|
for msg := range out {
|
|
msgs = append(msgs, msg)
|
|
}
|
|
assert.Len(t, msgs, 2)
|
|
})
|
|
|
|
t.Run("context cancel stops receiver", func(t *testing.T) {
|
|
ch := make(chan chat.Message)
|
|
defer close(ch)
|
|
|
|
e := &Engine{
|
|
Parameters: &Parameters{
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{receiveCh: ch}},
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
out := e.consumeChatMessages(ctx)
|
|
cancel()
|
|
|
|
select {
|
|
case _, ok := <-out:
|
|
assert.False(t, ok)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("channel not closed")
|
|
}
|
|
})
|
|
|
|
t.Run("context done during message forward", func(t *testing.T) {
|
|
ch := make(chan chat.Message, 1)
|
|
ch <- chat.Message{ID: "1", Text: "hello"}
|
|
|
|
e := &Engine{
|
|
Parameters: &Parameters{
|
|
Chats: map[string]chat.Chat{"telegram": &mockChat{receiveCh: ch}},
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
out := e.consumeChatMessages(ctx)
|
|
|
|
time.Sleep(50 * time.Millisecond)
|
|
cancel()
|
|
|
|
select {
|
|
case _, ok := <-out:
|
|
assert.False(t, ok)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("channel not closed")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestConsumeNotifications(t *testing.T) {
|
|
t.Run("context done during notification emit", func(t *testing.T) {
|
|
mockNotifs := &mockNotifications{
|
|
popFunc: func(ctx context.Context, limit int) ([]database.Notification, error) {
|
|
return []database.Notification{
|
|
{ID: uuid.New(), Content: "test"},
|
|
{ID: uuid.New(), Content: "test2"},
|
|
}, nil
|
|
},
|
|
}
|
|
|
|
e := &Engine{
|
|
Parameters: &Parameters{
|
|
Database: &mockDB{notifications: mockNotifs},
|
|
NotificationBatchSize: 10,
|
|
DatabasePollingDuration: 10 * time.Millisecond,
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
out := e.consumeNotifications(ctx)
|
|
|
|
<-out
|
|
cancel()
|
|
|
|
select {
|
|
case _, ok := <-out:
|
|
assert.False(t, ok)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("channel not closed")
|
|
}
|
|
})
|
|
|
|
t.Run("empty pop breaks inner loop", func(t *testing.T) {
|
|
popCount := 0
|
|
mockNotifs := &mockNotifications{
|
|
popFunc: func(ctx context.Context, limit int) ([]database.Notification, error) {
|
|
popCount++
|
|
if popCount == 1 {
|
|
return []database.Notification{{ID: uuid.New(), Content: "first"}}, nil
|
|
}
|
|
return nil, nil
|
|
},
|
|
}
|
|
|
|
e := &Engine{
|
|
Parameters: &Parameters{
|
|
Database: &mockDB{notifications: mockNotifs},
|
|
NotificationBatchSize: 10,
|
|
DatabasePollingDuration: 10 * time.Millisecond,
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
out := e.consumeNotifications(ctx)
|
|
|
|
msg := <-out
|
|
assert.Equal(t, "first", msg.Content)
|
|
|
|
select {
|
|
case <-out:
|
|
t.Fatal("unexpected second message")
|
|
case <-time.After(50 * time.Millisecond):
|
|
}
|
|
})
|
|
|
|
t.Run("pop error logged", func(t *testing.T) {
|
|
mockNotifs := &mockNotifications{
|
|
popFunc: func(ctx context.Context, limit int) ([]database.Notification, error) {
|
|
return nil, errors.New("db connection lost")
|
|
},
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
logger := slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelError}))
|
|
slog.SetDefault(logger)
|
|
|
|
e := &Engine{
|
|
Parameters: &Parameters{
|
|
Database: &mockDB{notifications: mockNotifs},
|
|
NotificationBatchSize: 10,
|
|
DatabasePollingDuration: 10 * time.Millisecond,
|
|
},
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(t.Context())
|
|
defer cancel()
|
|
|
|
out := e.consumeNotifications(ctx)
|
|
|
|
select {
|
|
case _, ok := <-out:
|
|
assert.False(t, ok)
|
|
case <-time.After(100 * time.Millisecond):
|
|
}
|
|
|
|
assert.Contains(t, buf.String(), "failed to pop notifications")
|
|
assert.Contains(t, buf.String(), "db connection lost")
|
|
})
|
|
}
|
|
|
|
func TestRunWorkers(t *testing.T) {
|
|
t.Run("processMessage called for message", func(t *testing.T) {
|
|
messages := make(chan chat.Message, 1)
|
|
notifications := make(chan database.Notification)
|
|
|
|
var called atomic.Bool
|
|
ready := make(chan struct{})
|
|
|
|
e := &Engine{Parameters: &Parameters{NumWorkers: 1}}
|
|
e.processMessage = func(ctx context.Context, msg chat.Message) {
|
|
called.Store(true)
|
|
}
|
|
e.processNotification = func(ctx context.Context, n database.Notification) {}
|
|
|
|
go func() {
|
|
ready <- struct{}{}
|
|
e.runWorkers(messages, notifications)
|
|
}()
|
|
|
|
<-ready
|
|
time.Sleep(5 * time.Millisecond)
|
|
|
|
messages <- chat.Message{ID: "123", Text: "hello"}
|
|
close(messages)
|
|
close(notifications)
|
|
|
|
time.Sleep(20 * time.Millisecond)
|
|
assert.True(t, called.Load())
|
|
})
|
|
|
|
t.Run("processNotification called for notification", func(t *testing.T) {
|
|
messages := make(chan chat.Message)
|
|
notifications := make(chan database.Notification, 1)
|
|
|
|
var called atomic.Bool
|
|
ready := make(chan struct{})
|
|
|
|
e := &Engine{Parameters: &Parameters{NumWorkers: 1}}
|
|
e.processMessage = func(ctx context.Context, msg chat.Message) {}
|
|
e.processNotification = func(ctx context.Context, n database.Notification) {
|
|
called.Store(true)
|
|
}
|
|
|
|
go func() {
|
|
ready <- struct{}{}
|
|
e.runWorkers(messages, notifications)
|
|
}()
|
|
|
|
<-ready
|
|
time.Sleep(5 * time.Millisecond)
|
|
|
|
notifications <- database.Notification{ID: uuid.New(), Content: "test"}
|
|
close(messages)
|
|
close(notifications)
|
|
|
|
time.Sleep(20 * time.Millisecond)
|
|
assert.True(t, called.Load())
|
|
})
|
|
|
|
t.Run("messages channel closed", func(t *testing.T) {
|
|
messages := make(chan chat.Message)
|
|
notifications := make(chan database.Notification)
|
|
close(messages)
|
|
|
|
var msgProcessed atomic.Bool
|
|
e := &Engine{Parameters: &Parameters{NumWorkers: 1}}
|
|
e.processMessage = func(ctx context.Context, msg chat.Message) {
|
|
msgProcessed.Store(true)
|
|
}
|
|
e.processNotification = func(ctx context.Context, n database.Notification) {}
|
|
|
|
go func() {
|
|
time.Sleep(50 * time.Millisecond)
|
|
close(notifications)
|
|
}()
|
|
|
|
e.runWorkers(messages, notifications)
|
|
assert.False(t, msgProcessed.Load())
|
|
})
|
|
|
|
t.Run("notifications channel closed", func(t *testing.T) {
|
|
messages := make(chan chat.Message)
|
|
notifications := make(chan database.Notification)
|
|
close(notifications)
|
|
|
|
var notifProcessed atomic.Bool
|
|
e := &Engine{Parameters: &Parameters{NumWorkers: 1}}
|
|
e.processMessage = func(ctx context.Context, msg chat.Message) {}
|
|
e.processNotification = func(ctx context.Context, n database.Notification) {
|
|
notifProcessed.Store(true)
|
|
}
|
|
|
|
go func() {
|
|
time.Sleep(50 * time.Millisecond)
|
|
close(messages)
|
|
}()
|
|
|
|
e.runWorkers(messages, notifications)
|
|
assert.False(t, notifProcessed.Load())
|
|
})
|
|
|
|
t.Run("both channels closed from start", func(t *testing.T) {
|
|
messages := make(chan chat.Message)
|
|
notifications := make(chan database.Notification)
|
|
close(messages)
|
|
close(notifications)
|
|
|
|
e := &Engine{Parameters: &Parameters{NumWorkers: 1}}
|
|
e.processMessage = func(ctx context.Context, msg chat.Message) {}
|
|
e.processNotification = func(ctx context.Context, n database.Notification) {}
|
|
|
|
e.runWorkers(messages, notifications)
|
|
})
|
|
}
|
|
|
|
func TestDefaultHandlers(t *testing.T) {
|
|
e := &Engine{}
|
|
|
|
calledMsg := false
|
|
calledNotif := false
|
|
|
|
e.processMessage = func(ctx context.Context, msg chat.Message) {
|
|
calledMsg = true
|
|
}
|
|
e.processNotification = func(ctx context.Context, n database.Notification) {
|
|
calledNotif = true
|
|
}
|
|
|
|
msg := chat.Message{ID: "123", Text: "hello"}
|
|
e.processMessage(context.Background(), msg)
|
|
assert.True(t, calledMsg)
|
|
|
|
notif := database.Notification{ID: uuid.New(), Content: "call mom"}
|
|
e.processNotification(context.Background(), notif)
|
|
assert.True(t, calledNotif)
|
|
}
|