Start on StepMania cache validator
- Fix main.py import error
This commit is contained in:
parent
e9a55fcb87
commit
1994cecde8
4 changed files with 92 additions and 11 deletions
8
main.py
8
main.py
|
@ -1,6 +1,6 @@
|
|||
from smoketest.args import SmoketestArgs
|
||||
from smoketest.runner import SmoketestRun
|
||||
from smoketest.storage import DB, MODELS
|
||||
from simfile_smoketest.args import SmoketestArgs
|
||||
from simfile_smoketest.runner import SmoketestRun
|
||||
from simfile_smoketest.storage import DB, MODELS
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -11,8 +11,6 @@ def main():
|
|||
DB.create_tables(MODELS)
|
||||
if args.songs_dir:
|
||||
SmoketestRun(args).process_songs_dir(args.songs_dir)
|
||||
elif args.pack_dir:
|
||||
SmoketestRun(args).process_pack(args.pack_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -8,10 +8,13 @@ __all__ = ["SmoketestArgs"]
|
|||
|
||||
class SmoketestArgs(Tap):
|
||||
songs_dir: Optional[str] = None
|
||||
"""directory of packs to scan"""
|
||||
"""StepMania Songs folder to scan"""
|
||||
|
||||
pack_dir: Optional[str] = None
|
||||
"""single pack directory to scan"""
|
||||
cache_dir: Optional[str] = None
|
||||
"""StepMania Cache/Songs directory to compare songs against"""
|
||||
|
||||
single_pack: Optional[str] = None
|
||||
"""Optionally limit the scan to a single, named pack"""
|
||||
|
||||
new_only: bool = False
|
||||
"""only scan newly discovered simfiles"""
|
||||
|
|
43
src/simfile_smoketest/nps.py
Normal file
43
src/simfile_smoketest/nps.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
from simfile.notes import NoteData, NoteType
|
||||
from simfile.timing import TimingData
|
||||
from simfile.timing.engine import TimingEngine
|
||||
from simfile.types import AttachedChart
|
||||
from typing import Tuple, Sequence
|
||||
|
||||
|
||||
__all__ = ["get_nps"]
|
||||
|
||||
|
||||
class NPSState:
|
||||
def __init__(self, timing_engine: TimingEngine):
|
||||
self.timing_engine = timing_engine
|
||||
self.chart_nps = []
|
||||
self.chart_nc = []
|
||||
self.current_measure = 0
|
||||
self.current_measure_nc = 0
|
||||
|
||||
def increment_measure(self):
|
||||
measure_duration = (
|
||||
self.timing_engine.time_at((self.current_measure + 1) * 4)
|
||||
- self.timing_engine.time_at(self.current_measure * 4)
|
||||
)
|
||||
self.chart_nps.append(self.current_measure_nc / measure_duration)
|
||||
self.chart_nc.append(self.current_measure_nc)
|
||||
self.current_measure_nc = 0
|
||||
self.current_measure += 1
|
||||
|
||||
|
||||
COUNTED_NOTE_TYPES = (NoteType.TAP, NoteType.HOLD_HEAD, NoteType.ROLL_HEAD, NoteType.LIFT)
|
||||
|
||||
|
||||
def get_nps(chart: AttachedChart) -> Tuple[Sequence[float], Sequence[int]]:
|
||||
state = NPSState(TimingEngine(TimingData(chart)))
|
||||
notedata = NoteData(chart)
|
||||
for note in notedata:
|
||||
note_measure = int(note.beat) // 4
|
||||
while note_measure > state.current_measure:
|
||||
state.increment_measure()
|
||||
if note.note_type in COUNTED_NOTE_TYPES:
|
||||
state.current_measure_nc += 1
|
||||
state.increment_measure()
|
||||
return (state.chart_nps, state.chart_nc)
|
|
@ -2,8 +2,9 @@ from contextlib import contextmanager
|
|||
import dataclasses
|
||||
from datetime import datetime
|
||||
import os
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
from typing import Iterator, Optional
|
||||
from typing import Iterator, Optional, Sequence
|
||||
|
||||
from . import __version__
|
||||
import msdparser
|
||||
|
@ -13,12 +14,14 @@ from simfile.notes import NoteData
|
|||
from simfile.notes.group import group_notes, SameBeatNotes
|
||||
from simfile.notes.timed import time_notes
|
||||
from simfile.ssc import SSCChart
|
||||
from simfile.timing.engine import TimingData
|
||||
from simfile.timing import TimingData
|
||||
from simfile.timing.engine import TimingEngine
|
||||
from simfile.timing.displaybpm import displaybpm
|
||||
from simfile.types import Simfile, Chart
|
||||
|
||||
from .args import SmoketestArgs
|
||||
from .storage import Run, SimfileObject, SimfileError
|
||||
from .nps import get_nps
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
|
@ -90,6 +93,7 @@ class SmoketestRun:
|
|||
def process_songs_dir(self, songs_dir: str):
|
||||
for entry in os.scandir(songs_dir):
|
||||
if entry.is_dir():
|
||||
if not self.args.single_pack or entry.name == self.args.single_pack:
|
||||
self.process_pack(entry.path)
|
||||
|
||||
def process_pack(self, pack_dir: str):
|
||||
|
@ -105,6 +109,13 @@ class SmoketestRun:
|
|||
if simfile_dir.ssc_path:
|
||||
self._open_simfile(context, simfile_dir.ssc_path)
|
||||
|
||||
def _open_cached_simfile(self, context: SimfileContext, simfile_path: str) -> Simfile:
|
||||
cache_filename = str(Path(simfile_path).parent.relative_to(Path(self.args.songs_dir).parent))
|
||||
cache_filename = cache_filename.replace("/", "_")
|
||||
cache_path = os.path.join(self.args.cache_dir, cache_filename)
|
||||
cache_simfile = simfile.open(cache_path)
|
||||
return cache_simfile
|
||||
|
||||
def _open_simfile(self, context: SimfileContext, simfile_path: str):
|
||||
if self.args.new_only:
|
||||
if SimfileObject.get_or_none(SimfileObject.path == simfile_path):
|
||||
|
@ -118,6 +129,10 @@ class SmoketestRun:
|
|||
):
|
||||
return
|
||||
|
||||
cache_simfile: Optional[Simfile] = None
|
||||
with context.perform("cache.open", path=simfile_path) as context:
|
||||
cache_simfile = self._open_cached_simfile(context, simfile_path=simfile_path)
|
||||
|
||||
with context.perform("simfile.open", path=simfile_path) as context:
|
||||
sim = simfile.open(simfile_path)
|
||||
|
||||
|
@ -127,6 +142,8 @@ class SmoketestRun:
|
|||
) as context:
|
||||
self._test_simfile(context, sim)
|
||||
|
||||
self._compare_to_cache(context, sim=sim, cache=cache_simfile)
|
||||
|
||||
def _test_simfile(self, context: SimfileContext, sim: Simfile):
|
||||
with context.perform("TimingData") as context:
|
||||
TimingData(sim)
|
||||
|
@ -174,3 +191,23 @@ class SmoketestRun:
|
|||
with context.perform("time_notes") as context:
|
||||
for _ in time_notes(nd, td):
|
||||
pass
|
||||
|
||||
def _compare_to_cache(self, context: SimfileContext, *, sim: Simfile, cache: Simfile):
|
||||
with context.perform("NPSPERMEASURE") as context:
|
||||
for real_chart, cache_chart in zip(sim.charts, cache.charts):
|
||||
cache_chart_nps = [float(nps) for nps in cache_chart["NPSPERMEASURE"].split(",")]
|
||||
cache_chart_nc = [int(nc) for nc in cache_chart["NOTESPERMEASURE"].split(",")]
|
||||
chart_nps, chart_nc = get_nps(real_chart)
|
||||
for n, (real_measure_nps, cache_measure_nps) in enumerate(zip(chart_nps, cache_chart_nps)):
|
||||
assert abs(real_measure_nps - cache_measure_nps) < 0.001, \
|
||||
f"NPSPERMEASURE mismatch on measure {n} ({cache_measure_nps} != {real_measure_nps})\n{cache_chart_nps}\n{chart_nps}"
|
||||
for real_measure_nc, cache_measure_nc in zip(chart_nc, cache_chart_nc):
|
||||
assert abs(real_measure_nc - cache_measure_nc) < 0.001, \
|
||||
f"NOTESPERMEASURE mismatch on measure {n} ({cache_measure_nc} != {real_measure_nc})\n{cache_chart_nc}\n{chart_nc}"
|
||||
|
||||
|
||||
# NOTESPERMEASURE (int list)
|
||||
# PEAKNPS (float)
|
||||
# GROOVESTATSHASH (16 hex chars)
|
||||
# GROOVESTATSHASHVERSION (3)
|
||||
# STEPFILENAME (/Songs/Pack/Simfile/Simfile.ssc)
|
||||
|
|
Loading…
Reference in a new issue