diff --git a/neutone_sdk/audio.py b/neutone_sdk/audio.py index daf8212..26cb3e1 100644 --- a/neutone_sdk/audio.py +++ b/neutone_sdk/audio.py @@ -17,6 +17,7 @@ from tqdm import tqdm import neutone_sdk +from neutone_sdk.exceptions import INFERENCE_MODE_EXCEPTION logging.basicConfig() log = logging.getLogger(__name__) @@ -148,84 +149,86 @@ def render_audio_sample( params: either [model.MAX_N_PARAMS] 1d tensor of constant parameter values or [model.MAX_N_PARAMS, input_sample.audio.size(1)] 2d tensor of parameter values for every input audio sample """ + try: + with tr.inference_mode(): + model.use_debug_mode = True # Turn on debug mode to catch common mistakes when rendering sample audio - with tr.no_grad(): - model.use_debug_mode = True # Turn on debug mode to catch common mistakes when rendering sample audio - - preferred_sr = neutone_sdk.SampleQueueWrapper.select_best_model_sr( - input_sample.sr, model.get_native_sample_rates() - ) - if len(model.get_native_buffer_sizes()) > 0: - buffer_size = model.get_native_buffer_sizes()[0] - else: - buffer_size = 512 - - audio = input_sample.audio - if input_sample.sr != preferred_sr: - audio = torchaudio.transforms.Resample(input_sample.sr, preferred_sr)(audio) - - if model.is_input_mono() and not input_sample.is_mono(): - audio = tr.mean(audio, dim=0, keepdim=True) - elif not model.is_input_mono() and input_sample.is_mono(): - audio = audio.repeat(2, 1) - - audio_len = audio.size(1) - padding_amount = math.ceil(audio_len / buffer_size) * buffer_size - audio_len - padded_audio = nn.functional.pad(audio, [0, padding_amount]) - audio_chunks = padded_audio.split(buffer_size, dim=1) + preferred_sr = neutone_sdk.SampleQueueWrapper.select_best_model_sr( + input_sample.sr, model.get_native_sample_rates() + ) + if len(model.get_native_buffer_sizes()) > 0: + buffer_size = model.get_native_buffer_sizes()[0] + else: + buffer_size = 512 - model.set_daw_sample_rate_and_buffer_size( - preferred_sr, buffer_size, preferred_sr, buffer_size - ) + audio = input_sample.audio + if input_sample.sr != preferred_sr: + audio = torchaudio.transforms.Resample(input_sample.sr, preferred_sr)(audio) - # make sure the shape of params is compatible with the model calls. - if params is not None: - assert params.shape[0] == model.MAX_N_PARAMS + if model.is_input_mono() and not input_sample.is_mono(): + audio = tr.mean(audio, dim=0, keepdim=True) + elif not model.is_input_mono() and input_sample.is_mono(): + audio = audio.repeat(2, 1) - # if constant values, copy across audio dimension - if params.dim() == 1: - params = params.repeat([audio_len, 1]).T + audio_len = audio.size(1) + padding_amount = math.ceil(audio_len / buffer_size) * buffer_size - audio_len + padded_audio = nn.functional.pad(audio, [0, padding_amount]) + audio_chunks = padded_audio.split(buffer_size, dim=1) - # otherwise resample to match audio - else: - assert params.shape == (model.MAX_N_PARAMS, input_sample.audio.size(1)) - params = torchaudio.transforms.Resample(input_sample.sr, preferred_sr)( - params - ) - params = tr.clamp(params, 0, 1) - - # padding and chunking parameters to match audio - padded_params = nn.functional.pad( - params, [0, padding_amount], mode="replicate" + model.set_daw_sample_rate_and_buffer_size( + preferred_sr, buffer_size, preferred_sr, buffer_size ) - param_chunks = padded_params.split(buffer_size, dim=1) - out_chunks = [ - model.forward(audio_chunk, param_chunk).clone() - for audio_chunk, param_chunk in tqdm( - zip(audio_chunks, param_chunks), total=len(audio_chunks) + # make sure the shape of params is compatible with the model calls. + if params is not None: + assert params.shape[0] == model.MAX_N_PARAMS + + # if constant values, copy across audio dimension + if params.dim() == 1: + params = params.repeat([audio_len, 1]).T + + # otherwise resample to match audio + else: + assert params.shape == (model.MAX_N_PARAMS, input_sample.audio.size(1)) + params = torchaudio.transforms.Resample(input_sample.sr, preferred_sr)( + params + ) + params = tr.clamp(params, 0, 1) + + # padding and chunking parameters to match audio + padded_params = nn.functional.pad( + params, [0, padding_amount], mode="replicate" ) - ] + param_chunks = padded_params.split(buffer_size, dim=1) + + out_chunks = [ + model.forward(audio_chunk, param_chunk).clone() + for audio_chunk, param_chunk in tqdm( + zip(audio_chunks, param_chunks), total=len(audio_chunks) + ) + ] - else: - out_chunks = [ - model.forward(audio_chunk, None).clone() - for audio_chunk in tqdm(audio_chunks) - ] + else: + out_chunks = [ + model.forward(audio_chunk, None).clone() + for audio_chunk in tqdm(audio_chunks) + ] - audio_out = tr.hstack(out_chunks)[:, :audio_len] + audio_out = tr.hstack(out_chunks)[:, :audio_len] - model.reset() + model.reset() - if preferred_sr != output_sr: - audio_out = torchaudio.transforms.Resample(preferred_sr, output_sr)( - audio_out - ) + if preferred_sr != output_sr: + audio_out = torchaudio.transforms.Resample(preferred_sr, output_sr)( + audio_out + ) - # Make the output audio consistent with the input audio - if audio_out.size(0) == 1 and not input_sample.is_mono(): - audio_out = audio_out.repeat(2, 1) - elif audio_out.size(0) == 2 and input_sample.is_mono(): - audio_out = tr.mean(audio_out, dim=0, keepdim=True) + # Make the output audio consistent with the input audio + if audio_out.size(0) == 1 and not input_sample.is_mono(): + audio_out = audio_out.repeat(2, 1) + elif audio_out.size(0) == 2 and input_sample.is_mono(): + audio_out = tr.mean(audio_out, dim=0, keepdim=True) - return AudioSample(audio_out, output_sr) + return AudioSample(audio_out, output_sr) + except RuntimeError as e: + INFERENCE_MODE_EXCEPTION.raise_if_triggered(e) diff --git a/neutone_sdk/benchmark.py b/neutone_sdk/benchmark.py index ba9ed2e..05efba5 100644 --- a/neutone_sdk/benchmark.py +++ b/neutone_sdk/benchmark.py @@ -7,6 +7,7 @@ import torch from torch.autograd.profiler import record_function from neutone_sdk import constants +from neutone_sdk.exceptions import INFERENCE_MODE_EXCEPTION from neutone_sdk.sqw import SampleQueueWrapper from neutone_sdk.utils import load_neutone_model, model_to_torchscript import numpy as np @@ -90,44 +91,47 @@ def benchmark_speed_( np.set_printoptions(precision=3) torch.set_num_threads(num_threads) torch.set_num_interop_threads(num_interop_threads) - with torch.no_grad(): - m, _ = load_neutone_model(model_file) - log.info( - f"Running benchmark for buffer sizes {buffer_size} and sample rates {sample_rate}. Outliers will be removed from the calculation of mean and std and displayed separately if existing." - ) - for sr, bs in itertools.product(sample_rate, buffer_size): - m.set_daw_sample_rate_and_buffer_size(sr, bs) - for _ in range(n_iters): # Warmup - m.forward(torch.rand((daw_n_ch, bs))) - m.reset() + try: + with torch.inference_mode(): + m, _ = load_neutone_model(model_file) + log.info( + f"Running benchmark for buffer sizes {buffer_size} and sample rates {sample_rate}. Outliers will be removed from the calculation of mean and std and displayed separately if existing." + ) + for sr, bs in itertools.product(sample_rate, buffer_size): + m.set_daw_sample_rate_and_buffer_size(sr, bs) + for _ in range(n_iters): # Warmup + m.forward(torch.rand((daw_n_ch, bs))) + m.reset() - # Pregenerate random buffers to more accurately benchmark the model itself - def get_random_buffer_generator(): - buffers = torch.rand(100, daw_n_ch, bs) - i = 0 + # Pregenerate random buffers to more accurately benchmark the model itself + def get_random_buffer_generator(): + buffers = torch.rand(100, daw_n_ch, bs) + i = 0 - def return_next_random_buffer(): - nonlocal i - i = (i + 1) % 100 - return buffers[i] + def return_next_random_buffer(): + nonlocal i + i = (i + 1) % 100 + return buffers[i] - return return_next_random_buffer + return return_next_random_buffer - rbg = get_random_buffer_generator() + rbg = get_random_buffer_generator() - durations = np.array( - timeit.repeat(lambda: m.forward(rbg()), repeat=repeat, number=n_iters) - ) - m.reset() - mean, std = np.mean(durations), np.std(durations) - outlier_mask = np.abs(durations - mean) > 2 * std - outliers = durations[outlier_mask] - # Remove outliers from general benchmark - durations = durations[~outlier_mask] - mean, std = np.mean(durations), np.std(durations) - log.info( - f"Sample rate: {sr: 6} | Buffer size: {bs: 6} | duration: {mean: 6.3f}±{std:.3f} | 1/RTF: {bs/(mean/n_iters*sr): 6.3f} | Outliers: {outliers[:3]}" - ) + durations = np.array( + timeit.repeat(lambda: m.forward(rbg()), repeat=repeat, number=n_iters) + ) + m.reset() + mean, std = np.mean(durations), np.std(durations) + outlier_mask = np.abs(durations - mean) > 2 * std + outliers = durations[outlier_mask] + # Remove outliers from general benchmark + durations = durations[~outlier_mask] + mean, std = np.mean(durations), np.std(durations) + log.info( + f"Sample rate: {sr: 6} | Buffer size: {bs: 6} | duration: {mean: 6.3f}±{std:.3f} | 1/RTF: {bs/(mean/n_iters*sr): 6.3f} | Outliers: {outliers[:3]}" + ) + except RuntimeError as e: + INFERENCE_MODE_EXCEPTION.raise_if_triggered(e) @cli.command() @@ -163,7 +167,7 @@ def benchmark_latency_( log.info(f"Native buffer sizes: {nbs[:10]}, Native sample rates: {nsr[:10]}") if len(nbs) > 10 or len(nsr) > 10: log.info(f"Showing only the first 10 values in case there are more.") - with torch.no_grad(): + with torch.inference_mode(): delays = [] for sr, bs in itertools.product(sample_rate, buffer_size): m.set_daw_sample_rate_and_buffer_size(sr, bs) @@ -212,34 +216,36 @@ def profile_sqw( sqw.prepare_for_inference() if convert_to_torchscript: log.info("Converting to TorchScript") - with torch.no_grad(): + with torch.inference_mode(): sqw = model_to_torchscript(sqw, freeze=False, optimize=False) - with torch.inference_mode(): - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CPU], - with_stack=True, - profile_memory=True, - record_shapes=False, - ) as prof: - with record_function("forward"): - for audio_buff, param_buff in tqdm(zip(audio_buffers, param_buffers)): - out_buff = sqw.forward(audio_buff, param_buff) + try: + with torch.inference_mode(): + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU], + with_stack=True, + profile_memory=True, + record_shapes=False, + ) as prof: + with record_function("forward"): + for audio_buff, param_buff in tqdm(zip(audio_buffers, param_buffers)): + out_buff = sqw.forward(audio_buff, param_buff) - log.info("Displaying Total CPU Time") - log.info(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) - # log.info(prof.key_averages(group_by_stack_n=5).table(sort_by="cpu_time_total", row_limit=10)) - log.info("Displaying CPU Memory Usage") - log.info( - prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10) - ) - log.info("Displaying Grouped CPU Memory Usage") - log.info( - prof.key_averages(group_by_stack_n=5).table( - sort_by="self_cpu_memory_usage", row_limit=5 + log.info("Displaying Total CPU Time") + log.info(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10)) + # log.info(prof.key_averages(group_by_stack_n=5).table(sort_by="cpu_time_total", row_limit=10)) + log.info("Displaying CPU Memory Usage") + log.info( + prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10) ) - ) - + log.info("Displaying Grouped CPU Memory Usage") + log.info( + prof.key_averages(group_by_stack_n=5).table( + sort_by="self_cpu_memory_usage", row_limit=5 + ) + ) + except RuntimeError as e: + INFERENCE_MODE_EXCEPTION.raise_if_triggered(e) @cli.command() @click.option("--model_file", help="Path to model file") diff --git a/neutone_sdk/exceptions.py b/neutone_sdk/exceptions.py new file mode 100644 index 0000000..98cd12a --- /dev/null +++ b/neutone_sdk/exceptions.py @@ -0,0 +1,51 @@ +import logging +import os + +logging.basicConfig() +log = logging.getLogger(__name__) +log.setLevel(level=os.environ.get("LOGLEVEL", "INFO")) + + +class NeutoneException(Exception): + """ + Custom exception class for Neutone. This is used to wrap other exceptions with more + information and tips when other, more cryptic exceptions are raised. + """ + def __init__(self, message: str, trigger_type: type[Exception], trigger_str: str): + """ + Args: + message: The message to display when this exception is raised. + trigger_type: The type of exception that triggers this exception. + trigger_str: Text that must be in the message of the trigger exception. + """ + super().__init__(message) + self.trigger_type = trigger_type + self.trigger_str = trigger_str + + def raise_if_triggered(self, orig_exception: Exception) -> None: + """ + Raises this exception from the original exception (still includes the stack + trace and information of the original exception) if it is of the trigger type + and contains the trigger string in its message. Otherwise, raises the original + exception. + """ + if (isinstance(orig_exception, self.trigger_type) + and self.trigger_str in str(orig_exception)): + raise self from orig_exception + else: + raise orig_exception + + +# TODO(cm): constant for now, but if we need more of these we could use a factory method +INFERENCE_MODE_EXCEPTION = NeutoneException( + message=""" + Your model does not support inference mode. Ensure you are not calling forward on + your model before wrapping it or saving it using `save_neutone_model()`. Also, try + to make sure that you are not creating new tensors in the forward call of your + model, instead pre-allocate them in the constructor. If these suggestions fail, try + creating and saving your model entirely inside of a `with torch.inference_mode():` + block. + """, + trigger_type=RuntimeError, + trigger_str="Inference tensors cannot be saved for backward." +) diff --git a/neutone_sdk/utils.py b/neutone_sdk/utils.py index e417071..aeac96f 100644 --- a/neutone_sdk/utils.py +++ b/neutone_sdk/utils.py @@ -112,7 +112,7 @@ def save_neutone_model( sqw = SampleQueueWrapper(model) - with tr.no_grad(): + with tr.inference_mode(): log.info("Converting model to torchscript...") script = model_to_torchscript(sqw, freeze=freeze, optimize=optimize) @@ -131,8 +131,8 @@ def save_neutone_model( with open(root_dir / "metadata.json", "w") as f: json.dump(metadata, f, indent=4) - log.info("Running model on audio samples...") if audio_sample_pairs is None: + log.info("Running model on default audio samples...") input_samples = get_default_audio_samples() audio_sample_pairs = [] for input_sample in input_samples: