238 lines
4.6 KiB
Go
238 lines
4.6 KiB
Go
package engine
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"sync"
|
|
"time"
|
|
|
|
"m8sh.su/d/jules/chat"
|
|
"m8sh.su/d/jules/database"
|
|
"m8sh.su/d/jules/llm"
|
|
"m8sh.su/d/jules/search"
|
|
)
|
|
|
|
type Parameters struct {
|
|
// Number of parallel workers running on the machine.
|
|
NumWorkers int
|
|
|
|
// The number of actions which will be loaded from
|
|
// database into LLM context.
|
|
ActionLimit int
|
|
|
|
// The batch size of processed notifications,
|
|
// popped out of the database.
|
|
NotificationBatchSize int
|
|
|
|
// The polling duration for database.
|
|
DatabasePollingDuration time.Duration
|
|
|
|
// Timeout for LLM response.
|
|
ProcessingTimeout time.Duration
|
|
|
|
// How many chances do you give an LLM to provide valid response.
|
|
LLMRetryAttempts int
|
|
|
|
Database database.Database
|
|
LLM llm.LLM
|
|
Searcher search.Searcher
|
|
Chats map[string]chat.Chat
|
|
}
|
|
|
|
type Engine struct {
|
|
*Parameters
|
|
|
|
processMessage func(ctx context.Context, msg chat.Message)
|
|
processNotification func(ctx context.Context, notif database.Notification)
|
|
}
|
|
|
|
func New(params *Parameters) (*Engine, error) {
|
|
if params == nil {
|
|
return nil, errors.New("params can't be nil")
|
|
}
|
|
|
|
if params.NumWorkers == 0 {
|
|
params.NumWorkers = 100
|
|
}
|
|
|
|
if params.ProcessingTimeout == 0 {
|
|
params.ProcessingTimeout = 90 * time.Second
|
|
}
|
|
|
|
if params.LLMRetryAttempts == 0 {
|
|
params.LLMRetryAttempts = 3
|
|
}
|
|
|
|
if params.Database == nil {
|
|
return nil, errors.New("database can't be nil")
|
|
}
|
|
|
|
if params.LLM == nil {
|
|
return nil, errors.New("llm can't be nil")
|
|
}
|
|
|
|
if params.Searcher == nil {
|
|
return nil, errors.New("seach engine can't be nil")
|
|
}
|
|
|
|
if len(params.Chats) == 0 {
|
|
return nil, errors.New("chats can't be empty")
|
|
}
|
|
|
|
for platform, chat := range params.Chats {
|
|
if platform == "" {
|
|
return nil, errors.New("platform name can't be empty")
|
|
}
|
|
if chat == nil {
|
|
return nil, fmt.Errorf("%s initialized as nil", platform)
|
|
}
|
|
}
|
|
|
|
engine := &Engine{
|
|
Parameters: params,
|
|
}
|
|
|
|
engine.processMessage = engine.defaultProcessMessage
|
|
engine.processNotification = engine.defaultProcessNotification
|
|
|
|
return engine, nil
|
|
}
|
|
|
|
func (e *Engine) Run(ctx context.Context) error {
|
|
messages := e.consumeChatMessages(ctx)
|
|
notifications := e.consumeNotifications(ctx)
|
|
|
|
slog.InfoContext(ctx, "starting jules...")
|
|
|
|
e.runWorkers(messages, notifications)
|
|
|
|
<-ctx.Done()
|
|
slog.InfoContext(ctx, "stopping jules...")
|
|
|
|
if err := e.Database.Close(); err != nil {
|
|
return fmt.Errorf("failed to close the database: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (e *Engine) runWorkers(messages <-chan chat.Message, notifications <-chan database.Notification) {
|
|
var wg sync.WaitGroup
|
|
|
|
for range e.NumWorkers {
|
|
wg.Go(func() {
|
|
for messages != nil && notifications != nil {
|
|
select {
|
|
case msg, ok := <-messages:
|
|
if !ok {
|
|
messages = nil
|
|
} else {
|
|
ctx, cancel := context.WithTimeout(context.Background(), e.ProcessingTimeout)
|
|
e.processMessage(ctx, msg)
|
|
cancel()
|
|
}
|
|
|
|
case notif, ok := <-notifications:
|
|
if !ok {
|
|
notifications = nil
|
|
} else {
|
|
ctx, cancel := context.WithTimeout(context.Background(), e.ProcessingTimeout)
|
|
e.processNotification(ctx, notif)
|
|
cancel()
|
|
}
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
func (e *Engine) consumeChatMessages(ctx context.Context) <-chan chat.Message {
|
|
var channels []<-chan chat.Message
|
|
for _, chat := range e.Chats {
|
|
channels = append(channels, chat.Receive(ctx))
|
|
}
|
|
|
|
out := make(chan chat.Message)
|
|
|
|
var wg sync.WaitGroup
|
|
wg.Add(len(channels))
|
|
|
|
for _, ch := range channels {
|
|
go func(c <-chan chat.Message) {
|
|
defer wg.Done()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case msg, ok := <-c:
|
|
if !ok {
|
|
return
|
|
}
|
|
select {
|
|
case out <- msg:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}(ch)
|
|
}
|
|
|
|
go func() {
|
|
wg.Wait()
|
|
close(out)
|
|
}()
|
|
|
|
return out
|
|
}
|
|
|
|
func (e *Engine) consumeNotifications(ctx context.Context) <-chan database.Notification { //nolint:gocognit
|
|
out := make(chan database.Notification)
|
|
|
|
go func() {
|
|
ticker := time.NewTicker(e.DatabasePollingDuration)
|
|
defer close(out)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
slog.InfoContext(ctx, "notification poller stopped")
|
|
return
|
|
|
|
case <-ticker.C:
|
|
for {
|
|
notifs, err := e.Database.Notifications().Pop(ctx, e.NotificationBatchSize)
|
|
if err != nil {
|
|
slog.ErrorContext(ctx, "failed to pop notifications", "error", err)
|
|
break
|
|
}
|
|
|
|
if len(notifs) == 0 {
|
|
break
|
|
}
|
|
|
|
for _, n := range notifs {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case out <- n:
|
|
}
|
|
}
|
|
|
|
if len(notifs) < e.NotificationBatchSize {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return out
|
|
}
|