1package db
2
3import (
4 "crypto/sha1"
5 "database/sql"
6 "encoding/hex"
7 "errors"
8 "fmt"
9 "log"
10 "maps"
11 "slices"
12 "strings"
13 "time"
14
15 "github.com/bluesky-social/indigo/atproto/syntax"
16 "tangled.sh/tangled.sh/core/api/tangled"
17)
18
19type ConcreteType string
20
21const (
22 ConcreteTypeNull ConcreteType = "null"
23 ConcreteTypeString ConcreteType = "string"
24 ConcreteTypeInt ConcreteType = "integer"
25 ConcreteTypeBool ConcreteType = "boolean"
26)
27
28type ValueTypeFormat string
29
30const (
31 ValueTypeFormatAny ValueTypeFormat = "any"
32 ValueTypeFormatDid ValueTypeFormat = "did"
33)
34
35// ValueType represents an atproto lexicon type definition with constraints
36type ValueType struct {
37 Type ConcreteType `json:"type"`
38 Format ValueTypeFormat `json:"format,omitempty"`
39 Enum []string `json:"enum,omitempty"`
40}
41
42func (vt *ValueType) AsRecord() tangled.LabelDefinition_ValueType {
43 return tangled.LabelDefinition_ValueType{
44 Type: string(vt.Type),
45 Format: string(vt.Format),
46 Enum: vt.Enum,
47 }
48}
49
50func ValueTypeFromRecord(record tangled.LabelDefinition_ValueType) ValueType {
51 return ValueType{
52 Type: ConcreteType(record.Type),
53 Format: ValueTypeFormat(record.Format),
54 Enum: record.Enum,
55 }
56}
57
58func (vt ValueType) IsConcreteType() bool {
59 return vt.Type == ConcreteTypeNull ||
60 vt.Type == ConcreteTypeString ||
61 vt.Type == ConcreteTypeInt ||
62 vt.Type == ConcreteTypeBool
63}
64
65func (vt ValueType) IsNull() bool {
66 return vt.Type == ConcreteTypeNull
67}
68
69func (vt ValueType) IsString() bool {
70 return vt.Type == ConcreteTypeString
71}
72
73func (vt ValueType) IsInt() bool {
74 return vt.Type == ConcreteTypeInt
75}
76
77func (vt ValueType) IsBool() bool {
78 return vt.Type == ConcreteTypeBool
79}
80
81func (vt ValueType) IsEnumType() bool {
82 return len(vt.Enum) > 0
83}
84
85type LabelDefinition struct {
86 Id int64
87 Did string
88 Rkey string
89
90 Name string
91 ValueType ValueType
92 Scope syntax.NSID
93 Color *string
94 Multiple bool
95 Created time.Time
96}
97
98func (l *LabelDefinition) AtUri() syntax.ATURI {
99 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", l.Did, tangled.LabelDefinitionNSID, l.Rkey))
100}
101
102func (l *LabelDefinition) AsRecord() tangled.LabelDefinition {
103 vt := l.ValueType.AsRecord()
104 return tangled.LabelDefinition{
105 Name: l.Name,
106 Color: l.Color,
107 CreatedAt: l.Created.Format(time.RFC3339),
108 Multiple: &l.Multiple,
109 Scope: l.Scope.String(),
110 ValueType: &vt,
111 }
112}
113
114// random color for a given seed
115func randomColor(seed string) string {
116 hash := sha1.Sum([]byte(seed))
117 hexStr := hex.EncodeToString(hash[:])
118 r := hexStr[0:2]
119 g := hexStr[2:4]
120 b := hexStr[4:6]
121
122 return fmt.Sprintf("#%s%s%s", r, g, b)
123}
124
125func (ld LabelDefinition) GetColor() string {
126 if ld.Color == nil {
127 seed := fmt.Sprintf("%d:%s:%s", ld.Id, ld.Did, ld.Rkey)
128 color := randomColor(seed)
129 return color
130 }
131
132 return *ld.Color
133}
134
135func LabelDefinitionFromRecord(did, rkey string, record tangled.LabelDefinition) LabelDefinition {
136 created, err := time.Parse(time.RFC3339, record.CreatedAt)
137 if err != nil {
138 created = time.Now()
139 }
140
141 multiple := false
142 if record.Multiple != nil {
143 multiple = *record.Multiple
144 }
145
146 var vt ValueType
147 if record.ValueType != nil {
148 vt = ValueTypeFromRecord(*record.ValueType)
149 }
150
151 return LabelDefinition{
152 Did: did,
153 Rkey: rkey,
154
155 Name: record.Name,
156 ValueType: vt,
157 Scope: syntax.NSID(record.Scope),
158 Color: record.Color,
159 Multiple: multiple,
160 Created: created,
161 }
162}
163
164func DeleteLabelDefinition(e Execer, filters ...filter) error {
165 var conditions []string
166 var args []any
167 for _, filter := range filters {
168 conditions = append(conditions, filter.Condition())
169 args = append(args, filter.Arg()...)
170 }
171 whereClause := ""
172 if conditions != nil {
173 whereClause = " where " + strings.Join(conditions, " and ")
174 }
175 query := fmt.Sprintf(`delete from label_definitions %s`, whereClause)
176 _, err := e.Exec(query, args...)
177 return err
178}
179
180func AddLabelDefinition(e Execer, l *LabelDefinition) (int64, error) {
181 result, err := e.Exec(
182 `insert into label_definitions (
183 did,
184 rkey,
185 name,
186 value_type,
187 value_format,
188 value_enum,
189 scope,
190 color,
191 multiple,
192 created
193 )
194 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
195 on conflict(did, rkey) do update set
196 name = excluded.name,
197 scope = excluded.scope,
198 color = excluded.color,
199 multiple = excluded.multiple`,
200 l.Did,
201 l.Rkey,
202 l.Name,
203 l.ValueType.Type,
204 l.ValueType.Format,
205 strings.Join(l.ValueType.Enum, ","),
206 l.Scope.String(),
207 l.Color,
208 l.Multiple,
209 l.Created.Format(time.RFC3339),
210 time.Now().Format(time.RFC3339),
211 )
212 if err != nil {
213 return 0, err
214 }
215
216 id, err := result.LastInsertId()
217 if err != nil {
218 return 0, err
219 }
220
221 l.Id = id
222
223 return id, nil
224}
225
226func GetLabelDefinitions(e Execer, filters ...filter) ([]LabelDefinition, error) {
227 var labelDefinitions []LabelDefinition
228 var conditions []string
229 var args []any
230
231 for _, filter := range filters {
232 conditions = append(conditions, filter.Condition())
233 args = append(args, filter.Arg()...)
234 }
235
236 whereClause := ""
237 if conditions != nil {
238 whereClause = " where " + strings.Join(conditions, " and ")
239 }
240
241 query := fmt.Sprintf(
242 `
243 select
244 id,
245 did,
246 rkey,
247 name,
248 value_type,
249 value_format,
250 value_enum,
251 scope,
252 color,
253 multiple,
254 created
255 from label_definitions
256 %s
257 order by created
258 `,
259 whereClause,
260 )
261
262 rows, err := e.Query(query, args...)
263 if err != nil {
264 return nil, err
265 }
266 defer rows.Close()
267
268 for rows.Next() {
269 var labelDefinition LabelDefinition
270 var createdAt, enumVariants string
271 var color sql.Null[string]
272 var multiple int
273
274 if err := rows.Scan(
275 &labelDefinition.Id,
276 &labelDefinition.Did,
277 &labelDefinition.Rkey,
278 &labelDefinition.Name,
279 &labelDefinition.ValueType.Type,
280 &labelDefinition.ValueType.Format,
281 &enumVariants,
282 &labelDefinition.Scope,
283 &color,
284 &multiple,
285 &createdAt,
286 ); err != nil {
287 return nil, err
288 }
289
290 labelDefinition.Created, err = time.Parse(time.RFC3339, createdAt)
291 if err != nil {
292 labelDefinition.Created = time.Now()
293 }
294
295 if color.Valid {
296 labelDefinition.Color = &color.V
297 }
298
299 if multiple != 0 {
300 labelDefinition.Multiple = true
301 }
302
303 if enumVariants != "" {
304 labelDefinition.ValueType.Enum = strings.Split(enumVariants, ",")
305 }
306
307 labelDefinitions = append(labelDefinitions, labelDefinition)
308 }
309
310 return labelDefinitions, nil
311}
312
313// helper to get exactly one label def
314func GetLabelDefinition(e Execer, filters ...filter) (*LabelDefinition, error) {
315 labels, err := GetLabelDefinitions(e, filters...)
316 if err != nil {
317 return nil, err
318 }
319
320 if labels == nil {
321 return nil, sql.ErrNoRows
322 }
323
324 if len(labels) != 1 {
325 return nil, fmt.Errorf("too many rows returned")
326 }
327
328 return &labels[0], nil
329}
330
331type LabelOp struct {
332 Id int64
333 Did string
334 Rkey string
335 Subject syntax.ATURI
336 Operation LabelOperation
337 OperandKey string
338 OperandValue string
339 PerformedAt time.Time
340 IndexedAt time.Time
341}
342
343func (l LabelOp) SortAt() time.Time {
344 createdAt := l.PerformedAt
345 indexedAt := l.IndexedAt
346
347 // if we don't have an indexedat, fall back to now
348 if indexedAt.IsZero() {
349 indexedAt = time.Now()
350 }
351
352 // if createdat is invalid (before epoch), treat as null -> return zero time
353 if createdAt.Before(time.UnixMicro(0)) {
354 return time.Time{}
355 }
356
357 // if createdat is <= indexedat, use createdat
358 if createdAt.Before(indexedAt) || createdAt.Equal(indexedAt) {
359 return createdAt
360 }
361
362 // otherwise, createdat is in the future relative to indexedat -> use indexedat
363 return indexedAt
364}
365
366type LabelOperation string
367
368const (
369 LabelOperationAdd LabelOperation = "add"
370 LabelOperationDel LabelOperation = "del"
371)
372
373// a record can create multiple label ops
374func LabelOpsFromRecord(did, rkey string, record tangled.LabelOp) []LabelOp {
375 performed, err := time.Parse(time.RFC3339, record.PerformedAt)
376 if err != nil {
377 performed = time.Now()
378 }
379
380 mkOp := func(operand *tangled.LabelOp_Operand) LabelOp {
381 return LabelOp{
382 Did: did,
383 Rkey: rkey,
384 Subject: syntax.ATURI(record.Subject),
385 OperandKey: operand.Key,
386 OperandValue: operand.Value,
387 PerformedAt: performed,
388 }
389 }
390
391 var ops []LabelOp
392 for _, o := range record.Add {
393 if o != nil {
394 op := mkOp(o)
395 op.Operation = LabelOperationAdd
396 ops = append(ops, op)
397 }
398 }
399 for _, o := range record.Delete {
400 if o != nil {
401 op := mkOp(o)
402 op.Operation = LabelOperationDel
403 ops = append(ops, op)
404 }
405 }
406
407 return ops
408}
409
410func LabelOpsAsRecord(ops []LabelOp) tangled.LabelOp {
411 if len(ops) == 0 {
412 return tangled.LabelOp{}
413 }
414
415 // use the first operation to establish common fields
416 first := ops[0]
417 record := tangled.LabelOp{
418 Subject: string(first.Subject),
419 PerformedAt: first.PerformedAt.Format(time.RFC3339),
420 }
421
422 var addOperands []*tangled.LabelOp_Operand
423 var deleteOperands []*tangled.LabelOp_Operand
424
425 for _, op := range ops {
426 operand := &tangled.LabelOp_Operand{
427 Key: op.OperandKey,
428 Value: op.OperandValue,
429 }
430
431 switch op.Operation {
432 case LabelOperationAdd:
433 addOperands = append(addOperands, operand)
434 case LabelOperationDel:
435 deleteOperands = append(deleteOperands, operand)
436 default:
437 return tangled.LabelOp{}
438 }
439 }
440
441 record.Add = addOperands
442 record.Delete = deleteOperands
443
444 return record
445}
446
447func AddLabelOp(e Execer, l *LabelOp) (int64, error) {
448 now := time.Now()
449 result, err := e.Exec(
450 `insert into label_ops (
451 did,
452 rkey,
453 subject,
454 operation,
455 operand_key,
456 operand_value,
457 performed,
458 indexed
459 )
460 values (?, ?, ?, ?, ?, ?, ?, ?)
461 on conflict(did, rkey, subject, operand_key, operand_value) do update set
462 operation = excluded.operation,
463 operand_value = excluded.operand_value,
464 performed = excluded.performed,
465 indexed = excluded.indexed`,
466 l.Did,
467 l.Rkey,
468 l.Subject.String(),
469 string(l.Operation),
470 l.OperandKey,
471 l.OperandValue,
472 l.PerformedAt.Format(time.RFC3339),
473 now.Format(time.RFC3339),
474 )
475 if err != nil {
476 return 0, err
477 }
478
479 id, err := result.LastInsertId()
480 if err != nil {
481 return 0, err
482 }
483
484 l.Id = id
485 l.IndexedAt = now
486
487 return id, nil
488}
489
490func GetLabelOps(e Execer, filters ...filter) ([]LabelOp, error) {
491 var labelOps []LabelOp
492 var conditions []string
493 var args []any
494
495 for _, filter := range filters {
496 conditions = append(conditions, filter.Condition())
497 args = append(args, filter.Arg()...)
498 }
499
500 whereClause := ""
501 if conditions != nil {
502 whereClause = " where " + strings.Join(conditions, " and ")
503 }
504
505 query := fmt.Sprintf(
506 `
507 select
508 id,
509 did,
510 rkey,
511 subject,
512 operation,
513 operand_key,
514 operand_value,
515 performed,
516 indexed
517 from label_ops
518 %s
519 order by indexed
520 `,
521 whereClause,
522 )
523
524 rows, err := e.Query(query, args...)
525 if err != nil {
526 return nil, err
527 }
528 defer rows.Close()
529
530 for rows.Next() {
531 var labelOp LabelOp
532 var performedAt, indexedAt string
533
534 if err := rows.Scan(
535 &labelOp.Id,
536 &labelOp.Did,
537 &labelOp.Rkey,
538 &labelOp.Subject,
539 &labelOp.Operation,
540 &labelOp.OperandKey,
541 &labelOp.OperandValue,
542 &performedAt,
543 &indexedAt,
544 ); err != nil {
545 return nil, err
546 }
547
548 labelOp.PerformedAt, err = time.Parse(time.RFC3339, performedAt)
549 if err != nil {
550 labelOp.PerformedAt = time.Now()
551 }
552
553 labelOp.IndexedAt, err = time.Parse(time.RFC3339, indexedAt)
554 if err != nil {
555 labelOp.IndexedAt = time.Now()
556 }
557
558 labelOps = append(labelOps, labelOp)
559 }
560
561 return labelOps, nil
562}
563
564// get labels for a given list of subject URIs
565func GetLabels(e Execer, filters ...filter) (map[syntax.ATURI]LabelState, error) {
566 ops, err := GetLabelOps(e, filters...)
567 if err != nil {
568 return nil, err
569 }
570
571 // group ops by subject
572 opsBySubject := make(map[syntax.ATURI][]LabelOp)
573 for _, op := range ops {
574 subject := syntax.ATURI(op.Subject)
575 opsBySubject[subject] = append(opsBySubject[subject], op)
576 }
577
578 // get all unique labelats for creating the context
579 labelAtSet := make(map[string]bool)
580 for _, op := range ops {
581 labelAtSet[op.OperandKey] = true
582 }
583 labelAts := slices.Collect(maps.Keys(labelAtSet))
584
585 actx, err := NewLabelApplicationCtx(e, FilterIn("at_uri", labelAts))
586 if err != nil {
587 return nil, err
588 }
589
590 // apply label ops for each subject and collect results
591 results := make(map[syntax.ATURI]LabelState)
592 for subject, subjectOps := range opsBySubject {
593 state := NewLabelState()
594 actx.ApplyLabelOps(state, subjectOps)
595 results[subject] = state
596 }
597
598 log.Println("results for get labels", "s", results)
599
600 return results, nil
601}
602
603type set = map[string]struct{}
604
605type LabelState struct {
606 inner map[string]set
607}
608
609func NewLabelState() LabelState {
610 return LabelState{
611 inner: make(map[string]set),
612 }
613}
614
615func (s LabelState) Inner() map[string]set {
616 return s.inner
617}
618
619func (s LabelState) ContainsLabel(l string) bool {
620 if valset, exists := s.inner[l]; exists {
621 if valset != nil {
622 return true
623 }
624 }
625
626 return false
627}
628
629func (s *LabelState) GetValSet(l string) set {
630 return s.inner[l]
631}
632
633type LabelApplicationCtx struct {
634 defs map[string]*LabelDefinition // labelAt -> labelDef
635}
636
637var (
638 LabelNoOpError = errors.New("no-op")
639)
640
641func NewLabelApplicationCtx(e Execer, filters ...filter) (*LabelApplicationCtx, error) {
642 labels, err := GetLabelDefinitions(e, filters...)
643 if err != nil {
644 return nil, err
645 }
646
647 defs := make(map[string]*LabelDefinition)
648 for _, l := range labels {
649 defs[l.AtUri().String()] = &l
650 }
651
652 return &LabelApplicationCtx{defs}, nil
653}
654
655func (c *LabelApplicationCtx) ApplyLabelOp(state LabelState, op LabelOp) error {
656 def := c.defs[op.OperandKey]
657
658 switch op.Operation {
659 case LabelOperationAdd:
660 // if valueset is empty, init it
661 if state.inner[op.OperandKey] == nil {
662 state.inner[op.OperandKey] = make(set)
663 }
664
665 // if valueset is populated & this val alr exists, this labelop is a noop
666 if valueSet, exists := state.inner[op.OperandKey]; exists {
667 if _, exists = valueSet[op.OperandValue]; exists {
668 return LabelNoOpError
669 }
670 }
671
672 if def.Multiple {
673 // append to set
674 state.inner[op.OperandKey][op.OperandValue] = struct{}{}
675 } else {
676 // reset to just this value
677 state.inner[op.OperandKey] = set{op.OperandValue: struct{}{}}
678 }
679
680 case LabelOperationDel:
681 // if label DNE, then deletion is a no-op
682 if valueSet, exists := state.inner[op.OperandKey]; !exists {
683 return LabelNoOpError
684 } else if _, exists = valueSet[op.OperandValue]; !exists { // if value DNE, then deletion is no-op
685 return LabelNoOpError
686 }
687
688 if def.Multiple {
689 // remove from set
690 delete(state.inner[op.OperandKey], op.OperandValue)
691 } else {
692 // reset the entire label
693 delete(state.inner, op.OperandKey)
694 }
695
696 // if the map becomes empty, then set it to nil, this is just the inverse of add
697 if len(state.inner[op.OperandKey]) == 0 {
698 state.inner[op.OperandKey] = nil
699 }
700
701 }
702
703 return nil
704}
705
706func (c *LabelApplicationCtx) ApplyLabelOps(state LabelState, ops []LabelOp) {
707 // sort label ops in sort order first
708 slices.SortFunc(ops, func(a, b LabelOp) int {
709 return a.SortAt().Compare(b.SortAt())
710 })
711
712 // apply ops in sequence
713 for _, o := range ops {
714 _ = c.ApplyLabelOp(state, o)
715 }
716}
717
718type Label struct {
719 def *LabelDefinition
720 val set
721}