package main

import (
	"bytes"
	"crypto/rand"
	"crypto/sha256"
	"database/sql"
	_ "embed"
	"encoding/base64"
	"encoding/hex"
	"errors"
	"fmt"
	"html/template"
	"net/http"
	"os"
	"path/filepath"
	"strings"
	"time"

	"github.com/labstack/echo/v4"
	"github.com/labstack/echo/v4/middleware"
	_ "github.com/mattn/go-sqlite3"
)

var (
	flag   string
	secret []byte
	db     *sql.DB

	indexTmpl    = "templates/index.html"
	loginTmpl    = "templates/login.html"
	registerTmpl = "templates/register.html"
	baseTmpl     = "templates/base.html"

	dsn = "file:file.db"
)

var createTableQuery = `
CREATE TABLE users (
	id INTEGER PRIMARY KEY AUTOINCREMENT,
	login TEXT NOT NULL,
	name TEXT,
	family TEXT,
	password TEXT NOT NULL
);`

var insertUserSQL = `INSERT INTO users (login, password, name, family) VALUES (?, ?, ?, ?)`

var ErrInvalidSignature = errors.New("invalid signature")

const lowerOnly = "abcdefghijklmnopqrstuvwxyz"
const lowerAndUpper = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"

type IndexRenderContext struct {
	Login string
	Flag  string
}

func randString(length int) (string, error) {
	const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"

	b := make([]byte, length)
	n, err := rand.Read(b)
	if err != nil || n != length {
		return "", err
	}

	for i := range b {
		b[i] = charset[int(b[i])%len(charset)]
	}

	return string(b), nil
}

func randBytes(length int) ([]byte, error) {
	secret := make([]byte, length)
	n, err := rand.Read(secret)
	if err != nil || n != length {
		return nil, err
	}

	return secret, nil
}

func initDB() error {
	var err error
	db, err = sql.Open("sqlite3", dsn)
	if err != nil {
		return err
	}

	_, err = db.Exec(createTableQuery)
	if err != nil {
		return err
	}

	password, err := randString(32)
	if err != nil {
		return err
	}

	if _, err := db.Exec(insertUserSQL, "admin", password, "kekker", "kekkerovich"); err != nil {
		return fmt.Errorf("failed to execute insert query: %v", err)
	}

	return nil
}

func main() {
	var err error
	secret, err = randBytes(32)
	if err != nil {
		fmt.Fprintf(os.Stderr, "failed to read secret: %v", err)
		os.Exit(1)
	}

	flagData, err := os.ReadFile("flag.txt")
	if err != nil {
		fmt.Fprintf(os.Stderr, "failed to read flag: %v", err)
		os.Exit(1)
	}

	flag = string(flagData)

	e := echo.New()

	if err := initDB(); err != nil {
		fmt.Fprintf(os.Stderr, "failed to init database: %v", err)
		os.Exit(1)
	}

	defer db.Close()

	e.Use(middleware.Logger())
	e.Use(middleware.Recover())

	e.GET("/", index)
	e.GET("/login", login)
	e.POST("/login", loginPOST)
	e.GET("/register", register)
	e.POST("/register", registerPOST)
	e.GET("/logout", LogoutPage)
	e.GET("/sources", sources)

	e.Logger.Fatal(e.Start(":1234"))
}

func render(c echo.Context, tmpl string, data any) error {
	buf := &bytes.Buffer{}
	t := template.Must(template.New(filepath.Base(baseTmpl)).ParseFiles(baseTmpl, tmpl))

	if err := t.Execute(buf, data); err != nil {
		return fmt.Errorf("failed to execute template: %v", err)
	}

	return c.Stream(http.StatusOK, "text/html", buf)
}

func getCheckedLogin(cookie *http.Cookie) (string, error) {
	if cookie.Value == "" {
		return "", errors.New("empty cookie")
	}

	d, err := base64.URLEncoding.DecodeString(cookie.Value)
	if err != nil {
		return "", fmt.Errorf("base64 decode failed: %v", err)
	}

	parts := strings.Split(string(d), ":")
	if len(parts) != 4 {
		return "", fmt.Errorf("invalid number of parts: %d", len(parts))
	}

	loginHex := parts[0]
	nameHex := parts[1]
	familyHex := parts[2]
	signature := parts[3]

	name, err := hex.DecodeString(nameHex)
	if err != nil {
		return "", fmt.Errorf("failed to decode name: %v", err)
	}

	family, err := hex.DecodeString(familyHex)
	if err != nil {
		return "", fmt.Errorf("failed to decode family: %v", err)
	}

	login, err := hex.DecodeString(loginHex)
	if err != nil {
		return "", fmt.Errorf("failed to decode login: %v", err)
	}

	sha := sha256.New()
	if _, err := sha.Write(secret); err != nil {
		return "", fmt.Errorf("sha write failed: %v", err)
	}

	if _, err := sha.Write(append(name, ':')); err != nil {
		return "", fmt.Errorf("sha write failed: %v", err)
	}

	if _, err := sha.Write(append(family, ':')); err != nil {
		return "", fmt.Errorf("sha write failed: %v", err)
	}

	if _, err := sha.Write(login); err != nil {
		return "", fmt.Errorf("sha write failed: %v", err)
	}

	s := hex.EncodeToString(sha.Sum(nil))
	if s != signature {
		return "", ErrInvalidSignature
	}

	return string(login), nil
}

