Improve request package and add more unit tests
This commit is contained in:
parent
844680e573
commit
9d08139f43
49 changed files with 916 additions and 400 deletions
|
@ -43,11 +43,7 @@ func (c *Controller) CreateCategory(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// UpdateCategory is the API handler to update a category.
|
// UpdateCategory is the API handler to update a category.
|
||||||
func (c *Controller) UpdateCategory(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) UpdateCategory(w http.ResponseWriter, r *http.Request) {
|
||||||
categoryID, err := request.IntParam(r, "categoryID")
|
categoryID := request.RouteInt64Param(r, "categoryID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
category, err := decodeCategoryPayload(r.Body)
|
category, err := decodeCategoryPayload(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -85,11 +81,7 @@ func (c *Controller) GetCategories(w http.ResponseWriter, r *http.Request) {
|
||||||
// RemoveCategory is the API handler to remove a category.
|
// RemoveCategory is the API handler to remove a category.
|
||||||
func (c *Controller) RemoveCategory(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) RemoveCategory(w http.ResponseWriter, r *http.Request) {
|
||||||
userID := request.UserID(r)
|
userID := request.UserID(r)
|
||||||
categoryID, err := request.IntParam(r, "categoryID")
|
categoryID := request.RouteInt64Param(r, "categoryID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !c.store.CategoryExists(userID, categoryID) {
|
if !c.store.CategoryExists(userID, categoryID) {
|
||||||
json.NotFound(w, errors.New("Category not found"))
|
json.NotFound(w, errors.New("Category not found"))
|
||||||
|
|
47
api/entry.go
47
api/entry.go
|
@ -17,17 +17,8 @@ import (
|
||||||
|
|
||||||
// GetFeedEntry is the API handler to get a single feed entry.
|
// GetFeedEntry is the API handler to get a single feed entry.
|
||||||
func (c *Controller) GetFeedEntry(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) GetFeedEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
|
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
|
||||||
builder.WithFeedID(feedID)
|
builder.WithFeedID(feedID)
|
||||||
|
@ -49,12 +40,7 @@ func (c *Controller) GetFeedEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// GetEntry is the API handler to get a single entry.
|
// GetEntry is the API handler to get a single entry.
|
||||||
func (c *Controller) GetEntry(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) GetEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
|
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
|
||||||
builder.WithEntryID(entryID)
|
builder.WithEntryID(entryID)
|
||||||
|
|
||||||
|
@ -74,13 +60,9 @@ func (c *Controller) GetEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// GetFeedEntries is the API handler to get all feed entries.
|
// GetFeedEntries is the API handler to get all feed entries.
|
||||||
func (c *Controller) GetFeedEntries(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) GetFeedEntries(w http.ResponseWriter, r *http.Request) {
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
status := request.QueryParam(r, "status", "")
|
status := request.QueryStringParam(r, "status", "")
|
||||||
if status != "" {
|
if status != "" {
|
||||||
if err := model.ValidateEntryStatus(status); err != nil {
|
if err := model.ValidateEntryStatus(status); err != nil {
|
||||||
json.BadRequest(w, err)
|
json.BadRequest(w, err)
|
||||||
|
@ -88,13 +70,13 @@ func (c *Controller) GetFeedEntries(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
order := request.QueryParam(r, "order", model.DefaultSortingOrder)
|
order := request.QueryStringParam(r, "order", model.DefaultSortingOrder)
|
||||||
if err := model.ValidateEntryOrder(order); err != nil {
|
if err := model.ValidateEntryOrder(order); err != nil {
|
||||||
json.BadRequest(w, err)
|
json.BadRequest(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
direction := request.QueryParam(r, "direction", model.DefaultSortingDirection)
|
direction := request.QueryStringParam(r, "direction", model.DefaultSortingDirection)
|
||||||
if err := model.ValidateDirection(direction); err != nil {
|
if err := model.ValidateDirection(direction); err != nil {
|
||||||
json.BadRequest(w, err)
|
json.BadRequest(w, err)
|
||||||
return
|
return
|
||||||
|
@ -133,7 +115,7 @@ func (c *Controller) GetFeedEntries(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// GetEntries is the API handler to fetch entries.
|
// GetEntries is the API handler to fetch entries.
|
||||||
func (c *Controller) GetEntries(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) GetEntries(w http.ResponseWriter, r *http.Request) {
|
||||||
status := request.QueryParam(r, "status", "")
|
status := request.QueryStringParam(r, "status", "")
|
||||||
if status != "" {
|
if status != "" {
|
||||||
if err := model.ValidateEntryStatus(status); err != nil {
|
if err := model.ValidateEntryStatus(status); err != nil {
|
||||||
json.BadRequest(w, err)
|
json.BadRequest(w, err)
|
||||||
|
@ -141,13 +123,13 @@ func (c *Controller) GetEntries(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
order := request.QueryParam(r, "order", model.DefaultSortingOrder)
|
order := request.QueryStringParam(r, "order", model.DefaultSortingOrder)
|
||||||
if err := model.ValidateEntryOrder(order); err != nil {
|
if err := model.ValidateEntryOrder(order); err != nil {
|
||||||
json.BadRequest(w, err)
|
json.BadRequest(w, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
direction := request.QueryParam(r, "direction", model.DefaultSortingDirection)
|
direction := request.QueryStringParam(r, "direction", model.DefaultSortingDirection)
|
||||||
if err := model.ValidateDirection(direction); err != nil {
|
if err := model.ValidateDirection(direction); err != nil {
|
||||||
json.BadRequest(w, err)
|
json.BadRequest(w, err)
|
||||||
return
|
return
|
||||||
|
@ -206,12 +188,7 @@ func (c *Controller) SetEntryStatus(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// ToggleBookmark is the API handler to toggle bookmark status.
|
// ToggleBookmark is the API handler to toggle bookmark status.
|
||||||
func (c *Controller) ToggleBookmark(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) ToggleBookmark(w http.ResponseWriter, r *http.Request) {
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.store.ToggleBookmark(request.UserID(r), entryID); err != nil {
|
if err := c.store.ToggleBookmark(request.UserID(r), entryID); err != nil {
|
||||||
json.ServerError(w, err)
|
json.ServerError(w, err)
|
||||||
return
|
return
|
||||||
|
@ -245,7 +222,7 @@ func configureFilters(builder *storage.EntryQueryBuilder, r *http.Request) {
|
||||||
builder.WithStarred()
|
builder.WithStarred()
|
||||||
}
|
}
|
||||||
|
|
||||||
searchQuery := request.QueryParam(r, "search", "")
|
searchQuery := request.QueryStringParam(r, "search", "")
|
||||||
if searchQuery != "" {
|
if searchQuery != "" {
|
||||||
builder.WithSearchQuery(searchQuery)
|
builder.WithSearchQuery(searchQuery)
|
||||||
}
|
}
|
||||||
|
|
30
api/feed.go
30
api/feed.go
|
@ -65,12 +65,7 @@ func (c *Controller) CreateFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// RefreshFeed is the API handler to refresh a feed.
|
// RefreshFeed is the API handler to refresh a feed.
|
||||||
func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userID := request.UserID(r)
|
userID := request.UserID(r)
|
||||||
|
|
||||||
if !c.store.FeedExists(userID, feedID) {
|
if !c.store.FeedExists(userID, feedID) {
|
||||||
|
@ -78,7 +73,7 @@ func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = c.feedHandler.RefreshFeed(userID, feedID)
|
err := c.feedHandler.RefreshFeed(userID, feedID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
json.ServerError(w, err)
|
json.ServerError(w, err)
|
||||||
return
|
return
|
||||||
|
@ -89,12 +84,7 @@ func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// UpdateFeed is the API handler that is used to update a feed.
|
// UpdateFeed is the API handler that is used to update a feed.
|
||||||
func (c *Controller) UpdateFeed(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) UpdateFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
feedChanges, err := decodeFeedModificationPayload(r.Body)
|
feedChanges, err := decodeFeedModificationPayload(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
json.BadRequest(w, err)
|
json.BadRequest(w, err)
|
||||||
|
@ -148,12 +138,7 @@ func (c *Controller) GetFeeds(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// GetFeed is the API handler to get a feed.
|
// GetFeed is the API handler to get a feed.
|
||||||
func (c *Controller) GetFeed(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) GetFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
feed, err := c.store.FeedByID(request.UserID(r), feedID)
|
feed, err := c.store.FeedByID(request.UserID(r), feedID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
json.ServerError(w, err)
|
json.ServerError(w, err)
|
||||||
|
@ -170,12 +155,7 @@ func (c *Controller) GetFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// RemoveFeed is the API handler to remove a feed.
|
// RemoveFeed is the API handler to remove a feed.
|
||||||
func (c *Controller) RemoveFeed(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) RemoveFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userID := request.UserID(r)
|
userID := request.UserID(r)
|
||||||
|
|
||||||
if !c.store.FeedExists(userID, feedID) {
|
if !c.store.FeedExists(userID, feedID) {
|
||||||
|
|
|
@ -14,11 +14,7 @@ import (
|
||||||
|
|
||||||
// FeedIcon returns a feed icon.
|
// FeedIcon returns a feed icon.
|
||||||
func (c *Controller) FeedIcon(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) FeedIcon(w http.ResponseWriter, r *http.Request) {
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !c.store.HasIcon(feedID) {
|
if !c.store.HasIcon(feedID) {
|
||||||
json.NotFound(w, errors.New("This feed doesn't have any icon"))
|
json.NotFound(w, errors.New("This feed doesn't have any icon"))
|
||||||
|
|
23
api/user.go
23
api/user.go
|
@ -63,12 +63,7 @@ func (c *Controller) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := request.IntParam(r, "userID")
|
userID := request.RouteInt64Param(r, "userID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userChanges, err := decodeUserModificationPayload(r.Body)
|
userChanges, err := decodeUserModificationPayload(r.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
json.BadRequest(w, err)
|
json.BadRequest(w, err)
|
||||||
|
@ -124,12 +119,7 @@ func (c *Controller) UserByID(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := request.IntParam(r, "userID")
|
userID := request.RouteInt64Param(r, "userID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := c.store.UserByID(userID)
|
user, err := c.store.UserByID(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
json.BadRequest(w, errors.New("Unable to fetch this user from the database"))
|
json.BadRequest(w, errors.New("Unable to fetch this user from the database"))
|
||||||
|
@ -152,7 +142,7 @@ func (c *Controller) UserByUsername(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
username := request.Param(r, "username", "")
|
username := request.RouteStringParam(r, "username")
|
||||||
user, err := c.store.UserByUsername(username)
|
user, err := c.store.UserByUsername(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
json.BadRequest(w, errors.New("Unable to fetch this user from the database"))
|
json.BadRequest(w, errors.New("Unable to fetch this user from the database"))
|
||||||
|
@ -174,12 +164,7 @@ func (c *Controller) RemoveUser(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := request.IntParam(r, "userID")
|
userID := request.RouteInt64Param(r, "userID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
user, err := c.store.UserByID(userID)
|
user, err := c.store.UserByID(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
json.ServerError(w, err)
|
json.ServerError(w, err)
|
||||||
|
|
|
@ -356,7 +356,7 @@ func (c *Controller) handleItems(w http.ResponseWriter, r *http.Request) {
|
||||||
builder.WithOffset(maxID)
|
builder.WithOffset(maxID)
|
||||||
}
|
}
|
||||||
|
|
||||||
csvItemIDs := request.QueryParam(r, "with_ids", "")
|
csvItemIDs := request.QueryStringParam(r, "with_ids", "")
|
||||||
if csvItemIDs != "" {
|
if csvItemIDs != "" {
|
||||||
var itemIDs []int64
|
var itemIDs []int64
|
||||||
|
|
||||||
|
|
38
http/request/client_ip.go
Normal file
38
http/request/client_ip.go
Normal file
|
@ -0,0 +1,38 @@
|
||||||
|
// Copyright 2018 Frédéric Guillot. All rights reserved.
|
||||||
|
// Use of this source code is governed by the Apache 2.0
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package request // import "miniflux.app/http/request"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FindClientIP returns client real IP address.
|
||||||
|
func FindClientIP(r *http.Request) string {
|
||||||
|
headers := []string{"X-Forwarded-For", "X-Real-Ip"}
|
||||||
|
for _, header := range headers {
|
||||||
|
value := r.Header.Get(header)
|
||||||
|
|
||||||
|
if value != "" {
|
||||||
|
addresses := strings.Split(value, ",")
|
||||||
|
address := strings.TrimSpace(addresses[0])
|
||||||
|
|
||||||
|
if net.ParseIP(address) != nil {
|
||||||
|
return address
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to TCP/IP source IP address.
|
||||||
|
var remoteIP string
|
||||||
|
if strings.ContainsRune(r.RemoteAddr, ':') {
|
||||||
|
remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr)
|
||||||
|
} else {
|
||||||
|
remoteIP = r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
return remoteIP
|
||||||
|
}
|
|
@ -9,7 +9,7 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRealIPWithoutHeaders(t *testing.T) {
|
func TestFindClientIPWithoutHeaders(t *testing.T) {
|
||||||
r := &http.Request{RemoteAddr: "192.168.0.1:4242"}
|
r := &http.Request{RemoteAddr: "192.168.0.1:4242"}
|
||||||
if ip := FindClientIP(r); ip != "192.168.0.1" {
|
if ip := FindClientIP(r); ip != "192.168.0.1" {
|
||||||
t.Fatalf(`Unexpected result, got: %q`, ip)
|
t.Fatalf(`Unexpected result, got: %q`, ip)
|
||||||
|
@ -21,7 +21,7 @@ func TestRealIPWithoutHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRealIPWithXFFHeader(t *testing.T) {
|
func TestFindClientIPWithXFFHeader(t *testing.T) {
|
||||||
// Test with multiple IPv4 addresses.
|
// Test with multiple IPv4 addresses.
|
||||||
headers := http.Header{}
|
headers := http.Header{}
|
||||||
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
|
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
|
||||||
|
@ -59,7 +59,7 @@ func TestRealIPWithXFFHeader(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRealIPWithXRealIPHeader(t *testing.T) {
|
func TestClientIPWithXRealIPHeader(t *testing.T) {
|
||||||
headers := http.Header{}
|
headers := http.Header{}
|
||||||
headers.Set("X-Real-Ip", "192.168.122.1")
|
headers.Set("X-Real-Ip", "192.168.122.1")
|
||||||
r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
|
r := &http.Request{RemoteAddr: "192.168.0.1:4242", Header: headers}
|
||||||
|
@ -69,7 +69,7 @@ func TestRealIPWithXRealIPHeader(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRealIPWithBothHeaders(t *testing.T) {
|
func TestClientIPWithBothHeaders(t *testing.T) {
|
||||||
headers := http.Header{}
|
headers := http.Header{}
|
||||||
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
|
headers.Set("X-Forwarded-For", "203.0.113.195, 70.41.3.18, 150.172.238.178")
|
||||||
headers.Set("X-Real-Ip", "192.168.122.1")
|
headers.Set("X-Real-Ip", "192.168.122.1")
|
|
@ -111,7 +111,12 @@ func ClientIP(r *http.Request) string {
|
||||||
|
|
||||||
func getContextStringValue(r *http.Request, key ContextKey) string {
|
func getContextStringValue(r *http.Request, key ContextKey) string {
|
||||||
if v := r.Context().Value(key); v != nil {
|
if v := r.Context().Value(key); v != nil {
|
||||||
return v.(string)
|
value, valid := v.(string)
|
||||||
|
if !valid {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
@ -119,7 +124,12 @@ func getContextStringValue(r *http.Request, key ContextKey) string {
|
||||||
|
|
||||||
func getContextBoolValue(r *http.Request, key ContextKey) bool {
|
func getContextBoolValue(r *http.Request, key ContextKey) bool {
|
||||||
if v := r.Context().Value(key); v != nil {
|
if v := r.Context().Value(key); v != nil {
|
||||||
return v.(bool)
|
value, valid := v.(bool)
|
||||||
|
if !valid {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
@ -127,7 +137,12 @@ func getContextBoolValue(r *http.Request, key ContextKey) bool {
|
||||||
|
|
||||||
func getContextInt64Value(r *http.Request, key ContextKey) int64 {
|
func getContextInt64Value(r *http.Request, key ContextKey) int64 {
|
||||||
if v := r.Context().Value(key); v != nil {
|
if v := r.Context().Value(key); v != nil {
|
||||||
return v.(int64)
|
value, valid := v.(int64)
|
||||||
|
if !valid {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
436
http/request/context_test.go
Normal file
436
http/request/context_test.go
Normal file
|
@ -0,0 +1,436 @@
|
||||||
|
// Copyright 2018 Frédéric Guillot. All rights reserved.
|
||||||
|
// Use of this source code is governed by the Apache 2.0
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package request // import "miniflux.app/http/request"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestContextStringValue(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, ClientIPContextKey, "IP")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result := getContextStringValue(r, ClientIPContextKey)
|
||||||
|
expected := "IP"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextStringValueWithInvalidType(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, ClientIPContextKey, 0)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result := getContextStringValue(r, ClientIPContextKey)
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextStringValueWhenUnset(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := getContextStringValue(r, ClientIPContextKey)
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextBoolValue(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result := getContextBoolValue(r, IsAdminUserContextKey)
|
||||||
|
expected := true
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextBoolValueWithInvalidType(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, IsAdminUserContextKey, "invalid")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result := getContextBoolValue(r, IsAdminUserContextKey)
|
||||||
|
expected := false
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextBoolValueWhenUnset(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := getContextBoolValue(r, IsAdminUserContextKey)
|
||||||
|
expected := false
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextInt64Value(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, UserIDContextKey, int64(1234))
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result := getContextInt64Value(r, UserIDContextKey)
|
||||||
|
expected := int64(1234)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextInt64ValueWithInvalidType(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, UserIDContextKey, "invalid")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result := getContextInt64Value(r, UserIDContextKey)
|
||||||
|
expected := int64(0)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestContextInt64ValueWhenUnset(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := getContextInt64Value(r, UserIDContextKey)
|
||||||
|
expected := int64(0)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAdmin(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := IsAdminUser(r)
|
||||||
|
expected := false
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, IsAdminUserContextKey, true)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = IsAdminUser(r)
|
||||||
|
expected = true
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAuthenticated(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := IsAuthenticated(r)
|
||||||
|
expected := false
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, IsAuthenticatedContextKey, true)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = IsAuthenticated(r)
|
||||||
|
expected = true
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserID(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := UserID(r)
|
||||||
|
expected := int64(0)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, UserIDContextKey, int64(123))
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = UserID(r)
|
||||||
|
expected = int64(123)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserTimezone(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := UserTimezone(r)
|
||||||
|
expected := "UTC"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, UserTimezoneContextKey, "Europe/Paris")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = UserTimezone(r)
|
||||||
|
expected = "Europe/Paris"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserLanguage(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := UserLanguage(r)
|
||||||
|
expected := "en_US"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, UserLanguageContextKey, "fr_FR")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = UserLanguage(r)
|
||||||
|
expected = "fr_FR"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserTheme(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := UserTheme(r)
|
||||||
|
expected := "default"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, UserThemeContextKey, "black")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = UserTheme(r)
|
||||||
|
expected = "black"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCSRF(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := CSRF(r)
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, CSRFContextKey, "secret")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = CSRF(r)
|
||||||
|
expected = "secret"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSessionID(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := SessionID(r)
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, SessionIDContextKey, "id")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = SessionID(r)
|
||||||
|
expected = "id"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserSessionToken(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := UserSessionToken(r)
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, UserSessionTokenContextKey, "token")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = UserSessionToken(r)
|
||||||
|
expected = "token"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOAuth2State(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := OAuth2State(r)
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, OAuth2StateContextKey, "state")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = OAuth2State(r)
|
||||||
|
expected = "state"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlashMessage(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := FlashMessage(r)
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, FlashMessageContextKey, "message")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = FlashMessage(r)
|
||||||
|
expected = "message"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlashErrorMessage(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := FlashErrorMessage(r)
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, FlashErrorMessageContextKey, "error message")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = FlashErrorMessage(r)
|
||||||
|
expected = "error message"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPocketRequestToken(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := PocketRequestToken(r)
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, PocketRequestTokenContextKey, "request token")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = PocketRequestToken(r)
|
||||||
|
expected = "request token"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientIP(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := ClientIP(r)
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = context.WithValue(ctx, ClientIPContextKey, "127.0.0.1")
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
result = ClientIP(r)
|
||||||
|
expected = "127.0.0.1"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected context value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
17
http/request/cookie.go
Normal file
17
http/request/cookie.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
// Copyright 2018 Frédéric Guillot. All rights reserved.
|
||||||
|
// Use of this source code is governed by the Apache 2.0
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package request // import "miniflux.app/http/request"
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
// CookieValue returns the cookie value.
|
||||||
|
func CookieValue(r *http.Request, name string) string {
|
||||||
|
cookie, err := r.Cookie(name)
|
||||||
|
if err == http.ErrNoCookie {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return cookie.Value
|
||||||
|
}
|
33
http/request/cookie_test.go
Normal file
33
http/request/cookie_test.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
// Copyright 2018 Frédéric Guillot. All rights reserved.
|
||||||
|
// Use of this source code is governed by the Apache 2.0
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package request // import "miniflux.app/http/request"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetCookieValue(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
r.AddCookie(&http.Cookie{Value: "cookie_value", Name: "my_cookie"})
|
||||||
|
|
||||||
|
result := CookieValue(r, "my_cookie")
|
||||||
|
expected := "cookie_value"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetCookieValueWhenUnset(t *testing.T) {
|
||||||
|
r, _ := http.NewRequest("GET", "http://example.org", nil)
|
||||||
|
|
||||||
|
result := CookieValue(r, "my_cookie")
|
||||||
|
expected := ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected cookie value, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
84
http/request/params.go
Normal file
84
http/request/params.go
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
// Copyright 2018 Frédéric Guillot. All rights reserved.
|
||||||
|
// Use of this source code is governed by the Apache 2.0
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package request // import "miniflux.app/http/request"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FormInt64Value returns a form value as integer.
|
||||||
|
func FormInt64Value(r *http.Request, param string) int64 {
|
||||||
|
value := r.FormValue(param)
|
||||||
|
integer, err := strconv.ParseInt(value, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return integer
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteInt64Param returns an URL route parameter as int64.
|
||||||
|
func RouteInt64Param(r *http.Request, param string) int64 {
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
value, err := strconv.ParseInt(vars[param], 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if value < 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteStringParam returns a URL route parameter as string.
|
||||||
|
func RouteStringParam(r *http.Request, param string) string {
|
||||||
|
vars := mux.Vars(r)
|
||||||
|
return vars[param]
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryStringParam returns a query string parameter as string.
|
||||||
|
func QueryStringParam(r *http.Request, param, defaultValue string) string {
|
||||||
|
value := r.URL.Query().Get(param)
|
||||||
|
if value == "" {
|
||||||
|
value = defaultValue
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryIntParam returns a query string parameter as integer.
|
||||||
|
func QueryIntParam(r *http.Request, param string, defaultValue int) int {
|
||||||
|
return int(QueryInt64Param(r, param, int64(defaultValue)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryInt64Param returns a query string parameter as int64.
|
||||||
|
func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
|
||||||
|
value := r.URL.Query().Get(param)
|
||||||
|
if value == "" {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := strconv.ParseInt(value, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
if val < 0 {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasQueryParam checks if the query string contains the given parameter.
|
||||||
|
func HasQueryParam(r *http.Request, param string) bool {
|
||||||
|
values := r.URL.Query()
|
||||||
|
_, ok := values[param]
|
||||||
|
return ok
|
||||||
|
}
|
215
http/request/params_test.go
Normal file
215
http/request/params_test.go
Normal file
|
@ -0,0 +1,215 @@
|
||||||
|
// Copyright 2018 Frédéric Guillot. All rights reserved.
|
||||||
|
// Use of this source code is governed by the Apache 2.0
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package request // import "miniflux.app/http/request"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFormInt64Value(t *testing.T) {
|
||||||
|
f := url.Values{}
|
||||||
|
f.Set("integer value", "42")
|
||||||
|
f.Set("invalid value", "invalid integer")
|
||||||
|
|
||||||
|
r := &http.Request{Form: f}
|
||||||
|
|
||||||
|
result := FormInt64Value(r, "integer value")
|
||||||
|
expected := int64(42)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = FormInt64Value(r, "invalid value")
|
||||||
|
expected = int64(0)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = FormInt64Value(r, "missing value")
|
||||||
|
expected = int64(0)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouteStringParam(t *testing.T) {
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/route/{variable}/index", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
result := RouteStringParam(r, "variable")
|
||||||
|
expected := "value"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = RouteStringParam(r, "missing variable")
|
||||||
|
expected = ""
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
r, err := http.NewRequest("GET", "/route/value/index", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouteInt64Param(t *testing.T) {
|
||||||
|
router := mux.NewRouter()
|
||||||
|
router.HandleFunc("/a/{variable1}/b/{variable2}/c/{variable3}", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
result := RouteInt64Param(r, "variable1")
|
||||||
|
expected := int64(42)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = RouteInt64Param(r, "missing variable")
|
||||||
|
expected = 0
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = RouteInt64Param(r, "variable2")
|
||||||
|
expected = 0
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = RouteInt64Param(r, "variable3")
|
||||||
|
expected = 0
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
r, err := http.NewRequest("GET", "/a/42/b/not-int/c/-10", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryStringParam(t *testing.T) {
|
||||||
|
u, _ := url.Parse("http://example.org/?key=value")
|
||||||
|
r := &http.Request{URL: u}
|
||||||
|
|
||||||
|
result := QueryStringParam(r, "key", "fallback")
|
||||||
|
expected := "value"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = QueryStringParam(r, "missing key", "fallback")
|
||||||
|
expected = "fallback"
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %q instead of %q`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryIntParam(t *testing.T) {
|
||||||
|
u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5")
|
||||||
|
r := &http.Request{URL: u}
|
||||||
|
|
||||||
|
result := QueryIntParam(r, "key", 84)
|
||||||
|
expected := 42
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = QueryIntParam(r, "missing key", 84)
|
||||||
|
expected = 84
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = QueryIntParam(r, "negative", 69)
|
||||||
|
expected = 69
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = QueryIntParam(r, "invalid", 99)
|
||||||
|
expected = 99
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryInt64Param(t *testing.T) {
|
||||||
|
u, _ := url.Parse("http://example.org/?key=42&invalid=value&negative=-5")
|
||||||
|
r := &http.Request{URL: u}
|
||||||
|
|
||||||
|
result := QueryInt64Param(r, "key", int64(84))
|
||||||
|
expected := int64(42)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = QueryInt64Param(r, "missing key", int64(84))
|
||||||
|
expected = int64(84)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = QueryInt64Param(r, "invalid", int64(69))
|
||||||
|
expected = int64(69)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = QueryInt64Param(r, "invalid", int64(99))
|
||||||
|
expected = int64(99)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %d instead of %d`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasQueryParam(t *testing.T) {
|
||||||
|
u, _ := url.Parse("http://example.org/?key=42")
|
||||||
|
r := &http.Request{URL: u}
|
||||||
|
|
||||||
|
result := HasQueryParam(r, "key")
|
||||||
|
expected := true
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
|
||||||
|
result = HasQueryParam(r, "missing key")
|
||||||
|
expected = false
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Errorf(`Unexpected result, got %v instead of %v`, result, expected)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,128 +0,0 @@
|
||||||
// Copyright 2018 Frédéric Guillot. All rights reserved.
|
|
||||||
// Use of this source code is governed by the Apache 2.0
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
package request // import "miniflux.app/http/request"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/http"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Cookie returns the cookie value.
|
|
||||||
func Cookie(r *http.Request, name string) string {
|
|
||||||
cookie, err := r.Cookie(name)
|
|
||||||
if err == http.ErrNoCookie {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return cookie.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
// FormInt64Value returns a form value as integer.
|
|
||||||
func FormInt64Value(r *http.Request, param string) int64 {
|
|
||||||
value := r.FormValue(param)
|
|
||||||
integer, err := strconv.ParseInt(value, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return integer
|
|
||||||
}
|
|
||||||
|
|
||||||
// IntParam returns an URL route parameter as integer.
|
|
||||||
func IntParam(r *http.Request, param string) (int64, error) {
|
|
||||||
vars := mux.Vars(r)
|
|
||||||
value, err := strconv.Atoi(vars[param])
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("request: %s parameter is not an integer", param)
|
|
||||||
}
|
|
||||||
|
|
||||||
if value < 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return int64(value), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Param returns an URL route parameter as string.
|
|
||||||
func Param(r *http.Request, param, defaultValue string) string {
|
|
||||||
vars := mux.Vars(r)
|
|
||||||
value := vars[param]
|
|
||||||
if value == "" {
|
|
||||||
value = defaultValue
|
|
||||||
}
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryParam returns a querystring parameter as string.
|
|
||||||
func QueryParam(r *http.Request, param, defaultValue string) string {
|
|
||||||
value := r.URL.Query().Get(param)
|
|
||||||
if value == "" {
|
|
||||||
value = defaultValue
|
|
||||||
}
|
|
||||||
return value
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryIntParam returns a querystring parameter as integer.
|
|
||||||
func QueryIntParam(r *http.Request, param string, defaultValue int) int {
|
|
||||||
return int(QueryInt64Param(r, param, int64(defaultValue)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryInt64Param returns a querystring parameter as int64.
|
|
||||||
func QueryInt64Param(r *http.Request, param string, defaultValue int64) int64 {
|
|
||||||
value := r.URL.Query().Get(param)
|
|
||||||
if value == "" {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
|
|
||||||
val, err := strconv.ParseInt(value, 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
|
|
||||||
if val < 0 {
|
|
||||||
return defaultValue
|
|
||||||
}
|
|
||||||
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasQueryParam checks if the query string contains the given parameter.
|
|
||||||
func HasQueryParam(r *http.Request, param string) bool {
|
|
||||||
values := r.URL.Query()
|
|
||||||
_, ok := values[param]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// FindClientIP returns client's real IP address.
|
|
||||||
func FindClientIP(r *http.Request) string {
|
|
||||||
headers := []string{"X-Forwarded-For", "X-Real-Ip"}
|
|
||||||
for _, header := range headers {
|
|
||||||
value := r.Header.Get(header)
|
|
||||||
|
|
||||||
if value != "" {
|
|
||||||
addresses := strings.Split(value, ",")
|
|
||||||
address := strings.TrimSpace(addresses[0])
|
|
||||||
|
|
||||||
if net.ParseIP(address) != nil {
|
|
||||||
return address
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to TCP/IP source IP address.
|
|
||||||
var remoteIP string
|
|
||||||
if strings.ContainsRune(r.RemoteAddr, ':') {
|
|
||||||
remoteIP, _, _ = net.SplitHostPort(r.RemoteAddr)
|
|
||||||
} else {
|
|
||||||
remoteIP = r.RemoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
return remoteIP
|
|
||||||
}
|
|
|
@ -62,7 +62,7 @@ func (m *Middleware) AppSession(next http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) getAppSessionValueFromCookie(r *http.Request) *model.Session {
|
func (m *Middleware) getAppSessionValueFromCookie(r *http.Request) *model.Session {
|
||||||
cookieValue := request.Cookie(r, cookie.CookieSessionID)
|
cookieValue := request.CookieValue(r, cookie.CookieSessionID)
|
||||||
if cookieValue == "" {
|
if cookieValue == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -62,7 +62,7 @@ func (m *Middleware) isPublicRoute(r *http.Request) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) getUserSessionFromCookie(r *http.Request) *model.UserSession {
|
func (m *Middleware) getUserSessionFromCookie(r *http.Request) *model.UserSession {
|
||||||
cookieValue := request.Cookie(r, cookie.CookieUserSessionID)
|
cookieValue := request.CookieValue(r, cookie.CookieUserSessionID)
|
||||||
if cookieValue == "" {
|
if cookieValue == "" {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,12 +25,7 @@ func (c *Controller) EditCategory(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
categoryID, err := request.IntParam(r, "categoryID")
|
categoryID := request.RouteInt64Param(r, "categoryID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
category, err := c.store.Category(request.UserID(r), categoryID)
|
category, err := c.store.Category(request.UserID(r), categoryID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
|
@ -23,12 +23,7 @@ func (c *Controller) CategoryEntries(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
categoryID, err := request.IntParam(r, "categoryID")
|
categoryID := request.RouteInt64Param(r, "categoryID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
category, err := c.store.Category(request.UserID(r), categoryID)
|
category, err := c.store.Category(request.UserID(r), categoryID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
|
@ -21,12 +21,7 @@ func (c *Controller) RemoveCategory(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
categoryID, err := request.IntParam(r, "categoryID")
|
categoryID := request.RouteInt64Param(r, "categoryID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
category, err := c.store.Category(request.UserID(r), categoryID)
|
category, err := c.store.Category(request.UserID(r), categoryID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
|
@ -25,12 +25,7 @@ func (c *Controller) UpdateCategory(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
categoryID, err := request.IntParam(r, "categoryID")
|
categoryID := request.RouteInt64Param(r, "categoryID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
category, err := c.store.Category(request.UserID(r), categoryID)
|
category, err := c.store.Category(request.UserID(r), categoryID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
|
@ -25,12 +25,7 @@ func (c *Controller) ShowStarredEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := c.store.NewEntryQueryBuilder(user.ID)
|
builder := c.store.NewEntryQueryBuilder(user.ID)
|
||||||
builder.WithEntryID(entryID)
|
builder.WithEntryID(entryID)
|
||||||
builder.WithoutStatus(model.EntryStatusRemoved)
|
builder.WithoutStatus(model.EntryStatusRemoved)
|
||||||
|
|
|
@ -25,17 +25,8 @@ func (c *Controller) ShowCategoryEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
categoryID, err := request.IntParam(r, "categoryID")
|
categoryID := request.RouteInt64Param(r, "categoryID")
|
||||||
if err != nil {
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := c.store.NewEntryQueryBuilder(user.ID)
|
builder := c.store.NewEntryQueryBuilder(user.ID)
|
||||||
builder.WithCategoryID(categoryID)
|
builder.WithCategoryID(categoryID)
|
||||||
|
|
|
@ -25,17 +25,8 @@ func (c *Controller) ShowFeedEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
if err != nil {
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := c.store.NewEntryQueryBuilder(user.ID)
|
builder := c.store.NewEntryQueryBuilder(user.ID)
|
||||||
builder.WithFeedID(feedID)
|
builder.WithFeedID(feedID)
|
||||||
|
|
|
@ -24,12 +24,7 @@ func (c *Controller) ShowReadEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := c.store.NewEntryQueryBuilder(user.ID)
|
builder := c.store.NewEntryQueryBuilder(user.ID)
|
||||||
builder.WithEntryID(entryID)
|
builder.WithEntryID(entryID)
|
||||||
builder.WithoutStatus(model.EntryStatusRemoved)
|
builder.WithoutStatus(model.EntryStatusRemoved)
|
||||||
|
|
|
@ -16,12 +16,7 @@ import (
|
||||||
|
|
||||||
// SaveEntry send the link to external services.
|
// SaveEntry send the link to external services.
|
||||||
func (c *Controller) SaveEntry(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) SaveEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
|
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
|
||||||
builder.WithEntryID(entryID)
|
builder.WithEntryID(entryID)
|
||||||
builder.WithoutStatus(model.EntryStatusRemoved)
|
builder.WithoutStatus(model.EntryStatusRemoved)
|
||||||
|
|
|
@ -17,12 +17,7 @@ import (
|
||||||
|
|
||||||
// FetchContent downloads the original HTML page and returns relevant contents.
|
// FetchContent downloads the original HTML page and returns relevant contents.
|
||||||
func (c *Controller) FetchContent(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) FetchContent(w http.ResponseWriter, r *http.Request) {
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
|
builder := c.store.NewEntryQueryBuilder(request.UserID(r))
|
||||||
builder.WithEntryID(entryID)
|
builder.WithEntryID(entryID)
|
||||||
builder.WithoutStatus(model.EntryStatusRemoved)
|
builder.WithoutStatus(model.EntryStatusRemoved)
|
||||||
|
|
|
@ -25,13 +25,8 @@ func (c *Controller) ShowSearchEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
if err != nil {
|
searchQuery := request.QueryStringParam(r, "q", "")
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
searchQuery := request.QueryParam(r, "q", "")
|
|
||||||
builder := c.store.NewEntryQueryBuilder(user.ID)
|
builder := c.store.NewEntryQueryBuilder(user.ID)
|
||||||
builder.WithSearchQuery(searchQuery)
|
builder.WithSearchQuery(searchQuery)
|
||||||
builder.WithEntryID(entryID)
|
builder.WithEntryID(entryID)
|
||||||
|
|
|
@ -14,12 +14,7 @@ import (
|
||||||
|
|
||||||
// ToggleBookmark handles Ajax request to toggle bookmark value.
|
// ToggleBookmark handles Ajax request to toggle bookmark value.
|
||||||
func (c *Controller) ToggleBookmark(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) ToggleBookmark(w http.ResponseWriter, r *http.Request) {
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
if err != nil {
|
|
||||||
json.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.store.ToggleBookmark(request.UserID(r), entryID); err != nil {
|
if err := c.store.ToggleBookmark(request.UserID(r), entryID); err != nil {
|
||||||
logger.Error("[Controller:ToggleBookmark] %v", err)
|
logger.Error("[Controller:ToggleBookmark] %v", err)
|
||||||
json.ServerError(w, nil)
|
json.ServerError(w, nil)
|
||||||
|
|
|
@ -25,12 +25,7 @@ func (c *Controller) ShowUnreadEntry(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
entryID, err := request.IntParam(r, "entryID")
|
entryID := request.RouteInt64Param(r, "entryID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
builder := c.store.NewEntryQueryBuilder(user.ID)
|
builder := c.store.NewEntryQueryBuilder(user.ID)
|
||||||
builder.WithEntryID(entryID)
|
builder.WithEntryID(entryID)
|
||||||
builder.WithoutStatus(model.EntryStatusRemoved)
|
builder.WithoutStatus(model.EntryStatusRemoved)
|
||||||
|
|
|
@ -23,12 +23,7 @@ func (c *Controller) EditFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
feed, err := c.store.FeedByID(user.ID, feedID)
|
feed, err := c.store.FeedByID(user.ID, feedID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
|
@ -23,12 +23,7 @@ func (c *Controller) ShowFeedEntries(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
feed, err := c.store.FeedByID(user.ID, feedID)
|
feed, err := c.store.FeedByID(user.ID, feedID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
|
@ -15,12 +15,7 @@ import (
|
||||||
|
|
||||||
// ShowIcon shows the feed icon.
|
// ShowIcon shows the feed icon.
|
||||||
func (c *Controller) ShowIcon(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) ShowIcon(w http.ResponseWriter, r *http.Request) {
|
||||||
iconID, err := request.IntParam(r, "iconID")
|
iconID := request.RouteInt64Param(r, "iconID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
icon, err := c.store.IconByID(iconID)
|
icon, err := c.store.IconByID(iconID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
|
@ -16,12 +16,7 @@ import (
|
||||||
|
|
||||||
// RefreshFeed refresh a subscription and redirect to the feed entries page.
|
// RefreshFeed refresh a subscription and redirect to the feed entries page.
|
||||||
func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) RefreshFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.feedHandler.RefreshFeed(request.UserID(r), feedID); err != nil {
|
if err := c.feedHandler.RefreshFeed(request.UserID(r), feedID); err != nil {
|
||||||
logger.Error("[Controller:RefreshFeed] %v", err)
|
logger.Error("[Controller:RefreshFeed] %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,12 +15,7 @@ import (
|
||||||
|
|
||||||
// RemoveFeed deletes a subscription from the database and redirect to the list of feeds page.
|
// RemoveFeed deletes a subscription from the database and redirect to the list of feeds page.
|
||||||
func (c *Controller) RemoveFeed(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) RemoveFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
html.ServerError(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.store.RemoveFeed(request.UserID(r), feedID); err != nil {
|
if err := c.store.RemoveFeed(request.UserID(r), feedID); err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
return
|
return
|
||||||
|
|
|
@ -26,12 +26,7 @@ func (c *Controller) UpdateFeed(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
feedID, err := request.IntParam(r, "feedID")
|
feedID := request.RouteInt64Param(r, "feedID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
feed, err := c.store.FeedByID(user.ID, feedID)
|
feed, err := c.store.FeedByID(user.ID, feedID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
|
@ -23,21 +23,21 @@ func (c *Controller) OAuth2Callback(w http.ResponseWriter, r *http.Request) {
|
||||||
printer := locale.NewPrinter(request.UserLanguage(r))
|
printer := locale.NewPrinter(request.UserLanguage(r))
|
||||||
sess := session.New(c.store, request.SessionID(r))
|
sess := session.New(c.store, request.SessionID(r))
|
||||||
|
|
||||||
provider := request.Param(r, "provider", "")
|
provider := request.RouteStringParam(r, "provider")
|
||||||
if provider == "" {
|
if provider == "" {
|
||||||
logger.Error("[OAuth2] Invalid or missing provider")
|
logger.Error("[OAuth2] Invalid or missing provider")
|
||||||
response.Redirect(w, r, route.Path(c.router, "login"))
|
response.Redirect(w, r, route.Path(c.router, "login"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
code := request.QueryParam(r, "code", "")
|
code := request.QueryStringParam(r, "code", "")
|
||||||
if code == "" {
|
if code == "" {
|
||||||
logger.Error("[OAuth2] No code received on callback")
|
logger.Error("[OAuth2] No code received on callback")
|
||||||
response.Redirect(w, r, route.Path(c.router, "login"))
|
response.Redirect(w, r, route.Path(c.router, "login"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
state := request.QueryParam(r, "state", "")
|
state := request.QueryStringParam(r, "state", "")
|
||||||
if state == "" || state != request.OAuth2State(r) {
|
if state == "" || state != request.OAuth2State(r) {
|
||||||
logger.Error(`[OAuth2] Invalid state value: got "%s" instead of "%s"`, state, request.OAuth2State(r))
|
logger.Error(`[OAuth2] Invalid state value: got "%s" instead of "%s"`, state, request.OAuth2State(r))
|
||||||
response.Redirect(w, r, route.Path(c.router, "login"))
|
response.Redirect(w, r, route.Path(c.router, "login"))
|
||||||
|
|
|
@ -18,7 +18,7 @@ import (
|
||||||
func (c *Controller) OAuth2Redirect(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) OAuth2Redirect(w http.ResponseWriter, r *http.Request) {
|
||||||
sess := session.New(c.store, request.SessionID(r))
|
sess := session.New(c.store, request.SessionID(r))
|
||||||
|
|
||||||
provider := request.Param(r, "provider", "")
|
provider := request.RouteStringParam(r, "provider")
|
||||||
if provider == "" {
|
if provider == "" {
|
||||||
logger.Error("[OAuth2] Invalid or missing provider: %s", provider)
|
logger.Error("[OAuth2] Invalid or missing provider: %s", provider)
|
||||||
response.Redirect(w, r, route.Path(c.router, "login"))
|
response.Redirect(w, r, route.Path(c.router, "login"))
|
||||||
|
|
|
@ -19,7 +19,7 @@ import (
|
||||||
// OAuth2Unlink unlink an account from the external provider.
|
// OAuth2Unlink unlink an account from the external provider.
|
||||||
func (c *Controller) OAuth2Unlink(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) OAuth2Unlink(w http.ResponseWriter, r *http.Request) {
|
||||||
printer := locale.NewPrinter(request.UserLanguage(r))
|
printer := locale.NewPrinter(request.UserLanguage(r))
|
||||||
provider := request.Param(r, "provider", "")
|
provider := request.RouteStringParam(r, "provider")
|
||||||
if provider == "" {
|
if provider == "" {
|
||||||
logger.Info("[OAuth2] Invalid or missing provider")
|
logger.Info("[OAuth2] Invalid or missing provider")
|
||||||
response.Redirect(w, r, route.Path(c.router, "login"))
|
response.Redirect(w, r, route.Path(c.router, "login"))
|
||||||
|
|
|
@ -27,7 +27,7 @@ func (c *Controller) ImageProxy(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
encodedURL := request.Param(r, "encodedURL", "")
|
encodedURL := request.RouteStringParam(r, "encodedURL")
|
||||||
if encodedURL == "" {
|
if encodedURL == "" {
|
||||||
html.BadRequest(w, errors.New("No URL provided"))
|
html.BadRequest(w, errors.New("No URL provided"))
|
||||||
return
|
return
|
||||||
|
|
|
@ -23,7 +23,7 @@ func (c *Controller) ShowSearchEntries(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
searchQuery := request.QueryParam(r, "q", "")
|
searchQuery := request.QueryStringParam(r, "q", "")
|
||||||
offset := request.QueryIntParam(r, "offset", 0)
|
offset := request.QueryIntParam(r, "offset", 0)
|
||||||
builder := c.store.NewEntryQueryBuilder(user.ID)
|
builder := c.store.NewEntryQueryBuilder(user.ID)
|
||||||
builder.WithSearchQuery(searchQuery)
|
builder.WithSearchQuery(searchQuery)
|
||||||
|
|
|
@ -9,20 +9,14 @@ import (
|
||||||
|
|
||||||
"miniflux.app/http/request"
|
"miniflux.app/http/request"
|
||||||
"miniflux.app/http/response"
|
"miniflux.app/http/response"
|
||||||
"miniflux.app/http/response/html"
|
|
||||||
"miniflux.app/http/route"
|
"miniflux.app/http/route"
|
||||||
"miniflux.app/logger"
|
"miniflux.app/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RemoveSession remove a user session.
|
// RemoveSession remove a user session.
|
||||||
func (c *Controller) RemoveSession(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) RemoveSession(w http.ResponseWriter, r *http.Request) {
|
||||||
sessionID, err := request.IntParam(r, "sessionID")
|
sessionID := request.RouteInt64Param(r, "sessionID")
|
||||||
if err != nil {
|
err := c.store.RemoveUserSessionByID(request.UserID(r), sessionID)
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.store.RemoveUserSessionByID(request.UserID(r), sessionID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("[Controller:RemoveSession] %v", err)
|
logger.Error("[Controller:RemoveSession] %v", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ import (
|
||||||
|
|
||||||
// AppIcon renders application icons.
|
// AppIcon renders application icons.
|
||||||
func (c *Controller) AppIcon(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) AppIcon(w http.ResponseWriter, r *http.Request) {
|
||||||
filename := request.Param(r, "filename", "favicon.png")
|
filename := request.RouteStringParam(r, "filename")
|
||||||
encodedBlob, found := static.Binaries[filename]
|
encodedBlob, found := static.Binaries[filename]
|
||||||
if !found {
|
if !found {
|
||||||
logger.Info("[Controller:AppIcon] This icon doesn't exists: %s", filename)
|
logger.Info("[Controller:AppIcon] This icon doesn't exists: %s", filename)
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
|
|
||||||
// Javascript renders application client side code.
|
// Javascript renders application client side code.
|
||||||
func (c *Controller) Javascript(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) Javascript(w http.ResponseWriter, r *http.Request) {
|
||||||
filename := request.Param(r, "name", "app")
|
filename := request.RouteStringParam(r, "name")
|
||||||
if _, found := static.Javascripts[filename]; !found {
|
if _, found := static.Javascripts[filename]; !found {
|
||||||
html.NotFound(w)
|
html.NotFound(w)
|
||||||
return
|
return
|
||||||
|
|
|
@ -16,7 +16,7 @@ import (
|
||||||
|
|
||||||
// Stylesheet renders the CSS.
|
// Stylesheet renders the CSS.
|
||||||
func (c *Controller) Stylesheet(w http.ResponseWriter, r *http.Request) {
|
func (c *Controller) Stylesheet(w http.ResponseWriter, r *http.Request) {
|
||||||
stylesheet := request.Param(r, "name", "default")
|
stylesheet := request.RouteStringParam(r, "name")
|
||||||
if _, found := static.Stylesheets[stylesheet]; !found {
|
if _, found := static.Stylesheets[stylesheet]; !found {
|
||||||
html.NotFound(w)
|
html.NotFound(w)
|
||||||
return
|
return
|
||||||
|
|
|
@ -32,7 +32,7 @@ func (c *Controller) Bookmarklet(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
bookmarkletURL := request.QueryParam(r, "uri", "")
|
bookmarkletURL := request.QueryStringParam(r, "uri", "")
|
||||||
|
|
||||||
view.Set("form", form.SubscriptionForm{URL: bookmarkletURL})
|
view.Set("form", form.SubscriptionForm{URL: bookmarkletURL})
|
||||||
view.Set("categories", categories)
|
view.Set("categories", categories)
|
||||||
|
|
|
@ -30,12 +30,7 @@ func (c *Controller) EditUser(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := request.IntParam(r, "userID")
|
userID := request.RouteInt64Param(r, "userID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedUser, err := c.store.UserByID(userID)
|
selectedUser, err := c.store.UserByID(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
|
@ -26,12 +26,7 @@ func (c *Controller) RemoveUser(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := request.IntParam(r, "userID")
|
userID := request.RouteInt64Param(r, "userID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedUser, err := c.store.UserByID(userID)
|
selectedUser, err := c.store.UserByID(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
|
@ -30,12 +30,7 @@ func (c *Controller) UpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID, err := request.IntParam(r, "userID")
|
userID := request.RouteInt64Param(r, "userID")
|
||||||
if err != nil {
|
|
||||||
html.BadRequest(w, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
selectedUser, err := c.store.UserByID(userID)
|
selectedUser, err := c.store.UserByID(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
html.ServerError(w, err)
|
html.ServerError(w, err)
|
||||||
|
|
Loading…
Reference in a new issue