1
0

route_test.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. package lfs
  2. import (
  3. "context"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/http/httptest"
  8. "testing"
  9. "github.com/flamego/flamego"
  10. "github.com/stretchr/testify/assert"
  11. "gogs.io/gogs/internal/auth"
  12. "gogs.io/gogs/internal/database"
  13. "gogs.io/gogs/internal/lfsutil"
  14. )
  15. func TestAuthenticate(t *testing.T) {
  16. tests := []struct {
  17. name string
  18. header http.Header
  19. mockStore func() *MockStore
  20. expStatusCode int
  21. expHeader http.Header
  22. expBody string
  23. }{
  24. {
  25. name: "no authorization",
  26. expStatusCode: http.StatusUnauthorized,
  27. expHeader: http.Header{
  28. "Lfs-Authenticate": []string{`Basic realm="Git LFS"`},
  29. "Content-Type": []string{"application/vnd.git-lfs+json"},
  30. },
  31. expBody: `{"message":"Credentials needed"}` + "\n",
  32. },
  33. {
  34. name: "user has 2FA enabled",
  35. header: http.Header{
  36. "Authorization": []string{"Basic dXNlcm5hbWU6cGFzc3dvcmQ="},
  37. },
  38. mockStore: func() *MockStore {
  39. mockStore := NewMockStore()
  40. mockStore.IsTwoFactorEnabledFunc.SetDefaultReturn(true)
  41. mockStore.AuthenticateUserFunc.SetDefaultReturn(&database.User{}, nil)
  42. return mockStore
  43. },
  44. expStatusCode: http.StatusBadRequest,
  45. expHeader: http.Header{},
  46. expBody: "Users with 2FA enabled are not allowed to authenticate via username and password.",
  47. },
  48. {
  49. name: "both user and access token do not exist",
  50. header: http.Header{
  51. "Authorization": []string{"Basic dXNlcm5hbWU="},
  52. },
  53. mockStore: func() *MockStore {
  54. mockStore := NewMockStore()
  55. mockStore.GetAccessTokenBySHA1Func.SetDefaultReturn(nil, database.ErrAccessTokenNotExist{})
  56. mockStore.AuthenticateUserFunc.SetDefaultReturn(nil, auth.ErrBadCredentials{})
  57. return mockStore
  58. },
  59. expStatusCode: http.StatusUnauthorized,
  60. expHeader: http.Header{
  61. "Lfs-Authenticate": []string{`Basic realm="Git LFS"`},
  62. "Content-Type": []string{"application/vnd.git-lfs+json"},
  63. },
  64. expBody: `{"message":"Credentials needed"}` + "\n",
  65. },
  66. {
  67. name: "authenticated by username and password",
  68. header: http.Header{
  69. "Authorization": []string{"Basic dXNlcm5hbWU6cGFzc3dvcmQ="},
  70. },
  71. mockStore: func() *MockStore {
  72. mockStore := NewMockStore()
  73. mockStore.IsTwoFactorEnabledFunc.SetDefaultReturn(false)
  74. mockStore.AuthenticateUserFunc.SetDefaultReturn(&database.User{ID: 1, Name: "unknwon"}, nil)
  75. return mockStore
  76. },
  77. expStatusCode: http.StatusOK,
  78. expHeader: http.Header{},
  79. expBody: "ID: 1, Name: unknwon",
  80. },
  81. {
  82. name: "authenticate by access token via username",
  83. header: http.Header{
  84. "Authorization": []string{"Basic dXNlcm5hbWU="},
  85. },
  86. mockStore: func() *MockStore {
  87. mockStore := NewMockStore()
  88. mockStore.GetAccessTokenBySHA1Func.SetDefaultReturn(&database.AccessToken{}, nil)
  89. mockStore.AuthenticateUserFunc.SetDefaultReturn(nil, auth.ErrBadCredentials{})
  90. mockStore.GetUserByIDFunc.SetDefaultReturn(&database.User{ID: 1, Name: "unknwon"}, nil)
  91. return mockStore
  92. },
  93. expStatusCode: http.StatusOK,
  94. expHeader: http.Header{},
  95. expBody: "ID: 1, Name: unknwon",
  96. },
  97. {
  98. name: "authenticate by access token via password",
  99. header: http.Header{
  100. "Authorization": []string{"Basic dXNlcm5hbWU6cGFzc3dvcmQ="},
  101. },
  102. mockStore: func() *MockStore {
  103. mockStore := NewMockStore()
  104. mockStore.GetAccessTokenBySHA1Func.SetDefaultHook(func(_ context.Context, sha1 string) (*database.AccessToken, error) {
  105. if sha1 == "password" {
  106. return &database.AccessToken{}, nil
  107. }
  108. return nil, database.ErrAccessTokenNotExist{}
  109. })
  110. mockStore.AuthenticateUserFunc.SetDefaultReturn(nil, auth.ErrBadCredentials{})
  111. mockStore.GetUserByIDFunc.SetDefaultReturn(&database.User{ID: 1, Name: "unknwon"}, nil)
  112. return mockStore
  113. },
  114. expStatusCode: http.StatusOK,
  115. expHeader: http.Header{},
  116. expBody: "ID: 1, Name: unknwon",
  117. },
  118. }
  119. for _, test := range tests {
  120. t.Run(test.name, func(t *testing.T) {
  121. if test.mockStore == nil {
  122. test.mockStore = NewMockStore
  123. }
  124. f := flamego.New()
  125. f.Get("/", authenticate(test.mockStore()), func(w http.ResponseWriter, user *database.User) {
  126. _, _ = fmt.Fprintf(w, "ID: %d, Name: %s", user.ID, user.Name)
  127. })
  128. r, err := http.NewRequest("GET", "/", nil)
  129. if err != nil {
  130. t.Fatal(err)
  131. }
  132. r.Header = test.header
  133. rr := httptest.NewRecorder()
  134. f.ServeHTTP(rr, r)
  135. resp := rr.Result()
  136. assert.Equal(t, test.expStatusCode, resp.StatusCode)
  137. assert.Equal(t, test.expHeader, resp.Header)
  138. body, err := io.ReadAll(resp.Body)
  139. if err != nil {
  140. t.Fatal(err)
  141. }
  142. assert.Equal(t, test.expBody, string(body))
  143. })
  144. }
  145. }
  146. func TestAuthorize(t *testing.T) {
  147. tests := []struct {
  148. name string
  149. accessMode database.AccessMode
  150. mockStore func() *MockStore
  151. expStatusCode int
  152. expBody string
  153. }{
  154. {
  155. name: "user does not exist",
  156. accessMode: database.AccessModeNone,
  157. mockStore: func() *MockStore {
  158. mockStore := NewMockStore()
  159. mockStore.GetUserByUsernameFunc.SetDefaultReturn(nil, database.ErrUserNotExist{})
  160. return mockStore
  161. },
  162. expStatusCode: http.StatusNotFound,
  163. },
  164. {
  165. name: "repository does not exist",
  166. accessMode: database.AccessModeNone,
  167. mockStore: func() *MockStore {
  168. mockStore := NewMockStore()
  169. mockStore.GetRepositoryByNameFunc.SetDefaultReturn(nil, database.ErrRepoNotExist{})
  170. mockStore.GetUserByUsernameFunc.SetDefaultHook(func(ctx context.Context, username string) (*database.User, error) {
  171. return &database.User{Name: username}, nil
  172. })
  173. return mockStore
  174. },
  175. expStatusCode: http.StatusNotFound,
  176. },
  177. {
  178. name: "actor is not authorized",
  179. accessMode: database.AccessModeWrite,
  180. mockStore: func() *MockStore {
  181. mockStore := NewMockStore()
  182. mockStore.AuthorizeRepositoryAccessFunc.SetDefaultHook(func(_ context.Context, _ int64, _ int64, desired database.AccessMode, _ database.AccessModeOptions) bool {
  183. return desired <= database.AccessModeRead
  184. })
  185. mockStore.GetRepositoryByNameFunc.SetDefaultHook(func(ctx context.Context, ownerID int64, name string) (*database.Repository, error) {
  186. return &database.Repository{Name: name}, nil
  187. })
  188. mockStore.GetUserByUsernameFunc.SetDefaultHook(func(ctx context.Context, username string) (*database.User, error) {
  189. return &database.User{Name: username}, nil
  190. })
  191. return mockStore
  192. },
  193. expStatusCode: http.StatusNotFound,
  194. },
  195. {
  196. name: "actor is authorized",
  197. accessMode: database.AccessModeRead,
  198. mockStore: func() *MockStore {
  199. mockStore := NewMockStore()
  200. mockStore.AuthorizeRepositoryAccessFunc.SetDefaultHook(func(_ context.Context, _ int64, _ int64, desired database.AccessMode, _ database.AccessModeOptions) bool {
  201. return desired <= database.AccessModeRead
  202. })
  203. mockStore.GetRepositoryByNameFunc.SetDefaultHook(func(ctx context.Context, ownerID int64, name string) (*database.Repository, error) {
  204. return &database.Repository{Name: name}, nil
  205. })
  206. mockStore.GetUserByUsernameFunc.SetDefaultHook(func(ctx context.Context, username string) (*database.User, error) {
  207. return &database.User{Name: username}, nil
  208. })
  209. return mockStore
  210. },
  211. expStatusCode: http.StatusOK,
  212. expBody: "owner.Name: owner, repo.Name: repo",
  213. },
  214. }
  215. for _, test := range tests {
  216. t.Run(test.name, func(t *testing.T) {
  217. mockStore := NewMockStore()
  218. if test.mockStore != nil {
  219. mockStore = test.mockStore()
  220. }
  221. f := flamego.New()
  222. f.Use(func(c flamego.Context) {
  223. c.Map(&database.User{})
  224. })
  225. f.Get(
  226. "/{username}/{reponame}",
  227. authorize(mockStore, test.accessMode),
  228. func(w http.ResponseWriter, owner *database.User, repo *database.Repository) {
  229. _, _ = fmt.Fprintf(w, "owner.Name: %s, repo.Name: %s", owner.Name, repo.Name)
  230. },
  231. )
  232. r, err := http.NewRequest("GET", "/owner/repo", nil)
  233. if err != nil {
  234. t.Fatal(err)
  235. }
  236. rr := httptest.NewRecorder()
  237. f.ServeHTTP(rr, r)
  238. resp := rr.Result()
  239. assert.Equal(t, test.expStatusCode, resp.StatusCode)
  240. body, err := io.ReadAll(resp.Body)
  241. if err != nil {
  242. t.Fatal(err)
  243. }
  244. assert.Equal(t, test.expBody, string(body))
  245. })
  246. }
  247. }
  248. func Test_verifyHeader(t *testing.T) {
  249. tests := []struct {
  250. name string
  251. verifyHeader flamego.Handler
  252. header http.Header
  253. expStatusCode int
  254. }{
  255. {
  256. name: "header not found",
  257. verifyHeader: verifyHeader("Accept", contentType, http.StatusNotAcceptable),
  258. expStatusCode: http.StatusNotAcceptable,
  259. },
  260. {
  261. name: "header found",
  262. verifyHeader: verifyHeader("Accept", "application/vnd.git-lfs+json", http.StatusNotAcceptable),
  263. header: http.Header{
  264. "Accept": []string{"application/vnd.git-lfs+json; charset=utf-8"},
  265. },
  266. expStatusCode: http.StatusOK,
  267. },
  268. }
  269. for _, test := range tests {
  270. t.Run(test.name, func(t *testing.T) {
  271. f := flamego.New()
  272. f.Get("/", test.verifyHeader)
  273. r, err := http.NewRequest("GET", "/", nil)
  274. if err != nil {
  275. t.Fatal(err)
  276. }
  277. r.Header = test.header
  278. rr := httptest.NewRecorder()
  279. f.ServeHTTP(rr, r)
  280. resp := rr.Result()
  281. assert.Equal(t, test.expStatusCode, resp.StatusCode)
  282. })
  283. }
  284. }
  285. func Test_verifyOID(t *testing.T) {
  286. f := flamego.New()
  287. f.Get("/{oid}", verifyOID(), func(w http.ResponseWriter, oid lfsutil.OID) {
  288. fmt.Fprintf(w, "oid: %s", oid)
  289. })
  290. tests := []struct {
  291. name string
  292. url string
  293. expStatusCode int
  294. expBody string
  295. }{
  296. {
  297. name: "bad oid",
  298. url: "/bad_oid",
  299. expStatusCode: http.StatusBadRequest,
  300. expBody: `{"message":"Invalid oid"}` + "\n",
  301. },
  302. {
  303. name: "good oid",
  304. url: "/ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f",
  305. expStatusCode: http.StatusOK,
  306. expBody: "oid: ef797c8118f02dfb649607dd5d3f8c7623048c9c063d532cc95c5ed7a898a64f",
  307. },
  308. }
  309. for _, test := range tests {
  310. t.Run(test.name, func(t *testing.T) {
  311. r, err := http.NewRequest("GET", test.url, nil)
  312. if err != nil {
  313. t.Fatal(err)
  314. }
  315. rr := httptest.NewRecorder()
  316. f.ServeHTTP(rr, r)
  317. resp := rr.Result()
  318. assert.Equal(t, test.expStatusCode, resp.StatusCode)
  319. body, err := io.ReadAll(resp.Body)
  320. if err != nil {
  321. t.Fatal(err)
  322. }
  323. assert.Equal(t, test.expBody, string(body))
  324. })
  325. }
  326. }
  327. func Test_internalServerError(t *testing.T) {
  328. rr := httptest.NewRecorder()
  329. internalServerError(rr)
  330. resp := rr.Result()
  331. assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
  332. body, err := io.ReadAll(resp.Body)
  333. if err != nil {
  334. t.Fatal(err)
  335. }
  336. assert.Equal(t, `{"message":"Internal server error"}`+"\n", string(body))
  337. }