almost finished database layer

This commit is contained in:
d1nch8g
2026-04-16 21:40:00 +03:00
parent ca80a9219e
commit 1c0bd3512e
17 changed files with 791 additions and 109 deletions
-7
View File
@@ -9,13 +9,6 @@ install:
gen: gen:
mkdir -p gen/migrations mkdir -p gen/migrations
go-bindata -pkg migrations -o gen/migrations/migrations.go -prefix "sql/migrations/" sql/migrations/... go-bindata -pkg migrations -o gen/migrations/migrations.go -prefix "sql/migrations/" sql/migrations/...
# mkdir -p gen/pb
# protoc -I proto \
# --go_out=gen/pb \
# --go_opt=paths=source_relative \
# --go-grpc_out=gen/pb \
# --go-grpc_opt=paths=source_relative \
# proto/jules.proto
fmt: fmt:
gofumpt -w . gofumpt -w .
+3 -1
View File
@@ -70,9 +70,9 @@ type Users interface {
// Chats manages chat persistence. // Chats manages chat persistence.
type Chats interface { type Chats interface {
Attach(ctx context.Context, userID uuid.UUID, platform, identifier string) error Attach(ctx context.Context, userID uuid.UUID, platform, identifier string) error
Detach(ctx context.Context, userID uuid.UUID, platform string) error
GetUserID(ctx context.Context, platform, identifier string) (uuid.UUID, error) GetUserID(ctx context.Context, platform, identifier string) (uuid.UUID, error)
List(ctx context.Context, userID uuid.UUID) ([]Chat, error) List(ctx context.Context, userID uuid.UUID) ([]Chat, error)
Detach(ctx context.Context, userID uuid.UUID, platform string) error
} }
// Facts manages facts persistence. // Facts manages facts persistence.
@@ -94,6 +94,8 @@ type Contacts interface {
type Notifications interface { type Notifications interface {
Push(ctx context.Context, n *Notification) error Push(ctx context.Context, n *Notification) error
Pop(ctx context.Context, limit int) ([]Notification, error) Pop(ctx context.Context, limit int) ([]Notification, error)
List(ctx context.Context, userID uuid.UUID) ([]Notification, error)
Delete(ctx context.Context, id uuid.UUID) error
} }
// Actions manages the action log. // Actions manages the action log.
+12 -8
View File
@@ -10,8 +10,12 @@ import (
"github.com/lib/pq" "github.com/lib/pq"
) )
func (db *DB) Attach(ctx context.Context, userID uuid.UUID, platform, identifier string) error { type Chats struct {
_, err := db.conn.ExecContext(ctx, ` conn *sql.DB
}
func (c *Chats) Attach(ctx context.Context, userID uuid.UUID, platform, identifier string) error {
_, err := c.conn.ExecContext(ctx, `
INSERT INTO chats (user_id, platform, identifier) INSERT INTO chats (user_id, platform, identifier)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
`, userID, platform, identifier) `, userID, platform, identifier)
@@ -24,9 +28,9 @@ func (db *DB) Attach(ctx context.Context, userID uuid.UUID, platform, identifier
return nil return nil
} }
func (db *DB) GetUserID(ctx context.Context, platform, identifier string) (uuid.UUID, error) { func (c *Chats) GetUserID(ctx context.Context, platform, identifier string) (uuid.UUID, error) {
var userID uuid.UUID var userID uuid.UUID
err := db.conn.QueryRowContext(ctx, ` err := c.conn.QueryRowContext(ctx, `
SELECT user_id SELECT user_id
FROM chats FROM chats
WHERE platform = $1 AND identifier = $2 WHERE platform = $1 AND identifier = $2
@@ -40,8 +44,8 @@ func (db *DB) GetUserID(ctx context.Context, platform, identifier string) (uuid.
return userID, nil return userID, nil
} }
func (db *DB) List(ctx context.Context, userID uuid.UUID) ([]database.Chat, error) { func (c *Chats) List(ctx context.Context, userID uuid.UUID) ([]database.Chat, error) {
rows, err := db.conn.QueryContext(ctx, ` rows, err := c.conn.QueryContext(ctx, `
SELECT user_id, platform, identifier SELECT user_id, platform, identifier
FROM chats FROM chats
WHERE user_id = $1 WHERE user_id = $1
@@ -65,8 +69,8 @@ func (db *DB) List(ctx context.Context, userID uuid.UUID) ([]database.Chat, erro
return chats, nil return chats, nil
} }
func (db *DB) Detach(ctx context.Context, userID uuid.UUID, platform string) error { func (c *Chats) Detach(ctx context.Context, userID uuid.UUID, platform string) error {
result, err := db.conn.ExecContext(ctx, ` result, err := c.conn.ExecContext(ctx, `
DELETE FROM chats DELETE FROM chats
WHERE user_id = $1 AND platform = $2 WHERE user_id = $1 AND platform = $2
`, userID, platform) `, userID, platform)
+36 -38
View File
@@ -9,28 +9,27 @@ import (
"github.com/d1nch8g/jules/database" "github.com/d1nch8g/jules/database"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestChats_Attach(t *testing.T) { func TestChats_Attach(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
user, err := db.Create(context.Background()) user, err := db.Users.Create(context.Background())
require.NoError(t, err) assert.NoError(t, err)
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
err := db.Attach(context.Background(), user.ID, "telegram", "@test_attach") err := db.Chats.Attach(context.Background(), user.ID, "telegram", "@test_attach")
require.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("already exists", func(t *testing.T) { t.Run("already exists", func(t *testing.T) {
err := db.Attach(context.Background(), user.ID, "telegram", "@test_attach") err := db.Chats.Attach(context.Background(), user.ID, "telegram", "@test_attach")
assert.ErrorIs(t, err, database.ErrAlreadyExists) assert.ErrorIs(t, err, database.ErrAlreadyExists)
}) })
t.Run("database error", func(t *testing.T) { t.Run("database error", func(t *testing.T) {
db2 := setupTestDB(t) db2 := setupTestDB(t)
db2.Close() db2.Close()
err := db2.Attach(context.Background(), user.ID, "whatsapp", "@test_attach") err := db2.Chats.Attach(context.Background(), user.ID, "whatsapp", "@test_attach")
assert.Error(t, err) assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrAlreadyExists) assert.NotErrorIs(t, err, database.ErrAlreadyExists)
}) })
@@ -38,28 +37,27 @@ func TestChats_Attach(t *testing.T) {
func TestChats_GetUserID(t *testing.T) { func TestChats_GetUserID(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
user, err := db.Create(context.Background()) user, err := db.Users.Create(context.Background())
require.NoError(t, err) assert.NoError(t, err)
err = db.Attach(context.Background(), user.ID, "telegram", "@test_get") err = db.Chats.Attach(context.Background(), user.ID, "telegram", "@test_get")
require.NoError(t, err) assert.NoError(t, err)
t.Run("found", func(t *testing.T) { t.Run("found", func(t *testing.T) {
got, err := db.GetUserID(context.Background(), "telegram", "@test_get") got, err := db.Chats.GetUserID(context.Background(), "telegram", "@test_get")
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, user.ID, got) assert.Equal(t, user.ID, got)
}) })
t.Run("not found", func(t *testing.T) { t.Run("not found", func(t *testing.T) {
_, err := db.GetUserID(context.Background(), "telegram", "@notfound") _, err := db.Chats.GetUserID(context.Background(), "telegram", "@notfound")
assert.ErrorIs(t, err, database.ErrNotFound) assert.ErrorIs(t, err, database.ErrNotFound)
}) })
t.Run("database error", func(t *testing.T) { t.Run("database error", func(t *testing.T) {
// Создаём свежее соединение и закрываем его
db2 := setupTestDB(t) db2 := setupTestDB(t)
db2.Close() db2.Close()
_, err := db2.GetUserID(context.Background(), "telegram", "@test_get") _, err := db2.Chats.GetUserID(context.Background(), "telegram", "@test_get")
assert.Error(t, err) assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrNotFound) assert.NotErrorIs(t, err, database.ErrNotFound)
}) })
@@ -67,51 +65,51 @@ func TestChats_GetUserID(t *testing.T) {
func TestChats_List(t *testing.T) { func TestChats_List(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
user, _ := db.Create(context.Background()) user, _ := db.Users.Create(context.Background())
db.Attach(context.Background(), user.ID, "telegram", "@test") db.Chats.Attach(context.Background(), user.ID, "telegram", "@test")
db.Attach(context.Background(), user.ID, "email", "test@example.com") db.Chats.Attach(context.Background(), user.ID, "email", "test@example.com")
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
chats, err := db.List(context.Background(), user.ID) chats, err := db.Chats.List(context.Background(), user.ID)
require.NoError(t, err) assert.NoError(t, err)
assert.Len(t, chats, 2) assert.Len(t, chats, 2)
}) })
t.Run("empty list", func(t *testing.T) { t.Run("empty list", func(t *testing.T) {
otherUser, _ := db.Create(context.Background()) otherUser, _ := db.Users.Create(context.Background())
chats, err := db.List(context.Background(), otherUser.ID) chats, err := db.Chats.List(context.Background(), otherUser.ID)
require.NoError(t, err) assert.NoError(t, err)
assert.Empty(t, chats) assert.Empty(t, chats)
}) })
t.Run("database error", func(t *testing.T) { t.Run("database error", func(t *testing.T) {
db.Close() db.Close()
_, err := db.List(context.Background(), user.ID) _, err := db.Chats.List(context.Background(), user.ID)
assert.Error(t, err) assert.Error(t, err)
}) })
} }
func TestChats_Detach(t *testing.T) { func TestChats_Detach(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
user, _ := db.Create(context.Background()) user, _ := db.Users.Create(context.Background())
db.Attach(context.Background(), user.ID, "telegram", "@test") db.Chats.Attach(context.Background(), user.ID, "telegram", "@test")
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
err := db.Detach(context.Background(), user.ID, "telegram") err := db.Chats.Detach(context.Background(), user.ID, "telegram")
require.NoError(t, err) assert.NoError(t, err)
chats, _ := db.List(context.Background(), user.ID) chats, _ := db.Chats.List(context.Background(), user.ID)
assert.Empty(t, chats) assert.Empty(t, chats)
}) })
t.Run("not found", func(t *testing.T) { t.Run("not found", func(t *testing.T) {
err := db.Detach(context.Background(), user.ID, "telegram") err := db.Chats.Detach(context.Background(), user.ID, "telegram")
assert.ErrorIs(t, err, database.ErrNotFound) assert.ErrorIs(t, err, database.ErrNotFound)
}) })
t.Run("database error", func(t *testing.T) { t.Run("database error", func(t *testing.T) {
db.Close() db.Close()
err := db.Detach(context.Background(), user.ID, "telegram") err := db.Chats.Detach(context.Background(), user.ID, "telegram")
assert.Error(t, err) assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrNotFound) assert.NotErrorIs(t, err, database.ErrNotFound)
}) })
@@ -119,18 +117,18 @@ func TestChats_Detach(t *testing.T) {
func TestChats_List_ScanError(t *testing.T) { func TestChats_List_ScanError(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
user, _ := db.Create(context.Background()) user, _ := db.Users.Create(context.Background())
_, err := db.conn.ExecContext(context.Background(), ` _, err := db.conn.ExecContext(context.Background(), `
ALTER TABLE chats ALTER COLUMN identifier DROP NOT NULL ALTER TABLE chats ALTER COLUMN identifier DROP NOT NULL
`) `)
require.NoError(t, err) assert.NoError(t, err)
_, err = db.conn.ExecContext(context.Background(), ` _, err = db.conn.ExecContext(context.Background(), `
INSERT INTO chats (user_id, platform, identifier) INSERT INTO chats (user_id, platform, identifier)
VALUES ($1, 'telegram', NULL) VALUES ($1, 'telegram', NULL)
`, user.ID) `, user.ID)
require.NoError(t, err) assert.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
db.conn.ExecContext(context.Background(), ` db.conn.ExecContext(context.Background(), `
@@ -141,16 +139,16 @@ func TestChats_List_ScanError(t *testing.T) {
`) `)
}) })
_, err = db.List(context.Background(), user.ID) _, err = db.Chats.List(context.Background(), user.ID)
assert.Error(t, err) assert.Error(t, err)
} }
func TestChats_List_RowsErr(t *testing.T) { func TestChats_List_RowsErr(t *testing.T) {
conn, mock, err := sqlmock.New() conn, mock, err := sqlmock.New()
require.NoError(t, err) assert.NoError(t, err)
defer conn.Close() defer conn.Close()
db := &DB{conn: conn} db := &Chats{conn: conn}
userID := uuid.New() userID := uuid.New()
rows := sqlmock.NewRows([]string{"user_id", "platform", "identifier"}). rows := sqlmock.NewRows([]string{"user_id", "platform", "identifier"}).
+86
View File
@@ -0,0 +1,86 @@
package postgres
import (
"context"
"database/sql"
"errors"
"github.com/d1nch8g/jules/database"
"github.com/google/uuid"
"github.com/lib/pq"
)
type Contacts struct {
conn *sql.DB
}
func (c *Contacts) Add(ctx context.Context, contact *database.Contact) error {
_, err := c.conn.ExecContext(ctx, `
INSERT INTO contacts (owner_id, target_id, name)
VALUES ($1, $2, $3)
`, contact.OwnerID, contact.TargetID, contact.Name)
if err != nil {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
return database.ErrAlreadyExists
}
return err
}
return nil
}
func (c *Contacts) Get(ctx context.Context, ownerID uuid.UUID, name string) (*database.Contact, error) {
var contact database.Contact
err := c.conn.QueryRowContext(ctx, `
SELECT owner_id, target_id, name
FROM contacts
WHERE owner_id = $1 AND name = $2
`, ownerID, name).Scan(&contact.OwnerID, &contact.TargetID, &contact.Name)
if errors.Is(err, sql.ErrNoRows) {
return nil, database.ErrNotFound
}
if err != nil {
return nil, err
}
return &contact, nil
}
func (c *Contacts) List(ctx context.Context, ownerID uuid.UUID) ([]database.Contact, error) {
rows, err := c.conn.QueryContext(ctx, `
SELECT owner_id, target_id, name
FROM contacts
WHERE owner_id = $1
`, ownerID)
if err != nil {
return nil, err
}
defer rows.Close()
var contacts []database.Contact
for rows.Next() {
var contact database.Contact
if err := rows.Scan(&contact.OwnerID, &contact.TargetID, &contact.Name); err != nil {
return nil, err
}
contacts = append(contacts, contact)
}
if err := rows.Err(); err != nil {
return nil, err
}
return contacts, nil
}
func (c *Contacts) Delete(ctx context.Context, ownerID uuid.UUID, name string) error {
result, err := c.conn.ExecContext(ctx, `
DELETE FROM contacts
WHERE owner_id = $1 AND name = $2
`, ownerID, name)
if err != nil {
return err
}
rows, _ := result.RowsAffected()
if rows == 0 {
return database.ErrNotFound
}
return nil
}
+171
View File
@@ -0,0 +1,171 @@
package postgres
import (
"context"
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/d1nch8g/jules/database"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestContacts_Add(t *testing.T) {
db := setupTestDB(t)
owner, _ := db.Users.Create(context.Background())
target, _ := db.Users.Create(context.Background())
contact := &database.Contact{
OwnerID: owner.ID,
TargetID: target.ID,
Name: "брат",
}
t.Run("success", func(t *testing.T) {
err := db.Contacts.Add(context.Background(), contact)
require.NoError(t, err)
})
t.Run("already exists", func(t *testing.T) {
err := db.Contacts.Add(context.Background(), contact)
assert.ErrorIs(t, err, database.ErrAlreadyExists)
})
t.Run("database error", func(t *testing.T) {
db.Close()
err := db.Contacts.Add(context.Background(), contact)
assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrAlreadyExists)
})
}
func TestContacts_Get(t *testing.T) {
db := setupTestDB(t)
owner, _ := db.Users.Create(context.Background())
target, _ := db.Users.Create(context.Background())
contact := &database.Contact{
OwnerID: owner.ID,
TargetID: target.ID,
Name: "брат",
}
db.Contacts.Add(context.Background(), contact)
t.Run("found", func(t *testing.T) {
got, err := db.Contacts.Get(context.Background(), owner.ID, "брат")
require.NoError(t, err)
assert.Equal(t, owner.ID, got.OwnerID)
assert.Equal(t, target.ID, got.TargetID)
assert.Equal(t, "брат", got.Name)
})
t.Run("not found", func(t *testing.T) {
_, err := db.Contacts.Get(context.Background(), owner.ID, "сестра")
assert.ErrorIs(t, err, database.ErrNotFound)
})
t.Run("database error", func(t *testing.T) {
db.Close()
_, err := db.Contacts.Get(context.Background(), owner.ID, "брат")
assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrNotFound)
})
}
func TestContacts_List(t *testing.T) {
db := setupTestDB(t)
owner, _ := db.Users.Create(context.Background())
target1, _ := db.Users.Create(context.Background())
target2, _ := db.Users.Create(context.Background())
db.Contacts.Add(context.Background(), &database.Contact{OwnerID: owner.ID, TargetID: target1.ID, Name: "брат"})
db.Contacts.Add(context.Background(), &database.Contact{OwnerID: owner.ID, TargetID: target2.ID, Name: "друг"})
t.Run("success", func(t *testing.T) {
contacts, err := db.Contacts.List(context.Background(), owner.ID)
require.NoError(t, err)
assert.Len(t, contacts, 2)
})
t.Run("empty", func(t *testing.T) {
other, _ := db.Users.Create(context.Background())
contacts, err := db.Contacts.List(context.Background(), other.ID)
require.NoError(t, err)
assert.Empty(t, contacts)
})
t.Run("database error", func(t *testing.T) {
db.Close()
_, err := db.Contacts.List(context.Background(), owner.ID)
assert.Error(t, err)
})
}
func TestContacts_Delete(t *testing.T) {
db := setupTestDB(t)
owner, _ := db.Users.Create(context.Background())
target, _ := db.Users.Create(context.Background())
db.Contacts.Add(context.Background(), &database.Contact{OwnerID: owner.ID, TargetID: target.ID, Name: "брат"})
t.Run("success", func(t *testing.T) {
err := db.Contacts.Delete(context.Background(), owner.ID, "брат")
require.NoError(t, err)
contacts, _ := db.Contacts.List(context.Background(), owner.ID)
assert.Empty(t, contacts)
})
t.Run("not found", func(t *testing.T) {
err := db.Contacts.Delete(context.Background(), owner.ID, "брат")
assert.ErrorIs(t, err, database.ErrNotFound)
})
t.Run("database error", func(t *testing.T) {
db.Close()
err := db.Contacts.Delete(context.Background(), owner.ID, "брат")
assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrNotFound)
})
}
func TestContacts_List_ScanError(t *testing.T) {
conn, mock, err := sqlmock.New()
require.NoError(t, err)
defer conn.Close()
db := &Contacts{conn: conn}
ownerID := uuid.New()
rows := sqlmock.NewRows([]string{"owner_id", "target_id", "name"}).
AddRow(uuid.New(), uuid.New(), nil) // NULL в name
mock.ExpectQuery(`SELECT owner_id, target_id, name FROM contacts WHERE owner_id = \$1`).
WithArgs(ownerID).
WillReturnRows(rows)
_, err = db.List(context.Background(), ownerID)
assert.Error(t, err)
}
func TestContacts_List_RowsErr(t *testing.T) {
conn, mock, err := sqlmock.New()
require.NoError(t, err)
defer conn.Close()
db := &Contacts{conn: conn}
ownerID := uuid.New()
rows := sqlmock.NewRows([]string{"owner_id", "target_id", "name"}).
AddRow(uuid.New(), uuid.New(), "брат").
RowError(0, errors.New("connection lost"))
mock.ExpectQuery(`SELECT owner_id, target_id, name FROM contacts WHERE owner_id = \$1`).
WithArgs(ownerID).
WillReturnRows(rows)
_, err = db.List(context.Background(), ownerID)
assert.Error(t, err)
}
+62
View File
@@ -0,0 +1,62 @@
package postgres
import (
"context"
"database/sql"
"github.com/d1nch8g/jules/database"
"github.com/google/uuid"
)
type Facts struct {
conn *sql.DB
}
func (f *Facts) Add(ctx context.Context, userID uuid.UUID, value string) error {
_, err := f.conn.ExecContext(ctx, `
INSERT INTO facts (user_id, value)
VALUES ($1, $2)
ON CONFLICT (user_id, value) DO NOTHING
`, userID, value)
return err
}
func (f *Facts) List(ctx context.Context, userID uuid.UUID) ([]database.Fact, error) {
rows, err := f.conn.QueryContext(ctx, `
SELECT user_id, value
FROM facts
WHERE user_id = $1
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var facts []database.Fact
for rows.Next() {
var fact database.Fact
if err := rows.Scan(&fact.UserID, &fact.Value); err != nil {
return nil, err
}
facts = append(facts, fact)
}
if err := rows.Err(); err != nil {
return nil, err
}
return facts, nil
}
func (f *Facts) Delete(ctx context.Context, userID uuid.UUID, value string) error {
result, err := f.conn.ExecContext(ctx, `
DELETE FROM facts
WHERE user_id = $1 AND value = $2
`, userID, value)
if err != nil {
return err
}
rows, _ := result.RowsAffected()
if rows == 0 {
return database.ErrNotFound
}
return nil
}
+112
View File
@@ -0,0 +1,112 @@
package postgres
import (
"context"
"errors"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/d1nch8g/jules/database"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestFacts_Add(t *testing.T) {
db := setupTestDB(t)
user, _ := db.Users.Create(context.Background())
err := db.Facts.Add(context.Background(), user.ID, "мама Ирина")
require.NoError(t, err)
err = db.Facts.Add(context.Background(), user.ID, "мама Ирина")
require.NoError(t, err)
}
func TestFacts_List(t *testing.T) {
db := setupTestDB(t)
user, _ := db.Users.Create(context.Background())
db.Facts.Add(context.Background(), user.ID, "мама Ирина")
db.Facts.Add(context.Background(), user.ID, "спит в 23:30")
facts, err := db.Facts.List(context.Background(), user.ID)
require.NoError(t, err)
assert.Len(t, facts, 2)
}
func TestFacts_Delete(t *testing.T) {
db := setupTestDB(t)
user, _ := db.Users.Create(context.Background())
db.Facts.Add(context.Background(), user.ID, "мама Ирина")
err := db.Facts.Delete(context.Background(), user.ID, "мама Ирина")
require.NoError(t, err)
facts, _ := db.Facts.List(context.Background(), user.ID)
assert.Empty(t, facts)
err = db.Facts.Delete(context.Background(), user.ID, "мама Ирина")
assert.ErrorIs(t, err, database.ErrNotFound)
}
func TestFacts_List_DatabaseError(t *testing.T) {
db := setupTestDB(t)
user, _ := db.Users.Create(context.Background())
db.Facts.Add(context.Background(), user.ID, "test")
db.Close()
_, err := db.Facts.List(context.Background(), user.ID)
assert.Error(t, err)
}
func TestFacts_Delete_DatabaseError(t *testing.T) {
db := setupTestDB(t)
user, _ := db.Users.Create(context.Background())
db.Facts.Add(context.Background(), user.ID, "test")
db.Close()
err := db.Facts.Delete(context.Background(), user.ID, "test")
assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrNotFound)
}
func TestFacts_List_RowsErr(t *testing.T) {
conn, mock, err := sqlmock.New()
require.NoError(t, err)
defer conn.Close()
db := &Facts{conn: conn}
userID := uuid.New()
rows := sqlmock.NewRows([]string{"user_id", "value"}).
AddRow(uuid.New(), "test fact").
RowError(0, errors.New("connection lost"))
mock.ExpectQuery(`SELECT user_id, value FROM facts WHERE user_id = \$1`).
WithArgs(userID).
WillReturnRows(rows)
_, err = db.List(context.Background(), userID)
assert.Error(t, err)
}
func TestFacts_List_ScanError(t *testing.T) {
conn, mock, err := sqlmock.New()
require.NoError(t, err)
defer conn.Close()
db := &Facts{conn: conn}
userID := uuid.New()
rows := sqlmock.NewRows([]string{"user_id", "value"}).
AddRow(uuid.New(), nil)
mock.ExpectQuery(`SELECT user_id, value FROM facts WHERE user_id = \$1`).
WithArgs(userID).
WillReturnRows(rows)
_, err = db.List(context.Background(), userID)
assert.Error(t, err)
}
+6 -7
View File
@@ -9,7 +9,6 @@ import (
_ "github.com/lib/pq" _ "github.com/lib/pq"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func getTestConnString() string { func getTestConnString() string {
@@ -22,10 +21,10 @@ func getTestConnString() string {
func getTestConn(t *testing.T) *sql.DB { func getTestConn(t *testing.T) *sql.DB {
conn, err := sql.Open("postgres", getTestConnString()) conn, err := sql.Open("postgres", getTestConnString())
require.NoError(t, err) assert.NoError(t, err)
ctx := context.Background() ctx := context.Background()
require.NoError(t, conn.PingContext(ctx)) assert.NoError(t, conn.PingContext(ctx))
t.Cleanup(func() { t.Cleanup(func() {
cleanTables(t, conn) cleanTables(t, conn)
@@ -67,7 +66,7 @@ func TestRunMigrations_AlreadyApplied(t *testing.T) {
dropSchema(t, conn) dropSchema(t, conn)
err := runMigrations(conn) err := runMigrations(conn)
require.NoError(t, err) assert.NoError(t, err)
err = runMigrations(conn) err = runMigrations(conn)
assert.NoError(t, err) assert.NoError(t, err)
@@ -75,7 +74,7 @@ func TestRunMigrations_AlreadyApplied(t *testing.T) {
func TestRunMigrations_InvalidConn(t *testing.T) { func TestRunMigrations_InvalidConn(t *testing.T) {
conn, err := sql.Open("postgres", "postgres://invalid:5432/db") conn, err := sql.Open("postgres", "postgres://invalid:5432/db")
require.NoError(t, err) assert.NoError(t, err)
defer conn.Close() defer conn.Close()
err = runMigrations(conn) err = runMigrations(conn)
@@ -87,7 +86,7 @@ func TestRunMigrations_FailedCreateIOFSDriver(t *testing.T) {
migrationsFS = embed.FS{} migrationsFS = embed.FS{}
conn, err := sql.Open("postgres", "postgres://invalid:5432/db") conn, err := sql.Open("postgres", "postgres://invalid:5432/db")
require.NoError(t, err) assert.NoError(t, err)
defer conn.Close() defer conn.Close()
err = runMigrations(conn) err = runMigrations(conn)
@@ -102,7 +101,7 @@ func TestRunMigrations_FailedUp(t *testing.T) {
dropSchema(t, conn) dropSchema(t, conn)
_, err := conn.ExecContext(context.Background(), `CREATE TABLE users ()`) _, err := conn.ExecContext(context.Background(), `CREATE TABLE users ()`)
require.NoError(t, err) assert.NoError(t, err)
err = runMigrations(conn) err = runMigrations(conn)
assert.Error(t, err) assert.Error(t, err)
+92
View File
@@ -0,0 +1,92 @@
package postgres
import (
"context"
"database/sql"
"time"
"github.com/d1nch8g/jules/database"
"github.com/google/uuid"
)
type Notifications struct {
conn *sql.DB
}
func (n *Notifications) Push(ctx context.Context, notif *database.Notification) error {
_, err := n.conn.ExecContext(ctx, `
INSERT INTO notifications (id, user_id, scheduled_at, content)
VALUES ($1, $2, $3, $4)
`, notif.ID, notif.UserID, notif.ScheduledAt, notif.Content)
return err
}
func (n *Notifications) Pop(ctx context.Context, limit int) ([]database.Notification, error) {
now := time.Now().UTC().Truncate(time.Minute)
nextMinute := now.Add(time.Minute)
rows, err := n.conn.QueryContext(ctx, `
WITH batch AS (
SELECT id
FROM notifications
WHERE scheduled_at >= $1 AND scheduled_at < $2
ORDER BY scheduled_at
LIMIT $3
FOR UPDATE SKIP LOCKED
)
DELETE FROM notifications
WHERE id IN (SELECT id FROM batch)
RETURNING id, user_id, scheduled_at, content
`, now, nextMinute, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var notifs []database.Notification
for rows.Next() {
var n database.Notification
if err := rows.Scan(&n.ID, &n.UserID, &n.ScheduledAt, &n.Content); err != nil {
return nil, err
}
notifs = append(notifs, n)
}
return notifs, rows.Err()
}
func (n *Notifications) List(ctx context.Context, userID uuid.UUID) ([]database.Notification, error) {
rows, err := n.conn.QueryContext(ctx, `
SELECT id, user_id, scheduled_at, content
FROM notifications
WHERE user_id = $1
ORDER BY scheduled_at
`, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var notifs []database.Notification
for rows.Next() {
var notif database.Notification
if err := rows.Scan(&notif.ID, &notif.UserID, &notif.ScheduledAt, &notif.Content); err != nil {
return nil, err
}
notifs = append(notifs, notif)
}
return notifs, rows.Err()
}
func (n *Notifications) Delete(ctx context.Context, id uuid.UUID) error {
result, err := n.conn.ExecContext(ctx, `
DELETE FROM notifications WHERE id = $1
`, id)
if err != nil {
return err
}
rows, _ := result.RowsAffected()
if rows == 0 {
return database.ErrNotFound
}
return nil
}
+149
View File
@@ -0,0 +1,149 @@
package postgres
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/d1nch8g/jules/database"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNotifications_ConcurrentPop(t *testing.T) {
db := setupTestDB(t)
user, _ := db.Users.Create(context.Background())
ctx := context.Background()
now := time.Now().UTC().Truncate(time.Minute)
for range 1000 {
err := db.Notifications.Push(ctx, &database.Notification{
ID: uuid.New(),
UserID: user.ID,
ScheduledAt: now,
Content: "test",
})
require.NoError(t, err)
}
var wg sync.WaitGroup
popped := make(chan uuid.UUID, 1000)
for range 10 {
wg.Go(func() {
for {
notifs, err := db.Notifications.Pop(ctx, 10)
require.NoError(t, err)
if len(notifs) == 0 {
return
}
for _, n := range notifs {
popped <- n.ID
}
}
})
}
wg.Wait()
close(popped)
ids := make(map[uuid.UUID]bool)
for id := range popped {
ids[id] = true
}
assert.Len(t, ids, 1000)
remaining, err := db.Notifications.Pop(ctx, 10)
require.NoError(t, err)
assert.Empty(t, remaining)
}
func TestNotifications_Pop_QueryError(t *testing.T) {
conn, mock, _ := sqlmock.New()
defer conn.Close()
db := &Notifications{conn: conn}
mock.ExpectQuery(`WITH batch AS`).WillReturnError(errors.New("db down"))
_, err := db.Pop(context.Background(), 10)
assert.Error(t, err)
}
func TestNotifications_Pop_ScanError(t *testing.T) {
conn, mock, _ := sqlmock.New()
defer conn.Close()
db := &Notifications{conn: conn}
rows := sqlmock.NewRows([]string{"id", "user_id", "scheduled_at", "content"}).
AddRow(uuid.New(), uuid.New(), time.Now(), nil)
mock.ExpectQuery(`WITH batch AS`).WillReturnRows(rows)
_, err := db.Pop(context.Background(), 10)
assert.Error(t, err)
}
func TestNotifications_List(t *testing.T) {
db := setupTestDB(t)
user, _ := db.Users.Create(context.Background())
now := time.Now().UTC().Truncate(time.Minute)
n1 := &database.Notification{ID: uuid.New(), UserID: user.ID, ScheduledAt: now, Content: "first"}
n2 := &database.Notification{ID: uuid.New(), UserID: user.ID, ScheduledAt: now.Add(time.Hour), Content: "second"}
db.Notifications.Push(context.Background(), n1)
db.Notifications.Push(context.Background(), n2)
notifs, err := db.Notifications.List(context.Background(), user.ID)
require.NoError(t, err)
assert.Len(t, notifs, 2)
other, _ := db.Users.Create(context.Background())
notifs, err = db.Notifications.List(context.Background(), other.ID)
require.NoError(t, err)
assert.Empty(t, notifs)
}
func TestNotifications_Delete(t *testing.T) {
db := setupTestDB(t)
user, _ := db.Users.Create(context.Background())
n := &database.Notification{ID: uuid.New(), UserID: user.ID, ScheduledAt: time.Now().UTC(), Content: "test"}
db.Notifications.Push(context.Background(), n)
err := db.Notifications.Delete(context.Background(), n.ID)
require.NoError(t, err)
err = db.Notifications.Delete(context.Background(), n.ID)
assert.ErrorIs(t, err, database.ErrNotFound)
}
func TestNotifications_List_QueryError(t *testing.T) {
conn, mock, _ := sqlmock.New()
defer conn.Close()
db := &Notifications{conn: conn}
mock.ExpectQuery(`SELECT id, user_id, scheduled_at, content FROM notifications`).WillReturnError(errors.New("db down"))
_, err := db.List(context.Background(), uuid.New())
assert.Error(t, err)
}
func TestNotifications_List_ScanError(t *testing.T) {
conn, mock, _ := sqlmock.New()
defer conn.Close()
db := &Notifications{conn: conn}
rows := sqlmock.NewRows([]string{"id", "user_id", "scheduled_at", "content"}).
AddRow(uuid.New(), uuid.New(), time.Now(), nil)
mock.ExpectQuery(`SELECT id, user_id, scheduled_at, content FROM notifications`).WillReturnRows(rows)
_, err := db.List(context.Background(), uuid.New())
assert.Error(t, err)
}
func TestNotifications_Delete_Error(t *testing.T) {
conn, mock, _ := sqlmock.New()
defer conn.Close()
db := &Notifications{conn: conn}
mock.ExpectExec(`DELETE FROM notifications`).WillReturnError(errors.New("db down"))
err := db.Delete(context.Background(), uuid.New())
assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrNotFound)
}
+15 -1
View File
@@ -10,6 +10,12 @@ import (
) )
type DB struct { type DB struct {
Users Users
Chats Chats
Facts Facts
Contacts Contacts
Notifications Notifications
conn *sql.DB conn *sql.DB
} }
@@ -32,7 +38,15 @@ func New(connString string) (*DB, error) {
return nil, fmt.Errorf("run migrations: %w", err) return nil, fmt.Errorf("run migrations: %w", err)
} }
return &DB{conn: conn}, nil return &DB{
Users: Users{conn: conn},
Chats: Chats{conn: conn},
Facts: Facts{conn: conn},
Contacts: Contacts{conn: conn},
Notifications: Notifications{conn: conn},
conn: conn,
}, nil
} }
func (db *DB) Close() error { func (db *DB) Close() error {
+2 -3
View File
@@ -5,7 +5,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func setupTestDB(t *testing.T) *DB { func setupTestDB(t *testing.T) *DB {
@@ -20,7 +19,7 @@ func TestNew_Success(t *testing.T) {
connString := getTestConnString() connString := getTestConnString()
db, err := New(connString) db, err := New(connString)
require.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, db) assert.NotNil(t, db)
defer db.Close() defer db.Close()
} }
@@ -36,7 +35,7 @@ func TestNew_RunMigrationsFailed(t *testing.T) {
dropSchema(t, conn) dropSchema(t, conn)
_, err := conn.ExecContext(context.Background(), `CREATE TABLE users ()`) _, err := conn.ExecContext(context.Background(), `CREATE TABLE users ()`)
require.NoError(t, err) assert.NoError(t, err)
_, err = New(getTestConnString()) _, err = New(getTestConnString())
assert.Error(t, err) assert.Error(t, err)
+12 -8
View File
@@ -9,9 +9,13 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
func (db *DB) Create(ctx context.Context) (*database.User, error) { type Users struct {
conn *sql.DB
}
func (u *Users) Create(ctx context.Context) (*database.User, error) {
var user database.User var user database.User
err := db.conn.QueryRowContext(ctx, ` err := u.conn.QueryRowContext(ctx, `
INSERT INTO users (preferred_chat, language, timezone) INSERT INTO users (preferred_chat, language, timezone)
VALUES ('telegram', 'en', 'UTC') VALUES ('telegram', 'en', 'UTC')
RETURNING id, preferred_chat, language, timezone RETURNING id, preferred_chat, language, timezone
@@ -22,9 +26,9 @@ func (db *DB) Create(ctx context.Context) (*database.User, error) {
return &user, nil return &user, nil
} }
func (db *DB) Get(ctx context.Context, id uuid.UUID) (*database.User, error) { func (u *Users) Get(ctx context.Context, id uuid.UUID) (*database.User, error) {
var user database.User var user database.User
err := db.conn.QueryRowContext(ctx, ` err := u.conn.QueryRowContext(ctx, `
SELECT id, preferred_chat, language, timezone SELECT id, preferred_chat, language, timezone
FROM users FROM users
WHERE id = $1 WHERE id = $1
@@ -38,8 +42,8 @@ func (db *DB) Get(ctx context.Context, id uuid.UUID) (*database.User, error) {
return &user, nil return &user, nil
} }
func (db *DB) Update(ctx context.Context, user *database.User) error { func (u *Users) Update(ctx context.Context, user *database.User) error {
result, err := db.conn.ExecContext(ctx, ` result, err := u.conn.ExecContext(ctx, `
UPDATE users UPDATE users
SET preferred_chat = $1, language = $2, timezone = $3 SET preferred_chat = $1, language = $2, timezone = $3
WHERE id = $4 WHERE id = $4
@@ -54,8 +58,8 @@ func (db *DB) Update(ctx context.Context, user *database.User) error {
return nil return nil
} }
func (db *DB) Delete(ctx context.Context, id uuid.UUID) error { func (u *Users) Delete(ctx context.Context, id uuid.UUID) error {
result, err := db.conn.ExecContext(ctx, `DELETE FROM users WHERE id = $1`, id) result, err := u.conn.ExecContext(ctx, `DELETE FROM users WHERE id = $1`, id)
if err != nil { if err != nil {
return err return err
} }
+28 -29
View File
@@ -7,14 +7,13 @@ import (
"github.com/d1nch8g/jules/database" "github.com/d1nch8g/jules/database"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestUsers_Create(t *testing.T) { func TestUsers_Create(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
user, err := db.Create(context.Background()) user, err := db.Users.Create(context.Background())
require.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, uuid.Nil, user.ID) assert.NotEqual(t, uuid.Nil, user.ID)
assert.Equal(t, "telegram", user.PreferredChat) assert.Equal(t, "telegram", user.PreferredChat)
assert.Equal(t, "en", user.Language) assert.Equal(t, "en", user.Language)
@@ -24,18 +23,18 @@ func TestUsers_Create(t *testing.T) {
func TestUsers_Get(t *testing.T) { func TestUsers_Get(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
user, err := db.Create(context.Background()) user, err := db.Users.Create(context.Background())
require.NoError(t, err) assert.NoError(t, err)
t.Run("found", func(t *testing.T) { t.Run("found", func(t *testing.T) {
fetched, err := db.Get(context.Background(), user.ID) fetched, err := db.Users.Get(context.Background(), user.ID)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, user.ID, fetched.ID) assert.Equal(t, user.ID, fetched.ID)
assert.Equal(t, user.PreferredChat, fetched.PreferredChat) assert.Equal(t, user.PreferredChat, fetched.PreferredChat)
}) })
t.Run("not found", func(t *testing.T) { t.Run("not found", func(t *testing.T) {
_, err := db.Get(context.Background(), uuid.New()) _, err := db.Users.Get(context.Background(), uuid.New())
assert.ErrorIs(t, err, database.ErrNotFound) assert.ErrorIs(t, err, database.ErrNotFound)
}) })
} }
@@ -43,19 +42,19 @@ func TestUsers_Get(t *testing.T) {
func TestUsers_Update(t *testing.T) { func TestUsers_Update(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
user, err := db.Create(context.Background()) user, err := db.Users.Create(context.Background())
require.NoError(t, err) assert.NoError(t, err)
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
user.PreferredChat = "whatsapp" user.PreferredChat = "whatsapp"
user.Language = "ru" user.Language = "ru"
user.Timezone = "Europe/Moscow" user.Timezone = "Europe/Moscow"
err := db.Update(context.Background(), user) err := db.Users.Update(context.Background(), user)
require.NoError(t, err) assert.NoError(t, err)
fetched, err := db.Get(context.Background(), user.ID) fetched, err := db.Users.Get(context.Background(), user.ID)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "whatsapp", fetched.PreferredChat) assert.Equal(t, "whatsapp", fetched.PreferredChat)
assert.Equal(t, "ru", fetched.Language) assert.Equal(t, "ru", fetched.Language)
assert.Equal(t, "Europe/Moscow", fetched.Timezone) assert.Equal(t, "Europe/Moscow", fetched.Timezone)
@@ -63,7 +62,7 @@ func TestUsers_Update(t *testing.T) {
t.Run("not found", func(t *testing.T) { t.Run("not found", func(t *testing.T) {
ghost := &database.User{ID: uuid.New()} ghost := &database.User{ID: uuid.New()}
err := db.Update(context.Background(), ghost) err := db.Users.Update(context.Background(), ghost)
assert.ErrorIs(t, err, database.ErrNotFound) assert.ErrorIs(t, err, database.ErrNotFound)
}) })
} }
@@ -72,18 +71,18 @@ func TestUsers_Delete(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
user, err := db.Create(context.Background()) user, err := db.Users.Create(context.Background())
require.NoError(t, err) assert.NoError(t, err)
err = db.Delete(context.Background(), user.ID) err = db.Users.Delete(context.Background(), user.ID)
require.NoError(t, err) assert.NoError(t, err)
_, err = db.Get(context.Background(), user.ID) _, err = db.Users.Get(context.Background(), user.ID)
assert.ErrorIs(t, err, database.ErrNotFound) assert.ErrorIs(t, err, database.ErrNotFound)
}) })
t.Run("not found", func(t *testing.T) { t.Run("not found", func(t *testing.T) {
err := db.Delete(context.Background(), uuid.New()) err := db.Users.Delete(context.Background(), uuid.New())
assert.ErrorIs(t, err, database.ErrNotFound) assert.ErrorIs(t, err, database.ErrNotFound)
}) })
} }
@@ -93,7 +92,7 @@ func TestUsers_Create_DatabaseError(t *testing.T) {
db.Close() db.Close()
_, err := db.Create(context.Background()) _, err := db.Users.Create(context.Background())
assert.Error(t, err) assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrNotFound) assert.NotErrorIs(t, err, database.ErrNotFound)
} }
@@ -103,7 +102,7 @@ func TestUsers_Get_DatabaseError(t *testing.T) {
db.Close() db.Close()
_, err := db.Get(context.Background(), uuid.New()) _, err := db.Users.Get(context.Background(), uuid.New())
assert.Error(t, err) assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrNotFound) assert.NotErrorIs(t, err, database.ErrNotFound)
} }
@@ -111,12 +110,12 @@ func TestUsers_Get_DatabaseError(t *testing.T) {
func TestUsers_Update_DatabaseError(t *testing.T) { func TestUsers_Update_DatabaseError(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
user, err := db.Create(context.Background()) user, err := db.Users.Create(context.Background())
require.NoError(t, err) assert.NoError(t, err)
db.Close() db.Close()
err = db.Update(context.Background(), user) err = db.Users.Update(context.Background(), user)
assert.Error(t, err) assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrNotFound) assert.NotErrorIs(t, err, database.ErrNotFound)
} }
@@ -124,12 +123,12 @@ func TestUsers_Update_DatabaseError(t *testing.T) {
func TestUsers_Delete_DatabaseError(t *testing.T) { func TestUsers_Delete_DatabaseError(t *testing.T) {
db := setupTestDB(t) db := setupTestDB(t)
user, err := db.Create(context.Background()) user, err := db.Users.Create(context.Background())
require.NoError(t, err) assert.NoError(t, err)
db.Close() db.Close()
err = db.Delete(context.Background(), user.ID) err = db.Users.Delete(context.Background(), user.ID)
assert.Error(t, err) assert.Error(t, err)
assert.NotErrorIs(t, err, database.ErrNotFound) assert.NotErrorIs(t, err, database.ErrNotFound)
} }
+1 -2
View File
@@ -8,7 +8,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNew(t *testing.T) { func TestNew(t *testing.T) {
@@ -30,7 +29,7 @@ func TestProcess_Success(t *testing.T) {
client := New("key", server.URL) client := New("key", server.URL)
result, err := client.Process(context.Background(), "prompt") result, err := client.Process(context.Background(), "prompt")
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "response", result) assert.Equal(t, "response", result)
} }
+4 -5
View File
@@ -11,7 +11,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
type roundTripFunc func(*http.Request) (*http.Response, error) type roundTripFunc func(*http.Request) (*http.Response, error)
@@ -50,7 +49,7 @@ func TestSearch_Success_WithSummarizer(t *testing.T) {
client := New("key", server.URL) client := New("key", server.URL)
result, err := client.Search(context.Background(), "query") result, err := client.Search(context.Background(), "query")
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "summary text", result) assert.Equal(t, "summary text", result)
} }
@@ -66,7 +65,7 @@ func TestSearch_Success_NoSummarizer_FallbackToDescription(t *testing.T) {
client := New("key", server.URL) client := New("key", server.URL)
result, err := client.Search(context.Background(), "query") result, err := client.Search(context.Background(), "query")
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "fallback description", result) assert.Equal(t, "fallback description", result)
} }
@@ -176,7 +175,7 @@ func TestSearch_DoError_WebSearch(t *testing.T) {
}) })
_, err := client.Search(context.Background(), "query") _, err := client.Search(context.Background(), "query")
require.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "web search network error") assert.Contains(t, err.Error(), "web search network error")
} }
@@ -204,6 +203,6 @@ func TestSearch_DoError_Summarizer(t *testing.T) {
}) })
_, err := client.Search(context.Background(), "query") _, err := client.Search(context.Background(), "query")
require.Error(t, err) assert.Error(t, err)
assert.Contains(t, err.Error(), "summarizer network error") assert.Contains(t, err.Error(), "summarizer network error")
} }