1package main
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "log"
8 "net/http"
9 "net/url"
10 "os"
11 "os/exec"
12 "strings"
13 "time"
14
15 securejoin "github.com/cyphar/filepath-securejoin"
16 "tangled.sh/tangled.sh/core/appview"
17)
18
19var (
20 logger *log.Logger
21 logFile *os.File
22 clientIP string
23
24 // Command line flags
25 incomingUser = flag.String("user", "", "Allowed git user")
26 baseDirFlag = flag.String("base-dir", "/home/git", "Base directory for git repositories")
27 logPathFlag = flag.String("log-path", "/var/log/git-wrapper.log", "Path to log file")
28 endpoint = flag.String("internal-api", "http://localhost:5444", "Internal API endpoint")
29)
30
31func main() {
32 flag.Parse()
33
34 defer cleanup()
35 initLogger()
36
37 // Get client IP from SSH environment
38 if connInfo := os.Getenv("SSH_CONNECTION"); connInfo != "" {
39 parts := strings.Fields(connInfo)
40 if len(parts) > 0 {
41 clientIP = parts[0]
42 }
43 }
44
45 if *incomingUser == "" {
46 exitWithLog("access denied: no user specified")
47 }
48
49 sshCommand := os.Getenv("SSH_ORIGINAL_COMMAND")
50
51 logEvent("Connection attempt", map[string]interface{}{
52 "user": *incomingUser,
53 "command": sshCommand,
54 "client": clientIP,
55 })
56
57 if sshCommand == "" {
58 exitWithLog("access denied: we don't serve interactive shells :)")
59 }
60
61 cmdParts := strings.Fields(sshCommand)
62 if len(cmdParts) < 2 {
63 exitWithLog("invalid command format")
64 }
65
66 gitCommand := cmdParts[0]
67
68 // did:foo/repo-name or
69 // handle/repo-name or
70 // any of the above with a leading slash (/)
71
72 components := strings.Split(strings.TrimPrefix(strings.Trim(cmdParts[1], "'"), "/"), "/")
73 logEvent("Command components", map[string]interface{}{
74 "components": components,
75 })
76 if len(components) != 2 {
77 exitWithLog("invalid repo format, needs <user>/<repo> or /<user>/<repo>")
78 }
79
80 didOrHandle := components[0]
81 did := resolveToDid(didOrHandle)
82 repoName := components[1]
83 qualifiedRepoName, _ := securejoin.SecureJoin(did, repoName)
84
85 validCommands := map[string]bool{
86 "git-receive-pack": true,
87 "git-upload-pack": true,
88 "git-upload-archive": true,
89 }
90 if !validCommands[gitCommand] {
91 exitWithLog("access denied: invalid git command")
92 }
93
94 if gitCommand != "git-upload-pack" {
95 if !isPushPermitted(*incomingUser, qualifiedRepoName) {
96 logEvent("all infos", map[string]interface{}{
97 "did": *incomingUser,
98 "reponame": qualifiedRepoName,
99 })
100 exitWithLog("access denied: user not allowed")
101 }
102 }
103
104 fullPath, _ := securejoin.SecureJoin(*baseDirFlag, qualifiedRepoName)
105
106 logEvent("Processing command", map[string]interface{}{
107 "user": *incomingUser,
108 "command": gitCommand,
109 "repo": repoName,
110 "fullPath": fullPath,
111 "client": clientIP,
112 })
113
114 if gitCommand == "git-upload-pack" {
115 fmt.Fprintf(os.Stderr, "\x02%s\n", "Welcome to this knot!")
116 } else {
117 fmt.Fprintf(os.Stderr, "%s\n", "Welcome to this knot!")
118 }
119
120 cmd := exec.Command(gitCommand, fullPath)
121 cmd.Stdout = os.Stdout
122 cmd.Stderr = os.Stderr
123 cmd.Stdin = os.Stdin
124
125 if err := cmd.Run(); err != nil {
126 exitWithLog(fmt.Sprintf("command failed: %v", err))
127 }
128
129 logEvent("Command completed", map[string]interface{}{
130 "user": *incomingUser,
131 "command": gitCommand,
132 "repo": repoName,
133 "success": true,
134 })
135}
136
137func resolveToDid(didOrHandle string) string {
138 resolver := appview.NewResolver()
139 ident, err := resolver.ResolveIdent(context.Background(), didOrHandle)
140 if err != nil {
141 exitWithLog(fmt.Sprintf("error resolving handle: %v", err))
142 }
143
144 // did:plc:foobarbaz/repo
145 return ident.DID.String()
146}
147
148func initLogger() {
149 var err error
150 logFile, err = os.OpenFile(*logPathFlag, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
151 if err != nil {
152 fmt.Fprintf(os.Stderr, "failed to open log file: %v\n", err)
153 os.Exit(1)
154 }
155
156 logger = log.New(logFile, "", 0)
157}
158
159func logEvent(event string, fields map[string]interface{}) {
160 entry := fmt.Sprintf(
161 "timestamp=%q event=%q",
162 time.Now().Format(time.RFC3339),
163 event,
164 )
165
166 for k, v := range fields {
167 entry += fmt.Sprintf(" %s=%q", k, v)
168 }
169
170 logger.Println(entry)
171}
172
173func exitWithLog(message string) {
174 logEvent("Access denied", map[string]interface{}{
175 "error": message,
176 })
177 logFile.Sync()
178 fmt.Fprintf(os.Stderr, "error: %s\n", message)
179 os.Exit(1)
180}
181
182func cleanup() {
183 if logFile != nil {
184 logFile.Sync()
185 logFile.Close()
186 }
187}
188
189func isPushPermitted(user, qualifiedRepoName string) bool {
190 u, _ := url.Parse(*endpoint + "/push-allowed")
191 q := u.Query()
192 q.Add("user", user)
193 q.Add("repo", qualifiedRepoName)
194 u.RawQuery = q.Encode()
195
196 req, err := http.Get(u.String())
197 if err != nil {
198 exitWithLog(fmt.Sprintf("error verifying permissions: %v", err))
199 }
200
201 logEvent("url", map[string]interface{}{
202 "url": u.String(),
203 "status": req.Status,
204 })
205
206 return req.StatusCode == http.StatusNoContent
207}