From 605c016a08b31bab3d89ce06743344e5989bf430 Mon Sep 17 00:00:00 2001 From: Will Andrews Date: Thu, 25 Sep 2025 06:57:01 +0100 Subject: [PATCH] Rough draft for working out who to send the alert to and allowing users to subscribe via dm --- cmd/main.go | 28 ++++ consumer.go | 50 +++++- database.go | 86 ++++++++++ dm_handler.go | 441 ++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 604 insertions(+), 1 deletion(-) create mode 100644 dm_handler.go diff --git a/cmd/main.go b/cmd/main.go index 705b8ff..a8494f3 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -12,6 +12,7 @@ import ( "os/signal" "path" "syscall" + "time" tangledalertbot "tangled.sh/willdot.net/tangled-alert-bot" @@ -56,6 +57,11 @@ func run() error { } defer database.Close() + dmService, err := tangledalertbot.NewDmService(database, time.Second*30) + if err != nil { + return fmt.Errorf("create dm service: %w", err) + } + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -63,6 +69,8 @@ func run() error { go startHttpServer(ctx, database) + go dmService.Start(ctx) + <-signals cancel() @@ -101,6 +109,7 @@ func startHttpServer(ctx context.Context, db *tangledalertbot.Database) { mux := http.NewServeMux() mux.HandleFunc("/issues", srv.handleListIssues) mux.HandleFunc("/comments", srv.handleListComments) + mux.HandleFunc("/users", srv.handleListUsers) err := http.ListenAndServe(":3000", mux) if err != nil { @@ -149,3 +158,22 @@ func (s *server) handleListComments(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.Write(b) } + +func (s *server) handleListUsers(w http.ResponseWriter, r *http.Request) { + users, err := s.db.GetUsers() + if err != nil { + slog.Error("getting users from DB", "error", err) + http.Error(w, "error getting users from DB", http.StatusInternalServerError) + return + } + + b, err := json.Marshal(users) + if err != nil { + slog.Error("marshalling users from DB", "error", err) + http.Error(w, "marshalling users from DB", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write(b) +} diff --git a/consumer.go b/consumer.go index 0ca8e1a..62e5f03 100644 --- a/consumer.go +++ b/consumer.go @@ -3,6 +3,7 @@ package tangledalertbot import ( "context" "encoding/json" + "strings" "fmt" "log/slog" @@ -38,6 +39,8 @@ type Store interface { DeleteIssue(did, rkey string) error DeleteComment(did, rkey string) error DeleteCommentsForIssue(issueURI string) error + GetUser(did string) (User, error) + CreateUser(user User) error } // JetstreamConsumer is responsible for consuming from a jetstream instance @@ -217,8 +220,19 @@ func (h *Handler) handleCreateUpdateIssueCommentEvent(ctx context.Context, event } // TODO: now send a notification to either the issue creator or whoever the comment was a reply to + didToNotify := getUserToAlert(comment) + if didToNotify == "" { + slog.Info("could not work out did to send alert to", "comment", comment) + return + } + + user, err := h.store.GetUser(didToNotify) + if err != nil { + slog.Error("getting user to send alert to", "error", err, "did", didToNotify) + return + } - slog.Info("created comment ", "value", comment, "did", did, "rkey", rkey) + slog.Info("sending alert to user", "value", comment, "did", didToNotify, "convo", user.ConvoID) } func (h *Handler) handleDeleteIssueEvent(ctx context.Context, event *models.Event) { @@ -259,3 +273,37 @@ func (h *Handler) handleDeleteIssueCommentEvent(ctx context.Context, event *mode slog.Info("deleted comment ", "did", did, "rkey", rkey) } + +// at://did:plc:dadhhalkfcq3gucaq25hjqon/sh.tangled.repo.issue.comment/3lzkp4va62m22 +func getUserToAlert(comment tangled.RepoIssueComment) string { + if comment.ReplyTo != nil { + return getDidFromCommentURI(*comment.ReplyTo) + } + return getDidFromIssueURI(comment.Issue) +} + +func getDidFromCommentURI(uri string) string { + split := strings.Split(uri, tangled.RepoIssueCommentNSID) + if len(split) != 2 { + slog.Error("invalid comment URI received", "uri", uri) + return "" + } + + did := strings.TrimPrefix(split[0], "at://") + did = strings.TrimSuffix(did, "/") + + return did +} + +func getDidFromIssueURI(uri string) string { + split := strings.Split(uri, tangled.RepoIssueNSID) + if len(split) != 2 { + slog.Error("invalid issue URI received", "uri", uri) + return "" + } + + did := strings.TrimPrefix(split[0], "at://") + did = strings.TrimSuffix(did, "/") + + return did +} diff --git a/database.go b/database.go index 1e21d83..8455027 100644 --- a/database.go +++ b/database.go @@ -44,6 +44,11 @@ func NewDatabase(dbPath string) (*Database, error) { return nil, fmt.Errorf("creating comments table: %w", err) } + err = createUsersTable(db) + if err != nil { + return nil, fmt.Errorf("creating users table: %w", err) + } + return &Database{db: db}, nil } @@ -122,6 +127,30 @@ func createCommentsTable(db *sql.DB) error { return nil } +func createUsersTable(db *sql.DB) error { + createTableSQL := `CREATE TABLE IF NOT EXISTS users ( + "id" integer NOT NULL PRIMARY KEY AUTOINCREMENT, + "did" TEXT, + "handle" TEXT, + "convoId" TEXT, + "createdAt" integer NOT NULL, + UNIQUE(did) + );` + + slog.Info("Create users table...") + statement, err := db.Prepare(createTableSQL) + if err != nil { + return fmt.Errorf("prepare DB statement to create users table: %w", err) + } + _, err = statement.Exec() + if err != nil { + return fmt.Errorf("exec sql statement to create users table: %w", err) + } + slog.Info("users table created") + + return nil +} + // CreateIssue will insert a issue into a database func (d *Database) CreateIssue(issue Issue) error { sql := `REPLACE INTO issues (authorDid, rkey, title, body, repo, createdAt) VALUES (?, ?, ?, ?, ?, ?);` @@ -142,6 +171,16 @@ func (d *Database) CreateComment(comment Comment) error { return nil } +// CreateUser will insert a user into a database +func (d *Database) CreateUser(user User) error { + sql := `REPLACE INTO users (did, handle, convoId, createdAt) VALUES (?, ?, ?, ?);` + _, err := d.db.Exec(sql, user.DID, user.Handle, user.ConvoID, user.CreatedAt) + if err != nil { + return fmt.Errorf("exec insert user: %w", err) + } + return nil +} + func (d *Database) GetIssues() ([]Issue, error) { sql := "SELECT authorDid, rkey, title, body, repo, createdAt FROM issues;" rows, err := d.db.Query(sql) @@ -182,6 +221,44 @@ func (d *Database) GetComments() ([]Comment, error) { return results, nil } +func (d *Database) GetUser(did string) (User, error) { + sql := "SELECT did, handle, convoId, createdAt FROM users WHERE did = ?;" + rows, err := d.db.Query(sql, did) + if err != nil { + return User{}, fmt.Errorf("run query to get user: %w", err) + } + defer rows.Close() + + for rows.Next() { + var user User + if err := rows.Scan(&user.DID, &user.Handle, &user.ConvoID, &user.CreatedAt); err != nil { + return User{}, fmt.Errorf("scan row: %w", err) + } + + return user, nil + } + return User{}, fmt.Errorf("user not found") +} + +func (d *Database) GetUsers() ([]User, error) { + sql := "SELECT did, handle, convoId, createdAt FROM users;" + rows, err := d.db.Query(sql) + if err != nil { + return nil, fmt.Errorf("run query to get user: %w", err) + } + defer rows.Close() + + var results []User + for rows.Next() { + var user User + if err := rows.Scan(&user.DID, &user.Handle, &user.ConvoID, &user.CreatedAt); err != nil { + return nil, fmt.Errorf("scan row: %w", err) + } + results = append(results, user) + } + return results, nil +} + func (d *Database) DeleteIssue(did, rkey string) error { sql := "DELETE FROM issues WHERE authorDid = ? AND rkey = ?;" _, err := d.db.Exec(sql, did, rkey) @@ -208,3 +285,12 @@ func (d *Database) DeleteCommentsForIssue(issueURI string) error { } return nil } + +func (d *Database) DeleteUser(did string) error { + sql := "DELETE FROM users WHERE did = ?;" + _, err := d.db.Exec(sql, did) + if err != nil { + return fmt.Errorf("exec delete user") + } + return nil +} diff --git a/dm_handler.go b/dm_handler.go new file mode 100644 index 0000000..bf98f67 --- /dev/null +++ b/dm_handler.go @@ -0,0 +1,441 @@ +package tangledalertbot + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "os" + "strings" + "time" + + "github.com/pkg/errors" +) + +const ( + httpClientTimeoutDuration = time.Second * 5 + transportIdleConnTimeoutDuration = time.Second * 90 + baseBskyURL = "https://bsky.social/xrpc" +) + +type auth struct { + AccessJwt string `json:"accessJwt"` + RefershJWT string `json:"refreshJwt"` + Did string `json:"did"` +} + +type accessData struct { + handle string + appPassword string +} + +type ListConvosResponse struct { + Cursor string `json:"cursor"` + Convos []Convo `json:"convos"` +} + +type Convo struct { + ID string `json:"id"` + Members []ConvoMember `json:"members"` + UnreadCount int `json:"unreadCount"` +} + +type ConvoMember struct { + Did string `json:"did"` + Handle string `json:"handle"` +} + +type ErrorResponse struct { + Error string `json:"error"` +} + +type MessageResp struct { + Messages []Message `json:"messages"` + Cursor string `json:"cursor"` +} + +type Message struct { + ID string `json:"id"` + Sender MessageSender `json:"sender"` + Text string `json:"text"` +} + +type MessageSender struct { + Did string `json:"did"` +} + +type UpdateMessageReadRequest struct { + ConvoID string `json:"convoId"` + MessageID string `json:"messageId"` +} + +type User struct { + DID string + Handle string + ConvoID string + CreatedAt int +} + +type DmService struct { + httpClient *http.Client + accessData accessData + auth auth + timerDuration time.Duration + pdsURL string + store Store +} + +func NewDmService(store Store, timerDuration time.Duration) (*DmService, error) { + httpClient := http.Client{ + Timeout: httpClientTimeoutDuration, + Transport: &http.Transport{ + IdleConnTimeout: transportIdleConnTimeoutDuration, + }, + } + + accessHandle := os.Getenv("MESSAGING_ACCESS_HANDLE") + accessAppPassword := os.Getenv("MESSAGING_ACCESS_APP_PASSWORD") + pdsURL := os.Getenv("MESSAGING_PDS_URL") + + service := DmService{ + httpClient: &httpClient, + accessData: accessData{ + handle: accessHandle, + appPassword: accessAppPassword, + }, + timerDuration: timerDuration, + pdsURL: pdsURL, + store: store, + } + + auth, err := service.Authenicate() + if err != nil { + return nil, fmt.Errorf("authenticating: %w", err) + } + + service.auth = auth + + return &service, nil +} + +func (d *DmService) Start(ctx context.Context) { + go d.RefreshTask(ctx) + + timer := time.NewTimer(d.timerDuration) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + slog.Warn("context canceled - stopping dm task") + return + case <-timer.C: + err := d.HandleMessageTimer(ctx) + if err != nil { + slog.Error("handle message timer", "error", err) + } + timer.Reset(d.timerDuration) + } + } +} + +func (d *DmService) RefreshTask(ctx context.Context) { + timer := time.NewTimer(time.Hour) + defer timer.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-timer.C: + err := d.RefreshAuthenication(ctx) + if err != nil { + slog.Error("handle refresh auth timer", "error", err) + // TODO: better retry with backoff probably + timer.Reset(time.Minute) + continue + } + timer.Reset(time.Hour) + } + } +} + +func (d *DmService) HandleMessageTimer(ctx context.Context) error { + convoResp, err := d.GetUnreadMessages() + if err != nil { + return fmt.Errorf("get unread messages: %w", err) + } + + // TODO: handle the cursor pagination + + for _, convo := range convoResp.Convos { + if convo.UnreadCount == 0 { + continue + } + + messageResp, err := d.GetMessages(ctx, convo.ID) + if err != nil { + slog.Error("failed to get messages for convo", "error", err, "convo id", convo.ID) + continue + } + + unreadCount := convo.UnreadCount + unreadMessages := make([]Message, 0, convo.UnreadCount) + // TODO: handle cursor pagination + for _, msg := range messageResp.Messages { + // TODO: techincally if I get to a message that's from the bot account, then there shouldn't be + // an more unread messages? + if msg.Sender.Did == d.auth.Did { + continue + } + + unreadMessages = append(unreadMessages, msg) + unreadCount-- + if unreadCount == 0 { + break + } + } + + for _, msg := range unreadMessages { + d.handleMessage(msg, convo) + + err = d.MarkMessageRead(msg.ID, convo.ID) + if err != nil { + slog.Error("marking message read", "error", err) + continue + } + } + } + + return nil +} + +func (d *DmService) handleMessage(msg Message, convo Convo) { + // TODO: add or remote user the list of "subsribed" users + if strings.ToLower(msg.Text) == "subscribe" { + userHandle := "" + for _, member := range convo.Members { + if member.Did == msg.Sender.Did { + userHandle = member.Handle + break + } + } + + if userHandle == "" { + slog.Error("user handle for sent message not found", "sender did", msg.Sender.Did, "convo members", convo.Members) + return + } + + user := User{ + DID: msg.Sender.Did, + ConvoID: convo.ID, + Handle: userHandle, + CreatedAt: int(time.Now().UnixMilli()), + } + + err := d.store.CreateUser(user) + if err != nil { + slog.Error("error creating user", "error", err, "user", user) + return + } + } +} + +func (d *DmService) GetUnreadMessages() (ListConvosResponse, error) { + url := fmt.Sprintf("%s/xrpc/chat.bsky.convo.listConvos?readState=unread", d.pdsURL) + request, err := http.NewRequest("GET", url, nil) + if err != nil { + return ListConvosResponse{}, fmt.Errorf("create new list convos http request: %w", err) + } + + request.Header.Add("Content-Type", "application/json") + request.Header.Add("Accept", "application/json") + request.Header.Add("Atproto-Proxy", "did:web:api.bsky.chat#bsky_chat") + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", d.auth.AccessJwt)) + + resp, err := d.httpClient.Do(request) + if err != nil { + return ListConvosResponse{}, fmt.Errorf("do http request to list convos: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + var errorResp ErrorResponse + err = decodeResp(resp.Body, &errorResp) + if err != nil { + return ListConvosResponse{}, err + } + + return ListConvosResponse{}, fmt.Errorf("listing convos responded with code %d: %s", resp.StatusCode, errorResp.Error) + } + + var listConvoResp ListConvosResponse + err = decodeResp(resp.Body, &listConvoResp) + if err != nil { + return ListConvosResponse{}, err + } + + return listConvoResp, nil +} + +func (d *DmService) MarkMessageRead(messageID, convoID string) error { + bodyReq := UpdateMessageReadRequest{ + ConvoID: convoID, + MessageID: messageID, + } + + bodyB, err := json.Marshal(bodyReq) + if err != nil { + return fmt.Errorf("marshal update message request body: %w", err) + } + + r := bytes.NewReader(bodyB) + + url := fmt.Sprintf("%s/xrpc/chat.bsky.convo.updateRead", d.pdsURL) + request, err := http.NewRequest("POST", url, r) + if err != nil { + return fmt.Errorf("create new list convos http request: %w", err) + } + + request.Header.Add("Content-Type", "application/json") + request.Header.Add("Accept", "application/json") + request.Header.Add("Atproto-Proxy", "did:web:api.bsky.chat#bsky_chat") + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", d.auth.AccessJwt)) + + resp, err := d.httpClient.Do(request) + if err != nil { + return fmt.Errorf("do http request to update message read: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + return nil + } + + var errorResp ErrorResponse + err = decodeResp(resp.Body, &errorResp) + if err != nil { + return err + } + + return fmt.Errorf("listing convos responded with code %d: %s", resp.StatusCode, errorResp.Error) + +} + +func (d *DmService) Authenicate() (auth, error) { + url := fmt.Sprintf("%s/com.atproto.server.createSession", baseBskyURL) + + requestData := map[string]interface{}{ + "identifier": d.accessData.handle, + "password": d.accessData.appPassword, + } + + data, err := json.Marshal(requestData) + if err != nil { + return auth{}, errors.Wrap(err, "failed to marshal request") + } + + r := bytes.NewReader(data) + + request, err := http.NewRequest("POST", url, r) + if err != nil { + return auth{}, errors.Wrap(err, "failed to create request") + } + + request.Header.Add("Content-Type", "application/json") + + resp, err := d.httpClient.Do(request) + if err != nil { + return auth{}, errors.Wrap(err, "failed to make request") + } + defer resp.Body.Close() + + var loginResp auth + err = decodeResp(resp.Body, &loginResp) + if err != nil { + return auth{}, err + } + + return loginResp, nil +} + +func (d *DmService) RefreshAuthenication(ctx context.Context) error { + url := fmt.Sprintf("%s/com.atproto.server.refreshSession", baseBskyURL) + + request, err := http.NewRequest("POST", url, nil) + if err != nil { + return errors.Wrap(err, "failed to create request") + } + + request.Header.Add("Content-Type", "application/json") + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", d.auth.RefershJWT)) + + resp, err := d.httpClient.Do(request) + if err != nil { + return errors.Wrap(err, "failed to make request") + } + defer resp.Body.Close() + + var loginResp auth + err = decodeResp(resp.Body, &loginResp) + if err != nil { + return err + } + + d.auth = loginResp + + return nil +} + +func (d *DmService) GetMessages(ctx context.Context, convoID string) (MessageResp, error) { + url := fmt.Sprintf("%s/xrpc/chat.bsky.convo.getMessages?convoId=%s", d.pdsURL, convoID) + request, err := http.NewRequest("GET", url, nil) + if err != nil { + return MessageResp{}, fmt.Errorf("create new get messages http request: %w", err) + } + + request.Header.Add("Content-Type", "application/json") + request.Header.Add("Accept", "application/json") + request.Header.Add("Atproto-Proxy", "did:web:api.bsky.chat#bsky_chat") + request.Header.Add("Authorization", fmt.Sprintf("Bearer %s", d.auth.AccessJwt)) + + resp, err := d.httpClient.Do(request) + if err != nil { + return MessageResp{}, fmt.Errorf("do http request to get messages: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + var errorResp ErrorResponse + err = decodeResp(resp.Body, &errorResp) + if err != nil { + return MessageResp{}, err + } + + return MessageResp{}, fmt.Errorf("listing convos responded with code %d: %s", resp.StatusCode, errorResp.Error) + } + + var messageResp MessageResp + err = decodeResp(resp.Body, &messageResp) + if err != nil { + return MessageResp{}, err + } + + return messageResp, nil +} + +func decodeResp(body io.Reader, result any) error { + resBody, err := io.ReadAll(body) + if err != nil { + return errors.Wrap(err, "failed to read response") + } + + err = json.Unmarshal(resBody, result) + if err != nil { + return errors.Wrap(err, "failed to unmarshal response") + } + return nil +} -- 2.51.0