forked from tangled.org/core
this repo has no description
1package db 2 3import ( 4 "crypto/sha1" 5 "database/sql" 6 "encoding/hex" 7 "errors" 8 "fmt" 9 "log" 10 "maps" 11 "slices" 12 "strings" 13 "time" 14 15 "github.com/bluesky-social/indigo/atproto/syntax" 16 "tangled.sh/tangled.sh/core/api/tangled" 17) 18 19type ConcreteType string 20 21const ( 22 ConcreteTypeNull ConcreteType = "null" 23 ConcreteTypeString ConcreteType = "string" 24 ConcreteTypeInt ConcreteType = "integer" 25 ConcreteTypeBool ConcreteType = "boolean" 26) 27 28type ValueTypeFormat string 29 30const ( 31 ValueTypeFormatAny ValueTypeFormat = "any" 32 ValueTypeFormatDid ValueTypeFormat = "did" 33) 34 35// ValueType represents an atproto lexicon type definition with constraints 36type ValueType struct { 37 Type ConcreteType `json:"type"` 38 Format ValueTypeFormat `json:"format,omitempty"` 39 Enum []string `json:"enum,omitempty"` 40} 41 42func (vt *ValueType) AsRecord() tangled.LabelDefinition_ValueType { 43 return tangled.LabelDefinition_ValueType{ 44 Type: string(vt.Type), 45 Format: string(vt.Format), 46 Enum: vt.Enum, 47 } 48} 49 50func ValueTypeFromRecord(record tangled.LabelDefinition_ValueType) ValueType { 51 return ValueType{ 52 Type: ConcreteType(record.Type), 53 Format: ValueTypeFormat(record.Format), 54 Enum: record.Enum, 55 } 56} 57 58func (vt ValueType) IsConcreteType() bool { 59 return vt.Type == ConcreteTypeNull || 60 vt.Type == ConcreteTypeString || 61 vt.Type == ConcreteTypeInt || 62 vt.Type == ConcreteTypeBool 63} 64 65func (vt ValueType) IsNull() bool { 66 return vt.Type == ConcreteTypeNull 67} 68 69func (vt ValueType) IsString() bool { 70 return vt.Type == ConcreteTypeString 71} 72 73func (vt ValueType) IsInt() bool { 74 return vt.Type == ConcreteTypeInt 75} 76 77func (vt ValueType) IsBool() bool { 78 return vt.Type == ConcreteTypeBool 79} 80 81func (vt ValueType) IsEnumType() bool { 82 return len(vt.Enum) > 0 83} 84 85type LabelDefinition struct { 86 Id int64 87 Did string 88 Rkey string 89 90 Name string 91 ValueType ValueType 92 Scope syntax.NSID 93 Color *string 94 Multiple bool 95 Created time.Time 96} 97 98func (l *LabelDefinition) AtUri() syntax.ATURI { 99 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", l.Did, tangled.LabelDefinitionNSID, l.Rkey)) 100} 101 102func (l *LabelDefinition) AsRecord() tangled.LabelDefinition { 103 vt := l.ValueType.AsRecord() 104 return tangled.LabelDefinition{ 105 Name: l.Name, 106 Color: l.Color, 107 CreatedAt: l.Created.Format(time.RFC3339), 108 Multiple: &l.Multiple, 109 Scope: l.Scope.String(), 110 ValueType: &vt, 111 } 112} 113 114// random color for a given seed 115func randomColor(seed string) string { 116 hash := sha1.Sum([]byte(seed)) 117 hexStr := hex.EncodeToString(hash[:]) 118 r := hexStr[0:2] 119 g := hexStr[2:4] 120 b := hexStr[4:6] 121 122 return fmt.Sprintf("#%s%s%s", r, g, b) 123} 124 125func (ld LabelDefinition) GetColor() string { 126 if ld.Color == nil { 127 seed := fmt.Sprintf("%d:%s:%s", ld.Id, ld.Did, ld.Rkey) 128 color := randomColor(seed) 129 return color 130 } 131 132 return *ld.Color 133} 134 135func LabelDefinitionFromRecord(did, rkey string, record tangled.LabelDefinition) LabelDefinition { 136 created, err := time.Parse(time.RFC3339, record.CreatedAt) 137 if err != nil { 138 created = time.Now() 139 } 140 141 multiple := false 142 if record.Multiple != nil { 143 multiple = *record.Multiple 144 } 145 146 var vt ValueType 147 if record.ValueType != nil { 148 vt = ValueTypeFromRecord(*record.ValueType) 149 } 150 151 return LabelDefinition{ 152 Did: did, 153 Rkey: rkey, 154 155 Name: record.Name, 156 ValueType: vt, 157 Scope: syntax.NSID(record.Scope), 158 Color: record.Color, 159 Multiple: multiple, 160 Created: created, 161 } 162} 163 164func DeleteLabelDefinition(e Execer, filters ...filter) error { 165 var conditions []string 166 var args []any 167 for _, filter := range filters { 168 conditions = append(conditions, filter.Condition()) 169 args = append(args, filter.Arg()...) 170 } 171 whereClause := "" 172 if conditions != nil { 173 whereClause = " where " + strings.Join(conditions, " and ") 174 } 175 query := fmt.Sprintf(`delete from label_definitions %s`, whereClause) 176 _, err := e.Exec(query, args...) 177 return err 178} 179 180func AddLabelDefinition(e Execer, l *LabelDefinition) (int64, error) { 181 result, err := e.Exec( 182 `insert into label_definitions ( 183 did, 184 rkey, 185 name, 186 value_type, 187 value_format, 188 value_enum, 189 scope, 190 color, 191 multiple, 192 created 193 ) 194 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 195 on conflict(did, rkey) do update set 196 name = excluded.name, 197 scope = excluded.scope, 198 color = excluded.color, 199 multiple = excluded.multiple`, 200 l.Did, 201 l.Rkey, 202 l.Name, 203 l.ValueType.Type, 204 l.ValueType.Format, 205 strings.Join(l.ValueType.Enum, ","), 206 l.Scope.String(), 207 l.Color, 208 l.Multiple, 209 l.Created.Format(time.RFC3339), 210 time.Now().Format(time.RFC3339), 211 ) 212 if err != nil { 213 return 0, err 214 } 215 216 id, err := result.LastInsertId() 217 if err != nil { 218 return 0, err 219 } 220 221 l.Id = id 222 223 return id, nil 224} 225 226func GetLabelDefinitions(e Execer, filters ...filter) ([]LabelDefinition, error) { 227 var labelDefinitions []LabelDefinition 228 var conditions []string 229 var args []any 230 231 for _, filter := range filters { 232 conditions = append(conditions, filter.Condition()) 233 args = append(args, filter.Arg()...) 234 } 235 236 whereClause := "" 237 if conditions != nil { 238 whereClause = " where " + strings.Join(conditions, " and ") 239 } 240 241 query := fmt.Sprintf( 242 ` 243 select 244 id, 245 did, 246 rkey, 247 name, 248 value_type, 249 value_format, 250 value_enum, 251 scope, 252 color, 253 multiple, 254 created 255 from label_definitions 256 %s 257 order by created 258 `, 259 whereClause, 260 ) 261 262 rows, err := e.Query(query, args...) 263 if err != nil { 264 return nil, err 265 } 266 defer rows.Close() 267 268 for rows.Next() { 269 var labelDefinition LabelDefinition 270 var createdAt, enumVariants string 271 var color sql.Null[string] 272 var multiple int 273 274 if err := rows.Scan( 275 &labelDefinition.Id, 276 &labelDefinition.Did, 277 &labelDefinition.Rkey, 278 &labelDefinition.Name, 279 &labelDefinition.ValueType.Type, 280 &labelDefinition.ValueType.Format, 281 &enumVariants, 282 &labelDefinition.Scope, 283 &color, 284 &multiple, 285 &createdAt, 286 ); err != nil { 287 return nil, err 288 } 289 290 labelDefinition.Created, err = time.Parse(time.RFC3339, createdAt) 291 if err != nil { 292 labelDefinition.Created = time.Now() 293 } 294 295 if color.Valid { 296 labelDefinition.Color = &color.V 297 } 298 299 if multiple != 0 { 300 labelDefinition.Multiple = true 301 } 302 303 if enumVariants != "" { 304 labelDefinition.ValueType.Enum = strings.Split(enumVariants, ",") 305 } 306 307 labelDefinitions = append(labelDefinitions, labelDefinition) 308 } 309 310 return labelDefinitions, nil 311} 312 313// helper to get exactly one label def 314func GetLabelDefinition(e Execer, filters ...filter) (*LabelDefinition, error) { 315 labels, err := GetLabelDefinitions(e, filters...) 316 if err != nil { 317 return nil, err 318 } 319 320 if labels == nil { 321 return nil, sql.ErrNoRows 322 } 323 324 if len(labels) != 1 { 325 return nil, fmt.Errorf("too many rows returned") 326 } 327 328 return &labels[0], nil 329} 330 331type LabelOp struct { 332 Id int64 333 Did string 334 Rkey string 335 Subject syntax.ATURI 336 Operation LabelOperation 337 OperandKey string 338 OperandValue string 339 PerformedAt time.Time 340 IndexedAt time.Time 341} 342 343func (l LabelOp) SortAt() time.Time { 344 createdAt := l.PerformedAt 345 indexedAt := l.IndexedAt 346 347 // if we don't have an indexedat, fall back to now 348 if indexedAt.IsZero() { 349 indexedAt = time.Now() 350 } 351 352 // if createdat is invalid (before epoch), treat as null -> return zero time 353 if createdAt.Before(time.UnixMicro(0)) { 354 return time.Time{} 355 } 356 357 // if createdat is <= indexedat, use createdat 358 if createdAt.Before(indexedAt) || createdAt.Equal(indexedAt) { 359 return createdAt 360 } 361 362 // otherwise, createdat is in the future relative to indexedat -> use indexedat 363 return indexedAt 364} 365 366type LabelOperation string 367 368const ( 369 LabelOperationAdd LabelOperation = "add" 370 LabelOperationDel LabelOperation = "del" 371) 372 373// a record can create multiple label ops 374func LabelOpsFromRecord(did, rkey string, record tangled.LabelOp) []LabelOp { 375 performed, err := time.Parse(time.RFC3339, record.PerformedAt) 376 if err != nil { 377 performed = time.Now() 378 } 379 380 mkOp := func(operand *tangled.LabelOp_Operand) LabelOp { 381 return LabelOp{ 382 Did: did, 383 Rkey: rkey, 384 Subject: syntax.ATURI(record.Subject), 385 OperandKey: operand.Key, 386 OperandValue: operand.Value, 387 PerformedAt: performed, 388 } 389 } 390 391 var ops []LabelOp 392 for _, o := range record.Add { 393 if o != nil { 394 op := mkOp(o) 395 op.Operation = LabelOperationAdd 396 ops = append(ops, op) 397 } 398 } 399 for _, o := range record.Delete { 400 if o != nil { 401 op := mkOp(o) 402 op.Operation = LabelOperationDel 403 ops = append(ops, op) 404 } 405 } 406 407 return ops 408} 409 410func LabelOpsAsRecord(ops []LabelOp) tangled.LabelOp { 411 if len(ops) == 0 { 412 return tangled.LabelOp{} 413 } 414 415 // use the first operation to establish common fields 416 first := ops[0] 417 record := tangled.LabelOp{ 418 Subject: string(first.Subject), 419 PerformedAt: first.PerformedAt.Format(time.RFC3339), 420 } 421 422 var addOperands []*tangled.LabelOp_Operand 423 var deleteOperands []*tangled.LabelOp_Operand 424 425 for _, op := range ops { 426 operand := &tangled.LabelOp_Operand{ 427 Key: op.OperandKey, 428 Value: op.OperandValue, 429 } 430 431 switch op.Operation { 432 case LabelOperationAdd: 433 addOperands = append(addOperands, operand) 434 case LabelOperationDel: 435 deleteOperands = append(deleteOperands, operand) 436 default: 437 return tangled.LabelOp{} 438 } 439 } 440 441 record.Add = addOperands 442 record.Delete = deleteOperands 443 444 return record 445} 446 447func AddLabelOp(e Execer, l *LabelOp) (int64, error) { 448 now := time.Now() 449 result, err := e.Exec( 450 `insert into label_ops ( 451 did, 452 rkey, 453 subject, 454 operation, 455 operand_key, 456 operand_value, 457 performed, 458 indexed 459 ) 460 values (?, ?, ?, ?, ?, ?, ?, ?) 461 on conflict(did, rkey, subject, operand_key, operand_value) do update set 462 operation = excluded.operation, 463 operand_value = excluded.operand_value, 464 performed = excluded.performed, 465 indexed = excluded.indexed`, 466 l.Did, 467 l.Rkey, 468 l.Subject.String(), 469 string(l.Operation), 470 l.OperandKey, 471 l.OperandValue, 472 l.PerformedAt.Format(time.RFC3339), 473 now.Format(time.RFC3339), 474 ) 475 if err != nil { 476 return 0, err 477 } 478 479 id, err := result.LastInsertId() 480 if err != nil { 481 return 0, err 482 } 483 484 l.Id = id 485 l.IndexedAt = now 486 487 return id, nil 488} 489 490func GetLabelOps(e Execer, filters ...filter) ([]LabelOp, error) { 491 var labelOps []LabelOp 492 var conditions []string 493 var args []any 494 495 for _, filter := range filters { 496 conditions = append(conditions, filter.Condition()) 497 args = append(args, filter.Arg()...) 498 } 499 500 whereClause := "" 501 if conditions != nil { 502 whereClause = " where " + strings.Join(conditions, " and ") 503 } 504 505 query := fmt.Sprintf( 506 ` 507 select 508 id, 509 did, 510 rkey, 511 subject, 512 operation, 513 operand_key, 514 operand_value, 515 performed, 516 indexed 517 from label_ops 518 %s 519 order by indexed 520 `, 521 whereClause, 522 ) 523 524 rows, err := e.Query(query, args...) 525 if err != nil { 526 return nil, err 527 } 528 defer rows.Close() 529 530 for rows.Next() { 531 var labelOp LabelOp 532 var performedAt, indexedAt string 533 534 if err := rows.Scan( 535 &labelOp.Id, 536 &labelOp.Did, 537 &labelOp.Rkey, 538 &labelOp.Subject, 539 &labelOp.Operation, 540 &labelOp.OperandKey, 541 &labelOp.OperandValue, 542 &performedAt, 543 &indexedAt, 544 ); err != nil { 545 return nil, err 546 } 547 548 labelOp.PerformedAt, err = time.Parse(time.RFC3339, performedAt) 549 if err != nil { 550 labelOp.PerformedAt = time.Now() 551 } 552 553 labelOp.IndexedAt, err = time.Parse(time.RFC3339, indexedAt) 554 if err != nil { 555 labelOp.IndexedAt = time.Now() 556 } 557 558 labelOps = append(labelOps, labelOp) 559 } 560 561 return labelOps, nil 562} 563 564// get labels for a given list of subject URIs 565func GetLabels(e Execer, filters ...filter) (map[syntax.ATURI]LabelState, error) { 566 ops, err := GetLabelOps(e, filters...) 567 if err != nil { 568 return nil, err 569 } 570 571 // group ops by subject 572 opsBySubject := make(map[syntax.ATURI][]LabelOp) 573 for _, op := range ops { 574 subject := syntax.ATURI(op.Subject) 575 opsBySubject[subject] = append(opsBySubject[subject], op) 576 } 577 578 // get all unique labelats for creating the context 579 labelAtSet := make(map[string]bool) 580 for _, op := range ops { 581 labelAtSet[op.OperandKey] = true 582 } 583 labelAts := slices.Collect(maps.Keys(labelAtSet)) 584 585 actx, err := NewLabelApplicationCtx(e, FilterIn("at_uri", labelAts)) 586 if err != nil { 587 return nil, err 588 } 589 590 // apply label ops for each subject and collect results 591 results := make(map[syntax.ATURI]LabelState) 592 for subject, subjectOps := range opsBySubject { 593 state := NewLabelState() 594 actx.ApplyLabelOps(state, subjectOps) 595 results[subject] = state 596 } 597 598 log.Println("results for get labels", "s", results) 599 600 return results, nil 601} 602 603type set = map[string]struct{} 604 605type LabelState struct { 606 inner map[string]set 607} 608 609func NewLabelState() LabelState { 610 return LabelState{ 611 inner: make(map[string]set), 612 } 613} 614 615func (s LabelState) Inner() map[string]set { 616 return s.inner 617} 618 619func (s LabelState) ContainsLabel(l string) bool { 620 if valset, exists := s.inner[l]; exists { 621 if valset != nil { 622 return true 623 } 624 } 625 626 return false 627} 628 629func (s *LabelState) GetValSet(l string) set { 630 return s.inner[l] 631} 632 633type LabelApplicationCtx struct { 634 defs map[string]*LabelDefinition // labelAt -> labelDef 635} 636 637var ( 638 LabelNoOpError = errors.New("no-op") 639) 640 641func NewLabelApplicationCtx(e Execer, filters ...filter) (*LabelApplicationCtx, error) { 642 labels, err := GetLabelDefinitions(e, filters...) 643 if err != nil { 644 return nil, err 645 } 646 647 defs := make(map[string]*LabelDefinition) 648 for _, l := range labels { 649 defs[l.AtUri().String()] = &l 650 } 651 652 return &LabelApplicationCtx{defs}, nil 653} 654 655func (c *LabelApplicationCtx) ApplyLabelOp(state LabelState, op LabelOp) error { 656 def := c.defs[op.OperandKey] 657 658 switch op.Operation { 659 case LabelOperationAdd: 660 // if valueset is empty, init it 661 if state.inner[op.OperandKey] == nil { 662 state.inner[op.OperandKey] = make(set) 663 } 664 665 // if valueset is populated & this val alr exists, this labelop is a noop 666 if valueSet, exists := state.inner[op.OperandKey]; exists { 667 if _, exists = valueSet[op.OperandValue]; exists { 668 return LabelNoOpError 669 } 670 } 671 672 if def.Multiple { 673 // append to set 674 state.inner[op.OperandKey][op.OperandValue] = struct{}{} 675 } else { 676 // reset to just this value 677 state.inner[op.OperandKey] = set{op.OperandValue: struct{}{}} 678 } 679 680 case LabelOperationDel: 681 // if label DNE, then deletion is a no-op 682 if valueSet, exists := state.inner[op.OperandKey]; !exists { 683 return LabelNoOpError 684 } else if _, exists = valueSet[op.OperandValue]; !exists { // if value DNE, then deletion is no-op 685 return LabelNoOpError 686 } 687 688 if def.Multiple { 689 // remove from set 690 delete(state.inner[op.OperandKey], op.OperandValue) 691 } else { 692 // reset the entire label 693 delete(state.inner, op.OperandKey) 694 } 695 696 // if the map becomes empty, then set it to nil, this is just the inverse of add 697 if len(state.inner[op.OperandKey]) == 0 { 698 state.inner[op.OperandKey] = nil 699 } 700 701 } 702 703 return nil 704} 705 706func (c *LabelApplicationCtx) ApplyLabelOps(state LabelState, ops []LabelOp) { 707 // sort label ops in sort order first 708 slices.SortFunc(ops, func(a, b LabelOp) int { 709 return a.SortAt().Compare(b.SortAt()) 710 }) 711 712 // apply ops in sequence 713 for _, o := range ops { 714 _ = c.ApplyLabelOp(state, o) 715 } 716} 717 718type Label struct { 719 def *LabelDefinition 720 val set 721}