192 lines
5.5 kB
1
package state
2
3
import (
4
"context"
5
"log"
6
"net/http"
7
"strings"
8
"time"
9
10
comatproto "github.com/bluesky-social/indigo/api/atproto"
11
"github.com/bluesky-social/indigo/atproto/identity"
12
"github.com/bluesky-social/indigo/xrpc"
13
"github.com/go-chi/chi/v5"
14
"github.com/sotangled/tangled/appview"
15
"github.com/sotangled/tangled/appview/auth"
16
)
17
18
type Middleware func(http.Handler) http.Handler
19
20
func AuthMiddleware(s *State) Middleware {
21
return func(next http.Handler) http.Handler {
22
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23
session, _ := s.auth.Store.Get(r, appview.SessionName)
24
authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
25
if !ok || !authorized {
26
log.Printf("not logged in, redirecting")
27
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
28
return
29
}
30
31
// refresh if nearing expiry
32
// TODO: dedup with /login
33
expiryStr := session.Values[appview.SessionExpiry].(string)
34
expiry, err := time.Parse(time.RFC3339, expiryStr)
35
if err != nil {
36
log.Println("invalid expiry time", err)
37
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
38
return
39
}
40
pdsUrl := session.Values[appview.SessionPds].(string)
41
did := session.Values[appview.SessionDid].(string)
42
refreshJwt := session.Values[appview.SessionRefreshJwt].(string)
43
44
if time.Now().After(expiry) {
45
log.Println("token expired, refreshing ...")
46
47
client := xrpc.Client{
48
Host: pdsUrl,
49
Auth: &xrpc.AuthInfo{
50
Did: did,
51
AccessJwt: refreshJwt,
52
RefreshJwt: refreshJwt,
53
},
54
}
55
atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
56
if err != nil {
57
log.Println("failed to refresh session", err)
58
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
59
return
60
}
61
62
sessionish := auth.RefreshSessionWrapper{atSession}
63
64
err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
65
if err != nil {
66
log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
67
return
68
}
69
70
log.Println("successfully refreshed token")
71
}
72
73
next.ServeHTTP(w, r)
74
})
75
}
76
}
77
78
func RoleMiddleware(s *State, group string) Middleware {
79
return func(next http.Handler) http.Handler {
80
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
81
// requires auth also
82
actor := s.auth.GetUser(r)
83
if actor == nil {
84
// we need a logged in user
85
log.Printf("not logged in, redirecting")
86
http.Error(w, "Forbiden", http.StatusUnauthorized)
87
return
88
}
89
domain := chi.URLParam(r, "domain")
90
if domain == "" {
91
http.Error(w, "malformed url", http.StatusBadRequest)
92
return
93
}
94
95
ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
96
if err != nil || !ok {
97
// we need a logged in user
98
log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
99
http.Error(w, "Forbiden", http.StatusUnauthorized)
100
return
101
}
102
103
next.ServeHTTP(w, r)
104
})
105
}
106
}
107
108
func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
109
return func(next http.Handler) http.Handler {
110
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
111
// requires auth also
112
actor := s.auth.GetUser(r)
113
if actor == nil {
114
// we need a logged in user
115
log.Printf("not logged in, redirecting")
116
http.Error(w, "Forbiden", http.StatusUnauthorized)
117
return
118
}
119
f, err := fullyResolvedRepo(r)
120
if err != nil {
121
http.Error(w, "malformed url", http.StatusBadRequest)
122
return
123
}
124
125
ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
126
if err != nil || !ok {
127
// we need a logged in user
128
log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
129
http.Error(w, "Forbiden", http.StatusUnauthorized)
130
return
131
}
132
133
next.ServeHTTP(w, r)
134
})
135
}
136
}
137
138
func StripLeadingAt(next http.Handler) http.Handler {
139
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
140
path := req.URL.Path
141
if strings.HasPrefix(path, "/@") {
142
req.URL.Path = "/" + strings.TrimPrefix(path, "/@")
143
}
144
next.ServeHTTP(w, req)
145
})
146
}
147
148
func ResolveIdent(s *State) Middleware {
149
return func(next http.Handler) http.Handler {
150
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
151
didOrHandle := chi.URLParam(req, "user")
152
153
id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
154
if err != nil {
155
// invalid did or handle
156
log.Println("failed to resolve did/handle:", err)
157
w.WriteHeader(http.StatusNotFound)
158
return
159
}
160
161
ctx := context.WithValue(req.Context(), "resolvedId", *id)
162
163
next.ServeHTTP(w, req.WithContext(ctx))
164
})
165
}
166
}
167
168
func ResolveRepoKnot(s *State) Middleware {
169
return func(next http.Handler) http.Handler {
170
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
171
repoName := chi.URLParam(req, "repo")
172
id, ok := req.Context().Value("resolvedId").(identity.Identity)
173
if !ok {
174
log.Println("malformed middleware")
175
w.WriteHeader(http.StatusInternalServerError)
176
return
177
}
178
179
repo, err := s.db.GetRepo(id.DID.String(), repoName)
180
if err != nil {
181
// invalid did or handle
182
log.Println("failed to resolve repo")
183
w.WriteHeader(http.StatusNotFound)
184
return
185
}
186
187
ctx := context.WithValue(req.Context(), "knot", repo.Knot)
188
ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
189
next.ServeHTTP(w, req.WithContext(ctx))
190
})
191
}
192
}
193