463 lines
11 KiB
Go
463 lines
11 KiB
Go
package actions
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/d1nch8g/jules/chat"
|
|
"github.com/d1nch8g/jules/database"
|
|
"github.com/d1nch8g/jules/engine/jtime"
|
|
"github.com/d1nch8g/jules/engine/user"
|
|
"github.com/d1nch8g/jules/search"
|
|
"github.com/google/uuid"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
const ActionsPromptPart = `
|
|
=== AVAILABLE ACTIONS ===
|
|
{"type": "message", "platform": "telegram", "text": "short response"}
|
|
{"type": "wait", "ms": 100-600}
|
|
{"type": "add_fact", "value": "fact about user"}
|
|
{"type": "remove_fact", "value": "fact to remove"}
|
|
{"type": "add_notification", "target": "self", "time": "...", "content": "... (repetition rules)"}
|
|
{"type": "add_notification", "target": "contact name", "time": "...", "content": "... (repetition rules)"}
|
|
{"type": "remove_notification", "uuid": "123e4567-e89b-12d3-a456-426614174000"}
|
|
{"type": "add_contact", "uuid": "123e4567-e89b-12d3-a456-426614174000", "name": "Contact Name"}
|
|
{"type": "bind_chat", "uuid": "123e4567-e89b-12d3-a456-426614174000"}
|
|
{"type": "update_lang", "lang": "ru"}
|
|
{"type": "update_tz", "tz": "Europe/Moscow"}
|
|
{"type": "set_chat", "chat": "telegram"}
|
|
`
|
|
|
|
// {"type": "search", "query": "search query"} // temporary disabled
|
|
|
|
type Action interface {
|
|
Validate(ctx context.Context, rt *Runtime) error
|
|
Execute(ctx context.Context, rt *Runtime) error
|
|
}
|
|
|
|
type Runtime struct {
|
|
User *user.User
|
|
Database database.Database
|
|
Searcher search.Searcher
|
|
Chats map[string]chat.Chat
|
|
}
|
|
|
|
type BindChat struct {
|
|
Type string `json:"type"`
|
|
UUID string `json:"uuid"`
|
|
}
|
|
|
|
func (a BindChat) Validate(ctx context.Context, rt *Runtime) error {
|
|
if a.UUID == "" {
|
|
return errors.New("target_uuid is required")
|
|
}
|
|
if _, err := uuid.Parse(a.UUID); err != nil {
|
|
return fmt.Errorf("target_uuid must be valid UUID: %w", err)
|
|
}
|
|
_, err := rt.Database.Users().Get(ctx, uuid.MustParse(a.UUID), database.UserLookupByBindCode)
|
|
if err != nil {
|
|
if errors.Is(err, database.ErrNotFound) {
|
|
return fmt.Errorf("user with bind code %s not found", a.UUID)
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a BindChat) Execute(ctx context.Context, rt *Runtime) error {
|
|
targetUser, err := rt.Database.Users().Get(ctx, uuid.MustParse(a.UUID), database.UserLookupByBindCode)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
u := rt.User
|
|
|
|
if err := rt.Database.Users().Delete(ctx, u.ID); err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, chat := range u.Chats {
|
|
if err := rt.Database.Chats().Attach(ctx, targetUser.ID, chat.Platform, chat.Identifier); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for _, contact := range u.Contacts {
|
|
if err := rt.Database.Contacts().Add(ctx, &database.Contact{
|
|
OwnerID: targetUser.ID,
|
|
TargetID: contact.TargetID,
|
|
Name: contact.Name,
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for _, fact := range u.Facts {
|
|
if err := rt.Database.Facts().Add(ctx, targetUser.ID, fact.Value); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for _, notif := range u.IncomingNotifications {
|
|
if notif.InitiatorID == u.ID {
|
|
notif.InitiatorID = targetUser.ID
|
|
}
|
|
if err := rt.Database.Notifications().Push(ctx, ¬if); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
for _, notif := range u.OutgoingNotifications {
|
|
if err := rt.Database.Notifications().Push(ctx, ¬if); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type Message struct {
|
|
Type string `json:"type"`
|
|
Platform string `json:"platform"`
|
|
Text string `json:"text"`
|
|
}
|
|
|
|
func (a Message) Validate(ctx context.Context, rt *Runtime) error {
|
|
if a.Text == "" {
|
|
return errors.New("text is empty")
|
|
}
|
|
for _, c := range rt.User.Chats {
|
|
if c.Platform == a.Platform {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("platform %s not connected", a.Platform)
|
|
}
|
|
|
|
func (a Message) Execute(ctx context.Context, rt *Runtime) error {
|
|
var platformID string
|
|
for _, c := range rt.User.Chats {
|
|
if c.Platform == a.Platform {
|
|
platformID = c.Identifier
|
|
break
|
|
}
|
|
}
|
|
return rt.Chats[a.Platform].Send(ctx, platformID, a.Text)
|
|
}
|
|
|
|
type Wait struct {
|
|
Type string `json:"type"`
|
|
Ms int `json:"ms"`
|
|
}
|
|
|
|
func (a Wait) Validate(ctx context.Context, rt *Runtime) error {
|
|
if a.Ms <= 0 {
|
|
return errors.New("ms must be positive")
|
|
}
|
|
if a.Ms > 60000 {
|
|
return errors.New("ms cannot exceed 60000")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a Wait) Execute(ctx context.Context, rt *Runtime) error {
|
|
time.Sleep(time.Duration(a.Ms) * time.Millisecond)
|
|
return nil
|
|
}
|
|
|
|
type UpdateLang struct {
|
|
Type string `json:"type"`
|
|
Lang string `json:"lang"`
|
|
}
|
|
|
|
func (a UpdateLang) Validate(ctx context.Context, rt *Runtime) error {
|
|
if a.Lang == "" {
|
|
return errors.New("lang is required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a UpdateLang) Execute(ctx context.Context, rt *Runtime) error {
|
|
rt.User.Language = a.Lang
|
|
return rt.Database.Users().Update(ctx, rt.User.User)
|
|
}
|
|
|
|
type UpdateTZ struct {
|
|
Type string `json:"type"`
|
|
TZ string `json:"tz"`
|
|
}
|
|
|
|
func (a UpdateTZ) Validate(ctx context.Context, rt *Runtime) error {
|
|
if a.TZ == "" {
|
|
return errors.New("tz is required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a UpdateTZ) Execute(ctx context.Context, rt *Runtime) error {
|
|
rt.User.Timezone = a.TZ
|
|
return rt.Database.Users().Update(ctx, rt.User.User)
|
|
}
|
|
|
|
type SetChat struct {
|
|
Type string `json:"type"`
|
|
Chat string `json:"chat"`
|
|
}
|
|
|
|
func (a SetChat) Validate(ctx context.Context, rt *Runtime) error {
|
|
if a.Chat == "" {
|
|
return errors.New("chat is required")
|
|
}
|
|
for _, c := range rt.User.Chats {
|
|
if c.Platform == a.Chat {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("chat %s not connected", a.Chat)
|
|
}
|
|
|
|
func (a SetChat) Execute(ctx context.Context, rt *Runtime) error {
|
|
rt.User.PreferredChat = a.Chat
|
|
return rt.Database.Users().Update(ctx, rt.User.User)
|
|
}
|
|
|
|
type AddFact struct {
|
|
Type string `json:"type"`
|
|
Value string `json:"value"`
|
|
}
|
|
|
|
func (a AddFact) Validate(ctx context.Context, rt *Runtime) error {
|
|
if a.Value == "" {
|
|
return errors.New("value is required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a AddFact) Execute(ctx context.Context, rt *Runtime) error {
|
|
return rt.Database.Facts().Add(ctx, rt.User.ID, a.Value)
|
|
}
|
|
|
|
type RemoveFact struct {
|
|
Type string `json:"type"`
|
|
Value string `json:"value"`
|
|
}
|
|
|
|
func (a RemoveFact) Validate(ctx context.Context, rt *Runtime) error { return nil }
|
|
|
|
func (a RemoveFact) Execute(ctx context.Context, rt *Runtime) error {
|
|
err := rt.Database.Facts().Delete(ctx, rt.User.ID, a.Value)
|
|
if errors.Is(err, database.ErrNotFound) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
type AddContact struct {
|
|
Type string `json:"type"`
|
|
UUID string `json:"uuid"`
|
|
Name string `json:"name"`
|
|
}
|
|
|
|
func (a AddContact) Validate(ctx context.Context, rt *Runtime) error {
|
|
if a.UUID == "" {
|
|
return errors.New("uuid is required")
|
|
}
|
|
if _, err := uuid.Parse(a.UUID); err != nil {
|
|
return fmt.Errorf("uuid must be valid UUID: %w", err)
|
|
}
|
|
if a.Name == "" {
|
|
return errors.New("name is required")
|
|
}
|
|
_, err := rt.Database.Users().Get(ctx, uuid.MustParse(a.UUID), database.UserLookupByContactCode)
|
|
if err != nil {
|
|
if errors.Is(err, database.ErrNotFound) {
|
|
return fmt.Errorf("user with contact code %s not found", a.UUID)
|
|
}
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a AddContact) Execute(ctx context.Context, rt *Runtime) error {
|
|
contactUser, err := rt.Database.Users().Get(ctx, uuid.MustParse(a.UUID), database.UserLookupByContactCode)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return rt.Database.Contacts().Add(ctx, &database.Contact{
|
|
OwnerID: rt.User.ID,
|
|
TargetID: contactUser.ID,
|
|
Name: a.Name,
|
|
})
|
|
}
|
|
|
|
type AddNotification struct {
|
|
Type string `json:"type"`
|
|
Target string `json:"target"`
|
|
Time string `json:"time"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
func (a AddNotification) Validate(ctx context.Context, rt *Runtime) error {
|
|
if a.Target == "" {
|
|
return errors.New("target is required")
|
|
}
|
|
if a.Time == "" {
|
|
return errors.New("time is required")
|
|
}
|
|
if a.Content == "" {
|
|
return errors.New("content is required")
|
|
}
|
|
if _, err := jtime.ToUTC(a.Time, rt.User.Timezone); err != nil {
|
|
return err
|
|
}
|
|
if a.Target == "self" {
|
|
return nil
|
|
}
|
|
for _, c := range rt.User.Contacts {
|
|
if c.Name == a.Target {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("contact %s not found", a.Target)
|
|
}
|
|
|
|
func (a AddNotification) Execute(ctx context.Context, rt *Runtime) error {
|
|
scheduledAt, _ := jtime.ToUTC(a.Time, rt.User.Timezone)
|
|
initiatorID := rt.User.ID
|
|
targetID := rt.User.ID
|
|
if a.Target != "self" {
|
|
for _, c := range rt.User.Contacts {
|
|
if c.Name == a.Target {
|
|
initiatorID = c.TargetID
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return rt.Database.Notifications().Push(ctx, &database.Notification{
|
|
ID: uuid.New(),
|
|
UserID: targetID,
|
|
InitiatorID: initiatorID,
|
|
ScheduledAt: scheduledAt,
|
|
Content: a.Content,
|
|
})
|
|
}
|
|
|
|
type RemoveNotification struct {
|
|
Type string `json:"type"`
|
|
UUID string `json:"uuid"`
|
|
}
|
|
|
|
func (a RemoveNotification) Validate(ctx context.Context, rt *Runtime) error {
|
|
if a.UUID == "" {
|
|
return errors.New("uuid is required")
|
|
}
|
|
if _, err := uuid.Parse(a.UUID); err != nil {
|
|
return fmt.Errorf("uuid must be valid UUID: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a RemoveNotification) Execute(ctx context.Context, rt *Runtime) error {
|
|
err := rt.Database.Notifications().Delete(ctx, uuid.MustParse(a.UUID))
|
|
if errors.Is(err, database.ErrNotFound) {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
|
|
func Parse(raw string, userTimezone string) ([]Action, error) {
|
|
start := strings.Index(raw, "[")
|
|
end := strings.LastIndex(raw, "]")
|
|
if start != -1 && end != -1 && end > start {
|
|
raw = raw[start : end+1]
|
|
}
|
|
raw = strings.TrimSpace(raw)
|
|
|
|
if !gjson.Valid(raw) {
|
|
return nil, errors.New("response is not valid JSON")
|
|
}
|
|
result := gjson.Parse(raw)
|
|
if !result.IsArray() {
|
|
return nil, errors.New("response must be an array of actions")
|
|
}
|
|
|
|
var actions []Action
|
|
for i, item := range result.Array() {
|
|
actionType := item.Get("type").String()
|
|
if actionType == "" {
|
|
return nil, fmt.Errorf("action %d: missing required field 'type'", i)
|
|
}
|
|
|
|
var action Action
|
|
var err error
|
|
|
|
switch actionType {
|
|
case "bind_chat":
|
|
var a BindChat
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
case "message":
|
|
var a Message
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
case "wait":
|
|
var a Wait
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
case "update_lang":
|
|
var a UpdateLang
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
case "update_tz":
|
|
var a UpdateTZ
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
case "set_chat":
|
|
var a SetChat
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
case "add_fact":
|
|
var a AddFact
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
case "remove_fact":
|
|
var a RemoveFact
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
case "add_contact":
|
|
var a AddContact
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
case "add_notification":
|
|
var a AddNotification
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
case "remove_notification":
|
|
var a RemoveNotification
|
|
err = json.Unmarshal([]byte(item.Raw), &a)
|
|
action = a
|
|
default:
|
|
return nil, fmt.Errorf("action %d: unknown action type: %s", i, actionType)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("action %d (%s): %w", i, actionType, err)
|
|
}
|
|
actions = append(actions, action)
|
|
}
|
|
return actions, nil
|
|
}
|
|
|
|
type UserAction struct {
|
|
Type string `json:"type"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
func (UserAction) Validate(ctx context.Context, rt *Runtime) error { return nil }
|
|
|
|
func (UserAction) Execute(ctx context.Context, rt *Runtime) error { return nil }
|
|
|
|
func Raw(v Action) json.RawMessage {
|
|
data, _ := json.Marshal(v)
|
|
return data
|
|
}
|