diff --git a/handlers.go b/handlers.go index c53bb5a..2d91a59 100644 --- a/handlers.go +++ b/handlers.go @@ -97,13 +97,13 @@ func registerHandlers(e *echo.Echo) { e.DELETE("/api/templates/:id", handleDeleteTemplate) // Subscriber facing views. - e.GET("/subscription/:campUUID/:subUUID", validateUUID(handleSubscriptionPage, + e.GET("/subscription/:campUUID/:subUUID", validateUUID(subscriberExists(handleSubscriptionPage), "campUUID", "subUUID")) - e.POST("/subscription/:campUUID/:subUUID", validateUUID(handleSubscriptionPage, + e.POST("/subscription/:campUUID/:subUUID", validateUUID(subscriberExists(handleSubscriptionPage), "campUUID", "subUUID")) - e.POST("/subscription/export/:subUUID", validateUUID(handleSelfExportSubscriberData, + e.POST("/subscription/export/:subUUID", validateUUID(subscriberExists(handleSelfExportSubscriberData), "subUUID")) - e.POST("/subscription/wipe/:subUUID", validateUUID(handleWipeSubscriberData, + e.POST("/subscription/wipe/:subUUID", validateUUID(subscriberExists(handleWipeSubscriberData), "subUUID")) e.GET("/link/:linkUUID/:campUUID/:subUUID", validateUUID(handleLinkRedirect, "linkUUID", "campUUID", "subUUID")) @@ -136,7 +136,7 @@ func handleIndexPage(c echo.Context) error { return c.String(http.StatusOK, string(b)) } -// validateUUID validates the UUID string format for a given set of params. +// validateUUID middleware validates the UUID string format for a given set of params. func validateUUID(next echo.HandlerFunc, params ...string) echo.HandlerFunc { return func(c echo.Context) error { for _, p := range params { @@ -150,6 +150,32 @@ func validateUUID(next echo.HandlerFunc, params ...string) echo.HandlerFunc { } } +// subscriberExists middleware checks if a subscriber exists given the UUID +// param in a request. +func subscriberExists(next echo.HandlerFunc, params ...string) echo.HandlerFunc { + return func(c echo.Context) error { + var ( + app = c.Get("app").(*App) + subUUID = c.Param("subUUID") + ) + + var exists bool + if err := app.Queries.SubscriberExists.Get(&exists, 0, subUUID); err != nil { + app.Logger.Printf("error checking subscriber existence: %v", err) + return c.Render(http.StatusInternalServerError, "message", + makeMsgTpl("Error", "", + `Error processing request. Please retry.`)) + } + + if !exists { + return c.Render(http.StatusBadRequest, "message", + makeMsgTpl("Not found", "", + `Subscription not found.`)) + } + return next(c) + } +} + // getPagination takes form values and extracts pagination values from it. func getPagination(q url.Values) pagination { var (