1package db
2
3import (
4 "crypto/sha1"
5 "database/sql"
6 "encoding/hex"
7 "errors"
8 "fmt"
9 "maps"
10 "slices"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.sh/tangled.sh/core/api/tangled"
16)
17
18type ConcreteType string
19
20const (
21 ConcreteTypeNull ConcreteType = "null"
22 ConcreteTypeString ConcreteType = "string"
23 ConcreteTypeInt ConcreteType = "integer"
24 ConcreteTypeBool ConcreteType = "boolean"
25)
26
27type ValueTypeFormat string
28
29const (
30 ValueTypeFormatAny ValueTypeFormat = "any"
31 ValueTypeFormatDid ValueTypeFormat = "did"
32)
33
34// ValueType represents an atproto lexicon type definition with constraints
35type ValueType struct {
36 Type ConcreteType `json:"type"`
37 Format ValueTypeFormat `json:"format,omitempty"`
38 Enum []string `json:"enum,omitempty"`
39}
40
41func (vt *ValueType) AsRecord() tangled.LabelDefinition_ValueType {
42 return tangled.LabelDefinition_ValueType{
43 Type: string(vt.Type),
44 Format: string(vt.Format),
45 Enum: vt.Enum,
46 }
47}
48
49func ValueTypeFromRecord(record tangled.LabelDefinition_ValueType) ValueType {
50 return ValueType{
51 Type: ConcreteType(record.Type),
52 Format: ValueTypeFormat(record.Format),
53 Enum: record.Enum,
54 }
55}
56
57func (vt ValueType) IsConcreteType() bool {
58 return vt.Type == ConcreteTypeNull ||
59 vt.Type == ConcreteTypeString ||
60 vt.Type == ConcreteTypeInt ||
61 vt.Type == ConcreteTypeBool
62}
63
64func (vt ValueType) IsNull() bool {
65 return vt.Type == ConcreteTypeNull
66}
67
68func (vt ValueType) IsString() bool {
69 return vt.Type == ConcreteTypeString
70}
71
72func (vt ValueType) IsInt() bool {
73 return vt.Type == ConcreteTypeInt
74}
75
76func (vt ValueType) IsBool() bool {
77 return vt.Type == ConcreteTypeBool
78}
79
80func (vt ValueType) IsEnumType() bool {
81 return len(vt.Enum) > 0
82}
83
84func (vt ValueType) IsDidFormat() bool {
85 return vt.Format == ValueTypeFormatDid
86}
87
88func (vt ValueType) IsAnyFormat() bool {
89 return vt.Format == ValueTypeFormatAny
90}
91
92type LabelDefinition struct {
93 Id int64
94 Did string
95 Rkey string
96
97 Name string
98 ValueType ValueType
99 Scope syntax.NSID
100 Color *string
101 Multiple bool
102 Created time.Time
103}
104
105func (l *LabelDefinition) AtUri() syntax.ATURI {
106 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", l.Did, tangled.LabelDefinitionNSID, l.Rkey))
107}
108
109func (l *LabelDefinition) AsRecord() tangled.LabelDefinition {
110 vt := l.ValueType.AsRecord()
111 return tangled.LabelDefinition{
112 Name: l.Name,
113 Color: l.Color,
114 CreatedAt: l.Created.Format(time.RFC3339),
115 Multiple: &l.Multiple,
116 Scope: l.Scope.String(),
117 ValueType: &vt,
118 }
119}
120
121// random color for a given seed
122func randomColor(seed string) string {
123 hash := sha1.Sum([]byte(seed))
124 hexStr := hex.EncodeToString(hash[:])
125 r := hexStr[0:2]
126 g := hexStr[2:4]
127 b := hexStr[4:6]
128
129 return fmt.Sprintf("#%s%s%s", r, g, b)
130}
131
132func (ld LabelDefinition) GetColor() string {
133 if ld.Color == nil {
134 seed := fmt.Sprintf("%d:%s:%s", ld.Id, ld.Did, ld.Rkey)
135 color := randomColor(seed)
136 return color
137 }
138
139 return *ld.Color
140}
141
142func LabelDefinitionFromRecord(did, rkey string, record tangled.LabelDefinition) LabelDefinition {
143 created, err := time.Parse(time.RFC3339, record.CreatedAt)
144 if err != nil {
145 created = time.Now()
146 }
147
148 multiple := false
149 if record.Multiple != nil {
150 multiple = *record.Multiple
151 }
152
153 var vt ValueType
154 if record.ValueType != nil {
155 vt = ValueTypeFromRecord(*record.ValueType)
156 }
157
158 return LabelDefinition{
159 Did: did,
160 Rkey: rkey,
161
162 Name: record.Name,
163 ValueType: vt,
164 Scope: syntax.NSID(record.Scope),
165 Color: record.Color,
166 Multiple: multiple,
167 Created: created,
168 }
169}
170
171func DeleteLabelDefinition(e Execer, filters ...filter) error {
172 var conditions []string
173 var args []any
174 for _, filter := range filters {
175 conditions = append(conditions, filter.Condition())
176 args = append(args, filter.Arg()...)
177 }
178 whereClause := ""
179 if conditions != nil {
180 whereClause = " where " + strings.Join(conditions, " and ")
181 }
182 query := fmt.Sprintf(`delete from label_definitions %s`, whereClause)
183 _, err := e.Exec(query, args...)
184 return err
185}
186
187func AddLabelDefinition(e Execer, l *LabelDefinition) (int64, error) {
188 result, err := e.Exec(
189 `insert into label_definitions (
190 did,
191 rkey,
192 name,
193 value_type,
194 value_format,
195 value_enum,
196 scope,
197 color,
198 multiple,
199 created
200 )
201 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
202 on conflict(did, rkey) do update set
203 name = excluded.name,
204 scope = excluded.scope,
205 color = excluded.color,
206 multiple = excluded.multiple`,
207 l.Did,
208 l.Rkey,
209 l.Name,
210 l.ValueType.Type,
211 l.ValueType.Format,
212 strings.Join(l.ValueType.Enum, ","),
213 l.Scope.String(),
214 l.Color,
215 l.Multiple,
216 l.Created.Format(time.RFC3339),
217 time.Now().Format(time.RFC3339),
218 )
219 if err != nil {
220 return 0, err
221 }
222
223 id, err := result.LastInsertId()
224 if err != nil {
225 return 0, err
226 }
227
228 l.Id = id
229
230 return id, nil
231}
232
233func GetLabelDefinitions(e Execer, filters ...filter) ([]LabelDefinition, error) {
234 var labelDefinitions []LabelDefinition
235 var conditions []string
236 var args []any
237
238 for _, filter := range filters {
239 conditions = append(conditions, filter.Condition())
240 args = append(args, filter.Arg()...)
241 }
242
243 whereClause := ""
244 if conditions != nil {
245 whereClause = " where " + strings.Join(conditions, " and ")
246 }
247
248 query := fmt.Sprintf(
249 `
250 select
251 id,
252 did,
253 rkey,
254 name,
255 value_type,
256 value_format,
257 value_enum,
258 scope,
259 color,
260 multiple,
261 created
262 from label_definitions
263 %s
264 order by created
265 `,
266 whereClause,
267 )
268
269 rows, err := e.Query(query, args...)
270 if err != nil {
271 return nil, err
272 }
273 defer rows.Close()
274
275 for rows.Next() {
276 var labelDefinition LabelDefinition
277 var createdAt, enumVariants string
278 var color sql.Null[string]
279 var multiple int
280
281 if err := rows.Scan(
282 &labelDefinition.Id,
283 &labelDefinition.Did,
284 &labelDefinition.Rkey,
285 &labelDefinition.Name,
286 &labelDefinition.ValueType.Type,
287 &labelDefinition.ValueType.Format,
288 &enumVariants,
289 &labelDefinition.Scope,
290 &color,
291 &multiple,
292 &createdAt,
293 ); err != nil {
294 return nil, err
295 }
296
297 labelDefinition.Created, err = time.Parse(time.RFC3339, createdAt)
298 if err != nil {
299 labelDefinition.Created = time.Now()
300 }
301
302 if color.Valid {
303 labelDefinition.Color = &color.V
304 }
305
306 if multiple != 0 {
307 labelDefinition.Multiple = true
308 }
309
310 if enumVariants != "" {
311 labelDefinition.ValueType.Enum = strings.Split(enumVariants, ",")
312 }
313
314 labelDefinitions = append(labelDefinitions, labelDefinition)
315 }
316
317 return labelDefinitions, nil
318}
319
320// helper to get exactly one label def
321func GetLabelDefinition(e Execer, filters ...filter) (*LabelDefinition, error) {
322 labels, err := GetLabelDefinitions(e, filters...)
323 if err != nil {
324 return nil, err
325 }
326
327 if labels == nil {
328 return nil, sql.ErrNoRows
329 }
330
331 if len(labels) != 1 {
332 return nil, fmt.Errorf("too many rows returned")
333 }
334
335 return &labels[0], nil
336}
337
338type LabelOp struct {
339 Id int64
340 Did string
341 Rkey string
342 Subject syntax.ATURI
343 Operation LabelOperation
344 OperandKey string
345 OperandValue string
346 PerformedAt time.Time
347 IndexedAt time.Time
348}
349
350func (l LabelOp) SortAt() time.Time {
351 createdAt := l.PerformedAt
352 indexedAt := l.IndexedAt
353
354 // if we don't have an indexedat, fall back to now
355 if indexedAt.IsZero() {
356 indexedAt = time.Now()
357 }
358
359 // if createdat is invalid (before epoch), treat as null -> return zero time
360 if createdAt.Before(time.UnixMicro(0)) {
361 return time.Time{}
362 }
363
364 // if createdat is <= indexedat, use createdat
365 if createdAt.Before(indexedAt) || createdAt.Equal(indexedAt) {
366 return createdAt
367 }
368
369 // otherwise, createdat is in the future relative to indexedat -> use indexedat
370 return indexedAt
371}
372
373type LabelOperation string
374
375const (
376 LabelOperationAdd LabelOperation = "add"
377 LabelOperationDel LabelOperation = "del"
378)
379
380// a record can create multiple label ops
381func LabelOpsFromRecord(did, rkey string, record tangled.LabelOp) []LabelOp {
382 performed, err := time.Parse(time.RFC3339, record.PerformedAt)
383 if err != nil {
384 performed = time.Now()
385 }
386
387 mkOp := func(operand *tangled.LabelOp_Operand) LabelOp {
388 return LabelOp{
389 Did: did,
390 Rkey: rkey,
391 Subject: syntax.ATURI(record.Subject),
392 OperandKey: operand.Key,
393 OperandValue: operand.Value,
394 PerformedAt: performed,
395 }
396 }
397
398 var ops []LabelOp
399 for _, o := range record.Add {
400 if o != nil {
401 op := mkOp(o)
402 op.Operation = LabelOperationAdd
403 ops = append(ops, op)
404 }
405 }
406 for _, o := range record.Delete {
407 if o != nil {
408 op := mkOp(o)
409 op.Operation = LabelOperationDel
410 ops = append(ops, op)
411 }
412 }
413
414 return ops
415}
416
417func LabelOpsAsRecord(ops []LabelOp) tangled.LabelOp {
418 if len(ops) == 0 {
419 return tangled.LabelOp{}
420 }
421
422 // use the first operation to establish common fields
423 first := ops[0]
424 record := tangled.LabelOp{
425 Subject: string(first.Subject),
426 PerformedAt: first.PerformedAt.Format(time.RFC3339),
427 }
428
429 var addOperands []*tangled.LabelOp_Operand
430 var deleteOperands []*tangled.LabelOp_Operand
431
432 for _, op := range ops {
433 operand := &tangled.LabelOp_Operand{
434 Key: op.OperandKey,
435 Value: op.OperandValue,
436 }
437
438 switch op.Operation {
439 case LabelOperationAdd:
440 addOperands = append(addOperands, operand)
441 case LabelOperationDel:
442 deleteOperands = append(deleteOperands, operand)
443 default:
444 return tangled.LabelOp{}
445 }
446 }
447
448 record.Add = addOperands
449 record.Delete = deleteOperands
450
451 return record
452}
453
454func AddLabelOp(e Execer, l *LabelOp) (int64, error) {
455 now := time.Now()
456 result, err := e.Exec(
457 `insert into label_ops (
458 did,
459 rkey,
460 subject,
461 operation,
462 operand_key,
463 operand_value,
464 performed,
465 indexed
466 )
467 values (?, ?, ?, ?, ?, ?, ?, ?)
468 on conflict(did, rkey, subject, operand_key, operand_value) do update set
469 operation = excluded.operation,
470 operand_value = excluded.operand_value,
471 performed = excluded.performed,
472 indexed = excluded.indexed`,
473 l.Did,
474 l.Rkey,
475 l.Subject.String(),
476 string(l.Operation),
477 l.OperandKey,
478 l.OperandValue,
479 l.PerformedAt.Format(time.RFC3339),
480 now.Format(time.RFC3339),
481 )
482 if err != nil {
483 return 0, err
484 }
485
486 id, err := result.LastInsertId()
487 if err != nil {
488 return 0, err
489 }
490
491 l.Id = id
492 l.IndexedAt = now
493
494 return id, nil
495}
496
497func GetLabelOps(e Execer, filters ...filter) ([]LabelOp, error) {
498 var labelOps []LabelOp
499 var conditions []string
500 var args []any
501
502 for _, filter := range filters {
503 conditions = append(conditions, filter.Condition())
504 args = append(args, filter.Arg()...)
505 }
506
507 whereClause := ""
508 if conditions != nil {
509 whereClause = " where " + strings.Join(conditions, " and ")
510 }
511
512 query := fmt.Sprintf(
513 `
514 select
515 id,
516 did,
517 rkey,
518 subject,
519 operation,
520 operand_key,
521 operand_value,
522 performed,
523 indexed
524 from label_ops
525 %s
526 order by indexed
527 `,
528 whereClause,
529 )
530
531 rows, err := e.Query(query, args...)
532 if err != nil {
533 return nil, err
534 }
535 defer rows.Close()
536
537 for rows.Next() {
538 var labelOp LabelOp
539 var performedAt, indexedAt string
540
541 if err := rows.Scan(
542 &labelOp.Id,
543 &labelOp.Did,
544 &labelOp.Rkey,
545 &labelOp.Subject,
546 &labelOp.Operation,
547 &labelOp.OperandKey,
548 &labelOp.OperandValue,
549 &performedAt,
550 &indexedAt,
551 ); err != nil {
552 return nil, err
553 }
554
555 labelOp.PerformedAt, err = time.Parse(time.RFC3339, performedAt)
556 if err != nil {
557 labelOp.PerformedAt = time.Now()
558 }
559
560 labelOp.IndexedAt, err = time.Parse(time.RFC3339, indexedAt)
561 if err != nil {
562 labelOp.IndexedAt = time.Now()
563 }
564
565 labelOps = append(labelOps, labelOp)
566 }
567
568 return labelOps, nil
569}
570
571// get labels for a given list of subject URIs
572func GetLabels(e Execer, filters ...filter) (map[syntax.ATURI]LabelState, error) {
573 ops, err := GetLabelOps(e, filters...)
574 if err != nil {
575 return nil, err
576 }
577
578 // group ops by subject
579 opsBySubject := make(map[syntax.ATURI][]LabelOp)
580 for _, op := range ops {
581 subject := syntax.ATURI(op.Subject)
582 opsBySubject[subject] = append(opsBySubject[subject], op)
583 }
584
585 // get all unique labelats for creating the context
586 labelAtSet := make(map[string]bool)
587 for _, op := range ops {
588 labelAtSet[op.OperandKey] = true
589 }
590 labelAts := slices.Collect(maps.Keys(labelAtSet))
591
592 actx, err := NewLabelApplicationCtx(e, FilterIn("at_uri", labelAts))
593 if err != nil {
594 return nil, err
595 }
596
597 // apply label ops for each subject and collect results
598 results := make(map[syntax.ATURI]LabelState)
599 for subject, subjectOps := range opsBySubject {
600 state := NewLabelState()
601 actx.ApplyLabelOps(state, subjectOps)
602 results[subject] = state
603 }
604
605 return results, nil
606}
607
608type set = map[string]struct{}
609
610type LabelState struct {
611 inner map[string]set
612}
613
614func NewLabelState() LabelState {
615 return LabelState{
616 inner: make(map[string]set),
617 }
618}
619
620func (s LabelState) Inner() map[string]set {
621 return s.inner
622}
623
624func (s LabelState) ContainsLabel(l string) bool {
625 if valset, exists := s.inner[l]; exists {
626 if valset != nil {
627 return true
628 }
629 }
630
631 return false
632}
633
634func (s *LabelState) GetValSet(l string) set {
635 return s.inner[l]
636}
637
638type LabelApplicationCtx struct {
639 Defs map[string]*LabelDefinition // labelAt -> labelDef
640}
641
642var (
643 LabelNoOpError = errors.New("no-op")
644)
645
646func NewLabelApplicationCtx(e Execer, filters ...filter) (*LabelApplicationCtx, error) {
647 labels, err := GetLabelDefinitions(e, filters...)
648 if err != nil {
649 return nil, err
650 }
651
652 defs := make(map[string]*LabelDefinition)
653 for _, l := range labels {
654 defs[l.AtUri().String()] = &l
655 }
656
657 return &LabelApplicationCtx{defs}, nil
658}
659
660func (c *LabelApplicationCtx) ApplyLabelOp(state LabelState, op LabelOp) error {
661 def := c.Defs[op.OperandKey]
662
663 switch op.Operation {
664 case LabelOperationAdd:
665 // if valueset is empty, init it
666 if state.inner[op.OperandKey] == nil {
667 state.inner[op.OperandKey] = make(set)
668 }
669
670 // if valueset is populated & this val alr exists, this labelop is a noop
671 if valueSet, exists := state.inner[op.OperandKey]; exists {
672 if _, exists = valueSet[op.OperandValue]; exists {
673 return LabelNoOpError
674 }
675 }
676
677 if def.Multiple {
678 // append to set
679 state.inner[op.OperandKey][op.OperandValue] = struct{}{}
680 } else {
681 // reset to just this value
682 state.inner[op.OperandKey] = set{op.OperandValue: struct{}{}}
683 }
684
685 case LabelOperationDel:
686 // if label DNE, then deletion is a no-op
687 if valueSet, exists := state.inner[op.OperandKey]; !exists {
688 return LabelNoOpError
689 } else if _, exists = valueSet[op.OperandValue]; !exists { // if value DNE, then deletion is no-op
690 return LabelNoOpError
691 }
692
693 if def.Multiple {
694 // remove from set
695 delete(state.inner[op.OperandKey], op.OperandValue)
696 } else {
697 // reset the entire label
698 delete(state.inner, op.OperandKey)
699 }
700
701 // if the map becomes empty, then set it to nil, this is just the inverse of add
702 if len(state.inner[op.OperandKey]) == 0 {
703 state.inner[op.OperandKey] = nil
704 }
705
706 }
707
708 return nil
709}
710
711func (c *LabelApplicationCtx) ApplyLabelOps(state LabelState, ops []LabelOp) {
712 // sort label ops in sort order first
713 slices.SortFunc(ops, func(a, b LabelOp) int {
714 return a.SortAt().Compare(b.SortAt())
715 })
716
717 // apply ops in sequence
718 for _, o := range ops {
719 _ = c.ApplyLabelOp(state, o)
720 }
721}