Start on StepMania cache validator

- Fix main.py import error
This commit is contained in:
Ash 2025-10-07 19:34:34 -07:00
parent e9a55fcb87
commit 1994cecde8
4 changed files with 92 additions and 11 deletions

View file

@ -1,6 +1,6 @@
from smoketest.args import SmoketestArgs from simfile_smoketest.args import SmoketestArgs
from smoketest.runner import SmoketestRun from simfile_smoketest.runner import SmoketestRun
from smoketest.storage import DB, MODELS from simfile_smoketest.storage import DB, MODELS
def main(): def main():
@ -11,8 +11,6 @@ def main():
DB.create_tables(MODELS) DB.create_tables(MODELS)
if args.songs_dir: if args.songs_dir:
SmoketestRun(args).process_songs_dir(args.songs_dir) SmoketestRun(args).process_songs_dir(args.songs_dir)
elif args.pack_dir:
SmoketestRun(args).process_pack(args.pack_dir)
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -8,10 +8,13 @@ __all__ = ["SmoketestArgs"]
class SmoketestArgs(Tap): class SmoketestArgs(Tap):
songs_dir: Optional[str] = None songs_dir: Optional[str] = None
"""directory of packs to scan""" """StepMania Songs folder to scan"""
pack_dir: Optional[str] = None cache_dir: Optional[str] = None
"""single pack directory to scan""" """StepMania Cache/Songs directory to compare songs against"""
single_pack: Optional[str] = None
"""Optionally limit the scan to a single, named pack"""
new_only: bool = False new_only: bool = False
"""only scan newly discovered simfiles""" """only scan newly discovered simfiles"""

View file

@ -0,0 +1,43 @@
from simfile.notes import NoteData, NoteType
from simfile.timing import TimingData
from simfile.timing.engine import TimingEngine
from simfile.types import AttachedChart
from typing import Tuple, Sequence
__all__ = ["get_nps"]
class NPSState:
def __init__(self, timing_engine: TimingEngine):
self.timing_engine = timing_engine
self.chart_nps = []
self.chart_nc = []
self.current_measure = 0
self.current_measure_nc = 0
def increment_measure(self):
measure_duration = (
self.timing_engine.time_at((self.current_measure + 1) * 4)
- self.timing_engine.time_at(self.current_measure * 4)
)
self.chart_nps.append(self.current_measure_nc / measure_duration)
self.chart_nc.append(self.current_measure_nc)
self.current_measure_nc = 0
self.current_measure += 1
COUNTED_NOTE_TYPES = (NoteType.TAP, NoteType.HOLD_HEAD, NoteType.ROLL_HEAD, NoteType.LIFT)
def get_nps(chart: AttachedChart) -> Tuple[Sequence[float], Sequence[int]]:
state = NPSState(TimingEngine(TimingData(chart)))
notedata = NoteData(chart)
for note in notedata:
note_measure = int(note.beat) // 4
while note_measure > state.current_measure:
state.increment_measure()
if note.note_type in COUNTED_NOTE_TYPES:
state.current_measure_nc += 1
state.increment_measure()
return (state.chart_nps, state.chart_nc)

View file

