191 lines
5.4 kB
1
package state
2
3
import (
4
"context"
5
"log"
6
"net/http"
7
"strings"
8
"time"
9
10
comatproto "github.com/bluesky-social/indigo/api/atproto"
11
"github.com/bluesky-social/indigo/atproto/identity"
12
"github.com/bluesky-social/indigo/xrpc"
13
"github.com/go-chi/chi/v5"
14
"github.com/sotangled/tangled/appview"
15
"github.com/sotangled/tangled/appview/auth"
16
)
17
18
type Middleware func(http.Handler) http.Handler
19
20
func AuthMiddleware(s *State) Middleware {
21
return func(next http.Handler) http.Handler {
22
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23
session, _ := s.auth.Store.Get(r, appview.SessionName)
24
authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
25
if !ok || !authorized {
26
log.Printf("not logged in, redirecting")
27
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
28
return
29
}
30
31
// refresh if nearing expiry
32
// TODO: dedup with /login
33
expiryStr := session.Values[appview.SessionExpiry].(string)
34
expiry, err := time.Parse(time.RFC3339, expiryStr)
35
if err != nil {
36
log.Println("invalid expiry time", err)
37
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
38
return
39
}
40
pdsUrl := session.Values[appview.SessionPds].(string)
41
did := session.Values[appview.SessionDid].(string)
42
refreshJwt := session.Values[appview.SessionRefreshJwt].(string)
43
44
if time.Now().After(expiry) {
45
log.Println("token expired, refreshing ...")
46
47
client := xrpc.Client{
48
Host: pdsUrl,
49
Auth: &xrpc.AuthInfo{
50
Did: did,
51
AccessJwt: refreshJwt,
52
RefreshJwt: refreshJwt,
53
},
54
}
55
atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
56
if err != nil {
57
log.Println(err)
58
return
59
}
60
61
sessionish := auth.RefreshSessionWrapper{atSession}
62
63
err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
64
if err != nil {
65
log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
66
return
67
}
68
69
log.Println("successfully refreshed token")
70
}
71
72
next.ServeHTTP(w, r)
73
})
74
}
75
}
76
77
func RoleMiddleware(s *State, group string) Middleware {
78
return func(next http.Handler) http.Handler {
79
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80
// requires auth also
81
actor := s.auth.GetUser(r)
82
if actor == nil {
83
// we need a logged in user
84
log.Printf("not logged in, redirecting")
85
http.Error(w, "Forbiden", http.StatusUnauthorized)
86
return
87
}
88
domain := chi.URLParam(r, "domain")
89
if domain == "" {
90
http.Error(w, "malformed url", http.StatusBadRequest)
91
return
92
}
93
94
ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
95
if err != nil || !ok {
96
// we need a logged in user
97
log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
98
http.Error(w, "Forbiden", http.StatusUnauthorized)
99
return
100
}
101
102
next.ServeHTTP(w, r)
103
})
104
}
105
}
106
107
func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
108
return func(next http.Handler) http.Handler {
109
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
110
// requires auth also
111
actor := s.auth.GetUser(r)
112
if actor == nil {
113
// we need a logged in user
114
log.Printf("not logged in, redirecting")
115
http.Error(w, "Forbiden", http.StatusUnauthorized)
116
return
117
}
118
f, err := fullyResolvedRepo(r)
119
if err != nil {
120
http.Error(w, "malformed url", http.StatusBadRequest)
121
return
122
}
123
124
ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
125
if err != nil || !ok {
126
// we need a logged in user
127
log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
128
http.Error(w, "Forbiden", http.StatusUnauthorized)
129
return
130
}
131
132
next.ServeHTTP(w, r)
133
})
134
}
135
}
136
137
func StripLeadingAt(next http.Handler) http.Handler {
138
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
139
path := req.URL.Path
140
if strings.HasPrefix(path, "/@") {
141
req.URL.Path = "/" + strings.TrimPrefix(path, "/@")
142
}
143
next.ServeHTTP(w, req)
144
})
145
}
146
147
func ResolveIdent(s *State) Middleware {
148
return func(next http.Handler) http.Handler {
149
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
150
didOrHandle := chi.URLParam(req, "user")
151
152
id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
153
if err != nil {
154
// invalid did or handle
155
log.Println("failed to resolve did/handle:", err)
156
w.WriteHeader(http.StatusNotFound)
157
return
158
}
159
160
ctx := context.WithValue(req.Context(), "resolvedId", *id)
161
162
next.ServeHTTP(w, req.WithContext(ctx))
163
})
164
}
165
}
166
167
func ResolveRepoKnot(s *State) Middleware {
168
return func(next http.Handler) http.Handler {
169
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
170
repoName := chi.URLParam(req, "repo")
171
id, ok := req.Context().Value("resolvedId").(identity.Identity)
172
if !ok {
173
log.Println("malformed middleware")
174
w.WriteHeader(http.StatusInternalServerError)
175
return
176
}
177
178
repo, err := s.db.GetRepo(id.DID.String(), repoName)
179
if err != nil {
180
// invalid did or handle
181
log.Println("failed to resolve repo")
182
w.WriteHeader(http.StatusNotFound)
183
return
184
}
185
186
ctx := context.WithValue(req.Context(), "knot", repo.Knot)
187
ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
188
next.ServeHTTP(w, req.WithContext(ctx))
189
})
190
}
191
}
192