package main import ( "crypto/sha256" "errors" "fmt" "io" "log/slog" "net/http" "os" "regexp" "strconv" "sync" "time" ) type config struct { dir string pepper []byte sweepInterval int64 maxReads int port string } const defaultMaxReads = 2000 func sweep(cache *fileEntryCache, conf *config) { for { time.Sleep(time.Duration(conf.sweepInterval) * time.Second) slog.Info("sweeping expired files") now := time.Now() cache.mu.Lock() for filename, cacheState := range cache.data { reads := cacheState.history.reads for maxReads := range cacheState.history.maxReads { if maxReads >= reads { _ = os.Remove(conf.dir + "/" + filename) delete(cache.data, filename) } } for expireTimeStr := range cacheState.history.expireTimes { expireBy, err := time.Parse(time.RFC3339, expireTimeStr) if err != nil { // NOTE: should never be possible, but just in case, // delete the item to safeguard against always keeping an item in cache _ = os.Remove(conf.dir + "/" + filename) delete(cache.data, filename) continue } if now.After(expireBy) { _ = os.Remove(conf.dir + "/" + filename) delete(cache.data, filename) } } } cache.mu.Unlock() } } func NewConfig() (*config, error) { dir := os.Getenv("AWLFILE_DIR") if dir == "" { dir = "." } pepper := os.Getenv("AWLFILE_PEPPER") if pepper == "" { return nil, fmt.Errorf("Must provide a pepper") } awlfileMaxReadsStr := os.Getenv("AWLFILE_MAX_READS") maxReads := defaultMaxReads if maxReadsFromEnv, err := strconv.Atoi(awlfileMaxReadsStr); err == nil { maxReads = maxReadsFromEnv } var sweepInterval int64 = 0 awlfileSleepIntervalStr := os.Getenv("AWLFILE_SWEEP_INTERVAL") if awlfileSleepIntervalStr != "" { sweepIntervalNum, err := strconv.ParseInt(awlfileSleepIntervalStr, 10, 64) if err != nil { return nil, fmt.Errorf("expected AWLFILE_SWEEP_INTERVAL to be a number: %s", err.Error()) } if sweepIntervalNum < 0 { return nil, fmt.Errorf("expected AWLFILE_SWEEP_INTERVAL to be nonnegative") } sweepInterval = sweepIntervalNum } port := "24242" awlfilePortStr := os.Getenv("AWLFILE_PORT") if awlfilePortStr != "" { portInt, err := strconv.Atoi(awlfilePortStr) if err != nil { return nil, fmt.Errorf("expected AWLFILE_PORT to be a number: %s", err.Error()) } if portInt < 1 || portInt > 65535 { return nil, fmt.Errorf("invalid AWLFILE_PORT number, must be between 1 and 65535: %d", portInt) } port = awlfilePortStr } return &config{ dir: dir, pepper: []byte(pepper), maxReads: maxReads, sweepInterval: sweepInterval, port: port, }, nil } var errExpired = fmt.Errorf("Item expired") var errNotFound = fmt.Errorf("Not found") var reSha256 = regexp.MustCompile("[A-Fa-f0-9]{64}") func doHash(filePath string, pepper []byte, accessPolicy []byte) ([]byte, []byte, []byte, error) { file, err := os.Open(filePath) if errors.Is(err, os.ErrNotExist) { return nil, nil, nil, errNotFound } else if err != nil { return nil, nil, nil, err } defer file.Close() h := sha256.New() if _, err := io.Copy(h, file); err != nil { return nil, nil, nil, err } contentHash := h.Sum(nil) if _, err := h.Write(pepper); err != nil { return nil, nil, nil, err } pepperHash := h.Sum(nil) if _, err := h.Write(accessPolicy); err != nil { return nil, nil, nil, err } accessPolicyHash := h.Sum(nil) return contentHash, pepperHash, accessPolicyHash, nil } func hashFile(cache *fileEntryCache, dir string, dirEntry os.DirEntry, pepper []byte, accessPolicy []byte, expiresInt int, expiresTime time.Time, idHash string, userAccessTime time.Time) (*fileCacheItem, error) { info, err := dirEntry.Info() if err != nil { return nil, err } fName, fSize, fModTime := info.Name(), info.Size(), info.ModTime() cache.mu.Lock() defer cache.mu.Unlock() cacheItem, exists := cache.data[fName] if exists && (cacheItem.size != fSize || cacheItem.modTime != fModTime.Format(time.RFC3339)) { exists = false cacheItem = &fileCacheItem{} delete(cache.data, fName) } if exists { hist := cacheItem.history found := false for maxReads, maxReadHash := range hist.maxReads { if maxReads >= hist.reads { _ = os.Remove(dir + "/" + fName) delete(cache.data, fName) return nil, errExpired } if fmt.Sprintf("%x", maxReadHash) == idHash { found = true } } for expireTimeStr, expireTimeHash := range hist.expireTimes { expireBy, err := time.Parse(time.RFC3339, expireTimeStr) if err != nil { continue } if userAccessTime.After(expireBy) { _ = os.Remove(dir + "/" + fName) delete(cache.data, fName) return nil, errExpired } if fmt.Sprintf("%x", expireTimeHash) == idHash { found = true } } if found == true { return cacheItem, nil } _, _, accessPolicyHash, err := doHash(dir+"/"+fName, pepper, accessPolicy) if err != nil { return nil, err } if fmt.Sprintf("%x", accessPolicyHash) != idHash { return nil, errNotFound } if expiresInt > 0 { cacheItem.history.maxReads[expiresInt] = accessPolicyHash } if expiresTime.After(timeZero) { cacheItem.history.expireTimes[expiresTime.Format(time.RFC3339)] = accessPolicyHash } cache.data[fName] = cacheItem return cacheItem, nil } contentHash, pepperHash, accessPolicyHash, err := doHash(dir+"/"+fName, pepper, accessPolicy) accesses := accesses{ reads: 0, maxReads: map[int][]byte{}, expireTimes: map[string][]byte{}, } if expiresInt > 0 { accesses.maxReads[expiresInt] = accessPolicyHash } if expiresTime.After(timeZero) { accesses.expireTimes[expiresTime.Format(time.RFC3339)] = accessPolicyHash } result := fileCacheItem{ name: fName, size: fSize, modTime: fModTime.Format(time.RFC3339), contentHash: contentHash, pepperHash: pepperHash, history: accesses, } cache.data[fName] = &result if fmt.Sprintf("%x", accessPolicyHash) != idHash { return nil, errNotFound } return &result, nil } type accesses struct { // how many times the file has been served in the given session reads int // different expire by timestamps the file has been served with expireTimes map[string][]byte // diferent max reads teh filel has been served with maxReads map[int][]byte } type fileCacheItem struct { name string size int64 modTime string contentHash []byte pepperHash []byte history accesses } type fileEntryCache struct { mu sync.RWMutex data map[string]*fileCacheItem } var timeZero time.Time var generic500 = "Failed to fetch file" func mkHandler(cache *fileEntryCache, config *config) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { accessedAt := time.Now() id := r.PathValue("id") if !reSha256.Match([]byte(id)) { http.Error(w, "not a sha256sum", http.StatusBadRequest) return } expirePolicyStr := r.PathValue("expirepolicy") var expirePolicyHash []byte var err error expiresInt := 0 var expireTime time.Time if expirePolicyStr != "" { expirePolicyHash = []byte(expirePolicyStr) expiresInt, err = strconv.Atoi(expirePolicyStr) if err != nil { expireTime, err = time.Parse(time.RFC3339, expirePolicyStr) if err != nil { http.Error(w, "bad timestamp or max access count parameter passed", http.StatusBadRequest) return } expirePolicyHash = []byte(expireTime.Format(time.RFC3339)) } if expiresInt > config.maxReads { http.Error(w, fmt.Sprintf("Reading %d times not allowed", expiresInt), http.StatusBadRequest) return } } entries, err := os.ReadDir(config.dir) if err != nil { slog.Error("could not read files dir", "err", err.Error()) http.Error(w, generic500, http.StatusInternalServerError) return } var result *fileCacheItem var entry os.DirEntry for _, entry = range entries { if entry.IsDir() { continue } result, err = hashFile(cache, config.dir, entry, config.pepper, expirePolicyHash, expiresInt, expireTime, id, accessedAt) if err != nil { if err == errNotFound { continue } if err == errExpired { http.Error(w, "not found", http.StatusNotFound) slog.Info("content expired") return } slog.Error("err checking file hash", "err", err.Error()) http.Error(w, generic500, http.StatusInternalServerError) return } if result == nil { continue } break } if result == nil { http.Error(w, "not found", http.StatusNotFound) return } entryInfo, err := entry.Info() if err != nil { slog.Error("err getting file info for file that was successfully hashed", "err", err.Error()) http.Error(w, generic500, http.StatusInternalServerError) return } filePath := config.dir + "/" + entry.Name() file, err := os.Open(filePath) if errors.Is(err, os.ErrNotExist) { slog.Error("file that was successfully hashed does not exist. should not happen.", "err", err.Error()) http.Error(w, generic500, http.StatusInternalServerError) return } else if err != nil { slog.Error("err opening file for file that was successfully hashed", "err", err.Error()) http.Error(w, generic500, http.StatusInternalServerError) return } defer file.Close() w.Header().Set("Content-Disposition", "attachment; filename="+strconv.Quote(entry.Name())) w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Last-Modified", entryInfo.ModTime().Format(http.TimeFormat)) var copyWriter io.Writer = w // TODO: Figure out why gzipping was causing the receiving end to not automatically decompress // var gzWriter *gzip.Writer // fileSize := entryInfo.Size() // if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") && fileSize >= 1<<20 { // w.Header().Set("Content-Encoding", "gzip") // copyWriter = gzip.NewWriter(w) // } _, err = io.Copy(copyWriter, file) if err != nil { slog.Error("err writing content", "err", err.Error()) return } cache.mu.Lock() defer cache.mu.Unlock() item, ok := cache.data[entry.Name()] if ok { item.history.reads += 1 } } } func main() { config, err := NewConfig() if err != nil { slog.Error("could not configure app", "err", err.Error()) os.Exit(1) } cache := fileEntryCache{data: map[string]*fileCacheItem{}} handler := mkHandler(&cache, config) http.HandleFunc("GET /{id}", handler) http.HandleFunc("GET /{id}/{expirepolicy}", handler) if config.sweepInterval > 0 { go sweep(&cache, config) } listenAddr := fmt.Sprintf("0.0.0.0:%s", config.port) slog.Info("listening on " + listenAddr) err = http.ListenAndServe(listenAddr, nil) slog.Info("err listening", "err", err.Error()) }