// Copyright (C) 2025 Opsmate, Inc. // // This Source Code Form is subject to the terms of the Mozilla // Public License, v. 2.0. If a copy of the MPL was not distributed // with this file, You can obtain one at http://mozilla.org/MPL/2.0/. // // This software is distributed WITHOUT A WARRANTY OF ANY KIND. // See the Mozilla Public License for details. package monitor import ( "context" "errors" "fmt" "golang.org/x/sync/errgroup" "log" mathrand "math/rand/v2" "net/url" "slices" "time" "software.sslmate.com/src/certspotter/ctclient" "software.sslmate.com/src/certspotter/ctcrypto" "software.sslmate.com/src/certspotter/cttypes" "software.sslmate.com/src/certspotter/loglist" "software.sslmate.com/src/certspotter/merkletree" "software.sslmate.com/src/certspotter/sequencer" ) const ( getSTHInterval = 5 * time.Minute ) func downloadJobSize(ctlog *loglist.Log) uint64 { if ctlog.IsStaticCTAPI() { return ctclient.StaticTileWidth } else if ctlog.CertspotterDownloadSize != 0 { return uint64(ctlog.CertspotterDownloadSize) } else { return 1000 } } func downloadWorkers(ctlog *loglist.Log) int { if ctlog.CertspotterDownloadJobs != 0 { return ctlog.CertspotterDownloadJobs } else { return 1 } } type verifyEntriesError struct { sth *cttypes.SignedTreeHead entriesRootHash merkletree.Hash } func (e *verifyEntriesError) Error() string { return fmt.Sprintf("error verifying at tree size %d: the STH root hash (%x) does not match the entries returned by the log (%x)", e.sth.TreeSize, e.sth.RootHash, e.entriesRootHash) } func withRetry(ctx context.Context, maxRetries int, f func() error) error { const minSleep = 1 * time.Second const maxSleep = 10 * time.Minute numRetries := 0 for ctx.Err() == nil { err := f() if err == nil || errors.Is(err, context.Canceled) { return err } if maxRetries != -1 && numRetries >= maxRetries { return fmt.Errorf("%w (retried %d times)", err, numRetries) } upperBound := min(minSleep*(1< downloadWorker ==> processWorker ==> saveStateWorker batches := make(chan *batch, downloadWorkers(ctlog)) processedBatches := sequencer.New[batch](0, uint64(downloadWorkers(ctlog))*10) group, gctx := errgroup.WithContext(ctx) group.Go(func() error { return getSTHWorker(gctx, config, ctlog, client) }) group.Go(func() error { return generateBatchesWorker(gctx, config, ctlog, position, batches) }) for range downloadWorkers(ctlog) { downloadedBatches := make(chan *batch, 1) group.Go(func() error { return downloadWorker(gctx, config, ctlog, client, batches, downloadedBatches) }) group.Go(func() error { return processWorker(gctx, config, ctlog, issuerGetter, downloadedBatches, processedBatches) }) } group.Go(func() error { return saveStateWorker(gctx, config, ctlog, state, processedBatches) }) err = group.Wait() if verifyErr := (*verifyEntriesError)(nil); errors.As(err, &verifyErr) { recordError(ctx, config, ctlog, verifyErr) state.rewindDownloadPosition() if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil { return fmt.Errorf("error storing log state: %w", err) } if err := sleep(ctx, 5*time.Minute); err != nil { return err } goto retry } return err } func getSTHWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log) error { for ctx.Err() == nil { sth, _, err := client.GetSTH(ctx) if err != nil { return err } if err := config.State.StoreSTH(ctx, ctlog.LogID, sth); err != nil { return fmt.Errorf("error storing STH: %w", err) } if err := sleep(ctx, getSTHInterval); err != nil { return err } } return ctx.Err() } type batch struct { number uint64 begin, end uint64 sths []*cttypes.SignedTreeHead // STHs with sizes in range [begin,end], sorted by TreeSize entries []ctclient.Entry // in range [begin,end) } func generateBatchesWorker(ctx context.Context, config *Config, ctlog *loglist.Log, position uint64, batches chan<- *batch) error { ticker := time.NewTicker(15 * time.Second) var number uint64 for ctx.Err() == nil { sths, err := config.State.LoadSTHs(ctx, ctlog.LogID) if err != nil { return fmt.Errorf("error loading STHs: %w", err) } for len(sths) > 0 && sths[0].TreeSize < position { // TODO-4: audit sths[0] against log's verified STH if err := config.State.RemoveSTH(ctx, ctlog.LogID, sths[0]); err != nil { return fmt.Errorf("error removing STH: %w", err) } sths = sths[1:] } position, number, err = generateBatches(ctx, ctlog, position, number, sths, batches) if err != nil { return err } select { case <-ctx.Done(): return ctx.Err() case <-ticker.C: } } return ctx.Err() } // return the earliest STH timestamp within the right-most tile func tileEarliestTimestamp(sths []*cttypes.SignedTreeHead) time.Time { largestSTH, sths := sths[len(sths)-1], sths[:len(sths)-1] tileNumber := largestSTH.TreeSize / ctclient.StaticTileWidth earliest := largestSTH.TimestampTime() for _, sth := range slices.Backward(sths) { if sth.TreeSize/ctclient.StaticTileWidth != tileNumber { break } if timestamp := sth.TimestampTime(); timestamp.Before(earliest) { earliest = timestamp } } return earliest } func generateBatches(ctx context.Context, ctlog *loglist.Log, position uint64, number uint64, sths []*cttypes.SignedTreeHead, batches chan<- *batch) (uint64, uint64, error) { downloadJobSize := downloadJobSize(ctlog) if len(sths) == 0 { return position, number, nil } largestSTH := sths[len(sths)-1] treeSize := largestSTH.TreeSize if ctlog.IsStaticCTAPI() && time.Since(tileEarliestTimestamp(sths)) < 5*time.Minute { // Round down to the tile boundary to avoid downloading a partial tile that was recently discovered // In a future invocation of this function, either enough time will have passed that this code path will be skipped, or the log will have grown and treeSize will be rounded to a larger tile boundary treeSize -= treeSize % ctclient.StaticTileWidth } for { batch := &batch{ number: number, begin: position, end: min(treeSize, (position/downloadJobSize+1)*downloadJobSize), } for len(sths) > 0 && sths[0].TreeSize <= batch.end { batch.sths = append(batch.sths, sths[0]) sths = sths[1:] } select { case <-ctx.Done(): return position, number, ctx.Err() default: } select { case <-ctx.Done(): return position, number, ctx.Err() case batches <- batch: } number++ if position == batch.end { break } position = batch.end } return position, number, nil } func downloadWorker(ctx context.Context, config *Config, ctlog *loglist.Log, client ctclient.Log, batchesIn <-chan *batch, batchesOut chan<- *batch) error { for { select { case <-ctx.Done(): return ctx.Err() default: } var batch *batch select { case <-ctx.Done(): return ctx.Err() case batch = <-batchesIn: } entries, err := getEntriesFull(ctx, client, batch.begin, batch.end-1) if err != nil { return err } batch.entries = entries select { case <-ctx.Done(): return ctx.Err() default: } select { case <-ctx.Done(): return ctx.Err() case batchesOut <- batch: } } return nil } func processWorker(ctx context.Context, config *Config, ctlog *loglist.Log, issuerGetter ctclient.IssuerGetter, batchesIn <-chan *batch, batchesOut *sequencer.Channel[batch]) error { for { select { case <-ctx.Done(): return ctx.Err() default: } var batch *batch select { case <-ctx.Done(): return ctx.Err() case batch = <-batchesIn: } for offset, entry := range batch.entries { index := batch.begin + uint64(offset) if err := processLogEntry(ctx, config, issuerGetter, &LogEntry{ Entry: entry, Index: index, Log: ctlog, }); err != nil { return fmt.Errorf("error processing entry %d: %w", index, err) } } if err := batchesOut.Add(ctx, batch.number, batch); err != nil { return err } } } func saveStateWorker(ctx context.Context, config *Config, ctlog *loglist.Log, state *LogState, batchesIn *sequencer.Channel[batch]) error { for { batch, err := batchesIn.Next(ctx) if err != nil { return err } if batch.begin != state.DownloadPosition.Size() { panic(fmt.Errorf("saveStateWorker: expected batch to start at %d but got %d instead", state.DownloadPosition.Size(), batch.begin)) } rootHash := state.DownloadPosition.CalculateRoot() for { for len(batch.sths) > 0 && batch.sths[0].TreeSize == state.DownloadPosition.Size() { sth := batch.sths[0] batch.sths = batch.sths[1:] if sth.RootHash != rootHash { return &verifyEntriesError{ sth: sth, entriesRootHash: rootHash, } } state.advanceVerifiedPosition() state.VerifiedSTH = sth if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil { return fmt.Errorf("error storing log state: %w", err) } // don't remove the STH until state has been durably stored if err := config.State.RemoveSTH(ctx, ctlog.LogID, sth); err != nil { return fmt.Errorf("error removing verified STH: %w", err) } } if len(batch.entries) == 0 { break } entry := batch.entries[0] batch.entries = batch.entries[1:] leafHash := merkletree.HashLeaf(entry.LeafInput()) state.DownloadPosition.Add(leafHash) rootHash = state.DownloadPosition.CalculateRoot() } if err := config.State.StoreLogState(ctx, ctlog.LogID, state); err != nil { return fmt.Errorf("error storing log state: %w", err) } } } func sleep(ctx context.Context, duration time.Duration) error { timer := time.NewTimer(duration) defer timer.Stop() select { case <-ctx.Done(): return ctx.Err() case <-timer.C: return nil } }