package main
import (
"crypto/sha256"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"os"
"regexp"
"strconv"
"sync"
"time"
)
type config struct {
dir string
pepper []byte
sweepInterval int64
maxReads int
port string
}
const defaultMaxReads = 2000
func sweep(cache *fileEntryCache, conf *config) {
for {
time.Sleep(time.Duration(conf.sweepInterval) * time.Second)
slog.Info("sweeping expired files")
now := time.Now()
cache.mu.Lock()
for filename, cacheState := range cache.data {
reads := cacheState.history.reads
for maxReads := range cacheState.history.maxReads {
if maxReads >= reads {
_ = os.Remove(conf.dir + "/" + filename)
delete(cache.data, filename)
}
}
for expireTimeStr := range cacheState.history.expireTimes {
expireBy, err := time.Parse(time.RFC3339, expireTimeStr)
if err != nil {
// NOTE: should never be possible, but just in case,
// delete the item to safeguard against always keeping an item in cache
_ = os.Remove(conf.dir + "/" + filename)
delete(cache.data, filename)
continue
}
if now.After(expireBy) {
_ = os.Remove(conf.dir + "/" + filename)
delete(cache.data, filename)
}
}
}
cache.mu.Unlock()
}
}
func NewConfig() (*config, error) {
dir := os.Getenv("AWLFILE_DIR")
if dir == "" {
dir = "."
}
pepper := os.Getenv("AWLFILE_PEPPER")
if pepper == "" {
return nil, fmt.Errorf("Must provide a pepper")
}
awlfileMaxReadsStr := os.Getenv("AWLFILE_MAX_READS")
maxReads := defaultMaxReads
if maxReadsFromEnv, err := strconv.Atoi(awlfileMaxReadsStr); err == nil {
maxReads = maxReadsFromEnv
}
var sweepInterval int64 = 0
awlfileSleepIntervalStr := os.Getenv("AWLFILE_SWEEP_INTERVAL")
if awlfileSleepIntervalStr != "" {
sweepIntervalNum, err := strconv.ParseInt(awlfileSleepIntervalStr, 10, 64)
if err != nil {
return nil, fmt.Errorf("expected AWLFILE_SWEEP_INTERVAL to be a number: %s", err.Error())
}
if sweepIntervalNum < 0 {
return nil, fmt.Errorf("expected AWLFILE_SWEEP_INTERVAL to be nonnegative")
}
sweepInterval = sweepIntervalNum
}
port := "24242"
awlfilePortStr := os.Getenv("AWLFILE_PORT")
if awlfilePortStr != "" {
portInt, err := strconv.Atoi(awlfilePortStr)
if err != nil {
return nil, fmt.Errorf("expected AWLFILE_PORT to be a number: %s", err.Error())
}
if portInt < 1 || portInt > 65535 {
return nil, fmt.Errorf("invalid AWLFILE_PORT number, must be between 1 and 65535: %d", portInt)
}
port = awlfilePortStr
}
return &config{
dir: dir,
pepper: []byte(pepper),
maxReads: maxReads,
sweepInterval: sweepInterval,
port: port,
}, nil
}
var errExpired = fmt.Errorf("Item expired")
var errNotFound = fmt.Errorf("Not found")
var reSha256 = regexp.MustCompile("[A-Fa-f0-9]{64}")
func doHash(filePath string, pepper []byte, accessPolicy []byte) ([]byte, []byte, []byte, error) {
file, err := os.Open(filePath)
if errors.Is(err, os.ErrNotExist) {
return nil, nil, nil, errNotFound
} else if err != nil {
return nil, nil, nil, err
}
defer file.Close()
h := sha256.New()
if _, err := io.Copy(h, file); err != nil {
return nil, nil, nil, err
}
contentHash := h.Sum(nil)
if _, err := h.Write(pepper); err != nil {
return nil, nil, nil, err
}
pepperHash := h.Sum(nil)
if _, err := h.Write(accessPolicy); err != nil {
return nil, nil, nil, err
}
accessPolicyHash := h.Sum(nil)
return contentHash, pepperHash, accessPolicyHash, nil
}
func hashFile(cache *fileEntryCache, dir string, dirEntry os.DirEntry, pepper []byte, accessPolicy []byte, expiresInt int, expiresTime time.Time, idHash string, userAccessTime time.Time) (*fileCacheItem, error) {
info, err := dirEntry.Info()
if err != nil {
return nil, err
}
fName, fSize, fModTime := info.Name(), info.Size(), info.ModTime()
cache.mu.Lock()
defer cache.mu.Unlock()
cacheItem, exists := cache.data[fName]
if exists && (cacheItem.size != fSize || cacheItem.modTime != fModTime.Format(time.RFC3339)) {
exists = false
cacheItem = &fileCacheItem{}
delete(cache.data, fName)
}
if exists {
hist := cacheItem.history
found := false
for maxReads, maxReadHash := range hist.maxReads {
if maxReads >= hist.reads {
_ = os.Remove(dir + "/" + fName)
delete(cache.data, fName)
return nil, errExpired
}
if fmt.Sprintf("%x", maxReadHash) == idHash {
found = true
}
}
for expireTimeStr, expireTimeHash := range hist.expireTimes {
expireBy, err := time.Parse(time.RFC3339, expireTimeStr)
if err != nil {
continue
}
if userAccessTime.After(expireBy) {
_ = os.Remove(dir + "/" + fName)
delete(cache.data, fName)
return nil, errExpired
}
if fmt.Sprintf("%x", expireTimeHash) == idHash {
found = true
}
}
if found == true {
return cacheItem, nil
}
_, _, accessPolicyHash, err := doHash(dir+"/"+fName, pepper, accessPolicy)
if err != nil {
return nil, err
}
if fmt.Sprintf("%x", accessPolicyHash) != idHash {
return nil, errNotFound
}
if expiresInt > 0 {
cacheItem.history.maxReads[expiresInt] = accessPolicyHash
}
if expiresTime.After(timeZero) {
cacheItem.history.expireTimes[expiresTime.Format(time.RFC3339)] = accessPolicyHash
}
cache.data[fName] = cacheItem
return cacheItem, nil
}
contentHash, pepperHash, accessPolicyHash, err := doHash(dir+"/"+fName, pepper, accessPolicy)
accesses := accesses{
reads: 0,
maxReads: map[int][]byte{},
expireTimes: map[string][]byte{},
}
if expiresInt > 0 {
accesses.maxReads[expiresInt] = accessPolicyHash
}
if expiresTime.After(timeZero) {
accesses.expireTimes[expiresTime.Format(time.RFC3339)] = accessPolicyHash
}
result := fileCacheItem{
name: fName,
size: fSize,
modTime: fModTime.Format(time.RFC3339),
contentHash: contentHash,
pepperHash: pepperHash,
history: accesses,
}
cache.data[fName] = &result
if fmt.Sprintf("%x", accessPolicyHash) != idHash {
return nil, errNotFound
}
return &result, nil
}
type accesses struct {
// how many times the file has been served in the given session
reads int
// different expire by timestamps the file has been served with
expireTimes map[string][]byte
// diferent max reads teh filel has been served with
maxReads map[int][]byte
}
type fileCacheItem struct {
name string
size int64
modTime string
contentHash []byte
pepperHash []byte
history accesses
}
type fileEntryCache struct {
mu sync.RWMutex
data map[string]*fileCacheItem
}
var timeZero time.Time
var generic500 = "Failed to fetch file"
func mkHandler(cache *fileEntryCache, config *config) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
accessedAt := time.Now()
id := r.PathValue("id")
if !reSha256.Match([]byte(id)) {
http.Error(w, "not a sha256sum", http.StatusBadRequest)
return
}
expirePolicyStr := r.PathValue("expirepolicy")
var expirePolicyHash []byte
var err error
expiresInt := 0
var expireTime time.Time
if expirePolicyStr != "" {
expirePolicyHash = []byte(expirePolicyStr)
expiresInt, err = strconv.Atoi(expirePolicyStr)
if err != nil {
expireTime, err = time.Parse(time.RFC3339, expirePolicyStr)
if err != nil {
http.Error(w, "bad timestamp or max access count parameter passed", http.StatusBadRequest)
return
}
expirePolicyHash = []byte(expireTime.Format(time.RFC3339))
}
if expiresInt > config.maxReads {
http.Error(w, fmt.Sprintf("Reading %d times not allowed", expiresInt), http.StatusBadRequest)
return
}
}
entries, err := os.ReadDir(config.dir)
if err != nil {
slog.Error("could not read files dir", "err", err.Error())
http.Error(w, generic500, http.StatusInternalServerError)
return
}
var result *fileCacheItem
var entry os.DirEntry
for _, entry = range entries {
if entry.IsDir() {
continue
}
result, err = hashFile(cache, config.dir, entry, config.pepper, expirePolicyHash, expiresInt, expireTime, id, accessedAt)
if err != nil {
if err == errNotFound {
continue
}
if err == errExpired {
http.Error(w, "not found", http.StatusNotFound)
slog.Info("content expired")
return
}
slog.Error("err checking file hash", "err", err.Error())
http.Error(w, generic500, http.StatusInternalServerError)
return
}
if result == nil {
continue
}
break
}
if result == nil {
http.Error(w, "not found", http.StatusNotFound)
return
}
entryInfo, err := entry.Info()
if err != nil {
slog.Error("err getting file info for file that was successfully hashed", "err", err.Error())
http.Error(w, generic500, http.StatusInternalServerError)
return
}
filePath := config.dir + "/" + entry.Name()
file, err := os.Open(filePath)
if errors.Is(err, os.ErrNotExist) {
slog.Error("file that was successfully hashed does not exist. should not happen.", "err", err.Error())
http.Error(w, generic500, http.StatusInternalServerError)
return
} else if err != nil {
slog.Error("err opening file for file that was successfully hashed", "err", err.Error())
http.Error(w, generic500, http.StatusInternalServerError)
return
}
defer file.Close()
w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(entry.Name()))
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Last-Modified", entryInfo.ModTime().Format(http.TimeFormat))
var copyWriter io.Writer = w
// TODO: Figure out why gzipping was causing the receiving end to not automatically decompress
// var gzWriter *gzip.Writer
// fileSize := entryInfo.Size()
// if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") && fileSize >= 1<<20 {
// w.Header().Set("Content-Encoding", "gzip")
// copyWriter = gzip.NewWriter(w)
// }
_, err = io.Copy(copyWriter, file)
if err != nil {
slog.Error("err writing content", "err", err.Error())
return
}
cache.mu.Lock()
defer cache.mu.Unlock()
item, ok := cache.data[entry.Name()]
if ok {
item.history.reads += 1
}
}
}
func main() {
config, err := NewConfig()
if err != nil {
slog.Error("could not configure app", "err", err.Error())
os.Exit(1)
}
cache := fileEntryCache{data: map[string]*fileCacheItem{}}
handler := mkHandler(&cache, config)
http.HandleFunc("GET /{id}", handler)
http.HandleFunc("GET /{id}/{expirepolicy}", handler)
if config.sweepInterval > 0 {
go sweep(&cache, config)
}
listenAddr := fmt.Sprintf("0.0.0.0:%s", config.port)
slog.Info("listening on " + listenAddr)
err = http.ListenAndServe(listenAddr, nil)
slog.Info("err listening", "err", err.Error())
}