|
1 | 1 | package store |
2 | 2 |
|
3 | 3 | import ( |
4 | | - "container/list" |
5 | | - "strconv" |
6 | | - "sync" |
| 4 | + "encoding/json" |
7 | 5 | "time" |
8 | 6 |
|
| 7 | + "github.com/satori/go.uuid" |
| 8 | + "github.com/tidwall/buntdb" |
9 | 9 | "gopkg.in/oauth2.v3" |
| 10 | + "gopkg.in/oauth2.v3/models" |
10 | 11 | ) |
11 | 12 |
|
12 | 13 | // NewMemoryTokenStore Create a token store instance based on memory |
13 | | -// gcInterval Perform garbage collection intervals(The default is 30 seconds) |
14 | | -func NewMemoryTokenStore(gcInterval time.Duration) oauth2.TokenStore { |
15 | | - if gcInterval == 0 { |
16 | | - gcInterval = time.Second * 30 |
17 | | - } |
18 | | - store := &MemoryTokenStore{ |
19 | | - gcInterval: gcInterval, |
20 | | - basicList: list.New(), |
21 | | - data: make(map[string]oauth2.TokenInfo), |
22 | | - access: make(map[string]string), |
23 | | - refresh: make(map[string]string), |
| 14 | +func NewMemoryTokenStore() (store oauth2.TokenStore, err error) { |
| 15 | + store, err = NewFileTokenStore(":memory:") |
| 16 | + return |
| 17 | +} |
| 18 | + |
| 19 | +// NewFileTokenStore Create a token store instance based on file |
| 20 | +func NewFileTokenStore(filename string) (store oauth2.TokenStore, err error) { |
| 21 | + db, err := buntdb.Open(filename) |
| 22 | + if err != nil { |
| 23 | + return |
24 | 24 | } |
25 | | - go store.gc() |
26 | | - return store |
| 25 | + store = &TokenStore{db: db} |
| 26 | + return |
27 | 27 | } |
28 | 28 |
|
29 | | -// MemoryTokenStore Memory storage for token |
30 | | -type MemoryTokenStore struct { |
31 | | - gcInterval time.Duration |
32 | | - globalID int64 |
33 | | - lock sync.RWMutex |
34 | | - data map[string]oauth2.TokenInfo |
35 | | - access map[string]string |
36 | | - refresh map[string]string |
37 | | - basicList *list.List |
38 | | - listLock sync.RWMutex |
| 29 | +// TokenStore Token storage based on buntdb(https://github.com/tidwall/buntdb) |
| 30 | +type TokenStore struct { |
| 31 | + db *buntdb.DB |
39 | 32 | } |
40 | 33 |
|
41 | | -func (mts *MemoryTokenStore) gc() { |
42 | | - time.AfterFunc(mts.gcInterval, func() { |
43 | | - defer mts.gc() |
44 | | - rmeles := make([]*list.Element, 0, 32) |
45 | | - mts.listLock.RLock() |
46 | | - ele := mts.basicList.Front() |
47 | | - mts.listLock.RUnlock() |
48 | | - for ele != nil { |
49 | | - if rm := mts.gcElement(ele); rm { |
50 | | - rmeles = append(rmeles, ele) |
| 34 | +// Create Create and store the new token information |
| 35 | +func (ts *TokenStore) Create(info oauth2.TokenInfo) (err error) { |
| 36 | + ct := time.Now() |
| 37 | + jv, err := json.Marshal(info) |
| 38 | + if err != nil { |
| 39 | + return |
| 40 | + } |
| 41 | + basicID := uuid.NewV4().String() |
| 42 | + aexp := info.GetAccessExpiresIn() |
| 43 | + rexp := aexp |
| 44 | + |
| 45 | + err = ts.db.Update(func(tx *buntdb.Tx) (err error) { |
| 46 | + if refresh := info.GetRefresh(); refresh != "" { |
| 47 | + rexp = info.GetRefreshCreateAt().Add(info.GetRefreshExpiresIn()).Sub(ct) |
| 48 | + if aexp.Seconds() > rexp.Seconds() { |
| 49 | + aexp = rexp |
| 50 | + } |
| 51 | + _, _, err = tx.Set(refresh, basicID, &buntdb.SetOptions{Expires: true, TTL: rexp}) |
| 52 | + if err != nil { |
| 53 | + return |
51 | 54 | } |
52 | | - mts.listLock.RLock() |
53 | | - ele = ele.Next() |
54 | | - mts.listLock.RUnlock() |
55 | 55 | } |
56 | | - |
57 | | - for _, e := range rmeles { |
58 | | - mts.listLock.Lock() |
59 | | - mts.basicList.Remove(e) |
60 | | - mts.listLock.Unlock() |
| 56 | + _, _, err = tx.Set(basicID, string(jv), &buntdb.SetOptions{Expires: true, TTL: rexp}) |
| 57 | + if err != nil { |
| 58 | + return |
61 | 59 | } |
| 60 | + _, _, err = tx.Set(info.GetAccess(), basicID, &buntdb.SetOptions{Expires: true, TTL: aexp}) |
| 61 | + return |
62 | 62 | }) |
| 63 | + return |
63 | 64 | } |
64 | 65 |
|
65 | | -func (mts *MemoryTokenStore) gcElement(ele *list.Element) (rm bool) { |
66 | | - basicID := ele.Value.(string) |
67 | | - mts.lock.RLock() |
68 | | - ti, ok := mts.data[basicID] |
69 | | - mts.lock.RUnlock() |
70 | | - if !ok { |
71 | | - rm = true |
| 66 | +// remove key |
| 67 | +func (ts *TokenStore) remove(key string) (err error) { |
| 68 | + verr := ts.db.Update(func(tx *buntdb.Tx) (err error) { |
| 69 | + _, err = tx.Delete(key) |
| 70 | + return |
| 71 | + }) |
| 72 | + if verr == buntdb.ErrNotFound { |
72 | 73 | return |
73 | 74 | } |
74 | | - ct := time.Now() |
75 | | - if refresh := ti.GetRefresh(); refresh != "" && |
76 | | - ti.GetRefreshCreateAt().Add(ti.GetRefreshExpiresIn()).Before(ct) { |
77 | | - mts.lock.RLock() |
78 | | - delete(mts.access, ti.GetAccess()) |
79 | | - delete(mts.refresh, refresh) |
80 | | - delete(mts.data, basicID) |
81 | | - mts.lock.RUnlock() |
82 | | - rm = true |
83 | | - } else if ti.GetAccessCreateAt().Add(ti.GetAccessExpiresIn()).Before(ct) { |
84 | | - mts.lock.RLock() |
85 | | - delete(mts.access, ti.GetAccess()) |
86 | | - if refresh := ti.GetRefresh(); refresh == "" { |
87 | | - delete(mts.data, basicID) |
88 | | - rm = true |
89 | | - } |
90 | | - mts.lock.RUnlock() |
91 | | - } |
| 75 | + err = verr |
92 | 76 | return |
93 | 77 | } |
94 | 78 |
|
95 | | -func (mts *MemoryTokenStore) getBasicID(id int64, info oauth2.TokenInfo) string { |
96 | | - return info.GetClientID() + "_" + strconv.FormatInt(id, 10) |
| 79 | +// RemoveByAccess Use the access token to delete the token information |
| 80 | +func (ts *TokenStore) RemoveByAccess(access string) (err error) { |
| 81 | + err = ts.remove(access) |
| 82 | + return |
97 | 83 | } |
98 | 84 |
|
99 | | -// Create Create and store the new token information |
100 | | -func (mts *MemoryTokenStore) Create(info oauth2.TokenInfo) (err error) { |
101 | | - mts.lock.Lock() |
102 | | - defer mts.lock.Unlock() |
103 | | - mts.globalID++ |
104 | | - basicID := mts.getBasicID(mts.globalID, info) |
105 | | - mts.data[basicID] = info |
106 | | - mts.access[info.GetAccess()] = basicID |
107 | | - if refresh := info.GetRefresh(); refresh != "" { |
108 | | - mts.refresh[refresh] = basicID |
109 | | - } |
110 | | - |
111 | | - mts.listLock.Lock() |
112 | | - mts.basicList.PushBack(basicID) |
113 | | - mts.listLock.Unlock() |
| 85 | +// RemoveByRefresh Use the refresh token to delete the token information |
| 86 | +func (ts *TokenStore) RemoveByRefresh(refresh string) (err error) { |
| 87 | + err = ts.remove(refresh) |
114 | 88 | return |
115 | 89 | } |
116 | 90 |
|
117 | | -// RemoveByAccess Use the access token to delete the token information |
118 | | -func (mts *MemoryTokenStore) RemoveByAccess(access string) (err error) { |
119 | | - mts.lock.RLock() |
120 | | - v, ok := mts.access[access] |
121 | | - if !ok { |
122 | | - mts.lock.RUnlock() |
| 91 | +func (ts *TokenStore) get(key string) (ti oauth2.TokenInfo, err error) { |
| 92 | + verr := ts.db.View(func(tx *buntdb.Tx) (err error) { |
| 93 | + basicID, err := tx.Get(key) |
| 94 | + if err != nil { |
| 95 | + return |
| 96 | + } |
| 97 | + jv, err := tx.Get(basicID) |
| 98 | + if err != nil { |
| 99 | + return |
| 100 | + } |
| 101 | + var tm models.Token |
| 102 | + err = json.Unmarshal([]byte(jv), &tm) |
| 103 | + if err != nil { |
| 104 | + return |
| 105 | + } |
| 106 | + ti = &tm |
| 107 | + return |
| 108 | + }) |
| 109 | + if verr == buntdb.ErrNotFound { |
123 | 110 | return |
124 | 111 | } |
125 | | - info := mts.data[v] |
126 | | - mts.lock.RUnlock() |
127 | | - |
128 | | - mts.lock.Lock() |
129 | | - defer mts.lock.Unlock() |
130 | | - delete(mts.access, access) |
131 | | - if refresh := info.GetRefresh(); refresh == "" { |
132 | | - delete(mts.data, v) |
133 | | - } |
134 | | - return |
135 | | -} |
136 | | - |
137 | | -// RemoveByRefresh Use the refresh token to delete the token information |
138 | | -func (mts *MemoryTokenStore) RemoveByRefresh(refresh string) (err error) { |
139 | | - mts.lock.Lock() |
140 | | - defer mts.lock.Unlock() |
141 | | - delete(mts.refresh, refresh) |
142 | | - |
| 112 | + err = verr |
143 | 113 | return |
144 | 114 | } |
145 | 115 |
|
146 | 116 | // GetByAccess Use the access token for token information data |
147 | | -func (mts *MemoryTokenStore) GetByAccess(access string) (ti oauth2.TokenInfo, err error) { |
148 | | - mts.lock.RLock() |
149 | | - v, ok := mts.access[access] |
150 | | - if !ok { |
151 | | - mts.lock.RUnlock() |
152 | | - return |
153 | | - } |
154 | | - ti = mts.data[v] |
155 | | - mts.lock.RUnlock() |
| 117 | +func (ts *TokenStore) GetByAccess(access string) (ti oauth2.TokenInfo, err error) { |
| 118 | + ti, err = ts.get(access) |
156 | 119 | return |
157 | 120 | } |
158 | 121 |
|
159 | 122 | // GetByRefresh Use the refresh token for token information data |
160 | | -func (mts *MemoryTokenStore) GetByRefresh(refresh string) (ti oauth2.TokenInfo, err error) { |
161 | | - mts.lock.RLock() |
162 | | - v, ok := mts.refresh[refresh] |
163 | | - if !ok { |
164 | | - mts.lock.RUnlock() |
165 | | - return |
166 | | - } |
167 | | - ti = mts.data[v] |
168 | | - mts.lock.RUnlock() |
| 123 | +func (ts *TokenStore) GetByRefresh(refresh string) (ti oauth2.TokenInfo, err error) { |
| 124 | + ti, err = ts.get(refresh) |
169 | 125 | return |
170 | 126 | } |
0 commit comments