Files
jules/engine/engine.go
T
2026-06-06 18:52:20 +03:00

238 lines
4.6 KiB
Go

package engine
import (
"context"
"errors"
"fmt"
"log/slog"
"sync"
"time"
"m8sh.su/d/jules/chat"
"m8sh.su/d/jules/database"
"m8sh.su/d/jules/llm"
"m8sh.su/d/jules/search"
)
type Parameters struct {
// Number of parallel workers running on the machine.
NumWorkers int
// The number of actions which will be loaded from
// database into LLM context.
ActionLimit int
// The batch size of processed notifications,
// popped out of the database.
NotificationBatchSize int
// The polling duration for database.
DatabasePollingDuration time.Duration
// Timeout for LLM response.
ProcessingTimeout time.Duration
// How many chances do you give an LLM to provide valid response.
LLMRetryAttempts int
Database database.Database
LLM llm.LLM
Searcher search.Searcher
Chats map[string]chat.Chat
}
type Engine struct {
*Parameters
processMessage func(ctx context.Context, msg chat.Message)
processNotification func(ctx context.Context, notif database.Notification)
}
func New(params *Parameters) (*Engine, error) {
if params == nil {
return nil, errors.New("params can't be nil")
}
if params.NumWorkers == 0 {
params.NumWorkers = 100
}
if params.ProcessingTimeout == 0 {
params.ProcessingTimeout = 90 * time.Second
}
if params.LLMRetryAttempts == 0 {
params.LLMRetryAttempts = 3
}
if params.Database == nil {
return nil, errors.New("database can't be nil")
}
if params.LLM == nil {
return nil, errors.New("llm can't be nil")
}
if params.Searcher == nil {
return nil, errors.New("seach engine can't be nil")
}
if len(params.Chats) == 0 {
return nil, errors.New("chats can't be empty")
}
for platform, chat := range params.Chats {
if platform == "" {
return nil, errors.New("platform name can't be empty")
}
if chat == nil {
return nil, fmt.Errorf("%s initialized as nil", platform)
}
}
engine := &Engine{
Parameters: params,
}
engine.processMessage = engine.defaultProcessMessage
engine.processNotification = engine.defaultProcessNotification
return engine, nil
}
func (e *Engine) Run(ctx context.Context) error {
messages := e.consumeChatMessages(ctx)
notifications := e.consumeNotifications(ctx)
slog.InfoContext(ctx, "starting jules...")
e.runWorkers(messages, notifications)
<-ctx.Done()
slog.InfoContext(ctx, "stopping jules...")
if err := e.Database.Close(); err != nil {
return fmt.Errorf("failed to close the database: %w", err)
}
return nil
}
func (e *Engine) runWorkers(messages <-chan chat.Message, notifications <-chan database.Notification) {
var wg sync.WaitGroup
for range e.NumWorkers {
wg.Go(func() {
for messages != nil && notifications != nil {
select {
case msg, ok := <-messages:
if !ok {
messages = nil
} else {
ctx, cancel := context.WithTimeout(context.Background(), e.ProcessingTimeout)
e.processMessage(ctx, msg)
cancel()
}
case notif, ok := <-notifications:
if !ok {
notifications = nil
} else {
ctx, cancel := context.WithTimeout(context.Background(), e.ProcessingTimeout)
e.processNotification(ctx, notif)
cancel()
}
}
}
})
}
wg.Wait()
}
func (e *Engine) consumeChatMessages(ctx context.Context) <-chan chat.Message {
var channels []<-chan chat.Message
for _, chat := range e.Chats {
channels = append(channels, chat.Receive(ctx))
}
out := make(chan chat.Message)
var wg sync.WaitGroup
wg.Add(len(channels))
for _, ch := range channels {
go func(c <-chan chat.Message) {
defer wg.Done()
for {
select {
case <-ctx.Done():
return
case msg, ok := <-c:
if !ok {
return
}
select {
case out <- msg:
case <-ctx.Done():
return
}
}
}
}(ch)
}
go func() {
wg.Wait()
close(out)
}()
return out
}
func (e *Engine) consumeNotifications(ctx context.Context) <-chan database.Notification { //nolint:gocognit
out := make(chan database.Notification)
go func() {
ticker := time.NewTicker(e.DatabasePollingDuration)
defer close(out)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
slog.InfoContext(ctx, "notification poller stopped")
return
case <-ticker.C:
for {
notifs, err := e.Database.Notifications().Pop(ctx, e.NotificationBatchSize)
if err != nil {
slog.ErrorContext(ctx, "failed to pop notifications", "error", err)
break
}
if len(notifs) == 0 {
break
}
for _, n := range notifs {
select {
case <-ctx.Done():
return
case out <- n:
}
}
if len(notifs) < e.NotificationBatchSize {
break
}
}
}
}
}()
return out
}