package main
import (
	"database/sql"
	"encoding/json"
	"errors"
	"flag"
	"fmt"
	"io"
	"log/slog"
	"net/http"
	"os"
	"os/exec"
	"strconv"
	"strings"
	"time"
	_ "github.com/mattn/go-sqlite3"
)
func errorsWrap(err error, msg string) error {
	if err == nil {
		return nil
	}
	return fmt.Errorf("%s: %w", msg, err)
}
func (q *Question) shouldWrite(old *string, new string) bool {
	if q == nil {
		return false
	}
	if q.Record == "always" {
		return true
	}
	if q.Record == "never" || (q.MaxHistory != nil && *q.MaxHistory == 0) {
		return false
	}
	if old == nil {
		return true
	}
	return *old != new
}
func (q *Question) shouldRespond(old *string, new string) bool {
	if q == nil ||
		q.Response == "" ||
		q.ResponseWhen == "" ||
		q.ResponseWhen == responseNever {
		return false
	}
	if q.ResponseWhen == responseAlways {
		return true
	}
	if q.ResponseWhen == responseOnChange {
		return old == nil || *old != new
	}
	// e.g. some arbitrary value
	return q.ResponseWhen == new
}
const (
	recordAlways   = "always"
	recordOnChange = "on_change"
	recordNever    = "never"
)
const (
	responseAlways   = "always"
	responseOnChange = "on_change"
	responseNever    = "never"
)
func (q *Question) Validate() error {
	if q.Name == "" {
		return errors.New("question name must not be empty")
	}
	if q.Question == "" {
		return fmt.Errorf("question %q query must not be empty", q.Name)
	}
	if q.Record != recordAlways && q.Record != recordOnChange && q.Record != recordNever {
		return fmt.Errorf("question %q record type must be one of {%s,%s,%s}", q.Name, recordAlways, recordOnChange, recordNever)
	}
	if q.ResponseWhen != "" && q.Response == "" {
		return fmt.Errorf("question %q must have a response command if given a response when (was %q)", q.Name, q.ResponseWhen)
	}
	return nil
}
type Question struct {
	// e.g "ip_addr"
	Name string `json:"name"`
	// in seconds
	Frequency uint `json:"frequency"`
	// in seconds
	Delay uint `json:"delay"`
	// query to be run, e.g. a shell command to run
	Question string `json:"question"`
	// always|on_change,never default on_change
	Record string `json:"record"`
	// Max number of answers to keep to a question
	// negative means no limit
	// 0 means do not enter answers in the db
	// positive, each write will also clean up extra entries
	// If null, defaults to -1
	MaxHistory *int `json:"max_history"`
	// Action to take in response to a new answer
	// Will be invoked with question name, timestamp, old value, new value
	Response string `json:"response"`
	// What sorts of answers invoke a response
	// e.g. always|on change|arbitrary|never string value
	ResponseWhen string `json:"response_when"`
}
type Config struct {
	DBName    string
	Questions string
	Port      uint
	SkipTimer bool
	EnableRun bool
}
var config *Config
func getConfig() *Config {
	if config != nil {
		return config
	}
	config = &Config{}
	flag.StringVar(&config.DBName, "dbname", ":memory:", "file db")
	flag.StringVar(&config.Questions, "questions", "", "File path to questions")
	flag.UintVar(&config.Port, "port", 8080, "Application port to listen on")
	flag.BoolVar(&config.SkipTimer, "skip-timer", false, "Whether to skip running jobs on the configured timer")
	flag.BoolVar(&config.EnableRun, "enable-run", false, "Whether jobs can be manually run by hitting /run?name=<question_name>")
	flag.Parse()
	return config
}
const dbV1 = `
CREATE TABLE IF NOT EXISTS question (
	id     INTEGER PRIMARY KEY AUTOINCREMENT,
	name   VARCHAR(64) NOT NULL,
	UNIQUE (name)
);
CREATE TABLE IF NOT EXISTS answer (
	id          INTEGER PRIMARY KEY AUTOINCREMENT,
	question_id INTEGER REFERENCES question NOT NULL,
	timestamp   TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
	answer      VARCHAR(1024) NOT NULL,
	UNIQUE(question_id, timestamp)
);
CREATE TABLE IF NOT EXISTS db_version (
	version int	 NOT NULL,
	dirty   bool NOT NULL
);
INSERT INTO db_version (version, dirty) VALUES (1,false);
`
func openAndMigrateDB(log *slog.Logger, config *Config) (*sql.DB, error) {
	db, err := sql.Open("sqlite3", config.DBName+"?_pragma=foreign_keys(1)")
	if err != nil {
		log.Error("could not open db", "err", err)
		os.Exit(1)
	}
	err = db.Ping()
	if err != nil {
		log.Error("could not contact db", "err", err)
		os.Exit(1)
	}
	return db, migrate(db)
}
func migrate(db *sql.DB) error {
	row := db.QueryRow("SELECT version,dirty FROM db_version LIMIT 1;")
	var dbVersion int
	var dirty bool
	err := row.Scan(&dbVersion, dirty)
	if err != nil {
		_, err = db.Exec(dbV1)
		if err != nil {
			return err
		}
	}
	row = db.QueryRow("SELECT version,dirty FROM db_version LIMIT 1;")
	err = row.Scan(&dbVersion, dirty)
	if dbVersion == 1 {
		if dirty {
			return fmt.Errorf("DB schema dirty, please check and clear manually")
		}
		return nil
	} else {
		return fmt.Errorf("Unexpected db version %d", dbVersion)
	}
}
func (q *Question) runResponse(log *slog.Logger, lastAnswer *string, answer string, timestamp int64) {
	commandAndArgs := strings.Split(q.Response, " ")
	args := []string{}
	if len(commandAndArgs) > 1 {
		args = append(args, commandAndArgs[1:]...)
	}
	args = append(args, q.Name, strconv.Itoa(int(timestamp)), denull(lastAnswer), answer)
	cmd := exec.Command(commandAndArgs[0], args...)
	cmd.Stdout = io.Discard
	cmd.Stderr = io.Discard
	err := cmd.Run()
	if err != nil {
		log.Error("error responding", "err", err)
	}
}
func runQuestion(log *slog.Logger, question *Question, db *sql.DB) (string, error) {
	commandAndArgs := strings.Split(question.Question, " ")
	cmd := exec.Command(commandAndArgs[0], commandAndArgs[1:]...)
	var stdout strings.Builder
	var stderr strings.Builder
	cmd.Stdout = &stdout
	cmd.Stderr = &stderr
	err := cmd.Run()
	if err != nil {
		log.Error("error running command", "err", err)
		return "", errorsWrap(err, "error running command")
	}
	answerTime := time.Now().Unix()
	answer := strings.TrimSpace(stdout.String())
	qId, err := getQuestionId(db, question.Name)
	if err != nil {
		return "", errorsWrap(err, fmt.Sprintf("question %q does not exist in db", question.Name))
	}
	lastAnswerRow := db.QueryRow(`SELECT answer FROM answer
	JOIN question on answer.question_id = ?  ORDER BY timestamp DESC LIMIT 1;`, qId)
	if err = lastAnswerRow.Err(); err != nil && !errors.Is(err, sql.ErrNoRows) {
		return "", errorsWrap(err, "could not find last answer")
	}
	var lastAnswer *string
	err = lastAnswerRow.Scan(&lastAnswer)
	if err != nil && !errors.Is(err, sql.ErrNoRows) {
		return "", errorsWrap(err, "could not scan last answer")
	}
	if question.shouldWrite(lastAnswer, answer) {
		cutoffPoint := min(len(answer), 1024)
		truncatedAnswer := answer[0:cutoffPoint]
		tx, err := db.Begin()
		if err != nil {
			return "", errorsWrap(err, "could not insert answer")
		}
		_, err = tx.Exec(
			`INSERT INTO answer (question_id, answer) VALUES (?, ?);`,
			qId, truncatedAnswer,
		)
		if err != nil {
			tx.Rollback()
			return "", errorsWrap(err, "could not insert answer")
		}
		if question.MaxHistory != nil && *question.MaxHistory > 0 {
			result, err := tx.Exec(`DELETE FROM answer
	WHERE id NOT IN (
	    SELECT id FROM answer
		WHERE question_id = ?
	    ORDER BY timestamp DESC
	    LIMIT ?
	)
	AND question_id = ?;`,
				qId, qId, *question.MaxHistory,
			)
			if err != nil {
				tx.Rollback()
				return "", errorsWrap(err, fmt.Sprintf("could not delete old answers to %q", question.Name))
			}
			numDeleted, err := result.RowsAffected()
			if err == nil {
				log.Info("rows deleted", "num", numDeleted)
			}
		}
		err = tx.Commit()
		if err != nil {
			return "", errorsWrap(err, fmt.Sprintf("could not complete save to db for question %q", question.Name))
		}
	}
	if question.shouldRespond(lastAnswer, answer) {
		go func() {
			question.runResponse(log, lastAnswer, answer, answerTime)
		}()
	}
	return answer, nil
}
func loadQuestions(config *Config) ([]Question, map[string]int, error) {
	f, err := os.Open(config.Questions)
	if err != nil {
		return nil, nil, err
	}
	questionsBytes, err := io.ReadAll(f)
	if err != nil {
		return nil, nil, err
	}
	questions := []Question{}
	questionsMap := map[string]int{}
	err = json.Unmarshal(questionsBytes, &questions)
	if err != nil {
		return nil, nil, err
	}
	for i, q := range questions {
		if err = q.Validate(); err != nil {
			return nil, nil, err
		}
		if _, seen := questionsMap[q.Name]; seen {
			return nil, nil, fmt.Errorf("question %s already seen at index %d", q.Name, i)
		}
		questionsMap[q.Name] = i
	}
	return questions, questionsMap, nil
}
func getQuestionId(db *sql.DB, questionName string) (int, error) {
	var id int
	row := db.QueryRow("SELECT id FROM question WHERE question.name=? LIMIT 1;", questionName)
	err := row.Scan(&id)
	return id, err
}
func denull[T any](t *T) T {
	if t == nil {
		return *new(T)
	}
	return *t
}
func toPlaceholders(length int, pattern string) string {
	if length == 0 {
		return ""
	}
	placeholders := strings.Repeat(fmt.Sprintf("%s,", pattern), length)[0 : (length*(len(pattern)+1))-1]
	return placeholders
}
func bootstrapQuestions(log *slog.Logger, db *sql.DB, questions []Question, questionsMap map[string]int) error {
	questionNames := []any{}
	for questionName := range questionsMap {
		questionNames = append(questionNames, questionName)
	}
	rows, err := db.Query(fmt.Sprintf(`SELECT name FROM question WHERE name in (%s);`, toPlaceholders(len(questionsMap), "?")), questionNames...)
	if err != nil {
		return err
	}
	if err := rows.Err(); err != nil {
		return err
	}
	defer rows.Close()
	toAdd := []any{}
	alreadyInDb := map[string]struct{}{}
	for rows.Next() {
		var str string
		err = rows.Scan(&str)
		if err != nil {
			return err
		}
		alreadyInDb[str] = struct{}{}
	}
	for questionName := range questionsMap {
		if _, alreadyAdded := alreadyInDb[questionName]; !alreadyAdded {
			toAdd = append(toAdd, questionName)
		}
	}
	if len(toAdd) == 0 {
		return nil
	}
	query := fmt.Sprintf(`INSERT INTO question (name) VALUES %s;`, toPlaceholders(len(toAdd), "(?)"))
	result, err := db.Exec(query, toAdd...)
	if err != nil {
		return err
	}
	numRows, err := result.RowsAffected()
	if err != nil {
		return err
	}
	if numRows > 0 {
		log.Info("added questions", "number_added", numRows)
	}
	return nil
}
var hQuestion = []byte("id,name\n")
var hAnswer = []byte("id,question.name,timestamp,answer\n")
var hVersion = []byte("version,dirty\n")
func runQuestionsTimer(questions []Question, log *slog.Logger, db *sql.DB) {
	for _, question := range questions {
		question := question
		go func() {
			if question.Delay > 0 {
				time.Sleep(time.Second * time.Duration(question.Delay))
			}
			for {
				_, err := runQuestion(log, &question, db)
				if err != nil {
					log.Error("Error running question", "question", question.Name, "err", err)
				}
				time.Sleep(time.Second * time.Duration(question.Frequency))
			}
		}()
	}
}
func main() {
	log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
	config := getConfig()
	log.Info("config as read", "values", config)
	db, err := openAndMigrateDB(log, config)
	if err != nil {
		log.Error("Error opening or migrating db", "err", err)
		os.Exit(1)
	}
	questions, questionsMap, err := loadQuestions(config)
	if err != nil {
		log.Error("Error reading questions", "err", err)
		os.Exit(1)
	}
	err = bootstrapQuestions(log, db, questions, questionsMap)
	if err != nil {
		log.Error("Error bootstrapping questions", "err", err)
		os.Exit(1)
	}
	if !config.SkipTimer {
		go runQuestionsTimer(questions, log, db)
	}
	http.HandleFunc("/registeredquestions", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		questionNames := []string{}
		for _, q := range questions {
			questionNames = append(questionNames, q.Name)
		}
		bQuestionNames, err := json.Marshal(questionNames)
		if err != nil {
			w.WriteHeader(http.StatusInternalServerError)
			return
		}
		w.Write(bQuestionNames)
	}))
	http.HandleFunc("/dump", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		rows, err := db.Query("SELECT id,name FROM question;")
		if err != nil {
			log.Error("bad question select", "err", err)
			w.WriteHeader(http.StatusInternalServerError)
			return
		}
		w.Write(hQuestion)
		for rows.Next() {
			var id int
			var name string
			err = rows.Scan(&id, &name)
			if err != nil {
				log.Error("bad answer scan", "err", err)
				return
			}
			w.Write([]byte(fmt.Sprintf("%d,%s\n", id, name)))
		}
		rows, err = db.Query(`SELECT a.id,q.name,a.timestamp,a.answer FROM answer AS a
LEFT JOIN question AS q ON a.question_id = q.id
ORDER BY q.name ASC, a.timestamp ASC
;`)
		if err != nil {
			log.Error("bad answer select", "err", err)
			w.WriteHeader(http.StatusInternalServerError)
			return
		}
		w.Write(hAnswer)
		for rows.Next() {
			var id int
			var qName string
			var timestamp time.Time
			var answer string
			err = rows.Scan(&id, &qName, ×tamp, &answer)
			if err != nil {
				log.Error("bad answer scan", "err", err)
				return
			}
			w.Write([]byte(fmt.Sprintf("%d,%s,%v,%s\n", id, qName, timestamp, answer)))
		}
		rows, err = db.Query("SELECT version,dirty FROM db_version;")
		if err != nil {
			w.WriteHeader(http.StatusInternalServerError)
			return
		}
		w.Write(hVersion)
		for rows.Next() {
			var version int
			var dirty bool
			err = rows.Scan(&version, &dirty)
			if err != nil {
				log.Error("bad version scan", "err", err)
				return
			}
			w.Write([]byte(fmt.Sprintf("%d,%v\n", version, dirty)))
		}
	}))
	if config.EnableRun {
		http.HandleFunc("/run", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
			q := r.URL.Query()
			qName := q.Get("name")
			idx, ok := questionsMap[qName]
			if !ok {
				log.Info("404", "question", qName)
				w.WriteHeader(http.StatusBadRequest)
				return
			}
			question := questions[idx]
			answer, err := runQuestion(log, &question, db)
			if err != nil {
				log.Error("failure to run question", "err", err)
				w.WriteHeader(http.StatusInternalServerError)
				return
			}
			w.Write([]byte(answer))
		}))
	}
	http.HandleFunc("/version", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		var version int
		row := db.QueryRow("SELECT version FROM db_version LIMIT 1;")
		err := row.Scan(&version)
		if err != nil {
			w.WriteHeader(http.StatusInternalServerError)
			return
		}
		w.Write([]byte(fmt.Sprintf("%d\n", version)))
	}))
	http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil)
}