Files
jules/engine/actions/actions.go
T

463 lines
11 KiB
Go

package actions
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/d1nch8g/jules/chat"
"github.com/d1nch8g/jules/database"
"github.com/d1nch8g/jules/engine/jtime"
"github.com/d1nch8g/jules/engine/user"
"github.com/d1nch8g/jules/search"
"github.com/google/uuid"
"github.com/tidwall/gjson"
)
const ActionsPromptPart = `
=== AVAILABLE ACTIONS ===
{"type": "message", "platform": "telegram", "text": "short response"}
{"type": "wait", "ms": 100-600}
{"type": "add_fact", "value": "fact about user"}
{"type": "remove_fact", "value": "fact to remove"}
{"type": "add_notification", "target": "self", "time": "...", "content": "... (repetition rules)"}
{"type": "add_notification", "target": "contact name", "time": "...", "content": "... (repetition rules)"}
{"type": "remove_notification", "uuid": "123e4567-e89b-12d3-a456-426614174000"}
{"type": "add_contact", "uuid": "123e4567-e89b-12d3-a456-426614174000", "name": "Contact Name"}
{"type": "bind_chat", "uuid": "123e4567-e89b-12d3-a456-426614174000"}
{"type": "update_lang", "lang": "ru"}
{"type": "update_tz", "tz": "Europe/Moscow"}
{"type": "set_chat", "chat": "telegram"}
`
// {"type": "search", "query": "search query"} // temporary disabled
type Action interface {
Validate(ctx context.Context, rt *Runtime) error
Execute(ctx context.Context, rt *Runtime) error
}
type Runtime struct {
User *user.User
Database database.Database
Searcher search.Searcher
Chats map[string]chat.Chat
}
type BindChat struct {
Type string `json:"type"`
UUID string `json:"uuid"`
}
func (a BindChat) Validate(ctx context.Context, rt *Runtime) error {
if a.UUID == "" {
return errors.New("target_uuid is required")
}
if _, err := uuid.Parse(a.UUID); err != nil {
return fmt.Errorf("target_uuid must be valid UUID: %w", err)
}
_, err := rt.Database.Users().Get(ctx, uuid.MustParse(a.UUID), database.UserLookupByBindCode)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("user with bind code %s not found", a.UUID)
}
return err
}
return nil
}
func (a BindChat) Execute(ctx context.Context, rt *Runtime) error {
targetUser, err := rt.Database.Users().Get(ctx, uuid.MustParse(a.UUID), database.UserLookupByBindCode)
if err != nil {
return err
}
u := rt.User
if err := rt.Database.Users().Delete(ctx, u.ID); err != nil {
return err
}
for _, chat := range u.Chats {
if err := rt.Database.Chats().Attach(ctx, targetUser.ID, chat.Platform, chat.Identifier); err != nil {
return err
}
}
for _, contact := range u.Contacts {
if err := rt.Database.Contacts().Add(ctx, &database.Contact{
OwnerID: targetUser.ID,
TargetID: contact.TargetID,
Name: contact.Name,
}); err != nil {
return err
}
}
for _, fact := range u.Facts {
if err := rt.Database.Facts().Add(ctx, targetUser.ID, fact.Value); err != nil {
return err
}
}
for _, notif := range u.IncomingNotifications {
if notif.InitiatorID == u.ID {
notif.InitiatorID = targetUser.ID
}
if err := rt.Database.Notifications().Push(ctx, &notif); err != nil {
return err
}
}
for _, notif := range u.OutgoingNotifications {
if err := rt.Database.Notifications().Push(ctx, &notif); err != nil {
return err
}
}
return nil
}
type Message struct {
Type string `json:"type"`
Platform string `json:"platform"`
Text string `json:"text"`
}
func (a Message) Validate(ctx context.Context, rt *Runtime) error {
if a.Text == "" {
return errors.New("text is empty")
}
for _, c := range rt.User.Chats {
if c.Platform == a.Platform {
return nil
}
}
return fmt.Errorf("platform %s not connected", a.Platform)
}
func (a Message) Execute(ctx context.Context, rt *Runtime) error {
var platformID string
for _, c := range rt.User.Chats {
if c.Platform == a.Platform {
platformID = c.Identifier
break
}
}
return rt.Chats[a.Platform].Send(ctx, platformID, a.Text)
}
type Wait struct {
Type string `json:"type"`
Ms int `json:"ms"`
}
func (a Wait) Validate(ctx context.Context, rt *Runtime) error {
if a.Ms <= 0 {
return errors.New("ms must be positive")
}
if a.Ms > 60000 {
return errors.New("ms cannot exceed 60000")
}
return nil
}
func (a Wait) Execute(ctx context.Context, rt *Runtime) error {
time.Sleep(time.Duration(a.Ms) * time.Millisecond)
return nil
}
type UpdateLang struct {
Type string `json:"type"`
Lang string `json:"lang"`
}
func (a UpdateLang) Validate(ctx context.Context, rt *Runtime) error {
if a.Lang == "" {
return errors.New("lang is required")
}
return nil
}
func (a UpdateLang) Execute(ctx context.Context, rt *Runtime) error {
rt.User.Language = a.Lang
return rt.Database.Users().Update(ctx, rt.User.User)
}
type UpdateTZ struct {
Type string `json:"type"`
TZ string `json:"tz"`
}
func (a UpdateTZ) Validate(ctx context.Context, rt *Runtime) error {
if a.TZ == "" {
return errors.New("tz is required")
}
return nil
}
func (a UpdateTZ) Execute(ctx context.Context, rt *Runtime) error {
rt.User.Timezone = a.TZ
return rt.Database.Users().Update(ctx, rt.User.User)
}
type SetChat struct {
Type string `json:"type"`
Chat string `json:"chat"`
}
func (a SetChat) Validate(ctx context.Context, rt *Runtime) error {
if a.Chat == "" {
return errors.New("chat is required")
}
for _, c := range rt.User.Chats {
if c.Platform == a.Chat {
return nil
}
}
return fmt.Errorf("chat %s not connected", a.Chat)
}
func (a SetChat) Execute(ctx context.Context, rt *Runtime) error {
rt.User.PreferredChat = a.Chat
return rt.Database.Users().Update(ctx, rt.User.User)
}
type AddFact struct {
Type string `json:"type"`
Value string `json:"value"`
}
func (a AddFact) Validate(ctx context.Context, rt *Runtime) error {
if a.Value == "" {
return errors.New("value is required")
}
return nil
}
func (a AddFact) Execute(ctx context.Context, rt *Runtime) error {
return rt.Database.Facts().Add(ctx, rt.User.ID, a.Value)
}
type RemoveFact struct {
Type string `json:"type"`
Value string `json:"value"`
}
func (a RemoveFact) Validate(ctx context.Context, rt *Runtime) error { return nil }
func (a RemoveFact) Execute(ctx context.Context, rt *Runtime) error {
err := rt.Database.Facts().Delete(ctx, rt.User.ID, a.Value)
if errors.Is(err, database.ErrNotFound) {
return nil
}
return err
}
type AddContact struct {
Type string `json:"type"`
UUID string `json:"uuid"`
Name string `json:"name"`
}
func (a AddContact) Validate(ctx context.Context, rt *Runtime) error {
if a.UUID == "" {
return errors.New("uuid is required")
}
if _, err := uuid.Parse(a.UUID); err != nil {
return fmt.Errorf("uuid must be valid UUID: %w", err)
}
if a.Name == "" {
return errors.New("name is required")
}
_, err := rt.Database.Users().Get(ctx, uuid.MustParse(a.UUID), database.UserLookupByContactCode)
if err != nil {
if errors.Is(err, database.ErrNotFound) {
return fmt.Errorf("user with contact code %s not found", a.UUID)
}
return err
}
return nil
}
func (a AddContact) Execute(ctx context.Context, rt *Runtime) error {
contactUser, err := rt.Database.Users().Get(ctx, uuid.MustParse(a.UUID), database.UserLookupByContactCode)
if err != nil {
return err
}
return rt.Database.Contacts().Add(ctx, &database.Contact{
OwnerID: rt.User.ID,
TargetID: contactUser.ID,
Name: a.Name,
})
}
type AddNotification struct {
Type string `json:"type"`
Target string `json:"target"`
Time string `json:"time"`
Content string `json:"content"`
}
func (a AddNotification) Validate(ctx context.Context, rt *Runtime) error {
if a.Target == "" {
return errors.New("target is required")
}
if a.Time == "" {
return errors.New("time is required")
}
if a.Content == "" {
return errors.New("content is required")
}
if _, err := jtime.ToUTC(a.Time, rt.User.Timezone); err != nil {
return err
}
if a.Target == "self" {
return nil
}
for _, c := range rt.User.Contacts {
if c.Name == a.Target {
return nil
}
}
return fmt.Errorf("contact %s not found", a.Target)
}
func (a AddNotification) Execute(ctx context.Context, rt *Runtime) error {
scheduledAt, _ := jtime.ToUTC(a.Time, rt.User.Timezone)
initiatorID := rt.User.ID
targetID := rt.User.ID
if a.Target != "self" {
for _, c := range rt.User.Contacts {
if c.Name == a.Target {
initiatorID = c.TargetID
break
}
}
}
return rt.Database.Notifications().Push(ctx, &database.Notification{
ID: uuid.New(),
UserID: targetID,
InitiatorID: initiatorID,
ScheduledAt: scheduledAt,
Content: a.Content,
})
}
type RemoveNotification struct {
Type string `json:"type"`
UUID string `json:"uuid"`
}
func (a RemoveNotification) Validate(ctx context.Context, rt *Runtime) error {
if a.UUID == "" {
return errors.New("uuid is required")
}
if _, err := uuid.Parse(a.UUID); err != nil {
return fmt.Errorf("uuid must be valid UUID: %w", err)
}
return nil
}
func (a RemoveNotification) Execute(ctx context.Context, rt *Runtime) error {
err := rt.Database.Notifications().Delete(ctx, uuid.MustParse(a.UUID))
if errors.Is(err, database.ErrNotFound) {
return nil
}
return err
}
func Parse(raw string, userTimezone string) ([]Action, error) {
start := strings.Index(raw, "[")
end := strings.LastIndex(raw, "]")
if start != -1 && end != -1 && end > start {
raw = raw[start : end+1]
}
raw = strings.TrimSpace(raw)
if !gjson.Valid(raw) {
return nil, errors.New("response is not valid JSON")
}
result := gjson.Parse(raw)
if !result.IsArray() {
return nil, errors.New("response must be an array of actions")
}
var actions []Action
for i, item := range result.Array() {
actionType := item.Get("type").String()
if actionType == "" {
return nil, fmt.Errorf("action %d: missing required field 'type'", i)
}
var action Action
var err error
switch actionType {
case "bind_chat":
var a BindChat
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
case "message":
var a Message
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
case "wait":
var a Wait
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
case "update_lang":
var a UpdateLang
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
case "update_tz":
var a UpdateTZ
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
case "set_chat":
var a SetChat
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
case "add_fact":
var a AddFact
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
case "remove_fact":
var a RemoveFact
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
case "add_contact":
var a AddContact
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
case "add_notification":
var a AddNotification
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
case "remove_notification":
var a RemoveNotification
err = json.Unmarshal([]byte(item.Raw), &a)
action = a
default:
return nil, fmt.Errorf("action %d: unknown action type: %s", i, actionType)
}
if err != nil {
return nil, fmt.Errorf("action %d (%s): %w", i, actionType, err)
}
actions = append(actions, action)
}
return actions, nil
}
type UserAction struct {
Type string `json:"type"`
Content string `json:"content"`
}
func (UserAction) Validate(ctx context.Context, rt *Runtime) error { return nil }
func (UserAction) Execute(ctx context.Context, rt *Runtime) error { return nil }
func Raw(v Action) json.RawMessage {
data, _ := json.Marshal(v)
return data
}