Skip to content

Commit 0f993ee

Browse files
authored
Merge pull request #14 from LyricTian/develop
Fixed some implement
2 parents ab0b41b + 683529c commit 0f993ee

File tree

4 files changed

+72
-43
lines changed

4 files changed

+72
-43
lines changed

example/server/main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ func main() {
3737

3838
srv := server.NewServer(server.NewConfig(), manager)
3939
srv.SetUserAuthorizationHandler(userAuthorizeHandler)
40-
srv.SetInternalErrorHandler(func(err error) {
41-
fmt.Println("OAuth2 Error:", err.Error())
40+
srv.SetInternalErrorHandler(func(r *http.Request, err error) {
41+
fmt.Println("OAuth2 Error:", r.RequestURI, err.Error())
4242
})
4343

4444
http.HandleFunc("/login", loginHandler)

server/handler.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ type RefreshingScopeHandler func(newScope, oldScope string) (allowed bool)
2929
type ResponseErrorHandler func(re *errors.Response)
3030

3131
// InternalErrorHandler Internal error handing
32-
type InternalErrorHandler func(err error)
32+
type InternalErrorHandler func(req *http.Request, err error)
3333

3434
// ClientFormHandler Get client data from form
3535
func ClientFormHandler(r *http.Request) (clientID, clientSecret string, err error) {

server/server.go

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,6 @@ func (s *Server) GetAuthorizeData(rt oauth2.ResponseType, ti oauth2.TokenInfo) (
228228
func (s *Server) GetErrorData(rerr, ierr error) (data map[string]interface{}, statusCode int) {
229229
if ierr != nil {
230230
rerr = errors.ErrServerError
231-
if fn := s.InternalErrorHandler; fn != nil {
232-
fn(ierr)
233-
}
234231
}
235232
re := &errors.Response{
236233
Error: rerr,
@@ -253,11 +250,18 @@ func (s *Server) GetErrorData(rerr, ierr error) (data map[string]interface{}, st
253250
}
254251

255252
// response redirect error
256-
func (s *Server) resRedirectError(w http.ResponseWriter, req *AuthorizeRequest, rerr, ierr error) (err error) {
253+
func (s *Server) resRedirectError(w http.ResponseWriter, r *http.Request, req *AuthorizeRequest, rerr, ierr error) (err error) {
257254
if req == nil {
258255
err = ierr
259256
return
260257
}
258+
if fn := s.InternalErrorHandler; fn != nil {
259+
verr := ierr
260+
if verr == nil {
261+
verr = rerr
262+
}
263+
fn(r, verr)
264+
}
261265
data, _ := s.GetErrorData(rerr, ierr)
262266
err = s.resRedirect(w, req, data)
263267
return
@@ -283,20 +287,20 @@ func (s *Server) HandleAuthorizeRequest(w http.ResponseWriter, r *http.Request)
283287
}()
284288
req, rerr, ierr := s.ValidationAuthorizeRequest(r)
285289
if rerr != nil || ierr != nil {
286-
err = s.resRedirectError(w, req, rerr, ierr)
290+
err = s.resRedirectError(w, r, req, rerr, ierr)
287291
return
288292
}
289293
userID, err := s.UserAuthorizationHandler(w, r)
290294
if err != nil {
291-
err = s.resRedirectError(w, req, nil, err)
295+
err = s.resRedirectError(w, r, req, nil, err)
292296
return
293297
} else if userID == "" {
294298
return
295299
}
296300
req.UserID = userID
297301
ti, rerr, ierr := s.GetAuthorizeToken(req)
298302
if rerr != nil || ierr != nil {
299-
err = s.resRedirectError(w, req, rerr, ierr)
303+
err = s.resRedirectError(w, r, req, rerr, ierr)
300304
return
301305
}
302306
err = s.resRedirect(w, req, s.GetAuthorizeData(req.ResponseType, ti))
@@ -460,19 +464,26 @@ func (s *Server) HandleTokenRequest(w http.ResponseWriter, r *http.Request) (err
460464
}()
461465
gt, tgr, rerr, ierr := s.ValidationTokenRequest(r)
462466
if rerr != nil || ierr != nil {
463-
err = s.resTokenError(w, rerr, ierr)
467+
err = s.resTokenError(w, r, rerr, ierr)
464468
return
465469
}
466470
ti, rerr, ierr := s.GetAccessToken(gt, tgr)
467471
if rerr != nil || ierr != nil {
468-
err = s.resTokenError(w, rerr, ierr)
472+
err = s.resTokenError(w, r, rerr, ierr)
469473
return
470474
}
471475
err = s.resToken(w, s.GetTokenData(ti))
472476
return
473477
}
474478

475-
func (s *Server) resTokenError(w http.ResponseWriter, rerr, ierr error) (err error) {
479+
func (s *Server) resTokenError(w http.ResponseWriter, r *http.Request, rerr, ierr error) (err error) {
480+
if fn := s.InternalErrorHandler; fn != nil {
481+
verr := ierr
482+
if verr == nil {
483+
verr = rerr
484+
}
485+
fn(r, verr)
486+
}
476487
data, statusCode := s.GetErrorData(rerr, ierr)
477488
s.resToken(w, data, statusCode)
478489
return

store/token.go

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,50 +31,65 @@ type MemoryTokenStore struct {
3131
gcInterval time.Duration
3232
globalID int64
3333
lock sync.RWMutex
34-
basicList *list.List
3534
data map[string]oauth2.TokenInfo
3635
access map[string]string
3736
refresh map[string]string
37+
basicList *list.List
38+
listLock sync.RWMutex
3839
}
3940

4041
func (mts *MemoryTokenStore) gc() {
4142
time.AfterFunc(mts.gcInterval, func() {
4243
defer mts.gc()
43-
mts.lock.RLock()
44+
rmeles := make([]*list.Element, 0, 32)
45+
mts.listLock.RLock()
4446
ele := mts.basicList.Front()
45-
mts.lock.RUnlock()
46-
if ele == nil {
47-
return
47+
mts.listLock.RUnlock()
48+
for ele != nil {
49+
if rm := mts.gcElement(ele); rm {
50+
rmeles = append(rmeles, ele)
51+
}
52+
mts.listLock.RLock()
53+
ele = ele.Next()
54+
mts.listLock.RUnlock()
55+
}
56+
57+
for _, e := range rmeles {
58+
mts.listLock.Lock()
59+
mts.basicList.Remove(e)
60+
mts.listLock.Unlock()
4861
}
49-
basicID := ele.Value.(string)
62+
})
63+
}
64+
65+
func (mts *MemoryTokenStore) gcElement(ele *list.Element) (rm bool) {
66+
basicID := ele.Value.(string)
67+
mts.lock.RLock()
68+
ti, ok := mts.data[basicID]
69+
mts.lock.RUnlock()
70+
if !ok {
71+
rm = true
72+
return
73+
}
74+
ct := time.Now()
75+
if refresh := ti.GetRefresh(); refresh != "" &&
76+
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
5077
mts.lock.RLock()
51-
ti, ok := mts.data[basicID]
78+
delete(mts.access, ti.GetAccess())
79+
delete(mts.refresh, refresh)
80+
delete(mts.data, basicID)
5281
mts.lock.RUnlock()
53-
if !ok {
54-
mts.lock.Lock()
55-
mts.basicList.Remove(ele)
56-
mts.lock.Unlock()
57-
return
58-
}
59-
ct := time.Now()
60-
if refresh := ti.GetRefresh(); refresh != "" &&
61-
ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) {
62-
mts.lock.RLock()
63-
delete(mts.access, ti.GetAccess())
64-
delete(mts.refresh, refresh)
82+
rm = true
83+
} else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
84+
mts.lock.RLock()
85+
delete(mts.access, ti.GetAccess())
86+
if refresh := ti.GetRefresh(); refresh == "" {
6587
delete(mts.data, basicID)
66-
mts.basicList.Remove(ele)
67-
mts.lock.RUnlock()
68-
} else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) {
69-
mts.lock.RLock()
70-
delete(mts.access, ti.GetAccess())
71-
if refresh := ti.GetRefresh(); refresh == "" {
72-
delete(mts.data, basicID)
73-
mts.basicList.Remove(ele)
74-
}
75-
mts.lock.RUnlock()
88+
rm = true
7689
}
77-
})
90+
mts.lock.RUnlock()
91+
}
92+
return
7893
}
7994

8095
func (mts *MemoryTokenStore) getBasicID(id int64, info oauth2.TokenInfo) string {
@@ -92,7 +107,10 @@ func (mts *MemoryTokenStore) Create(info oauth2.TokenInfo) (err error) {
92107
if refresh := info.GetRefresh(); refresh != "" {
93108
mts.refresh[refresh] = basicID
94109
}
110+
111+
mts.listLock.Lock()
95112
mts.basicList.PushBack(basicID)
113+
mts.listLock.Unlock()
96114
return
97115
}
98116

0 commit comments

Comments
 (0)