package main
import (
"database/sql"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"os/exec"
"strconv"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
)
func errorsWrap(err error, msg string) error {
if err == nil {
return nil
}
return fmt.Errorf("%s: %w", msg, err)
}
func (q *Question) shouldWrite(old *string, new string) bool {
if q == nil {
return false
}
if q.Record == "always" {
return true
}
if q.Record == "never" || (q.MaxHistory != nil && *q.MaxHistory == 0) {
return false
}
if old == nil {
return true
}
return *old != new
}
func (q *Question) shouldRespond(old *string, new string) bool {
if q == nil ||
q.Response == "" ||
q.ResponseWhen == "" ||
q.ResponseWhen == responseNever {
return false
}
if q.ResponseWhen == responseAlways {
return true
}
if q.ResponseWhen == responseOnChange {
return old == nil || *old != new
}
// e.g. some arbitrary value
return q.ResponseWhen == new
}
const (
recordAlways = "always"
recordOnChange = "on_change"
recordNever = "never"
)
const (
responseAlways = "always"
responseOnChange = "on_change"
responseNever = "never"
)
func (q *Question) Validate() error {
if q.Name == "" {
return errors.New("question name must not be empty")
}
if q.Question == "" {
return fmt.Errorf("question %q query must not be empty", q.Name)
}
if q.Record != recordAlways && q.Record != recordOnChange && q.Record != recordNever {
return fmt.Errorf("question %q record type must be one of {%s,%s,%s}", q.Name, recordAlways, recordOnChange, recordNever)
}
if q.ResponseWhen != "" && q.Response == "" {
return fmt.Errorf("question %q must have a response command if given a response when (was %q)", q.Name, q.ResponseWhen)
}
return nil
}
type Question struct {
// e.g "ip_addr"
Name string `json:"name"`
// in seconds
Frequency uint `json:"frequency"`
// in seconds
Delay uint `json:"delay"`
// query to be run, e.g. a shell command to run
Question string `json:"question"`
// always|on_change,never default on_change
Record string `json:"record"`
// Max number of answers to keep to a question
// negative means no limit
// 0 means do not enter answers in the db
// positive, each write will also clean up extra entries
// If null, defaults to -1
MaxHistory *int `json:"max_history"`
// Action to take in response to a new answer
// Will be invoked with question name, timestamp, old value, new value
Response string `json:"response"`
// What sorts of answers invoke a response
// e.g. always|on change|arbitrary|never string value
ResponseWhen string `json:"response_when"`
}
type Config struct {
DBName string
Questions string
Port uint
SkipTimer bool
EnableRun bool
}
var config *Config
func getConfig() *Config {
if config != nil {
return config
}
config = &Config{}
flag.StringVar(&config.DBName, "dbname", ":memory:", "file db")
flag.StringVar(&config.Questions, "questions", "", "File path to questions")
flag.UintVar(&config.Port, "port", 8080, "Application port to listen on")
flag.BoolVar(&config.SkipTimer, "skip-timer", false, "Whether to skip running jobs on the configured timer")
flag.BoolVar(&config.EnableRun, "enable-run", false, "Whether jobs can be manually run by hitting /run?name=<question_name>")
flag.Parse()
return config
}
const dbV1 = `
CREATE TABLE IF NOT EXISTS question (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name VARCHAR(64) NOT NULL,
UNIQUE (name)
);
CREATE TABLE IF NOT EXISTS answer (
id INTEGER PRIMARY KEY AUTOINCREMENT,
question_id INTEGER REFERENCES question NOT NULL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
answer VARCHAR(1024) NOT NULL,
UNIQUE(question_id, timestamp)
);
CREATE TABLE IF NOT EXISTS db_version (
version int NOT NULL,
dirty bool NOT NULL
);
INSERT INTO db_version (version, dirty) VALUES (1,false);
`
func openAndMigrateDB(log *slog.Logger, config *Config) (*sql.DB, error) {
db, err := sql.Open("sqlite3", config.DBName+"?_pragma=foreign_keys(1)")
if err != nil {
log.Error("could not open db", "err", err)
os.Exit(1)
}
err = db.Ping()
if err != nil {
log.Error("could not contact db", "err", err)
os.Exit(1)
}
return db, migrate(db)
}
func migrate(db *sql.DB) error {
row := db.QueryRow("SELECT version,dirty FROM db_version LIMIT 1;")
var dbVersion int
var dirty bool
err := row.Scan(&dbVersion, dirty)
if err != nil {
_, err = db.Exec(dbV1)
if err != nil {
return err
}
}
row = db.QueryRow("SELECT version,dirty FROM db_version LIMIT 1;")
err = row.Scan(&dbVersion, dirty)
if dbVersion == 1 {
if dirty {
return fmt.Errorf("DB schema dirty, please check and clear manually")
}
return nil
} else {
return fmt.Errorf("Unexpected db version %d", dbVersion)
}
}
func (q *Question) runResponse(log *slog.Logger, lastAnswer *string, answer string, timestamp int64) {
commandAndArgs := strings.Split(q.Response, " ")
args := []string{}
if len(commandAndArgs) > 1 {
args = append(args, commandAndArgs[1:]...)
}
args = append(args, q.Name, strconv.Itoa(int(timestamp)), denull(lastAnswer), answer)
cmd := exec.Command(commandAndArgs[0], args...)
cmd.Stdout = io.Discard
cmd.Stderr = io.Discard
err := cmd.Run()
if err != nil {
log.Error("error responding", "err", err)
}
}
func runQuestion(log *slog.Logger, question *Question, db *sql.DB) (string, error) {
commandAndArgs := strings.Split(question.Question, " ")
cmd := exec.Command(commandAndArgs[0], commandAndArgs[1:]...)
var stdout strings.Builder
var stderr strings.Builder
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err := cmd.Run()
if err != nil {
log.Error("error running command", "err", err)
return "", errorsWrap(err, "error running command")
}
answerTime := time.Now().Unix()
answer := strings.TrimSpace(stdout.String())
qId, err := getQuestionId(db, question.Name)
if err != nil {
return "", errorsWrap(err, fmt.Sprintf("question %q does not exist in db", question.Name))
}
lastAnswerRow := db.QueryRow(`SELECT answer FROM answer
JOIN question on answer.question_id = ? ORDER BY timestamp DESC LIMIT 1;`, qId)
if err = lastAnswerRow.Err(); err != nil && !errors.Is(err, sql.ErrNoRows) {
return "", errorsWrap(err, "could not find last answer")
}
var lastAnswer *string
err = lastAnswerRow.Scan(&lastAnswer)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return "", errorsWrap(err, "could not scan last answer")
}
if question.shouldWrite(lastAnswer, answer) {
cutoffPoint := min(len(answer), 1024)
truncatedAnswer := answer[0:cutoffPoint]
tx, err := db.Begin()
if err != nil {
return "", errorsWrap(err, "could not insert answer")
}
_, err = tx.Exec(
`INSERT INTO answer (question_id, answer) VALUES (?, ?);`,
qId, truncatedAnswer,
)
if err != nil {
tx.Rollback()
return "", errorsWrap(err, "could not insert answer")
}
if question.MaxHistory != nil && *question.MaxHistory > 0 {
result, err := tx.Exec(`DELETE FROM answer
WHERE id NOT IN (
SELECT id FROM answer
WHERE question_id = ?
ORDER BY timestamp DESC
LIMIT ?
)
AND question_id = ?;`,
qId, qId, *question.MaxHistory,
)
if err != nil {
tx.Rollback()
return "", errorsWrap(err, fmt.Sprintf("could not delete old answers to %q", question.Name))
}
numDeleted, err := result.RowsAffected()
if err == nil {
log.Info("rows deleted", "num", numDeleted)
}
}
err = tx.Commit()
if err != nil {
return "", errorsWrap(err, fmt.Sprintf("could not complete save to db for question %q", question.Name))
}
}
if question.shouldRespond(lastAnswer, answer) {
go func() {
question.runResponse(log, lastAnswer, answer, answerTime)
}()
}
return answer, nil
}
func loadQuestions(config *Config) ([]Question, map[string]int, error) {
f, err := os.Open(config.Questions)
if err != nil {
return nil, nil, err
}
questionsBytes, err := io.ReadAll(f)
if err != nil {
return nil, nil, err
}
questions := []Question{}
questionsMap := map[string]int{}
err = json.Unmarshal(questionsBytes, &questions)
if err != nil {
return nil, nil, err
}
for i, q := range questions {
if err = q.Validate(); err != nil {
return nil, nil, err
}
if _, seen := questionsMap[q.Name]; seen {
return nil, nil, fmt.Errorf("question %s already seen at index %d", q.Name, i)
}
questionsMap[q.Name] = i
}
return questions, questionsMap, nil
}
func getQuestionId(db *sql.DB, questionName string) (int, error) {
var id int
row := db.QueryRow("SELECT id FROM question WHERE question.name=? LIMIT 1;", questionName)
err := row.Scan(&id)
return id, err
}
func denull[T any](t *T) T {
if t == nil {
return *new(T)
}
return *t
}
func toPlaceholders(length int, pattern string) string {
if length == 0 {
return ""
}
placeholders := strings.Repeat(fmt.Sprintf("%s,", pattern), length)[0 : (length*(len(pattern)+1))-1]
return placeholders
}
func bootstrapQuestions(log *slog.Logger, db *sql.DB, questions []Question, questionsMap map[string]int) error {
questionNames := []any{}
for questionName := range questionsMap {
questionNames = append(questionNames, questionName)
}
rows, err := db.Query(fmt.Sprintf(`SELECT name FROM question WHERE name in (%s);`, toPlaceholders(len(questionsMap), "?")), questionNames...)
if err != nil {
return err
}
if err := rows.Err(); err != nil {
return err
}
defer rows.Close()
toAdd := []any{}
alreadyInDb := map[string]struct{}{}
for rows.Next() {
var str string
err = rows.Scan(&str)
if err != nil {
return err
}
alreadyInDb[str] = struct{}{}
}
for questionName := range questionsMap {
if _, alreadyAdded := alreadyInDb[questionName]; !alreadyAdded {
toAdd = append(toAdd, questionName)
}
}
if len(toAdd) == 0 {
return nil
}
query := fmt.Sprintf(`INSERT INTO question (name) VALUES %s;`, toPlaceholders(len(toAdd), "(?)"))
result, err := db.Exec(query, toAdd...)
if err != nil {
return err
}
numRows, err := result.RowsAffected()
if err != nil {
return err
}
if numRows > 0 {
log.Info("added questions", "number_added", numRows)
}
return nil
}
var hQuestion = []byte("id,name\n")
var hAnswer = []byte("id,question.name,timestamp,answer\n")
var hVersion = []byte("version,dirty\n")
func runQuestionsTimer(questions []Question, log *slog.Logger, db *sql.DB) {
for _, question := range questions {
question := question
go func() {
if question.Delay > 0 {
time.Sleep(time.Second * time.Duration(question.Delay))
}
for {
_, err := runQuestion(log, &question, db)
if err != nil {
log.Error("Error running question", "question", question.Name, "err", err)
}
time.Sleep(time.Second * time.Duration(question.Frequency))
}
}()
}
}
func main() {
log := slog.New(slog.NewJSONHandler(os.Stderr, nil))
config := getConfig()
log.Info("config as read", "values", config)
db, err := openAndMigrateDB(log, config)
if err != nil {
log.Error("Error opening or migrating db", "err", err)
os.Exit(1)
}
questions, questionsMap, err := loadQuestions(config)
if err != nil {
log.Error("Error reading questions", "err", err)
os.Exit(1)
}
err = bootstrapQuestions(log, db, questions, questionsMap)
if err != nil {
log.Error("Error bootstrapping questions", "err", err)
os.Exit(1)
}
if !config.SkipTimer {
go runQuestionsTimer(questions, log, db)
}
http.HandleFunc("/registeredquestions", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
questionNames := []string{}
for _, q := range questions {
questionNames = append(questionNames, q.Name)
}
bQuestionNames, err := json.Marshal(questionNames)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write(bQuestionNames)
}))
http.HandleFunc("/dump", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
rows, err := db.Query("SELECT id,name FROM question;")
if err != nil {
log.Error("bad question select", "err", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write(hQuestion)
for rows.Next() {
var id int
var name string
err = rows.Scan(&id, &name)
if err != nil {
log.Error("bad answer scan", "err", err)
return
}
w.Write([]byte(fmt.Sprintf("%d,%s\n", id, name)))
}
rows, err = db.Query(`SELECT a.id,q.name,a.timestamp,a.answer FROM answer AS a
LEFT JOIN question AS q ON a.question_id = q.id
ORDER BY q.name ASC, a.timestamp ASC
;`)
if err != nil {
log.Error("bad answer select", "err", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write(hAnswer)
for rows.Next() {
var id int
var qName string
var timestamp time.Time
var answer string
err = rows.Scan(&id, &qName, ×tamp, &answer)
if err != nil {
log.Error("bad answer scan", "err", err)
return
}
w.Write([]byte(fmt.Sprintf("%d,%s,%v,%s\n", id, qName, timestamp, answer)))
}
rows, err = db.Query("SELECT version,dirty FROM db_version;")
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write(hVersion)
for rows.Next() {
var version int
var dirty bool
err = rows.Scan(&version, &dirty)
if err != nil {
log.Error("bad version scan", "err", err)
return
}
w.Write([]byte(fmt.Sprintf("%d,%v\n", version, dirty)))
}
}))
if config.EnableRun {
http.HandleFunc("/run", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
qName := q.Get("name")
idx, ok := questionsMap[qName]
if !ok {
log.Info("404", "question", qName)
w.WriteHeader(http.StatusBadRequest)
return
}
question := questions[idx]
answer, err := runQuestion(log, &question, db)
if err != nil {
log.Error("failure to run question", "err", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write([]byte(answer))
}))
}
http.HandleFunc("/version", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var version int
row := db.QueryRow("SELECT version FROM db_version LIMIT 1;")
err := row.Scan(&version)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
return
}
w.Write([]byte(fmt.Sprintf("%d\n", version)))
}))
http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil)
}