// "stateless" storage for all user info in one cookie
func setLoginCookie(c echo.Context, login, name, family string) {
	sha := sha256.New()
	if _, err := sha.Write(secret); err != nil {
		c.Logger().Errorf("sha write failed: %v", err)
		return
	}

	shaName := append([]byte(name), ':')
	if _, err := sha.Write(shaName); err != nil {
		c.Logger().Errorf("sha write failed: %v", err)
		return
	}

	shaFamily := append([]byte(family), ':')
	if _, err := sha.Write(shaFamily); err != nil {
		c.Logger().Errorf("sha write failed: %v", err)
		return
	}

	if _, err := sha.Write([]byte(login)); err != nil {
		c.Logger().Errorf("sha write failed: %v", err)
		return
	}

	// hex encode everythin to avoid confusion from user with special symbols, like ':'
	digest := sha.Sum(nil)
	value := fmt.Sprintf("%s:%s:%s:%s",
		hex.EncodeToString([]byte(login)),
		hex.EncodeToString([]byte(name)),
		hex.EncodeToString([]byte(family)),
		hex.EncodeToString([]byte(digest)),
	)

	c.SetCookie(&http.Cookie{
		Name:   "session",
		Value:  base64.URLEncoding.EncodeToString([]byte(value)),
		MaxAge: 3600,
	})
}

func isLoggedIn(c echo.Context) (bool, error) {
	cookie, err := c.Cookie("session")
	if err != nil {
		return false, nil
	}

	if _, err := getCheckedLogin(cookie); err != nil {
		return false, err
	}

	return true, nil
}

func register(c echo.Context) error {
	ok, err := isLoggedIn(c)
	if err != nil {
		return c.String(http.StatusBadRequest, "")
	}

	if ok {
		return c.Redirect(http.StatusFound, "/")
	}

	return render(c, registerTmpl, nil)
}

func loginPOST(c echo.Context) error {
	loginParam := c.FormValue("login")
	passwordParam := c.FormValue("password")

	if loginParam == "" || passwordParam == "" {
		return c.Redirect(http.StatusFound, "/login")
	}

	results, err := db.Query("SELECT * FROM users WHERE login = ? AND password = ? LIMIT 1", loginParam, passwordParam)
	if err != nil {
		return fmt.Errorf("failed to execute select query: %v", err)
	}

	defer func() {
		_ = results.Close()
	}()

	if !results.Next() {
		return c.String(http.StatusUnauthorized, "Login failed!")
	}

	var id int
	var login string
	var name string
	var family string
	var password string

	if err := results.Scan(&id, &login, &name, &family, &password); err != nil {
		c.Logger().Error("failed to scan: ", err)
		return c.String(http.StatusUnauthorized, "Login failed!")
	}

	setLoginCookie(c, login, name, family)
	return c.Redirect(http.StatusFound, "/")
}

func registerPOST(c echo.Context) error {
	loginParam := c.FormValue("login")
	name := c.FormValue("name")
	family := c.FormValue("family")
	passwordParam := c.FormValue("password")
	confirmParam := c.FormValue("confirm-password")

	if passwordParam == "" {
		return c.Redirect(http.StatusFound, "/register")
	}

	l := strings.TrimSpace(loginParam)
	if l == "" {
		return c.Redirect(http.StatusFound, "/register")
	}

	check := func(s, allowed string) bool {
		for _, char := range s {
			if !strings.Contains(allowed, string(char)) {
				return false
			}
		}
		return true
	}

	if !check(loginParam, lowerOnly) || !check(name, lowerAndUpper) || !check(family, lowerAndUpper) {
		return c.String(http.StatusBadRequest, "bad login, name or family name!")
	}

	if passwordParam != confirmParam {
		return c.String(http.StatusBadRequest, "passwords do not match!")
	}

	results, err := db.Query("SELECT * FROM users WHERE login = ? LIMIT 1", l)
	if err != nil {
		return fmt.Errorf("failed to execute select query: %v", err)
	}

	defer func() {
		_ = results.Close()
	}()

	if results.Next() {
		return c.String(http.StatusBadRequest, "User already exists!")
	}

	_, err = db.Exec(insertUserSQL, l, passwordParam, name, family)
	if err != nil {
		return fmt.Errorf("")
	}

	setLoginCookie(c, l, name, family)
	return c.Redirect(http.StatusFound, "/")
}

func login(c echo.Context) error {
	ok, err := isLoggedIn(c)
	if err != nil {
		return c.String(http.StatusBadRequest, "")
	}

	if ok {
		return c.Redirect(http.StatusFound, "/")
	}

	return render(c, loginTmpl, nil)
}

func LogoutPage(c echo.Context) error {
	c.SetCookie(&http.Cookie{
		Name:    "session",
		Expires: time.Now().Add(-1 * time.Minute),
	})

	c.Redirect(http.StatusFound, "/")
	return nil
}

func sources(c echo.Context) error {
	return c.File("sources.zip")
}

func index(c echo.Context) error {
	rc := &IndexRenderContext{}

	cookie, err := c.Cookie("session")
	if err == nil {
		login, err := getCheckedLogin(cookie)
		if err != nil {
			c.Logger().Error("err", err)
			return render(c, indexTmpl, rc)
		}

		rc.Flag = flag
		rc.Login = login
	}
	return render(c, indexTmpl, rc)
}
