752 lines
24 KiB
Go
752 lines
24 KiB
Go
package engine
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"m8sh.su/d/jules/chat"
|
|
"m8sh.su/d/jules/database"
|
|
"m8sh.su/d/jules/engine/actions"
|
|
"m8sh.su/d/jules/engine/trace"
|
|
"m8sh.su/d/jules/engine/user"
|
|
"m8sh.su/d/jules/llm"
|
|
)
|
|
|
|
type TestDB struct {
|
|
UsersDB *TestUsers
|
|
ChatsDB *TestChats
|
|
FactsDB *TestFacts
|
|
ContactsDB *TestContacts
|
|
NotificationsDB *TestNotifications
|
|
ActionsDB *TestActions
|
|
}
|
|
|
|
func (db *TestDB) Users() database.Users { return db.UsersDB }
|
|
func (db *TestDB) Chats() database.Chats { return db.ChatsDB }
|
|
func (db *TestDB) Facts() database.Facts { return db.FactsDB }
|
|
func (db *TestDB) Contacts() database.Contacts { return db.ContactsDB }
|
|
func (db *TestDB) Notifications() database.Notifications { return db.NotificationsDB }
|
|
func (db *TestDB) Actions() database.Actions { return db.ActionsDB }
|
|
func (db *TestDB) Close() error { return nil }
|
|
|
|
type TestUsers struct {
|
|
GetFunc func(ctx context.Context, id uuid.UUID, lookup database.UserLookup) (*database.User, error)
|
|
CreateFunc func(ctx context.Context, u *database.User) error
|
|
UpdateFunc func(ctx context.Context, u *database.User) error
|
|
DeleteFunc func(ctx context.Context, id uuid.UUID) error
|
|
}
|
|
|
|
func (u *TestUsers) Get(ctx context.Context, id uuid.UUID, lookup database.UserLookup) (*database.User, error) {
|
|
if u.GetFunc != nil {
|
|
return u.GetFunc(ctx, id, lookup)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (u *TestUsers) Create(ctx context.Context, user *database.User) error {
|
|
if u.CreateFunc != nil {
|
|
return u.CreateFunc(ctx, user)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (u *TestUsers) Update(ctx context.Context, user *database.User) error {
|
|
if u.UpdateFunc != nil {
|
|
return u.UpdateFunc(ctx, user)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (u *TestUsers) Delete(ctx context.Context, id uuid.UUID) error {
|
|
if u.DeleteFunc != nil {
|
|
return u.DeleteFunc(ctx, id)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type TestChats struct {
|
|
AttachFunc func(ctx context.Context, userID uuid.UUID, platform, identifier string) error
|
|
DetachFunc func(ctx context.Context, userID uuid.UUID, platform string) error
|
|
GetUserIDFunc func(ctx context.Context, platform, identifier string) (uuid.UUID, error)
|
|
ListFunc func(ctx context.Context, userID uuid.UUID) ([]database.Chat, error)
|
|
}
|
|
|
|
func (c *TestChats) Attach(ctx context.Context, userID uuid.UUID, platform, identifier string) error {
|
|
if c.AttachFunc != nil {
|
|
return c.AttachFunc(ctx, userID, platform, identifier)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *TestChats) Detach(ctx context.Context, userID uuid.UUID, platform string) error {
|
|
if c.DetachFunc != nil {
|
|
return c.DetachFunc(ctx, userID, platform)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *TestChats) GetUserID(ctx context.Context, platform, identifier string) (uuid.UUID, error) {
|
|
if c.GetUserIDFunc != nil {
|
|
return c.GetUserIDFunc(ctx, platform, identifier)
|
|
}
|
|
return uuid.Nil, nil
|
|
}
|
|
|
|
func (c *TestChats) List(ctx context.Context, userID uuid.UUID) ([]database.Chat, error) {
|
|
if c.ListFunc != nil {
|
|
return c.ListFunc(ctx, userID)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
type TestFacts struct {
|
|
AddFunc func(ctx context.Context, userID uuid.UUID, value string) error
|
|
ListFunc func(ctx context.Context, userID uuid.UUID) ([]database.Fact, error)
|
|
DeleteFunc func(ctx context.Context, userID uuid.UUID, value string) error
|
|
}
|
|
|
|
func (f *TestFacts) Add(ctx context.Context, userID uuid.UUID, value string) error {
|
|
if f.AddFunc != nil {
|
|
return f.AddFunc(ctx, userID, value)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (f *TestFacts) List(ctx context.Context, userID uuid.UUID) ([]database.Fact, error) {
|
|
if f.ListFunc != nil {
|
|
return f.ListFunc(ctx, userID)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (f *TestFacts) Delete(ctx context.Context, userID uuid.UUID, value string) error {
|
|
if f.DeleteFunc != nil {
|
|
return f.DeleteFunc(ctx, userID, value)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type TestContacts struct {
|
|
AddFunc func(ctx context.Context, contact *database.Contact) error
|
|
ListFunc func(ctx context.Context, ownerID uuid.UUID) ([]database.Contact, error)
|
|
DeleteFunc func(ctx context.Context, ownerID, targetID uuid.UUID) error
|
|
}
|
|
|
|
func (c *TestContacts) Add(ctx context.Context, contact *database.Contact) error {
|
|
if c.AddFunc != nil {
|
|
return c.AddFunc(ctx, contact)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *TestContacts) List(ctx context.Context, ownerID uuid.UUID) ([]database.Contact, error) {
|
|
if c.ListFunc != nil {
|
|
return c.ListFunc(ctx, ownerID)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (c *TestContacts) Delete(ctx context.Context, ownerID, targetID uuid.UUID) error {
|
|
if c.DeleteFunc != nil {
|
|
return c.DeleteFunc(ctx, ownerID, targetID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type TestNotifications struct {
|
|
PushFunc func(ctx context.Context, n *database.Notification) error
|
|
PopFunc func(ctx context.Context, limit int) ([]database.Notification, error)
|
|
ListFunc func(ctx context.Context, userID uuid.UUID) ([]database.Notification, error)
|
|
ListOutgoingFunc func(ctx context.Context, initiatorID uuid.UUID) ([]database.Notification, error)
|
|
DeleteFunc func(ctx context.Context, id uuid.UUID) error
|
|
}
|
|
|
|
func (n *TestNotifications) Push(ctx context.Context, notif *database.Notification) error {
|
|
if n.PushFunc != nil {
|
|
return n.PushFunc(ctx, notif)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (n *TestNotifications) Pop(ctx context.Context, limit int) ([]database.Notification, error) {
|
|
if n.PopFunc != nil {
|
|
return n.PopFunc(ctx, limit)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (n *TestNotifications) List(ctx context.Context, userID uuid.UUID) ([]database.Notification, error) {
|
|
if n.ListFunc != nil {
|
|
return n.ListFunc(ctx, userID)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (n *TestNotifications) ListOutgoing(ctx context.Context, initiatorID uuid.UUID) ([]database.Notification, error) {
|
|
if n.ListOutgoingFunc != nil {
|
|
return n.ListOutgoingFunc(ctx, initiatorID)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
func (n *TestNotifications) Delete(ctx context.Context, id uuid.UUID) error {
|
|
if n.DeleteFunc != nil {
|
|
return n.DeleteFunc(ctx, id)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type TestActions struct {
|
|
LogFunc func(ctx context.Context, userID uuid.UUID, typ, content string) error
|
|
RecentFunc func(ctx context.Context, userID uuid.UUID, limit int) ([]database.Action, error)
|
|
}
|
|
|
|
func (a *TestActions) Log(ctx context.Context, userID uuid.UUID, typ, content string) error {
|
|
if a.LogFunc != nil {
|
|
return a.LogFunc(ctx, userID, typ, content)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *TestActions) Recent(ctx context.Context, userID uuid.UUID, limit int) ([]database.Action, error) {
|
|
if a.RecentFunc != nil {
|
|
return a.RecentFunc(ctx, userID, limit)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
type TestSearcher struct {
|
|
SearchFunc func(ctx context.Context, query string) (string, error)
|
|
}
|
|
|
|
func (s *TestSearcher) Search(ctx context.Context, query string) (string, error) {
|
|
if s.SearchFunc != nil {
|
|
return s.SearchFunc(ctx, query)
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func NewTestRuntime() *actions.Runtime {
|
|
u := &user.User{
|
|
User: &database.User{
|
|
ID: uuid.New(),
|
|
Language: "en",
|
|
Timezone: "UTC",
|
|
PreferredChat: "telegram",
|
|
BindCode: uuid.New(),
|
|
ContactCode: uuid.New(),
|
|
Role: "free",
|
|
},
|
|
Chats: []database.Chat{{Platform: "telegram", Identifier: "123"}},
|
|
Facts: []database.Fact{{Value: "SYSTEM: COMPLETED INTEGRATION"}},
|
|
}
|
|
|
|
return &actions.Runtime{
|
|
User: u,
|
|
Database: &TestDB{
|
|
UsersDB: &TestUsers{},
|
|
ChatsDB: &TestChats{},
|
|
FactsDB: &TestFacts{},
|
|
ContactsDB: &TestContacts{},
|
|
NotificationsDB: &TestNotifications{},
|
|
ActionsDB: &TestActions{},
|
|
},
|
|
Searcher: &TestSearcher{},
|
|
Chats: make(map[string]chat.Chat),
|
|
}
|
|
}
|
|
|
|
type TestChat struct {
|
|
SendFunc func(ctx context.Context, id, text string) error
|
|
}
|
|
|
|
func (c *TestChat) Send(ctx context.Context, id, text string) error {
|
|
if c.SendFunc != nil {
|
|
return c.SendFunc(ctx, id, text)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (c *TestChat) Receive(ctx context.Context) <-chan chat.Message { return nil }
|
|
|
|
type testLLM struct {
|
|
llm.LLM
|
|
processFunc func(ctx context.Context, prompt string) (string, error)
|
|
}
|
|
|
|
func (t *testLLM) Process(ctx context.Context, prompt string) (string, error) {
|
|
if t.processFunc != nil {
|
|
return t.processFunc(ctx, prompt)
|
|
}
|
|
return "", nil
|
|
}
|
|
|
|
func setupTestEngine(t *testing.T) (*Engine, *actions.Runtime) {
|
|
t.Helper()
|
|
rt := NewTestRuntime()
|
|
rt.User.User = &database.User{
|
|
ID: uuid.New(),
|
|
Language: "en",
|
|
Timezone: "UTC",
|
|
PreferredChat: "telegram",
|
|
BindCode: uuid.New(),
|
|
ContactCode: uuid.New(),
|
|
Role: "free",
|
|
}
|
|
|
|
e := &Engine{
|
|
Parameters: &Parameters{
|
|
LLMRetryAttempts: 1,
|
|
Database: rt.Database,
|
|
LLM: &testLLM{},
|
|
Searcher: rt.Searcher,
|
|
Chats: rt.Chats,
|
|
},
|
|
}
|
|
return e, rt
|
|
}
|
|
|
|
func TestDefaultProcessMessage(t *testing.T) {
|
|
t.Run("successful processing", func(t *testing.T) {
|
|
e, rt := setupTestEngine(t)
|
|
|
|
msg := chat.Message{Chat: "telegram", ID: "123", Text: "hello"}
|
|
|
|
rt.Database.(*TestDB).ChatsDB.GetUserIDFunc = func(ctx context.Context, platform, identifier string) (uuid.UUID, error) {
|
|
return rt.User.ID, nil
|
|
}
|
|
rt.Database.(*TestDB).UsersDB.GetFunc = func(ctx context.Context, id uuid.UUID, lookup database.UserLookup) (*database.User, error) {
|
|
return rt.User.User, nil
|
|
}
|
|
rt.Database.(*TestDB).ChatsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Chat, error) {
|
|
return []database.Chat{{Platform: "telegram", Identifier: "123"}}, nil
|
|
}
|
|
rt.Database.(*TestDB).FactsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Fact, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ContactsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Contact, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).NotificationsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Notification, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).NotificationsDB.ListOutgoingFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Notification, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ActionsDB.RecentFunc = func(ctx context.Context, userID uuid.UUID, limit int) ([]database.Action, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ActionsDB.LogFunc = func(ctx context.Context, userID uuid.UUID, typ, content string) error {
|
|
return nil
|
|
}
|
|
|
|
llmCalled := false
|
|
e.LLM = &testLLM{
|
|
processFunc: func(ctx context.Context, prompt string) (string, error) {
|
|
llmCalled = true
|
|
return `[{"type":"message","platform":"telegram","text":"hi"}]`, nil
|
|
},
|
|
}
|
|
rt.Chats["telegram"] = &TestChat{
|
|
SendFunc: func(ctx context.Context, id, text string) error {
|
|
return nil
|
|
},
|
|
}
|
|
|
|
e.defaultProcessMessage(t.Context(), msg)
|
|
assert.True(t, llmCalled)
|
|
})
|
|
|
|
t.Run("message too long", func(t *testing.T) {
|
|
e, rt := setupTestEngine(t)
|
|
|
|
longMsg := strings.Repeat("a", 4001)
|
|
msg := chat.Message{Chat: "telegram", ID: "123", Text: longMsg}
|
|
|
|
sendCalled := false
|
|
rt.Chats["telegram"] = &TestChat{
|
|
SendFunc: func(ctx context.Context, id, text string) error {
|
|
assert.Equal(t, "123", id)
|
|
assert.Contains(t, text, "ERROR: message too long")
|
|
sendCalled = true
|
|
return nil
|
|
},
|
|
}
|
|
|
|
e.defaultProcessMessage(t.Context(), msg)
|
|
assert.True(t, sendCalled)
|
|
})
|
|
|
|
t.Run("message too long send fails", func(t *testing.T) {
|
|
e, rt := setupTestEngine(t)
|
|
|
|
longMsg := strings.Repeat("a", 4001)
|
|
msg := chat.Message{Chat: "telegram", ID: "123", Text: longMsg}
|
|
|
|
rt.Chats["telegram"] = &TestChat{
|
|
SendFunc: func(ctx context.Context, id, text string) error {
|
|
return errors.New("network error")
|
|
},
|
|
}
|
|
|
|
e.defaultProcessMessage(t.Context(), msg)
|
|
})
|
|
|
|
t.Run("failed to get user", func(t *testing.T) {
|
|
e, rt := setupTestEngine(t)
|
|
msg := chat.Message{Chat: "telegram", ID: "123", Text: "hello"}
|
|
|
|
rt.Database.(*TestDB).ChatsDB.GetUserIDFunc = func(ctx context.Context, platform, identifier string) (uuid.UUID, error) {
|
|
return uuid.Nil, errors.New("db error")
|
|
}
|
|
|
|
e.defaultProcessMessage(t.Context(), msg)
|
|
})
|
|
|
|
t.Run("failed to log action", func(t *testing.T) {
|
|
e, rt := setupTestEngine(t)
|
|
msg := chat.Message{Chat: "telegram", ID: "123", Text: "hello"}
|
|
|
|
rt.Database.(*TestDB).ChatsDB.GetUserIDFunc = func(ctx context.Context, platform, identifier string) (uuid.UUID, error) {
|
|
return rt.User.ID, nil
|
|
}
|
|
rt.Database.(*TestDB).UsersDB.GetFunc = func(ctx context.Context, id uuid.UUID, lookup database.UserLookup) (*database.User, error) {
|
|
return rt.User.User, nil
|
|
}
|
|
rt.Database.(*TestDB).ChatsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Chat, error) {
|
|
return []database.Chat{{Platform: "telegram", Identifier: "123"}}, nil
|
|
}
|
|
rt.Database.(*TestDB).FactsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Fact, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ContactsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Contact, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).NotificationsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Notification, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).NotificationsDB.ListOutgoingFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Notification, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ActionsDB.RecentFunc = func(ctx context.Context, userID uuid.UUID, limit int) ([]database.Action, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ActionsDB.LogFunc = func(ctx context.Context, userID uuid.UUID, typ, content string) error {
|
|
return errors.New("log failed")
|
|
}
|
|
|
|
e.defaultProcessMessage(t.Context(), msg)
|
|
})
|
|
|
|
t.Run("empty message skipped", func(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
|
|
|
e, _ := setupTestEngine(t)
|
|
msg := chat.Message{Chat: "telegram", ID: "123", Text: ""}
|
|
|
|
e.defaultProcessMessage(t.Context(), msg)
|
|
assert.Contains(t, buf.String(), "skipping empty message")
|
|
})
|
|
}
|
|
|
|
func TestDefaultProcessNotification(t *testing.T) {
|
|
t.Run("successful processing", func(t *testing.T) {
|
|
e, rt := setupTestEngine(t)
|
|
notif := database.Notification{UserID: rt.User.ID, Content: "test", RepeatOn: "daily, morning"}
|
|
|
|
rt.Database.(*TestDB).UsersDB.GetFunc = func(ctx context.Context, id uuid.UUID, lookup database.UserLookup) (*database.User, error) {
|
|
return rt.User.User, nil
|
|
}
|
|
rt.Database.(*TestDB).ChatsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Chat, error) {
|
|
return []database.Chat{{Platform: "telegram", Identifier: "123"}}, nil
|
|
}
|
|
rt.Database.(*TestDB).FactsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Fact, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ContactsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Contact, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).NotificationsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Notification, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).NotificationsDB.ListOutgoingFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Notification, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ActionsDB.RecentFunc = func(ctx context.Context, userID uuid.UUID, limit int) ([]database.Action, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ActionsDB.LogFunc = func(ctx context.Context, userID uuid.UUID, typ, content string) error {
|
|
return nil
|
|
}
|
|
|
|
llmCalled := false
|
|
e.LLM = &testLLM{
|
|
processFunc: func(ctx context.Context, prompt string) (string, error) {
|
|
llmCalled = true
|
|
return `[{"type":"message","platform":"telegram","text":"ok"}]`, nil
|
|
},
|
|
}
|
|
rt.Chats["telegram"] = &TestChat{
|
|
SendFunc: func(ctx context.Context, id, text string) error {
|
|
return nil
|
|
},
|
|
}
|
|
|
|
e.defaultProcessNotification(t.Context(), notif)
|
|
assert.True(t, llmCalled)
|
|
})
|
|
|
|
t.Run("failed to get user", func(t *testing.T) {
|
|
e, rt := setupTestEngine(t)
|
|
notif := database.Notification{UserID: rt.User.ID, Content: "test"}
|
|
|
|
rt.Database.(*TestDB).UsersDB.GetFunc = func(ctx context.Context, id uuid.UUID, lookup database.UserLookup) (*database.User, error) {
|
|
return nil, errors.New("db error")
|
|
}
|
|
|
|
e.defaultProcessNotification(t.Context(), notif)
|
|
})
|
|
|
|
t.Run("failed to log action", func(t *testing.T) {
|
|
e, rt := setupTestEngine(t)
|
|
notif := database.Notification{UserID: rt.User.ID, Content: "test"}
|
|
|
|
rt.Database.(*TestDB).UsersDB.GetFunc = func(ctx context.Context, id uuid.UUID, lookup database.UserLookup) (*database.User, error) {
|
|
return rt.User.User, nil
|
|
}
|
|
rt.Database.(*TestDB).ChatsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Chat, error) {
|
|
return []database.Chat{{Platform: "telegram", Identifier: "123"}}, nil
|
|
}
|
|
rt.Database.(*TestDB).FactsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Fact, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ContactsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Contact, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).NotificationsDB.ListFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Notification, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).NotificationsDB.ListOutgoingFunc = func(ctx context.Context, userID uuid.UUID) ([]database.Notification, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ActionsDB.RecentFunc = func(ctx context.Context, userID uuid.UUID, limit int) ([]database.Action, error) {
|
|
return nil, nil
|
|
}
|
|
rt.Database.(*TestDB).ActionsDB.LogFunc = func(ctx context.Context, userID uuid.UUID, typ, content string) error {
|
|
return errors.New("log failed")
|
|
}
|
|
|
|
e.defaultProcessNotification(t.Context(), notif)
|
|
})
|
|
}
|
|
|
|
func TestProcess(t *testing.T) {
|
|
t.Run("successful on first attempt", func(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
|
|
|
e, rt := setupTestEngine(t)
|
|
span := trace.FromMessage(t.Context(), chat.Message{Chat: "telegram", ID: "123", Text: "hello"})
|
|
u := &user.User{User: rt.User.User}
|
|
u.Chats = []database.Chat{{Platform: "telegram", Identifier: "123"}}
|
|
|
|
e.LLM = &testLLM{
|
|
processFunc: func(ctx context.Context, prompt string) (string, error) {
|
|
return `[{"type":"message","platform":"telegram","text":"hi"}]`, nil
|
|
},
|
|
}
|
|
rt.Chats["telegram"] = &TestChat{SendFunc: func(ctx context.Context, id, text string) error { return nil }}
|
|
rt.Database.(*TestDB).ActionsDB.LogFunc = func(ctx context.Context, userID uuid.UUID, typ, content string) error {
|
|
return nil
|
|
}
|
|
|
|
e.process(t.Context(), span, u, "telegram", "hello")
|
|
assert.Contains(t, buf.String(), "successfully processed")
|
|
})
|
|
|
|
t.Run("llm fails then succeeds", func(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
|
|
|
e, rt := setupTestEngine(t)
|
|
e.LLMRetryAttempts = 2
|
|
span := trace.FromMessage(t.Context(), chat.Message{Chat: "telegram", ID: "123", Text: "hello"})
|
|
u := &user.User{User: rt.User.User}
|
|
u.Chats = []database.Chat{{Platform: "telegram", Identifier: "123"}}
|
|
|
|
attempts := 0
|
|
e.LLM = &testLLM{
|
|
processFunc: func(ctx context.Context, prompt string) (string, error) {
|
|
attempts++
|
|
if attempts == 1 {
|
|
return "", errors.New("llm down")
|
|
}
|
|
return `[{"type":"message","platform":"telegram","text":"hi"}]`, nil
|
|
},
|
|
}
|
|
rt.Chats["telegram"] = &TestChat{SendFunc: func(ctx context.Context, id, text string) error { return nil }}
|
|
rt.Database.(*TestDB).ActionsDB.LogFunc = func(ctx context.Context, userID uuid.UUID, typ, content string) error {
|
|
return nil
|
|
}
|
|
|
|
e.process(t.Context(), span, u, "telegram", "hello")
|
|
assert.Contains(t, buf.String(), "failed to receive LLM response")
|
|
assert.Contains(t, buf.String(), "successfully processed")
|
|
})
|
|
|
|
t.Run("parse fails then succeeds", func(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
|
|
|
e, rt := setupTestEngine(t)
|
|
e.LLMRetryAttempts = 2
|
|
span := trace.FromMessage(t.Context(), chat.Message{Chat: "telegram", ID: "123", Text: "hello"})
|
|
u := &user.User{User: rt.User.User}
|
|
u.Chats = []database.Chat{{Platform: "telegram", Identifier: "123"}}
|
|
|
|
attempts := 0
|
|
e.LLM = &testLLM{
|
|
processFunc: func(ctx context.Context, prompt string) (string, error) {
|
|
attempts++
|
|
if attempts == 1 {
|
|
return "invalid json", nil
|
|
}
|
|
return `[{"type":"message","platform":"telegram","text":"hi"}]`, nil
|
|
},
|
|
}
|
|
rt.Chats["telegram"] = &TestChat{SendFunc: func(ctx context.Context, id, text string) error { return nil }}
|
|
rt.Database.(*TestDB).ActionsDB.LogFunc = func(ctx context.Context, userID uuid.UUID, typ, content string) error {
|
|
return nil
|
|
}
|
|
|
|
e.process(t.Context(), span, u, "telegram", "hello")
|
|
assert.Contains(t, buf.String(), "failed to parse actions")
|
|
assert.Contains(t, buf.String(), "successfully processed")
|
|
})
|
|
|
|
t.Run("validation fails and continues", func(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
|
|
|
e, rt := setupTestEngine(t)
|
|
e.LLMRetryAttempts = 2
|
|
span := trace.FromMessage(t.Context(), chat.Message{Chat: "telegram", ID: "123", Text: "hello"})
|
|
u := &user.User{User: rt.User.User}
|
|
|
|
attempts := 0
|
|
e.LLM = &testLLM{
|
|
processFunc: func(ctx context.Context, prompt string) (string, error) {
|
|
attempts++
|
|
if attempts == 1 {
|
|
return `[{"type":"add_fact","value":""}]`, nil
|
|
}
|
|
return `[{"type":"add_fact","value":"test"}]`, nil
|
|
},
|
|
}
|
|
rt.Database.(*TestDB).FactsDB.AddFunc = func(ctx context.Context, userID uuid.UUID, value string) error {
|
|
return nil
|
|
}
|
|
rt.Database.(*TestDB).ActionsDB.LogFunc = func(ctx context.Context, userID uuid.UUID, typ, content string) error {
|
|
return nil
|
|
}
|
|
|
|
e.process(t.Context(), span, u, "telegram", "hello")
|
|
assert.Contains(t, buf.String(), "failed to validate actions")
|
|
assert.Contains(t, buf.String(), "successfully processed")
|
|
})
|
|
|
|
t.Run("execute fails", func(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
|
|
|
e, rt := setupTestEngine(t)
|
|
e.LLMRetryAttempts = 1
|
|
span := trace.FromMessage(t.Context(), chat.Message{Chat: "telegram", ID: "123", Text: "hello"})
|
|
u := &user.User{User: rt.User.User}
|
|
u.Chats = []database.Chat{{Platform: "telegram", Identifier: "123"}}
|
|
|
|
e.LLM = &testLLM{
|
|
processFunc: func(ctx context.Context, prompt string) (string, error) {
|
|
return `[{"type":"message","platform":"telegram","text":"hi"}]`, nil
|
|
},
|
|
}
|
|
rt.Chats["telegram"] = &TestChat{SendFunc: func(ctx context.Context, id, text string) error {
|
|
return errors.New("send failed")
|
|
}}
|
|
rt.Database.(*TestDB).ActionsDB.LogFunc = func(ctx context.Context, userID uuid.UUID, typ, content string) error {
|
|
return nil
|
|
}
|
|
|
|
e.process(t.Context(), span, u, "telegram", "hello")
|
|
assert.Contains(t, buf.String(), "failed to execute actions")
|
|
})
|
|
|
|
t.Run("all attempts fail", func(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{Level: slog.LevelInfo})))
|
|
|
|
e, rt := setupTestEngine(t)
|
|
e.LLMRetryAttempts = 2
|
|
span := trace.FromMessage(t.Context(), chat.Message{Chat: "telegram", ID: "123", Text: "hello"})
|
|
u := &user.User{User: rt.User.User}
|
|
|
|
e.LLM = &testLLM{
|
|
processFunc: func(ctx context.Context, prompt string) (string, error) {
|
|
return "", errors.New("llm down")
|
|
},
|
|
}
|
|
|
|
e.process(t.Context(), span, u, "telegram", "hello")
|
|
assert.Contains(t, buf.String(), "all attempts to process event have failed")
|
|
})
|
|
}
|
|
|
|
func TestLastError(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
params []any
|
|
want error
|
|
}{
|
|
{
|
|
name: "empty",
|
|
params: nil,
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "last is error",
|
|
params: []any{1, "string", errors.New("fail")},
|
|
want: errors.New("fail"),
|
|
},
|
|
{
|
|
name: "last is not error",
|
|
params: []any{1, "string", 42},
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "single error",
|
|
params: []any{errors.New("fail")},
|
|
want: errors.New("fail"),
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := lastError(tt.params)
|
|
if tt.want == nil {
|
|
assert.Nil(t, got)
|
|
} else {
|
|
require.NotNil(t, got)
|
|
assert.Equal(t, tt.want.Error(), got.Error())
|
|
}
|
|
})
|
|
}
|
|
}
|