@ -2,8 +2,9 @@ from contextlib import contextmanager
import dataclasses import dataclasses
from datetime import datetime from datetime import datetime
import os import os
from pathlib import Path
import traceback import traceback
from typing import Iterator, Optional from typing import Iterator, Optional, Sequence
from . import __version__ from . import __version__
import msdparser import msdparser
@ -13,12 +14,14 @@ from simfile.notes import NoteData
from simfile.notes.group import group_notes, SameBeatNotes from simfile.notes.group import group_notes, SameBeatNotes
from simfile.notes.timed import time_notes from simfile.notes.timed import time_notes
from simfile.ssc import SSCChart from simfile.ssc import SSCChart
from simfile.timing.engine import TimingData from simfile.timing import TimingData
from simfile.timing.engine import TimingEngine
from simfile.timing.displaybpm import displaybpm from simfile.timing.displaybpm import displaybpm
from simfile.types import Simfile, Chart from simfile.types import Simfile, Chart
from .args import SmoketestArgs from .args import SmoketestArgs
from .storage import Run, SimfileObject, SimfileError from .storage import Run, SimfileObject, SimfileError
from .nps import get_nps
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
@ -90,7 +93,8 @@ class SmoketestRun:
def process_songs_dir(self, songs_dir: str): def process_songs_dir(self, songs_dir: str):
for entry in os.scandir(songs_dir): for entry in os.scandir(songs_dir):
if entry.is_dir(): if entry.is_dir():
self.process_pack(entry.path) if not self.args.single_pack or entry.name == self.args.single_pack:
self.process_pack(entry.path)
def process_pack(self, pack_dir: str): def process_pack(self, pack_dir: str):
root_context = SimfileContext(run=self.run, path=pack_dir) root_context = SimfileContext(run=self.run, path=pack_dir)
@ -105,6 +109,13 @@ class SmoketestRun:
if simfile_dir.ssc_path: if simfile_dir.ssc_path:
self._open_simfile(context, simfile_dir.ssc_path) self._open_simfile(context, simfile_dir.ssc_path)
def _open_cached_simfile(self, context: SimfileContext, simfile_path: str) -> Simfile:
cache_filename = str(Path(simfile_path).parent.relative_to(Path(self.args.songs_dir).parent))
cache_filename = cache_filename.replace("/", "_")
cache_path = os.path.join(self.args.cache_dir, cache_filename)
cache_simfile = simfile.open(cache_path)
return cache_simfile
def _open_simfile(self, context: SimfileContext, simfile_path: str): def _open_simfile(self, context: SimfileContext, simfile_path: str):
if self.args.new_only: if self.args.new_only:
if SimfileObject.get_or_none(SimfileObject.path == simfile_path): if SimfileObject.get_or_none(SimfileObject.path == simfile_path):
@ -118,6 +129,10 @@ class SmoketestRun:
): ):
return return
cache_simfile: Optional[Simfile] = None
with context.perform("cache.open", path=simfile_path) as context:
cache_simfile = self._open_cached_simfile(context, simfile_path=simfile_path)
with context.perform("simfile.open", path=simfile_path) as context: with context.perform("simfile.open", path=simfile_path) as context:
sim = simfile.open(simfile_path) sim = simfile.open(simfile_path)
@ -127,6 +142,8 @@ class SmoketestRun:
) as context: ) as context:
self._test_simfile(context, sim) self._test_simfile(context, sim)
self._compare_to_cache(context, sim=sim, cache=cache_simfile)
def _test_simfile(self, context: SimfileContext, sim: Simfile): def _test_simfile(self, context: SimfileContext, sim: Simfile):
with context.perform("TimingData") as context: with context.perform("TimingData") as context:
TimingData(sim) TimingData(sim)
@ -174,3 +191,23 @@ class SmoketestRun:
with context.perform("time_notes") as context: with context.perform("time_notes") as context:
for _ in time_notes(nd, td): for _ in time_notes(nd, td):
pass pass
def _compare_to_cache(self, context: SimfileContext, *, sim: Simfile, cache: Simfile):
with context.perform("NPSPERMEASURE") as context:
for real_chart, cache_chart in zip(sim.charts, cache.charts):
cache_chart_nps = [float(nps) for nps in cache_chart["NPSPERMEASURE"].split(",")]
cache_chart_nc = [int(nc) for nc in cache_chart["NOTESPERMEASURE"].split(",")]
chart_nps, chart_nc = get_nps(real_chart)
for n, (real_measure_nps, cache_measure_nps) in enumerate(zip(chart_nps, cache_chart_nps)):
assert abs(real_measure_nps - cache_measure_nps) < 0.001, \
f"NPSPERMEASURE mismatch on measure {n} ({cache_measure_nps} != {real_measure_nps})\n{cache_chart_nps}\n{chart_nps}"
for real_measure_nc, cache_measure_nc in zip(chart_nc, cache_chart_nc):
assert abs(real_measure_nc - cache_measure_nc) < 0.001, \
f"NOTESPERMEASURE mismatch on measure {n} ({cache_measure_nc} != {real_measure_nc})\n{cache_chart_nc}\n{chart_nc}"
# NOTESPERMEASURE (int list)
# PEAKNPS (float)
# GROOVESTATSHASH (16 hex chars)
# GROOVESTATSHASHVERSION (3)
# STEPFILENAME (/Songs/Pack/Simfile/Simfile.ssc)