package main import ( "database/sql" "encoding/json" "errors" "flag" "fmt" "io" "log/slog" "net/http" "os" "os/exec" "strconv" "strings" "time" _ "github.com/mattn/go-sqlite3" ) func errorsWrap(err error, msg string) error { if err == nil { return nil } return fmt.Errorf("%s: %w", msg, err) } func (q *Question) shouldWrite(old *string, new string) bool { if q == nil { return false } if q.Record == "always" { return true } if q.Record == "never" || (q.MaxHistory != nil && *q.MaxHistory == 0) { return false } if old == nil { return true } return *old != new } func (q *Question) shouldRespond(old *string, new string) bool { if q == nil || q.Response == "" || q.ResponseWhen == "" || q.ResponseWhen == responseNever { return false } if q.ResponseWhen == responseAlways { return true } if q.ResponseWhen == responseOnChange { return old == nil || *old != new } // e.g. some arbitrary value return q.ResponseWhen == new } const ( recordAlways = "always" recordOnChange = "on_change" recordNever = "never" ) const ( responseAlways = "always" responseOnChange = "on_change" responseNever = "never" ) func (q *Question) Validate() error { if q.Name == "" { return errors.New("question name must not be empty") } if q.Question == "" { return fmt.Errorf("question %q query must not be empty", q.Name) } if q.Record != recordAlways && q.Record != recordOnChange && q.Record != recordNever { return fmt.Errorf("question %q record type must be one of {%s,%s,%s}", q.Name, recordAlways, recordOnChange, recordNever) } if q.ResponseWhen != "" && q.Response == "" { return fmt.Errorf("question %q must have a response command if given a response when (was %q)", q.Name, q.ResponseWhen) } return nil } type Question struct { // e.g "ip_addr" Name string `json:"name"` // in seconds Frequency uint `json:"frequency"` // in seconds Delay uint `json:"delay"` // query to be run, e.g. a shell command to run Question string `json:"question"` // always|on_change,never default on_change Record string `json:"record"` // Max number of answers to keep to a question // negative means no limit // 0 means do not enter answers in the db // positive, each write will also clean up extra entries // If null, defaults to -1 MaxHistory *int `json:"max_history"` // Action to take in response to a new answer // Will be invoked with question name, timestamp, old value, new value Response string `json:"response"` // What sorts of answers invoke a response // e.g. always|on change|arbitrary|never string value ResponseWhen string `json:"response_when"` } type Config struct { DBName string Questions string Port uint SkipTimer bool EnableRun bool } var config *Config func getConfig() *Config { if config != nil { return config } config = &Config{} flag.StringVar(&config.DBName, "dbname", ":memory:", "file db") flag.StringVar(&config.Questions, "questions", "", "File path to questions") flag.UintVar(&config.Port, "port", 8080, "Application port to listen on") flag.BoolVar(&config.SkipTimer, "skip-timer", false, "Whether to skip running jobs on the configured timer") flag.BoolVar(&config.EnableRun, "enable-run", false, "Whether jobs can be manually run by hitting /run?name=<question_name>") flag.Parse() return config } const dbV1 = ` CREATE TABLE IF NOT EXISTS question ( id INTEGER PRIMARY KEY AUTOINCREMENT, name VARCHAR(64) NOT NULL, UNIQUE (name) ); CREATE TABLE IF NOT EXISTS answer ( id INTEGER PRIMARY KEY AUTOINCREMENT, question_id INTEGER REFERENCES question NOT NULL, timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, answer VARCHAR(1024) NOT NULL, UNIQUE(question_id, timestamp) ); CREATE TABLE IF NOT EXISTS db_version ( version int NOT NULL, dirty bool NOT NULL ); INSERT INTO db_version (version, dirty) VALUES (1,false); ` func openAndMigrateDB(log *slog.Logger, config *Config) (*sql.DB, error) { db, err := sql.Open("sqlite3", config.DBName+"?_pragma=foreign_keys(1)") if err != nil { log.Error("could not open db", "err", err) os.Exit(1) } err = db.Ping() if err != nil { log.Error("could not contact db", "err", err) os.Exit(1) } return db, migrate(db) } func migrate(db *sql.DB) error { row := db.QueryRow("SELECT version,dirty FROM db_version LIMIT 1;") var dbVersion int var dirty bool err := row.Scan(&dbVersion, dirty) if err != nil { _, err = db.Exec(dbV1) if err != nil { return err } } row = db.QueryRow("SELECT version,dirty FROM db_version LIMIT 1;") err = row.Scan(&dbVersion, dirty) if dbVersion == 1 { if dirty { return fmt.Errorf("DB schema dirty, please check and clear manually") } return nil } else { return fmt.Errorf("Unexpected db version %d", dbVersion) } } func (q *Question) runResponse(log *slog.Logger, lastAnswer *string, answer string, timestamp int64) { commandAndArgs := strings.Split(q.Response, " ") args := []string{} if len(commandAndArgs) > 1 { args = append(args, commandAndArgs[1:]...) } args = append(args, q.Name, strconv.Itoa(int(timestamp)), denull(lastAnswer), answer) cmd := exec.Command(commandAndArgs[0], args...) cmd.Stdout = io.Discard cmd.Stderr = io.Discard err := cmd.Run() if err != nil { log.Error("error responding", "err", err) } } func runQuestion(log *slog.Logger, question *Question, db *sql.DB) (string, error) { commandAndArgs := strings.Split(question.Question, " ") cmd := exec.Command(commandAndArgs[0], commandAndArgs[1:]...) var stdout strings.Builder var stderr strings.Builder cmd.Stdout = &stdout cmd.Stderr = &stderr err := cmd.Run() if err != nil { log.Error("error running command", "err", err) return "", errorsWrap(err, "error running command") } answerTime := time.Now().Unix() answer := strings.TrimSpace(stdout.String()) qId, err := getQuestionId(db, question.Name) if err != nil { return "", errorsWrap(err, fmt.Sprintf("question %q does not exist in db", question.Name)) } lastAnswerRow := db.QueryRow(`SELECT answer FROM answer JOIN question on answer.question_id = ? ORDER BY timestamp DESC LIMIT 1;`, qId) if err = lastAnswerRow.Err(); err != nil && !errors.Is(err, sql.ErrNoRows) { return "", errorsWrap(err, "could not find last answer") } var lastAnswer *string err = lastAnswerRow.Scan(&lastAnswer) if err != nil && !errors.Is(err, sql.ErrNoRows) { return "", errorsWrap(err, "could not scan last answer") } if question.shouldWrite(lastAnswer, answer) { cutoffPoint := min(len(answer), 1024) truncatedAnswer := answer[0:cutoffPoint] tx, err := db.Begin() if err != nil { return "", errorsWrap(err, "could not insert answer") } _, err = tx.Exec( `INSERT INTO answer (question_id, answer) VALUES (?, ?);`, qId, truncatedAnswer, ) if err != nil { tx.Rollback() return "", errorsWrap(err, "could not insert answer") } if question.MaxHistory != nil && *question.MaxHistory > 0 { result, err := tx.Exec(`DELETE FROM answer WHERE id NOT IN ( SELECT id FROM answer WHERE question_id = ? ORDER BY timestamp DESC LIMIT ? ) AND question_id = ?;`, qId, qId, *question.MaxHistory, ) if err != nil { tx.Rollback() return "", errorsWrap(err, fmt.Sprintf("could not delete old answers to %q", question.Name)) } numDeleted, err := result.RowsAffected() if err == nil { log.Info("rows deleted", "num", numDeleted) } } err = tx.Commit() if err != nil { return "", errorsWrap(err, fmt.Sprintf("could not complete save to db for question %q", question.Name)) } } if question.shouldRespond(lastAnswer, answer) { go func() { question.runResponse(log, lastAnswer, answer, answerTime) }() } return answer, nil } func loadQuestions(config *Config) ([]Question, map[string]int, error) { f, err := os.Open(config.Questions) if err != nil { return nil, nil, err } questionsBytes, err := io.ReadAll(f) if err != nil { return nil, nil, err } questions := []Question{} questionsMap := map[string]int{} err = json.Unmarshal(questionsBytes, &questions) if err != nil { return nil, nil, err } for i, q := range questions { if err = q.Validate(); err != nil { return nil, nil, err } if _, seen := questionsMap[q.Name]; seen { return nil, nil, fmt.Errorf("question %s already seen at index %d", q.Name, i) } questionsMap[q.Name] = i } return questions, questionsMap, nil } func getQuestionId(db *sql.DB, questionName string) (int, error) { var id int row := db.QueryRow("SELECT id FROM question WHERE question.name=? LIMIT 1;", questionName) err := row.Scan(&id) return id, err } func denull[T any](t *T) T { if t == nil { return *new(T) } return *t } func toPlaceholders(length int, pattern string) string { if length == 0 { return "" } placeholders := strings.Repeat(fmt.Sprintf("%s,", pattern), length)[0 : (length*(len(pattern)+1))-1] return placeholders } func bootstrapQuestions(log *slog.Logger, db *sql.DB, questions []Question, questionsMap map[string]int) error { questionNames := []any{} for questionName := range questionsMap { questionNames = append(questionNames, questionName) } rows, err := db.Query(fmt.Sprintf(`SELECT name FROM question WHERE name in (%s);`, toPlaceholders(len(questionsMap), "?")), questionNames...) if err != nil { return err } if err := rows.Err(); err != nil { return err } defer rows.Close() toAdd := []any{} alreadyInDb := map[string]struct{}{} for rows.Next() { var str string err = rows.Scan(&str) if err != nil { return err } alreadyInDb[str] = struct{}{} } for questionName := range questionsMap { if _, alreadyAdded := alreadyInDb[questionName]; !alreadyAdded { toAdd = append(toAdd, questionName) } } if len(toAdd) == 0 { return nil } query := fmt.Sprintf(`INSERT INTO question (name) VALUES %s;`, toPlaceholders(len(toAdd), "(?)")) result, err := db.Exec(query, toAdd...) if err != nil { return err } numRows, err := result.RowsAffected() if err != nil { return err } if numRows > 0 { log.Info("added questions", "number_added", numRows) } return nil } var hQuestion = []byte("id,name\n") var hAnswer = []byte("id,question.name,timestamp,answer\n") var hVersion = []byte("version,dirty\n") func runQuestionsTimer(questions []Question, log *slog.Logger, db *sql.DB) { for _, question := range questions { question := question go func() { if question.Delay > 0 { time.Sleep(time.Second * time.Duration(question.Delay)) } for { _, err := runQuestion(log, &question, db) if err != nil { log.Error("Error running question", "question", question.Name, "err", err) } time.Sleep(time.Second * time.Duration(question.Frequency)) } }() } } func main() { log := slog.New(slog.NewJSONHandler(os.Stderr, nil)) config := getConfig() log.Info("config as read", "values", config) db, err := openAndMigrateDB(log, config) if err != nil { log.Error("Error opening or migrating db", "err", err) os.Exit(1) } questions, questionsMap, err := loadQuestions(config) if err != nil { log.Error("Error reading questions", "err", err) os.Exit(1) } err = bootstrapQuestions(log, db, questions, questionsMap) if err != nil { log.Error("Error bootstrapping questions", "err", err) os.Exit(1) } if !config.SkipTimer { go runQuestionsTimer(questions, log, db) } http.HandleFunc("/registeredquestions", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { questionNames := []string{} for _, q := range questions { questionNames = append(questionNames, q.Name) } bQuestionNames, err := json.Marshal(questionNames) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.Write(bQuestionNames) })) http.HandleFunc("/dump", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { rows, err := db.Query("SELECT id,name FROM question;") if err != nil { log.Error("bad question select", "err", err) w.WriteHeader(http.StatusInternalServerError) return } w.Write(hQuestion) for rows.Next() { var id int var name string err = rows.Scan(&id, &name) if err != nil { log.Error("bad answer scan", "err", err) return } w.Write([]byte(fmt.Sprintf("%d,%s\n", id, name))) } rows, err = db.Query(`SELECT a.id,q.name,a.timestamp,a.answer FROM answer AS a LEFT JOIN question AS q ON a.question_id = q.id ORDER BY q.name ASC, a.timestamp ASC ;`) if err != nil { log.Error("bad answer select", "err", err) w.WriteHeader(http.StatusInternalServerError) return } w.Write(hAnswer) for rows.Next() { var id int var qName string var timestamp time.Time var answer string err = rows.Scan(&id, &qName, &timestamp, &answer) if err != nil { log.Error("bad answer scan", "err", err) return } w.Write([]byte(fmt.Sprintf("%d,%s,%v,%s\n", id, qName, timestamp, answer))) } rows, err = db.Query("SELECT version,dirty FROM db_version;") if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.Write(hVersion) for rows.Next() { var version int var dirty bool err = rows.Scan(&version, &dirty) if err != nil { log.Error("bad version scan", "err", err) return } w.Write([]byte(fmt.Sprintf("%d,%v\n", version, dirty))) } })) if config.EnableRun { http.HandleFunc("/run", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { q := r.URL.Query() qName := q.Get("name") idx, ok := questionsMap[qName] if !ok { log.Info("404", "question", qName) w.WriteHeader(http.StatusBadRequest) return } question := questions[idx] answer, err := runQuestion(log, &question, db) if err != nil { log.Error("failure to run question", "err", err) w.WriteHeader(http.StatusInternalServerError) return } w.Write([]byte(answer)) })) } http.HandleFunc("/version", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var version int row := db.QueryRow("SELECT version FROM db_version LIMIT 1;") err := row.Scan(&version) if err != nil { w.WriteHeader(http.StatusInternalServerError) return } w.Write([]byte(fmt.Sprintf("%d\n", version))) })) http.ListenAndServe(fmt.Sprintf(":%d", config.Port), nil) }