lightweight go reverse proxy for ollama with bearer token authentication
go
proxy
ollama
1package main
2
3import (
4 "net/http"
5 "net/http/httptest"
6 "testing"
7
8 "gotest.tools/v3/assert"
9)
10
11func TestAuthMiddleware(t *testing.T) {
12 tests := []struct {
13 name string
14 token string
15 authHeader string
16 expectedStatus int
17 expectedBody string
18 }{
19 {
20 name: "valid token",
21 token: "test-token-123",
22 authHeader: "Bearer test-token-123",
23 expectedStatus: http.StatusOK,
24 expectedBody: "OK",
25 },
26 {
27 name: "invalid token",
28 token: "test-token-123",
29 authHeader: "Bearer wrong-token",
30 expectedStatus: http.StatusUnauthorized,
31 expectedBody: "Unauthorized\n",
32 },
33 {
34 name: "missing bearer prefix",
35 token: "test-token-123",
36 authHeader: "test-token-123",
37 expectedStatus: http.StatusUnauthorized,
38 expectedBody: "Unauthorized\n",
39 },
40 {
41 name: "missing authorization header",
42 token: "test-token-123",
43 authHeader: "",
44 expectedStatus: http.StatusUnauthorized,
45 expectedBody: "Unauthorized\n",
46 },
47 {
48 name: "empty token value",
49 token: "test-token-123",
50 authHeader: "Bearer ",
51 expectedStatus: http.StatusUnauthorized,
52 expectedBody: "Unauthorized\n",
53 },
54 }
55
56 for _, tt := range tests {
57 t.Run(tt.name, func(t *testing.T) {
58 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
59 w.WriteHeader(http.StatusOK)
60 _, _ = w.Write([]byte("OK"))
61 })
62
63 middleware := authMiddleware(tt.token)
64 req := httptest.NewRequest(http.MethodGet, "/", nil)
65 if tt.authHeader != "" {
66 req.Header.Set("Authorization", tt.authHeader)
67 }
68
69 rr := httptest.NewRecorder()
70 middleware(handler).ServeHTTP(rr, req)
71
72 assert.Equal(t, tt.expectedStatus, rr.Code)
73 assert.Equal(t, tt.expectedBody, rr.Body.String())
74 })
75 }
76}
77
78func TestNewProxy(t *testing.T) {
79 tests := []struct {
80 name string
81 targetURL string
82 wantErr bool
83 }{
84 {
85 name: "valid http URL",
86 targetURL: "http://localhost:11434",
87 wantErr: false,
88 },
89 {
90 name: "valid https URL",
91 targetURL: "https://example.com",
92 wantErr: false,
93 },
94 {
95 name: "invalid URL",
96 targetURL: "://invalid",
97 wantErr: true,
98 },
99 {
100 name: "empty URL",
101 targetURL: "",
102 wantErr: true,
103 },
104 }
105
106 for _, tt := range tests {
107 t.Run(tt.name, func(t *testing.T) {
108 proxy, err := newProxy(tt.targetURL)
109 if tt.wantErr {
110 assert.Assert(t, err != nil)
111 assert.Assert(t, proxy == nil)
112 } else {
113 assert.NilError(t, err)
114 assert.Assert(t, proxy != nil)
115 }
116 })
117 }
118}